Commit 1d45fada authored by Michael Carilli's avatar Michael Carilli
Browse files

Fix to enable freezing params

parent 2063287b
...@@ -136,7 +136,7 @@ class DistributedDataParallel(Module): ...@@ -136,7 +136,7 @@ class DistributedDataParallel(Module):
torch.cuda.current_stream().wait_stream(self.reduction_stream) torch.cuda.current_stream().wait_stream(self.reduction_stream)
for param_i, param in enumerate(list(self.module.parameters())): for param_i, param in enumerate([p for p in self.module.parameters() if p.requires_grad]):
def wrapper(param_i): def wrapper(param_i):
def allreduce_hook(*unused): def allreduce_hook(*unused):
...@@ -203,7 +203,7 @@ class DistributedDataParallel(Module): ...@@ -203,7 +203,7 @@ class DistributedDataParallel(Module):
def forward(self, *inputs, **kwargs): def forward(self, *inputs, **kwargs):
param_list = [param for param in list(self.module.parameters()) if param.requires_grad] param_list = [param for param in self.module.parameters() if param.requires_grad]
#Force needs_refresh to True if there are shared params #Force needs_refresh to True if there are shared params
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment