Commit 7f0d8c87 authored by Michael Carilli's avatar Michael Carilli
Browse files

multi-tensor apply for DDP unflatten

parent 65ca6177
......@@ -6,6 +6,7 @@ from collections import OrderedDict
from itertools import chain
import copy
import importlib
from ..multi_tensor_apply import multi_tensor_applier
imported_flatten_impl = False
......@@ -226,7 +227,13 @@ class DistributedDataParallel(Module):
self.param_type_to_tmp_i = {"torch.cuda.HalfTensor" : 0,
"torch.cuda.FloatTensor" : 1,
"torch.cuda.DoubleTensor" : 2}
if multi_tensor_applier.available:
# TODO: I really need to centralize the C++ backed imports
import amp_C
self.multi_tensor_scale = amp_C.multi_tensor_scale
self._overflow_buf = torch.cuda.IntTensor([0])
self.create_hooks()
flat_dist_call([param.data for param in self.module.parameters()], dist.broadcast, (0,) )
......@@ -396,8 +403,15 @@ class DistributedDataParallel(Module):
"allreduce buffer. This is almost certainly an error.")
self.allreduce_buffers[bucket_idx] = allreduced
else:
for buf, synced in zip(bucket, unflatten(allreduced, bucket)):
buf.copy_(synced)
if multi_tensor_applier.available:
multi_tensor_applier(
self.multi_tensor_scale,
self._overflow_buf,
[unflatten(allreduced, bucket), bucket],
1.0)
else:
for buf, synced in zip(bucket, unflatten(allreduced, bucket)):
buf.copy_(synced)
def allreduce_fallback(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment