Unverified Commit e1da89cc authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix reproducibility in Training for PyTorch 1.11 (#16209)

parent e5101c2e
...@@ -1354,9 +1354,18 @@ class Trainer: ...@@ -1354,9 +1354,18 @@ class Trainer:
# Skip the first epochs_trained epochs to get the random state of the dataloader at the right point. # Skip the first epochs_trained epochs to get the random state of the dataloader at the right point.
if not args.ignore_data_skip: if not args.ignore_data_skip:
for epoch in range(epochs_trained): for epoch in range(epochs_trained):
# We just need to begin an iteration to create the randomization of the sampler. is_random_sampler = hasattr(train_dataloader, "sampler") and isinstance(
for _ in train_dataloader: train_dataloader.sampler, RandomSampler
break )
if version.parse(torch.__version__) < version.parse("1.11") or not is_random_sampler:
# We just need to begin an iteration to create the randomization of the sampler.
# That was before PyTorch 1.11 however...
for _ in train_dataloader:
break
else:
# Otherwise we need to call the whooooole sampler cause there is some random operation added
# AT THE VERY END!
_ = list(train_dataloader.sampler)
for epoch in range(epochs_trained, num_train_epochs): for epoch in range(epochs_trained, num_train_epochs):
if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler): if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment