Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
e27d9308
Unverified
Commit
e27d9308
authored
Apr 16, 2024
by
Zach Mueller
Committed by
GitHub
Apr 16, 2024
Browse files
Raise relevent err when wrong type is passed in as the accelerator_config (#29997)
* Raise relevent err * Use type instead
parent
0eaef0c7
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
36 additions
and
0 deletions
+36
-0
src/transformers/training_args.py
src/transformers/training_args.py
+7
-0
tests/trainer/test_trainer.py
tests/trainer/test_trainer.py
+29
-0
No files found.
src/transformers/training_args.py
View file @
e27d9308
...
@@ -1815,6 +1815,13 @@ class TrainingArguments:
...
@@ -1815,6 +1815,13 @@ class TrainingArguments:
self
.
accelerator_config
=
AcceleratorConfig
()
self
.
accelerator_config
=
AcceleratorConfig
()
elif
isinstance
(
self
.
accelerator_config
,
dict
):
elif
isinstance
(
self
.
accelerator_config
,
dict
):
self
.
accelerator_config
=
AcceleratorConfig
(
**
self
.
accelerator_config
)
self
.
accelerator_config
=
AcceleratorConfig
(
**
self
.
accelerator_config
)
# Check that a user didn't pass in the class instantiator
# such as `accelerator_config = AcceleratorConfig`
elif
isinstance
(
self
.
accelerator_config
,
type
):
raise
NotImplementedError
(
"Tried passing in a callable to `accelerator_config`, but this is not supported. "
"Please pass in a fully constructed `AcceleratorConfig` object instead."
)
else
:
else
:
self
.
accelerator_config
=
AcceleratorConfig
.
from_json_file
(
self
.
accelerator_config
)
self
.
accelerator_config
=
AcceleratorConfig
.
from_json_file
(
self
.
accelerator_config
)
if
self
.
dispatch_batches
is
not
None
:
if
self
.
dispatch_batches
is
not
None
:
...
...
tests/trainer/test_trainer.py
View file @
e27d9308
...
@@ -3104,6 +3104,35 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
...
@@ -3104,6 +3104,35 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
trainer
=
Trainer
(
model
=
model
,
args
=
args
,
eval_dataset
=
eval_dataset
)
trainer
=
Trainer
(
model
=
model
,
args
=
args
,
eval_dataset
=
eval_dataset
)
self
.
assertTrue
(
"The `AcceleratorConfig`'s `num_steps` is set but"
in
str
(
context
.
exception
))
self
.
assertTrue
(
"The `AcceleratorConfig`'s `num_steps` is set but"
in
str
(
context
.
exception
))
def
test_accelerator_config_not_instantiated
(
self
):
# Checks that accelerator kwargs can be passed through
# and the accelerator is initialized respectively
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
with
self
.
assertRaises
(
NotImplementedError
)
as
context
:
_
=
RegressionTrainingArguments
(
output_dir
=
tmp_dir
,
accelerator_config
=
AcceleratorConfig
,
)
self
.
assertTrue
(
"Tried passing in a callable to `accelerator_config`"
in
str
(
context
.
exception
))
# Now test with a custom subclass
@
dataclasses
.
dataclass
class
CustomAcceleratorConfig
(
AcceleratorConfig
):
pass
@
dataclasses
.
dataclass
class
CustomTrainingArguments
(
TrainingArguments
):
accelerator_config
:
dict
=
dataclasses
.
field
(
default
=
CustomAcceleratorConfig
,
)
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
with
self
.
assertRaises
(
NotImplementedError
)
as
context
:
_
=
CustomTrainingArguments
(
output_dir
=
tmp_dir
,
)
self
.
assertTrue
(
"Tried passing in a callable to `accelerator_config`"
in
str
(
context
.
exception
))
@
require_torch
@
require_torch
@
is_staging_test
@
is_staging_test
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment