Commit 64b94f00 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

setup code to try _reduce_scatter_base, _all_gather_base.

parent 6728a780
...@@ -72,6 +72,11 @@ def clip_grad_norm_fp32(parameters, grads_for_norm, ...@@ -72,6 +72,11 @@ def clip_grad_norm_fp32(parameters, grads_for_norm,
# grads_for_norm.append(grad) # grads_for_norm.append(grad)
# <<< # <<<
# >>>
# Grads.
grads = [ p.grad for p in parameters if p is not None ]
# <<<
# Norm parameters. # Norm parameters.
max_norm = float(max_norm) max_norm = float(max_norm)
norm_type = float(norm_type) norm_type = float(norm_type)
...@@ -115,14 +120,14 @@ def clip_grad_norm_fp32(parameters, grads_for_norm, ...@@ -115,14 +120,14 @@ def clip_grad_norm_fp32(parameters, grads_for_norm,
total_norm = total_norm.item() ** (1.0 / norm_type) total_norm = total_norm.item() ** (1.0 / norm_type)
# >>> # >>>
from megatron import get_args # from megatron import get_args
from lutil import pax # from lutil import pax
args = get_args() # args = get_args()
pax(0, { # pax(0, {
"use distrib opt" : args.use_distributed_optimizer, # "use distrib opt" : args.use_distributed_optimizer,
"norm_type" : norm_type, # "norm_type" : norm_type,
"total_norm" : total_norm, # "total_norm" : total_norm,
}) # })
# <<< # <<<
# Scale. # Scale.
......
...@@ -413,6 +413,19 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -413,6 +413,19 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
gbuf_view_items.append((model_index, dtype, gbuf_views)) gbuf_view_items.append((model_index, dtype, gbuf_views))
return gbuf_view_items return gbuf_view_items
# >>>
def get_model_grad_buffer_dp_views_SINGLE(self):
data_parallel_world_size = mpu.get_data_parallel_world_size()
# Grad buffer views.
gbuf_items = []
for model_index, model in enumerate(self.models):
for dtype, gbuf in model._grad_buffers.items():
gbuf_items.append((model_index, dtype, gbuf.data))
return gbuf_items
# <<<
def get_model_grad_buffer_dp_views_chunked(self, mem_savings_factor): def get_model_grad_buffer_dp_views_chunked(self, mem_savings_factor):
...@@ -466,14 +479,36 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -466,14 +479,36 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
gbuf.data /= data_parallel_world_size gbuf.data /= data_parallel_world_size
# Reduce scatter all grads. # Reduce scatter all grads.
gbuf_view_items = \ # >>>
self.get_model_grad_buffer_dp_views_chunked(mem_savings_factor) # gbuf_view_items = \
for model_index, dtype, gbuf_views in gbuf_view_items: # self.get_model_grad_buffer_dp_views_chunked(mem_savings_factor)
torch.distributed.reduce_scatter( # for model_index, dtype, gbuf_views in gbuf_view_items:
# torch.distributed.reduce_scatter(
# gbuf_views[data_parallel_rank],
# gbuf_views,
# group = data_parallel_group,
# )
# +++
gbuf_view_items = self.get_model_grad_buffer_dp_views()
gbuf_view_items_SINGLE = self.get_model_grad_buffer_dp_views_SINGLE()
for index, (model_index, dtype, gbuf_views) in enumerate(gbuf_view_items):
# >>>
pax(0, {
"gbuf_view" : gbuf_views[data_parallel_rank].shape,
"gbuf SINGLE" : gbuf_view_items_SINGLE[index][2].shape,
})
# <<<
torch.distributed._reduce_scatter_base(
gbuf_views[data_parallel_rank], gbuf_views[data_parallel_rank],
gbuf_views, gbuf_view_items_SINGLE[index][2],
group = data_parallel_group, group = data_parallel_group,
) )
# torch.distributed.reduce_scatter(
# gbuf_views[data_parallel_rank],
# gbuf_views,
# group = data_parallel_group,
# )
# <<<
timers('backward-params-all-reduce').stop() timers('backward-params-all-reduce').stop()
def gather_model_params(self, args, timers): def gather_model_params(self, args, timers):
...@@ -489,14 +524,25 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -489,14 +524,25 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# across all data parallel ranks, with grad buffer padding that is done # across all data parallel ranks, with grad buffer padding that is done
# in distributed.py. Thus, all sub-views will have consistent start/end # in distributed.py. Thus, all sub-views will have consistent start/end
# indexes across data parallel ranks. # indexes across data parallel ranks.
gbuf_view_items = \ # >>>
self.get_model_grad_buffer_dp_views_chunked(mem_savings_factor) # gbuf_view_items = \
for model_index, dtype, gbuf_views in gbuf_view_items: # self.get_model_grad_buffer_dp_views_chunked(mem_savings_factor)
torch.distributed.all_gather( # for model_index, dtype, gbuf_views in gbuf_view_items:
gbuf_views, # torch.distributed.all_gather(
gbuf_views[data_parallel_rank], # gbuf_views,
group = data_parallel_group, # gbuf_views[data_parallel_rank],
# group = data_parallel_group,
# )
# +++
gbuf_view_items = self.get_model_grad_buffer_dp_views()
gbuf_view_items_SINGLE = self.get_model_grad_buffer_dp_views_SINGLE()
for index, (model_index, dtype, gbuf_views) in enumerate(gbuf_view_items):
torch.distributed._all_gather_base(
# gbuf_view_items_SINGLE[index][2],
# gbuf_views[data_parallel_rank],
# group = data_parallel_group,
) )
# <<<
# Each model param now contains its updated values in its # Each model param now contains its updated values in its
# '.main_grad' field. # '.main_grad' field.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment