Commit 53eae198 authored by Deyu Fu's avatar Deyu Fu Committed by mcarilli
Browse files

[novograd] move exp_avg_sq to param device in load_state_dict (#459)

parent dec4fdd6
...@@ -95,6 +95,14 @@ class FusedNovoGrad(torch.optim.Optimizer): ...@@ -95,6 +95,14 @@ class FusedNovoGrad(torch.optim.Optimizer):
else: else:
super(FusedNovoGrad, self).zero_grad() super(FusedNovoGrad, self).zero_grad()
def load_state_dict(self, state_dict):
super(FusedNovoGrad, self).load_state_dict(state_dict)
# in case exp_avg_sq is not on the same device as params, move it there
for group in self.param_groups:
if len(group['params']) > 0:
group['exp_avg_sq'][0] = group['exp_avg_sq'][0].to(group['params'][0].device)
group['exp_avg_sq'][1] = group['exp_avg_sq'][1].to(group['params'][0].device)
def step(self, closure=None): def step(self, closure=None):
"""Performs a single optimization step. """Performs a single optimization step.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment