Commit 41c68b1e authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix bug in training script

parent 1d657f55
......@@ -73,7 +73,7 @@ class OpenFoldWrapper(pl.LightningModule):
def validation_step(self, batch, batch_idx):
# At the start of validation, load the EMA weights
if(self.cached_weights is None):
self.cached_weights = model.state_dict()
self.cached_weights = self.model.state_dict()
self.model.load_state_dict(self.ema.state_dict()["params"])
# Calculate validation loss
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment