"git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "ee11621ce43097762612df46f6e34bec9cdc915d"
Commit ff29bddb authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix another undefined reference in training script

parent 8f9865c3
...@@ -63,7 +63,7 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -63,7 +63,7 @@ class OpenFoldWrapper(pl.LightningModule):
def forward(self, batch): def forward(self, batch):
return self.model(batch) return self.model(batch)
def _log(self, loss_breakdown, batch, train=True): def _log(self, loss_breakdown, batch, outputs, train=True):
phase = "train" if train else "val" phase = "train" if train else "val"
for loss_name, indiv_loss in loss_breakdown.items(): for loss_name, indiv_loss in loss_breakdown.items():
self.log( self.log(
...@@ -109,7 +109,7 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -109,7 +109,7 @@ class OpenFoldWrapper(pl.LightningModule):
) )
# Log it # Log it
self._log(loss_breakdown, batch) self._log(loss_breakdown, batch, outputs)
return loss return loss
...@@ -138,7 +138,7 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -138,7 +138,7 @@ class OpenFoldWrapper(pl.LightningModule):
outputs, batch, _return_breakdown=True outputs, batch, _return_breakdown=True
) )
self._log(loss_breakdown, batch, train=False) self._log(loss_breakdown, batch, outputs, train=False)
def validation_epoch_end(self, _): def validation_epoch_end(self, _):
# Restore the model weights to normal # Restore the model weights to normal
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment