Unverified Commit 4df5b9b4 authored by Yu Chin Fabian Lim's avatar Yu Chin Fabian Lim Committed by GitHub
Browse files

Allow GradientAccumulationPlugin to be configured from AcceleratorConfig (#29589)



* add gradient_accumulation_kwargs to AcceleratorConfig

* add suggestions from @muellerzr to docstrings, new behavior and tests

* Documentation suggestions from @muellerz
Co-authored-by: default avatarZach Mueller <muellerzr@gmail.com>

* addressed @muellerzr comments regarding tests and test utils

* moved accelerate version to top of file.

* @muellerzr's variable fix
Co-authored-by: default avatarZach Mueller <muellerzr@gmail.com>

* address @amyeroberts. fix tests and docstrings

* address @amyeroberts additional suggestions

---------
Co-authored-by: default avatarYu Chin Fabian Lim <flim@sg.ibm.com>
Co-authored-by: default avatarZach Mueller <muellerzr@gmail.com>
parent a2a7f716
...@@ -52,6 +52,7 @@ from .integrations import ( ...@@ -52,6 +52,7 @@ from .integrations import (
) )
from .integrations.deepspeed import is_deepspeed_available from .integrations.deepspeed import is_deepspeed_available
from .utils import ( from .utils import (
ACCELERATE_MIN_VERSION,
is_accelerate_available, is_accelerate_available,
is_apex_available, is_apex_available,
is_aqlm_available, is_aqlm_available,
...@@ -365,11 +366,13 @@ def require_nltk(test_case): ...@@ -365,11 +366,13 @@ def require_nltk(test_case):
return unittest.skipUnless(is_nltk_available(), "test requires NLTK")(test_case) return unittest.skipUnless(is_nltk_available(), "test requires NLTK")(test_case)
def require_accelerate(test_case): def require_accelerate(test_case, min_version: str = ACCELERATE_MIN_VERSION):
""" """
Decorator marking a test that requires accelerate. These tests are skipped when accelerate isn't installed. Decorator marking a test that requires accelerate. These tests are skipped when accelerate isn't installed.
""" """
return unittest.skipUnless(is_accelerate_available(), "test requires accelerate")(test_case) return unittest.skipUnless(
is_accelerate_available(min_version), f"test requires accelerate version >= {min_version}"
)(test_case)
def require_fsdp(test_case, min_version: str = "1.12.0"): def require_fsdp(test_case, min_version: str = "1.12.0"):
......
...@@ -4324,8 +4324,23 @@ class Trainer: ...@@ -4324,8 +4324,23 @@ class Trainer:
self.repo.git_push() self.repo.git_push()
def create_accelerator_and_postprocess(self): def create_accelerator_and_postprocess(self):
grad_acc_kwargs = {"num_steps": self.args.gradient_accumulation_steps} grad_acc_kwargs = {}
if is_accelerate_available("0.28.0") and self.args.accelerator_config.gradient_accumulation_kwargs is not None:
grad_acc_kwargs = self.args.accelerator_config.gradient_accumulation_kwargs
# check if num_steps is attempted to be passed in gradient_accumulation_kwargs
if "num_steps" in grad_acc_kwargs and self.args.gradient_accumulation_steps > 1:
# raise because we do not know which setting is intended.
raise ValueError(
"The `AcceleratorConfig`'s `num_steps` is set but `gradient_accumulation_steps` is greater than 1 in the passed `TrainingArguments`"
"If using the passed `AcceleratorConfig` is desired, do not set the `TrainingArguments` `gradient_accumulation_steps`."
)
elif "num_steps" not in grad_acc_kwargs:
# take the gradient_accumulation_steps setting from TrainingArguments.
grad_acc_kwargs["num_steps"] = self.args.gradient_accumulation_steps
grad_acc_kwargs["sync_with_dataloader"] = False grad_acc_kwargs["sync_with_dataloader"] = False
gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs) gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs)
accelerator_config = self.args.accelerator_config.to_dict() accelerator_config = self.args.accelerator_config.to_dict()
...@@ -4337,6 +4352,8 @@ class Trainer: ...@@ -4337,6 +4352,8 @@ class Trainer:
even_batches=accelerator_config.pop("even_batches"), even_batches=accelerator_config.pop("even_batches"),
use_seedable_sampler=accelerator_config.pop("use_seedable_sampler"), use_seedable_sampler=accelerator_config.pop("use_seedable_sampler"),
) )
# this would have been updated above, no need for it anymore
accelerator_config.pop("gradient_accumulation_kwargs")
args = { args = {
"deepspeed_plugin": self.args.deepspeed_plugin, "deepspeed_plugin": self.args.deepspeed_plugin,
"gradient_accumulation_plugin": gradient_accumulation_plugin, "gradient_accumulation_plugin": gradient_accumulation_plugin,
......
...@@ -1185,6 +1185,15 @@ class AcceleratorConfig: ...@@ -1185,6 +1185,15 @@ class AcceleratorConfig:
training results are fully reproducable using a different sampling technique. While seed-to-seed results training results are fully reproducable using a different sampling technique. While seed-to-seed results
may differ, on average the differences are neglible when using multiple different seeds to compare. Should may differ, on average the differences are neglible when using multiple different seeds to compare. Should
also be ran with [`~utils.set_seed`] for the best results. also be ran with [`~utils.set_seed`] for the best results.
gradient_accumulation_kwargs (`dict`, *optional*):
Additional kwargs to configure gradient accumulation, see [`accelerate.utils.GradientAccumulationPlugin`].
Any of the following (optional) keys are acceptable:
num_steps (`int`): Will take precedence over [`~.TrainingArguments.gradient_accumulation_steps`] if
the latter is set to 1, otherwise an exception will be raised.
adjust_scheduler (`bool`): Whether to adjust the scheduler steps to account for [`~.TrainingArguments.gradient_accumulation_steps`].
The [`accelerate.utils.GradientAccumulationPlugin`] default is `True`.
sync_each_batch (`bool`): Whether to synchronize the gradients at each data batch.
The [`accelerate.utils.GradientAccumulationPlugin`] default is `False`.
""" """
...@@ -1223,6 +1232,19 @@ class AcceleratorConfig: ...@@ -1223,6 +1232,19 @@ class AcceleratorConfig:
"multiple different seeds to compare. Should also be ran with [`~utils.set_seed`] for the best results." "multiple different seeds to compare. Should also be ran with [`~utils.set_seed`] for the best results."
}, },
) )
gradient_accumulation_kwargs: Optional[Dict] = field(
default=None,
metadata={
"help": "Additional kwargs to configure gradient accumulation, see [`accelerate.utils.GradientAccumulationPlugin`]. "
"Any of the following (optional) keys are acceptable: "
" num_steps (`int`): Will take precedence over [`~.TrainingArguments.gradient_accumulation_steps`] if "
" the latter is set to 1, otherwise an exception will be raised. "
" adjust_scheduler (`bool`): Whether to adjust the scheduler steps to account for [`~.TrainingArguments.gradient_accumulation_steps`]. "
" The [`accelerate.utils.GradientAccumulationPlugin`] default is `True`. "
" sync_each_batch (`bool`): Whether to synchronize the gradients at each data batch. "
" The [`accelerate.utils.GradientAccumulationPlugin`] default is `False`."
},
)
@classmethod @classmethod
def from_json_file(cls, json_file): def from_json_file(cls, json_file):
......
...@@ -805,9 +805,7 @@ def is_protobuf_available(): ...@@ -805,9 +805,7 @@ def is_protobuf_available():
def is_accelerate_available(min_version: str = ACCELERATE_MIN_VERSION): def is_accelerate_available(min_version: str = ACCELERATE_MIN_VERSION):
if min_version is not None:
return _accelerate_available and version.parse(_accelerate_version) >= version.parse(min_version) return _accelerate_available and version.parse(_accelerate_version) >= version.parse(min_version)
return _accelerate_available
def is_fsdp_available(min_version: str = FSDP_MIN_VERSION): def is_fsdp_available(min_version: str = FSDP_MIN_VERSION):
......
...@@ -24,6 +24,7 @@ import subprocess ...@@ -24,6 +24,7 @@ import subprocess
import sys import sys
import tempfile import tempfile
import unittest import unittest
from functools import partial
from itertools import product from itertools import product
from pathlib import Path from pathlib import Path
from typing import Dict, List from typing import Dict, List
...@@ -92,6 +93,7 @@ from transformers.utils import ( ...@@ -92,6 +93,7 @@ from transformers.utils import (
SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME,
WEIGHTS_NAME, WEIGHTS_NAME,
is_accelerate_available,
is_apex_available, is_apex_available,
is_bitsandbytes_available, is_bitsandbytes_available,
is_safetensors_available, is_safetensors_available,
...@@ -127,6 +129,9 @@ if is_torch_available(): ...@@ -127,6 +129,9 @@ if is_torch_available():
if is_safetensors_available(): if is_safetensors_available():
import safetensors.torch import safetensors.torch
# for version specific tests in TrainerIntegrationTest
require_accelerate_version_min_0_28 = partial(require_accelerate, min_version="0.28")
GRAD_ACCUM_KWARGS_VERSION_AVAILABLE = is_accelerate_available("0.28")
PATH_SAMPLE_TEXT = f"{get_tests_dir()}/fixtures/sample_text.txt" PATH_SAMPLE_TEXT = f"{get_tests_dir()}/fixtures/sample_text.txt"
...@@ -2877,6 +2882,10 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): ...@@ -2877,6 +2882,10 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertEqual(trainer.accelerator.even_batches, True) self.assertEqual(trainer.accelerator.even_batches, True)
self.assertEqual(trainer.accelerator.use_seedable_sampler, True) self.assertEqual(trainer.accelerator.use_seedable_sampler, True)
if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE:
# gradient accumulation kwargs configures gradient_state
self.assertNotIn("sync_each_batch", trainer.accelerator.gradient_state.plugin_kwargs)
def test_accelerator_config_from_dict(self): def test_accelerator_config_from_dict(self):
# Checks that accelerator kwargs can be passed through # Checks that accelerator kwargs can be passed through
# and the accelerator is initialized respectively # and the accelerator is initialized respectively
...@@ -2885,15 +2894,19 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): ...@@ -2885,15 +2894,19 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
model = RegressionPreTrainedModel(config) model = RegressionPreTrainedModel(config)
eval_dataset = SampleIterableDataset() eval_dataset = SampleIterableDataset()
# Leaves all options as something *not* basic accelerator_config = {
args = RegressionTrainingArguments(
output_dir=tmp_dir,
accelerator_config={
"split_batches": True, "split_batches": True,
"dispatch_batches": True, "dispatch_batches": True,
"even_batches": False, "even_batches": False,
"use_seedable_sampler": True, "use_seedable_sampler": True,
}, }
if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE:
accelerator_config["gradient_accumulation_kwargs"] = {"sync_each_batch": True}
# Leaves all options as something *not* basic
args = RegressionTrainingArguments(
output_dir=tmp_dir,
accelerator_config=accelerator_config,
) )
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset) trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
self.assertEqual(trainer.accelerator.split_batches, True) self.assertEqual(trainer.accelerator.split_batches, True)
...@@ -2901,6 +2914,9 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): ...@@ -2901,6 +2914,9 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertEqual(trainer.accelerator.even_batches, False) self.assertEqual(trainer.accelerator.even_batches, False)
self.assertEqual(trainer.accelerator.use_seedable_sampler, True) self.assertEqual(trainer.accelerator.use_seedable_sampler, True)
if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE:
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_each_batch"], True)
def test_accelerator_config_from_yaml(self): def test_accelerator_config_from_yaml(self):
# Checks that accelerator kwargs can be passed through # Checks that accelerator kwargs can be passed through
# and the accelerator is initialized respectively # and the accelerator is initialized respectively
...@@ -2913,6 +2929,8 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): ...@@ -2913,6 +2929,8 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
"even_batches": False, "even_batches": False,
"use_seedable_sampler": False, "use_seedable_sampler": False,
} }
if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE:
accelerator_config["gradient_accumulation_kwargs"] = {"sync_each_batch": True}
json.dump(accelerator_config, f) json.dump(accelerator_config, f)
config = RegressionModelConfig(a=1.5, b=2.5) config = RegressionModelConfig(a=1.5, b=2.5)
model = RegressionPreTrainedModel(config) model = RegressionPreTrainedModel(config)
...@@ -2926,11 +2944,18 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): ...@@ -2926,11 +2944,18 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertEqual(trainer.accelerator.even_batches, False) self.assertEqual(trainer.accelerator.even_batches, False)
self.assertEqual(trainer.accelerator.use_seedable_sampler, False) self.assertEqual(trainer.accelerator.use_seedable_sampler, False)
if GRAD_ACCUM_KWARGS_VERSION_AVAILABLE:
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_each_batch"], True)
def test_accelerator_config_from_dataclass(self): def test_accelerator_config_from_dataclass(self):
# Checks that accelerator kwargs can be passed through # Checks that accelerator kwargs can be passed through
# and the accelerator is initialized respectively # and the accelerator is initialized respectively
accelerator_config = AcceleratorConfig( accelerator_config = AcceleratorConfig(
split_batches=True, dispatch_batches=True, even_batches=False, use_seedable_sampler=False split_batches=True,
dispatch_batches=True,
even_batches=False,
use_seedable_sampler=False,
) )
config = RegressionModelConfig(a=1.5, b=2.5) config = RegressionModelConfig(a=1.5, b=2.5)
model = RegressionPreTrainedModel(config) model = RegressionPreTrainedModel(config)
...@@ -2943,6 +2968,35 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): ...@@ -2943,6 +2968,35 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertEqual(trainer.accelerator.even_batches, False) self.assertEqual(trainer.accelerator.even_batches, False)
self.assertEqual(trainer.accelerator.use_seedable_sampler, False) self.assertEqual(trainer.accelerator.use_seedable_sampler, False)
@require_accelerate_version_min_0_28
def test_accelerate_config_from_dataclass_grad_accum(self):
# Checks that accelerator kwargs can be passed through
# and the accelerator is initialized respectively
grad_acc_kwargs = {
"num_steps": 10,
"adjust_scheduler": False,
"sync_with_dataloader": False,
"sync_each_batch": True,
}
accelerator_config = AcceleratorConfig(
split_batches=True,
dispatch_batches=True,
even_batches=False,
use_seedable_sampler=False,
gradient_accumulation_kwargs=grad_acc_kwargs,
)
config = RegressionModelConfig(a=1.5, b=2.5)
model = RegressionPreTrainedModel(config)
eval_dataset = SampleIterableDataset()
with tempfile.TemporaryDirectory() as tmp_dir:
args = RegressionTrainingArguments(output_dir=tmp_dir, accelerator_config=accelerator_config)
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["num_steps"], 10)
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["adjust_scheduler"], False)
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_with_dataloader"], False)
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["sync_each_batch"], True)
def test_accelerator_config_from_partial(self): def test_accelerator_config_from_partial(self):
# Checks that accelerator kwargs can be passed through # Checks that accelerator kwargs can be passed through
# and the accelerator is initialized respectively # and the accelerator is initialized respectively
...@@ -3014,6 +3068,44 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): ...@@ -3014,6 +3068,44 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset) trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
self.assertEqual(trainer.accelerator.split_batches, True) self.assertEqual(trainer.accelerator.split_batches, True)
@require_accelerate_version_min_0_28
def test_accelerator_config_from_dict_grad_accum_num_steps(self):
with tempfile.TemporaryDirectory() as tmp_dir:
config = RegressionModelConfig(a=1.5, b=2.5)
model = RegressionPreTrainedModel(config)
eval_dataset = SampleIterableDataset()
# case - TrainingArguments.gradient_accumulation_steps == 1
# - gradient_accumulation_kwargs['num_steps] == 1
# results in grad accum set to 1
args = RegressionTrainingArguments(
output_dir=tmp_dir,
gradient_accumulation_steps=1,
accelerator_config={
"gradient_accumulation_kwargs": {
"num_steps": 1,
}
},
)
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
self.assertEqual(trainer.accelerator.gradient_state.plugin_kwargs["num_steps"], 1)
# case - TrainingArguments.gradient_accumulation_steps > 1
# - gradient_accumulation_kwargs['num_steps] specified
# results in exception raised
args = RegressionTrainingArguments(
output_dir=tmp_dir,
gradient_accumulation_steps=2,
accelerator_config={
"gradient_accumulation_kwargs": {
"num_steps": 10,
}
},
)
with self.assertRaises(Exception) as context:
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
self.assertTrue("The `AcceleratorConfig`'s `num_steps` is set but" in str(context.exception))
@require_torch @require_torch
@is_staging_test @is_staging_test
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment