Unverified Commit e27d9308 authored by Zach Mueller's avatar Zach Mueller Committed by GitHub
Browse files

Raise relevent err when wrong type is passed in as the accelerator_config (#29997)

* Raise relevent err

* Use type instead
parent 0eaef0c7
...@@ -1815,6 +1815,13 @@ class TrainingArguments: ...@@ -1815,6 +1815,13 @@ class TrainingArguments:
self.accelerator_config = AcceleratorConfig() self.accelerator_config = AcceleratorConfig()
elif isinstance(self.accelerator_config, dict): elif isinstance(self.accelerator_config, dict):
self.accelerator_config = AcceleratorConfig(**self.accelerator_config) self.accelerator_config = AcceleratorConfig(**self.accelerator_config)
# Check that a user didn't pass in the class instantiator
# such as `accelerator_config = AcceleratorConfig`
elif isinstance(self.accelerator_config, type):
raise NotImplementedError(
"Tried passing in a callable to `accelerator_config`, but this is not supported. "
"Please pass in a fully constructed `AcceleratorConfig` object instead."
)
else: else:
self.accelerator_config = AcceleratorConfig.from_json_file(self.accelerator_config) self.accelerator_config = AcceleratorConfig.from_json_file(self.accelerator_config)
if self.dispatch_batches is not None: if self.dispatch_batches is not None:
......
...@@ -3104,6 +3104,35 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): ...@@ -3104,6 +3104,35 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset) trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
self.assertTrue("The `AcceleratorConfig`'s `num_steps` is set but" in str(context.exception)) self.assertTrue("The `AcceleratorConfig`'s `num_steps` is set but" in str(context.exception))
def test_accelerator_config_not_instantiated(self):
# Checks that accelerator kwargs can be passed through
# and the accelerator is initialized respectively
with tempfile.TemporaryDirectory() as tmp_dir:
with self.assertRaises(NotImplementedError) as context:
_ = RegressionTrainingArguments(
output_dir=tmp_dir,
accelerator_config=AcceleratorConfig,
)
self.assertTrue("Tried passing in a callable to `accelerator_config`" in str(context.exception))
# Now test with a custom subclass
@dataclasses.dataclass
class CustomAcceleratorConfig(AcceleratorConfig):
pass
@dataclasses.dataclass
class CustomTrainingArguments(TrainingArguments):
accelerator_config: dict = dataclasses.field(
default=CustomAcceleratorConfig,
)
with tempfile.TemporaryDirectory() as tmp_dir:
with self.assertRaises(NotImplementedError) as context:
_ = CustomTrainingArguments(
output_dir=tmp_dir,
)
self.assertTrue("Tried passing in a callable to `accelerator_config`" in str(context.exception))
@require_torch @require_torch
@is_staging_test @is_staging_test
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment