Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
6191ff59
Commit
6191ff59
authored
Dec 27, 2020
by
Mohammad
Browse files
found a bug in l2 norm calculation
parent
b84d7a90
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
4 deletions
+3
-4
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+3
-4
No files found.
megatron/optimizer/optimizer.py
View file @
6191ff59
...
@@ -76,7 +76,6 @@ def _clip_grad_norm(parameters, max_norm, norm_type=2):
...
@@ -76,7 +76,6 @@ def _clip_grad_norm(parameters, max_norm, norm_type=2):
(
mpu
.
get_tensor_model_parallel_rank
()
==
0
)
(
mpu
.
get_tensor_model_parallel_rank
()
==
0
)
if
grad_not_none
and
is_not_shared
and
is_not_tp_duplicate
:
if
grad_not_none
and
is_not_shared
and
is_not_tp_duplicate
:
filtered_parameters
.
append
(
param
)
filtered_parameters
.
append
(
param
)
parameters
=
filtered_parameters
# Norm parameters.
# Norm parameters.
max_norm
=
float
(
max_norm
)
max_norm
=
float
(
max_norm
)
...
@@ -86,7 +85,7 @@ def _clip_grad_norm(parameters, max_norm, norm_type=2):
...
@@ -86,7 +85,7 @@ def _clip_grad_norm(parameters, max_norm, norm_type=2):
# Calculate norm.
# Calculate norm.
if
norm_type
==
inf
:
if
norm_type
==
inf
:
total_norm
=
max
(
param
.
grad
.
detach
().
abs
().
max
()
total_norm
=
max
(
param
.
grad
.
detach
().
abs
().
max
()
for
param
in
parameters
)
for
param
in
filtered_
parameters
)
total_norm_cuda
=
torch
.
cuda
.
FloatTensor
([
float
(
total_norm
)])
total_norm_cuda
=
torch
.
cuda
.
FloatTensor
([
float
(
total_norm
)])
# Take max across all model-parallel GPUs.
# Take max across all model-parallel GPUs.
torch
.
distributed
.
all_reduce
(
total_norm_cuda
,
torch
.
distributed
.
all_reduce
(
total_norm_cuda
,
...
@@ -95,7 +94,7 @@ def _clip_grad_norm(parameters, max_norm, norm_type=2):
...
@@ -95,7 +94,7 @@ def _clip_grad_norm(parameters, max_norm, norm_type=2):
total_norm
=
total_norm_cuda
[
0
].
item
()
total_norm
=
total_norm_cuda
[
0
].
item
()
else
:
else
:
for
param
in
parameters
:
for
param
in
filtered_
parameters
:
param_norm
=
torch
.
norm
(
param
.
grad
.
detach
(),
norm_type
)
param_norm
=
torch
.
norm
(
param
.
grad
.
detach
(),
norm_type
)
total_norm
+=
param_norm
.
item
()
**
norm_type
total_norm
+=
param_norm
.
item
()
**
norm_type
# Sum across all model-parallel GPUs.
# Sum across all model-parallel GPUs.
...
@@ -107,7 +106,7 @@ def _clip_grad_norm(parameters, max_norm, norm_type=2):
...
@@ -107,7 +106,7 @@ def _clip_grad_norm(parameters, max_norm, norm_type=2):
# Scale.
# Scale.
clip_coef
=
max_norm
/
(
total_norm
+
1e-6
)
clip_coef
=
max_norm
/
(
total_norm
+
1e-6
)
if
clip_coef
<
1
:
if
clip_coef
<
1
.0
:
for
param
in
parameters
:
for
param
in
parameters
:
param
.
grad
.
detach
().
mul_
(
clip_coef
)
param
.
grad
.
detach
().
mul_
(
clip_coef
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment