Commit 91f3579e authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

cleanup.

parent 4b9a218e
...@@ -26,16 +26,12 @@ from megatron.model.module import param_is_not_shared ...@@ -26,16 +26,12 @@ from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
from .optimizer import MixedPrecisionOptimizer, _zero_grad_group_helper from .optimizer import MixedPrecisionOptimizer, _zero_grad_group_helper
# from .optimizer import Float16OptimizerWithFloat16Params
# >>> # >>>
from lutil import pax, tp, print_seq from lutil import pax, tp, print_seq
# <<< # <<<
# >>>
# class Shard:
class Range: class Range:
# <<<
def __init__(self, start, end): def __init__(self, start, end):
self.start = start self.start = start
...@@ -47,7 +43,6 @@ class Range: ...@@ -47,7 +43,6 @@ class Range:
return "%d,%d [%d]" % (self.start, self.end, self.size) return "%d,%d [%d]" % (self.start, self.end, self.size)
# class DistributedOptimizer(Float16OptimizerWithFloat16Params):
class DistributedOptimizer(MixedPrecisionOptimizer): class DistributedOptimizer(MixedPrecisionOptimizer):
@classmethod @classmethod
...@@ -82,6 +77,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -82,6 +77,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
return param_range_map return param_range_map
@classmethod @classmethod
def build_model_gbuf_range(cls, model, dtype): def build_model_gbuf_range(cls, model, dtype):
...@@ -121,6 +117,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -121,6 +117,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
return data return data
@classmethod @classmethod
def build_model_gbuf_range_map(cls, model): def build_model_gbuf_range_map(cls, model):
return { return {
...@@ -128,6 +125,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -128,6 +125,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
for dtype in model._grad_buffers for dtype in model._grad_buffers
} }
@classmethod @classmethod
def build_model_param_gbuf_map(cls, model_gbuf_ranges): def build_model_param_gbuf_map(cls, model_gbuf_ranges):
'''Create a reverse of the model_gbuf_ranges, for referencing in '''Create a reverse of the model_gbuf_ranges, for referencing in
...@@ -139,42 +137,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -139,42 +137,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
param_gbuf_map[param] = (model_index, dtype) param_gbuf_map[param] = (model_index, dtype)
return param_gbuf_map return param_gbuf_map
# >>>
# @classmethod
# def build_optimizer_group_ranges(cls, param_groups, model_gbuf_ranges):
# num_groups = len(param_groups)
# # Param group map.
# param_group_map = {}
# for group_index, group in enumerate(param_groups):
# for param in group["params"]:
# assert param.requires_grad
# param_group_map[param] = group_index
# # Optimizer group ranges.
# group_ranges = [ {"size": 0, "param_map": {}} for _ in param_groups ]
# for model_gbuf_range_map in model_gbuf_ranges:
# for dtype, gbuf_range_map in model_gbuf_range_map.items():
# for param in gbuf_range_map["param_map"]:
# group_index = param_group_map[param]
# group_range = group_ranges[group_index]
# param_size = gbuf_range_map["param_map"][param]["param"].size
# param_group_start = group_range["size"]
# param_group_end = param_group_start + param_size
# param_group_range = Range(param_group_start, param_group_end)
# group_range["size"] += param_size
# group_range["param_map"][param] = param_group_range
# # Squeeze zero-size group ranges.
# for group_index, group_range in enumerate(group_ranges):
# group_range["orig_group"] = param_groups[group_index]
# group_ranges = [ g for g in group_ranges if g["size"] > 0 ]
# return group_ranges
@classmethod @classmethod
def build_optimizer_group_ranges(cls, param_groups, model_gbuf_ranges): def build_optimizer_group_ranges(cls, param_groups, model_gbuf_ranges):
...@@ -291,6 +254,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -291,6 +254,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
shard_fp32_from_float16_groups, shard_fp32_from_float16_groups,
) )
def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad, def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, use_contiguous_buffers_in_local_ddp, params_have_main_grad, use_contiguous_buffers_in_local_ddp,
fp16, bf16, grad_scaler, models): fp16, bf16, grad_scaler, models):
...@@ -302,11 +266,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -302,11 +266,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# Verify that contiguous buffers are being used # Verify that contiguous buffers are being used
# - Note: this should already be checked in arguments.py # - Note: this should already be checked in arguments.py
# >>>
# args = get_args()
# assert args.use_contiguous_buffers_in_local_ddp
assert use_contiguous_buffers_in_local_ddp assert use_contiguous_buffers_in_local_ddp
# <<<
# Model grad buffer ranges. # Model grad buffer ranges.
self.model_gbuf_ranges = [] self.model_gbuf_ranges = []
...@@ -331,12 +291,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -331,12 +291,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
self.model_param_gbuf_map, self.model_param_gbuf_map,
self.opt_group_ranges) self.opt_group_ranges)
# print_seq("16 [%d], 16x32 [%d], 32 [%d]." % (
# sum(len(g) for g in self.float16_groups),
# sum(len(g) for g in self.fp32_from_float16_groups),
# sum(len(g) for g in self.fp32_groups),
# ))
# Update optimizer groups. # Update optimizer groups.
# - Also, leverage state_dict() and load_state_dict() to # - Also, leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors. # recast preexisting per-param state tensors.
...@@ -344,34 +298,11 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -344,34 +298,11 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
[ g["orig_group"] for g in self.opt_group_ranges ] [ g["orig_group"] for g in self.opt_group_ranges ]
self.optimizer.load_state_dict(self.optimizer.state_dict()) self.optimizer.load_state_dict(self.optimizer.state_dict())
# >>>
# # Initialize main params.
# self._copy_model_params_to_main_params()
# <<<
# >>>
# # Params for grad norm.
# self.main_grad_views_for_grad_norm = self.build_main_grad_views_for_grad_norm(
# self.opt_group_ranges,
# self.optimizer)
# <<<
def get_model_param_range_map(self, param): def get_model_param_range_map(self, param):
model_index, dtype = self.model_param_gbuf_map[param] model_index, dtype = self.model_param_gbuf_map[param]
gbuf_range_map = self.model_gbuf_ranges[model_index][dtype] gbuf_range_map = self.model_gbuf_ranges[model_index][dtype]
param_range_map = gbuf_range_map["param_map"][param] param_range_map = gbuf_range_map["param_map"][param]
# >>>
# pax(0, {
# "param" : param,
# "model_index" : model_index,
# "dtype" : str(dtype),
# "gbuf_range_map" : gbuf_range_map,
# "param_range_map" : param_range_map,
# })
# <<<
return param_range_map return param_range_map
...@@ -379,28 +310,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -379,28 +310,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
return None return None
# def get_main_params(self):
# return [ g["params"][0] for g in self.optimizer.param_groups ]
# def get_main_grads(self):
# return [ p.grad for p in self.get_main_params() ]
# def get_main_param(self, group_index):
# return self.get_main_params()[group_index]
# def get_main_grad(self, group_index):
# return self.get_main_param(group_index).grad
# >>> # >>>
# def get_main_grads_for_grad_norm(self):
# return self.main_grad_views_for_grad_norm
# def get_main_grads_for_grad_norm(self):
# raise Exception("....... use 'super' .......")
# grads_for_norm = super().get_main_grads_for_grad_norm()
# if torch.distributed.get_rank() == 1:
# print_seq([ tp(g) for g in grads_for_norm ])
# return grads_for_norm
# <<<
# def state_dict(self): # def state_dict(self):
# state_dict = {} # state_dict = {}
# state_dict['optimizer'] = self.optimizer.state_dict() # state_dict['optimizer'] = self.optimizer.state_dict()
...@@ -410,8 +320,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -410,8 +320,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# return state_dict # return state_dict
def state_dict(self): def state_dict(self):
raise Exception("fix me.") raise Exception("fix me.")
# <<<
# >>>
# def load_state_dict(self, state_dict): # def load_state_dict(self, state_dict):
# # Optimizer. # # Optimizer.
# optimizer_key = 'optimizer' # optimizer_key = 'optimizer'
...@@ -441,20 +353,9 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -441,20 +353,9 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# current_param.data.copy_(saved_param.data) # current_param.data.copy_(saved_param.data)
def load_state_dict(self, state_dict): def load_state_dict(self, state_dict):
raise Exception("hi.") raise Exception("hi.")
# <<<
# >>>
# def zero_grad(self, set_to_none=True):
# # Collect model params.
# model_params = []
# for model in self.models:
# for dtype, param_map in model._grad_buffer_param_index_map.items():
# model_params.extend(param_map.keys())
# # Distributed optimizer requires contiguous buffer; don't set to None.
# _zero_grad_group_helper(model_params, set_to_none = False)
# def zero_grad(self, set_to_none=True):
# raise Exception("does 'super' work?")
def zero_grad(self, set_to_none=True): def zero_grad(self, set_to_none=True):
"""We only need to zero the model related parameters, i.e., """We only need to zero the model related parameters, i.e.,
float16_groups & fp32_groups. We additionally zero float16_groups & fp32_groups. We additionally zero
...@@ -469,7 +370,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -469,7 +370,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
self.shard_fp32_from_float16_groups): self.shard_fp32_from_float16_groups):
for group in groups: for group in groups:
_zero_grad_group_helper(group, set_to_none) _zero_grad_group_helper(group, set_to_none)
# <<<
def get_model_grad_buffer_dp_views(self): def get_model_grad_buffer_dp_views(self):
...@@ -489,6 +389,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -489,6 +389,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
return gbuf_view_items return gbuf_view_items
def reduce_model_grads(self, args, timers): def reduce_model_grads(self, args, timers):
'''Note: this is a different order of reduction, versus the non- '''Note: this is a different order of reduction, versus the non-
distributed optimizer, which reduces: 1) all grads, 2) embedding distributed optimizer, which reduces: 1) all grads, 2) embedding
...@@ -522,6 +423,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -522,6 +423,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
timers('backward-params-all-reduce').stop() timers('backward-params-all-reduce').stop()
def gather_model_params(self, args, timers): def gather_model_params(self, args, timers):
timers('backward-params-all-gather').start() timers('backward-params-all-gather').start()
...@@ -552,55 +454,27 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -552,55 +454,27 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
timers('backward-params-all-gather').stop() timers('backward-params-all-gather').stop()
# >>>
# def _collect_main_grad_data_for_unscaling(self):
# return [ g.data for g in self.get_main_grads() ]
def _collect_main_grad_data_for_unscaling(self): def _collect_main_grad_data_for_unscaling(self):
main_grad_data = [ return [
param.grad.data param.grad.data
for group in self.optimizer.param_groups for group in self.optimizer.param_groups
for param in group["params"] for param in group["params"]
] ]
# print_seq([ tp(g) for g in main_grad_data ])
return main_grad_data
# <<<
# >>>
# def _copy_model_params_to_main_params(self):
# for group_index, group_range in enumerate(self.opt_group_ranges):
# main_param = self.get_main_param(group_index)
# for model_param, main_range in group_range["param_map"].items():
# # Model range. def _get_model_and_main_params_data_float16(self):
# # model_index, dtype = self.param_gbuf_map[model_param] model_data = []
# # model_range = self.model_gbuf_ranges \ main_data = []
# # [model_index][dtype]["param_map"][model_param]["param"] for model_group, main_group in zip(self.shard_float16_groups,
# model_range = self.get_model_param_range_map(model_param)["param"] self.shard_fp32_from_float16_groups):
for model_param, main_param in zip(model_group, main_group):
model_data.append(model_param.data)
main_data.append(main_param.data)
return model_data, main_data
# assert main_range.size == model_range.size
# # Copy shard data.
# main_view = main_param[main_range.start:main_range.end]
# model_view = model_param.view(-1)[model_range.start:model_range.end]
# main_view.detach().copy_(model_view)
def _copy_model_params_to_main_params(self):
raise Exception("check if super's copy works.")
# <<<
# >>>
def _copy_model_grads_to_main_grads(self): def _copy_model_grads_to_main_grads(self):
# >>>
# print_seq([
# "grad = %s." % tp(p.grad)
# for g in self.optimizer.param_groups
# for p in g["params"]
# ])
# <<<
def copy_group_grads(full_model_groups, shard_main_groups): def copy_group_grads(full_model_groups, shard_main_groups):
for full_model_group, shard_main_group in zip(full_model_groups, for full_model_group, shard_main_group in zip(full_model_groups,
shard_main_groups): shard_main_groups):
...@@ -616,28 +490,12 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -616,28 +490,12 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
[param_range.start:param_range.end] [param_range.start:param_range.end]
shard_main_param.grad = shard_model_grad.float() shard_main_param.grad = shard_model_grad.float()
# print_seq([ "%s / %d, [%d] %s" % (
# k, i, len(g), ", ".join(str(p.nelement()) for p in g),
# ) for k, gs in [
# ("model", self.full_float16_groups),
# ("main", self.shard_fp32_from_float16_groups),
# ] for i, g in enumerate(gs)])
copy_group_grads(self.full_float16_groups, copy_group_grads(self.full_float16_groups,
self.shard_fp32_from_float16_groups) self.shard_fp32_from_float16_groups)
copy_group_grads(self.full_fp32_groups, copy_group_grads(self.full_fp32_groups,
self.shard_fp32_groups) self.shard_fp32_groups)
# >>>
# print_seq([
# "grad = %s." % tp(p.grad)
# for g in self.optimizer.param_groups
# for p in g["params"]
# ])
# <<<
# <<<
# >>>
def _copy_main_params_to_model_params(self): def _copy_main_params_to_model_params(self):
def copy_group_params(shard_main_groups, full_model_groups): def copy_group_params(shard_main_groups, full_model_groups):
...@@ -660,12 +518,3 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -660,12 +518,3 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
self.full_float16_groups) self.full_float16_groups)
copy_group_params(self.shard_fp32_groups, copy_group_params(self.shard_fp32_groups,
self.full_fp32_groups) self.full_fp32_groups)
# >>>
# print_seq([
# "grad = %s." % tp(p.grad)
# for g in self.optimizer.param_groups
# for p in g["params"]
# ])
# <<<
# <<<
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment