Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
a9b1fc0a
Commit
a9b1fc0a
authored
Mar 09, 2022
by
Lawrence McAfee
Browse files
finished cleaning clip_grads.py
parent
9546d8f0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
35 deletions
+6
-35
megatron/optimizer/clip_grads.py
megatron/optimizer/clip_grads.py
+6
-35
No files found.
megatron/optimizer/clip_grads.py
View file @
a9b1fc0a
...
...
@@ -21,18 +21,10 @@ from torch._six import inf
from
apex.multi_tensor_apply
import
multi_tensor_applier
import
amp_C
# >>>
# from megatron import mpu
# <<<
from
megatron.model.module
import
param_is_not_shared
from
megatron.mpu.layers
import
param_is_not_tensor_parallel_duplicate
# >>>
from
lutil
import
pax
,
tp
DEBUG_ITERATION
=
1
# <<<
def
clip_grad_norm_fp32
(
parameters
,
max_norm
,
norm_type
=
2
,
model_parallel_group
=
None
,
ITERATION
=
None
):
...
...
@@ -56,10 +48,6 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2,
Total norm of the parameters (viewed as a single vector).
"""
# >>>
# raise Exception("currently debugging ... don't call me.")
# <<<
if
isinstance
(
parameters
,
torch
.
Tensor
):
parameters
=
[
parameters
]
...
...
@@ -119,17 +107,9 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2,
total_norm
+=
grad_norm
**
norm_type
# Sum across all model-parallel GPUs.
# >>>
from
megatron
import
get_args
args
=
get_args
()
if
args
.
use_distributed_optimizer
:
torch
.
distributed
.
all_reduce
(
total_norm
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
)
else
:
torch
.
distributed
.
all_reduce
(
total_norm
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
model_parallel_group
)
# <<<
torch
.
distributed
.
all_reduce
(
total_norm
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
model_parallel_group
)
total_norm
=
total_norm
.
item
()
**
(
1.0
/
norm_type
)
# Scale.
...
...
@@ -164,18 +144,9 @@ def count_zeros_fp32(parameters, model_parallel_group):
total_num_zeros
=
num_zeros
+
total_num_zeros
# Sum across all model-parallel GPUs.
# >>>
from
megatron
import
get_args
args
=
get_args
()
if
args
.
use_distributed_optimizer
:
torch
.
distributed
.
all_reduce
(
total_num_zeros
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
)
# pax({"total_num_zeros": total_num_zeros.item()})
else
:
torch
.
distributed
.
all_reduce
(
total_num_zeros
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
model_parallel_group
)
# <<<
torch
.
distributed
.
all_reduce
(
total_num_zeros
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
model_parallel_group
)
total_num_zeros
=
total_num_zeros
.
item
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment