Commit 4b843668 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

fixed param_world_shard bug.

parent c13c0a3e
......@@ -756,7 +756,9 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# Add shard, if within range.
if param_local_end > param_local_start:
param_local_shard = Shard(param_local_start, param_local_end)
param_world_shard = param_local_shard.normalize(param_world_start)
# param_world_shard = param_local_shard.normalize(param_world_start)
param_world_shard = param_local_shard.normalize(
param_local_start + gbuf_world_shard.start)
sub_param_start = max(0, gbuf_world_shard.start-param_world_start)
sub_param_shard = param_local_shard.normalize(sub_param_start)
param_shard_map[param] = {
......@@ -764,6 +766,10 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
"gbuf_local" : param_local_shard,
"param" : sub_param_shard,
}
pax(1, {
"gbuf_world_shard" : gbuf_world_shard,
"param shards" : param_shard_map[param],
})
# >>>
# if param_world_start < gbuf_world_shard.start:
# pax({"param shards": param_shard_map[param]})
......@@ -806,7 +812,7 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
"param_map" : param_shard_map,
}
# pax(0, {"data": data})
# pax(1, {"data": data})
return data
......@@ -1155,9 +1161,11 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# ],
# })
pax(1, {
"data_parallel_rank" : data_parallel_rank,
"main params" : self.get_main_params(),
"model params / world" : self.get_world_model_params(),
"gbuf_view_item" : tp(gbuf_view[data_parallel_rank]),
# "model params / world" : self.get_world_model_params(),
**{"gbuf_view_items / %d"%i:v[2] for i,v in enumerate(gbuf_view_items)},
# "gbuf_view_item" : tp(gbuf_view[data_parallel_rank]),
# "model params / local" : self.get_local_model_param_views(),
})
......@@ -1311,19 +1319,19 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
model_view.detach().copy_(main_view)
# Debug.
# pax(0, {
# "group_index" : group_index,
# "group_shard" : group_shard,
# "param" : tp(param),
# "model_index" : model_index,
# "gbuf_dtype" : str(gbuf_dtype),
# "model_grad_tensor" : tp(model_grad_tensor),
# "main_grad_tensor" : tp(main_grad_tensor),
# "model_grad_view" : tp(model_grad_view),
# "main_grad_view" : tp(main_grad_view),
# "model_shard" : str(model_shard),
# "main_shard" : str(main_shard),
# })
pax(1, {
"group_index" : group_index,
"group_shard" : group_shard,
"model_param" : tp(model_param),
"model_index" : model_index,
"dtype" : str(dtype),
"model_param" : tp(model_param),
"main_param" : tp(main_param),
"model_view" : tp(model_view),
"main_view" : tp(main_view),
"model_shard" : str(model_shard),
"main_shard" : str(main_shard),
})
# pax(0, {
# "model_gbuf_shards" : self.model_gbuf_shards,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment