Commit 6e83649f authored by Jared Casper's avatar Jared Casper Committed by Deepak Narayanan
Browse files

Quick fix for pipeline tasks to get learning rate correct

parent 25c07e14
......@@ -129,10 +129,14 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset):
args.num_workers, not args.keep_last)
valid_dataloader = _build_infinite_size_dataloader(valid_dataloader_)
# Now that we've built the data loaders, set args.micro_batch_size to
# the actual batch size the model will see for this dataset
# Now that we've built the data loaders, set batch_size arguments
# to the actual batch size the model will see for this dataset.
# This is necessary so pipeline transfers know what size they are
# and the LR schedule, which is based on samples seen, gets set
# correctly.
if hasattr(train_dataset, 'sample_multiplier'):
args.micro_batch_size *= train_dataset.sample_multiplier
args.global_batch_size *= train_dataset.sample_multiplier
return train_dataloader, valid_dataloader
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment