Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
a495871b
Commit
a495871b
authored
Dec 29, 2020
by
mohammad
Browse files
added clip grads
parent
b77d9062
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
114 additions
and
0 deletions
+114
-0
megatron/optimizer/clip_grads.py
megatron/optimizer/clip_grads.py
+114
-0
No files found.
megatron/optimizer/clip_grads.py
0 → 100644
View file @
a495871b
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Gradient clipping."""
import
torch
from
torch._six
import
inf
from
apex.multi_tensor_apply
import
multi_tensor_applier
import
amp_C
from
megatron
import
mpu
def
clip_grad_norm_fp32
(
parameters
,
max_norm
,
norm_type
=
2
):
"""Clips gradient norm of an iterable of parameters whose gradients
are in fp32.
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
added functionality to handle model parallel parameters. Note that
the gradients are modified in place.
Arguments:
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have gradients normalized
max_norm (float or int): max norm of the gradients
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
Returns:
Total norm of the parameters (viewed as a single vector).
"""
if
isinstance
(
parameters
,
torch
.
Tensor
):
parameters
=
[
parameters
]
# Filter parameters based on:
# - grad should not be none
# - parameter should not be shared
# - should not be a replica due to tensor model parallelism
grads
=
[]
grads_for_norm
=
[]
for
param
in
parameters
:
# Make sure the grads are in fp32
assert
param
.
grad
.
type
()
==
'torch.cuda.FloatTensor'
grad_not_none
=
param
.
grad
is
not
None
is_not_shared
=
not
hasattr
(
param
,
'shared'
)
or
not
param
.
shared
is_not_tp_duplicate
=
param
.
tensor_model_parallel
or
\
(
mpu
.
get_tensor_model_parallel_rank
()
==
0
)
grad
=
param
.
grad
.
detach
()
if
grad_not_none
:
grads
.
append
(
grad
)
if
grad_not_none
and
is_not_shared
and
is_not_tp_duplicate
:
grads_for_norm
.
append
(
grad
)
# Norm parameters.
max_norm
=
float
(
max_norm
)
norm_type
=
float
(
norm_type
)
total_norm
=
0.0
# Calculate norm.
if
norm_type
==
inf
:
total_norm
=
max
(
grad
.
abs
().
max
()
for
grad
in
grads_for_norm
)
total_norm_cuda
=
torch
.
cuda
.
FloatTensor
([
float
(
total_norm
)])
# Take max across all model-parallel GPUs.
torch
.
distributed
.
all_reduce
(
total_norm_cuda
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
mpu
.
get_model_parallel_group
())
total_norm
=
total_norm_cuda
[
0
].
item
()
else
:
if
norm_type
==
2.0
:
dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
grad_norm
,
_
=
multi_tensor_applier
(
amp_C
.
multi_tensor_l2norm
,
dummy_overflow_buf
,
[
grads_for_norm
],
False
# no per-parameter norm
)
total_norm
=
grad_norm
**
norm_type
else
:
for
grad
in
grads_for_norm
:
grad_norm
=
torch
.
norm
(
grad
,
norm_type
)
total_norm
+=
grad_norm
**
norm_type
# Sum across all model-parallel GPUs.
torch
.
distributed
.
all_reduce
(
total_norm
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
mpu
.
get_model_parallel_group
())
total_norm
=
total_norm
.
item
()
**
(
1.0
/
norm_type
)
# Scale.
clip_coeff
=
max_norm
/
(
total_norm
+
1.0e-6
)
if
clip_coeff
<
1.0
:
dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
multi_tensor_applier
(
amp_C
.
multi_tensor_scale
,
dummy_overflow_buf
,
[
grads
,
grads
],
clip_coeff
)
return
total_norm
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment