Commit 12aa565e authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix loss bug

parent 49767099
......@@ -165,7 +165,7 @@ def backbone_loss(
)
# Take the mean over the layer dimension
fape_loss = torch.mean(fape_loss, dim=-1)
fape_loss = torch.mean(fape_loss)
return fape_loss
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment