Commit 48f105d9 authored by Michael Carilli's avatar Michael Carilli
Browse files

Only save and load master params if training with FP16

parent 327b2446
......@@ -149,9 +149,10 @@ def main():
args.start_epoch = checkpoint['epoch']
best_prec1 = checkpoint['best_prec1']
model.load_state_dict(checkpoint['state_dict'])
saved_master_params = checkpoint['master_params']
for master, saved in zip(master_params, saved_master_params):
master.data.copy_(saved.data)
if args.fp16:
saved_master_params = checkpoint['master_params']
for master, saved in zip(master_params, saved_master_params):
master.data.copy_(saved.data)
optimizer.load_state_dict(checkpoint['optimizer'])
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.resume, checkpoint['epoch']))
......@@ -219,14 +220,19 @@ def main():
if args.local_rank == 0:
is_best = prec1 > best_prec1
best_prec1 = max(prec1, best_prec1)
save_checkpoint({
'epoch': epoch + 1,
'arch': args.arch,
'state_dict': model.state_dict(),
'best_prec1': best_prec1,
'optimizer' : optimizer.state_dict(),
'master_params': master_params,
}, is_best)
# Use local scope to avoid dangling references
def create_and_save_checkpoint():
checkpoint_dict = {
'epoch': epoch + 1,
'arch': args.arch,
'state_dict': model.state_dict(),
'best_prec1': best_prec1,
'optimizer' : optimizer.state_dict(),
}
if args.fp16:
checkpoint_dict['master_params'] = master_params
save_checkpoint(checkpoint_dict, is_best)
create_and_save_checkpoint()
class data_prefetcher():
def __init__(self, loader):
......
......@@ -149,9 +149,10 @@ def main():
args.start_epoch = checkpoint['epoch']
best_prec1 = checkpoint['best_prec1']
model.load_state_dict(checkpoint['state_dict'])
saved_master_params = checkpoint['master_params']
for master, saved in zip(master_params, saved_master_params):
master.data.copy_(saved.data)
if args.fp16:
saved_master_params = checkpoint['master_params']
for master, saved in zip(master_params, saved_master_params):
master.data.copy_(saved.data)
optimizer.load_state_dict(checkpoint['optimizer'])
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.resume, checkpoint['epoch']))
......@@ -219,14 +220,19 @@ def main():
if args.local_rank == 0:
is_best = prec1 > best_prec1
best_prec1 = max(prec1, best_prec1)
save_checkpoint({
'epoch': epoch + 1,
'arch': args.arch,
'state_dict': model.state_dict(),
'best_prec1': best_prec1,
'optimizer' : optimizer.state_dict(),
'master_params': master_params,
}, is_best)
# Use local scope to avoid dangling references
def create_and_save_checkpoint():
checkpoint_dict = {
'epoch': epoch + 1,
'arch': args.arch,
'state_dict': model.state_dict(),
'best_prec1': best_prec1,
'optimizer' : optimizer.state_dict(),
}
if args.fp16:
checkpoint_dict['master_params'] = master_params
save_checkpoint(checkpoint_dict, is_best)
create_and_save_checkpoint()
class data_prefetcher():
def __init__(self, loader):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment