Commit 48f105d9 authored by Michael Carilli's avatar Michael Carilli
Browse files

Only save and load master params if training with FP16

parent 327b2446
...@@ -149,6 +149,7 @@ def main(): ...@@ -149,6 +149,7 @@ def main():
args.start_epoch = checkpoint['epoch'] args.start_epoch = checkpoint['epoch']
best_prec1 = checkpoint['best_prec1'] best_prec1 = checkpoint['best_prec1']
model.load_state_dict(checkpoint['state_dict']) model.load_state_dict(checkpoint['state_dict'])
if args.fp16:
saved_master_params = checkpoint['master_params'] saved_master_params = checkpoint['master_params']
for master, saved in zip(master_params, saved_master_params): for master, saved in zip(master_params, saved_master_params):
master.data.copy_(saved.data) master.data.copy_(saved.data)
...@@ -219,14 +220,19 @@ def main(): ...@@ -219,14 +220,19 @@ def main():
if args.local_rank == 0: if args.local_rank == 0:
is_best = prec1 > best_prec1 is_best = prec1 > best_prec1
best_prec1 = max(prec1, best_prec1) best_prec1 = max(prec1, best_prec1)
save_checkpoint({ # Use local scope to avoid dangling references
def create_and_save_checkpoint():
checkpoint_dict = {
'epoch': epoch + 1, 'epoch': epoch + 1,
'arch': args.arch, 'arch': args.arch,
'state_dict': model.state_dict(), 'state_dict': model.state_dict(),
'best_prec1': best_prec1, 'best_prec1': best_prec1,
'optimizer' : optimizer.state_dict(), 'optimizer' : optimizer.state_dict(),
'master_params': master_params, }
}, is_best) if args.fp16:
checkpoint_dict['master_params'] = master_params
save_checkpoint(checkpoint_dict, is_best)
create_and_save_checkpoint()
class data_prefetcher(): class data_prefetcher():
def __init__(self, loader): def __init__(self, loader):
......
...@@ -149,6 +149,7 @@ def main(): ...@@ -149,6 +149,7 @@ def main():
args.start_epoch = checkpoint['epoch'] args.start_epoch = checkpoint['epoch']
best_prec1 = checkpoint['best_prec1'] best_prec1 = checkpoint['best_prec1']
model.load_state_dict(checkpoint['state_dict']) model.load_state_dict(checkpoint['state_dict'])
if args.fp16:
saved_master_params = checkpoint['master_params'] saved_master_params = checkpoint['master_params']
for master, saved in zip(master_params, saved_master_params): for master, saved in zip(master_params, saved_master_params):
master.data.copy_(saved.data) master.data.copy_(saved.data)
...@@ -219,14 +220,19 @@ def main(): ...@@ -219,14 +220,19 @@ def main():
if args.local_rank == 0: if args.local_rank == 0:
is_best = prec1 > best_prec1 is_best = prec1 > best_prec1
best_prec1 = max(prec1, best_prec1) best_prec1 = max(prec1, best_prec1)
save_checkpoint({ # Use local scope to avoid dangling references
def create_and_save_checkpoint():
checkpoint_dict = {
'epoch': epoch + 1, 'epoch': epoch + 1,
'arch': args.arch, 'arch': args.arch,
'state_dict': model.state_dict(), 'state_dict': model.state_dict(),
'best_prec1': best_prec1, 'best_prec1': best_prec1,
'optimizer' : optimizer.state_dict(), 'optimizer' : optimizer.state_dict(),
'master_params': master_params, }
}, is_best) if args.fp16:
checkpoint_dict['master_params'] = master_params
save_checkpoint(checkpoint_dict, is_best)
create_and_save_checkpoint()
class data_prefetcher(): class data_prefetcher():
def __init__(self, loader): def __init__(self, loader):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment