"vscode:/vscode.git/clone" did not exist on "fa84540e98a6af309c3007f64def5011db775a70"
Unverified Commit c036c814 authored by Sourab Mangrulkar's avatar Sourab Mangrulkar Committed by GitHub
Browse files

fix the grad_acc issue at epoch boundaries (#24415)



* fix the grad_acc issue at epoch boundaries
Co-Authored-By: default avatarZach Mueller <7831895+muellerzr@users.noreply.github.com>

* add contributors.

Co-authored-by: sumpster

* address comments

---------
Co-authored-by: default avatarZach Mueller <7831895+muellerzr@users.noreply.github.com>
parent 468aed39
......@@ -196,7 +196,7 @@ if is_peft_available():
if is_accelerate_available():
from accelerate import Accelerator, skip_first_batches
from accelerate import __version__ as accelerate_version
from accelerate.utils import DistributedDataParallelKwargs
from accelerate.utils import DistributedDataParallelKwargs, GradientAccumulationPlugin
if version.parse(accelerate_version) > version.parse("0.20.3"):
from accelerate.utils import (
......@@ -1806,14 +1806,23 @@ class Trainer:
self.current_flos += float(self.floating_point_ops(inputs))
# should this be under the accumulate context manager?
# the `or` condition of `steps_in_epoch <= args.gradient_accumulation_steps` is not covered
# in accelerate
if total_batched_samples % args.gradient_accumulation_steps == 0 or (
is_last_step_and_steps_less_than_grad_acc = (
steps_in_epoch <= args.gradient_accumulation_steps and (step + 1) == steps_in_epoch
)
if (
total_batched_samples % args.gradient_accumulation_steps == 0
or
# last step in epoch but step is always smaller than gradient_accumulation_steps
steps_in_epoch <= args.gradient_accumulation_steps
and (step + 1) == steps_in_epoch
is_last_step_and_steps_less_than_grad_acc
):
# the `or` condition of `is_last_step_and_steps_less_than_grad_acc` is not covered
# in accelerate. So, explicitly enable sync gradients to True in that case.
if is_last_step_and_steps_less_than_grad_acc or (
version.parse(accelerate_version) <= version.parse("0.20.3")
):
self.accelerator.gradient_state._set_sync_gradients(True)
# Gradient clipping
if args.max_grad_norm is not None and args.max_grad_norm > 0:
# deepspeed does its own clipping
......@@ -3815,10 +3824,14 @@ class Trainer:
self.repo.git_push()
def create_accelerator_and_postprocess(self):
grad_acc_kwargs = {"num_steps": self.args.gradient_accumulation_steps}
if version.parse(accelerate_version) > version.parse("0.20.3"):
grad_acc_kwargs["sync_with_dataloader"] = False
gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs)
# create accelerator object
self.accelerator = Accelerator(
deepspeed_plugin=self.args.deepspeed_plugin,
gradient_accumulation_steps=self.args.gradient_accumulation_steps,
deepspeed_plugin=self.args.deepspeed_plugin, gradient_accumulation_plugin=gradient_accumulation_plugin
)
# deepspeed and accelerate flags covering both trainer args and accelerate launcher
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment