"docs/git@developer.sourcefind.cn:change/sglang.git" did not exist on "ef5b0ff90b526e4f8dc301553f3a1ce04e9dd71d"
Commit 3f0bc681 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

copying model grad slices to main grad

parent 6875dff5
...@@ -934,7 +934,7 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -934,7 +934,7 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
def reduce_gradients(self, model): def reduce_gradients(self, model):
# >>> # >>>
pax(0, {"main param" : self.world_shard_info_groups[0][self.data_parallel_rank]["data"][torch.float]}) # pax(0, {"main param" : self.world_shard_info_groups[0][self.data_parallel_rank]["data"][torch.float]})
# <<< # <<<
# >>> # >>>
...@@ -968,7 +968,6 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -968,7 +968,6 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Copy model grads to main shard. # Copy model grads to main shard.
local_shard_info_groups = [g[self.data_parallel_rank] local_shard_info_groups = [g[self.data_parallel_rank]
for g in self.world_shard_info_groups] for g in self.world_shard_info_groups]
for group_index, local_shard_info in enumerate(local_shard_info_groups): for group_index, local_shard_info in enumerate(local_shard_info_groups):
...@@ -986,51 +985,49 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -986,51 +985,49 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
dtype_model_dict = param_model_map[param] dtype_model_dict = param_model_map[param]
dtype = dtype_model_dict["dtype"] dtype = dtype_model_dict["dtype"]
vmodel = dtype_model_dict["model"] vmodel = dtype_model_dict["model"]
model_grad_buffer = vmodel._grad_buffers[dtype] model_grad_buffer = vmodel._grad_buffers[dtype].data
model_grad_buffer_start_index = \ model_grad_buffer_start_index = \
vmodel._grad_buffer_param_index_map[dtype][param][0] + \ vmodel._grad_buffer_param_index_map[dtype][param][0] + \
main_slice_orig_start_index main_slice_orig_start_index
# main_grad_view = self.main_param_shard_groups \ main_grad_view = local_shard_info["data"][torch.float].grad[
# [group_index][torch.float].grad \ main_slice_shard_start_index:
# [shard_indexes["shard"][0]:shard_indexes["shard"][1]] main_slice_shard_start_index + main_slice_size
main_grad_view = local_shard_info["data"][torch.float] ]
model_grad_view = model_grad_buffer[
model_grad_buffer_start_index:
model_grad_buffer_start_index + main_slice_size
]
pax(0, { main_grad_view.detach().copy_(model_grad_view)
"local_shard_info" : local_shard_info,
"main_slice_orig_start_index" : main_slice_orig_start_index,
"main_slice_shard_start_index" : main_slice_shard_start_index,
"main_slice_size" : main_slice_size,
"model_grad_buffer_start_index" :
model_grad_buffer_start_index,
"main_grad_view" : main_grad_view,
})
pax(0, { # pax(0, {
# "dtype" : dtype, # # "local_shard_info" : local_shard_info,
# "vmodel" : vmodel, # "main_slice_orig_start_index" : main_slice_orig_start_index,
"shard_indexes" : shard_indexes, # "main_slice_shard_start_index" : main_slice_shard_start_index,
"grad_buffer_indexes" : grad_buffer_indexes, # "main_slice_size" : main_slice_size,
"model_grad_view" : model_grad_view, # "model_grad_buffer_start_index" :
"main_grad_views" : main_grad_view, # model_grad_buffer_start_index,
}) # "main_grad_view" : tp(main_grad_view),
# "main_grad_view / detach" : tp(main_grad_view.detach()),
# "model_grad_view" : tp(model_grad_view),
# })
pax(0, { # pax(0, {
"group_index" : group_index, # "group_index" : group_index,
"local_shard_info" : local_shard_info, # "local_shard_info" : local_shard_info,
"shard_param_index_map" : shard_param_index_map, # "shard_param_index_map" : shard_param_index_map,
"param" : tp(param), # "param" : tp(param),
"shard_indexes" : shard_indexes, # "shard_indexes" : shard_indexes,
"grad_buffer_indexes" : grad_buffer_indexes, # "grad_buffer_indexes" : grad_buffer_indexes,
}) # })
pax(0, { pax(0, {
# "world_shard_info_groups" : self.world_shard_info_groups, # "world_shard_info_groups" : self.world_shard_info_groups,
# **{"world_shard_info_groups / %d" % i : v # **{"world_shard_info_groups / %d" % i : v
# for i, v in enumerate(self.world_shard_info_groups)}, # for i, v in enumerate(self.world_shard_info_groups)},
"local_shard_info_groups" : local_shard_info_groups, # "local_shard_info_groups" : local_shard_info_groups,
"main_param_shard_groups" : self.main_param_shard_groups, "local_shard_info_groups" : [ g["data"] for g in local_shard_info_groups ],
# "main_param_shard_groups" : self.main_param_shard_groups,
}) })
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment