Commit 2c3cb9fc authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

many edits; working towards first draft.

parent 867105c2
...@@ -26,116 +26,122 @@ from megatron.model.module import param_is_not_shared ...@@ -26,116 +26,122 @@ from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
from .optimizer import MixedPrecisionOptimizer, _zero_grad_group_helper from .optimizer import MixedPrecisionOptimizer, _zero_grad_group_helper
# from .optimizer import Float16OptimizerWithFloat16Params
# >>> # >>>
from lutil import pax, tp, print_seq from lutil import pax, tp, print_seq
# <<< # <<<
class Shard: # >>>
# class Shard:
class Range:
# <<<
def __init__(self, start, end): def __init__(self, start, end):
self.start = start self.start = start
self.end = end self.end = end
self.size = end - start self.size = end - start
def normalize(self, start = 0): def normalize(self, start = 0):
return Shard(start, start + self.size) return Range(start, start + self.size)
def __str__(self): def __str__(self):
return "%d,%d [%d]" % (self.start, self.end, self.size) return "%d,%d [%d]" % (self.start, self.end, self.size)
# class DistributedOptimizer(Float16OptimizerWithFloat16Params):
class DistributedOptimizer(MixedPrecisionOptimizer): class DistributedOptimizer(MixedPrecisionOptimizer):
@classmethod @classmethod
def get_model_gbuf_param_shard_map(cls, model, dtype, gbuf_world_shard): def build_model_gbuf_param_range_map(cls, model, dtype, gbuf_world_range):
# Param shard map. # Param range map.
param_world_index_map = model._grad_buffer_param_index_map[dtype] param_world_index_map = model._grad_buffer_param_index_map[dtype]
param_shard_map = {} param_range_map = {}
for param, param_world_indexes in param_world_index_map.items(): for param, param_world_indexes in param_world_index_map.items():
# Shard range. # Param range.
param_world_start, param_world_end = param_world_indexes param_world_start, param_world_end = param_world_indexes
param_local_start = max( param_local_start = max(
0, 0,
param_world_start - gbuf_world_shard.start) param_world_start - gbuf_world_range.start)
param_local_end = min( param_local_end = min(
gbuf_world_shard.size, gbuf_world_range.size,
param_world_end - gbuf_world_shard.start) param_world_end - gbuf_world_range.start)
# Add shard, if within range. # Add param, if within local gbuf range.
if param_local_end > param_local_start: if param_local_end > param_local_start:
param_local_shard = Shard(param_local_start, param_local_end) param_local_range = Range(param_local_start, param_local_end)
param_world_shard = param_local_shard.normalize( param_world_range = param_local_range.normalize(
param_local_start + gbuf_world_shard.start) param_local_start + gbuf_world_range.start)
sub_param_start = max(0, gbuf_world_shard.start-param_world_start) sub_param_start = max(0, gbuf_world_range.start-param_world_start)
sub_param_shard = param_local_shard.normalize(sub_param_start) sub_param_range = param_local_range.normalize(sub_param_start)
param_shard_map[param] = { param_range_map[param] = {
"gbuf_world" : param_world_shard, "gbuf_world" : param_world_range,
"gbuf_local" : param_local_shard, "gbuf_local" : param_local_range,
"param" : sub_param_shard, "param" : sub_param_range,
} }
return param_shard_map return param_range_map
@classmethod @classmethod
def get_model_gbuf_shard(cls, model, dtype): def build_model_gbuf_range(cls, model, dtype):
data_parallel_rank = mpu.get_data_parallel_rank() data_parallel_rank = mpu.get_data_parallel_rank()
data_parallel_world_size = mpu.get_data_parallel_world_size() data_parallel_world_size = mpu.get_data_parallel_world_size()
# Grad buffer shard. # Grad buffer range.
grad_buffer = model._grad_buffers[dtype] grad_buffer = model._grad_buffers[dtype]
gbuf_size = grad_buffer.numel gbuf_size = grad_buffer.numel
max_gbuf_shard_size = int(math.ceil(gbuf_size / data_parallel_world_size)) max_gbuf_range_size = int(math.ceil(gbuf_size / data_parallel_world_size))
# All world shards. (i.e., across all data parallel ranks) # All world ranges. (i.e., across all data parallel ranks)
gbuf_world_all_shards = [] gbuf_world_all_ranges = []
for r in range(data_parallel_world_size): for r in range(data_parallel_world_size):
gbuf_world_start = r * max_gbuf_shard_size gbuf_world_start = r * max_gbuf_range_size
gbuf_world_end = min(gbuf_size, gbuf_world_start+max_gbuf_shard_size) gbuf_world_end = min(gbuf_size, gbuf_world_start+max_gbuf_range_size)
gbuf_world_shard = Shard(gbuf_world_start, gbuf_world_end) gbuf_world_range = Range(gbuf_world_start, gbuf_world_end)
gbuf_world_all_shards.append(gbuf_world_shard) gbuf_world_all_ranges.append(gbuf_world_range)
# Local DP's shards. # Local DP's ranges.
gbuf_world_shard = gbuf_world_all_shards[data_parallel_rank] gbuf_world_range = gbuf_world_all_ranges[data_parallel_rank]
gbuf_local_shard = gbuf_world_shard.normalize() gbuf_local_range = gbuf_world_range.normalize()
# Get each param's shards. # Get each param's ranges.
param_shard_map = cls.get_model_gbuf_param_shard_map(model, param_range_map = cls.build_model_gbuf_param_range_map(model,
dtype, dtype,
gbuf_world_shard) gbuf_world_range)
# Altogether. # Altogether.
data = { data = {
"local" : gbuf_local_shard, "local" : gbuf_local_range,
"world" : gbuf_world_shard, "world" : gbuf_world_range,
"world_all" : gbuf_world_all_shards, "world_all" : gbuf_world_all_ranges,
"param_map" : param_shard_map, "param_map" : param_range_map,
"max_shard_size" : max_gbuf_shard_size, "max_range_size" : max_gbuf_range_size,
} }
return data return data
@classmethod @classmethod
def get_model_gbuf_shard_map(cls, model): def build_model_gbuf_range_map(cls, model):
return { return {
dtype : cls.get_model_gbuf_shard(model, dtype) dtype : cls.build_model_gbuf_range(model, dtype)
for dtype in model._grad_buffers for dtype in model._grad_buffers
} }
@classmethod @classmethod
def get_param_gbuf_map(cls, model_gbuf_shards): def build_model_param_gbuf_map(cls, model_gbuf_ranges):
'''Create a reverse of the model_gbuf_shards, for referencing in '''Create a reverse of the model_gbuf_ranges, for referencing in
opposite direction.''' opposite direction.'''
param_gbuf_map = {} param_gbuf_map = {}
for model_index, model_gbuf_shard_map in enumerate(model_gbuf_shards): for model_index, model_gbuf_range_map in enumerate(model_gbuf_ranges):
for dtype, gbuf_shard_map in model_gbuf_shard_map.items(): for dtype, gbuf_range_map in model_gbuf_range_map.items():
for param, param_shard_map in gbuf_shard_map["param_map"].items(): for param, param_range_map in gbuf_range_map["param_map"].items():
param_gbuf_map[param] = (model_index, dtype) param_gbuf_map[param] = (model_index, dtype)
return param_gbuf_map return param_gbuf_map
# >>> # >>>
# @classmethod # @classmethod
# def get_optimizer_group_shards(cls, param_groups, model_gbuf_shards): # def build_optimizer_group_ranges(cls, param_groups, model_gbuf_ranges):
# num_groups = len(param_groups) # num_groups = len(param_groups)
...@@ -146,31 +152,31 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -146,31 +152,31 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# assert param.requires_grad # assert param.requires_grad
# param_group_map[param] = group_index # param_group_map[param] = group_index
# # Optimizer group shards. # # Optimizer group ranges.
# group_shards = [ {"size": 0, "param_map": {}} for _ in param_groups ] # group_ranges = [ {"size": 0, "param_map": {}} for _ in param_groups ]
# for model_gbuf_shard_map in model_gbuf_shards: # for model_gbuf_range_map in model_gbuf_ranges:
# for dtype, gbuf_shard_map in model_gbuf_shard_map.items(): # for dtype, gbuf_range_map in model_gbuf_range_map.items():
# for param in gbuf_shard_map["param_map"]: # for param in gbuf_range_map["param_map"]:
# group_index = param_group_map[param] # group_index = param_group_map[param]
# group_shard = group_shards[group_index] # group_range = group_ranges[group_index]
# param_size = gbuf_shard_map["param_map"][param]["param"].size # param_size = gbuf_range_map["param_map"][param]["param"].size
# param_group_start = group_shard["size"] # param_group_start = group_range["size"]
# param_group_end = param_group_start + param_size # param_group_end = param_group_start + param_size
# param_group_shard = Shard(param_group_start, param_group_end) # param_group_range = Range(param_group_start, param_group_end)
# group_shard["size"] += param_size # group_range["size"] += param_size
# group_shard["param_map"][param] = param_group_shard # group_range["param_map"][param] = param_group_range
# # Squeeze zero-size group shards. # # Squeeze zero-size group ranges.
# for group_index, group_shard in enumerate(group_shards): # for group_index, group_range in enumerate(group_ranges):
# group_shard["orig_group"] = param_groups[group_index] # group_range["orig_group"] = param_groups[group_index]
# group_shards = [ g for g in group_shards if g["size"] > 0 ] # group_ranges = [ g for g in group_ranges if g["size"] > 0 ]
# return group_shards # return group_ranges
@classmethod @classmethod
def get_optimizer_group_shards(cls, param_groups, model_gbuf_shards): def build_optimizer_group_ranges(cls, param_groups, model_gbuf_ranges):
num_groups = len(param_groups) num_groups = len(param_groups)
...@@ -181,35 +187,35 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -181,35 +187,35 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
assert param.requires_grad assert param.requires_grad
param_group_map[param] = group_index param_group_map[param] = group_index
# Optimizer group shards. # Optimizer group ranges.
# >>> # >>>
# group_shards = [ {"size": 0, "param_map": {}} for _ in param_groups ] # group_ranges = [ {"size": 0, "param_map": {}} for _ in param_groups ]
group_shards = [ {"params": []} for _ in param_groups ] group_ranges = [ {"params": []} for _ in param_groups ]
# group_shards = [ [] for _ in param_groups ] # group_ranges = [ [] for _ in param_groups ]
# <<< # <<<
for model_gbuf_shard_map in model_gbuf_shards: for model_gbuf_range_map in model_gbuf_ranges:
for dtype, gbuf_shard_map in model_gbuf_shard_map.items(): for dtype, gbuf_range_map in model_gbuf_range_map.items():
for param in gbuf_shard_map["param_map"]: for param in gbuf_range_map["param_map"]:
group_index = param_group_map[param] group_index = param_group_map[param]
group_shard = group_shards[group_index] group_range = group_ranges[group_index]
group_shard["params"].append(param) group_range["params"].append(param)
# Squeeze zero-size group shards. # Squeeze zero-size group ranges.
for group_index, group_shard in enumerate(group_shards): for group_index, group_range in enumerate(group_ranges):
group_shard["orig_group"] = param_groups[group_index] group_range["orig_group"] = param_groups[group_index]
group_shards = [ g for g in group_shards if len(g["params"]) > 0 ] group_ranges = [ g for g in group_ranges if len(g["params"]) > 0 ]
# >>> # >>>
# print_seq("group shards / len = %s." % # print_seq("group ranges / len = %s." %
# ", ".join(str(len(s["params"])) for s in group_shards)) # ", ".join(str(len(s["params"])) for s in group_ranges))
# <<< # <<<
return group_shards return group_ranges
# <<< # <<<
# >>> # >>>
# @classmethod # @classmethod
# def allocate_main_param_shards(cls, opt_group_shards): # def allocate_main_param_shards(cls, opt_group_ranges):
# # Allocator method. # # Allocator method.
# allocate_shard = lambda shard_size, dtype : torch.empty( # allocate_shard = lambda shard_size, dtype : torch.empty(
...@@ -219,9 +225,9 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -219,9 +225,9 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# requires_grad = True) # requires_grad = True)
# # Allocate each group's param/grad shard. # # Allocate each group's param/grad shard.
# for group_index, group_shard in enumerate(opt_group_shards): # for group_index, group_range in enumerate(opt_group_ranges):
# group_size = group_shard["size"] # group_size = group_range["size"]
# assert group_size != 0, "temporary check ... remove me." # assert group_size != 0, "temporary check ... remove me."
# # Allocate shard. # # Allocate shard.
...@@ -230,71 +236,74 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -230,71 +236,74 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# mpu.set_tensor_model_parallel_attributes(main_param, True, 0, 1) # mpu.set_tensor_model_parallel_attributes(main_param, True, 0, 1)
# # Update group's param. # # Update group's param.
# group_shard["orig_group"]["params"] = [ main_param ] # group_range["orig_group"]["params"] = [ main_param ]
@classmethod @classmethod
# def allocate_main_params(cls, opt_group_shards): # def allocate_main_params(cls, opt_group_ranges):
def allocate_or_view_main_param_shards(cls, # def allocate_or_view_main_param_shards(cls,
model_gbuf_shards, def build_model_and_main_param_groups(cls,
model_gbuf_ranges,
param_gbuf_map, param_gbuf_map,
opt_group_shards): opt_group_ranges):
# # Allocator method. # Three groups of parameters:
# allocate_shard = lambda shard_size, dtype : torch.empty( # float16_groups: original float16 parameters
# (shard_size,), # fp32_from_float16_groups: fp32 copy of float16 parameters
# dtype = dtype, # fp32_groups: original fp32 parameters
# device = torch.cuda.current_device(), full_float16_groups = []
# requires_grad = True) full_fp32_groups = []
shard_float16_groups = []
# Allocate each group's param/grad shard. shard_fp32_groups = []
for group_index, group_shard in enumerate(opt_group_shards): shard_fp32_from_float16_groups = []
# group_size = group_shard["size"] # Allocate each group's param shard.
# assert group_size != 0, "temporary check ... remove me." for group_index, group_range in enumerate(opt_group_ranges):
# # Allocate shard. # Params of this group.
# main_param = allocate_shard(group_size, torch.float) full_float16_params_this_group = []
# main_param.grad = allocate_shard(group_size, torch.float) full_fp32_params_this_group = []
# mpu.set_tensor_model_parallel_attributes(main_param, True, 0, 1) shard_float16_params_this_group = []
shard_fp32_params_this_group = []
# # Update group's param. shard_fp32_from_float16_params_this_group = []
# group_shard["orig_group"]["params"] = [ main_param ] full_float16_groups.append(full_float16_params_this_group)
full_fp32_groups.append(full_fp32_params_this_group)
group_main_params = [] shard_float16_groups.append(shard_float16_params_this_group)
group_shard["orig_group"]["params"] = group_main_params shard_fp32_groups.append(shard_fp32_params_this_group)
shard_fp32_from_float16_groups.append(
for param in group_shard["params"]: shard_fp32_from_float16_params_this_group)
model_index, dtype = param_gbuf_map[param] for model_param in group_range["params"]:
gbuf_shard = model_gbuf_shards[model_index][dtype]
param_shard = gbuf_shard["param_map"][param]["param"] model_index, dtype = param_gbuf_map[model_param]
gbuf_range = model_gbuf_ranges[model_index][dtype]
pax(0, { param_range = gbuf_range["param_map"][model_param]["param"]
"model_index" : model_index,
"dtype" : dtype,
"gbuf_shard" : gbuf_shard,
"param_shard" : param_shard,
})
# fp16, bf16 params. # fp16, bf16 params.
if param.type() in ['torch.cuda.HalfTensor', if model_param.type() in ['torch.cuda.HalfTensor',
'torch.cuda.BFloat16Tensor']: 'torch.cuda.BFloat16Tensor']:
# Allocate/copy main param/grad. # Clone model -> main.
main_param = param.detach()[param_shard.start:param_shard.end].clone().float() shard_model_param = \
if accumulate_allreduce_grads_in_fp32: model_param.detach()[param_range.start:param_range.end]
main_param.grad = param.main_grad[param_shard.start:param_shard.end] shard_main_param = shard_model_param.clone().float()
else: mpu.copy_tensor_model_parallel_attributes(
main_param.grad = param.main_grad.detach()[param_shard.start:param_shard.end].clone().float() shard_model_param, model_param)
mpu.copy_tensor_model_parallel_attributes(
shard_main_param, model_param)
if hasattr(model_param, 'shared'):
shard_model_param.shared = model_param.shared
shard_main_param.shared = model_param.shared
# Copy tensor model parallel attributes. # Add to group.
mpu.copy_tensor_model_parallel_attributes(main_param, param) full_float16_params_this_group.append(model_param)
if hasattr(param, 'shared'): shard_float16_params_this_group.append(shard_model_param)
main_param.shared = param.shared shard_fp32_from_float16_params_this_group.append(shard_main_param)
# fp32 params. # fp32 params.
elif param.type() == 'torch.cuda.FloatTensor': elif param.type() == 'torch.cuda.FloatTensor':
main_param = param shard_model_param = \
main_param.grad = param.main_grad model_param[param_range.start:param_range.end]
full_fp32_params_this_group.append(model_param)
shard_fp32_params_this_group.append(shard_model_param)
else: else:
raise TypeError('Wrapped parameters must be one of ' raise TypeError('Wrapped parameters must be one of '
...@@ -303,23 +312,35 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -303,23 +312,35 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
'torch.cuda.BFloat16Tensor. ' 'torch.cuda.BFloat16Tensor. '
'Received {}'.format(param.type())) 'Received {}'.format(param.type()))
# Add to group. # # Add to group.
group_main_params.append(main_param) # group_main_params.append(main_param)
group_range["orig_group"]["params"] = [
*shard_fp32_params_this_group,
*shard_fp32_from_float16_params_this_group,
]
return (
full_float16_groups,
full_fp32_groups,
shard_float16_groups,
shard_fp32_groups,
shard_fp32_from_float16_groups,
)
# <<< # <<<
# >>> # >>>
# @classmethod # @classmethod
# def get_main_grad_views_for_grad_norm(cls, opt_group_shards, optimizer): # def build_main_grad_views_for_grad_norm(cls, opt_group_ranges, optimizer):
# grad_views = [] # grad_views = []
# for group_index, opt_group_shard in enumerate(opt_group_shards): # for group_index, opt_group_range in enumerate(opt_group_ranges):
# opt_grad = optimizer.param_groups[group_index]["params"][0].grad # opt_grad = optimizer.param_groups[group_index]["params"][0].grad
# for param, shard in opt_group_shard["param_map"].items(): # for param, range in opt_group_range["param_map"].items():
# if param_is_not_shared(param) and \ # if param_is_not_shared(param) and \
# param_is_not_tensor_parallel_duplicate(param): # param_is_not_tensor_parallel_duplicate(param):
# grad_view = opt_grad[shard.start:shard.end] # grad_view = opt_grad[range.start:range.end]
# grad_views.append(grad_view) # grad_views.append(grad_view)
# return grad_views # return grad_views
...@@ -342,108 +363,162 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -342,108 +363,162 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
assert use_contiguous_buffers_in_local_ddp assert use_contiguous_buffers_in_local_ddp
# <<< # <<<
# Model grad buffer shards. # Model grad buffer ranges.
self.model_gbuf_shards = [] self.model_gbuf_ranges = []
for model_index, model in enumerate(self.models): for model_index, model in enumerate(self.models):
self.model_gbuf_shards.append(self.get_model_gbuf_shard_map(model)) self.model_gbuf_ranges.append(self.build_model_gbuf_range_map(model))
self.param_gbuf_map = self.get_param_gbuf_map(self.model_gbuf_shards) self.model_param_gbuf_map = \
self.build_model_param_gbuf_map(self.model_gbuf_ranges)
# Optimizer shards. # Optimizer ranges.
self.opt_group_shards = self.get_optimizer_group_shards( self.opt_group_ranges = self.build_optimizer_group_ranges(
self.optimizer.param_groups, self.optimizer.param_groups,
self.model_gbuf_shards) self.model_gbuf_ranges)
# Allocate main param shards. # Allocate main param shards.
self.allocate_or_view_main_param_shards(self.model_gbuf_shards, (
self.param_gbuf_map, self.full_float16_groups,
self.opt_group_shards) self.full_fp32_groups,
self.shard_float16_groups,
self.shard_fp32_groups,
self.shard_fp32_from_float16_groups,
) = self.build_model_and_main_param_groups(self.model_gbuf_ranges,
self.model_param_gbuf_map,
self.opt_group_ranges)
# print_seq("16 [%d], 16x32 [%d], 32 [%d]." % (
# sum(len(g) for g in self.float16_groups),
# sum(len(g) for g in self.fp32_from_float16_groups),
# sum(len(g) for g in self.fp32_groups),
# ))
# Update optimizer groups. # Update optimizer groups.
# - Also, leverage state_dict() and load_state_dict() to # - Also, leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors. # recast preexisting per-param state tensors.
self.optimizer.param_groups = \ self.optimizer.param_groups = \
[ g["orig_group"] for g in self.opt_group_shards ] [ g["orig_group"] for g in self.opt_group_ranges ]
self.optimizer.load_state_dict(self.optimizer.state_dict()) self.optimizer.load_state_dict(self.optimizer.state_dict())
# Initialize main params. # >>>
self._copy_model_params_to_main_params() # # Initialize main params.
# self._copy_model_params_to_main_params()
# <<<
# >>> # >>>
# # Params for grad norm. # # Params for grad norm.
# self.main_grad_views_for_grad_norm = self.get_main_grad_views_for_grad_norm( # self.main_grad_views_for_grad_norm = self.build_main_grad_views_for_grad_norm(
# self.opt_group_shards, # self.opt_group_ranges,
# self.optimizer) # self.optimizer)
# <<< # <<<
def get_model_param_range_map(self, param):
model_index, dtype = self.model_param_gbuf_map[param]
gbuf_range_map = self.model_gbuf_ranges[model_index][dtype]
param_range_map = gbuf_range_map["param_map"][param]
# >>>
# pax(0, {
# "param" : param,
# "model_index" : model_index,
# "dtype" : str(dtype),
# "gbuf_range_map" : gbuf_range_map,
# "param_range_map" : param_range_map,
# })
# <<<
return param_range_map
def get_model_parallel_group(self): def get_model_parallel_group(self):
return None return None
def get_main_params(self): # def get_main_params(self):
return [ g["params"][0] for g in self.optimizer.param_groups ] # return [ g["params"][0] for g in self.optimizer.param_groups ]
def get_main_grads(self): # def get_main_grads(self):
return [ p.grad for p in self.get_main_params() ] # return [ p.grad for p in self.get_main_params() ]
def get_main_param(self, group_index): # def get_main_param(self, group_index):
return self.get_main_params()[group_index] # return self.get_main_params()[group_index]
def get_main_grad(self, group_index): # def get_main_grad(self, group_index):
return self.get_main_param(group_index).grad # return self.get_main_param(group_index).grad
# >>> # >>>
# def get_main_grads_for_grad_norm(self): # def get_main_grads_for_grad_norm(self):
# return self.main_grad_views_for_grad_norm # return self.main_grad_views_for_grad_norm
def get_main_grads_for_grad_norm(self):
raise Exception("does 'super' work?")
# <<< # <<<
# def state_dict(self):
# state_dict = {}
# state_dict['optimizer'] = self.optimizer.state_dict()
# if self.grad_scaler:
# state_dict['grad_scaler'] = self.grad_scaler.state_dict()
# state_dict['groups'] = [g['params'] for g in self.optimizer.param_groups]
# return state_dict
def state_dict(self): def state_dict(self):
state_dict = {} raise Exception("fix me.")
state_dict['optimizer'] = self.optimizer.state_dict()
if self.grad_scaler:
state_dict['grad_scaler'] = self.grad_scaler.state_dict() # def load_state_dict(self, state_dict):
state_dict['groups'] = [g['params'] for g in self.optimizer.param_groups] # # Optimizer.
return state_dict # optimizer_key = 'optimizer'
# if optimizer_key not in state_dict:
# optimizer_key = 'optimizer_state_dict'
# print_rank_0('***WARNING*** loading optimizer from '
# 'an old checkpoint ...')
# self.optimizer.load_state_dict(state_dict[optimizer_key])
# # Grad scaler.
# if 'grad_scaler' not in state_dict:
# print_rank_0('***WARNING*** found an old checkpoint, will not '
# 'load grad scaler ...')
# else:
# if self.grad_scaler:
# self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
# else:
# print_rank_0('***WARNING*** fould the grad scaler in the '
# 'checkpoint but it is None in the class. '
# 'Skipping loading grad scaler ...')
# # Copy data for the main params.
# current_groups = [ g["params"] for g in self.optimizer.param_groups ]
# assert "groups" in state_dict, "key 'groups' not in state_dict."
# for current_group, saved_group in zip(current_groups, state_dict["groups"]):
# for current_param, saved_param in zip(current_group, saved_group):
# current_param.data.copy_(saved_param.data)
def load_state_dict(self, state_dict): def load_state_dict(self, state_dict):
# Optimizer. raise Exception("hi.")
optimizer_key = 'optimizer'
if optimizer_key not in state_dict:
optimizer_key = 'optimizer_state_dict'
print_rank_0('***WARNING*** loading optimizer from '
'an old checkpoint ...')
self.optimizer.load_state_dict(state_dict[optimizer_key])
# Grad scaler.
if 'grad_scaler' not in state_dict:
print_rank_0('***WARNING*** found an old checkpoint, will not '
'load grad scaler ...')
else:
if self.grad_scaler:
self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
else:
print_rank_0('***WARNING*** fould the grad scaler in the '
'checkpoint but it is None in the class. '
'Skipping loading grad scaler ...')
# Copy data for the main params. # def zero_grad(self, set_to_none=True):
current_groups = [ g["params"] for g in self.optimizer.param_groups ]
assert "groups" in state_dict, "key 'groups' not in state_dict."
for current_group, saved_group in zip(current_groups, state_dict["groups"]):
for current_param, saved_param in zip(current_group, saved_group):
current_param.data.copy_(saved_param.data)
# # Collect model params.
# model_params = []
# for model in self.models:
# for dtype, param_map in model._grad_buffer_param_index_map.items():
# model_params.extend(param_map.keys())
# # Distributed optimizer requires contiguous buffer; don't set to None.
# _zero_grad_group_helper(model_params, set_to_none = False)
# def zero_grad(self, set_to_none=True):
# raise Exception("does 'super' work?")
# >>>
def zero_grad(self, set_to_none=True): def zero_grad(self, set_to_none=True):
"""We only need to zero the model related parameters, i.e.,
# Collect model params. float16_groups & fp32_groups. We additionally zero
model_params = [] fp32_from_float16_groups as a memory optimization to reduce
for model in self.models: fragmentation; in the case of set_to_none==True, the space
for dtype, param_map in model._grad_buffer_param_index_map.items(): used by this field can be safely deallocated at this point."""
model_params.extend(param_map.keys()) for groups in (
self.full_float16_groups,
# Distributed optimizer requires contiguous buffer; don't set to None. self.full_fp32_groups,
_zero_grad_group_helper(model_params, set_to_none = False) self.shard_fp32_from_float16_groups):
for group in groups:
_zero_grad_group_helper(group, set_to_none)
# <<<
def get_model_grad_buffer_dp_views(self): def get_model_grad_buffer_dp_views(self):
...@@ -469,6 +544,14 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -469,6 +544,14 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
grads. grads.
''' '''
# >>>
# print_seq([
# tp(b.data)
# for m in self.models
# for b in m._grad_buffers.values()
# ])
# <<<
# All-reduce embedding grads. # All-reduce embedding grads.
timers('backward-embedding-all-reduce').start() timers('backward-embedding-all-reduce').start()
self.allreduce_embedding_grads(args) self.allreduce_embedding_grads(args)
...@@ -498,6 +581,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -498,6 +581,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
def gather_model_params(self, args, timers): def gather_model_params(self, args, timers):
raise Exception("hi.")
timers('backward-params-all-gather').start() timers('backward-params-all-gather').start()
data_parallel_rank = mpu.get_data_parallel_rank() data_parallel_rank = mpu.get_data_parallel_rank()
...@@ -526,69 +611,151 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -526,69 +611,151 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
timers('backward-params-all-gather').stop() timers('backward-params-all-gather').stop()
def _collect_main_grad_data_for_unscaling(self): def _collect_main_grad_data_for_unscaling(self):
raise Exception("hi.")
return [ g.data for g in self.get_main_grads() ] return [ g.data for g in self.get_main_grads() ]
# >>>
# def _copy_model_params_to_main_params(self):
# for group_index, group_range in enumerate(self.opt_group_ranges):
# main_param = self.get_main_param(group_index)
# for model_param, main_range in group_range["param_map"].items():
# # Model range.
# # model_index, dtype = self.param_gbuf_map[model_param]
# # model_range = self.model_gbuf_ranges \
# # [model_index][dtype]["param_map"][model_param]["param"]
# model_range = self.get_model_param_range_map(model_param)["param"]
# assert main_range.size == model_range.size
# # Copy shard data.
# main_view = main_param[main_range.start:main_range.end]
# model_view = model_param.view(-1)[model_range.start:model_range.end]
# main_view.detach().copy_(model_view)
def _copy_model_params_to_main_params(self): def _copy_model_params_to_main_params(self):
raise Exception("check if super's copy works.")
# <<<
for group_index, group_shard in enumerate(self.opt_group_shards): # >>>
main_param = self.get_main_param(group_index) # def _copy_model_grads_to_main_grads(self):
for model_param, main_shard in group_shard["param_map"].items():
# Model shard. # for group_index, group_range in enumerate(self.opt_group_ranges):
model_index, dtype = self.param_gbuf_map[model_param] # for model_param, main_range in group_range["param_map"].items():
model_shard = self.model_gbuf_shards \
[model_index][dtype]["param_map"][model_param]["param"]
assert main_shard.size == model_shard.size # # Model range.
# # model_index, dtype = self.param_gbuf_map[model_param]
# # model_range = self.model_gbuf_ranges \
# # [model_index][dtype]["param_map"][model_param]["gbuf_world"]
# model_range = self.get_model_param_range_map(model_param)["gbuf_world"]
# Copy shard data. # assert main_range.size == model_range.size
main_view = main_param[main_shard.start:main_shard.end]
model_view = model_param.view(-1)[model_shard.start:model_shard.end]
main_view.detach().copy_(model_view) # # Copy from DDP's contiguous buffer to main shard's grad.
# model_grad = self.models[model_index]._grad_buffers[dtype].data
# main_grad = self.get_main_grad(group_index)
# # Copy sub-range within tensor.
# model_view = model_grad[model_range.start:model_range.end]
# main_view = main_grad[main_range.start:main_range.end]
def _copy_model_grads_to_main_grads(self): # main_view.detach().copy_(model_view)
# def _copy_model_grads_to_main_grads(self):
# super()._copy_model_grads_to_main_grads()
# raise Exception("check main param '.grad'.")
for group_index, group_shard in enumerate(self.opt_group_shards): # for group in self.optimizer.param_groups:
for model_param, main_shard in group_shard["param_map"].items(): # for param in group["params"]:
# param.grad =
def _copy_model_grads_to_main_grads(self):
# Model shard. # >>>
model_index, dtype = self.param_gbuf_map[model_param] # print_seq([
model_shard = self.model_gbuf_shards \ # "grad = %s." % tp(p.grad)
[model_index][dtype]["param_map"][model_param]["gbuf_world"] # for g in self.optimizer.param_groups
# for p in g["params"]
# ])
# <<<
assert main_shard.size == model_shard.size # This only needs to be done for the float16 group.
for full_model_group, shard_main_group in zip(
self.full_float16_groups,
self.shard_fp32_from_float16_groups):
for full_model_param, shard_main_param in zip(full_model_group,
shard_main_group):
# Copy from DDP's contiguous buffer to main shard's grad. param_range_map = self.get_model_param_range_map(full_model_param)
model_grad = self.models[model_index]._grad_buffers[dtype].data param_range = param_range_map["param"]
main_grad = self.get_main_grad(group_index) full_model_grad = full_model_param.main_grad
shard_model_grad = \
full_model_grad[param_range.start:param_range.end]
shard_main_param.grad = shard_model_grad.float()
# Copy sub-range within tensor. # >>>
model_view = model_grad[model_shard.start:model_shard.end] if full_model_param.nelement() != shard_main_param.nelement():
main_view = main_grad[main_shard.start:main_shard.end] pax(0, {
"param_range_map" : param_range_map,
"param_range" : param_range,
"full_model_param" : tp(full_model_param),
"full_model_grad" : tp(full_model_grad),
"shard_model_grad" : tp(shard_model_grad),
"shard_main_grad" : tp(shard_main_param.grad),
"shard_main_param" : tp(shard_main_param),
})
# <<<
main_view.detach().copy_(model_view) # For fp32 grads, we need to reset the grads to main grad.
for group in self.fp32_groups:
for param in group:
param.grad = param.main_grad
# >>>
print_seq([
"grad = %s." % tp(p.grad)
for g in self.optimizer.param_groups
for p in g["params"]
])
# <<<
def _copy_main_params_to_model_params(self): # <<<
for group_index, group_shard in enumerate(self.opt_group_shards): # >>>
for model_param, main_shard in group_shard["param_map"].items(): # def _copy_main_params_to_model_params(self):
model_index, dtype = self.param_gbuf_map[model_param] # for group_index, group_range in enumerate(self.opt_group_ranges):
model_shard = self.model_gbuf_shards \ # for model_param, main_range in group_range["param_map"].items():
[model_index][dtype]["param_map"][model_param]["gbuf_world"]
assert main_shard.size == model_shard.size # # model_index, dtype = self.param_gbuf_map[model_param]
# # model_range = self.model_gbuf_ranges \
# # [model_index][dtype]["param_map"][model_param]["gbuf_world"]
# model_range = self.get_model_param_range_map(model_param)["gbuf_world"]
# Use DDP's contiguous buffer to temporarily hold params. # assert main_range.size == model_range.size
model_param = self.models[model_index]._grad_buffers[dtype].data
main_param = self.get_main_param(group_index)
# Copy sub-range within tensor. # # Use DDP's contiguous buffer to temporarily hold params.
model_view = model_param[model_shard.start:model_shard.end] # model_param = self.models[model_index]._grad_buffers[dtype].data
main_view = main_param[main_shard.start:main_shard.end] # main_param = self.get_main_param(group_index)
model_view.detach().copy_(main_view) # # Copy sub-range within tensor.
# model_view = model_param[model_range.start:model_range.end]
# main_view = main_param[main_range.start:main_range.end]
# model_view.detach().copy_(main_view)
# def _copy_main_params_to_model_params(self):
# super()._copy_main_params_to_model_params()
# raise Exception("check main param '.grad'.")
def _copy_main_params_to_model_params(self):
raise Exception("hi.")
# This only needs to be done for the float16 group.
for model_group, main_group in zip(self.float16_groups,
self.fp32_from_float16_groups):
for model_param, main_param in zip(model_group, main_group):
model_param.main_grad.detach().copy_(main_param)
# For fp32 grads, we need to reset the grads to main grad.
for group in self.fp32_groups:
for param in group:
param.main_grad.detach().copy_(param)
# <<<
...@@ -323,6 +323,22 @@ class MixedPrecisionOptimizer(MegatronOptimizer): ...@@ -323,6 +323,22 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
self._copy_model_params_to_main_params() self._copy_model_params_to_main_params()
# >>>
# def zero_grad(self, set_to_none=True):
# """We only need to zero the model related parameters, i.e.,
# float16_groups & fp32_from_fp32_groups. We additionally zero
# fp32_from_float16_groups as a memory optimization to reduce
# fragmentation; in the case of set_to_none==True, the space
# used by this field can be safely deallocated at this point."""
# for group in self.float16_groups:
# _zero_grad_group_helper(group, set_to_none)
# for group in self.fp32_from_float16_groups:
# _zero_grad_group_helper(group, set_to_none)
# for group in self.fp32_from_fp32_groups:
# _zero_grad_group_helper(group, set_to_none)
# <<<
def _unscale_main_grads_and_check_for_nan(self): def _unscale_main_grads_and_check_for_nan(self):
# Collect main grads. # Collect main grads.
...@@ -552,18 +568,20 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer): ...@@ -552,18 +568,20 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
self.optimizer.load_state_dict(self.optimizer.state_dict()) self.optimizer.load_state_dict(self.optimizer.state_dict())
def zero_grad(self, set_to_none=True): # >>>
"""We only need to zero the model related parameters, i.e., # def zero_grad(self, set_to_none=True):
float16_groups & fp32_from_fp32_groups. We additionally zero # """We only need to zero the model related parameters, i.e.,
fp32_from_float16_groups as a memory optimization to reduce # float16_groups & fp32_from_fp32_groups. We additionally zero
fragmentation; in the case of set_to_none==True, the space # fp32_from_float16_groups as a memory optimization to reduce
used by this field can be safely deallocated at this point.""" # fragmentation; in the case of set_to_none==True, the space
for group in self.float16_groups: # used by this field can be safely deallocated at this point."""
_zero_grad_group_helper(group, set_to_none) # for group in self.float16_groups:
for group in self.fp32_from_float16_groups: # _zero_grad_group_helper(group, set_to_none)
_zero_grad_group_helper(group, set_to_none) # for group in self.fp32_from_float16_groups:
for group in self.fp32_from_fp32_groups: # _zero_grad_group_helper(group, set_to_none)
_zero_grad_group_helper(group, set_to_none) # for group in self.fp32_from_fp32_groups:
# _zero_grad_group_helper(group, set_to_none)
# <<<
def _collect_main_grad_data_for_unscaling(self): def _collect_main_grad_data_for_unscaling(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment