Commit ff4c4273 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

reduce scatter working

parent bf64c85c
......@@ -1307,30 +1307,36 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# ** only contiguous grad buffer supported, for now [ TEMPORARY ] **
assert args.use_contiguous_buffers_in_local_ddp
data_parallel_rank = mpu.get_data_parallel_rank()
data_parallel_group = mpu.get_data_parallel_group()
for model_index, model in enumerate(self.models):
for dtype, gbuf_shard in self.model_gbuf_shards[model_index].items():
world_shards = gbuf_shard["world_all"]
pax(0, {
"model_index" : model_index,
"model" : model,
"dtype" : str(dtype),
"gbuf_shard" : gbuf_shard,
"world_shards" : world_shards,
})
gbuf = model._grad_buffers[dtype]
gbuf_views = []
for shard in world_shards:
gbuf_views.append(gbuf.data[shard.start:shard.end])
torch.distributed.reduce_scatter(
gbuf_views[data_parallel_rank],
gbuf_views,
group = data_parallel_group,
)
world_sizes = []
for r in self.world_shard_infos:
# world_sizes.append(sum(g["size"] for g in r))
world_sizes.append([ g["size"] for g in r["groups"] ])
# grad_refs ...
pax(0, {"world_sizes": world_sizes})
# for world_grads = []
# for world_shard_info_group
# x ?
# pax(0, {
# "model_index" : model_index,
# "model" : model,
# "dtype" : str(dtype),
# "gbuf_shard" : gbuf_shard,
# "world_shards" : world_shards,
# "gbuf_views" : gbuf_views,
# })
# >>>
torch.distributed.barrier()
raise Exception("hi.")
# <<<
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment