"vscode:/vscode.git/clone" did not exist on "2b3a1b6dfc50a4daf9c9c5cf76606b4cccc61892"
Commit c5f93269 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

map param to originating virtual model; eventually move this to constructor

parent 3ded2425
......@@ -123,11 +123,16 @@ class DistributedDataParallel(DistributedDataParallelBase):
self._grad_buffers = None
# >>>
from collections import defaultdict
self._grad_buffer_param_offsets = None
# self._grad_buffer_param_offsets = None
self._grad_buffer_param_index_map = None
# <<<
if self.use_contiguous_buffers:
self._grad_buffers = {}
self._grad_buffer_param_offsets = defaultdict(dict)
# >>>
# self._grad_buffer_param_offsets = defaultdict(dict)
# self._grad_buffer_param_index_map = defaultdict(dict)
self._grad_buffer_param_index_map = {}
# <<<
# Simple function to define buffer type.
def _get_buffer_type(param):
......@@ -154,8 +159,16 @@ class DistributedDataParallel(DistributedDataParallelBase):
type_num_elements[dtype] -= param.data.nelement()
param.main_grad = self._grad_buffers[dtype].get(
param.data.shape, type_num_elements[dtype])
self._grad_buffer_param_offsets[dtype][param] = \
type_num_elements[dtype]
# >>>
# self._grad_buffer_param_offsets[dtype][param] = \
# type_num_elements[dtype]
if dtype not in self._grad_buffer_param_index_map:
self._grad_buffer_param_index_map[dtype] = {}
self._grad_buffer_param_index_map[dtype][param] = {
"start" : type_num_elements[dtype],
"end" : param.data.nelement(),
}
# <<<
# Backward hook.
# Accumalation function for the gradients. We need
......
......@@ -775,7 +775,6 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# pax(0, {
# "model_param_group" : model_param_group,
# # "offset_map" : {str(p.shape):o for p, o in model_param_group["offset_map"].items()},
# "offset_map" : [(o,tp(p)) for p, o in model_param_group["offset_map"].items()],
# })
......@@ -843,10 +842,10 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# pax(2, {
# "data_parallel_rank" : self.data_parallel_rank,
# "local_shard_info" : local_shard_info,
# "param_index_map " : {
# str(p.shape) : i
# "param_index_map " : [
# (str(p.shape), i)
# for p, i in local_shard_info["param_index_map"].items()
# },
# ],
# })
# Allocate shards.
......@@ -904,15 +903,57 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# timers = get_timers()
# <<<
# >>> [ already checked in arguments.py ]
# >>> [ temporary requirement ... and already checked in arguments.py ]
assert args.use_contiguous_buffers_in_local_ddp
# <<<
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Map param to virtual model.
# ** ideally, this should happen once, during construction.
param_model_map = {}
for vmodel in model:
for dtype, param_index_map in \
vmodel._grad_buffer_param_index_map.items():
for param in param_index_map:
param_model_map[param] = {
"dtype" : dtype,
"model" : vmodel,
}
# pax(0, {
# "param_model_map" : [
# (str(tuple(p.shape)), m)
# for p, m in param_model_map.items()
# ],
# })
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Copy model grads to main shard.
local_shard_info_groups = [g[self.data_parallel_rank]
for g in self.world_shard_info_groups]
for group_index, local_shard_info in enumerate(local_shard_info_groups):
# model_param_index_map =
shard_param_index_map = local_shard_info["param_index_map"]
for param, shard_indexes in shard_param_index_map.items():
dtype_model_dict = param_model_map[param]
dtype = dtype_model_dict["dtype"]
vmodel = dtype_model_dict["model"]
grad_buffer_indexes = \
vmodel._grad_buffer_param_index_map[dtype][param]
pax(0, {"dtype": dtype})
pax(0, {
"group_index" : group_index,
"local_shard_info" : local_shard_info,
"shard_param_index_map" : shard_param_index_map,
"param" : tp(param),
"shard_indexes" : shard_indexes,
"grad_buffer_indexes" : grad_buffer_indexes,
})
pax(0, {
# "world_shard_info_groups" : self.world_shard_info_groups,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment