Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
9a010310
Commit
9a010310
authored
Dec 29, 2020
by
mohammad
Browse files
add multi-tensor-apply to clip grad
parent
345f5d0d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
6 deletions
+17
-6
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+17
-6
No files found.
megatron/optimizer/optimizer.py
View file @
9a010310
...
...
@@ -99,15 +99,26 @@ def _clip_grad_norm(parameters, max_norm, norm_type=2):
total_norm
=
total_norm_cuda
[
0
].
item
()
else
:
for
grad
in
grads_for_norm
:
grad_norm
=
torch
.
norm
(
grad
,
norm_type
)
total_norm
+=
grad_norm
.
item
()
**
norm_type
if
norm_type
==
2.0
:
dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
grad_norm
,
_
=
multi_tensor_applier
(
amp_C
.
multi_tensor_l2norm
,
dummy_overflow_buf
,
[
grads_for_norm
],
False
# no per-parameter norm
)
total_norm
=
grad_norm
**
norm_type
else
:
for
grad
in
grads_for_norm
:
grad_norm
=
torch
.
norm
(
grad
,
norm_type
)
total_norm
+=
grad_norm
**
norm_type
# Sum across all model-parallel GPUs.
total_norm_cuda
=
torch
.
cuda
.
FloatTensor
([
float
(
total_norm
)])
torch
.
distributed
.
all_reduce
(
total_norm_cuda
,
torch
.
distributed
.
all_reduce
(
total_norm
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
mpu
.
get_model_parallel_group
())
total_norm
=
total_norm
_cuda
[
0
]
.
item
()
**
(
1.0
/
norm_type
)
total_norm
=
total_norm
.
item
()
**
(
1.0
/
norm_type
)
# Scale.
clip_coeff
=
max_norm
/
(
total_norm
+
1.0e-6
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment