"vscode:/vscode.git/clone" did not exist on "2fecde742db1b08e402eb6b11cfc3d80f2ec8a21"
Unverified Commit e20faa6f authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

Add BeitForSemanticSegmentation (#14096)



* Add first draft

* Make forward pass work

* Improve conversion script

* Add notebook that checks if it works

* Add BeitForSemanticSegmentation to the tests

* More improvements

* Make BeitForSemanticSegmentation consistent with Segformer

* Small bug fix

* Add BeitForSemanticSegmentation to docs

* Make sure model doesn't output hidden states when the user doesn't want to

* Make it possible to convert the large model

* Fix issue

* Fix conversion script for large model

* Add auxiliary_head option to semantic segmentation model

* Apply suggestions from @sgugger's review

* Apply suggestions from code review

* Fix failing test
Co-authored-by: default avatarLysandre <lysandre.debut@reseau.eseo.fr>
parent 8b325781
...@@ -98,6 +98,13 @@ BeitForImageClassification ...@@ -98,6 +98,13 @@ BeitForImageClassification
:members: forward :members: forward
BeitForSemanticSegmentation
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.BeitForSemanticSegmentation
:members: forward
FlaxBeitModel FlaxBeitModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
...@@ -638,6 +638,7 @@ if is_torch_available(): ...@@ -638,6 +638,7 @@ if is_torch_available():
"BEIT_PRETRAINED_MODEL_ARCHIVE_LIST", "BEIT_PRETRAINED_MODEL_ARCHIVE_LIST",
"BeitForImageClassification", "BeitForImageClassification",
"BeitForMaskedImageModeling", "BeitForMaskedImageModeling",
"BeitForSemanticSegmentation",
"BeitModel", "BeitModel",
"BeitPreTrainedModel", "BeitPreTrainedModel",
] ]
...@@ -2483,6 +2484,7 @@ if TYPE_CHECKING: ...@@ -2483,6 +2484,7 @@ if TYPE_CHECKING:
BEIT_PRETRAINED_MODEL_ARCHIVE_LIST, BEIT_PRETRAINED_MODEL_ARCHIVE_LIST,
BeitForImageClassification, BeitForImageClassification,
BeitForMaskedImageModeling, BeitForMaskedImageModeling,
BeitForSemanticSegmentation,
BeitModel, BeitModel,
BeitPreTrainedModel, BeitPreTrainedModel,
) )
......
...@@ -33,6 +33,7 @@ if is_torch_available(): ...@@ -33,6 +33,7 @@ if is_torch_available():
"BEIT_PRETRAINED_MODEL_ARCHIVE_LIST", "BEIT_PRETRAINED_MODEL_ARCHIVE_LIST",
"BeitForImageClassification", "BeitForImageClassification",
"BeitForMaskedImageModeling", "BeitForMaskedImageModeling",
"BeitForSemanticSegmentation",
"BeitModel", "BeitModel",
"BeitPreTrainedModel", "BeitPreTrainedModel",
] ]
...@@ -57,6 +58,7 @@ if TYPE_CHECKING: ...@@ -57,6 +58,7 @@ if TYPE_CHECKING:
BEIT_PRETRAINED_MODEL_ARCHIVE_LIST, BEIT_PRETRAINED_MODEL_ARCHIVE_LIST,
BeitForImageClassification, BeitForImageClassification,
BeitForMaskedImageModeling, BeitForMaskedImageModeling,
BeitForSemanticSegmentation,
BeitModel, BeitModel,
BeitPreTrainedModel, BeitPreTrainedModel,
) )
......
...@@ -78,6 +78,20 @@ class BeitConfig(PretrainedConfig): ...@@ -78,6 +78,20 @@ class BeitConfig(PretrainedConfig):
use_mean_pooling (:obj:`bool`, `optional`, defaults to :obj:`True`): use_mean_pooling (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether to mean pool the final hidden states of the patches instead of using the final hidden state of the Whether to mean pool the final hidden states of the patches instead of using the final hidden state of the
CLS token, before applying the classification head. CLS token, before applying the classification head.
out_indices (:obj:`List[int]`, `optional`, defaults to :obj:`[3, 5, 7, 11]`):
Indices of the feature maps to use for semantic segmentation.
pool_scales (:obj:`Tuple[int]`, `optional`, defaults to :obj:`[1, 2, 3, 6]`):
Pooling scales used in Pooling Pyramid Module applied on the last feature map.
use_auxiliary_head (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether to use an auxiliary head during training.
auxiliary_loss_weight (:obj:`float`, `optional`, defaults to 0.4):
Weight of the cross-entropy loss of the auxiliary head.
auxiliary_channels (:obj:`int`, `optional`, defaults to 256):
Number of channels to use in the auxiliary head.
auxiliary_num_convs (:obj:`int`, `optional`, defaults to 1):
Number of convolutional layers to use in the auxiliary head.
auxiliary_concat_input (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to concatenate the output of the auxiliary head with the input before the classification layer.
Example:: Example::
...@@ -117,6 +131,13 @@ class BeitConfig(PretrainedConfig): ...@@ -117,6 +131,13 @@ class BeitConfig(PretrainedConfig):
layer_scale_init_value=0.1, layer_scale_init_value=0.1,
drop_path_rate=0.1, drop_path_rate=0.1,
use_mean_pooling=True, use_mean_pooling=True,
out_indices=[3, 5, 7, 11],
pool_scales=[1, 2, 3, 6],
use_auxiliary_head=True,
auxiliary_loss_weight=0.4,
auxiliary_channels=256,
auxiliary_num_convs=1,
auxiliary_concat_input=False,
**kwargs **kwargs
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -142,3 +163,12 @@ class BeitConfig(PretrainedConfig): ...@@ -142,3 +163,12 @@ class BeitConfig(PretrainedConfig):
self.layer_scale_init_value = layer_scale_init_value self.layer_scale_init_value = layer_scale_init_value
self.drop_path_rate = drop_path_rate self.drop_path_rate = drop_path_rate
self.use_mean_pooling = use_mean_pooling self.use_mean_pooling = use_mean_pooling
# decode head attributes (semantic segmentation)
self.out_indices = out_indices
self.pool_scales = pool_scales
# auxiliary head attributes (semantic segmentation)
self.use_auxiliary_head = use_auxiliary_head
self.auxiliary_loss_weight = auxiliary_loss_weight
self.auxiliary_channels = auxiliary_channels
self.auxiliary_num_convs = auxiliary_num_convs
self.auxiliary_concat_input = auxiliary_concat_input
...@@ -20,11 +20,18 @@ import json ...@@ -20,11 +20,18 @@ import json
from pathlib import Path from pathlib import Path
import torch import torch
from datasets import load_dataset
from PIL import Image from PIL import Image
import requests import requests
from huggingface_hub import cached_download, hf_hub_url from huggingface_hub import cached_download, hf_hub_url
from transformers import BeitConfig, BeitFeatureExtractor, BeitForImageClassification, BeitForMaskedImageModeling from transformers import (
BeitConfig,
BeitFeatureExtractor,
BeitForImageClassification,
BeitForMaskedImageModeling,
BeitForSemanticSegmentation,
)
from transformers.utils import logging from transformers.utils import logging
...@@ -33,27 +40,33 @@ logger = logging.get_logger(__name__) ...@@ -33,27 +40,33 @@ logger = logging.get_logger(__name__)
# here we list all keys to be renamed (original name on the left, our name on the right) # here we list all keys to be renamed (original name on the left, our name on the right)
def create_rename_keys(config, has_lm_head=False): def create_rename_keys(config, has_lm_head=False, is_semantic=False):
prefix = "backbone." if is_semantic else ""
rename_keys = [] rename_keys = []
for i in range(config.num_hidden_layers): for i in range(config.num_hidden_layers):
# encoder layers: output projection, 2 feedforward neural networks and 2 layernorms # encoder layers: output projection, 2 feedforward neural networks and 2 layernorms
rename_keys.append((f"blocks.{i}.norm1.weight", f"beit.encoder.layer.{i}.layernorm_before.weight")) rename_keys.append((f"{prefix}blocks.{i}.norm1.weight", f"beit.encoder.layer.{i}.layernorm_before.weight"))
rename_keys.append((f"blocks.{i}.norm1.bias", f"beit.encoder.layer.{i}.layernorm_before.bias")) rename_keys.append((f"{prefix}blocks.{i}.norm1.bias", f"beit.encoder.layer.{i}.layernorm_before.bias"))
rename_keys.append((f"blocks.{i}.attn.proj.weight", f"beit.encoder.layer.{i}.attention.output.dense.weight")) rename_keys.append(
rename_keys.append((f"blocks.{i}.attn.proj.bias", f"beit.encoder.layer.{i}.attention.output.dense.bias")) (f"{prefix}blocks.{i}.attn.proj.weight", f"beit.encoder.layer.{i}.attention.output.dense.weight")
rename_keys.append((f"blocks.{i}.norm2.weight", f"beit.encoder.layer.{i}.layernorm_after.weight")) )
rename_keys.append((f"blocks.{i}.norm2.bias", f"beit.encoder.layer.{i}.layernorm_after.bias")) rename_keys.append(
rename_keys.append((f"blocks.{i}.mlp.fc1.weight", f"beit.encoder.layer.{i}.intermediate.dense.weight")) (f"{prefix}blocks.{i}.attn.proj.bias", f"beit.encoder.layer.{i}.attention.output.dense.bias")
rename_keys.append((f"blocks.{i}.mlp.fc1.bias", f"beit.encoder.layer.{i}.intermediate.dense.bias")) )
rename_keys.append((f"blocks.{i}.mlp.fc2.weight", f"beit.encoder.layer.{i}.output.dense.weight")) rename_keys.append((f"{prefix}blocks.{i}.norm2.weight", f"beit.encoder.layer.{i}.layernorm_after.weight"))
rename_keys.append((f"blocks.{i}.mlp.fc2.bias", f"beit.encoder.layer.{i}.output.dense.bias")) rename_keys.append((f"{prefix}blocks.{i}.norm2.bias", f"beit.encoder.layer.{i}.layernorm_after.bias"))
rename_keys.append((f"{prefix}blocks.{i}.mlp.fc1.weight", f"beit.encoder.layer.{i}.intermediate.dense.weight"))
rename_keys.append((f"{prefix}blocks.{i}.mlp.fc1.bias", f"beit.encoder.layer.{i}.intermediate.dense.bias"))
rename_keys.append((f"{prefix}blocks.{i}.mlp.fc2.weight", f"beit.encoder.layer.{i}.output.dense.weight"))
rename_keys.append((f"{prefix}blocks.{i}.mlp.fc2.bias", f"beit.encoder.layer.{i}.output.dense.bias"))
# projection layer + position embeddings # projection layer + position embeddings
rename_keys.extend( rename_keys.extend(
[ [
("cls_token", "beit.embeddings.cls_token"), (f"{prefix}cls_token", "beit.embeddings.cls_token"),
("patch_embed.proj.weight", "beit.embeddings.patch_embeddings.projection.weight"), (f"{prefix}patch_embed.proj.weight", "beit.embeddings.patch_embeddings.projection.weight"),
("patch_embed.proj.bias", "beit.embeddings.patch_embeddings.projection.bias"), (f"{prefix}patch_embed.proj.bias", "beit.embeddings.patch_embeddings.projection.bias"),
] ]
) )
...@@ -74,6 +87,16 @@ def create_rename_keys(config, has_lm_head=False): ...@@ -74,6 +87,16 @@ def create_rename_keys(config, has_lm_head=False):
("norm.bias", "layernorm.bias"), ("norm.bias", "layernorm.bias"),
] ]
) )
elif is_semantic:
# semantic segmentation classification heads
rename_keys.extend(
[
("decode_head.conv_seg.weight", "decode_head.classifier.weight"),
("decode_head.conv_seg.bias", "decode_head.classifier.bias"),
("auxiliary_head.conv_seg.weight", "auxiliary_head.classifier.weight"),
("auxiliary_head.conv_seg.bias", "auxiliary_head.classifier.bias"),
]
)
else: else:
# layernorm + classification head # layernorm + classification head
rename_keys.extend( rename_keys.extend(
...@@ -89,45 +112,45 @@ def create_rename_keys(config, has_lm_head=False): ...@@ -89,45 +112,45 @@ def create_rename_keys(config, has_lm_head=False):
# we split up the matrix of each encoder layer into queries, keys and values # we split up the matrix of each encoder layer into queries, keys and values
def read_in_q_k_v(state_dict, config, has_lm_head=False): def read_in_q_k_v(state_dict, config, has_lm_head=False, is_semantic=False):
for i in range(config.num_hidden_layers): for i in range(config.num_hidden_layers):
prefix = "beit." prefix = "backbone." if is_semantic else ""
# queries, keys and values # queries, keys and values
in_proj_weight = state_dict.pop(f"blocks.{i}.attn.qkv.weight") in_proj_weight = state_dict.pop(f"{prefix}blocks.{i}.attn.qkv.weight")
q_bias = state_dict.pop(f"blocks.{i}.attn.q_bias") q_bias = state_dict.pop(f"{prefix}blocks.{i}.attn.q_bias")
v_bias = state_dict.pop(f"blocks.{i}.attn.v_bias") v_bias = state_dict.pop(f"{prefix}blocks.{i}.attn.v_bias")
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[ state_dict[f"beit.encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[
: config.hidden_size, : : config.hidden_size, :
] ]
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.bias"] = q_bias state_dict[f"beit.encoder.layer.{i}.attention.attention.query.bias"] = q_bias
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[ state_dict[f"beit.encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
config.hidden_size : config.hidden_size * 2, : config.hidden_size : config.hidden_size * 2, :
] ]
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[ state_dict[f"beit.encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[
-config.hidden_size :, : -config.hidden_size :, :
] ]
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.bias"] = v_bias state_dict[f"beit.encoder.layer.{i}.attention.attention.value.bias"] = v_bias
# gamma_1 and gamma_2 # gamma_1 and gamma_2
# we call them lambda because otherwise they are renamed when using .from_pretrained # we call them lambda because otherwise they are renamed when using .from_pretrained
gamma_1 = state_dict.pop(f"blocks.{i}.gamma_1") gamma_1 = state_dict.pop(f"{prefix}blocks.{i}.gamma_1")
gamma_2 = state_dict.pop(f"blocks.{i}.gamma_2") gamma_2 = state_dict.pop(f"{prefix}blocks.{i}.gamma_2")
state_dict[f"{prefix}encoder.layer.{i}.lambda_1"] = gamma_1 state_dict[f"beit.encoder.layer.{i}.lambda_1"] = gamma_1
state_dict[f"{prefix}encoder.layer.{i}.lambda_2"] = gamma_2 state_dict[f"beit.encoder.layer.{i}.lambda_2"] = gamma_2
# relative_position bias table + index # relative_position bias table + index
if not has_lm_head: if not has_lm_head:
# each layer has its own relative position bias # each layer has its own relative position bias
table = state_dict.pop(f"blocks.{i}.attn.relative_position_bias_table") table = state_dict.pop(f"{prefix}blocks.{i}.attn.relative_position_bias_table")
index = state_dict.pop(f"blocks.{i}.attn.relative_position_index") index = state_dict.pop(f"{prefix}blocks.{i}.attn.relative_position_index")
state_dict[ state_dict[
f"{prefix}encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_bias_table" f"beit.encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_bias_table"
] = table ] = table
state_dict[ state_dict[
f"{prefix}encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_index" f"beit.encoder.layer.{i}.attention.attention.relative_position_bias.relative_position_index"
] = index ] = index
...@@ -152,6 +175,7 @@ def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path): ...@@ -152,6 +175,7 @@ def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path):
# define default BEiT configuration # define default BEiT configuration
config = BeitConfig() config = BeitConfig()
has_lm_head = False has_lm_head = False
is_semantic = False
repo_id = "datasets/huggingface/label-files" repo_id = "datasets/huggingface/label-files"
# set config parameters based on URL # set config parameters based on URL
if checkpoint_url[-9:-4] == "pt22k": if checkpoint_url[-9:-4] == "pt22k":
...@@ -185,8 +209,19 @@ def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path): ...@@ -185,8 +209,19 @@ def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path):
config.image_size = 384 config.image_size = 384
if "512" in checkpoint_url: if "512" in checkpoint_url:
config.image_size = 512 config.image_size = 512
elif "ade20k" in checkpoint_url:
# fine-tuning
config.use_relative_position_bias = True
config.num_labels = 150
filename = "ade20k-id2label.json"
id2label = json.load(open(cached_download(hf_hub_url(repo_id, filename)), "r"))
id2label = {int(k): v for k, v in id2label.items()}
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}
config.image_size = 640
is_semantic = True
else: else:
raise ValueError("Checkpoint not supported, URL should either end with 'pt22k', 'ft22k' or 'to1k'") raise ValueError("Checkpoint not supported, URL should either end with 'pt22k', 'ft22k', 'to1k' or 'ade20k'")
# size of the architecture # size of the architecture
if "base" in checkpoint_url: if "base" in checkpoint_url:
...@@ -196,27 +231,48 @@ def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path): ...@@ -196,27 +231,48 @@ def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path):
config.intermediate_size = 4096 config.intermediate_size = 4096
config.num_hidden_layers = 24 config.num_hidden_layers = 24
config.num_attention_heads = 16 config.num_attention_heads = 16
if "ade20k" in checkpoint_url:
config.image_size = 640
config.out_indices = [7, 11, 15, 23]
else: else:
raise ValueError("Should either find 'base' or 'large' in checkpoint URL") raise ValueError("Should either find 'base' or 'large' in checkpoint URL")
# load state_dict of original model, remove and rename some keys # load state_dict of original model, remove and rename some keys
state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu", check_hash=True)["model"] state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu", check_hash=True)
rename_keys = create_rename_keys(config, has_lm_head=has_lm_head) state_dict = state_dict["model"] if "ade20k" not in checkpoint_url else state_dict["state_dict"]
rename_keys = create_rename_keys(config, has_lm_head=has_lm_head, is_semantic=is_semantic)
for src, dest in rename_keys: for src, dest in rename_keys:
rename_key(state_dict, src, dest) rename_key(state_dict, src, dest)
read_in_q_k_v(state_dict, config, has_lm_head=has_lm_head) read_in_q_k_v(state_dict, config, has_lm_head=has_lm_head, is_semantic=is_semantic)
if is_semantic:
# add prefix to decoder keys
for key, val in state_dict.copy().items():
val = state_dict.pop(key)
if key.startswith("backbone.fpn"):
key = key.replace("backbone.fpn", "fpn")
state_dict[key] = val
# load HuggingFace model # load HuggingFace model
if checkpoint_url[-9:-4] == "pt22k": if checkpoint_url[-9:-4] == "pt22k":
model = BeitForMaskedImageModeling(config) model = BeitForMaskedImageModeling(config)
elif "ade20k" in checkpoint_url:
model = BeitForSemanticSegmentation(config)
else: else:
model = BeitForImageClassification(config) model = BeitForImageClassification(config)
model.eval() model.eval()
model.load_state_dict(state_dict) model.load_state_dict(state_dict)
# Check outputs on an image # Check outputs on an image
feature_extractor = BeitFeatureExtractor(size=config.image_size, resample=Image.BILINEAR, do_center_crop=False) if is_semantic:
encoding = feature_extractor(images=prepare_img(), return_tensors="pt") feature_extractor = BeitFeatureExtractor(size=config.image_size, do_center_crop=False)
ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test")
image = Image.open(ds[0]["file"])
else:
feature_extractor = BeitFeatureExtractor(size=config.image_size, resample=Image.BILINEAR, do_center_crop=False)
image = prepare_img()
encoding = feature_extractor(images=image, return_tensors="pt")
pixel_values = encoding["pixel_values"] pixel_values = encoding["pixel_values"]
outputs = model(pixel_values) outputs = model(pixel_values)
...@@ -257,15 +313,39 @@ def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path): ...@@ -257,15 +313,39 @@ def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path):
elif checkpoint_url[:-4].endswith("beit_large_patch16_512_pt22k_ft22kto1k"): elif checkpoint_url[:-4].endswith("beit_large_patch16_512_pt22k_ft22kto1k"):
expected_logits = torch.tensor([-0.3062, 0.7261, 0.4852]) expected_logits = torch.tensor([-0.3062, 0.7261, 0.4852])
expected_class_idx = 761 expected_class_idx = 761
elif checkpoint_url[:-4].endswith("beit_base_patch16_640_pt22k_ft22ktoade20k"):
expected_shape = (1, 150, 160, 160)
expected_logits = torch.tensor(
[
[[-4.9225, -2.3954, -3.0522], [-2.8822, -1.0046, -1.7561], [-2.9549, -1.3228, -2.1347]],
[[-5.8168, -3.4129, -4.0778], [-3.8651, -2.2214, -3.0277], [-3.8356, -2.4643, -3.3535]],
[[-0.0078, 3.9952, 4.0754], [2.9856, 4.6944, 5.0035], [3.2413, 4.7813, 4.9969]],
]
)
elif checkpoint_url[:-4].endswith("beit_large_patch16_640_pt22k_ft22ktoade20k"):
expected_shape = (1, 150, 160, 160)
expected_logits = torch.tensor(
[
[[-4.3305, -2.3049, -3.0161], [-2.9591, -1.5305, -2.2251], [-3.4198, -1.8004, -2.9062]],
[[-5.8922, -3.7435, -4.3978], [-4.2063, -2.7872, -3.4755], [-4.2791, -3.1874, -4.1681]],
[[0.9895, 4.3467, 4.7663], [4.2476, 5.6830, 6.1518], [4.5550, 6.2495, 6.5154]],
]
)
else: else:
raise ValueError("Can't verify logits as model is not supported") raise ValueError("Can't verify logits as model is not supported")
assert logits.shape == expected_shape, "Shape of logits not as expected" assert logits.shape == expected_shape, "Shape of logits not as expected"
print("Shape of logits:", logits.shape)
if not has_lm_head: if not has_lm_head:
print("Predicted class idx:", logits.argmax(-1).item()) if is_semantic:
assert torch.allclose(logits[0, :3], expected_logits, atol=1e-3), "First elements of logits not as expected" assert torch.allclose(
assert logits.argmax(-1).item() == expected_class_idx, "Predicted class index not as expected" logits[0, :3, :3, :3], expected_logits, atol=1e-3
), "First elements of logits not as expected"
else:
print("Predicted class idx:", logits.argmax(-1).item())
assert torch.allclose(
logits[0, :3], expected_logits, atol=1e-3
), "First elements of logits not as expected"
assert logits.argmax(-1).item() == expected_class_idx, "Predicted class index not as expected"
Path(pytorch_dump_folder_path).mkdir(exist_ok=True) Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
print(f"Saving model to {pytorch_dump_folder_path}") print(f"Saving model to {pytorch_dump_folder_path}")
......
...@@ -163,6 +163,7 @@ class PatchEmbeddings(nn.Module): ...@@ -163,6 +163,7 @@ class PatchEmbeddings(nn.Module):
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})." f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
) )
x = self.projection(pixel_values).flatten(2).transpose(1, 2) x = self.projection(pixel_values).flatten(2).transpose(1, 2)
return x return x
...@@ -499,7 +500,7 @@ class BeitPreTrainedModel(PreTrainedModel): ...@@ -499,7 +500,7 @@ class BeitPreTrainedModel(PreTrainedModel):
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d)): if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
# Slightly different from the TF version which uses truncated_normal for initialization # Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617 # cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
...@@ -851,3 +852,354 @@ class BeitForImageClassification(BeitPreTrainedModel): ...@@ -851,3 +852,354 @@ class BeitForImageClassification(BeitPreTrainedModel):
hidden_states=outputs.hidden_states, hidden_states=outputs.hidden_states,
attentions=outputs.attentions, attentions=outputs.attentions,
) )
class BeitConvModule(nn.Module):
"""
A convolutional block that bundles conv/norm/activation layers. This block simplifies the usage of convolution
layers, which are commonly used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU).
Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
"""
def __init__(self, in_channels, out_channels, kernel_size, padding=0, bias=False, dilation=1):
super().__init__()
self.conv = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
padding=padding,
bias=bias,
dilation=dilation,
)
self.bn = nn.BatchNorm2d(out_channels)
self.activation = nn.ReLU()
def forward(self, input):
output = self.conv(input)
output = self.bn(output)
output = self.activation(output)
return output
class BeitPyramidPoolingModule(nn.ModuleList):
"""
Pyramid Pooling Module (PPM) used in PSPNet.
Args:
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
Module.
in_channels (int): Input channels.
channels (int): Channels after modules, before conv_seg.
align_corners (bool): align_corners argument of F.interpolate.
Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
"""
def __init__(self, pool_scales, in_channels, channels, align_corners):
super().__init__()
self.pool_scales = pool_scales
self.align_corners = align_corners
self.in_channels = in_channels
self.channels = channels
for pool_scale in pool_scales:
self.append(
nn.Sequential(
nn.AdaptiveAvgPool2d(pool_scale),
BeitConvModule(self.in_channels, self.channels, kernel_size=1),
)
)
def forward(self, x):
ppm_outs = []
for ppm in self:
ppm_out = ppm(x)
upsampled_ppm_out = nn.functional.interpolate(
ppm_out, size=x.size()[2:], mode="bilinear", align_corners=self.align_corners
)
ppm_outs.append(upsampled_ppm_out)
return ppm_outs
class BeitUperHead(nn.Module):
"""
Unified Perceptual Parsing for Scene Understanding. This head is the implementation of `UPerNet
<https://arxiv.org/abs/1807.10221>`_.
Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
"""
def __init__(self, config):
super().__init__()
self.pool_scales = config.pool_scales # e.g. (1, 2, 3, 6)
self.in_channels = [config.hidden_size] * 4 # e.g. [768, 768, 768, 768]
self.channels = config.hidden_size
self.align_corners = False
self.classifier = nn.Conv2d(self.channels, config.num_labels, kernel_size=1)
# PSP Module
self.psp_modules = BeitPyramidPoolingModule(
self.pool_scales,
self.in_channels[-1],
self.channels,
align_corners=self.align_corners,
)
self.bottleneck = BeitConvModule(
self.in_channels[-1] + len(self.pool_scales) * self.channels,
self.channels,
kernel_size=3,
padding=1,
)
# FPN Module
self.lateral_convs = nn.ModuleList()
self.fpn_convs = nn.ModuleList()
for in_channels in self.in_channels[:-1]: # skip the top layer
l_conv = BeitConvModule(in_channels, self.channels, kernel_size=1)
fpn_conv = BeitConvModule(self.channels, self.channels, kernel_size=3, padding=1)
self.lateral_convs.append(l_conv)
self.fpn_convs.append(fpn_conv)
self.fpn_bottleneck = BeitConvModule(
len(self.in_channels) * self.channels,
self.channels,
kernel_size=3,
padding=1,
)
def psp_forward(self, inputs):
x = inputs[-1]
psp_outs = [x]
psp_outs.extend(self.psp_modules(x))
psp_outs = torch.cat(psp_outs, dim=1)
output = self.bottleneck(psp_outs)
return output
def forward(self, encoder_hidden_states):
# build laterals
laterals = [lateral_conv(encoder_hidden_states[i]) for i, lateral_conv in enumerate(self.lateral_convs)]
laterals.append(self.psp_forward(encoder_hidden_states))
# build top-down path
used_backbone_levels = len(laterals)
for i in range(used_backbone_levels - 1, 0, -1):
prev_shape = laterals[i - 1].shape[2:]
laterals[i - 1] += nn.functional.interpolate(
laterals[i], size=prev_shape, mode="bilinear", align_corners=self.align_corners
)
# build outputs
fpn_outs = [self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels - 1)]
# append psp feature
fpn_outs.append(laterals[-1])
for i in range(used_backbone_levels - 1, 0, -1):
fpn_outs[i] = nn.functional.interpolate(
fpn_outs[i], size=fpn_outs[0].shape[2:], mode="bilinear", align_corners=self.align_corners
)
fpn_outs = torch.cat(fpn_outs, dim=1)
output = self.fpn_bottleneck(fpn_outs)
output = self.classifier(output)
return output
class BeitFCNHead(nn.Module):
"""
Fully Convolution Networks for Semantic Segmentation. This head is implemented of `FCNNet
<https://arxiv.org/abs/1411.4038>`_.
Args:
config (BeitConfig): Configuration.
in_channels
kernel_size (int): The kernel size for convs in the head. Default: 3.
dilation (int): The dilation rate for convs in the head. Default: 1.
Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
"""
def __init__(self, config, in_index=2, kernel_size=3, dilation=1):
super().__init__()
self.in_channels = config.hidden_size
self.channels = config.auxiliary_channels
self.num_convs = config.auxiliary_num_convs
self.concat_input = config.auxiliary_concat_input
self.in_index = in_index
conv_padding = (kernel_size // 2) * dilation
convs = []
convs.append(
BeitConvModule(
self.in_channels, self.channels, kernel_size=kernel_size, padding=conv_padding, dilation=dilation
)
)
for i in range(self.num_convs - 1):
convs.append(
BeitConvModule(
self.channels, self.channels, kernel_size=kernel_size, padding=conv_padding, dilation=dilation
)
)
if self.num_convs == 0:
self.convs = nn.Identity()
else:
self.convs = nn.Sequential(*convs)
if self.concat_input:
self.conv_cat = BeitConvModule(
self.in_channels + self.channels, self.channels, kernel_size=kernel_size, padding=kernel_size // 2
)
self.classifier = nn.Conv2d(self.channels, config.num_labels, kernel_size=1)
def forward(self, encoder_hidden_states):
# just take the relevant feature maps
hidden_states = encoder_hidden_states[self.in_index]
output = self.convs(hidden_states)
if self.concat_input:
output = self.conv_cat(torch.cat([hidden_states, output], dim=1))
output = self.classifier(output)
return output
@add_start_docstrings(
"""
Beit Model transformer with a semantic segmentation head on top e.g. for ADE20k, CityScapes.
""",
BEIT_START_DOCSTRING,
)
class BeitForSemanticSegmentation(BeitPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.beit = BeitModel(config, add_pooling_layer=False)
# FPNs
self.fpn1 = nn.Sequential(
nn.ConvTranspose2d(config.hidden_size, config.hidden_size, kernel_size=2, stride=2),
nn.BatchNorm2d(config.hidden_size),
nn.GELU(),
nn.ConvTranspose2d(config.hidden_size, config.hidden_size, kernel_size=2, stride=2),
)
self.fpn2 = nn.Sequential(
nn.ConvTranspose2d(config.hidden_size, config.hidden_size, kernel_size=2, stride=2),
)
self.fpn3 = nn.Identity()
self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)
# Semantic segmentation head(s)
self.decode_head = BeitUperHead(config)
self.auxiliary_head = BeitFCNHead(config) if config.use_auxiliary_head else None
self.init_weights()
def compute_loss(self, logits, auxiliary_logits, labels):
# upsample logits to the images' original size
upsampled_logits = nn.functional.interpolate(
logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
)
if auxiliary_logits is not None:
upsampled_auxiliary_logits = nn.functional.interpolate(
auxiliary_logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
)
# compute weighted loss
loss_fct = CrossEntropyLoss(ignore_index=255)
main_loss = loss_fct(upsampled_logits, labels)
auxiliary_loss = loss_fct(upsampled_auxiliary_logits, labels)
loss = main_loss + self.config.auxiliary_loss_weight * auxiliary_loss
return loss
@add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values=None,
head_mask=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, height, width)`, `optional`):
Ground truth semantic segmentation maps for computing the loss. Indices should be in :obj:`[0, ...,
config.num_labels - 1]`. If :obj:`config.num_labels > 1`, a classification loss is computed
(Cross-Entropy).
Returns:
Examples::
>>> from transformers import BeitFeatureExtractor, BeitForSemanticSegmentation
>>> from PIL import Image
>>> import requests
>>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> feature_extractor = BeitFeatureExtractor.from_pretrained('microsoft/beit-base-finetuned-ade-640-640')
>>> model = BeitForSemanticSegmentation.from_pretrained('microsoft/beit-base-finetuned-ade-640-640')
>>> inputs = feature_extractor(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> # logits are of shape (batch_size, num_labels, height/4, width/4)
>>> logits = outputs.logits
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
outputs = self.beit(
pixel_values,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=True, # we need the intermediate hidden states
return_dict=return_dict,
)
encoder_hidden_states = outputs.hidden_states if return_dict else outputs[2]
# only keep certain features, and reshape
# note that we do +1 as the encoder_hidden_states also includes the initial embeddings
features = [feature for idx, feature in enumerate(encoder_hidden_states) if idx + 1 in self.config.out_indices]
batch_size = pixel_values.shape[0]
patch_resolution = self.config.image_size // self.config.patch_size
features = [
x[:, 1:, :].permute(0, 2, 1).reshape(batch_size, -1, patch_resolution, patch_resolution) for x in features
]
# apply FPNs
ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]
for i in range(len(features)):
features[i] = ops[i](features[i])
logits = self.decode_head(features)
auxiliary_logits = None
if self.auxiliary_head is not None:
auxiliary_logits = self.auxiliary_head(features)
loss = None
if labels is not None:
if self.config.num_labels == 1:
raise ValueError("The number of labels should be greater than one")
else:
loss = self.compute_loss(logits, auxiliary_logits, labels)
if not return_dict:
if output_hidden_states:
output = (logits,) + outputs[2:]
else:
output = (logits,) + outputs[3:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states if output_hidden_states else None,
attentions=outputs.attentions,
)
...@@ -606,6 +606,11 @@ class BeitForMaskedImageModeling: ...@@ -606,6 +606,11 @@ class BeitForMaskedImageModeling:
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class BeitForSemanticSegmentation:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class BeitModel: class BeitModel:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
......
...@@ -18,6 +18,8 @@ ...@@ -18,6 +18,8 @@
import inspect import inspect
import unittest import unittest
from datasets import load_dataset
from transformers import BeitConfig from transformers import BeitConfig
from transformers.file_utils import cached_property, is_torch_available, is_vision_available from transformers.file_utils import cached_property, is_torch_available, is_vision_available
from transformers.models.auto import get_values from transformers.models.auto import get_values
...@@ -31,7 +33,13 @@ if is_torch_available(): ...@@ -31,7 +33,13 @@ if is_torch_available():
import torch import torch
from torch import nn from torch import nn
from transformers import MODEL_MAPPING, BeitForImageClassification, BeitForMaskedImageModeling, BeitModel from transformers import (
MODEL_MAPPING,
BeitForImageClassification,
BeitForMaskedImageModeling,
BeitForSemanticSegmentation,
BeitModel,
)
from transformers.models.beit.modeling_beit import BEIT_PRETRAINED_MODEL_ARCHIVE_LIST, to_2tuple from transformers.models.beit.modeling_beit import BEIT_PRETRAINED_MODEL_ARCHIVE_LIST, to_2tuple
...@@ -53,7 +61,7 @@ class BeitModelTester: ...@@ -53,7 +61,7 @@ class BeitModelTester:
is_training=True, is_training=True,
use_labels=True, use_labels=True,
hidden_size=32, hidden_size=32,
num_hidden_layers=5, num_hidden_layers=4,
num_attention_heads=4, num_attention_heads=4,
intermediate_size=37, intermediate_size=37,
hidden_act="gelu", hidden_act="gelu",
...@@ -63,6 +71,7 @@ class BeitModelTester: ...@@ -63,6 +71,7 @@ class BeitModelTester:
initializer_range=0.02, initializer_range=0.02,
num_labels=3, num_labels=3,
scope=None, scope=None,
out_indices=[0, 1, 2, 3],
): ):
self.parent = parent self.parent = parent
self.vocab_size = 100 self.vocab_size = 100
...@@ -82,6 +91,7 @@ class BeitModelTester: ...@@ -82,6 +91,7 @@ class BeitModelTester:
self.type_sequence_label_size = type_sequence_label_size self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.scope = scope self.scope = scope
self.out_indices = out_indices
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
...@@ -109,6 +119,7 @@ class BeitModelTester: ...@@ -109,6 +119,7 @@ class BeitModelTester:
attention_probs_dropout_prob=self.attention_probs_dropout_prob, attention_probs_dropout_prob=self.attention_probs_dropout_prob,
is_decoder=False, is_decoder=False,
initializer_range=self.initializer_range, initializer_range=self.initializer_range,
out_indices=self.out_indices,
) )
def create_and_check_model(self, config, pixel_values, labels): def create_and_check_model(self, config, pixel_values, labels):
...@@ -160,7 +171,9 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -160,7 +171,9 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase):
""" """
all_model_classes = ( all_model_classes = (
(BeitModel, BeitForImageClassification, BeitForMaskedImageModeling) if is_torch_available() else () (BeitModel, BeitForImageClassification, BeitForMaskedImageModeling, BeitForSemanticSegmentation)
if is_torch_available()
else ()
) )
test_pruning = False test_pruning = False
...@@ -212,11 +225,14 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -212,11 +225,14 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase):
config.return_dict = True config.return_dict = True
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
if model_class in get_values(MODEL_MAPPING):
continue
# we don't test BeitForMaskedImageModeling # we don't test BeitForMaskedImageModeling
if model_class.__name__ == "BeitForMaskedImageModeling": if model_class in [*get_values(MODEL_MAPPING), BeitForMaskedImageModeling]:
continue continue
# TODO: remove the following 3 lines once we have a MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING
# this can then be incorporated into _prepare_for_class in test_modeling_common.py
elif model_class.__name__ == "BeitForSemanticSegmentation":
batch_size, num_channels, height, width = inputs_dict["pixel_values"].shape
inputs_dict["labels"] = torch.zeros([self.model_tester.batch_size, height, width]).long()
model = model_class(config) model = model_class(config)
model.to(torch_device) model.to(torch_device)
model.train() model.train()
...@@ -233,11 +249,17 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -233,11 +249,17 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase):
config.return_dict = True config.return_dict = True
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
if model_class in get_values(MODEL_MAPPING) or not model_class.supports_gradient_checkpointing:
continue
# we don't test BeitForMaskedImageModeling # we don't test BeitForMaskedImageModeling
if model_class.__name__ == "BeitForMaskedImageModeling": if (
model_class in [*get_values(MODEL_MAPPING), BeitForMaskedImageModeling]
or not model_class.supports_gradient_checkpointing
):
continue continue
# TODO: remove the following 3 lines once we have a MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING
# this can then be incorporated into _prepare_for_class in test_modeling_common.py
elif model_class.__name__ == "BeitForSemanticSegmentation":
batch_size, num_channels, height, width = inputs_dict["pixel_values"].shape
inputs_dict["labels"] = torch.zeros([self.model_tester.batch_size, height, width]).long()
model = model_class(config) model = model_class(config)
model.to(torch_device) model.to(torch_device)
model.train() model.train()
...@@ -298,7 +320,8 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -298,7 +320,8 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase):
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class)) outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
attentions = outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
self.assertListEqual( self.assertListEqual(
...@@ -316,15 +339,9 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -316,15 +339,9 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase):
with torch.no_grad(): with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class)) outputs = model(**self._prepare_for_class(inputs_dict, model_class))
if hasattr(self.model_tester, "num_hidden_states_types"): self.assertEqual(out_len + 1, len(outputs))
added_hidden_states = self.model_tester.num_hidden_states_types
elif self.is_encoder_decoder:
added_hidden_states = 2
else:
added_hidden_states = 1
self.assertEqual(out_len + added_hidden_states, len(outputs))
self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions self_attentions = outputs.attentions
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers) self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual( self.assertListEqual(
...@@ -472,3 +489,32 @@ class BeitModelIntegrationTest(unittest.TestCase): ...@@ -472,3 +489,32 @@ class BeitModelIntegrationTest(unittest.TestCase):
expected_class_idx = 2396 expected_class_idx = 2396
self.assertEqual(logits.argmax(-1).item(), expected_class_idx) self.assertEqual(logits.argmax(-1).item(), expected_class_idx)
@slow
def test_inference_semantic_segmentation(self):
model = BeitForSemanticSegmentation.from_pretrained("microsoft/beit-base-finetuned-ade-640-640")
model = model.to(torch_device)
feature_extractor = BeitFeatureExtractor(do_resize=True, size=640, do_center_crop=False)
ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test")
image = Image.open(ds[0]["file"])
inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device)
# forward pass
outputs = model(**inputs)
logits = outputs.logits
# verify the logits
expected_shape = torch.Size((1, 150, 160, 160))
self.assertEqual(logits.shape, expected_shape)
expected_slice = torch.tensor(
[
[[-4.9225, -2.3954, -3.0522], [-2.8822, -1.0046, -1.7561], [-2.9549, -1.3228, -2.1347]],
[[-5.8168, -3.4129, -4.0778], [-3.8651, -2.2214, -3.0277], [-3.8356, -2.4643, -3.3535]],
[[-0.0078, 3.9952, 4.0754], [2.9856, 4.6944, 5.0035], [3.2413, 4.7813, 4.9969]],
]
).to(torch_device)
self.assertTrue(torch.allclose(logits[0, :3, :3, :3], expected_slice, atol=1e-4))
...@@ -88,7 +88,7 @@ if is_torch_fx_available(): ...@@ -88,7 +88,7 @@ if is_torch_fx_available():
def _config_zero_init(config): def _config_zero_init(config):
configs_no_init = copy.deepcopy(config) configs_no_init = copy.deepcopy(config)
for key in configs_no_init.__dict__.keys(): for key in configs_no_init.__dict__.keys():
if "_range" in key or "_std" in key or "initializer_factor" in key: if "_range" in key or "_std" in key or "initializer_factor" in key or "layer_scale" in key:
setattr(configs_no_init, key, 1e-10) setattr(configs_no_init, key, 1e-10)
return configs_no_init return configs_no_init
......
...@@ -102,6 +102,7 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [ ...@@ -102,6 +102,7 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
# models to ignore for model xxx mapping # models to ignore for model xxx mapping
"SegformerDecodeHead", "SegformerDecodeHead",
"SegformerForSemanticSegmentation", "SegformerForSemanticSegmentation",
"BeitForSemanticSegmentation",
"FlaxBeitForMaskedImageModeling", "FlaxBeitForMaskedImageModeling",
"BeitForMaskedImageModeling", "BeitForMaskedImageModeling",
"CLIPTextModel", "CLIPTextModel",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment