"app/git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "c3e62ba38a2bb6b461661f2385180894aae76164"
Unverified Commit 88a4f68f authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

[`MaskFormer`, `Mask2Former`] Use einsum where possible (#29544)

* Use einsum where possible

* Fix
parent 62478857
......@@ -34,6 +34,7 @@ from ...file_utils import (
)
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import is_torch_greater_or_equal_than_2_1
from ...utils import is_accelerate_available, logging
from ...utils.backbone_utils import load_backbone
from .configuration_mask2former import Mask2FormerConfig
......@@ -2004,12 +2005,22 @@ class Mask2FormerMaskPredictor(nn.Module):
def forward(self, outputs: torch.Tensor, pixel_embeddings: torch.Tensor, attention_mask_target_size: int = None):
mask_embeddings = self.mask_embedder(outputs.transpose(0, 1))
# Equivalent to einsum('bqc, bchw -> bqhw') but jit friendly
batch_size, num_queries, num_channels = mask_embeddings.shape
_, _, height, width = pixel_embeddings.shape
outputs_mask = torch.zeros((batch_size, num_queries, height, width), device=mask_embeddings.device)
for c in range(num_channels):
outputs_mask += mask_embeddings[..., c][..., None, None] * pixel_embeddings[:, None, c]
is_tracing = (
torch.jit.is_tracing()
or isinstance(outputs, torch.fx.Proxy)
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
)
# Sum up over the channels
if is_tracing and not is_torch_greater_or_equal_than_2_1:
# Equivalent to einsum('bqc, bchw -> bqhw') but jit friendly
batch_size, num_queries, num_channels = mask_embeddings.shape
_, _, height, width = pixel_embeddings.shape
outputs_mask = torch.zeros((batch_size, num_queries, height, width), device=mask_embeddings.device)
for c in range(num_channels):
outputs_mask += mask_embeddings[..., c][..., None, None] * pixel_embeddings[:, None, c]
else:
outputs_mask = torch.einsum("bqc, bchw -> bqhw", mask_embeddings, pixel_embeddings)
attention_mask = nn.functional.interpolate(
outputs_mask, size=attention_mask_target_size, mode="bilinear", align_corners=False
......
......@@ -27,6 +27,7 @@ from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_outputs import BaseModelOutputWithCrossAttentions
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import is_torch_greater_or_equal_than_2_1
from ...utils import (
ModelOutput,
add_start_docstrings,
......@@ -1762,6 +1763,12 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
pixel_embeddings = outputs.pixel_decoder_last_hidden_state
# get the auxiliary predictions (one for each decoder's layer)
auxiliary_logits: List[str, Tensor] = []
is_tracing = (
torch.jit.is_tracing()
or isinstance(outputs, torch.fx.Proxy)
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
)
# This code is a little bit cumbersome, an improvement can be to return a list of predictions. If we have auxiliary loss then we are going to return more than one element in the list
if self.config.use_auxiliary_loss:
stacked_transformer_decoder_outputs = torch.stack(outputs.transformer_decoder_hidden_states)
......@@ -1770,14 +1777,17 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
# get the masks
mask_embeddings = self.mask_embedder(stacked_transformer_decoder_outputs)
# Equivalent to einsum('lbqc, bchw -> lbqhw') but jit friendly
num_embeddings, batch_size, num_queries, num_channels = mask_embeddings.shape
_, _, height, width = pixel_embeddings.shape
binaries_masks = torch.zeros(
(num_embeddings, batch_size, num_queries, height, width), device=mask_embeddings.device
)
for c in range(num_channels):
binaries_masks += mask_embeddings[..., c][..., None, None] * pixel_embeddings[None, :, None, c]
if is_tracing and not is_torch_greater_or_equal_than_2_1:
# Equivalent to einsum('lbqc, bchw -> lbqhw') but jit friendly
num_embeddings, batch_size, num_queries, num_channels = mask_embeddings.shape
_, _, height, width = pixel_embeddings.shape
binaries_masks = torch.zeros(
(num_embeddings, batch_size, num_queries, height, width), device=mask_embeddings.device
)
for c in range(num_channels):
binaries_masks += mask_embeddings[..., c][..., None, None] * pixel_embeddings[None, :, None, c]
else:
binaries_masks = torch.einsum("lbqc, bchw -> lbqhw", mask_embeddings, pixel_embeddings)
masks_queries_logits = binaries_masks[-1]
# go til [:-1] because the last one is always used
......@@ -1794,12 +1804,17 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
mask_embeddings = self.mask_embedder(transformer_decoder_hidden_states)
# sum up over the channels
# Equivalent to einsum('bqc, bchw -> bqhw') but jit friendly
batch_size, num_queries, num_channels = mask_embeddings.shape
_, _, height, width = pixel_embeddings.shape
masks_queries_logits = torch.zeros((batch_size, num_queries, height, width), device=mask_embeddings.device)
for c in range(num_channels):
masks_queries_logits += mask_embeddings[..., c][..., None, None] * pixel_embeddings[:, None, c]
if is_tracing and not is_torch_greater_or_equal_than_2_1:
# Equivalent to einsum('bqc, bchw -> bqhw') but jit friendly
batch_size, num_queries, num_channels = mask_embeddings.shape
_, _, height, width = pixel_embeddings.shape
masks_queries_logits = torch.zeros(
(batch_size, num_queries, height, width), device=mask_embeddings.device
)
for c in range(num_channels):
masks_queries_logits += mask_embeddings[..., c][..., None, None] * pixel_embeddings[:, None, c]
else:
masks_queries_logits = torch.einsum("bqc, bchw -> bqhw", mask_embeddings, pixel_embeddings)
return class_queries_logits, masks_queries_logits, auxiliary_logits
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment