Unverified Commit e20faa6f authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

Add BeitForSemanticSegmentation (#14096)



* Add first draft

* Make forward pass work

* Improve conversion script

* Add notebook that checks if it works

* Add BeitForSemanticSegmentation to the tests

* More improvements

* Make BeitForSemanticSegmentation consistent with Segformer

* Small bug fix

* Add BeitForSemanticSegmentation to docs

* Make sure model doesn't output hidden states when the user doesn't want to

* Make it possible to convert the large model

* Fix issue

* Fix conversion script for large model

* Add auxiliary_head option to semantic segmentation model

* Apply suggestions from @sgugger's review

* Apply suggestions from code review

* Fix failing test
Co-authored-by: default avatarLysandre <lysandre.debut@reseau.eseo.fr>
parent 8b325781
......@@ -98,6 +98,13 @@ BeitForImageClassification
:members: forward
BeitForSemanticSegmentation
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.BeitForSemanticSegmentation
:members: forward
FlaxBeitModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
......@@ -638,6 +638,7 @@ if is_torch_available():
"BEIT_PRETRAINED_MODEL_ARCHIVE_LIST",
"BeitForImageClassification",
"BeitForMaskedImageModeling",
"BeitForSemanticSegmentation",
"BeitModel",
"BeitPreTrainedModel",
]
......@@ -2483,6 +2484,7 @@ if TYPE_CHECKING:
BEIT_PRETRAINED_MODEL_ARCHIVE_LIST,
BeitForImageClassification,
BeitForMaskedImageModeling,
BeitForSemanticSegmentation,
BeitModel,
BeitPreTrainedModel,
)
......
......@@ -33,6 +33,7 @@ if is_torch_available():
"BEIT_PRETRAINED_MODEL_ARCHIVE_LIST",
"BeitForImageClassification",
"BeitForMaskedImageModeling",
"BeitForSemanticSegmentation",
"BeitModel",
"BeitPreTrainedModel",
]
......@@ -57,6 +58,7 @@ if TYPE_CHECKING:
BEIT_PRETRAINED_MODEL_ARCHIVE_LIST,
BeitForImageClassification,
BeitForMaskedImageModeling,
BeitForSemanticSegmentation,
BeitModel,
BeitPreTrainedModel,
)
......
......@@ -78,6 +78,20 @@ class BeitConfig(PretrainedConfig):
use_mean_pooling (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether to mean pool the final hidden states of the patches instead of using the final hidden state of the
CLS token, before applying the classification head.
out_indices (:obj:`List[int]`, `optional`, defaults to :obj:`[3, 5, 7, 11]`):
Indices of the feature maps to use for semantic segmentation.
pool_scales (:obj:`Tuple[int]`, `optional`, defaults to :obj:`[1, 2, 3, 6]`):
Pooling scales used in Pooling Pyramid Module applied on the last feature map.
use_auxiliary_head (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether to use an auxiliary head during training.
auxiliary_loss_weight (:obj:`float`, `optional`, defaults to 0.4):
Weight of the cross-entropy loss of the auxiliary head.
auxiliary_channels (:obj:`int`, `optional`, defaults to 256):
Number of channels to use in the auxiliary head.
auxiliary_num_convs (:obj:`int`, `optional`, defaults to 1):
Number of convolutional layers to use in the auxiliary head.
auxiliary_concat_input (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to concatenate the output of the auxiliary head with the input before the classification layer.
Example::
......@@ -117,6 +131,13 @@ class BeitConfig(PretrainedConfig):
layer_scale_init_value=0.1,
drop_path_rate=0.1,
use_mean_pooling=True,
out_indices=[3, 5, 7, 11],
pool_scales=[1, 2, 3, 6],
use_auxiliary_head=True,
auxiliary_loss_weight=0.4,
auxiliary_channels=256,
auxiliary_num_convs=1,
auxiliary_concat_input=False,
**kwargs
):
super().__init__(**kwargs)
......@@ -142,3 +163,12 @@ class BeitConfig(PretrainedConfig):
self.layer_scale_init_value = layer_scale_init_value
self.drop_path_rate = drop_path_rate
self.use_mean_pooling = use_mean_pooling
# decode head attributes (semantic segmentation)
self.out_indices = out_indices
self.pool_scales = pool_scales
# auxiliary head attributes (semantic segmentation)
self.use_auxiliary_head = use_auxiliary_head
self.auxiliary_loss_weight = auxiliary_loss_weight
self.auxiliary_channels = auxiliary_channels
self.auxiliary_num_convs = auxiliary_num_convs
self.auxiliary_concat_input = auxiliary_concat_input
......@@ -20,11 +20,18 @@ import json
from pathlib import Path
import torch
from datasets import load_dataset
from PIL import Image
import requests
from huggingface_hub import cached_download, hf_hub_url
from transformers import BeitConfig, BeitFeatureExtractor, BeitForImageClassification, BeitForMaskedImageModeling
from transformers import (
BeitConfig,
BeitFeatureExtractor,
BeitForImageClassification,
BeitForMaskedImageModeling,
BeitForSemanticSegmentation,
)
from transformers.utils import logging
......@@ -33,27 +40,33 @@ logger = logging.get_logger(__name__)
# here we list all keys to be renamed (original name on the left, our name on the right)
def create_rename_keys(config, has_lm_head=False):
def create_rename_keys(config, has_lm_head=False, is_semantic=False):
prefix = "backbone." if is_semantic else ""
rename_keys = []
for i in range(config.num_hidden_layers):
# encoder layers: output projection, 2 feedforward neural networks and 2 layernorms
rename_keys.append((f"blocks.{i}.norm1.weight", f"beit.encoder.layer.{i}.layernorm_before.weight"))
rename_keys.append((f"blocks.{i}.norm1.bias", f"beit.encoder.layer.{i}.layernorm_before.bias"))
rename_keys.append((f"blocks.{i}.attn.proj.weight", f"beit.encoder.layer.{i}.attention.output.dense.weight"))
rename_keys.append((f"blocks.{i}.attn.proj.bias", f"beit.encoder.layer.{i}.attention.output.dense.bias"))
rename_keys.append((f"blocks.{i}.norm2.weight", f"beit.encoder.layer.{i}.layernorm_after.weight"))
rename_keys.append((f"blocks.{i}.norm2.bias", f"beit.encoder.layer.{i}.layernorm_after.bias"))
rename_keys.append((f"blocks.{i}.mlp.fc1.weight", f"beit.encoder.layer.{i}.intermediate.dense.weight"))
rename_keys.append((f"blocks.{i}.mlp.fc1.bias", f"beit.encoder.layer.{i}.intermediate.dense.bias"))
rename_keys.append((f"blocks.{i}.mlp.fc2.weight", f"beit.encoder.layer.{i}.output.dense.weight"))
rename_keys.append((f"blocks.{i}.mlp.fc2.bias", f"beit.encoder.layer.{i}.output.dense.bias"))
rename_keys.append((f"{prefix}blocks.{i}.norm1.weight", f"beit.encoder.layer.{i}.layernorm_before.weight"))
rename_keys.append((f"{prefix}blocks.{i}.norm1.bias", f"beit.encoder.layer.{i}.layernorm_before.bias"))
rename_keys.append(
(f"{prefix}blocks.{i}.attn.proj.weight", f"beit.encoder.layer.{i}.attention.output.dense.weight")
)
rename_keys.append(
(f"{prefix}blocks.{i}.attn.proj.bias", f"beit.encoder.layer.{i}.attention.output.dense.bias")
)
rename_keys.append((f"{prefix}blocks.{i}.norm2.weight", f"beit.encoder.layer.{i}.layernorm_after.weight"))
rename_keys.append((f"{prefix}blocks.{i}.norm2.bias", f"beit.encoder.layer.{i}.layernorm_after.bias"))
rename_keys.append((f"{prefix}blocks.{i}.mlp.fc1.weight", f"beit.encoder.layer.{i}.intermediate.dense.weight"))
rename_keys.append((f"{prefix}blocks.{i}.mlp.fc1.bias", f"beit.encoder.layer.{i}.intermediate.dense.bias"))
rename_keys.append((f"{prefix}blocks.{i}.mlp.fc2.weight", f"beit.encoder.layer.{i}.output.dense.weight"))
rename_keys.append((f"{prefix}blocks.{i}.mlp.fc2.bias", f"beit.encoder.layer.{i}.output.dense.bias"))
# projection layer + position embeddings
rename_keys.extend(
[
("cls_token", "beit.embeddings.cls_token"),
("patch_embed.proj.weight", "beit.embeddings.patch_embeddings.projection.weight"),
("patch_embed.proj.bias", "beit.embeddings.patch_embeddings.projection.bias"),
(f"{prefix}cls_token", "beit.embeddings.cls_token"),
(f"{prefix}patch_embed.proj.weight", "beit.embeddings.patch_embeddings.projection.weight"),
(f"{prefix}patch_embed.proj.bias", "beit.embeddings.patch_embeddings.projection.bias"),
]
)
......@@ -74,6 +87,16 @@ def create_rename_keys(config, has_lm_head=False):
("norm.bias", "layernorm.bias"),
]
)
elif is_semantic:
# semantic segmentation classification heads
rename_keys.extend(
[
("decode_head.conv_seg.weight", "decode_head.classifier.weight"),
("decode_head.conv_seg.bias", "decode_head.classifier.bias"),
("auxiliary_head.conv_seg.weight", "auxiliary_head.classifier.weight"),
("auxiliary_head.conv_seg.bias", "auxiliary_head.classifier.bias"),
]
)
else:
# layernorm + classification head
rename_keys.extend(
......@@ -89,45 +112,45 @@ def create_rename_keys(config, has_lm_head=False):
# we split up the matrix of each encoder layer into queries, keys and values
def read_in_q_k_v(state_dict, config, has_lm_head=False):
def read_in_q_k_v(state_dict, config, has_lm_head=False, is_semantic=False):
for i in range(config.num_hidden_layers):
prefix = "beit."
prefix = "backbone." if is_semantic else ""
# queries, keys and values
in_proj_weight = state_dict.pop(f"blocks.{i}.attn.qkv.weight")
q_bias = state_dict.pop(f"blocks.{i}.attn.q_bias")
v_bias = state_dict.pop(f"blocks.{i}.attn.v_bias")
in_proj_weight = state_dict.pop(f"{prefix}blocks.{i}.attn.qkv.weight")
q_bias = state_dict.pop(f"{prefix}blocks.{i}.attn.q_bias")
v_bias = state_dict.pop(f"{prefix}blocks.{i}.attn.v_bias")
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[
state_dict[f"beit.encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[
: config.hidden_size, :
]
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.bias"] = q_bias
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
state_dict[f"beit.encoder.layer.{i}.attention.attention.query.bias"] = q_bias
state_dict[f"beit.encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
config.hidden_size : config.hidden_size * 2, :
]
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[
state_dict[f"beit.encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[
-config.hidden_size :, :
]
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.bias"] = v_bias
state_dict[f"beit.encoder.layer.{i}.attention.attention.value.bias"] = v_bias
# gamma_1 and gamma_2
# we call them lambda because otherwise they are renamed when using .from_pretrained
gamma_1 = state_dict.pop(f"blocks.{i}.gamma_1")
gamma_2 = state_dict.pop(f"blocks.{i}.gamma_2")
gamma_1 = state_dict.pop(f"{prefix}blocks.{i}.gamma_1")
gamma_2 = state_dict.pop(f"{prefix}blocks.{i}.gamma_2")
state_dict[f"{prefix}encoder.layer.{i}.lambda_1"] = gamma_1
state_dict[f"{prefix}encoder.layer.{i}.lambda_2"] = gamma_2
state_dict[f"beit.encoder.layer.{i}.lambda_1"] = gamma_1
state_dict[f"beit.encoder.layer.{i}.lambda_2"] = gamma_2
# relative_position bias table + index
if not has_lm_head:
# each layer has its own relative position bias
table = state_dict.pop(f"blocks.{i}.attn.relative_position_bias_table")
index = state_dict.pop(f"blocks.{i}.attn.relative_position_index")
table = state_dict.pop(f"{prefix}blocks.{i}.attn.relative_position_bias_table")
index = state_dict.pop(f"{prefix}blocks.{i}.attn.relative_position_index")
state_dict[
f"{prefix}encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_bias_table"
f"beit.encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_bias_table"
] = table
state_dict[
f"{prefix}encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_index"
f"beit.encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_index"
] = index
......@@ -152,6 +175,7 @@ def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path):
# define default BEiT configuration
config = BeitConfig()
has_lm_head = False
is_semantic = False
repo_id = "datasets/huggingface/label-files"
# set config parameters based on URL
if checkpoint_url[-9:-4] == "pt22k":
......@@ -185,8 +209,19 @@ def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path):
config.image_size = 384
if "512" in checkpoint_url:
config.image_size = 512
elif "ade20k" in checkpoint_url:
# fine-tuning
config.use_relative_position_bias = True
config.num_labels = 150
filename = "ade20k-id2label.json"
id2label = json.load(open(cached_download(hf_hub_url(repo_id, filename)), "r"))
id2label = {int(k): v for k, v in id2label.items()}
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}
config.image_size = 640
is_semantic = True
else:
raise ValueError("Checkpoint not supported, URL should either end with 'pt22k', 'ft22k' or 'to1k'")
raise ValueError("Checkpoint not supported, URL should either end with 'pt22k', 'ft22k', 'to1k' or 'ade20k'")
# size of the architecture
if "base" in checkpoint_url:
......@@ -196,27 +231,48 @@ def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path):
config.intermediate_size = 4096
config.num_hidden_layers = 24
config.num_attention_heads = 16
if "ade20k" in checkpoint_url:
config.image_size = 640
config.out_indices = [7, 11, 15, 23]
else:
raise ValueError("Should either find 'base' or 'large' in checkpoint URL")
# load state_dict of original model, remove and rename some keys
state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu", check_hash=True)["model"]
rename_keys = create_rename_keys(config, has_lm_head=has_lm_head)
state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu", check_hash=True)
state_dict = state_dict["model"] if "ade20k" not in checkpoint_url else state_dict["state_dict"]
rename_keys = create_rename_keys(config, has_lm_head=has_lm_head, is_semantic=is_semantic)
for src, dest in rename_keys:
rename_key(state_dict, src, dest)
read_in_q_k_v(state_dict, config, has_lm_head=has_lm_head)
read_in_q_k_v(state_dict, config, has_lm_head=has_lm_head, is_semantic=is_semantic)
if is_semantic:
# add prefix to decoder keys
for key, val in state_dict.copy().items():
val = state_dict.pop(key)
if key.startswith("backbone.fpn"):
key = key.replace("backbone.fpn", "fpn")
state_dict[key] = val
# load HuggingFace model
if checkpoint_url[-9:-4] == "pt22k":
model = BeitForMaskedImageModeling(config)
elif "ade20k" in checkpoint_url:
model = BeitForSemanticSegmentation(config)
else:
model = BeitForImageClassification(config)
model.eval()
model.load_state_dict(state_dict)
# Check outputs on an image
feature_extractor = BeitFeatureExtractor(size=config.image_size, resample=Image.BILINEAR, do_center_crop=False)
encoding = feature_extractor(images=prepare_img(), return_tensors="pt")
if is_semantic:
feature_extractor = BeitFeatureExtractor(size=config.image_size, do_center_crop=False)
ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test")
image = Image.open(ds[0]["file"])
else:
feature_extractor = BeitFeatureExtractor(size=config.image_size, resample=Image.BILINEAR, do_center_crop=False)
image = prepare_img()
encoding = feature_extractor(images=image, return_tensors="pt")
pixel_values = encoding["pixel_values"]
outputs = model(pixel_values)
......@@ -257,15 +313,39 @@ def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path):
elif checkpoint_url[:-4].endswith("beit_large_patch16_512_pt22k_ft22kto1k"):
expected_logits = torch.tensor([-0.3062, 0.7261, 0.4852])
expected_class_idx = 761
elif checkpoint_url[:-4].endswith("beit_base_patch16_640_pt22k_ft22ktoade20k"):
expected_shape = (1, 150, 160, 160)
expected_logits = torch.tensor(
[
[[-4.9225, -2.3954, -3.0522], [-2.8822, -1.0046, -1.7561], [-2.9549, -1.3228, -2.1347]],
[[-5.8168, -3.4129, -4.0778], [-3.8651, -2.2214, -3.0277], [-3.8356, -2.4643, -3.3535]],
[[-0.0078, 3.9952, 4.0754], [2.9856, 4.6944, 5.0035], [3.2413, 4.7813, 4.9969]],
]
)
elif checkpoint_url[:-4].endswith("beit_large_patch16_640_pt22k_ft22ktoade20k"):
expected_shape = (1, 150, 160, 160)
expected_logits = torch.tensor(
[
[[-4.3305, -2.3049, -3.0161], [-2.9591, -1.5305, -2.2251], [-3.4198, -1.8004, -2.9062]],
[[-5.8922, -3.7435, -4.3978], [-4.2063, -2.7872, -3.4755], [-4.2791, -3.1874, -4.1681]],
[[0.9895, 4.3467, 4.7663], [4.2476, 5.6830, 6.1518], [4.5550, 6.2495, 6.5154]],
]
)
else:
raise ValueError("Can't verify logits as model is not supported")
assert logits.shape == expected_shape, "Shape of logits not as expected"
print("Shape of logits:", logits.shape)
if not has_lm_head:
print("Predicted class idx:", logits.argmax(-1).item())
assert torch.allclose(logits[0, :3], expected_logits, atol=1e-3), "First elements of logits not as expected"
assert logits.argmax(-1).item() == expected_class_idx, "Predicted class index not as expected"
if is_semantic:
assert torch.allclose(
logits[0, :3, :3, :3], expected_logits, atol=1e-3
), "First elements of logits not as expected"
else:
print("Predicted class idx:", logits.argmax(-1).item())
assert torch.allclose(
logits[0, :3], expected_logits, atol=1e-3
), "First elements of logits not as expected"
assert logits.argmax(-1).item() == expected_class_idx, "Predicted class index not as expected"
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
print(f"Saving model to {pytorch_dump_folder_path}")
......
......@@ -163,6 +163,7 @@ class PatchEmbeddings(nn.Module):
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
)
x = self.projection(pixel_values).flatten(2).transpose(1, 2)
return x
......@@ -499,7 +500,7 @@ class BeitPreTrainedModel(PreTrainedModel):
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d)):
if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
......@@ -851,3 +852,354 @@ class BeitForImageClassification(BeitPreTrainedModel):
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class BeitConvModule(nn.Module):
"""
A convolutional block that bundles conv/norm/activation layers. This block simplifies the usage of convolution
layers, which are commonly used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU).
Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
"""
def __init__(self, in_channels, out_channels, kernel_size, padding=0, bias=False, dilation=1):
super().__init__()
self.conv = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
padding=padding,
bias=bias,
dilation=dilation,
)
self.bn = nn.BatchNorm2d(out_channels)
self.activation = nn.ReLU()
def forward(self, input):
output = self.conv(input)
output = self.bn(output)
output = self.activation(output)
return output
class BeitPyramidPoolingModule(nn.ModuleList):
"""
Pyramid Pooling Module (PPM) used in PSPNet.
Args:
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
Module.
in_channels (int): Input channels.
channels (int): Channels after modules, before conv_seg.
align_corners (bool): align_corners argument of F.interpolate.
Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
"""
def __init__(self, pool_scales, in_channels, channels, align_corners):
super().__init__()
self.pool_scales = pool_scales
self.align_corners = align_corners
self.in_channels = in_channels
self.channels = channels
for pool_scale in pool_scales:
self.append(
nn.Sequential(
nn.AdaptiveAvgPool2d(pool_scale),
BeitConvModule(self.in_channels, self.channels, kernel_size=1),
)
)
def forward(self, x):
ppm_outs = []
for ppm in self:
ppm_out = ppm(x)
upsampled_ppm_out = nn.functional.interpolate(
ppm_out, size=x.size()[2:], mode="bilinear", align_corners=self.align_corners
)
ppm_outs.append(upsampled_ppm_out)
return ppm_outs
class BeitUperHead(nn.Module):
"""
Unified Perceptual Parsing for Scene Understanding. This head is the implementation of `UPerNet
<https://arxiv.org/abs/1807.10221>`_.
Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
"""
def __init__(self, config):
super().__init__()
self.pool_scales = config.pool_scales # e.g. (1, 2, 3, 6)
self.in_channels = [config.hidden_size] * 4 # e.g. [768, 768, 768, 768]
self.channels = config.hidden_size
self.align_corners = False
self.classifier = nn.Conv2d(self.channels, config.num_labels, kernel_size=1)
# PSP Module
self.psp_modules = BeitPyramidPoolingModule(
self.pool_scales,
self.in_channels[-1],
self.channels,
align_corners=self.align_corners,
)
self.bottleneck = BeitConvModule(
self.in_channels[-1] + len(self.pool_scales) * self.channels,
self.channels,
kernel_size=3,
padding=1,
)
# FPN Module
self.lateral_convs = nn.ModuleList()
self.fpn_convs = nn.ModuleList()
for in_channels in self.in_channels[:-1]: # skip the top layer
l_conv = BeitConvModule(in_channels, self.channels, kernel_size=1)
fpn_conv = BeitConvModule(self.channels, self.channels, kernel_size=3, padding=1)
self.lateral_convs.append(l_conv)
self.fpn_convs.append(fpn_conv)
self.fpn_bottleneck = BeitConvModule(
len(self.in_channels) * self.channels,
self.channels,
kernel_size=3,
padding=1,
)
def psp_forward(self, inputs):
x = inputs[-1]
psp_outs = [x]
psp_outs.extend(self.psp_modules(x))
psp_outs = torch.cat(psp_outs, dim=1)
output = self.bottleneck(psp_outs)
return output
def forward(self, encoder_hidden_states):
# build laterals
laterals = [lateral_conv(encoder_hidden_states[i]) for i, lateral_conv in enumerate(self.lateral_convs)]
laterals.append(self.psp_forward(encoder_hidden_states))
# build top-down path
used_backbone_levels = len(laterals)
for i in range(used_backbone_levels - 1, 0, -1):
prev_shape = laterals[i - 1].shape[2:]
laterals[i - 1] += nn.functional.interpolate(
laterals[i], size=prev_shape, mode="bilinear", align_corners=self.align_corners
)
# build outputs
fpn_outs = [self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels - 1)]
# append psp feature
fpn_outs.append(laterals[-1])
for i in range(used_backbone_levels - 1, 0, -1):
fpn_outs[i] = nn.functional.interpolate(
fpn_outs[i], size=fpn_outs[0].shape[2:], mode="bilinear", align_corners=self.align_corners
)
fpn_outs = torch.cat(fpn_outs, dim=1)
output = self.fpn_bottleneck(fpn_outs)
output = self.classifier(output)
return output
class BeitFCNHead(nn.Module):
"""
Fully Convolution Networks for Semantic Segmentation. This head is implemented of `FCNNet
<https://arxiv.org/abs/1411.4038>`_.
Args:
config (BeitConfig): Configuration.
in_channels
kernel_size (int): The kernel size for convs in the head. Default: 3.
dilation (int): The dilation rate for convs in the head. Default: 1.
Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
"""
def __init__(self, config, in_index=2, kernel_size=3, dilation=1):
super().__init__()
self.in_channels = config.hidden_size
self.channels = config.auxiliary_channels
self.num_convs = config.auxiliary_num_convs
self.concat_input = config.auxiliary_concat_input
self.in_index = in_index
conv_padding = (kernel_size // 2) * dilation
convs = []
convs.append(
BeitConvModule(
self.in_channels, self.channels, kernel_size=kernel_size, padding=conv_padding, dilation=dilation
)
)
for i in range(self.num_convs - 1):
convs.append(
BeitConvModule(
self.channels, self.channels, kernel_size=kernel_size, padding=conv_padding, dilation=dilation
)
)
if self.num_convs == 0:
self.convs = nn.Identity()
else:
self.convs = nn.Sequential(*convs)
if self.concat_input:
self.conv_cat = BeitConvModule(
self.in_channels + self.channels, self.channels, kernel_size=kernel_size, padding=kernel_size // 2
)
self.classifier = nn.Conv2d(self.channels, config.num_labels, kernel_size=1)
def forward(self, encoder_hidden_states):
# just take the relevant feature maps
hidden_states = encoder_hidden_states[self.in_index]
output = self.convs(hidden_states)
if self.concat_input:
output = self.conv_cat(torch.cat([hidden_states, output], dim=1))
output = self.classifier(output)
return output
@add_start_docstrings(
"""
Beit Model transformer with a semantic segmentation head on top e.g. for ADE20k, CityScapes.
""",
BEIT_START_DOCSTRING,
)
class BeitForSemanticSegmentation(BeitPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.beit = BeitModel(config, add_pooling_layer=False)
# FPNs
self.fpn1 = nn.Sequential(
nn.ConvTranspose2d(config.hidden_size, config.hidden_size, kernel_size=2, stride=2),
nn.BatchNorm2d(config.hidden_size),
nn.GELU(),
nn.ConvTranspose2d(config.hidden_size, config.hidden_size, kernel_size=2, stride=2),
)
self.fpn2 = nn.Sequential(
nn.ConvTranspose2d(config.hidden_size, config.hidden_size, kernel_size=2, stride=2),
)
self.fpn3 = nn.Identity()
self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)
# Semantic segmentation head(s)
self.decode_head = BeitUperHead(config)
self.auxiliary_head = BeitFCNHead(config) if config.use_auxiliary_head else None
self.init_weights()
def compute_loss(self, logits, auxiliary_logits, labels):
# upsample logits to the images' original size
upsampled_logits = nn.functional.interpolate(
logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
)
if auxiliary_logits is not None:
upsampled_auxiliary_logits = nn.functional.interpolate(
auxiliary_logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
)
# compute weighted loss
loss_fct = CrossEntropyLoss(ignore_index=255)
main_loss = loss_fct(upsampled_logits, labels)
auxiliary_loss = loss_fct(upsampled_auxiliary_logits, labels)
loss = main_loss + self.config.auxiliary_loss_weight * auxiliary_loss
return loss
@add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values=None,
head_mask=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, height, width)`, `optional`):
Ground truth semantic segmentation maps for computing the loss. Indices should be in :obj:`[0, ...,
config.num_labels - 1]`. If :obj:`config.num_labels > 1`, a classification loss is computed
(Cross-Entropy).
Returns:
Examples::
>>> from transformers import BeitFeatureExtractor, BeitForSemanticSegmentation
>>> from PIL import Image
>>> import requests
>>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> feature_extractor = BeitFeatureExtractor.from_pretrained('microsoft/beit-base-finetuned-ade-640-640')
>>> model = BeitForSemanticSegmentation.from_pretrained('microsoft/beit-base-finetuned-ade-640-640')
>>> inputs = feature_extractor(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> # logits are of shape (batch_size, num_labels, height/4, width/4)
>>> logits = outputs.logits
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
outputs = self.beit(
pixel_values,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=True, # we need the intermediate hidden states
return_dict=return_dict,
)
encoder_hidden_states = outputs.hidden_states if return_dict else outputs[2]
# only keep certain features, and reshape
# note that we do +1 as the encoder_hidden_states also includes the initial embeddings
features = [feature for idx, feature in enumerate(encoder_hidden_states) if idx + 1 in self.config.out_indices]
batch_size = pixel_values.shape[0]
patch_resolution = self.config.image_size // self.config.patch_size
features = [
x[:, 1:, :].permute(0, 2, 1).reshape(batch_size, -1, patch_resolution, patch_resolution) for x in features
]
# apply FPNs
ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]
for i in range(len(features)):
features[i] = ops[i](features[i])
logits = self.decode_head(features)
auxiliary_logits = None
if self.auxiliary_head is not None:
auxiliary_logits = self.auxiliary_head(features)
loss = None
if labels is not None:
if self.config.num_labels == 1:
raise ValueError("The number of labels should be greater than one")
else:
loss = self.compute_loss(logits, auxiliary_logits, labels)
if not return_dict:
if output_hidden_states:
output = (logits,) + outputs[2:]
else:
output = (logits,) + outputs[3:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states if output_hidden_states else None,
attentions=outputs.attentions,
)
......@@ -606,6 +606,11 @@ class BeitForMaskedImageModeling:
requires_backends(cls, ["torch"])
class BeitForSemanticSegmentation:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class BeitModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
......
......@@ -18,6 +18,8 @@
import inspect
import unittest
from datasets import load_dataset
from transformers import BeitConfig
from transformers.file_utils import cached_property, is_torch_available, is_vision_available
from transformers.models.auto import get_values
......@@ -31,7 +33,13 @@ if is_torch_available():
import torch
from torch import nn
from transformers import MODEL_MAPPING, BeitForImageClassification, BeitForMaskedImageModeling, BeitModel
from transformers import (
MODEL_MAPPING,
BeitForImageClassification,
BeitForMaskedImageModeling,
BeitForSemanticSegmentation,
BeitModel,
)
from transformers.models.beit.modeling_beit import BEIT_PRETRAINED_MODEL_ARCHIVE_LIST, to_2tuple
......@@ -53,7 +61,7 @@ class BeitModelTester:
is_training=True,
use_labels=True,
hidden_size=32,
num_hidden_layers=5,
num_hidden_layers=4,
num_attention_heads=4,
intermediate_size=37,
hidden_act="gelu",
......@@ -63,6 +71,7 @@ class BeitModelTester:
initializer_range=0.02,
num_labels=3,
scope=None,
out_indices=[0, 1, 2, 3],
):
self.parent = parent
self.vocab_size = 100
......@@ -82,6 +91,7 @@ class BeitModelTester:
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.scope = scope
self.out_indices = out_indices
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
......@@ -109,6 +119,7 @@ class BeitModelTester:
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
is_decoder=False,
initializer_range=self.initializer_range,
out_indices=self.out_indices,
)
def create_and_check_model(self, config, pixel_values, labels):
......@@ -160,7 +171,9 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase):
"""
all_model_classes = (
(BeitModel, BeitForImageClassification, BeitForMaskedImageModeling) if is_torch_available() else ()
(BeitModel, BeitForImageClassification, BeitForMaskedImageModeling, BeitForSemanticSegmentation)
if is_torch_available()
else ()
)
test_pruning = False
......@@ -212,11 +225,14 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase):
config.return_dict = True
for model_class in self.all_model_classes:
if model_class in get_values(MODEL_MAPPING):
continue
# we don't test BeitForMaskedImageModeling
if model_class.__name__ == "BeitForMaskedImageModeling":
if model_class in [*get_values(MODEL_MAPPING), BeitForMaskedImageModeling]:
continue
# TODO: remove the following 3 lines once we have a MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING
# this can then be incorporated into _prepare_for_class in test_modeling_common.py
elif model_class.__name__ == "BeitForSemanticSegmentation":
batch_size, num_channels, height, width = inputs_dict["pixel_values"].shape
inputs_dict["labels"] = torch.zeros([self.model_tester.batch_size, height, width]).long()
model = model_class(config)
model.to(torch_device)
model.train()
......@@ -233,11 +249,17 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase):
config.return_dict = True
for model_class in self.all_model_classes:
if model_class in get_values(MODEL_MAPPING) or not model_class.supports_gradient_checkpointing:
continue
# we don't test BeitForMaskedImageModeling
if model_class.__name__ == "BeitForMaskedImageModeling":
if (
model_class in [*get_values(MODEL_MAPPING), BeitForMaskedImageModeling]
or not model_class.supports_gradient_checkpointing
):
continue
# TODO: remove the following 3 lines once we have a MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING
# this can then be incorporated into _prepare_for_class in test_modeling_common.py
elif model_class.__name__ == "BeitForSemanticSegmentation":
batch_size, num_channels, height, width = inputs_dict["pixel_values"].shape
inputs_dict["labels"] = torch.zeros([self.model_tester.batch_size, height, width]).long()
model = model_class(config)
model.to(torch_device)
model.train()
......@@ -298,7 +320,8 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase):
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
attentions = outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
......@@ -316,15 +339,9 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase):
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
if hasattr(self.model_tester, "num_hidden_states_types"):
added_hidden_states = self.model_tester.num_hidden_states_types
elif self.is_encoder_decoder:
added_hidden_states = 2
else:
added_hidden_states = 1
self.assertEqual(out_len + added_hidden_states, len(outputs))
self.assertEqual(out_len + 1, len(outputs))
self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self_attentions = outputs.attentions
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
......@@ -472,3 +489,32 @@ class BeitModelIntegrationTest(unittest.TestCase):
expected_class_idx = 2396
self.assertEqual(logits.argmax(-1).item(), expected_class_idx)
@slow
def test_inference_semantic_segmentation(self):
model = BeitForSemanticSegmentation.from_pretrained("microsoft/beit-base-finetuned-ade-640-640")
model = model.to(torch_device)
feature_extractor = BeitFeatureExtractor(do_resize=True, size=640, do_center_crop=False)
ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test")
image = Image.open(ds[0]["file"])
inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device)
# forward pass
outputs = model(**inputs)
logits = outputs.logits
# verify the logits
expected_shape = torch.Size((1, 150, 160, 160))
self.assertEqual(logits.shape, expected_shape)
expected_slice = torch.tensor(
[
[[-4.9225, -2.3954, -3.0522], [-2.8822, -1.0046, -1.7561], [-2.9549, -1.3228, -2.1347]],
[[-5.8168, -3.4129, -4.0778], [-3.8651, -2.2214, -3.0277], [-3.8356, -2.4643, -3.3535]],
[[-0.0078, 3.9952, 4.0754], [2.9856, 4.6944, 5.0035], [3.2413, 4.7813, 4.9969]],
]
).to(torch_device)
self.assertTrue(torch.allclose(logits[0, :3, :3, :3], expected_slice, atol=1e-4))
......@@ -88,7 +88,7 @@ if is_torch_fx_available():
def _config_zero_init(config):
configs_no_init = copy.deepcopy(config)
for key in configs_no_init.__dict__.keys():
if "_range" in key or "_std" in key or "initializer_factor" in key:
if "_range" in key or "_std" in key or "initializer_factor" in key or "layer_scale" in key:
setattr(configs_no_init, key, 1e-10)
return configs_no_init
......
......@@ -102,6 +102,7 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
# models to ignore for model xxx mapping
"SegformerDecodeHead",
"SegformerForSemanticSegmentation",
"BeitForSemanticSegmentation",
"FlaxBeitForMaskedImageModeling",
"BeitForMaskedImageModeling",
"CLIPTextModel",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment