Unverified Commit 5b85add7 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[trainer] fix bug in grad accum with multiple epochs (#22098)

* [trainer] fix bug in grad accum

* comment out debug

* fix one-off

* rename counter
parent 1c801d65
......@@ -1831,6 +1831,7 @@ class Trainer:
# AT THE VERY END!
_ = list(train_dataloader.sampler)
total_batched_samples = 0
for epoch in range(epochs_trained, num_train_epochs):
if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
train_dataloader.sampler.set_epoch(epoch)
......@@ -1867,6 +1868,7 @@ class Trainer:
step = -1
for step, inputs in enumerate(epoch_iterator):
total_batched_samples += 1
if rng_to_sync:
self._load_rng_state(resume_from_checkpoint)
rng_to_sync = False
......@@ -1887,7 +1889,7 @@ class Trainer:
self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
if (
((step + 1) % args.gradient_accumulation_steps != 0)
(total_batched_samples % args.gradient_accumulation_steps != 0)
and args.local_rank != -1
and args._no_sync_in_gradient_accumulation
):
......@@ -1913,7 +1915,7 @@ class Trainer:
if self.deepspeed:
self.deepspeed.step()
if (step + 1) % args.gradient_accumulation_steps == 0 or (
if total_batched_samples % args.gradient_accumulation_steps == 0 or (
# last step in epoch but step is always smaller than gradient_accumulation_steps
steps_in_epoch <= args.gradient_accumulation_steps
and (step + 1) == steps_in_epoch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment