Unverified Commit fbbb0479 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz Committed by GitHub
Browse files

Merge pull request #90 from KiddoZhu/main

Fix parameter restorage after loading EMA parameters 
parents c83c42e8 9a08acd8
......@@ -125,7 +125,11 @@ class OpenFoldWrapper(pl.LightningModule):
def validation_step(self, batch, batch_idx):
# At the start of validation, load the EMA weights
if(self.cached_weights is None):
self.cached_weights = self.model.state_dict()
# load_state_dict() is an in-place operation
# it will change the content in any reference of model.state_dict()
# therefore we need to explicitly clone the parameters
clone_param = lambda t: t.clone().detach()
self.cached_weights = tensor_tree_map(clone_param, self.model.state_dict())
self.model.load_state_dict(self.ema.state_dict()["params"])
# Run the model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment