Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
53eae198
Commit
53eae198
authored
Aug 29, 2019
by
Deyu Fu
Committed by
mcarilli
Aug 29, 2019
Browse files
[novograd] move exp_avg_sq to param device in load_state_dict (#459)
parent
dec4fdd6
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
0 deletions
+8
-0
apex/optimizers/fused_novograd.py
apex/optimizers/fused_novograd.py
+8
-0
No files found.
apex/optimizers/fused_novograd.py
View file @
53eae198
...
@@ -95,6 +95,14 @@ class FusedNovoGrad(torch.optim.Optimizer):
...
@@ -95,6 +95,14 @@ class FusedNovoGrad(torch.optim.Optimizer):
else
:
else
:
super
(
FusedNovoGrad
,
self
).
zero_grad
()
super
(
FusedNovoGrad
,
self
).
zero_grad
()
def
load_state_dict
(
self
,
state_dict
):
super
(
FusedNovoGrad
,
self
).
load_state_dict
(
state_dict
)
# in case exp_avg_sq is not on the same device as params, move it there
for
group
in
self
.
param_groups
:
if
len
(
group
[
'params'
])
>
0
:
group
[
'exp_avg_sq'
][
0
]
=
group
[
'exp_avg_sq'
][
0
].
to
(
group
[
'params'
][
0
].
device
)
group
[
'exp_avg_sq'
][
1
]
=
group
[
'exp_avg_sq'
][
1
].
to
(
group
[
'params'
][
0
].
device
)
def
step
(
self
,
closure
=
None
):
def
step
(
self
,
closure
=
None
):
"""Performs a single optimization step.
"""Performs a single optimization step.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment