Commit f1f97f9f authored by Christian Sarofeen's avatar Christian Sarofeen
Browse files

Handle set/get state for DDP, remove stream which cant be pickled.

parent 5c6144e6
......@@ -88,6 +88,17 @@ class DistributedDataParallel(Module):
flat_dist_call([param.data for param in self.module.parameters()], dist.broadcast, (0,) )
def __setstate__(self, state):
super(DistributedDataParallel, self).__setstate__(state)
self.reduction_stream = torch.cuda.Stream()
def __getstate__(self, state):
attrs = copy.copy(self.__dict__)
if dist._backend != dist.dist_backend.NCCL:
del attrs['self.reduction_stream']
return attrs
def create_hooks(self):
#all reduce gradient hook
def allreduce_params():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment