Unverified Commit 68c0cfbb authored by Peiqi Yin's avatar Peiqi Yin Committed by GitHub
Browse files

[Model] Fix diffpool loss (#3233)

parent 31772b14
...@@ -214,6 +214,8 @@ class DiffPool(nn.Module): ...@@ -214,6 +214,8 @@ class DiffPool(nn.Module):
#softmax + CE #softmax + CE
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
loss = criterion(pred, label) loss = criterion(pred, label)
for key, value in self.first_diffpool_layer.loss_log.items():
loss += value
for diffpool_layer in self.diffpool_layers: for diffpool_layer in self.diffpool_layers:
for key, value in diffpool_layer.loss_log.items(): for key, value in diffpool_layer.loss_log.items():
loss += value loss += value
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment