Unverified Commit 3b6e95ec authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

[`Mask2Former`] Move normalization for numerical stability (#29542)

* Move normalization for numerical stability

* Apply suggestions from code review

Remove useless x=x line

* PR comment - normalize later to preserve var name meaning
parent 350c5d15
...@@ -377,10 +377,9 @@ def pair_wise_sigmoid_cross_entropy_loss(inputs: torch.Tensor, labels: torch.Ten ...@@ -377,10 +377,9 @@ def pair_wise_sigmoid_cross_entropy_loss(inputs: torch.Tensor, labels: torch.Ten
cross_entropy_loss_pos = criterion(inputs, torch.ones_like(inputs)) cross_entropy_loss_pos = criterion(inputs, torch.ones_like(inputs))
cross_entropy_loss_neg = criterion(inputs, torch.zeros_like(inputs)) cross_entropy_loss_neg = criterion(inputs, torch.zeros_like(inputs))
loss_pos = torch.matmul(cross_entropy_loss_pos, labels.T) loss_pos = torch.matmul(cross_entropy_loss_pos / height_and_width, labels.T)
loss_neg = torch.matmul(cross_entropy_loss_neg, (1 - labels).T) loss_neg = torch.matmul(cross_entropy_loss_neg / height_and_width, (1 - labels).T)
loss = loss_pos + loss_neg loss = loss_pos + loss_neg
loss = loss / height_and_width
return loss return loss
......
...@@ -201,10 +201,9 @@ def pair_wise_sigmoid_cross_entropy_loss(inputs: torch.Tensor, labels: torch.Ten ...@@ -201,10 +201,9 @@ def pair_wise_sigmoid_cross_entropy_loss(inputs: torch.Tensor, labels: torch.Ten
cross_entropy_loss_pos = criterion(inputs, torch.ones_like(inputs)) cross_entropy_loss_pos = criterion(inputs, torch.ones_like(inputs))
cross_entropy_loss_neg = criterion(inputs, torch.zeros_like(inputs)) cross_entropy_loss_neg = criterion(inputs, torch.zeros_like(inputs))
loss_pos = torch.matmul(cross_entropy_loss_pos, labels.T) loss_pos = torch.matmul(cross_entropy_loss_pos / height_and_width, labels.T)
loss_neg = torch.matmul(cross_entropy_loss_neg, (1 - labels).T) loss_neg = torch.matmul(cross_entropy_loss_neg / height_and_width, (1 - labels).T)
loss = loss_pos + loss_neg loss = loss_pos + loss_neg
loss = loss / height_and_width
return loss return loss
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment