Commit 1d45fada authored by Michael Carilli's avatar Michael Carilli
Browse files

Fix to enable freezing params

parent 2063287b
......@@ -136,7 +136,7 @@ class DistributedDataParallel(Module):
torch.cuda.current_stream().wait_stream(self.reduction_stream)
for param_i, param in enumerate(list(self.module.parameters())):
for param_i, param in enumerate([p for p in self.module.parameters() if p.requires_grad]):
def wrapper(param_i):
def allreduce_hook(*unused):
......@@ -203,7 +203,7 @@ class DistributedDataParallel(Module):
def forward(self, *inputs, **kwargs):
param_list = [param for param in list(self.module.parameters()) if param.requires_grad]
param_list = [param for param in self.module.parameters() if param.requires_grad]
#Force needs_refresh to True if there are shared params
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment