Commit 34fc1e37 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

after experiments, before new arg 'distrib-opt-comm-buf-1m'.

parent 12d91733
......@@ -362,11 +362,11 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
sub_view_items.append((model_index, dtype, sub_views))
# >>>
from lutil import pax
pax(0, {
"gbuf_view_items" : [(a,b,c.shape) for a,b,c in gbuf_view_items],
"sub_view_items" : [(a,b,c.shape) for a,b,c in sub_view_items],
})
# from lutil import pax
# pax(0, {
# "gbuf_view_items" : [(a,b,"%d / %s" % (len(c), [ d.nelement() for d in c ])) for a,b,c in gbuf_view_items],
# "sub_view_items" : [(a,b,"%d / %s" % (len(c), [ d.nelement() for d in c ])) for a,b,c in sub_view_items],
# })
# <<<
return sub_view_items
......@@ -432,10 +432,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
data_parallel_group = mpu.get_data_parallel_group()
sub_numel = 1 * 1048576
sub_view_items = self.get_model_grad_buffer_dp_views_SUB(sub_numel)
gbuf_view_items = self.get_model_grad_buffer_dp_views_SUB(sub_numel)
for model_index, dtype, gbuf_views in gbuf_view_items:
gbuf = self.models[model_index]._grad_buffers[dtype].data
gbuf /= data_parallel_world_size
# gbuf = self.models[model_index]._grad_buffers[dtype].data
# gbuf /= data_parallel_world_size
torch.distributed.reduce_scatter(
gbuf_views[data_parallel_rank],
gbuf_views,
......@@ -444,7 +444,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
timers('backward-params-all-reduce').stop()
def reduce_model_grads(self, *args):
# >>>
# return
return
# <<<
# self.reduce_model_grads_0(*args)
self.reduce_model_grads_1(*args)
......@@ -475,6 +475,49 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
param.detach().copy_(param.main_grad)
timers('backward-params-all-gather').stop()
# def gather_model_params_1(self, args, timers):
# timers('backward-params-all-gather').start()
# data_parallel_rank = mpu.get_data_parallel_rank()
# data_parallel_group = mpu.get_data_parallel_group()
# # All-gather updated main params.
# # - All grad buffer views are guaranteed to have the same num elements
# # across all data parallel ranks, with grad buffer padding that is done
# # in distributed.py. Thus, all sub-views will have consistent start/end
# # indexes across data parallel ranks.
# gbuf_view_items = self.get_model_grad_buffer_dp_views()
# # sub_view_numel = 1 * 1024
# # sub_view_numel = 1 * 131072
# sub_view_numel = 256 * 1048576
# for model_index, dtype, gbuf_views in gbuf_view_items:
# # ** Sanity check. ** (should be unnecessary; see comment above)
# view_numel = gbuf_views[0].nelement()
# for view in gbuf_views:
# assert view.nelement() == view_numel
# for start_index in range(0, view_numel, sub_view_numel):
# end_index = min(view_numel, start_index + sub_view_numel)
# sub_views = [ t[start_index:end_index] for t in gbuf_views ]
# torch.distributed.all_gather(
# sub_views,
# sub_views[data_parallel_rank],
# group = data_parallel_group,
# )
# # Each model param now contains its updated values in its
# # '.main_grad' field.
# for model in self.models:
# for dtype, param_map in model._grad_buffer_param_index_map.items():
# for param in param_map:
# param.detach().copy_(param.main_grad)
# timers('backward-params-all-gather').stop()
def gather_model_params_1(self, args, timers):
timers('backward-params-all-gather').start()
......@@ -487,26 +530,15 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# across all data parallel ranks, with grad buffer padding that is done
# in distributed.py. Thus, all sub-views will have consistent start/end
# indexes across data parallel ranks.
gbuf_view_items = self.get_model_grad_buffer_dp_views()
# sub_view_numel = 1 * 1024
# sub_view_numel = 1 * 131072
sub_view_numel = 1 * 1048576
# sub_numel = 1 * 1024
# sub_numel = 1 * 131072
sub_numel = 1024 * 1048576
gbuf_view_items = self.get_model_grad_buffer_dp_views_SUB(sub_numel)
for model_index, dtype, gbuf_views in gbuf_view_items:
# ** Sanity check. ** (should be unnecessary; see comment above)
view_numel = gbuf_views[0].nelement()
for view in gbuf_views:
assert view.nelement() == view_numel
for start_index in range(0, view_numel, sub_view_numel):
end_index = min(view_numel, start_index + sub_view_numel)
sub_views = [ t[start_index:end_index] for t in gbuf_views ]
torch.distributed.all_gather(
sub_views,
sub_views[data_parallel_rank],
gbuf_views,
gbuf_views[data_parallel_rank],
group = data_parallel_group,
)
......@@ -641,7 +673,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# timers('backward-params-all-gather').stop()
def gather_model_params(self, *args):
# >>>
return
# return
# <<<
# self.gather_model_params_0(*args)
self.gather_model_params_1(*args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment