Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
dd06b323
"git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "da38e96abfacc93f6e2fb0d7b9a141ab03435b9c"
Commit
dd06b323
authored
Oct 04, 2021
by
Gustaf Ahdritz
Browse files
Add barebones EMA
parent
ff969b98
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
57 additions
and
0 deletions
+57
-0
openfold/utils/exponential_moving_average.py
openfold/utils/exponential_moving_average.py
+57
-0
No files found.
openfold/utils/exponential_moving_average.py
0 → 100644
View file @
dd06b323
from
collections
import
OrderedDict
import
copy
import
torch
import
torch.nn
as
nn
class
ExponentialMovingAverage
:
"""
Maintains moving averages of parameters with exponential decay
At each step, the stored copy `copy` of each parameter `param` is
updated as follows:
`copy = decay * copy + (1 - decay) * param`
where `decay` is an attribute of the ExponentialMovingAverage object.
"""
def
__init__
(
self
,
model
:
nn
.
Module
,
decay
:
float
):
"""
Args:
model:
A torch.nn.Module whose parameters are to be tracked
decay:
A value (usually close to 1.) by which updates are
weighted as part of the above formula
"""
super
(
ExponentialMovingAverage
,
self
).
__init__
()
self
.
params
=
copy
.
deepcopy
(
model
.
state_dict
())
self
.
decay
=
decay
def
_update_state_dict_
(
self
,
update
,
state_dict
):
for
k
,
v
in
update
.
items
():
stored
=
state_dict
[
k
]
if
(
not
isinstance
(
v
,
torch
.
Tensor
)):
self
.
_update_state_dict_
(
v
,
stored
)
else
:
diff
=
stored
-
v
diff
*=
(
1
-
self
.
decay
)
stored
-=
diff
def
update
(
self
,
model
:
torch
.
nn
.
Module
)
->
None
:
"""
Updates the stored parameters using the state dict of the provided
module. The module should have the same structure as that used to
initialize the ExponentialMovingAverage object.
"""
self
.
_update_state_dict_
(
model
.
state_dict
(),
self
.
params
)
def
load_state_dict
(
self
,
state_dict
:
OrderedDict
)
->
None
:
self
.
params
=
state_dict
[
"params"
]
self
.
decay
=
state_dict
[
"decay"
]
def
state_dict
(
self
)
->
OrderedDict
:
return
OrderedDict
({
"params"
:
self
.
params
,
"decay"
:
self
.
decay
,
})
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment