Commit e5db0fda authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

modularized reduce_gradients, gather params; trainin runs, but loss==nan

parent a7782b21
...@@ -1318,6 +1318,29 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -1318,6 +1318,29 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# "local_shard_info_groups" : [ g["data"] for g in local_shard_info_groups ], # "local_shard_info_groups" : [ g["data"] for g in local_shard_info_groups ],
# }) # })
def get_model_grad_buffer_dp_views(self):
# ** only contiguous grad buffer supported, for now [ TEMPORARY ] **
args = get_args()
assert args.use_contiguous_buffers_in_local_ddp
# Grad buffer views.
gbuf_view_items = []
for model_index, model in enumerate(self.models):
for dtype, gbuf_shard in self.model_gbuf_shards[model_index].items():
world_shards = gbuf_shard["world_all"]
gbuf = model._grad_buffers[dtype]
gbuf_views = []
for shard in world_shards:
gbuf_views.append(gbuf.data[shard.start:shard.end])
gbuf_view_items.append((model_index, dtype, gbuf_views))
# pax(0, {"gbuf_view_items": gbuf_view_items})
return gbuf_view_items
def reduce_gradients(self, model): def reduce_gradients(self, model):
# >>> # >>>
...@@ -1338,43 +1361,87 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -1338,43 +1361,87 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Reduce-scatter. # Reduce-scatter.
# ** only contiguous grad buffer supported, for now [ TEMPORARY ] ** # # ** only contiguous grad buffer supported, for now [ TEMPORARY ] **
assert args.use_contiguous_buffers_in_local_ddp # assert args.use_contiguous_buffers_in_local_ddp
# data_parallel_rank = mpu.get_data_parallel_rank()
# data_parallel_group = mpu.get_data_parallel_group()
# for model_index, model in enumerate(self.models):
# for dtype, gbuf_shard in self.model_gbuf_shards[model_index].items():
# world_shards = gbuf_shard["world_all"]
# gbuf = model._grad_buffers[dtype]
# gbuf_views = []
# for shard in world_shards:
# gbuf_views.append(gbuf.data[shard.start:shard.end])
# torch.distributed.reduce_scatter(
# gbuf_views[data_parallel_rank],
# gbuf_views,
# group = data_parallel_group,
# )
# # pax(0, {
# # "model_index" : model_index,
# # "model" : model,
# # "dtype" : str(dtype),
# # "gbuf_shard" : gbuf_shard,
# # "world_shards" : world_shards,
# # "gbuf_views" : gbuf_views,
# # })
data_parallel_rank = mpu.get_data_parallel_rank() data_parallel_rank = mpu.get_data_parallel_rank()
data_parallel_group = mpu.get_data_parallel_group() data_parallel_group = mpu.get_data_parallel_group()
for model_index, model in enumerate(self.models):
for dtype, gbuf_shard in self.model_gbuf_shards[model_index].items():
world_shards = gbuf_shard["world_all"]
gbuf = model._grad_buffers[dtype] gbuf_view_items = self.get_model_grad_buffer_dp_views()
gbuf_views = []
for shard in world_shards:
gbuf_views.append(gbuf.data[shard.start:shard.end])
torch.distributed.reduce_scatter( for model_index, dtype, gbuf_views in gbuf_view_items:
gbuf_views[data_parallel_rank], torch.distributed.reduce_scatter(
gbuf_views, gbuf_views[data_parallel_rank],
group = data_parallel_group, gbuf_views,
) group = data_parallel_group,
)
# pax(0, {"gbuf_view_items": gbuf_view_items})
# pax(0, { def gather_params(self):
# "model_index" : model_index,
# "model" : model,
# "dtype" : str(dtype),
# "gbuf_shard" : gbuf_shard,
# "world_shards" : world_shards,
# "gbuf_views" : gbuf_views,
# })
# >>> data_parallel_rank = mpu.get_data_parallel_rank()
# torch.distributed.barrier() data_parallel_group = mpu.get_data_parallel_group()
# raise Exception("hi.")
# <<<
def gather_params(self): gbuf_view_items = self.get_model_grad_buffer_dp_views()
for model_index, dtype, gbuf_views in gbuf_view_items:
torch.distributed.all_gather(
gbuf_views,
gbuf_views[data_parallel_rank],
group = data_parallel_group,
)
# for param, (model_index, dtype) in self.param_gbuf_map.items():
# gbuf = self.model_gbuf_shards[model_index][dtype]
# pax(0, {
# "param" : tp(param),
# "model_index" : model_index,
# "dtype" : str(dtype),
# "gbuf" : gbuf,
# })
for param in self.param_gbuf_map:
param.detach().copy_(param.main_grad)
# pax(0, {
# "param" : tp(param),
# "main_grad" : tp(param.main_grad),
# # "grad" : tp(param.grad),
# })
raise Exception("gather params.") # pax(0, {
# "gbuf_view_items" : gbuf_view_items,
# "param_gbuf_map" : [
# (str(tuple(p.shape)), d)
# for p, d in self.param_gbuf_map.items()
# ],
# })
# def step(self): # def step(self):
...@@ -1429,6 +1496,49 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -1429,6 +1496,49 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# "opt_group_shards" : self.opt_group_shards, # "opt_group_shards" : self.opt_group_shards,
# }) # })
def _copy_main_params_to_model_params(self):
for group_index, group_shard in enumerate(self.opt_group_shards):
for param, main_shard in group_shard["param_map"].items():
model_index, gbuf_dtype = self.param_gbuf_map[param]
model_shard = self.model_gbuf_shards \
[model_index][gbuf_dtype]["param_map"][param]["world"]
assert main_shard.size == model_shard.size
# Use DDP's contiguous buffer to temporarily hold params.
model_tensor = \
self.models[model_index]._grad_buffers[gbuf_dtype].data
main_tensor = self.main_param_shards[group_index]
# Copy sub-range within tensor.
model_view = model_tensor[model_shard.start:model_shard.end]
main_view = main_tensor[main_shard.start:main_shard.end]
model_view.detach().copy_(main_view)
# Debug.
# pax(0, {
# "group_index" : group_index,
# "group_shard" : group_shard,
# "param" : tp(param),
# "model_index" : model_index,
# "gbuf_dtype" : str(gbuf_dtype),
# "model_grad_tensor" : tp(model_grad_tensor),
# "main_grad_tensor" : tp(main_grad_tensor),
# "model_grad_view" : tp(model_grad_view),
# "main_grad_view" : tp(main_grad_view),
# "model_shard" : str(model_shard),
# "main_shard" : str(main_shard),
# })
# pax(0, {
# "model_gbuf_shards" : self.model_gbuf_shards,
# "opt_group_shards" : self.opt_group_shards,
# })
# <<< # <<<
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment