Commit 12aa565e authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix loss bug

parent 49767099
...@@ -165,7 +165,7 @@ def backbone_loss( ...@@ -165,7 +165,7 @@ def backbone_loss(
) )
# Take the mean over the layer dimension # Take the mean over the layer dimension
fape_loss = torch.mean(fape_loss, dim=-1) fape_loss = torch.mean(fape_loss)
return fape_loss return fape_loss
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment