Commit b4af1c94 authored by zehaos's avatar zehaos
Browse files

Fix index error when class_agnostic is true.

parent c95c6373
......@@ -97,7 +97,11 @@ class FCNMaskHead(nn.Module):
def loss(self, mask_pred, mask_targets, labels):
loss = dict()
loss_mask = mask_cross_entropy(mask_pred, mask_targets, labels)
if self.class_agnostic:
loss_mask = mask_cross_entropy(mask_pred, mask_targets,
torch.zeros_like(labels))
else:
loss_mask = mask_cross_entropy(mask_pred, mask_targets, labels)
loss['loss_mask'] = loss_mask
return loss
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment