Commit 3d433e8a authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

cleanup.

parent 19730725
......@@ -737,7 +737,7 @@ def _add_distributed_args(parser):
help='Trade-off memory savings & iteration time, for '
'disributed optimizer\'s communication operations (i.e., '
'(reduce/gather). This value ranges from 0.0 (default, '
'no memory savings) to 1.0 (max memory savings, at '
'no memory savings) to 1.0 (max memory savings, at the '
'expense of iteration time).')
return parser
......
......@@ -231,14 +231,14 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
return None
# >>>
@staticmethod
def has_nan_debug(tensors):
if isinstance(tensors, torch.Tensor):
tensors = [ tensors ]
assert isinstance(tensors, list)
has_nans = [ (not torch.all(torch.isfinite(t)).item()) for t in tensors ]
has_nan = any(has_nans)
return has_nan
# @staticmethod
# def has_nan_debug(tensors):
# if isinstance(tensors, torch.Tensor):
# tensors = [ tensors ]
# assert isinstance(tensors, list)
# has_nans = [ (not torch.all(torch.isfinite(t)).item()) for t in tensors ]
# has_nan = any(has_nans)
# return has_nan
# def get_local_model_param_views(self):
# '''** FOR DEBUGGING. **'''
# model_param_views = []
......@@ -329,7 +329,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# Distributed optimizer requires contiguous buffer; don't set to None.
_zero_grad_group_helper(model_params, set_to_none = False)
# >>>
def get_model_grad_buffer_dp_views(self):
data_parallel_world_size = mpu.get_data_parallel_world_size()
......@@ -349,6 +348,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
def get_model_grad_buffer_dp_views_chunked(self, mem_savings_factor):
# Iterate grad buffers & chunk.
gbuf_view_items = self.get_model_grad_buffer_dp_views()
chunk_view_items = []
for model_index, dtype, gbuf_views in gbuf_view_items:
......@@ -358,79 +358,22 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
for view in gbuf_views:
assert view.nelement() == view_numel
# chunk_numel_min = 1024
# chunk_numel_min = 16384
# Compute chunk size (via savings factor).
chunk_numel_min = 131072
# chunk_numel_min = 1048576
chunk_numel_max = view_numel
# chunk_numel_min_log = math.log(chunk_numel_min)
# chunk_numel_max_log = math.log(chunk_numel_max)
# chunk_numel_log = (chunk_numel_min_log + chunk_numel_max_log) / 2
# chunk_numel = int(math.exp(chunk_numel_log))
chunk_numel = int(
mem_savings_factor * chunk_numel_min
+ (1 - mem_savings_factor) * chunk_numel_max
)
# >>>
# from lutil import pax
# pax(0, {
# "view_numel" : view_numel,
# "chunk_numel_min" : chunk_numel_min,
# "chunk_numel_max" : chunk_numel_max,
# "chunk_numel_min_log" : chunk_numel_min_log,
# "chunk_numel_max_log" : chunk_numel_max_log,
# "chunk_numel_log" : chunk_numel_log,
# "chunk_numel" : chunk_numel,
# "mem_savings_factor" : mem_savings_factor,
# })
# <<<
# Chunk views.
for start_index in range(0, view_numel, chunk_numel):
end_index = min(view_numel, start_index + chunk_numel)
chunk_views = [ t[start_index:end_index] for t in gbuf_views ]
chunk_view_items.append((model_index, dtype, chunk_views))
# >>>
# from lutil import pax
# pax(0, {
# "gbuf_view_items" : [(a,b,"%d / %s" % (len(c), [ d.nelement() for d in c ])) for a,b,c in gbuf_view_items],
# "chunk_view_items" : [(a,b,"%d / %s" % (len(c), [ d.nelement() for d in c ])) for a,b,c in chunk_view_items],
# })
# <<<
return chunk_view_items
# <<<
# >>>
# def reduce_model_grads_0(self, args, timers):
# '''Note: this is a different order of reduction, versus the non-
# distributed optimizer, which reduces: 1) all grads, 2) embedding
# grads.
# '''
# # All-reduce embedding grads.
# timers('backward-embedding-all-reduce').start()
# self.allreduce_embedding_grads(args)
# timers('backward-embedding-all-reduce').stop()
# # Reduce-scatter all grads.
# timers('backward-params-all-reduce').start()
# data_parallel_rank = mpu.get_data_parallel_rank()
# data_parallel_world_size = mpu.get_data_parallel_world_size()
# data_parallel_group = mpu.get_data_parallel_group()
# gbuf_view_items = self.get_model_grad_buffer_dp_views()
# for model_index, dtype, gbuf_views in gbuf_view_items:
# gbuf = self.models[model_index]._grad_buffers[dtype].data
# gbuf /= data_parallel_world_size
# torch.distributed.reduce_scatter(
# gbuf_views[data_parallel_rank],
# gbuf_views,
# group = data_parallel_group,
# )
# timers('backward-params-all-reduce').stop()
# def reduce_model_grads_1(self, args, timers):
def reduce_model_grads(self, args, timers):
'''Note: this is a different order of reduction, versus the non-
distributed optimizer, which reduces: 1) all grads, 2) embedding
......@@ -458,91 +401,12 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
gbuf_view_items = \
self.get_model_grad_buffer_dp_views_chunked(mem_savings_factor)
for model_index, dtype, gbuf_views in gbuf_view_items:
# gbuf = self.models[model_index]._grad_buffers[dtype].data
# gbuf /= data_parallel_world_size
torch.distributed.reduce_scatter(
gbuf_views[data_parallel_rank],
gbuf_views,
group = data_parallel_group,
)
timers('backward-params-all-reduce').stop()
# def reduce_model_grads(self, *args):
# # >>>
# return
# # <<<
# # self.reduce_model_grads_0(*args)
# self.reduce_model_grads_1(*args)
# <<<
# >>>
# def gather_model_params_0(self, args, timers):
# timers('backward-params-all-gather').start()
# data_parallel_rank = mpu.get_data_parallel_rank()
# data_parallel_group = mpu.get_data_parallel_group()
# # All-gather updated main params.
# gbuf_view_items = self.get_model_grad_buffer_dp_views()
# for model_index, dtype, gbuf_views in gbuf_view_items:
# torch.distributed.all_gather(
# gbuf_views,
# gbuf_views[data_parallel_rank],
# group = data_parallel_group,
# )
# # Each model param now contains its updated values in its
# # '.main_grad' field.
# for model in self.models:
# for dtype, param_map in model._grad_buffer_param_index_map.items():
# for param in param_map:
# param.detach().copy_(param.main_grad)
# timers('backward-params-all-gather').stop()
# def gather_model_params_1(self, args, timers):
# timers('backward-params-all-gather').start()
# data_parallel_rank = mpu.get_data_parallel_rank()
# data_parallel_group = mpu.get_data_parallel_group()
# # All-gather updated main params.
# # - All grad buffer views are guaranteed to have the same num elements
# # across all data parallel ranks, with grad buffer padding that is done
# # in distributed.py. Thus, all sub-views will have consistent start/end
# # indexes across data parallel ranks.
# gbuf_view_items = self.get_model_grad_buffer_dp_views()
# # sub_view_numel = 1 * 1024
# # sub_view_numel = 1 * 131072
# sub_view_numel = 256 * 1048576
# for model_index, dtype, gbuf_views in gbuf_view_items:
# # ** Sanity check. ** (should be unnecessary; see comment above)
# view_numel = gbuf_views[0].nelement()
# for view in gbuf_views:
# assert view.nelement() == view_numel
# for start_index in range(0, view_numel, sub_view_numel):
# end_index = min(view_numel, start_index + sub_view_numel)
# sub_views = [ t[start_index:end_index] for t in gbuf_views ]
# torch.distributed.all_gather(
# sub_views,
# sub_views[data_parallel_rank],
# group = data_parallel_group,
# )
# # Each model param now contains its updated values in its
# # '.main_grad' field.
# for model in self.models:
# for dtype, param_map in model._grad_buffer_param_index_map.items():
# for param in param_map:
# param.detach().copy_(param.main_grad)
# timers('backward-params-all-gather').stop()
# def gather_model_params_1(self, args, timers):
def gather_model_params(self, args, timers):
timers('backward-params-all-gather').start()
......@@ -556,11 +420,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# across all data parallel ranks, with grad buffer padding that is done
# in distributed.py. Thus, all sub-views will have consistent start/end
# indexes across data parallel ranks.
# sub_numel = 1 * 1024
# sub_numel = 1 * 131072
# sub_numel = 1024 * 1048576
# gbuf_view_items = self.get_model_grad_buffer_dp_views_SUB(sub_numel)
gbuf_view_items = \
self.get_model_grad_buffer_dp_views_chunked(mem_savings_factor)
for model_index, dtype, gbuf_views in gbuf_view_items:
......@@ -578,138 +437,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
param.detach().copy_(param.main_grad)
timers('backward-params-all-gather').stop()
# def gather_model_params_2(self, args, timers):
# raise Exception("_all_gather_base not applicable when each DP rank owns contiguous range of grad buffer.")
# timers('backward-params-all-gather').start()
# data_parallel_rank = mpu.get_data_parallel_rank()
# data_parallel_world_size = mpu.get_data_parallel_world_size()
# data_parallel_group = mpu.get_data_parallel_group()
# # All-gather updated main params.
# # - All grad buffer views are guaranteed to have the same num elements
# # across all data parallel ranks, with grad buffer padding that is done
# # in distributed.py. Thus, all sub-views will have consistent start/end
# # indexes across data parallel ranks.
# gbuf_items = self.get_model_grad_buffers_SINGLE()
# # local_sub_numel = 1 * 1024
# # local_sub_numel = 1 * 131072
# ideal_local_numel = 128 * 1048576
# ideal_world_numel = data_parallel_world_size * ideal_local_numel
# for model_index, dtype, gbuf in gbuf_items:
# gbuf_numel = gbuf.nelement()
# # >>>
# # from lutil import pax
# # pax(0, {
# # "gbuf_items" : [ (a, b, c.shape) for a, b, c in gbuf_items ],
# # "gbuf" : str(gbuf.shape),
# # "gbuf_numel" : gbuf_numel,
# # "local_sub_numel" : local_sub_numel,
# # "world_sub_numel" : world_sub_numel,
# # })
# # <<<
# for world_start_index in range(0, gbuf_numel, ideal_world_numel):
# world_end_index = \
# min(gbuf_numel, world_start_index + ideal_world_numel)
# world_numel = world_end_index - world_start_index
# assert world_numel % data_parallel_world_size == 0
# local_numel = int(world_numel / data_parallel_world_size)
# local_start_index = \
# world_start_index + data_parallel_rank * local_numel
# local_end_index = \
# min(gbuf_numel, local_start_index + local_numel)
# try:
# world_view = gbuf[world_start_index:world_end_index]
# local_view = gbuf[local_start_index:local_end_index]
# except:
# # >>>
# from lutil import pax
# pax(0, {
# "world_start_index" : world_start_index,
# "world_end_index" : world_end_index,
# "local_start_index" : local_start_index,
# "local_end_index" : local_end_index,
# })
# # <<<
# try:
# torch.distributed._all_gather_base(
# world_view,
# local_view,
# group = data_parallel_group,
# )
# except:
# # >>>
# from lutil import pax
# pax(0, {
# "data_parallel_rank" : data_parallel_rank,
# # "local_sub_numel" : local_sub_numel,
# # "world_sub_numel" : world_sub_numel,
# "world_start_index" : world_start_index,
# "world_end_index" : world_end_index,
# "local_start_index" : local_start_index,
# "local_end_index" : local_end_index,
# "gbuf" : str(gbuf.shape),
# "world_view" : str(world_view.shape),
# "local_view" : str(local_view.shape),
# "local_sub_numel / ideal" : local_sub_numel,
# "local_sub_numel / act" :
# local_end_index - local_start_index,
# })
# # <<<
# # >>>
# # from lutil import pax, tp
# # pax(0, {
# # # "gbuf" : tp(gbuf),
# # "world range" : "%d, %d"%(world_start_index, world_end_index),
# # "local range" : "%d, %d"%(local_start_index, local_end_index),
# # "world_view" : tp(world_view),
# # "local_view" : tp(local_view),
# # "gbuf view" : tp(gbuf[world_start_index:world_end_index]),
# # })
# # <<<
# # >>>
# for model_index, dtype, gbuf in gbuf_items:
# if self.has_nan_debug(gbuf):
# raise Exception("hi.")
# # from lutil import pax, tp
# # pax(0, {
# # "gbuf_items" : [ (a, b, tp(c)) for a, b, c in gbuf_items ],
# # })
# # <<<
# # Each model param now contains its updated values in its
# # '.main_grad' field.
# for model in self.models:
# for dtype, param_map in model._grad_buffer_param_index_map.items():
# for param in param_map:
# param.detach().copy_(param.main_grad)
# # >>>
# if self.has_nan_debug(param):
# raise Exception("wha?")
# # <<<
# timers('backward-params-all-gather').stop()
# def gather_model_params(self, *args):
# # >>>
# # return
# # <<<
# # self.gather_model_params_0(*args)
# self.gather_model_params_1(*args)
# # self.gather_model_params_2(*args)
# # ~~~
# # self.debug_model(0, "after / gather_model_params", 0)
# <<<
def _collect_main_grad_data_for_unscaling(self):
return [ g.data for g in self.get_main_grads() ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment