Unverified Commit 1a7c117d authored by Zach Mueller's avatar Zach Mueller Committed by GitHub
Browse files

Fix deprecated arg issue (#29372)

* Fix deprecated arg issue

* Trainer check too

* Check for dict or dataclass

* Simplify, make config always AcceleratorConfig

* Upstream to Trainer
parent cec77334
......@@ -80,7 +80,6 @@ from .trainer_callback import (
TrainerState,
)
from .trainer_pt_utils import (
AcceleratorConfig,
DistributedTensorGatherer,
IterableDatasetShard,
LabelSmoother,
......@@ -4116,21 +4115,10 @@ class Trainer:
gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs)
# create accelerator object
accelerator_kwargs = {}
if self.args.accelerator_config is not None:
accelerator_kwargs = self.args.accelerator_config
# dict and AcceleratorConfigs are parseable, json files are not
if isinstance(accelerator_kwargs, AcceleratorConfig):
accelerator_kwargs = accelerator_kwargs.to_dict()
elif isinstance(accelerator_kwargs, dict):
# Some values may need to go through non-accelerate aligned defaults
# and we need to run the `__post_init__` to set them
accelerator_kwargs = AcceleratorConfig(**accelerator_kwargs).to_dict()
self.accelerator = Accelerator(
deepspeed_plugin=self.args.deepspeed_plugin,
gradient_accumulation_plugin=gradient_accumulation_plugin,
**accelerator_kwargs,
**self.args.accelerator_config.to_dict(),
)
# some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag
self.gather_function = self.accelerator.gather_for_metrics
......
......@@ -1737,9 +1737,11 @@ class TrainingArguments:
os.environ[f"{prefix}USE_ORIG_PARAMS"] = self.fsdp_config.get("use_orig_params", "true")
if is_accelerate_available():
if not isinstance(self.accelerator_config, (AcceleratorConfig, dict)):
if not isinstance(self.accelerator_config, (AcceleratorConfig)):
if self.accelerator_config is None:
self.accelerator_config = AcceleratorConfig()
elif isinstance(self.accelerator_config, dict):
self.accelerator_config = AcceleratorConfig(**self.accelerator_config)
else:
self.accelerator_config = AcceleratorConfig.from_json_file(self.accelerator_config)
if self.dispatch_batches is not None:
......@@ -1748,7 +1750,7 @@ class TrainingArguments:
" `--accelerator_config {'dispatch_batches':VALUE} instead",
FutureWarning,
)
self.accelerator_config["dispatch_batches"] = self.dispatch_batches
self.accelerator_config.dispatch_batches = self.dispatch_batches
if self.split_batches is not None:
warnings.warn(
......@@ -1756,7 +1758,7 @@ class TrainingArguments:
" `--accelerator_config {'split_batches':VALUE} instead",
FutureWarning,
)
self.accelerator_config["split_batches"] = self.split_batches
self.accelerator_config.split_batches = self.split_batches
if self.tpu_metrics_debug:
warnings.warn(
......
......@@ -2633,6 +2633,20 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertEqual(trainer.accelerator.even_batches, False)
self.assertEqual(trainer.accelerator.dispatch_batches, None)
def test_accelerator_config_only_deprecated_args(self):
with tempfile.TemporaryDirectory() as tmp_dir:
with self.assertWarns(FutureWarning) as cm:
args = RegressionTrainingArguments(
output_dir=tmp_dir,
split_batches=True,
)
self.assertIn("split_batches", str(cm.warnings[0].message))
config = RegressionModelConfig(a=1.5, b=2.5)
model = RegressionPreTrainedModel(config)
eval_dataset = SampleIterableDataset()
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
self.assertEqual(trainer.accelerator.split_batches, True)
@require_torch
@is_staging_test
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment