Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
1dcff6aa
Commit
1dcff6aa
authored
Oct 19, 2021
by
Gustaf Ahdritz
Browse files
Update EMA
parent
12aa565e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
1 deletion
+9
-1
openfold/utils/exponential_moving_average.py
openfold/utils/exponential_moving_average.py
+9
-1
No files found.
openfold/utils/exponential_moving_average.py
View file @
1dcff6aa
...
...
@@ -3,6 +3,8 @@ import copy
import
torch
import
torch.nn
as
nn
from
openfold.utils.tensor_utils
import
tensor_tree_map
class
ExponentialMovingAverage
:
"""
...
...
@@ -27,8 +29,14 @@ class ExponentialMovingAverage:
"""
super
(
ExponentialMovingAverage
,
self
).
__init__
()
self
.
params
=
copy
.
deepcopy
(
model
.
state_dict
())
clone_param
=
lambda
t
:
t
.
clone
().
detach
()
self
.
params
=
tensor_tree_map
(
clone_param
,
model
.
state_dict
())
self
.
decay
=
decay
self
.
device
=
next
(
model
.
parameters
()).
device
def
to
(
self
,
device
):
self
.
params
=
tensor_tree_map
(
lambda
t
:
t
.
to
(
device
),
self
.
params
)
self
.
device
=
device
def
_update_state_dict_
(
self
,
update
,
state_dict
):
with
torch
.
no_grad
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment