Commit 327b2446 authored by Michael Carilli's avatar Michael Carilli
Browse files

Fixing imagenet main.py and main_reducer.py to save and load master params

parent b7025fc9
...@@ -139,19 +139,25 @@ def main(): ...@@ -139,19 +139,25 @@ def main():
momentum=args.momentum, momentum=args.momentum,
weight_decay=args.weight_decay) weight_decay=args.weight_decay)
# optionally resume from a checkpoint # Optionally resume from a checkpoint
if args.resume: if args.resume:
if os.path.isfile(args.resume): # Use a local scope to avoid dangling references
print("=> loading checkpoint '{}'".format(args.resume)) def resume():
checkpoint = torch.load(args.resume, map_location = lambda storage, loc: storage.cuda(args.gpu)) if os.path.isfile(args.resume):
args.start_epoch = checkpoint['epoch'] print("=> loading checkpoint '{}'".format(args.resume))
best_prec1 = checkpoint['best_prec1'] checkpoint = torch.load(args.resume, map_location = lambda storage, loc: storage.cuda(args.gpu))
model.load_state_dict(checkpoint['state_dict']) args.start_epoch = checkpoint['epoch']
optimizer.load_state_dict(checkpoint['optimizer']) best_prec1 = checkpoint['best_prec1']
print("=> loaded checkpoint '{}' (epoch {})" model.load_state_dict(checkpoint['state_dict'])
.format(args.resume, checkpoint['epoch'])) saved_master_params = checkpoint['master_params']
else: for master, saved in zip(master_params, saved_master_params):
print("=> no checkpoint found at '{}'".format(args.resume)) master.data.copy_(saved.data)
optimizer.load_state_dict(checkpoint['optimizer'])
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.resume, checkpoint['epoch']))
else:
print("=> no checkpoint found at '{}'".format(args.resume))
resume()
# Data loading code # Data loading code
traindir = os.path.join(args.data, 'train') traindir = os.path.join(args.data, 'train')
...@@ -219,6 +225,7 @@ def main(): ...@@ -219,6 +225,7 @@ def main():
'state_dict': model.state_dict(), 'state_dict': model.state_dict(),
'best_prec1': best_prec1, 'best_prec1': best_prec1,
'optimizer' : optimizer.state_dict(), 'optimizer' : optimizer.state_dict(),
'master_params': master_params,
}, is_best) }, is_best)
class data_prefetcher(): class data_prefetcher():
......
...@@ -144,19 +144,23 @@ def main(): ...@@ -144,19 +144,23 @@ def main():
static_loss_scale=args.static_loss_scale, static_loss_scale=args.static_loss_scale,
dynamic_loss_scale=args.dynamic_loss_scale) dynamic_loss_scale=args.dynamic_loss_scale)
# optionally resume from a checkpoint # Optionally resume from a checkpoint
if args.resume: if args.resume:
if os.path.isfile(args.resume): # Use a local scope to avoid dangling references
print("=> loading checkpoint '{}'".format(args.resume)) def resume():
checkpoint = torch.load(args.resume, map_location = lambda storage, loc: storage.cuda(args.gpu)) if os.path.isfile(args.resume):
args.start_epoch = checkpoint['epoch'] print("=> loading checkpoint '{}'".format(args.resume))
best_prec1 = checkpoint['best_prec1'] checkpoint = torch.load(args.resume, map_location = lambda storage, loc: storage.cuda(args.gpu))
model.load_state_dict(checkpoint['state_dict']) args.start_epoch = checkpoint['epoch']
optimizer.load_state_dict(checkpoint['optimizer']) best_prec1 = checkpoint['best_prec1']
print("=> loaded checkpoint '{}' (epoch {})" model.load_state_dict(checkpoint['state_dict'])
.format(args.resume, checkpoint['epoch'])) # An FP16_Optimizer instance's state dict internally stashes the master params.
else: optimizer.load_state_dict(checkpoint['optimizer'])
print("=> no checkpoint found at '{}'".format(args.resume)) print("=> loaded checkpoint '{}' (epoch {})"
.format(args.resume, checkpoint['epoch']))
else:
print("=> no checkpoint found at '{}'".format(args.resume))
resume()
# Data loading code # Data loading code
traindir = os.path.join(args.data, 'train') traindir = os.path.join(args.data, 'train')
......
...@@ -139,19 +139,25 @@ def main(): ...@@ -139,19 +139,25 @@ def main():
momentum=args.momentum, momentum=args.momentum,
weight_decay=args.weight_decay) weight_decay=args.weight_decay)
# optionally resume from a checkpoint # Optionally resume from a checkpoint
if args.resume: if args.resume:
if os.path.isfile(args.resume): # Use a local scope to avoid dangling references
print("=> loading checkpoint '{}'".format(args.resume)) def resume():
checkpoint = torch.load(args.resume, map_location = lambda storage, loc: storage.cuda(args.gpu)) if os.path.isfile(args.resume):
args.start_epoch = checkpoint['epoch'] print("=> loading checkpoint '{}'".format(args.resume))
best_prec1 = checkpoint['best_prec1'] checkpoint = torch.load(args.resume, map_location = lambda storage, loc: storage.cuda(args.gpu))
model.load_state_dict(checkpoint['state_dict']) args.start_epoch = checkpoint['epoch']
optimizer.load_state_dict(checkpoint['optimizer']) best_prec1 = checkpoint['best_prec1']
print("=> loaded checkpoint '{}' (epoch {})" model.load_state_dict(checkpoint['state_dict'])
.format(args.resume, checkpoint['epoch'])) saved_master_params = checkpoint['master_params']
else: for master, saved in zip(master_params, saved_master_params):
print("=> no checkpoint found at '{}'".format(args.resume)) master.data.copy_(saved.data)
optimizer.load_state_dict(checkpoint['optimizer'])
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.resume, checkpoint['epoch']))
else:
print("=> no checkpoint found at '{}'".format(args.resume))
resume()
# Data loading code # Data loading code
traindir = os.path.join(args.data, 'train') traindir = os.path.join(args.data, 'train')
...@@ -219,6 +225,7 @@ def main(): ...@@ -219,6 +225,7 @@ def main():
'state_dict': model.state_dict(), 'state_dict': model.state_dict(),
'best_prec1': best_prec1, 'best_prec1': best_prec1,
'optimizer' : optimizer.state_dict(), 'optimizer' : optimizer.state_dict(),
'master_params': master_params,
}, is_best) }, is_best)
class data_prefetcher(): class data_prefetcher():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment