Commit 0528bf77 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

more cleanup

parent 94a90215
......@@ -306,11 +306,9 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# main_params.extend(main_group["params"])
# ** using contiguous buffer; don't set_to_none **
_zero_grad_group_helper(model_params, set_to_none = False) # set_to_none)
_zero_grad_group_helper(model_params, set_to_none = False)
# _zero_grad_group_helper(params, set_to_none = False)
# pax(0, {"model_params": model_params})
# def get_model_grad_buffer_dp_views(self):
# # >>>
......@@ -344,13 +342,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# return gbuf_view_items
def get_model_grad_buffer_dp_views(self):
# >>>
# ** only contiguous grad buffer supported, for now [ TEMPORARY ] **
args = get_args()
assert args.use_contiguous_buffers_in_local_ddp
# <<<
# data_parallel_rank = mpu.get_data_parallel_rank()
data_parallel_world_size = mpu.get_data_parallel_world_size()
# Grad buffer views.
......@@ -358,27 +349,12 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
for model_index, model in enumerate(self.models):
for dtype, gbuf in model._grad_buffers.items():
# gbuf_size = gbuf.numel_padded
assert gbuf.numel_padded % data_parallel_world_size == 0
shard_size = int(gbuf.numel_padded / data_parallel_world_size)
# pax(0, {
# "numel" : gbuf.numel,
# "numel_padded" : gbuf.numel_padded,
# "shard_size / f" : gbuf.numel_padded/data_parallel_world_size,
# "shard_size / i" : shard_size,
# })
gbuf_views = [gbuf.data[(r*shard_size):((r+1)*shard_size)]
for r in range(data_parallel_world_size)]
gbuf_view_items.append((model_index, dtype, gbuf_views))
# pax(0, {
# "gbuf_view_items" : gbuf_view_items,
# **{
# "views / %d" % i : item[2]
# for i, item in enumerate(gbuf_view_items)
# },
# })
return gbuf_view_items
def reduce_model_grads(self, args, timers):
......@@ -417,9 +393,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
data_parallel_rank = mpu.get_data_parallel_rank()
data_parallel_group = mpu.get_data_parallel_group()
gbuf_view_items = self.get_model_grad_buffer_dp_views()
# All-gather updated main params.
gbuf_view_items = self.get_model_grad_buffer_dp_views()
for model_index, dtype, gbuf_views in gbuf_view_items:
torch.distributed.all_gather(
gbuf_views,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment