Commit ed47ebff authored by Michael Carilli's avatar Michael Carilli
Browse files

Forward compatibility fixes for distributed backend, thanks to @Ssnl

parent 0ec8addb
...@@ -129,7 +129,17 @@ class DistributedDataParallel(Module): ...@@ -129,7 +129,17 @@ class DistributedDataParallel(Module):
def __init__(self, module, message_size=10000000, shared_param=False): def __init__(self, module, message_size=10000000, shared_param=False):
super(DistributedDataParallel, self).__init__() super(DistributedDataParallel, self).__init__()
self.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False
# Backward/forward compatibility around
# https://github.com/pytorch/pytorch/commit/540ef9b1fc5506369a48491af8a285a686689b36
if(hasattr(dist, "get_backend")):
self._backend = dist.get_backend()
self.backend_enum_holder = dist.DistBackend
else:
self._backend = dist._backend
self.backend_enum_holder = dist.dist_backend
self.warn_on_half = True if self._backend == self.backend_enum_holder.GLOO else False
self.shared_param = shared_param self.shared_param = shared_param
self.message_size = message_size self.message_size = message_size
...@@ -141,7 +151,7 @@ class DistributedDataParallel(Module): ...@@ -141,7 +151,7 @@ class DistributedDataParallel(Module):
self.module = module self.module = module
self.param_list = list(self.module.parameters()) self.param_list = list(self.module.parameters())
if dist._backend == dist.dist_backend.NCCL: if self._backend == self.backend_enum_holder.NCCL:
for param in self.param_list: for param in self.param_list:
assert param.is_cuda, "NCCL backend only supports model parameters to be on GPU." assert param.is_cuda, "NCCL backend only supports model parameters to be on GPU."
...@@ -156,7 +166,7 @@ class DistributedDataParallel(Module): ...@@ -156,7 +166,7 @@ class DistributedDataParallel(Module):
def __getstate__(self): def __getstate__(self):
attrs = copy.copy(self.__dict__) attrs = copy.copy(self.__dict__)
if dist._backend != dist.dist_backend.NCCL: if self._backend != self.backend_enum_holder.NCCL:
del attrs['self.reduction_stream'] del attrs['self.reduction_stream']
return attrs return attrs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment