Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
5e29905f
Commit
5e29905f
authored
Mar 24, 2022
by
Lawrence McAfee
Browse files
stalling in copy_grads().
parent
2c3cb9fc
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
91 additions
and
53 deletions
+91
-53
megatron/optimizer/distrib_optimizer.py
megatron/optimizer/distrib_optimizer.py
+74
-41
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+17
-12
No files found.
megatron/optimizer/distrib_optimizer.py
View file @
5e29905f
...
@@ -447,8 +447,12 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -447,8 +447,12 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# >>>
# >>>
# def get_main_grads_for_grad_norm(self):
# def get_main_grads_for_grad_norm(self):
# return self.main_grad_views_for_grad_norm
# return self.main_grad_views_for_grad_norm
def
get_main_grads_for_grad_norm
(
self
):
# def get_main_grads_for_grad_norm(self):
raise
Exception
(
"does 'super' work?"
)
# raise Exception("....... use 'super' .......")
# grads_for_norm = super().get_main_grads_for_grad_norm()
# if torch.distributed.get_rank() == 1:
# print_seq([ tp(g) for g in grads_for_norm ])
# return grads_for_norm
# <<<
# <<<
...
@@ -493,6 +497,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -493,6 +497,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
def
load_state_dict
(
self
,
state_dict
):
def
load_state_dict
(
self
,
state_dict
):
raise
Exception
(
"hi."
)
raise
Exception
(
"hi."
)
# >>>
# def zero_grad(self, set_to_none=True):
# def zero_grad(self, set_to_none=True):
# # Collect model params.
# # Collect model params.
...
@@ -505,7 +510,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -505,7 +510,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# _zero_grad_group_helper(model_params, set_to_none = False)
# _zero_grad_group_helper(model_params, set_to_none = False)
# def zero_grad(self, set_to_none=True):
# def zero_grad(self, set_to_none=True):
# raise Exception("does 'super' work?")
# raise Exception("does 'super' work?")
# >>>
def
zero_grad
(
self
,
set_to_none
=
True
):
def
zero_grad
(
self
,
set_to_none
=
True
):
"""We only need to zero the model related parameters, i.e.,
"""We only need to zero the model related parameters, i.e.,
float16_groups & fp32_groups. We additionally zero
float16_groups & fp32_groups. We additionally zero
...
@@ -515,6 +519,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -515,6 +519,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
for
groups
in
(
for
groups
in
(
self
.
full_float16_groups
,
self
.
full_float16_groups
,
self
.
full_fp32_groups
,
self
.
full_fp32_groups
,
self
.
shard_float16_groups
,
# grad empty/unused here?
self
.
shard_fp32_groups
,
self
.
shard_fp32_from_float16_groups
):
self
.
shard_fp32_from_float16_groups
):
for
group
in
groups
:
for
group
in
groups
:
_zero_grad_group_helper
(
group
,
set_to_none
)
_zero_grad_group_helper
(
group
,
set_to_none
)
...
@@ -550,6 +556,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -550,6 +556,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# for m in self.models
# for m in self.models
# for b in m._grad_buffers.values()
# for b in m._grad_buffers.values()
# ])
# ])
# print_seq("hi.")
# <<<
# <<<
# All-reduce embedding grads.
# All-reduce embedding grads.
...
@@ -577,6 +584,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -577,6 +584,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
group
=
data_parallel_group
,
group
=
data_parallel_group
,
)
)
# >>>
# print_seq("hi.")
# <<<
timers
(
'backward-params-all-reduce'
).
stop
()
timers
(
'backward-params-all-reduce'
).
stop
()
def
gather_model_params
(
self
,
args
,
timers
):
def
gather_model_params
(
self
,
args
,
timers
):
...
@@ -610,9 +621,20 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -610,9 +621,20 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
timers
(
'backward-params-all-gather'
).
stop
()
timers
(
'backward-params-all-gather'
).
stop
()
# >>>
# def _collect_main_grad_data_for_unscaling(self):
# return [ g.data for g in self.get_main_grads() ]
def
_collect_main_grad_data_for_unscaling
(
self
):
def
_collect_main_grad_data_for_unscaling
(
self
):
raise
Exception
(
"hi."
)
main_grad_data
=
[
return
[
g
.
data
for
g
in
self
.
get_main_grads
()
]
param
.
grad
.
data
for
group
in
self
.
optimizer
.
param_groups
for
param
in
group
[
"params"
]
]
# print_seq([ tp(g) for g in main_grad_data ])
return
main_grad_data
# <<<
# >>>
# >>>
# def _copy_model_params_to_main_params(self):
# def _copy_model_params_to_main_params(self):
...
@@ -678,44 +700,55 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -678,44 +700,55 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# ])
# ])
# <<<
# <<<
# This only needs to be done for the float16 group.
def
copy_group_grads
(
full_model_groups
,
shard_main_groups
):
for
full_model_group
,
shard_main_group
in
zip
(
for
full_model_group
,
shard_main_group
in
zip
(
full_model_groups
,
self
.
full_float16_groups
,
shard_main_groups
):
self
.
shard_fp32_from_float16_groups
):
for
full_model_param
,
shard_main_param
in
zip
(
full_model_group
,
for
full_model_param
,
shard_main_param
in
zip
(
full_model_group
,
shard_main_group
):
shard_main_group
):
param_range_map
=
self
.
get_model_param_range_map
(
full_model_param
)
param_range_map
=
self
.
get_model_param_range_map
(
full_model_param
)
param_range
=
param_range_map
[
"param"
]
param_range
=
param_range_map
[
"param"
]
full_model_grad
=
full_model_param
.
main_grad
full_model_grad
=
full_model_param
.
main_grad
shard_model_grad
=
\
shard_model_grad
=
\
full_model_grad
[
param_range
.
start
:
param_range
.
end
]
full_model_grad
[
param_range
.
start
:
param_range
.
end
]
shard_main_param
.
grad
=
shard_model_grad
.
float
()
shard_main_param
.
grad
=
shard_model_grad
.
float
()
# >>>
# >>>
if
full_model_param
.
nelement
()
!=
shard_main_param
.
nelement
():
if
full_model_param
.
nelement
()
!=
shard_main_param
.
nelement
():
pax
(
0
,
{
pax
(
0
,
{
"param_range_map"
:
param_range_map
,
"param_range_map"
:
param_range_map
,
"param_range"
:
param_range
,
"param_range"
:
param_range
,
"full_model_param"
:
tp
(
full_model_param
),
"full_model_param"
:
tp
(
full_model_param
),
"full_model_grad"
:
tp
(
full_model_grad
),
"full_model_grad"
:
tp
(
full_model_grad
),
"shard_model_grad"
:
tp
(
shard_model_grad
),
"shard_model_grad"
:
tp
(
shard_model_grad
),
"shard_main_grad"
:
tp
(
shard_main_param
.
grad
),
"shard_main_grad"
:
tp
(
shard_main_param
.
grad
),
"shard_main_param"
:
tp
(
shard_main_param
),
"shard_main_param"
:
tp
(
shard_main_param
),
})
})
# <<<
# <<<
# print_seq("float16 groups: %d [%s], %d [%s]." % (
# For fp32 grads, we need to reset the grads to main grad.
# len(self.full_float16_groups),
for
group
in
self
.
fp32_groups
:
# # ",".join(str(len(g)) for g in self.full_float16_groups),
for
param
in
group
:
# ",".join(str(tuple(p.shape)) for gs in self.full_float16_groups for g in gs for p in g),
param
.
grad
=
param
.
main_grad
# len(self.shard_fp32_from_float16_groups),
# ",".join(str(len(g)) for g in self.shard_fp32_from_float16_groups),
# ))
gs
=
self
.
full_float16_groups
pax
(
0
,
{
**
{
"gs / %d"
%
i
:
len
(
g
)
for
i
,
g
in
enumerate
(
gs
)},
})
copy_group_grads
(
self
.
full_float16_groups
,
self
.
shard_fp32_from_float16_groups
)
print_seq
(
"hi."
)
copy_group_grads
(
self
.
full_fp32_groups
,
self
.
shard_fp32_groups
)
# >>>
# >>>
print_seq
([
#
print_seq([
"grad = %s."
%
tp
(
p
.
grad
)
#
"grad = %s." % tp(p.grad)
for
g
in
self
.
optimizer
.
param_groups
#
for g in self.optimizer.param_groups
for
p
in
g
[
"params"
]
#
for p in g["params"]
])
#
])
# <<<
# <<<
# <<<
# <<<
...
...
megatron/optimizer/optimizer.py
View file @
5e29905f
...
@@ -33,6 +33,10 @@ from megatron.utils import unwrap_model
...
@@ -33,6 +33,10 @@ from megatron.utils import unwrap_model
from
.clip_grads
import
clip_grad_norm_fp32
,
count_zeros_fp32
from
.clip_grads
import
clip_grad_norm_fp32
,
count_zeros_fp32
# >>>
from
lutil
import
pax
,
tp
,
print_seq
# <<<
def
_zero_grad_group_helper
(
group
,
set_to_none
):
def
_zero_grad_group_helper
(
group
,
set_to_none
):
"""Zero out the gradient for a group of parameters.
"""Zero out the gradient for a group of parameters.
...
@@ -427,6 +431,7 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
...
@@ -427,6 +431,7 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
self
.
_copy_model_grads_to_main_grads
()
self
.
_copy_model_grads_to_main_grads
()
timers
(
'optimizer-copy-to-main-grad'
).
stop
()
timers
(
'optimizer-copy-to-main-grad'
).
stop
()
print_seq
(
"hi."
)
# Do unscale, check for inf, and update grad scaler only for
# Do unscale, check for inf, and update grad scaler only for
# the case that grad scaler is provided.
# the case that grad scaler is provided.
if
self
.
grad_scaler
:
if
self
.
grad_scaler
:
...
@@ -569,18 +574,18 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
...
@@ -569,18 +574,18 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
# >>>
# >>>
#
def zero_grad(self, set_to_none=True):
def
zero_grad
(
self
,
set_to_none
=
True
):
#
"""We only need to zero the model related parameters, i.e.,
"""We only need to zero the model related parameters, i.e.,
#
float16_groups & fp32_from_fp32_groups. We additionally zero
float16_groups & fp32_from_fp32_groups. We additionally zero
#
fp32_from_float16_groups as a memory optimization to reduce
fp32_from_float16_groups as a memory optimization to reduce
#
fragmentation; in the case of set_to_none==True, the space
fragmentation; in the case of set_to_none==True, the space
#
used by this field can be safely deallocated at this point."""
used by this field can be safely deallocated at this point."""
#
for group in self.float16_groups:
for
group
in
self
.
float16_groups
:
#
_zero_grad_group_helper(group, set_to_none)
_zero_grad_group_helper
(
group
,
set_to_none
)
#
for group in self.fp32_from_float16_groups:
for
group
in
self
.
fp32_from_float16_groups
:
#
_zero_grad_group_helper(group, set_to_none)
_zero_grad_group_helper
(
group
,
set_to_none
)
#
for group in self.fp32_from_fp32_groups:
for
group
in
self
.
fp32_from_fp32_groups
:
#
_zero_grad_group_helper(group, set_to_none)
_zero_grad_group_helper
(
group
,
set_to_none
)
# <<<
# <<<
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment