Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
9546d8f0
Commit
9546d8f0
authored
Mar 09, 2022
by
Lawrence McAfee
Browse files
passing 'model_parallel_group' to clip_grads, count_zeros
parent
e46230dc
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
46 additions
and
112 deletions
+46
-112
megatron/optimizer/__init__.py
megatron/optimizer/__init__.py
+13
-29
megatron/optimizer/clip_grads.py
megatron/optimizer/clip_grads.py
+13
-70
megatron/optimizer/distrib_optimizer.py
megatron/optimizer/distrib_optimizer.py
+7
-1
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+13
-12
No files found.
megatron/optimizer/__init__.py
View file @
9546d8f0
...
@@ -91,18 +91,6 @@ def get_megatron_optimizer(model,
...
@@ -91,18 +91,6 @@ def get_megatron_optimizer(model,
scale_lr_cond
,
scale_lr_cond
,
lr_mult
)
lr_mult
)
# >>>
# params = [ p for m in model for p in m.parameters() ]
# pax(0, {
# "params" : [ (p.tensor_model_parallel, tp(p)) for p in params ],
# })
# <<<
# >>>
# if args.use_distributed_optimizer:
# optimizer = DistributedFusedAdam(param_groups)
# elif args.optimizer == 'adam':
# <<<
if
args
.
optimizer
==
'adam'
:
if
args
.
optimizer
==
'adam'
:
optimizer
=
Adam
(
param_groups
,
optimizer
=
Adam
(
param_groups
,
lr
=
args
.
lr
,
lr
=
args
.
lr
,
...
@@ -123,7 +111,7 @@ def get_megatron_optimizer(model,
...
@@ -123,7 +111,7 @@ def get_megatron_optimizer(model,
if
args
.
DDP_impl
==
'local'
:
if
args
.
DDP_impl
==
'local'
:
params_have_main_grad
=
True
params_have_main_grad
=
True
if
args
.
fp16
or
args
.
bf16
:
if
args
.
fp16
or
args
.
bf16
or
args
.
use_distributed_optimizer
:
# Grad scaler:
# Grad scaler:
# if loss-scale is provided, instantiate the constant scaler.
# if loss-scale is provided, instantiate the constant scaler.
...
@@ -148,10 +136,10 @@ def get_megatron_optimizer(model,
...
@@ -148,10 +136,10 @@ def get_megatron_optimizer(model,
# Megatron optimizer.
# Megatron optimizer.
# >>>
# >>>
opt_ty
=
Float16
DistributedOptimizer
\
opt_ty
=
DistributedOptimizer
\
if
args
.
use_distributed_optimizer
\
if
args
.
use_distributed_optimizer
else
\
else
Float16OptimizerWithFloat16Params
Float16OptimizerWithFloat16Params
opt
=
opt_ty
(
optimizer
,
return
opt_ty
(
optimizer
,
args
.
clip_grad
,
args
.
clip_grad
,
args
.
log_num_zeros_in_grad
,
args
.
log_num_zeros_in_grad
,
params_have_main_grad
,
params_have_main_grad
,
...
@@ -159,20 +147,16 @@ def get_megatron_optimizer(model,
...
@@ -159,20 +147,16 @@ def get_megatron_optimizer(model,
args
.
bf16
,
args
.
bf16
,
grad_scaler
,
grad_scaler
,
model
)
model
)
# >>>
# opt.debug_main_param_sum(0, "after init")
# opt.debug_main_grad_sum(0, "after init")
# <<<
return
opt
# <<<
# <<<
# FP32.
# FP32.
# >>>
# >>>
opt_ty
=
Float32DistributedOptimizer
\
# opt_ty = Float32DistributedOptimizer \
if
args
.
use_distributed_optimizer
\
# if args.use_distributed_optimizer \
else
Float32Optimizer
# else Float32Optimizer
return
opt_ty
(
optimizer
,
args
.
clip_grad
,
# return opt_ty(optimizer, args.clip_grad,
args
.
log_num_zeros_in_grad
,
params_have_main_grad
,
args
.
use_contiguous_buffers_in_local_ddp
)
# <<<
# <<<
return
Float32Optimizer
(
optimizer
,
args
.
clip_grad
,
args
.
log_num_zeros_in_grad
,
params_have_main_grad
,
args
.
use_contiguous_buffers_in_local_ddp
)
megatron/optimizer/clip_grads.py
View file @
9546d8f0
...
@@ -21,7 +21,9 @@ from torch._six import inf
...
@@ -21,7 +21,9 @@ from torch._six import inf
from
apex.multi_tensor_apply
import
multi_tensor_applier
from
apex.multi_tensor_apply
import
multi_tensor_applier
import
amp_C
import
amp_C
from
megatron
import
mpu
# >>>
# from megatron import mpu
# <<<
from
megatron.model.module
import
param_is_not_shared
from
megatron.model.module
import
param_is_not_shared
from
megatron.mpu.layers
import
param_is_not_tensor_parallel_duplicate
from
megatron.mpu.layers
import
param_is_not_tensor_parallel_duplicate
...
@@ -31,7 +33,9 @@ from lutil import pax, tp
...
@@ -31,7 +33,9 @@ from lutil import pax, tp
DEBUG_ITERATION
=
1
DEBUG_ITERATION
=
1
# <<<
# <<<
def
clip_grad_norm_fp32
(
parameters
,
max_norm
,
norm_type
=
2
,
ITERATION
=
None
):
def
clip_grad_norm_fp32
(
parameters
,
max_norm
,
norm_type
=
2
,
model_parallel_group
=
None
,
ITERATION
=
None
):
"""Clips gradient norm of an iterable of parameters whose gradients
"""Clips gradient norm of an iterable of parameters whose gradients
are in fp32.
are in fp32.
...
@@ -45,13 +49,15 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2, ITERATION=None):
...
@@ -45,13 +49,15 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2, ITERATION=None):
max_norm (float or int): max norm of the gradients
max_norm (float or int): max norm of the gradients
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
infinity norm.
model_parallel_group (group): due to the nature of the distributed
optimizer, this is passed as an argument.
Returns:
Returns:
Total norm of the parameters (viewed as a single vector).
Total norm of the parameters (viewed as a single vector).
"""
"""
# >>>
# >>>
raise
Exception
(
"currently debugging ... don't call me."
)
#
raise Exception("currently debugging ... don't call me.")
# <<<
# <<<
if
isinstance
(
parameters
,
torch
.
Tensor
):
if
isinstance
(
parameters
,
torch
.
Tensor
):
...
@@ -75,26 +81,6 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2, ITERATION=None):
...
@@ -75,26 +81,6 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2, ITERATION=None):
grads
.
append
(
grad
)
grads
.
append
(
grad
)
if
grad_not_none
and
is_not_shared
and
is_not_tp_duplicate
:
if
grad_not_none
and
is_not_shared
and
is_not_tp_duplicate
:
grads_for_norm
.
append
(
grad
)
grads_for_norm
.
append
(
grad
)
# >>>
# else:
# pax(1, {
# "grad_not_none" : grad_not_none,
# "is_not_shared" : is_not_shared,
# "is_not_tp_duplicate" : is_not_tp_duplicate,
# })
# <<<
# >>>
# if ITERATION == DEBUG_ITERATION:
# pax(0, {
# "[LOC]" : "[** BEFORE CALC NORM **]",
# "[ITERATION]" : ITERATION,
# "max_norm" : max_norm,
# "parameters" : parameters,
# # "grads" : grads,
# "grads_for_norm" : grads_for_norm,
# })
# <<<
# Norm parameters.
# Norm parameters.
max_norm
=
float
(
max_norm
)
max_norm
=
float
(
max_norm
)
...
@@ -108,7 +94,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2, ITERATION=None):
...
@@ -108,7 +94,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2, ITERATION=None):
# Take max across all model-parallel GPUs.
# Take max across all model-parallel GPUs.
torch
.
distributed
.
all_reduce
(
total_norm_cuda
,
torch
.
distributed
.
all_reduce
(
total_norm_cuda
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
mpu
.
get_
model_parallel_group
()
)
group
=
model_parallel_group
)
total_norm
=
total_norm_cuda
[
0
].
item
()
total_norm
=
total_norm_cuda
[
0
].
item
()
else
:
else
:
...
@@ -117,13 +103,6 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2, ITERATION=None):
...
@@ -117,13 +103,6 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2, ITERATION=None):
# Use apex's multi-tensor applier for efficiency reasons.
# Use apex's multi-tensor applier for efficiency reasons.
# Multi-tensor applier takes a function and a list of list
# Multi-tensor applier takes a function and a list of list
# and performs the operation on that list all in one kernel.
# and performs the operation on that list all in one kernel.
# >>>
# pax(1, {
# # "fn" : amp_C.multi_tensor_l2norm,
# "dummy_overflow_buf" : tp(dummy_overflow_buf),
# "grads_for_norm" : grads_for_norm,
# })
# <<<
grad_norm
,
_
=
multi_tensor_applier
(
grad_norm
,
_
=
multi_tensor_applier
(
amp_C
.
multi_tensor_l2norm
,
amp_C
.
multi_tensor_l2norm
,
dummy_overflow_buf
,
dummy_overflow_buf
,
...
@@ -139,18 +118,6 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2, ITERATION=None):
...
@@ -139,18 +118,6 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2, ITERATION=None):
grad_norm
=
torch
.
norm
(
grad
,
norm_type
)
grad_norm
=
torch
.
norm
(
grad
,
norm_type
)
total_norm
+=
grad_norm
**
norm_type
total_norm
+=
grad_norm
**
norm_type
# >>>
# if ITERATION == DEBUG_ITERATION:
# pax(0, {
# "[LOC]" : "[** CALC NORM **]",
# "[ITERATION]" : ITERATION,
# "max_norm" : max_norm,
# "norm_type" : norm_type,
# "grad_norm" : tp(grad_norm),
# "total_norm" : tp(total_norm),
# })
# <<<
# Sum across all model-parallel GPUs.
# Sum across all model-parallel GPUs.
# >>>
# >>>
from
megatron
import
get_args
from
megatron
import
get_args
...
@@ -161,22 +128,10 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2, ITERATION=None):
...
@@ -161,22 +128,10 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2, ITERATION=None):
else
:
else
:
torch
.
distributed
.
all_reduce
(
total_norm
,
torch
.
distributed
.
all_reduce
(
total_norm
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
mpu
.
get_
model_parallel_group
()
)
group
=
model_parallel_group
)
# <<<
# <<<
total_norm
=
total_norm
.
item
()
**
(
1.0
/
norm_type
)
total_norm
=
total_norm
.
item
()
**
(
1.0
/
norm_type
)
# >>>
# if ITERATION == DEBUG_ITERATION:
# pax(0, {
# "[LOC]" : "[** AFTER REDUCE. **]",
# "[ITERATION]" : ITERATION,
# "max_norm" : max_norm,
# "norm_type" : norm_type,
# "grad_norm" : grad_norm.item(),
# "total_norm" : total_norm,
# })
# <<<
# Scale.
# Scale.
clip_coeff
=
max_norm
/
(
total_norm
+
1.0e-6
)
clip_coeff
=
max_norm
/
(
total_norm
+
1.0e-6
)
if
clip_coeff
<
1.0
:
if
clip_coeff
<
1.0
:
...
@@ -186,22 +141,10 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2, ITERATION=None):
...
@@ -186,22 +141,10 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2, ITERATION=None):
[
grads
,
grads
],
[
grads
,
grads
],
clip_coeff
)
clip_coeff
)
# >>>
# # from pygit2 import Repository
# if ITERATION == DEBUG_ITERATION:
# pax(1, {
# "[LOC]" : "[** CLIP / FINAL **]",
# "[ITERATION]" : ITERATION,
# "grads" : grads,
# "clip_coeff" : tp(clip_coeff),
# # "repo" : Repository('.').head.shorthand,
# })
# <<<
return
total_norm
return
total_norm
def
count_zeros_fp32
(
parameters
):
def
count_zeros_fp32
(
parameters
,
model_parallel_group
):
if
isinstance
(
parameters
,
torch
.
Tensor
):
if
isinstance
(
parameters
,
torch
.
Tensor
):
parameters
=
[
parameters
]
parameters
=
[
parameters
]
...
@@ -231,7 +174,7 @@ def count_zeros_fp32(parameters):
...
@@ -231,7 +174,7 @@ def count_zeros_fp32(parameters):
else
:
else
:
torch
.
distributed
.
all_reduce
(
total_num_zeros
,
torch
.
distributed
.
all_reduce
(
total_num_zeros
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
mpu
.
get_
model_parallel_group
()
)
group
=
model_parallel_group
)
# <<<
# <<<
total_num_zeros
=
total_num_zeros
.
item
()
total_num_zeros
=
total_num_zeros
.
item
()
...
...
megatron/optimizer/distrib_optimizer.py
View file @
9546d8f0
...
@@ -17,8 +17,13 @@
...
@@ -17,8 +17,13 @@
import
math
import
math
import
torch
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
get_timers
from
megatron
import
mpu
from
.optimizer
import
MixedPrecisionOptimizer
,
_zero_grad_group_helper
# >>>
# >>>
from
lutil
import
pax
,
tp
from
lutil
import
pax
,
tp
...
@@ -40,7 +45,8 @@ class Shard:
...
@@ -40,7 +45,8 @@ class Shard:
# class Float16DistributedOptimizer(Float16OptimizerWithFloat16Params):
# class Float16DistributedOptimizer(Float16OptimizerWithFloat16Params):
# class Float16DistributedOptimizer(MegatronOptimizer):
# class Float16DistributedOptimizer(MegatronOptimizer):
# class Float16DistributedOptimizer(BaseFloat16Optimizer):
# class Float16DistributedOptimizer(BaseFloat16Optimizer):
class
DistributedOptimizer
(
MegatronOptimizer
):
# class DistributedOptimizer(MegatronOptimizer):
class
DistributedOptimizer
(
MixedPrecisionOptimizer
):
@
classmethod
@
classmethod
def
get_model_gbuf_param_shard_map
(
cls
,
model
,
dtype
,
gbuf_world_shard
):
def
get_model_gbuf_param_shard_map
(
cls
,
model
,
dtype
,
gbuf_world_shard
):
...
...
megatron/optimizer/optimizer.py
View file @
9546d8f0
...
@@ -98,14 +98,23 @@ class MegatronOptimizer(ABC):
...
@@ -98,14 +98,23 @@ class MegatronOptimizer(ABC):
return
params
return
params
def
get_model_parallel_group
(
self
):
'''Default returned here, but the distributed optimizer overrides this.'''
return
mpu
.
get_model_parallel_group
()
def
clip_grad_norm
(
self
,
clip_grad
,
ITERATION
):
def
clip_grad_norm
(
self
,
clip_grad
,
ITERATION
):
params
=
self
.
get_parameters
()
params
=
self
.
get_parameters
()
return
clip_grad_norm_fp32
(
params
,
clip_grad
,
ITERATION
=
ITERATION
)
return
clip_grad_norm_fp32
(
params
,
clip_grad
,
model_parallel_group
=
self
.
get_model_parallel_group
(),
ITERATION
=
ITERATION
)
def
count_zeros
(
self
):
def
count_zeros
(
self
):
params
=
self
.
get_parameters
()
params
=
self
.
get_parameters
()
return
count_zeros_fp32
(
params
)
return
count_zeros_fp32
(
params
,
model_parallel_group
=
self
.
get_model_parallel_group
())
@
abstractmethod
@
abstractmethod
...
@@ -171,7 +180,7 @@ class MegatronOptimizer(ABC):
...
@@ -171,7 +180,7 @@ class MegatronOptimizer(ABC):
def
step
(
self
):
def
step
(
self
):
pass
pass
def
gather_params
(
self
):
def
gather_params
(
self
,
ITERATION
):
pass
pass
def
reduce_grads
(
self
,
model
):
def
reduce_grads
(
self
,
model
):
...
@@ -282,10 +291,6 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
...
@@ -282,10 +291,6 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
self
.
_scale_one
=
torch
.
cuda
.
FloatTensor
([
1.0
])
self
.
_scale_one
=
torch
.
cuda
.
FloatTensor
([
1.0
])
@
abstractmethod
def
get_model_parallel_group
(
self
,
state_dict
):
pass
def
get_loss_scale
(
self
):
def
get_loss_scale
(
self
):
if
self
.
grad_scaler
is
None
:
if
self
.
grad_scaler
is
None
:
return
self
.
_scale_one
return
self
.
_scale_one
...
@@ -296,7 +301,7 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
...
@@ -296,7 +301,7 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
self
.
_copy_model_params_to_main_params
()
self
.
_copy_model_params_to_main_params
()
def
_unscale_main_grads_and_check_for_nan
(
self
,
group
):
def
_unscale_main_grads_and_check_for_nan
(
self
):
# Collect main grads.
# Collect main grads.
main_grads
=
self
.
_collect_main_grad_data_for_unscaling
()
main_grads
=
self
.
_collect_main_grad_data_for_unscaling
()
...
@@ -528,10 +533,6 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
...
@@ -528,10 +533,6 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
self
.
optimizer
.
load_state_dict
(
self
.
optimizer
.
state_dict
())
self
.
optimizer
.
load_state_dict
(
self
.
optimizer
.
state_dict
())
def
get_model_parallel_group
(
self
):
return
mpu
.
get_model_parallel_group
()
def
zero_grad
(
self
,
set_to_none
=
True
):
def
zero_grad
(
self
,
set_to_none
=
True
):
"""We only need to zero the model related parameters, i.e.,
"""We only need to zero the model related parameters, i.e.,
float16_groups & fp32_from_fp32_groups. We additionally zero
float16_groups & fp32_from_fp32_groups. We additionally zero
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment