Commit 3f0bc681 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

copying model grad slices to main grad

parent 6875dff5
......@@ -934,7 +934,7 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
def reduce_gradients(self, model):
# >>>
pax(0, {"main param" : self.world_shard_info_groups[0][self.data_parallel_rank]["data"][torch.float]})
# pax(0, {"main param" : self.world_shard_info_groups[0][self.data_parallel_rank]["data"][torch.float]})
# <<<
# >>>
......@@ -968,7 +968,6 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Copy model grads to main shard.
local_shard_info_groups = [g[self.data_parallel_rank]
for g in self.world_shard_info_groups]
for group_index, local_shard_info in enumerate(local_shard_info_groups):
......@@ -986,51 +985,49 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
dtype_model_dict = param_model_map[param]
dtype = dtype_model_dict["dtype"]
vmodel = dtype_model_dict["model"]
model_grad_buffer = vmodel._grad_buffers[dtype]
model_grad_buffer = vmodel._grad_buffers[dtype].data
model_grad_buffer_start_index = \
vmodel._grad_buffer_param_index_map[dtype][param][0] + \
main_slice_orig_start_index
# main_grad_view = self.main_param_shard_groups \
# [group_index][torch.float].grad \
# [shard_indexes["shard"][0]:shard_indexes["shard"][1]]
main_grad_view = local_shard_info["data"][torch.float]
main_grad_view = local_shard_info["data"][torch.float].grad[
main_slice_shard_start_index:
main_slice_shard_start_index + main_slice_size
]
model_grad_view = model_grad_buffer[
model_grad_buffer_start_index:
model_grad_buffer_start_index + main_slice_size
]
pax(0, {
"local_shard_info" : local_shard_info,
"main_slice_orig_start_index" : main_slice_orig_start_index,
"main_slice_shard_start_index" : main_slice_shard_start_index,
"main_slice_size" : main_slice_size,
"model_grad_buffer_start_index" :
model_grad_buffer_start_index,
"main_grad_view" : main_grad_view,
})
main_grad_view.detach().copy_(model_grad_view)
pax(0, {
# "dtype" : dtype,
# "vmodel" : vmodel,
"shard_indexes" : shard_indexes,
"grad_buffer_indexes" : grad_buffer_indexes,
"model_grad_view" : model_grad_view,
"main_grad_views" : main_grad_view,
})
# pax(0, {
# # "local_shard_info" : local_shard_info,
# "main_slice_orig_start_index" : main_slice_orig_start_index,
# "main_slice_shard_start_index" : main_slice_shard_start_index,
# "main_slice_size" : main_slice_size,
# "model_grad_buffer_start_index" :
# model_grad_buffer_start_index,
# "main_grad_view" : tp(main_grad_view),
# "main_grad_view / detach" : tp(main_grad_view.detach()),
# "model_grad_view" : tp(model_grad_view),
# })
pax(0, {
"group_index" : group_index,
"local_shard_info" : local_shard_info,
"shard_param_index_map" : shard_param_index_map,
"param" : tp(param),
"shard_indexes" : shard_indexes,
"grad_buffer_indexes" : grad_buffer_indexes,
})
# pax(0, {
# "group_index" : group_index,
# "local_shard_info" : local_shard_info,
# "shard_param_index_map" : shard_param_index_map,
# "param" : tp(param),
# "shard_indexes" : shard_indexes,
# "grad_buffer_indexes" : grad_buffer_indexes,
# })
pax(0, {
# "world_shard_info_groups" : self.world_shard_info_groups,
# **{"world_shard_info_groups / %d" % i : v
# for i, v in enumerate(self.world_shard_info_groups)},
"local_shard_info_groups" : local_shard_info_groups,
"main_param_shard_groups" : self.main_param_shard_groups,
# "main_param_shard_groups" : self.main_param_shard_groups,
# "local_shard_info_groups" : local_shard_info_groups,
"local_shard_info_groups" : [ g["data"] for g in local_shard_info_groups ],
})
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment