Commit fd95b577 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix undefined reference in training script

parent e6263425
......@@ -63,7 +63,7 @@ class OpenFoldWrapper(pl.LightningModule):
def forward(self, batch):
return self.model(batch)
def _log(self, loss, loss_breakdown, train=True):
def _log(self, loss, loss_breakdown, batch, train=True):
phase = "train" if train else "val"
for loss_name, indiv_loss in loss_breakdown.items():
self.log(
......@@ -109,7 +109,7 @@ class OpenFoldWrapper(pl.LightningModule):
)
# Log it
self._log(loss_breakdown)
self._log(loss_breakdown, batch)
return loss
......@@ -138,7 +138,7 @@ class OpenFoldWrapper(pl.LightningModule):
outputs, batch, _return_breakdown=True
)
self._log(loss_breakdown, train=False)
self._log(loss_breakdown, batch, train=False)
def validation_epoch_end(self, _):
# Restore the model weights to normal
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment