Commit ed90a1b2 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

mem savings factor training; need to check loss.

parent 06b9ebe0
...@@ -733,7 +733,7 @@ def _add_distributed_args(parser): ...@@ -733,7 +733,7 @@ def _add_distributed_args(parser):
'affects the encoder embedding.)') 'affects the encoder embedding.)')
group.add_argument('--use-distributed-optimizer', action='store_true', group.add_argument('--use-distributed-optimizer', action='store_true',
help='Use distributed optimizer.') help='Use distributed optimizer.')
group.add_argument('--disrib-opt-comm-mem-savings', default=0., type=float, group.add_argument('--distrib-opt-comm-mem-savings', default=0., type=float,
help='Trade-off memory savings & iteration time, for ' help='Trade-off memory savings & iteration time, for '
'disributed optimizer\'s communication operations (i.e., ' 'disributed optimizer\'s communication operations (i.e., '
'(reduce/gather). This value ranges from 0.0 (default, ' '(reduce/gather). This value ranges from 0.0 (default, '
......
...@@ -346,9 +346,11 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -346,9 +346,11 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
gbuf_view_items.append((model_index, dtype, gbuf_views)) gbuf_view_items.append((model_index, dtype, gbuf_views))
return gbuf_view_items return gbuf_view_items
def get_model_grad_buffer_dp_views_SUB(self, sub_view_numel):
def get_model_grad_buffer_dp_views_chunked(self, mem_savings_factor):
gbuf_view_items = self.get_model_grad_buffer_dp_views() gbuf_view_items = self.get_model_grad_buffer_dp_views()
sub_view_items = [] chunk_view_items = []
for model_index, dtype, gbuf_views in gbuf_view_items: for model_index, dtype, gbuf_views in gbuf_view_items:
# ** Sanity check. ** (should be unnecessary; see comment above) # ** Sanity check. ** (should be unnecessary; see comment above)
...@@ -356,65 +358,77 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -356,65 +358,77 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
for view in gbuf_views: for view in gbuf_views:
assert view.nelement() == view_numel assert view.nelement() == view_numel
for start_index in range(0, view_numel, sub_view_numel): chunk_numel_min = 1024**2
end_index = min(view_numel, start_index + sub_view_numel) chunk_numel_max = view_numel
sub_views = [ t[start_index:end_index] for t in gbuf_views ] # chunk_numel_min_log = math.log(chunk_numel_min)
sub_view_items.append((model_index, dtype, sub_views)) # chunk_numel_max_log = math.log(chunk_numel_max)
# chunk_numel_log = (chunk_numel_min_log + chunk_numel_max_log) / 2
# chunk_numel = int(math.exp(chunk_numel_log))
chunk_numel = int(
mem_savings_factor * chunk_numel_min
+ (1 - mem_savings_factor) * chunk_numel_max
)
# >>> # >>>
# from lutil import pax # from lutil import pax
# pax(0, { # pax(0, {
# "gbuf_view_items" : [(a,b,"%d / %s" % (len(c), [ d.nelement() for d in c ])) for a,b,c in gbuf_view_items], # "view_numel" : view_numel,
# "sub_view_items" : [(a,b,"%d / %s" % (len(c), [ d.nelement() for d in c ])) for a,b,c in sub_view_items], # "chunk_numel_min" : chunk_numel_min,
# "chunk_numel_max" : chunk_numel_max,
# "chunk_numel_min_log" : chunk_numel_min_log,
# "chunk_numel_max_log" : chunk_numel_max_log,
# "chunk_numel_log" : chunk_numel_log,
# "chunk_numel" : chunk_numel,
# "mem_savings_factor" : mem_savings_factor,
# }) # })
# <<< # <<<
return sub_view_items for start_index in range(0, view_numel, chunk_numel):
# def get_model_grad_buffers_SINGLE(self): end_index = min(view_numel, start_index + chunk_numel)
chunk_views = [ t[start_index:end_index] for t in gbuf_views ]
# data_parallel_world_size = mpu.get_data_parallel_world_size() chunk_view_items.append((model_index, dtype, chunk_views))
# # Grad buffers.
# gbuf_items = []
# for model_index, model in enumerate(self.models):
# for dtype, gbuf in model._grad_buffers.items():
# assert gbuf.numel_padded % data_parallel_world_size == 0 # >>>
# shard_size = int(gbuf.numel_padded / data_parallel_world_size) # from lutil import pax
# gbuf_items.append((model_index, dtype, gbuf.data)) # pax(0, {
# "gbuf_view_items" : [(a,b,"%d / %s" % (len(c), [ d.nelement() for d in c ])) for a,b,c in gbuf_view_items],
# "chunk_view_items" : [(a,b,"%d / %s" % (len(c), [ d.nelement() for d in c ])) for a,b,c in chunk_view_items],
# })
# <<<
# return gbuf_items return chunk_view_items
# <<< # <<<
# >>> # >>>
def reduce_model_grads_0(self, args, timers): # def reduce_model_grads_0(self, args, timers):
'''Note: this is a different order of reduction, versus the non- # '''Note: this is a different order of reduction, versus the non-
distributed optimizer, which reduces: 1) all grads, 2) embedding # distributed optimizer, which reduces: 1) all grads, 2) embedding
grads. # grads.
''' # '''
# All-reduce embedding grads. # # All-reduce embedding grads.
timers('backward-embedding-all-reduce').start() # timers('backward-embedding-all-reduce').start()
self.allreduce_embedding_grads(args) # self.allreduce_embedding_grads(args)
timers('backward-embedding-all-reduce').stop() # timers('backward-embedding-all-reduce').stop()
# Reduce-scatter all grads. # # Reduce-scatter all grads.
timers('backward-params-all-reduce').start() # timers('backward-params-all-reduce').start()
data_parallel_rank = mpu.get_data_parallel_rank() # data_parallel_rank = mpu.get_data_parallel_rank()
data_parallel_world_size = mpu.get_data_parallel_world_size() # data_parallel_world_size = mpu.get_data_parallel_world_size()
data_parallel_group = mpu.get_data_parallel_group() # data_parallel_group = mpu.get_data_parallel_group()
gbuf_view_items = self.get_model_grad_buffer_dp_views() # gbuf_view_items = self.get_model_grad_buffer_dp_views()
for model_index, dtype, gbuf_views in gbuf_view_items: # for model_index, dtype, gbuf_views in gbuf_view_items:
gbuf = self.models[model_index]._grad_buffers[dtype].data # gbuf = self.models[model_index]._grad_buffers[dtype].data
gbuf /= data_parallel_world_size # gbuf /= data_parallel_world_size
torch.distributed.reduce_scatter( # torch.distributed.reduce_scatter(
gbuf_views[data_parallel_rank], # gbuf_views[data_parallel_rank],
gbuf_views, # gbuf_views,
group = data_parallel_group, # group = data_parallel_group,
) # )
timers('backward-params-all-reduce').stop() # timers('backward-params-all-reduce').stop()
def reduce_model_grads_1(self, args, timers): # def reduce_model_grads_1(self, args, timers):
def reduce_model_grads(self, args, timers):
'''Note: this is a different order of reduction, versus the non- '''Note: this is a different order of reduction, versus the non-
distributed optimizer, which reduces: 1) all grads, 2) embedding distributed optimizer, which reduces: 1) all grads, 2) embedding
grads. grads.
...@@ -425,14 +439,21 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -425,14 +439,21 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
self.allreduce_embedding_grads(args) self.allreduce_embedding_grads(args)
timers('backward-embedding-all-reduce').stop() timers('backward-embedding-all-reduce').stop()
# Reduce-scatter all grads. # Reduce-scatter setup.
timers('backward-params-all-reduce').start() timers('backward-params-all-reduce').start()
data_parallel_rank = mpu.get_data_parallel_rank() data_parallel_rank = mpu.get_data_parallel_rank()
data_parallel_world_size = mpu.get_data_parallel_world_size() data_parallel_world_size = mpu.get_data_parallel_world_size()
data_parallel_group = mpu.get_data_parallel_group() data_parallel_group = mpu.get_data_parallel_group()
mem_savings_factor = args.distrib_opt_comm_mem_savings
# Scale grad buffers by '1 / data_parallel_world_size'.
for model in self.models:
for dtype, gbuf in model._grad_buffers.items():
gbuf.data /= data_parallel_world_size
sub_numel = 1 * 1048576 # Reduce scatter all grads.
gbuf_view_items = self.get_model_grad_buffer_dp_views_SUB(sub_numel) gbuf_view_items = \
self.get_model_grad_buffer_dp_views_chunked(mem_savings_factor)
for model_index, dtype, gbuf_views in gbuf_view_items: for model_index, dtype, gbuf_views in gbuf_view_items:
# gbuf = self.models[model_index]._grad_buffers[dtype].data # gbuf = self.models[model_index]._grad_buffers[dtype].data
# gbuf /= data_parallel_world_size # gbuf /= data_parallel_world_size
...@@ -442,39 +463,39 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -442,39 +463,39 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
group = data_parallel_group, group = data_parallel_group,
) )
timers('backward-params-all-reduce').stop() timers('backward-params-all-reduce').stop()
def reduce_model_grads(self, *args): # def reduce_model_grads(self, *args):
# >>> # # >>>
return # return
# <<< # # <<<
# self.reduce_model_grads_0(*args) # # self.reduce_model_grads_0(*args)
self.reduce_model_grads_1(*args) # self.reduce_model_grads_1(*args)
# <<< # <<<
# >>> # >>>
def gather_model_params_0(self, args, timers): # def gather_model_params_0(self, args, timers):
timers('backward-params-all-gather').start() # timers('backward-params-all-gather').start()
data_parallel_rank = mpu.get_data_parallel_rank() # data_parallel_rank = mpu.get_data_parallel_rank()
data_parallel_group = mpu.get_data_parallel_group() # data_parallel_group = mpu.get_data_parallel_group()
# All-gather updated main params. # # All-gather updated main params.
gbuf_view_items = self.get_model_grad_buffer_dp_views() # gbuf_view_items = self.get_model_grad_buffer_dp_views()
for model_index, dtype, gbuf_views in gbuf_view_items: # for model_index, dtype, gbuf_views in gbuf_view_items:
torch.distributed.all_gather( # torch.distributed.all_gather(
gbuf_views, # gbuf_views,
gbuf_views[data_parallel_rank], # gbuf_views[data_parallel_rank],
group = data_parallel_group, # group = data_parallel_group,
) # )
# Each model param now contains its updated values in its # # Each model param now contains its updated values in its
# '.main_grad' field. # # '.main_grad' field.
for model in self.models: # for model in self.models:
for dtype, param_map in model._grad_buffer_param_index_map.items(): # for dtype, param_map in model._grad_buffer_param_index_map.items():
for param in param_map: # for param in param_map:
param.detach().copy_(param.main_grad) # param.detach().copy_(param.main_grad)
timers('backward-params-all-gather').stop() # timers('backward-params-all-gather').stop()
# def gather_model_params_1(self, args, timers): # def gather_model_params_1(self, args, timers):
# timers('backward-params-all-gather').start() # timers('backward-params-all-gather').start()
...@@ -518,12 +539,14 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -518,12 +539,14 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# param.detach().copy_(param.main_grad) # param.detach().copy_(param.main_grad)
# timers('backward-params-all-gather').stop() # timers('backward-params-all-gather').stop()
def gather_model_params_1(self, args, timers): # def gather_model_params_1(self, args, timers):
def gather_model_params(self, args, timers):
timers('backward-params-all-gather').start() timers('backward-params-all-gather').start()
data_parallel_rank = mpu.get_data_parallel_rank() data_parallel_rank = mpu.get_data_parallel_rank()
data_parallel_group = mpu.get_data_parallel_group() data_parallel_group = mpu.get_data_parallel_group()
mem_savings_factor = args.distrib_opt_comm_mem_savings
# All-gather updated main params. # All-gather updated main params.
# - All grad buffer views are guaranteed to have the same num elements # - All grad buffer views are guaranteed to have the same num elements
...@@ -533,8 +556,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -533,8 +556,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# sub_numel = 1 * 1024 # sub_numel = 1 * 1024
# sub_numel = 1 * 131072 # sub_numel = 1 * 131072
sub_numel = 1024 * 1048576 # sub_numel = 1024 * 1048576
gbuf_view_items = self.get_model_grad_buffer_dp_views_SUB(sub_numel) # gbuf_view_items = self.get_model_grad_buffer_dp_views_SUB(sub_numel)
gbuf_view_items = \
self.get_model_grad_buffer_dp_views_chunked(mem_savings_factor)
for model_index, dtype, gbuf_views in gbuf_view_items: for model_index, dtype, gbuf_views in gbuf_view_items:
torch.distributed.all_gather( torch.distributed.all_gather(
gbuf_views, gbuf_views,
...@@ -671,16 +696,16 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -671,16 +696,16 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# # <<< # # <<<
# timers('backward-params-all-gather').stop() # timers('backward-params-all-gather').stop()
def gather_model_params(self, *args): # def gather_model_params(self, *args):
# >>> # # >>>
# return # # return
# <<< # # <<<
# self.gather_model_params_0(*args) # # self.gather_model_params_0(*args)
self.gather_model_params_1(*args) # self.gather_model_params_1(*args)
# self.gather_model_params_2(*args) # # self.gather_model_params_2(*args)
# ~~~ # # ~~~
# self.debug_model(0, "after / gather_model_params", 0) # # self.debug_model(0, "after / gather_model_params", 0)
# <<< # <<<
def _collect_main_grad_data_for_unscaling(self): def _collect_main_grad_data_for_unscaling(self):
......
...@@ -322,61 +322,61 @@ class MixedPrecisionOptimizer(MegatronOptimizer): ...@@ -322,61 +322,61 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
return found_inf_flag return found_inf_flag
# >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
@classmethod # @classmethod
def debug_base(cls, ITERATION, key, value): # def debug_base(cls, ITERATION, key, value):
from megatron import get_args # from megatron import get_args
args = get_args() # args = get_args()
my_rank = torch.distributed.get_rank() # my_rank = torch.distributed.get_rank()
DEBUG_ITERATION = ITERATION # DEBUG_ITERATION = ITERATION
if ITERATION != DEBUG_ITERATION: # if ITERATION != DEBUG_ITERATION:
return # return
for r in range(torch.distributed.get_world_size()): # for r in range(torch.distributed.get_world_size()):
if my_rank == r: # if my_rank == r:
# prefix = " + " # # prefix = " + "
prefix = "" # prefix = ""
print("%sbr/%s; [r%d, i%d]; %s, %.12e" % (prefix, "fix " if args.use_distributed_optimizer else "main", my_rank, ITERATION, key, value)) # print("%sbr/%s; [r%d, i%d]; %s, %.12e" % (prefix, "fix " if args.use_distributed_optimizer else "main", my_rank, ITERATION, key, value))
torch.distributed.barrier() # torch.distributed.barrier()
torch.distributed.barrier() # torch.distributed.barrier()
# if my_rank == 0: # # if my_rank == 0:
# raise Exception("debug.") # # raise Exception("debug.")
# else: # # else:
# # exit(0)
# exit(0) # exit(0)
exit(0) # def debug_model(self, ITERATION, key, use_grad):
def debug_model(self, ITERATION, key, use_grad): # use_grad = bool(use_grad)
use_grad = bool(use_grad) # tensors = [
tensors = [ # (p.main_grad.float() if use_grad else p.float())
(p.main_grad.float() if use_grad else p.float()) # for m in self.models for p in m.parameters()
for m in self.models for p in m.parameters() # ]
] # count = sum(t.nelement() for t in tensors)
count = sum(t.nelement() for t in tensors) # return self.debug_base(
return self.debug_base( # ITERATION,
ITERATION, # "model/%s, %s [count %d]" % (
"model/%s, %s [count %d]" % ( # "grad" if use_grad else "param",
"grad" if use_grad else "param", # key,
key, # count,
count, # ),
), # # sum(torch.sum(torch.abs(t)) for t in tensors).item() / count,
# sum(torch.sum(torch.abs(t)) for t in tensors).item() / count, # sum(torch.sum(torch.abs(t)) for t in tensors),
sum(torch.sum(torch.abs(t)) for t in tensors), # )
) # def debug_main(self, ITERATION, key, use_grad):
def debug_main(self, ITERATION, key, use_grad): # use_grad = bool(use_grad)
use_grad = bool(use_grad) # tensors = [
tensors = [ # p.grad if use_grad else p
p.grad if use_grad else p # for g in self.optimizer.param_groups
for g in self.optimizer.param_groups # for p in g["params"]
for p in g["params"] # ]
] # tensors = [ t.float() for t in tensors ]
tensors = [ t.float() for t in tensors ] # count = sum(t.nelement() for t in tensors)
count = sum(t.nelement() for t in tensors) # return self.debug_base(
return self.debug_base( # ITERATION,
ITERATION, # "main/%s, %s [count %d]" % (
"main/%s, %s [count %d]" % ( # "grad" if use_grad else "param",
"grad" if use_grad else "param", # key,
key, # count,
count, # ),
), # sum(torch.sum(torch.abs(t)) for t in tensors),
sum(torch.sum(torch.abs(t)) for t in tensors), # )
)
# <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< # <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
@torch.no_grad() @torch.no_grad()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment