Commit 327b2446 authored by Michael Carilli's avatar Michael Carilli
Browse files

Fixing imagenet main.py and main_reducer.py to save and load master params

parent b7025fc9
......@@ -139,19 +139,25 @@ def main():
momentum=args.momentum,
weight_decay=args.weight_decay)
# optionally resume from a checkpoint
# Optionally resume from a checkpoint
if args.resume:
if os.path.isfile(args.resume):
print("=> loading checkpoint '{}'".format(args.resume))
checkpoint = torch.load(args.resume, map_location = lambda storage, loc: storage.cuda(args.gpu))
args.start_epoch = checkpoint['epoch']
best_prec1 = checkpoint['best_prec1']
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.resume, checkpoint['epoch']))
else:
print("=> no checkpoint found at '{}'".format(args.resume))
# Use a local scope to avoid dangling references
def resume():
if os.path.isfile(args.resume):
print("=> loading checkpoint '{}'".format(args.resume))
checkpoint = torch.load(args.resume, map_location = lambda storage, loc: storage.cuda(args.gpu))
args.start_epoch = checkpoint['epoch']
best_prec1 = checkpoint['best_prec1']
model.load_state_dict(checkpoint['state_dict'])
saved_master_params = checkpoint['master_params']
for master, saved in zip(master_params, saved_master_params):
master.data.copy_(saved.data)
optimizer.load_state_dict(checkpoint['optimizer'])
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.resume, checkpoint['epoch']))
else:
print("=> no checkpoint found at '{}'".format(args.resume))
resume()
# Data loading code
traindir = os.path.join(args.data, 'train')
......@@ -219,6 +225,7 @@ def main():
'state_dict': model.state_dict(),
'best_prec1': best_prec1,
'optimizer' : optimizer.state_dict(),
'master_params': master_params,
}, is_best)
class data_prefetcher():
......
......@@ -144,19 +144,23 @@ def main():
static_loss_scale=args.static_loss_scale,
dynamic_loss_scale=args.dynamic_loss_scale)
# optionally resume from a checkpoint
# Optionally resume from a checkpoint
if args.resume:
if os.path.isfile(args.resume):
print("=> loading checkpoint '{}'".format(args.resume))
checkpoint = torch.load(args.resume, map_location = lambda storage, loc: storage.cuda(args.gpu))
args.start_epoch = checkpoint['epoch']
best_prec1 = checkpoint['best_prec1']
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.resume, checkpoint['epoch']))
else:
print("=> no checkpoint found at '{}'".format(args.resume))
# Use a local scope to avoid dangling references
def resume():
if os.path.isfile(args.resume):
print("=> loading checkpoint '{}'".format(args.resume))
checkpoint = torch.load(args.resume, map_location = lambda storage, loc: storage.cuda(args.gpu))
args.start_epoch = checkpoint['epoch']
best_prec1 = checkpoint['best_prec1']
model.load_state_dict(checkpoint['state_dict'])
# An FP16_Optimizer instance's state dict internally stashes the master params.
optimizer.load_state_dict(checkpoint['optimizer'])
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.resume, checkpoint['epoch']))
else:
print("=> no checkpoint found at '{}'".format(args.resume))
resume()
# Data loading code
traindir = os.path.join(args.data, 'train')
......
......@@ -139,19 +139,25 @@ def main():
momentum=args.momentum,
weight_decay=args.weight_decay)
# optionally resume from a checkpoint
# Optionally resume from a checkpoint
if args.resume:
if os.path.isfile(args.resume):
print("=> loading checkpoint '{}'".format(args.resume))
checkpoint = torch.load(args.resume, map_location = lambda storage, loc: storage.cuda(args.gpu))
args.start_epoch = checkpoint['epoch']
best_prec1 = checkpoint['best_prec1']
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.resume, checkpoint['epoch']))
else:
print("=> no checkpoint found at '{}'".format(args.resume))
# Use a local scope to avoid dangling references
def resume():
if os.path.isfile(args.resume):
print("=> loading checkpoint '{}'".format(args.resume))
checkpoint = torch.load(args.resume, map_location = lambda storage, loc: storage.cuda(args.gpu))
args.start_epoch = checkpoint['epoch']
best_prec1 = checkpoint['best_prec1']
model.load_state_dict(checkpoint['state_dict'])
saved_master_params = checkpoint['master_params']
for master, saved in zip(master_params, saved_master_params):
master.data.copy_(saved.data)
optimizer.load_state_dict(checkpoint['optimizer'])
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.resume, checkpoint['epoch']))
else:
print("=> no checkpoint found at '{}'".format(args.resume))
resume()
# Data loading code
traindir = os.path.join(args.data, 'train')
......@@ -219,6 +225,7 @@ def main():
'state_dict': model.state_dict(),
'best_prec1': best_prec1,
'optimizer' : optimizer.state_dict(),
'master_params': master_params,
}, is_best)
class data_prefetcher():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment