Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
ff4c4273
Commit
ff4c4273
authored
Feb 15, 2022
by
Lawrence McAfee
Browse files
reduce scatter working
parent
bf64c85c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
24 additions
and
18 deletions
+24
-18
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+24
-18
No files found.
megatron/optimizer/optimizer.py
View file @
ff4c4273
...
@@ -1307,30 +1307,36 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
...
@@ -1307,30 +1307,36 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# ** only contiguous grad buffer supported, for now [ TEMPORARY ] **
# ** only contiguous grad buffer supported, for now [ TEMPORARY ] **
assert
args
.
use_contiguous_buffers_in_local_ddp
assert
args
.
use_contiguous_buffers_in_local_ddp
data_parallel_rank
=
mpu
.
get_data_parallel_rank
()
data_parallel_group
=
mpu
.
get_data_parallel_group
()
for
model_index
,
model
in
enumerate
(
self
.
models
):
for
model_index
,
model
in
enumerate
(
self
.
models
):
for
dtype
,
gbuf_shard
in
self
.
model_gbuf_shards
[
model_index
].
items
():
for
dtype
,
gbuf_shard
in
self
.
model_gbuf_shards
[
model_index
].
items
():
world_shards
=
gbuf_shard
[
"world_all"
]
world_shards
=
gbuf_shard
[
"world_all"
]
pax
(
0
,
{
gbuf
=
model
.
_grad_buffers
[
dtype
]
"model_index"
:
model_index
,
gbuf_views
=
[]
"model"
:
model
,
for
shard
in
world_shards
:
"dtype"
:
str
(
dtype
),
gbuf_views
.
append
(
gbuf
.
data
[
shard
.
start
:
shard
.
end
])
"gbuf_shard"
:
gbuf_shard
,
"world_shards"
:
world_shards
,
torch
.
distributed
.
reduce_scatter
(
})
gbuf_views
[
data_parallel_rank
],
gbuf_views
,
group
=
data_parallel_group
,
)
world_sizes
=
[]
# pax(0, {
for
r
in
self
.
world_shard_infos
:
# "model_index" : model_index,
# world_sizes.append(sum(g["size"] for g in r))
# "model" : model,
world_sizes
.
append
([
g
[
"size"
]
for
g
in
r
[
"groups"
]
])
# "dtype" : str(dtype),
# "gbuf_shard" : gbuf_shard,
#
grad_refs ...
#
"world_shards" : world_shards,
# "gbuf_views" : gbuf_views,
pax
(
0
,
{
"world_sizes"
:
world_sizes
})
#
})
# for world_grads = []
#
for world_shard_info_group
#
>>>
# x ?
torch
.
distributed
.
barrier
()
raise
Exception
(
"hi."
)
raise
Exception
(
"hi."
)
# <<<
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment