Commit cb6f96b6 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

wip; switching to grad-buffer-centric design

parent a3f3c3ad
...@@ -130,9 +130,11 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -130,9 +130,11 @@ def parse_args(extra_args_provider=None, defaults={},
args.global_batch_size), flush=True) args.global_batch_size), flush=True)
assert args.global_batch_size > 0 assert args.global_batch_size > 0
if args.num_layers_per_virtual_pipeline_stage is not None: if args.num_layers_per_virtual_pipeline_stage is not None:
assert args.pipeline_model_parallel_size > 2, \ # >>> [ temporarily turning off ]
'pipeline-model-parallel size should be greater than 2 with ' \ # assert args.pipeline_model_parallel_size > 2, \
'interleaved schedule' # 'pipeline-model-parallel size should be greater than 2 with ' \
# 'interleaved schedule'
# <<<
assert args.num_layers % args.num_layers_per_virtual_pipeline_stage == 0, \ assert args.num_layers % args.num_layers_per_virtual_pipeline_stage == 0, \
'number of layers is not divisible by number of layers per virtual ' \ 'number of layers is not divisible by number of layers per virtual ' \
'pipeline stage' 'pipeline stage'
......
...@@ -97,11 +97,11 @@ def get_megatron_optimizer(model, ...@@ -97,11 +97,11 @@ def get_megatron_optimizer(model,
# from lutil import pax # from lutil import pax
# pax(0, { # pax(0, {
# "model" : model, # "model" : model,
# "param_groups" : param_groups, # # "param_groups" : param_groups,
# "param_groups / 0" : param_groups[0], # # "param_groups / 0" : param_groups[0],
# "param_groups / 0 / params" : param_groups[0]["params"], # # "param_groups / 0 / params" : param_groups[0]["params"],
# "param_groups / 1" : param_groups[1], # # "param_groups / 1" : param_groups[1],
# "param_groups / 1 / params" : param_groups[1]["params"], # # "param_groups / 1 / params" : param_groups[1]["params"],
# }) # })
# <<< # <<<
...@@ -164,7 +164,8 @@ def get_megatron_optimizer(model, ...@@ -164,7 +164,8 @@ def get_megatron_optimizer(model,
params_have_main_grad, params_have_main_grad,
args.use_contiguous_buffers_in_local_ddp, args.use_contiguous_buffers_in_local_ddp,
args.bf16, args.bf16,
grad_scaler) grad_scaler,
model)
# <<< # <<<
# FP32. # FP32.
......
...@@ -184,12 +184,16 @@ class BaseFloat16Optimizer(MegatronOptimizer): ...@@ -184,12 +184,16 @@ class BaseFloat16Optimizer(MegatronOptimizer):
def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad, def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, use_contiguous_buffers_in_local_ddp, params_have_main_grad, use_contiguous_buffers_in_local_ddp,
bf16, grad_scaler): bf16, grad_scaler,
models):
super().__init__( super().__init__(
optimizer, clip_grad, log_num_zeros_in_grad, optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, use_contiguous_buffers_in_local_ddp) params_have_main_grad, use_contiguous_buffers_in_local_ddp)
# >>>
self.models = models
# <<<
self.bf16 = bf16 self.bf16 = bf16
self.grad_scaler = grad_scaler self.grad_scaler = grad_scaler
# None grad scaler is only supported for bf16. # None grad scaler is only supported for bf16.
...@@ -697,65 +701,338 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -697,65 +701,338 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# def __init__(self, *_args): # def __init__(self, *_args):
# super().__init__(*_args) # super().__init__(*_args)
# def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
# params_have_main_grad, use_contiguous_buffers_in_local_ddp,
# bf16, grad_scaler):
# super().__init__(
# optimizer, clip_grad, log_num_zeros_in_grad,
# params_have_main_grad, use_contiguous_buffers_in_local_ddp,
# bf16, grad_scaler)
# # >>>
# # self.test_reduce_scatter()
# # <<<
# # >>>
# args = get_args()
# # <<<
# # Data parallel info.
# self.data_parallel_group = mpu.get_data_parallel_group()
# self.data_parallel_rank = mpu.get_data_parallel_rank()
# self.data_parallel_world_size = mpu.get_data_parallel_world_size()
# # Total trainable param count.
# # self.total_param_size = sum(
# # p.numel()
# # for g in self.param_groups
# # for p in g["params"]
# # # if p .requires_grad ???
# # )
# # Model params: group sizes, group offset maps.
# # self.model_params = []
# # self.model_param_group_sizes = []
# # self.model_param_group_offset_maps = []
# self.model_param_groups = []
# for param_group in self.optimizer.param_groups:
# param_group_offset = 0
# param_group_offset_map = {}
# for param in param_group['params']:
# if not param.requires_grad:
# continue
# # self.model_params.append(param)
# param_group_offset_map[param] = {
# "start" : param_group_offset,
# "end" : param_group_offset + param.numel(),
# }
# param_group_offset += param.numel()
# # self.model_param_group_sizes.append(param_group_offset)
# # self.model_param_group_offset_maps.append(param_group_offset_map)
# self.model_param_groups.append({
# "size" : param_group_offset,
# "offset_map" : param_group_offset_map,
# })
# # pax(0, {
# # "model_params" : model_params,
# # "model_param_group_sizes" : model_param_group_sizes,
# # "model_param_group_offset_maps" : model_param_group_offset_maps,
# # })
# # Shard allocator.
# # ** torch.nn.Parameter ??
# # ** MemoryBuffer ??
# allocate_shard = lambda shard_size, dtype : torch.empty(
# (shard_size,),
# dtype = dtype,
# device = torch.cuda.current_device(),
# requires_grad = True)
# # Allocate shards.
# # (Also, collect world DP shard info.)
# # model_main_dtypes = set([ args.params_dtype, torch.float ])
# model_main_dtypes = set([ torch.float ]) # fp32 only, for now
# # self.world_shard_info_groups = [] # world_group_shard_infos ?
# # self.main_param_shard_groups = []
# self.world_shard_infos = [{"groups": []} for _ in self.model_param_groups]
# for group_index, model_param_group in enumerate(self.model_param_groups):
# # Max world shard size.
# model_param_size = model_param_group["size"]
# max_world_shard_size = int(math.ceil(model_param_size /
# self.data_parallel_world_size))
# # DP world shard infos.
# # world_shard_infos = []
# for r in range(self.data_parallel_world_size):
# shard_start_index = r * max_world_shard_size
# shard_end_index = min(model_param_size,
# shard_start_index + max_world_shard_size)
# # world_shard_infos.append({
# self.world_shard_infos[r]["groups"].append({
# "start" : shard_start_index,
# "end" : shard_end_index,
# "size" : shard_end_index - shard_start_index,
# })
# # self.world_shard_info_groups.append(world_shard_infos)
# # self.world_shard_infos[group_index].append(world_shard_infos)
# # DP local rank's shard info.
# # local_shard_info = world_shard_infos[self.data_parallel_rank]
# local_shard_info = \
# self.world_shard_infos[self.data_parallel_rank]["groups"][-1]
# local_shard_start_index = local_shard_info["start"]
# local_shard_end_index = local_shard_info["end"]
# local_shard_size = local_shard_info["size"]
# # Local shard's param 'slice' index map.
# local_shard_info["param_slice_index_map"] = {}
# for param, offset_dict in model_param_group["offset_map"].items():
# # param_start_index = offset_dict["start"]
# # param_end_index = offset_dict["end"]
# # param_shard_start_index = max(local_shard_start_index,
# # param_start_index)
# # param_shard_end_index = min(local_shard_end_index,
# # param_end_index)
# orig_start_index = offset_dict["start"]
# orig_end_index = offset_dict["end"]
# shard_start_index = max(
# 0,
# orig_start_index - local_shard_start_index)
# shard_end_index = min(
# local_shard_end_index,
# orig_end_index - local_shard_start_index)
# # if param_shard_end_index > param_shard_start_index:
# # # Indexes are relative to local shard start index.
# # # local_shard_info["param_index_map"][param] = {
# # # "param" : (
# # # param_shard_start_index,
# # # param_shard_end_index,
# # # ),
# # # "shard" : (
# # # param_shard_start_index - local_shard_start_index,
# # # param_shard_end_index - local_shard_start_index,
# # # ),
# # # }
# # local_shard_info["param_slice_index_map"][param] = {
# # "param_start" :
# # param_shard_start_index,
# # "shard_start" :
# # param_shard_start_index - local_shard_start_index,
# # "size":
# # param_shard_end_index - param_shard_start_index,
# # }
# if shard_end_index > shard_start_index:
# local_shard_info["param_slice_index_map"][param] = {
# "orig_start" : orig_start_index,
# "shard_start" : shard_start_index,
# "size" : shard_end_index - shard_start_index,
# }
# # pax(0, {
# # "local index" : "%d, %d" % (
# # local_shard_start_index,
# # local_shard_end_index,
# # ),
# # "param index" : "%s, %d" % (
# # param_start_index,
# # param_end_index,
# # ),
# # "param" : tp(param),
# # "shard_param_index_map" : shard_param_index_map,
# # "local_shard_info" : local_shard_info,
# # })
# # pax(2, {
# # "data_parallel_rank" : self.data_parallel_rank,
# # "local_shard_info" : local_shard_info,
# # "param_index_map " : [
# # (str(p.shape), i)
# # for p, i in local_shard_info["param_index_map"].items()
# # ],
# # })
# # Allocate shards.
# # (Non-fp32 shards are for convenience; e.g., intermediaries
# # between model params and main fp32 shard. Necessary???)
# # main_param_shards = {
# # ty : allocate_shard(local_shard_size, ty)
# # for ty in model_main_dtypes}
# main_param_shards = {}
# for dtype in model_main_dtypes:
# main_param = allocate_shard(local_shard_size, dtype)
# main_param.grad = allocate_shard(local_shard_size, dtype)
# # pax(0, {"main_param": main_param})
# main_param_shards[dtype] = main_param
# # self.main_param_shard_groups.append(main_param_shards)
# local_shard_info["data"] = main_param_shards
# # Update optimizer group.
# self.optimizer.param_groups[group_index]["params"] = \
# [ main_param_shards[torch.float] ]
# # pax(0, {
# # "param_groups" : self.optimizer.param_groups,
# # "params" : self.optimizer.param_groups[group_index]["params"],
# # })
# # Add world start/end indexes, for reduce/gather steps.
# offset = 0
# for r in self.world_shard_infos:
# r["start_index"] = offset
# offset += sum(g["size"] for g in r["groups"])
# r["end_index"] = offset
# # Leverage state_dict() and load_state_dict() to
# # recast preexisting per-param state tensors
# self.optimizer.load_state_dict(self.optimizer.state_dict())
# # >>>
# # pax(0, {
# # "world_shard_infos" : self.world_shard_infos,
# # **{
# # "world_shard_infos / %d" % i : r
# # for i, r in enumerate(self.world_shard_infos)
# # },
# # })
# # <<<
@classmethod
# def get_ddp_gbuf_param_shards(cls, model, dtype, gbuf_start):
def get_ddp_gbuf_param_shard_map(cls, model, dtype, gbuf_start):
param_shard_map = {}
for param, indexes in \
model._grad_buffer_param_index_map[dtype].items():
param_gbuf_start, param_gbuf_end = indexes
param_shard_start = max(
0,
param_gbuf_start - shard_start)
param_shard_end = min(
shard_end,
param_gbuf_end - shard_start)
if param_shard_end > param_shard_start:
dtype_info["grad_buffer_param_shards"][param] = {
"gbuf_start" : param_gbuf_start,
"shard_start" : param_shard_start,
"size" : param_shard_end - param_shard_start,
}
# pax(0, {
# "param" : param,
# "indexes" : indexes,
# "param_gbuf_start" : param_gbuf_start,
# "param_gbuf_end" : param_gbuf_end,
# "param_shard_start" : param_shard_start,
# "param_shard_end" : param_shard_end,
# })
pax(0, {"param_shard_map": param_shard_map})
return param_shard_map
@classmethod
def get_ddp_gbuf_shard(cls, model, dtype):
# Per-dtype info.
dtype_info = {}
model_info[dtype] = dtype_info
# Grad buffer shard.
model_param_size = grad_buffer.numel
max_world_shard_size = int(math.ceil(
model_param_size / self.data_parallel_world_size))
shard_start = rank * max_world_shard_size
shard_end = min(model_param_size,
shard_start + max_world_shard_size)
dtype_info["grad_buffer_shard"] = {
"start" : shard_start,
"end" : shard_end,
"size" : shard_end - shard_start,
}
# Grad buffer param shards.
dtype_info["grad_buffer_param_shards"] = self.get_ddp_gbuf_param_shards()
pax(0, { "grad_buffer_param_shards" : [
str((str(tuple(p.shape)), i))
for p,i in dtype_info["grad_buffer_param_shards"].items()
]})
return ddp_gbuf_shard
@classmethod
# def get_ddp_gbuf_shards(cls, model):
def get_ddp_gbuf_shard_map(cls, model):
shard_map = {
dtype : cls.get_ddp_gbuf_shard(model, dtype)
for dtype in model._grad_buffers
}
pax(0, {"shard_map": shard_map})
return shard_map
def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad, def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, use_contiguous_buffers_in_local_ddp, params_have_main_grad, use_contiguous_buffers_in_local_ddp,
bf16, grad_scaler): bf16, grad_scaler, models):
super().__init__( super().__init__(
optimizer, clip_grad, log_num_zeros_in_grad, optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, use_contiguous_buffers_in_local_ddp, params_have_main_grad, use_contiguous_buffers_in_local_ddp,
bf16, grad_scaler) bf16, grad_scaler, models)
# >>>
# self.test_reduce_scatter()
# <<<
# >>> # >>>
args = get_args() args = get_args()
assert args.use_contiguous_buffers_in_local_ddp # already checked in args
# <<< # <<<
# pax(0, {"models": models})
# Data parallel info. # Data parallel info.
self.data_parallel_group = mpu.get_data_parallel_group() self.data_parallel_group = mpu.get_data_parallel_group()
self.data_parallel_rank = mpu.get_data_parallel_rank() self.data_parallel_rank = mpu.get_data_parallel_rank()
self.data_parallel_world_size = mpu.get_data_parallel_world_size() self.data_parallel_world_size = mpu.get_data_parallel_world_size()
# Total trainable param count. # Param group map.
# self.total_param_size = sum( self.param_group_map = {}
# p.numel() for group_index, group in enumerate(self.optimizer.param_groups):
# for g in self.param_groups for param in group["params"]:
# for p in g["params"] assert param.requires_grad
# # if p .requires_grad ??? self.param_group_map[param] = group_index
# )
# Model params: group sizes, group offset maps.
# self.model_params = []
# self.model_param_group_sizes = []
# self.model_param_group_offset_maps = []
self.model_param_groups = []
for param_group in self.optimizer.param_groups:
param_group_offset = 0
param_group_offset_map = {}
for param in param_group['params']:
if not param.requires_grad:
continue
# self.model_params.append(param)
param_group_offset_map[param] = {
"start" : param_group_offset,
"end" : param_group_offset + param.numel(),
}
param_group_offset += param.numel()
# self.model_param_group_sizes.append(param_group_offset)
# self.model_param_group_offset_maps.append(param_group_offset_map)
self.model_param_groups.append({
"size" : param_group_offset,
"offset_map" : param_group_offset_map,
})
# pax(0, { # pax(0, {"param_group_map": [
# "model_params" : model_params, # (g, str(p.shape))
# "model_param_group_sizes" : model_param_group_sizes, # for p, g in self.param_group_map.items()
# "model_param_group_offset_maps" : model_param_group_offset_maps, # ]})
# })
# Shard allocator. # Shard allocator.
# ** torch.nn.Parameter ?? # ** torch.nn.Parameter ??
...@@ -766,154 +1043,28 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -766,154 +1043,28 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
device = torch.cuda.current_device(), device = torch.cuda.current_device(),
requires_grad = True) requires_grad = True)
# Allocate shards. # World shard infos.
# (Also, collect world DP shard info.) self.world_shard_infos = []
# model_main_dtypes = set([ args.params_dtype, torch.float ]) for rank in range(self.data_parallel_world_size):
model_main_dtypes = set([ torch.float ]) # fp32 only, for now
# self.world_shard_info_groups = [] # world_group_shard_infos ?
# self.main_param_shard_groups = []
self.world_shard_infos = [{"groups": []} for _ in self.model_param_groups]
for group_index, model_param_group in enumerate(self.model_param_groups):
# Max world shard size.
model_param_size = model_param_group["size"]
max_world_shard_size = int(math.ceil(model_param_size /
self.data_parallel_world_size))
# DP world shard infos.
# world_shard_infos = []
for r in range(self.data_parallel_world_size):
shard_start_index = r * max_world_shard_size
shard_end_index = min(model_param_size,
shard_start_index + max_world_shard_size)
# world_shard_infos.append({
self.world_shard_infos[r]["groups"].append({
"start" : shard_start_index,
"end" : shard_end_index,
"size" : shard_end_index - shard_start_index,
})
# self.world_shard_info_groups.append(world_shard_infos)
# self.world_shard_infos[group_index].append(world_shard_infos)
# DP local rank's shard info.
# local_shard_info = world_shard_infos[self.data_parallel_rank]
local_shard_info = \
self.world_shard_infos[self.data_parallel_rank]["groups"][-1]
local_shard_start_index = local_shard_info["start"]
local_shard_end_index = local_shard_info["end"]
local_shard_size = local_shard_info["size"]
# Local shard's param 'slice' index map.
local_shard_info["param_slice_index_map"] = {}
for param, offset_dict in model_param_group["offset_map"].items():
# param_start_index = offset_dict["start"]
# param_end_index = offset_dict["end"]
# param_shard_start_index = max(local_shard_start_index,
# param_start_index)
# param_shard_end_index = min(local_shard_end_index,
# param_end_index)
orig_start_index = offset_dict["start"]
orig_end_index = offset_dict["end"]
shard_start_index = max(
0,
orig_start_index - local_shard_start_index)
shard_end_index = min(
local_shard_end_index,
orig_end_index - local_shard_start_index)
# if param_shard_end_index > param_shard_start_index:
# # Indexes are relative to local shard start index.
# # local_shard_info["param_index_map"][param] = {
# # "param" : (
# # param_shard_start_index,
# # param_shard_end_index,
# # ),
# # "shard" : (
# # param_shard_start_index - local_shard_start_index,
# # param_shard_end_index - local_shard_start_index,
# # ),
# # }
# local_shard_info["param_slice_index_map"][param] = {
# "param_start" :
# param_shard_start_index,
# "shard_start" :
# param_shard_start_index - local_shard_start_index,
# "size":
# param_shard_end_index - param_shard_start_index,
# }
if shard_end_index > shard_start_index:
local_shard_info["param_slice_index_map"][param] = {
"orig_start" : orig_start_index,
"shard_start" : shard_start_index,
"size" : shard_end_index - shard_start_index,
}
# pax(0, {
# "local index" : "%d, %d" % (
# local_shard_start_index,
# local_shard_end_index,
# ),
# "param index" : "%s, %d" % (
# param_start_index,
# param_end_index,
# ),
# "param" : tp(param),
# "shard_param_index_map" : shard_param_index_map,
# "local_shard_info" : local_shard_info,
# })
# pax(2, {
# "data_parallel_rank" : self.data_parallel_rank,
# "local_shard_info" : local_shard_info,
# "param_index_map " : [
# (str(p.shape), i)
# for p, i in local_shard_info["param_index_map"].items()
# ],
# })
# Allocate shards.
# (Non-fp32 shards are for convenience; e.g., intermediaries
# between model params and main fp32 shard. Necessary???)
# main_param_shards = {
# ty : allocate_shard(local_shard_size, ty)
# for ty in model_main_dtypes}
main_param_shards = {}
for dtype in model_main_dtypes:
main_param = allocate_shard(local_shard_size, dtype)
main_param.grad = allocate_shard(local_shard_size, dtype)
# pax(0, {"main_param": main_param})
main_param_shards[dtype] = main_param
# self.main_param_shard_groups.append(main_param_shards)
local_shard_info["data"] = main_param_shards
# Update optimizer group.
self.optimizer.param_groups[group_index]["params"] = \
[ main_param_shards[torch.float] ]
# pax(0, { # Per-rank info.
# "param_groups" : self.optimizer.param_groups, rank_info = []
# "params" : self.optimizer.param_groups[group_index]["params"], self.world_shard_infos.append(rank_info)
# }) for model_index, model in enumerate(self.models):
# Add world start/end indexes, for reduce/gather steps. # Per-virtual-model info.
offset = 0 # model_info = {}
for r in self.world_shard_infos: # rank_info.append(model_info)
r["start_index"] = offset ddp_gbuf_shards = self.get_ddp_gbuf_shards(model)
offset += sum(g["size"] for g in r["groups"])
r["end_index"] = offset
# Leverage state_dict() and load_state_dict() to # Leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors # recast preexisting per-param state tensors
self.optimizer.load_state_dict(self.optimizer.state_dict()) self.optimizer.load_state_dict(self.optimizer.state_dict())
# >>> # >>>
# pax(0, { pax(0, {
# "world_shard_infos" : self.world_shard_infos, "world_shard_infos" : self.world_shard_infos,
# **{ })
# "world_shard_infos / %d" % i : r
# for i, r in enumerate(self.world_shard_infos)
# },
# })
# <<< # <<<
# def get_loss_scale(self): # def get_loss_scale(self):
......
...@@ -365,8 +365,12 @@ def setup_model_and_optimizer(model_provider_func, ...@@ -365,8 +365,12 @@ def setup_model_and_optimizer(model_provider_func,
unwrapped_model = unwrap_model(model, unwrapped_model = unwrap_model(model,
(torchDDP, LocalDDP, Float16Module)) (torchDDP, LocalDDP, Float16Module))
optimizer = get_megatron_optimizer(unwrapped_model, no_wd_decay_cond, # >>>
# optimizer = get_megatron_optimizer(unwrapped_model, no_wd_decay_cond,
# scale_lr_cond, lr_mult)
optimizer = get_megatron_optimizer(model, no_wd_decay_cond,
scale_lr_cond, lr_mult) scale_lr_cond, lr_mult)
# <<<
opt_param_scheduler = get_optimizer_param_scheduler(optimizer) opt_param_scheduler = get_optimizer_param_scheduler(optimizer)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment