"vscode:/vscode.git/clone" did not exist on "ff0dfb74d76872bcbcbadb6e1e52c0dcb00bb4ce"
Commit 3ded2425 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

included original param index in map

parent a74e245c
...@@ -779,7 +779,7 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -779,7 +779,7 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# "offset_map" : [(o,tp(p)) for p, o in model_param_group["offset_map"].items()], # "offset_map" : [(o,tp(p)) for p, o in model_param_group["offset_map"].items()],
# }) # })
# Group sizes. # Max world shard size.
model_param_size = model_param_group["size"] model_param_size = model_param_group["size"]
max_world_shard_size = int(math.ceil(model_param_size / max_world_shard_size = int(math.ceil(model_param_size /
self.data_parallel_world_size)) self.data_parallel_world_size))
...@@ -797,13 +797,13 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -797,13 +797,13 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
}) })
self.world_shard_info_groups.append(world_shard_infos) self.world_shard_info_groups.append(world_shard_infos)
# DP local shard info. # DP local rank's shard info.
local_shard_info = world_shard_infos[self.data_parallel_rank] local_shard_info = world_shard_infos[self.data_parallel_rank]
local_shard_start_index = local_shard_info["start"] local_shard_start_index = local_shard_info["start"]
local_shard_end_index = local_shard_info["end"] local_shard_end_index = local_shard_info["end"]
local_shard_size = local_shard_info["size"] local_shard_size = local_shard_info["size"]
# Shard param index map. # Local shard's param index map.
local_shard_info["param_index_map"] = {} local_shard_info["param_index_map"] = {}
for param, offset_dict in model_param_group["offset_map"].items(): for param, offset_dict in model_param_group["offset_map"].items():
param_start_index = offset_dict["start"] param_start_index = offset_dict["start"]
...@@ -814,11 +814,16 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -814,11 +814,16 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
param_end_index) param_end_index)
if param_shard_end_index > param_shard_start_index: if param_shard_end_index > param_shard_start_index:
# Indexes are relative to local shard start index.
local_shard_info["param_index_map"][param] = { local_shard_info["param_index_map"][param] = {
"start" : "param" : (
param_shard_start_index - local_shard_start_index, param_shard_start_index,
"end" : param_shard_end_index,
param_shard_end_index - local_shard_start_index, ),
"shard" : (
param_shard_start_index - local_shard_start_index,
param_shard_end_index - local_shard_start_index,
),
} }
# pax(0, { # pax(0, {
...@@ -835,7 +840,14 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -835,7 +840,14 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# "local_shard_info" : local_shard_info, # "local_shard_info" : local_shard_info,
# }) # })
pax(0, {"local_shard_info": local_shard_info}) # pax(2, {
# "data_parallel_rank" : self.data_parallel_rank,
# "local_shard_info" : local_shard_info,
# "param_index_map " : {
# str(p.shape) : i
# for p, i in local_shard_info["param_index_map"].items()
# },
# })
# Allocate shards. # Allocate shards.
# (Non-fp32 shards are for convenience; e.g., intermediaries # (Non-fp32 shards are for convenience; e.g., intermediaries
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment