Commit ff29bddb authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix another undefined reference in training script

parent 8f9865c3
......@@ -63,7 +63,7 @@ class OpenFoldWrapper(pl.LightningModule):
def forward(self, batch):
return self.model(batch)
def _log(self, loss_breakdown, batch, train=True):
def _log(self, loss_breakdown, batch, outputs, train=True):
phase = "train" if train else "val"
for loss_name, indiv_loss in loss_breakdown.items():
self.log(
......@@ -109,7 +109,7 @@ class OpenFoldWrapper(pl.LightningModule):
)
# Log it
self._log(loss_breakdown, batch)
self._log(loss_breakdown, batch, outputs)
return loss
......@@ -138,7 +138,7 @@ class OpenFoldWrapper(pl.LightningModule):
outputs, batch, _return_breakdown=True
)
self._log(loss_breakdown, batch, train=False)
self._log(loss_breakdown, batch, outputs, train=False)
def validation_epoch_end(self, _):
# Restore the model weights to normal
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment