Commit bf64c85c authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

added 'all' world shards to gbuf map

parent eaa0c1df
...@@ -977,9 +977,13 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -977,9 +977,13 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
gbuf_size = grad_buffer.numel gbuf_size = grad_buffer.numel
max_gbuf_shard_size = int(math.ceil(gbuf_size / data_parallel_world_size)) max_gbuf_shard_size = int(math.ceil(gbuf_size / data_parallel_world_size))
gbuf_world_start = data_parallel_rank * max_gbuf_shard_size gbuf_world_all_shards = []
gbuf_world_end = min(gbuf_size, gbuf_world_start + max_gbuf_shard_size) for r in range(data_parallel_world_size):
gbuf_world_start = r * max_gbuf_shard_size
gbuf_world_end = min(gbuf_size, gbuf_world_start+max_gbuf_shard_size)
gbuf_world_shard = Shard(gbuf_world_start, gbuf_world_end) gbuf_world_shard = Shard(gbuf_world_start, gbuf_world_end)
gbuf_world_all_shards.append(gbuf_world_shard)
gbuf_world_shard = gbuf_world_all_shards[data_parallel_rank]
gbuf_local_shard = gbuf_world_shard.normalize() gbuf_local_shard = gbuf_world_shard.normalize()
# gbuf_local_shard = Shard(0, gbuf_world_index.size) # gbuf_local_shard = Shard(0, gbuf_world_index.size)
...@@ -992,6 +996,7 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -992,6 +996,7 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
data = { data = {
"local" : gbuf_local_shard, "local" : gbuf_local_shard,
"world" : gbuf_world_shard, "world" : gbuf_world_shard,
"world_all" : gbuf_world_all_shards,
"param_map" : param_shard_map, "param_map" : param_shard_map,
} }
...@@ -1302,11 +1307,16 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -1302,11 +1307,16 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# ** only contiguous grad buffer supported, for now [ TEMPORARY ] ** # ** only contiguous grad buffer supported, for now [ TEMPORARY ] **
assert args.use_contiguous_buffers_in_local_ddp assert args.use_contiguous_buffers_in_local_ddp
for model_index, model in enuemrate(self.models): for model_index, model in enumerate(self.models):
for dtype, gbuf_shard in self.model_gbuf_shards[model_index].items():
world_shards = gbuf_shard["world_all"]
pax(0, { pax(0, {
"model_index" : model_index, "model_index" : model_index,
"model" : model, "model" : model,
"dtype" : str(dtype),
"gbuf_shard" : gbuf_shard,
"world_shards" : world_shards,
}) })
world_sizes = [] world_sizes = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment