Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
b0a3fdfe
Commit
b0a3fdfe
authored
Dec 29, 2020
by
mohammad
Browse files
slight refactor of clip grads
parent
6191ff59
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
6 deletions
+11
-6
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+11
-6
No files found.
megatron/optimizer/optimizer.py
View file @
b0a3fdfe
...
...
@@ -68,14 +68,19 @@ def _clip_grad_norm(parameters, max_norm, norm_type=2):
# - grad should not be none
# - parameter should not be shared
# - should not be a replica due to tensor model parallelism
filtered_parameters
=
[]
params_with_grads
=
[]
params_for_norm
=
[]
for
param
in
parameters
:
# Make sure the grads are in fp32
assert
param
.
grad
.
type
()
==
'torch.cuda.FloatTensor'
grad_not_none
=
param
.
grad
is
not
None
is_not_shared
=
not
hasattr
(
param
,
'shared'
)
or
not
param
.
shared
is_not_tp_duplicate
=
param
.
tensor_model_parallel
or
\
(
mpu
.
get_tensor_model_parallel_rank
()
==
0
)
if
grad_not_none
:
params_with_grads
.
append
(
param
)
if
grad_not_none
and
is_not_shared
and
is_not_tp_duplicate
:
filtered_parameters
.
append
(
param
)
params_for_norm
.
append
(
param
)
# Norm parameters.
max_norm
=
float
(
max_norm
)
...
...
@@ -85,7 +90,7 @@ def _clip_grad_norm(parameters, max_norm, norm_type=2):
# Calculate norm.
if
norm_type
==
inf
:
total_norm
=
max
(
param
.
grad
.
detach
().
abs
().
max
()
for
param
in
filtered_parameters
)
for
param
in
params_for_norm
)
total_norm_cuda
=
torch
.
cuda
.
FloatTensor
([
float
(
total_norm
)])
# Take max across all model-parallel GPUs.
torch
.
distributed
.
all_reduce
(
total_norm_cuda
,
...
...
@@ -94,7 +99,7 @@ def _clip_grad_norm(parameters, max_norm, norm_type=2):
total_norm
=
total_norm_cuda
[
0
].
item
()
else
:
for
param
in
filtered_parameters
:
for
param
in
params_for_norm
:
param_norm
=
torch
.
norm
(
param
.
grad
.
detach
(),
norm_type
)
total_norm
+=
param_norm
.
item
()
**
norm_type
# Sum across all model-parallel GPUs.
...
...
@@ -105,9 +110,9 @@ def _clip_grad_norm(parameters, max_norm, norm_type=2):
total_norm
=
total_norm_cuda
[
0
].
item
()
**
(
1.
/
norm_type
)
# Scale.
clip_coef
=
max_norm
/
(
total_norm
+
1e-6
)
clip_coef
=
max_norm
/
(
total_norm
+
1
.0
e-6
)
if
clip_coef
<
1.0
:
for
param
in
param
eter
s
:
for
param
in
param
s_with_grad
s
:
param
.
grad
.
detach
().
mul_
(
clip_coef
)
return
total_norm
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment