Unverified Commit 5b85add7 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[trainer] fix bug in grad accum with multiple epochs (#22098)

* [trainer] fix bug in grad accum

* comment out debug

* fix one-off

* rename counter
parent 1c801d65
...@@ -1831,6 +1831,7 @@ class Trainer: ...@@ -1831,6 +1831,7 @@ class Trainer:
# AT THE VERY END! # AT THE VERY END!
_ = list(train_dataloader.sampler) _ = list(train_dataloader.sampler)
total_batched_samples = 0
for epoch in range(epochs_trained, num_train_epochs): for epoch in range(epochs_trained, num_train_epochs):
if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler): if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
train_dataloader.sampler.set_epoch(epoch) train_dataloader.sampler.set_epoch(epoch)
...@@ -1867,6 +1868,7 @@ class Trainer: ...@@ -1867,6 +1868,7 @@ class Trainer:
step = -1 step = -1
for step, inputs in enumerate(epoch_iterator): for step, inputs in enumerate(epoch_iterator):
total_batched_samples += 1
if rng_to_sync: if rng_to_sync:
self._load_rng_state(resume_from_checkpoint) self._load_rng_state(resume_from_checkpoint)
rng_to_sync = False rng_to_sync = False
...@@ -1887,7 +1889,7 @@ class Trainer: ...@@ -1887,7 +1889,7 @@ class Trainer:
self.control = self.callback_handler.on_step_begin(args, self.state, self.control) self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
if ( if (
((step + 1) % args.gradient_accumulation_steps != 0) (total_batched_samples % args.gradient_accumulation_steps != 0)
and args.local_rank != -1 and args.local_rank != -1
and args._no_sync_in_gradient_accumulation and args._no_sync_in_gradient_accumulation
): ):
...@@ -1913,7 +1915,7 @@ class Trainer: ...@@ -1913,7 +1915,7 @@ class Trainer:
if self.deepspeed: if self.deepspeed:
self.deepspeed.step() self.deepspeed.step()
if (step + 1) % args.gradient_accumulation_steps == 0 or ( if total_batched_samples % args.gradient_accumulation_steps == 0 or (
# last step in epoch but step is always smaller than gradient_accumulation_steps # last step in epoch but step is always smaller than gradient_accumulation_steps
steps_in_epoch <= args.gradient_accumulation_steps steps_in_epoch <= args.gradient_accumulation_steps
and (step + 1) == steps_in_epoch and (step + 1) == steps_in_epoch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment