Unverified Commit 90412401 authored by Zach Mueller's avatar Zach Mueller Committed by GitHub
Browse files

Bring back `set_epoch` for Accelerate-based dataloaders (#26850)



* Working tests!

* Fix sampler

* Fix

* Update src/transformers/trainer.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Fix check

* Clean

---------
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent 3c269240
......@@ -200,6 +200,11 @@ if is_accelerate_available():
save_fsdp_model,
save_fsdp_optimizer,
)
DATA_SAMPLERS = [RandomSampler]
if version.parse(accelerate_version) > version.parse("0.23.0"):
from accelerate.data_loader import SeedableRandomSampler
DATA_SAMPLERS += [SeedableRandomSampler]
if is_deepspeed_available():
from accelerate.utils import DeepSpeedSchedulerWrapper
......@@ -1738,7 +1743,10 @@ class Trainer:
if not args.ignore_data_skip:
for epoch in range(epochs_trained):
sampler = get_dataloader_sampler(train_dataloader)
is_random_sampler = isinstance(sampler, RandomSampler)
sampler_kinds = [RandomSampler]
if version.parse(accelerate_version) > version.parse("0.23.0"):
sampler_kinds.append(SeedableRandomSampler)
is_random_sampler = isinstance(sampler, tuple(sampler_kinds))
if is_torch_less_than_1_11 or not is_random_sampler:
# We just need to begin an iteration to create the randomization of the sampler.
for _ in train_dataloader:
......@@ -1752,6 +1760,8 @@ class Trainer:
total_batched_samples = 0
for epoch in range(epochs_trained, num_train_epochs):
epoch_iterator = train_dataloader
if hasattr(epoch_iterator, "set_epoch"):
epoch_iterator.set_epoch(epoch)
# Reset the past mems state at the beginning of each epoch if necessary.
if args.past_index >= 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment