Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
fbbb0479
Unverified
Commit
fbbb0479
authored
Mar 18, 2022
by
Gustaf Ahdritz
Committed by
GitHub
Mar 18, 2022
Browse files
Merge pull request #90 from KiddoZhu/main
Fix parameter restorage after loading EMA parameters
parents
c83c42e8
9a08acd8
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
1 deletion
+5
-1
train_openfold.py
train_openfold.py
+5
-1
No files found.
train_openfold.py
View file @
fbbb0479
...
@@ -125,7 +125,11 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -125,7 +125,11 @@ class OpenFoldWrapper(pl.LightningModule):
def
validation_step
(
self
,
batch
,
batch_idx
):
def
validation_step
(
self
,
batch
,
batch_idx
):
# At the start of validation, load the EMA weights
# At the start of validation, load the EMA weights
if
(
self
.
cached_weights
is
None
):
if
(
self
.
cached_weights
is
None
):
self
.
cached_weights
=
self
.
model
.
state_dict
()
# load_state_dict() is an in-place operation
# it will change the content in any reference of model.state_dict()
# therefore we need to explicitly clone the parameters
clone_param
=
lambda
t
:
t
.
clone
().
detach
()
self
.
cached_weights
=
tensor_tree_map
(
clone_param
,
self
.
model
.
state_dict
())
self
.
model
.
load_state_dict
(
self
.
ema
.
state_dict
()[
"params"
])
self
.
model
.
load_state_dict
(
self
.
ema
.
state_dict
()[
"params"
])
# Run the model
# Run the model
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment