Commit 34fc1e37 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

after experiments, before new arg 'distrib-opt-comm-buf-1m'.

parent 12d91733
...@@ -362,11 +362,11 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -362,11 +362,11 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
sub_view_items.append((model_index, dtype, sub_views)) sub_view_items.append((model_index, dtype, sub_views))
# >>> # >>>
from lutil import pax # from lutil import pax
pax(0, { # pax(0, {
"gbuf_view_items" : [(a,b,c.shape) for a,b,c in gbuf_view_items], # "gbuf_view_items" : [(a,b,"%d / %s" % (len(c), [ d.nelement() for d in c ])) for a,b,c in gbuf_view_items],
"sub_view_items" : [(a,b,c.shape) for a,b,c in sub_view_items], # "sub_view_items" : [(a,b,"%d / %s" % (len(c), [ d.nelement() for d in c ])) for a,b,c in sub_view_items],
}) # })
# <<< # <<<
return sub_view_items return sub_view_items
...@@ -432,10 +432,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -432,10 +432,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
data_parallel_group = mpu.get_data_parallel_group() data_parallel_group = mpu.get_data_parallel_group()
sub_numel = 1 * 1048576 sub_numel = 1 * 1048576
sub_view_items = self.get_model_grad_buffer_dp_views_SUB(sub_numel) gbuf_view_items = self.get_model_grad_buffer_dp_views_SUB(sub_numel)
for model_index, dtype, gbuf_views in gbuf_view_items: for model_index, dtype, gbuf_views in gbuf_view_items:
gbuf = self.models[model_index]._grad_buffers[dtype].data # gbuf = self.models[model_index]._grad_buffers[dtype].data
gbuf /= data_parallel_world_size # gbuf /= data_parallel_world_size
torch.distributed.reduce_scatter( torch.distributed.reduce_scatter(
gbuf_views[data_parallel_rank], gbuf_views[data_parallel_rank],
gbuf_views, gbuf_views,
...@@ -444,7 +444,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -444,7 +444,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
timers('backward-params-all-reduce').stop() timers('backward-params-all-reduce').stop()
def reduce_model_grads(self, *args): def reduce_model_grads(self, *args):
# >>> # >>>
# return return
# <<< # <<<
# self.reduce_model_grads_0(*args) # self.reduce_model_grads_0(*args)
self.reduce_model_grads_1(*args) self.reduce_model_grads_1(*args)
...@@ -475,6 +475,49 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -475,6 +475,49 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
param.detach().copy_(param.main_grad) param.detach().copy_(param.main_grad)
timers('backward-params-all-gather').stop() timers('backward-params-all-gather').stop()
# def gather_model_params_1(self, args, timers):
# timers('backward-params-all-gather').start()
# data_parallel_rank = mpu.get_data_parallel_rank()
# data_parallel_group = mpu.get_data_parallel_group()
# # All-gather updated main params.
# # - All grad buffer views are guaranteed to have the same num elements
# # across all data parallel ranks, with grad buffer padding that is done
# # in distributed.py. Thus, all sub-views will have consistent start/end
# # indexes across data parallel ranks.
# gbuf_view_items = self.get_model_grad_buffer_dp_views()
# # sub_view_numel = 1 * 1024
# # sub_view_numel = 1 * 131072
# sub_view_numel = 256 * 1048576
# for model_index, dtype, gbuf_views in gbuf_view_items:
# # ** Sanity check. ** (should be unnecessary; see comment above)
# view_numel = gbuf_views[0].nelement()
# for view in gbuf_views:
# assert view.nelement() == view_numel
# for start_index in range(0, view_numel, sub_view_numel):
# end_index = min(view_numel, start_index + sub_view_numel)
# sub_views = [ t[start_index:end_index] for t in gbuf_views ]
# torch.distributed.all_gather(
# sub_views,
# sub_views[data_parallel_rank],
# group = data_parallel_group,
# )
# # Each model param now contains its updated values in its
# # '.main_grad' field.
# for model in self.models:
# for dtype, param_map in model._grad_buffer_param_index_map.items():
# for param in param_map:
# param.detach().copy_(param.main_grad)
# timers('backward-params-all-gather').stop()
def gather_model_params_1(self, args, timers): def gather_model_params_1(self, args, timers):
timers('backward-params-all-gather').start() timers('backward-params-all-gather').start()
...@@ -487,28 +530,17 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -487,28 +530,17 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# across all data parallel ranks, with grad buffer padding that is done # across all data parallel ranks, with grad buffer padding that is done
# in distributed.py. Thus, all sub-views will have consistent start/end # in distributed.py. Thus, all sub-views will have consistent start/end
# indexes across data parallel ranks. # indexes across data parallel ranks.
gbuf_view_items = self.get_model_grad_buffer_dp_views()
# sub_view_numel = 1 * 1024 # sub_numel = 1 * 1024
# sub_view_numel = 1 * 131072 # sub_numel = 1 * 131072
sub_view_numel = 1 * 1048576 sub_numel = 1024 * 1048576
gbuf_view_items = self.get_model_grad_buffer_dp_views_SUB(sub_numel)
for model_index, dtype, gbuf_views in gbuf_view_items: for model_index, dtype, gbuf_views in gbuf_view_items:
torch.distributed.all_gather(
# ** Sanity check. ** (should be unnecessary; see comment above) gbuf_views,
view_numel = gbuf_views[0].nelement() gbuf_views[data_parallel_rank],
for view in gbuf_views: group = data_parallel_group,
assert view.nelement() == view_numel )
for start_index in range(0, view_numel, sub_view_numel):
end_index = min(view_numel, start_index + sub_view_numel)
sub_views = [ t[start_index:end_index] for t in gbuf_views ]
torch.distributed.all_gather(
sub_views,
sub_views[data_parallel_rank],
group = data_parallel_group,
)
# Each model param now contains its updated values in its # Each model param now contains its updated values in its
# '.main_grad' field. # '.main_grad' field.
...@@ -641,7 +673,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -641,7 +673,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# timers('backward-params-all-gather').stop() # timers('backward-params-all-gather').stop()
def gather_model_params(self, *args): def gather_model_params(self, *args):
# >>> # >>>
return # return
# <<< # <<<
# self.gather_model_params_0(*args) # self.gather_model_params_0(*args)
self.gather_model_params_1(*args) self.gather_model_params_1(*args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment