Commit cd4822f1 authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Make sure dataloader state is the same after checkpoint is loaded

parent c671de3e
......@@ -870,8 +870,8 @@ def build_train_valid_test_data_iterators(
# Shift the start iterations.
if train_dataloader is not None:
train_dataloader.batch_sampler.start_iter = args.iteration % \
len(train_dataloader)
train_dataloader.batch_sampler.start_iter = \
(args.iteration * args.num_microbatches_in_minibatch) % len(train_dataloader)
print_rank_0('setting training data start iteration to {}'.
format(train_dataloader.batch_sampler.start_iter))
if valid_dataloader is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment