Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
41038d54
Commit
41038d54
authored
Mar 22, 2022
by
Lawrence McAfee
Browse files
running & saving memory w/ _reduce_scatter_base/_all_gather_base.
parent
64b94f00
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
51 additions
and
48 deletions
+51
-48
megatron/optimizer/distrib_optimizer.py
megatron/optimizer/distrib_optimizer.py
+51
-48
No files found.
megatron/optimizer/distrib_optimizer.py
View file @
41038d54
...
@@ -410,50 +410,53 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -410,50 +410,53 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
shard_size
=
int
(
gbuf
.
numel_padded
/
data_parallel_world_size
)
shard_size
=
int
(
gbuf
.
numel_padded
/
data_parallel_world_size
)
gbuf_views
=
[
gbuf
.
data
[(
r
*
shard_size
):((
r
+
1
)
*
shard_size
)]
gbuf_views
=
[
gbuf
.
data
[(
r
*
shard_size
):((
r
+
1
)
*
shard_size
)]
for
r
in
range
(
data_parallel_world_size
)]
for
r
in
range
(
data_parallel_world_size
)]
gbuf_view_items
.
append
((
model_index
,
dtype
,
gbuf_views
))
# gbuf_view_items.append((model_index, dtype, gbuf_views))
gbuf_view_items
.
append
((
model_index
,
dtype
,
gbuf
.
data
,
gbuf_views
))
return
gbuf_view_items
return
gbuf_view_items
# >>>
# >>>
def
get_model_grad_buffer_dp_views_SINGLE
(
self
):
#
def get_model_grad_buffer_dp_views_SINGLE(self):
data_parallel_world_size
=
mpu
.
get_data_parallel_world_size
()
#
data_parallel_world_size = mpu.get_data_parallel_world_size()
# Grad buffer views.
#
# Grad buffer views.
gbuf_items
=
[]
#
gbuf_items = []
for
model_index
,
model
in
enumerate
(
self
.
models
):
#
for model_index, model in enumerate(self.models):
for
dtype
,
gbuf
in
model
.
_grad_buffers
.
items
():
#
for dtype, gbuf in model._grad_buffers.items():
gbuf_items
.
append
((
model_index
,
dtype
,
gbuf
.
data
))
#
gbuf_items.append((model_index, dtype, gbuf.data))
return
gbuf_items
#
return gbuf_items
# <<<
# <<<
def
get_model_grad_buffer_dp_views_chunked
(
self
,
mem_savings_factor
):
# >>>
# def get_model_grad_buffer_dp_views_chunked(self, mem_savings_factor):
# Iterate grad buffers & chunk.
gbuf_view_items
=
self
.
get_model_grad_buffer_dp_views
()
# # Iterate grad buffers & chunk.
chunk_view_items
=
[]
# gbuf_view_items = self.get_model_grad_buffer_dp_views()
for
model_index
,
dtype
,
gbuf_views
in
gbuf_view_items
:
# chunk_view_items = []
# for model_index, dtype, gbuf_views in gbuf_view_items:
# ** Sanity check. ** (should be unnecessary; see comment above)
view_numel
=
gbuf_views
[
0
].
nelement
()
# # ** Sanity check. ** (should be unnecessary; see comment above)
for
view
in
gbuf_views
:
# view_numel = gbuf_views[0].nelement()
assert
view
.
nelement
()
==
view_numel
# for view in gbuf_views:
# assert view.nelement() == view_numel
# Compute chunk size (via savings factor).
chunk_numel_min
=
131072
# # Compute chunk size (via savings factor).
chunk_numel_max
=
view_numel
# chunk_numel_min = 131072
chunk_numel
=
int
(
# chunk_numel_max = view_numel
mem_savings_factor
*
chunk_numel_min
# chunk_numel = int(
+
(
1
-
mem_savings_factor
)
*
chunk_numel_max
# mem_savings_factor * chunk_numel_min
)
# + (1 - mem_savings_factor) * chunk_numel_max
# )
# Chunk views.
for
start_index
in
range
(
0
,
view_numel
,
chunk_numel
):
# # Chunk views.
end_index
=
min
(
view_numel
,
start_index
+
chunk_numel
)
# for start_index in range(0, view_numel, chunk_numel):
chunk_views
=
[
t
[
start_index
:
end_index
]
for
t
in
gbuf_views
]
# end_index = min(view_numel, start_index + chunk_numel)
chunk_view_items
.
append
((
model_index
,
dtype
,
chunk_views
))
# chunk_views = [ t[start_index:end_index] for t in gbuf_views ]
# chunk_view_items.append((model_index, dtype, chunk_views))
return
chunk_view_items
# return chunk_view_items
# <<<
def
reduce_model_grads
(
self
,
args
,
timers
):
def
reduce_model_grads
(
self
,
args
,
timers
):
'''Note: this is a different order of reduction, versus the non-
'''Note: this is a different order of reduction, versus the non-
...
@@ -490,17 +493,17 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -490,17 +493,17 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# )
# )
# +++
# +++
gbuf_view_items
=
self
.
get_model_grad_buffer_dp_views
()
gbuf_view_items
=
self
.
get_model_grad_buffer_dp_views
()
gbuf_view_items_SINGLE
=
self
.
get_model_grad_buffer_dp_views_SINGLE
()
#
gbuf_view_items_SINGLE = self.get_model_grad_buffer_dp_views_SINGLE()
for
index
,
(
model_index
,
dtype
,
gbuf_views
)
in
enumerate
(
gbuf_view_items
):
for
index
,
(
model_index
,
dtype
,
gbuf
,
gbuf_views
)
in
enumerate
(
gbuf_view_items
):
# >>>
# >>>
pax
(
0
,
{
#
pax(0, {
"gbuf_view"
:
gbuf_views
[
data_parallel_rank
].
shape
,
#
"gbuf_view" : gbuf_views[data_parallel_rank].shape,
"gbuf SINGLE"
:
gbuf_view_items_SINGLE
[
index
][
2
].
shape
,
#
"gbuf SINGLE" : gbuf_view_items_SINGLE[index][2].shape,
})
#
})
# <<<
# <<<
torch
.
distributed
.
_reduce_scatter_base
(
torch
.
distributed
.
_reduce_scatter_base
(
gbuf_views
[
data_parallel_rank
],
gbuf_views
[
data_parallel_rank
],
gbuf_view_items_SINGLE
[
index
][
2
],
gbuf
,
#
gbuf_view_items_SINGLE[index][2],
group
=
data_parallel_group
,
group
=
data_parallel_group
,
)
)
# torch.distributed.reduce_scatter(
# torch.distributed.reduce_scatter(
...
@@ -535,12 +538,12 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -535,12 +538,12 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# )
# )
# +++
# +++
gbuf_view_items
=
self
.
get_model_grad_buffer_dp_views
()
gbuf_view_items
=
self
.
get_model_grad_buffer_dp_views
()
gbuf_view_items_SINGLE
=
self
.
get_model_grad_buffer_dp_views_SINGLE
()
#
gbuf_view_items_SINGLE = self.get_model_grad_buffer_dp_views_SINGLE()
for
index
,
(
model_index
,
dtype
,
gbuf_views
)
in
enumerate
(
gbuf_view_items
):
for
index
,
(
model_index
,
dtype
,
gbuf
,
gbuf_views
)
in
enumerate
(
gbuf_view_items
):
torch
.
distributed
.
_all_gather_base
(
torch
.
distributed
.
_all_gather_base
(
# gbuf_view_items_SINGLE[index][2],
gbuf
,
# gbuf_view_items_SINGLE[index][2],
#
gbuf_views[data_parallel_rank],
gbuf_views
[
data_parallel_rank
],
#
group = data_parallel_group,
group
=
data_parallel_group
,
)
)
# <<<
# <<<
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment