Commit 343819f9 authored by Alexei Baevski's avatar Alexei Baevski Committed by Myle Ott
Browse files

dont send dummy batch when reloading from checkpoint

also don't crash if param does not recieve grads
parent b9956a6a
...@@ -271,9 +271,11 @@ class Trainer(object): ...@@ -271,9 +271,11 @@ class Trainer(object):
if not p.requires_grad: if not p.requires_grad:
continue continue
if p.grad is None: if p.grad is None:
raise RuntimeError('Model parameter did not receive gradient: ' + name + '. ' print('WARNING: model parameter did not receive gradient: ' + name + '. '
'Use the param in the forward pass or set requires_grad=False') 'Check that you\'re using the param in the forward pass or set requires_grad=False')
grads.append(p.grad.data) grads.append(p.new_zeros(p.shape))
else:
grads.append(p.grad.data)
return grads return grads
def _get_flat_grads(self, out=None): def _get_flat_grads(self, out=None):
......
...@@ -71,11 +71,10 @@ def main(args): ...@@ -71,11 +71,10 @@ def main(args):
) )
# Load the latest checkpoint if one is available # Load the latest checkpoint if one is available
load_checkpoint(args, trainer, epoch_itr) if not load_checkpoint(args, trainer, epoch_itr):
# Send a dummy batch to warm the caching allocator
# Send a dummy batch to warm the caching allocator dummy_batch = task.dataset('train').get_dummy_batch(args.max_tokens, max_positions)
dummy_batch = task.dataset('train').get_dummy_batch(args.max_tokens, max_positions) trainer.dummy_train_step(dummy_batch)
trainer.dummy_train_step(dummy_batch)
# Train until the learning rate gets too small # Train until the learning rate gets too small
max_epoch = args.max_epoch or math.inf max_epoch = args.max_epoch or math.inf
...@@ -319,6 +318,8 @@ def load_checkpoint(args, trainer, epoch_itr): ...@@ -319,6 +318,8 @@ def load_checkpoint(args, trainer, epoch_itr):
trainer.lr_step_update(trainer.get_num_updates()) trainer.lr_step_update(trainer.get_num_updates())
if 'best' in extra_state: if 'best' in extra_state:
save_checkpoint.best = extra_state['best'] save_checkpoint.best = extra_state['best']
return True
return False
def load_dataset_splits(task, splits): def load_dataset_splits(task, splits):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment