Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
bf64c85c
Commit
bf64c85c
authored
Feb 15, 2022
by
Lawrence McAfee
Browse files
added 'all' world shards to gbuf map
parent
eaa0c1df
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
18 additions
and
8 deletions
+18
-8
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+18
-8
No files found.
megatron/optimizer/optimizer.py
View file @
bf64c85c
...
...
@@ -977,9 +977,13 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
gbuf_size
=
grad_buffer
.
numel
max_gbuf_shard_size
=
int
(
math
.
ceil
(
gbuf_size
/
data_parallel_world_size
))
gbuf_world_start
=
data_parallel_rank
*
max_gbuf_shard_size
gbuf_world_end
=
min
(
gbuf_size
,
gbuf_world_start
+
max_gbuf_shard_size
)
gbuf_world_shard
=
Shard
(
gbuf_world_start
,
gbuf_world_end
)
gbuf_world_all_shards
=
[]
for
r
in
range
(
data_parallel_world_size
):
gbuf_world_start
=
r
*
max_gbuf_shard_size
gbuf_world_end
=
min
(
gbuf_size
,
gbuf_world_start
+
max_gbuf_shard_size
)
gbuf_world_shard
=
Shard
(
gbuf_world_start
,
gbuf_world_end
)
gbuf_world_all_shards
.
append
(
gbuf_world_shard
)
gbuf_world_shard
=
gbuf_world_all_shards
[
data_parallel_rank
]
gbuf_local_shard
=
gbuf_world_shard
.
normalize
()
# gbuf_local_shard = Shard(0, gbuf_world_index.size)
...
...
@@ -992,6 +996,7 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
data
=
{
"local"
:
gbuf_local_shard
,
"world"
:
gbuf_world_shard
,
"world_all"
:
gbuf_world_all_shards
,
"param_map"
:
param_shard_map
,
}
...
...
@@ -1302,12 +1307,17 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# ** only contiguous grad buffer supported, for now [ TEMPORARY ] **
assert
args
.
use_contiguous_buffers_in_local_ddp
for
model_index
,
model
in
enuemrate
(
self
.
models
):
for
model_index
,
model
in
enumerate
(
self
.
models
):
for
dtype
,
gbuf_shard
in
self
.
model_gbuf_shards
[
model_index
].
items
():
world_shards
=
gbuf_shard
[
"world_all"
]
pax
(
0
,
{
"model_index"
:
model_index
,
"model"
:
model
,
})
pax
(
0
,
{
"model_index"
:
model_index
,
"model"
:
model
,
"dtype"
:
str
(
dtype
),
"gbuf_shard"
:
gbuf_shard
,
"world_shards"
:
world_shards
,
})
world_sizes
=
[]
for
r
in
self
.
world_shard_infos
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment