"src/turbomind/vscode:/vscode.git/clone" did not exist on "911c0a85dadbf1783940138d7c6aafdbc88d6a17"
Commit c5f93269 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

map param to originating virtual model; eventually move this to constructor

parent 3ded2425
...@@ -123,11 +123,16 @@ class DistributedDataParallel(DistributedDataParallelBase): ...@@ -123,11 +123,16 @@ class DistributedDataParallel(DistributedDataParallelBase):
self._grad_buffers = None self._grad_buffers = None
# >>> # >>>
from collections import defaultdict from collections import defaultdict
self._grad_buffer_param_offsets = None # self._grad_buffer_param_offsets = None
self._grad_buffer_param_index_map = None
# <<< # <<<
if self.use_contiguous_buffers: if self.use_contiguous_buffers:
self._grad_buffers = {} self._grad_buffers = {}
self._grad_buffer_param_offsets = defaultdict(dict) # >>>
# self._grad_buffer_param_offsets = defaultdict(dict)
# self._grad_buffer_param_index_map = defaultdict(dict)
self._grad_buffer_param_index_map = {}
# <<<
# Simple function to define buffer type. # Simple function to define buffer type.
def _get_buffer_type(param): def _get_buffer_type(param):
...@@ -154,8 +159,16 @@ class DistributedDataParallel(DistributedDataParallelBase): ...@@ -154,8 +159,16 @@ class DistributedDataParallel(DistributedDataParallelBase):
type_num_elements[dtype] -= param.data.nelement() type_num_elements[dtype] -= param.data.nelement()
param.main_grad = self._grad_buffers[dtype].get( param.main_grad = self._grad_buffers[dtype].get(
param.data.shape, type_num_elements[dtype]) param.data.shape, type_num_elements[dtype])
self._grad_buffer_param_offsets[dtype][param] = \ # >>>
type_num_elements[dtype] # self._grad_buffer_param_offsets[dtype][param] = \
# type_num_elements[dtype]
if dtype not in self._grad_buffer_param_index_map:
self._grad_buffer_param_index_map[dtype] = {}
self._grad_buffer_param_index_map[dtype][param] = {
"start" : type_num_elements[dtype],
"end" : param.data.nelement(),
}
# <<<
# Backward hook. # Backward hook.
# Accumalation function for the gradients. We need # Accumalation function for the gradients. We need
......
...@@ -775,7 +775,6 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -775,7 +775,6 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# pax(0, { # pax(0, {
# "model_param_group" : model_param_group, # "model_param_group" : model_param_group,
# # "offset_map" : {str(p.shape):o for p, o in model_param_group["offset_map"].items()},
# "offset_map" : [(o,tp(p)) for p, o in model_param_group["offset_map"].items()], # "offset_map" : [(o,tp(p)) for p, o in model_param_group["offset_map"].items()],
# }) # })
...@@ -843,10 +842,10 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -843,10 +842,10 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# pax(2, { # pax(2, {
# "data_parallel_rank" : self.data_parallel_rank, # "data_parallel_rank" : self.data_parallel_rank,
# "local_shard_info" : local_shard_info, # "local_shard_info" : local_shard_info,
# "param_index_map " : { # "param_index_map " : [
# str(p.shape) : i # (str(p.shape), i)
# for p, i in local_shard_info["param_index_map"].items() # for p, i in local_shard_info["param_index_map"].items()
# }, # ],
# }) # })
# Allocate shards. # Allocate shards.
...@@ -904,15 +903,57 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -904,15 +903,57 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# timers = get_timers() # timers = get_timers()
# <<< # <<<
# >>> [ already checked in arguments.py ] # >>> [ temporary requirement ... and already checked in arguments.py ]
assert args.use_contiguous_buffers_in_local_ddp assert args.use_contiguous_buffers_in_local_ddp
# <<< # <<<
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Map param to virtual model.
# ** ideally, this should happen once, during construction.
param_model_map = {}
for vmodel in model:
for dtype, param_index_map in \
vmodel._grad_buffer_param_index_map.items():
for param in param_index_map:
param_model_map[param] = {
"dtype" : dtype,
"model" : vmodel,
}
# pax(0, {
# "param_model_map" : [
# (str(tuple(p.shape)), m)
# for p, m in param_model_map.items()
# ],
# })
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Copy model grads to main shard. # Copy model grads to main shard.
local_shard_info_groups = [g[self.data_parallel_rank] local_shard_info_groups = [g[self.data_parallel_rank]
for g in self.world_shard_info_groups] for g in self.world_shard_info_groups]
for group_index, local_shard_info in enumerate(local_shard_info_groups):
# model_param_index_map =
shard_param_index_map = local_shard_info["param_index_map"]
for param, shard_indexes in shard_param_index_map.items():
dtype_model_dict = param_model_map[param]
dtype = dtype_model_dict["dtype"]
vmodel = dtype_model_dict["model"]
grad_buffer_indexes = \
vmodel._grad_buffer_param_index_map[dtype][param]
pax(0, {"dtype": dtype})
pax(0, {
"group_index" : group_index,
"local_shard_info" : local_shard_info,
"shard_param_index_map" : shard_param_index_map,
"param" : tp(param),
"shard_indexes" : shard_indexes,
"grad_buffer_indexes" : grad_buffer_indexes,
})
pax(0, { pax(0, {
# "world_shard_info_groups" : self.world_shard_info_groups, # "world_shard_info_groups" : self.world_shard_info_groups,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment