Commit 343819f9 authored by Alexei Baevski's avatar Alexei Baevski Committed by Myle Ott
Browse files

dont send dummy batch when reloading from checkpoint

also don't crash if param does not recieve grads
parent b9956a6a
......@@ -271,9 +271,11 @@ class Trainer(object):
if not p.requires_grad:
continue
if p.grad is None:
raise RuntimeError('Model parameter did not receive gradient: ' + name + '. '
'Use the param in the forward pass or set requires_grad=False')
grads.append(p.grad.data)
print('WARNING: model parameter did not receive gradient: ' + name + '. '
'Check that you\'re using the param in the forward pass or set requires_grad=False')
grads.append(p.new_zeros(p.shape))
else:
grads.append(p.grad.data)
return grads
def _get_flat_grads(self, out=None):
......
......@@ -71,11 +71,10 @@ def main(args):
)
# Load the latest checkpoint if one is available
load_checkpoint(args, trainer, epoch_itr)
# Send a dummy batch to warm the caching allocator
dummy_batch = task.dataset('train').get_dummy_batch(args.max_tokens, max_positions)
trainer.dummy_train_step(dummy_batch)
if not load_checkpoint(args, trainer, epoch_itr):
# Send a dummy batch to warm the caching allocator
dummy_batch = task.dataset('train').get_dummy_batch(args.max_tokens, max_positions)
trainer.dummy_train_step(dummy_batch)
# Train until the learning rate gets too small
max_epoch = args.max_epoch or math.inf
......@@ -319,6 +318,8 @@ def load_checkpoint(args, trainer, epoch_itr):
trainer.lr_step_update(trainer.get_num_updates())
if 'best' in extra_state:
save_checkpoint.best = extra_state['best']
return True
return False
def load_dataset_splits(task, splits):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment