"tests/models/roberta/test_tokenization_roberta.py" did not exist on "6c41a8f5dc5c630f31fda7b1617b701b40ea27d6"
Unverified Commit 5b949623 authored by Jitesh Jain's avatar Jitesh Jain Committed by GitHub
Browse files

Add OneFormer Model (#20577)

* Add Oneformer Model

* Add OneFormer Tests

* Add UNIVERSAL_SEGMENTATION_MAPPING

* Fix config

* 🐛 Fix error encountered while writing tests

* 🔨 Fix instance segmentation post processing

* Format Files and Add Documentation

* Add Documentation mdx file

* Run make fixup

* Run make fix-copies

* Remove unnecessary code

* Format modeling_oneformer.py

* Add OneFormer to ImageSegmentationPipeline

* Format files

* Add Demo link to Readme

* Fix fomatting errors

* Fix test failures

* Update Table in index.mdx

* Fix version

* Fix style

* Remove OneFormer from TF

* Fix Imports

* Fix dummy objects

* Fix tests

* Add newline

* Remove OneFormerFeatureExtractor

* Remove CUDA Kernels

* Use AutoBackbone for Swin

* Fix description

* Use Image Processor

* Fix copies

* Fix formatting

* Fix import order

* Fix flake8 errors

* Fix doc errors

* Add Hindi Readme entry

* Update supported backbones

* Update supported backbones

* Undo Changes

* Fix type of config

* Fix isort

* Fix auto.mdx

* Fix swin config

* Replace DinatBackbone with AutoBackbone

* Use SwinBackbone

* Use SwinBackbone

* Fix conversion script

* Fix arguments

* Add argument description

* Fix style

* Add OneFormerProcessor

* Fix OneFormerProcessor Tests

* Fix mapping

* Fix imports

* Fix inits

* Fix style

* Fix comment

* Fix docstring

* Move OneFormer to MultiModal

* Fix Copies

* Remove size divisor

* Fix check_repo.py

* Fix copies

* Add Processor for Testing Pipeline

* Fix padding for tokens

* Fix variables

* Fix formatting with correct black version

* Add Image Processor Test

* Apply suggestions

* Revert common modeling

* Add check for task

* Fix conversion script

* Fix initialization order

* Fix tests

* Undo Pipeline Changes

* Fix layers in MLP

* Fix copies

* Update image paths

* Fix copies

* Apply suggestions
parent 6d676643
......@@ -53,6 +53,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
("layoutlmv3", "LayoutLMv3Processor"),
("layoutxlm", "LayoutXLMProcessor"),
("markuplm", "MarkupLMProcessor"),
("oneformer", "OneFormerProcessor"),
("owlvit", "OwlViTProcessor"),
("sew", "Wav2Vec2Processor"),
("sew-d", "Wav2Vec2Processor"),
......
......@@ -208,6 +208,7 @@ else:
"AlbertTokenizerFast" if is_tokenizers_available() else None,
),
),
("oneformer", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)),
("openai-gpt", ("OpenAIGPTTokenizer", "OpenAIGPTTokenizerFast" if is_tokenizers_available() else None)),
("opt", ("GPT2Tokenizer", None)),
("owlvit", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)),
......
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
_import_structure = {
"configuration_oneformer": ["ONEFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "OneFormerConfig"],
"processing_oneformer": ["OneFormerProcessor"],
}
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["image_processing_oneformer"] = ["OneFormerImageProcessor"]
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_oneformer"] = [
"ONEFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
"OneFormerForUniversalSegmentation",
"OneFormerModel",
"OneFormerPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_oneformer import ONEFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, OneFormerConfig
from .processing_oneformer import OneFormerProcessor
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .image_processing_oneformer import OneFormerImageProcessor
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_oneformer import (
ONEFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
OneFormerForUniversalSegmentation,
OneFormerModel,
OneFormerPreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
# coding=utf-8
# Copyright 2022 SHI Labs and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""OneFormer model configuration"""
import copy
from typing import Dict, Optional
from ...configuration_utils import PretrainedConfig
from ...utils import logging
from ..auto import CONFIG_MAPPING
logger = logging.get_logger(__name__)
ONEFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"shi-labs/oneformer_ade20k_swin_tiny": (
"https://huggingface.co/shi-labs/oneformer_ade20k_swin_tiny/blob/main/config.json"
),
# See all OneFormer models at https://huggingface.co/models?filter=oneformer
}
class OneFormerConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`OneFormerModel`]. It is used to instantiate a
OneFormer model according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the OneFormer
[shi-labs/oneformer_ade20k_swin_tiny](https://huggingface.co/shi-labs/oneformer_ade20k_swin_tiny) architecture
trained on [ADE20k-150](https://huggingface.co/datasets/scene_parse_150).
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
backbone_config (`PretrainedConfig`, *optional*, defaults to `SwinConfig`)
The configuration of the backbone model.
ignore_value (`int`, *optional*, defaults to 255)
Values to be ignored in GT label while calculating loss.
num_queries (`int`, *optional*, defaults to 150)
Number of object queries.
no_object_weight (`float`, *optional*, defaults to 0.1)
Weight for no-object class predictions.
class_weight (`float`, *optional*, defaults to 2.0)
Weight for Classification CE loss.
mask_weight (`float`, *optional*, defaults to 5.0)
Weight for binary CE loss.
dice_weight (`float`, *optional*, defaults to 5.0)
Weight for dice loss.
contrastive_weight (`float`, *optional*, defaults to 0.5)
Weight for contrastive loss.
contrastive_temperature (`float`, *optional*, defaults to 0.07)
Initial value for scaling the contrastive logits.
train_num_points (`int`, *optional*, defaults to 12544)
Number of points to sample while calculating losses on mask predictions.
oversample_ratio (`float`, *optional*, defaults to 3.0)
Ratio to decide how many points to oversample.
importance_sample_ratio (`float`, *optional*, defaults to 0.75)
Ratio of points that are sampled via importance sampling.
init_std (`float`, *optional*, defaults to 0.02)
Standard deviation for normal intialization.
init_xavier_std (`float`, *optional*, defaults to 0.02)
Standard deviation for xavier uniform initialization.
layer_norm_eps (`float`, *optional*, defaults to 1e-05)
Epsilon for layer normalization.
is_training (`bool`, *optional*, defaults to False)
Whether to run in training or inference mode.
use_auxiliary_loss (`bool`, *optional*, defaults to True)
Whether to calculate loss using intermediate predictions from transformer decoder.
output_auxiliary_logits (`bool`, *optional*, defaults to True)
Whether to return intermediate predictions from transformer decoder.
strides (`list`, *optional*, defaults to [4, 8, 16, 32])
List containing the strides for feature maps in the encoder.
task_seq_len (`int`, *optional*, defaults to 77)
Sequence length for tokenizing text list input.
max_seq_len (`int`, *optional*, defaults to 77)
Sequence length for tokenizing task input.
text_encoder_width (`int`, *optional*, defaults to 256)
Hidden size for text encoder.
text_encoder_context_length (`int`, *optional*, defaults to 77):
Input sequence length for text encoder.
text_encoder_num_layers (`int`, *optional*, defaults to 6)
Number of layers for transformer in text encoder.
text_encoder_vocab_size (`int`, *optional*, defaults to 49408)
Vocabulary size for tokenizer.
text_encoder_proj_layers (`int`, *optional*, defaults to 2)
Number of layers in MLP for project text queries.
text_encoder_n_ctx (`int`, *optional*, defaults to 16)
Number of learnable text context queries.
conv_dim (`int`, *optional*, defaults to 256)
Feature map dimension to map outputs from the backbone.
mask_dim (`int`, *optional*, defaults to 256)
Dimension for feature maps in pixel decoder.
hidden_dim (`int`, *optional*, defaults to 256)
Dimension for hidden states in transformer decoder.
encoder_feedforward_dim (`int`, *optional*, defaults to 1024)
Dimension for FFN layer in pixel decoder.
norm (`str`, *optional*, defaults to `GN`)
Type of normalization.
encoder_layers (`int`, *optional*, defaults to 6)
Number of layers in pixel decoder.
decoder_layers (`int`, *optional*, defaults to 10)
Number of layers in transformer decoder.
use_task_norm (`bool`, *optional*, defaults to `True`)
Whether to normalize the task token.
num_attention_heads (`int`, *optional*, defaults to 8)
Number of attention heads in transformer layers in the pixel and transformer decoders.
dropout (`float`, *optional*, defaults to 0.1)
Dropout probability for pixel and transformer decoders.
dim_feedforward (`int`, *optional*, defaults to 2048)
Dimension for FFN layer in transformer decoder.
pre_norm (`bool`, *optional*, defaults to `False`)
Whether to normalize hidden states before attention layers in transformer decoder.
enforce_input_proj (`bool`, *optional*, defaults to `False`)
Whether to project hidden states in transformer decoder.
query_dec_layers (`int`, *optional*, defaults to 2)
Number of layers in query transformer.
common_stride (`int`, *optional*, defaults to 4)
Common stride used for features in pixel decoder.
Examples:
```python
>>> from transformers import OneFormerConfig, OneFormerModel
>>> # Initializing a OneFormer shi-labs/oneformer_ade20k_swin_tiny configuration
>>> configuration = OneFormerConfig()
>>> # Initializing a model (with random weights) from the shi-labs/oneformer_ade20k_swin_tiny style configuration
>>> model = OneFormerModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
"""
model_type = "oneformer"
attribute_map = {"hidden_size": "hidden_dim"}
def __init__(
self,
backbone_config: Optional[Dict] = None,
ignore_value: int = 255,
num_queries: int = 150,
no_object_weight: int = 0.1,
class_weight: float = 2.0,
mask_weight: float = 5.0,
dice_weight: float = 5.0,
contrastive_weight: float = 0.5,
contrastive_temperature: float = 0.07,
train_num_points: int = 12544,
oversample_ratio: float = 3.0,
importance_sample_ratio: float = 0.75,
init_std: float = 0.02,
init_xavier_std: float = 1.0,
layer_norm_eps: float = 1e-05,
is_training: bool = False,
use_auxiliary_loss: bool = True,
output_auxiliary_logits: bool = True,
strides: Optional[list] = [4, 8, 16, 32],
task_seq_len: int = 77,
max_seq_len: int = 77,
text_encoder_width: int = 256,
text_encoder_context_length: int = 77,
text_encoder_num_layers: int = 6,
text_encoder_vocab_size: int = 49408,
text_encoder_proj_layers: int = 2,
text_encoder_n_ctx: int = 16,
conv_dim: int = 256,
mask_dim: int = 256,
hidden_dim: int = 256,
encoder_feedforward_dim: int = 1024,
norm: str = "GN",
encoder_layers: int = 6,
decoder_layers: int = 10,
use_task_norm: bool = True,
num_attention_heads: int = 8,
dropout: float = 0.1,
dim_feedforward: int = 2048,
pre_norm: bool = False,
enforce_input_proj: bool = False,
query_dec_layers: int = 2,
common_stride: int = 4,
**kwargs,
):
if backbone_config is None:
logger.info("`backbone_config` is unset. Initializing the config with the default `Swin` backbone.")
backbone_config = CONFIG_MAPPING["swin"](
image_size=224,
in_channels=3,
patch_size=4,
embed_dim=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
drop_path_rate=0.3,
use_absolute_embeddings=False,
out_features=["stage1", "stage2", "stage3", "stage4"],
)
elif isinstance(backbone_config, dict):
backbone_model_type = backbone_config.get("model_type")
config_class = CONFIG_MAPPING[backbone_model_type]
backbone_config = config_class.from_dict(backbone_config)
self.backbone_config = backbone_config
self.ignore_value = ignore_value
self.num_queries = num_queries
self.no_object_weight = no_object_weight
self.class_weight = class_weight
self.mask_weight = mask_weight
self.dice_weight = dice_weight
self.contrastive_weight = contrastive_weight
self.contrastive_temperature = contrastive_temperature
self.train_num_points = train_num_points
self.oversample_ratio = oversample_ratio
self.importance_sample_ratio = importance_sample_ratio
self.init_std = init_std
self.init_xavier_std = init_xavier_std
self.layer_norm_eps = layer_norm_eps
self.is_training = is_training
self.use_auxiliary_loss = use_auxiliary_loss
self.output_auxiliary_logits = output_auxiliary_logits
self.strides = strides
self.task_seq_len = task_seq_len
self.max_seq_len = max_seq_len
self.text_encoder_width = text_encoder_width
self.text_encoder_context_length = text_encoder_context_length
self.text_encoder_num_layers = text_encoder_num_layers
self.text_encoder_vocab_size = text_encoder_vocab_size
self.text_encoder_proj_layers = text_encoder_proj_layers
self.text_encoder_n_ctx = text_encoder_n_ctx
self.conv_dim = conv_dim
self.mask_dim = mask_dim
self.hidden_dim = hidden_dim
self.encoder_feedforward_dim = encoder_feedforward_dim
self.norm = norm
self.encoder_layers = encoder_layers
self.decoder_layers = decoder_layers
self.use_task_norm = use_task_norm
self.num_attention_heads = num_attention_heads
self.dropout = dropout
self.dim_feedforward = dim_feedforward
self.pre_norm = pre_norm
self.enforce_input_proj = enforce_input_proj
self.query_dec_layers = query_dec_layers
self.common_stride = common_stride
self.num_hidden_layers = decoder_layers
super().__init__(**kwargs)
def to_dict(self) -> Dict[str, any]:
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["backbone_config"] = self.backbone_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
# coding=utf-8
# Copyright 2022 SHI Labs and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert OneFormer checkpoints from the original repository. URL: https://github.com/SHI-Labs/OneFormer"""
import os
import sys
from argparse import ArgumentParser
from dataclasses import dataclass
from pathlib import Path
from pprint import pformat
from typing import Any, Dict, Iterator, List, Set, Tuple
import torch
import torchvision.transforms as T
from PIL import Image
from torch import Tensor, nn
import requests
try:
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import get_cfg
from detectron2.data import MetadataCatalog
from detectron2.projects.deeplab import add_deeplab_config
except ImportError:
pass
from transformers import CLIPTokenizer, DinatConfig, SwinConfig
from transformers.models.oneformer.image_processing_oneformer import OneFormerImageProcessor
from transformers.models.oneformer.modeling_oneformer import (
OneFormerConfig,
OneFormerForUniversalSegmentation,
OneFormerForUniversalSegmentationOutput,
OneFormerModel,
OneFormerModelOutput,
)
from transformers.models.oneformer.processing_oneformer import OneFormerProcessor
from transformers.utils import logging
StateDict = Dict[str, Tensor]
logging.set_verbosity_info()
logger = logging.get_logger()
torch.manual_seed(0)
class TrackedStateDict:
def __init__(self, to_track: Dict):
"""This class "tracks" a python dictionary by keeping track of which item is accessed.
Args:
to_track (Dict): The dictionary we wish to track
"""
self.to_track = to_track
self._seen: Set[str] = set()
def __getitem__(self, key: str) -> Any:
return self.to_track[key]
def __setitem__(self, key: str, item: Any):
self._seen.add(key)
self.to_track[key] = item
def diff(self) -> List[str]:
"""This method returns a set difference between the keys in the tracked state dict and the one we have access so far.
This is an effective method to check if we have update all the keys
Returns:
List[str]: List of keys not yet updated
"""
return set(list(self.to_track.keys())) - self._seen
def copy(self) -> Dict:
# proxy the call to the internal dictionary
return self.to_track.copy()
# Image to verify the result
def prepare_img():
url = "https://praeclarumjj3.github.io/files/coco.jpeg"
img_data = requests.get(url, stream=True).raw
im = Image.open(img_data)
return im
@dataclass
class Args:
"""Fake command line arguments needed by oneformer/detectron2 implementation"""
config_file: str
def setup_cfg(args: Args):
# load config from file and command-line arguments
cfg = get_cfg()
add_deeplab_config(cfg)
add_common_config(cfg)
add_oneformer_config(cfg)
add_swin_config(cfg)
add_dinat_config(cfg)
cfg.merge_from_file(args.config_file)
cfg.freeze()
return cfg
class OriginalOneFormerConfigToOursConverter:
def __call__(self, original_config: object, is_swin: bool) -> OneFormerConfig:
model = original_config.MODEL
dataset_catalog = MetadataCatalog.get(original_config.DATASETS.TEST_PANOPTIC[0])
id2label = {idx: label for idx, label in enumerate(dataset_catalog.stuff_classes)}
label2id = {label: idx for idx, label in id2label.items()}
if is_swin:
if model.SWIN.EMBED_DIM == 96:
backbone_config = SwinConfig.from_pretrained(
"microsoft/swin-tiny-patch4-window7-224",
drop_path_rate=model.SWIN.DROP_PATH_RATE,
out_features=["stage1", "stage2", "stage3", "stage4"],
)
elif model.SWIN.EMBED_DIM == 192:
backbone_config = SwinConfig.from_pretrained(
"microsoft/swin-large-patch4-window12-384",
drop_path_rate=model.SWIN.DROP_PATH_RATE,
out_features=["stage1", "stage2", "stage3", "stage4"],
)
else:
raise ValueError(f"embed dim {model.SWIN.EMBED_DIM} not supported for Swin!")
else:
backbone_config = DinatConfig.from_pretrained(
"shi-labs/dinat-large-11x11-in22k-in1k-384",
dilations=model.DiNAT.DILATIONS,
kernel_size=model.DiNAT.KERNEL_SIZE,
out_features=["stage1", "stage2", "stage3", "stage4"],
)
config: OneFormerConfig = OneFormerConfig(
backbone_config=backbone_config,
output_attentions=True,
output_hidden_states=True,
return_dict=True,
ignore_value=model.SEM_SEG_HEAD.IGNORE_VALUE,
num_classes=model.SEM_SEG_HEAD.NUM_CLASSES,
num_queries=model.ONE_FORMER.NUM_OBJECT_QUERIES,
no_object_weight=model.ONE_FORMER.NO_OBJECT_WEIGHT,
class_weight=model.ONE_FORMER.CLASS_WEIGHT,
mask_weight=model.ONE_FORMER.MASK_WEIGHT,
dice_weight=model.ONE_FORMER.DICE_WEIGHT,
contrastive_weight=model.ONE_FORMER.CONTRASTIVE_WEIGHT,
contrastive_temperature=model.ONE_FORMER.CONTRASTIVE_TEMPERATURE,
train_num_points=model.ONE_FORMER.TRAIN_NUM_POINTS,
oversample_ratio=model.ONE_FORMER.OVERSAMPLE_RATIO,
importance_sample_ratio=model.ONE_FORMER.IMPORTANCE_SAMPLE_RATIO,
init_std=0.02,
init_xavier_std=1.0,
layer_norm_eps=1e-05,
is_training=False,
use_auxiliary_loss=model.ONE_FORMER.DEEP_SUPERVISION,
output_auxiliary_logits=True,
strides=[4, 8, 16, 32],
task_seq_len=original_config.INPUT.TASK_SEQ_LEN,
max_seq_len=original_config.INPUT.MAX_SEQ_LEN,
text_encoder_width=model.TEXT_ENCODER.WIDTH,
text_encoder_context_length=model.TEXT_ENCODER.CONTEXT_LENGTH,
text_encoder_num_layers=model.TEXT_ENCODER.NUM_LAYERS,
text_encoder_vocab_size=model.TEXT_ENCODER.VOCAB_SIZE,
text_encoder_proj_layers=model.TEXT_ENCODER.PROJ_NUM_LAYERS,
text_encoder_n_ctx=model.TEXT_ENCODER.N_CTX,
conv_dim=model.SEM_SEG_HEAD.CONVS_DIM,
mask_dim=model.SEM_SEG_HEAD.MASK_DIM,
hidden_dim=model.ONE_FORMER.HIDDEN_DIM,
norm=model.SEM_SEG_HEAD.NORM,
encoder_layers=model.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS,
encoder_feedforward_dim=1024,
decoder_layers=model.ONE_FORMER.DEC_LAYERS,
use_task_norm=model.ONE_FORMER.USE_TASK_NORM,
num_attention_heads=model.ONE_FORMER.NHEADS,
dropout=model.ONE_FORMER.DROPOUT,
dim_feedforward=model.ONE_FORMER.DIM_FEEDFORWARD,
pre_norm=model.ONE_FORMER.PRE_NORM,
enforce_input_proj=model.ONE_FORMER.ENFORCE_INPUT_PROJ,
query_dec_layers=model.ONE_FORMER.CLASS_DEC_LAYERS,
common_stride=model.SEM_SEG_HEAD.COMMON_STRIDE,
id2label=id2label,
label2id=label2id,
)
return config
class OriginalOneFormerConfigToProcessorConverter:
def __call__(self, original_config: object, model_repo: str) -> OneFormerProcessor:
model = original_config.MODEL
model_input = original_config.INPUT
dataset_catalog = MetadataCatalog.get(original_config.DATASETS.TEST_PANOPTIC[0])
if "ade20k" in model_repo:
class_info_file = "ade20k_panoptic.json"
elif "coco" in model_repo:
class_info_file = "coco_panoptic.json"
elif "cityscapes" in model_repo:
class_info_file = "cityscapes_panoptic.json"
else:
raise ValueError("Invalid Dataset!")
image_processor = OneFormerImageProcessor(
image_mean=(torch.tensor(model.PIXEL_MEAN) / 255).tolist(),
image_std=(torch.tensor(model.PIXEL_STD) / 255).tolist(),
size=model_input.MIN_SIZE_TEST,
max_size=model_input.MAX_SIZE_TEST,
num_labels=model.SEM_SEG_HEAD.NUM_CLASSES,
ignore_index=dataset_catalog.ignore_label,
class_info_file=class_info_file,
)
tokenizer = CLIPTokenizer.from_pretrained(model_repo)
return OneFormerProcessor(
image_processor=image_processor,
tokenizer=tokenizer,
task_seq_length=original_config.INPUT.TASK_SEQ_LEN,
max_seq_length=original_config.INPUT.MAX_SEQ_LEN,
)
class OriginalOneFormerCheckpointToOursConverter:
def __init__(self, original_model: nn.Module, config: OneFormerConfig):
self.original_model = original_model
self.config = config
def pop_all(self, renamed_keys: List[Tuple[str, str]], dst_state_dict: StateDict, src_state_dict: StateDict):
for src_key, dst_key in renamed_keys:
dst_state_dict[dst_key] = src_state_dict.pop(src_key)
# Swin Backbone
def replace_swin_backbone(self, dst_state_dict: StateDict, src_state_dict: StateDict, config: OneFormerConfig):
dst_prefix: str = "pixel_level_module.encoder"
src_prefix: str = "backbone"
renamed_keys = [
(
f"{src_prefix}.patch_embed.proj.weight",
f"{dst_prefix}.embeddings.patch_embeddings.projection.weight",
),
(f"{src_prefix}.patch_embed.proj.bias", f"{dst_prefix}.embeddings.patch_embeddings.projection.bias"),
(f"{src_prefix}.patch_embed.norm.weight", f"{dst_prefix}.embeddings.norm.weight"),
(f"{src_prefix}.patch_embed.norm.bias", f"{dst_prefix}.embeddings.norm.bias"),
]
num_layers = len(config.backbone_config.depths)
for layer_idx in range(num_layers):
for block_idx in range(config.backbone_config.depths[layer_idx]):
renamed_keys.extend(
[ # src, dst
(
f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.norm1.weight",
f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.layernorm_before.weight",
),
(
f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.norm1.bias",
f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.layernorm_before.bias",
),
(
f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.relative_position_bias_table",
f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.relative_position_bias_table",
),
]
)
# now we need to handle the attentions
# read in weights + bias of input projection layer of cross-attention
src_att_weight = src_state_dict[f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.qkv.weight"]
src_att_bias = src_state_dict[f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.qkv.bias"]
size = src_att_weight.shape[0]
offset = size // 3
dst_state_dict[
f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.query.weight"
] = src_att_weight[:offset, :]
dst_state_dict[
f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.query.bias"
] = src_att_bias[:offset]
dst_state_dict[
f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.key.weight"
] = src_att_weight[offset : offset * 2, :]
dst_state_dict[
f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.key.bias"
] = src_att_bias[offset : offset * 2]
dst_state_dict[
f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.value.weight"
] = src_att_weight[-offset:, :]
dst_state_dict[
f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.value.bias"
] = src_att_bias[-offset:]
# let's pop them
src_state_dict.pop(f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.qkv.weight")
src_state_dict.pop(f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.qkv.bias")
# proj
renamed_keys.extend(
[
(
f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.proj.weight",
f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.output.dense.weight",
),
(
f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.proj.bias",
f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.output.dense.bias",
),
]
)
# second norm
renamed_keys.extend(
[
(
f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.norm2.weight",
f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.layernorm_after.weight",
),
(
f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.norm2.bias",
f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.layernorm_after.bias",
),
]
)
# mlp
renamed_keys.extend(
[
(
f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.mlp.fc1.weight",
f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.intermediate.dense.weight",
),
(
f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.mlp.fc1.bias",
f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.intermediate.dense.bias",
),
(
f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.mlp.fc2.weight",
f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.output.dense.weight",
),
(
f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.mlp.fc2.bias",
f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.output.dense.bias",
),
]
)
renamed_keys.extend(
[
(
f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.relative_position_index",
f"{dst_prefix}.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.relative_position_index",
)
]
)
if layer_idx < num_layers - 1:
# patch merging
renamed_keys.extend(
[
(
f"{src_prefix}.layers.{layer_idx}.downsample.reduction.weight",
f"{dst_prefix}.encoder.layers.{layer_idx}.downsample.reduction.weight",
),
(
f"{src_prefix}.layers.{layer_idx}.downsample.norm.weight",
f"{dst_prefix}.encoder.layers.{layer_idx}.downsample.norm.weight",
),
(
f"{src_prefix}.layers.{layer_idx}.downsample.norm.bias",
f"{dst_prefix}.encoder.layers.{layer_idx}.downsample.norm.bias",
),
]
)
# hidden states norms
renamed_keys.extend(
[
(
f"{src_prefix}.norm{layer_idx}.weight",
f"{dst_prefix}.hidden_states_norms.stage{layer_idx+1}.weight",
),
(
f"{src_prefix}.norm{layer_idx}.bias",
f"{dst_prefix}.hidden_states_norms.stage{layer_idx+1}.bias",
),
]
)
self.pop_all(renamed_keys, dst_state_dict, src_state_dict)
# Dinat Backbone
def replace_dinat_backbone(self, dst_state_dict: StateDict, src_state_dict: StateDict, config: OneFormerConfig):
dst_prefix: str = "pixel_level_module.encoder"
src_prefix: str = "backbone"
def rename_keys_for_weight_bias(src_prefix: str, dst_prefix: str):
return [
(f"{src_prefix}.weight", f"{dst_prefix}.weight"),
(f"{src_prefix}.bias", f"{dst_prefix}.bias"),
]
renamed_keys = rename_keys_for_weight_bias(f"{src_prefix}.patch_embed.norm", f"{dst_prefix}.embeddings.norm")
for i in range(2):
renamed_keys.extend(
rename_keys_for_weight_bias(
f"{src_prefix}.patch_embed.proj.{i}",
f"{dst_prefix}.embeddings.patch_embeddings.projection.{i}",
)
)
num_layers = len(config.backbone_config.depths)
for layer_idx in range(num_layers):
for block_idx in range(config.backbone_config.depths[layer_idx]):
renamed_keys.extend(
rename_keys_for_weight_bias(
f"{src_prefix}.levels.{layer_idx}.blocks.{block_idx}.norm1",
f"{dst_prefix}.encoder.levels.{layer_idx}.layers.{block_idx}.layernorm_before",
)
)
renamed_keys.extend(
rename_keys_for_weight_bias(
f"{src_prefix}.levels.{layer_idx}.blocks.{block_idx}.norm2",
f"{dst_prefix}.encoder.levels.{layer_idx}.layers.{block_idx}.layernorm_after",
)
)
renamed_keys.extend(
[ # src, dst
(
f"{src_prefix}.levels.{layer_idx}.blocks.{block_idx}.attn.rpb",
f"{dst_prefix}.encoder.levels.{layer_idx}.layers.{block_idx}.attention.self.rpb",
),
]
)
# now we need to handle the attentions
# read in weights + bias of input projection layer of cross-attention
src_att_weight = src_state_dict[f"{src_prefix}.levels.{layer_idx}.blocks.{block_idx}.attn.qkv.weight"]
src_att_bias = src_state_dict[f"{src_prefix}.levels.{layer_idx}.blocks.{block_idx}.attn.qkv.bias"]
size = src_att_weight.shape[0]
offset = size // 3
dst_state_dict[
f"{dst_prefix}.encoder.levels.{layer_idx}.layers.{block_idx}.attention.self.query.weight"
] = src_att_weight[:offset, :]
dst_state_dict[
f"{dst_prefix}.encoder.levels.{layer_idx}.layers.{block_idx}.attention.self.query.bias"
] = src_att_bias[:offset]
dst_state_dict[
f"{dst_prefix}.encoder.levels.{layer_idx}.layers.{block_idx}.attention.self.key.weight"
] = src_att_weight[offset : offset * 2, :]
dst_state_dict[
f"{dst_prefix}.encoder.levels.{layer_idx}.layers.{block_idx}.attention.self.key.bias"
] = src_att_bias[offset : offset * 2]
dst_state_dict[
f"{dst_prefix}.encoder.levels.{layer_idx}.layers.{block_idx}.attention.self.value.weight"
] = src_att_weight[-offset:, :]
dst_state_dict[
f"{dst_prefix}.encoder.levels.{layer_idx}.layers.{block_idx}.attention.self.value.bias"
] = src_att_bias[-offset:]
# let's pop them
src_state_dict.pop(f"{src_prefix}.levels.{layer_idx}.blocks.{block_idx}.attn.qkv.weight")
src_state_dict.pop(f"{src_prefix}.levels.{layer_idx}.blocks.{block_idx}.attn.qkv.bias")
# proj
renamed_keys.extend(
rename_keys_for_weight_bias(
f"{src_prefix}.levels.{layer_idx}.blocks.{block_idx}.attn.proj",
f"{dst_prefix}.encoder.levels.{layer_idx}.layers.{block_idx}.attention.output.dense",
)
)
# mlp
renamed_keys.extend(
rename_keys_for_weight_bias(
f"{src_prefix}.levels.{layer_idx}.blocks.{block_idx}.mlp.fc1",
f"{dst_prefix}.encoder.levels.{layer_idx}.layers.{block_idx}.intermediate.dense",
)
)
renamed_keys.extend(
rename_keys_for_weight_bias(
f"{src_prefix}.levels.{layer_idx}.blocks.{block_idx}.mlp.fc2",
f"{dst_prefix}.encoder.levels.{layer_idx}.layers.{block_idx}.output.dense",
)
)
if layer_idx < num_layers - 1:
# patch merging
renamed_keys.extend(
[
(
f"{src_prefix}.levels.{layer_idx}.downsample.reduction.weight",
f"{dst_prefix}.encoder.levels.{layer_idx}.downsample.reduction.weight",
),
(
f"{src_prefix}.levels.{layer_idx}.downsample.norm.weight",
f"{dst_prefix}.encoder.levels.{layer_idx}.downsample.norm.weight",
),
(
f"{src_prefix}.levels.{layer_idx}.downsample.norm.bias",
f"{dst_prefix}.encoder.levels.{layer_idx}.downsample.norm.bias",
),
]
)
# hidden states norms
renamed_keys.extend(
[
(
f"{src_prefix}.norm{layer_idx}.weight",
f"{dst_prefix}.hidden_states_norms.stage{layer_idx+1}.weight",
),
(
f"{src_prefix}.norm{layer_idx}.bias",
f"{dst_prefix}.hidden_states_norms.stage{layer_idx+1}.bias",
),
]
)
self.pop_all(renamed_keys, dst_state_dict, src_state_dict)
# Backbone + Pixel Decoder
def replace_pixel_module(self, dst_state_dict: StateDict, src_state_dict: StateDict, is_swin: bool):
dst_prefix: str = "pixel_level_module.decoder"
src_prefix: str = "sem_seg_head.pixel_decoder"
if is_swin:
self.replace_swin_backbone(dst_state_dict, src_state_dict, self.config)
else:
self.replace_dinat_backbone(dst_state_dict, src_state_dict, self.config)
def rename_keys_for_weight_bias(src_prefix: str, dst_prefix: str):
return [
(f"{src_prefix}.weight", f"{dst_prefix}.weight"),
(f"{src_prefix}.bias", f"{dst_prefix}.bias"),
]
def rename_keys_for_self_attn(src_prefix: str, dst_prefix: str):
self_attn_keys = []
self_attn_keys.extend(
rename_keys_for_weight_bias(f"{src_prefix}.attention_weights", f"{dst_prefix}.attention_weights")
)
self_attn_keys.extend(
rename_keys_for_weight_bias(f"{src_prefix}.output_proj", f"{dst_prefix}.output_proj")
)
self_attn_keys.extend(
rename_keys_for_weight_bias(f"{src_prefix}.sampling_offsets", f"{dst_prefix}.sampling_offsets")
)
self_attn_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.value_proj", f"{dst_prefix}.value_proj"))
return self_attn_keys
def rename_keys_for_encoder_layer(src_prefix: str, dst_prefix: str):
encoder_keys = []
encoder_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.linear1", f"{dst_prefix}.fc1"))
encoder_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.linear2", f"{dst_prefix}.fc2"))
encoder_keys.extend(
rename_keys_for_weight_bias(f"{src_prefix}.norm1", f"{dst_prefix}.self_attn_layer_norm")
)
encoder_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.norm2", f"{dst_prefix}.final_layer_norm"))
encoder_keys.extend(rename_keys_for_self_attn(f"{src_prefix}.self_attn", f"{dst_prefix}.self_attn"))
return encoder_keys
# convolution layer for final features
renamed_keys = [
(f"{src_prefix}.adapter_1.weight", f"{dst_prefix}.adapter_1.0.weight"),
(f"{src_prefix}.adapter_1.norm.weight", f"{dst_prefix}.adapter_1.1.weight"),
(f"{src_prefix}.adapter_1.norm.bias", f"{dst_prefix}.adapter_1.1.bias"),
]
renamed_keys.extend(
[
(f"{src_prefix}.layer_1.weight", f"{dst_prefix}.layer_1.0.weight"),
(f"{src_prefix}.layer_1.norm.weight", f"{dst_prefix}.layer_1.1.weight"),
(f"{src_prefix}.layer_1.norm.bias", f"{dst_prefix}.layer_1.1.bias"),
]
)
# proj layers
for i in range(3):
for j in range(2):
renamed_keys.extend(
[
(f"{src_prefix}.input_proj.{i}.{j}.weight", f"{dst_prefix}.input_projections.{i}.{j}.weight"),
(f"{src_prefix}.input_proj.{i}.{j}.bias", f"{dst_prefix}.input_projections.{i}.{j}.bias"),
]
)
renamed_keys.extend([(f"{src_prefix}.transformer.level_embed", f"{dst_prefix}.level_embed")])
# layers
for layer_idx in range(self.config.encoder_layers):
renamed_keys.extend(
rename_keys_for_encoder_layer(
f"{src_prefix}.transformer.encoder.layers.{layer_idx}", f"{dst_prefix}.encoder.layers.{layer_idx}"
)
)
# proj
renamed_keys.extend(
[
(f"{src_prefix}.mask_features.weight", f"{dst_prefix}.mask_projection.weight"),
(f"{src_prefix}.mask_features.bias", f"{dst_prefix}.mask_projection.bias"),
]
)
self.pop_all(renamed_keys, dst_state_dict, src_state_dict)
# Transformer Decoder
def replace_keys_qkv_transformer_decoder(self, dst_state_dict: StateDict, src_state_dict: StateDict):
dst_prefix: str = "transformer_module.decoder.layers"
src_prefix: str = "sem_seg_head.predictor"
for i in range(self.config.decoder_layers - 1):
# read in weights + bias of input projection layer of self-attention
in_proj_weight = src_state_dict.pop(
f"{src_prefix}.transformer_self_attention_layers.{i}.self_attn.in_proj_weight"
)
in_proj_bias = src_state_dict.pop(
f"{src_prefix}.transformer_self_attention_layers.{i}.self_attn.in_proj_bias"
)
# next, add query, keys and values (in that order) to the state dict
dst_state_dict[f"{dst_prefix}.{i}.self_attn.self_attn.q_proj.weight"] = in_proj_weight[:256, :]
dst_state_dict[f"{dst_prefix}.{i}.self_attn.self_attn.q_proj.bias"] = in_proj_bias[:256]
dst_state_dict[f"{dst_prefix}.{i}.self_attn.self_attn.k_proj.weight"] = in_proj_weight[256:512, :]
dst_state_dict[f"{dst_prefix}.{i}.self_attn.self_attn.k_proj.bias"] = in_proj_bias[256:512]
dst_state_dict[f"{dst_prefix}.{i}.self_attn.self_attn.v_proj.weight"] = in_proj_weight[-256:, :]
dst_state_dict[f"{dst_prefix}.{i}.self_attn.self_attn.v_proj.bias"] = in_proj_bias[-256:]
def replace_transformer_module(self, dst_state_dict: StateDict, src_state_dict: StateDict):
dst_prefix: str = "transformer_module"
src_prefix: str = "sem_seg_head.predictor"
def rename_keys_for_weight_bias(src_prefix: str, dst_prefix: str):
return [
(f"{src_prefix}.weight", f"{dst_prefix}.weight"),
(f"{src_prefix}.bias", f"{dst_prefix}.bias"),
]
def rename_keys_for_attn(src_prefix: str, dst_prefix: str):
attn_keys = [
(f"{src_prefix}.in_proj_bias", f"{dst_prefix}.in_proj_bias"),
(f"{src_prefix}.in_proj_weight", f"{dst_prefix}.in_proj_weight"),
]
attn_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.out_proj", f"{dst_prefix}.out_proj"))
return attn_keys
def rename_keys_for_self_attn(src_prefix: str, dst_prefix: str):
attn_keys = []
attn_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.out_proj", f"{dst_prefix}.out_proj"))
return attn_keys
def rename_keys_for_query_transformer_layer(src_prefix: str, dst_prefix: str):
query_transformer_layer_keys = []
query_transformer_layer_keys.extend(
rename_keys_for_weight_bias(f"{src_prefix}.linear1", f"{dst_prefix}.linear1")
)
query_transformer_layer_keys.extend(
rename_keys_for_weight_bias(f"{src_prefix}.linear2", f"{dst_prefix}.linear2")
)
query_transformer_layer_keys.extend(
rename_keys_for_weight_bias(f"{src_prefix}.norm1", f"{dst_prefix}.norm1")
)
query_transformer_layer_keys.extend(
rename_keys_for_weight_bias(f"{src_prefix}.norm2", f"{dst_prefix}.norm2")
)
query_transformer_layer_keys.extend(
rename_keys_for_weight_bias(f"{src_prefix}.norm3", f"{dst_prefix}.norm3")
)
query_transformer_layer_keys.extend(
rename_keys_for_attn(f"{src_prefix}.self_attn", f"{dst_prefix}.self_attn")
)
query_transformer_layer_keys.extend(
rename_keys_for_attn(f"{src_prefix}.multihead_attn", f"{dst_prefix}.multihead_attn")
)
return query_transformer_layer_keys
def rename_keys_for_cross_attn_layer(src_prefix: str, dst_prefix: str):
cross_attn_layer_keys = []
cross_attn_layer_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.norm", f"{dst_prefix}.norm"))
cross_attn_layer_keys.extend(
rename_keys_for_attn(f"{src_prefix}.multihead_attn", f"{dst_prefix}.multihead_attn")
)
return cross_attn_layer_keys
def rename_keys_for_self_attn_layer(src_prefix: str, dst_prefix: str):
self_attn_layer_keys = []
self_attn_layer_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.norm", f"{dst_prefix}.norm"))
self_attn_layer_keys.extend(
rename_keys_for_self_attn(f"{src_prefix}.self_attn", f"{dst_prefix}.self_attn")
)
return self_attn_layer_keys
def rename_keys_for_ffn_layer(src_prefix: str, dst_prefix: str):
ffn_layer_keys = []
ffn_layer_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.linear1", f"{dst_prefix}.linear1"))
ffn_layer_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.linear2", f"{dst_prefix}.linear2"))
ffn_layer_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.norm", f"{dst_prefix}.norm"))
return ffn_layer_keys
def rename_keys_for_transformer_decoder_layer(src_prefix: str, dst_prefix: str, idx: int):
transformer_decoder_layer_keys = []
transformer_decoder_layer_keys.extend(
rename_keys_for_cross_attn_layer(
f"{src_prefix}.transformer_cross_attention_layers.{idx}", f"{dst_prefix}.{idx}.cross_attn"
)
)
transformer_decoder_layer_keys.extend(
rename_keys_for_self_attn_layer(
f"{src_prefix}.transformer_self_attention_layers.{idx}", f"{dst_prefix}.{idx}.self_attn"
)
)
transformer_decoder_layer_keys.extend(
rename_keys_for_ffn_layer(f"{src_prefix}.transformer_ffn_layers.{idx}", f"{dst_prefix}.{idx}.ffn")
)
return transformer_decoder_layer_keys
# positional embedding for object queries
renamed_keys = [
(f"{src_prefix}.query_embed.weight", f"{dst_prefix}.queries_embedder.weight"),
(f"{src_prefix}.level_embed.weight", f"{dst_prefix}.level_embed.weight"),
]
# norm
renamed_keys.extend(
rename_keys_for_weight_bias(f"{src_prefix}.decoder_norm", f"{dst_prefix}.decoder.decoder_norm")
)
# proj
renamed_keys.extend(
rename_keys_for_weight_bias(
f"{src_prefix}.class_input_proj", f"{dst_prefix}.decoder.query_input_projection"
)
)
renamed_keys.extend(
rename_keys_for_weight_bias(f"{src_prefix}.class_embed", f"{dst_prefix}.decoder.class_embed")
)
for i in range(3):
renamed_keys.extend(
rename_keys_for_weight_bias(
f"{src_prefix}.mask_embed.layers.{i}", f"{dst_prefix}.decoder.mask_embed.layers.{i}.0"
)
)
# norm
renamed_keys.extend(
rename_keys_for_weight_bias(
f"{src_prefix}.class_transformer.decoder.norm", f"{dst_prefix}.decoder.query_transformer.decoder.norm"
)
)
# transformer to update queries with task tokens
for i in range(self.config.query_dec_layers):
renamed_keys.extend(
rename_keys_for_query_transformer_layer(
f"{src_prefix}.class_transformer.decoder.layers.{i}",
f"{dst_prefix}.decoder.query_transformer.decoder.layers.{i}",
)
)
# decoder layers
for i in range(self.config.decoder_layers - 1):
renamed_keys.extend(
rename_keys_for_transformer_decoder_layer(
f"{src_prefix}",
f"{dst_prefix}.decoder.layers",
i,
)
)
self.pop_all(renamed_keys, dst_state_dict, src_state_dict)
self.replace_keys_qkv_transformer_decoder(dst_state_dict, src_state_dict)
def replace_task_mlp(self, dst_state_dict: StateDict, src_state_dict: StateDict):
dst_prefix: str = "task_encoder"
src_prefix: str = "task_mlp"
def rename_keys_for_weight_bias(src_prefix: str, dst_prefix: str):
return [
(f"{src_prefix}.weight", f"{dst_prefix}.weight"),
(f"{src_prefix}.bias", f"{dst_prefix}.bias"),
]
renamed_keys = []
for i in range(2):
renamed_keys.extend(
rename_keys_for_weight_bias(f"{src_prefix}.layers.{i}", f"{dst_prefix}.task_mlp.layers.{i}.0")
)
self.pop_all(renamed_keys, dst_state_dict, src_state_dict)
def replace_text_projector(self, dst_state_dict: StateDict, src_state_dict: StateDict):
dst_prefix: str = "text_mapper.text_projector"
src_prefix: str = "text_projector"
def rename_keys_for_weight_bias(src_prefix: str, dst_prefix: str):
return [
(f"{src_prefix}.weight", f"{dst_prefix}.weight"),
(f"{src_prefix}.bias", f"{dst_prefix}.bias"),
]
renamed_keys = []
for i in range(self.config.text_encoder_config["text_encoder_proj_layers"]):
renamed_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.layers.{i}", f"{dst_prefix}.{i}.0"))
self.pop_all(renamed_keys, dst_state_dict, src_state_dict)
def replace_text_mapper(self, dst_state_dict: StateDict, src_state_dict: StateDict):
dst_prefix: str = "text_mapper.text_encoder"
src_prefix: str = "text_encoder"
self.replace_text_projector(dst_state_dict, src_state_dict)
def rename_keys_for_weight_bias(src_prefix: str, dst_prefix: str):
return [
(f"{src_prefix}.weight", f"{dst_prefix}.weight"),
(f"{src_prefix}.bias", f"{dst_prefix}.bias"),
]
def rename_keys_for_attn(src_prefix: str, dst_prefix: str):
attn_keys = [
(f"{src_prefix}.in_proj_bias", f"{dst_prefix}.in_proj_bias"),
(f"{src_prefix}.in_proj_weight", f"{dst_prefix}.in_proj_weight"),
]
attn_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.out_proj", f"{dst_prefix}.out_proj"))
return attn_keys
def rename_keys_for_layer(src_prefix: str, dst_prefix: str):
resblock_keys = []
resblock_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.mlp.c_fc", f"{dst_prefix}.mlp.fc1"))
resblock_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.mlp.c_proj", f"{dst_prefix}.mlp.fc2"))
resblock_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.ln_1", f"{dst_prefix}.layer_norm1"))
resblock_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.ln_2", f"{dst_prefix}.layer_norm2"))
resblock_keys.extend(rename_keys_for_attn(f"{src_prefix}.attn", f"{dst_prefix}.self_attn"))
return resblock_keys
renamed_keys = [
("prompt_ctx.weight", "text_mapper.prompt_ctx.weight"),
]
renamed_keys.extend(
[
(f"{src_prefix}.positional_embedding", f"{dst_prefix}.positional_embedding"),
(f"{src_prefix}.token_embedding.weight", f"{dst_prefix}.token_embedding.weight"),
]
)
renamed_keys.extend(rename_keys_for_weight_bias(f"{src_prefix}.ln_final", f"{dst_prefix}.ln_final"))
for i in range(self.config.text_encoder_config["text_encoder_num_layers"]):
renamed_keys.extend(
rename_keys_for_layer(
f"{src_prefix}.transformer.resblocks.{i}", f"{dst_prefix}.transformer.layers.{i}"
)
)
self.pop_all(renamed_keys, dst_state_dict, src_state_dict)
def convert(self, oneformer: OneFormerModel, is_swin: bool) -> OneFormerModel:
dst_state_dict = TrackedStateDict(oneformer.state_dict())
src_state_dict = self.original_model.state_dict()
self.replace_pixel_module(dst_state_dict, src_state_dict, is_swin)
self.replace_transformer_module(dst_state_dict, src_state_dict)
self.replace_task_mlp(dst_state_dict, src_state_dict)
if self.config.is_training:
self.replace_text_mapper(dst_state_dict, src_state_dict)
logger.info(f"Missed keys are {pformat(dst_state_dict.diff())}")
logger.info(f"Not copied keys are {pformat(src_state_dict.keys())}")
logger.info("🙌 Done")
oneformer.load_state_dict(dst_state_dict)
return oneformer
@staticmethod
def using_dirs(checkpoints_dir: Path, config_dir: Path) -> Iterator[Tuple[object, Path, Path]]:
checkpoints: List[Path] = checkpoints_dir.glob("**/*.pth")
for checkpoint in checkpoints:
logger.info(f"💪 Converting {checkpoint.stem}")
# find associated config file
config: Path = config_dir / f"{checkpoint.stem}.yaml"
yield config, checkpoint
def post_process_sem_seg_output(outputs: OneFormerForUniversalSegmentationOutput, target_size: Tuple[int, int]):
# class_queries_logits has shape [BATCH, QUERIES, CLASSES + 1]
class_queries_logits = outputs.class_queries_logits
# masks_queries_logits has shape [BATCH, QUERIES, HEIGHT, WIDTH]
masks_queries_logits = outputs.masks_queries_logits
if target_size is not None:
masks_queries_logits = torch.nn.functional.interpolate(
masks_queries_logits,
size=target_size,
mode="bilinear",
align_corners=False,
)
# remove the null class `[..., :-1]`
masks_classes = class_queries_logits.softmax(dim=-1)[..., :-1]
# mask probs has shape [BATCH, QUERIES, HEIGHT, WIDTH]
masks_probs = masks_queries_logits.sigmoid()
# now we want to sum over the queries,
# $ out_{c,h,w} = \sum_q p_{q,c} * m_{q,h,w} $
# where $ softmax(p) \in R^{q, c} $ is the mask classes
# and $ sigmoid(m) \in R^{q, h, w}$ is the mask probabilities
# b(atch)q(uery)c(lasses), b(atch)q(uery)h(eight)w(idth)
segmentation = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs)
return segmentation
def test(
original_model,
our_model: OneFormerForUniversalSegmentation,
processor: OneFormerProcessor,
model_repo: str,
):
def _preprocess_text(text_list=None, max_length=77):
if text_list is None:
raise ValueError("tokens cannot be None.")
tokens = tokenizer(text_list, padding="max_length", max_length=max_length, truncation=True)
attention_masks, input_ids = tokens["attention_mask"], tokens["input_ids"]
token_inputs = []
for attn_mask, input_id in zip(attention_masks, input_ids):
token = torch.tensor(attn_mask) * torch.tensor(input_id)
token_inputs.append(token.unsqueeze(0))
token_inputs = torch.cat(token_inputs, dim=0)
return token_inputs
with torch.no_grad():
tokenizer = CLIPTokenizer.from_pretrained(model_repo)
original_model = original_model.eval()
our_model = our_model.eval()
im = prepare_img()
tr = T.Compose(
[
T.Resize((640, 640)),
T.ToTensor(),
T.Normalize(
mean=torch.tensor([123.675, 116.280, 103.530]) / 255.0,
std=torch.tensor([58.395, 57.120, 57.375]) / 255.0,
),
],
)
x = tr(im).unsqueeze(0)
task_input = ["the task is semantic"]
task_token = _preprocess_text(task_input, max_length=processor.task_seq_length)
original_model_backbone_features = original_model.backbone(x.clone())
our_model_output: OneFormerModelOutput = our_model.model(x.clone(), task_token, output_hidden_states=True)
for original_model_feature, our_model_feature in zip(
original_model_backbone_features.values(), our_model_output.encoder_hidden_states
):
assert torch.allclose(
original_model_feature, our_model_feature, atol=3e-3
), "The backbone features are not the same."
mask_features, _, multi_scale_features, _, _ = original_model.sem_seg_head.pixel_decoder.forward_features(
original_model_backbone_features
)
original_pixel_decoder_features = []
original_pixel_decoder_features.append(mask_features)
for i in range(len(multi_scale_features)):
original_pixel_decoder_features.append(multi_scale_features[i])
for original_model_feature, our_model_feature in zip(
original_pixel_decoder_features, our_model_output.pixel_decoder_hidden_states
):
assert torch.allclose(
original_model_feature, our_model_feature, atol=3e-4
), "The pixel decoder feature are not the same"
tr_complete = T.Compose(
[
T.Resize((640, 640)),
T.ToTensor(),
],
)
y = (tr_complete(im) * 255.0).to(torch.int).float()
# let's test the full model
original_model_out = original_model([{"image": y.clone(), "task": "The task is semantic"}])
original_segmentation = original_model_out[0]["sem_seg"]
our_model_out: OneFormerForUniversalSegmentationOutput = our_model(
x.clone(), task_token, output_hidden_states=True
)
our_segmentation = post_process_sem_seg_output(our_model_out, target_size=(640, 640))[0]
assert torch.allclose(
original_segmentation, our_segmentation, atol=1e-3
), "The segmentation image is not the same."
logger.info("✅ Test passed!")
def get_name(checkpoint_file: Path):
model_name_raw: str = checkpoint_file.stem
backbone = "swin" if "swin" in model_name_raw else "dinat"
dataset = ""
if "coco" in model_name_raw:
dataset = "coco"
elif "ade20k" in model_name_raw:
dataset = "ade20k"
elif "cityscapes" in model_name_raw:
dataset = "cityscapes"
else:
raise ValueError(
f"{model_name_raw} must be wrong since we didn't find 'coco' or 'ade20k' or 'cityscapes' in it "
)
backbone_types = ["tiny", "large"]
backbone_type = list(filter(lambda x: x in model_name_raw, backbone_types))[0]
model_name = f"oneformer_{dataset}_{backbone}_{backbone_type}"
return model_name
if __name__ == "__main__":
parser = ArgumentParser(
description=(
"Command line to convert the original oneformer models (with swin backbone) to transformers"
" implementation."
)
)
parser.add_argument(
"--checkpoints_dir",
type=Path,
help=(
"A directory containing the model's checkpoints. The directory has to have the following structure:"
" structure: <DIR_NAME>/<DATASET_NAME>/<CONFIG_NAME>.pth; where <CONFIG_NAME> name must follow the"
" following nomenclature nomenclature: oneformer_<DATASET_NAME>_<BACKBONE>_<BACKBONE_TYPE>"
),
)
parser.add_argument(
"--configs_dir",
type=Path,
help=(
"A directory containing the model's configs, see detectron2 doc. The directory has to have the following"
" structure: <DIR_NAME>/<DATASET_NAME>/<CONFIG_NAME>.yaml; where <CONFIG_NAME> name must follow the"
" following nomenclature nomenclature: oneformer_<DATASET_NAME>_<BACKBONE>_<BACKBONE_TYPE>"
),
)
parser.add_argument(
"--pytorch_dump_folder_path",
required=True,
type=Path,
help="Path to the folder to output PyTorch models.",
)
parser.add_argument(
"--oneformer_dir",
required=True,
type=Path,
help=(
"A path to OneFormer's original implementation directory. You can download from here:"
"https://github.com/SHI-Labs/OneFormer"
),
)
args = parser.parse_args()
checkpoints_dir: Path = args.checkpoints_dir
config_dir: Path = args.configs_dir
save_directory: Path = args.pytorch_dump_folder_path
oneformer_dir: Path = args.oneformer_dir
# append the path to the parents to oneformer dir
sys.path.append(str(oneformer_dir.parent))
# and import what's needed
from OneFormer.oneformer import add_common_config, add_dinat_config, add_oneformer_config, add_swin_config
from OneFormer.oneformer.oneformer_model import OneFormer as OriginalOneFormer
if not save_directory.exists():
save_directory.mkdir(parents=True)
for config_file, checkpoint_file in OriginalOneFormerCheckpointToOursConverter.using_dirs(
checkpoints_dir, config_dir
):
processor = OriginalOneFormerConfigToProcessorConverter()(
setup_cfg(Args(config_file=config_file)), os.path.join("shi-labs", config_file.stem)
)
original_config = setup_cfg(Args(config_file=config_file))
oneformer_kwargs = OriginalOneFormer.from_config(original_config)
original_model = OriginalOneFormer(**oneformer_kwargs).eval()
DetectionCheckpointer(original_model).load(str(checkpoint_file))
is_swin = "swin" in config_file.stem
config: OneFormerConfig = OriginalOneFormerConfigToOursConverter()(original_config, is_swin)
oneformer = OneFormerModel(config=config).eval()
converter = OriginalOneFormerCheckpointToOursConverter(original_model, config)
oneformer = converter.convert(oneformer, is_swin)
oneformer_for_universal_segmentation = OneFormerForUniversalSegmentation(config=config).eval()
oneformer_for_universal_segmentation.model = oneformer
test(
original_model,
oneformer_for_universal_segmentation,
processor,
os.path.join("shi-labs", config_file.stem),
)
model_name = get_name(checkpoint_file)
logger.info(f"🪄 Saving {model_name}")
processor.save_pretrained(save_directory / model_name)
oneformer_for_universal_segmentation.save_pretrained(save_directory / model_name)
processor.push_to_hub(
repo_id=os.path.join("shi-labs", config_file.stem),
commit_message="Add configs",
use_temp_dir=True,
)
oneformer_for_universal_segmentation.push_to_hub(
repo_id=os.path.join("shi-labs", config_file.stem),
commit_message="Add model",
use_temp_dir=True,
)
# coding=utf-8
# Copyright 2022 SHI Labs and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image processor class for OneFormer."""
import json
import warnings
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
import numpy as np
from huggingface_hub import hf_hub_download
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from transformers.image_transforms import (
PaddingMode,
get_resize_output_image_size,
normalize,
pad,
rescale,
resize,
to_channel_dimension_format,
to_numpy_array,
)
from transformers.image_utils import (
ChannelDimension,
ImageInput,
PILImageResampling,
get_image_size,
infer_channel_dimension_format,
is_batched,
valid_images,
)
from transformers.utils import (
IMAGENET_DEFAULT_MEAN,
IMAGENET_DEFAULT_STD,
TensorType,
is_torch_available,
is_torch_tensor,
logging,
)
logger = logging.get_logger(__name__)
if is_torch_available():
import torch
from torch import nn
# Copied from transformers.models.detr.image_processing_detr.max_across_indices
def max_across_indices(values: Iterable[Any]) -> List[Any]:
"""
Return the maximum value across all indices of an iterable of values.
"""
return [max(values_i) for values_i in zip(*values)]
# Copied from transformers.models.detr.image_processing_detr.get_max_height_width
def get_max_height_width(images: List[np.ndarray]) -> List[int]:
"""
Get the maximum height and width across all images in a batch.
"""
input_channel_dimension = infer_channel_dimension_format(images[0])
if input_channel_dimension == ChannelDimension.FIRST:
_, max_height, max_width = max_across_indices([img.shape for img in images])
elif input_channel_dimension == ChannelDimension.LAST:
max_height, max_width, _ = max_across_indices([img.shape for img in images])
else:
raise ValueError(f"Invalid channel dimension format: {input_channel_dimension}")
return (max_height, max_width)
# Copied from transformers.models.detr.image_processing_detr.make_pixel_mask
def make_pixel_mask(image: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray:
"""
Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
Args:
image (`np.ndarray`):
Image to make the pixel mask for.
output_size (`Tuple[int, int]`):
Output size of the mask.
"""
input_height, input_width = get_image_size(image)
mask = np.zeros(output_size, dtype=np.int64)
mask[:input_height, :input_width] = 1
return mask
# Copied from transformers.models.detr.image_processing_detr.binary_mask_to_rle
def binary_mask_to_rle(mask):
"""
Converts given binary mask of shape `(height, width)` to the run-length encoding (RLE) format.
Args:
mask (`torch.Tensor` or `numpy.array`):
A binary mask tensor of shape `(height, width)` where 0 denotes background and 1 denotes the target
segment_id or class_id.
Returns:
`List`: Run-length encoded list of the binary mask. Refer to COCO API for more information about the RLE
format.
"""
if is_torch_tensor(mask):
mask = mask.numpy()
pixels = mask.flatten()
pixels = np.concatenate([[0], pixels, [0]])
runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
runs[1::2] -= runs[::2]
return [x for x in runs]
# Copied from transformers.models.detr.image_processing_detr.convert_segmentation_to_rle
def convert_segmentation_to_rle(segmentation):
"""
Converts given segmentation map of shape `(height, width)` to the run-length encoding (RLE) format.
Args:
segmentation (`torch.Tensor` or `numpy.array`):
A segmentation map of shape `(height, width)` where each value denotes a segment or class id.
Returns:
`List[List]`: A list of lists, where each list is the run-length encoding of a segment / class id.
"""
segment_ids = torch.unique(segmentation)
run_length_encodings = []
for idx in segment_ids:
mask = torch.where(segmentation == idx, 1, 0)
rle = binary_mask_to_rle(mask)
run_length_encodings.append(rle)
return run_length_encodings
# Copied from transformers.models.detr.image_processing_detr.remove_low_and_no_objects
def remove_low_and_no_objects(masks, scores, labels, object_mask_threshold, num_labels):
"""
Binarize the given masks using `object_mask_threshold`, it returns the associated values of `masks`, `scores` and
`labels`.
Args:
masks (`torch.Tensor`):
A tensor of shape `(num_queries, height, width)`.
scores (`torch.Tensor`):
A tensor of shape `(num_queries)`.
labels (`torch.Tensor`):
A tensor of shape `(num_queries)`.
object_mask_threshold (`float`):
A number between 0 and 1 used to binarize the masks.
Raises:
`ValueError`: Raised when the first dimension doesn't match in all input tensors.
Returns:
`Tuple[`torch.Tensor`, `torch.Tensor`, `torch.Tensor`]`: The `masks`, `scores` and `labels` without the region
< `object_mask_threshold`.
"""
if not (masks.shape[0] == scores.shape[0] == labels.shape[0]):
raise ValueError("mask, scores and labels must have the same shape!")
to_keep = labels.ne(num_labels) & (scores > object_mask_threshold)
return masks[to_keep], scores[to_keep], labels[to_keep]
# Copied from transformers.models.detr.image_processing_detr.check_segment_validity
def check_segment_validity(mask_labels, mask_probs, k, mask_threshold=0.5, overlap_mask_area_threshold=0.8):
# Get the mask associated with the k class
mask_k = mask_labels == k
mask_k_area = mask_k.sum()
# Compute the area of all the stuff in query k
original_area = (mask_probs[k] >= mask_threshold).sum()
mask_exists = mask_k_area > 0 and original_area > 0
# Eliminate disconnected tiny segments
if mask_exists:
area_ratio = mask_k_area / original_area
if not area_ratio.item() > overlap_mask_area_threshold:
mask_exists = False
return mask_exists, mask_k
# Copied from transformers.models.detr.image_processing_detr.compute_segments
def compute_segments(
mask_probs,
pred_scores,
pred_labels,
mask_threshold: float = 0.5,
overlap_mask_area_threshold: float = 0.8,
label_ids_to_fuse: Optional[Set[int]] = None,
target_size: Tuple[int, int] = None,
):
height = mask_probs.shape[1] if target_size is None else target_size[0]
width = mask_probs.shape[2] if target_size is None else target_size[1]
segmentation = torch.zeros((height, width), dtype=torch.int32, device=mask_probs.device)
segments: List[Dict] = []
if target_size is not None:
mask_probs = nn.functional.interpolate(
mask_probs.unsqueeze(0), size=target_size, mode="bilinear", align_corners=False
)[0]
current_segment_id = 0
# Weigh each mask by its prediction score
mask_probs *= pred_scores.view(-1, 1, 1)
mask_labels = mask_probs.argmax(0) # [height, width]
# Keep track of instances of each class
stuff_memory_list: Dict[str, int] = {}
for k in range(pred_labels.shape[0]):
pred_class = pred_labels[k].item()
should_fuse = pred_class in label_ids_to_fuse
# Check if mask exists and large enough to be a segment
mask_exists, mask_k = check_segment_validity(
mask_labels, mask_probs, k, mask_threshold, overlap_mask_area_threshold
)
if mask_exists:
if pred_class in stuff_memory_list:
current_segment_id = stuff_memory_list[pred_class]
else:
current_segment_id += 1
# Add current object segment to final segmentation map
segmentation[mask_k] = current_segment_id
segment_score = round(pred_scores[k].item(), 6)
segments.append(
{
"id": current_segment_id,
"label_id": pred_class,
"was_fused": should_fuse,
"score": segment_score,
}
)
if should_fuse:
stuff_memory_list[pred_class] = current_segment_id
return segmentation, segments
# Copied from transformers.models.maskformer.image_processing_maskformer.convert_segmentation_map_to_binary_masks
def convert_segmentation_map_to_binary_masks(
segmentation_map: "np.ndarray",
instance_id_to_semantic_id: Optional[Dict[int, int]] = None,
ignore_index: Optional[int] = None,
reduce_labels: bool = False,
):
if reduce_labels and ignore_index is None:
raise ValueError("If `reduce_labels` is True, `ignore_index` must be provided.")
if reduce_labels:
segmentation_map = np.where(segmentation_map == 0, ignore_index, segmentation_map - 1)
# Get unique ids (class or instance ids based on input)
all_labels = np.unique(segmentation_map)
# Drop background label if applicable
if ignore_index is not None:
all_labels = all_labels[all_labels != ignore_index]
# Generate a binary mask for each object instance
binary_masks = [(segmentation_map == i) for i in all_labels]
binary_masks = np.stack(binary_masks, axis=0) # (num_labels, height, width)
# Convert instance ids to class ids
if instance_id_to_semantic_id is not None:
labels = np.zeros(all_labels.shape[0])
for label in all_labels:
class_id = instance_id_to_semantic_id[label + 1 if reduce_labels else label]
labels[all_labels == label] = class_id - 1 if reduce_labels else class_id
else:
labels = all_labels
return binary_masks.astype(np.float32), labels.astype(np.int64)
def get_oneformer_resize_output_image_size(
image: np.ndarray,
size: Union[int, Tuple[int, int], List[int], Tuple[int]],
max_size: Optional[int] = None,
default_to_square: bool = True,
) -> tuple:
"""
Computes the output size given the desired size.
Args:
input_image (`np.ndarray`):
The input image.
size (`int`, `Tuple[int, int]`, `List[int]`, `Tuple[int]`):
The size of the output image.
default_to_square (`bool`, *optional*, defaults to `True`):
Whether to default to square if no size is provided.
max_size (`int`, *optional*):
The maximum size of the output image.
Returns:
`Tuple[int, int]`: The output size.
"""
output_size = get_resize_output_image_size(
input_image=image, size=size, default_to_square=default_to_square, max_size=max_size
)
return output_size
def prepare_metadata(repo_path, class_info_file):
with open(hf_hub_download(repo_path, class_info_file, repo_type="dataset"), "r") as f:
class_info = json.load(f)
metadata = {}
class_names = []
thing_ids = []
for key, info in class_info.items():
metadata[key] = info["name"]
class_names.append(info["name"])
if info["isthing"]:
thing_ids.append(int(key))
metadata["thing_ids"] = thing_ids
metadata["class_names"] = class_names
return metadata
class OneFormerImageProcessor(BaseImageProcessor):
r"""
Constructs a OneFormer image processor. The image processor can be used to prepare image(s), task input(s) and
optional text inputs and targets for the model.
This image processor inherits from [`BaseImageProcessor`] which contains most of the main methods. Users should
refer to this superclass for more information regarding those methods.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the input to a certain `size`.
size (`int`, *optional*, defaults to 800):
Resize the input to the given size. Only has an effect if `do_resize` is set to `True`. If size is a
sequence like `(width, height)`, output size will be matched to this. If size is an int, smaller edge of
the image will be matched to this number. i.e, if `height > width`, then image will be rescaled to `(size *
height / width, size)`.
max_size (`int`, *optional*, defaults to 1333):
The largest size an image dimension can have (otherwise it's capped). Only has an effect if `do_resize` is
set to `True`.
resample (`int`, *optional*, defaults to `PIL.Image.Resampling.BILINEAR`):
An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`,
`PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`,
`PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set
to `True`.
do_rescale (`bool`, *optional*, defaults to `True`):
Whether to rescale the input to a certain `scale`.
rescale_factor (`float`, *optional*, defaults to 1/ 255):
Rescale the input by the given factor. Only has an effect if `do_rescale` is set to `True`.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether or not to normalize the input with mean and standard deviation.
image_mean (`int`, *optional*, defaults to `[0.485, 0.456, 0.406]`):
The sequence of means for each channel, to be used when normalizing images. Defaults to the ImageNet mean.
image_std (`int`, *optional*, defaults to `[0.229, 0.224, 0.225]`):
The sequence of standard deviations for each channel, to be used when normalizing images. Defaults to the
ImageNet std.
ignore_index (`int`, *optional*):
Label to be assigned to background pixels in segmentation maps. If provided, segmentation map pixels
denoted with 0 (background) will be replaced with `ignore_index`.
reduce_labels (`bool`, *optional*, defaults to `False`):
Whether or not to decrement all label values of segmentation maps by 1. Usually used for datasets where 0
is used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k).
The background label will be replaced by `ignore_index`.
repo_path (`str`, defaults to `shi-labs/oneformer_demo`):
Dataset repository on huggingface hub containing the JSON file with class information for the dataset.
class_info_file (`str`):
JSON file containing class information for the dataset. It is stored inside on the `repo_path` dataset
repository.
num_text (`int`, *optional*):
Number of text entries in the text input list.
"""
model_input_names = ["pixel_values", "pixel_mask", "task_inputs"]
def __init__(
self,
do_resize: bool = True,
size: Dict[str, int] = None,
resample: PILImageResampling = PILImageResampling.BILINEAR,
do_rescale: bool = True,
rescale_factor: float = 1 / 255,
do_normalize: bool = True,
image_mean: Union[float, List[float]] = None,
image_std: Union[float, List[float]] = None,
ignore_index: Optional[int] = None,
reduce_labels: bool = False,
repo_path: str = "shi-labs/oneformer_demo",
class_info_file: str = None,
num_text: Optional[int] = None,
**kwargs
):
if "max_size" in kwargs:
self._max_size = kwargs.pop("max_size")
else:
self._max_size = 1333
size = size if size is not None else {"shortest_edge": 800, "longest_edge": self._max_size}
size = get_size_dict(size, max_size=self._max_size, default_to_square=False)
super().__init__(**kwargs)
self.do_resize = do_resize
self.size = size
self.resample = resample
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
self.ignore_index = ignore_index
self.reduce_labels = reduce_labels
self.class_info_file = class_info_file
self.repo_path = repo_path
self.metadata = prepare_metadata(repo_path, class_info_file)
self.num_text = num_text
def resize(
self,
image: np.ndarray,
size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BILINEAR,
data_format=None,
**kwargs
) -> np.ndarray:
"""
Resize the image to the given size. Size can be min_size (scalar) or `(height, width)` tuple. If size is an
int, smaller edge of the image will be matched to this number.
"""
if "max_size" in kwargs:
warnings.warn(
"The `max_size` parameter is deprecated and will be removed in v4.27. "
"Please specify in `size['longest_edge'] instead`.",
FutureWarning,
)
max_size = kwargs.pop("max_size")
else:
max_size = None
size = get_size_dict(size, max_size=max_size, default_to_square=False)
if "shortest_edge" in size and "longest_edge" in size:
size, max_size = size["shortest_edge"], size["longest_edge"]
elif "height" in size and "width" in size:
size = (size["height"], size["width"])
max_size = None
else:
raise ValueError(
"Size must contain 'height' and 'width' keys or 'shortest_edge' and 'longest_edge' keys. Got"
f" {size.keys()}."
)
size = get_oneformer_resize_output_image_size(
image=image,
size=size,
max_size=max_size,
default_to_square=False,
)
image = resize(image, size=size, resample=resample, data_format=data_format)
return image
# Copied from transformers.models.maskformer.image_processing_maskformer.MaskFormerImageProcessor.rescale
def rescale(
self, image: np.ndarray, rescale_factor: float, data_format: Optional[ChannelDimension] = None
) -> np.ndarray:
"""
Rescale the image by the given factor.
"""
return rescale(image, rescale_factor, data_format=data_format)
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.normalize
def normalize(
self,
image: np.ndarray,
mean: Union[float, Iterable[float]],
std: Union[float, Iterable[float]],
data_format: Optional[ChannelDimension] = None,
) -> np.ndarray:
"""
Normalize the image with the given mean and standard deviation.
"""
return normalize(image, mean=mean, std=std, data_format=data_format)
# Copied from transformers.models.maskformer.image_processing_maskformer.MaskFormerImageProcessor.convert_segmentation_map_to_binary_masks
def convert_segmentation_map_to_binary_masks(
self,
segmentation_map: "np.ndarray",
instance_id_to_semantic_id: Optional[Dict[int, int]] = None,
ignore_index: Optional[int] = None,
reduce_labels: bool = False,
):
reduce_labels = reduce_labels if reduce_labels is not None else self.reduce_labels
ignore_index = ignore_index if ignore_index is not None else self.ignore_index
return convert_segmentation_map_to_binary_masks(
segmentation_map=segmentation_map,
instance_id_to_semantic_id=instance_id_to_semantic_id,
ignore_index=ignore_index,
reduce_labels=reduce_labels,
)
def __call__(self, images, task_inputs, segmentation_maps=None, **kwargs) -> BatchFeature:
return self.preprocess(images, task_inputs, segmentation_maps=segmentation_maps, **kwargs)
def _preprocess(
self,
image: ImageInput,
do_resize: bool = None,
size: Dict[str, int] = None,
resample: PILImageResampling = None,
do_rescale: bool = None,
rescale_factor: float = None,
do_normalize: bool = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
):
if do_resize:
image = self.resize(image, size=size, resample=resample)
if do_rescale:
image = self.rescale(image, rescale_factor=rescale_factor)
if do_normalize:
image = self.normalize(image, mean=image_mean, std=image_std)
return image
def _preprocess_image(
self,
image: ImageInput,
do_resize: bool = None,
size: Dict[str, int] = None,
resample: PILImageResampling = None,
do_rescale: bool = None,
rescale_factor: float = None,
do_normalize: bool = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
"""Preprocesses a single image."""
# All transformations expect numpy arrays.
image = to_numpy_array(image)
image = self._preprocess(
image=image,
do_resize=do_resize,
size=size,
resample=resample,
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
)
if data_format is not None:
image = to_channel_dimension_format(image, data_format)
return image
def _preprocess_mask(
self,
segmentation_map: ImageInput,
do_resize: bool = None,
size: Dict[str, int] = None,
) -> np.ndarray:
"""Preprocesses a single mask."""
segmentation_map = to_numpy_array(segmentation_map)
# Add channel dimension if missing - needed for certain transformations
added_channel_dim = False
if segmentation_map.ndim == 2:
added_channel_dim = True
segmentation_map = segmentation_map[None, ...]
# TODO: (Amy)
# Remork segmentation map processing to include reducing labels and resizing which doesn't
# drop segment IDs > 255.
segmentation_map = self._preprocess(
image=segmentation_map,
do_resize=do_resize,
resample=PILImageResampling.NEAREST,
size=size,
do_rescale=False,
do_normalize=False,
)
# Remove extra channel dimension if added for processing
if added_channel_dim:
segmentation_map = segmentation_map.squeeze(0)
return segmentation_map
def preprocess(
self,
images: ImageInput,
task_inputs: List[str],
segmentation_maps: Optional[ImageInput] = None,
instance_id_to_semantic_id: Optional[Dict[int, int]] = None,
do_resize: Optional[bool] = None,
size: Optional[Dict[str, int]] = None,
resample: PILImageResampling = None,
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
ignore_index: Optional[int] = None,
reduce_labels: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST,
**kwargs
) -> BatchFeature:
if "pad_and_return_pixel_mask" in kwargs:
warnings.warn(
"The `pad_and_return_pixel_mask` argument is deprecated and will be removed in a future version",
FutureWarning,
)
do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size
size = get_size_dict(size, default_to_square=False, max_size=self._max_size)
resample = resample if resample is not None else self.resample
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
ignore_index = ignore_index if ignore_index is not None else self.ignore_index
reduce_labels = reduce_labels if reduce_labels is not None else self.reduce_labels
if do_resize is not None and size is None:
raise ValueError("If `do_resize` is True, `size` must be provided.")
if do_rescale is not None and rescale_factor is None:
raise ValueError("If `do_rescale` is True, `rescale_factor` must be provided.")
if do_normalize is not None and (image_mean is None or image_std is None):
raise ValueError("If `do_normalize` is True, `image_mean` and `image_std` must be provided.")
if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
if segmentation_maps is not None and not valid_images(segmentation_maps):
raise ValueError(
"Invalid segmentation map type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
if not is_batched(images):
images = [images]
segmentation_maps = [segmentation_maps] if segmentation_maps is not None else None
if segmentation_maps is not None and len(images) != len(segmentation_maps):
raise ValueError("Images and segmentation maps must have the same length.")
images = [
self._preprocess_image(
image,
do_resize=do_resize,
size=size,
resample=resample,
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
data_format=data_format,
)
for image in images
]
if segmentation_maps is not None:
segmentation_maps = [
self._preprocess_mask(segmentation_map, do_resize, size) for segmentation_map in segmentation_maps
]
encoded_inputs = self.encode_inputs(
images,
task_inputs,
segmentation_maps,
instance_id_to_semantic_id,
ignore_index,
reduce_labels,
return_tensors,
)
return encoded_inputs
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor._pad_image
def _pad_image(
self,
image: np.ndarray,
output_size: Tuple[int, int],
constant_values: Union[float, Iterable[float]] = 0,
data_format: Optional[ChannelDimension] = None,
) -> np.ndarray:
"""
Pad an image with zeros to the given size.
"""
input_height, input_width = get_image_size(image)
output_height, output_width = output_size
pad_bottom = output_height - input_height
pad_right = output_width - input_width
padding = ((0, pad_bottom), (0, pad_right))
padded_image = pad(
image, padding, mode=PaddingMode.CONSTANT, constant_values=constant_values, data_format=data_format
)
return padded_image
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor.pad
def pad(
self,
images: List[np.ndarray],
constant_values: Union[float, Iterable[float]] = 0,
return_pixel_mask: bool = True,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[ChannelDimension] = None,
) -> np.ndarray:
"""
Pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width
in the batch and optionally returns their corresponding pixel mask.
Args:
image (`np.ndarray`):
Image to pad.
constant_values (`float` or `Iterable[float]`, *optional*):
The value to use for the padding if `mode` is `"constant"`.
return_pixel_mask (`bool`, *optional*, defaults to `True`):
Whether to return a pixel mask.
input_channel_dimension (`ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be inferred from the input image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
pad_size = get_max_height_width(images)
padded_images = [
self._pad_image(image, pad_size, constant_values=constant_values, data_format=data_format)
for image in images
]
data = {"pixel_values": padded_images}
if return_pixel_mask:
masks = [make_pixel_mask(image=image, output_size=pad_size) for image in images]
data["pixel_mask"] = masks
return BatchFeature(data=data, tensor_type=return_tensors)
def get_semantic_annotations(self, label, num_class_obj):
annotation_classes = label["classes"]
annotation_masks = label["masks"]
texts = ["a semantic photo"] * self.num_text
classes = []
masks = []
for idx in range(len(annotation_classes)):
class_id = annotation_classes[idx]
mask = annotation_masks[idx]
if not np.all(mask is False):
if class_id not in classes:
cls_name = self.metadata[str(class_id)]
classes.append(class_id)
masks.append(mask)
num_class_obj[cls_name] += 1
else:
idx = classes.index(class_id)
masks[idx] += mask
masks[idx] = np.clip(masks[idx], 0, 1)
num = 0
for i, cls_name in enumerate(self.metadata["class_names"]):
if num_class_obj[cls_name] > 0:
for _ in range(num_class_obj[cls_name]):
if num >= len(texts):
break
texts[num] = f"a photo with a {cls_name}"
num += 1
classes = np.array(classes)
masks = np.array(masks)
return classes, masks, texts
def get_instance_annotations(self, label, num_class_obj):
annotation_classes = label["classes"]
annotation_masks = label["masks"]
texts = ["an instance photo"] * self.num_text
classes = []
masks = []
for idx in range(len(annotation_classes)):
class_id = annotation_classes[idx]
mask = annotation_masks[idx]
if class_id in self.metadata["thing_ids"]:
if not np.all(mask is False):
cls_name = self.metadata[str(class_id)]
classes.append(class_id)
masks.append(mask)
num_class_obj[cls_name] += 1
num = 0
for i, cls_name in enumerate(self.metadata["class_names"]):
if num_class_obj[cls_name] > 0:
for _ in range(num_class_obj[cls_name]):
if num >= len(texts):
break
texts[num] = f"a photo with a {cls_name}"
num += 1
classes = np.array(classes)
masks = np.array(masks)
return classes, masks, texts
def get_panoptic_annotations(self, label, num_class_obj):
annotation_classes = label["classes"]
annotation_masks = label["masks"]
texts = ["an panoptic photo"] * self.num_text
classes = []
masks = []
for idx in range(len(annotation_classes)):
class_id = annotation_classes[idx]
mask = annotation_masks[idx].data
if not np.all(mask is False):
cls_name = self.metadata[str(class_id)]
classes.append(class_id)
masks.append(mask)
num_class_obj[cls_name] += 1
num = 0
for i, cls_name in enumerate(self.metadata["class_names"]):
if num_class_obj[cls_name] > 0:
for _ in range(num_class_obj[cls_name]):
if num >= len(texts):
break
texts[num] = f"a photo with a {cls_name}"
num += 1
classes = np.array(classes)
masks = np.array(masks)
return classes, masks, texts
def encode_inputs(
self,
pixel_values_list: List[ImageInput],
task_inputs: List[str],
segmentation_maps: ImageInput = None,
instance_id_to_semantic_id: Optional[Union[List[Dict[int, int]], Dict[int, int]]] = None,
ignore_index: Optional[int] = None,
reduce_labels: bool = False,
return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs
):
"""
Pad images up to the largest image in a batch and create a corresponding `pixel_mask`.
OneFormer addresses semantic segmentation with a mask classification paradigm, thus input segmentation maps
will be converted to lists of binary masks and their respective labels. Let's see an example, assuming
`segmentation_maps = [[2,6,7,9]]`, the output will contain `mask_labels =
[[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,1]]` (four binary masks) and `class_labels = [2,6,7,9]`, the labels for
each mask.
Args:
pixel_values_list (`List[ImageInput]`):
List of images (pixel values) to be padded. Each image should be a tensor of shape `(channels, height,
width)`.
task_inputs (`List[str]`):
List of task values.
segmentation_maps (`ImageInput`, *optional*):
The corresponding semantic segmentation maps with the pixel-wise annotations.
(`bool`, *optional*, defaults to `True`):
Whether or not to pad images up to the largest image in a batch and create a pixel mask.
If left to the default, will return a pixel mask that is:
- 1 for pixels that are real (i.e. **not masked**),
- 0 for pixels that are padding (i.e. **masked**).
instance_id_to_semantic_id (`List[Dict[int, int]]` or `Dict[int, int]`, *optional*):
A mapping between object instance ids and class ids. If passed, `segmentation_maps` is treated as an
instance segmentation map where each pixel represents an instance id. Can be provided as a single
dictionary with a global/dataset-level mapping or as a list of dictionaries (one per image), to map
instance ids in each image separately.
return_tensors (`str` or [`~file_utils.TensorType`], *optional*):
If set, will return tensors instead of NumPy arrays. If set to `'pt'`, return PyTorch `torch.Tensor`
objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **pixel_values** -- Pixel values to be fed to a model.
- **pixel_mask** -- Pixel mask to be fed to a model (when `=True` or if `pixel_mask` is in
`self.model_input_names`).
- **mask_labels** -- Optional list of mask labels of shape `(labels, height, width)` to be fed to a model
(when `annotations` are provided).
- **class_labels** -- Optional list of class labels of shape `(labels)` to be fed to a model (when
`annotations` are provided). They identify the labels of `mask_labels`, e.g. the label of
`mask_labels[i][j]` if `class_labels[i][j]`.
- **text_inputs** -- Optional list of text string entries to be fed to a model (when `annotations` are
provided). They identify the binary masks present in the image.
"""
ignore_index = self.ignore_index if ignore_index is None else ignore_index
reduce_labels = self.reduce_labels if reduce_labels is None else reduce_labels
if "pad_and_return_pixel_mask" in kwargs:
warnings.warn(
"The `pad_and_return_pixel_mask` argument has no effect and will be removed in v4.27", FutureWarning
)
pixel_values_list = [to_numpy_array(pixel_values) for pixel_values in pixel_values_list]
pad_size = get_max_height_width(pixel_values_list)
encoded_inputs = self.pad(pixel_values_list, return_tensors=return_tensors)
annotations = None
if segmentation_maps is not None:
segmentation_maps = map(np.array, segmentation_maps)
annotations = []
for idx, segmentation_map in enumerate(segmentation_maps):
# Use instance2class_id mapping per image
if isinstance(instance_id_to_semantic_id, list):
instance_id = instance_id_to_semantic_id[idx]
else:
instance_id = instance_id_to_semantic_id
# Use instance2class_id mapping per image
masks, classes = self.convert_segmentation_map_to_binary_masks(
segmentation_map, instance_id, ignore_index=ignore_index, reduce_labels=reduce_labels
)
annotations.append({"masks": masks, "classes": classes})
if annotations is not None:
mask_labels = []
class_labels = []
text_inputs = []
num_class_obj = {}
for cls_name in self.metadata["class_names"]:
num_class_obj[cls_name] = 0
for i, label in enumerate(annotations):
task = task_inputs[i]
if task == "semantic":
classes, masks, texts = self.get_semantic_annotations(label, num_class_obj)
elif task == "instance":
classes, masks, texts = self.get_instance_annotations(label, num_class_obj)
if task == "panoptic":
classes, masks, texts = self.get_panoptic_annotations(label, num_class_obj)
# we cannot batch them since they don't share a common class size
masks = [mask[None, ...] for mask in masks]
masks = [
self._pad_image(image=mask, output_size=pad_size, constant_values=ignore_index) for mask in masks
]
masks = np.concatenate(masks, axis=0)
mask_labels.append(torch.from_numpy(masks))
class_labels.append(torch.from_numpy(classes).long())
text_inputs.append(texts)
encoded_inputs["mask_labels"] = mask_labels
encoded_inputs["class_labels"] = class_labels
encoded_inputs["text_inputs"] = text_inputs
return encoded_inputs
# Copied from transformers.models.maskformer.image_processing_maskformer.MaskFormerImageProcessor.post_process_semantic_segmentation
def post_process_semantic_segmentation(
self, outputs, target_sizes: Optional[List[Tuple[int, int]]] = None
) -> "torch.Tensor":
"""
Converts the output of [`MaskFormerForInstanceSegmentation`] into semantic segmentation maps. Only supports
PyTorch.
Args:
outputs ([`MaskFormerForInstanceSegmentation`]):
Raw outputs of the model.
target_sizes (`List[Tuple[int, int]]`, *optional*):
List of length (batch_size), where each list item (`Tuple[int, int]]`) corresponds to the requested
final size (height, width) of each prediction. If left to None, predictions will not be resized.
Returns:
`List[torch.Tensor]`:
A list of length `batch_size`, where each item is a semantic segmentation map of shape (height, width)
corresponding to the target_sizes entry (if `target_sizes` is specified). Each entry of each
`torch.Tensor` correspond to a semantic class id.
"""
class_queries_logits = outputs.class_queries_logits # [batch_size, num_queries, num_classes+1]
masks_queries_logits = outputs.masks_queries_logits # [batch_size, num_queries, height, width]
# Remove the null class `[..., :-1]`
masks_classes = class_queries_logits.softmax(dim=-1)[..., :-1]
masks_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width]
# Semantic segmentation logits of shape (batch_size, num_classes, height, width)
segmentation = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs)
batch_size = class_queries_logits.shape[0]
# Resize logits and compute semantic segmentation maps
if target_sizes is not None:
if batch_size != len(target_sizes):
raise ValueError(
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
)
semantic_segmentation = []
for idx in range(batch_size):
resized_logits = torch.nn.functional.interpolate(
segmentation[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
)
semantic_map = resized_logits[0].argmax(dim=0)
semantic_segmentation.append(semantic_map)
else:
semantic_segmentation = segmentation.argmax(dim=1)
semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
return semantic_segmentation
def post_process_instance_segmentation(
self,
outputs,
task_type: str = "instance",
is_demo: bool = True,
threshold: float = 0.5,
mask_threshold: float = 0.5,
overlap_mask_area_threshold: float = 0.8,
target_sizes: Optional[List[Tuple[int, int]]] = None,
return_coco_annotation: Optional[bool] = False,
):
"""
Converts the output of [`OneFormerForUniversalSegmentationOutput`] into image instance segmentation
predictions. Only supports PyTorch.
Args:
outputs ([`OneFormerForUniversalSegmentationOutput`]):
The outputs from [`OneFormerForUniversalSegmentationOutput`].
task_type (`str`, *optional)*, defaults to "instance"):
The post processing depends on the task token input. If the `task_type` is "panoptic", we need to
ignore the stuff predictions.
is_demo (`bool`, *optional)*, defaults to `True`):
Whether the model is in demo mode. If true, use threshold to predict final masks.
threshold (`float`, *optional*, defaults to 0.5):
The probability score threshold to keep predicted instance masks.
mask_threshold (`float`, *optional*, defaults to 0.5):
Threshold to use when turning the predicted masks into binary values.
overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8):
The overlap mask area threshold to merge or discard small disconnected parts within each binary
instance mask.
target_sizes (`List[Tuple]`, *optional*):
List of length (batch_size), where each list item (`Tuple[int, int]]`) corresponds to the requested
final size (height, width) of each prediction in batch. If left to None, predictions will not be
resized.
return_coco_annotation (`bool`, *optional)*, defaults to `False`):
Whether to return predictions in COCO format.
Returns:
`List[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys:
- **segmentation** -- a tensor of shape `(height, width)` where each pixel represents a `segment_id`, set
to `None` if no mask if found above `threshold`. If `target_sizes` is specified, segmentation is resized
to the corresponding `target_sizes` entry.
- **segments_info** -- A dictionary that contains additional information on each segment.
- **id** -- an integer representing the `segment_id`.
- **label_id** -- An integer representing the label / semantic class id corresponding to `segment_id`.
- **was_fused** -- a boolean, `True` if `label_id` was in `label_ids_to_fuse`, `False` otherwise.
Multiple instances of the same class / label were fused and assigned a single `segment_id`.
- **score** -- Prediction score of segment with `segment_id`.
"""
class_queries_logits = outputs.class_queries_logits # [batch_size, num_queries, num_classes+1]
masks_queries_logits = outputs.masks_queries_logits # [batch_size, num_queries, height, width]
batch_size = class_queries_logits.shape[0]
num_queries = class_queries_logits.shape[1]
num_classes = class_queries_logits.shape[-1] - 1
# Loop over items in batch size
results: List[Dict[str, torch.Tensor]] = []
for i in range(batch_size):
# [Q, K]
scores = torch.nn.functional.softmax(class_queries_logits[i], dim=-1)[:, :-1]
labels = torch.arange(num_classes).unsqueeze(0).repeat(num_queries, 1).flatten(0, 1)
# scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.num_queries, sorted=False)
scores_per_image, topk_indices = scores.flatten(0, 1).topk(num_queries, sorted=False)
labels_per_image = labels[topk_indices]
topk_indices = topk_indices // num_classes
# mask_pred = mask_pred.unsqueeze(1).repeat(1, self.sem_seg_head.num_classes, 1).flatten(0, 1)
mask_pred = masks_queries_logits[i][topk_indices]
# Only consider scores with confidence over [threshold] for demo
if is_demo:
keep = scores_per_image > threshold
scores_per_image = scores_per_image[keep]
labels_per_image = labels_per_image[keep]
mask_pred = mask_pred[keep]
# if this is panoptic segmentation, we only keep the "thing" classes
if task_type == "panoptic":
keep = torch.zeros_like(scores_per_image).bool()
for i, lab in enumerate(labels_per_image):
keep[i] = lab in self.metadata["thing_ids"]
scores_per_image = scores_per_image[keep]
labels_per_image = labels_per_image[keep]
mask_pred = mask_pred[keep]
if mask_pred.shape[0] <= 0:
height, width = target_sizes[i] if target_sizes is not None else mask_pred.shape[1:]
segmentation = torch.zeros((height, width)) - 1
results.append({"segmentation": segmentation, "segments_info": []})
continue
if "ade20k" in self.class_info_file and not is_demo and "instance" in task_type:
for i in range(labels_per_image.shape[0]):
labels_per_image[i] = self.metadata["thing_ids"].index(labels_per_image[i].item())
# Get segmentation map and segment information of batch item
target_size = target_sizes[i] if target_sizes is not None else None
segmentation, segments = compute_segments(
mask_pred,
scores_per_image,
labels_per_image,
mask_threshold,
overlap_mask_area_threshold,
set(),
target_size,
)
# Return segmentation map in run-length encoding (RLE) format
if return_coco_annotation:
segmentation = convert_segmentation_to_rle(segmentation)
results.append({"segmentation": segmentation, "segments_info": segments})
return results
# Copied from transformers.models.maskformer.image_processing_maskformer.MaskFormerImageProcessor.post_process_panoptic_segmentation
def post_process_panoptic_segmentation(
self,
outputs,
threshold: float = 0.5,
mask_threshold: float = 0.5,
overlap_mask_area_threshold: float = 0.8,
label_ids_to_fuse: Optional[Set[int]] = None,
target_sizes: Optional[List[Tuple[int, int]]] = None,
) -> List[Dict]:
"""
Converts the output of [`MaskFormerForInstanceSegmentationOutput`] into image panoptic segmentation
predictions. Only supports PyTorch.
Args:
outputs ([`MaskFormerForInstanceSegmentationOutput`]):
The outputs from [`MaskFormerForInstanceSegmentation`].
threshold (`float`, *optional*, defaults to 0.5):
The probability score threshold to keep predicted instance masks.
mask_threshold (`float`, *optional*, defaults to 0.5):
Threshold to use when turning the predicted masks into binary values.
overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8):
The overlap mask area threshold to merge or discard small disconnected parts within each binary
instance mask.
label_ids_to_fuse (`Set[int]`, *optional*):
The labels in this state will have all their instances be fused together. For instance we could say
there can only be one sky in an image, but several persons, so the label ID for sky would be in that
set, but not the one for person.
target_sizes (`List[Tuple]`, *optional*):
List of length (batch_size), where each list item (`Tuple[int, int]]`) corresponds to the requested
final size (height, width) of each prediction in batch. If left to None, predictions will not be
resized.
Returns:
`List[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys:
- **segmentation** -- a tensor of shape `(height, width)` where each pixel represents a `segment_id`, set
to `None` if no mask if found above `threshold`. If `target_sizes` is specified, segmentation is resized
to the corresponding `target_sizes` entry.
- **segments_info** -- A dictionary that contains additional information on each segment.
- **id** -- an integer representing the `segment_id`.
- **label_id** -- An integer representing the label / semantic class id corresponding to `segment_id`.
- **was_fused** -- a boolean, `True` if `label_id` was in `label_ids_to_fuse`, `False` otherwise.
Multiple instances of the same class / label were fused and assigned a single `segment_id`.
- **score** -- Prediction score of segment with `segment_id`.
"""
if label_ids_to_fuse is None:
logger.warning("`label_ids_to_fuse` unset. No instance will be fused.")
label_ids_to_fuse = set()
class_queries_logits = outputs.class_queries_logits # [batch_size, num_queries, num_classes+1]
masks_queries_logits = outputs.masks_queries_logits # [batch_size, num_queries, height, width]
batch_size = class_queries_logits.shape[0]
num_labels = class_queries_logits.shape[-1] - 1
mask_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width]
# Predicted label and score of each query (batch_size, num_queries)
pred_scores, pred_labels = nn.functional.softmax(class_queries_logits, dim=-1).max(-1)
# Loop over items in batch size
results: List[Dict[str, TensorType]] = []
for i in range(batch_size):
mask_probs_item, pred_scores_item, pred_labels_item = remove_low_and_no_objects(
mask_probs[i], pred_scores[i], pred_labels[i], threshold, num_labels
)
# No mask found
if mask_probs_item.shape[0] <= 0:
height, width = target_sizes[i] if target_sizes is not None else mask_probs_item.shape[1:]
segmentation = torch.zeros((height, width)) - 1
results.append({"segmentation": segmentation, "segments_info": []})
continue
# Get segmentation map and segment information of batch item
target_size = target_sizes[i] if target_sizes is not None else None
segmentation, segments = compute_segments(
mask_probs=mask_probs_item,
pred_scores=pred_scores_item,
pred_labels=pred_labels_item,
mask_threshold=mask_threshold,
overlap_mask_area_threshold=overlap_mask_area_threshold,
label_ids_to_fuse=label_ids_to_fuse,
target_size=target_size,
)
results.append({"segmentation": segmentation, "segments_info": segments})
return results
# coding=utf-8
# Copyright 2022 SHI Labs and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch OneFormer model."""
import copy
import math
import warnings
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
from torch import Tensor, nn
from torch.cuda.amp import autocast
from transformers import AutoBackbone
from transformers.utils import logging
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import PreTrainedModel
from ...utils import (
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_scipy_available,
replace_return_docstrings,
requires_backends,
)
from .configuration_oneformer import OneFormerConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "OneFormerConfig"
_CHECKPOINT_FOR_DOC = "shi-labs/oneformer_ade20k_swin_tiny"
_IMAGE_PROCESSOR_FOR_DOC = "OneFormerImageProcessor"
ONEFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
"shi-labs/oneformer_ade20k_swin_tiny",
# See all OneFormer models at https://huggingface.co/models?filter=oneformer
]
if is_scipy_available():
from scipy.optimize import linear_sum_assignment
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
def multiscale_deform_attn_core_pytorch(
value: Tensor, value_spatial_shapes: Tensor, sampling_locations: Tensor, attention_weights: Tensor
) -> Tensor:
batch_size, _, num_heads, hidden_dim = value.shape
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
value_list = value.split([height * width for height, width in value_spatial_shapes], dim=1)
sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for level_id, (height, width) in enumerate(value_spatial_shapes):
# batch_size, height*width, num_heads, hidden_dim
# -> batch_size, height*width, num_heads*hidden_dim
# -> batch_size, num_heads*hidden_dim, height*width
# -> batch_size*num_heads, hidden_dim, height, width
value_l_ = (
value_list[level_id].flatten(2).transpose(1, 2).reshape(batch_size * num_heads, hidden_dim, height, width)
)
# batch_size, num_queries, num_heads, num_points, 2
# -> batch_size, num_heads, num_queries, num_points, 2
# -> batch_size*num_heads, num_queries, num_points, 2
sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1)
# batch_size*num_heads, hidden_dim, num_queries, num_points
sampling_value_l_ = nn.functional.grid_sample(
value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False
)
sampling_value_list.append(sampling_value_l_)
# (batch_size, num_queries, num_heads, num_levels, num_points)
# -> (batch_size, num_heads, num_queries, num_levels, num_points)
# -> (batch_size, num_heads, 1, num_queries, num_levels*num_points)
attention_weights = attention_weights.transpose(1, 2).reshape(
batch_size * num_heads, 1, num_queries, num_levels * num_points
)
output = (
(torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
.sum(-1)
.view(batch_size, num_heads * hidden_dim, num_queries)
)
return output.transpose(1, 2).contiguous()
# Copied from transformers.models.maskformer.modeling_maskformer.dice_loss
def dice_loss(inputs: Tensor, labels: Tensor, num_masks: int) -> Tensor:
r"""
Compute the DICE loss, similar to generalized IOU for masks as follows:
$$ \mathcal{L}_{\text{dice}(x, y) = 1 - \frac{2 * x \cap y }{x \cup y + 1}} $$
In practice, since `labels` is a binary mask, (only 0s and 1s), dice can be computed as follow
$$ \mathcal{L}_{\text{dice}(x, y) = 1 - \frac{2 * x * y }{x + y + 1}} $$
Args:
inputs (`torch.Tensor`):
A tensor representing a mask.
labels (`torch.Tensor`):
A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs
(0 for the negative class and 1 for the positive class).
num_masks (`int`):
The number of masks present in the current batch, used for normalization.
Returns:
`torch.Tensor`: The computed loss.
"""
probs = inputs.sigmoid().flatten(1)
numerator = 2 * (probs * labels).sum(-1)
denominator = probs.sum(-1) + labels.sum(-1)
loss = 1 - (numerator + 1) / (denominator + 1)
loss = loss.sum() / num_masks
return loss
# Copied from transformers.models.mask2former.modeling_mask2former.sigmoid_cross_entropy_loss
def sigmoid_cross_entropy_loss(inputs: torch.Tensor, labels: torch.Tensor, num_masks: int) -> torch.Tensor:
r"""
Args:
inputs (`torch.Tensor`):
A float tensor of arbitrary shape.
labels (`torch.Tensor`):
A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs
(0 for the negative class and 1 for the positive class).
Returns:
loss (`torch.Tensor`): The computed loss.
"""
criterion = nn.BCEWithLogitsLoss(reduction="none")
cross_entropy_loss = criterion(inputs, labels)
loss = cross_entropy_loss.mean(1).sum() / num_masks
return loss
# Copied from transformers.models.maskformer.modeling_maskformer.pair_wise_dice_loss
def pair_wise_dice_loss(inputs: Tensor, labels: Tensor) -> Tensor:
"""
A pair wise version of the dice loss, see `dice_loss` for usage.
Args:
inputs (`torch.Tensor`):
A tensor representing a mask
labels (`torch.Tensor`):
A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs
(0 for the negative class and 1 for the positive class).
Returns:
`torch.Tensor`: The computed loss between each pairs.
"""
inputs = inputs.sigmoid().flatten(1)
numerator = 2 * torch.einsum("nc,mc->nm", inputs, labels)
# using broadcasting to get a [num_queries, NUM_CLASSES] matrix
denominator = inputs.sum(-1)[:, None] + labels.sum(-1)[None, :]
loss = 1 - (numerator + 1) / (denominator + 1)
return loss
# Copied from transformers.models.mask2former.modeling_mask2former.pair_wise_sigmoid_cross_entropy_loss
def pair_wise_sigmoid_cross_entropy_loss(inputs: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
r"""
A pair wise version of the cross entropy loss, see `sigmoid_cross_entropy_loss` for usage.
Args:
inputs (`torch.Tensor`):
A tensor representing a mask.
labels (`torch.Tensor`):
A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs
(0 for the negative class and 1 for the positive class).
Returns:
loss (`torch.Tensor`): The computed loss between each pairs.
"""
height_and_width = inputs.shape[1]
criterion = nn.BCEWithLogitsLoss(reduction="none")
cross_entropy_loss_pos = criterion(inputs, torch.ones_like(inputs))
cross_entropy_loss_neg = criterion(inputs, torch.zeros_like(inputs))
loss = torch.einsum("nc,mc->nm", cross_entropy_loss_pos, labels) + torch.einsum(
"nc,mc->nm", cross_entropy_loss_neg, (1 - labels)
)
loss = loss / height_and_width
return loss
# Copied from transformers.models.mask2former.modeling_mask2former.sample_point
def sample_point(
input_features: torch.Tensor, point_coordinates: torch.Tensor, add_dim=False, **kwargs
) -> torch.Tensor:
"""
A wrapper around `torch.nn.functional.grid_sample` to support 3D point_coordinates tensors.
Args:
input_features (`torch.Tensor` of shape (batch_size, channels, height, width)):
A tensor that contains features map on a height * width grid
point_coordinates (`torch.Tensor` of shape (batch_size, num_points, 2) or (batch_size, grid_height, grid_width,:
2)):
A tensor that contains [0, 1] * [0, 1] normalized point coordinates
add_dim (`bool`):
boolean value to keep track of added dimension
Returns:
point_features (`torch.Tensor` of shape (batch_size, channels, num_points) or (batch_size, channels,
height_grid, width_grid):
A tensor that contains features for points in `point_coordinates`.
"""
if point_coordinates.dim() == 3:
add_dim = True
point_coordinates = point_coordinates.unsqueeze(2)
# use nn.function.grid_sample to get features for points in `point_coordinates` via bilinear interpolation
point_features = torch.nn.functional.grid_sample(input_features, 2.0 * point_coordinates - 1.0, **kwargs)
if add_dim:
point_features = point_features.squeeze(3)
return point_features
# Refactored from https://github.com/SHI-Labs/OneFormer/blob/33ebb56ed34f970a30ae103e786c0cb64c653d9a/oneformer/modeling/matcher.py#L93
class OneFormerHungarianMatcher(nn.Module):
def __init__(
self, cost_class: float = 1.0, cost_mask: float = 1.0, cost_dice: float = 1.0, num_points: int = 12544
):
"""This class computes an assignment between the labels and the predictions of the network.
For efficiency reasons, the labels don't include the no_object. Because of this, in general, there are more
predictions than labels. In this case, we do a 1-to-1 matching of the best predictions, while the others are
un-matched (and thus treated as non-objects).
Params:
cost_class (float, *optional*, defaults to 1.0):
This is the relative weight of the classification error in the matching cost.
cost_mask (float, *optional*, defaults to 1.0):
This is the relative weight of the sigmoid ce loss of the binary mask in the matching cost.
cost_dice (float, *optional*, defaults to 1.0):
This is the relative weight of the dice loss of the binary mask in the matching cost
num_points (int, *optional*, defaults to 12544):
Number of points to be sampled for dice and mask loss matching cost.
"""
super().__init__()
if cost_class == 0 and cost_mask == 0 and cost_dice == 0:
raise ValueError("All costs cant be 0")
self.cost_class = cost_class
self.cost_mask = cost_mask
self.cost_dice = cost_dice
self.num_points = num_points
@torch.no_grad()
def forward(self, masks_queries_logits, class_queries_logits, mask_labels, class_labels) -> List[Tuple[Tensor]]:
"""Performs the matching
Params:
masks_queries_logits (`torch.Tensor`):
A tensor` of dim `batch_size, num_queries, num_labels` with the
classification logits.
class_queries_logits (`torch.Tensor`):
A tensor` of dim `batch_size, num_queries, height, width` with the
predicted masks.
class_labels (`torch.Tensor`):
A tensor` of dim `num_target_boxes` (where num_target_boxes is the number
of ground-truth objects in the target) containing the class labels.
mask_labels (`torch.Tensor`):
A tensor` of dim `num_target_boxes, height, width` containing the target
masks.
Returns:
`List[Tuple[Tensor]]`: A list of size batch_size, containing tuples of (index_i, index_j) where:
- index_i is the indices of the selected predictions (in order)
- index_j is the indices of the corresponding selected labels (in order)
For each batch element, it holds:
len(index_i) = len(index_j) = min(num_queries, num_targets).
"""
indices: List[Tuple[np.array]] = []
num_queries = class_queries_logits.shape[1]
preds_masks = masks_queries_logits
preds_probs = class_queries_logits
# iterate through batch size
for pred_probs, pred_mask, target_mask, labels in zip(preds_probs, preds_masks, mask_labels, class_labels):
pred_probs = pred_probs.softmax(-1)
# Compute the classification cost. Contrary to the loss, we don't use the NLL,
# but approximate it in 1 - proba[target class].
# The 1 is a constant that doesn't change the matching, it can be ommitted.
cost_class = -pred_probs[:, labels]
pred_mask = pred_mask[:, None]
target_mask = target_mask[:, None].to(pred_mask.device)
# all masks share the same set of points for efficient matching!
point_coords = torch.rand(1, self.num_points, 2, device=pred_mask.device)
# get ground truth labels
target_mask = sample_point(
target_mask,
point_coords.repeat(target_mask.shape[0], 1, 1),
align_corners=False,
).squeeze(1)
pred_mask = sample_point(
pred_mask,
point_coords.repeat(pred_mask.shape[0], 1, 1),
align_corners=False,
).squeeze(1)
with autocast(enabled=False):
pred_mask = pred_mask.float()
target_mask = target_mask.float()
# compute the sigmoid ce loss
cost_mask = pair_wise_sigmoid_cross_entropy_loss(pred_mask, target_mask)
# Compute the dice loss
cost_dice = pair_wise_dice_loss(pred_mask, target_mask)
# final cost matrix
cost_matrix = self.cost_mask * cost_mask + self.cost_class * cost_class + self.cost_dice * cost_dice
cost_matrix = cost_matrix.reshape(num_queries, -1).cpu()
# do the assigmented using the hungarian algorithm in scipy
assigned_indices: Tuple[np.array] = linear_sum_assignment(cost_matrix.cpu())
indices.append(assigned_indices)
# It could be stacked in one tensor
matched_indices = [
(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices
]
return matched_indices
class OneFormerLoss(nn.Module):
def __init__(
self,
num_classes: int,
matcher: OneFormerHungarianMatcher,
weight_dict: Dict[str, float],
eos_coef: float,
num_points: int,
oversample_ratio: float,
importance_sample_ratio: float,
contrastive_temperature: float = None,
):
"""
This class computes the losses using the class predictions, mask predictions and the contrastive queries.
Oneformer calculates the classification CE loss on the class predictions. Mask predictions are used for
calculating the binary CE loss and dice loss. The contrastive queries are used for calculating the contrastive
loss.
Args:
num_labels (`int`):
The number of classes.
matcher (`OneFormerHungarianMatcher`):
A torch module that computes the assigments between the predictions and labels.
weight_dict (`Dict[str, float]`):
A dictionary of weights to be applied to the different losses.
eos_coef (`float`):
Weight to apply to the null class.
num_points (`int`):
Number of points to be sampled for dice and mask loss calculations.
oversample_ratio (`float`):
Required for pointwise loss calculation.
importance_sample_ratio (`float`):
Required for pointwise loss calculation.
contrastive_temperature (`float`):
Temperature for scaling the contrastive logits.
"""
requires_backends(self, ["scipy"])
super().__init__()
self.num_classes = num_classes
self.matcher = matcher
self.weight_dict = weight_dict
self.eos_coef = eos_coef
empty_weight = torch.ones(self.num_classes + 1)
empty_weight[-1] = self.eos_coef
self.register_buffer("empty_weight", empty_weight)
# pointwise mask loss parameters
self.num_points = num_points
self.oversample_ratio = oversample_ratio
self.importance_sample_ratio = importance_sample_ratio
self.contrastive_temperature = contrastive_temperature
if self.contrastive_temperature is not None:
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / contrastive_temperature))
def _max_by_axis(self, the_list: List[List[int]]) -> List[int]:
maxes = the_list[0]
for sublist in the_list[1:]:
for index, item in enumerate(sublist):
maxes[index] = max(maxes[index], item)
return maxes
def _pad_images_to_max_in_batch(self, tensors: List[Tensor]) -> Tuple[Tensor, Tensor]:
# get the maximum size in the batch
max_size = self._max_by_axis([list(tensor.shape) for tensor in tensors])
batch_size = len(tensors)
# compute finel size
batch_shape = [batch_size] + max_size
b, _, h, w = batch_shape
# get metadata
dtype = tensors[0].dtype
device = tensors[0].device
padded_tensors = torch.zeros(batch_shape, dtype=dtype, device=device)
padding_masks = torch.ones((b, h, w), dtype=torch.bool, device=device)
# pad the tensors to the size of the biggest one
for tensor, padded_tensor, padding_mask in zip(tensors, padded_tensors, padding_masks):
padded_tensor[: tensor.shape[0], : tensor.shape[1], : tensor.shape[2]].copy_(tensor)
padding_mask[: tensor.shape[1], : tensor.shape[2]] = False
return padded_tensors, padding_masks
def loss_contrastive(self, contrastive_queries_logits: Tensor, text_queries: Tensor):
"""Compute the query-text contrastive loss.
Args:
contrastive_queries_logits (`torch.Tensor`):
A tensor of shape `batch_size, num_queries, hidden_dim`
text_queries (`torch.Tensor`):
A tensor of shape `batch_size, num_queries, hidden_dim`
Returns:
`Dict[str, Tensor]`: A dict of `torch.Tensor` containing the following key:
- **loss_contrastive** -- The query-text contrastive loss computed using task-guided queries
and text queries derived from input text list.
"""
image_queries = contrastive_queries_logits.float()
# [batch_size, hidden_dim]
image_queries = nn.functional.normalize(image_queries.flatten(1), dim=-1)
text_queries = nn.functional.normalize(text_queries.flatten(1), dim=-1)
logit_scale = torch.clamp(self.logit_scale.exp(), max=100)
logits_per_text = torch.matmul(text_queries, image_queries.t()) * logit_scale
logits_per_img = logits_per_text.t()
loss_img = nn.functional.cross_entropy(
logits_per_img, torch.arange(len(logits_per_img), device=logits_per_text.device)
)
loss_text = nn.functional.cross_entropy(
logits_per_text, torch.arange(len(logits_per_text), device=logits_per_text.device)
)
loss_contrastive = loss_img + loss_text
losses = {"loss_contrastive": loss_contrastive}
return losses
def loss_labels(
self, class_queries_logits: Tensor, class_labels: List[Tensor], indices: Tuple[np.array]
) -> Dict[str, Tensor]:
"""Compute the losses related to the labels using cross entropy.
Args:
class_queries_logits (`torch.Tensor`):
A tensor of shape `batch_size, num_queries, num_labels`
class_labels (`List[torch.Tensor]`):
List of class labels of shape `(labels)`.
indices (`Tuple[np.array])`:
The indices computed by the Hungarian matcher.
Returns:
`Dict[str, Tensor]`: A dict of `torch.Tensor` containing the following key:
- **loss_cross_entropy** -- The loss computed using cross entropy on the predicted and ground truth labels.
"""
pred_logits = class_queries_logits
batch_size, num_queries, _ = pred_logits.shape
criterion = nn.CrossEntropyLoss(weight=self.empty_weight)
idx = self._get_predictions_permutation_indices(indices)
# shape = (batch_size, num_queries)
target_classes_o = torch.cat([target[j] for target, (_, j) in zip(class_labels, indices)])
# shape = (batch_size, num_queries)
target_classes = torch.full(
(batch_size, num_queries), fill_value=self.num_classes, dtype=torch.int64, device=pred_logits.device
)
target_classes[idx] = target_classes_o
# permute pred_logits (batch_size, num_queries, num_labels) -> (batch_size, num_labels, num_queries)
pred_logits_transposed = pred_logits.transpose(1, 2)
loss_ce = criterion(pred_logits_transposed, target_classes)
losses = {"loss_cross_entropy": loss_ce}
return losses
def loss_masks(
self, masks_queries_logits: Tensor, mask_labels: List[Tensor], indices: Tuple[np.array], num_masks: int
) -> Dict[str, Tensor]:
"""Compute the losses related to the masks using focal and dice loss.
Args:
masks_queries_logits (`torch.Tensor`):
A tensor of shape `batch_size, num_queries, height, width`
mask_labels (`torch.Tensor`):
List of mask labels of shape `(labels, height, width)`.
indices (`Tuple[np.array])`:
The indices computed by the Hungarian matcher.
num_masks (`int)`:
The number of masks, used for normalization.
Returns:
`Dict[str, Tensor]`: A dict of `torch.Tensor` containing two keys:
- **loss_mask** -- The loss computed using sigmoid ce loss on the predicted and ground truth masks.
- **loss_dice** -- The loss computed using dice loss on the predicted on the predicted and ground truth
masks.
"""
src_idx = self._get_predictions_permutation_indices(indices)
tgt_idx = self._get_targets_permutation_indices(indices)
# shape (batch_size * num_queries, height, width)
pred_masks = masks_queries_logits[src_idx]
# shape (batch_size, num_queries, height, width)
# pad all and stack the targets to the num_labels dimension
# upsample predictions to the target size, we have to add one dim to use interpolate
target_masks, _ = self._pad_images_to_max_in_batch(mask_labels)
target_masks = target_masks[tgt_idx]
pred_masks = pred_masks[:, None]
target_masks = target_masks[:, None]
with torch.no_grad():
# sample point_coords
point_coords = self.sample_points_using_uncertainty(
pred_masks,
self.calculate_uncertainty,
self.num_points,
self.oversample_ratio,
self.importance_sample_ratio,
)
# get ground-truth labels
point_labels = sample_point(target_masks, point_coords, align_corners=False).squeeze(1)
point_logits = sample_point(pred_masks, point_coords, align_corners=False).squeeze(1)
losses = {
"loss_mask": sigmoid_cross_entropy_loss(point_logits, point_labels, num_masks),
"loss_dice": dice_loss(point_logits, point_labels, num_masks),
}
del pred_masks
del target_masks
return losses
# Copied from transformers.models.mask2former.modeling_mask2former.Mask2FormerLoss.calculate_uncertainty
def calculate_uncertainty(self, logits: torch.Tensor) -> torch.Tensor:
"""
In Mask2Former paper, uncertainty is estimated as L1 distance between 0.0 and the logit prediction in 'logits'
for the foreground class in `classes`.
Args:
logits (`torch.Tensor`):
A tensor of shape (R, 1, ...) for class-specific or class-agnostic, where R is the total number of predicted masks in all images and C is:
the number of foreground classes. The values are logits.
Returns:
scores (`torch.Tensor`): A tensor of shape (R, 1, ...) that contains uncertainty scores with the most
uncertain locations having the highest uncertainty score.
"""
uncertainty_scores = -(torch.abs(logits))
return uncertainty_scores
# Copied from transformers.models.mask2former.modeling_mask2former.Mask2FormerLoss.sample_points_using_uncertainty
def sample_points_using_uncertainty(
self,
logits: torch.Tensor,
uncertainty_function,
num_points: int,
oversample_ratio: int,
importance_sample_ratio: float,
) -> torch.Tensor:
"""
This function is meant for sampling points in [0, 1] * [0, 1] coordinate space based on their uncertainty. The
uncertainty is calculated for each point using the passed `uncertainty function` that takes points logit
prediction as input.
Args:
logits (`float`):
Logit predictions for P points.
uncertainty_function:
A function that takes logit predictions for P points and returns their uncertainties.
num_points (`int`):
The number of points P to sample.
oversample_ratio (`int`):
Oversampling parameter.
importance_sample_ratio (`float`):
Ratio of points that are sampled via importance sampling.
Returns:
point_coordinates (`torch.Tensor`):
Coordinates for P sampled points.
"""
num_boxes = logits.shape[0]
num_points_sampled = int(num_points * oversample_ratio)
# Get random point coordinates
point_coordinates = torch.rand(num_boxes, num_points_sampled, 2, device=logits.device)
# Get sampled prediction value for the point coordinates
point_logits = sample_point(logits, point_coordinates, align_corners=False)
# Calculate the uncertainties based on the sampled prediction values of the points
point_uncertainties = uncertainty_function(point_logits)
num_uncertain_points = int(importance_sample_ratio * num_points)
num_random_points = num_points - num_uncertain_points
idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]
shift = num_points_sampled * torch.arange(num_boxes, dtype=torch.long, device=logits.device)
idx += shift[:, None]
point_coordinates = point_coordinates.view(-1, 2)[idx.view(-1), :].view(num_boxes, num_uncertain_points, 2)
if num_random_points > 0:
point_coordinates = torch.cat(
[point_coordinates, torch.rand(num_boxes, num_random_points, 2, device=logits.device)],
dim=1,
)
return point_coordinates
def _get_predictions_permutation_indices(self, indices):
# permute predictions following indices
batch_indices = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
predictions_indices = torch.cat([src for (src, _) in indices])
return batch_indices, predictions_indices
def _get_targets_permutation_indices(self, indices):
# permute labels following indices
batch_indices = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
target_indices = torch.cat([tgt for (_, tgt) in indices])
return batch_indices, target_indices
def forward(
self,
masks_queries_logits: Tensor,
class_queries_logits: Tensor,
contrastive_queries_logits: Tensor,
mask_labels: List[Tensor],
class_labels: List[Tensor],
text_queries: Tensor,
auxiliary_predictions: Optional[Dict[str, Tensor]] = None,
calculate_contrastive_loss: bool = True,
) -> Dict[str, Tensor]:
"""
This performs the loss computation.
Args:
masks_queries_logits (`torch.Tensor`):
A tensor of shape `batch_size, num_queries, height, width`
class_queries_logits (`torch.Tensor`):
A tensor of shape `batch_size, num_queries, num_labels`
contrastive_queries_logits (`torch.Tensor`):
A tensor of shape `batch_size, num_queries, hidden_dim`
mask_labels (`torch.Tensor`):
List of mask labels of shape `(labels, height, width)`.
class_labels (`List[torch.Tensor]`):
List of class labels of shape `(labels)`.
text_queries (`torch.Tensor`):
A tensor of shape `batch_size, num_queries, hidden_dim`
auxiliary_predictions (`Dict[str, torch.Tensor]`, *optional*):
if `use_auxiliary_loss` was set to `true` in [`OneFormerConfig`], then it contains the logits from the
inner layers of the Detr's Decoder.
calculate_contrastive_loss (`bool`, *optional*, defaults to `True`):
Whether or not to calculate the contrastive loss.
Returns:
`Dict[str, Tensor]`: A dict of `torch.Tensor` containing two keys:
- **loss_cross_entropy** -- The loss computed using cross entropy on the predicted and ground truth labels.
- **loss_mask** -- The loss computed using sigmoid ce loss on the predicted and ground truth masks.
- **loss_dice** -- The loss computed using dice loss on the predicted on the predicted and ground truth
masks.
- **loss_contrastive** -- The query-text contrstive loss computed using object and text queries.
if `use_auxiliary_loss` was set to `true` in [`OneFormerConfig`], the dictionary contains addional losses
for each auxiliary predictions.
"""
# retrieve the matching between the outputs of the last layer and the labels
indices = self.matcher(masks_queries_logits, class_queries_logits, mask_labels, class_labels)
# compute the average number of target masks for normalization purposes
num_masks = self.get_num_masks(class_labels, device=class_labels[0].device)
# get all the losses
losses: Dict[str, Tensor] = {
**self.loss_masks(masks_queries_logits, mask_labels, indices, num_masks),
**self.loss_labels(class_queries_logits, class_labels, indices),
}
if calculate_contrastive_loss:
losses = {**losses, **self.loss_contrastive(contrastive_queries_logits, text_queries)}
# in case of auxiliary losses, we repeat this process with the output of each intermediate layer.
if auxiliary_predictions is not None:
for idx, aux_outputs in enumerate(auxiliary_predictions):
masks_queries_logits = aux_outputs["masks_queries_logits"]
class_queries_logits = aux_outputs["class_queries_logits"]
loss_dict = self.forward(
masks_queries_logits,
class_queries_logits,
None,
mask_labels,
class_labels,
None,
calculate_contrastive_loss=False,
)
loss_dict = {f"{key}_{idx}": value for key, value in loss_dict.items()}
losses.update(loss_dict)
return losses
def get_num_masks(self, class_labels: torch.Tensor, device: torch.device) -> torch.Tensor:
"""
Computes the average number of target masks across the batch, for normalization purposes.
"""
num_masks = sum([len(classes) for classes in class_labels])
num_masks_pt = torch.as_tensor([num_masks], dtype=torch.float, device=device)
return num_masks_pt
@dataclass
class OneFormerTransformerDecoderOutput(BaseModelOutput):
"""
Base class for outputs of the Transformer decoder. This class adds attributes for class predictions, mask
predictions and contrastive logits to BaseModelOutputWithCrossAttentions.
Args:
object_logits (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_dim)`):
Queries representation for the region proposals.
contrastive_logits (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_dim)`):
Queries representation for the contrastive loss.
prediction_masks (`torch.FloatTensor` of shape `(batch_size, num_queries, height, width)`):
Mask predictions from last layer of the transformer decoder.
prediction_class (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes+1)`):
Class predictions from last layer of the transformer decoder.
auxiliary_predictions (Tuple of Dict of `str, torch.FloatTensor`, *optional*):
Tuple of class and mask predictions from each layer of the transformer decoder.
"""
object_queries: torch.FloatTensor = None
contrastive_logits: Optional[torch.FloatTensor] = None
prediction_masks: torch.FloatTensor = None
prediction_class: torch.FloatTensor = None
auxiliary_predictions: Optional[Tuple[Dict[str, torch.FloatTensor]]] = None
@dataclass
# Copied from transformers.models.mask2former.modeling_mask2former.Mask2FormerPixelDecoderOutput with Mask2->One
class OneFormerPixelDecoderOutput(ModelOutput):
"""
OneFormer's pixel decoder module output, practically a Multi-Scale Deformable Attention based decoder. It returns
the mask features and the multiscale features.
Args:
multi_scale_features (`tuple(torch.FloatTensor)`):
Tuple of multi-scale features of scales [1/8, 1/16, 1/32] and shape `(batch_size, num_channels, height,
width)`from the Multi-Scale Deformable Attenntion based Pixel Decoder.
mask_features (`torch.FloatTensor`):
Tensor of shape `(batch_size, num_channels, height, width)`, 1/4 scale features from the last Pixel Decoder
Layer.
attentions (`tuple(torch.FloatTensor)`, *optional*):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`. Attentions weights from pixel decoder. Returned when `output_attentions=True` is passed
or when `config.output_attentions=True`
"""
multi_scale_features: Tuple[torch.FloatTensor] = None
mask_features: torch.FloatTensor = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class OneFormerPixelLevelModuleOutput(ModelOutput):
"""
OneFormer's pixel level module output. It returns both the last and (optionally) the hidden states from the
`encoder` and `decoder`. By default, the `encoder` is a Swin/Dinat Backbone and the `decoder` is a Multi-Scale
Deformable Attention based decoder.
Args:
encoder_features (List of `(torch.FloatTensor)`):
List of `torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`. Hidden-states (also
called feature maps) of the model at the output of each stage.
decoder_features (List of `(torch.FloatTensor)`):
List of `torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`. Hidden-states (also
called feature maps) of the model at the output of each stage.
decoder_last_feature (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)):
1/4 scale features from the last Pixel Decoder Layer.
"""
encoder_features: List[torch.FloatTensor] = None
decoder_features: List[torch.FloatTensor] = None
decoder_last_feature: torch.FloatTensor = None
@dataclass
class OneFormerModelOutput(ModelOutput):
"""
Class for outputs of [`OneFormerModel`]. This class returns all the needed hidden states to compute the logits.
Args:
encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the encoder
model at the output of each stage.
pixel_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the pixel
decoder model at the output of each stage.
transformer_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, sequence_length, hidden_size)`. Hidden-states (also called feature maps) of the
transformer decoder at the output of each stage.
transformer_decoder_object_queries (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_dim)`)
Output object queries from the last layer in the transformer decoder.
transformer_decoder_contrastive_queries (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_dim)`)
Contrastive queries from the transformer decoder.
transformer_decoder_mask_predictions (`torch.FloatTensor` of shape `(batch_size, num_queries, height, width)`)
Mask Predictions from the last layer in the transformer decoder.
transformer_decoder_class_predictions (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes+1)`):
Class Predictions from the last layer in the transformer decoder.
transformer_decoder_auxiliary_predictions (Tuple of Dict of `str, torch.FloatTensor`, *optional*):
Tuple of class and mask predictions from each layer of the transformer decoder.
text_queries (`torch.FloatTensor`, *optional* of shape `(batch_size, num_queries, hidden_dim)`)
Text queries derived from the input text list used for calculating contrastive loss during training.
task_token (`torch.FloatTensor` of shape `(batch_size, hidden_dim)`)
1D task token to condition the queries.
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `tuple(torch.FloatTensor)` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`. Self and Cross Attentions weights from transformer decoder.
"""
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
pixel_decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
transformer_decoder_hidden_states: Optional[torch.FloatTensor] = None
transformer_decoder_object_queries: torch.FloatTensor = None
transformer_decoder_contrastive_queries: Optional[torch.FloatTensor] = None
transformer_decoder_mask_predictions: torch.FloatTensor = None
transformer_decoder_class_predictions: torch.FloatTensor = None
transformer_decoder_auxiliary_predictions: Optional[Tuple[Dict[str, torch.FloatTensor]]] = None
text_queries: Optional[torch.FloatTensor] = None
task_token: torch.FloatTensor = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class OneFormerForUniversalSegmentationOutput(ModelOutput):
"""
Class for outputs of [`OneFormerForUniversalSegmentationOutput`].
This output can be directly passed to [`~OneFormerImageProcessor.post_process_semantic_segmentation`] or
[`~OneFormerImageProcessor.post_process_instance_segmentation`] or
[`~OneFormerImageProcessor.post_process_panoptic_segmentation`] depending on the task. Please, see
[`~OneFormerImageProcessor] for details regarding usage.
Args:
loss (`torch.Tensor`, *optional*):
The computed loss, returned when labels are present.
class_queries_logits (`torch.FloatTensor`):
A tensor of shape `(batch_size, num_queries, num_labels + 1)` representing the proposed classes for each
query. Note the `+ 1` is needed because we incorporate the null class.
masks_queries_logits (`torch.FloatTensor`):
A tensor of shape `(batch_size, num_queries, height, width)` representing the proposed masks for each
query.
auxiliary_predictions (List of Dict of `str, torch.FloatTensor`, *optional*):
List of class and mask predictions from each layer of the transformer decoder.
encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the encoder
model at the output of each stage.
pixel_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the pixel
decoder model at the output of each stage.
transformer_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, sequence_length, hidden_size)`. Hidden-states (also called feature maps) of the
transformer decoder at the output of each stage.
transformer_decoder_object_queries (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_dim)`)
Output object queries from the last layer in the transformer decoder.
transformer_decoder_contrastive_queries (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_dim)`)
Contrastive queries from the transformer decoder.
transformer_decoder_mask_predictions (`torch.FloatTensor` of shape `(batch_size, num_queries, height, width)`)
Mask Predictions from the last layer in the transformer decoder.
transformer_decoder_class_predictions (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes+1)`):
Class Predictions from the last layer in the transformer decoder.
transformer_decoder_auxiliary_predictions (List of Dict of `str, torch.FloatTensor`, *optional*):
List of class and mask predictions from each layer of the transformer decoder.
text_queries (`torch.FloatTensor`, *optional* of shape `(batch_size, num_queries, hidden_dim)`)
Text queries derived from the input text list used for calculating contrastive loss during training.
task_token (`torch.FloatTensor` of shape `(batch_size, hidden_dim)`)
1D task token to condition the queries.
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `tuple(torch.FloatTensor)` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`. Self and Cross Attentions weights from transformer decoder.
"""
loss: Optional[torch.FloatTensor] = None
class_queries_logits: torch.FloatTensor = None
masks_queries_logits: torch.FloatTensor = None
auxiliary_predictions: List[Dict[str, torch.FloatTensor]] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
pixel_decoder_hidden_states: Optional[List[torch.FloatTensor]] = None
transformer_decoder_hidden_states: Optional[torch.FloatTensor] = None
transformer_decoder_object_queries: torch.FloatTensor = None
transformer_decoder_contrastive_queries: Optional[torch.FloatTensor] = None
transformer_decoder_mask_predictions: torch.FloatTensor = None
transformer_decoder_class_predictions: torch.FloatTensor = None
transformer_decoder_auxiliary_predictions: Optional[List[Dict[str, torch.FloatTensor]]] = None
text_queries: Optional[torch.FloatTensor] = None
task_token: torch.FloatTensor = None
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
# Modified from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrFrozenBatchNorm2d with DeformableDetr->OneFormerPixelDecoder
class OneFormerPixelDecoderFrozenBatchNorm2d(nn.Module):
"""
BatchNorm2d where the batch statistics and the affine parameters are fixed.
Copy-paste from torchvision.misc.ops with added eps before rqsrt, without which any other models than
torchvision.models.resnet[18,34,50,101] produce nans.
"""
def __init__(self, n):
super().__init__()
self.register_buffer("weight", torch.ones(n))
self.register_buffer("bias", torch.zeros(n))
self.register_buffer("running_mean", torch.zeros(n))
self.register_buffer("running_var", torch.ones(n))
def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
num_batches_tracked_key = prefix + "num_batches_tracked"
if num_batches_tracked_key in state_dict:
del state_dict[num_batches_tracked_key]
super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)
def forward(self, x):
weight = self.weight.reshape(1, -1, 1, 1)
bias = self.bias.reshape(1, -1, 1, 1)
running_var = self.running_var.reshape(1, -1, 1, 1)
running_mean = self.running_mean.reshape(1, -1, 1, 1)
epsilon = 1e-5
scale = weight * (running_var + epsilon).rsqrt()
bias = bias - running_mean * scale
return x * scale + bias
# Modified from transformers.models.detr.modeling_deformable_detr.DeformableDetrMultiscaleDeformableAttention with DeformableDetr->OneFormerPixelDecoderEncoder
class OneFormerPixelDecoderEncoderMultiscaleDeformableAttention(nn.Module):
"""
Multiscale deformable attention as proposed in Deformable DETR.
"""
def __init__(self, embed_dim: int, num_heads: int, n_levels: int, n_points: int):
super().__init__()
if embed_dim % num_heads != 0:
raise ValueError(
f"embed_dim (d_model) must be divisible by num_heads, but got {embed_dim} and {num_heads}"
)
dim_per_head = embed_dim // num_heads
# check if dim_per_head is power of 2
if not ((dim_per_head & (dim_per_head - 1) == 0) and dim_per_head != 0):
warnings.warn(
"You'd better set embed_dim (d_model) in DeformableDetrMultiscaleDeformableAttention to make the"
" dimension of each attention head a power of 2 which is more efficient in the authors' CUDA"
" implementation."
)
self.im2col_step = 128
self.d_model = embed_dim
self.n_levels = n_levels
self.n_heads = num_heads
self.n_points = n_points
self.sampling_offsets = nn.Linear(embed_dim, num_heads * n_levels * n_points * 2)
self.attention_weights = nn.Linear(embed_dim, num_heads * n_levels * n_points)
self.value_proj = nn.Linear(embed_dim, embed_dim)
self.output_proj = nn.Linear(embed_dim, embed_dim)
def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]):
return tensor if position_embeddings is None else tensor + position_embeddings
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states=None,
encoder_attention_mask=None,
position_embeddings: Optional[torch.Tensor] = None,
reference_points=None,
spatial_shapes=None,
level_start_index=None,
output_attentions: bool = False,
):
# add position embeddings to the hidden states before projecting to queries and keys
if position_embeddings is not None:
hidden_states = self.with_pos_embed(hidden_states, position_embeddings)
batch_size, num_queries, _ = hidden_states.shape
batch_size, sequence_length, _ = encoder_hidden_states.shape
if (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() != sequence_length:
raise ValueError(
"Make sure to align the spatial shapes with the sequence length of the encoder hidden states"
)
value = self.value_proj(encoder_hidden_states)
if attention_mask is not None:
# we invert the attention_mask
value = value.masked_fill(attention_mask[..., None], float(0))
value = value.view(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads)
sampling_offsets = self.sampling_offsets(hidden_states).view(
batch_size, num_queries, self.n_heads, self.n_levels, self.n_points, 2
)
attention_weights = self.attention_weights(hidden_states).view(
batch_size, num_queries, self.n_heads, self.n_levels * self.n_points
)
attention_weights = nn.functional.softmax(attention_weights, -1).view(
batch_size, num_queries, self.n_heads, self.n_levels, self.n_points
)
# batch_size, num_queries, n_heads, n_levels, n_points, 2
if reference_points.shape[-1] == 2:
offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
sampling_locations = (
reference_points[:, :, None, :, None, :]
+ sampling_offsets / offset_normalizer[None, None, None, :, None, :]
)
elif reference_points.shape[-1] == 4:
sampling_locations = (
reference_points[:, :, None, :, None, :2]
+ sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
)
else:
raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
# CPU
output = multiscale_deform_attn_core_pytorch(value, spatial_shapes, sampling_locations, attention_weights)
output = self.output_proj(output)
return output, attention_weights
class OneFormerPixelDecoderEncoderLayer(nn.Module):
def __init__(self, config: OneFormerConfig):
super().__init__()
self.embed_dim = config.conv_dim
self.self_attn = OneFormerPixelDecoderEncoderMultiscaleDeformableAttention(
embed_dim=self.embed_dim,
num_heads=config.num_attention_heads,
n_levels=3,
n_points=4,
)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.dropout = config.dropout
self.activation_fn = nn.functional.relu
self.activation_dropout = config.dropout
self.fc1 = nn.Linear(self.embed_dim, config.encoder_feedforward_dim)
self.fc2 = nn.Linear(config.encoder_feedforward_dim, self.embed_dim)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
self.is_training = config.is_training
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
position_embeddings: torch.Tensor = None,
reference_points=None,
spatial_shapes=None,
level_start_index=None,
output_attentions: bool = False,
):
"""
Args:
hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Input to the layer.
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Attention mask.
position_embeddings (`torch.FloatTensor`, *optional*):
Position embeddings, to be added to `hidden_states`.
reference_points (`torch.FloatTensor`, *optional*):
Reference points.
spatial_shapes (`torch.LongTensor`, *optional*):
Spatial shapes of the backbone feature maps.
level_start_index (`torch.LongTensor`, *optional*):
Level start index.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
residual = hidden_states
# Apply Multi-scale Deformable Attention Module on the multi-scale feature maps.
hidden_states, attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=hidden_states,
encoder_attention_mask=attention_mask,
position_embeddings=position_embeddings,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
output_attentions=output_attentions,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.is_training)
hidden_states = residual + hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
residual = hidden_states
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.is_training)
hidden_states = self.fc2(hidden_states)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.is_training)
hidden_states = residual + hidden_states
hidden_states = self.final_layer_norm(hidden_states)
if self.is_training:
if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any():
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
# Modified from from transformers.models.detr.modeling_deformable_detr.DeformableDetrEncoder with DeformableDetrEncoder->OneFormerPixelDecoderEncoderOnly
class OneFormerPixelDecoderEncoderOnly(nn.Module):
"""
Transformer encoder consisting of *config.encoder_layers* deformable attention layers. Each layer is a
[`OneFormerPixelDecoderEncoderLayer`].
The encoder updates the flattened multi-scale feature maps through multiple deformable attention layers.
Args:
config: OneFormerConfig
"""
def __init__(self, config: OneFormerConfig):
super().__init__()
self.config = config
self.dropout = config.dropout
self.layers = nn.ModuleList([OneFormerPixelDecoderEncoderLayer(config) for _ in range(config.encoder_layers)])
@staticmethod
def get_reference_points(spatial_shapes, valid_ratios, device):
"""
Get reference points for each feature map. Used in decoder.
Args:
spatial_shapes (`torch.LongTensor` of shape `(num_feature_levels, 2)`):
Spatial shapes of each feature map.
valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`):
Valid ratios of each feature map.
device (`torch.device`):
Device on which to create the tensors.
Returns:
`torch.FloatTensor` of shape `(batch_size, num_queries, num_feature_levels, 2)`
"""
reference_points_list = []
for lvl, (height, width) in enumerate(spatial_shapes):
ref_y, ref_x = torch.meshgrid(
torch.linspace(0.5, height - 0.5, height, dtype=torch.float32, device=device),
torch.linspace(0.5, width - 0.5, width, dtype=torch.float32, device=device),
)
ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * height)
ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * width)
ref = torch.stack((ref_x, ref_y), -1)
reference_points_list.append(ref)
reference_points = torch.cat(reference_points_list, 1)
reference_points = reference_points[:, :, None] * valid_ratios[:, None]
return reference_points
def forward(
self,
inputs_embeds=None,
attention_mask=None,
position_embeddings=None,
spatial_shapes=None,
level_start_index=None,
valid_ratios=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
Args:
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Flattened feature map (output of the backbone + projection layer) that is passed to the encoder.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding pixel features. Mask values selected in `[0, 1]`:
- 1 for pixel features that are real (i.e. **not masked**),
- 0 for pixel features that are padding (i.e. **masked**).
[What are attention masks?](../glossary#attention-mask)
position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Position embeddings that are added to the queries and keys in each self-attention layer.
spatial_shapes (`torch.LongTensor` of shape `(num_feature_levels, 2)`):
Spatial shapes of each feature map.
level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`):
Starting index of each feature map.
valid_ratios (`torch.FloatTensor` of shape `(batch_size, num_feature_levels, 2)`):
Ratio of valid area in each feature level.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
hidden_states = inputs_embeds
reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=inputs_embeds.device)
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
for i, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
position_embeddings=position_embeddings,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
)
# Modified from from transformers.models.mask2former.modeling_mask2former.Mask2FormerPixelDecoder with Mask2->One
class OneFormerPixelDecoder(nn.Module):
def __init__(self, config: OneFormerConfig, feature_channels):
super().__init__()
self.config = config
# positional encoding
self.position_embedding = OneFormerSinePositionEmbedding(num_pos_feats=config.conv_dim // 2, normalize=True)
self.num_feature_levels = 3
transformer_in_channels = feature_channels[-self.num_feature_levels :]
self.transformer_feature_strides = config.strides[-self.num_feature_levels :]
self.feature_channels = feature_channels
self.level_embed = nn.Parameter(torch.Tensor(self.num_feature_levels, config.conv_dim))
# Create input projection layers
if self.num_feature_levels > 1:
input_projections_list = []
for in_channels in transformer_in_channels[::-1]:
input_projections_list.append(
nn.Sequential(
nn.Conv2d(in_channels, config.conv_dim, kernel_size=1),
nn.GroupNorm(32, config.conv_dim),
)
)
self.input_projections = nn.ModuleList(input_projections_list)
else:
self.input_projections = nn.ModuleList(
[
nn.Sequential(
nn.Conv2d(transformer_in_channels[-1], config.conv_dim, kernel_size=1),
nn.GroupNorm(32, config.conv_dim),
)
]
)
self.encoder = OneFormerPixelDecoderEncoderOnly(config)
self.mask_projection = nn.Conv2d(
config.conv_dim,
config.mask_dim,
kernel_size=1,
stride=1,
padding=0,
)
self.common_stride = config.common_stride
# extra fpn levels
stride = min(self.transformer_feature_strides)
self.num_fpn_levels = int(np.log2(stride) - np.log2(self.common_stride))
lateral_convs = []
output_convs = []
for idx, in_channels in enumerate(self.feature_channels[: self.num_fpn_levels]):
lateral_conv = nn.Sequential(
nn.Conv2d(
in_channels,
config.conv_dim,
kernel_size=1,
bias=False,
),
nn.GroupNorm(32, config.conv_dim),
)
output_conv = nn.Sequential(
nn.Conv2d(
config.conv_dim,
config.conv_dim,
kernel_size=3,
stride=1,
padding=1,
bias=False,
),
nn.GroupNorm(32, config.conv_dim),
nn.ReLU(),
)
self.add_module("adapter_{}".format(idx + 1), lateral_conv)
self.add_module("layer_{}".format(idx + 1), output_conv)
lateral_convs.append(lateral_conv)
output_convs.append(output_conv)
# Place convs into top-down order (from low to high resolution)
# to make the top-down computation in forward clearer.
self.lateral_convs = lateral_convs[::-1]
self.output_convs = output_convs[::-1]
def get_valid_ratio(self, mask):
"""Get the valid ratio of all feature maps."""
_, height, width = mask.shape
valid_height = torch.sum(~mask[:, :, 0], 1)
valid_width = torch.sum(~mask[:, 0, :], 1)
valid_ratio_heigth = valid_height.float() / height
valid_ratio_width = valid_width.float() / width
valid_ratio = torch.stack([valid_ratio_width, valid_ratio_heigth], -1)
return valid_ratio
def forward(
self,
features,
encoder_outputs=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
# Then, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default)
sources = []
position_embeddings_list = []
for level, source in enumerate(features[::-1][: self.num_feature_levels]):
feats = source.float()
sources.append(self.input_projections[level](feats))
position_embeddings_list.append(self.position_embedding(feats))
masks = [torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) for x in sources]
# Prepare encoder inputs (by flattening)
source_flatten = []
mask_flatten = []
lvl_pos_embed_flatten = []
spatial_shapes = []
for level, (source, mask, pos_embed) in enumerate(zip(sources, masks, position_embeddings_list)):
batch_size, num_channels, height, width = source.shape
spatial_shape = (height, width)
spatial_shapes.append(spatial_shape)
source = source.flatten(2).transpose(1, 2)
mask = mask.flatten(1)
pos_embed = pos_embed.flatten(2).transpose(1, 2)
lvl_pos_embed = pos_embed + self.level_embed[level].view(1, 1, -1)
lvl_pos_embed_flatten.append(lvl_pos_embed)
source_flatten.append(source)
mask_flatten.append(mask)
source_flatten = torch.cat(source_flatten, 1)
mask_flatten = torch.cat(mask_flatten, 1)
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=source_flatten.device)
level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
valid_ratios = valid_ratios.float()
# Fourth, sent source_flatten + mask_flatten + lvl_pos_embed_flatten (backbone + proj layer output) through encoder
# Also provide spatial_shapes, level_start_index and valid_ratios
if encoder_outputs is None:
encoder_outputs = self.encoder(
inputs_embeds=source_flatten,
attention_mask=mask_flatten,
position_embeddings=lvl_pos_embed_flatten,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
valid_ratios=valid_ratios,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
y = encoder_outputs.last_hidden_state
bs = y.shape[0]
split_size_or_sections = [None] * self.num_feature_levels
for i in range(self.num_feature_levels):
if i < self.num_feature_levels - 1:
split_size_or_sections[i] = level_start_index[i + 1] - level_start_index[i]
else:
split_size_or_sections[i] = y.shape[1] - level_start_index[i]
y = torch.split(y, split_size_or_sections, dim=1)
out = []
multi_scale_features = []
num_cur_levels = 0
for i, z in enumerate(y):
out.append(z.transpose(1, 2).view(bs, -1, spatial_shapes[i][0], spatial_shapes[i][1]))
# append `out` with extra FPN levels
# Reverse feature maps into top-down order (from low to high resolution)
for idx, feats in enumerate(features[: self.num_fpn_levels][::-1]):
feats = feats.float()
lateral_conv = self.lateral_convs[idx]
output_conv = self.output_convs[idx]
cur_fpn = lateral_conv(feats)
# Following FPN implementation, we use nearest upsampling here
y = cur_fpn + nn.functional.interpolate(
out[-1], size=cur_fpn.shape[-2:], mode="bilinear", align_corners=False
)
y = output_conv(y)
out.append(y)
for o in out:
if num_cur_levels < self.num_feature_levels:
multi_scale_features.append(o)
num_cur_levels += 1
return OneFormerPixelDecoderOutput(
mask_features=self.mask_projection(out[-1]),
multi_scale_features=multi_scale_features,
attentions=encoder_outputs.attentions,
)
# Modified from from transformers.models.mask2former.modeling_mask2former.Mask2FormerPixelLevelModule with Mask2->One
class OneFormerPixelLevelModule(nn.Module):
def __init__(self, config: OneFormerConfig):
"""
Pixel Level Module proposed in [Masked-attention Mask Transformer for Universal Image
Segmentation](https://arxiv.org/abs/2112.01527). It runs the input image through a backbone and a pixel
decoder, generating multi-scale feature maps and pixel embeddings.
Args:
config ([`OneFormerConfig`]):
The configuration used to instantiate this model.
"""
super().__init__()
backbone_config = config.backbone_config
self.encoder = AutoBackbone.from_config(backbone_config)
self.decoder = OneFormerPixelDecoder(config, feature_channels=self.encoder.channels)
def forward(self, pixel_values: Tensor, output_hidden_states: bool = False) -> OneFormerPixelLevelModuleOutput:
features: List[Tensor] = self.encoder(pixel_values).feature_maps
decoder_output: OneFormerPixelDecoderOutput = self.decoder(features, output_hidden_states=output_hidden_states)
return OneFormerPixelLevelModuleOutput(
encoder_features=tuple(features),
decoder_features=decoder_output.multi_scale_features,
decoder_last_feature=decoder_output.mask_features,
)
# Modified from transformers.models.detr.modeling_detr.DetrAttention with Detr->OneFormer
class OneFormerAttention(nn.Module):
"""
Multi-headed attention from 'Attention Is All You Need' paper. Here, we add position embeddings to the queries and
keys (as explained in the DETR paper).
"""
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
if self.head_dim * num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {num_heads})."
)
self.scaling = self.head_dim**-0.5
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int):
return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]):
return tensor if position_embeddings is None else tensor + position_embeddings
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_embeddings: Optional[torch.Tensor] = None,
key_value_states: Optional[torch.Tensor] = None,
key_value_position_embeddings: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
hidden_states = hidden_states.permute(1, 0, 2) if hidden_states is not None else None
position_embeddings = position_embeddings.permute(1, 0, 2) if position_embeddings is not None else None
key_value_states = key_value_states.permute(1, 0, 2) if key_value_states is not None else None
key_value_position_embeddings = (
key_value_position_embeddings.permute(1, 0, 2) if key_value_position_embeddings is not None else None
)
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
batch_size, target_len, embed_dim = hidden_states.size()
# add position embeddings to the hidden states before projecting to queries and keys
if position_embeddings is not None:
hidden_states_original = hidden_states
hidden_states = self.with_pos_embed(hidden_states, position_embeddings)
# add key-value position embeddings to the key value states
if key_value_position_embeddings is not None:
key_value_states_original = key_value_states
key_value_states = self.with_pos_embed(key_value_states, key_value_position_embeddings)
# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
# get key, value proj
if is_cross_attention:
# cross_attentions
key_states = self._shape(self.k_proj(key_value_states), -1, batch_size)
value_states = self._shape(self.v_proj(key_value_states_original), -1, batch_size)
else:
# self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, batch_size)
value_states = self._shape(self.v_proj(hidden_states_original), -1, batch_size)
proj_shape = (batch_size * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, target_len, batch_size).view(*proj_shape)
key_states = key_states.view(*proj_shape)
value_states = value_states.view(*proj_shape)
source_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
if attn_weights.size() != (batch_size * self.num_heads, target_len, source_len):
raise ValueError(
f"Attention weights should be of size {(batch_size * self.num_heads, target_len, source_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (batch_size * self.num_heads, target_len, source_len):
raise ValueError(
f"Attention mask should be of size {(target_len, batch_size * self.num_heads, source_len)}, but is"
f" {attention_mask.size()}"
)
attn_weights += attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if output_attentions:
# this operation is a bit awkward, but it's required to
# make sure that attn_weights keeps its gradient.
# In order to do so, attn_weights have to reshaped
# twice and have to be reused in the following
attn_weights_reshaped = attn_weights.view(batch_size, self.num_heads, target_len, source_len)
attn_weights = attn_weights_reshaped.view(batch_size * self.num_heads, target_len, source_len)
else:
attn_weights_reshaped = None
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.bmm(attn_probs, value_states)
if attn_output.size() != (batch_size * self.num_heads, target_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(batch_size, self.num_heads, target_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.view(batch_size, self.num_heads, target_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(batch_size, target_len, embed_dim)
attn_output = self.out_proj(attn_output).permute(1, 0, 2)
return attn_output, attn_weights_reshaped
class OneFormerTransformerDecoderSelfAttentionLayer(nn.Module):
def __init__(self, embed_dim, num_heads, dropout=0.0, activation="relu", normalize_before=False):
super().__init__()
self.self_attn = OneFormerAttention(embed_dim=embed_dim, num_heads=num_heads, dropout=dropout, is_decoder=True)
self.norm = nn.LayerNorm(embed_dim)
self.dropout = nn.Dropout(dropout)
self.activation = ACT2FN[activation]
self.normalize_before = normalize_before
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward_post(
self,
output,
output_mask: Optional[Tensor] = None,
output_key_padding_mask: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
output2, attention_weights = self.self_attn(
hidden_states=output, position_embeddings=query_pos, attention_mask=output_mask, output_attentions=True
)
output = output + self.dropout(output2)
output = self.norm(output)
return output, attention_weights
def forward_pre(
self,
output,
output_mask: Optional[Tensor] = None,
output_key_padding_mask: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
output2 = self.norm(output)
output2, attention_weights = self.self_attn(
hidden_states=output2, position_embeddings=query_pos, attention_mask=output_mask, output_attentions=True
)
output = output + self.dropout(output2)
return output, attention_weights
def forward(
self,
output,
output_mask: Optional[Tensor] = None,
output_key_padding_mask: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
if self.normalize_before:
return self.forward_pre(output, output_mask, output_key_padding_mask, query_pos)
return self.forward_post(output, output_mask, output_key_padding_mask, query_pos)
class OneFormerTransformerDecoderCrossAttentionLayer(nn.Module):
def __init__(self, embed_dim, num_heads, dropout=0.0, activation="relu", normalize_before=False):
super().__init__()
self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
self.norm = nn.LayerNorm(embed_dim)
self.dropout = nn.Dropout(dropout)
self.activation = ACT2FN[activation]
self.normalize_before = normalize_before
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward_post(
self,
output,
memory,
memory_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
output2, attention_weights = self.multihead_attn(
query=self.with_pos_embed(output, query_pos),
key=self.with_pos_embed(memory, pos),
value=memory,
attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask,
)
output = output + self.dropout(output2)
output = self.norm(output)
return output, attention_weights
def forward_pre(
self,
output,
memory,
memory_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
output2 = self.norm(output)
output2, attention_weights = self.multihead_attn(
query=self.with_pos_embed(output2, query_pos),
key=self.with_pos_embed(memory, pos),
value=memory,
attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask,
)
output = output + self.dropout(output2)
return output, attention_weights
def forward(
self,
output,
memory,
memory_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
if self.normalize_before:
return self.forward_pre(output, memory, memory_mask, memory_key_padding_mask, pos, query_pos)
return self.forward_post(output, memory, memory_mask, memory_key_padding_mask, pos, query_pos)
class OneFormerTransformerDecoderFFNLayer(nn.Module):
def __init__(self, d_model, dim_feedforward=2048, dropout=0.0, activation="relu", normalize_before=False):
super().__init__()
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm = nn.LayerNorm(d_model)
self.activation = ACT2FN[activation]
self.normalize_before = normalize_before
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward_post(self, output):
output2 = self.linear2(self.dropout(self.activation(self.linear1(output))))
output = output + self.dropout(output2)
output = self.norm(output)
return output
def forward_pre(self, output):
output2 = self.norm(output)
output2 = self.linear2(self.dropout(self.activation(self.linear1(output2))))
output = output + self.dropout(output2)
return output
def forward(self, output):
if self.normalize_before:
return self.forward_pre(output)
return self.forward_post(output)
class OneFormerMLPPredictionHead(nn.Module):
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int = 3):
"""
A classic Multi Layer Perceptron (MLP).
Args:
input_dim (`int`):
The input dimensions.
hidden_dim (`int`):
The hidden dimensions.
output_dim (`int`):
The output dimensions.
num_layers (int, *optional*, defaults to 3):
The number of layers.
"""
super().__init__()
in_dims = [input_dim] + [hidden_dim] * (num_layers - 1)
out_dims = [hidden_dim] * (num_layers - 1) + [output_dim]
layers = []
for i, (in_dim, out_dim) in enumerate(zip(in_dims, out_dims)):
layers.append(
PredictionBlock(in_dim, out_dim, activation=nn.ReLU() if i < num_layers - 1 else nn.Identity())
)
self.layers = nn.Sequential(*layers)
def forward(self, input: Tensor) -> Tensor:
return self.layers(input)
# refactored from original implementation
class OneFormerTransformerDecoderLayer(nn.Module):
def __init__(self, config: OneFormerConfig):
super().__init__()
self.embed_dim = config.hidden_dim
self.num_feature_levels = 3
self.cross_attn = OneFormerTransformerDecoderCrossAttentionLayer(
embed_dim=self.embed_dim,
num_heads=config.num_attention_heads,
dropout=0.0,
normalize_before=config.pre_norm,
)
self.self_attn = OneFormerTransformerDecoderSelfAttentionLayer(
embed_dim=self.embed_dim,
num_heads=config.num_attention_heads,
dropout=0.0,
normalize_before=config.pre_norm,
)
self.ffn = OneFormerTransformerDecoderFFNLayer(
d_model=self.embed_dim,
dim_feedforward=config.dim_feedforward,
dropout=0.0,
normalize_before=config.pre_norm,
)
def forward(
self,
index: int,
output: torch.Tensor,
multi_stage_features: List[torch.Tensor],
multi_stage_positional_embeddings: List[torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
query_embeddings: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
):
"""
Args:
index (`int`): index of the layer in the Transformer decoder.
output (`torch.FloatTensor`): the object queries of shape `(N, batch, hidden_dim)`
multi_stage_features (`List[torch.Tensor]`): the multi-scale features from the pixel decoder.
multi_stage_positional_embeddings (`List[torch.Tensor]`):
positional embeddings for the multi_stage_features
attention_mask (`torch.FloatTensor`): attention mask for the masked cross attention layer
query_embeddings (`torch.FloatTensor`, *optional*):
position embeddings that are added to the queries and keys in the self-attention layer.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
level_index = index % self.num_feature_levels
attention_mask[torch.where(attention_mask.sum(-1) == attention_mask.shape[-1])] = False
# Masked Cross Attention
output, cross_attn_weights = self.cross_attn(
output,
multi_stage_features[level_index],
memory_mask=attention_mask,
memory_key_padding_mask=None, # here we do not apply masking on padded region
pos=multi_stage_positional_embeddings[level_index],
query_pos=query_embeddings,
)
# Self Attention
output, self_attn_weights = self.self_attn(
output,
output_mask=None,
output_key_padding_mask=None,
query_pos=query_embeddings,
)
# Fully Connected
output = self.ffn(output)
outputs = (output,)
if output_attentions:
outputs += (self_attn_weights, cross_attn_weights)
return outputs
class OneFormerTransformerDecoderQueryTransformerDecoder(nn.Module):
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
super().__init__()
self.layers = _get_clones(decoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
self.return_intermediate = return_intermediate
def forward(
self,
output,
memory,
output_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
output_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
intermediate = []
for layer in self.layers:
output = layer(
output,
memory,
output_mask=output_mask,
memory_mask=memory_mask,
output_key_padding_mask=output_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
pos=pos,
query_pos=query_pos,
)
if self.return_intermediate:
intermediate.append(self.norm(output))
if self.norm is not None:
output = self.norm(output)
if self.return_intermediate:
intermediate.pop()
intermediate.append(output)
if self.return_intermediate:
return torch.stack(intermediate)
return output.unsqueeze(0)
class OneFormerTransformerDecoderQueryTransformerDecoderLayer(nn.Module):
def __init__(
self,
d_model,
nhead,
dim_feedforward=2048,
dropout=0.1,
activation="relu",
normalize_before=False,
):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.activation = ACT2FN[activation]
self.normalize_before = normalize_before
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward_post(
self,
output,
memory,
output_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
output_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
q = k = self.with_pos_embed(output, query_pos)
output2 = self.self_attn(q, k, value=output, attn_mask=output_mask, key_padding_mask=output_key_padding_mask)
output2 = output2[0]
output = output + self.dropout1(output2)
output = self.norm1(output)
output2 = self.multihead_attn(
query=self.with_pos_embed(output, query_pos),
key=self.with_pos_embed(memory, pos),
value=memory,
attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask,
)
output2 = output2[0]
output = output + self.dropout2(output2)
output = self.norm2(output)
output2 = self.linear2(self.dropout(self.activation(self.linear1(output))))
output = output + self.dropout3(output2)
output = self.norm3(output)
return output
def forward_pre(
self,
output,
memory,
output_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
output_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
output2 = self.norm1(output)
q = k = self.with_pos_embed(output2, query_pos)
output2 = self.self_attn(q, k, value=output2, attn_mask=output_mask, key_padding_mask=output_key_padding_mask)
output2 = output2[0]
output = output + self.dropout1(output2)
output2 = self.norm2(output)
output2 = self.multihead_attn(
query=self.with_pos_embed(output2, query_pos),
key=self.with_pos_embed(memory, pos),
value=memory,
attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask,
)
output2 = output2[0]
output = output + self.dropout2(output2)
output2 = self.norm3(output)
output2 = self.linear2(self.dropout(self.activation(self.linear1(output2))))
output = output + self.dropout3(output2)
return output
def forward(
self,
output,
memory,
output_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
output_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
if self.normalize_before:
return self.forward_pre(
output,
memory,
output_mask,
memory_mask,
output_key_padding_mask,
memory_key_padding_mask,
pos,
query_pos,
)
return self.forward_post(
output,
memory,
output_mask,
memory_mask,
output_key_padding_mask,
memory_key_padding_mask,
pos,
query_pos,
)
class OneFormerTransformerDecoderQueryTransformer(nn.Module):
def __init__(
self,
d_model=512,
nhead=8,
num_decoder_layers=6,
dim_feedforward=2048,
dropout=0.1,
activation="relu",
normalize_before=False,
return_intermediate_dec=False,
):
super().__init__()
decoder_layer = OneFormerTransformerDecoderQueryTransformerDecoderLayer(
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
)
decoder_norm = nn.LayerNorm(d_model)
self.decoder = OneFormerTransformerDecoderQueryTransformerDecoder(
decoder_layer,
num_decoder_layers,
decoder_norm,
return_intermediate=return_intermediate_dec,
)
self.d_model = d_model
self.nhead = nhead
def forward(self, src, mask, query_embed, pos_embed, task_token=None):
batch_size = src.shape[0]
src = src.flatten(2).permute(2, 0, 1)
pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
query_embed = query_embed.unsqueeze(1).repeat(1, batch_size, 1)
if mask is not None:
mask = mask.flatten(1)
if task_token is None:
queries = torch.zeros_like(query_embed)
else:
queries = task_token.repeat(query_embed.shape[0], 1, 1)
queries = self.decoder(queries, src, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_embed)
return queries.transpose(1, 2)
class OneFormerTransformerDecoder(nn.Module):
"""
Transformer decoder
"""
def __init__(self, in_channels: int, config: OneFormerConfig):
super().__init__()
self.config = config
self.dropout = config.dropout
self.num_heads = config.num_attention_heads
self.is_training = config.is_training
self.use_task_norm = config.use_task_norm
self.use_auxiliary_loss = config.use_auxiliary_loss
self.query_transformer = OneFormerTransformerDecoderQueryTransformer(
d_model=config.hidden_dim,
dropout=config.dropout,
nhead=config.num_attention_heads,
dim_feedforward=config.dim_feedforward,
num_decoder_layers=config.query_dec_layers,
normalize_before=config.pre_norm,
return_intermediate_dec=False,
)
self.decoder_norm = nn.LayerNorm(config.hidden_dim)
self.num_feature_levels = 3
self.layers = nn.ModuleList(
[OneFormerTransformerDecoderLayer(config) for _ in range(config.decoder_layers - 1)]
)
self.query_input_projection = nn.Conv2d(in_channels, config.hidden_dim, kernel_size=1)
self.class_embed = nn.Linear(config.hidden_dim, config.num_labels + 1)
self.mask_embed = OneFormerMLPPredictionHead(
config.hidden_dim,
config.hidden_dim,
config.mask_dim,
3,
)
def forward(
self,
task_token=None,
multi_stage_features=None,
multi_stage_positional_embeddings=None,
mask_features=None,
query_features=None,
query_embeddings=None,
query_embedder=None,
size_list=None,
output_attentions=None,
):
if self.use_task_norm:
task_token = self.decoder_norm(task_token)
object_queries = self.query_transformer(
query_features,
None,
query_embedder.weight[:-1],
self.query_input_projection(mask_features),
task_token if self.use_task_norm else None,
)
object_queries = object_queries[0].permute(1, 0, 2)
queries = torch.cat([object_queries, task_token], dim=0)
output = queries.clone()
intermediate_class_predictions = []
intermediate_mask_predictions = []
# prediction heads on learnable query features
outputs_class, outputs_mask, attention_mask = self.forward_prediction_heads(
output, mask_features, attention_mask_target_size=size_list[0]
)
intermediate_class_predictions.append(outputs_class)
intermediate_mask_predictions.append(outputs_mask)
attentions = ()
for index, layer in enumerate(self.layers):
layer_outputs = layer(
index=index,
output=output,
multi_stage_features=multi_stage_features,
multi_stage_positional_embeddings=multi_stage_positional_embeddings,
attention_mask=attention_mask,
query_embeddings=query_embeddings,
output_attentions=output_attentions,
)
output = layer_outputs[0]
attentions += (layer_outputs[1:],)
outputs_class, outputs_mask, attention_mask = self.forward_prediction_heads(
output, mask_features, attention_mask_target_size=size_list[(index + 1) % self.num_feature_levels]
)
intermediate_class_predictions.append(outputs_class)
intermediate_mask_predictions.append(outputs_mask)
if not len(intermediate_mask_predictions) == len(self.layers) + 1:
raise ValueError(
"Intermediate predictions in the transformer decoder must have the same number of elements as number"
" of layers"
)
object_queries = layer_outputs[0].permute(1, 0, 2)
contrastive_logits = queries.permute(1, 0, 2)
return OneFormerTransformerDecoderOutput(
object_queries=object_queries,
contrastive_logits=contrastive_logits,
prediction_masks=intermediate_mask_predictions[-1],
prediction_class=intermediate_class_predictions[-1],
auxiliary_predictions=self._get_aux_predictions(
intermediate_class_predictions, intermediate_mask_predictions
)
if self.use_auxiliary_loss
else None,
attentions=attentions,
)
def forward_prediction_heads(self, output, mask_features, attention_mask_target_size):
decoder_output = self.decoder_norm(output)
decoder_output = decoder_output.transpose(0, 1)
outputs_class = self.class_embed(decoder_output)
mask_embed = self.mask_embed(decoder_output)
outputs_mask = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features)
attention_mask = nn.functional.interpolate(
outputs_mask, size=attention_mask_target_size, mode="bilinear", align_corners=False
)
# must use bool type
# If a BoolTensor is provided, positions with ``True`` are not allowed to attend while ``False`` values will be unchanged.
attention_mask = (
attention_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, self.num_heads, 1, 1).flatten(0, 1) < 0.5
).bool()
attention_mask = attention_mask.detach()
return outputs_class, outputs_mask, attention_mask
@torch.jit.unused
def _get_aux_predictions(self, outputs_class, outputs_seg_masks):
# this is a workaround to make torchscript happy, as torchscript
# doesn't support dictionary with non-homogeneous values, such
# as a dict having both a Tensor and a list.
aux_list = [
{"class_queries_logits": a, "masks_queries_logits": b}
for a, b in zip(outputs_class[:-1], outputs_seg_masks[:-1])
]
return tuple(aux_list)
class OneFormerTransformerModule(nn.Module):
"""
The OneFormer's transformer module.
"""
def __init__(self, in_features: int, config: OneFormerConfig):
super().__init__()
hidden_dim = config.hidden_dim
self.num_feature_levels = 3
self.position_embedder = OneFormerSinePositionEmbedding(num_pos_feats=hidden_dim // 2, normalize=True)
self.queries_embedder = nn.Embedding(config.num_queries, hidden_dim)
self.input_projections = []
for _ in range(self.num_feature_levels):
if in_features != hidden_dim or config.enforce_input_proj:
self.input_projections.append(nn.Conv2d(in_features, hidden_dim, kernel_size=1))
else:
self.input_projections.append(nn.Sequential())
self.decoder = OneFormerTransformerDecoder(in_channels=in_features, config=config)
self.level_embed = nn.Embedding(self.num_feature_levels, hidden_dim)
def forward(
self,
multi_scale_features: List[Tensor],
mask_features: Tensor,
task_token: Tensor,
output_attentions: bool = False,
) -> OneFormerTransformerDecoderOutput:
if not len(multi_scale_features) == self.num_feature_levels:
raise ValueError(
f"Number of elements in multi_scale_features ({len(multi_scale_features)}) and num_feature_levels"
f" ({self.num_feature_levels}) do not match!"
)
multi_stage_features = []
multi_stage_positional_embeddings = []
size_list = []
for i in range(self.num_feature_levels):
size_list.append(multi_scale_features[i].shape[-2:])
multi_stage_positional_embeddings.append(self.position_embedder(multi_scale_features[i], None).flatten(2))
multi_stage_features.append(
self.input_projections[i](multi_scale_features[i]).flatten(2)
+ self.level_embed.weight[i][None, :, None]
)
# flatten NxCxHxW to HWxNxC
multi_stage_positional_embeddings[-1] = multi_stage_positional_embeddings[-1].permute(2, 0, 1)
multi_stage_features[-1] = multi_stage_features[-1].permute(2, 0, 1)
_, batch_size, _ = multi_stage_features[0].shape
# QxNxC
query_embeddings = self.queries_embedder.weight.unsqueeze(1).repeat(1, batch_size, 1)
task_token = task_token.unsqueeze(0)
query_features = self.position_embedder(mask_features, None)
return self.decoder(
task_token=task_token,
multi_stage_features=multi_stage_features,
multi_stage_positional_embeddings=multi_stage_positional_embeddings,
mask_features=mask_features,
query_features=query_features,
query_embeddings=query_embeddings,
query_embedder=self.queries_embedder,
size_list=size_list,
output_attentions=output_attentions,
)
# Copied from transformers.models.maskformer.modeling_maskformer.MaskFormerSinePositionEmbedding with Mask->One
class OneFormerSinePositionEmbedding(nn.Module):
"""
This is a more standard version of the position embedding, very similar to the one used by the Attention is all you
need paper, generalized to work on images.
"""
def __init__(
self, num_pos_feats: int = 64, temperature: int = 10000, normalize: bool = False, scale: Optional[float] = None
):
super().__init__()
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
self.num_pos_feats = num_pos_feats
self.temperature = temperature
self.normalize = normalize
self.scale = 2 * math.pi if scale is None else scale
def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor:
if mask is None:
mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
not_mask = ~mask
y_embed = not_mask.cumsum(1, dtype=torch.float32)
x_embed = not_mask.cumsum(2, dtype=torch.float32)
if self.normalize:
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
return pos
# Copied from transformers.models.maskformer.modeling_maskformer.PredictionBlock
class PredictionBlock(nn.Module):
def __init__(self, in_dim: int, out_dim: int, activation: nn.Module) -> None:
super().__init__()
self.layers = [nn.Linear(in_dim, out_dim), activation]
# Maintain submodule indexing as if part of a Sequential block
for i, layer in enumerate(self.layers):
self.add_module(str(i), layer)
def forward(self, input: Tensor) -> Tensor:
hidden_state = input
for layer in self.layers:
hidden_state = layer(hidden_state)
return hidden_state
class OneFormerTextMapperAttention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self.scale = qk_scale or head_dim**-0.5
self.q_proj = nn.Linear(dim, dim, bias=qkv_bias)
self.k_proj = nn.Linear(dim, dim, bias=qkv_bias)
self.v_proj = nn.Linear(dim, dim, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, q, k, v):
batch_size, q_sequence_length, num_channels = q.shape
if not k.shape == v.shape:
raise ValueError(f"keys ({list(k.shape)}) and values ({list(v.shape)}) have different shapes!")
batch_size, k_sequence_length, num_channels = k.shape
q = self.q_proj(q).reshape(batch_size, q_sequence_length, self.num_heads, num_channels // self.num_heads)
k = self.k_proj(k).reshape(batch_size, k_sequence_length, self.num_heads, num_channels // self.num_heads)
v = self.v_proj(v).reshape(batch_size, k_sequence_length, self.num_heads, num_channels // self.num_heads)
attn = torch.einsum("bnkc,bmkc->bknm", q, k) * self.scale
attn = attn.softmax(dim=-1)
output = torch.einsum("bknm,bmkc->bnkc", attn, v).reshape(batch_size, q_sequence_length, num_channels)
output = self.proj(output)
output = self.proj_drop(output)
return output
class OneFormerTextTransformerDecoderLayer(nn.Module):
def __init__(
self,
d_model,
nhead,
dropout=0.1,
):
super().__init__()
self.self_attn = OneFormerTextMapperAttention(d_model, nhead, proj_drop=dropout)
self.cross_attn = OneFormerTextMapperAttention(d_model, nhead, proj_drop=dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
self.mlp = nn.Sequential(
nn.Linear(d_model, d_model * 4), nn.GELU(), nn.Dropout(dropout), nn.Linear(d_model * 4, d_model)
)
def forward(self, hidden_state, mem):
q = k = v = self.norm1(hidden_state)
hidden_state = hidden_state + self.self_attn(q, k, v)
q = self.norm2(hidden_state)
hidden_state = hidden_state + self.cross_attn(q, mem, mem)
hidden_state = hidden_state + self.dropout(self.mlp(self.norm3(hidden_state)))
return hidden_state
class OneFormerTextContextDecoder(nn.Module):
def __init__(
self, transformer_width=256, transformer_heads=4, transformer_layers=6, visual_dim=1024, dropout=0.1, **kwargs
):
super().__init__()
self.memory_proj = nn.Sequential(
nn.LayerNorm(visual_dim),
nn.Linear(visual_dim, transformer_width),
nn.LayerNorm(transformer_width),
)
self.text_proj = nn.Sequential(
nn.LayerNorm(visual_dim),
nn.Linear(visual_dim, transformer_width),
)
self.decoder = nn.ModuleList(
[
OneFormerTextTransformerDecoderLayer(transformer_width, transformer_heads, dropout)
for _ in range(transformer_layers)
]
)
self.out_proj = nn.Sequential(nn.LayerNorm(transformer_width), nn.Linear(transformer_width, visual_dim))
def forward(self, text, visual):
visual = self.memory_proj(visual)
hidden_state = self.text_proj(text)
for layer in self.decoder:
hidden_state = layer(hidden_state, visual)
return self.out_proj(hidden_state)
class OneFormerTextMLP(nn.Module):
def __init__(
self,
hidden_size: Optional[int] = None,
intermediate_size: Optional[int] = None,
output_size: Optional[int] = None,
):
super().__init__()
self.activation_fn = ACT2FN["quick_gelu"]
hidden_size = hidden_size
intermediate_size = intermediate_size
output_size = output_size
self.fc1 = nn.Linear(hidden_size, intermediate_size)
self.fc2 = nn.Linear(intermediate_size, output_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class OneFormerTextTransformerLayer(nn.Module):
def __init__(self, width: int, heads: int, attn_mask: torch.Tensor):
super().__init__()
self.self_attn = nn.MultiheadAttention(width, heads)
self.layer_norm1 = nn.LayerNorm(width)
self.mlp = OneFormerTextMLP(width, width * 4, width)
self.layer_norm2 = nn.LayerNorm(width)
self.attn_mask = attn_mask
def forward(
self,
hidden_states: torch.Tensor,
key_padding_mask: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states = self.self_attn(
hidden_states,
hidden_states,
hidden_states,
need_weights=False,
key_padding_mask=key_padding_mask,
)[0]
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class OneFormerTextTransformer(nn.Module):
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, use_checkpoint=False):
super().__init__()
self.width = width
self.num_layers = layers
self.layers = nn.Sequential(*[OneFormerTextTransformerLayer(width, heads, attn_mask) for _ in range(layers)])
self.use_checkpoint = use_checkpoint
def forward(self, hidden_states: torch.Tensor):
for layer in self.layers:
if self.use_checkpoint:
hidden_states = torch.utils.checkpoint.checkpoint(layer, hidden_states)
else:
hidden_states = layer(hidden_states)
return hidden_states
class OneFormerTextEncoder(nn.Module):
def __init__(
self,
context_length: int,
width: int,
layers: int,
vocab_size,
use_checkpoint=False,
):
super().__init__()
heads = width // 64
self.context_length = context_length
self.width = width
self.transformer = OneFormerTextTransformer(
width=width,
layers=layers,
heads=heads,
attn_mask=self.build_attention_mask(),
use_checkpoint=use_checkpoint,
)
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, width))
self.ln_final = nn.LayerNorm(width)
self.token_embedding = nn.Embedding(vocab_size, width)
def build_attention_mask(self):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(self.context_length, self.context_length)
mask.fill_(float("-inf"))
mask.triu_(1) # zero out the lower diagonal
return mask
def forward(self, text):
hidden_state = self.token_embedding(text)
hidden_state = hidden_state + self.positional_embedding
hidden_state = hidden_state.permute(1, 0, 2)
hidden_state = self.transformer(hidden_state)
hidden_state = hidden_state.permute(1, 0, 2)
hidden_state = self.ln_final(hidden_state)
hidden_state = hidden_state[torch.arange(hidden_state.shape[0]), text.argmax(dim=-1)]
return hidden_state
class OneFormerTextMapper(nn.Module):
def __init__(self, config: OneFormerConfig):
super().__init__()
self.text_encoder = OneFormerTextEncoder(
context_length=config.text_encoder_context_length,
width=config.text_encoder_width,
layers=config.text_encoder_num_layers,
vocab_size=config.text_encoder_vocab_size,
)
self.text_projector = OneFormerMLPPredictionHead(
config.text_encoder_width,
config.hidden_dim,
config.hidden_dim,
config.text_encoder_proj_layers,
)
if config.text_encoder_n_ctx > 0:
self.prompt_ctx = nn.Embedding(
config.text_encoder_n_ctx,
config.text_encoder_width,
)
else:
self.prompt_ctx = None
def forward(
self,
inputs: Tensor,
) -> Tensor:
text_queries = self.encode_text(inputs)
return text_queries
def encode_text(self, text):
if text.ndim is None:
raise ValueError("text must not be NoneType")
if text.ndim not in [2, 3]:
raise ValueError("Number of dimensions in text must be 2 or 3")
squeeze_dim = False
num_text = 1
if text.ndim == 3:
num_text = text.shape[1]
batch_size, num_text, hidden_dim = text.shape
text = text.reshape(batch_size * num_text, hidden_dim)
squeeze_dim = True
# [batch_size, num_channels]
encoded_text = self.text_encoder(text)
text_queries = self.text_projector(encoded_text)
if squeeze_dim:
_, hidden_dim = text_queries.shape
text_queries = text_queries.reshape(batch_size, num_text, hidden_dim)
if self.prompt_ctx is not None:
text_queries_ctx = self.prompt_ctx.weight.unsqueeze(0).repeat(text_queries.shape[0], 1, 1)
text_queries = torch.cat([text_queries, text_queries_ctx], dim=1)
return text_queries
class OneFormerTaskModel(nn.Module):
def __init__(self, config: OneFormerConfig):
super().__init__()
self.task_mlp = OneFormerMLPPredictionHead(
config.task_seq_len,
config.hidden_dim,
config.hidden_dim,
2,
)
def forward(self, inputs: Tensor) -> Tensor:
task_tokens = self.task_mlp(inputs.float())
return task_tokens
ONEFORMER_START_DOCSTRING = r"""
This model is a PyTorch [nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use it as a
regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
Parameters:
config ([`OneFormerConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
ONEFORMER_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using [`OneFormerProcessor`]. See
[`OneFormerProcessor.__call__`] for details.
task_inputs (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Task inputs. Task inputs can be obtained using [`OneFormerImageProcessor`]. See
[`OneFormerProcessor.__call__`] for details.
pixel_mask (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
Mask to avoid performing attention on padding pixel values. Mask values selected in `[0, 1]`:
- 1 for pixels that are real (i.e. **not masked**),
- 0 for pixels that are padding (i.e. **masked**).
[What are attention masks?](../glossary#attention-mask)
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of Detr's decoder attention layers.
return_dict (`bool`, *optional*):
Whether or not to return a [`~OneFormerModelOutput`] instead of a plain tuple.
"""
class OneFormerPreTrainedModel(PreTrainedModel):
config_class = OneFormerConfig
base_model_prefix = "model"
main_input_name = "pixel_values"
def _init_weights(self, module: nn.Module):
xavier_std = self.config.init_xavier_std
std = self.config.init_std
if isinstance(module, OneFormerTransformerModule):
if module.input_projections is not None:
for input_projection in module.input_projections:
if not isinstance(input_projection, nn.Sequential):
nn.init.xavier_uniform_(input_projection.weight, gain=xavier_std)
nn.init.constant_(input_projection.bias, 0)
elif isinstance(module, OneFormerTransformerDecoder):
nn.init.xavier_uniform_(module.query_input_projection.weight, gain=xavier_std)
nn.init.constant_(module.query_input_projection.bias, 0)
elif isinstance(module, OneFormerPixelDecoderEncoderMultiscaleDeformableAttention):
nn.init.constant_(module.sampling_offsets.weight.data, 0.0)
thetas = torch.arange(module.n_heads, dtype=torch.float32) * (2.0 * math.pi / module.n_heads)
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
grid_init = (
(grid_init / grid_init.abs().max(-1, keepdim=True)[0])
.view(module.n_heads, 1, 1, 2)
.repeat(1, module.n_levels, module.n_points, 1)
)
for i in range(module.n_points):
grid_init[:, :, i, :] *= i + 1
with torch.no_grad():
module.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
nn.init.constant_(module.attention_weights.weight.data, 0.0)
nn.init.constant_(module.attention_weights.bias.data, 0.0)
nn.init.xavier_uniform_(module.value_proj.weight.data)
nn.init.constant_(module.value_proj.bias.data, 0.0)
nn.init.xavier_uniform_(module.output_proj.weight.data)
nn.init.constant_(module.output_proj.bias.data, 0.0)
elif isinstance(module, OneFormerPixelDecoderEncoderOnly):
for p in module.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
elif isinstance(module, OneFormerPixelDecoder):
for p in module.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
nn.init.normal_(module.level_embed, std=0)
elif isinstance(module, OneFormerTransformerDecoderSelfAttentionLayer):
for p in module.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p, gain=xavier_std)
elif isinstance(module, OneFormerTransformerDecoderCrossAttentionLayer):
for p in module.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p, gain=xavier_std)
elif isinstance(module, OneFormerTransformerDecoderFFNLayer):
for p in module.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p, gain=xavier_std)
elif isinstance(module, OneFormerTransformerDecoderQueryTransformer):
for p in module.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p, gain=xavier_std)
elif isinstance(module, OneFormerPixelLevelModule):
for submodule in module.modules():
if isinstance(submodule, (nn.Conv2d, nn.Linear)):
submodule.weight.data.normal_(mean=0.0, std=std)
if submodule.bias is not None:
submodule.bias.data.zero_()
elif isinstance(module, OneFormerTextContextDecoder):
for submodule in module.modules():
if isinstance(submodule, nn.Linear):
nn.init.trunc_normal_(submodule.weight, std=0.02)
if isinstance(submodule, nn.Linear) and submodule.bias is not None:
nn.init.constant_(submodule.bias, 0)
elif isinstance(submodule, nn.LayerNorm):
nn.init.constant_(submodule.bias, 0)
nn.init.constant_(submodule.weight, 1.0)
elif isinstance(module, OneFormerTextTransformer):
proj_std = (module.width**-0.5) * ((2 * module.num_layers) ** -0.5)
attn_std = module.width**-0.5
fc_std = (2 * module.width) ** -0.5
for layer in module.layers:
nn.init.normal_(layer.self_attn.in_proj_weight, std=attn_std)
nn.init.normal_(layer.self_attn.out_proj.weight, std=proj_std)
nn.init.normal_(layer.mlp.fc1.weight, std=fc_std)
nn.init.normal_(layer.mlp.fc2.weight, std=proj_std)
elif isinstance(module, OneFormerTextEncoder):
nn.init.normal_(module.token_embedding.weight, std=0.02)
nn.init.normal_(module.positional_embedding, std=0.01)
if hasattr(module, "reference_points"):
nn.init.xavier_uniform_(module.reference_points.weight.data, gain=1.0)
nn.init.constant_(module.reference_points.bias.data, 0.0)
elif isinstance(module, OneFormerTaskModel):
for submodule in module.modules():
if isinstance(module, OneFormerMLPPredictionHead):
for submodule in module.modules():
if isinstance(submodule, nn.Linear):
nn.init.xavier_uniform_(submodule.weight, gain=xavier_std)
nn.init.constant_(submodule.bias, 0)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, nn.MultiheadAttention):
module.in_proj_weight.data.normal_(mean=0.0, std=std)
module.in_proj_bias.data.zero_()
elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
@add_start_docstrings(
"The bare OneFormer Model outputting raw hidden-states without any specific head on top.",
ONEFORMER_START_DOCSTRING,
)
class OneFormerModel(OneFormerPreTrainedModel):
main_input_name = ["pixel_values", "task_inputs"]
def __init__(self, config: OneFormerConfig):
super().__init__(config)
self.pixel_level_module = OneFormerPixelLevelModule(config)
self.transformer_module = OneFormerTransformerModule(in_features=config.conv_dim, config=config)
self.task_encoder = OneFormerTaskModel(config)
self.is_training = config.is_training
if self.is_training:
self.text_mapper = OneFormerTextMapper(config)
else:
self.text_mapper = None
self.post_init()
@add_start_docstrings_to_model_forward(ONEFORMER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=OneFormerModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values: Tensor,
task_inputs: Tensor,
text_inputs: Optional[Tensor] = None,
pixel_mask: Optional[Tensor] = None,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> OneFormerModelOutput:
r"""
Returns:
`OneFormerModelOutput`
Example:
```python
>>> import torch
>>> from PIL import Image
>>> import requests
>>> from transformers import OneFormerProcessor, OneFormerModel
>>> # download texting image
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> # load processor for preprocessing the inputs
>>> processor = OneFormerProcessor.from_pretrained("shi-labs/oneformer_ade20k_swin_tiny")
>>> model = OneFormerModel.from_pretrained("shi-labs/oneformer_ade20k_swin_tiny")
>>> inputs = processor(image, ["semantic"], return_tensors="pt")
>>> with torch.no_grad():
... outputs = model(**inputs)
>>> mask_predictions = outputs.transformer_decoder_mask_predictions
>>> class_predictions = outputs.transformer_decoder_class_predictions
>>> f"👉 Mask Predictions Shape: {list(mask_predictions.shape)}, Class Predictions Shape: {list(class_predictions.shape)}"
'👉 Mask Predictions Shape: [1, 150, 128, 176], Class Predictions Shape: [1, 150, 151]'
```"""
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
batch_size, _, height, width = pixel_values.shape
if pixel_mask is None:
pixel_mask = torch.ones((batch_size, height, width), device=pixel_values.device)
pixel_level_module_output = self.pixel_level_module(pixel_values, output_hidden_states)
multi_scale_features = pixel_level_module_output.decoder_features
mask_features = pixel_level_module_output.decoder_last_feature
task_token = self.task_encoder(task_inputs)
if self.is_training:
text_queries = self.text_mapper(text_inputs)
else:
text_queries = None
transformer_module_output = self.transformer_module(
multi_scale_features=multi_scale_features,
mask_features=mask_features,
task_token=task_token,
output_attentions=output_attentions,
)
queries = transformer_module_output.object_queries
encoder_hidden_states = None
pixel_decoder_hidden_states = None
transformer_decoder_hidden_states = None
if output_hidden_states:
encoder_hidden_states = pixel_level_module_output.encoder_features
pixel_decoder_hidden_states = (pixel_level_module_output.decoder_last_feature,)
for f in pixel_level_module_output.decoder_features:
pixel_decoder_hidden_states += (f,)
transformer_decoder_hidden_states = transformer_module_output.auxiliary_predictions
output = OneFormerModelOutput(
encoder_hidden_states=encoder_hidden_states,
pixel_decoder_hidden_states=pixel_decoder_hidden_states,
transformer_decoder_hidden_states=transformer_decoder_hidden_states,
transformer_decoder_object_queries=queries,
transformer_decoder_contrastive_queries=transformer_module_output.contrastive_logits,
transformer_decoder_mask_predictions=transformer_module_output.prediction_masks,
transformer_decoder_class_predictions=transformer_module_output.prediction_class,
transformer_decoder_auxiliary_predictions=transformer_module_output.auxiliary_predictions,
text_queries=text_queries,
task_token=task_token,
attentions=transformer_module_output.attentions,
)
if not return_dict:
output = tuple(v for v in output.values())
return output
@add_start_docstrings(
"OneFormer Model for instance, semantic and panoptic image segmentation.",
ONEFORMER_START_DOCSTRING,
)
class OneFormerForUniversalSegmentation(OneFormerPreTrainedModel):
main_input_name = ["pixel_values", "task_inputs"]
def __init__(self, config: OneFormerConfig):
super().__init__(config)
self.model = OneFormerModel(config)
self.matcher = OneFormerHungarianMatcher(
cost_class=config.class_weight,
cost_dice=config.dice_weight,
cost_mask=config.mask_weight,
num_points=config.train_num_points,
)
self.weight_dict: Dict[str, float] = {
"loss_cross_entropy": config.class_weight,
"loss_mask": config.mask_weight,
"loss_dice": config.dice_weight,
"loss_contrastive": config.contrastive_weight,
}
self.criterion = OneFormerLoss(
num_classes=config.num_labels,
matcher=self.matcher,
weight_dict=self.weight_dict,
eos_coef=config.no_object_weight,
num_points=config.train_num_points,
oversample_ratio=config.oversample_ratio,
importance_sample_ratio=config.importance_sample_ratio,
contrastive_temperature=config.contrastive_temperature,
)
self.post_init()
def get_loss_dict(
self,
masks_queries_logits: Tensor,
class_queries_logits: Tensor,
contrastive_queries_logits: Tensor,
mask_labels: Tensor,
class_labels: Tensor,
text_queries: Tensor,
auxiliary_predictions: Dict[str, Tensor],
calculate_contrastive_loss: bool,
) -> Dict[str, Tensor]:
loss_dict: Dict[str, Tensor] = self.criterion(
masks_queries_logits=masks_queries_logits,
class_queries_logits=class_queries_logits,
contrastive_queries_logits=contrastive_queries_logits,
mask_labels=mask_labels,
class_labels=class_labels,
text_queries=text_queries,
auxiliary_predictions=auxiliary_predictions,
calculate_contrastive_loss=calculate_contrastive_loss,
)
# weight each loss by `self.weight_dict[<LOSS_NAME>]` including auxiliary losses
for key, weight in self.weight_dict.items():
for loss_key, loss in loss_dict.items():
if key in loss_key:
loss *= weight
return loss_dict
def get_loss(self, loss_dict: Dict[str, Tensor]) -> Tensor:
return sum(loss_dict.values())
@add_start_docstrings_to_model_forward(ONEFORMER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=OneFormerForUniversalSegmentationOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values: Tensor,
task_inputs: Tensor,
text_inputs: Optional[Tensor] = None,
mask_labels: Optional[List[Tensor]] = None,
class_labels: Optional[List[Tensor]] = None,
pixel_mask: Optional[Tensor] = None,
output_auxiliary_logits: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> OneFormerForUniversalSegmentationOutput:
r"""
text_inputs (`List[torch.Tensor]`, *optional*):
Tensor fof shape `(num_queries, sequence_length)` to be fed to a model
mask_labels (`List[torch.Tensor]`, *optional*):
List of mask labels of shape `(num_labels, height, width)` to be fed to a model
class_labels (`List[torch.LongTensor]`, *optional*):
list of target class labels of shape `(num_labels, height, width)` to be fed to a model. They identify the
labels of `mask_labels`, e.g. the label of `mask_labels[i][j]` if `class_labels[i][j]`.
Returns:
`OneFormerUniversalSegmentationOutput`
Example:
Universal segmentation example:
```python
>>> from transformers import OneFormerProcessor, OneFormerForUniversalSegmentation
>>> from PIL import Image
>>> import requests
>>> import torch
>>> # load OneFormer fine-tuned on ADE20k for universal segmentation
>>> processor = OneFormerProcessor.from_pretrained("shi-labs/oneformer_ade20k_swin_tiny")
>>> model = OneFormerForUniversalSegmentation.from_pretrained("shi-labs/oneformer_ade20k_swin_tiny")
>>> url = (
... "https://huggingface.co/datasets/hf-internal-testing/fixtures_ade20k/resolve/main/ADE_val_00000001.jpg"
... )
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> # Semantic Segmentation
>>> inputs = processor(image, ["semantic"], return_tensors="pt")
>>> with torch.no_grad():
... outputs = model(**inputs)
>>> # model predicts class_queries_logits of shape `(batch_size, num_queries)`
>>> # and masks_queries_logits of shape `(batch_size, num_queries, height, width)`
>>> class_queries_logits = outputs.class_queries_logits
>>> masks_queries_logits = outputs.masks_queries_logits
>>> # you can pass them to feature_extractor for semantic postprocessing
>>> predicted_semantic_map = feature_extractor.post_process_semantic_segmentation(
... outputs, target_sizes=[image.size[::-1]]
... )[0]
>>> f"👉 Semantic Predictions Shape: {list(predicted_semantic_map.shape)}"
'👉 Semantic Predictions Shape: [512, 683]'
>>> # Instance Segmentation
>>> inputs = processor(image, ["instance"], return_tensors="pt")
>>> with torch.no_grad():
... outputs = model(**inputs)
>>> # model predicts class_queries_logits of shape `(batch_size, num_queries)`
>>> # and masks_queries_logits of shape `(batch_size, num_queries, height, width)`
>>> class_queries_logits = outputs.class_queries_logits
>>> masks_queries_logits = outputs.masks_queries_logits
>>> # you can pass them to feature_extractor for instance postprocessing
>>> predicted_instance_map = feature_extractor.post_process_instance_segmentation(
... outputs, target_sizes=[image.size[::-1]]
... )[0]["segmentation"]
>>> f"👉 Instance Predictions Shape: {list(predicted_instance_map.shape)}"
'👉 Instance Predictions Shape: [512, 683]'
>>> # Panoptic Segmentation
>>> inputs = processor(image, ["panoptic"], return_tensors="pt")
>>> with torch.no_grad():
... outputs = model(**inputs)
>>> # model predicts class_queries_logits of shape `(batch_size, num_queries)`
>>> # and masks_queries_logits of shape `(batch_size, num_queries, height, width)`
>>> class_queries_logits = outputs.class_queries_logits
>>> masks_queries_logits = outputs.masks_queries_logits
>>> # you can pass them to feature_extractor for panoptic postprocessing
>>> predicted_panoptic_map = feature_extractor.post_process_panoptic_segmentation(
... outputs, target_sizes=[image.size[::-1]]
... )[0]["segmentation"]
>>> f"👉 Panoptic Predictions Shape: {list(predicted_panoptic_map.shape)}"
'👉 Panoptic Predictions Shape: [512, 683]'
```
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
pixel_values=pixel_values,
task_inputs=task_inputs,
text_inputs=text_inputs,
pixel_mask=pixel_mask,
output_hidden_states=output_hidden_states or self.config.use_auxiliary_loss,
output_attentions=output_attentions,
return_dict=True,
)
loss, loss_dict, auxiliary_predictions = None, None, None
class_queries_logits = outputs.transformer_decoder_class_predictions
masks_queries_logits = outputs.transformer_decoder_mask_predictions
contrastive_queries_logits = outputs.transformer_decoder_contrastive_queries
auxiliary_predictions = outputs.transformer_decoder_auxiliary_predictions
text_queries = outputs.text_queries
if mask_labels is not None and class_labels is not None:
loss_dict: Dict[str, Tensor] = self.get_loss_dict(
masks_queries_logits=masks_queries_logits,
class_queries_logits=class_queries_logits,
contrastive_queries_logits=contrastive_queries_logits,
mask_labels=mask_labels,
class_labels=class_labels,
text_queries=text_queries,
auxiliary_predictions=auxiliary_predictions,
calculate_contrastive_loss=self.config.contrastive_temperature is not None,
)
loss = self.get_loss(loss_dict)
output_auxiliary_logits = (
self.config.output_auxiliary_logits if output_auxiliary_logits is None else output_auxiliary_logits
)
if not output_auxiliary_logits:
auxiliary_predictions = None
output = OneFormerForUniversalSegmentationOutput(
class_queries_logits=class_queries_logits,
masks_queries_logits=masks_queries_logits,
auxiliary_predictions=auxiliary_predictions,
loss=loss,
**outputs,
)
if not return_dict:
output = tuple(v for v in output.values())
if loss is not None:
output = ((loss)) + output
return output
# coding=utf-8
# Copyright 2022 SHI Labs and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Image/Text processor class for OneFormer
"""
from typing import List
from transformers.utils import is_torch_available
from ...processing_utils import ProcessorMixin
if is_torch_available():
import torch
class OneFormerProcessor(ProcessorMixin):
r"""
Constructs an OneFormer processor which wraps [`OneFormerImageProcessor`] and
[`CLIPTokenizer`]/[`CLIPTokenizerFast`] into a single processor that inherits both the image processor and
tokenizer functionalities.
Args:
image_processor ([`OneFormerImageProcessor`]):
The image processor is a required input.
tokenizer ([`CLIPTokenizer`, `CLIPTokenizerFast`]):
The tokenizer is a required input.
max_seq_len (`int`, *optional*, defaults to 77)):
Sequence length for input text list.
task_seq_len (`int`, *optional*, defaults to 77):
Sequence length for input task token.
"""
attributes = ["image_processor", "tokenizer"]
image_processor_class = "OneFormerImageProcessor"
tokenizer_class = ("CLIPTokenizer", "CLIPTokenizerFast")
def __init__(
self, image_processor=None, tokenizer=None, max_seq_length: int = 77, task_seq_length: int = 77, **kwargs
):
if image_processor is None:
raise ValueError("You need to specify an `image_processor`.")
if tokenizer is None:
raise ValueError("You need to specify a `tokenizer`.")
self.max_seq_length = max_seq_length
self.task_seq_length = task_seq_length
super().__init__(image_processor, tokenizer)
def _preprocess_text(self, text_list=None, max_length=77):
if text_list is None:
raise ValueError("tokens cannot be None.")
tokens = self.tokenizer(text_list, padding="max_length", max_length=max_length, truncation=True)
attention_masks, input_ids = tokens["attention_mask"], tokens["input_ids"]
token_inputs = []
for attn_mask, input_id in zip(attention_masks, input_ids):
token = torch.tensor(attn_mask) * torch.tensor(input_id)
token_inputs.append(token.unsqueeze(0))
token_inputs = torch.cat(token_inputs, dim=0)
return token_inputs
def __call__(self, images=None, task_inputs=None, segmentation_maps=None, **kwargs):
"""
Main method to prepare for the model one or several task input(s) and image(s). This method forwards the
`task_inputs` and `kwargs` arguments to CLIPTokenizer's [`~CLIPTokenizer.__call__`] if `task_inputs` is not
`None` to encode. To prepare the image(s), this method forwards the `images` and `kwargs` arguments to
OneFormerImageProcessor's [`~OneFormerImageProcessor.__call__`] if `images` is not `None`. Please refer to the
doctsring of the above two methods for more information.
Args:
task_inputs (`str`, `List[str]`):
The sequence or batch of task_inputs sequences to be encoded. Each sequence can be a string or a list
of strings of the template "the task is {task}".
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`,
`List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
number of channels, H and W are image height and width.
segmentation_maps (`ImageInput`, *optional*):
The corresponding semantic segmentation maps with the pixel-wise annotations.
(`bool`, *optional*, defaults to `True`):
Whether or not to pad images up to the largest image in a batch and create a pixel mask.
If left to the default, will return a pixel mask that is:
- 1 for pixels that are real (i.e. **not masked**),
- 0 for pixels that are padding (i.e. **masked**).
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **task_inputs** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
"""
if task_inputs is None:
raise ValueError("You have to specify the task_input. Found None.")
elif images is None:
raise ValueError("You have to specify the image. Found None.")
if not all(task in ["semantic", "instance", "panoptic"] for task in task_inputs):
raise ValueError("task_inputs must be semantic, instance, or panoptic.")
encoded_inputs = self.image_processor(images, task_inputs, segmentation_maps, **kwargs)
if isinstance(task_inputs, str):
task_inputs = [task_inputs]
if isinstance(task_inputs, List) and all(isinstance(task_input, str) for task_input in task_inputs):
task_token_inputs = []
for task in task_inputs:
task_input = f"the task is {task}"
task_token_inputs.append(task_input)
encoded_inputs["task_inputs"] = self._preprocess_text(task_token_inputs, max_length=self.task_seq_length)
else:
raise TypeError("Task Inputs should be a string or a list of strings.")
if hasattr(encoded_inputs, "text_inputs"):
texts_list = encoded_inputs.text_inputs
text_inputs = []
for texts in texts_list:
text_input_list = self._preprocess_text(texts, max_length=self.max_seq_length)
text_inputs.append(text_input_list.unsqueeze(0))
encoded_inputs["text_inputs"] = torch.cat(text_inputs, dim=0)
return encoded_inputs
def encode_inputs(self, images=None, task_inputs=None, segmentation_maps=None, **kwargs):
"""
This method forwards all its arguments to [`OneFormerImageProcessor.encode_inputs`] and then tokenizes the
task_inputs. Please refer to the docstring of this method for more information.
"""
if task_inputs is None:
raise ValueError("You have to specify the task_input. Found None.")
elif images is None:
raise ValueError("You have to specify the image. Found None.")
if not all(task in ["semantic", "instance", "panoptic"] for task in task_inputs):
raise ValueError("task_inputs must be semantic, instance, or panoptic.")
encoded_inputs = self.image_processor.encode_inputs(images, task_inputs, segmentation_maps, **kwargs)
if isinstance(task_inputs, str):
task_inputs = [task_inputs]
if isinstance(task_inputs, List) and all(isinstance(task_input, str) for task_input in task_inputs):
task_token_inputs = []
for task in task_inputs:
task_input = f"the task is {task}"
task_token_inputs.append(task_input)
encoded_inputs["task_inputs"] = self._preprocess_text(task_token_inputs, max_length=self.task_seq_length)
else:
raise TypeError("Task Inputs should be a string or a list of strings.")
if hasattr(encoded_inputs, "text_inputs"):
texts_list = encoded_inputs.text_inputs
text_inputs = []
for texts in texts_list:
text_input_list = self._preprocess_text(texts, max_length=self.max_seq_length)
text_inputs.append(text_input_list.unsqueeze(0))
encoded_inputs["text_inputs"] = torch.cat(text_inputs, dim=0)
return encoded_inputs
def post_process_semantic_segmentation(self, *args, **kwargs):
"""
This method forwards all its arguments to [`OneFormerImageProcessor.post_process_semantic_segmentation`].
Please refer to the docstring of this method for more information.
"""
return self.image_processor.post_process_semantic_segmentation(*args, **kwargs)
def post_process_instance_segmentation(self, *args, **kwargs):
"""
This method forwards all its arguments to [`OneFormerImageProcessor.post_process_instance_segmentation`].
Please refer to the docstring of this method for more information.
"""
return self.image_processor.post_process_instance_segmentation(*args, **kwargs)
def post_process_panoptic_segmentation(self, *args, **kwargs):
"""
This method forwards all its arguments to [`OneFormerImageProcessor.post_process_panoptic_segmentation`].
Please refer to the docstring of this method for more information.
"""
return self.image_processor.post_process_panoptic_segmentation(*args, **kwargs)
......@@ -4242,6 +4242,30 @@ class NystromformerPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"])
ONEFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
class OneFormerForUniversalSegmentation(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class OneFormerModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class OneFormerPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
......@@ -318,6 +318,13 @@ class MobileViTImageProcessor(metaclass=DummyObject):
requires_backends(self, ["vision"])
class OneFormerImageProcessor(metaclass=DummyObject):
_backends = ["vision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])
class OwlViTFeatureExtractor(metaclass=DummyObject):
_backends = ["vision"]
......
# coding=utf-8
# Copyright 2022 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import unittest
import numpy as np
from huggingface_hub import hf_hub_download
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_torch_available, is_vision_available
from ...test_feature_extraction_common import FeatureExtractionSavingTestMixin, prepare_image_inputs
if is_torch_available():
import torch
if is_vision_available():
from transformers import OneFormerImageProcessor
from transformers.models.oneformer.image_processing_oneformer import binary_mask_to_rle
from transformers.models.oneformer.modeling_oneformer import OneFormerForUniversalSegmentationOutput
if is_vision_available():
from PIL import Image
def prepare_metadata(class_info_file, repo_path="shi-labs/oneformer_demo"):
with open(hf_hub_download(repo_path, class_info_file, repo_type="dataset"), "r") as f:
class_info = json.load(f)
metadata = {}
class_names = []
thing_ids = []
for key, info in class_info.items():
metadata[key] = info["name"]
class_names.append(info["name"])
if info["isthing"]:
thing_ids.append(int(key))
metadata["thing_ids"] = thing_ids
metadata["class_names"] = class_names
return metadata
class OneFormerImageProcessorTester(unittest.TestCase):
def __init__(
self,
parent,
batch_size=7,
num_channels=3,
min_resolution=30,
max_resolution=400,
size=None,
do_resize=True,
do_normalize=True,
image_mean=[0.5, 0.5, 0.5],
image_std=[0.5, 0.5, 0.5],
num_labels=10,
reduce_labels=False,
ignore_index=255,
repo_path="shi-labs/oneformer_demo",
class_info_file="ade20k_panoptic.json",
num_text=10,
):
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
self.min_resolution = min_resolution
self.max_resolution = max_resolution
self.do_resize = do_resize
self.size = {"shortest_edge": 32, "longest_edge": 1333} if size is None else size
self.do_normalize = do_normalize
self.image_mean = image_mean
self.image_std = image_std
self.class_info_file = class_info_file
self.metadata = prepare_metadata(class_info_file, repo_path)
self.num_text = num_text
self.repo_path = repo_path
# for the post_process_functions
self.batch_size = 2
self.num_queries = 10
self.num_classes = 10
self.height = 3
self.width = 4
self.num_labels = num_labels
self.reduce_labels = reduce_labels
self.ignore_index = ignore_index
def prepare_feat_extract_dict(self):
return {
"do_resize": self.do_resize,
"size": self.size,
"do_normalize": self.do_normalize,
"image_mean": self.image_mean,
"image_std": self.image_std,
"num_labels": self.num_labels,
"reduce_labels": self.reduce_labels,
"ignore_index": self.ignore_index,
"class_info_file": self.class_info_file,
"metadata": self.metadata,
"num_text": self.num_text,
}
def get_expected_values(self, image_inputs, batched=False):
"""
This function computes the expected height and width when providing images to OneFormerImageProcessor,
assuming do_resize is set to True with a scalar size.
"""
if not batched:
image = image_inputs[0]
if isinstance(image, Image.Image):
w, h = image.size
else:
h, w = image.shape[1], image.shape[2]
if w < h:
expected_height = int(self.size["shortest_edge"] * h / w)
expected_width = self.size["shortest_edge"]
elif w > h:
expected_height = self.size["shortest_edge"]
expected_width = int(self.size["shortest_edge"] * w / h)
else:
expected_height = self.size["shortest_edge"]
expected_width = self.size["shortest_edge"]
else:
expected_values = []
for image in image_inputs:
expected_height, expected_width = self.get_expected_values([image])
expected_values.append((expected_height, expected_width))
expected_height = max(expected_values, key=lambda item: item[0])[0]
expected_width = max(expected_values, key=lambda item: item[1])[1]
return expected_height, expected_width
def get_fake_oneformer_outputs(self):
return OneFormerForUniversalSegmentationOutput(
# +1 for null class
class_queries_logits=torch.randn((self.batch_size, self.num_queries, self.num_classes + 1)),
masks_queries_logits=torch.randn((self.batch_size, self.num_queries, self.height, self.width)),
)
@require_torch
@require_vision
class OneFormerImageProcessingTest(FeatureExtractionSavingTestMixin, unittest.TestCase):
image_processing_class = OneFormerImageProcessor if (is_vision_available() and is_torch_available()) else None
# only for test_feat_extracttion_common.test_feat_extract_to_json_string
feature_extraction_class = image_processing_class
def setUp(self):
self.image_processing_tester = OneFormerImageProcessorTester(self)
@property
def feat_extract_dict(self):
return self.image_processing_tester.prepare_feat_extract_dict()
def test_feat_extract_properties(self):
image_processor = self.image_processing_class(**self.feat_extract_dict)
self.assertTrue(hasattr(image_processor, "image_mean"))
self.assertTrue(hasattr(image_processor, "image_std"))
self.assertTrue(hasattr(image_processor, "do_normalize"))
self.assertTrue(hasattr(image_processor, "do_resize"))
self.assertTrue(hasattr(image_processor, "size"))
self.assertTrue(hasattr(image_processor, "ignore_index"))
self.assertTrue(hasattr(image_processor, "class_info_file"))
self.assertTrue(hasattr(image_processor, "num_text"))
self.assertTrue(hasattr(image_processor, "repo_path"))
self.assertTrue(hasattr(image_processor, "metadata"))
self.assertTrue(hasattr(image_processor, "reduce_labels"))
def test_batch_feature(self):
pass
def test_call_pil(self):
# Initialize image_processor
image_processor = self.image_processing_class(**self.feat_extract_dict)
# create random PIL images
image_inputs = prepare_image_inputs(self.image_processing_tester, equal_resolution=False)
for image in image_inputs:
self.assertIsInstance(image, Image.Image)
# Test not batched input
encoded_images = image_processor(image_inputs[0], ["semantic"], return_tensors="pt").pixel_values
expected_height, expected_width = self.image_processing_tester.get_expected_values(image_inputs)
self.assertEqual(
encoded_images.shape,
(1, self.image_processing_tester.num_channels, expected_height, expected_width),
)
# Test batched
expected_height, expected_width = self.image_processing_tester.get_expected_values(image_inputs, batched=True)
encoded_images = image_processor(
image_inputs, ["semantic"] * len(image_inputs), return_tensors="pt"
).pixel_values
self.assertEqual(
encoded_images.shape,
(
self.image_processing_tester.batch_size,
self.image_processing_tester.num_channels,
expected_height,
expected_width,
),
)
def test_call_numpy(self):
# Initialize image_processor
image_processor = self.image_processing_class(**self.feat_extract_dict)
# create random numpy tensors
image_inputs = prepare_image_inputs(self.image_processing_tester, equal_resolution=False, numpify=True)
for image in image_inputs:
self.assertIsInstance(image, np.ndarray)
# Test not batched input
encoded_images = image_processor(image_inputs[0], ["semantic"], return_tensors="pt").pixel_values
expected_height, expected_width = self.image_processing_tester.get_expected_values(image_inputs)
self.assertEqual(
encoded_images.shape,
(1, self.image_processing_tester.num_channels, expected_height, expected_width),
)
# Test batched
expected_height, expected_width = self.image_processing_tester.get_expected_values(image_inputs, batched=True)
encoded_images = image_processor(
image_inputs, ["semantic"] * len(image_inputs), return_tensors="pt"
).pixel_values
self.assertEqual(
encoded_images.shape,
(
self.image_processing_tester.batch_size,
self.image_processing_tester.num_channels,
expected_height,
expected_width,
),
)
def test_call_pytorch(self):
# Initialize image_processor
image_processor = self.image_processing_class(**self.feat_extract_dict)
# create random PyTorch tensors
image_inputs = prepare_image_inputs(self.image_processing_tester, equal_resolution=False, torchify=True)
for image in image_inputs:
self.assertIsInstance(image, torch.Tensor)
# Test not batched input
encoded_images = image_processor(image_inputs[0], ["semantic"], return_tensors="pt").pixel_values
expected_height, expected_width = self.image_processing_tester.get_expected_values(image_inputs)
self.assertEqual(
encoded_images.shape,
(1, self.image_processing_tester.num_channels, expected_height, expected_width),
)
# Test batched
expected_height, expected_width = self.image_processing_tester.get_expected_values(image_inputs, batched=True)
encoded_images = image_processor(
image_inputs, ["semantic"] * len(image_inputs), return_tensors="pt"
).pixel_values
self.assertEqual(
encoded_images.shape,
(
self.image_processing_tester.batch_size,
self.image_processing_tester.num_channels,
expected_height,
expected_width,
),
)
def test_equivalence_pad_and_create_pixel_mask(self):
# Initialize image_processors
image_processor_1 = self.image_processing_class(**self.feat_extract_dict)
image_processor_2 = self.image_processing_class(
do_resize=False,
do_normalize=False,
do_rescale=False,
num_labels=self.image_processing_tester.num_classes,
class_info_file="ade20k_panoptic.json",
num_text=self.image_processing_tester.num_text,
repo_path="shi-labs/oneformer_demo",
)
# create random PyTorch tensors
image_inputs = prepare_image_inputs(self.image_processing_tester, equal_resolution=False, torchify=True)
for image in image_inputs:
self.assertIsInstance(image, torch.Tensor)
# Test whether the method "pad_and_return_pixel_mask" and calling the image processor return the same tensors
encoded_images_with_method = image_processor_1.encode_inputs(
image_inputs, ["semantic"] * len(image_inputs), return_tensors="pt"
)
encoded_images = image_processor_2(image_inputs, ["semantic"] * len(image_inputs), return_tensors="pt")
self.assertTrue(
torch.allclose(encoded_images_with_method["pixel_values"], encoded_images["pixel_values"], atol=1e-4)
)
self.assertTrue(
torch.allclose(encoded_images_with_method["pixel_mask"], encoded_images["pixel_mask"], atol=1e-4)
)
def comm_get_image_processor_inputs(
self, with_segmentation_maps=False, is_instance_map=False, segmentation_type="np"
):
image_processor = self.image_processing_class(**self.feat_extract_dict)
# prepare image and target
num_labels = self.image_processing_tester.num_labels
annotations = None
instance_id_to_semantic_id = None
image_inputs = prepare_image_inputs(self.image_processing_tester, equal_resolution=False)
if with_segmentation_maps:
high = num_labels
if is_instance_map:
labels_expanded = list(range(num_labels)) * 2
instance_id_to_semantic_id = {
instance_id: label_id for instance_id, label_id in enumerate(labels_expanded)
}
annotations = [
np.random.randint(0, high * 2, (img.size[1], img.size[0])).astype(np.uint8) for img in image_inputs
]
if segmentation_type == "pil":
annotations = [Image.fromarray(annotation) for annotation in annotations]
inputs = image_processor(
image_inputs,
["semantic"] * len(image_inputs),
annotations,
return_tensors="pt",
instance_id_to_semantic_id=instance_id_to_semantic_id,
pad_and_return_pixel_mask=True,
)
return inputs
def test_init_without_params(self):
pass
def test_call_with_segmentation_maps(self):
def common(is_instance_map=False, segmentation_type=None):
inputs = self.comm_get_image_processor_inputs(
with_segmentation_maps=True, is_instance_map=is_instance_map, segmentation_type=segmentation_type
)
mask_labels = inputs["mask_labels"]
class_labels = inputs["class_labels"]
pixel_values = inputs["pixel_values"]
text_inputs = inputs["text_inputs"]
# check the batch_size
for mask_label, class_label, text_input in zip(mask_labels, class_labels, text_inputs):
self.assertEqual(mask_label.shape[0], class_label.shape[0])
# this ensure padding has happened
self.assertEqual(mask_label.shape[1:], pixel_values.shape[2:])
self.assertEqual(len(text_input), self.image_processing_tester.num_text)
common()
common(is_instance_map=True)
common(is_instance_map=False, segmentation_type="pil")
common(is_instance_map=True, segmentation_type="pil")
def test_binary_mask_to_rle(self):
fake_binary_mask = np.zeros((20, 50))
fake_binary_mask[0, 20:] = 1
fake_binary_mask[1, :15] = 1
fake_binary_mask[5, :10] = 1
rle = binary_mask_to_rle(fake_binary_mask)
self.assertEqual(len(rle), 4)
self.assertEqual(rle[0], 21)
self.assertEqual(rle[1], 45)
def test_post_process_semantic_segmentation(self):
fature_extractor = self.image_processing_class(
num_labels=self.image_processing_tester.num_classes,
max_seq_length=77,
task_seq_length=77,
class_info_file="ade20k_panoptic.json",
num_text=self.image_processing_tester.num_text,
repo_path="shi-labs/oneformer_demo",
)
outputs = self.image_processing_tester.get_fake_oneformer_outputs()
segmentation = fature_extractor.post_process_semantic_segmentation(outputs)
self.assertEqual(len(segmentation), self.image_processing_tester.batch_size)
self.assertEqual(
segmentation[0].shape,
(
self.image_processing_tester.height,
self.image_processing_tester.width,
),
)
target_sizes = [(1, 4) for i in range(self.image_processing_tester.batch_size)]
segmentation = fature_extractor.post_process_semantic_segmentation(outputs, target_sizes=target_sizes)
self.assertEqual(segmentation[0].shape, target_sizes[0])
def test_post_process_instance_segmentation(self):
image_processor = self.image_processing_class(
num_labels=self.image_processing_tester.num_classes,
max_seq_length=77,
task_seq_length=77,
class_info_file="ade20k_panoptic.json",
num_text=self.image_processing_tester.num_text,
repo_path="shi-labs/oneformer_demo",
)
outputs = self.image_processing_tester.get_fake_oneformer_outputs()
segmentation = image_processor.post_process_instance_segmentation(outputs, threshold=0)
self.assertTrue(len(segmentation) == self.image_processing_tester.batch_size)
for el in segmentation:
self.assertTrue("segmentation" in el)
self.assertTrue("segments_info" in el)
self.assertEqual(type(el["segments_info"]), list)
self.assertEqual(
el["segmentation"].shape, (self.image_processing_tester.height, self.image_processing_tester.width)
)
def test_post_process_panoptic_segmentation(self):
image_processor = self.image_processing_class(
num_labels=self.image_processing_tester.num_classes,
max_seq_length=77,
task_seq_length=77,
class_info_file="ade20k_panoptic.json",
num_text=self.image_processing_tester.num_text,
repo_path="shi-labs/oneformer_demo",
)
outputs = self.image_processing_tester.get_fake_oneformer_outputs()
segmentation = image_processor.post_process_panoptic_segmentation(outputs, threshold=0)
self.assertTrue(len(segmentation) == self.image_processing_tester.batch_size)
for el in segmentation:
self.assertTrue("segmentation" in el)
self.assertTrue("segments_info" in el)
self.assertEqual(type(el["segments_info"]), list)
self.assertEqual(
el["segmentation"].shape, (self.image_processing_tester.height, self.image_processing_tester.width)
)
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the PyTorch OneFormer model. """
import copy
import inspect
import unittest
import numpy as np
from tests.test_modeling_common import floats_tensor
from transformers import OneFormerConfig, is_torch_available, is_vision_available
from transformers.testing_utils import require_torch, require_torch_multi_gpu, require_vision, slow, torch_device
from transformers.utils import cached_property
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin
if is_torch_available():
import torch
from transformers import OneFormerForUniversalSegmentation, OneFormerModel
if is_vision_available():
from transformers import OneFormerProcessor
if is_vision_available():
from PIL import Image
def _config_zero_init(config):
configs_no_init = copy.deepcopy(config)
for key in configs_no_init.__dict__.keys():
if "_range" in key or "_std" in key or "initializer_factor" in key or "layer_scale" in key:
setattr(configs_no_init, key, 1e-10)
return configs_no_init
class OneFormerModelTester:
def __init__(
self,
parent,
batch_size=2,
is_training=True,
use_auxiliary_loss=False,
num_queries=10,
num_channels=3,
min_size=32 * 8,
max_size=32 * 8,
num_labels=4,
hidden_dim=64,
sequence_length=77,
n_ctx=4,
):
self.parent = parent
self.batch_size = batch_size
self.is_training = is_training
self.use_auxiliary_loss = use_auxiliary_loss
self.num_queries = num_queries
self.num_channels = num_channels
self.min_size = min_size
self.max_size = max_size
self.num_labels = num_labels
self.hidden_dim = hidden_dim
self.sequence_length = sequence_length
self.n_ctx = n_ctx
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.min_size, self.max_size]).to(
torch_device
)
task_inputs = torch.randint(high=49408, size=(self.batch_size, self.sequence_length)).to(torch_device).long()
pixel_mask = torch.ones([self.batch_size, self.min_size, self.max_size], device=torch_device)
text_inputs = (
torch.randint(high=49408, size=(self.batch_size, self.num_queries - self.n_ctx, self.sequence_length))
.to(torch_device)
.long()
)
mask_labels = (
torch.rand([self.batch_size, self.num_labels, self.min_size, self.max_size], device=torch_device) > 0.5
).float()
class_labels = (torch.rand((self.batch_size, self.num_labels), device=torch_device) > 0.5).long()
config = self.get_config()
return config, pixel_values, task_inputs, text_inputs, pixel_mask, mask_labels, class_labels
def get_config(self):
config = OneFormerConfig(
hidden_size=self.hidden_dim,
)
config.num_queries = self.num_queries
config.num_labels = self.num_labels
config.backbone_config.depths = [1, 1, 1, 1]
config.backbone_config.num_channels = self.num_channels
config.encoder_feedforward_dim = 64
config.dim_feedforward = 128
config.hidden_dim = self.hidden_dim
config.mask_dim = self.hidden_dim
config.conv_dim = self.hidden_dim
config.text_encoder_width = self.hidden_dim
config.task_seq_len = self.sequence_length
config.max_seq_len = self.sequence_length
config.text_encoder_context_length = self.sequence_length
config.text_encoder_n_ctx = self.n_ctx
return config
def prepare_config_and_inputs_for_common(self):
config, pixel_values, task_inputs, pixel_mask, _, _, _ = self.prepare_config_and_inputs()
inputs_dict = {"pixel_values": pixel_values, "pixel_mask": pixel_mask, "task_inputs": task_inputs}
return config, inputs_dict
def check_output_hidden_state(self, output, config):
encoder_hidden_states = output.encoder_hidden_states
pixel_decoder_hidden_states = output.pixel_decoder_hidden_states
transformer_decoder_hidden_states = output.transformer_decoder_hidden_states
self.parent.assertTrue(len(encoder_hidden_states), len(config.backbone_config.depths))
self.parent.assertTrue(len(pixel_decoder_hidden_states), config.encoder_layers)
self.parent.assertTrue(len(transformer_decoder_hidden_states), config.decoder_layers - 1)
def create_and_check_oneformer_model(
self, config, pixel_values, task_inputs, pixel_mask, output_hidden_states=False
):
with torch.no_grad():
model = OneFormerModel(config=config)
model.to(torch_device)
model.eval()
output = model(pixel_values=pixel_values, task_inputs=task_inputs, pixel_mask=pixel_mask)
output = model(pixel_values, task_inputs=task_inputs, output_hidden_states=True)
# the correct shape of output.transformer_decoder_hidden_states ensure the correcteness of the
# encoder and pixel decoder
self.parent.assertEqual(
output.transformer_decoder_object_queries.shape,
(self.batch_size, self.num_queries, self.hidden_dim),
)
# let's ensure the other two hidden state exists
self.parent.assertTrue(output.pixel_decoder_hidden_states is not None)
self.parent.assertTrue(output.encoder_hidden_states is not None)
if output_hidden_states:
self.check_output_hidden_state(output, config)
def create_and_check_oneformer_universal_segmentation_head_model(
self, config, pixel_values, task_inputs, text_inputs, pixel_mask, mask_labels, class_labels
):
model = OneFormerForUniversalSegmentation(config=config)
model.to(torch_device)
model.eval()
def comm_check_on_output(result):
# let's still check that all the required stuff is there
self.parent.assertTrue(result.transformer_decoder_hidden_states is not None)
self.parent.assertTrue(result.pixel_decoder_hidden_states is not None)
self.parent.assertTrue(result.encoder_hidden_states is not None)
# okay, now we need to check the logits shape
# due to the encoder compression, masks have a //4 spatial size
self.parent.assertEqual(
result.masks_queries_logits.shape,
(self.batch_size, self.num_queries, self.min_size // 4, self.max_size // 4),
)
# + 1 for null class
self.parent.assertEqual(
result.class_queries_logits.shape, (self.batch_size, self.num_queries, self.num_labels + 1)
)
with torch.no_grad():
result = model(pixel_values=pixel_values, task_inputs=task_inputs, pixel_mask=pixel_mask)
result = model(pixel_values, task_inputs)
comm_check_on_output(result)
config.is_training = True
model = OneFormerForUniversalSegmentation(config=config)
model.to(torch_device)
model.eval()
with torch.no_grad():
result = model(
pixel_values=pixel_values,
task_inputs=task_inputs,
pixel_mask=pixel_mask,
mask_labels=mask_labels,
class_labels=class_labels,
text_inputs=text_inputs,
)
comm_check_on_output(result)
self.parent.assertTrue(result.loss is not None)
self.parent.assertEqual(result.loss.shape, torch.Size([1]))
@require_torch
class OneFormerModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (OneFormerModel, OneFormerForUniversalSegmentation) if is_torch_available() else ()
is_encoder_decoder = False
test_pruning = False
test_head_masking = False
test_missing_keys = False
def setUp(self):
self.model_tester = OneFormerModelTester(self)
self.config_tester = ConfigTester(self, config_class=OneFormerConfig, has_text_modality=False)
def test_config(self):
self.config_tester.run_common_tests()
def test_oneformer_model(self):
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
self.model_tester.create_and_check_oneformer_model(config, **inputs, output_hidden_states=False)
def test_oneformer_universal_segmentation_head_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_oneformer_universal_segmentation_head_model(*config_and_inputs)
def test_model_main_input_name(self):
for model_class in self.all_model_classes:
model_signature = inspect.signature(getattr(model_class, "forward"))
# The main input is the name of the argument after `self`
observed_main_input_name = list(model_signature.parameters.keys())[1:3]
self.assertEqual(model_class.main_input_name, observed_main_input_name)
@unittest.skip(reason="OneFormer uses two main inputs")
def test_torchscript_simple(self):
pass
@unittest.skip(reason="OneFormer uses two main inputs")
def test_torchscript_output_attentions(self):
pass
@unittest.skip(reason="OneFormer uses two main inputs")
def test_torchscript_output_hidden_state(self):
pass
@unittest.skip(reason="OneFormer does not use inputs_embeds")
def test_inputs_embeds(self):
pass
@unittest.skip(reason="OneFormer does not have a get_input_embeddings method")
def test_model_common_attributes(self):
pass
@unittest.skip(reason="OneFormer is not a generative model")
def test_generate_without_input_ids(self):
pass
@unittest.skip(reason="OneFormer does not use token embeddings")
def test_resize_tokens_embeddings(self):
pass
@require_torch_multi_gpu
@unittest.skip(
reason="OneFormer has some layers using `add_module` which doesn't work well with `nn.DataParallel`"
)
def test_multi_gpu_data_parallel_forward(self):
pass
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.forward)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["pixel_values", "task_inputs"]
self.assertListEqual(arg_names[:2], expected_arg_names)
@slow
def test_model_from_pretrained(self):
for model_name in ["shi-labs/oneformer_ade20k_swin_tiny"]:
model = OneFormerModel.from_pretrained(model_name)
self.assertIsNotNone(model)
def test_model_with_labels(self):
size = (self.model_tester.min_size,) * 2
inputs = {
"pixel_values": torch.randn((2, 3, *size), device=torch_device),
"task_inputs": torch.randint(high=49408, size=(2, 77), device=torch_device).long(),
"text_inputs": torch.randint(high=49408, size=(2, 134, 77), device=torch_device).long(),
"mask_labels": torch.randn((2, 150, *size), device=torch_device),
"class_labels": torch.zeros(2, 150, device=torch_device).long(),
}
config = OneFormerConfig()
config.is_training = True
model = OneFormerForUniversalSegmentation(config).to(torch_device)
outputs = model(**inputs)
self.assertTrue(outputs.loss is not None)
def test_hidden_states_output(self):
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
self.model_tester.create_and_check_oneformer_model(config, **inputs, output_hidden_states=True)
def test_attention_outputs(self):
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config).to(torch_device)
outputs = model(**inputs, output_attentions=True)
self.assertTrue(outputs.attentions is not None)
def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.contrastive_temperature = 1
configs_no_init = _config_zero_init(config)
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
for name, param in model.named_parameters():
if param.requires_grad:
self.assertIn(
((param.data.mean() * 1e9).round() / 1e9).item(),
[0.0, 1.0],
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
def test_training(self):
if not self.model_tester.is_training:
return
# only OneFormerForUniversalSegmentation has the loss
model_class = self.all_model_classes[1]
(
config,
pixel_values,
task_inputs,
text_inputs,
pixel_mask,
mask_labels,
class_labels,
) = self.model_tester.prepare_config_and_inputs()
config.is_training = True
model = model_class(config)
model.to(torch_device)
model.train()
loss = model(
pixel_values, task_inputs, text_inputs=text_inputs, mask_labels=mask_labels, class_labels=class_labels
).loss
loss.backward()
def test_retain_grad_hidden_states_attentions(self):
# only OneFormerForUniversalSegmentation has the loss
model_class = self.all_model_classes[1]
(
config,
pixel_values,
task_inputs,
text_inputs,
pixel_mask,
mask_labels,
class_labels,
) = self.model_tester.prepare_config_and_inputs()
config.output_hidden_states = True
config.output_attentions = True
config.is_training = True
model = model_class(config)
model.to(torch_device)
model.train()
outputs = model(
pixel_values, task_inputs, text_inputs=text_inputs, mask_labels=mask_labels, class_labels=class_labels
)
encoder_hidden_states = outputs.encoder_hidden_states[0]
encoder_hidden_states.retain_grad()
pixel_decoder_hidden_states = outputs.pixel_decoder_hidden_states[0]
pixel_decoder_hidden_states.retain_grad()
transformer_decoder_class_predictions = outputs.transformer_decoder_class_predictions
transformer_decoder_class_predictions.retain_grad()
transformer_decoder_mask_predictions = outputs.transformer_decoder_mask_predictions
transformer_decoder_mask_predictions.retain_grad()
attentions = outputs.attentions[0][0]
attentions.retain_grad()
outputs.loss.backward(retain_graph=True)
self.assertIsNotNone(encoder_hidden_states.grad)
self.assertIsNotNone(pixel_decoder_hidden_states.grad)
self.assertIsNotNone(transformer_decoder_class_predictions.grad)
self.assertIsNotNone(transformer_decoder_mask_predictions.grad)
self.assertIsNotNone(attentions.grad)
TOLERANCE = 1e-4
# We will verify our results on an image of cute cats
def prepare_img():
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
return image
@require_vision
@slow
class OneFormerModelIntegrationTest(unittest.TestCase):
@cached_property
def model_checkpoints(self):
return "shi-labs/oneformer_ade20k_swin_tiny"
@cached_property
def default_processor(self):
return OneFormerProcessor.from_pretrained(self.model_checkpoints) if is_vision_available() else None
def test_inference_no_head(self):
model = OneFormerModel.from_pretrained(self.model_checkpoints).to(torch_device)
processor = self.default_processor
image = prepare_img()
inputs = processor(image, ["semantic"], return_tensors="pt").to(torch_device)
inputs_shape = inputs["pixel_values"].shape
# check size
self.assertEqual(inputs_shape, (1, 3, 512, 682))
task_inputs_shape = inputs["task_inputs"].shape
# check size
self.assertEqual(task_inputs_shape, (1, 77))
with torch.no_grad():
outputs = model(**inputs)
expected_slice_hidden_state = torch.tensor(
[[0.2724, 0.8287, 0.6025], [1.2706, 1.1252, 1.1445], [1.1357, 0.6150, 0.4185]]
).to(torch_device)
self.assertTrue(
torch.allclose(
outputs.encoder_hidden_states[-1][0, 0, :3, :3], expected_slice_hidden_state, atol=TOLERANCE
)
)
expected_slice_hidden_state = torch.tensor(
[[1.0581, 1.2275, 1.2000], [1.1901, 1.2925, 1.2861], [1.1578, 1.2558, 1.3212]]
).to(torch_device)
self.assertTrue(
torch.allclose(
outputs.pixel_decoder_hidden_states[0][0, 0, :3, :3], expected_slice_hidden_state, atol=TOLERANCE
)
)
expected_slice_hidden_state = torch.tensor(
[[3.0711, -1.1855, -5.1095], [3.5536, -3.2710, -5.2052], [2.6020, -4.3605, -4.1422]]
).to(torch_device)
self.assertTrue(
torch.allclose(
outputs.transformer_decoder_class_predictions[0, :3, :3], expected_slice_hidden_state, atol=TOLERANCE
)
)
def test_inference_universal_segmentation_head(self):
model = OneFormerForUniversalSegmentation.from_pretrained(self.model_checkpoints).to(torch_device).eval()
processor = self.default_processor
image = prepare_img()
inputs = processor(image, ["semantic"], return_tensors="pt").to(torch_device)
inputs_shape = inputs["pixel_values"].shape
# check size
self.assertEqual(inputs_shape, (1, 3, 512, 682))
with torch.no_grad():
outputs = model(**inputs)
# masks_queries_logits
masks_queries_logits = outputs.masks_queries_logits
self.assertEqual(
masks_queries_logits.shape,
(1, model.config.num_queries, inputs_shape[-2] // 4, (inputs_shape[-1] + 2) // 4),
)
expected_slice = [[[3.1215, 4.1250, 4.1106], [2.8183, 3.4623, 3.5512], [2.4550, 2.9841, 3.5081]]]
expected_slice = torch.tensor(expected_slice).to(torch_device)
self.assertTrue(torch.allclose(masks_queries_logits[0, 0, :3, :3], expected_slice, atol=TOLERANCE))
# class_queries_logits
class_queries_logits = outputs.class_queries_logits
self.assertEqual(
class_queries_logits.shape,
(1, model.config.num_queries, model.config.num_labels + 1),
)
expected_slice = torch.tensor(
[[3.0711, -1.1855, -5.1095], [3.5536, -3.2710, -5.2052], [2.6020, -4.3605, -4.1422]]
).to(torch_device)
self.assertTrue(torch.allclose(class_queries_logits[0, :3, :3], expected_slice, atol=TOLERANCE))
def test_with_segmentation_maps_and_loss(self):
dummy_model = OneFormerForUniversalSegmentation.from_pretrained(self.model_checkpoints)
processor = self.default_processor
processor.image_processor.num_text = dummy_model.config.num_queries - dummy_model.config.text_encoder_n_ctx
dummy_model.config.is_training = True
model = OneFormerForUniversalSegmentation(dummy_model.config).to(torch_device).eval()
del dummy_model
inputs = processor(
[np.zeros((3, 512, 640)), np.zeros((3, 512, 640))],
["semantic", "semantic"],
segmentation_maps=[np.zeros((384, 384)).astype(np.float32), np.zeros((384, 384)).astype(np.float32)],
return_tensors="pt",
)
inputs["pixel_values"] = inputs["pixel_values"].to(torch_device)
inputs["task_inputs"] = inputs["task_inputs"].to(torch_device)
inputs["text_inputs"] = inputs["text_inputs"].to(torch_device)
inputs["mask_labels"] = [el.to(torch_device) for el in inputs["mask_labels"]]
inputs["class_labels"] = [el.to(torch_device) for el in inputs["class_labels"]]
with torch.no_grad():
outputs = model(**inputs)
self.assertTrue(outputs.loss is not None)
# coding=utf-8
# Copyright 2022 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import tempfile
import unittest
import numpy as np
from datasets import load_dataset
from huggingface_hub import hf_hub_download
from transformers.testing_utils import check_json_file_has_correct_format, require_torch, require_vision
from transformers.utils import is_torch_available, is_vision_available
from ...test_feature_extraction_common import prepare_image_inputs
if is_torch_available():
import torch
if is_vision_available():
from transformers import CLIPTokenizer, OneFormerImageProcessor, OneFormerProcessor
from transformers.models.oneformer.image_processing_oneformer import binary_mask_to_rle
from transformers.models.oneformer.modeling_oneformer import OneFormerForUniversalSegmentationOutput
if is_vision_available():
from PIL import Image
def prepare_metadata(class_info_file, repo_path="shi-labs/oneformer_demo"):
with open(hf_hub_download(repo_path, class_info_file, repo_type="dataset"), "r") as f:
class_info = json.load(f)
metadata = {}
class_names = []
thing_ids = []
for key, info in class_info.items():
metadata[key] = info["name"]
class_names.append(info["name"])
if info["isthing"]:
thing_ids.append(int(key))
metadata["thing_ids"] = thing_ids
metadata["class_names"] = class_names
return metadata
class OneFormerProcessorTester(unittest.TestCase):
def __init__(
self,
parent,
batch_size=7,
num_channels=3,
min_resolution=30,
max_resolution=400,
size=None,
do_resize=True,
do_normalize=True,
image_mean=[0.5, 0.5, 0.5],
image_std=[0.5, 0.5, 0.5],
num_labels=10,
reduce_labels=False,
ignore_index=255,
max_seq_length=77,
task_seq_length=77,
model_repo="shi-labs/oneformer_ade20k_swin_tiny",
class_info_file="ade20k_panoptic.json",
num_text=10,
):
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
self.min_resolution = min_resolution
self.max_resolution = max_resolution
self.do_resize = do_resize
self.size = {"shortest_edge": 32, "longest_edge": 1333} if size is None else size
self.do_normalize = do_normalize
self.image_mean = image_mean
self.image_std = image_std
self.max_seq_length = max_seq_length
self.task_seq_length = task_seq_length
self.class_info_file = class_info_file
self.metadata = prepare_metadata(class_info_file)
self.num_text = num_text
self.model_repo = model_repo
# for the post_process_functions
self.batch_size = 2
self.num_queries = 10
self.num_classes = 10
self.height = 3
self.width = 4
self.num_labels = num_labels
self.reduce_labels = reduce_labels
self.ignore_index = ignore_index
def prepare_processor_dict(self):
image_processor_dict = {
"do_resize": self.do_resize,
"size": self.size,
"do_normalize": self.do_normalize,
"image_mean": self.image_mean,
"image_std": self.image_std,
"num_labels": self.num_labels,
"reduce_labels": self.reduce_labels,
"ignore_index": self.ignore_index,
"class_info_file": self.class_info_file,
"metadata": self.metadata,
"num_text": self.num_text,
}
image_processor = OneFormerImageProcessor(**image_processor_dict)
tokenizer = CLIPTokenizer.from_pretrained(self.model_repo)
return {
"image_processor": image_processor,
"tokenizer": tokenizer,
"max_seq_length": self.max_seq_length,
"task_seq_length": self.task_seq_length,
}
def get_expected_values(self, image_inputs, batched=False):
"""
This function computes the expected height and width when providing images to OneFormerProcessor,
assuming do_resize is set to True with a scalar size. It also provides the expected sequence length
for the task_inputs and text_list_input.
"""
if not batched:
image = image_inputs[0]
if isinstance(image, Image.Image):
w, h = image.size
else:
h, w = image.shape[1], image.shape[2]
if w < h:
expected_height = int(self.size["shortest_edge"] * h / w)
expected_width = self.size["shortest_edge"]
elif w > h:
expected_height = self.size["shortest_edge"]
expected_width = int(self.size["shortest_edge"] * w / h)
else:
expected_height = self.size["shortest_edge"]
expected_width = self.size["shortest_edge"]
else:
expected_values = []
for image in image_inputs:
expected_height, expected_width, expected_sequence_length = self.get_expected_values([image])
expected_values.append((expected_height, expected_width, expected_sequence_length))
expected_height = max(expected_values, key=lambda item: item[0])[0]
expected_width = max(expected_values, key=lambda item: item[1])[1]
expected_sequence_length = self.max_seq_length
return expected_height, expected_width, expected_sequence_length
def get_fake_oneformer_outputs(self):
return OneFormerForUniversalSegmentationOutput(
# +1 for null class
class_queries_logits=torch.randn((self.batch_size, self.num_queries, self.num_classes + 1)),
masks_queries_logits=torch.randn((self.batch_size, self.num_queries, self.height, self.width)),
)
@require_torch
@require_vision
class OneFormerProcessingTest(unittest.TestCase):
processing_class = OneFormerProcessor if (is_vision_available() and is_torch_available()) else None
# only for test_feat_extracttion_common.test_feat_extract_to_json_string
feature_extraction_class = processing_class
def setUp(self):
self.processing_tester = OneFormerProcessorTester(self)
@property
def processor_dict(self):
return self.processing_tester.prepare_processor_dict()
def test_feat_extract_properties(self):
processor = self.processing_class(**self.processor_dict)
self.assertTrue(hasattr(processor, "image_processor"))
self.assertTrue(hasattr(processor, "tokenizer"))
self.assertTrue(hasattr(processor, "max_seq_length"))
self.assertTrue(hasattr(processor, "task_seq_length"))
def test_batch_feature(self):
pass
def test_call_pil(self):
# Initialize processor
processor = self.processing_class(**self.processor_dict)
# create random PIL images
image_inputs = prepare_image_inputs(self.processing_tester, equal_resolution=False)
for image in image_inputs:
self.assertIsInstance(image, Image.Image)
# Test not batched input
encoded_images = processor(image_inputs[0], ["semantic"], return_tensors="pt").pixel_values
expected_height, expected_width, expected_sequence_length = self.processing_tester.get_expected_values(
image_inputs
)
self.assertEqual(
encoded_images.shape,
(1, self.processing_tester.num_channels, expected_height, expected_width),
)
tokenized_task_inputs = processor(image_inputs[0], ["semantic"], return_tensors="pt").task_inputs
self.assertEqual(
tokenized_task_inputs.shape,
(1, expected_sequence_length),
)
# Test batched
expected_height, expected_width, expected_sequence_length = self.processing_tester.get_expected_values(
image_inputs, batched=True
)
encoded_images = processor(image_inputs, ["semantic"] * len(image_inputs), return_tensors="pt").pixel_values
self.assertEqual(
encoded_images.shape,
(
self.processing_tester.batch_size,
self.processing_tester.num_channels,
expected_height,
expected_width,
),
)
tokenized_task_inputs = processor(
image_inputs, ["semantic"] * len(image_inputs), return_tensors="pt"
).task_inputs
self.assertEqual(
tokenized_task_inputs.shape,
(self.processing_tester.batch_size, expected_sequence_length),
)
def test_call_numpy(self):
# Initialize processor
processor = self.processing_class(**self.processor_dict)
# create random numpy tensors
image_inputs = prepare_image_inputs(self.processing_tester, equal_resolution=False, numpify=True)
for image in image_inputs:
self.assertIsInstance(image, np.ndarray)
# Test not batched input
encoded_images = processor(image_inputs[0], ["semantic"], return_tensors="pt").pixel_values
expected_height, expected_width, expected_sequence_length = self.processing_tester.get_expected_values(
image_inputs
)
self.assertEqual(
encoded_images.shape,
(1, self.processing_tester.num_channels, expected_height, expected_width),
)
tokenized_task_inputs = processor(image_inputs[0], ["semantic"], return_tensors="pt").task_inputs
self.assertEqual(
tokenized_task_inputs.shape,
(1, expected_sequence_length),
)
# Test batched
expected_height, expected_width, expected_sequence_length = self.processing_tester.get_expected_values(
image_inputs, batched=True
)
encoded_images = processor(image_inputs, ["semantic"] * len(image_inputs), return_tensors="pt").pixel_values
self.assertEqual(
encoded_images.shape,
(
self.processing_tester.batch_size,
self.processing_tester.num_channels,
expected_height,
expected_width,
),
)
tokenized_task_inputs = processor(
image_inputs, ["semantic"] * len(image_inputs), return_tensors="pt"
).task_inputs
self.assertEqual(
tokenized_task_inputs.shape,
(self.processing_tester.batch_size, expected_sequence_length),
)
def test_call_pytorch(self):
# Initialize processor
processor = self.processing_class(**self.processor_dict)
# create random PyTorch tensors
image_inputs = prepare_image_inputs(self.processing_tester, equal_resolution=False, torchify=True)
for image in image_inputs:
self.assertIsInstance(image, torch.Tensor)
# Test not batched input
encoded_images = processor(image_inputs[0], ["semantic"], return_tensors="pt").pixel_values
expected_height, expected_width, expected_sequence_length = self.processing_tester.get_expected_values(
image_inputs
)
self.assertEqual(
encoded_images.shape,
(1, self.processing_tester.num_channels, expected_height, expected_width),
)
tokenized_task_inputs = processor(image_inputs[0], ["semantic"], return_tensors="pt").task_inputs
self.assertEqual(
tokenized_task_inputs.shape,
(1, expected_sequence_length),
)
# Test batched
expected_height, expected_width, expected_sequence_length = self.processing_tester.get_expected_values(
image_inputs, batched=True
)
encoded_images = processor(image_inputs, ["semantic"] * len(image_inputs), return_tensors="pt").pixel_values
self.assertEqual(
encoded_images.shape,
(
self.processing_tester.batch_size,
self.processing_tester.num_channels,
expected_height,
expected_width,
),
)
tokenized_task_inputs = processor(
image_inputs, ["semantic"] * len(image_inputs), return_tensors="pt"
).task_inputs
self.assertEqual(
tokenized_task_inputs.shape,
(self.processing_tester.batch_size, expected_sequence_length),
)
def test_equivalence_pad_and_create_pixel_mask(self):
# Initialize processors
processor_1 = self.processing_class(**self.processor_dict)
image_processor = OneFormerImageProcessor(
do_resize=False,
do_normalize=False,
do_rescale=False,
num_labels=self.processing_tester.num_classes,
class_info_file="ade20k_panoptic.json",
num_text=self.processing_tester.num_text,
)
tokenizer = CLIPTokenizer.from_pretrained("shi-labs/oneformer_ade20k_swin_tiny")
processor_2 = self.processing_class(
image_processor=image_processor, tokenizer=tokenizer, max_seq_length=77, task_seq_length=77
)
# create random PyTorch tensors
image_inputs = prepare_image_inputs(self.processing_tester, equal_resolution=False, torchify=True)
for image in image_inputs:
self.assertIsInstance(image, torch.Tensor)
# Test whether the method "pad_and_return_pixel_mask" and calling the image processor return the same tensors
encoded_images_with_method = processor_1.encode_inputs(
image_inputs, ["semantic"] * len(image_inputs), return_tensors="pt"
)
encoded_images = processor_2(image_inputs, ["semantic"] * len(image_inputs), return_tensors="pt")
self.assertTrue(
torch.allclose(encoded_images_with_method["pixel_values"], encoded_images["pixel_values"], atol=1e-4)
)
self.assertTrue(
torch.allclose(encoded_images_with_method["pixel_mask"], encoded_images["pixel_mask"], atol=1e-4)
)
def comm_get_processor_inputs(self, with_segmentation_maps=False, is_instance_map=False, segmentation_type="np"):
processor = self.processing_class(**self.processor_dict)
# prepare image and target
num_labels = self.processing_tester.num_labels
annotations = None
instance_id_to_semantic_id = None
image_inputs = prepare_image_inputs(self.processing_tester, equal_resolution=False)
if with_segmentation_maps:
high = num_labels
if is_instance_map:
labels_expanded = list(range(num_labels)) * 2
instance_id_to_semantic_id = {
instance_id: label_id for instance_id, label_id in enumerate(labels_expanded)
}
annotations = [
np.random.randint(0, high * 2, (img.size[1], img.size[0])).astype(np.uint8) for img in image_inputs
]
if segmentation_type == "pil":
annotations = [Image.fromarray(annotation) for annotation in annotations]
inputs = processor(
image_inputs,
["semantic"] * len(image_inputs),
annotations,
return_tensors="pt",
instance_id_to_semantic_id=instance_id_to_semantic_id,
pad_and_return_pixel_mask=True,
)
return inputs
def test_init_without_params(self):
pass
def test_feat_extract_from_and_save_pretrained(self):
feat_extract_first = self.feature_extraction_class(**self.processor_dict)
with tempfile.TemporaryDirectory() as tmpdirname:
feat_extract_first.save_pretrained(tmpdirname)
check_json_file_has_correct_format(os.path.join(tmpdirname, "preprocessor_config.json"))
feat_extract_second = self.feature_extraction_class.from_pretrained(tmpdirname)
self.assertEqual(feat_extract_second.image_processor.to_dict(), feat_extract_first.image_processor.to_dict())
self.assertIsInstance(feat_extract_first.image_processor, OneFormerImageProcessor)
self.assertIsInstance(feat_extract_first.tokenizer, CLIPTokenizer)
def test_call_with_segmentation_maps(self):
def common(is_instance_map=False, segmentation_type=None):
inputs = self.comm_get_processor_inputs(
with_segmentation_maps=True, is_instance_map=is_instance_map, segmentation_type=segmentation_type
)
mask_labels = inputs["mask_labels"]
class_labels = inputs["class_labels"]
pixel_values = inputs["pixel_values"]
text_inputs = inputs["text_inputs"]
# check the batch_size
for mask_label, class_label, text_input in zip(mask_labels, class_labels, text_inputs):
self.assertEqual(mask_label.shape[0], class_label.shape[0])
# this ensure padding has happened
self.assertEqual(mask_label.shape[1:], pixel_values.shape[2:])
self.assertEqual(text_input.shape[0], self.processing_tester.num_text)
common()
common(is_instance_map=True)
common(is_instance_map=False, segmentation_type="pil")
common(is_instance_map=True, segmentation_type="pil")
def test_integration_semantic_segmentation(self):
# load 2 images and corresponding panoptic annotations from the hub
dataset = load_dataset("nielsr/ade20k-panoptic-demo")
image1 = dataset["train"][0]["image"]
image2 = dataset["train"][1]["image"]
segments_info1 = dataset["train"][0]["segments_info"]
segments_info2 = dataset["train"][1]["segments_info"]
annotation1 = dataset["train"][0]["label"]
annotation2 = dataset["train"][1]["label"]
def rgb_to_id(color):
if isinstance(color, np.ndarray) and len(color.shape) == 3:
if color.dtype == np.uint8:
color = color.astype(np.int32)
return color[:, :, 0] + 256 * color[:, :, 1] + 256 * 256 * color[:, :, 2]
return int(color[0] + 256 * color[1] + 256 * 256 * color[2])
def create_panoptic_map(annotation, segments_info):
annotation = np.array(annotation)
# convert RGB to segment IDs per pixel
# 0 is the "ignore" label, for which we don't need to make binary masks
panoptic_map = rgb_to_id(annotation)
# create mapping between segment IDs and semantic classes
inst2class = {segment["id"]: segment["category_id"] for segment in segments_info}
return panoptic_map, inst2class
panoptic_map1, inst2class1 = create_panoptic_map(annotation1, segments_info1)
panoptic_map2, inst2class2 = create_panoptic_map(annotation2, segments_info2)
image_processor = OneFormerImageProcessor(
reduce_labels=True,
ignore_index=0,
size=(512, 512),
class_info_file="ade20k_panoptic.json",
num_text=self.processing_tester.num_text,
)
tokenizer = CLIPTokenizer.from_pretrained("shi-labs/oneformer_ade20k_swin_tiny")
processor = OneFormerProcessor(
image_processor=image_processor,
tokenizer=tokenizer,
max_seq_length=77,
task_seq_length=77,
)
# prepare the images and annotations
pixel_values_list = [np.moveaxis(np.array(image1), -1, 0), np.moveaxis(np.array(image2), -1, 0)]
inputs = processor.encode_inputs(
pixel_values_list,
["semantic", "semantic"],
[panoptic_map1, panoptic_map2],
instance_id_to_semantic_id=[inst2class1, inst2class2],
return_tensors="pt",
)
# verify the pixel values, task inputs, text inputs and pixel mask
self.assertEqual(inputs["pixel_values"].shape, (2, 3, 512, 711))
self.assertEqual(inputs["pixel_mask"].shape, (2, 512, 711))
self.assertEqual(inputs["task_inputs"].shape, (2, 77))
self.assertEqual(inputs["text_inputs"].shape, (2, self.processing_tester.num_text, 77))
# verify the class labels
self.assertEqual(len(inputs["class_labels"]), 2)
# fmt: off
expected_class_labels = torch.tensor([4, 17, 32, 42, 12, 3, 5, 0, 43, 96, 104, 31, 125, 138, 87, 149]) # noqa: E231
# fmt: on
self.assertTrue(torch.allclose(inputs["class_labels"][0], expected_class_labels))
# fmt: off
expected_class_labels = torch.tensor([19, 67, 82, 17, 12, 42, 3, 14, 5, 0, 115, 43, 8, 138, 125, 143]) # noqa: E231
# fmt: on
self.assertTrue(torch.allclose(inputs["class_labels"][1], expected_class_labels))
# verify the task inputs
self.assertEqual(len(inputs["task_inputs"]), 2)
self.assertEqual(inputs["task_inputs"][0].sum().item(), 141082)
self.assertEqual(inputs["task_inputs"][0].sum().item(), inputs["task_inputs"][1].sum().item())
# verify the text inputs
self.assertEqual(len(inputs["text_inputs"]), 2)
self.assertEqual(inputs["text_inputs"][0].sum().item(), 1095752)
self.assertEqual(inputs["text_inputs"][1].sum().item(), 1062468)
# verify the mask labels
self.assertEqual(len(inputs["mask_labels"]), 2)
self.assertEqual(inputs["mask_labels"][0].shape, (16, 512, 711))
self.assertEqual(inputs["mask_labels"][1].shape, (16, 512, 711))
self.assertEqual(inputs["mask_labels"][0].sum().item(), 315193.0)
self.assertEqual(inputs["mask_labels"][1].sum().item(), 350747.0)
def test_integration_instance_segmentation(self):
# load 2 images and corresponding panoptic annotations from the hub
dataset = load_dataset("nielsr/ade20k-panoptic-demo")
image1 = dataset["train"][0]["image"]
image2 = dataset["train"][1]["image"]
segments_info1 = dataset["train"][0]["segments_info"]
segments_info2 = dataset["train"][1]["segments_info"]
annotation1 = dataset["train"][0]["label"]
annotation2 = dataset["train"][1]["label"]
def rgb_to_id(color):
if isinstance(color, np.ndarray) and len(color.shape) == 3:
if color.dtype == np.uint8:
color = color.astype(np.int32)
return color[:, :, 0] + 256 * color[:, :, 1] + 256 * 256 * color[:, :, 2]
return int(color[0] + 256 * color[1] + 256 * 256 * color[2])
def create_panoptic_map(annotation, segments_info):
annotation = np.array(annotation)
# convert RGB to segment IDs per pixel
# 0 is the "ignore" label, for which we don't need to make binary masks
panoptic_map = rgb_to_id(annotation)
# create mapping between segment IDs and semantic classes
inst2class = {segment["id"]: segment["category_id"] for segment in segments_info}
return panoptic_map, inst2class
panoptic_map1, inst2class1 = create_panoptic_map(annotation1, segments_info1)
panoptic_map2, inst2class2 = create_panoptic_map(annotation2, segments_info2)
image_processor = OneFormerImageProcessor(
reduce_labels=True,
ignore_index=0,
size=(512, 512),
class_info_file="ade20k_panoptic.json",
num_text=self.processing_tester.num_text,
)
tokenizer = CLIPTokenizer.from_pretrained("shi-labs/oneformer_ade20k_swin_tiny")
processor = OneFormerProcessor(
image_processor=image_processor,
tokenizer=tokenizer,
max_seq_length=77,
task_seq_length=77,
)
# prepare the images and annotations
pixel_values_list = [np.moveaxis(np.array(image1), -1, 0), np.moveaxis(np.array(image2), -1, 0)]
inputs = processor.encode_inputs(
pixel_values_list,
["instance", "instance"],
[panoptic_map1, panoptic_map2],
instance_id_to_semantic_id=[inst2class1, inst2class2],
return_tensors="pt",
)
# verify the pixel values, task inputs, text inputs and pixel mask
self.assertEqual(inputs["pixel_values"].shape, (2, 3, 512, 711))
self.assertEqual(inputs["pixel_mask"].shape, (2, 512, 711))
self.assertEqual(inputs["task_inputs"].shape, (2, 77))
self.assertEqual(inputs["text_inputs"].shape, (2, self.processing_tester.num_text, 77))
# verify the class labels
self.assertEqual(len(inputs["class_labels"]), 2)
# fmt: off
expected_class_labels = torch.tensor([32, 42, 42, 42, 42, 42, 42, 42, 32, 12, 12, 12, 12, 12, 42, 42, 12, 12, 12, 42, 12, 12, 12, 12, 12, 12, 12, 12, 12, 42, 42, 42, 12, 42, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 43, 43, 43, 43, 104, 43, 31, 125, 31, 125, 138, 87, 125, 149, 138, 125, 87, 87]) # noqa: E231
# fmt: on
self.assertTrue(torch.allclose(inputs["class_labels"][0], expected_class_labels))
# fmt: off
expected_class_labels = torch.tensor([19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 67, 82, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 12, 12, 42, 12, 12, 12, 12, 14, 12, 12, 12, 12, 12, 12, 12, 12, 14, 12, 12, 115, 43, 43, 115, 43, 43, 43, 8, 8, 8, 138, 138, 125, 143]) # noqa: E231
# fmt: on
self.assertTrue(torch.allclose(inputs["class_labels"][1], expected_class_labels))
# verify the task inputs
self.assertEqual(len(inputs["task_inputs"]), 2)
self.assertEqual(inputs["task_inputs"][0].sum().item(), 144985)
self.assertEqual(inputs["task_inputs"][0].sum().item(), inputs["task_inputs"][1].sum().item())
# verify the text inputs
self.assertEqual(len(inputs["text_inputs"]), 2)
self.assertEqual(inputs["text_inputs"][0].sum().item(), 1037040)
self.assertEqual(inputs["text_inputs"][1].sum().item(), 1044078)
# verify the mask labels
self.assertEqual(len(inputs["mask_labels"]), 2)
self.assertEqual(inputs["mask_labels"][0].shape, (73, 512, 711))
self.assertEqual(inputs["mask_labels"][1].shape, (57, 512, 711))
self.assertEqual(inputs["mask_labels"][0].sum().item(), 35040.0)
self.assertEqual(inputs["mask_labels"][1].sum().item(), 98228.0)
def test_integration_panoptic_segmentation(self):
# load 2 images and corresponding panoptic annotations from the hub
dataset = load_dataset("nielsr/ade20k-panoptic-demo")
image1 = dataset["train"][0]["image"]
image2 = dataset["train"][1]["image"]
segments_info1 = dataset["train"][0]["segments_info"]
segments_info2 = dataset["train"][1]["segments_info"]
annotation1 = dataset["train"][0]["label"]
annotation2 = dataset["train"][1]["label"]
def rgb_to_id(color):
if isinstance(color, np.ndarray) and len(color.shape) == 3:
if color.dtype == np.uint8:
color = color.astype(np.int32)
return color[:, :, 0] + 256 * color[:, :, 1] + 256 * 256 * color[:, :, 2]
return int(color[0] + 256 * color[1] + 256 * 256 * color[2])
def create_panoptic_map(annotation, segments_info):
annotation = np.array(annotation)
# convert RGB to segment IDs per pixel
# 0 is the "ignore" label, for which we don't need to make binary masks
panoptic_map = rgb_to_id(annotation)
# create mapping between segment IDs and semantic classes
inst2class = {segment["id"]: segment["category_id"] for segment in segments_info}
return panoptic_map, inst2class
panoptic_map1, inst2class1 = create_panoptic_map(annotation1, segments_info1)
panoptic_map2, inst2class2 = create_panoptic_map(annotation2, segments_info2)
image_processor = OneFormerImageProcessor(
reduce_labels=True,
ignore_index=0,
size=(512, 512),
class_info_file="ade20k_panoptic.json",
num_text=self.processing_tester.num_text,
)
tokenizer = CLIPTokenizer.from_pretrained("shi-labs/oneformer_ade20k_swin_tiny")
processor = OneFormerProcessor(
image_processor=image_processor,
tokenizer=tokenizer,
max_seq_length=77,
task_seq_length=77,
)
# prepare the images and annotations
pixel_values_list = [np.moveaxis(np.array(image1), -1, 0), np.moveaxis(np.array(image2), -1, 0)]
inputs = processor.encode_inputs(
pixel_values_list,
["panoptic", "panoptic"],
[panoptic_map1, panoptic_map2],
instance_id_to_semantic_id=[inst2class1, inst2class2],
return_tensors="pt",
)
# verify the pixel values, task inputs, text inputs and pixel mask
self.assertEqual(inputs["pixel_values"].shape, (2, 3, 512, 711))
self.assertEqual(inputs["pixel_mask"].shape, (2, 512, 711))
self.assertEqual(inputs["task_inputs"].shape, (2, 77))
self.assertEqual(inputs["text_inputs"].shape, (2, self.processing_tester.num_text, 77))
# verify the class labels
self.assertEqual(len(inputs["class_labels"]), 2)
# fmt: off
expected_class_labels = torch.tensor([4, 17, 32, 42, 42, 42, 42, 42, 42, 42, 32, 12, 12, 12, 12, 12, 42, 42, 12, 12, 12, 42, 12, 12, 12, 12, 12, 3, 12, 12, 12, 12, 42, 42, 42, 12, 42, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 5, 12, 12, 12, 12, 12, 12, 12, 0, 43, 43, 43, 96, 43, 104, 43, 31, 125, 31, 125, 138, 87, 125, 149, 138, 125, 87, 87]) # noqa: E231
# fmt: on
self.assertTrue(torch.allclose(inputs["class_labels"][0], expected_class_labels))
# fmt: off
expected_class_labels = torch.tensor([19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 67, 82, 19, 19, 17, 19, 19, 19, 19, 19, 19, 19, 19, 19, 12, 12, 42, 12, 12, 12, 12, 3, 14, 12, 12, 12, 12, 12, 12, 12, 12, 14, 5, 12, 12, 0, 115, 43, 43, 115, 43, 43, 43, 8, 8, 8, 138, 138, 125, 143]) # noqa: E231
# fmt: on
self.assertTrue(torch.allclose(inputs["class_labels"][1], expected_class_labels))
# verify the task inputs
self.assertEqual(len(inputs["task_inputs"]), 2)
self.assertEqual(inputs["task_inputs"][0].sum().item(), 136240)
self.assertEqual(inputs["task_inputs"][0].sum().item(), inputs["task_inputs"][1].sum().item())
# verify the text inputs
self.assertEqual(len(inputs["text_inputs"]), 2)
self.assertEqual(inputs["text_inputs"][0].sum().item(), 1048653)
self.assertEqual(inputs["text_inputs"][1].sum().item(), 1067160)
# verify the mask labels
self.assertEqual(len(inputs["mask_labels"]), 2)
self.assertEqual(inputs["mask_labels"][0].shape, (79, 512, 711))
self.assertEqual(inputs["mask_labels"][1].shape, (61, 512, 711))
self.assertEqual(inputs["mask_labels"][0].sum().item(), 315193.0)
self.assertEqual(inputs["mask_labels"][1].sum().item(), 350747.0)
def test_binary_mask_to_rle(self):
fake_binary_mask = np.zeros((20, 50))
fake_binary_mask[0, 20:] = 1
fake_binary_mask[1, :15] = 1
fake_binary_mask[5, :10] = 1
rle = binary_mask_to_rle(fake_binary_mask)
self.assertEqual(len(rle), 4)
self.assertEqual(rle[0], 21)
self.assertEqual(rle[1], 45)
def test_post_process_semantic_segmentation(self):
image_processor = OneFormerImageProcessor(
reduce_labels=True,
ignore_index=0,
size=(512, 512),
class_info_file="ade20k_panoptic.json",
num_text=self.processing_tester.num_text,
)
tokenizer = CLIPTokenizer.from_pretrained("shi-labs/oneformer_ade20k_swin_tiny")
processor = OneFormerProcessor(
image_processor=image_processor,
tokenizer=tokenizer,
max_seq_length=77,
task_seq_length=77,
)
outputs = self.processing_tester.get_fake_oneformer_outputs()
segmentation = processor.post_process_semantic_segmentation(outputs)
self.assertEqual(len(segmentation), self.processing_tester.batch_size)
self.assertEqual(
segmentation[0].shape,
(
self.processing_tester.height,
self.processing_tester.width,
),
)
target_sizes = [(1, 4) for i in range(self.processing_tester.batch_size)]
segmentation = processor.post_process_semantic_segmentation(outputs, target_sizes=target_sizes)
self.assertEqual(segmentation[0].shape, target_sizes[0])
def test_post_process_instance_segmentation(self):
image_processor = OneFormerImageProcessor(
reduce_labels=True,
ignore_index=0,
size=(512, 512),
class_info_file="ade20k_panoptic.json",
num_text=self.processing_tester.num_text,
)
tokenizer = CLIPTokenizer.from_pretrained("shi-labs/oneformer_ade20k_swin_tiny")
processor = OneFormerProcessor(
image_processor=image_processor,
tokenizer=tokenizer,
max_seq_length=77,
task_seq_length=77,
)
outputs = self.processing_tester.get_fake_oneformer_outputs()
segmentation = processor.post_process_instance_segmentation(outputs, threshold=0)
self.assertTrue(len(segmentation) == self.processing_tester.batch_size)
for el in segmentation:
self.assertTrue("segmentation" in el)
self.assertTrue("segments_info" in el)
self.assertEqual(type(el["segments_info"]), list)
self.assertEqual(el["segmentation"].shape, (self.processing_tester.height, self.processing_tester.width))
def test_post_process_panoptic_segmentation(self):
image_processor = OneFormerImageProcessor(
reduce_labels=True,
ignore_index=0,
size=(512, 512),
class_info_file="ade20k_panoptic.json",
num_text=self.processing_tester.num_text,
)
tokenizer = CLIPTokenizer.from_pretrained("shi-labs/oneformer_ade20k_swin_tiny")
processor = OneFormerProcessor(
image_processor=image_processor,
tokenizer=tokenizer,
max_seq_length=77,
task_seq_length=77,
)
outputs = self.processing_tester.get_fake_oneformer_outputs()
segmentation = processor.post_process_panoptic_segmentation(outputs, threshold=0)
self.assertTrue(len(segmentation) == self.processing_tester.batch_size)
for el in segmentation:
self.assertTrue("segmentation" in el)
self.assertTrue("segments_info" in el)
self.assertEqual(type(el["segments_info"]), list)
self.assertEqual(el["segmentation"].shape, (self.processing_tester.height, self.processing_tester.width))
......@@ -124,6 +124,8 @@ src/transformers/models/mobilevit/modeling_tf_mobilevit.py
src/transformers/models/nat/configuration_nat.py
src/transformers/models/nat/modeling_nat.py
src/transformers/models/nezha/configuration_nezha.py
src/transformers/models/oneformer/configuration_oneformer.py
src/transformers/models/oneformer/modeling_oneformer.py
src/transformers/models/openai/configuration_openai.py
src/transformers/models/opt/configuration_opt.py
src/transformers/models/opt/modeling_opt.py
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment