Commit f4108909 authored by Alexei Baevski's avatar Alexei Baevski Committed by Myle Ott
Browse files

build optimizer only once, otherwise it leaks cuda memory

parent 92050ef2
...@@ -38,8 +38,7 @@ class Trainer(object): ...@@ -38,8 +38,7 @@ class Trainer(object):
self.model = model.cuda() self.model = model.cuda()
self.criterion = criterion.cuda() self.criterion = criterion.cuda()
# initialize optimizer and LR scheduler self.optimizer = None
self._build_optimizer()
# initialize meters # initialize meters
self.meters = OrderedDict() self.meters = OrderedDict()
...@@ -96,6 +95,10 @@ class Trainer(object): ...@@ -96,6 +95,10 @@ class Trainer(object):
def train_step(self, sample, update_params=True): def train_step(self, sample, update_params=True):
"""Do forward, backward and parameter update.""" """Do forward, backward and parameter update."""
if self.optimizer is None:
# initialize optimizer and LR scheduler if hasn't been loaded from the checkpoint
self._build_optimizer()
sample = self._prepare_sample(sample, volatile=False) sample = self._prepare_sample(sample, volatile=False)
# forward and backward pass # forward and backward pass
......
...@@ -310,9 +310,13 @@ def save_checkpoint(args, trainer, epoch, end_of_epoch, val_loss): ...@@ -310,9 +310,13 @@ def save_checkpoint(args, trainer, epoch, end_of_epoch, val_loss):
for fn in checkpoints: for fn in checkpoints:
if os.path.exists(fn): if os.path.exists(fn):
os.remove(fn) os.remove(fn)
trainer.save_checkpoint(checkpoints[0], extra_state) if not end_of_epoch and args.keep_interval_updates > 0:
for fn in checkpoints[1:]: for cp in checkpoints:
os.symlink(os.path.basename(checkpoints[0]), fn) trainer.save_checkpoint(cp, extra_state)
else:
trainer.save_checkpoint(checkpoints[0], extra_state)
for fn in checkpoints[1:]:
os.symlink(os.path.basename(checkpoints[0]), fn)
if not end_of_epoch and args.keep_interval_updates > 0: if not end_of_epoch and args.keep_interval_updates > 0:
# remove old checkpoints; checkpoints are sorted in descending order # remove old checkpoints; checkpoints are sorted in descending order
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment