"...text-generation-inference.git" did not exist on "09674e6df99918dad5169b5cc81241c884c543d0"
Commit 12d91733 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

split reduce_grads/gather_params using gbuf sub-views.

parent e30ad67e
......@@ -230,14 +230,15 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
def get_model_parallel_group(self):
return None
# @staticmethod
# def has_nan_debug(tensors):
# if isinstance(tensors, torch.Tensor):
# tensors = [ tensors ]
# assert isinstance(tensors, list)
# has_nans = [ (not torch.all(torch.isfinite(t)).item()) for t in tensors ]
# has_nan = any(has_nans)
# return has_nan
# >>>
@staticmethod
def has_nan_debug(tensors):
if isinstance(tensors, torch.Tensor):
tensors = [ tensors ]
assert isinstance(tensors, list)
has_nans = [ (not torch.all(torch.isfinite(t)).item()) for t in tensors ]
has_nan = any(has_nans)
return has_nan
# def get_local_model_param_views(self):
# '''** FOR DEBUGGING. **'''
# model_param_views = []
......@@ -269,6 +270,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# def get_world_model_grads(self):
# '''** FOR DEBUGGING. **'''
# return [ p.main_grad for p in self.get_world_model_params() ]
# <<<
def get_main_params(self):
return [ g["params"][0] for g in self.optimizer.param_groups ]
......@@ -327,6 +329,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# Distributed optimizer requires contiguous buffer; don't set to None.
_zero_grad_group_helper(model_params, set_to_none = False)
# >>>
def get_model_grad_buffer_dp_views(self):
data_parallel_world_size = mpu.get_data_parallel_world_size()
......@@ -343,8 +346,48 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
gbuf_view_items.append((model_index, dtype, gbuf_views))
return gbuf_view_items
def get_model_grad_buffer_dp_views_SUB(self, sub_view_numel):
gbuf_view_items = self.get_model_grad_buffer_dp_views()
sub_view_items = []
for model_index, dtype, gbuf_views in gbuf_view_items:
# ** Sanity check. ** (should be unnecessary; see comment above)
view_numel = gbuf_views[0].nelement()
for view in gbuf_views:
assert view.nelement() == view_numel
for start_index in range(0, view_numel, sub_view_numel):
end_index = min(view_numel, start_index + sub_view_numel)
sub_views = [ t[start_index:end_index] for t in gbuf_views ]
sub_view_items.append((model_index, dtype, sub_views))
# >>>
from lutil import pax
pax(0, {
"gbuf_view_items" : [(a,b,c.shape) for a,b,c in gbuf_view_items],
"sub_view_items" : [(a,b,c.shape) for a,b,c in sub_view_items],
})
# <<<
return sub_view_items
# def get_model_grad_buffers_SINGLE(self):
# data_parallel_world_size = mpu.get_data_parallel_world_size()
# # Grad buffers.
# gbuf_items = []
# for model_index, model in enumerate(self.models):
# for dtype, gbuf in model._grad_buffers.items():
def reduce_model_grads(self, args, timers):
# assert gbuf.numel_padded % data_parallel_world_size == 0
# shard_size = int(gbuf.numel_padded / data_parallel_world_size)
# gbuf_items.append((model_index, dtype, gbuf.data))
# return gbuf_items
# <<<
# >>>
def reduce_model_grads_0(self, args, timers):
'''Note: this is a different order of reduction, versus the non-
distributed optimizer, which reduces: 1) all grads, 2) embedding
grads.
......@@ -371,9 +414,44 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
group = data_parallel_group,
)
timers('backward-params-all-reduce').stop()
def reduce_model_grads_1(self, args, timers):
'''Note: this is a different order of reduction, versus the non-
distributed optimizer, which reduces: 1) all grads, 2) embedding
grads.
'''
def gather_model_params(self, args, timers):
# All-reduce embedding grads.
timers('backward-embedding-all-reduce').start()
self.allreduce_embedding_grads(args)
timers('backward-embedding-all-reduce').stop()
# Reduce-scatter all grads.
timers('backward-params-all-reduce').start()
data_parallel_rank = mpu.get_data_parallel_rank()
data_parallel_world_size = mpu.get_data_parallel_world_size()
data_parallel_group = mpu.get_data_parallel_group()
sub_numel = 1 * 1048576
sub_view_items = self.get_model_grad_buffer_dp_views_SUB(sub_numel)
for model_index, dtype, gbuf_views in gbuf_view_items:
gbuf = self.models[model_index]._grad_buffers[dtype].data
gbuf /= data_parallel_world_size
torch.distributed.reduce_scatter(
gbuf_views[data_parallel_rank],
gbuf_views,
group = data_parallel_group,
)
timers('backward-params-all-reduce').stop()
def reduce_model_grads(self, *args):
# >>>
# return
# <<<
# self.reduce_model_grads_0(*args)
self.reduce_model_grads_1(*args)
# <<<
# >>>
def gather_model_params_0(self, args, timers):
timers('backward-params-all-gather').start()
......@@ -397,7 +475,181 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
param.detach().copy_(param.main_grad)
timers('backward-params-all-gather').stop()
def gather_model_params_1(self, args, timers):
timers('backward-params-all-gather').start()
data_parallel_rank = mpu.get_data_parallel_rank()
data_parallel_group = mpu.get_data_parallel_group()
# All-gather updated main params.
# - All grad buffer views are guaranteed to have the same num elements
# across all data parallel ranks, with grad buffer padding that is done
# in distributed.py. Thus, all sub-views will have consistent start/end
# indexes across data parallel ranks.
gbuf_view_items = self.get_model_grad_buffer_dp_views()
# sub_view_numel = 1 * 1024
# sub_view_numel = 1 * 131072
sub_view_numel = 1 * 1048576
for model_index, dtype, gbuf_views in gbuf_view_items:
# ** Sanity check. ** (should be unnecessary; see comment above)
view_numel = gbuf_views[0].nelement()
for view in gbuf_views:
assert view.nelement() == view_numel
for start_index in range(0, view_numel, sub_view_numel):
end_index = min(view_numel, start_index + sub_view_numel)
sub_views = [ t[start_index:end_index] for t in gbuf_views ]
torch.distributed.all_gather(
sub_views,
sub_views[data_parallel_rank],
group = data_parallel_group,
)
# Each model param now contains its updated values in its
# '.main_grad' field.
for model in self.models:
for dtype, param_map in model._grad_buffer_param_index_map.items():
for param in param_map:
param.detach().copy_(param.main_grad)
timers('backward-params-all-gather').stop()
# def gather_model_params_2(self, args, timers):
# raise Exception("_all_gather_base not applicable when each DP rank owns contiguous range of grad buffer.")
# timers('backward-params-all-gather').start()
# data_parallel_rank = mpu.get_data_parallel_rank()
# data_parallel_world_size = mpu.get_data_parallel_world_size()
# data_parallel_group = mpu.get_data_parallel_group()
# # All-gather updated main params.
# # - All grad buffer views are guaranteed to have the same num elements
# # across all data parallel ranks, with grad buffer padding that is done
# # in distributed.py. Thus, all sub-views will have consistent start/end
# # indexes across data parallel ranks.
# gbuf_items = self.get_model_grad_buffers_SINGLE()
# # local_sub_numel = 1 * 1024
# # local_sub_numel = 1 * 131072
# ideal_local_numel = 128 * 1048576
# ideal_world_numel = data_parallel_world_size * ideal_local_numel
# for model_index, dtype, gbuf in gbuf_items:
# gbuf_numel = gbuf.nelement()
# # >>>
# # from lutil import pax
# # pax(0, {
# # "gbuf_items" : [ (a, b, c.shape) for a, b, c in gbuf_items ],
# # "gbuf" : str(gbuf.shape),
# # "gbuf_numel" : gbuf_numel,
# # "local_sub_numel" : local_sub_numel,
# # "world_sub_numel" : world_sub_numel,
# # })
# # <<<
# for world_start_index in range(0, gbuf_numel, ideal_world_numel):
# world_end_index = \
# min(gbuf_numel, world_start_index + ideal_world_numel)
# world_numel = world_end_index - world_start_index
# assert world_numel % data_parallel_world_size == 0
# local_numel = int(world_numel / data_parallel_world_size)
# local_start_index = \
# world_start_index + data_parallel_rank * local_numel
# local_end_index = \
# min(gbuf_numel, local_start_index + local_numel)
# try:
# world_view = gbuf[world_start_index:world_end_index]
# local_view = gbuf[local_start_index:local_end_index]
# except:
# # >>>
# from lutil import pax
# pax(0, {
# "world_start_index" : world_start_index,
# "world_end_index" : world_end_index,
# "local_start_index" : local_start_index,
# "local_end_index" : local_end_index,
# })
# # <<<
# try:
# torch.distributed._all_gather_base(
# world_view,
# local_view,
# group = data_parallel_group,
# )
# except:
# # >>>
# from lutil import pax
# pax(0, {
# "data_parallel_rank" : data_parallel_rank,
# # "local_sub_numel" : local_sub_numel,
# # "world_sub_numel" : world_sub_numel,
# "world_start_index" : world_start_index,
# "world_end_index" : world_end_index,
# "local_start_index" : local_start_index,
# "local_end_index" : local_end_index,
# "gbuf" : str(gbuf.shape),
# "world_view" : str(world_view.shape),
# "local_view" : str(local_view.shape),
# "local_sub_numel / ideal" : local_sub_numel,
# "local_sub_numel / act" :
# local_end_index - local_start_index,
# })
# # <<<
# # >>>
# # from lutil import pax, tp
# # pax(0, {
# # # "gbuf" : tp(gbuf),
# # "world range" : "%d, %d"%(world_start_index, world_end_index),
# # "local range" : "%d, %d"%(local_start_index, local_end_index),
# # "world_view" : tp(world_view),
# # "local_view" : tp(local_view),
# # "gbuf view" : tp(gbuf[world_start_index:world_end_index]),
# # })
# # <<<
# # >>>
# for model_index, dtype, gbuf in gbuf_items:
# if self.has_nan_debug(gbuf):
# raise Exception("hi.")
# # from lutil import pax, tp
# # pax(0, {
# # "gbuf_items" : [ (a, b, tp(c)) for a, b, c in gbuf_items ],
# # })
# # <<<
# # Each model param now contains its updated values in its
# # '.main_grad' field.
# for model in self.models:
# for dtype, param_map in model._grad_buffer_param_index_map.items():
# for param in param_map:
# param.detach().copy_(param.main_grad)
# # >>>
# if self.has_nan_debug(param):
# raise Exception("wha?")
# # <<<
# timers('backward-params-all-gather').stop()
def gather_model_params(self, *args):
# >>>
return
# <<<
# self.gather_model_params_0(*args)
self.gather_model_params_1(*args)
# self.gather_model_params_2(*args)
# ~~~
# self.debug_model(0, "after / gather_model_params", 0)
# <<<
def _collect_main_grad_data_for_unscaling(self):
return [ g.data for g in self.get_main_grads() ]
......
......@@ -327,6 +327,7 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
from megatron import get_args
args = get_args()
my_rank = torch.distributed.get_rank()
DEBUG_ITERATION = ITERATION
if ITERATION != DEBUG_ITERATION:
return
for r in range(torch.distributed.get_world_size()):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment