Commit 1215c420 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

tweaked slice index naming convention

parent c5f93269
...@@ -164,10 +164,14 @@ class DistributedDataParallel(DistributedDataParallelBase): ...@@ -164,10 +164,14 @@ class DistributedDataParallel(DistributedDataParallelBase):
# type_num_elements[dtype] # type_num_elements[dtype]
if dtype not in self._grad_buffer_param_index_map: if dtype not in self._grad_buffer_param_index_map:
self._grad_buffer_param_index_map[dtype] = {} self._grad_buffer_param_index_map[dtype] = {}
self._grad_buffer_param_index_map[dtype][param] = { # self._grad_buffer_param_index_map[dtype][param] = {
"start" : type_num_elements[dtype], # "start" : type_num_elements[dtype],
"end" : param.data.nelement(), # "end" : type_num_elements[dtype] + param.data.nelement(),
} # }
self._grad_buffer_param_index_map[dtype][param] = (
type_num_elements[dtype],
type_num_elements[dtype] + param.data.nelement(),
)
# <<< # <<<
# Backward hook. # Backward hook.
......
...@@ -802,27 +802,48 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -802,27 +802,48 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
local_shard_end_index = local_shard_info["end"] local_shard_end_index = local_shard_info["end"]
local_shard_size = local_shard_info["size"] local_shard_size = local_shard_info["size"]
# Local shard's param index map. # Local shard's param 'slice' index map.
local_shard_info["param_index_map"] = {} local_shard_info["param_slice_index_map"] = {}
for param, offset_dict in model_param_group["offset_map"].items(): for param, offset_dict in model_param_group["offset_map"].items():
param_start_index = offset_dict["start"] # param_start_index = offset_dict["start"]
param_end_index = offset_dict["end"] # param_end_index = offset_dict["end"]
param_shard_start_index = max(local_shard_start_index, # param_shard_start_index = max(local_shard_start_index,
param_start_index) # param_start_index)
param_shard_end_index = min(local_shard_end_index, # param_shard_end_index = min(local_shard_end_index,
param_end_index) # param_end_index)
orig_start_index = offset_dict["start"]
orig_end_index = offset_dict["end"]
shard_start_index = max(
0,
orig_start_index - local_shard_start_index)
shard_end_index = min(
local_shard_end_index,
orig_end_index - local_shard_start_index)
if param_shard_end_index > param_shard_start_index: if param_shard_end_index > param_shard_start_index:
# Indexes are relative to local shard start index. # Indexes are relative to local shard start index.
local_shard_info["param_index_map"][param] = { # local_shard_info["param_index_map"][param] = {
"param" : ( # "param" : (
param_shard_start_index, # param_shard_start_index,
param_shard_end_index, # param_shard_end_index,
), # ),
"shard" : ( # "shard" : (
param_shard_start_index - local_shard_start_index, # param_shard_start_index - local_shard_start_index,
param_shard_end_index - local_shard_start_index, # param_shard_end_index - local_shard_start_index,
), # ),
# }
# local_shard_info["param_slice_index_map"][param] = {
# "param_start" :
# param_shard_start_index,
# "shard_start" :
# param_shard_start_index - local_shard_start_index,
# "size":
# param_shard_end_index - param_shard_start_index,
# }
local_shard_info["param_slice_index_map"][param] = {
"orig_start" : orig_start_index,
"shard_start" : shard_start_index,
"size" : shard_end_index - shard_start_index,
} }
# pax(0, { # pax(0, {
...@@ -854,7 +875,8 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -854,7 +875,8 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
main_param_shards = { main_param_shards = {
ty : allocate_shard(local_shard_size, ty) ty : allocate_shard(local_shard_size, ty)
for ty in model_main_dtypes} for ty in model_main_dtypes}
self.main_param_shard_groups.append(main_param_shards) # self.main_param_shard_groups.append(main_param_shards)
local_shard_info["data"] = main_param_shards
# Update optimizer group. # Update optimizer group.
self.optimizer.param_groups[group_index]["params"] = \ self.optimizer.param_groups[group_index]["params"] = \
...@@ -935,16 +957,41 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -935,16 +957,41 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
for group_index, local_shard_info in enumerate(local_shard_info_groups): for group_index, local_shard_info in enumerate(local_shard_info_groups):
# model_param_index_map = # model_param_index_map =
shard_param_index_map = local_shard_info["param_index_map"] # shard_param_index_map = local_shard_info["param_index_map"]
for param, shard_indexes in shard_param_index_map.items(): # main_index_map = local_shard_info["param_index_map"]
main_slice_index_map = local_shard_info["param_slice_index_map"]
for param, main_slice_indexes in main_slice_index_map.items():
main_param_start_index = main_slice_indexes["param_start"]
main_shard_start_index = main_slice_indexes["shard_start"]
main_slice_size = ddd
main_size = main_shard_indexesddd
dtype_model_dict = param_model_map[param] dtype_model_dict = param_model_map[param]
dtype = dtype_model_dict["dtype"] dtype = dtype_model_dict["dtype"]
vmodel = dtype_model_dict["model"] vmodel = dtype_model_dict["model"]
grad_buffer_indexes = \ model_grad_buffer = vmodel._grad_buffers[dtype]
vmodel._grad_buffer_param_index_map[dtype][param] model_grad_buffer_start_index = \
vmodel._grad_buffer_param_index_map[dtype][param][0]
# model_grad_buffer_indexes = [ model_grad_buffer_start_index + i
# for i in main_
# model_grad_view = model_grad_buffer.data[
pax(0, {"model_grad_buffer_indexes": model_grad_buffer_indexes})
main_grad_view = self.main_param_shard_groups \
[group_index][torch.float].grad \
[shard_indexes["shard"][0]:shard_indexes["shard"][1]]
pax(0, {"dtype": dtype}) pax(0, {
# "dtype" : dtype,
# "vmodel" : vmodel,
"shard_indexes" : shard_indexes,
"grad_buffer_indexes" : grad_buffer_indexes,
"model_grad_view" : model_grad_view,
"main_grad_views" : main_grad_view,
})
pax(0, { pax(0, {
"group_index" : group_index, "group_index" : group_index,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment