Commit a3f3c3ad authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

todo; align shards with model's contiguous buffer

parent 3f0bc681
...@@ -122,7 +122,7 @@ class DistributedDataParallel(DistributedDataParallelBase): ...@@ -122,7 +122,7 @@ class DistributedDataParallel(DistributedDataParallelBase):
# =================================== # ===================================
self._grad_buffers = None self._grad_buffers = None
# >>> # >>>
from collections import defaultdict # from collections import defaultdict
# self._grad_buffer_param_offsets = None # self._grad_buffer_param_offsets = None
self._grad_buffer_param_index_map = None self._grad_buffer_param_index_map = None
# <<< # <<<
......
...@@ -770,35 +770,35 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -770,35 +770,35 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# (Also, collect world DP shard info.) # (Also, collect world DP shard info.)
# model_main_dtypes = set([ args.params_dtype, torch.float ]) # model_main_dtypes = set([ args.params_dtype, torch.float ])
model_main_dtypes = set([ torch.float ]) # fp32 only, for now model_main_dtypes = set([ torch.float ]) # fp32 only, for now
self.world_shard_info_groups = [] # world_group_shard_infos ? # self.world_shard_info_groups = [] # world_group_shard_infos ?
# self.main_param_shard_groups = [] # self.main_param_shard_groups = []
self.world_shard_infos = [{"groups": []} for _ in self.model_param_groups]
for group_index, model_param_group in enumerate(self.model_param_groups): for group_index, model_param_group in enumerate(self.model_param_groups):
# pax(0, {
# "model_param_group" : model_param_group,
# "offset_map" : [(o,tp(p)) for p, o in model_param_group["offset_map"].items()],
# })
# Max world shard size. # Max world shard size.
model_param_size = model_param_group["size"] model_param_size = model_param_group["size"]
max_world_shard_size = int(math.ceil(model_param_size / max_world_shard_size = int(math.ceil(model_param_size /
self.data_parallel_world_size)) self.data_parallel_world_size))
# DP world shard infos. # DP world shard infos.
world_shard_infos = [] # world_shard_infos = []
for r in range(self.data_parallel_world_size): for r in range(self.data_parallel_world_size):
shard_start_index = r * max_world_shard_size shard_start_index = r * max_world_shard_size
shard_end_index = min(model_param_size, shard_end_index = min(model_param_size,
shard_start_index + max_world_shard_size) shard_start_index + max_world_shard_size)
world_shard_infos.append({ # world_shard_infos.append({
self.world_shard_infos[r]["groups"].append({
"start" : shard_start_index, "start" : shard_start_index,
"end" : shard_end_index, "end" : shard_end_index,
"size" : shard_end_index - shard_start_index, "size" : shard_end_index - shard_start_index,
}) })
self.world_shard_info_groups.append(world_shard_infos) # self.world_shard_info_groups.append(world_shard_infos)
# self.world_shard_infos[group_index].append(world_shard_infos)
# DP local rank's shard info. # DP local rank's shard info.
local_shard_info = world_shard_infos[self.data_parallel_rank] # local_shard_info = world_shard_infos[self.data_parallel_rank]
local_shard_info = \
self.world_shard_infos[self.data_parallel_rank]["groups"][-1]
local_shard_start_index = local_shard_info["start"] local_shard_start_index = local_shard_info["start"]
local_shard_end_index = local_shard_info["end"] local_shard_end_index = local_shard_info["end"]
local_shard_size = local_shard_info["size"] local_shard_size = local_shard_info["size"]
...@@ -895,12 +895,25 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -895,12 +895,25 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# "params" : self.optimizer.param_groups[group_index]["params"], # "params" : self.optimizer.param_groups[group_index]["params"],
# }) # })
# Add world start/end indexes, for reduce/gather steps.
offset = 0
for r in self.world_shard_infos:
r["start_index"] = offset
offset += sum(g["size"] for g in r["groups"])
r["end_index"] = offset
# Leverage state_dict() and load_state_dict() to # Leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors # recast preexisting per-param state tensors
self.optimizer.load_state_dict(self.optimizer.state_dict()) self.optimizer.load_state_dict(self.optimizer.state_dict())
# >>> # >>>
# pax(0, {"main param" : self.world_shard_info_groups[0][self.data_parallel_rank]["data"][torch.float]}) # pax(0, {
# "world_shard_infos" : self.world_shard_infos,
# **{
# "world_shard_infos / %d" % i : r
# for i, r in enumerate(self.world_shard_infos)
# },
# })
# <<< # <<<
# def get_loss_scale(self): # def get_loss_scale(self):
...@@ -931,107 +944,129 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -931,107 +944,129 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# "params" : params, # "params" : params,
# }) # })
def reduce_gradients(self, model): # def reduce_gradients(self, model):
# # >>>
# # pax(0, {"main param" : self.world_shard_info_groups[0][self.data_parallel_rank]["data"][torch.float]})
# # <<<
# # >>>
# args = get_args()
# # timers = get_timers()
# # <<<
# # >>> [ temporary requirement ... and already checked in arguments.py ]
# assert args.use_contiguous_buffers_in_local_ddp
# # <<<
# # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# # Map param to virtual model.
# # ** ideally, this should happen once, during construction.
# param_model_map = {}
# for vmodel in model:
# for dtype, param_index_map in \
# vmodel._grad_buffer_param_index_map.items():
# for param in param_index_map:
# param_model_map[param] = {
# "dtype" : dtype,
# "model" : vmodel,
# }
# # pax(0, {
# # "param_model_map" : [
# # (str(tuple(p.shape)), m)
# # for p, m in param_model_map.items()
# # ],
# # })
# # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# # Copy model grads to main shard.
# local_shard_info_groups = [g[self.data_parallel_rank]
# for g in self.world_shard_info_groups]
# for group_index, local_shard_info in enumerate(local_shard_info_groups):
# # model_param_index_map =
# # shard_param_index_map = local_shard_info["param_index_map"]
# # main_index_map = local_shard_info["param_index_map"]
# main_slice_index_map = local_shard_info["param_slice_index_map"]
# for param, main_slice_indexes in main_slice_index_map.items():
# main_slice_orig_start_index = main_slice_indexes["orig_start"]
# main_slice_shard_start_index = main_slice_indexes["shard_start"]
# main_slice_size = main_slice_indexes["size"]
# dtype_model_dict = param_model_map[param]
# dtype = dtype_model_dict["dtype"]
# vmodel = dtype_model_dict["model"]
# model_grad_buffer = vmodel._grad_buffers[dtype].data
# model_grad_buffer_start_index = \
# vmodel._grad_buffer_param_index_map[dtype][param][0] + \
# main_slice_orig_start_index
# main_grad_view = local_shard_info["data"][torch.float].grad[
# main_slice_shard_start_index:
# main_slice_shard_start_index + main_slice_size
# ]
# model_grad_view = model_grad_buffer[
# model_grad_buffer_start_index:
# model_grad_buffer_start_index + main_slice_size
# ]
# main_grad_view.detach().copy_(model_grad_view)
# # pax(0, {
# # # "local_shard_info" : local_shard_info,
# # "main_slice_orig_start_index" : main_slice_orig_start_index,
# # "main_slice_shard_start_index" : main_slice_shard_start_index,
# # "main_slice_size" : main_slice_size,
# # "model_grad_buffer_start_index" :
# # model_grad_buffer_start_index,
# # "main_grad_view" : tp(main_grad_view),
# # "main_grad_view / detach" : tp(main_grad_view.detach()),
# # "model_grad_view" : tp(model_grad_view),
# # })
# # pax(0, {
# # "group_index" : group_index,
# # "local_shard_info" : local_shard_info,
# # "shard_param_index_map" : shard_param_index_map,
# # "param" : tp(param),
# # "shard_indexes" : shard_indexes,
# # "grad_buffer_indexes" : grad_buffer_indexes,
# # })
# pax(0, {
# # "world_shard_info_groups" : self.world_shard_info_groups,
# # **{"world_shard_info_groups / %d" % i : v
# # for i, v in enumerate(self.world_shard_info_groups)},
# # "local_shard_info_groups" : local_shard_info_groups,
# "local_shard_info_groups" : [ g["data"] for g in local_shard_info_groups ],
# })
# >>> def reduce_gradients(self, model):
# pax(0, {"main param" : self.world_shard_info_groups[0][self.data_parallel_rank]["data"][torch.float]})
# <<<
# >>> # >>>
args = get_args() args = get_args()
# timers = get_timers() # timers = get_timers()
# <<< # <<<
# >>> [ temporary requirement ... and already checked in arguments.py ]
assert args.use_contiguous_buffers_in_local_ddp
# <<<
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Map param to virtual model.
# ** ideally, this should happen once, during construction.
param_model_map = {}
for vmodel in model:
for dtype, param_index_map in \
vmodel._grad_buffer_param_index_map.items():
for param in param_index_map:
param_model_map[param] = {
"dtype" : dtype,
"model" : vmodel,
}
# pax(0, {
# "param_model_map" : [
# (str(tuple(p.shape)), m)
# for p, m in param_model_map.items()
# ],
# })
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Copy model grads to main shard.
local_shard_info_groups = [g[self.data_parallel_rank]
for g in self.world_shard_info_groups]
for group_index, local_shard_info in enumerate(local_shard_info_groups):
# model_param_index_map =
# shard_param_index_map = local_shard_info["param_index_map"]
# main_index_map = local_shard_info["param_index_map"]
main_slice_index_map = local_shard_info["param_slice_index_map"]
for param, main_slice_indexes in main_slice_index_map.items():
main_slice_orig_start_index = main_slice_indexes["orig_start"]
main_slice_shard_start_index = main_slice_indexes["shard_start"]
main_slice_size = main_slice_indexes["size"]
dtype_model_dict = param_model_map[param]
dtype = dtype_model_dict["dtype"]
vmodel = dtype_model_dict["model"]
model_grad_buffer = vmodel._grad_buffers[dtype].data
model_grad_buffer_start_index = \
vmodel._grad_buffer_param_index_map[dtype][param][0] + \
main_slice_orig_start_index
main_grad_view = local_shard_info["data"][torch.float].grad[
main_slice_shard_start_index:
main_slice_shard_start_index + main_slice_size
]
model_grad_view = model_grad_buffer[
model_grad_buffer_start_index:
model_grad_buffer_start_index + main_slice_size
]
main_grad_view.detach().copy_(model_grad_view)
# pax(0, {
# # "local_shard_info" : local_shard_info,
# "main_slice_orig_start_index" : main_slice_orig_start_index,
# "main_slice_shard_start_index" : main_slice_shard_start_index,
# "main_slice_size" : main_slice_size,
# "model_grad_buffer_start_index" :
# model_grad_buffer_start_index,
# "main_grad_view" : tp(main_grad_view),
# "main_grad_view / detach" : tp(main_grad_view.detach()),
# "model_grad_view" : tp(model_grad_view),
# })
# pax(0, {
# "group_index" : group_index,
# "local_shard_info" : local_shard_info,
# "shard_param_index_map" : shard_param_index_map,
# "param" : tp(param),
# "shard_indexes" : shard_indexes,
# "grad_buffer_indexes" : grad_buffer_indexes,
# })
pax(0, {
# "world_shard_info_groups" : self.world_shard_info_groups,
# **{"world_shard_info_groups / %d" % i : v
# for i, v in enumerate(self.world_shard_info_groups)},
# "local_shard_info_groups" : local_shard_info_groups,
"local_shard_info_groups" : [ g["data"] for g in local_shard_info_groups ],
})
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Reduce-scatter. # Reduce-scatter.
# ** only contiguous grad buffer supported, for now [ TEMPORARY ] **
assert args.use_contiguous_buffers_in_local_ddp
world_sizes = []
for r in self.world_shard_infos:
# world_sizes.append(sum(g["size"] for g in r))
world_sizes.append([ g["size"] for g in r["groups"] ])
# grad_refs ...
pax(0, {"world_sizes": world_sizes})
# for world_grads = []
# for world_shard_info_group
# x ?
raise Exception("hi.")
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment