Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
9a08acd8
Commit
9a08acd8
authored
Mar 18, 2022
by
Zhaocheng Zhu
Browse files
fix parameter restorage in validation
parent
c83c42e8
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
1 deletion
+5
-1
train_openfold.py
train_openfold.py
+5
-1
No files found.
train_openfold.py
View file @
9a08acd8
...
...
@@ -125,7 +125,11 @@ class OpenFoldWrapper(pl.LightningModule):
def
validation_step
(
self
,
batch
,
batch_idx
):
# At the start of validation, load the EMA weights
if
(
self
.
cached_weights
is
None
):
self
.
cached_weights
=
self
.
model
.
state_dict
()
# load_state_dict() is an in-place operation
# it will change the content in any reference of model.state_dict()
# therefore we need to explicitly clone the parameters
clone_param
=
lambda
t
:
t
.
clone
().
detach
()
self
.
cached_weights
=
tensor_tree_map
(
clone_param
,
self
.
model
.
state_dict
())
self
.
model
.
load_state_dict
(
self
.
ema
.
state_dict
()[
"params"
])
# Run the model
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment