Commit 7f0d8c87 authored by Michael Carilli's avatar Michael Carilli
Browse files

multi-tensor apply for DDP unflatten

parent 65ca6177
...@@ -6,6 +6,7 @@ from collections import OrderedDict ...@@ -6,6 +6,7 @@ from collections import OrderedDict
from itertools import chain from itertools import chain
import copy import copy
import importlib import importlib
from ..multi_tensor_apply import multi_tensor_applier
imported_flatten_impl = False imported_flatten_impl = False
...@@ -227,6 +228,12 @@ class DistributedDataParallel(Module): ...@@ -227,6 +228,12 @@ class DistributedDataParallel(Module):
"torch.cuda.FloatTensor" : 1, "torch.cuda.FloatTensor" : 1,
"torch.cuda.DoubleTensor" : 2} "torch.cuda.DoubleTensor" : 2}
if multi_tensor_applier.available:
# TODO: I really need to centralize the C++ backed imports
import amp_C
self.multi_tensor_scale = amp_C.multi_tensor_scale
self._overflow_buf = torch.cuda.IntTensor([0])
self.create_hooks() self.create_hooks()
flat_dist_call([param.data for param in self.module.parameters()], dist.broadcast, (0,) ) flat_dist_call([param.data for param in self.module.parameters()], dist.broadcast, (0,) )
...@@ -395,6 +402,13 @@ class DistributedDataParallel(Module): ...@@ -395,6 +402,13 @@ class DistributedDataParallel(Module):
raise RuntimeError("The backward pass is attempting to replace an already-filled " raise RuntimeError("The backward pass is attempting to replace an already-filled "
"allreduce buffer. This is almost certainly an error.") "allreduce buffer. This is almost certainly an error.")
self.allreduce_buffers[bucket_idx] = allreduced self.allreduce_buffers[bucket_idx] = allreduced
else:
if multi_tensor_applier.available:
multi_tensor_applier(
self.multi_tensor_scale,
self._overflow_buf,
[unflatten(allreduced, bucket), bucket],
1.0)
else: else:
for buf, synced in zip(bucket, unflatten(allreduced, bucket)): for buf, synced in zip(bucket, unflatten(allreduced, bucket)):
buf.copy_(synced) buf.copy_(synced)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment