Commit f7232502 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

collect param offsets for contiguous grad buffer

parent 5706ba42
......@@ -121,8 +121,13 @@ class DistributedDataParallel(DistributedDataParallelBase):
# the case we use continuous buffers.
# ===================================
self._grad_buffers = None
# >>>
from collections import defaultdict
self._grad_buffer_param_offsets = None
# <<<
if self.use_contiguous_buffers:
self._grad_buffers = {}
self._grad_buffer_param_offsets = defaultdict(dict)
# Simple function to define buffer type.
def _get_buffer_type(param):
......@@ -149,6 +154,8 @@ class DistributedDataParallel(DistributedDataParallelBase):
type_num_elements[dtype] -= param.data.nelement()
param.main_grad = self._grad_buffers[dtype].get(
param.data.shape, type_num_elements[dtype])
self._grad_buffer_param_offsets[dtype][param] = \
type_num_elements[dtype]
# Backward hook.
# Accumalation function for the gradients. We need
......@@ -164,6 +171,17 @@ class DistributedDataParallel(DistributedDataParallelBase):
grad_acc.register_hook(self._make_param_hook(param))
self.grad_accs.append(grad_acc)
# >>>
# from lutil import pax, tp
# pax(0, {
# "_grad_buffers" : {k:b.numel for k,b in self._grad_buffers.items()},
# "_grad_buffer_param_offsets" : self._grad_buffer_param_offsets,
# **{"_grad_buffer_param_offsets / %s" % ty : {
# str(p.shape) : o for p, o in po.items()
# } for ty, po in self._grad_buffer_param_offsets.items()},
# })
# <<<
def _make_param_hook(self, param):
"""Create the all-reduce hook for backprop."""
......
......@@ -758,13 +758,13 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# })
# Shard allocator.
# ** torch.nn.Parameter ??
# ** MemoryBuffer ??
allocate_shard = lambda shard_size, dtype : torch.empty(
(shard_size,),
dtype = dtype,
device = torch.cuda.current_device(),
requires_grad = True)
# return torch.nn.Parameter ?
# allocate_shard = lambda dtype : MemoryBuffer(self.shard_size, dtype)
# Allocate shards.
# (Also, collect world DP shard info.)
......@@ -860,6 +860,19 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
assert args.use_contiguous_buffers_in_local_ddp
# <<<
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Copy model grads to main shard.
self.world_shard_info_groups = [] # world_group_shard_infos ?
self.main_param_shard_groups = []
pax(0, {"main_shard_info_groups": self.main_shard_info_groups})
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Reduce-scatter.
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# grad_buffers = [ m._grad_buffers for m in model ]
for virtual_model in model:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment