Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
34fc1e37
Commit
34fc1e37
authored
Mar 18, 2022
by
Lawrence McAfee
Browse files
after experiments, before new arg 'distrib-opt-comm-buf-1m'.
parent
12d91733
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
62 additions
and
30 deletions
+62
-30
megatron/optimizer/distrib_optimizer.py
megatron/optimizer/distrib_optimizer.py
+62
-30
No files found.
megatron/optimizer/distrib_optimizer.py
View file @
34fc1e37
...
@@ -362,11 +362,11 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -362,11 +362,11 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
sub_view_items
.
append
((
model_index
,
dtype
,
sub_views
))
sub_view_items
.
append
((
model_index
,
dtype
,
sub_views
))
# >>>
# >>>
from
lutil
import
pax
#
from lutil import pax
pax
(
0
,
{
#
pax(0, {
"gbuf_view_items"
:
[(
a
,
b
,
c
.
shape
)
for
a
,
b
,
c
in
gbuf_view_items
],
#
"gbuf_view_items" : [(a,b,
"%d / %s" % (len(c), [ d.nelement() for d in c ])
) for a,b,c in gbuf_view_items],
"sub_view_items"
:
[(
a
,
b
,
c
.
shape
)
for
a
,
b
,
c
in
sub_view_items
],
#
"sub_view_items" : [(a,b,
"%d / %s" % (len(c), [ d.nelement() for d in c ])
) for a,b,c in sub_view_items],
})
#
})
# <<<
# <<<
return
sub_view_items
return
sub_view_items
...
@@ -432,10 +432,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -432,10 +432,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
data_parallel_group
=
mpu
.
get_data_parallel_group
()
data_parallel_group
=
mpu
.
get_data_parallel_group
()
sub_numel
=
1
*
1048576
sub_numel
=
1
*
1048576
sub
_view_items
=
self
.
get_model_grad_buffer_dp_views_SUB
(
sub_numel
)
gbuf
_view_items
=
self
.
get_model_grad_buffer_dp_views_SUB
(
sub_numel
)
for
model_index
,
dtype
,
gbuf_views
in
gbuf_view_items
:
for
model_index
,
dtype
,
gbuf_views
in
gbuf_view_items
:
gbuf
=
self
.
models
[
model_index
].
_grad_buffers
[
dtype
].
data
#
gbuf = self.models[model_index]._grad_buffers[dtype].data
gbuf
/=
data_parallel_world_size
#
gbuf /= data_parallel_world_size
torch
.
distributed
.
reduce_scatter
(
torch
.
distributed
.
reduce_scatter
(
gbuf_views
[
data_parallel_rank
],
gbuf_views
[
data_parallel_rank
],
gbuf_views
,
gbuf_views
,
...
@@ -444,7 +444,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -444,7 +444,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
timers
(
'backward-params-all-reduce'
).
stop
()
timers
(
'backward-params-all-reduce'
).
stop
()
def
reduce_model_grads
(
self
,
*
args
):
def
reduce_model_grads
(
self
,
*
args
):
# >>>
# >>>
#
return
return
# <<<
# <<<
# self.reduce_model_grads_0(*args)
# self.reduce_model_grads_0(*args)
self
.
reduce_model_grads_1
(
*
args
)
self
.
reduce_model_grads_1
(
*
args
)
...
@@ -475,6 +475,49 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -475,6 +475,49 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
param
.
detach
().
copy_
(
param
.
main_grad
)
param
.
detach
().
copy_
(
param
.
main_grad
)
timers
(
'backward-params-all-gather'
).
stop
()
timers
(
'backward-params-all-gather'
).
stop
()
# def gather_model_params_1(self, args, timers):
# timers('backward-params-all-gather').start()
# data_parallel_rank = mpu.get_data_parallel_rank()
# data_parallel_group = mpu.get_data_parallel_group()
# # All-gather updated main params.
# # - All grad buffer views are guaranteed to have the same num elements
# # across all data parallel ranks, with grad buffer padding that is done
# # in distributed.py. Thus, all sub-views will have consistent start/end
# # indexes across data parallel ranks.
# gbuf_view_items = self.get_model_grad_buffer_dp_views()
# # sub_view_numel = 1 * 1024
# # sub_view_numel = 1 * 131072
# sub_view_numel = 256 * 1048576
# for model_index, dtype, gbuf_views in gbuf_view_items:
# # ** Sanity check. ** (should be unnecessary; see comment above)
# view_numel = gbuf_views[0].nelement()
# for view in gbuf_views:
# assert view.nelement() == view_numel
# for start_index in range(0, view_numel, sub_view_numel):
# end_index = min(view_numel, start_index + sub_view_numel)
# sub_views = [ t[start_index:end_index] for t in gbuf_views ]
# torch.distributed.all_gather(
# sub_views,
# sub_views[data_parallel_rank],
# group = data_parallel_group,
# )
# # Each model param now contains its updated values in its
# # '.main_grad' field.
# for model in self.models:
# for dtype, param_map in model._grad_buffer_param_index_map.items():
# for param in param_map:
# param.detach().copy_(param.main_grad)
# timers('backward-params-all-gather').stop()
def
gather_model_params_1
(
self
,
args
,
timers
):
def
gather_model_params_1
(
self
,
args
,
timers
):
timers
(
'backward-params-all-gather'
).
start
()
timers
(
'backward-params-all-gather'
).
start
()
...
@@ -487,28 +530,17 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -487,28 +530,17 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# across all data parallel ranks, with grad buffer padding that is done
# across all data parallel ranks, with grad buffer padding that is done
# in distributed.py. Thus, all sub-views will have consistent start/end
# in distributed.py. Thus, all sub-views will have consistent start/end
# indexes across data parallel ranks.
# indexes across data parallel ranks.
gbuf_view_items
=
self
.
get_model_grad_buffer_dp_views
()
# sub_view_numel = 1 * 1024
# sub_numel = 1 * 1024
# sub_view_numel = 1 * 131072
# sub_numel = 1 * 131072
sub_view_numel
=
1
*
1048576
sub_numel
=
1024
*
1048576
gbuf_view_items
=
self
.
get_model_grad_buffer_dp_views_SUB
(
sub_numel
)
for
model_index
,
dtype
,
gbuf_views
in
gbuf_view_items
:
for
model_index
,
dtype
,
gbuf_views
in
gbuf_view_items
:
torch
.
distributed
.
all_gather
(
# ** Sanity check. ** (should be unnecessary; see comment above)
gbuf_views
,
view_numel
=
gbuf_views
[
0
].
nelement
()
gbuf_views
[
data_parallel_rank
],
for
view
in
gbuf_views
:
group
=
data_parallel_group
,
assert
view
.
nelement
()
==
view_numel
)
for
start_index
in
range
(
0
,
view_numel
,
sub_view_numel
):
end_index
=
min
(
view_numel
,
start_index
+
sub_view_numel
)
sub_views
=
[
t
[
start_index
:
end_index
]
for
t
in
gbuf_views
]
torch
.
distributed
.
all_gather
(
sub_views
,
sub_views
[
data_parallel_rank
],
group
=
data_parallel_group
,
)
# Each model param now contains its updated values in its
# Each model param now contains its updated values in its
# '.main_grad' field.
# '.main_grad' field.
...
@@ -641,7 +673,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -641,7 +673,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# timers('backward-params-all-gather').stop()
# timers('backward-params-all-gather').stop()
def
gather_model_params
(
self
,
*
args
):
def
gather_model_params
(
self
,
*
args
):
# >>>
# >>>
return
#
return
# <<<
# <<<
# self.gather_model_params_0(*args)
# self.gather_model_params_0(*args)
self
.
gather_model_params_1
(
*
args
)
self
.
gather_model_params_1
(
*
args
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment