Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
b2e4b091
".circleci/vscode:/vscode.git/clone" did not exist on "dc07ac2add8285e16a716564867d0b4b953f6735"
Unverified
Commit
b2e4b091
authored
Jul 30, 2022
by
Sourab Mangrulkar
Committed by
GitHub
Jul 30, 2022
Browse files
fix FSDP ShardedGradScaler (#18358)
renaming it
parent
51227e26
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
2 deletions
+6
-2
src/transformers/trainer.py
src/transformers/trainer.py
+6
-2
No files found.
src/transformers/trainer.py
View file @
b2e4b091
...
...
@@ -565,9 +565,11 @@ class Trainer:
self
.
scaler
=
ShardedGradScaler
()
elif
self
.
fsdp
is
not
None
:
if
self
.
amp_dtype
==
torch
.
float16
:
from
torch.distributed.fsdp.sharded_grad_scaler
import
ShardedGradScaler
from
torch.distributed.fsdp.sharded_grad_scaler
import
(
ShardedGradScaler
as
FSDPShardedGradScaler
,
)
self
.
scaler
=
ShardedGradScaler
()
self
.
scaler
=
FSDP
ShardedGradScaler
()
else
:
self
.
do_grad_scaling
=
False
self
.
use_cuda_amp
=
False
...
...
@@ -1366,6 +1368,8 @@ class Trainer:
transformer_cls_to_wrap
=
get_module_class_from_name
(
model
,
self
.
args
.
fsdp_transformer_layer_cls_to_wrap
)
if
transformer_cls_to_wrap
is
None
:
raise
Exception
(
"Could not find the transformer layer class to wrap in the model."
)
auto_wrap_policy
=
functools
.
partial
(
transformer_auto_wrap_policy
,
# Transformer layer class to wrap
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment