"examples/distillation/vscode:/vscode.git/clone" did not exist on "fea921d38265fad7d92b952f152a2aac314c3207"
Unverified Commit 4df5b9b4 authored by Yu Chin Fabian Lim's avatar Yu Chin Fabian Lim Committed by GitHub
Browse files

Allow GradientAccumulationPlugin to be configured from AcceleratorConfig (#29589)



* add gradient_accumulation_kwargs to AcceleratorConfig

* add suggestions from @muellerzr to docstrings, new behavior and tests

* Documentation suggestions from @muellerz
Co-authored-by: default avatarZach Mueller <muellerzr@gmail.com>

* addressed @muellerzr comments regarding tests and test utils

* moved accelerate version to top of file.

* @muellerzr's variable fix
Co-authored-by: default avatarZach Mueller <muellerzr@gmail.com>

* address @amyeroberts. fix tests and docstrings

* address @amyeroberts additional suggestions

---------
Co-authored-by: default avatarYu Chin Fabian Lim <flim@sg.ibm.com>
Co-authored-by: default avatarZach Mueller <muellerzr@gmail.com>
parent a2a7f716
......@@ -52,6 +52,7 @@ from .integrations import (
)
from .integrations.deepspeed import is_deepspeed_available
from .utils import (
ACCELERATE_MIN_VERSION,
is_accelerate_available,
is_apex_available,
is_aqlm_available,
......@@ -365,11 +366,13 @@ def require_nltk(test_case):
return unittest.skipUnless(is_nltk_available(), "test requires NLTK")(test_case)
def require_accelerate(test_case):
def require_accelerate(test_case, min_version: str = ACCELERATE_MIN_VERSION):
"""
Decorator marking a test that requires accelerate. These tests are skipped when accelerate isn't installed.
"""
return unittest.skipUnless(is_accelerate_available(), "test requires accelerate")(test_case)
return unittest.skipUnless(
is_accelerate_available(min_version), f"test requires accelerate version >= {min_version}"
)(test_case)
def require_fsdp(test_case, min_version: str = "1.12.0"):
......
......@@ -4324,8 +4324,23 @@ class Trainer:
self.repo.git_push()
def create_accelerator_and_postprocess(self):
grad_acc_kwargs = {"num_steps": self.args.gradient_accumulation_steps}
grad_acc_kwargs = {}
if is_accelerate_available("0.28.0") and self.args.accelerator_config.gradient_accumulation_kwargs is not None:
grad_acc_kwargs = self.args.accelerator_config.gradient_accumulation_kwargs
# check if num_steps is attempted to be passed in gradient_accumulation_kwargs
if "num_steps" in grad_acc_kwargs and self.args.gradient_accumulation_steps > 1:
# raise because we do not know which setting is intended.
raise ValueError(
"The `AcceleratorConfig`'s `num_steps` is set but `gradient_accumulation_steps` is greater than 1 in the passed `TrainingArguments`"
"If using the passed `AcceleratorConfig` is desired, do not set the `TrainingArguments` `gradient_accumulation_steps`."
)
elif "num_steps" not in grad_acc_kwargs:
# take the gradient_accumulation_steps setting from TrainingArguments.
grad_acc_kwargs["num_steps"] = self.args.gradient_accumulation_steps
grad_acc_kwargs["sync_with_dataloader"] = False
gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs)
accelerator_config = self.args.accelerator_config.to_dict()
......@@ -4337,6 +4352,8 @@ class Trainer:
even_batches=accelerator_config.pop("even_batches"),
use_seedable_sampler=accelerator_config.pop("use_seedable_sampler"),
)
# this would have been updated above, no need for it anymore
accelerator_config.pop("gradient_accumulation_kwargs")
args = {
"deepspeed_plugin": self.args.deepspeed_plugin,
"gradient_accumulation_plugin": gradient_accumulation_plugin,
......
......@@ -1185,6 +1185,15 @@ class AcceleratorConfig:
training results are fully reproducable using a different sampling technique. While seed-to-seed results
may differ, on average the differences are neglible when using multiple different seeds to compare. Should
also be ran with [`~utils.set_seed`] for the best results.
gradient_accumulation_kwargs (`dict`, *optional*):
Additional kwargs to configure gradient accumulation, see [`accelerate.utils.GradientAccumulationPlugin`].
Any of the following (optional) keys are acceptable:
num_steps (`int`): Will take precedence over [`~.TrainingArguments.gradient_accumulation_steps`] if
the latter is set to 1, otherwise an exception will be raised.
adjust_scheduler (`bool`): Whether to adjust the scheduler steps to account for [`~.TrainingArguments.gradient_accumulation_steps`].
The [`accelerate.utils.GradientAccumulationPlugin`] default is `True`.
sync_each_batch (`bool`): Whether to synchronize the gradients at each data batch.
The [`accelerate.utils.GradientAccumulationPlugin`] default is `False`.
"""
......@@ -1223,6 +1232,19 @@ class AcceleratorConfig:
"multiple different seeds to compare. Should also be ran with [`~utils.set_seed`] for the best results."
},
)
gradient_accumulation_kwargs: Optional[Dict] = field(
default=None,
metadata={
"help": "Additional kwargs to configure gradient accumulation, see [`accelerate.utils.GradientAccumulationPlugin`]. "
"Any of the following (optional) keys are acceptable: "
" num_steps (`int`): Will take precedence over [`~.TrainingArguments.gradient_accumulation_steps`] if "
" the latter is set to 1, otherwise an exception will be raised. "
" adjust_scheduler (`bool`): Whether to adjust the scheduler steps to account for [`~.TrainingArguments.gradient_accumulation_steps`]. "
" The [`accelerate.utils.GradientAccumulationPlugin`] default is `True`. "
" sync_each_batch (`bool`): Whether to synchronize the gradients at each data batch. "
" The [`accelerate.utils.GradientAccumulationPlugin`] default is `False`."
},
)
@classmethod
def from_json_file(cls, json_file):
......
......@@ -805,9 +805,7 @@ def is_protobuf_available():
def is_accelerate_available(min_version: str = ACCELERATE_MIN_VERSION):
if min_version is not None:
return _accelerate_available and version.parse(_accelerate_version) >= version.parse(min_version)
return _accelerate_available
def is_fsdp_available(min_version: str = FSDP_MIN_VERSION):
......
......@@ -24,6 +24,7 @@ import subprocess
import sys
import tempfile
import unittest
from functools import partial
from itertools import product
from pathlib import Path
from typing import Dict, List
......@@ -92,6 +93,7 @@ from transformers.utils import (
SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
is_accelerate_available,
is_apex_available,
is_bitsandbytes_available,
is_safetensors_available,
......@@ -127,6 +129,9 @@ if is_torch_available():
if is_safetensors_available():
import safetensors.torch
# for version specific tests in TrainerIntegrationTest
require_accelerate_version_min_0_28 = partial(require_accelerate, min_version="0.28")
GRAD_ACCUM_KWARGS_VERSION_AVAILABLE = is_accelerate_available("0.28")
PATH_SAMPLE_TEXT = f"{get_tests_dir()}/fixtures/sample_text.txt"
......@@ -2877,6 +2882,10 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertEqual(trainer.accelerator.even_batches, True)
self.assertEqual(trainer.accelerator.use_seedable_sampler, True)
if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE:
# gradient accumulation kwargs configures gradient_state
self.assertNotIn("sync_each_batch", trainer.accelerator.gradient_state.plugin_kwargs)
def test_accelerator_config_from_dict(self):
# Checks that accelerator kwargs can be passed through
# and the accelerator is initialized respectively
......@@ -2885,15 +2894,19 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
model = RegressionPreTrainedModel(config)
eval_dataset = SampleIterableDataset()
# Leaves all options as something *not* basic
args = RegressionTrainingArguments(
output_dir=tmp_dir,
accelerator_config={
accelerator_config = {
"split_batches": True,
"dispatch_batches": True,
"even_batches": False,
"use_seedable_sampler": True,
},
}
if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE:
accelerator_config["gradient_accumulation_kwargs"] = {"sync_each_batch": True}
# Leaves all options as something *not* basic
args = RegressionTrainingArguments(
output_dir=tmp_dir,
accelerator_config=accelerator_config,
)
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
self.assertEqual(trainer.accelerator.split_batches, True)
......@@ -2901,6 +2914,9 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertEqual(trainer.accelerator.even_batches, False)
self.assertEqual(trainer.accelerator.use_seedable_sampler, True)
if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE:
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_each_batch"], True)
def test_accelerator_config_from_yaml(self):
# Checks that accelerator kwargs can be passed through
# and the accelerator is initialized respectively
......@@ -2913,6 +2929,8 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
"even_batches": False,
"use_seedable_sampler": False,
}
if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE:
accelerator_config["gradient_accumulation_kwargs"] = {"sync_each_batch": True}
json.dump(accelerator_config, f)
config = RegressionModelConfig(a=1.5, b=2.5)
model = RegressionPreTrainedModel(config)
......@@ -2926,11 +2944,18 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertEqual(trainer.accelerator.even_batches, False)
self.assertEqual(trainer.accelerator.use_seedable_sampler, False)
if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE:
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_each_batch"], True)
def test_accelerator_config_from_dataclass(self):
# Checks that accelerator kwargs can be passed through
# and the accelerator is initialized respectively
accelerator_config = AcceleratorConfig(
split_batches=True, dispatch_batches=True, even_batches=False, use_seedable_sampler=False
split_batches=True,
dispatch_batches=True,
even_batches=False,
use_seedable_sampler=False,
)
config = RegressionModelConfig(a=1.5, b=2.5)
model = RegressionPreTrainedModel(config)
......@@ -2943,6 +2968,35 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertEqual(trainer.accelerator.even_batches, False)
self.assertEqual(trainer.accelerator.use_seedable_sampler, False)
@require_accelerate_version_min_0_28
def test_accelerate_config_from_dataclass_grad_accum(self):
# Checks that accelerator kwargs can be passed through
# and the accelerator is initialized respectively
grad_acc_kwargs = {
"num_steps": 10,
"adjust_scheduler": False,
"sync_with_dataloader": False,
"sync_each_batch": True,
}
accelerator_config = AcceleratorConfig(
split_batches=True,
dispatch_batches=True,
even_batches=False,
use_seedable_sampler=False,
gradient_accumulation_kwargs=grad_acc_kwargs,
)
config = RegressionModelConfig(a=1.5, b=2.5)
model = RegressionPreTrainedModel(config)
eval_dataset = SampleIterableDataset()
with tempfile.TemporaryDirectory() as tmp_dir:
args = RegressionTrainingArguments(output_dir=tmp_dir, accelerator_config=accelerator_config)
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["num_steps"], 10)
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["adjust_scheduler"], False)
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_with_dataloader"], False)
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_each_batch"], True)
def test_accelerator_config_from_partial(self):
# Checks that accelerator kwargs can be passed through
# and the accelerator is initialized respectively
......@@ -3014,6 +3068,44 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
self.assertEqual(trainer.accelerator.split_batches, True)
@require_accelerate_version_min_0_28
def test_accelerator_config_from_dict_grad_accum_num_steps(self):
with tempfile.TemporaryDirectory() as tmp_dir:
config = RegressionModelConfig(a=1.5, b=2.5)
model = RegressionPreTrainedModel(config)
eval_dataset = SampleIterableDataset()
# case - TrainingArguments.gradient_accumulation_steps == 1
# - gradient_accumulation_kwargs['num_steps] == 1
# results in grad accum set to 1
args = RegressionTrainingArguments(
output_dir=tmp_dir,
gradient_accumulation_steps=1,
accelerator_config={
"gradient_accumulation_kwargs": {
"num_steps": 1,
}
},
)
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["num_steps"], 1)
# case - TrainingArguments.gradient_accumulation_steps > 1
# - gradient_accumulation_kwargs['num_steps] specified
# results in exception raised
args = RegressionTrainingArguments(
output_dir=tmp_dir,
gradient_accumulation_steps=2,
accelerator_config={
"gradient_accumulation_kwargs": {
"num_steps": 10,
}
},
)
with self.assertRaises(Exception) as context:
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
self.assertTrue("The `AcceleratorConfig`'s `num_steps` is set but" in str(context.exception))
@require_torch
@is_staging_test
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment