"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "965e98dc54c4853b34b3a49810f24856a9fcdba8"
Unverified Commit e27d9308 authored by Zach Mueller's avatar Zach Mueller Committed by GitHub
Browse files

Raise relevent err when wrong type is passed in as the accelerator_config (#29997)

* Raise relevent err

* Use type instead
parent 0eaef0c7
...@@ -1815,6 +1815,13 @@ class TrainingArguments: ...@@ -1815,6 +1815,13 @@ class TrainingArguments:
self.accelerator_config = AcceleratorConfig() self.accelerator_config = AcceleratorConfig()
elif isinstance(self.accelerator_config, dict): elif isinstance(self.accelerator_config, dict):
self.accelerator_config = AcceleratorConfig(**self.accelerator_config) self.accelerator_config = AcceleratorConfig(**self.accelerator_config)
# Check that a user didn't pass in the class instantiator
# such as `accelerator_config = AcceleratorConfig`
elif isinstance(self.accelerator_config, type):
raise NotImplementedError(
"Tried passing in a callable to `accelerator_config`, but this is not supported. "
"Please pass in a fully constructed `AcceleratorConfig` object instead."
)
else: else:
self.accelerator_config = AcceleratorConfig.from_json_file(self.accelerator_config) self.accelerator_config = AcceleratorConfig.from_json_file(self.accelerator_config)
if self.dispatch_batches is not None: if self.dispatch_batches is not None:
......
...@@ -3104,6 +3104,35 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): ...@@ -3104,6 +3104,35 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset) trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
self.assertTrue("The `AcceleratorConfig`'s `num_steps` is set but" in str(context.exception)) self.assertTrue("The `AcceleratorConfig`'s `num_steps` is set but" in str(context.exception))
def test_accelerator_config_not_instantiated(self):
# Checks that accelerator kwargs can be passed through
# and the accelerator is initialized respectively
with tempfile.TemporaryDirectory() as tmp_dir:
with self.assertRaises(NotImplementedError) as context:
_ = RegressionTrainingArguments(
output_dir=tmp_dir,
accelerator_config=AcceleratorConfig,
)
self.assertTrue("Tried passing in a callable to `accelerator_config`" in str(context.exception))
# Now test with a custom subclass
@dataclasses.dataclass
class CustomAcceleratorConfig(AcceleratorConfig):
pass
@dataclasses.dataclass
class CustomTrainingArguments(TrainingArguments):
accelerator_config: dict = dataclasses.field(
default=CustomAcceleratorConfig,
)
with tempfile.TemporaryDirectory() as tmp_dir:
with self.assertRaises(NotImplementedError) as context:
_ = CustomTrainingArguments(
output_dir=tmp_dir,
)
self.assertTrue("Tried passing in a callable to `accelerator_config`" in str(context.exception))
@require_torch @require_torch
@is_staging_test @is_staging_test
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment