Commit ed90a1b2 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

mem savings factor training; need to check loss.

parent 06b9ebe0
......@@ -733,7 +733,7 @@ def _add_distributed_args(parser):
'affects the encoder embedding.)')
group.add_argument('--use-distributed-optimizer', action='store_true',
help='Use distributed optimizer.')
group.add_argument('--disrib-opt-comm-mem-savings', default=0., type=float,
group.add_argument('--distrib-opt-comm-mem-savings', default=0., type=float,
help='Trade-off memory savings & iteration time, for '
'disributed optimizer\'s communication operations (i.e., '
'(reduce/gather). This value ranges from 0.0 (default, '
......
......@@ -346,9 +346,11 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
gbuf_view_items.append((model_index, dtype, gbuf_views))
return gbuf_view_items
def get_model_grad_buffer_dp_views_SUB(self, sub_view_numel):
def get_model_grad_buffer_dp_views_chunked(self, mem_savings_factor):
gbuf_view_items = self.get_model_grad_buffer_dp_views()
sub_view_items = []
chunk_view_items = []
for model_index, dtype, gbuf_views in gbuf_view_items:
# ** Sanity check. ** (should be unnecessary; see comment above)
......@@ -356,65 +358,77 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
for view in gbuf_views:
assert view.nelement() == view_numel
for start_index in range(0, view_numel, sub_view_numel):
end_index = min(view_numel, start_index + sub_view_numel)
sub_views = [ t[start_index:end_index] for t in gbuf_views ]
sub_view_items.append((model_index, dtype, sub_views))
chunk_numel_min = 1024**2
chunk_numel_max = view_numel
# chunk_numel_min_log = math.log(chunk_numel_min)
# chunk_numel_max_log = math.log(chunk_numel_max)
# chunk_numel_log = (chunk_numel_min_log + chunk_numel_max_log) / 2
# chunk_numel = int(math.exp(chunk_numel_log))
chunk_numel = int(
mem_savings_factor * chunk_numel_min
+ (1 - mem_savings_factor) * chunk_numel_max
)
# >>>
# from lutil import pax
# pax(0, {
# "view_numel" : view_numel,
# "chunk_numel_min" : chunk_numel_min,
# "chunk_numel_max" : chunk_numel_max,
# "chunk_numel_min_log" : chunk_numel_min_log,
# "chunk_numel_max_log" : chunk_numel_max_log,
# "chunk_numel_log" : chunk_numel_log,
# "chunk_numel" : chunk_numel,
# "mem_savings_factor" : mem_savings_factor,
# })
# <<<
for start_index in range(0, view_numel, chunk_numel):
end_index = min(view_numel, start_index + chunk_numel)
chunk_views = [ t[start_index:end_index] for t in gbuf_views ]
chunk_view_items.append((model_index, dtype, chunk_views))
# >>>
# from lutil import pax
# pax(0, {
# "gbuf_view_items" : [(a,b,"%d / %s" % (len(c), [ d.nelement() for d in c ])) for a,b,c in gbuf_view_items],
# "sub_view_items" : [(a,b,"%d / %s" % (len(c), [ d.nelement() for d in c ])) for a,b,c in sub_view_items],
# "chunk_view_items" : [(a,b,"%d / %s" % (len(c), [ d.nelement() for d in c ])) for a,b,c in chunk_view_items],
# })
# <<<
return sub_view_items
# def get_model_grad_buffers_SINGLE(self):
# data_parallel_world_size = mpu.get_data_parallel_world_size()
# # Grad buffers.
# gbuf_items = []
# for model_index, model in enumerate(self.models):
# for dtype, gbuf in model._grad_buffers.items():
# assert gbuf.numel_padded % data_parallel_world_size == 0
# shard_size = int(gbuf.numel_padded / data_parallel_world_size)
# gbuf_items.append((model_index, dtype, gbuf.data))
# return gbuf_items
return chunk_view_items
# <<<
# >>>
def reduce_model_grads_0(self, args, timers):
'''Note: this is a different order of reduction, versus the non-
distributed optimizer, which reduces: 1) all grads, 2) embedding
grads.
'''
# All-reduce embedding grads.
timers('backward-embedding-all-reduce').start()
self.allreduce_embedding_grads(args)
timers('backward-embedding-all-reduce').stop()
# Reduce-scatter all grads.
timers('backward-params-all-reduce').start()
data_parallel_rank = mpu.get_data_parallel_rank()
data_parallel_world_size = mpu.get_data_parallel_world_size()
data_parallel_group = mpu.get_data_parallel_group()
# def reduce_model_grads_0(self, args, timers):
# '''Note: this is a different order of reduction, versus the non-
# distributed optimizer, which reduces: 1) all grads, 2) embedding
# grads.
# '''
# # All-reduce embedding grads.
# timers('backward-embedding-all-reduce').start()
# self.allreduce_embedding_grads(args)
# timers('backward-embedding-all-reduce').stop()
# # Reduce-scatter all grads.
# timers('backward-params-all-reduce').start()
# data_parallel_rank = mpu.get_data_parallel_rank()
# data_parallel_world_size = mpu.get_data_parallel_world_size()
# data_parallel_group = mpu.get_data_parallel_group()
gbuf_view_items = self.get_model_grad_buffer_dp_views()
for model_index, dtype, gbuf_views in gbuf_view_items:
gbuf = self.models[model_index]._grad_buffers[dtype].data
gbuf /= data_parallel_world_size
torch.distributed.reduce_scatter(
gbuf_views[data_parallel_rank],
gbuf_views,
group = data_parallel_group,
)
timers('backward-params-all-reduce').stop()
def reduce_model_grads_1(self, args, timers):
# gbuf_view_items = self.get_model_grad_buffer_dp_views()
# for model_index, dtype, gbuf_views in gbuf_view_items:
# gbuf = self.models[model_index]._grad_buffers[dtype].data
# gbuf /= data_parallel_world_size
# torch.distributed.reduce_scatter(
# gbuf_views[data_parallel_rank],
# gbuf_views,
# group = data_parallel_group,
# )
# timers('backward-params-all-reduce').stop()
# def reduce_model_grads_1(self, args, timers):
def reduce_model_grads(self, args, timers):
'''Note: this is a different order of reduction, versus the non-
distributed optimizer, which reduces: 1) all grads, 2) embedding
grads.
......@@ -425,14 +439,21 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
self.allreduce_embedding_grads(args)
timers('backward-embedding-all-reduce').stop()
# Reduce-scatter all grads.
# Reduce-scatter setup.
timers('backward-params-all-reduce').start()
data_parallel_rank = mpu.get_data_parallel_rank()
data_parallel_world_size = mpu.get_data_parallel_world_size()
data_parallel_group = mpu.get_data_parallel_group()
mem_savings_factor = args.distrib_opt_comm_mem_savings
sub_numel = 1 * 1048576
gbuf_view_items = self.get_model_grad_buffer_dp_views_SUB(sub_numel)
# Scale grad buffers by '1 / data_parallel_world_size'.
for model in self.models:
for dtype, gbuf in model._grad_buffers.items():
gbuf.data /= data_parallel_world_size
# Reduce scatter all grads.
gbuf_view_items = \
self.get_model_grad_buffer_dp_views_chunked(mem_savings_factor)
for model_index, dtype, gbuf_views in gbuf_view_items:
# gbuf = self.models[model_index]._grad_buffers[dtype].data
# gbuf /= data_parallel_world_size
......@@ -442,39 +463,39 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
group = data_parallel_group,
)
timers('backward-params-all-reduce').stop()
def reduce_model_grads(self, *args):
# >>>
return
# <<<
# self.reduce_model_grads_0(*args)
self.reduce_model_grads_1(*args)
# def reduce_model_grads(self, *args):
# # >>>
# return
# # <<<
# # self.reduce_model_grads_0(*args)
# self.reduce_model_grads_1(*args)
# <<<
# >>>
def gather_model_params_0(self, args, timers):
# def gather_model_params_0(self, args, timers):
timers('backward-params-all-gather').start()
# timers('backward-params-all-gather').start()
data_parallel_rank = mpu.get_data_parallel_rank()
data_parallel_group = mpu.get_data_parallel_group()
# data_parallel_rank = mpu.get_data_parallel_rank()
# data_parallel_group = mpu.get_data_parallel_group()
# All-gather updated main params.
gbuf_view_items = self.get_model_grad_buffer_dp_views()
for model_index, dtype, gbuf_views in gbuf_view_items:
torch.distributed.all_gather(
gbuf_views,
gbuf_views[data_parallel_rank],
group = data_parallel_group,
)
# # All-gather updated main params.
# gbuf_view_items = self.get_model_grad_buffer_dp_views()
# for model_index, dtype, gbuf_views in gbuf_view_items:
# torch.distributed.all_gather(
# gbuf_views,
# gbuf_views[data_parallel_rank],
# group = data_parallel_group,
# )
# Each model param now contains its updated values in its
# '.main_grad' field.
for model in self.models:
for dtype, param_map in model._grad_buffer_param_index_map.items():
for param in param_map:
param.detach().copy_(param.main_grad)
# # Each model param now contains its updated values in its
# # '.main_grad' field.
# for model in self.models:
# for dtype, param_map in model._grad_buffer_param_index_map.items():
# for param in param_map:
# param.detach().copy_(param.main_grad)
timers('backward-params-all-gather').stop()
# timers('backward-params-all-gather').stop()
# def gather_model_params_1(self, args, timers):
# timers('backward-params-all-gather').start()
......@@ -518,12 +539,14 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# param.detach().copy_(param.main_grad)
# timers('backward-params-all-gather').stop()
def gather_model_params_1(self, args, timers):
# def gather_model_params_1(self, args, timers):
def gather_model_params(self, args, timers):
timers('backward-params-all-gather').start()
data_parallel_rank = mpu.get_data_parallel_rank()
data_parallel_group = mpu.get_data_parallel_group()
mem_savings_factor = args.distrib_opt_comm_mem_savings
# All-gather updated main params.
# - All grad buffer views are guaranteed to have the same num elements
......@@ -533,8 +556,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# sub_numel = 1 * 1024
# sub_numel = 1 * 131072
sub_numel = 1024 * 1048576
gbuf_view_items = self.get_model_grad_buffer_dp_views_SUB(sub_numel)
# sub_numel = 1024 * 1048576
# gbuf_view_items = self.get_model_grad_buffer_dp_views_SUB(sub_numel)
gbuf_view_items = \
self.get_model_grad_buffer_dp_views_chunked(mem_savings_factor)
for model_index, dtype, gbuf_views in gbuf_view_items:
torch.distributed.all_gather(
gbuf_views,
......@@ -671,16 +696,16 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# # <<<
# timers('backward-params-all-gather').stop()
def gather_model_params(self, *args):
# >>>
# return
# <<<
# self.gather_model_params_0(*args)
self.gather_model_params_1(*args)
# self.gather_model_params_2(*args)
# def gather_model_params(self, *args):
# # >>>
# # return
# # <<<
# # self.gather_model_params_0(*args)
# self.gather_model_params_1(*args)
# # self.gather_model_params_2(*args)
# ~~~
# self.debug_model(0, "after / gather_model_params", 0)
# # ~~~
# # self.debug_model(0, "after / gather_model_params", 0)
# <<<
def _collect_main_grad_data_for_unscaling(self):
......
......@@ -322,61 +322,61 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
return found_inf_flag
# >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
@classmethod
def debug_base(cls, ITERATION, key, value):
from megatron import get_args
args = get_args()
my_rank = torch.distributed.get_rank()
DEBUG_ITERATION = ITERATION
if ITERATION != DEBUG_ITERATION:
return
for r in range(torch.distributed.get_world_size()):
if my_rank == r:
# prefix = " + "
prefix = ""
print("%sbr/%s; [r%d, i%d]; %s, %.12e" % (prefix, "fix " if args.use_distributed_optimizer else "main", my_rank, ITERATION, key, value))
torch.distributed.barrier()
torch.distributed.barrier()
# if my_rank == 0:
# raise Exception("debug.")
# else:
# exit(0)
exit(0)
def debug_model(self, ITERATION, key, use_grad):
use_grad = bool(use_grad)
tensors = [
(p.main_grad.float() if use_grad else p.float())
for m in self.models for p in m.parameters()
]
count = sum(t.nelement() for t in tensors)
return self.debug_base(
ITERATION,
"model/%s, %s [count %d]" % (
"grad" if use_grad else "param",
key,
count,
),
# sum(torch.sum(torch.abs(t)) for t in tensors).item() / count,
sum(torch.sum(torch.abs(t)) for t in tensors),
)
def debug_main(self, ITERATION, key, use_grad):
use_grad = bool(use_grad)
tensors = [
p.grad if use_grad else p
for g in self.optimizer.param_groups
for p in g["params"]
]
tensors = [ t.float() for t in tensors ]
count = sum(t.nelement() for t in tensors)
return self.debug_base(
ITERATION,
"main/%s, %s [count %d]" % (
"grad" if use_grad else "param",
key,
count,
),
sum(torch.sum(torch.abs(t)) for t in tensors),
)
# @classmethod
# def debug_base(cls, ITERATION, key, value):
# from megatron import get_args
# args = get_args()
# my_rank = torch.distributed.get_rank()
# DEBUG_ITERATION = ITERATION
# if ITERATION != DEBUG_ITERATION:
# return
# for r in range(torch.distributed.get_world_size()):
# if my_rank == r:
# # prefix = " + "
# prefix = ""
# print("%sbr/%s; [r%d, i%d]; %s, %.12e" % (prefix, "fix " if args.use_distributed_optimizer else "main", my_rank, ITERATION, key, value))
# torch.distributed.barrier()
# torch.distributed.barrier()
# # if my_rank == 0:
# # raise Exception("debug.")
# # else:
# # exit(0)
# exit(0)
# def debug_model(self, ITERATION, key, use_grad):
# use_grad = bool(use_grad)
# tensors = [
# (p.main_grad.float() if use_grad else p.float())
# for m in self.models for p in m.parameters()
# ]
# count = sum(t.nelement() for t in tensors)
# return self.debug_base(
# ITERATION,
# "model/%s, %s [count %d]" % (
# "grad" if use_grad else "param",
# key,
# count,
# ),
# # sum(torch.sum(torch.abs(t)) for t in tensors).item() / count,
# sum(torch.sum(torch.abs(t)) for t in tensors),
# )
# def debug_main(self, ITERATION, key, use_grad):
# use_grad = bool(use_grad)
# tensors = [
# p.grad if use_grad else p
# for g in self.optimizer.param_groups
# for p in g["params"]
# ]
# tensors = [ t.float() for t in tensors ]
# count = sum(t.nelement() for t in tensors)
# return self.debug_base(
# ITERATION,
# "main/%s, %s [count %d]" % (
# "grad" if use_grad else "param",
# key,
# count,
# ),
# sum(torch.sum(torch.abs(t)) for t in tensors),
# )
# <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
@torch.no_grad()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment