Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
fe3cfd86
Commit
fe3cfd86
authored
Mar 24, 2022
by
Lawrence McAfee
Browse files
working: copy grads.
parent
5e29905f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
77 additions
and
87 deletions
+77
-87
megatron/optimizer/distrib_optimizer.py
megatron/optimizer/distrib_optimizer.py
+77
-70
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+0
-17
No files found.
megatron/optimizer/distrib_optimizer.py
View file @
fe3cfd86
...
...
@@ -277,13 +277,17 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
gbuf_range
=
model_gbuf_ranges
[
model_index
][
dtype
]
param_range
=
gbuf_range
[
"param_map"
][
model_param
][
"param"
]
# >>>
assert
param_range
.
size
>
0
# <<<
# fp16, bf16 params.
if
model_param
.
type
()
in
[
'torch.cuda.HalfTensor'
,
'torch.cuda.BFloat16Tensor'
]:
# Clone model -> main.
shard_model_param
=
\
model_param
.
detach
()
[
param_range
.
start
:
param_range
.
end
]
shard_model_param
=
model_param
.
detach
().
view
(
-
1
)
\
[
param_range
.
start
:
param_range
.
end
]
shard_main_param
=
shard_model_param
.
clone
().
float
()
mpu
.
copy_tensor_model_parallel_attributes
(
shard_model_param
,
model_param
)
...
...
@@ -293,6 +297,11 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
shard_model_param
.
shared
=
model_param
.
shared
shard_main_param
.
shared
=
model_param
.
shared
# >>>
assert
shard_main_param
.
nelement
()
>
0
,
\
"param_range = %s."
%
param_range
# <<<
# Add to group.
full_float16_params_this_group
.
append
(
model_param
)
shard_float16_params_this_group
.
append
(
shard_model_param
)
...
...
@@ -300,8 +309,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# fp32 params.
elif
param
.
type
()
==
'torch.cuda.FloatTensor'
:
shard_model_param
=
\
model_param
[
param_range
.
start
:
param_range
.
end
]
shard_model_param
=
model_param
.
view
(
-
1
)
\
[
param_range
.
start
:
param_range
.
end
]
full_fp32_params_this_group
.
append
(
model_param
)
shard_fp32_params_this_group
.
append
(
shard_model_param
)
...
...
@@ -661,35 +670,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# <<<
# >>>
# def _copy_model_grads_to_main_grads(self):
# for group_index, group_range in enumerate(self.opt_group_ranges):
# for model_param, main_range in group_range["param_map"].items():
# # Model range.
# # model_index, dtype = self.param_gbuf_map[model_param]
# # model_range = self.model_gbuf_ranges \
# # [model_index][dtype]["param_map"][model_param]["gbuf_world"]
# model_range = self.get_model_param_range_map(model_param)["gbuf_world"]
# assert main_range.size == model_range.size
# # Copy from DDP's contiguous buffer to main shard's grad.
# model_grad = self.models[model_index]._grad_buffers[dtype].data
# main_grad = self.get_main_grad(group_index)
# # Copy sub-range within tensor.
# model_view = model_grad[model_range.start:model_range.end]
# main_view = main_grad[main_range.start:main_range.end]
# main_view.detach().copy_(model_view)
# def _copy_model_grads_to_main_grads(self):
# super()._copy_model_grads_to_main_grads()
# raise Exception("check main param '.grad'.")
# for group in self.optimizer.param_groups:
# for param in group["params"]:
# param.grad =
def
_copy_model_grads_to_main_grads
(
self
):
# >>>
...
...
@@ -708,38 +688,22 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
param_range_map
=
self
.
get_model_param_range_map
(
full_model_param
)
param_range
=
param_range_map
[
"param"
]
assert
param_range
.
size
==
shard_main_param
.
nelement
()
full_model_grad
=
full_model_param
.
main_grad
shard_model_grad
=
\
full_model_grad
[
param_range
.
start
:
param_range
.
end
]
shard_model_grad
=
full_model_grad
.
view
(
-
1
)
\
[
param_range
.
start
:
param_range
.
end
]
shard_main_param
.
grad
=
shard_model_grad
.
float
()
# >>>
if
full_model_param
.
nelement
()
!=
shard_main_param
.
nelement
():
pax
(
0
,
{
"param_range_map"
:
param_range_map
,
"param_range"
:
param_range
,
"full_model_param"
:
tp
(
full_model_param
),
"full_model_grad"
:
tp
(
full_model_grad
),
"shard_model_grad"
:
tp
(
shard_model_grad
),
"shard_main_grad"
:
tp
(
shard_main_param
.
grad
),
"shard_main_param"
:
tp
(
shard_main_param
),
})
# <<<
# print_seq([ "%s / %d, [%d] %s" % (
# k, i, len(g), ", ".join(str(p.nelement()) for p in g),
# ) for k, gs in [
# ("model", self.full_float16_groups),
# ("main", self.shard_fp32_from_float16_groups),
# ] for i, g in enumerate(gs)])
# print_seq("float16 groups: %d [%s], %d [%s]." % (
# len(self.full_float16_groups),
# # ",".join(str(len(g)) for g in self.full_float16_groups),
# ",".join(str(tuple(p.shape)) for gs in self.full_float16_groups for g in gs for p in g),
# len(self.shard_fp32_from_float16_groups),
# ",".join(str(len(g)) for g in self.shard_fp32_from_float16_groups),
# ))
gs
=
self
.
full_float16_groups
pax
(
0
,
{
**
{
"gs / %d"
%
i
:
len
(
g
)
for
i
,
g
in
enumerate
(
gs
)},
})
copy_group_grads
(
self
.
full_float16_groups
,
self
.
shard_fp32_from_float16_groups
)
print_seq
(
"hi."
)
copy_group_grads
(
self
.
full_fp32_groups
,
self
.
shard_fp32_groups
)
...
...
@@ -750,7 +714,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# for p in g["params"]
# ])
# <<<
# <<<
# >>>
...
...
@@ -778,17 +741,61 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# def _copy_main_params_to_model_params(self):
# super()._copy_main_params_to_model_params()
# raise Exception("check main param '.grad'.")
# def _copy_main_params_to_model_params(self):
# raise Exception("hi.")
# # This only needs to be done for the float16 group.
# for model_group, main_group in zip(self.float16_groups,
# self.fp32_from_float16_groups):
# for model_param, main_param in zip(model_group, main_group):
# model_param.main_grad.detach().copy_(main_param)
# # For fp32 grads, we need to reset the grads to main grad.
# for group in self.fp32_groups:
# for param in group:
# param.main_grad.detach().copy_(param)
def
_copy_main_params_to_model_params
(
self
):
raise
Exception
(
"hi."
)
# This only needs to be done for the float16 group.
for
model_group
,
main_group
in
zip
(
self
.
float16_groups
,
self
.
fp32_from_float16_groups
):
for
model_param
,
main_param
in
zip
(
model_group
,
main_group
):
model_param
.
main_grad
.
detach
().
copy_
(
main_param
)
# >>>
# print_seq([
# "grad = %s." % tp(p.grad)
# for g in self.optimizer.param_groups
# for p in g["params"]
# ])
# <<<
def
copy_group_params
(
shard_main_groups
,
full_model_groups
):
for
shard_main_group
,
full_model_group
in
zip
(
shard_main_groups
,
full_model_groups
):
for
shard_main_param
,
full_model_param
in
zip
(
shard_main_group
,
full_model_group
):
param_range_map
=
self
.
get_model_param_range_map
(
full_model_param
)
param_range
=
param_range_map
[
"param"
]
assert
param_range
.
size
==
shard_main_param
.
nelement
()
full_model_grad
=
full_model_param
.
main_grad
shard_model_grad
=
full_model_grad
.
view
(
-
1
)
\
[
param_range
.
start
:
param_range
.
end
]
shard_main_param
.
grad
=
shard_model_grad
.
float
()
# For fp32 grads, we need to reset the grads to main grad.
for
group
in
self
.
fp32_groups
:
for
param
in
group
:
param
.
main_grad
.
detach
().
copy_
(
param
)
# print_seq([ "%s / %d, [%d] %s" % (
# k, i, len(g), ", ".join(str(p.nelement()) for p in g),
# ) for k, gs in [
# ("model", self.full_float16_groups),
# ("main", self.shard_fp32_from_float16_groups),
# ] for i, g in enumerate(gs)])
copy_group_params
(
self
.
shard_fp32_from_float16_groups
,
self
.
full_float16_groups
)
copy_group_params
(
self
.
shard_fp32_groups
,
self
.
full_fp32_groups
)
# >>>
# print_seq([
# "grad = %s." % tp(p.grad)
# for g in self.optimizer.param_groups
# for p in g["params"]
# ])
# <<<
# <<<
megatron/optimizer/optimizer.py
View file @
fe3cfd86
...
...
@@ -327,22 +327,6 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
self
.
_copy_model_params_to_main_params
()
# >>>
# def zero_grad(self, set_to_none=True):
# """We only need to zero the model related parameters, i.e.,
# float16_groups & fp32_from_fp32_groups. We additionally zero
# fp32_from_float16_groups as a memory optimization to reduce
# fragmentation; in the case of set_to_none==True, the space
# used by this field can be safely deallocated at this point."""
# for group in self.float16_groups:
# _zero_grad_group_helper(group, set_to_none)
# for group in self.fp32_from_float16_groups:
# _zero_grad_group_helper(group, set_to_none)
# for group in self.fp32_from_fp32_groups:
# _zero_grad_group_helper(group, set_to_none)
# <<<
def
_unscale_main_grads_and_check_for_nan
(
self
):
# Collect main grads.
...
...
@@ -431,7 +415,6 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
self
.
_copy_model_grads_to_main_grads
()
timers
(
'optimizer-copy-to-main-grad'
).
stop
()
print_seq
(
"hi."
)
# Do unscale, check for inf, and update grad scaler only for
# the case that grad scaler is provided.
if
self
.
grad_scaler
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment