Unverified Commit 2411f0e4 authored by Alara Dirik's avatar Alara Dirik Committed by GitHub
Browse files
parent 9edf3758
...@@ -103,6 +103,7 @@ MODEL_MAPPING_NAMES = OrderedDict( ...@@ -103,6 +103,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("m2m_100", "M2M100Model"), ("m2m_100", "M2M100Model"),
("marian", "MarianModel"), ("marian", "MarianModel"),
("markuplm", "MarkupLMModel"), ("markuplm", "MarkupLMModel"),
("mask2former", "Mask2FormerModel"),
("maskformer", "MaskFormerModel"), ("maskformer", "MaskFormerModel"),
("maskformer-swin", "MaskFormerSwinModel"), ("maskformer-swin", "MaskFormerSwinModel"),
("mbart", "MBartModel"), ("mbart", "MBartModel"),
...@@ -454,6 +455,7 @@ MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES = OrderedDict( ...@@ -454,6 +455,7 @@ MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES = OrderedDict(
[ [
# Model for Universal Segmentation mapping # Model for Universal Segmentation mapping
("detr", "DetrForSegmentation"), ("detr", "DetrForSegmentation"),
("mask2former", "Mask2FormerForUniversalSegmentation"),
("maskformer", "MaskFormerForInstanceSegmentation"), ("maskformer", "MaskFormerForInstanceSegmentation"),
] ]
) )
......
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
_import_structure = {
"configuration_mask2former": [
"MASK2FORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
"Mask2FormerConfig",
],
}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_mask2former"] = [
"MASK2FORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
"Mask2FormerForUniversalSegmentation",
"Mask2FormerModel",
"Mask2FormerPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_mask2former import MASK2FORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, Mask2FormerConfig
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_mask2former import (
MASK2FORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
Mask2FormerForUniversalSegmentation,
Mask2FormerModel,
Mask2FormerPreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
# coding=utf-8
# Copyright 2022 Meta Platforms, Inc.and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Mask2Former model configuration"""
import copy
from typing import Dict, List, Optional
from ...configuration_utils import PretrainedConfig
from ...utils import logging
from ..auto import CONFIG_MAPPING
MASK2FORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"facebook/mask2former-swin-small-coco-instance": (
"https://huggingface.co/facebook/mask2former-swin-small-coco-instance/blob/main/config.json"
)
# See all Mask2Former models at https://huggingface.co/models?filter=mask2former
}
logger = logging.get_logger(__name__)
class Mask2FormerConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Mask2FormerModel`]. It is used to instantiate a
Mask2Former model according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the Mask2Former
[facebook/mask2former-swin-small-coco-instance](https://huggingface.co/facebook/mask2former-swin-small-coco-instance)
architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Currently, Mask2Former only supports the [Swin Transformer](swin) as backbone.
Args:
backbone_config (`PretrainedConfig` or `dict`, *optional*, defaults to `SwinConfig()`):
The configuration of the backbone model. If unset, the configuration corresponding to
`swin-base-patch4-window12-384` will be used.
feature_size (`int`, *optional*, defaults to 256):
The features (channels) of the resulting feature maps.
mask_feature_size (`int`, *optional*, defaults to 256):
The masks' features size, this value will also be used to specify the Feature Pyramid Network features'
size.
hidden_dim (`int`, *optional*, defaults to 256):
Dimensionality of the encoder layers.
encoder_feedforward_dim (`int`, *optional*, defaults to 1024):
Dimension of feedforward network for deformable detr encoder used as part of pixel decoder.
encoder_layers (`int`, *optional*, defaults to 6):
Number of layers in the deformable detr encoder used as part of pixel decoder.
decoder_layers (`int`, *optional*, defaults to 10):
Number of layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 8):
Number of attention heads for each attention layer.
dropout (`float`, *optional*, defaults to 0.1):
The dropout probability for all fully connected layers in the embeddings, encoder.
dim_feedforward (`int`, *optional*, defaults to 2048):
Feature dimension in feedforward network for transformer decoder.
pre_norm (`bool`, *optional*, defaults to `False`):
Whether to use pre-LayerNorm or not for transformer decoder.
enforce_input_projection (`bool`, *optional*, defaults to `False`):
Whether to add an input projection 1x1 convolution even if the input channels and hidden dim are identical
in the Transformer decoder.
common_stride (`int`, *optional*, defaults to 4):
Parameter used for determining number of FPN levels used as part of pixel decoder.
ignore_value (`int`, *optional*, defaults to 255):
Category id to be ignored during training.
num_queries (`int`, *optional*, defaults to 100):
Number of queries for the decoder.
no_object_weight (`int`, *optional*, defaults to 0.1):
The weight to apply to the null (no object) class.
class_weight (`int`, *optional*, defaults to 2.0):
The weight for the cross entropy loss.
mask_weight (`int`, *optional*, defaults to 5.0):
The weight for the mask loss.
dice_weight (`int`, *optional*, defaults to 5.0):
The weight for the dice loss.
train_num_points (`str` or `function`, *optional*, defaults to 12544):
Number of points used for sampling during loss calculation.
oversample_ratio (`float`, *optional*, defaults to 3.0):
Oversampling parameter used for calculating no. of sampled points
importance_sample_ratio (`float`, *optional*, defaults to 0.75):
Ratio of points that are sampled via importance sampling.
init_std (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
init_xavier_std (`float``, *optional*, defaults to 1.0):
The scaling factor used for the Xavier initialization gain in the HM Attention map module.
use_auxiliary_loss (`boolean``, *optional*, defaults to `True`):
If `True` [`Mask2FormerForUniversalSegmentationOutput`] will contain the auxiliary losses computed using
the logits from each decoder's stage.
feature_strides (`List[int]`, *optional*, defaults to `[4, 8, 16, 32]`):
Feature strides corresponding to features generated from backbone network.
output_auxiliary_logits (`bool`, *optional*):
Should the model output its `auxiliary_logits` or not.
Examples:
```python
>>> from transformers import Mask2FormerConfig, Mask2FormerModel
>>> # Initializing a Mask2Former facebook/mask2former-swin-small-coco-instance configuration
>>> configuration = Mask2FormerConfig()
>>> # Initializing a model (with random weights) from the facebook/mask2former-swin-small-coco-instance style configuration
>>> model = Mask2FormerModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
"""
model_type = "mask2former"
backbones_supported = ["swin"]
attribute_map = {"hidden_size": "hidden_dim"}
def __init__(
self,
backbone_config: Optional[Dict] = None,
feature_size: int = 256,
mask_feature_size: int = 256,
hidden_dim: int = 256,
encoder_feedforward_dim: int = 1024,
activation_function: str = "relu",
encoder_layers: int = 6,
decoder_layers: int = 10,
num_attention_heads: int = 8,
dropout: float = 0.0,
dim_feedforward: int = 2048,
pre_norm: bool = False,
enforce_input_projection: bool = False,
common_stride: int = 4,
ignore_value: int = 255,
num_queries: int = 100,
no_object_weight: float = 0.1,
class_weight: float = 2.0,
mask_weight: float = 5.0,
dice_weight: float = 5.0,
train_num_points: int = 12544,
oversample_ratio: float = 3.0,
importance_sample_ratio: float = 0.75,
init_std: float = 0.02,
init_xavier_std: float = 1.0,
use_auxiliary_loss: bool = True,
feature_strides: List[int] = [4, 8, 16, 32],
output_auxiliary_logits: bool = None,
**kwargs,
):
if backbone_config is None:
logger.info("`backbone_config` is `None`. Initializing the config with the default `Swin` backbone.")
backbone_config = CONFIG_MAPPING["swin"](
image_size=224,
in_channels=3,
patch_size=4,
embed_dim=96,
depths=[2, 2, 18, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
drop_path_rate=0.3,
use_absolute_embeddings=False,
out_features=["stage1", "stage2", "stage3", "stage4"],
)
elif isinstance(backbone_config, dict):
backbone_model_type = backbone_config.get("model_type")
config_class = CONFIG_MAPPING[backbone_model_type]
backbone_config = config_class.from_dict(backbone_config)
self.backbone_config = backbone_config
self.feature_size = feature_size
self.mask_feature_size = mask_feature_size
self.hidden_dim = hidden_dim
self.encoder_feedforward_dim = encoder_feedforward_dim
self.activation_function = activation_function
self.encoder_layers = encoder_layers
self.decoder_layers = decoder_layers
self.num_attention_heads = num_attention_heads
self.dropout = dropout
self.dim_feedforward = dim_feedforward
self.pre_norm = pre_norm
self.enforce_input_projection = enforce_input_projection
self.common_stride = common_stride
self.ignore_value = ignore_value
self.num_queries = num_queries
self.no_object_weight = no_object_weight
self.class_weight = class_weight
self.mask_weight = mask_weight
self.dice_weight = dice_weight
self.train_num_points = train_num_points
self.oversample_ratio = oversample_ratio
self.importance_sample_ratio = importance_sample_ratio
self.init_std = init_std
self.init_xavier_std = init_xavier_std
self.use_auxiliary_loss = use_auxiliary_loss
self.feature_strides = feature_strides
self.output_auxiliary_logits = output_auxiliary_logits
self.num_hidden_layers = decoder_layers
super().__init__(**kwargs)
@classmethod
def from_backbone_config(cls, backbone_config: PretrainedConfig, **kwargs):
"""Instantiate a [`Mask2FormerConfig`] (or a derived class) from a pre-trained backbone model configuration.
Args:
backbone_config ([`PretrainedConfig`]):
The backbone configuration.
Returns:
[`Mask2FormerConfig`]: An instance of a configuration object
"""
return cls(
backbone_config=backbone_config,
**kwargs,
)
def to_dict(self) -> Dict[str, any]:
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["backbone_config"] = self.backbone_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
# coding=utf-8
# Copyright 2022 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import sys
from argparse import ArgumentParser
from dataclasses import dataclass
from pathlib import Path
from pprint import pformat
from typing import Any, Dict, Iterator, List, Set, Tuple
import torch
import torchvision.transforms as T
from PIL import Image
from torch import Tensor, nn
import requests
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import get_cfg
from detectron2.projects.deeplab import add_deeplab_config
from huggingface_hub import hf_hub_download
from transformers import (
Mask2FormerConfig,
Mask2FormerForUniversalSegmentation,
Mask2FormerModel,
MaskFormerImageProcessor,
SwinConfig,
)
from transformers.models.mask2former.modeling_mask2former import (
Mask2FormerForUniversalSegmentationOutput,
Mask2FormerModelOutput,
)
from transformers.utils import logging
StateDict = Dict[str, Tensor]
logging.set_verbosity_info()
logger = logging.get_logger()
torch.manual_seed(0)
class TrackedStateDict:
def __init__(self, to_track: Dict):
"""This class "tracks" a python dictionary by keeping track of which item is accessed.
Args:
to_track (Dict): The dictionary we wish to track
"""
self.to_track = to_track
self._seen: Set[str] = set()
def __getitem__(self, key: str) -> Any:
return self.to_track[key]
def __setitem__(self, key: str, item: Any):
self._seen.add(key)
self.to_track[key] = item
def diff(self) -> List[str]:
"""This method returns a set difference between the keys in the tracked state dict and the one we have access so far.
This is an effective method to check if we have update all the keys
Returns:
List[str]: List of keys not yet updated
"""
return set(list(self.to_track.keys())) - self._seen
def copy(self) -> Dict:
# proxy the call to the internal dictionary
return self.to_track.copy()
# We will verify our results on an image of cute cats
def prepare_img():
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
img_data = requests.get(url, stream=True).raw
im = Image.open(img_data)
return im
@dataclass
class Args:
"""Fake command line arguments needed by mask2former/detectron implementation"""
config_file: str
def setup_cfg(args: Args):
# load config from file and command-line arguments
cfg = get_cfg()
add_deeplab_config(cfg)
add_maskformer2_config(cfg)
cfg.merge_from_file(args.config_file)
cfg.freeze()
return cfg
class OriginalMask2FormerConfigToOursConverter:
def __call__(self, original_config: object) -> Mask2FormerConfig:
model = original_config.MODEL
repo_id = "huggingface/label-files"
if model.SEM_SEG_HEAD.NUM_CLASSES == 847:
filename = "mask2former-ade20k-full-id2label.json"
elif model.SEM_SEG_HEAD.NUM_CLASSES == 150:
filename = "ade20k-id2label.json"
elif model.SEM_SEG_HEAD.NUM_CLASSES == 80:
filename = "coco-detection-mmdet-id2label.json"
elif model.SEM_SEG_HEAD.NUM_CLASSES == 171:
filename = "mask2former-coco-stuff-id2label.json"
elif model.SEM_SEG_HEAD.NUM_CLASSES == 133:
filename = "coco-panoptic-id2label.json"
elif model.SEM_SEG_HEAD.NUM_CLASSES == 19:
filename = "cityscapes-id2label.json"
elif model.SEM_SEG_HEAD.NUM_CLASSES == 8:
filename = "cityscapes-instance-id2label.json"
elif model.SEM_SEG_HEAD.NUM_CLASSES == 65:
filename = "mapillary-vistas-id2label.json"
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
label2id = {label: idx for idx, label in id2label.items()}
if model.SWIN.EMBED_DIM == 96:
backbone_config = SwinConfig.from_pretrained(
"microsoft/swin-tiny-patch4-window7-224", out_features=["stage1", "stage2", "stage3", "stage4"]
)
elif model.SWIN.EMBED_DIM == 128:
backbone_config = SwinConfig(
embed_dim=128,
window_size=12,
depths=(2, 2, 18, 2),
num_heads=(4, 8, 16, 32),
out_features=["stage1", "stage2", "stage3", "stage4"],
)
elif model.SWIN.EMBED_DIM == 192:
backbone_config = SwinConfig.from_pretrained(
"microsoft/swin-large-patch4-window12-384", out_features=["stage1", "stage2", "stage3", "stage4"]
)
else:
raise ValueError(f"embed dim {model.SWIN.EMBED_DIM} not supported for Swin!")
backbone_config.drop_path_rate = model.SWIN.DROP_PATH_RATE
backbone_config.attention_probs_dropout_prob = model.SWIN.ATTN_DROP_RATE
backbone_config.depths = model.SWIN.DEPTHS
config: Mask2FormerConfig = Mask2FormerConfig(
ignore_value=model.SEM_SEG_HEAD.IGNORE_VALUE,
num_labels=model.SEM_SEG_HEAD.NUM_CLASSES,
num_queries=model.MASK_FORMER.NUM_OBJECT_QUERIES,
no_object_weight=model.MASK_FORMER.NO_OBJECT_WEIGHT,
class_weight=model.MASK_FORMER.CLASS_WEIGHT,
mask_weight=model.MASK_FORMER.MASK_WEIGHT,
dice_weight=model.MASK_FORMER.DICE_WEIGHT,
train_num_points=model.MASK_FORMER.TRAIN_NUM_POINTS,
oversample_ratio=model.MASK_FORMER.OVERSAMPLE_RATIO,
importance_sample_ratio=model.MASK_FORMER.IMPORTANCE_SAMPLE_RATIO,
init_std=0.02,
init_xavier_std=1.0,
use_auxiliary_loss=model.MASK_FORMER.DEEP_SUPERVISION,
feature_strides=[4, 8, 16, 32],
backbone_config=backbone_config,
id2label=id2label,
label2id=label2id,
feature_size=model.SEM_SEG_HEAD.CONVS_DIM,
mask_feature_size=model.SEM_SEG_HEAD.MASK_DIM,
hidden_dim=model.MASK_FORMER.HIDDEN_DIM,
encoder_layers=model.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS,
encoder_feedforward_dim=1024,
decoder_layers=model.MASK_FORMER.DEC_LAYERS,
num_attention_heads=model.MASK_FORMER.NHEADS,
dropout=model.MASK_FORMER.DROPOUT,
dim_feedforward=model.MASK_FORMER.DIM_FEEDFORWARD,
pre_norm=model.MASK_FORMER.PRE_NORM,
enforce_input_proj=model.MASK_FORMER.ENFORCE_INPUT_PROJ,
common_stride=model.SEM_SEG_HEAD.COMMON_STRIDE,
)
return config
class OriginalMask2FormerConfigToFeatureExtractorConverter:
def __call__(self, original_config: object) -> MaskFormerImageProcessor:
model = original_config.MODEL
model_input = original_config.INPUT
return MaskFormerImageProcessor(
image_mean=(torch.tensor(model.PIXEL_MEAN) / 255).tolist(),
image_std=(torch.tensor(model.PIXEL_STD) / 255).tolist(),
size=model_input.MIN_SIZE_TEST,
max_size=model_input.MAX_SIZE_TEST,
num_labels=model.SEM_SEG_HEAD.NUM_CLASSES,
ignore_index=model.SEM_SEG_HEAD.IGNORE_VALUE,
size_divisibility=32,
)
class OriginalMask2FormerCheckpointToOursConverter:
def __init__(self, original_model: nn.Module, config: Mask2FormerConfig):
self.original_model = original_model
self.config = config
def pop_all(self, renamed_keys: List[Tuple[str, str]], dst_state_dict: StateDict, src_state_dict: StateDict):
for src_key, dst_key in renamed_keys:
dst_state_dict[dst_key] = src_state_dict.pop(src_key)
def replace_maskformer_swin_backbone(
self, dst_state_dict: StateDict, src_state_dict: StateDict, config: Mask2FormerConfig
):
dst_prefix: str = "pixel_level_module.encoder"
src_prefix: str = "backbone"
renamed_keys = [
(
f"{src_prefix}.patch_embed.proj.weight",
f"{dst_prefix}.model.embeddings.patch_embeddings.projection.weight",
),
(f"{src_prefix}.patch_embed.proj.bias", f"{dst_prefix}.model.embeddings.patch_embeddings.projection.bias"),
(f"{src_prefix}.patch_embed.norm.weight", f"{dst_prefix}.model.embeddings.norm.weight"),
(f"{src_prefix}.patch_embed.norm.bias", f"{dst_prefix}.model.embeddings.norm.bias"),
]
num_layers = len(config.backbone_config.depths)
for layer_idx in range(num_layers):
for block_idx in range(config.backbone_config.depths[layer_idx]):
renamed_keys.extend(
[ # src, dst
(
f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.norm1.weight",
f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.layernorm_before.weight",
),
(
f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.norm1.bias",
f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.layernorm_before.bias",
),
(
f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.relative_position_bias_table",
f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.relative_position_bias_table",
),
]
)
# now we need to handle the attentions
# read in weights + bias of input projection layer of cross-attention
src_att_weight = src_state_dict[f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.qkv.weight"]
src_att_bias = src_state_dict[f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.qkv.bias"]
size = src_att_weight.shape[0]
offset = size // 3
dst_state_dict[
f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.query.weight"
] = src_att_weight[:offset, :]
dst_state_dict[
f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.query.bias"
] = src_att_bias[:offset]
dst_state_dict[
f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.key.weight"
] = src_att_weight[offset : offset * 2, :]
dst_state_dict[
f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.key.bias"
] = src_att_bias[offset : offset * 2]
dst_state_dict[
f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.value.weight"
] = src_att_weight[-offset:, :]
dst_state_dict[
f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.value.bias"
] = src_att_bias[-offset:]
# let's pop them
src_state_dict.pop(f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.qkv.weight")
src_state_dict.pop(f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.qkv.bias")
# proj
renamed_keys.extend(
[
(
f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.proj.weight",
f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.output.dense.weight",
),
(
f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.proj.bias",
f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.output.dense.bias",
),
]
)
# second norm
renamed_keys.extend(
[
(
f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.norm2.weight",
f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.layernorm_after.weight",
),
(
f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.norm2.bias",
f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.layernorm_after.bias",
),
]
)
# mlp
renamed_keys.extend(
[
(
f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.mlp.fc1.weight",
f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.intermediate.dense.weight",
),
(
f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.mlp.fc1.bias",
f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.intermediate.dense.bias",
),
(
f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.mlp.fc2.weight",
f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.output.dense.weight",
),
(
f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.mlp.fc2.bias",
f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.output.dense.bias",
),
]
)
renamed_keys.extend(
[
(
f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.relative_position_index",
f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.relative_position_index",
)
]
)
if layer_idx < num_layers - 1:
# patch merging
renamed_keys.extend(
[
(
f"{src_prefix}.layers.{layer_idx}.downsample.reduction.weight",
f"{dst_prefix}.model.encoder.layers.{layer_idx}.downsample.reduction.weight",
),
(
f"{src_prefix}.layers.{layer_idx}.downsample.norm.weight",
f"{dst_prefix}.model.encoder.layers.{layer_idx}.downsample.norm.weight",
),
(
f"{src_prefix}.layers.{layer_idx}.downsample.norm.bias",
f"{dst_prefix}.model.encoder.layers.{layer_idx}.downsample.norm.bias",
),
]
)
# hidden states norms
renamed_keys.extend(
[
(
f"{src_prefix}.norm{layer_idx}.weight",
f"{dst_prefix}.hidden_states_norms.{layer_idx}.weight",
),
(
f"{src_prefix}.norm{layer_idx}.bias",
f"{dst_prefix}.hidden_states_norms.{layer_idx}.bias",
),
]
)
self.pop_all(renamed_keys, dst_state_dict, src_state_dict)
def replace_swin_backbone(self, dst_state_dict: StateDict, src_state_dict: StateDict, config: Mask2FormerConfig):
dst_prefix: str = "pixel_level_module.encoder"
src_prefix: str = "backbone"
renamed_keys = [
(
f"{src_prefix}.patch_embed.proj.weight",
f"{dst_prefix}.embeddings.patch_embeddings.projection.weight",
),
(f"{src_prefix}.patch_embed.proj.bias", f"{dst_prefix}.embeddings.patch_embeddings.projection.bias"),
(f"{src_prefix}.patch_embed.norm.weight", f"{dst_prefix}.embeddings.norm.weight"),
(f"{src_prefix}.patch_embed.norm.bias", f"{dst_prefix}.embeddings.norm.bias"),
]
for layer_idx in range(len(config.backbone_config.depths)):
for block_idx in range(config.backbone_config.depths[layer_idx]):
renamed_keys.extend(
[ # src, dst
(
f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.norm1.weight",
f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.layernorm_before.weight",
),
(
f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.norm1.bias",
f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.layernorm_before.bias",
),
(
f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.relative_position_bias_table",
f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.relative_position_bias_table",
),
]
)
# now we need to handle the attentions
# read in weights + bias of input projection layer of cross-attention
src_att_weight = src_state_dict[f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.qkv.weight"]
src_att_bias = src_state_dict[f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.qkv.bias"]
size = src_att_weight.shape[0]
offset = size // 3
dst_state_dict[
f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.query.weight"
] = src_att_weight[:offset, :]
dst_state_dict[
f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.query.bias"
] = src_att_bias[:offset]
dst_state_dict[
f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.key.weight"
] = src_att_weight[offset : offset * 2, :]
dst_state_dict[
f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.key.bias"
] = src_att_bias[offset : offset * 2]
dst_state_dict[
f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.value.weight"
] = src_att_weight[-offset:, :]
dst_state_dict[
f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.value.bias"
] = src_att_bias[-offset:]
# let's pop them
src_state_dict.pop(f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.qkv.weight")
src_state_dict.pop(f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.qkv.bias")
# proj
renamed_keys.extend(
[
(
f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.proj.weight",
f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.output.dense.weight",
),
(
f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.proj.bias",
f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.output.dense.bias",
),
]
)
# second norm
renamed_keys.extend(
[
(
f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.norm2.weight",
f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.layernorm_after.weight",
),
(
f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.norm2.bias",
f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.layernorm_after.bias",
),
]
)
# mlp
renamed_keys.extend(
[
(
f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.mlp.fc1.weight",
f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.intermediate.dense.weight",
),
(
f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.mlp.fc1.bias",
f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.intermediate.dense.bias",
),
(
f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.mlp.fc2.weight",
f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.output.dense.weight",
),
(
f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.mlp.fc2.bias",
f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.output.dense.bias",
),
]
)
renamed_keys.extend(
[
(
f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.relative_position_index",
f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.relative_position_index",
)
]
)
if layer_idx < 3:
# patch merging
renamed_keys.extend(
[
(
f"{src_prefix}.layers.{layer_idx}.downsample.reduction.weight",
f"{dst_prefix}.encoder.layers.{layer_idx}.downsample.reduction.weight",
),
(
f"{src_prefix}.layers.{layer_idx}.downsample.norm.weight",
f"{dst_prefix}.encoder.layers.{layer_idx}.downsample.norm.weight",
),
(
f"{src_prefix}.layers.{layer_idx}.downsample.norm.bias",
f"{dst_prefix}.encoder.layers.{layer_idx}.downsample.norm.bias",
),
]
)
# hidden states norms
renamed_keys.extend(
[
(
f"{src_prefix}.norm{layer_idx}.weight",
f"{dst_prefix}.hidden_states_norms.stage{layer_idx+1}.weight",
),
(
f"{src_prefix}.norm{layer_idx}.bias",
f"{dst_prefix}.hidden_states_norms.stage{layer_idx+1}.bias",
),
]
)
self.pop_all(renamed_keys, dst_state_dict, src_state_dict)
# Backbone + Pixel Decoder
def replace_pixel_module(self, dst_state_dict: StateDict, src_state_dict: StateDict):
dst_prefix: str = "pixel_level_module.decoder"
src_prefix: str = "sem_seg_head.pixel_decoder"
self.replace_swin_backbone(dst_state_dict, src_state_dict, self.config)
def rename_keys_for_weight_bias(src_prefix: str, dst_prefix: str):
return [
(f"{src_prefix}.weight", f"{dst_prefix}.weight"),
(f"{src_prefix}.bias", f"{dst_prefix}.bias"),
]
def rename_keys_for_self_attn(src_prefix: str, dst_prefix: str):
self_attn_keys = []
self_attn_keys.extend(
rename_keys_for_weight_bias(f"{src_prefix}.attention_weights", f"{dst_prefix}.attention_weights")
)
self_attn_keys.extend(
rename_keys_for_weight_bias(f"{src_prefix}.output_proj", f"{dst_prefix}.output_proj")
)
self_attn_keys.extend(
rename_keys_for_weight_bias(f"{src_prefix}.sampling_offsets", f"{dst_prefix}.sampling_offsets")
)
self_attn_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.value_proj", f"{dst_prefix}.value_proj"))
return self_attn_keys
def rename_keys_for_encoder_layer(src_prefix: str, dst_prefix: str):
encoder_keys = []
encoder_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.linear1", f"{dst_prefix}.fc1"))
encoder_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.linear2", f"{dst_prefix}.fc2"))
encoder_keys.extend(
rename_keys_for_weight_bias(f"{src_prefix}.norm1", f"{dst_prefix}.self_attn_layer_norm")
)
encoder_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.norm2", f"{dst_prefix}.final_layer_norm"))
encoder_keys.extend(rename_keys_for_self_attn(f"{src_prefix}.self_attn", f"{dst_prefix}.self_attn"))
return encoder_keys
# convolution layer for final features
renamed_keys = [
(f"{src_prefix}.adapter_1.weight", f"{dst_prefix}.adapter_1.0.weight"),
(f"{src_prefix}.adapter_1.norm.weight", f"{dst_prefix}.adapter_1.1.weight"),
(f"{src_prefix}.adapter_1.norm.bias", f"{dst_prefix}.adapter_1.1.bias"),
]
renamed_keys.extend(
[
(f"{src_prefix}.layer_1.weight", f"{dst_prefix}.layer_1.0.weight"),
(f"{src_prefix}.layer_1.norm.weight", f"{dst_prefix}.layer_1.1.weight"),
(f"{src_prefix}.layer_1.norm.bias", f"{dst_prefix}.layer_1.1.bias"),
]
)
# proj layers
for i in range(3):
for j in range(2):
renamed_keys.extend(
[
(f"{src_prefix}.input_proj.{i}.{j}.weight", f"{dst_prefix}.input_projections.{i}.{j}.weight"),
(f"{src_prefix}.input_proj.{i}.{j}.bias", f"{dst_prefix}.input_projections.{i}.{j}.bias"),
]
)
renamed_keys.extend([(f"{src_prefix}.transformer.level_embed", f"{dst_prefix}.level_embed")])
# layers
for layer_idx in range(self.config.encoder_layers):
renamed_keys.extend(
rename_keys_for_encoder_layer(
f"{src_prefix}.transformer.encoder.layers.{layer_idx}", f"{dst_prefix}.encoder.layers.{layer_idx}"
)
)
# proj
renamed_keys.extend(
[
(f"{src_prefix}.mask_features.weight", f"{dst_prefix}.mask_projection.weight"),
(f"{src_prefix}.mask_features.bias", f"{dst_prefix}.mask_projection.bias"),
]
)
self.pop_all(renamed_keys, dst_state_dict, src_state_dict)
# Transformer Decoder
def rename_keys_in_masked_attention_decoder(self, dst_state_dict: StateDict, src_state_dict: StateDict):
dst_prefix: str = "transformer_module.decoder"
src_prefix: str = "sem_seg_head.predictor"
rename_keys = []
for i in range(self.config.decoder_layers - 1):
rename_keys.append(
(
f"{src_prefix}.transformer_self_attention_layers.{i}.self_attn.out_proj.weight",
f"{dst_prefix}.layers.{i}.self_attn.out_proj.weight",
)
)
rename_keys.append(
(
f"{src_prefix}.transformer_self_attention_layers.{i}.self_attn.out_proj.bias",
f"{dst_prefix}.layers.{i}.self_attn.out_proj.bias",
)
)
rename_keys.append(
(
f"{src_prefix}.transformer_self_attention_layers.{i}.norm.weight",
f"{dst_prefix}.layers.{i}.self_attn_layer_norm.weight",
)
)
rename_keys.append(
(
f"{src_prefix}.transformer_self_attention_layers.{i}.norm.bias",
f"{dst_prefix}.layers.{i}.self_attn_layer_norm.bias",
)
)
rename_keys.append(
(
f"{src_prefix}.transformer_cross_attention_layers.{i}.multihead_attn.in_proj_weight",
f"{dst_prefix}.layers.{i}.cross_attn.in_proj_weight",
)
)
rename_keys.append(
(
f"{src_prefix}.transformer_cross_attention_layers.{i}.multihead_attn.in_proj_bias",
f"{dst_prefix}.layers.{i}.cross_attn.in_proj_bias",
)
)
rename_keys.append(
(
f"{src_prefix}.transformer_cross_attention_layers.{i}.multihead_attn.out_proj.weight",
f"{dst_prefix}.layers.{i}.cross_attn.out_proj.weight",
)
)
rename_keys.append(
(
f"{src_prefix}.transformer_cross_attention_layers.{i}.multihead_attn.out_proj.bias",
f"{dst_prefix}.layers.{i}.cross_attn.out_proj.bias",
)
)
rename_keys.append(
(
f"{src_prefix}.transformer_cross_attention_layers.{i}.norm.weight",
f"{dst_prefix}.layers.{i}.cross_attn_layer_norm.weight",
)
)
rename_keys.append(
(
f"{src_prefix}.transformer_cross_attention_layers.{i}.norm.bias",
f"{dst_prefix}.layers.{i}.cross_attn_layer_norm.bias",
)
)
rename_keys.append(
(f"{src_prefix}.transformer_ffn_layers.{i}.linear1.weight", f"{dst_prefix}.layers.{i}.fc1.weight")
)
rename_keys.append(
(f"{src_prefix}.transformer_ffn_layers.{i}.linear1.bias", f"{dst_prefix}.layers.{i}.fc1.bias")
)
rename_keys.append(
(f"{src_prefix}.transformer_ffn_layers.{i}.linear2.weight", f"{dst_prefix}.layers.{i}.fc2.weight")
)
rename_keys.append(
(f"{src_prefix}.transformer_ffn_layers.{i}.linear2.bias", f"{dst_prefix}.layers.{i}.fc2.bias")
)
rename_keys.append(
(
f"{src_prefix}.transformer_ffn_layers.{i}.norm.weight",
f"{dst_prefix}.layers.{i}.final_layer_norm.weight",
)
)
rename_keys.append(
(
f"{src_prefix}.transformer_ffn_layers.{i}.norm.bias",
f"{dst_prefix}.layers.{i}.final_layer_norm.bias",
)
)
return rename_keys
def replace_masked_attention_decoder(self, dst_state_dict: StateDict, src_state_dict: StateDict):
dst_prefix: str = "transformer_module.decoder"
src_prefix: str = "sem_seg_head.predictor"
renamed_keys = self.rename_keys_in_masked_attention_decoder(dst_state_dict, src_state_dict)
# add more
renamed_keys.extend(
[
(f"{src_prefix}.decoder_norm.weight", f"{dst_prefix}.layernorm.weight"),
(f"{src_prefix}.decoder_norm.bias", f"{dst_prefix}.layernorm.bias"),
]
)
mlp_len = 3
for i in range(mlp_len):
renamed_keys.extend(
[
(
f"{src_prefix}.mask_embed.layers.{i}.weight",
f"{dst_prefix}.mask_predictor.mask_embedder.{i}.0.weight",
),
(
f"{src_prefix}.mask_embed.layers.{i}.bias",
f"{dst_prefix}.mask_predictor.mask_embedder.{i}.0.bias",
),
]
)
self.pop_all(renamed_keys, dst_state_dict, src_state_dict)
def replace_keys_qkv_transformer_decoder(self, dst_state_dict: StateDict, src_state_dict: StateDict):
dst_prefix: str = "transformer_module.decoder.layers"
src_prefix: str = "sem_seg_head.predictor"
for i in range(self.config.decoder_layers - 1):
# read in weights + bias of input projection layer of self-attention
in_proj_weight = src_state_dict.pop(
f"{src_prefix}.transformer_self_attention_layers.{i}.self_attn.in_proj_weight"
)
in_proj_bias = src_state_dict.pop(
f"{src_prefix}.transformer_self_attention_layers.{i}.self_attn.in_proj_bias"
)
# next, add query, keys and values (in that order) to the state dict
dst_state_dict[f"{dst_prefix}.{i}.self_attn.q_proj.weight"] = in_proj_weight[:256, :]
dst_state_dict[f"{dst_prefix}.{i}.self_attn.q_proj.bias"] = in_proj_bias[:256]
dst_state_dict[f"{dst_prefix}.{i}.self_attn.k_proj.weight"] = in_proj_weight[256:512, :]
dst_state_dict[f"{dst_prefix}.{i}.self_attn.k_proj.bias"] = in_proj_bias[256:512]
dst_state_dict[f"{dst_prefix}.{i}.self_attn.v_proj.weight"] = in_proj_weight[-256:, :]
dst_state_dict[f"{dst_prefix}.{i}.self_attn.v_proj.bias"] = in_proj_bias[-256:]
def replace_transformer_module(self, dst_state_dict: StateDict, src_state_dict: StateDict):
dst_prefix: str = "transformer_module"
src_prefix: str = "sem_seg_head.predictor"
self.replace_masked_attention_decoder(dst_state_dict, src_state_dict)
renamed_keys = [
(f"{src_prefix}.query_embed.weight", f"{dst_prefix}.queries_embedder.weight"),
(f"{src_prefix}.query_feat.weight", f"{dst_prefix}.queries_features.weight"),
(f"{src_prefix}.level_embed.weight", f"{dst_prefix}.level_embed.weight"),
]
self.pop_all(renamed_keys, dst_state_dict, src_state_dict)
self.replace_keys_qkv_transformer_decoder(dst_state_dict, src_state_dict)
def replace_universal_segmentation_module(self, dst_state_dict: StateDict, src_state_dict: StateDict):
dst_prefix: str = ""
src_prefix: str = "sem_seg_head.predictor"
renamed_keys = [
(f"{src_prefix}.class_embed.weight", f"{dst_prefix}class_predictor.weight"),
(f"{src_prefix}.class_embed.bias", f"{dst_prefix}class_predictor.bias"),
]
logger.info(f"Replacing keys {pformat(renamed_keys)}")
self.pop_all(renamed_keys, dst_state_dict, src_state_dict)
def convert(self, mask2former: Mask2FormerModel) -> Mask2FormerModel:
dst_state_dict = TrackedStateDict(mask2former.state_dict())
src_state_dict = self.original_model.state_dict()
self.replace_pixel_module(dst_state_dict, src_state_dict)
self.replace_transformer_module(dst_state_dict, src_state_dict)
logger.info(f"Missed keys are {pformat(dst_state_dict.diff())}")
logger.info(f"Not copied keys are {pformat(src_state_dict.keys())}")
logger.info("🙌 Done")
state_dict = {key: dst_state_dict[key] for key in dst_state_dict.to_track.keys()}
mask2former.load_state_dict(state_dict)
return mask2former
def convert_universal_segmentation(
self, mask2former: Mask2FormerForUniversalSegmentation
) -> Mask2FormerForUniversalSegmentation:
dst_state_dict = TrackedStateDict(mask2former.state_dict())
src_state_dict = self.original_model.state_dict()
self.replace_universal_segmentation_module(dst_state_dict, src_state_dict)
state_dict = {key: dst_state_dict[key] for key in dst_state_dict.to_track.keys()}
mask2former.load_state_dict(state_dict)
return mask2former
@staticmethod
def using_dirs(checkpoints_dir: Path, config_dir: Path) -> Iterator[Tuple[object, Path, Path]]:
checkpoints: List[Path] = checkpoints_dir.glob("**/*.pkl")
for checkpoint in checkpoints:
logger.info(f"💪 Converting {checkpoint.stem}")
# find associated config file
# dataset_name e.g 'coco'
dataset_name = checkpoint.parents[2].stem
if dataset_name == "ade":
dataset_name = dataset_name.replace("ade", "ade20k")
# task type e.g 'instance-segmentation'
segmentation_task = checkpoint.parents[1].stem
# config file corresponding to checkpoint
config_file_name = f"{checkpoint.parents[0].stem}.yaml"
config: Path = config_dir / dataset_name / segmentation_task / "swin" / config_file_name
yield config, checkpoint
def test(
original_model,
our_model: Mask2FormerForUniversalSegmentation,
feature_extractor: MaskFormerImageProcessor,
tolerance: float,
):
with torch.no_grad():
original_model = original_model.eval()
our_model = our_model.eval()
im = prepare_img()
x = feature_extractor(images=im, return_tensors="pt")["pixel_values"]
original_model_backbone_features = original_model.backbone(x.clone())
our_model_output: Mask2FormerModelOutput = our_model.model(x.clone(), output_hidden_states=True)
# Test backbone
for original_model_feature, our_model_feature in zip(
original_model_backbone_features.values(), our_model_output.encoder_hidden_states
):
assert torch.allclose(
original_model_feature, our_model_feature, atol=tolerance
), "The backbone features are not the same."
# Test pixel decoder
mask_features, _, multi_scale_features = original_model.sem_seg_head.pixel_decoder.forward_features(
original_model_backbone_features
)
for original_model_feature, our_model_feature in zip(
multi_scale_features, our_model_output.pixel_decoder_hidden_states
):
assert torch.allclose(
original_model_feature, our_model_feature, atol=tolerance
), "The pixel decoder feature are not the same"
# Let's test the full model
tr_complete = T.Compose(
[T.Resize((384, 384)), T.ToTensor()],
)
y = (tr_complete(im) * 255.0).to(torch.int).float()
# modify original Mask2Former code to return mask and class logits
original_class_logits, original_mask_logits = original_model([{"image": y.clone().squeeze(0)}])
our_model_out: Mask2FormerForUniversalSegmentationOutput = our_model(x.clone())
our_mask_logits = our_model_out.masks_queries_logits
our_class_logits = our_model_out.class_queries_logits
assert original_mask_logits.shape == our_mask_logits.shape, "Output masks shapes are not matching."
assert original_class_logits.shape == our_class_logits.shape, "Output class logits shapes are not matching."
assert torch.allclose(
original_class_logits, our_class_logits, atol=tolerance
), "The class logits are not the same."
assert torch.allclose(
original_mask_logits, our_mask_logits, atol=tolerance
), "The predicted masks are not the same."
logger.info("✅ Test passed!")
def get_model_name(checkpoint_file: Path):
# model_name_raw is something like maskformer2_swin_small_bs16_50ep
model_name_raw: str = checkpoint_file.parents[0].stem
# `segmentation_task_type` must be one of the following: `instance-segmentation`, `panoptic-segmentation`, `semantic-segmentation`
segmentation_task_name: str = checkpoint_file.parents[1].stem
if segmentation_task_name not in ["instance-segmentation", "panoptic-segmentation", "semantic-segmentation"]:
raise ValueError(
f"{segmentation_task_name} must be wrong since acceptable values are: instance-segmentation,"
" panoptic-segmentation, semantic-segmentation."
)
# dataset name must be one of the following: `coco`, `ade`, `cityscapes`, `mapillary-vistas`
dataset_name: str = checkpoint_file.parents[2].stem
if dataset_name not in ["coco", "ade", "cityscapes", "mapillary-vistas"]:
raise ValueError(
f"{dataset_name} must be wrong since we didn't find 'coco' or 'ade' or 'cityscapes' or 'mapillary-vistas'"
" in it "
)
backbone = "swin"
backbone_types = ["tiny", "small", "base_IN21k", "base", "large"]
backbone_type = list(filter(lambda x: x in model_name_raw, backbone_types))[0].replace("_", "-")
model_name = f"mask2former-{backbone}-{backbone_type}-{dataset_name}-{segmentation_task_name.split('-')[0]}"
return model_name
if __name__ == "__main__":
parser = ArgumentParser(
description="Command line to convert the original mask2formers (with swin backbone) to our implementations."
)
parser.add_argument(
"--checkpoints_dir",
type=Path,
help=(
"A directory containing the model's checkpoints. The directory has to have the following structure:"
" <DIR_NAME>/<DATASET_NAME>/<SEGMENTATION_TASK_NAME>/<CONFIG_NAME>.pkl"
),
)
parser.add_argument(
"--configs_dir",
type=Path,
help=(
"A directory containing the model's configs, see detectron2 doc. The directory has to have the following"
" structure: <DIR_NAME>/<DATASET_NAME>/<SEGMENTATION_TASK_NAME>/<CONFIG_NAME>.yaml"
),
)
parser.add_argument(
"--mask2former_dir",
required=True,
type=Path,
help=(
"A path to Mask2Former's original implementation directory. You can download from here:"
" https://github.com/facebookresearch/Mask2Former"
),
)
args = parser.parse_args()
checkpoints_dir: Path = args.checkpoints_dir
config_dir: Path = args.configs_dir
mask2former_dir: Path = args.mask2former_dir
# append the path to the parents to mask2former dir
sys.path.append(str(mask2former_dir.parent))
# import original Mask2Former config and model from original source code repo
from Mask2Former.mask2former.config import add_maskformer2_config
from Mask2Former.mask2former.maskformer_model import MaskFormer as OriginalMask2Former
for config_file, checkpoint_file in OriginalMask2FormerCheckpointToOursConverter.using_dirs(
checkpoints_dir, config_dir
):
model_name = get_model_name(checkpoint_file)
feature_extractor = OriginalMask2FormerConfigToFeatureExtractorConverter()(
setup_cfg(Args(config_file=config_file))
)
feature_extractor.size = {"height": 384, "width": 384}
original_config = setup_cfg(Args(config_file=config_file))
mask2former_kwargs = OriginalMask2Former.from_config(original_config)
original_model = OriginalMask2Former(**mask2former_kwargs).eval()
DetectionCheckpointer(original_model).load(str(checkpoint_file))
config: Mask2FormerConfig = OriginalMask2FormerConfigToOursConverter()(original_config)
mask2former = Mask2FormerModel(config=config).eval()
converter = OriginalMask2FormerCheckpointToOursConverter(original_model, config)
mask2former = converter.convert(mask2former)
mask2former_for_segmentation = Mask2FormerForUniversalSegmentation(config=config).eval()
mask2former_for_segmentation.model = mask2former
mask2former_for_segmentation = converter.convert_universal_segmentation(mask2former_for_segmentation)
tolerance = 3e-1
high_tolerance_models = [
"mask2former-swin-base-IN21k-coco-instance",
"mask2former-swin-base-coco-instance",
"mask2former-swin-small-cityscapes-semantic",
]
if model_name in high_tolerance_models:
tolerance = 3e-1
logger.info(f"🪄 Testing {model_name}...")
test(original_model, mask2former_for_segmentation, feature_extractor, tolerance)
logger.info(f"🪄 Pushing {model_name} to hub...")
feature_extractor.push_to_hub(model_name)
mask2former_for_segmentation.push_to_hub(model_name)
# coding=utf-8
# Copyright 2022 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Mask2Former model."""
import math
import random
import warnings
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
from torch import Tensor, nn
from transformers import AutoBackbone, SwinConfig
from transformers.utils import logging
from ...activations import ACT2FN
from ...file_utils import (
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_scipy_available,
replace_return_docstrings,
requires_backends,
)
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions
from ...modeling_utils import PreTrainedModel
from .configuration_mask2former import Mask2FormerConfig
if is_scipy_available():
from scipy.optimize import linear_sum_assignment
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "Mask2FormerConfig"
_CHECKPOINT_FOR_DOC = "facebook/mask2former-swin-small-coco-instance"
_IMAGE_PROCESSOR_FOR_DOC = "MaskFormerImageProcessor"
MASK2FORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/mask2former-swin-small-coco-instance",
# See all mask2former models at https://huggingface.co/models?filter=mask2former
]
@dataclass
class Mask2FormerPixelDecoderOutput(ModelOutput):
"""
Mask2Former's pixel decoder module output, practically a Multi-Scale Deformable Attention based decoder. It returns
the mask features and the multiscale features.
Args:
multi_scale_features (`tuple(torch.FloatTensor)`):
Tuple of multi-scale features of scales [1/8, 1/16, 1/32] and shape `(batch_size, num_channels, height,
width)`from the Multi-Scale Deformable Attenntion based Pixel Decoder.
mask_features (`torch.FloatTensor`):
Tensor of shape `(batch_size, num_channels, height, width)`, 1/4 scale features from the last Pixel Decoder
Layer.
attentions (`tuple(torch.FloatTensor)`, *optional*):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`. Attentions weights from pixel decoder. Returned when `output_attentions=True` is passed
or when `config.output_attentions=True`
"""
multi_scale_features: Tuple[torch.FloatTensor] = None
mask_features: torch.FloatTensor = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class Mask2FormerMaskedAttentionDecoderOutput(BaseModelOutputWithCrossAttentions):
"""
Base class for outputs of the Transformer decoder. This class adds two attributes to
BaseModelOutputWithCrossAttentions for mask predictions logits and a tuple of intermediate decoder activations,
i.e. the output of each decoder layer, each of them gone through a layernorm.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (`tuple(torch.FloatTensor)`, *optional*):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
plus the initial embedding outputs. Returned when `output_hidden_states=True`.
attentions (`tuple(torch.FloatTensor)`, *optional*):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
the self-attention heads. Returned when `output_attentions=True`.
masks_queries_logits (`tuple(torch.FloatTensor)` of shape `(batch_size, num_queries, height, width)`):
Tuple of mask predictions from all layers of the transformer decoder.
intermediate_hidden_states (`tuple(torch.FloatTensor)` of shape `(num_queries, 1, hidden_size)`):
Intermediate decoder activations, i.e. the output of each decoder layer, each of them gone through a
layernorm.
"""
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[torch.FloatTensor] = None
masks_queries_logits: Tuple[torch.FloatTensor] = None
intermediate_hidden_states: Tuple[torch.FloatTensor] = None
@dataclass
class Mask2FormerPixelLevelModuleOutput(ModelOutput):
"""
Mask2Former's pixel level module output. It returns the output of the encoder (optional) and all hidden states
(multi-scale features) from the `decoder`. By default, the `encoder` is a Swin Backbone and the `decoder` is a
Multi-Scale Deformable Attention based decoder.
The `decoder_last_hidden_state` are the **per-pixel embeddings** while `decoder_hidden_states` refer to multi-scale
feature maps produced using **multi-scaling strategy** defined in the paper.
Args:
encoder_last_hidden_state (`torch.FloatTensor`):
Last hidden states (final feature map of shape `(batch_size, num_channels, height, width)`) of the last
stage of the encoder.
encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
Tuple of `torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`. Hidden states (also
called feature maps) of the model at the output of each stage. Returned if output_hidden_states is set to
True.
decoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)):
1/4 scale features from the last Pixel Decoder Layer.
decoder_hidden_states (`tuple(torch.FloatTensor)`):
Tuple of `torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`. Hidden states (also
called feature maps) of the model at the output of each stage.
"""
encoder_last_hidden_state: torch.FloatTensor = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_last_hidden_state: torch.FloatTensor = None
decoder_hidden_states: Tuple[torch.FloatTensor] = None
@dataclass
class Mask2FormerModelOutput(ModelOutput):
"""
Class for outputs of [`Mask2FormerModel`]. This class returns all the needed hidden states to compute the logits.
Args:
encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*):
Last hidden states (final feature map) of the last stage of the encoder model (backbone). Returned when
`output_hidden_states=True` is passed.
encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the encoder
model at the output of each stage. Returned when `output_hidden_states=True` is passed.
pixel_decoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*):
Last hidden states (final feature map) of the last stage of the pixel decoder model.
pixel_decoder_hidden_states (`tuple(torch.FloatTensor)`, , *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the pixel
decoder model at the output of each stage. Returned when `output_hidden_states=True` is passed.
transformer_decoder_last_hidden_state (`tuple(torch.FloatTensor)`):
Final output of the transformer decoder `(batch_size, sequence_length, hidden_size)`.
transformer_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, sequence_length, hidden_size)`. Hidden-states (also called feature maps) of the
transformer decoder at the output of each stage. Returned when `output_hidden_states=True` is passed.
transformer_decoder_intermediate_states (`tuple(torch.FloatTensor)` of shape `(num_queries, 1, hidden_size)`):
Intermediate decoder activations, i.e. the output of each decoder layer, each of them gone through a
layernorm.
masks_queries_logits (`tuple(torch.FloatTensor)` of shape `(batch_size, num_queries, height, width)`)
Mask Predictions from each layer in the transformer decoder.
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed):
Tuple of `tuple(torch.FloatTensor)` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`. Self attentions weights from transformer decoder.
"""
encoder_last_hidden_state: torch.FloatTensor = None
pixel_decoder_last_hidden_state: torch.FloatTensor = None
transformer_decoder_last_hidden_state: torch.FloatTensor = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
pixel_decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
transformer_decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
transformer_decoder_intermediate_states: Tuple[torch.FloatTensor] = None
masks_queries_logits: Tuple[torch.FloatTensor] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class Mask2FormerForUniversalSegmentationOutput(ModelOutput):
"""
Class for outputs of [`Mask2FormerForUniversalSegmentationOutput`].
This output can be directly passed to [`~MaskFormerImageProcessor.post_process_semantic_segmentation`] or
[`~MaskFormerImageProcessor.post_process_instance_segmentation`] or
[`~MaskFormerImageProcessor.post_process_panoptic_segmentation`] to compute final segmentation maps. Please, see
[`~MaskFormerImageProcessor] for details regarding usage.
Args:
loss (`torch.Tensor`, *optional*):
The computed loss, returned when labels are present.
class_queries_logits (`torch.FloatTensor`):
A tensor of shape `(batch_size, num_queries, num_labels + 1)` representing the proposed classes for each
query. Note the `+ 1` is needed because we incorporate the null class.
masks_queries_logits (`torch.FloatTensor`):
A tensor of shape `(batch_size, num_queries, height, width)` representing the proposed masks for each
query.
auxiliary_logits (`List[Dict(str, torch.FloatTensor)]`, *optional*):
List of class and mask predictions from each layer of the transformer decoder.
encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Last hidden states (final feature map) of the last stage of the encoder model (backbone).
encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the encoder
model at the output of each stage.
pixel_decoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Last hidden states (final feature map) of the last stage of the pixel decoder model.
pixel_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the pixel
decoder model at the output of each stage.
transformer_decoder_last_hidden_state (`tuple(torch.FloatTensor)`):
Final output of the transformer decoder `(batch_size, sequence_length, hidden_size)`.
transformer_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, sequence_length, hidden_size)`. Hidden-states (also called feature maps) of the
transformer decoder at the output of each stage.
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `tuple(torch.FloatTensor)` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`. Self and Cross Attentions weights from transformer decoder.
"""
loss: Optional[torch.FloatTensor] = None
class_queries_logits: torch.FloatTensor = None
masks_queries_logits: torch.FloatTensor = None
auxiliary_logits: Optional[List[Dict[str, torch.FloatTensor]]] = None
encoder_last_hidden_state: torch.FloatTensor = None
pixel_decoder_last_hidden_state: torch.FloatTensor = None
transformer_decoder_last_hidden_state: torch.FloatTensor = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
pixel_decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
transformer_decoder_hidden_states: Optional[torch.FloatTensor] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
# Copied from transformers.models.detr.modeling_detr._expand_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, target_len: Optional[int] = None):
"""
Expands attention_mask from `[batch_size, seq_len]` to `[batch_size, 1, target_seq_len, source_seq_len]`.
"""
batch_size, source_len = mask.size()
target_len = target_len if target_len is not None else source_len
expanded_mask = mask[:, None, None, :].expand(batch_size, 1, target_len, source_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
# Adapted from https://github.com/facebookresearch/detectron2/blob/main/projects/PointRend/point_rend/point_features.py
def sample_point(
input_features: torch.Tensor, point_coordinates: torch.Tensor, add_dim=False, **kwargs
) -> torch.Tensor:
"""
A wrapper around `torch.nn.functional.grid_sample` to support 3D point_coordinates tensors.
Args:
input_features (`torch.Tensor` of shape (batch_size, channels, height, width)):
A tensor that contains features map on a height * width grid
point_coordinates (`torch.Tensor` of shape (batch_size, num_points, 2) or (batch_size, grid_height, grid_width,:
2)):
A tensor that contains [0, 1] * [0, 1] normalized point coordinates
add_dim (`bool`):
boolean value to keep track of added dimension
Returns:
point_features (`torch.Tensor` of shape (batch_size, channels, num_points) or (batch_size, channels,
height_grid, width_grid):
A tensor that contains features for points in `point_coordinates`.
"""
if point_coordinates.dim() == 3:
add_dim = True
point_coordinates = point_coordinates.unsqueeze(2)
# use nn.function.grid_sample to get features for points in `point_coordinates` via bilinear interpolation
point_features = torch.nn.functional.grid_sample(input_features, 2.0 * point_coordinates - 1.0, **kwargs)
if add_dim:
point_features = point_features.squeeze(3)
return point_features
# Copied from transformers.models.maskformer.modeling_maskformer.dice_loss
def dice_loss(inputs: Tensor, labels: Tensor, num_masks: int) -> Tensor:
r"""
Compute the DICE loss, similar to generalized IOU for masks as follows:
$$ \mathcal{L}_{\text{dice}(x, y) = 1 - \frac{2 * x \cap y }{x \cup y + 1}} $$
In practice, since `labels` is a binary mask, (only 0s and 1s), dice can be computed as follow
$$ \mathcal{L}_{\text{dice}(x, y) = 1 - \frac{2 * x * y }{x + y + 1}} $$
Args:
inputs (`torch.Tensor`):
A tensor representing a mask.
labels (`torch.Tensor`):
A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs
(0 for the negative class and 1 for the positive class).
num_masks (`int`):
The number of masks present in the current batch, used for normalization.
Returns:
`torch.Tensor`: The computed loss.
"""
probs = inputs.sigmoid().flatten(1)
numerator = 2 * (probs * labels).sum(-1)
denominator = probs.sum(-1) + labels.sum(-1)
loss = 1 - (numerator + 1) / (denominator + 1)
loss = loss.sum() / num_masks
return loss
def sigmoid_cross_entropy_loss(inputs: torch.Tensor, labels: torch.Tensor, num_masks: int) -> torch.Tensor:
r"""
Args:
inputs (`torch.Tensor`):
A float tensor of arbitrary shape.
labels (`torch.Tensor`):
A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs
(0 for the negative class and 1 for the positive class).
Returns:
loss (`torch.Tensor`): The computed loss.
"""
criterion = nn.BCEWithLogitsLoss(reduction="none")
cross_entropy_loss = criterion(inputs, labels)
loss = cross_entropy_loss.mean(1).sum() / num_masks
return loss
# Copied from transformers.models.maskformer.modeling_maskformer.pair_wise_dice_loss
def pair_wise_dice_loss(inputs: Tensor, labels: Tensor) -> Tensor:
"""
A pair wise version of the dice loss, see `dice_loss` for usage.
Args:
inputs (`torch.Tensor`):
A tensor representing a mask
labels (`torch.Tensor`):
A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs
(0 for the negative class and 1 for the positive class).
Returns:
`torch.Tensor`: The computed loss between each pairs.
"""
inputs = inputs.sigmoid().flatten(1)
numerator = 2 * torch.einsum("nc,mc->nm", inputs, labels)
# using broadcasting to get a [num_queries, NUM_CLASSES] matrix
denominator = inputs.sum(-1)[:, None] + labels.sum(-1)[None, :]
loss = 1 - (numerator + 1) / (denominator + 1)
return loss
def pair_wise_sigmoid_cross_entropy_loss(inputs: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
r"""
A pair wise version of the cross entropy loss, see `sigmoid_cross_entropy_loss` for usage.
Args:
inputs (`torch.Tensor`):
A tensor representing a mask.
labels (`torch.Tensor`):
A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs
(0 for the negative class and 1 for the positive class).
Returns:
loss (`torch.Tensor`): The computed loss between each pairs.
"""
height_and_width = inputs.shape[1]
criterion = nn.BCEWithLogitsLoss(reduction="none")
cross_entropy_loss_pos = criterion(inputs, torch.ones_like(inputs))
cross_entropy_loss_neg = criterion(inputs, torch.zeros_like(inputs))
loss = torch.einsum("nc,mc->nm", cross_entropy_loss_pos, labels) + torch.einsum(
"nc,mc->nm", cross_entropy_loss_neg, (1 - labels)
)
loss = loss / height_and_width
return loss
# Adapted from https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/matcher.py
class Mask2FormerHungarianMatcher(nn.Module):
"""This class computes an assignment between the labels and the predictions of the network.
For efficiency reasons, the labels don't include the no_object. Because of this, in general, there are more
predictions than labels. In this case, we do a 1-to-1 matching of the best predictions, while the others are
un-matched (and thus treated as non-objects).
"""
def __init__(
self, cost_class: float = 1.0, cost_mask: float = 1.0, cost_dice: float = 1.0, num_points: int = 12544
):
"""Creates the matcher
Params:
cost_class (`float`, *optional*, defaults to 1.0):
Relative weight of the classification error in the matching cost.
cost_mask (`float`, *optional*, defaults to 1.0):
This is the relative weight of the focal loss of the binary mask in the matching cost.
cost_dice (`float`, *optional*, defaults to 1.0):
This is the relative weight of the dice loss of the binary mask in the matching cost.
num_points (`int`, *optional*, defaults to 12544):
No. of points to sample on which the mask loss will be calculated. The same set of K points are
uniformly sampled for all prediction and ground truth masks to construct the cost matrix for bipartite
matching.
"""
super().__init__()
if cost_class == 0 and cost_mask == 0 and cost_dice == 0:
raise ValueError("All costs cant be 0")
self.num_points = num_points
self.cost_class = cost_class
self.cost_mask = cost_mask
self.cost_dice = cost_dice
@torch.no_grad()
def forward(
self,
masks_queries_logits: torch.Tensor,
class_queries_logits: torch.Tensor,
mask_labels: torch.Tensor,
class_labels: torch.Tensor,
) -> List[Tuple[Tensor]]:
"""
Params:
masks_queries_logits (`torch.Tensor`):
A tensor of dim `batch_size, num_queries, num_labels` with the classification logits.
class_queries_logits (`torch.Tensor`):
A tensor of dim `batch_size, num_queries, height, width` with the predicted masks.
class_labels (`torch.Tensor`):
A tensor of dim `num_target_boxes` (where num_target_boxes is the number of ground-truth objects in the
target) containing the class labels.
mask_labels (`torch.Tensor`):
A tensor of dim `num_target_boxes, height, width` containing the target masks.
Returns:
matched_indices (`List[Tuple[Tensor]]`): A list of size batch_size, containing tuples of (index_i, index_j)
where:
- index_i is the indices of the selected predictions (in order)
- index_j is the indices of the corresponding selected labels (in order)
For each batch element, it holds:
len(index_i) = len(index_j) = min(num_queries, num_target_boxes).
"""
indices: List[Tuple[np.array]] = []
# iterate through batch size
batch_size = masks_queries_logits.shape[0]
for i in range(batch_size):
pred_probs = class_queries_logits[i].softmax(-1)
pred_mask = masks_queries_logits[i]
# Compute the classification cost. Contrary to the loss, we don't use the NLL, but approximate it in 1 - proba[target class]. The 1 is a constant that doesn't change the matching, it can be ommitted.
cost_class = -pred_probs[:, class_labels[i]]
target_mask = mask_labels[i].to(pred_mask)
target_mask = target_mask[:, None]
pred_mask = pred_mask[:, None]
# Sample ground truth and predicted masks
point_coordinates = torch.rand(1, self.num_points, 2, device=pred_mask.device)
target_coordinates = point_coordinates.repeat(target_mask.shape[0], 1, 1)
target_mask = sample_point(target_mask, target_coordinates, align_corners=False).squeeze(1)
pred_coordinates = point_coordinates.repeat(pred_mask.shape[0], 1, 1)
pred_mask = sample_point(pred_mask, pred_coordinates, align_corners=False).squeeze(1)
# compute the cross entropy loss between each mask pairs -> shape (num_queries, num_labels)
cost_mask = pair_wise_sigmoid_cross_entropy_loss(pred_mask, target_mask)
# Compute the dice loss betwen each mask pairs -> shape (num_queries, num_labels)
cost_dice = pair_wise_dice_loss(pred_mask, target_mask)
# final cost matrix
cost_matrix = self.cost_mask * cost_mask + self.cost_class * cost_class + self.cost_dice * cost_dice
# do the assigmented using the hungarian algorithm in scipy
assigned_indices: Tuple[np.array] = linear_sum_assignment(cost_matrix.cpu())
indices.append(assigned_indices)
# It could be stacked in one tensor
matched_indices = [
(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices
]
return matched_indices
# Adapted from https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/criterion.py
class Mask2FormerLoss(nn.Module):
def __init__(self, config: Mask2FormerConfig, weight_dict: Dict[str, float]):
"""
The Mask2Former Loss. The loss is computed very similar to DETR. The process happens in two steps: 1) we
compute hungarian assignment between ground truth masks and the outputs of the model 2) we supervise each pair
of matched ground-truth / prediction (supervise class and mask)
Args:
config (`Mask2FormerConfig`):
The configuration for Mask2Former model also containing loss calculation specific parameters.
weight_dict (`Dict[str, float]`):
A dictionary of weights to be applied to the different losses.
"""
super().__init__()
requires_backends(self, ["scipy"])
self.num_labels = config.num_labels
self.weight_dict = weight_dict
# Weight to apply to the null class
self.eos_coef = config.no_object_weight
empty_weight = torch.ones(self.num_labels + 1)
empty_weight[-1] = self.eos_coef
self.register_buffer("empty_weight", empty_weight)
# pointwise mask loss parameters
self.num_points = config.train_num_points
self.oversample_ratio = config.oversample_ratio
self.importance_sample_ratio = config.importance_sample_ratio
self.matcher = Mask2FormerHungarianMatcher(
cost_class=1.0,
cost_dice=config.dice_weight,
cost_mask=config.mask_weight,
num_points=self.num_points,
)
def _max_by_axis(self, sizes: List[List[int]]) -> List[int]:
maxes = sizes[0]
for sublist in sizes[1:]:
for index, item in enumerate(sublist):
maxes[index] = max(maxes[index], item)
return maxes
# Adapted from nested_tensor_from_tensor_list() in original implementation
def _pad_images_to_max_in_batch(self, tensors: List[Tensor]) -> Tuple[Tensor, Tensor]:
# get the maximum size in the batch
max_size = self._max_by_axis([list(tensor.shape) for tensor in tensors])
# compute final size
batch_shape = [len(tensors)] + max_size
batch_size, _, height, width = batch_shape
dtype = tensors[0].dtype
device = tensors[0].device
padded_tensors = torch.zeros(batch_shape, dtype=dtype, device=device)
padding_masks = torch.ones((batch_size, height, width), dtype=torch.bool, device=device)
# pad the tensors to the size of the biggest one
for tensor, padded_tensor, padding_mask in zip(tensors, padded_tensors, padding_masks):
padded_tensor[: tensor.shape[0], : tensor.shape[1], : tensor.shape[2]].copy_(tensor)
padding_mask[: tensor.shape[1], : tensor.shape[2]] = False
return padded_tensors, padding_masks
def loss_labels(
self, class_queries_logits: Tensor, class_labels: List[Tensor], indices: Tuple[np.array]
) -> Dict[str, Tensor]:
"""Compute the losses related to the labels using cross entropy.
Args:
class_queries_logits (`torch.Tensor`):
A tensor of shape `batch_size, num_queries, num_labels`
class_labels (`List[torch.Tensor]`):
List of class labels of shape `(labels)`.
indices (`Tuple[np.array])`:
The indices computed by the Hungarian matcher.
Returns:
`Dict[str, Tensor]`: A dict of `torch.Tensor` containing the following key:
- **loss_cross_entropy** -- The loss computed using cross entropy on the predicted and ground truth labels.
"""
pred_logits = class_queries_logits
batch_size, num_queries, _ = pred_logits.shape
criterion = nn.CrossEntropyLoss(weight=self.empty_weight)
idx = self._get_predictions_permutation_indices(indices) # shape of (batch_size, num_queries)
target_classes_o = torch.cat(
[target[j] for target, (_, j) in zip(class_labels, indices)]
) # shape of (batch_size, num_queries)
target_classes = torch.full(
(batch_size, num_queries), fill_value=self.num_labels, dtype=torch.int64, device=pred_logits.device
)
target_classes[idx] = target_classes_o
# Permute target_classes (batch_size, num_queries, num_labels) -> (batch_size, num_labels, num_queries)
pred_logits_transposed = pred_logits.transpose(1, 2)
loss_ce = criterion(pred_logits_transposed, target_classes)
losses = {"loss_cross_entropy": loss_ce}
return losses
def loss_masks(
self,
masks_queries_logits: torch.Tensor,
mask_labels: List[torch.Tensor],
indices: Tuple[np.array],
num_masks: int,
) -> Dict[str, torch.Tensor]:
"""Compute the losses related to the masks using sigmoid_cross_entropy_loss and dice loss.
Args:
masks_queries_logits (`torch.Tensor`):
A tensor of shape `(batch_size, num_queries, height, width)`.
mask_labels (`torch.Tensor`):
List of mask labels of shape `(labels, height, width)`.
indices (`Tuple[np.array])`:
The indices computed by the Hungarian matcher.
num_masks (`int)`:
The number of masks, used for normalization.
Returns:
losses (`Dict[str, Tensor]`): A dict of `torch.Tensor` containing two keys:
- **loss_mask** -- The loss computed using sigmoid cross entropy loss on the predicted and ground truth.
masks.
- **loss_dice** -- The loss computed using dice loss on the predicted on the predicted and ground truth,
masks.
"""
src_idx = self._get_predictions_permutation_indices(indices)
tgt_idx = self._get_targets_permutation_indices(indices)
# shape (batch_size * num_queries, height, width)
pred_masks = masks_queries_logits[src_idx]
# shape (batch_size, num_queries, height, width)
# pad all and stack the targets to the num_labels dimension
target_masks, _ = self._pad_images_to_max_in_batch(mask_labels)
target_masks = target_masks[tgt_idx]
# No need to upsample predictions as we are using normalized coordinates
pred_masks = pred_masks[:, None]
target_masks = target_masks[:, None]
# Sample point coordinates
with torch.no_grad():
point_coordinates = self.sample_points_using_uncertainty(
pred_masks,
lambda logits: self.calculate_uncertainty(logits),
self.num_points,
self.oversample_ratio,
self.importance_sample_ratio,
)
point_labels = sample_point(target_masks, point_coordinates, align_corners=False).squeeze(1)
point_logits = sample_point(pred_masks, point_coordinates, align_corners=False).squeeze(1)
losses = {
"loss_mask": sigmoid_cross_entropy_loss(point_logits, point_labels, num_masks),
"loss_dice": dice_loss(point_logits, point_labels, num_masks),
}
del pred_masks
del target_masks
return losses
def _get_predictions_permutation_indices(self, indices):
# Permute predictions following indices
batch_indices = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
predictions_indices = torch.cat([src for (src, _) in indices])
return batch_indices, predictions_indices
def _get_targets_permutation_indices(self, indices):
# Permute labels following indices
batch_indices = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
target_indices = torch.cat([tgt for (_, tgt) in indices])
return batch_indices, target_indices
def calculate_uncertainty(self, logits: torch.Tensor) -> torch.Tensor:
"""
In Mask2Former paper, uncertainty is estimated as L1 distance between 0.0 and the logit prediction in 'logits'
for the foreground class in `classes`.
Args:
logits (`torch.Tensor`):
A tensor of shape (R, 1, ...) for class-specific or class-agnostic, where R is the total number of predicted masks in all images and C is:
the number of foreground classes. The values are logits.
Returns:
scores (`torch.Tensor`): A tensor of shape (R, 1, ...) that contains uncertainty scores with the most
uncertain locations having the highest uncertainty score.
"""
uncertainty_scores = -(torch.abs(logits))
return uncertainty_scores
def sample_points_using_uncertainty(
self,
logits: torch.Tensor,
uncertainty_function,
num_points: int,
oversample_ratio: int,
importance_sample_ratio: float,
) -> torch.Tensor:
"""
This function is meant for sampling points in [0, 1] * [0, 1] coordinate space based on their uncertainty. The
uncertainty is calculated for each point using the passed `uncertainty function` that takes points logit
prediction as input.
Args:
logits (`float`):
Logit predictions for P points.
uncertainty_function:
A function that takes logit predictions for P points and returns their uncertainties.
num_points (`int`):
The number of points P to sample.
oversample_ratio (`int`):
Oversampling parameter.
importance_sample_ratio (`float`):
Ratio of points that are sampled via importance sampling.
Returns:
point_coordinates (`torch.Tensor`):
Coordinates for P sampled points.
"""
num_boxes = logits.shape[0]
num_points_sampled = int(num_points * oversample_ratio)
# Get random point coordinates
point_coordinates = torch.rand(num_boxes, num_points_sampled, 2, device=logits.device)
# Get sampled prediction value for the point coordinates
point_logits = sample_point(logits, point_coordinates, align_corners=False)
# Calculate the uncertainties based on the sampled prediction values of the points
point_uncertainties = uncertainty_function(point_logits)
num_uncertain_points = int(importance_sample_ratio * num_points)
num_random_points = num_points - num_uncertain_points
idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]
shift = num_points_sampled * torch.arange(num_boxes, dtype=torch.long, device=logits.device)
idx += shift[:, None]
point_coordinates = point_coordinates.view(-1, 2)[idx.view(-1), :].view(num_boxes, num_uncertain_points, 2)
if num_random_points > 0:
point_coordinates = torch.cat(
[point_coordinates, torch.rand(num_boxes, num_random_points, 2, device=logits.device)],
dim=1,
)
return point_coordinates
def forward(
self,
masks_queries_logits: torch.Tensor,
class_queries_logits: torch.Tensor,
mask_labels: List[torch.Tensor],
class_labels: List[torch.Tensor],
auxiliary_predictions: Optional[Dict[str, torch.Tensor]] = None,
) -> Dict[str, torch.Tensor]:
"""
This performs the loss computation.
Args:
masks_queries_logits (`torch.Tensor`):
A tensor of shape `(batch_size, num_queries, height, width)`.
class_queries_logits (`torch.Tensor`):
A tensor of shape `(batch_size, num_queries, num_labels)`.
mask_labels (`torch.Tensor`):
List of mask labels of shape `(labels, height, width)`.
class_labels (`List[torch.Tensor]`):
List of class labels of shape `(labels)`.
auxiliary_predictions (`Dict[str, torch.Tensor]`, *optional*):
if `use_auxiliary_loss` was set to `true` in [`Mask2FormerConfig`], then it contains the logits from
the inner layers of the Mask2FormerMaskedAttentionDecoder.
Returns:
losses (`Dict[str, Tensor]`): A dict of `torch.Tensor` containing three keys:
- **loss_cross_entropy** -- The loss computed using cross entropy on the predicted and ground truth labels.
- **loss_mask** -- The loss computed using sigmoid cross_entropy loss on the predicted and ground truth
masks.
- **loss_dice** -- The loss computed using dice loss on the predicted on the predicted and ground truth
masks.
if `use_auxiliary_loss` was set to `true` in [`Mask2FormerConfig`], the dictionary contains additional
losses for each auxiliary predictions.
"""
# retrieve the matching between the outputs of the last layer and the labels
indices = self.matcher(masks_queries_logits, class_queries_logits, mask_labels, class_labels)
# compute the average number of target masks for normalization purposes
num_masks = self.get_num_masks(class_labels, device=class_labels[0].device)
# get all the losses
losses: Dict[str, Tensor] = {
**self.loss_masks(masks_queries_logits, mask_labels, indices, num_masks),
**self.loss_labels(class_queries_logits, class_labels, indices),
}
# in case of auxiliary losses, we repeat this process with the output of each intermediate layer.
if auxiliary_predictions is not None:
for idx, aux_outputs in enumerate(auxiliary_predictions):
masks_queries_logits = aux_outputs["masks_queries_logits"]
class_queries_logits = aux_outputs["class_queries_logits"]
loss_dict = self.forward(masks_queries_logits, class_queries_logits, mask_labels, class_labels)
loss_dict = {f"{key}_{idx}": value for key, value in loss_dict.items()}
losses.update(loss_dict)
return losses
def get_num_masks(self, class_labels: torch.Tensor, device: torch.device) -> torch.Tensor:
"""
Computes the average number of target masks across the batch, for normalization purposes.
"""
num_masks = sum([len(classes) for classes in class_labels])
num_masks_pt = torch.as_tensor([num_masks], dtype=torch.float, device=device)
return num_masks_pt
def multi_scale_deformable_attention(
value: Tensor, value_spatial_shapes: Tensor, sampling_locations: Tensor, attention_weights: Tensor
):
batch_size, _, num_attn_head, hidden_dim = value.shape
_, num_queries, num_attn_head, num_levels, num_points, _ = sampling_locations.shape
value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1)
sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for idx, (height, width) in enumerate(value_spatial_shapes):
# batch_size, height*width, num_attn_head, hidden_dim -> batch_size*num_attn_head, hidden_dim, height, width
value_l_ = (
value_list[idx].flatten(2).transpose(1, 2).reshape(batch_size * num_attn_head, hidden_dim, height, width)
)
# (batch_size, num_queries, num_attn_head, num_points) -> (batch_size * num_attn_head, num_queries, num_points, 2)
sampling_grid_l_ = sampling_grids[:, :, :, idx].transpose(1, 2).flatten(0, 1)
# batch_size*num_attn_head, D_, num_queries, num_points
sampling_value_l_ = torch.nn.functional.grid_sample(
value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False
)
sampling_value_list.append(sampling_value_l_)
# (batch_size, num_queries, num_attn_head, num_levels, num_points) -> (batch_size, num_attn_head, 1, num_queries, num_levels*num_points)
attention_weights = attention_weights.transpose(1, 2).reshape(
batch_size * num_attn_head, 1, num_queries, num_levels * num_points
)
output = (
(torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
.sum(-1)
.view(batch_size, num_attn_head * hidden_dim, num_queries)
)
return output.transpose(1, 2).contiguous()
# Copied from transformers.models.maskformer.modeling_maskformer.MaskFormerSinePositionEmbedding with MaskFormer->Mask2Former
class Mask2FormerSinePositionEmbedding(nn.Module):
"""
This is a more standard version of the position embedding, very similar to the one used by the Attention is all you
need paper, generalized to work on images.
"""
def __init__(
self, num_pos_feats: int = 64, temperature: int = 10000, normalize: bool = False, scale: Optional[float] = None
):
super().__init__()
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
self.num_pos_feats = num_pos_feats
self.temperature = temperature
self.normalize = normalize
self.scale = 2 * math.pi if scale is None else scale
def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor:
if mask is None:
mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
not_mask = ~mask
y_embed = not_mask.cumsum(1, dtype=torch.float32)
x_embed = not_mask.cumsum(2, dtype=torch.float32)
if self.normalize:
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
return pos
# Modified from transformers.models.detr.modeling_deformable_detr.DeformableDetrMultiscaleDeformableAttention
class Mask2FormerPixelDecoderEncoderMultiscaleDeformableAttention(nn.Module):
"""
Multiscale deformable attention as proposed in Deformable DETR.
"""
def __init__(self, embed_dim: int, num_heads: int, n_levels: int, n_points: int):
super().__init__()
if embed_dim % num_heads != 0:
raise ValueError(
f"embed_dim (d_model) must be divisible by num_heads, but got {embed_dim} and {num_heads}"
)
dim_per_head = embed_dim // num_heads
# check if dim_per_head is power of 2
if not ((dim_per_head & (dim_per_head - 1) == 0) and dim_per_head != 0):
warnings.warn(
"You'd better set embed_dim (d_model) in DeformableDetrMultiscaleDeformableAttention to make the"
" dimension of each attention head a power of 2 which is more efficient in the authors' CUDA"
" implementation."
)
self.im2col_step = 128
self.d_model = embed_dim
self.n_levels = n_levels
self.n_heads = num_heads
self.n_points = n_points
self.sampling_offsets = nn.Linear(embed_dim, num_heads * n_levels * n_points * 2)
self.attention_weights = nn.Linear(embed_dim, num_heads * n_levels * n_points)
self.value_proj = nn.Linear(embed_dim, embed_dim)
self.output_proj = nn.Linear(embed_dim, embed_dim)
def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]):
return tensor if position_embeddings is None else tensor + position_embeddings
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states=None,
encoder_attention_mask=None,
position_embeddings: Optional[torch.Tensor] = None,
reference_points=None,
spatial_shapes=None,
level_start_index=None,
output_attentions: bool = False,
):
# add position embeddings to the hidden states before projecting to queries and keys
if position_embeddings is not None:
hidden_states = self.with_pos_embed(hidden_states, position_embeddings)
batch_size, num_queries, _ = hidden_states.shape
batch_size, sequence_length, _ = encoder_hidden_states.shape
if (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length:
raise ValueError(
"Make sure to align the spatial shapes with the sequence length of the encoder hidden states"
)
value = self.value_proj(encoder_hidden_states)
if attention_mask is not None:
# we invert the attention_mask
value = value.masked_fill(attention_mask[..., None], float(0))
value = value.view(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads)
sampling_offsets = self.sampling_offsets(hidden_states).view(
batch_size, num_queries, self.n_heads, self.n_levels, self.n_points, 2
)
attention_weights = self.attention_weights(hidden_states).view(
batch_size, num_queries, self.n_heads, self.n_levels * self.n_points
)
attention_weights = nn.functional.softmax(attention_weights, -1).view(
batch_size, num_queries, self.n_heads, self.n_levels, self.n_points
)
# batch_size, num_queries, n_heads, n_levels, n_points, 2
if reference_points.shape[-1] == 2:
offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
sampling_locations = (
reference_points[:, :, None, :, None, :]
+ sampling_offsets / offset_normalizer[None, None, None, :, None, :]
)
elif reference_points.shape[-1] == 4:
sampling_locations = (
reference_points[:, :, None, :, None, :2]
+ sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
)
else:
raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
output = multi_scale_deformable_attention(value, spatial_shapes, sampling_locations, attention_weights)
output = self.output_proj(output)
return output, attention_weights
class Mask2FormerPixelDecoderEncoderLayer(nn.Module):
def __init__(self, config: Mask2FormerConfig):
super().__init__()
self.embed_dim = config.feature_size
self.self_attn = Mask2FormerPixelDecoderEncoderMultiscaleDeformableAttention(
embed_dim=self.embed_dim,
num_heads=config.num_attention_heads,
n_levels=3,
n_points=4,
)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.dropout = config.dropout
self.activation_fn = nn.functional.relu
self.activation_dropout = config.dropout
self.fc1 = nn.Linear(self.embed_dim, config.encoder_feedforward_dim)
self.fc2 = nn.Linear(config.encoder_feedforward_dim, self.embed_dim)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
position_embeddings: torch.Tensor = None,
reference_points=None,
spatial_shapes=None,
level_start_index=None,
output_attentions: bool = False,
):
"""
Args:
hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Input to the layer.
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Attention mask.
position_embeddings (`torch.FloatTensor`, *optional*):
Position embeddings, to be added to `hidden_states`.
reference_points (`torch.FloatTensor`, *optional*):
Reference points.
spatial_shapes (`torch.LongTensor`, *optional*):
Spatial shapes of the backbone feature maps.
level_start_index (`torch.LongTensor`, *optional*):
Level start index.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
residual = hidden_states
# Apply Multi-scale Deformable Attention Module on the multi-scale feature maps.
hidden_states, attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=hidden_states,
encoder_attention_mask=attention_mask,
position_embeddings=position_embeddings,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
output_attentions=output_attentions,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
residual = hidden_states
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
hidden_states = self.final_layer_norm(hidden_states)
if self.training:
if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any():
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights.transpose(1, 0),)
return outputs
# Modified from from transformers.models.detr.modeling_deformable_detr.DeformableDetrEncoder with DeformableDetrEncoder->Mask2FormerPixelDecoderEncoderOnly
class Mask2FormerPixelDecoderEncoderOnly(nn.Module):
"""
Transformer encoder consisting of *config.encoder_layers* deformable attention layers. Each layer is a
[`Mask2FormerPixelDecoderEncoderLayer`]. The encoder updates the flattened multi-scale feature maps through
multiple deformable attention layers.
Args:
config: Mask2FormerConfig
"""
def __init__(self, config: Mask2FormerConfig):
super().__init__()
self.config = config
self.dropout = config.dropout
self.layers = nn.ModuleList(
[Mask2FormerPixelDecoderEncoderLayer(config) for _ in range(config.encoder_layers)]
)
@staticmethod
def get_reference_points(spatial_shapes, valid_ratios, device):
"""
Get reference points for each feature map. Used in decoder.
Args:
spatial_shapes (`torch.LongTensor`):
Spatial shapes of each feature map, has shape of `(num_feature_levels, 2)`.
valid_ratios (`torch.FloatTensor`):
Valid ratios of each feature map, has shape of `(batch_size, num_feature_levels, 2)`.
device (`torch.device`):
Device on which to create the tensors.
Returns:
`torch.FloatTensor` of shape `(batch_size, num_queries, num_feature_levels, 2)`
"""
reference_points_list = []
for lvl, (height, width) in enumerate(spatial_shapes):
ref_y, ref_x = torch.meshgrid(
torch.linspace(0.5, height - 0.5, height, dtype=torch.float32, device=device),
torch.linspace(0.5, width - 0.5, width, dtype=torch.float32, device=device),
indexing="ij",
)
ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * height)
ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * width)
ref = torch.stack((ref_x, ref_y), -1)
reference_points_list.append(ref)
reference_points = torch.cat(reference_points_list, 1)
reference_points = reference_points[:, :, None] * valid_ratios[:, None]
return reference_points
def forward(
self,
inputs_embeds=None,
attention_mask=None,
position_embeddings=None,
spatial_shapes=None,
level_start_index=None,
valid_ratios=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
Args:
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Flattened feature map (output of the backbone + projection layer) that is passed to the encoder.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding pixel features. Mask values selected in `[0, 1]`:
- 1 for pixel features that are real (i.e. **not masked**),
- 0 for pixel features that are padding (i.e. **masked**).
[What are attention masks?](../glossary#attention-mask)
position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Position embeddings that are added to the queries and keys in each self-attention layer.
spatial_shapes (`torch.LongTensor` of shape `(num_feature_levels, 2)`):
Spatial shapes of each feature map.
level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`):
Starting index of each feature map.
valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`):
Ratio of valid area in each feature level.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
hidden_states = inputs_embeds
reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=inputs_embeds.device)
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
for i, encoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states.transpose(1, 0),)
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
position_embeddings=position_embeddings,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
if output_hidden_states:
all_hidden_states += (hidden_states.transpose(1, 0),)
return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
)
# Modified from from transformers.models.detr.modeling_deformable_detr.DeformableDetrModel with DeformableDetrModel->Mask2FormerPixelDecoder
class Mask2FormerPixelDecoder(nn.Module):
def __init__(self, config: Mask2FormerConfig, feature_channels):
super().__init__()
self.config = config
feature_dim = config.feature_size
mask_dim = config.mask_feature_size
num_pos_features = feature_dim // 2
self.position_embedding = Mask2FormerSinePositionEmbedding(num_pos_feats=num_pos_features, normalize=True)
self.num_feature_levels = 3
transformer_in_channels = feature_channels[-self.num_feature_levels :]
self.transformer_feature_strides = config.feature_strides[-self.num_feature_levels :]
self.feature_channels = feature_channels
self.level_embed = nn.Parameter(torch.Tensor(self.num_feature_levels, feature_dim))
# Create input projection layers
if self.num_feature_levels > 1:
input_projections_list = []
for in_channels in transformer_in_channels[::-1]:
input_projections_list.append(
nn.Sequential(
nn.Conv2d(in_channels, feature_dim, kernel_size=1),
nn.GroupNorm(32, feature_dim),
)
)
self.input_projections = nn.ModuleList(input_projections_list)
else:
self.input_projections = nn.ModuleList(
[
nn.Sequential(
nn.Conv2d(transformer_in_channels[-1], feature_dim, kernel_size=1),
nn.GroupNorm(32, feature_dim),
)
]
)
self.encoder = Mask2FormerPixelDecoderEncoderOnly(config)
self.mask_projection = nn.Conv2d(feature_dim, mask_dim, kernel_size=1, stride=1, padding=0)
# Extra FPN levels
stride = min(self.transformer_feature_strides)
self.common_stride = config.common_stride
self.num_fpn_levels = int(np.log2(stride) - np.log2(self.common_stride))
lateral_convs = []
output_convs = []
for idx, in_channels in enumerate(self.feature_channels[: self.num_fpn_levels]):
lateral_conv = nn.Sequential(
nn.Conv2d(in_channels, feature_dim, kernel_size=1, bias=False),
nn.GroupNorm(32, feature_dim),
)
output_conv = nn.Sequential(
nn.Conv2d(feature_dim, feature_dim, kernel_size=3, stride=1, padding=1, bias=False),
nn.GroupNorm(32, feature_dim),
nn.ReLU(),
)
self.add_module("adapter_{}".format(idx + 1), lateral_conv)
self.add_module("layer_{}".format(idx + 1), output_conv)
lateral_convs.append(lateral_conv)
output_convs.append(output_conv)
# Order convolutional layers from low to high resolution
self.lateral_convolutions = lateral_convs[::-1]
self.output_convolutions = output_convs[::-1]
def get_valid_ratio(self, mask):
"""Get the valid ratio of all feature maps."""
_, height, width = mask.shape
valid_height = torch.sum(~mask[:, :, 0], 1)
valid_width = torch.sum(~mask[:, 0, :], 1)
valid_ratio_heigth = valid_height.float() / height
valid_ratio_width = valid_width.float() / width
valid_ratio = torch.stack([valid_ratio_width, valid_ratio_heigth], -1)
return valid_ratio
def forward(
self,
features,
encoder_outputs=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
# Apply 1x1 convolution to reduce the channel dimension to d_model (256 by default)
input_embeds = []
position_embeddings = []
for level, x in enumerate(features[::-1][: self.num_feature_levels]):
input_embeds.append(self.input_projections[level](x.float()))
position_embeddings.append(self.position_embedding(x.float()))
masks = [
torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) for x in input_embeds
]
# Prepare encoder inputs (by flattening)
spatial_shapes = [(embed.shape[2], embed.shape[3]) for embed in input_embeds]
input_embeds_flat = torch.cat([embed.flatten(2).transpose(1, 2) for embed in input_embeds], 1)
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=input_embeds_flat.device)
masks_flat = torch.cat([mask.flatten(1) for mask in masks], 1)
position_embeddings = [embed.flatten(2).transpose(1, 2) for embed in position_embeddings]
level_pos_embed_flat = [x + self.level_embed[i].view(1, 1, -1) for i, x in enumerate(position_embeddings)]
level_pos_embed_flat = torch.cat(level_pos_embed_flat, 1)
level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
valid_ratios = torch.stack([self.get_valid_ratio(mask) for mask in masks], 1)
# Send input_embeds_flat + masks_flat + level_pos_embed_flat (backbone + proj layer output) through encoder
if encoder_outputs is None:
encoder_outputs = self.encoder(
inputs_embeds=input_embeds_flat,
attention_mask=masks_flat,
position_embeddings=level_pos_embed_flat,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
valid_ratios=valid_ratios,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = encoder_outputs.last_hidden_state
batch_size = last_hidden_state.shape[0]
split_sizes = [None] * self.num_feature_levels
for i in range(self.num_feature_levels):
if i < self.num_feature_levels - 1:
split_sizes[i] = level_start_index[i + 1] - level_start_index[i]
else:
split_sizes[i] = last_hidden_state.shape[1] - level_start_index[i]
encoder_output = torch.split(last_hidden_state, split_sizes, dim=1)
# Compute final features
outputs = [
x.transpose(1, 2).view(batch_size, -1, spatial_shapes[i][0], spatial_shapes[i][1])
for i, x in enumerate(encoder_output)
]
# Append extra FPN levels to outputs, ordered from low to high resolution
for idx, feature in enumerate(features[: self.num_fpn_levels][::-1]):
lateral_conv = self.lateral_convolutions[idx]
output_conv = self.output_convolutions[idx]
current_fpn = lateral_conv(feature.float())
# Following FPN implementation, we use nearest upsampling here
out = current_fpn + nn.functional.interpolate(
outputs[-1], size=current_fpn.shape[-2:], mode="bilinear", align_corners=False
)
out = output_conv(out)
outputs.append(out)
num_cur_levels = 0
multi_scale_features = []
for out in outputs:
if num_cur_levels < self.num_feature_levels:
multi_scale_features.append(out)
num_cur_levels += 1
return Mask2FormerPixelDecoderOutput(
mask_features=self.mask_projection(outputs[-1]),
multi_scale_features=tuple(multi_scale_features),
attentions=encoder_outputs.attentions,
)
class Mask2FormerPixelLevelModule(nn.Module):
def __init__(self, config: Mask2FormerConfig):
"""
Pixel Level Module proposed in [Masked-attention Mask Transformer for Universal Image
Segmentation](https://arxiv.org/abs/2112.01527). It runs the input image through a backbone and a pixel
decoder, generating multi-scale feature maps and pixel embeddings.
Args:
config ([`Mask2FormerConfig`]):
The configuration used to instantiate this model.
"""
super().__init__()
backbone_config_dict = config.backbone_config.to_dict()
backbone_config = SwinConfig.from_dict(backbone_config_dict)
self.encoder = AutoBackbone.from_config(backbone_config)
self.decoder = Mask2FormerPixelDecoder(config, feature_channels=self.encoder.channels)
def forward(self, pixel_values: Tensor, output_hidden_states: bool = False) -> Mask2FormerPixelLevelModuleOutput:
backbone_features = self.encoder(pixel_values).feature_maps
decoder_output = self.decoder(backbone_features, output_hidden_states=output_hidden_states)
return Mask2FormerPixelLevelModuleOutput(
encoder_last_hidden_state=backbone_features[-1],
encoder_hidden_states=tuple(backbone_features) if output_hidden_states else None,
decoder_last_hidden_state=decoder_output.mask_features,
decoder_hidden_states=decoder_output.multi_scale_features,
)
# Modified from transformers.models.detr.modeling_detr.DetrAttention with Detr->Mask2Former
class Mask2FormerAttention(nn.Module):
"""
Multi-headed attention from 'Attention Is All You Need' paper. Here, we add position embeddings to the queries and
keys (as explained in the DETR paper).
"""
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
if self.head_dim * num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {num_heads})."
)
self.scaling = self.head_dim**-0.5
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]):
return tensor if position_embeddings is None else tensor + position_embeddings
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_embeddings: Optional[torch.Tensor] = None,
key_value_states: Optional[torch.Tensor] = None,
key_value_position_embeddings: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
hidden_states = hidden_states.permute(1, 0, 2) if hidden_states is not None else None
position_embeddings = position_embeddings.permute(1, 0, 2) if position_embeddings is not None else None
key_value_states = key_value_states.permute(1, 0, 2) if key_value_states is not None else None
key_value_position_embeddings = (
key_value_position_embeddings.permute(1, 0, 2) if key_value_position_embeddings is not None else None
)
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
batch_size, target_len, embed_dim = hidden_states.size()
# add position embeddings to the hidden states before projecting to queries and keys
if position_embeddings is not None:
hidden_states_original = hidden_states
hidden_states = self.with_pos_embed(hidden_states, position_embeddings)
# add key-value position embeddings to the key value states
if key_value_position_embeddings is not None:
key_value_states_original = key_value_states
key_value_states = self.with_pos_embed(key_value_states, key_value_position_embeddings)
# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
# get key, value proj
if is_cross_attention:
# cross_attentions
key_states = self._shape(self.k_proj(key_value_states), -1, batch_size)
value_states = self._shape(self.v_proj(key_value_states_original), -1, batch_size)
else:
# self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, batch_size)
value_states = self._shape(self.v_proj(hidden_states_original), -1, batch_size)
proj_shape = (batch_size * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, target_len, batch_size).view(*proj_shape)
key_states = key_states.view(*proj_shape)
value_states = value_states.view(*proj_shape)
source_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
if attn_weights.size() != (batch_size * self.num_heads, target_len, source_len):
raise ValueError(
f"Attention weights should be of size {(batch_size * self.num_heads, target_len, source_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (batch_size * self.num_heads, target_len, source_len):
raise ValueError(
f"Attention mask should be of size {(target_len, batch_size * self.num_heads, source_len)}, but is"
f" {attention_mask.size()}"
)
attn_weights += attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if output_attentions:
# this operation is a bit awkward, but it's required to
# make sure that attn_weights keeps its gradient.
# In order to do so, attn_weights have to reshaped
# twice and have to be reused in the following
attn_weights_reshaped = attn_weights.view(batch_size, self.num_heads, target_len, source_len)
attn_weights = attn_weights_reshaped.view(batch_size * self.num_heads, target_len, source_len)
else:
attn_weights_reshaped = None
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.bmm(attn_probs, value_states)
if attn_output.size() != (batch_size * self.num_heads, target_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(batch_size, self.num_heads, target_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.view(batch_size, self.num_heads, target_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(batch_size, target_len, embed_dim)
attn_output = self.out_proj(attn_output).permute(1, 0, 2)
return attn_output, attn_weights_reshaped
class Mask2FormerMaskedAttentionDecoderLayer(nn.Module):
"""
The Mask2FormerMaskedAttentionDecoderLayer is made up of self-attention, cross (masked) attention as well as FFN
blocks. The cross attention block used as part of `Mask2FormerMaskedAttentionDecoderLayer` is actually a `masked
attention` block that restricts the attention to localized features centered around predicted segments which leads
to faster convergence and improved performance. The order of self and cross (i.e. masked) attention blocks have
also been swapped in Mask2FormerMaskedAttentionDecoder compared to a standard DetrDecoder as an optimization
improvement.
Args:
config (`Mask2FormerConfig`):
The configuration used to initialize the Mask2FormerMaskedAttentionDecoder.
"""
def __init__(self, config: Mask2FormerConfig):
super().__init__()
self.config = config
self.embed_dim = self.config.hidden_dim
self.pre_norm = self.config.pre_norm
self.self_attn = Mask2FormerAttention(
embed_dim=self.embed_dim,
num_heads=config.num_attention_heads,
dropout=config.dropout,
is_decoder=True,
)
self.dropout = self.config.dropout
self.activation_fn = ACT2FN[self.config.activation_function]
self.activation_dropout = self.config.dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.cross_attn = nn.MultiheadAttention(self.embed_dim, self.config.num_attention_heads, self.config.dropout)
self.cross_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.fc1 = nn.Linear(self.embed_dim, self.config.dim_feedforward)
self.fc2 = nn.Linear(self.config.dim_feedforward, self.embed_dim)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward_post(
self,
hidden_states: torch.Tensor,
level_index: int = None,
attention_mask: Optional[torch.Tensor] = None,
position_embeddings: Optional[torch.Tensor] = None,
query_position_embeddings: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
):
# Masked(Cross)-Attention Block
cross_attn_weights = None
self_attn_weights = None
residual = hidden_states
hidden_states, cross_attn_weights = self.cross_attn(
query=self.with_pos_embed(hidden_states, query_position_embeddings),
key=self.with_pos_embed(encoder_hidden_states[level_index], position_embeddings[level_index]),
value=encoder_hidden_states[level_index],
attn_mask=encoder_attention_mask,
key_padding_mask=None,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
hidden_states = self.cross_attn_layer_norm(hidden_states)
# Self Attention Block
residual = hidden_states
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
position_embeddings=query_position_embeddings,
attention_mask=None,
output_attentions=True,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
# Fully Connected
residual = hidden_states
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
hidden_states = self.final_layer_norm(hidden_states)
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights, cross_attn_weights)
return outputs
def forward_pre(
self,
hidden_states: torch.Tensor,
level_index: int = None,
attention_mask: Optional[torch.Tensor] = None,
position_embeddings: Optional[torch.Tensor] = None,
query_position_embeddings: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
):
# Masked(Cross)-Attention Block
cross_attn_weights = None
self_attn_weights = None
residual = hidden_states
hidden_states = self.cross_attn_layer_norm(hidden_states)
hidden_states, cross_attn_weights = self.cross_attn(
query=self.with_pos_embed(hidden_states, query_position_embeddings),
key=self.with_pos_embed(encoder_hidden_states[level_index], position_embeddings[level_index]),
value=encoder_hidden_states[level_index],
attn_mask=encoder_attention_mask,
key_padding_mask=None,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
# Self Attention Block
residual = hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
position_embeddings=query_position_embeddings,
attention_mask=None,
output_attentions=True,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights, cross_attn_weights)
return outputs
def forward(
self,
hidden_states: torch.Tensor,
level_index: int = None,
attention_mask: Optional[torch.Tensor] = None,
position_embeddings: Optional[torch.Tensor] = None,
query_position_embeddings: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
):
"""
Args:
hidden_states (`torch.FloatTensor`):
Input to the layer of shape `(seq_len, batch, embed_dim)`.
attention_mask (`torch.FloatTensor`):
Attention mask of shape `(1, seq_len, tgt_len, src_len)`.
position_embeddings (`torch.FloatTensor`, *optional*):
Position embeddings that are added to the keys in the masked-attention layer.
query_position_embeddings (`torch.FloatTensor`, *optional*):
Position embeddings that are added to the queries and keys in the self-attention layer.
encoder_hidden_states (`torch.FloatTensor`):
Cross attention input to the layer of shape `(seq_len, batch, embed_dim)`.
encoder_attention_mask (`torch.FloatTensor`):
Encoder attention mask of size`(1, seq_len, tgt_len, src_len)`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
if self.pre_norm:
outputs = self.forward_pre(
hidden_states=hidden_states,
level_index=level_index,
position_embeddings=position_embeddings,
query_position_embeddings=query_position_embeddings,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
)
else:
outputs = self.forward_post(
hidden_states=hidden_states,
level_index=level_index,
position_embeddings=position_embeddings,
query_position_embeddings=query_position_embeddings,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
)
return outputs
class Mask2FormerMaskedAttentionDecoder(nn.Module):
"""
Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a
[`Mask2FormerMaskedAttentionDecoderLayer`]. The decoder updates the query embeddings through multiple cross
(masked) and self-attention layers. The decoder uses a new **masked attention** mechanism instead of the standard
cross-attention, which extracts localized features by constraining cross-attention to within the foreground region
of the predicted mask for each query, instead of attending to the full feature map.
Args:
config: (`Mask2FormerConfig`):
Configuration used to instantiate Mask2FormerMaskedAttentionDecoder.
"""
def __init__(self, config: Mask2FormerConfig):
super().__init__()
self.config = config
self.mask_feature_size = config.mask_feature_size
self.dropout = config.dropout
self.layerdrop = config.dropout
self.num_feature_levels = 3 # level embedding (3 scales)
self.decoder_layers = config.decoder_layers - 1
self.layers = nn.ModuleList(
[Mask2FormerMaskedAttentionDecoderLayer(self.config) for _ in range(self.decoder_layers)]
)
self.layernorm = nn.LayerNorm(config.hidden_dim)
self.mask_predictor = Mask2FormerMaskPredictor(
hidden_size=config.hidden_dim,
num_heads=config.num_attention_heads,
mask_feature_size=self.mask_feature_size,
)
self.gradient_checkpointing = False
def forward(
self,
inputs_embeds: torch.Tensor = None,
multi_stage_positional_embeddings: torch.Tensor = None,
pixel_embeddings: torch.Tensor = None,
encoder_hidden_states: torch.Tensor = None,
query_position_embeddings: torch.Tensor = None,
feature_size_list: List = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
r"""
Args:
inputs_embeds (`torch.FloatTensor` of shape `(num_queries, batch_size, hidden_size)`):
The query embeddings that are passed into the decoder.
multi_stage_positional_embeddings (`torch.FloatTensor` of shape `(height*width, batch_size, num_channels)`):
Position embeddings that are added to the keys in each cross(masked)-attention layer.
pixel_embeddings (`torch.FloatTensor`):
Tensor of shape `(batch_size, num_channels, height, width)`, 1/4 scale features from the last Pixel
Decoder.
query_position_embeddings (`torch.FloatTensor` of shape `(num_queries, batch_size, hidden_size)`):
, *optional*): Position embeddings that are added to the queries and keys in each self-attention layer.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the
cross(masked)-attention of the decoder.
feature_size_list (`List[torch.Size]` ):
This is a list containing shapes (height & width) of multi-scale features from the Pixel Decoder.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if inputs_embeds is not None:
hidden_states = inputs_embeds
# intermediate hidden states with layernorm applied - required for predicting class logits
intermediate = ()
# decoder layers
all_hidden_states = () if output_hidden_states else None
attentions = () if output_attentions else None
# intermediate mask predictions from transformer decoder layers
intermediate_mask_predictions = ()
intermediate_hidden_states = self.layernorm(inputs_embeds)
intermediate += (intermediate_hidden_states,)
predicted_mask, attention_mask = self.mask_predictor(
intermediate_hidden_states, pixel_embeddings, feature_size_list[0]
)
intermediate_mask_predictions += (predicted_mask,)
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
dropout_probability = random.uniform(0, 1)
if self.training and (dropout_probability < self.layerdrop):
continue
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
encoder_hidden_states,
None,
None,
)
else:
level_index = idx % self.num_feature_levels
attention_mask[torch.where(attention_mask.sum(-1) == attention_mask.shape[-1])] = False
layer_outputs = decoder_layer(
hidden_states,
level_index=level_index,
position_embeddings=multi_stage_positional_embeddings,
query_position_embeddings=query_position_embeddings,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=attention_mask,
output_attentions=output_attentions,
)
intermediate_hidden_states = self.layernorm(layer_outputs[0])
predicted_mask, attention_mask = self.mask_predictor(
intermediate_hidden_states,
pixel_embeddings,
feature_size_list[(idx + 1) % self.num_feature_levels],
)
intermediate_mask_predictions += (predicted_mask,)
# add intermediate hidden states with layer norm applied which will be used for predicting class logits
intermediate += (intermediate_hidden_states,)
hidden_states = layer_outputs[0]
if output_attentions:
attentions += (layer_outputs[1],)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
hidden_states = hidden_states.transpose(1, 0)
if not return_dict:
outputs = [hidden_states, all_hidden_states, attentions, intermediate, intermediate_mask_predictions]
return tuple(v for v in outputs if v is not None)
return Mask2FormerMaskedAttentionDecoderOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=attentions,
intermediate_hidden_states=intermediate,
masks_queries_logits=intermediate_mask_predictions,
)
# Copied from transformers.models.maskformer.modeling_maskformer.PredictionBlock with MaskFormer->Mask2Former
class Mask2FormerPredictionBlock(nn.Module):
def __init__(self, in_dim: int, out_dim: int, activation: nn.Module) -> None:
super().__init__()
self.layers = [nn.Linear(in_dim, out_dim), activation]
# Maintain submodule indexing as if part of a Sequential block
for i, layer in enumerate(self.layers):
self.add_module(str(i), layer)
def forward(self, input: Tensor) -> Tensor:
hidden_state = input
for layer in self.layers:
hidden_state = layer(hidden_state)
return hidden_state
class Mask2FormerMLPPredictionHead(nn.Module):
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int = 3):
"""
A classic Multi Layer Perceptron (MLP).
Args:
input_dim (`int`):
The input dimensions.
hidden_dim (`int`):
The hidden dimensions.
output_dim (`int`):
The output dimensions.
num_layers (int, *optional*, defaults to 3):
The number of layers.
"""
super().__init__()
in_dims = [input_dim] + [hidden_dim] * (num_layers - 1)
out_dims = [hidden_dim] * (num_layers - 1) + [output_dim]
self.layers = []
for i, (in_dim, out_dim) in enumerate(zip(in_dims, out_dims)):
activation = nn.ReLU() if i < num_layers - 1 else nn.Identity()
layer = Mask2FormerPredictionBlock(in_dim, out_dim, activation=activation)
self.layers.append(layer)
# Provide backwards compatibility from when the class inherited from nn.Sequential
# In nn.Sequential subclasses, the name given to the layer is its index in the sequence.
# In nn.Module subclasses they derived from the instance attribute they are assigned to e.g.
# self.my_layer_name = Layer()
# We can't give instance attributes integer names i.e. self.0 is not permitted and so need to register
# explicitly
self.add_module(str(i), layer)
def forward(self, input: Tensor) -> Tensor:
hidden_state = input
for layer in self.layers:
hidden_state = layer(hidden_state)
return hidden_state
class Mask2FormerMaskPredictor(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, mask_feature_size: torch.Tensor):
"""
This class is used to get the predicted mask for a given Mask2FormerMaskedAttentionDecoder layer. It also
generates the binarized attention mask associated with the given predicted mask. The attention mask obtained
using predicted mask of the (l-1)th decoder layer is fed to the cross(masked)-attention block of the next
decoder layer as input.
Args:
hidden_size (`int`):
The feature dimension of the Mask2FormerMaskedAttentionDecoder
num_heads (`int`):
The number of heads used in the Mask2FormerMaskedAttentionDecoder
mask_feature_size: (`torch.Tensor`):
one of the output dimensions of the predicted masks for each query
"""
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.mask_embedder = Mask2FormerMLPPredictionHead(self.hidden_size, self.hidden_size, mask_feature_size)
def forward(self, outputs: torch.Tensor, pixel_embeddings: torch.Tensor, attention_mask_target_size: int = None):
mask_embeddings = self.mask_embedder(outputs.transpose(0, 1))
# Sum up over the channels
outputs_mask = torch.einsum("bqc, bchw -> bqhw", mask_embeddings, pixel_embeddings)
attention_mask = nn.functional.interpolate(
outputs_mask, size=attention_mask_target_size, mode="bilinear", align_corners=False
)
attention_mask = attention_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, self.num_heads, 1, 1)
attention_mask = (attention_mask.flatten(0, 1) < 0.5).bool()
attention_mask = attention_mask.detach()
return outputs_mask, attention_mask
class Mask2FormerTransformerModule(nn.Module):
"""
The Mask2Former's transformer module.
"""
def __init__(self, in_features: int, config: Mask2FormerConfig):
super().__init__()
hidden_dim = config.hidden_dim
self.num_feature_levels = 3
self.position_embedder = Mask2FormerSinePositionEmbedding(num_pos_feats=hidden_dim // 2, normalize=True)
self.queries_embedder = nn.Embedding(config.num_queries, hidden_dim)
self.queries_features = nn.Embedding(config.num_queries, hidden_dim)
self.input_projections = []
for _ in range(self.num_feature_levels):
if in_features != hidden_dim or config.enforce_input_projection:
self.input_projections.append(nn.Conv2d(in_features, hidden_dim, kernel_size=1))
else:
self.input_projections.append(nn.Sequential())
self.decoder = Mask2FormerMaskedAttentionDecoder(config=config)
self.level_embed = nn.Embedding(self.num_feature_levels, hidden_dim)
def forward(
self,
multi_scale_features: List[Tensor],
mask_features: Tensor,
output_hidden_states: bool = False,
output_attentions: bool = False,
) -> Mask2FormerMaskedAttentionDecoderOutput:
multi_stage_features = []
multi_stage_positional_embeddings = []
size_list = []
for i in range(self.num_feature_levels):
size_list.append(multi_scale_features[i].shape[-2:])
multi_stage_positional_embeddings.append(self.position_embedder(multi_scale_features[i], None).flatten(2))
multi_stage_features.append(
self.input_projections[i](multi_scale_features[i]).flatten(2)
+ self.level_embed.weight[i][None, :, None]
)
# Flatten (batch_size, num_channels, height, width) -> (height*width, batch_size, num_channels)
multi_stage_positional_embeddings[-1] = multi_stage_positional_embeddings[-1].permute(2, 0, 1)
multi_stage_features[-1] = multi_stage_features[-1].permute(2, 0, 1)
_, batch_size, _ = multi_stage_features[0].shape
# [num_queries, batch_size, num_channels]
query_embeddings = self.queries_embedder.weight.unsqueeze(1).repeat(1, batch_size, 1)
query_features = self.queries_features.weight.unsqueeze(1).repeat(1, batch_size, 1)
decoder_output = self.decoder(
inputs_embeds=query_features,
multi_stage_positional_embeddings=multi_stage_positional_embeddings,
pixel_embeddings=mask_features,
encoder_hidden_states=multi_stage_features,
query_position_embeddings=query_embeddings,
feature_size_list=size_list,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
return_dict=True,
)
return decoder_output
MASK2FORMER_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
behavior.
Parameters:
config ([`Mask2FormerConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
MASK2FORMER_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using [`AutoFeatureExtractor`]. See
[`AutoFeatureExtractor.__call__`] for details.
pixel_mask (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
Mask to avoid performing attention on padding pixel values. Mask values selected in `[0, 1]`:
- 1 for pixels that are real (i.e. **not masked**),
- 0 for pixels that are padding (i.e. **masked**).
[What are attention masks?](../glossary#attention-mask)
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of Detr's decoder attention layers.
return_dict (`bool`, *optional*):
Whether or not to return a [`~Mask2FormerModelOutput`] instead of a plain tuple.
"""
class Mask2FormerPreTrainedModel(PreTrainedModel):
config_class = Mask2FormerConfig
base_model_prefix = "model"
main_input_name = "pixel_values"
def _init_weights(self, module: nn.Module):
xavier_std = self.config.init_xavier_std
std = self.config.init_std
if isinstance(module, Mask2FormerTransformerModule):
if module.input_projections is not None:
for input_projection in module.input_projections:
if not isinstance(input_projection, nn.Sequential):
nn.init.xavier_uniform_(input_projection.weight, gain=xavier_std)
nn.init.constant_(input_projection.bias, 0)
elif isinstance(module, Mask2FormerPixelDecoderEncoderMultiscaleDeformableAttention):
nn.init.constant_(module.sampling_offsets.weight.data, 0.0)
thetas = torch.arange(module.n_heads, dtype=torch.float32) * (2.0 * math.pi / module.n_heads)
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
grid_init = (
(grid_init / grid_init.abs().max(-1, keepdim=True)[0])
.view(module.n_heads, 1, 1, 2)
.repeat(1, module.n_levels, module.n_points, 1)
)
for i in range(module.n_points):
grid_init[:, :, i, :] *= i + 1
with torch.no_grad():
module.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
nn.init.constant_(module.attention_weights.weight.data, 0.0)
nn.init.constant_(module.attention_weights.bias.data, 0.0)
nn.init.xavier_uniform_(module.value_proj.weight.data)
nn.init.constant_(module.value_proj.bias.data, 0.0)
nn.init.xavier_uniform_(module.output_proj.weight.data)
nn.init.constant_(module.output_proj.bias.data, 0.0)
elif isinstance(module, Mask2FormerMaskedAttentionDecoderLayer):
for p in module.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p, gain=xavier_std)
elif isinstance(module, Mask2FormerPixelLevelModule):
for submodule in module.modules():
if isinstance(submodule, (nn.Conv2d, nn.Linear)):
submodule.weight.data.normal_(mean=0.0, std=std)
if submodule.bias is not None:
submodule.bias.data.zero_()
elif isinstance(module, Mask2FormerPixelDecoder):
for p in module.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
nn.init.normal_(module.level_embed, std=0)
elif isinstance(module, Mask2FormerPixelDecoderEncoderOnly):
for p in module.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
if hasattr(module, "reference_points"):
nn.init.xavier_uniform_(module.reference_points.weight.data, gain=1.0)
nn.init.constant_(module.reference_points.bias.data, 0.0)
@add_start_docstrings(
"The bare Mask2Former Model outputting raw hidden-states without any specific head on top.",
MASK2FORMER_START_DOCSTRING,
)
class Mask2FormerModel(Mask2FormerPreTrainedModel):
main_input_name = "pixel_values"
def __init__(self, config: Mask2FormerConfig):
super().__init__(config)
self.pixel_level_module = Mask2FormerPixelLevelModule(config)
self.transformer_module = Mask2FormerTransformerModule(in_features=config.feature_size, config=config)
self.post_init()
@add_start_docstrings_to_model_forward(MASK2FORMER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Mask2FormerModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values: Tensor,
pixel_mask: Optional[Tensor] = None,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Mask2FormerModelOutput:
r"""
Returns:
`Mask2FormerModelOutput`
Examples:
```python
>>> import torch
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoImageProcessor, Mask2FormerModel
>>> # download texting image
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> # Load image preprocessor and Mask2FormerModel trained on ADE20K instance segmentation dataset
>>> image_processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-small-ade-instance")
>>> model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-small-ade-instance")
>>> inputs = image_processor(image, return_tensors="pt")
>>> with torch.no_grad():
... outputs = model(**inputs)
```
"""
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
batch_size, _, height, width = pixel_values.shape
if pixel_mask is None:
pixel_mask = torch.ones((batch_size, height, width), device=pixel_values.device)
pixel_level_module_output = self.pixel_level_module(
pixel_values=pixel_values, output_hidden_states=output_hidden_states
)
transformer_module_output = self.transformer_module(
multi_scale_features=pixel_level_module_output.decoder_hidden_states,
mask_features=pixel_level_module_output.decoder_last_hidden_state,
output_hidden_states=True,
output_attentions=output_attentions,
)
encoder_hidden_states = None
pixel_decoder_hidden_states = None
transformer_decoder_hidden_states = None
transformer_decoder_intermediate_states = None
if output_hidden_states:
encoder_hidden_states = pixel_level_module_output.encoder_hidden_states
pixel_decoder_hidden_states = pixel_level_module_output.decoder_hidden_states
transformer_decoder_hidden_states = transformer_module_output.hidden_states
transformer_decoder_intermediate_states = transformer_module_output.intermediate_hidden_states
output = Mask2FormerModelOutput(
encoder_last_hidden_state=pixel_level_module_output.encoder_last_hidden_state,
pixel_decoder_last_hidden_state=pixel_level_module_output.decoder_last_hidden_state,
transformer_decoder_last_hidden_state=transformer_module_output.last_hidden_state,
encoder_hidden_states=encoder_hidden_states,
pixel_decoder_hidden_states=pixel_decoder_hidden_states,
transformer_decoder_hidden_states=transformer_decoder_hidden_states,
transformer_decoder_intermediate_states=transformer_decoder_intermediate_states,
attentions=transformer_module_output.attentions,
masks_queries_logits=transformer_module_output.masks_queries_logits,
)
if not return_dict:
output = tuple(v for v in output.values() if v is not None)
return output
@add_start_docstrings(
"The Mask2Former Model with heads on top for instance/semantic/panoptic segmentation.",
MASK2FORMER_START_DOCSTRING,
)
class Mask2FormerForUniversalSegmentation(Mask2FormerPreTrainedModel):
main_input_name = "pixel_values"
def __init__(self, config: Mask2FormerConfig):
super().__init__(config)
self.model = Mask2FormerModel(config)
self.weight_dict: Dict[str, float] = {
"loss_cross_entropy": config.class_weight,
"loss_mask": config.mask_weight,
"loss_dice": config.dice_weight,
}
self.class_predictor = nn.Linear(config.hidden_dim, config.num_labels + 1)
self.criterion = Mask2FormerLoss(config=config, weight_dict=self.weight_dict)
self.post_init()
def get_loss_dict(
self,
masks_queries_logits: Tensor,
class_queries_logits: Tensor,
mask_labels: Tensor,
class_labels: Tensor,
auxiliary_predictions: Dict[str, Tensor],
) -> Dict[str, Tensor]:
loss_dict: Dict[str, Tensor] = self.criterion(
masks_queries_logits=masks_queries_logits,
class_queries_logits=class_queries_logits,
mask_labels=mask_labels,
class_labels=class_labels,
auxiliary_predictions=auxiliary_predictions,
)
# weight each loss by `self.weight_dict[<LOSS_NAME>]` including auxiliary losses
for key, weight in self.weight_dict.items():
for loss_key, loss in loss_dict.items():
if key in loss_key:
loss *= weight
return loss_dict
def get_loss(self, loss_dict: Dict[str, Tensor]) -> Tensor:
return sum(loss_dict.values())
def get_auxiliary_logits(self, classes: torch.Tensor, output_masks: torch.Tensor):
auxiliary_logits: List[Dict(str, Tensor)] = []
for aux_binary_masks, aux_classes in zip(output_masks[:-1], classes[:-1]):
auxiliary_logits.append({"masks_queries_logits": aux_binary_masks, "class_queries_logits": aux_classes})
return auxiliary_logits
@add_start_docstrings_to_model_forward(MASK2FORMER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Mask2FormerForUniversalSegmentationOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values: Tensor,
mask_labels: Optional[List[Tensor]] = None,
class_labels: Optional[List[Tensor]] = None,
pixel_mask: Optional[Tensor] = None,
output_hidden_states: Optional[bool] = None,
output_auxiliary_logits: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Mask2FormerForUniversalSegmentationOutput:
r"""
mask_labels (`List[torch.Tensor]`, *optional*):
List of mask labels of shape `(num_labels, height, width)` to be fed to a model
class_labels (`List[torch.LongTensor]`, *optional*):
list of target class labels of shape `(num_labels, height, width)` to be fed to a model. They identify the
labels of `mask_labels`, e.g. the label of `mask_labels[i][j]` if `class_labels[i][j]`.
Returns:
`Mask2FormerUniversalSegmentationOutput`
Examples:
```python
>>> from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
>>> from PIL import Image
>>> import requests
>>> import torch
>>> # Load Mask2Former trained on ADE20K panoptic segmentation dataset
>>> image_processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-small-ade-panoptic")
>>> model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-small-ade-panoptic")
>>> url = (
... "https://huggingface.co/datasets/hf-internal-testing/fixtures_ade20k/resolve/main/ADE_val_00000001.jpg"
... )
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = image_processor(image, return_tensors="pt")
>>> with torch.no_grad():
... outputs = model(**inputs)
>>> # Model predicts class_queries_logits of shape `(batch_size, num_queries)`
>>> # and masks_queries_logits of shape `(batch_size, num_queries, height, width)`
>>> class_queries_logits = outputs.class_queries_logits
>>> masks_queries_logits = outputs.masks_queries_logits
>>> # Perform post-processing to get semantic, instance or panoptic segmentation maps
>>> pred_semantic_map = image_processor.post_process_semantic_segmentation(
... outputs, target_sizes=[image.size[::-1]]
... )[0]
>>> pred_instance_map = image_processor.post_process_instance_segmentation(
... outputs, target_sizes=[image.size[::-1]]
... )[0]["segmentation"]
>>> pred_panoptic_map = image_processor.post_process_panoptic_segmentation(
... outputs, target_sizes=[image.size[::-1]]
... )[0]["segmentation"]
```
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
pixel_values=pixel_values,
pixel_mask=pixel_mask,
output_hidden_states=output_hidden_states or self.config.use_auxiliary_loss,
output_attentions=output_attentions,
return_dict=True,
)
loss, loss_dict, auxiliary_logits = None, None, None
class_queries_logits = ()
for decoder_output in outputs.transformer_decoder_intermediate_states:
class_prediction = self.class_predictor(decoder_output.transpose(0, 1))
class_queries_logits += (class_prediction,)
masks_queries_logits = outputs.masks_queries_logits
auxiliary_logits = self.get_auxiliary_logits(class_queries_logits, masks_queries_logits)
if mask_labels is not None and class_labels is not None:
loss_dict = self.get_loss_dict(
masks_queries_logits=masks_queries_logits[-1],
class_queries_logits=class_queries_logits[-1],
mask_labels=mask_labels,
class_labels=class_labels,
auxiliary_predictions=auxiliary_logits,
)
loss = self.get_loss(loss_dict)
encoder_hidden_states = None
pixel_decoder_hidden_states = None
transformer_decoder_hidden_states = None
if output_hidden_states:
encoder_hidden_states = outputs.encoder_hidden_states
pixel_decoder_hidden_states = outputs.pixel_decoder_hidden_states
transformer_decoder_hidden_states = outputs.transformer_decoder_hidden_states
output_auxiliary_logits = (
self.config.use_auxiliary_loss if output_auxiliary_logits is None else output_auxiliary_logits
)
if not output_auxiliary_logits:
auxiliary_logits = None
output = Mask2FormerForUniversalSegmentationOutput(
loss=loss,
class_queries_logits=class_queries_logits[-1],
masks_queries_logits=masks_queries_logits[-1],
auxiliary_logits=auxiliary_logits,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
pixel_decoder_last_hidden_state=outputs.pixel_decoder_last_hidden_state,
transformer_decoder_last_hidden_state=outputs.transformer_decoder_last_hidden_state,
encoder_hidden_states=encoder_hidden_states,
pixel_decoder_hidden_states=pixel_decoder_hidden_states,
transformer_decoder_hidden_states=transformer_decoder_hidden_states,
attentions=outputs.attentions,
)
if not return_dict:
output = tuple(v for v in output.values() if v is not None)
if loss is not None:
output = ((loss)) + output
return output
...@@ -3568,6 +3568,30 @@ class MarkupLMPreTrainedModel(metaclass=DummyObject): ...@@ -3568,6 +3568,30 @@ class MarkupLMPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
MASK2FORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
class Mask2FormerForUniversalSegmentation(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Mask2FormerModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Mask2FormerPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
MASKFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None MASKFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the PyTorch Mask2Former model. """
import inspect
import unittest
import numpy as np
from tests.test_modeling_common import floats_tensor
from transformers import Mask2FormerConfig, is_torch_available, is_vision_available
from transformers.testing_utils import require_torch, require_torch_multi_gpu, require_vision, slow, torch_device
from transformers.utils import cached_property
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin
if is_torch_available():
import torch
from transformers import Mask2FormerForUniversalSegmentation, Mask2FormerModel
if is_vision_available():
from transformers import MaskFormerImageProcessor
if is_vision_available():
from PIL import Image
class Mask2FormerModelTester:
def __init__(
self,
parent,
batch_size=2,
is_training=True,
use_auxiliary_loss=False,
num_queries=10,
num_channels=3,
min_size=32 * 8,
max_size=32 * 8,
num_labels=4,
hidden_dim=64,
):
self.parent = parent
self.batch_size = batch_size
self.is_training = is_training
self.use_auxiliary_loss = use_auxiliary_loss
self.num_queries = num_queries
self.num_channels = num_channels
self.min_size = min_size
self.max_size = max_size
self.num_labels = num_labels
self.hidden_dim = hidden_dim
self.mask_feature_size = hidden_dim
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.min_size, self.max_size]).to(
torch_device
)
pixel_mask = torch.ones([self.batch_size, self.min_size, self.max_size], device=torch_device)
mask_labels = (
torch.rand([self.batch_size, self.num_labels, self.min_size, self.max_size], device=torch_device) > 0.5
).float()
class_labels = (torch.rand((self.batch_size, self.num_labels), device=torch_device) > 0.5).long()
config = self.get_config()
return config, pixel_values, pixel_mask, mask_labels, class_labels
def get_config(self):
config = Mask2FormerConfig(
hidden_size=self.hidden_dim,
)
config.num_queries = self.num_queries
config.num_labels = self.num_labels
config.backbone_config.depths = [1, 1, 1, 1]
config.backbone_config.num_channels = self.num_channels
config.encoder_feedforward_dim = 64
config.dim_feedforward = 128
config.hidden_dim = self.hidden_dim
config.mask_feature_size = self.hidden_dim
config.feature_size = self.hidden_dim
return config
def prepare_config_and_inputs_for_common(self):
config, pixel_values, pixel_mask, _, _ = self.prepare_config_and_inputs()
inputs_dict = {"pixel_values": pixel_values, "pixel_mask": pixel_mask}
return config, inputs_dict
def check_output_hidden_state(self, output, config):
encoder_hidden_states = output.encoder_hidden_states
pixel_decoder_hidden_states = output.pixel_decoder_hidden_states
transformer_decoder_hidden_states = output.transformer_decoder_hidden_states
self.parent.assertTrue(len(encoder_hidden_states), len(config.backbone_config.depths))
self.parent.assertTrue(len(pixel_decoder_hidden_states), len(config.backbone_config.depths))
self.parent.assertTrue(len(transformer_decoder_hidden_states), config.decoder_layers)
def create_and_check_mask2former_model(self, config, pixel_values, pixel_mask, output_hidden_states=False):
with torch.no_grad():
model = Mask2FormerModel(config=config)
model.to(torch_device)
model.eval()
output = model(pixel_values=pixel_values, pixel_mask=pixel_mask)
output = model(pixel_values, output_hidden_states=True)
self.parent.assertEqual(
output.transformer_decoder_last_hidden_state.shape,
(self.batch_size, self.num_queries, self.hidden_dim),
)
# let's ensure the other two hidden state exists
self.parent.assertTrue(output.pixel_decoder_last_hidden_state is not None)
self.parent.assertTrue(output.encoder_last_hidden_state is not None)
if output_hidden_states:
self.check_output_hidden_state(output, config)
def create_and_check_mask2former_instance_segmentation_head_model(
self, config, pixel_values, pixel_mask, mask_labels, class_labels
):
model = Mask2FormerForUniversalSegmentation(config=config)
model.to(torch_device)
model.eval()
def comm_check_on_output(result):
# let's still check that all the required stuff is there
self.parent.assertTrue(result.transformer_decoder_last_hidden_state is not None)
self.parent.assertTrue(result.pixel_decoder_last_hidden_state is not None)
self.parent.assertTrue(result.encoder_last_hidden_state is not None)
# okay, now we need to check the logits shape
# due to the encoder compression, masks have a //4 spatial size
self.parent.assertEqual(
result.masks_queries_logits.shape,
(self.batch_size, self.num_queries, self.min_size // 4, self.max_size // 4),
)
# + 1 for null class
self.parent.assertEqual(
result.class_queries_logits.shape, (self.batch_size, self.num_queries, self.num_labels + 1)
)
with torch.no_grad():
result = model(pixel_values=pixel_values, pixel_mask=pixel_mask)
result = model(pixel_values)
comm_check_on_output(result)
result = model(
pixel_values=pixel_values, pixel_mask=pixel_mask, mask_labels=mask_labels, class_labels=class_labels
)
comm_check_on_output(result)
self.parent.assertTrue(result.loss is not None)
self.parent.assertEqual(result.loss.shape, torch.Size([1]))
@require_torch
class Mask2FormerModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (Mask2FormerModel, Mask2FormerForUniversalSegmentation) if is_torch_available() else ()
is_encoder_decoder = False
test_pruning = False
test_head_masking = False
test_missing_keys = False
def setUp(self):
self.model_tester = Mask2FormerModelTester(self)
self.config_tester = ConfigTester(self, config_class=Mask2FormerConfig, has_text_modality=False)
def test_config(self):
self.config_tester.run_common_tests()
def test_mask2former_model(self):
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
self.model_tester.create_and_check_mask2former_model(config, **inputs, output_hidden_states=False)
def test_mask2former_instance_segmentation_head_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_mask2former_instance_segmentation_head_model(*config_and_inputs)
@unittest.skip(reason="Mask2Former does not use inputs_embeds")
def test_inputs_embeds(self):
pass
@unittest.skip(reason="Mask2Former does not have a get_input_embeddings method")
def test_model_common_attributes(self):
pass
@unittest.skip(reason="Mask2Former is not a generative model")
def test_generate_without_input_ids(self):
pass
@unittest.skip(reason="Mask2Former does not use token embeddings")
def test_resize_tokens_embeddings(self):
pass
@require_torch_multi_gpu
@unittest.skip(
reason="Mask2Former has some layers using `add_module` which doesn't work well with `nn.DataParallel`"
)
def test_multi_gpu_data_parallel_forward(self):
pass
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.forward)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["pixel_values"]
self.assertListEqual(arg_names[:1], expected_arg_names)
@slow
def test_model_from_pretrained(self):
for model_name in ["facebook/mask2former-swin-small-coco-instance"]:
model = Mask2FormerModel.from_pretrained(model_name)
self.assertIsNotNone(model)
def test_model_with_labels(self):
size = (self.model_tester.min_size,) * 2
inputs = {
"pixel_values": torch.randn((2, 3, *size), device=torch_device),
"mask_labels": torch.randn((2, 10, *size), device=torch_device),
"class_labels": torch.zeros(2, 10, device=torch_device).long(),
}
config = self.model_tester.get_config()
model = Mask2FormerForUniversalSegmentation(config).to(torch_device)
outputs = model(**inputs)
self.assertTrue(outputs.loss is not None)
def test_hidden_states_output(self):
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
self.model_tester.create_and_check_mask2former_model(config, **inputs, output_hidden_states=True)
def test_attention_outputs(self):
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config).to(torch_device)
outputs = model(**inputs, output_attentions=True)
self.assertTrue(outputs.attentions is not None)
def test_training(self):
if not self.model_tester.is_training:
return
model_class = self.all_model_classes[1]
config, pixel_values, pixel_mask, mask_labels, class_labels = self.model_tester.prepare_config_and_inputs()
model = model_class(config)
model.to(torch_device)
model.train()
loss = model(pixel_values, mask_labels=mask_labels, class_labels=class_labels).loss
loss.backward()
def test_retain_grad_hidden_states_attentions(self):
model_class = self.all_model_classes[1]
config, pixel_values, pixel_mask, mask_labels, class_labels = self.model_tester.prepare_config_and_inputs()
config.output_hidden_states = True
config.output_attentions = True
model = model_class(config).to(torch_device)
model.train()
outputs = model(pixel_values, mask_labels=mask_labels, class_labels=class_labels)
encoder_hidden_states = outputs.encoder_hidden_states[0]
encoder_hidden_states.retain_grad()
pixel_decoder_hidden_states = outputs.pixel_decoder_hidden_states[0]
pixel_decoder_hidden_states.retain_grad()
transformer_decoder_hidden_states = outputs.transformer_decoder_hidden_states[0]
transformer_decoder_hidden_states.retain_grad()
attentions = outputs.attentions[0]
attentions.retain_grad()
outputs.loss.backward(retain_graph=True)
self.assertIsNotNone(encoder_hidden_states.grad)
self.assertIsNotNone(pixel_decoder_hidden_states.grad)
self.assertIsNotNone(transformer_decoder_hidden_states.grad)
self.assertIsNotNone(attentions.grad)
TOLERANCE = 1e-4
# We will verify our results on an image of cute cats
def prepare_img():
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
return image
@require_vision
@slow
class Mask2FormerModelIntegrationTest(unittest.TestCase):
@cached_property
def model_checkpoints(self):
return "facebook/mask2former-swin-small-coco-instance"
@cached_property
def default_feature_extractor(self):
return MaskFormerImageProcessor.from_pretrained(self.model_checkpoints) if is_vision_available() else None
def test_inference_no_head(self):
model = Mask2FormerModel.from_pretrained(self.model_checkpoints).to(torch_device)
feature_extractor = self.default_feature_extractor
image = prepare_img()
inputs = feature_extractor(image, return_tensors="pt").to(torch_device)
inputs_shape = inputs["pixel_values"].shape
# check size is divisible by 32
self.assertTrue((inputs_shape[-1] % 32) == 0 and (inputs_shape[-2] % 32) == 0)
# check size
self.assertEqual(inputs_shape, (1, 3, 384, 384))
with torch.no_grad():
outputs = model(**inputs)
expected_slice_hidden_state = torch.tensor(
[[-0.2790, -1.0717, -1.1668], [-0.5128, -0.3128, -0.4987], [-0.5832, 0.1971, -0.0197]]
).to(torch_device)
self.assertTrue(
torch.allclose(
outputs.encoder_last_hidden_state[0, 0, :3, :3], expected_slice_hidden_state, atol=TOLERANCE
)
)
expected_slice_hidden_state = torch.tensor(
[[0.8973, 1.1847, 1.1776], [1.1934, 1.5040, 1.5128], [1.1153, 1.4486, 1.4951]]
).to(torch_device)
self.assertTrue(
torch.allclose(
outputs.pixel_decoder_last_hidden_state[0, 0, :3, :3], expected_slice_hidden_state, atol=TOLERANCE
)
)
expected_slice_hidden_state = torch.tensor(
[[2.1152, 1.7000, -0.8603], [1.5808, 1.8004, -0.9353], [1.6043, 1.7495, -0.5999]]
).to(torch_device)
self.assertTrue(
torch.allclose(
outputs.transformer_decoder_last_hidden_state[0, :3, :3], expected_slice_hidden_state, atol=TOLERANCE
)
)
def test_inference_universal_segmentation_head(self):
model = Mask2FormerForUniversalSegmentation.from_pretrained(self.model_checkpoints).to(torch_device).eval()
feature_extractor = self.default_feature_extractor
image = prepare_img()
inputs = feature_extractor(image, return_tensors="pt").to(torch_device)
inputs_shape = inputs["pixel_values"].shape
# check size is divisible by 32
self.assertTrue((inputs_shape[-1] % 32) == 0 and (inputs_shape[-2] % 32) == 0)
# check size
self.assertEqual(inputs_shape, (1, 3, 384, 384))
with torch.no_grad():
outputs = model(**inputs)
# masks_queries_logits
masks_queries_logits = outputs.masks_queries_logits
self.assertEqual(
masks_queries_logits.shape, (1, model.config.num_queries, inputs_shape[-2] // 4, inputs_shape[-1] // 4)
)
expected_slice = [
[-8.7839, -9.0056, -8.8121],
[-7.4104, -7.0313, -6.5401],
[-6.6105, -6.3427, -6.4675],
]
expected_slice = torch.tensor(expected_slice).to(torch_device)
self.assertTrue(torch.allclose(masks_queries_logits[0, 0, :3, :3], expected_slice, atol=TOLERANCE))
# class_queries_logits
class_queries_logits = outputs.class_queries_logits
self.assertEqual(class_queries_logits.shape, (1, model.config.num_queries, model.config.num_labels + 1))
expected_slice = torch.tensor(
[
[1.8324, -8.0835, -4.1922],
[0.8450, -9.0050, -3.6053],
[0.3045, -7.7293, -3.0275],
]
).to(torch_device)
self.assertTrue(torch.allclose(outputs.class_queries_logits[0, :3, :3], expected_slice, atol=TOLERANCE))
def test_with_segmentation_maps_and_loss(self):
model = Mask2FormerForUniversalSegmentation.from_pretrained(self.model_checkpoints).to(torch_device).eval()
feature_extractor = self.default_feature_extractor
inputs = feature_extractor(
[np.zeros((3, 800, 1333)), np.zeros((3, 800, 1333))],
segmentation_maps=[np.zeros((384, 384)).astype(np.float32), np.zeros((384, 384)).astype(np.float32)],
return_tensors="pt",
)
inputs["pixel_values"] = inputs["pixel_values"].to(torch_device)
inputs["mask_labels"] = [el.to(torch_device) for el in inputs["mask_labels"]]
inputs["class_labels"] = [el.to(torch_device) for el in inputs["class_labels"]]
with torch.no_grad():
outputs = model(**inputs)
self.assertTrue(outputs.loss is not None)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment