Commit 4b843668 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

fixed param_world_shard bug.

parent c13c0a3e
...@@ -756,7 +756,9 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -756,7 +756,9 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# Add shard, if within range. # Add shard, if within range.
if param_local_end > param_local_start: if param_local_end > param_local_start:
param_local_shard = Shard(param_local_start, param_local_end) param_local_shard = Shard(param_local_start, param_local_end)
param_world_shard = param_local_shard.normalize(param_world_start) # param_world_shard = param_local_shard.normalize(param_world_start)
param_world_shard = param_local_shard.normalize(
param_local_start + gbuf_world_shard.start)
sub_param_start = max(0, gbuf_world_shard.start-param_world_start) sub_param_start = max(0, gbuf_world_shard.start-param_world_start)
sub_param_shard = param_local_shard.normalize(sub_param_start) sub_param_shard = param_local_shard.normalize(sub_param_start)
param_shard_map[param] = { param_shard_map[param] = {
...@@ -764,6 +766,10 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -764,6 +766,10 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
"gbuf_local" : param_local_shard, "gbuf_local" : param_local_shard,
"param" : sub_param_shard, "param" : sub_param_shard,
} }
pax(1, {
"gbuf_world_shard" : gbuf_world_shard,
"param shards" : param_shard_map[param],
})
# >>> # >>>
# if param_world_start < gbuf_world_shard.start: # if param_world_start < gbuf_world_shard.start:
# pax({"param shards": param_shard_map[param]}) # pax({"param shards": param_shard_map[param]})
...@@ -806,7 +812,7 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -806,7 +812,7 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
"param_map" : param_shard_map, "param_map" : param_shard_map,
} }
# pax(0, {"data": data}) # pax(1, {"data": data})
return data return data
...@@ -1155,9 +1161,11 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -1155,9 +1161,11 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# ], # ],
# }) # })
pax(1, { pax(1, {
"data_parallel_rank" : data_parallel_rank,
"main params" : self.get_main_params(), "main params" : self.get_main_params(),
"model params / world" : self.get_world_model_params(), # "model params / world" : self.get_world_model_params(),
"gbuf_view_item" : tp(gbuf_view[data_parallel_rank]), **{"gbuf_view_items / %d"%i:v[2] for i,v in enumerate(gbuf_view_items)},
# "gbuf_view_item" : tp(gbuf_view[data_parallel_rank]),
# "model params / local" : self.get_local_model_param_views(), # "model params / local" : self.get_local_model_param_views(),
}) })
...@@ -1311,19 +1319,19 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -1311,19 +1319,19 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
model_view.detach().copy_(main_view) model_view.detach().copy_(main_view)
# Debug. # Debug.
# pax(0, { pax(1, {
# "group_index" : group_index, "group_index" : group_index,
# "group_shard" : group_shard, "group_shard" : group_shard,
# "param" : tp(param), "model_param" : tp(model_param),
# "model_index" : model_index, "model_index" : model_index,
# "gbuf_dtype" : str(gbuf_dtype), "dtype" : str(dtype),
# "model_grad_tensor" : tp(model_grad_tensor), "model_param" : tp(model_param),
# "main_grad_tensor" : tp(main_grad_tensor), "main_param" : tp(main_param),
# "model_grad_view" : tp(model_grad_view), "model_view" : tp(model_view),
# "main_grad_view" : tp(main_grad_view), "main_view" : tp(main_view),
# "model_shard" : str(model_shard), "model_shard" : str(model_shard),
# "main_shard" : str(main_shard), "main_shard" : str(main_shard),
# }) })
# pax(0, { # pax(0, {
# "model_gbuf_shards" : self.model_gbuf_shards, # "model_gbuf_shards" : self.model_gbuf_shards,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment