Unverified Commit 77ee4bcd authored by Christian Sarofeen's avatar Christian Sarofeen Committed by GitHub
Browse files

Merge pull request #30 from NVIDIA/checkpoint_fix

Handle set/get state for DDP, remove stream which cant be pickled.
parents 5c6144e6 f1f97f9f
...@@ -87,6 +87,17 @@ class DistributedDataParallel(Module): ...@@ -87,6 +87,17 @@ class DistributedDataParallel(Module):
self.create_hooks() self.create_hooks()
flat_dist_call([param.data for param in self.module.parameters()], dist.broadcast, (0,) ) flat_dist_call([param.data for param in self.module.parameters()], dist.broadcast, (0,) )
def __setstate__(self, state):
super(DistributedDataParallel, self).__setstate__(state)
self.reduction_stream = torch.cuda.Stream()
def __getstate__(self, state):
attrs = copy.copy(self.__dict__)
if dist._backend != dist.dist_backend.NCCL:
del attrs['self.reduction_stream']
return attrs
def create_hooks(self): def create_hooks(self):
#all reduce gradient hook #all reduce gradient hook
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment