Unverified Commit 88a4f68f authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

[`MaskFormer`, `Mask2Former`] Use einsum where possible (#29544)

* Use einsum where possible

* Fix
parent 62478857
...@@ -34,6 +34,7 @@ from ...file_utils import ( ...@@ -34,6 +34,7 @@ from ...file_utils import (
) )
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import is_torch_greater_or_equal_than_2_1
from ...utils import is_accelerate_available, logging from ...utils import is_accelerate_available, logging
from ...utils.backbone_utils import load_backbone from ...utils.backbone_utils import load_backbone
from .configuration_mask2former import Mask2FormerConfig from .configuration_mask2former import Mask2FormerConfig
...@@ -2004,6 +2005,13 @@ class Mask2FormerMaskPredictor(nn.Module): ...@@ -2004,6 +2005,13 @@ class Mask2FormerMaskPredictor(nn.Module):
def forward(self, outputs: torch.Tensor, pixel_embeddings: torch.Tensor, attention_mask_target_size: int = None): def forward(self, outputs: torch.Tensor, pixel_embeddings: torch.Tensor, attention_mask_target_size: int = None):
mask_embeddings = self.mask_embedder(outputs.transpose(0, 1)) mask_embeddings = self.mask_embedder(outputs.transpose(0, 1))
is_tracing = (
torch.jit.is_tracing()
or isinstance(outputs, torch.fx.Proxy)
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
)
# Sum up over the channels
if is_tracing and not is_torch_greater_or_equal_than_2_1:
# Equivalent to einsum('bqc, bchw -> bqhw') but jit friendly # Equivalent to einsum('bqc, bchw -> bqhw') but jit friendly
batch_size, num_queries, num_channels = mask_embeddings.shape batch_size, num_queries, num_channels = mask_embeddings.shape
_, _, height, width = pixel_embeddings.shape _, _, height, width = pixel_embeddings.shape
...@@ -2011,6 +2019,9 @@ class Mask2FormerMaskPredictor(nn.Module): ...@@ -2011,6 +2019,9 @@ class Mask2FormerMaskPredictor(nn.Module):
for c in range(num_channels): for c in range(num_channels):
outputs_mask += mask_embeddings[..., c][..., None, None] * pixel_embeddings[:, None, c] outputs_mask += mask_embeddings[..., c][..., None, None] * pixel_embeddings[:, None, c]
else:
outputs_mask = torch.einsum("bqc, bchw -> bqhw", mask_embeddings, pixel_embeddings)
attention_mask = nn.functional.interpolate( attention_mask = nn.functional.interpolate(
outputs_mask, size=attention_mask_target_size, mode="bilinear", align_corners=False outputs_mask, size=attention_mask_target_size, mode="bilinear", align_corners=False
) )
......
...@@ -27,6 +27,7 @@ from ...activations import ACT2FN ...@@ -27,6 +27,7 @@ from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_outputs import BaseModelOutputWithCrossAttentions from ...modeling_outputs import BaseModelOutputWithCrossAttentions
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import is_torch_greater_or_equal_than_2_1
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
add_start_docstrings, add_start_docstrings,
...@@ -1762,6 +1763,12 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel): ...@@ -1762,6 +1763,12 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
pixel_embeddings = outputs.pixel_decoder_last_hidden_state pixel_embeddings = outputs.pixel_decoder_last_hidden_state
# get the auxiliary predictions (one for each decoder's layer) # get the auxiliary predictions (one for each decoder's layer)
auxiliary_logits: List[str, Tensor] = [] auxiliary_logits: List[str, Tensor] = []
is_tracing = (
torch.jit.is_tracing()
or isinstance(outputs, torch.fx.Proxy)
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
)
# This code is a little bit cumbersome, an improvement can be to return a list of predictions. If we have auxiliary loss then we are going to return more than one element in the list # This code is a little bit cumbersome, an improvement can be to return a list of predictions. If we have auxiliary loss then we are going to return more than one element in the list
if self.config.use_auxiliary_loss: if self.config.use_auxiliary_loss:
stacked_transformer_decoder_outputs = torch.stack(outputs.transformer_decoder_hidden_states) stacked_transformer_decoder_outputs = torch.stack(outputs.transformer_decoder_hidden_states)
...@@ -1770,6 +1777,7 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel): ...@@ -1770,6 +1777,7 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
# get the masks # get the masks
mask_embeddings = self.mask_embedder(stacked_transformer_decoder_outputs) mask_embeddings = self.mask_embedder(stacked_transformer_decoder_outputs)
if is_tracing and not is_torch_greater_or_equal_than_2_1:
# Equivalent to einsum('lbqc, bchw -> lbqhw') but jit friendly # Equivalent to einsum('lbqc, bchw -> lbqhw') but jit friendly
num_embeddings, batch_size, num_queries, num_channels = mask_embeddings.shape num_embeddings, batch_size, num_queries, num_channels = mask_embeddings.shape
_, _, height, width = pixel_embeddings.shape _, _, height, width = pixel_embeddings.shape
...@@ -1778,6 +1786,8 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel): ...@@ -1778,6 +1786,8 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
) )
for c in range(num_channels): for c in range(num_channels):
binaries_masks += mask_embeddings[..., c][..., None, None] * pixel_embeddings[None, :, None, c] binaries_masks += mask_embeddings[..., c][..., None, None] * pixel_embeddings[None, :, None, c]
else:
binaries_masks = torch.einsum("lbqc, bchw -> lbqhw", mask_embeddings, pixel_embeddings)
masks_queries_logits = binaries_masks[-1] masks_queries_logits = binaries_masks[-1]
# go til [:-1] because the last one is always used # go til [:-1] because the last one is always used
...@@ -1794,12 +1804,17 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel): ...@@ -1794,12 +1804,17 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
mask_embeddings = self.mask_embedder(transformer_decoder_hidden_states) mask_embeddings = self.mask_embedder(transformer_decoder_hidden_states)
# sum up over the channels # sum up over the channels
if is_tracing and not is_torch_greater_or_equal_than_2_1:
# Equivalent to einsum('bqc, bchw -> bqhw') but jit friendly # Equivalent to einsum('bqc, bchw -> bqhw') but jit friendly
batch_size, num_queries, num_channels = mask_embeddings.shape batch_size, num_queries, num_channels = mask_embeddings.shape
_, _, height, width = pixel_embeddings.shape _, _, height, width = pixel_embeddings.shape
masks_queries_logits = torch.zeros((batch_size, num_queries, height, width), device=mask_embeddings.device) masks_queries_logits = torch.zeros(
(batch_size, num_queries, height, width), device=mask_embeddings.device
)
for c in range(num_channels): for c in range(num_channels):
masks_queries_logits += mask_embeddings[..., c][..., None, None] * pixel_embeddings[:, None, c] masks_queries_logits += mask_embeddings[..., c][..., None, None] * pixel_embeddings[:, None, c]
else:
masks_queries_logits = torch.einsum("bqc, bchw -> bqhw", mask_embeddings, pixel_embeddings)
return class_queries_logits, masks_queries_logits, auxiliary_logits return class_queries_logits, masks_queries_logits, auxiliary_logits
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment