Unverified Commit 36d5b8b0 authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

MaskFormer, Mask2Former - replace einsum for tracing (#25297)

* Replace einsum with ops for tracing

* Fix comment
parent dedd1116
......@@ -359,7 +359,7 @@ def pair_wise_dice_loss(inputs: Tensor, labels: Tensor) -> Tensor:
`torch.Tensor`: The computed loss between each pairs.
"""
inputs = inputs.sigmoid().flatten(1)
numerator = 2 * torch.einsum("nc,mc->nm", inputs, labels)
numerator = 2 * torch.matmul(inputs, labels.T)
# using broadcasting to get a [num_queries, NUM_CLASSES] matrix
denominator = inputs.sum(-1)[:, None] + labels.sum(-1)[None, :]
loss = 1 - (numerator + 1) / (denominator + 1)
......@@ -387,9 +387,9 @@ def pair_wise_sigmoid_cross_entropy_loss(inputs: torch.Tensor, labels: torch.Ten
cross_entropy_loss_pos = criterion(inputs, torch.ones_like(inputs))
cross_entropy_loss_neg = criterion(inputs, torch.zeros_like(inputs))
loss = torch.einsum("nc,mc->nm", cross_entropy_loss_pos, labels) + torch.einsum(
"nc,mc->nm", cross_entropy_loss_neg, (1 - labels)
)
loss_pos = torch.matmul(cross_entropy_loss_pos, labels.T)
loss_neg = torch.matmul(cross_entropy_loss_neg, (1 - labels).T)
loss = loss_pos + loss_neg
loss = loss / height_and_width
return loss
......@@ -2012,7 +2012,12 @@ class Mask2FormerMaskPredictor(nn.Module):
mask_embeddings = self.mask_embedder(outputs.transpose(0, 1))
# Sum up over the channels
outputs_mask = torch.einsum("bqc, bchw -> bqhw", mask_embeddings, pixel_embeddings)
# (batch_size, num_queries, num_channels, 1, 1)
mask_embeddings = mask_embeddings.unsqueeze(-1).unsqueeze(-1)
# (batch_size, 1, num_channels, height, width)
pixel_embeddings = pixel_embeddings.unsqueeze(1)
# (batch_size, num_queries, height, width)
outputs_mask = (mask_embeddings * pixel_embeddings).sum(2)
attention_mask = nn.functional.interpolate(
outputs_mask, size=attention_mask_target_size, mode="bilinear", align_corners=False
......
......@@ -355,7 +355,7 @@ def pair_wise_dice_loss(inputs: Tensor, labels: Tensor) -> Tensor:
`torch.Tensor`: The computed loss between each pairs.
"""
inputs = inputs.sigmoid().flatten(1)
numerator = 2 * torch.einsum("nc,mc->nm", inputs, labels)
numerator = 2 * torch.matmul(inputs, labels.T)
# using broadcasting to get a [num_queries, NUM_CLASSES] matrix
denominator = inputs.sum(-1)[:, None] + labels.sum(-1)[None, :]
loss = 1 - (numerator + 1) / (denominator + 1)
......@@ -397,7 +397,7 @@ def pair_wise_sigmoid_focal_loss(inputs: Tensor, labels: Tensor, alpha: float =
focal_neg = (prob**gamma) * cross_entropy_loss_neg
focal_neg *= 1 - alpha
loss = torch.einsum("nc,mc->nm", focal_pos, labels) + torch.einsum("nc,mc->nm", focal_neg, (1 - labels))
loss = torch.matmul(focal_pos, labels.T) + torch.matmul(focal_neg, (1 - labels).T)
return loss / height_and_width
......@@ -1712,7 +1712,13 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
# get the masks
mask_embeddings = self.mask_embedder(stacked_transformer_decoder_outputs)
# sum up over the channels for each embedding
binaries_masks = torch.einsum("lbqc, bchw -> lbqhw", mask_embeddings, pixel_embeddings)
# (num_embeddings, batch_size, num_queries, num_channels, 1, 1)
mask_embeddings = mask_embeddings.unsqueeze(-1).unsqueeze(-1)
# (1, batch_size, 1, num_channels, height, width)
pixel_embeddings = pixel_embeddings.unsqueeze(0).unsqueeze(2)
# (num_embeddings, batch_size, num_queries, height, width)
binaries_masks = (mask_embeddings * pixel_embeddings).sum(dim=3)
masks_queries_logits = binaries_masks[-1]
# go til [:-1] because the last one is always used
for aux_binary_masks, aux_classes in zip(binaries_masks[:-1], classes[:-1]):
......@@ -1727,7 +1733,12 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
# get the masks
mask_embeddings = self.mask_embedder(transformer_decoder_hidden_states)
# sum up over the channels
masks_queries_logits = torch.einsum("bqc, bchw -> bqhw", mask_embeddings, pixel_embeddings)
# (batch_size, num_queries, num_channels, 1, 1)
mask_embeddings = mask_embeddings.unsqueeze(-1).unsqueeze(-1)
# (batch_size, 1, num_channels, height, width)
pixel_embeddings = pixel_embeddings.unsqueeze(1)
# (batch_size, num_queries, height, width)
masks_queries_logits = (mask_embeddings * pixel_embeddings).sum(dim=2)
return class_queries_logits, masks_queries_logits, auxiliary_logits
......
......@@ -167,7 +167,7 @@ def pair_wise_dice_loss(inputs: Tensor, labels: Tensor) -> Tensor:
`torch.Tensor`: The computed loss between each pairs.
"""
inputs = inputs.sigmoid().flatten(1)
numerator = 2 * torch.einsum("nc,mc->nm", inputs, labels)
numerator = 2 * torch.matmul(inputs, labels.T)
# using broadcasting to get a [num_queries, NUM_CLASSES] matrix
denominator = inputs.sum(-1)[:, None] + labels.sum(-1)[None, :]
loss = 1 - (numerator + 1) / (denominator + 1)
......@@ -196,9 +196,9 @@ def pair_wise_sigmoid_cross_entropy_loss(inputs: torch.Tensor, labels: torch.Ten
cross_entropy_loss_pos = criterion(inputs, torch.ones_like(inputs))
cross_entropy_loss_neg = criterion(inputs, torch.zeros_like(inputs))
loss = torch.einsum("nc,mc->nm", cross_entropy_loss_pos, labels) + torch.einsum(
"nc,mc->nm", cross_entropy_loss_neg, (1 - labels)
)
loss_pos = torch.matmul(cross_entropy_loss_pos, labels.T)
loss_neg = torch.matmul(cross_entropy_loss_neg, (1 - labels).T)
loss = loss_pos + loss_neg
loss = loss / height_and_width
return loss
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment