Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
636acc75
Unverified
Commit
636acc75
authored
Aug 18, 2023
by
Sourab Mangrulkar
Committed by
GitHub
Aug 18, 2023
Browse files
fix z3 init when using accelerate launcher (#25589)
parent
8d2f953f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
9 deletions
+11
-9
src/transformers/training_args.py
src/transformers/training_args.py
+11
-9
No files found.
src/transformers/training_args.py
View file @
636acc75
...
@@ -1467,6 +1467,15 @@ class TrainingArguments:
...
@@ -1467,6 +1467,15 @@ class TrainingArguments:
torch
.
backends
.
cudnn
.
allow_tf32
=
False
torch
.
backends
.
cudnn
.
allow_tf32
=
False
# no need to assert on else
# no need to assert on else
# if training args is specified, it will override the one specified in the accelerate config
if
self
.
half_precision_backend
!=
"apex"
and
len
(
self
.
sharded_ddp
)
==
0
:
mixed_precision_dtype
=
os
.
environ
.
get
(
"ACCELERATE_MIXED_PRECISION"
,
"no"
)
if
self
.
fp16
:
mixed_precision_dtype
=
"fp16"
elif
self
.
bf16
:
mixed_precision_dtype
=
"bf16"
os
.
environ
[
"ACCELERATE_MIXED_PRECISION"
]
=
mixed_precision_dtype
if
self
.
report_to
is
None
:
if
self
.
report_to
is
None
:
logger
.
info
(
logger
.
info
(
"The default value for the training argument `--report_to` will change in v5 (from all installed "
"The default value for the training argument `--report_to` will change in v5 (from all installed "
...
@@ -1655,6 +1664,8 @@ class TrainingArguments:
...
@@ -1655,6 +1664,8 @@ class TrainingArguments:
from
accelerate.utils
import
DeepSpeedPlugin
from
accelerate.utils
import
DeepSpeedPlugin
self
.
deepspeed_plugin
=
DeepSpeedPlugin
()
self
.
deepspeed_plugin
=
DeepSpeedPlugin
()
mixed_precision
=
os
.
environ
.
get
(
"ACCELERATE_MIXED_PRECISION"
,
"no"
)
self
.
deepspeed_plugin
.
set_mixed_precision
(
mixed_precision
)
self
.
deepspeed_plugin
.
set_deepspeed_weakref
()
self
.
deepspeed_plugin
.
set_deepspeed_weakref
()
if
self
.
push_to_hub_token
is
not
None
:
if
self
.
push_to_hub_token
is
not
None
:
...
@@ -1692,15 +1703,6 @@ class TrainingArguments:
...
@@ -1692,15 +1703,6 @@ class TrainingArguments:
FutureWarning
,
FutureWarning
,
)
)
# if training args is specified, it will override the one specified in the accelerate config
if
self
.
half_precision_backend
!=
"apex"
and
len
(
self
.
sharded_ddp
)
==
0
:
mixed_precision_dtype
=
os
.
environ
.
get
(
"ACCELERATE_MIXED_PRECISION"
,
"no"
)
if
self
.
fp16
:
mixed_precision_dtype
=
"fp16"
elif
self
.
bf16
:
mixed_precision_dtype
=
"bf16"
os
.
environ
[
"ACCELERATE_MIXED_PRECISION"
]
=
mixed_precision_dtype
# Finally set the `TrainingArguments` to be immutable
# Finally set the `TrainingArguments` to be immutable
self
.
_frozen
=
True
self
.
_frozen
=
True
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment