"tests/vscode:/vscode.git/clone" did not exist on "04f4bd54ea3126185ced2ffdf26f608dcd1db30e"
Commit f4108909 authored by Alexei Baevski's avatar Alexei Baevski Committed by Myle Ott
Browse files

build optimizer only once, otherwise it leaks cuda memory

parent 92050ef2
......@@ -38,8 +38,7 @@ class Trainer(object):
self.model = model.cuda()
self.criterion = criterion.cuda()
# initialize optimizer and LR scheduler
self._build_optimizer()
self.optimizer = None
# initialize meters
self.meters = OrderedDict()
......@@ -96,6 +95,10 @@ class Trainer(object):
def train_step(self, sample, update_params=True):
"""Do forward, backward and parameter update."""
if self.optimizer is None:
# initialize optimizer and LR scheduler if hasn't been loaded from the checkpoint
self._build_optimizer()
sample = self._prepare_sample(sample, volatile=False)
# forward and backward pass
......
......@@ -310,9 +310,13 @@ def save_checkpoint(args, trainer, epoch, end_of_epoch, val_loss):
for fn in checkpoints:
if os.path.exists(fn):
os.remove(fn)
trainer.save_checkpoint(checkpoints[0], extra_state)
for fn in checkpoints[1:]:
os.symlink(os.path.basename(checkpoints[0]), fn)
if not end_of_epoch and args.keep_interval_updates > 0:
for cp in checkpoints:
trainer.save_checkpoint(cp, extra_state)
else:
trainer.save_checkpoint(checkpoints[0], extra_state)
for fn in checkpoints[1:]:
os.symlink(os.path.basename(checkpoints[0]), fn)
if not end_of_epoch and args.keep_interval_updates > 0:
# remove old checkpoints; checkpoints are sorted in descending order
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment