Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
0528bf77
Commit
0528bf77
authored
Mar 09, 2022
by
Lawrence McAfee
Browse files
more cleanup
parent
94a90215
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
2 additions
and
27 deletions
+2
-27
megatron/optimizer/distrib_optimizer.py
megatron/optimizer/distrib_optimizer.py
+2
-27
No files found.
megatron/optimizer/distrib_optimizer.py
View file @
0528bf77
...
...
@@ -306,11 +306,9 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# main_params.extend(main_group["params"])
# ** using contiguous buffer; don't set_to_none **
_zero_grad_group_helper
(
model_params
,
set_to_none
=
False
)
# set_to_none)
_zero_grad_group_helper
(
model_params
,
set_to_none
=
False
)
# _zero_grad_group_helper(params, set_to_none = False)
# pax(0, {"model_params": model_params})
# def get_model_grad_buffer_dp_views(self):
# # >>>
...
...
@@ -344,13 +342,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# return gbuf_view_items
def
get_model_grad_buffer_dp_views
(
self
):
# >>>
# ** only contiguous grad buffer supported, for now [ TEMPORARY ] **
args
=
get_args
()
assert
args
.
use_contiguous_buffers_in_local_ddp
# <<<
# data_parallel_rank = mpu.get_data_parallel_rank()
data_parallel_world_size
=
mpu
.
get_data_parallel_world_size
()
# Grad buffer views.
...
...
@@ -358,27 +349,12 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
for
model_index
,
model
in
enumerate
(
self
.
models
):
for
dtype
,
gbuf
in
model
.
_grad_buffers
.
items
():
# gbuf_size = gbuf.numel_padded
assert
gbuf
.
numel_padded
%
data_parallel_world_size
==
0
shard_size
=
int
(
gbuf
.
numel_padded
/
data_parallel_world_size
)
# pax(0, {
# "numel" : gbuf.numel,
# "numel_padded" : gbuf.numel_padded,
# "shard_size / f" : gbuf.numel_padded/data_parallel_world_size,
# "shard_size / i" : shard_size,
# })
gbuf_views
=
[
gbuf
.
data
[(
r
*
shard_size
):((
r
+
1
)
*
shard_size
)]
for
r
in
range
(
data_parallel_world_size
)]
gbuf_view_items
.
append
((
model_index
,
dtype
,
gbuf_views
))
# pax(0, {
# "gbuf_view_items" : gbuf_view_items,
# **{
# "views / %d" % i : item[2]
# for i, item in enumerate(gbuf_view_items)
# },
# })
return
gbuf_view_items
def
reduce_model_grads
(
self
,
args
,
timers
):
...
...
@@ -417,9 +393,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
data_parallel_rank
=
mpu
.
get_data_parallel_rank
()
data_parallel_group
=
mpu
.
get_data_parallel_group
()
gbuf_view_items
=
self
.
get_model_grad_buffer_dp_views
()
# All-gather updated main params.
gbuf_view_items
=
self
.
get_model_grad_buffer_dp_views
()
for
model_index
,
dtype
,
gbuf_views
in
gbuf_view_items
:
torch
.
distributed
.
all_gather
(
gbuf_views
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment