Commit 1215c420 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

tweaked slice index naming convention

parent c5f93269
......@@ -164,10 +164,14 @@ class DistributedDataParallel(DistributedDataParallelBase):
# type_num_elements[dtype]
if dtype not in self._grad_buffer_param_index_map:
self._grad_buffer_param_index_map[dtype] = {}
self._grad_buffer_param_index_map[dtype][param] = {
"start" : type_num_elements[dtype],
"end" : param.data.nelement(),
}
# self._grad_buffer_param_index_map[dtype][param] = {
# "start" : type_num_elements[dtype],
# "end" : type_num_elements[dtype] + param.data.nelement(),
# }
self._grad_buffer_param_index_map[dtype][param] = (
type_num_elements[dtype],
type_num_elements[dtype] + param.data.nelement(),
)
# <<<
# Backward hook.
......
......@@ -802,27 +802,48 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
local_shard_end_index = local_shard_info["end"]
local_shard_size = local_shard_info["size"]
# Local shard's param index map.
local_shard_info["param_index_map"] = {}
# Local shard's param 'slice' index map.
local_shard_info["param_slice_index_map"] = {}
for param, offset_dict in model_param_group["offset_map"].items():
param_start_index = offset_dict["start"]
param_end_index = offset_dict["end"]
param_shard_start_index = max(local_shard_start_index,
param_start_index)
param_shard_end_index = min(local_shard_end_index,
param_end_index)
# param_start_index = offset_dict["start"]
# param_end_index = offset_dict["end"]
# param_shard_start_index = max(local_shard_start_index,
# param_start_index)
# param_shard_end_index = min(local_shard_end_index,
# param_end_index)
orig_start_index = offset_dict["start"]
orig_end_index = offset_dict["end"]
shard_start_index = max(
0,
orig_start_index - local_shard_start_index)
shard_end_index = min(
local_shard_end_index,
orig_end_index - local_shard_start_index)
if param_shard_end_index > param_shard_start_index:
# Indexes are relative to local shard start index.
local_shard_info["param_index_map"][param] = {
"param" : (
param_shard_start_index,
param_shard_end_index,
),
"shard" : (
param_shard_start_index - local_shard_start_index,
param_shard_end_index - local_shard_start_index,
),
# local_shard_info["param_index_map"][param] = {
# "param" : (
# param_shard_start_index,
# param_shard_end_index,
# ),
# "shard" : (
# param_shard_start_index - local_shard_start_index,
# param_shard_end_index - local_shard_start_index,
# ),
# }
# local_shard_info["param_slice_index_map"][param] = {
# "param_start" :
# param_shard_start_index,
# "shard_start" :
# param_shard_start_index - local_shard_start_index,
# "size":
# param_shard_end_index - param_shard_start_index,
# }
local_shard_info["param_slice_index_map"][param] = {
"orig_start" : orig_start_index,
"shard_start" : shard_start_index,
"size" : shard_end_index - shard_start_index,
}
# pax(0, {
......@@ -854,7 +875,8 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
main_param_shards = {
ty : allocate_shard(local_shard_size, ty)
for ty in model_main_dtypes}
self.main_param_shard_groups.append(main_param_shards)
# self.main_param_shard_groups.append(main_param_shards)
local_shard_info["data"] = main_param_shards
# Update optimizer group.
self.optimizer.param_groups[group_index]["params"] = \
......@@ -935,16 +957,41 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
for group_index, local_shard_info in enumerate(local_shard_info_groups):
# model_param_index_map =
shard_param_index_map = local_shard_info["param_index_map"]
for param, shard_indexes in shard_param_index_map.items():
# shard_param_index_map = local_shard_info["param_index_map"]
# main_index_map = local_shard_info["param_index_map"]
main_slice_index_map = local_shard_info["param_slice_index_map"]
for param, main_slice_indexes in main_slice_index_map.items():
main_param_start_index = main_slice_indexes["param_start"]
main_shard_start_index = main_slice_indexes["shard_start"]
main_slice_size = ddd
main_size = main_shard_indexesddd
dtype_model_dict = param_model_map[param]
dtype = dtype_model_dict["dtype"]
vmodel = dtype_model_dict["model"]
grad_buffer_indexes = \
vmodel._grad_buffer_param_index_map[dtype][param]
model_grad_buffer = vmodel._grad_buffers[dtype]
model_grad_buffer_start_index = \
vmodel._grad_buffer_param_index_map[dtype][param][0]
# model_grad_buffer_indexes = [ model_grad_buffer_start_index + i
# for i in main_
# model_grad_view = model_grad_buffer.data[
pax(0, {"model_grad_buffer_indexes": model_grad_buffer_indexes})
main_grad_view = self.main_param_shard_groups \
[group_index][torch.float].grad \
[shard_indexes["shard"][0]:shard_indexes["shard"][1]]
pax(0, {"dtype": dtype})
pax(0, {
# "dtype" : dtype,
# "vmodel" : vmodel,
"shard_indexes" : shard_indexes,
"grad_buffer_indexes" : grad_buffer_indexes,
"model_grad_view" : model_grad_view,
"main_grad_views" : main_grad_view,
})
pax(0, {
"group_index" : group_index,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment