Commit 867105c2 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

getting started with no-grad-duplication implementation.

parent c52c9dbf
...@@ -144,6 +144,7 @@ def get_megatron_optimizer(model, ...@@ -144,6 +144,7 @@ def get_megatron_optimizer(model,
args.log_num_zeros_in_grad, args.log_num_zeros_in_grad,
params_have_main_grad, params_have_main_grad,
args.use_contiguous_buffers_in_local_ddp, args.use_contiguous_buffers_in_local_ddp,
args.fp16,
args.bf16, args.bf16,
grad_scaler, grad_scaler,
model) model)
......
...@@ -27,6 +27,9 @@ from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate ...@@ -27,6 +27,9 @@ from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
from .optimizer import MixedPrecisionOptimizer, _zero_grad_group_helper from .optimizer import MixedPrecisionOptimizer, _zero_grad_group_helper
# >>>
from lutil import pax, tp, print_seq
# <<<
class Shard: class Shard:
def __init__(self, start, end): def __init__(self, start, end):
...@@ -130,6 +133,42 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -130,6 +133,42 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
param_gbuf_map[param] = (model_index, dtype) param_gbuf_map[param] = (model_index, dtype)
return param_gbuf_map return param_gbuf_map
# >>>
# @classmethod
# def get_optimizer_group_shards(cls, param_groups, model_gbuf_shards):
# num_groups = len(param_groups)
# # Param group map.
# param_group_map = {}
# for group_index, group in enumerate(param_groups):
# for param in group["params"]:
# assert param.requires_grad
# param_group_map[param] = group_index
# # Optimizer group shards.
# group_shards = [ {"size": 0, "param_map": {}} for _ in param_groups ]
# for model_gbuf_shard_map in model_gbuf_shards:
# for dtype, gbuf_shard_map in model_gbuf_shard_map.items():
# for param in gbuf_shard_map["param_map"]:
# group_index = param_group_map[param]
# group_shard = group_shards[group_index]
# param_size = gbuf_shard_map["param_map"][param]["param"].size
# param_group_start = group_shard["size"]
# param_group_end = param_group_start + param_size
# param_group_shard = Shard(param_group_start, param_group_end)
# group_shard["size"] += param_size
# group_shard["param_map"][param] = param_group_shard
# # Squeeze zero-size group shards.
# for group_index, group_shard in enumerate(group_shards):
# group_shard["orig_group"] = param_groups[group_index]
# group_shards = [ g for g in group_shards if g["size"] > 0 ]
# return group_shards
@classmethod @classmethod
def get_optimizer_group_shards(cls, param_groups, model_gbuf_shards): def get_optimizer_group_shards(cls, param_groups, model_gbuf_shards):
...@@ -143,81 +182,165 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -143,81 +182,165 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
param_group_map[param] = group_index param_group_map[param] = group_index
# Optimizer group shards. # Optimizer group shards.
group_shards = [ {"size": 0, "param_map": {}} for _ in param_groups ] # >>>
# group_shards = [ {"size": 0, "param_map": {}} for _ in param_groups ]
group_shards = [ {"params": []} for _ in param_groups ]
# group_shards = [ [] for _ in param_groups ]
# <<<
for model_gbuf_shard_map in model_gbuf_shards: for model_gbuf_shard_map in model_gbuf_shards:
for dtype, gbuf_shard_map in model_gbuf_shard_map.items(): for dtype, gbuf_shard_map in model_gbuf_shard_map.items():
for param in gbuf_shard_map["param_map"]: for param in gbuf_shard_map["param_map"]:
group_index = param_group_map[param] group_index = param_group_map[param]
group_shard = group_shards[group_index] group_shard = group_shards[group_index]
param_size = gbuf_shard_map["param_map"][param]["param"].size group_shard["params"].append(param)
param_group_start = group_shard["size"]
param_group_end = param_group_start + param_size
param_group_shard = Shard(param_group_start, param_group_end)
group_shard["size"] += param_size
group_shard["param_map"][param] = param_group_shard
# Squeeze zero-size group shards. # Squeeze zero-size group shards.
for group_index, group_shard in enumerate(group_shards): for group_index, group_shard in enumerate(group_shards):
group_shard["orig_group"] = param_groups[group_index] group_shard["orig_group"] = param_groups[group_index]
group_shards = [ g for g in group_shards if g["size"] > 0 ] group_shards = [ g for g in group_shards if len(g["params"]) > 0 ]
return group_shards # >>>
# print_seq("group shards / len = %s." %
# ", ".join(str(len(s["params"])) for s in group_shards))
# <<<
@classmethod return group_shards
def allocate_main_param_shards(cls, opt_group_shards): # <<<
# Allocator method. # >>>
allocate_shard = lambda shard_size, dtype : torch.empty( # @classmethod
(shard_size,), # def allocate_main_param_shards(cls, opt_group_shards):
dtype = dtype,
device = torch.cuda.current_device(),
requires_grad = True)
# Allocate each group's param/grad shard. # # Allocator method.
for group_index, group_shard in enumerate(opt_group_shards): # allocate_shard = lambda shard_size, dtype : torch.empty(
# (shard_size,),
# dtype = dtype,
# device = torch.cuda.current_device(),
# requires_grad = True)
group_size = group_shard["size"] # # Allocate each group's param/grad shard.
assert group_size != 0, "temporary check ... remove me." # for group_index, group_shard in enumerate(opt_group_shards):
# Allocate shard. # group_size = group_shard["size"]
main_param = allocate_shard(group_size, torch.float) # assert group_size != 0, "temporary check ... remove me."
main_param.grad = allocate_shard(group_size, torch.float)
mpu.set_tensor_model_parallel_attributes(main_param, True, 0, 1)
# Update group's param. # # Allocate shard.
group_shard["orig_group"]["params"] = [ main_param ] # main_param = allocate_shard(group_size, torch.float)
# main_param.grad = allocate_shard(group_size, torch.float)
# mpu.set_tensor_model_parallel_attributes(main_param, True, 0, 1)
# # Update group's param.
# group_shard["orig_group"]["params"] = [ main_param ]
@classmethod @classmethod
def get_main_grad_views_for_grad_norm(cls, opt_group_shards, optimizer): # def allocate_main_params(cls, opt_group_shards):
def allocate_or_view_main_param_shards(cls,
grad_views = [] model_gbuf_shards,
for group_index, opt_group_shard in enumerate(opt_group_shards): param_gbuf_map,
opt_grad = optimizer.param_groups[group_index]["params"][0].grad opt_group_shards):
for param, shard in opt_group_shard["param_map"].items():
if param_is_not_shared(param) and \ # # Allocator method.
param_is_not_tensor_parallel_duplicate(param): # allocate_shard = lambda shard_size, dtype : torch.empty(
# (shard_size,),
# dtype = dtype,
# device = torch.cuda.current_device(),
# requires_grad = True)
# Allocate each group's param/grad shard.
for group_index, group_shard in enumerate(opt_group_shards):
# group_size = group_shard["size"]
# assert group_size != 0, "temporary check ... remove me."
# # Allocate shard.
# main_param = allocate_shard(group_size, torch.float)
# main_param.grad = allocate_shard(group_size, torch.float)
# mpu.set_tensor_model_parallel_attributes(main_param, True, 0, 1)
# # Update group's param.
# group_shard["orig_group"]["params"] = [ main_param ]
group_main_params = []
group_shard["orig_group"]["params"] = group_main_params
for param in group_shard["params"]:
model_index, dtype = param_gbuf_map[param]
gbuf_shard = model_gbuf_shards[model_index][dtype]
param_shard = gbuf_shard["param_map"][param]["param"]
pax(0, {
"model_index" : model_index,
"dtype" : dtype,
"gbuf_shard" : gbuf_shard,
"param_shard" : param_shard,
})
# fp16, bf16 params.
if param.type() in ['torch.cuda.HalfTensor',
'torch.cuda.BFloat16Tensor']:
# Allocate/copy main param/grad.
main_param = param.detach()[param_shard.start:param_shard.end].clone().float()
if accumulate_allreduce_grads_in_fp32:
main_param.grad = param.main_grad[param_shard.start:param_shard.end]
else:
main_param.grad = param.main_grad.detach()[param_shard.start:param_shard.end].clone().float()
# Copy tensor model parallel attributes.
mpu.copy_tensor_model_parallel_attributes(main_param, param)
if hasattr(param, 'shared'):
main_param.shared = param.shared
# fp32 params.
elif param.type() == 'torch.cuda.FloatTensor':
main_param = param
main_param.grad = param.main_grad
else:
raise TypeError('Wrapped parameters must be one of '
'torch.cuda.FloatTensor, '
'torch.cuda.HalfTensor, or '
'torch.cuda.BFloat16Tensor. '
'Received {}'.format(param.type()))
# Add to group.
group_main_params.append(main_param)
# <<<
# >>>
# @classmethod
# def get_main_grad_views_for_grad_norm(cls, opt_group_shards, optimizer):
# grad_views = []
# for group_index, opt_group_shard in enumerate(opt_group_shards):
# opt_grad = optimizer.param_groups[group_index]["params"][0].grad
# for param, shard in opt_group_shard["param_map"].items():
# if param_is_not_shared(param) and \
# param_is_not_tensor_parallel_duplicate(param):
grad_view = opt_grad[shard.start:shard.end] # grad_view = opt_grad[shard.start:shard.end]
grad_views.append(grad_view) # grad_views.append(grad_view)
return grad_views # return grad_views
# <<<
def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad, def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, use_contiguous_buffers_in_local_ddp, params_have_main_grad, use_contiguous_buffers_in_local_ddp,
bf16, grad_scaler, models): fp16, bf16, grad_scaler, models):
super().__init__( super().__init__(
optimizer, clip_grad, log_num_zeros_in_grad, optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, use_contiguous_buffers_in_local_ddp, params_have_main_grad, use_contiguous_buffers_in_local_ddp,
bf16, grad_scaler, models) fp16, bf16, grad_scaler, models)
# Verify that contiguous buffers are being used # Verify that contiguous buffers are being used
# - Note: this should already be checked in arguments.py # - Note: this should already be checked in arguments.py
args = get_args() # >>>
assert args.use_contiguous_buffers_in_local_ddp # args = get_args()
# assert args.use_contiguous_buffers_in_local_ddp
assert use_contiguous_buffers_in_local_ddp
# <<<
# Model grad buffer shards. # Model grad buffer shards.
self.model_gbuf_shards = [] self.model_gbuf_shards = []
...@@ -231,7 +354,9 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -231,7 +354,9 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
self.model_gbuf_shards) self.model_gbuf_shards)
# Allocate main param shards. # Allocate main param shards.
self.allocate_main_param_shards(self.opt_group_shards) self.allocate_or_view_main_param_shards(self.model_gbuf_shards,
self.param_gbuf_map,
self.opt_group_shards)
# Update optimizer groups. # Update optimizer groups.
# - Also, leverage state_dict() and load_state_dict() to # - Also, leverage state_dict() and load_state_dict() to
...@@ -243,10 +368,12 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -243,10 +368,12 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# Initialize main params. # Initialize main params.
self._copy_model_params_to_main_params() self._copy_model_params_to_main_params()
# Params for grad norm. # >>>
self.main_grad_views_for_grad_norm = self.get_main_grad_views_for_grad_norm( # # Params for grad norm.
self.opt_group_shards, # self.main_grad_views_for_grad_norm = self.get_main_grad_views_for_grad_norm(
self.optimizer) # self.opt_group_shards,
# self.optimizer)
# <<<
def get_model_parallel_group(self): def get_model_parallel_group(self):
...@@ -263,8 +390,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -263,8 +390,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
return self.get_main_param(group_index).grad return self.get_main_param(group_index).grad
def get_main_grads_for_grad_norm(self): # >>>
return self.main_grad_views_for_grad_norm # def get_main_grads_for_grad_norm(self):
# return self.main_grad_views_for_grad_norm
# <<<
def state_dict(self): def state_dict(self):
......
...@@ -277,7 +277,7 @@ class MixedPrecisionOptimizer(MegatronOptimizer): ...@@ -277,7 +277,7 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad, def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, use_contiguous_buffers_in_local_ddp, params_have_main_grad, use_contiguous_buffers_in_local_ddp,
bf16, grad_scaler, fp16, bf16, grad_scaler,
models): models):
super().__init__( super().__init__(
...@@ -285,6 +285,7 @@ class MixedPrecisionOptimizer(MegatronOptimizer): ...@@ -285,6 +285,7 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
params_have_main_grad, use_contiguous_buffers_in_local_ddp, params_have_main_grad, use_contiguous_buffers_in_local_ddp,
models) models)
self.fp16 = fp16
self.bf16 = bf16 self.bf16 = bf16
self.grad_scaler = grad_scaler self.grad_scaler = grad_scaler
...@@ -481,12 +482,12 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer): ...@@ -481,12 +482,12 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad, def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, use_contiguous_buffers_in_local_ddp, params_have_main_grad, use_contiguous_buffers_in_local_ddp,
bf16, grad_scaler, models): fp16, bf16, grad_scaler, models):
super().__init__( super().__init__(
optimizer, clip_grad, log_num_zeros_in_grad, optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, use_contiguous_buffers_in_local_ddp, params_have_main_grad, use_contiguous_buffers_in_local_ddp,
bf16, grad_scaler, models) fp16, bf16, grad_scaler, models)
# ====================== # ======================
# main parameter stuff # main parameter stuff
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment