Commit f498a6fe authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

modularized shard indexing

parent cb6f96b6
...@@ -661,6 +661,17 @@ from megatron import get_args ...@@ -661,6 +661,17 @@ from megatron import get_args
from lutil import pax, tp from lutil import pax, tp
# <<< # <<<
# class ShardIndex:
class Shard:
def __init__(self, start, end):
self.start = start
self.end = end
self.size = end - start
def normalize(self, start = 0):
return Shard(start, start + self.size)
def __str__(self):
return "%d,%d [%d]" % (self.start, self.end, self.size)
# class Float16DistributedOptimizer(Float16OptimizerWithFloat16Params): # class Float16DistributedOptimizer(Float16OptimizerWithFloat16Params):
# class Float16DistributedOptimizer(MegatronOptimizer): # class Float16DistributedOptimizer(MegatronOptimizer):
class Float16DistributedOptimizer(BaseFloat16Optimizer): class Float16DistributedOptimizer(BaseFloat16Optimizer):
...@@ -921,83 +932,87 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -921,83 +932,87 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# # <<< # # <<<
@classmethod @classmethod
# def get_ddp_gbuf_param_shards(cls, model, dtype, gbuf_start): # def get_ddp_gbuf_param_shards(cls, model, dtype, gbuf_start):
def get_ddp_gbuf_param_shard_map(cls, model, dtype, gbuf_start): # def get_ddp_gbuf_param_shard_map(cls, model, dtype, gbuf_start):
# def get_model_gbuf_param_shard_index_map(cls,model,dtype,gbuf_world_index):
def get_model_gbuf_param_shard_map(cls, model, dtype, gbuf_world_shard):
# Param shard map.
param_world_index_map = model._grad_buffer_param_index_map[dtype]
param_shard_map = {} param_shard_map = {}
for param, indexes in \ for param, param_world_indexes in param_world_index_map.items():
model._grad_buffer_param_index_map[dtype].items():
param_gbuf_start, param_gbuf_end = indexes # Shard range.
param_shard_start = max( param_world_start, param_world_end = param_world_indexes
param_local_start = max(
0, 0,
param_gbuf_start - shard_start) param_world_start - gbuf_world_shard.start)
param_shard_end = min( param_local_end = min(
shard_end, gbuf_world_shard.size,
param_gbuf_end - shard_start) param_world_end - gbuf_world_shard.start)
if param_shard_end > param_shard_start: # Add shard, if within range.
dtype_info["grad_buffer_param_shards"][param] = { if param_local_end > param_local_start:
"gbuf_start" : param_gbuf_start, param_local_shard = Shard(param_local_start, param_local_end)
"shard_start" : param_shard_start, param_world_shard = param_local_shard.normalize(param_world_start)
"size" : param_shard_end - param_shard_start, param_shard_map[param] = {
"local" : param_local_shard,
"world" : param_world_shard,
} }
# pax(0, { # pax(0, {"param_shard_map": [ str((str(p.shape), s)) for p,s in param_shard_map.items() ]})
# "param" : param,
# "indexes" : indexes,
# "param_gbuf_start" : param_gbuf_start,
# "param_gbuf_end" : param_gbuf_end,
# "param_shard_start" : param_shard_start,
# "param_shard_end" : param_shard_end,
# })
pax(0, {"param_shard_map": param_shard_map})
return param_shard_map return param_shard_map
@classmethod @classmethod
def get_ddp_gbuf_shard(cls, model, dtype): # def get_ddp_gbuf_shard(cls, model, dtype):
# def get_model_gbuf_shard(cls, model, dtype):
# def get_model_gbuf_shard_index(cls, model, dtype):
def get_model_gbuf_shard(cls, model, dtype):
# Per-dtype info. data_parallel_rank = mpu.get_data_parallel_rank()
dtype_info = {} data_parallel_world_size = mpu.get_data_parallel_world_size()
model_info[dtype] = dtype_info
# Grad buffer shard. # Grad buffer shard.
model_param_size = grad_buffer.numel grad_buffer = model._grad_buffers[dtype]
max_world_shard_size = int(math.ceil( gbuf_size = grad_buffer.numel
model_param_size / self.data_parallel_world_size)) max_gbuf_shard_size = int(math.ceil(gbuf_size / data_parallel_world_size))
shard_start = rank * max_world_shard_size gbuf_world_start = data_parallel_rank * max_gbuf_shard_size
shard_end = min(model_param_size, gbuf_world_end = min(gbuf_size, gbuf_world_start + max_gbuf_shard_size)
shard_start + max_world_shard_size) gbuf_world_shard = Shard(gbuf_world_start, gbuf_world_end)
gbuf_local_shard = gbuf_world_shard.normalize()
dtype_info["grad_buffer_shard"] = { # gbuf_local_shard = Shard(0, gbuf_world_index.size)
"start" : shard_start,
"end" : shard_end, # Param shards.
"size" : shard_end - shard_start, param_shard_map = cls.get_model_gbuf_param_shard_map(model,
dtype,
gbuf_world_shard)
# Altogether.
data = {
"local" : gbuf_local_shard,
"world" : gbuf_world_shard,
"param_map" : param_shard_map,
} }
# Grad buffer param shards. # pax(0, {"data": data})
dtype_info["grad_buffer_param_shards"] = self.get_ddp_gbuf_param_shards()
pax(0, { "grad_buffer_param_shards" : [ return data
str((str(tuple(p.shape)), i))
for p,i in dtype_info["grad_buffer_param_shards"].items()
]})
return ddp_gbuf_shard
@classmethod @classmethod
# def get_ddp_gbuf_shards(cls, model): # def get_ddp_gbuf_shards(cls, model):
def get_ddp_gbuf_shard_map(cls, model): # def get_ddp_gbuf_shard_map(cls, model):
# def get_model_gbuf_shard_map(cls, model):
# def get_model_gbuf_shard_index_map(cls, model):
def get_model_gbuf_shard_map(cls, model):
# shard_index_map = {
shard_map = { shard_map = {
dtype : cls.get_ddp_gbuf_shard(model, dtype) dtype : cls.get_model_gbuf_shard(model, dtype)
for dtype in model._grad_buffers for dtype in model._grad_buffers
} }
pax(0, {"shard_map": shard_map}) # pax(0, {"shard_map": shard_map})
return shard_map return shard_map
...@@ -1017,10 +1032,10 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -1017,10 +1032,10 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# pax(0, {"models": models}) # pax(0, {"models": models})
# Data parallel info. # # Data parallel info.
self.data_parallel_group = mpu.get_data_parallel_group() # self.data_parallel_group = mpu.get_data_parallel_group()
self.data_parallel_rank = mpu.get_data_parallel_rank() # self.data_parallel_rank = mpu.get_data_parallel_rank()
self.data_parallel_world_size = mpu.get_data_parallel_world_size() # self.data_parallel_world_size = mpu.get_data_parallel_world_size()
# Param group map. # Param group map.
self.param_group_map = {} self.param_group_map = {}
...@@ -1037,25 +1052,24 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -1037,25 +1052,24 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# Shard allocator. # Shard allocator.
# ** torch.nn.Parameter ?? # ** torch.nn.Parameter ??
# ** MemoryBuffer ?? # ** MemoryBuffer ??
allocate_shard = lambda shard_size, dtype : torch.empty( # allocate_shard = lambda shard_size, dtype : torch.empty(
(shard_size,), # (shard_size,),
dtype = dtype, # dtype = dtype,
device = torch.cuda.current_device(), # device = torch.cuda.current_device(),
requires_grad = True) # requires_grad = True)
# World shard infos. # Model grad buffer shards.
self.world_shard_infos = [] self.model_gbuf_shards = []
for rank in range(self.data_parallel_world_size): for model_index, model in enumerate(self.models):
self.model_gbuf_shards.append(self.get_model_gbuf_shard_map(model))
# Per-rank info.
rank_info = [] # Allocate main param/grad shard.
self.world_shard_infos.append(rank_info) param_shard_map = self.get_param_shard_map(self.model_gbuf_shards)
for model_index, model in enumerate(self.models):
pax(0, {
# Per-virtual-model info. "model_gbuf_shards" : self.model_gbuf_shards,
# model_info = {} "param_shard_map" : param_shard_map,
# rank_info.append(model_info) })
ddp_gbuf_shards = self.get_ddp_gbuf_shards(model)
# Leverage state_dict() and load_state_dict() to # Leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors # recast preexisting per-param state tensors
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment