Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
77d6941e
Unverified
Commit
77d6941e
authored
Dec 17, 2020
by
Sylvain Gugger
Committed by
GitHub
Dec 17, 2020
Browse files
Fix gradient clipping for Sharded DDP (#9168)
* Fix gradient clipping for Sharded DDP * Fix typos in comments
parent
1aca3d6a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
8 deletions
+17
-8
src/transformers/trainer.py
src/transformers/trainer.py
+17
-8
No files found.
src/transformers/trainer.py
View file @
77d6941e
...
...
@@ -804,14 +804,23 @@ class Trainer:
steps_in_epoch
<=
self
.
args
.
gradient_accumulation_steps
and
(
step
+
1
)
==
steps_in_epoch
):
if
self
.
use_amp
:
self
.
scaler
.
unscale_
(
self
.
optimizer
)
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
self
.
args
.
max_grad_norm
)
elif
self
.
use_apex
:
torch
.
nn
.
utils
.
clip_grad_norm_
(
amp
.
master_params
(
self
.
optimizer
),
self
.
args
.
max_grad_norm
)
else
:
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
self
.
args
.
max_grad_norm
)
# Gradient clipping
if
self
.
args
.
max_grad_norm
is
not
None
and
self
.
args
.
max_grad_norm
>
0
:
if
self
.
use_amp
:
# AMP: gradients need unscaling
self
.
scaler
.
unscale_
(
self
.
optimizer
)
if
hasattr
(
self
.
optimizer
,
"clip_grad_norm"
):
# Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping
self
.
optimizer
.
clip_grad_norm
(
self
.
args
.
max_grad_norm
)
else
:
# Revert to normal clipping otherwise, handling Apex or full precision
torch
.
nn
.
utils
.
clip_grad_norm_
(
amp
.
master_params
(
self
.
optimizer
)
if
self
.
use_apex
else
model
.
parameters
(),
self
.
args
.
max_grad_norm
,
)
# Optimizer step
if
is_torch_tpu_available
():
xm
.
optimizer_step
(
self
.
optimizer
)
elif
self
.
use_amp
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment