Commit 784994a5 authored by Benjamin Thomas Graham's avatar Benjamin Thomas Graham
Browse files

checkpointing

parent 00ad22e6
...@@ -49,11 +49,11 @@ def ClassificationTrainValidate(model, dataset, p): ...@@ -49,11 +49,11 @@ def ClassificationTrainValidate(model, dataset, p):
'weightDecay': p['weightDecay'], 'weightDecay': p['weightDecay'],
'epoch': 1 'epoch': 1
} }
if os.path.isfile('epoch.pth'): if p['checkPoint'] and os.path.isfile('epoch.pth'):
optimState['epoch'] = torch.load('epoch.pth') + 1 optimState['epoch'] = torch.load('epoch.pth') + 1
print('Restarting at epoch ' + print('Restarting at epoch ' +
str(optimState['epoch']) + str(optimState['epoch']) +
' from model.pickle ..') ' from model.pth ..')
model = torch.load('model.pth') model = torch.load('model.pth')
print(p) print(p)
......
...@@ -35,4 +35,4 @@ print('input spatial size', spatial_size) ...@@ -35,4 +35,4 @@ print('input spatial size', spatial_size)
dataset = getIterators(spatial_size, 63, 3) dataset = getIterators(spatial_size, 63, 3)
scn.ClassificationTrainValidate( scn.ClassificationTrainValidate(
model, dataset, model, dataset,
{'nEpochs': 100, 'initial_LR': 0.1, 'LR_decay': 0.05, 'weightDecay': 1e-4}) {'nEpochs': 100, 'initial_LR': 0.1, 'LR_decay': 0.05, 'weightDecay': 1e-4, 'checkPoint': False})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment