Unverified Commit 36d5b8b0 authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

MaskFormer, Mask2Former - replace einsum for tracing (#25297)

* Replace einsum with ops for tracing

* Fix comment
parent dedd1116
...@@ -359,7 +359,7 @@ def pair_wise_dice_loss(inputs: Tensor, labels: Tensor) -> Tensor: ...@@ -359,7 +359,7 @@ def pair_wise_dice_loss(inputs: Tensor, labels: Tensor) -> Tensor:
`torch.Tensor`: The computed loss between each pairs. `torch.Tensor`: The computed loss between each pairs.
""" """
inputs = inputs.sigmoid().flatten(1) inputs = inputs.sigmoid().flatten(1)
numerator = 2 * torch.einsum("nc,mc->nm", inputs, labels) numerator = 2 * torch.matmul(inputs, labels.T)
# using broadcasting to get a [num_queries, NUM_CLASSES] matrix # using broadcasting to get a [num_queries, NUM_CLASSES] matrix
denominator = inputs.sum(-1)[:, None] + labels.sum(-1)[None, :] denominator = inputs.sum(-1)[:, None] + labels.sum(-1)[None, :]
loss = 1 - (numerator + 1) / (denominator + 1) loss = 1 - (numerator + 1) / (denominator + 1)
...@@ -387,9 +387,9 @@ def pair_wise_sigmoid_cross_entropy_loss(inputs: torch.Tensor, labels: torch.Ten ...@@ -387,9 +387,9 @@ def pair_wise_sigmoid_cross_entropy_loss(inputs: torch.Tensor, labels: torch.Ten
cross_entropy_loss_pos = criterion(inputs, torch.ones_like(inputs)) cross_entropy_loss_pos = criterion(inputs, torch.ones_like(inputs))
cross_entropy_loss_neg = criterion(inputs, torch.zeros_like(inputs)) cross_entropy_loss_neg = criterion(inputs, torch.zeros_like(inputs))
loss = torch.einsum("nc,mc->nm", cross_entropy_loss_pos, labels) + torch.einsum( loss_pos = torch.matmul(cross_entropy_loss_pos, labels.T)
"nc,mc->nm", cross_entropy_loss_neg, (1 - labels) loss_neg = torch.matmul(cross_entropy_loss_neg, (1 - labels).T)
) loss = loss_pos + loss_neg
loss = loss / height_and_width loss = loss / height_and_width
return loss return loss
...@@ -2012,7 +2012,12 @@ class Mask2FormerMaskPredictor(nn.Module): ...@@ -2012,7 +2012,12 @@ class Mask2FormerMaskPredictor(nn.Module):
mask_embeddings = self.mask_embedder(outputs.transpose(0, 1)) mask_embeddings = self.mask_embedder(outputs.transpose(0, 1))
# Sum up over the channels # Sum up over the channels
outputs_mask = torch.einsum("bqc, bchw -> bqhw", mask_embeddings, pixel_embeddings) # (batch_size, num_queries, num_channels, 1, 1)
mask_embeddings = mask_embeddings.unsqueeze(-1).unsqueeze(-1)
# (batch_size, 1, num_channels, height, width)
pixel_embeddings = pixel_embeddings.unsqueeze(1)
# (batch_size, num_queries, height, width)
outputs_mask = (mask_embeddings * pixel_embeddings).sum(2)
attention_mask = nn.functional.interpolate( attention_mask = nn.functional.interpolate(
outputs_mask, size=attention_mask_target_size, mode="bilinear", align_corners=False outputs_mask, size=attention_mask_target_size, mode="bilinear", align_corners=False
......
...@@ -355,7 +355,7 @@ def pair_wise_dice_loss(inputs: Tensor, labels: Tensor) -> Tensor: ...@@ -355,7 +355,7 @@ def pair_wise_dice_loss(inputs: Tensor, labels: Tensor) -> Tensor:
`torch.Tensor`: The computed loss between each pairs. `torch.Tensor`: The computed loss between each pairs.
""" """
inputs = inputs.sigmoid().flatten(1) inputs = inputs.sigmoid().flatten(1)
numerator = 2 * torch.einsum("nc,mc->nm", inputs, labels) numerator = 2 * torch.matmul(inputs, labels.T)
# using broadcasting to get a [num_queries, NUM_CLASSES] matrix # using broadcasting to get a [num_queries, NUM_CLASSES] matrix
denominator = inputs.sum(-1)[:, None] + labels.sum(-1)[None, :] denominator = inputs.sum(-1)[:, None] + labels.sum(-1)[None, :]
loss = 1 - (numerator + 1) / (denominator + 1) loss = 1 - (numerator + 1) / (denominator + 1)
...@@ -397,7 +397,7 @@ def pair_wise_sigmoid_focal_loss(inputs: Tensor, labels: Tensor, alpha: float = ...@@ -397,7 +397,7 @@ def pair_wise_sigmoid_focal_loss(inputs: Tensor, labels: Tensor, alpha: float =
focal_neg = (prob**gamma) * cross_entropy_loss_neg focal_neg = (prob**gamma) * cross_entropy_loss_neg
focal_neg *= 1 - alpha focal_neg *= 1 - alpha
loss = torch.einsum("nc,mc->nm", focal_pos, labels) + torch.einsum("nc,mc->nm", focal_neg, (1 - labels)) loss = torch.matmul(focal_pos, labels.T) + torch.matmul(focal_neg, (1 - labels).T)
return loss / height_and_width return loss / height_and_width
...@@ -1712,7 +1712,13 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel): ...@@ -1712,7 +1712,13 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
# get the masks # get the masks
mask_embeddings = self.mask_embedder(stacked_transformer_decoder_outputs) mask_embeddings = self.mask_embedder(stacked_transformer_decoder_outputs)
# sum up over the channels for each embedding # sum up over the channels for each embedding
binaries_masks = torch.einsum("lbqc, bchw -> lbqhw", mask_embeddings, pixel_embeddings) # (num_embeddings, batch_size, num_queries, num_channels, 1, 1)
mask_embeddings = mask_embeddings.unsqueeze(-1).unsqueeze(-1)
# (1, batch_size, 1, num_channels, height, width)
pixel_embeddings = pixel_embeddings.unsqueeze(0).unsqueeze(2)
# (num_embeddings, batch_size, num_queries, height, width)
binaries_masks = (mask_embeddings * pixel_embeddings).sum(dim=3)
masks_queries_logits = binaries_masks[-1] masks_queries_logits = binaries_masks[-1]
# go til [:-1] because the last one is always used # go til [:-1] because the last one is always used
for aux_binary_masks, aux_classes in zip(binaries_masks[:-1], classes[:-1]): for aux_binary_masks, aux_classes in zip(binaries_masks[:-1], classes[:-1]):
...@@ -1727,7 +1733,12 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel): ...@@ -1727,7 +1733,12 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
# get the masks # get the masks
mask_embeddings = self.mask_embedder(transformer_decoder_hidden_states) mask_embeddings = self.mask_embedder(transformer_decoder_hidden_states)
# sum up over the channels # sum up over the channels
masks_queries_logits = torch.einsum("bqc, bchw -> bqhw", mask_embeddings, pixel_embeddings) # (batch_size, num_queries, num_channels, 1, 1)
mask_embeddings = mask_embeddings.unsqueeze(-1).unsqueeze(-1)
# (batch_size, 1, num_channels, height, width)
pixel_embeddings = pixel_embeddings.unsqueeze(1)
# (batch_size, num_queries, height, width)
masks_queries_logits = (mask_embeddings * pixel_embeddings).sum(dim=2)
return class_queries_logits, masks_queries_logits, auxiliary_logits return class_queries_logits, masks_queries_logits, auxiliary_logits
......
...@@ -167,7 +167,7 @@ def pair_wise_dice_loss(inputs: Tensor, labels: Tensor) -> Tensor: ...@@ -167,7 +167,7 @@ def pair_wise_dice_loss(inputs: Tensor, labels: Tensor) -> Tensor:
`torch.Tensor`: The computed loss between each pairs. `torch.Tensor`: The computed loss between each pairs.
""" """
inputs = inputs.sigmoid().flatten(1) inputs = inputs.sigmoid().flatten(1)
numerator = 2 * torch.einsum("nc,mc->nm", inputs, labels) numerator = 2 * torch.matmul(inputs, labels.T)
# using broadcasting to get a [num_queries, NUM_CLASSES] matrix # using broadcasting to get a [num_queries, NUM_CLASSES] matrix
denominator = inputs.sum(-1)[:, None] + labels.sum(-1)[None, :] denominator = inputs.sum(-1)[:, None] + labels.sum(-1)[None, :]
loss = 1 - (numerator + 1) / (denominator + 1) loss = 1 - (numerator + 1) / (denominator + 1)
...@@ -196,9 +196,9 @@ def pair_wise_sigmoid_cross_entropy_loss(inputs: torch.Tensor, labels: torch.Ten ...@@ -196,9 +196,9 @@ def pair_wise_sigmoid_cross_entropy_loss(inputs: torch.Tensor, labels: torch.Ten
cross_entropy_loss_pos = criterion(inputs, torch.ones_like(inputs)) cross_entropy_loss_pos = criterion(inputs, torch.ones_like(inputs))
cross_entropy_loss_neg = criterion(inputs, torch.zeros_like(inputs)) cross_entropy_loss_neg = criterion(inputs, torch.zeros_like(inputs))
loss = torch.einsum("nc,mc->nm", cross_entropy_loss_pos, labels) + torch.einsum( loss_pos = torch.matmul(cross_entropy_loss_pos, labels.T)
"nc,mc->nm", cross_entropy_loss_neg, (1 - labels) loss_neg = torch.matmul(cross_entropy_loss_neg, (1 - labels).T)
) loss = loss_pos + loss_neg
loss = loss / height_and_width loss = loss / height_and_width
return loss return loss
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment