"tests/vscode:/vscode.git/clone" did not exist on "327deaee4122b3ff7780e36d0e481c5997dbe1fa"
Commit 7f0d8c87 authored by Michael Carilli's avatar Michael Carilli
Browse files

multi-tensor apply for DDP unflatten

parent 65ca6177
...@@ -6,6 +6,7 @@ from collections import OrderedDict ...@@ -6,6 +6,7 @@ from collections import OrderedDict
from itertools import chain from itertools import chain
import copy import copy
import importlib import importlib
from ..multi_tensor_apply import multi_tensor_applier
imported_flatten_impl = False imported_flatten_impl = False
...@@ -226,7 +227,13 @@ class DistributedDataParallel(Module): ...@@ -226,7 +227,13 @@ class DistributedDataParallel(Module):
self.param_type_to_tmp_i = {"torch.cuda.HalfTensor" : 0, self.param_type_to_tmp_i = {"torch.cuda.HalfTensor" : 0,
"torch.cuda.FloatTensor" : 1, "torch.cuda.FloatTensor" : 1,
"torch.cuda.DoubleTensor" : 2} "torch.cuda.DoubleTensor" : 2}
if multi_tensor_applier.available:
# TODO: I really need to centralize the C++ backed imports
import amp_C
self.multi_tensor_scale = amp_C.multi_tensor_scale
self._overflow_buf = torch.cuda.IntTensor([0])
self.create_hooks() self.create_hooks()
flat_dist_call([param.data for param in self.module.parameters()], dist.broadcast, (0,) ) flat_dist_call([param.data for param in self.module.parameters()], dist.broadcast, (0,) )
...@@ -396,8 +403,15 @@ class DistributedDataParallel(Module): ...@@ -396,8 +403,15 @@ class DistributedDataParallel(Module):
"allreduce buffer. This is almost certainly an error.") "allreduce buffer. This is almost certainly an error.")
self.allreduce_buffers[bucket_idx] = allreduced self.allreduce_buffers[bucket_idx] = allreduced
else: else:
for buf, synced in zip(bucket, unflatten(allreduced, bucket)): if multi_tensor_applier.available:
buf.copy_(synced) multi_tensor_applier(
self.multi_tensor_scale,
self._overflow_buf,
[unflatten(allreduced, bucket), bucket],
1.0)
else:
for buf, synced in zip(bucket, unflatten(allreduced, bucket)):
buf.copy_(synced)
def allreduce_fallback(self): def allreduce_fallback(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment