Commit 867105c2 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

getting started with no-grad-duplication implementation.

parent c52c9dbf
......@@ -144,6 +144,7 @@ def get_megatron_optimizer(model,
args.log_num_zeros_in_grad,
params_have_main_grad,
args.use_contiguous_buffers_in_local_ddp,
args.fp16,
args.bf16,
grad_scaler,
model)
......
......@@ -27,6 +27,9 @@ from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
from .optimizer import MixedPrecisionOptimizer, _zero_grad_group_helper
# >>>
from lutil import pax, tp, print_seq
# <<<
class Shard:
def __init__(self, start, end):
......@@ -130,6 +133,42 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
param_gbuf_map[param] = (model_index, dtype)
return param_gbuf_map
# >>>
# @classmethod
# def get_optimizer_group_shards(cls, param_groups, model_gbuf_shards):
# num_groups = len(param_groups)
# # Param group map.
# param_group_map = {}
# for group_index, group in enumerate(param_groups):
# for param in group["params"]:
# assert param.requires_grad
# param_group_map[param] = group_index
# # Optimizer group shards.
# group_shards = [ {"size": 0, "param_map": {}} for _ in param_groups ]
# for model_gbuf_shard_map in model_gbuf_shards:
# for dtype, gbuf_shard_map in model_gbuf_shard_map.items():
# for param in gbuf_shard_map["param_map"]:
# group_index = param_group_map[param]
# group_shard = group_shards[group_index]
# param_size = gbuf_shard_map["param_map"][param]["param"].size
# param_group_start = group_shard["size"]
# param_group_end = param_group_start + param_size
# param_group_shard = Shard(param_group_start, param_group_end)
# group_shard["size"] += param_size
# group_shard["param_map"][param] = param_group_shard
# # Squeeze zero-size group shards.
# for group_index, group_shard in enumerate(group_shards):
# group_shard["orig_group"] = param_groups[group_index]
# group_shards = [ g for g in group_shards if g["size"] > 0 ]
# return group_shards
@classmethod
def get_optimizer_group_shards(cls, param_groups, model_gbuf_shards):
......@@ -143,81 +182,165 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
param_group_map[param] = group_index
# Optimizer group shards.
group_shards = [ {"size": 0, "param_map": {}} for _ in param_groups ]
# >>>
# group_shards = [ {"size": 0, "param_map": {}} for _ in param_groups ]
group_shards = [ {"params": []} for _ in param_groups ]
# group_shards = [ [] for _ in param_groups ]
# <<<
for model_gbuf_shard_map in model_gbuf_shards:
for dtype, gbuf_shard_map in model_gbuf_shard_map.items():
for param in gbuf_shard_map["param_map"]:
group_index = param_group_map[param]
group_shard = group_shards[group_index]
param_size = gbuf_shard_map["param_map"][param]["param"].size
param_group_start = group_shard["size"]
param_group_end = param_group_start + param_size
param_group_shard = Shard(param_group_start, param_group_end)
group_shard["size"] += param_size
group_shard["param_map"][param] = param_group_shard
group_shard["params"].append(param)
# Squeeze zero-size group shards.
for group_index, group_shard in enumerate(group_shards):
group_shard["orig_group"] = param_groups[group_index]
group_shards = [ g for g in group_shards if g["size"] > 0 ]
group_shards = [ g for g in group_shards if len(g["params"]) > 0 ]
return group_shards
# >>>
# print_seq("group shards / len = %s." %
# ", ".join(str(len(s["params"])) for s in group_shards))
# <<<
@classmethod
def allocate_main_param_shards(cls, opt_group_shards):
return group_shards
# <<<
# Allocator method.
allocate_shard = lambda shard_size, dtype : torch.empty(
(shard_size,),
dtype = dtype,
device = torch.cuda.current_device(),
requires_grad = True)
# >>>
# @classmethod
# def allocate_main_param_shards(cls, opt_group_shards):
# Allocate each group's param/grad shard.
for group_index, group_shard in enumerate(opt_group_shards):
# # Allocator method.
# allocate_shard = lambda shard_size, dtype : torch.empty(
# (shard_size,),
# dtype = dtype,
# device = torch.cuda.current_device(),
# requires_grad = True)
group_size = group_shard["size"]
assert group_size != 0, "temporary check ... remove me."
# # Allocate each group's param/grad shard.
# for group_index, group_shard in enumerate(opt_group_shards):
# Allocate shard.
main_param = allocate_shard(group_size, torch.float)
main_param.grad = allocate_shard(group_size, torch.float)
mpu.set_tensor_model_parallel_attributes(main_param, True, 0, 1)
# group_size = group_shard["size"]
# assert group_size != 0, "temporary check ... remove me."
# Update group's param.
group_shard["orig_group"]["params"] = [ main_param ]
# # Allocate shard.
# main_param = allocate_shard(group_size, torch.float)
# main_param.grad = allocate_shard(group_size, torch.float)
# mpu.set_tensor_model_parallel_attributes(main_param, True, 0, 1)
# # Update group's param.
# group_shard["orig_group"]["params"] = [ main_param ]
@classmethod
def get_main_grad_views_for_grad_norm(cls, opt_group_shards, optimizer):
grad_views = []
for group_index, opt_group_shard in enumerate(opt_group_shards):
opt_grad = optimizer.param_groups[group_index]["params"][0].grad
for param, shard in opt_group_shard["param_map"].items():
if param_is_not_shared(param) and \
param_is_not_tensor_parallel_duplicate(param):
# def allocate_main_params(cls, opt_group_shards):
def allocate_or_view_main_param_shards(cls,
model_gbuf_shards,
param_gbuf_map,
opt_group_shards):
# # Allocator method.
# allocate_shard = lambda shard_size, dtype : torch.empty(
# (shard_size,),
# dtype = dtype,
# device = torch.cuda.current_device(),
# requires_grad = True)
# Allocate each group's param/grad shard.
for group_index, group_shard in enumerate(opt_group_shards):
# group_size = group_shard["size"]
# assert group_size != 0, "temporary check ... remove me."
# # Allocate shard.
# main_param = allocate_shard(group_size, torch.float)
# main_param.grad = allocate_shard(group_size, torch.float)
# mpu.set_tensor_model_parallel_attributes(main_param, True, 0, 1)
# # Update group's param.
# group_shard["orig_group"]["params"] = [ main_param ]
group_main_params = []
group_shard["orig_group"]["params"] = group_main_params
for param in group_shard["params"]:
model_index, dtype = param_gbuf_map[param]
gbuf_shard = model_gbuf_shards[model_index][dtype]
param_shard = gbuf_shard["param_map"][param]["param"]
pax(0, {
"model_index" : model_index,
"dtype" : dtype,
"gbuf_shard" : gbuf_shard,
"param_shard" : param_shard,
})
# fp16, bf16 params.
if param.type() in ['torch.cuda.HalfTensor',
'torch.cuda.BFloat16Tensor']:
# Allocate/copy main param/grad.
main_param = param.detach()[param_shard.start:param_shard.end].clone().float()
if accumulate_allreduce_grads_in_fp32:
main_param.grad = param.main_grad[param_shard.start:param_shard.end]
else:
main_param.grad = param.main_grad.detach()[param_shard.start:param_shard.end].clone().float()
# Copy tensor model parallel attributes.
mpu.copy_tensor_model_parallel_attributes(main_param, param)
if hasattr(param, 'shared'):
main_param.shared = param.shared
# fp32 params.
elif param.type() == 'torch.cuda.FloatTensor':
main_param = param
main_param.grad = param.main_grad
else:
raise TypeError('Wrapped parameters must be one of '
'torch.cuda.FloatTensor, '
'torch.cuda.HalfTensor, or '
'torch.cuda.BFloat16Tensor. '
'Received {}'.format(param.type()))
# Add to group.
group_main_params.append(main_param)
# <<<
# >>>
# @classmethod
# def get_main_grad_views_for_grad_norm(cls, opt_group_shards, optimizer):
# grad_views = []
# for group_index, opt_group_shard in enumerate(opt_group_shards):
# opt_grad = optimizer.param_groups[group_index]["params"][0].grad
# for param, shard in opt_group_shard["param_map"].items():
# if param_is_not_shared(param) and \
# param_is_not_tensor_parallel_duplicate(param):
grad_view = opt_grad[shard.start:shard.end]
grad_views.append(grad_view)
# grad_view = opt_grad[shard.start:shard.end]
# grad_views.append(grad_view)
return grad_views
# return grad_views
# <<<
def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, use_contiguous_buffers_in_local_ddp,
bf16, grad_scaler, models):
fp16, bf16, grad_scaler, models):
super().__init__(
optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, use_contiguous_buffers_in_local_ddp,
bf16, grad_scaler, models)
fp16, bf16, grad_scaler, models)
# Verify that contiguous buffers are being used
# - Note: this should already be checked in arguments.py
args = get_args()
assert args.use_contiguous_buffers_in_local_ddp
# >>>
# args = get_args()
# assert args.use_contiguous_buffers_in_local_ddp
assert use_contiguous_buffers_in_local_ddp
# <<<
# Model grad buffer shards.
self.model_gbuf_shards = []
......@@ -231,7 +354,9 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
self.model_gbuf_shards)
# Allocate main param shards.
self.allocate_main_param_shards(self.opt_group_shards)
self.allocate_or_view_main_param_shards(self.model_gbuf_shards,
self.param_gbuf_map,
self.opt_group_shards)
# Update optimizer groups.
# - Also, leverage state_dict() and load_state_dict() to
......@@ -243,10 +368,12 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# Initialize main params.
self._copy_model_params_to_main_params()
# Params for grad norm.
self.main_grad_views_for_grad_norm = self.get_main_grad_views_for_grad_norm(
self.opt_group_shards,
self.optimizer)
# >>>
# # Params for grad norm.
# self.main_grad_views_for_grad_norm = self.get_main_grad_views_for_grad_norm(
# self.opt_group_shards,
# self.optimizer)
# <<<
def get_model_parallel_group(self):
......@@ -263,8 +390,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
return self.get_main_param(group_index).grad
def get_main_grads_for_grad_norm(self):
return self.main_grad_views_for_grad_norm
# >>>
# def get_main_grads_for_grad_norm(self):
# return self.main_grad_views_for_grad_norm
# <<<
def state_dict(self):
......
......@@ -277,7 +277,7 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, use_contiguous_buffers_in_local_ddp,
bf16, grad_scaler,
fp16, bf16, grad_scaler,
models):
super().__init__(
......@@ -285,6 +285,7 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
params_have_main_grad, use_contiguous_buffers_in_local_ddp,
models)
self.fp16 = fp16
self.bf16 = bf16
self.grad_scaler = grad_scaler
......@@ -481,12 +482,12 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, use_contiguous_buffers_in_local_ddp,
bf16, grad_scaler, models):
fp16, bf16, grad_scaler, models):
super().__init__(
optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, use_contiguous_buffers_in_local_ddp,
bf16, grad_scaler, models)
fp16, bf16, grad_scaler, models)
# ======================
# main parameter stuff
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment