Commit 784994a5 authored by Benjamin Thomas Graham's avatar Benjamin Thomas Graham
Browse files

checkpointing

parent 00ad22e6
......@@ -49,11 +49,11 @@ def ClassificationTrainValidate(model, dataset, p):
'weightDecay': p['weightDecay'],
'epoch': 1
}
if os.path.isfile('epoch.pth'):
if p['checkPoint'] and os.path.isfile('epoch.pth'):
optimState['epoch'] = torch.load('epoch.pth') + 1
print('Restarting at epoch ' +
str(optimState['epoch']) +
' from model.pickle ..')
' from model.pth ..')
model = torch.load('model.pth')
print(p)
......
......@@ -35,4 +35,4 @@ print('input spatial size', spatial_size)
dataset = getIterators(spatial_size, 63, 3)
scn.ClassificationTrainValidate(
model, dataset,
{'nEpochs': 100, 'initial_LR': 0.1, 'LR_decay': 0.05, 'weightDecay': 1e-4})
{'nEpochs': 100, 'initial_LR': 0.1, 'LR_decay': 0.05, 'weightDecay': 1e-4, 'checkPoint': False})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment