Commit a74e245c authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

built local shard param index map

parent f7232502
...@@ -773,6 +773,13 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -773,6 +773,13 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
self.main_param_shard_groups = [] self.main_param_shard_groups = []
for group_index, model_param_group in enumerate(self.model_param_groups): for group_index, model_param_group in enumerate(self.model_param_groups):
# pax(0, {
# "model_param_group" : model_param_group,
# # "offset_map" : {str(p.shape):o for p, o in model_param_group["offset_map"].items()},
# "offset_map" : [(o,tp(p)) for p, o in model_param_group["offset_map"].items()],
# })
# Group sizes.
model_param_size = model_param_group["size"] model_param_size = model_param_group["size"]
max_world_shard_size = int(math.ceil(model_param_size / max_world_shard_size = int(math.ceil(model_param_size /
self.data_parallel_world_size)) self.data_parallel_world_size))
...@@ -790,20 +797,49 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -790,20 +797,49 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
}) })
self.world_shard_info_groups.append(world_shard_infos) self.world_shard_info_groups.append(world_shard_infos)
# pax(0, {"world_shard_infos": world_shard_infos}) # DP local shard info.
local_shard_info = world_shard_infos[self.data_parallel_rank]
local_shard_start_index = local_shard_info["start"]
local_shard_end_index = local_shard_info["end"]
local_shard_size = local_shard_info["size"]
# Shard param index map.
local_shard_info["param_index_map"] = {}
for param, offset_dict in model_param_group["offset_map"].items():
param_start_index = offset_dict["start"]
param_end_index = offset_dict["end"]
param_shard_start_index = max(local_shard_start_index,
param_start_index)
param_shard_end_index = min(local_shard_end_index,
param_end_index)
if param_shard_end_index > param_shard_start_index:
local_shard_info["param_index_map"][param] = {
"start" :
param_shard_start_index - local_shard_start_index,
"end" :
param_shard_end_index - local_shard_start_index,
}
# pax(0, {
# "local index" : "%d, %d" % (
# local_shard_start_index,
# local_shard_end_index,
# ),
# "param index" : "%s, %d" % (
# param_start_index,
# param_end_index,
# ),
# "param" : tp(param),
# "shard_param_index_map" : shard_param_index_map,
# "local_shard_info" : local_shard_info,
# })
pax(0, {"local_shard_info": local_shard_info})
# Allocate shards. # Allocate shards.
# (Non-fp32 shards are for convenience; e.g., intermediaries # (Non-fp32 shards are for convenience; e.g., intermediaries
# between model params and main fp32 shard. Necessary???) # between model params and main fp32 shard. Necessary???)
local_shard_size = world_shard_infos[self.data_parallel_rank]["size"]
# # self.main_param_shard = allocate_shard(torch.float)
# # self.main_grad_shard = allocate_shard(torch.float)
# self.param_shard_map = {ty:allocate_shard(ty) for ty in dtypes}
# self.grad_shard_map = {ty:allocate_shard(ty) for ty in dtypes}
# self.adam_m_shard = allocate_shard(torch.float)
# self.adam_v_shard = allocate_shard(torch.float)
main_param_shards = { main_param_shards = {
ty : allocate_shard(local_shard_size, ty) ty : allocate_shard(local_shard_size, ty)
for ty in model_main_dtypes} for ty in model_main_dtypes}
...@@ -863,9 +899,17 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -863,9 +899,17 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Copy model grads to main shard. # Copy model grads to main shard.
self.world_shard_info_groups = [] # world_group_shard_infos ? local_shard_info_groups = [g[self.data_parallel_rank]
self.main_param_shard_groups = [] for g in self.world_shard_info_groups]
pax(0, {"main_shard_info_groups": self.main_shard_info_groups})
pax(0, {
# "world_shard_info_groups" : self.world_shard_info_groups,
# **{"world_shard_info_groups / %d" % i : v
# for i, v in enumerate(self.world_shard_info_groups)},
"local_shard_info_groups" : local_shard_info_groups,
"main_param_shard_groups" : self.main_param_shard_groups,
# "main_param_shard_groups" : self.main_param_shard_groups,
})
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Reduce-scatter. # Reduce-scatter.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment