Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
41ad0c5d
Commit
41ad0c5d
authored
Mar 24, 2022
by
Lawrence McAfee
Browse files
training; loss not yet checked.
parent
fe3cfd86
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
1 addition
and
30 deletions
+1
-30
megatron/optimizer/distrib_optimizer.py
megatron/optimizer/distrib_optimizer.py
+1
-30
No files found.
megatron/optimizer/distrib_optimizer.py
View file @
41ad0c5d
...
...
@@ -559,15 +559,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
grads.
'''
# >>>
# print_seq([
# tp(b.data)
# for m in self.models
# for b in m._grad_buffers.values()
# ])
# print_seq("hi.")
# <<<
# All-reduce embedding grads.
timers
(
'backward-embedding-all-reduce'
).
start
()
self
.
allreduce_embedding_grads
(
args
)
...
...
@@ -593,16 +584,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
group
=
data_parallel_group
,
)
# >>>
# print_seq("hi.")
# <<<
timers
(
'backward-params-all-reduce'
).
stop
()
def
gather_model_params
(
self
,
args
,
timers
):
raise
Exception
(
"hi."
)
timers
(
'backward-params-all-gather'
).
start
()
data_parallel_rank
=
mpu
.
get_data_parallel_rank
()
...
...
@@ -756,14 +741,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# param.main_grad.detach().copy_(param)
def
_copy_main_params_to_model_params
(
self
):
# >>>
# print_seq([
# "grad = %s." % tp(p.grad)
# for g in self.optimizer.param_groups
# for p in g["params"]
# ])
# <<<
def
copy_group_params
(
shard_main_groups
,
full_model_groups
):
for
shard_main_group
,
full_model_group
in
zip
(
shard_main_groups
,
full_model_groups
):
...
...
@@ -777,14 +754,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
full_model_grad
=
full_model_param
.
main_grad
shard_model_grad
=
full_model_grad
.
view
(
-
1
)
\
[
param_range
.
start
:
param_range
.
end
]
shard_main_param
.
grad
=
shard_model_grad
.
float
()
# print_seq([ "%s / %d, [%d] %s" % (
# k, i, len(g), ", ".join(str(p.nelement()) for p in g),
# ) for k, gs in [
# ("model", self.full_float16_groups),
# ("main", self.shard_fp32_from_float16_groups),
# ] for i, g in enumerate(gs)])
shard_model_grad
.
data
.
copy_
(
shard_main_param
)
copy_group_params
(
self
.
shard_fp32_from_float16_groups
,
self
.
full_float16_groups
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment