Unverified Commit cb298978 authored by Sangbum Daniel Choi's avatar Sangbum Daniel Choi Committed by GitHub
Browse files

add gather_use_object arguments (#31514)



* add gather_use_object arguments

* fix name and pass the CI test for Seq2SeqTrainer

* make style

* make it to functools

* fix typo

* add accelerate version:

* adding warning

* Update src/transformers/trainer.py
Co-authored-by: default avatarMarc Sun <57196510+SunMarc@users.noreply.github.com>

* make style

* Update src/transformers/training_args.py

* check function move to initial part

* add test for eval_use_gather_object

---------
Co-authored-by: default avatarMarc Sun <57196510+SunMarc@users.noreply.github.com>
parent 82a1fc72
...@@ -4605,6 +4605,11 @@ class Trainer: ...@@ -4605,6 +4605,11 @@ class Trainer:
# some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag # some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag
self.gather_function = self.accelerator.gather_for_metrics self.gather_function = self.accelerator.gather_for_metrics
if "use_gather_object" in inspect.signature(self.gather_function).parameters.keys():
self.gather_function = functools.partial(
self.gather_function, use_gather_object=self.args.eval_use_gather_object
)
# deepspeed and accelerate flags covering both trainer args and accelerate launcher # deepspeed and accelerate flags covering both trainer args and accelerate launcher
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
......
...@@ -773,8 +773,11 @@ class TrainingArguments: ...@@ -773,8 +773,11 @@ class TrainingArguments:
that takes a boolean argument `compute_result`, which when passed `True`, will trigger the final global that takes a boolean argument `compute_result`, which when passed `True`, will trigger the final global
summary statistics from the batch-level summary statistics you've accumulated over the evaluation set. summary statistics from the batch-level summary statistics you've accumulated over the evaluation set.
eval_on_start(`bool`, *optional*, defaults to `False`): eval_on_start (`bool`, *optional*, defaults to `False`):
Whether to perform a evaluation step (sanity check) before the training to ensure the validation steps works correctly. Whether to perform a evaluation step (sanity check) before the training to ensure the validation steps works correctly.
eval_use_gather_object (`bool`, *optional*, defaults to `False`):
Whether to run recursively gather object in a nested list/tuple/dictionary of objects from all devices.
""" """
framework = "pt" framework = "pt"
...@@ -1465,6 +1468,13 @@ class TrainingArguments: ...@@ -1465,6 +1468,13 @@ class TrainingArguments:
}, },
) )
eval_use_gather_object: Optional[bool] = field(
default=False,
metadata={
"help": "Whether to run recursively gather object in a nested list/tuple/dictionary of objects from all devices."
},
)
def __post_init__(self): def __post_init__(self):
# Parse in args that could be `dict` sent in from the CLI as a string # Parse in args that could be `dict` sent in from the CLI as a string
for field in _VALID_DICT_FIELDS: for field in _VALID_DICT_FIELDS:
...@@ -1992,6 +2002,12 @@ class TrainingArguments: ...@@ -1992,6 +2002,12 @@ class TrainingArguments:
FutureWarning, FutureWarning,
) )
if self.eval_use_gather_object and not is_accelerate_available("0.30.0"):
raise ValueError(
"--eval_use_gather_object requires Accelerate to be version of `accelerate` < 0.30.0."
"This is not supported and we recommend you to update your version."
)
def __str__(self): def __str__(self):
self_as_dict = asdict(self) self_as_dict = asdict(self)
......
...@@ -132,6 +132,7 @@ if is_torch_available(): ...@@ -132,6 +132,7 @@ if is_torch_available():
# for version specific tests in TrainerIntegrationTest # for version specific tests in TrainerIntegrationTest
require_accelerate_version_min_0_28 = partial(require_accelerate, min_version="0.28") require_accelerate_version_min_0_28 = partial(require_accelerate, min_version="0.28")
require_accelerate_version_min_0_30 = partial(require_accelerate, min_version="0.30")
GRAD_ACCUM_KWARGS_VERSION_AVAILABLE = is_accelerate_available("0.28") GRAD_ACCUM_KWARGS_VERSION_AVAILABLE = is_accelerate_available("0.28")
if is_accelerate_available(): if is_accelerate_available():
from accelerate import Accelerator from accelerate import Accelerator
...@@ -3565,6 +3566,17 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): ...@@ -3565,6 +3566,17 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertIn("torch_dtype", args_dict) self.assertIn("torch_dtype", args_dict)
self.assertEqual(args_dict["torch_dtype"], dtype) self.assertEqual(args_dict["torch_dtype"], dtype)
@require_accelerate_version_min_0_30
def test_eval_use_gather_object(self):
train_dataset = RegressionDataset()
eval_dataset = RegressionDataset()
model = RegressionDictModel()
args = TrainingArguments("./regression", report_to="none", eval_use_gather_object=True)
trainer = Trainer(model, args, train_dataset=train_dataset, eval_dataset=eval_dataset)
trainer.train()
_ = trainer.evaluate()
_ = trainer.predict(eval_dataset)
@require_torch @require_torch
@is_staging_test @is_staging_test
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment