Commit 3ded2425 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

included original param index in map

parent a74e245c
......@@ -779,7 +779,7 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# "offset_map" : [(o,tp(p)) for p, o in model_param_group["offset_map"].items()],
# })
# Group sizes.
# Max world shard size.
model_param_size = model_param_group["size"]
max_world_shard_size = int(math.ceil(model_param_size /
self.data_parallel_world_size))
......@@ -797,13 +797,13 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
})
self.world_shard_info_groups.append(world_shard_infos)
# DP local shard info.
# DP local rank's shard info.
local_shard_info = world_shard_infos[self.data_parallel_rank]
local_shard_start_index = local_shard_info["start"]
local_shard_end_index = local_shard_info["end"]
local_shard_size = local_shard_info["size"]
# Shard param index map.
# Local shard's param index map.
local_shard_info["param_index_map"] = {}
for param, offset_dict in model_param_group["offset_map"].items():
param_start_index = offset_dict["start"]
......@@ -814,11 +814,16 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
param_end_index)
if param_shard_end_index > param_shard_start_index:
# Indexes are relative to local shard start index.
local_shard_info["param_index_map"][param] = {
"start" :
param_shard_start_index - local_shard_start_index,
"end" :
param_shard_end_index - local_shard_start_index,
"param" : (
param_shard_start_index,
param_shard_end_index,
),
"shard" : (
param_shard_start_index - local_shard_start_index,
param_shard_end_index - local_shard_start_index,
),
}
# pax(0, {
......@@ -835,7 +840,14 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# "local_shard_info" : local_shard_info,
# })
pax(0, {"local_shard_info": local_shard_info})
# pax(2, {
# "data_parallel_rank" : self.data_parallel_rank,
# "local_shard_info" : local_shard_info,
# "param_index_map " : {
# str(p.shape) : i
# for p, i in local_shard_info["param_index_map"].items()
# },
# })
# Allocate shards.
# (Non-fp32 shards are for convenience; e.g., intermediaries
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment