Commit ae6cd9b5 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

some cleanup

parent e5db0fda
......@@ -233,6 +233,10 @@ class BaseFloat16Optimizer(MegatronOptimizer):
return self.grad_scaler.scale
def reload_model_params(self):
self._copy_model_params_to_main_params()
def _unscale_main_grads_and_check_for_nan(self):
# Collect main grads.
......@@ -269,6 +273,11 @@ class BaseFloat16Optimizer(MegatronOptimizer):
self._copy_model_grads_to_main_grads()
timers('optimizer-copy-to-main-grad').stop()
# pax(0, {
# "params" : self.get_parameters(), # self.main_param_shards,
# "grads" : [ p.grad for p in self.get_parameters() ], # self.main_param_shards ],
# })
# Do unscale, check for inf, and update grad scaler only for
# the case that grad scaler is provided.
if self.grad_scaler:
......@@ -284,6 +293,7 @@ class BaseFloat16Optimizer(MegatronOptimizer):
# If we found inf/nan, skip the update.
if found_inf_flag:
pax(0, {"found_inf_flag": found_inf_flag})
return False, None, None
# Clip the main gradients.
......@@ -301,12 +311,16 @@ class BaseFloat16Optimizer(MegatronOptimizer):
self.optimizer.step()
# >>>
# from lutil import pax, tp
# pax(0, {
# "optimizer / state" :
# { hash(k):tp(v) for k,v in self.optimizer.state.items() },
# "optimizer / state / len" : len(self.optimizer.state),
# "optimizer / state / 0" : list(self.optimizer.state.values())[0],
# # "optimizer / state" :
# # { hash(k):tp(v) for k,v in self.optimizer.state.items() },
# # "optimizer / state / len" : len(self.optimizer.state),
# # "optimizer / state / 0" : list(self.optimizer.state.values())[0],
# **{"optimizer / state / %s" % hash(k) : tp(v) for k, v in self.optimizer.state.items() },
# "params" : sum(
# s["exp_avg"].numel()
# for s in self.optimizer.state.values()
# ),
# })
# <<<
......@@ -536,8 +550,7 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
timers('backward-embedding-all-reduce').stop()
def gather_params(self):
raise Exception("hi.")
pass
def _copy_model_grads_to_main_grads(self):
# This only needs to be done for the float16 group.
......@@ -621,10 +634,6 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
overflow_buf=self._dummy_overflow_buf)
def reload_model_params(self):
self._copy_model_params_to_main_params()
def state_dict(self):
state_dict = {}
state_dict['optimizer'] = self.optimizer.state_dict()
......@@ -669,13 +678,7 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
# >>>
import math
# from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron import get_args
# from megatron import get_timers
# from megatron.model import DistributedDataParallel as LocalDDP
# from megatron.model import Float16Module
# from megatron.utils import unwrap_model
# class ShardIndex:
class Shard:
......@@ -726,230 +729,7 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
})
# <<<
# def __init__(self, *_args):
# super().__init__(*_args)
# def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
# params_have_main_grad, use_contiguous_buffers_in_local_ddp,
# bf16, grad_scaler):
# super().__init__(
# optimizer, clip_grad, log_num_zeros_in_grad,
# params_have_main_grad, use_contiguous_buffers_in_local_ddp,
# bf16, grad_scaler)
# # >>>
# # self.test_reduce_scatter()
# # <<<
# # >>>
# args = get_args()
# # <<<
# # Data parallel info.
# self.data_parallel_group = mpu.get_data_parallel_group()
# self.data_parallel_rank = mpu.get_data_parallel_rank()
# self.data_parallel_world_size = mpu.get_data_parallel_world_size()
# # Total trainable param count.
# # self.total_param_size = sum(
# # p.numel()
# # for g in self.param_groups
# # for p in g["params"]
# # # if p .requires_grad ???
# # )
# # Model params: group sizes, group offset maps.
# # self.model_params = []
# # self.model_param_group_sizes = []
# # self.model_param_group_offset_maps = []
# self.model_param_groups = []
# for param_group in self.optimizer.param_groups:
# param_group_offset = 0
# param_group_offset_map = {}
# for param in param_group['params']:
# if not param.requires_grad:
# continue
# # self.model_params.append(param)
# param_group_offset_map[param] = {
# "start" : param_group_offset,
# "end" : param_group_offset + param.numel(),
# }
# param_group_offset += param.numel()
# # self.model_param_group_sizes.append(param_group_offset)
# # self.model_param_group_offset_maps.append(param_group_offset_map)
# self.model_param_groups.append({
# "size" : param_group_offset,
# "offset_map" : param_group_offset_map,
# })
# # pax(0, {
# # "model_params" : model_params,
# # "model_param_group_sizes" : model_param_group_sizes,
# # "model_param_group_offset_maps" : model_param_group_offset_maps,
# # })
# # Shard allocator.
# # ** torch.nn.Parameter ??
# # ** MemoryBuffer ??
# allocate_shard = lambda shard_size, dtype : torch.empty(
# (shard_size,),
# dtype = dtype,
# device = torch.cuda.current_device(),
# requires_grad = True)
# # Allocate shards.
# # (Also, collect world DP shard info.)
# # model_main_dtypes = set([ args.params_dtype, torch.float ])
# model_main_dtypes = set([ torch.float ]) # fp32 only, for now
# # self.world_shard_info_groups = [] # world_group_shard_infos ?
# # self.main_param_shard_groups = []
# self.world_shard_infos = [{"groups": []} for _ in self.model_param_groups]
# for group_index, model_param_group in enumerate(self.model_param_groups):
# # Max world shard size.
# model_param_size = model_param_group["size"]
# max_world_shard_size = int(math.ceil(model_param_size /
# self.data_parallel_world_size))
# # DP world shard infos.
# # world_shard_infos = []
# for r in range(self.data_parallel_world_size):
# shard_start_index = r * max_world_shard_size
# shard_end_index = min(model_param_size,
# shard_start_index + max_world_shard_size)
# # world_shard_infos.append({
# self.world_shard_infos[r]["groups"].append({
# "start" : shard_start_index,
# "end" : shard_end_index,
# "size" : shard_end_index - shard_start_index,
# })
# # self.world_shard_info_groups.append(world_shard_infos)
# # self.world_shard_infos[group_index].append(world_shard_infos)
# # DP local rank's shard info.
# # local_shard_info = world_shard_infos[self.data_parallel_rank]
# local_shard_info = \
# self.world_shard_infos[self.data_parallel_rank]["groups"][-1]
# local_shard_start_index = local_shard_info["start"]
# local_shard_end_index = local_shard_info["end"]
# local_shard_size = local_shard_info["size"]
# # Local shard's param 'slice' index map.
# local_shard_info["param_slice_index_map"] = {}
# for param, offset_dict in model_param_group["offset_map"].items():
# # param_start_index = offset_dict["start"]
# # param_end_index = offset_dict["end"]
# # param_shard_start_index = max(local_shard_start_index,
# # param_start_index)
# # param_shard_end_index = min(local_shard_end_index,
# # param_end_index)
# orig_start_index = offset_dict["start"]
# orig_end_index = offset_dict["end"]
# shard_start_index = max(
# 0,
# orig_start_index - local_shard_start_index)
# shard_end_index = min(
# local_shard_end_index,
# orig_end_index - local_shard_start_index)
# # if param_shard_end_index > param_shard_start_index:
# # # Indexes are relative to local shard start index.
# # # local_shard_info["param_index_map"][param] = {
# # # "param" : (
# # # param_shard_start_index,
# # # param_shard_end_index,
# # # ),
# # # "shard" : (
# # # param_shard_start_index - local_shard_start_index,
# # # param_shard_end_index - local_shard_start_index,
# # # ),
# # # }
# # local_shard_info["param_slice_index_map"][param] = {
# # "param_start" :
# # param_shard_start_index,
# # "shard_start" :
# # param_shard_start_index - local_shard_start_index,
# # "size":
# # param_shard_end_index - param_shard_start_index,
# # }
# if shard_end_index > shard_start_index:
# local_shard_info["param_slice_index_map"][param] = {
# "orig_start" : orig_start_index,
# "shard_start" : shard_start_index,
# "size" : shard_end_index - shard_start_index,
# }
# # pax(0, {
# # "local index" : "%d, %d" % (
# # local_shard_start_index,
# # local_shard_end_index,
# # ),
# # "param index" : "%s, %d" % (
# # param_start_index,
# # param_end_index,
# # ),
# # "param" : tp(param),
# # "shard_param_index_map" : shard_param_index_map,
# # "local_shard_info" : local_shard_info,
# # })
# # pax(2, {
# # "data_parallel_rank" : self.data_parallel_rank,
# # "local_shard_info" : local_shard_info,
# # "param_index_map " : [
# # (str(p.shape), i)
# # for p, i in local_shard_info["param_index_map"].items()
# # ],
# # })
# # Allocate shards.
# # (Non-fp32 shards are for convenience; e.g., intermediaries
# # between model params and main fp32 shard. Necessary???)
# # main_param_shards = {
# # ty : allocate_shard(local_shard_size, ty)
# # for ty in model_main_dtypes}
# main_param_shards = {}
# for dtype in model_main_dtypes:
# main_param = allocate_shard(local_shard_size, dtype)
# main_param.grad = allocate_shard(local_shard_size, dtype)
# # pax(0, {"main_param": main_param})
# main_param_shards[dtype] = main_param
# # self.main_param_shard_groups.append(main_param_shards)
# local_shard_info["data"] = main_param_shards
# # Update optimizer group.
# self.optimizer.param_groups[group_index]["params"] = \
# [ main_param_shards[torch.float] ]
# # pax(0, {
# # "param_groups" : self.optimizer.param_groups,
# # "params" : self.optimizer.param_groups[group_index]["params"],
# # })
# # Add world start/end indexes, for reduce/gather steps.
# offset = 0
# for r in self.world_shard_infos:
# r["start_index"] = offset
# offset += sum(g["size"] for g in r["groups"])
# r["end_index"] = offset
# # Leverage state_dict() and load_state_dict() to
# # recast preexisting per-param state tensors
# self.optimizer.load_state_dict(self.optimizer.state_dict())
# # >>>
# # pax(0, {
# # "world_shard_infos" : self.world_shard_infos,
# # **{
# # "world_shard_infos / %d" % i : r
# # for i, r in enumerate(self.world_shard_infos)
# # },
# # })
# # <<<
@classmethod
# def get_ddp_gbuf_param_shards(cls, model, dtype, gbuf_start):
# def get_ddp_gbuf_param_shard_map(cls, model, dtype, gbuf_start):
# def get_model_gbuf_param_shard_index_map(cls,model,dtype,gbuf_world_index):
def get_model_gbuf_param_shard_map(cls, model, dtype, gbuf_world_shard):
# Param shard map.
......@@ -980,9 +760,6 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
return param_shard_map
@classmethod
# def get_ddp_gbuf_shard(cls, model, dtype):
# def get_model_gbuf_shard(cls, model, dtype):
# def get_model_gbuf_shard_index(cls, model, dtype):
def get_model_gbuf_shard(cls, model, dtype):
data_parallel_rank = mpu.get_data_parallel_rank()
......@@ -1001,7 +778,6 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
gbuf_world_all_shards.append(gbuf_world_shard)
gbuf_world_shard = gbuf_world_all_shards[data_parallel_rank]
gbuf_local_shard = gbuf_world_shard.normalize()
# gbuf_local_shard = Shard(0, gbuf_world_index.size)
# Param shards.
param_shard_map = cls.get_model_gbuf_param_shard_map(model,
......@@ -1021,10 +797,6 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
return data
@classmethod
# def get_ddp_gbuf_shards(cls, model):
# def get_ddp_gbuf_shard_map(cls, model):
# def get_model_gbuf_shard_map(cls, model):
# def get_model_gbuf_shard_index_map(cls, model):
def get_model_gbuf_shard_map(cls, model):
# shard_index_map = {
......@@ -1038,11 +810,8 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
return shard_map
@classmethod
# def get_param_size_map(cls, model_gbuf_shards):
# def get_param_model_gbuf_map(cls, model_gbuf_shards):
def get_param_gbuf_map(cls, model_gbuf_shards):
# param_size_map = {}
param_gbuf_map = {}
for model_index, model_gbuf_shard_map in enumerate(model_gbuf_shards):
for dtype, gbuf_shard_map in model_gbuf_shard_map.items():
......@@ -1064,7 +833,6 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# "param_gbuf_map" : param_gbuf_map,
# })
# return param_size_map
return param_gbuf_map
@classmethod
......@@ -1120,8 +888,6 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
assert args.use_contiguous_buffers_in_local_ddp # already checked in args
# <<<
# pax(1, {"models": models})
# # Data parallel info.
# self.data_parallel_group = mpu.get_data_parallel_group()
# self.data_parallel_rank = mpu.get_data_parallel_rank()
......@@ -1138,8 +904,6 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
self.optimizer.param_groups,
self.model_gbuf_shards)
# pax(0, {"opt_group_shards": self.opt_group_shards})
# Allocate main param/grad shard.
# ** torch.nn.Parameter ??
# ** MemoryBuffer ??
......@@ -1165,6 +929,9 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# Update optimizer group.
self.optimizer.param_groups[group_index]["params"] = [ main_param ]
# Initialize main params.
self._copy_model_params_to_main_params()
# Leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors
self.optimizer.load_state_dict(self.optimizer.state_dict())
......@@ -1177,11 +944,6 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# })
# <<<
# def get_loss_scale(self):
# if self.grad_scaler is None:
# return self._scale_one
# return self.grad_scaler.scale
def load_state_dict(self):
raise Exception("hi.")
def reload_model_params(self):
......@@ -1189,21 +951,6 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
def state_dict(self):
raise Exception("hi.")
# def zero_grad(self, set_to_none=True):
# params = []
# for model_param_group in self.model_param_groups:
# params.extend(model_param_group["offset_map"].keys())
# for main_group in self.optimizer.param_groups:
# params.extend(main_group["params"])
# # _zero_grad_group_helper(params, set_to_none)
# _zero_grad_group_helper(params, set_to_none = False)
# # pax(0, {
# # "model_param_groups" : self.model_param_groups,
# # "params" : params,
# # })
def zero_grad(self, set_to_none=True):
model_params = []
......@@ -1219,110 +966,13 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# pax(0, {"params": params})
# def reduce_gradients(self, model):
# # >>>
# # pax(0, {"main param" : self.world_shard_info_groups[0][self.data_parallel_rank]["data"][torch.float]})
# # <<<
# # >>>
# args = get_args()
# # timers = get_timers()
# # <<<
# # >>> [ temporary requirement ... and already checked in arguments.py ]
# assert args.use_contiguous_buffers_in_local_ddp
# # <<<
# # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# # Map param to virtual model.
# # ** ideally, this should happen once, during construction.
# param_model_map = {}
# for vmodel in model:
# for dtype, param_index_map in \
# vmodel._grad_buffer_param_index_map.items():
# for param in param_index_map:
# param_model_map[param] = {
# "dtype" : dtype,
# "model" : vmodel,
# }
# # pax(0, {
# # "param_model_map" : [
# # (str(tuple(p.shape)), m)
# # for p, m in param_model_map.items()
# # ],
# # })
# # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# # Copy model grads to main shard.
# local_shard_info_groups = [g[self.data_parallel_rank]
# for g in self.world_shard_info_groups]
# for group_index, local_shard_info in enumerate(local_shard_info_groups):
# # model_param_index_map =
# # shard_param_index_map = local_shard_info["param_index_map"]
# # main_index_map = local_shard_info["param_index_map"]
# main_slice_index_map = local_shard_info["param_slice_index_map"]
# for param, main_slice_indexes in main_slice_index_map.items():
# main_slice_orig_start_index = main_slice_indexes["orig_start"]
# main_slice_shard_start_index = main_slice_indexes["shard_start"]
# main_slice_size = main_slice_indexes["size"]
# dtype_model_dict = param_model_map[param]
# dtype = dtype_model_dict["dtype"]
# vmodel = dtype_model_dict["model"]
# model_grad_buffer = vmodel._grad_buffers[dtype].data
# model_grad_buffer_start_index = \
# vmodel._grad_buffer_param_index_map[dtype][param][0] + \
# main_slice_orig_start_index
# main_grad_view = local_shard_info["data"][torch.float].grad[
# main_slice_shard_start_index:
# main_slice_shard_start_index + main_slice_size
# ]
# model_grad_view = model_grad_buffer[
# model_grad_buffer_start_index:
# model_grad_buffer_start_index + main_slice_size
# ]
# main_grad_view.detach().copy_(model_grad_view)
# # pax(0, {
# # # "local_shard_info" : local_shard_info,
# # "main_slice_orig_start_index" : main_slice_orig_start_index,
# # "main_slice_shard_start_index" : main_slice_shard_start_index,
# # "main_slice_size" : main_slice_size,
# # "model_grad_buffer_start_index" :
# # model_grad_buffer_start_index,
# # "main_grad_view" : tp(main_grad_view),
# # "main_grad_view / detach" : tp(main_grad_view.detach()),
# # "model_grad_view" : tp(model_grad_view),
# # })
# # pax(0, {
# # "group_index" : group_index,
# # "local_shard_info" : local_shard_info,
# # "shard_param_index_map" : shard_param_index_map,
# # "param" : tp(param),
# # "shard_indexes" : shard_indexes,
# # "grad_buffer_indexes" : grad_buffer_indexes,
# # })
# pax(0, {
# # "world_shard_info_groups" : self.world_shard_info_groups,
# # **{"world_shard_info_groups / %d" % i : v
# # for i, v in enumerate(self.world_shard_info_groups)},
# # "local_shard_info_groups" : local_shard_info_groups,
# "local_shard_info_groups" : [ g["data"] for g in local_shard_info_groups ],
# })
def get_model_grad_buffer_dp_views(self):
# >>>
# ** only contiguous grad buffer supported, for now [ TEMPORARY ] **
args = get_args()
assert args.use_contiguous_buffers_in_local_ddp
# <<<
# Grad buffer views.
gbuf_view_items = []
......@@ -1343,11 +993,6 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
def reduce_gradients(self, model):
# >>>
args = get_args()
# timers = get_timers()
# <<<
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Sync word embedding params.
......@@ -1360,36 +1005,6 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Reduce-scatter.
# # ** only contiguous grad buffer supported, for now [ TEMPORARY ] **
# assert args.use_contiguous_buffers_in_local_ddp
# data_parallel_rank = mpu.get_data_parallel_rank()
# data_parallel_group = mpu.get_data_parallel_group()
# for model_index, model in enumerate(self.models):
# for dtype, gbuf_shard in self.model_gbuf_shards[model_index].items():
# world_shards = gbuf_shard["world_all"]
# gbuf = model._grad_buffers[dtype]
# gbuf_views = []
# for shard in world_shards:
# gbuf_views.append(gbuf.data[shard.start:shard.end])
# torch.distributed.reduce_scatter(
# gbuf_views[data_parallel_rank],
# gbuf_views,
# group = data_parallel_group,
# )
# # pax(0, {
# # "model_index" : model_index,
# # "model" : model,
# # "dtype" : str(dtype),
# # "gbuf_shard" : gbuf_shard,
# # "world_shards" : world_shards,
# # "gbuf_views" : gbuf_views,
# # })
data_parallel_rank = mpu.get_data_parallel_rank()
data_parallel_group = mpu.get_data_parallel_group()
......@@ -1411,6 +1026,7 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
gbuf_view_items = self.get_model_grad_buffer_dp_views()
# All-gather updated main params.
for model_index, dtype, gbuf_views in gbuf_view_items:
torch.distributed.all_gather(
gbuf_views,
......@@ -1418,15 +1034,8 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
group = data_parallel_group,
)
# for param, (model_index, dtype) in self.param_gbuf_map.items():
# gbuf = self.model_gbuf_shards[model_index][dtype]
# pax(0, {
# "param" : tp(param),
# "model_index" : model_index,
# "dtype" : str(dtype),
# "gbuf" : gbuf,
# })
# Each model param now contains its updated values in it's
# '.main_grad' field.
for param in self.param_gbuf_map:
param.detach().copy_(param.main_grad)
# pax(0, {
......@@ -1443,15 +1052,7 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# ],
# })
# def step(self):
# raise Exception("step.")
def _collect_main_grad_data_for_unscaling(self):
# pax(1, {
# "main_param_shards" : self.main_param_shards,
# "grads" : [ p.grad.data for p in self.main_param_shards ],
# })
return [ p.grad.data for p in self.main_param_shards ]
def _copy_model_grads_to_main_grads(self):
......@@ -1465,11 +1066,13 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
assert main_shard.size == model_shard.size
# Copy from DDP's contiguous buffer to main shard's grad.
model_grad_tensor = \
self.models[model_index]._grad_buffers[gbuf_dtype].data
main_grad_tensor = \
self.main_param_shards[group_index].grad
# Copy sub-range within tensor.
model_grad_view = \
model_grad_tensor[model_shard.start:model_shard.end]
main_grad_view = \
......@@ -1491,10 +1094,20 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# "main_shard" : str(main_shard),
# })
# >>>
# pax(0, {
# "model_gbuf_shards" : self.model_gbuf_shards,
# "opt_group_shards" : self.opt_group_shards,
# })
for param in self.main_param_shards:
grad = param.grad
is_nan = torch.any(torch.isnan(grad)).item()
if is_nan:
pax(0, {
"grad" : tp(grad),
"is_nan" : is_nan,
})
# <<<
def _copy_main_params_to_model_params(self):
......@@ -1538,6 +1151,15 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# "model_gbuf_shards" : self.model_gbuf_shards,
# "opt_group_shards" : self.opt_group_shards,
# })
# >>>
for param in self.param_gbuf_map:
is_nan = torch.any(torch.isnan(param)).item()
if is_nan:
pax(0, {
"param" : tp(param),
"is_nan" : is_nan,
})
# <<<
# <<<
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment