Commit f498a6fe authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

modularized shard indexing

parent cb6f96b6
......@@ -661,6 +661,17 @@ from megatron import get_args
from lutil import pax, tp
# <<<
# class ShardIndex:
class Shard:
def __init__(self, start, end):
self.start = start
self.end = end
self.size = end - start
def normalize(self, start = 0):
return Shard(start, start + self.size)
def __str__(self):
return "%d,%d [%d]" % (self.start, self.end, self.size)
# class Float16DistributedOptimizer(Float16OptimizerWithFloat16Params):
# class Float16DistributedOptimizer(MegatronOptimizer):
class Float16DistributedOptimizer(BaseFloat16Optimizer):
......@@ -921,83 +932,87 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# # <<<
@classmethod
# def get_ddp_gbuf_param_shards(cls, model, dtype, gbuf_start):
def get_ddp_gbuf_param_shard_map(cls, model, dtype, gbuf_start):
# def get_ddp_gbuf_param_shard_map(cls, model, dtype, gbuf_start):
# def get_model_gbuf_param_shard_index_map(cls,model,dtype,gbuf_world_index):
def get_model_gbuf_param_shard_map(cls, model, dtype, gbuf_world_shard):
# Param shard map.
param_world_index_map = model._grad_buffer_param_index_map[dtype]
param_shard_map = {}
for param, indexes in \
model._grad_buffer_param_index_map[dtype].items():
for param, param_world_indexes in param_world_index_map.items():
param_gbuf_start, param_gbuf_end = indexes
param_shard_start = max(
# Shard range.
param_world_start, param_world_end = param_world_indexes
param_local_start = max(
0,
param_gbuf_start - shard_start)
param_shard_end = min(
shard_end,
param_gbuf_end - shard_start)
if param_shard_end > param_shard_start:
dtype_info["grad_buffer_param_shards"][param] = {
"gbuf_start" : param_gbuf_start,
"shard_start" : param_shard_start,
"size" : param_shard_end - param_shard_start,
param_world_start - gbuf_world_shard.start)
param_local_end = min(
gbuf_world_shard.size,
param_world_end - gbuf_world_shard.start)
# Add shard, if within range.
if param_local_end > param_local_start:
param_local_shard = Shard(param_local_start, param_local_end)
param_world_shard = param_local_shard.normalize(param_world_start)
param_shard_map[param] = {
"local" : param_local_shard,
"world" : param_world_shard,
}
# pax(0, {
# "param" : param,
# "indexes" : indexes,
# "param_gbuf_start" : param_gbuf_start,
# "param_gbuf_end" : param_gbuf_end,
# "param_shard_start" : param_shard_start,
# "param_shard_end" : param_shard_end,
# })
pax(0, {"param_shard_map": param_shard_map})
# pax(0, {"param_shard_map": [ str((str(p.shape), s)) for p,s in param_shard_map.items() ]})
return param_shard_map
@classmethod
def get_ddp_gbuf_shard(cls, model, dtype):
# def get_ddp_gbuf_shard(cls, model, dtype):
# def get_model_gbuf_shard(cls, model, dtype):
# def get_model_gbuf_shard_index(cls, model, dtype):
def get_model_gbuf_shard(cls, model, dtype):
# Per-dtype info.
dtype_info = {}
model_info[dtype] = dtype_info
data_parallel_rank = mpu.get_data_parallel_rank()
data_parallel_world_size = mpu.get_data_parallel_world_size()
# Grad buffer shard.
model_param_size = grad_buffer.numel
max_world_shard_size = int(math.ceil(
model_param_size / self.data_parallel_world_size))
shard_start = rank * max_world_shard_size
shard_end = min(model_param_size,
shard_start + max_world_shard_size)
dtype_info["grad_buffer_shard"] = {
"start" : shard_start,
"end" : shard_end,
"size" : shard_end - shard_start,
grad_buffer = model._grad_buffers[dtype]
gbuf_size = grad_buffer.numel
max_gbuf_shard_size = int(math.ceil(gbuf_size / data_parallel_world_size))
gbuf_world_start = data_parallel_rank * max_gbuf_shard_size
gbuf_world_end = min(gbuf_size, gbuf_world_start + max_gbuf_shard_size)
gbuf_world_shard = Shard(gbuf_world_start, gbuf_world_end)
gbuf_local_shard = gbuf_world_shard.normalize()
# gbuf_local_shard = Shard(0, gbuf_world_index.size)
# Param shards.
param_shard_map = cls.get_model_gbuf_param_shard_map(model,
dtype,
gbuf_world_shard)
# Altogether.
data = {
"local" : gbuf_local_shard,
"world" : gbuf_world_shard,
"param_map" : param_shard_map,
}
# Grad buffer param shards.
dtype_info["grad_buffer_param_shards"] = self.get_ddp_gbuf_param_shards()
# pax(0, {"data": data})
pax(0, { "grad_buffer_param_shards" : [
str((str(tuple(p.shape)), i))
for p,i in dtype_info["grad_buffer_param_shards"].items()
]})
return ddp_gbuf_shard
return data
@classmethod
# def get_ddp_gbuf_shards(cls, model):
def get_ddp_gbuf_shard_map(cls, model):
# def get_ddp_gbuf_shard_map(cls, model):
# def get_model_gbuf_shard_map(cls, model):
# def get_model_gbuf_shard_index_map(cls, model):
def get_model_gbuf_shard_map(cls, model):
# shard_index_map = {
shard_map = {
dtype : cls.get_ddp_gbuf_shard(model, dtype)
dtype : cls.get_model_gbuf_shard(model, dtype)
for dtype in model._grad_buffers
}
pax(0, {"shard_map": shard_map})
# pax(0, {"shard_map": shard_map})
return shard_map
......@@ -1017,10 +1032,10 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# pax(0, {"models": models})
# Data parallel info.
self.data_parallel_group = mpu.get_data_parallel_group()
self.data_parallel_rank = mpu.get_data_parallel_rank()
self.data_parallel_world_size = mpu.get_data_parallel_world_size()
# # Data parallel info.
# self.data_parallel_group = mpu.get_data_parallel_group()
# self.data_parallel_rank = mpu.get_data_parallel_rank()
# self.data_parallel_world_size = mpu.get_data_parallel_world_size()
# Param group map.
self.param_group_map = {}
......@@ -1037,25 +1052,24 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# Shard allocator.
# ** torch.nn.Parameter ??
# ** MemoryBuffer ??
allocate_shard = lambda shard_size, dtype : torch.empty(
(shard_size,),
dtype = dtype,
device = torch.cuda.current_device(),
requires_grad = True)
# World shard infos.
self.world_shard_infos = []
for rank in range(self.data_parallel_world_size):
# Per-rank info.
rank_info = []
self.world_shard_infos.append(rank_info)
for model_index, model in enumerate(self.models):
# Per-virtual-model info.
# model_info = {}
# rank_info.append(model_info)
ddp_gbuf_shards = self.get_ddp_gbuf_shards(model)
# allocate_shard = lambda shard_size, dtype : torch.empty(
# (shard_size,),
# dtype = dtype,
# device = torch.cuda.current_device(),
# requires_grad = True)
# Model grad buffer shards.
self.model_gbuf_shards = []
for model_index, model in enumerate(self.models):
self.model_gbuf_shards.append(self.get_model_gbuf_shard_map(model))
# Allocate main param/grad shard.
param_shard_map = self.get_param_shard_map(self.model_gbuf_shards)
pax(0, {
"model_gbuf_shards" : self.model_gbuf_shards,
"param_shard_map" : param_shard_map,
})
# Leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment