Commit fd95b577 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix undefined reference in training script

parent e6263425
...@@ -63,7 +63,7 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -63,7 +63,7 @@ class OpenFoldWrapper(pl.LightningModule):
def forward(self, batch): def forward(self, batch):
return self.model(batch) return self.model(batch)
def _log(self, loss, loss_breakdown, train=True): def _log(self, loss, loss_breakdown, batch, train=True):
phase = "train" if train else "val" phase = "train" if train else "val"
for loss_name, indiv_loss in loss_breakdown.items(): for loss_name, indiv_loss in loss_breakdown.items():
self.log( self.log(
...@@ -109,7 +109,7 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -109,7 +109,7 @@ class OpenFoldWrapper(pl.LightningModule):
) )
# Log it # Log it
self._log(loss_breakdown) self._log(loss_breakdown, batch)
return loss return loss
...@@ -138,7 +138,7 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -138,7 +138,7 @@ class OpenFoldWrapper(pl.LightningModule):
outputs, batch, _return_breakdown=True outputs, batch, _return_breakdown=True
) )
self._log(loss_breakdown, train=False) self._log(loss_breakdown, batch, train=False)
def validation_epoch_end(self, _): def validation_epoch_end(self, _):
# Restore the model weights to normal # Restore the model weights to normal
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment