Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
fb5b2b36
Commit
fb5b2b36
authored
Apr 20, 2020
by
Mohammad
Browse files
fixed a bug in l2 grad clip
parent
eb0a8bf0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
9 deletions
+13
-9
megatron/mpu/grads.py
megatron/mpu/grads.py
+13
-9
No files found.
megatron/mpu/grads.py
View file @
fb5b2b36
...
@@ -32,16 +32,20 @@ def l2_grad_clipper(parameters, max_norm):
...
@@ -32,16 +32,20 @@ def l2_grad_clipper(parameters, max_norm):
"""Efficient L2 norm gradient clipping."""
"""Efficient L2 norm gradient clipping."""
overflow_buf
=
torch
.
zeros
(
1
,
dtype
=
torch
.
int
,
device
=
'cuda'
)
overflow_buf
=
torch
.
zeros
(
1
,
dtype
=
torch
.
int
,
device
=
'cuda'
)
# Make sure we have an iterable.
if
isinstance
(
parameters
,
torch
.
Tensor
):
if
isinstance
(
parameters
,
torch
.
Tensor
):
parameters
=
[
parameters
]
parameters
=
[
parameters
]
mp_rank_is_zero
=
(
get_model_parallel_rank
()
==
0
)
# Filter parameters with gradients.
parameters
=
list
(
filter
(
lambda
p
:
(
p
.
grad
is
not
None
)
and
parameters_with_grads
=
list
(
filter
(
(
p
.
model_parallel
or
mp_rank_is_zero
),
lambda
p
:
p
.
grad
is
not
None
,
parameters
))
parameters
))
# Filter parameters for norm calculations.
parameters_for_norm
=
list
(
filter
(
lambda
p
:
p
.
model_parallel
or
mp_rank_is_zero
,
parameters_with_grads
))
# Calculate L2 norm.
norm
,
_
=
multi_tensor_applier
(
norm
,
_
=
multi_tensor_applier
(
amp_C
.
multi_tensor_l2norm
,
amp_C
.
multi_tensor_l2norm
,
overflow_buf
,
overflow_buf
,
[
parameters
],
[
parameters
_for_norm
],
False
# no per-parameter norm
False
# no per-parameter norm
)
)
# Sum across all model parallel GPUs.
# Sum across all model parallel GPUs.
...
@@ -50,10 +54,10 @@ def l2_grad_clipper(parameters, max_norm):
...
@@ -50,10 +54,10 @@ def l2_grad_clipper(parameters, max_norm):
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
get_model_parallel_group
())
group
=
get_model_parallel_group
())
total_norm
=
norm_2
.
item
()
**
0.5
total_norm
=
norm_2
.
item
()
**
0.5
# Scale to get max_norm.
clip_coef
=
max_norm
/
(
total_norm
+
1e-6
)
clip_coef
=
float
(
max_norm
)
/
(
total_norm
+
1
.0
e-6
)
grads
=
[
p
.
grad
for
p
in
parameters
]
grads
=
[
p
.
grad
for
p
in
parameters
_with_grads
]
if
clip_coef
<
1
:
if
clip_coef
<
1
.0
:
multi_tensor_applier
(
multi_tensor_applier
(
amp_C
.
multi_tensor_scale
,
amp_C
.
multi_tensor_scale
,
overflow_buf
,
overflow_buf
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment