Commit 64b94f00 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

setup code to try _reduce_scatter_base, _all_gather_base.

parent 6728a780
......@@ -72,6 +72,11 @@ def clip_grad_norm_fp32(parameters, grads_for_norm,
# grads_for_norm.append(grad)
# <<<
# >>>
# Grads.
grads = [ p.grad for p in parameters if p is not None ]
# <<<
# Norm parameters.
max_norm = float(max_norm)
norm_type = float(norm_type)
......@@ -115,14 +120,14 @@ def clip_grad_norm_fp32(parameters, grads_for_norm,
total_norm = total_norm.item() ** (1.0 / norm_type)
# >>>
from megatron import get_args
from lutil import pax
args = get_args()
pax(0, {
"use distrib opt" : args.use_distributed_optimizer,
"norm_type" : norm_type,
"total_norm" : total_norm,
})
# from megatron import get_args
# from lutil import pax
# args = get_args()
# pax(0, {
# "use distrib opt" : args.use_distributed_optimizer,
# "norm_type" : norm_type,
# "total_norm" : total_norm,
# })
# <<<
# Scale.
......
......@@ -413,6 +413,19 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
gbuf_view_items.append((model_index, dtype, gbuf_views))
return gbuf_view_items
# >>>
def get_model_grad_buffer_dp_views_SINGLE(self):
data_parallel_world_size = mpu.get_data_parallel_world_size()
# Grad buffer views.
gbuf_items = []
for model_index, model in enumerate(self.models):
for dtype, gbuf in model._grad_buffers.items():
gbuf_items.append((model_index, dtype, gbuf.data))
return gbuf_items
# <<<
def get_model_grad_buffer_dp_views_chunked(self, mem_savings_factor):
......@@ -466,14 +479,36 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
gbuf.data /= data_parallel_world_size
# Reduce scatter all grads.
gbuf_view_items = \
self.get_model_grad_buffer_dp_views_chunked(mem_savings_factor)
for model_index, dtype, gbuf_views in gbuf_view_items:
torch.distributed.reduce_scatter(
# >>>
# gbuf_view_items = \
# self.get_model_grad_buffer_dp_views_chunked(mem_savings_factor)
# for model_index, dtype, gbuf_views in gbuf_view_items:
# torch.distributed.reduce_scatter(
# gbuf_views[data_parallel_rank],
# gbuf_views,
# group = data_parallel_group,
# )
# +++
gbuf_view_items = self.get_model_grad_buffer_dp_views()
gbuf_view_items_SINGLE = self.get_model_grad_buffer_dp_views_SINGLE()
for index, (model_index, dtype, gbuf_views) in enumerate(gbuf_view_items):
# >>>
pax(0, {
"gbuf_view" : gbuf_views[data_parallel_rank].shape,
"gbuf SINGLE" : gbuf_view_items_SINGLE[index][2].shape,
})
# <<<
torch.distributed._reduce_scatter_base(
gbuf_views[data_parallel_rank],
gbuf_views,
gbuf_view_items_SINGLE[index][2],
group = data_parallel_group,
)
# torch.distributed.reduce_scatter(
# gbuf_views[data_parallel_rank],
# gbuf_views,
# group = data_parallel_group,
# )
# <<<
timers('backward-params-all-reduce').stop()
def gather_model_params(self, args, timers):
......@@ -489,14 +524,25 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# across all data parallel ranks, with grad buffer padding that is done
# in distributed.py. Thus, all sub-views will have consistent start/end
# indexes across data parallel ranks.
gbuf_view_items = \
self.get_model_grad_buffer_dp_views_chunked(mem_savings_factor)
for model_index, dtype, gbuf_views in gbuf_view_items:
torch.distributed.all_gather(
gbuf_views,
gbuf_views[data_parallel_rank],
group = data_parallel_group,
# >>>
# gbuf_view_items = \
# self.get_model_grad_buffer_dp_views_chunked(mem_savings_factor)
# for model_index, dtype, gbuf_views in gbuf_view_items:
# torch.distributed.all_gather(
# gbuf_views,
# gbuf_views[data_parallel_rank],
# group = data_parallel_group,
# )
# +++
gbuf_view_items = self.get_model_grad_buffer_dp_views()
gbuf_view_items_SINGLE = self.get_model_grad_buffer_dp_views_SINGLE()
for index, (model_index, dtype, gbuf_views) in enumerate(gbuf_view_items):
torch.distributed._all_gather_base(
# gbuf_view_items_SINGLE[index][2],
# gbuf_views[data_parallel_rank],
# group = data_parallel_group,
)
# <<<
# Each model param now contains its updated values in its
# '.main_grad' field.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment