Commit 9a08acd8 authored by Zhaocheng Zhu's avatar Zhaocheng Zhu
Browse files

fix parameter restorage in validation

parent c83c42e8
...@@ -125,7 +125,11 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -125,7 +125,11 @@ class OpenFoldWrapper(pl.LightningModule):
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
# At the start of validation, load the EMA weights # At the start of validation, load the EMA weights
if(self.cached_weights is None): if(self.cached_weights is None):
self.cached_weights = self.model.state_dict() # load_state_dict() is an in-place operation
# it will change the content in any reference of model.state_dict()
# therefore we need to explicitly clone the parameters
clone_param = lambda t: t.clone().detach()
self.cached_weights = tensor_tree_map(clone_param, self.model.state_dict())
self.model.load_state_dict(self.ema.state_dict()["params"]) self.model.load_state_dict(self.ema.state_dict()["params"])
# Run the model # Run the model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment