Unverified Commit cdddfbff authored by Alara Dirik's avatar Alara Dirik Committed by GitHub
Browse files

Add ConvNeXT V2 (#21679)

* Add ConvNeXt V2 to transformers
* TF model is separated from the PR to fix issues
parent 6c2ad00c
...@@ -78,6 +78,7 @@ specific language governing permissions and limitations under the License. ...@@ -78,6 +78,7 @@ specific language governing permissions and limitations under the License.
1. **[Conditional DETR](model_doc/conditional_detr)** (from Microsoft Research Asia) released with the paper [Conditional DETR for Fast Training Convergence](https://arxiv.org/abs/2108.06152) by Depu Meng, Xiaokang Chen, Zejia Fan, Gang Zeng, Houqiang Li, Yuhui Yuan, Lei Sun, Jingdong Wang. 1. **[Conditional DETR](model_doc/conditional_detr)** (from Microsoft Research Asia) released with the paper [Conditional DETR for Fast Training Convergence](https://arxiv.org/abs/2108.06152) by Depu Meng, Xiaokang Chen, Zejia Fan, Gang Zeng, Houqiang Li, Yuhui Yuan, Lei Sun, Jingdong Wang.
1. **[ConvBERT](model_doc/convbert)** (from YituTech) released with the paper [ConvBERT: Improving BERT with Span-based Dynamic Convolution](https://arxiv.org/abs/2008.02496) by Zihang Jiang, Weihao Yu, Daquan Zhou, Yunpeng Chen, Jiashi Feng, Shuicheng Yan. 1. **[ConvBERT](model_doc/convbert)** (from YituTech) released with the paper [ConvBERT: Improving BERT with Span-based Dynamic Convolution](https://arxiv.org/abs/2008.02496) by Zihang Jiang, Weihao Yu, Daquan Zhou, Yunpeng Chen, Jiashi Feng, Shuicheng Yan.
1. **[ConvNeXT](model_doc/convnext)** (from Facebook AI) released with the paper [A ConvNet for the 2020s](https://arxiv.org/abs/2201.03545) by Zhuang Liu, Hanzi Mao, Chao-Yuan Wu, Christoph Feichtenhofer, Trevor Darrell, Saining Xie. 1. **[ConvNeXT](model_doc/convnext)** (from Facebook AI) released with the paper [A ConvNet for the 2020s](https://arxiv.org/abs/2201.03545) by Zhuang Liu, Hanzi Mao, Chao-Yuan Wu, Christoph Feichtenhofer, Trevor Darrell, Saining Xie.
1. **[ConvNeXTV2](model_doc/convnextv2)** (from Facebook AI) released with the paper [ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders](https://arxiv.org/abs/2301.00808) by Sanghyun Woo, Shoubhik Debnath, Ronghang Hu, Xinlei Chen, Zhuang Liu, In So Kweon, Saining Xie.
1. **[CPM](model_doc/cpm)** (from Tsinghua University) released with the paper [CPM: A Large-scale Generative Chinese Pre-trained Language Model](https://arxiv.org/abs/2012.00413) by Zhengyan Zhang, Xu Han, Hao Zhou, Pei Ke, Yuxian Gu, Deming Ye, Yujia Qin, Yusheng Su, Haozhe Ji, Jian Guan, Fanchao Qi, Xiaozhi Wang, Yanan Zheng, Guoyang Zeng, Huanqi Cao, Shengqi Chen, Daixuan Li, Zhenbo Sun, Zhiyuan Liu, Minlie Huang, Wentao Han, Jie Tang, Juanzi Li, Xiaoyan Zhu, Maosong Sun. 1. **[CPM](model_doc/cpm)** (from Tsinghua University) released with the paper [CPM: A Large-scale Generative Chinese Pre-trained Language Model](https://arxiv.org/abs/2012.00413) by Zhengyan Zhang, Xu Han, Hao Zhou, Pei Ke, Yuxian Gu, Deming Ye, Yujia Qin, Yusheng Su, Haozhe Ji, Jian Guan, Fanchao Qi, Xiaozhi Wang, Yanan Zheng, Guoyang Zeng, Huanqi Cao, Shengqi Chen, Daixuan Li, Zhenbo Sun, Zhiyuan Liu, Minlie Huang, Wentao Han, Jie Tang, Juanzi Li, Xiaoyan Zhu, Maosong Sun.
1. **[CTRL](model_doc/ctrl)** (from Salesforce) released with the paper [CTRL: A Conditional Transformer Language Model for Controllable Generation](https://arxiv.org/abs/1909.05858) by Nitish Shirish Keskar*, Bryan McCann*, Lav R. Varshney, Caiming Xiong and Richard Socher. 1. **[CTRL](model_doc/ctrl)** (from Salesforce) released with the paper [CTRL: A Conditional Transformer Language Model for Controllable Generation](https://arxiv.org/abs/1909.05858) by Nitish Shirish Keskar*, Bryan McCann*, Lav R. Varshney, Caiming Xiong and Richard Socher.
1. **[CvT](model_doc/cvt)** (from Microsoft) released with the paper [CvT: Introducing Convolutions to Vision Transformers](https://arxiv.org/abs/2103.15808) by Haiping Wu, Bin Xiao, Noel Codella, Mengchen Liu, Xiyang Dai, Lu Yuan, Lei Zhang. 1. **[CvT](model_doc/cvt)** (from Microsoft) released with the paper [CvT: Introducing Convolutions to Vision Transformers](https://arxiv.org/abs/2103.15808) by Haiping Wu, Bin Xiao, Noel Codella, Mengchen Liu, Xiyang Dai, Lu Yuan, Lei Zhang.
......
...@@ -241,6 +241,7 @@ _import_structure = { ...@@ -241,6 +241,7 @@ _import_structure = {
"models.conditional_detr": ["CONDITIONAL_DETR_PRETRAINED_CONFIG_ARCHIVE_MAP", "ConditionalDetrConfig"], "models.conditional_detr": ["CONDITIONAL_DETR_PRETRAINED_CONFIG_ARCHIVE_MAP", "ConditionalDetrConfig"],
"models.convbert": ["CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ConvBertConfig", "ConvBertTokenizer"], "models.convbert": ["CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ConvBertConfig", "ConvBertTokenizer"],
"models.convnext": ["CONVNEXT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ConvNextConfig"], "models.convnext": ["CONVNEXT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ConvNextConfig"],
"models.convnextv2": ["CONVNEXTV2_PRETRAINED_CONFIG_ARCHIVE_MAP", "ConvNextV2Config"],
"models.cpm": [], "models.cpm": [],
"models.ctrl": ["CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP", "CTRLConfig", "CTRLTokenizer"], "models.ctrl": ["CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP", "CTRLConfig", "CTRLTokenizer"],
"models.cvt": ["CVT_PRETRAINED_CONFIG_ARCHIVE_MAP", "CvtConfig"], "models.cvt": ["CVT_PRETRAINED_CONFIG_ARCHIVE_MAP", "CvtConfig"],
...@@ -1300,6 +1301,15 @@ else: ...@@ -1300,6 +1301,15 @@ else:
"ConvNextPreTrainedModel", "ConvNextPreTrainedModel",
] ]
) )
_import_structure["models.convnextv2"].extend(
[
"CONVNEXTV2_PRETRAINED_MODEL_ARCHIVE_LIST",
"ConvNextV2Backbone",
"ConvNextV2ForImageClassification",
"ConvNextV2Model",
"ConvNextV2PreTrainedModel",
]
)
_import_structure["models.ctrl"].extend( _import_structure["models.ctrl"].extend(
[ [
"CTRL_PRETRAINED_MODEL_ARCHIVE_LIST", "CTRL_PRETRAINED_MODEL_ARCHIVE_LIST",
...@@ -3851,6 +3861,7 @@ if TYPE_CHECKING: ...@@ -3851,6 +3861,7 @@ if TYPE_CHECKING:
from .models.conditional_detr import CONDITIONAL_DETR_PRETRAINED_CONFIG_ARCHIVE_MAP, ConditionalDetrConfig from .models.conditional_detr import CONDITIONAL_DETR_PRETRAINED_CONFIG_ARCHIVE_MAP, ConditionalDetrConfig
from .models.convbert import CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, ConvBertConfig, ConvBertTokenizer from .models.convbert import CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, ConvBertConfig, ConvBertTokenizer
from .models.convnext import CONVNEXT_PRETRAINED_CONFIG_ARCHIVE_MAP, ConvNextConfig from .models.convnext import CONVNEXT_PRETRAINED_CONFIG_ARCHIVE_MAP, ConvNextConfig
from .models.convnextv2 import CONVNEXTV2_PRETRAINED_CONFIG_ARCHIVE_MAP, ConvNextV2Config
from .models.ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig, CTRLTokenizer from .models.ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig, CTRLTokenizer
from .models.cvt import CVT_PRETRAINED_CONFIG_ARCHIVE_MAP, CvtConfig from .models.cvt import CVT_PRETRAINED_CONFIG_ARCHIVE_MAP, CvtConfig
from .models.data2vec import ( from .models.data2vec import (
...@@ -4777,6 +4788,13 @@ if TYPE_CHECKING: ...@@ -4777,6 +4788,13 @@ if TYPE_CHECKING:
ConvNextModel, ConvNextModel,
ConvNextPreTrainedModel, ConvNextPreTrainedModel,
) )
from .models.convnextv2 import (
CONVNEXTV2_PRETRAINED_MODEL_ARCHIVE_LIST,
ConvNextV2Backbone,
ConvNextV2ForImageClassification,
ConvNextV2Model,
ConvNextV2PreTrainedModel,
)
from .models.ctrl import ( from .models.ctrl import (
CTRL_PRETRAINED_MODEL_ARCHIVE_LIST, CTRL_PRETRAINED_MODEL_ARCHIVE_LIST,
CTRLForSequenceClassification, CTRLForSequenceClassification,
......
...@@ -48,6 +48,7 @@ from . import ( ...@@ -48,6 +48,7 @@ from . import (
conditional_detr, conditional_detr,
convbert, convbert,
convnext, convnext,
convnextv2,
cpm, cpm,
ctrl, ctrl,
cvt, cvt,
......
...@@ -57,6 +57,7 @@ CONFIG_MAPPING_NAMES = OrderedDict( ...@@ -57,6 +57,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("conditional_detr", "ConditionalDetrConfig"), ("conditional_detr", "ConditionalDetrConfig"),
("convbert", "ConvBertConfig"), ("convbert", "ConvBertConfig"),
("convnext", "ConvNextConfig"), ("convnext", "ConvNextConfig"),
("convnextv2", "ConvNextV2Config"),
("ctrl", "CTRLConfig"), ("ctrl", "CTRLConfig"),
("cvt", "CvtConfig"), ("cvt", "CvtConfig"),
("data2vec-audio", "Data2VecAudioConfig"), ("data2vec-audio", "Data2VecAudioConfig"),
...@@ -236,6 +237,7 @@ CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict( ...@@ -236,6 +237,7 @@ CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
("conditional_detr", "CONDITIONAL_DETR_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("conditional_detr", "CONDITIONAL_DETR_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("convbert", "CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("convbert", "CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("convnext", "CONVNEXT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("convnext", "CONVNEXT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("convnextv2", "CONVNEXTV2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("ctrl", "CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("ctrl", "CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("cvt", "CVT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("cvt", "CVT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("data2vec-audio", "DATA2VEC_AUDIO_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("data2vec-audio", "DATA2VEC_AUDIO_PRETRAINED_CONFIG_ARCHIVE_MAP"),
...@@ -405,6 +407,7 @@ MODEL_NAMES_MAPPING = OrderedDict( ...@@ -405,6 +407,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("conditional_detr", "Conditional DETR"), ("conditional_detr", "Conditional DETR"),
("convbert", "ConvBERT"), ("convbert", "ConvBERT"),
("convnext", "ConvNeXT"), ("convnext", "ConvNeXT"),
("convnextv2", "ConvNeXTV2"),
("cpm", "CPM"), ("cpm", "CPM"),
("ctrl", "CTRL"), ("ctrl", "CTRL"),
("cvt", "CvT"), ("cvt", "CvT"),
......
...@@ -48,6 +48,7 @@ IMAGE_PROCESSOR_MAPPING_NAMES = OrderedDict( ...@@ -48,6 +48,7 @@ IMAGE_PROCESSOR_MAPPING_NAMES = OrderedDict(
("clipseg", "ViTImageProcessor"), ("clipseg", "ViTImageProcessor"),
("conditional_detr", "ConditionalDetrImageProcessor"), ("conditional_detr", "ConditionalDetrImageProcessor"),
("convnext", "ConvNextImageProcessor"), ("convnext", "ConvNextImageProcessor"),
("convnextv2", "ConvNextImageProcessor"),
("cvt", "ConvNextImageProcessor"), ("cvt", "ConvNextImageProcessor"),
("data2vec-vision", "BeitImageProcessor"), ("data2vec-vision", "BeitImageProcessor"),
("deformable_detr", "DeformableDetrImageProcessor"), ("deformable_detr", "DeformableDetrImageProcessor"),
......
...@@ -56,6 +56,7 @@ MODEL_MAPPING_NAMES = OrderedDict( ...@@ -56,6 +56,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("conditional_detr", "ConditionalDetrModel"), ("conditional_detr", "ConditionalDetrModel"),
("convbert", "ConvBertModel"), ("convbert", "ConvBertModel"),
("convnext", "ConvNextModel"), ("convnext", "ConvNextModel"),
("convnextv2", "ConvNextV2Model"),
("ctrl", "CTRLModel"), ("ctrl", "CTRLModel"),
("cvt", "CvtModel"), ("cvt", "CvtModel"),
("data2vec-audio", "Data2VecAudioModel"), ("data2vec-audio", "Data2VecAudioModel"),
...@@ -410,6 +411,7 @@ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ...@@ -410,6 +411,7 @@ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("beit", "BeitForImageClassification"), ("beit", "BeitForImageClassification"),
("bit", "BitForImageClassification"), ("bit", "BitForImageClassification"),
("convnext", "ConvNextForImageClassification"), ("convnext", "ConvNextForImageClassification"),
("convnextv2", "ConvNextV2ForImageClassification"),
("cvt", "CvtForImageClassification"), ("cvt", "CvtForImageClassification"),
("data2vec-vision", "Data2VecVisionForImageClassification"), ("data2vec-vision", "Data2VecVisionForImageClassification"),
("deit", ("DeiTForImageClassification", "DeiTForImageClassificationWithTeacher")), ("deit", ("DeiTForImageClassification", "DeiTForImageClassificationWithTeacher")),
...@@ -938,6 +940,7 @@ MODEL_FOR_BACKBONE_MAPPING_NAMES = OrderedDict( ...@@ -938,6 +940,7 @@ MODEL_FOR_BACKBONE_MAPPING_NAMES = OrderedDict(
# Backbone mapping # Backbone mapping
("bit", "BitBackbone"), ("bit", "BitBackbone"),
("convnext", "ConvNextBackbone"), ("convnext", "ConvNextBackbone"),
("convnextv2", "ConvNextV2Backbone"),
("dinat", "DinatBackbone"), ("dinat", "DinatBackbone"),
("maskformer-swin", "MaskFormerSwinBackbone"), ("maskformer-swin", "MaskFormerSwinBackbone"),
("nat", "NatBackbone"), ("nat", "NatBackbone"),
......
...@@ -61,7 +61,7 @@ class ConvNextImageProcessor(BaseImageProcessor): ...@@ -61,7 +61,7 @@ class ConvNextImageProcessor(BaseImageProcessor):
be matched to `int(size["shortest_edge"]/crop_pct)`, after which the image is cropped to be matched to `int(size["shortest_edge"]/crop_pct)`, after which the image is cropped to
`(size["shortest_edge"], size["shortest_edge"])`. Only has an effect if `do_resize` is set to `True`. Can `(size["shortest_edge"], size["shortest_edge"])`. Only has an effect if `do_resize` is set to `True`. Can
be overriden by `size` in the `preprocess` method. be overriden by `size` in the `preprocess` method.
crop_pct (`float` *optional*, defaults to 244 / 256): crop_pct (`float` *optional*, defaults to 224 / 256):
Percentage of the image to crop. Only has an effect if `do_resize` is `True` and size < 384. Can be Percentage of the image to crop. Only has an effect if `do_resize` is `True` and size < 384. Can be
overriden by `crop_pct` in the `preprocess` method. overriden by `crop_pct` in the `preprocess` method.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`): resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`):
......
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
# rely on isort to merge the imports
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_torch_available,
)
_import_structure = {
"configuration_convnextv2": [
"CONVNEXTV2_PRETRAINED_CONFIG_ARCHIVE_MAP",
"ConvNextV2Config",
]
}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_convnextv2"] = [
"CONVNEXTV2_PRETRAINED_MODEL_ARCHIVE_LIST",
"ConvNextV2ForImageClassification",
"ConvNextV2Model",
"ConvNextV2PreTrainedModel",
"ConvNextV2Backbone",
]
if TYPE_CHECKING:
from .configuration_convnextv2 import (
CONVNEXTV2_PRETRAINED_CONFIG_ARCHIVE_MAP,
ConvNextV2Config,
)
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_convnextv2 import (
CONVNEXTV2_PRETRAINED_MODEL_ARCHIVE_LIST,
ConvNextV2Backbone,
ConvNextV2ForImageClassification,
ConvNextV2Model,
ConvNextV2PreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
# coding=utf-8
# Copyright 2023 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" ConvNeXTV2 model configuration"""
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
CONVNEXTV2_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"facebook/convnextv2-tiny-1k-224": "https://huggingface.co/facebook/convnextv2-tiny-1k-224/resolve/main/config.json",
}
class ConvNextV2Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`ConvNextV2Model`]. It is used to instantiate an
ConvNeXTV2 model according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the ConvNeXTV2
[facebook/convnextv2-tiny-1k-224](https://huggingface.co/facebook/convnextv2-tiny-1k-224) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
num_channels (`int`, *optional*, defaults to 3):
The number of input channels.
patch_size (`int`, optional, defaults to 4):
Patch size to use in the patch embedding layer.
num_stages (`int`, optional, defaults to 4):
The number of stages in the model.
hidden_sizes (`List[int]`, *optional*, defaults to `[96, 192, 384, 768]`):
Dimensionality (hidden size) at each stage.
depths (`List[int]`, *optional*, defaults to `[3, 3, 9, 3]`):
Depth (number of blocks) for each stage.
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in each block. If string, `"gelu"`, `"relu"`,
`"selu"` and `"gelu_new"` are supported.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
The epsilon used by the layer normalization layers.
drop_path_rate (`float`, *optional*, defaults to 0.0):
The drop rate for stochastic depth.
out_features (`List[str]`, *optional*):
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
(depending on how many stages the model has). Will default to the last stage if unset.
Example:
```python
>>> from transformers import ConvNeXTV2Config, ConvNextV2Model
>>> # Initializing a ConvNeXTV2 convnextv2-tiny-1k-224 style configuration
>>> configuration = ConvNeXTV2Config()
>>> # Initializing a model (with random weights) from the convnextv2-tiny-1k-224 style configuration
>>> model = ConvNextV2Model(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "convnextv2"
def __init__(
self,
num_channels=3,
patch_size=4,
num_stages=4,
hidden_sizes=None,
depths=None,
hidden_act="gelu",
initializer_range=0.02,
layer_norm_eps=1e-12,
drop_path_rate=0.0,
image_size=224,
out_features=None,
**kwargs,
):
super().__init__(**kwargs)
self.num_channels = num_channels
self.patch_size = patch_size
self.num_stages = num_stages
self.hidden_sizes = [96, 192, 384, 768] if hidden_sizes is None else hidden_sizes
self.depths = [3, 3, 9, 3] if depths is None else depths
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.drop_path_rate = drop_path_rate
self.image_size = image_size
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(self.depths) + 1)]
if out_features is not None:
if not isinstance(out_features, list):
raise ValueError("out_features should be a list")
for feature in out_features:
if feature not in self.stage_names:
raise ValueError(
f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}"
)
self.out_features = out_features
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert ConvNeXTV2 checkpoints from the original repository.
URL: https://github.com/facebookresearch/ConvNeXt"""
import argparse
import json
import os
import requests
import torch
from huggingface_hub import hf_hub_download
from PIL import Image
from transformers import ConvNextImageProcessor, ConvNextV2Config, ConvNextV2ForImageClassification
from transformers.image_utils import PILImageResampling
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
def get_convnextv2_config(checkpoint_url):
config = ConvNextV2Config()
if "atto" in checkpoint_url:
depths = [2, 2, 6, 2]
hidden_sizes = [40, 80, 160, 320]
if "femto" in checkpoint_url:
depths = [2, 2, 6, 2]
hidden_sizes = [48, 96, 192, 384]
if "pico" in checkpoint_url:
depths = [2, 2, 6, 2]
hidden_sizes = [64, 128, 256, 512]
if "nano" in checkpoint_url:
depths = [2, 2, 8, 2]
hidden_sizes = [80, 160, 320, 640]
if "tiny" in checkpoint_url:
depths = [3, 3, 9, 3]
hidden_sizes = [96, 192, 384, 768]
if "base" in checkpoint_url:
depths = [3, 3, 27, 3]
hidden_sizes = [128, 256, 512, 1024]
if "large" in checkpoint_url:
depths = [3, 3, 27, 3]
hidden_sizes = [192, 384, 768, 1536]
if "huge" in checkpoint_url:
depths = [3, 3, 27, 3]
hidden_sizes = [352, 704, 1408, 2816]
num_labels = 1000
filename = "imagenet-1k-id2label.json"
expected_shape = (1, 1000)
repo_id = "huggingface/label-files"
config.num_labels = num_labels
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}
config.hidden_sizes = hidden_sizes
config.depths = depths
return config, expected_shape
def rename_key(name):
if "downsample_layers.0.0" in name:
name = name.replace("downsample_layers.0.0", "embeddings.patch_embeddings")
if "downsample_layers.0.1" in name:
name = name.replace("downsample_layers.0.1", "embeddings.norm") # we rename to layernorm later on
if "downsample_layers.1.0" in name:
name = name.replace("downsample_layers.1.0", "stages.1.downsampling_layer.0")
if "downsample_layers.1.1" in name:
name = name.replace("downsample_layers.1.1", "stages.1.downsampling_layer.1")
if "downsample_layers.2.0" in name:
name = name.replace("downsample_layers.2.0", "stages.2.downsampling_layer.0")
if "downsample_layers.2.1" in name:
name = name.replace("downsample_layers.2.1", "stages.2.downsampling_layer.1")
if "downsample_layers.3.0" in name:
name = name.replace("downsample_layers.3.0", "stages.3.downsampling_layer.0")
if "downsample_layers.3.1" in name:
name = name.replace("downsample_layers.3.1", "stages.3.downsampling_layer.1")
if "stages" in name and "downsampling_layer" not in name:
# stages.0.0. for instance should be renamed to stages.0.layers.0.
name = name[: len("stages.0")] + ".layers" + name[len("stages.0") :]
if "stages" in name:
name = name.replace("stages", "encoder.stages")
if "norm" in name:
name = name.replace("norm", "layernorm")
if "head" in name:
name = name.replace("head", "classifier")
return name
# We will verify our results on an image of cute cats
def prepare_img():
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
im = Image.open(requests.get(url, stream=True).raw)
return im
def convert_preprocessor(checkpoint_url):
if "224" in checkpoint_url:
size = 224
crop_pct = 224 / 256
elif "384" in checkpoint_url:
size = 384
crop_pct = None
else:
size = 512
crop_pct = None
return ConvNextImageProcessor(
size=size,
crop_pct=crop_pct,
image_mean=[0.485, 0.456, 0.406],
image_std=[0.229, 0.224, 0.225],
resample=PILImageResampling.BICUBIC,
)
@torch.no_grad()
def convert_convnextv2_checkpoint(checkpoint_url, pytorch_dump_folder_path, save_model, push_to_hub):
"""
Copy/paste/tweak model's weights to our ConvNeXTV2 structure.
"""
print("Downloading original model from checkpoint...")
# define ConvNeXTV2 configuration based on URL
config, expected_shape = get_convnextv2_config(checkpoint_url)
# load original state_dict from URL
state_dict = torch.hub.load_state_dict_from_url(checkpoint_url)["model"]
print("Converting model parameters...")
# rename keys
for key in state_dict.copy().keys():
val = state_dict.pop(key)
state_dict[rename_key(key)] = val
# add prefix to all keys expect classifier head
for key in state_dict.copy().keys():
val = state_dict.pop(key)
if not key.startswith("classifier"):
key = "convnextv2." + key
state_dict[key] = val
# load HuggingFace model
model = ConvNextV2ForImageClassification(config)
model.load_state_dict(state_dict)
model.eval()
# Check outputs on an image, prepared by ConvNextImageProcessor
preprocessor = convert_preprocessor(checkpoint_url)
inputs = preprocessor(images=prepare_img(), return_tensors="pt")
logits = model(**inputs).logits
# note: the logits below were obtained without center cropping
if checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_atto_1k_224_ema.pt":
expected_logits = torch.tensor([-0.3930, 0.1747, -0.5246, 0.4177, 0.4295])
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_femto_1k_224_ema.pt":
expected_logits = torch.tensor([-0.1727, -0.5341, -0.7818, -0.4745, -0.6566])
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_pico_1k_224_ema.pt":
expected_logits = torch.tensor([-0.0333, 0.1563, -0.9137, 0.1054, 0.0381])
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_nano_1k_224_ema.pt":
expected_logits = torch.tensor([-0.1744, -0.1555, -0.0713, 0.0950, -0.1431])
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_tiny_1k_224_ema.pt":
expected_logits = torch.tensor([0.9996, 0.1966, -0.4386, -0.3472, 0.6661])
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_base_1k_224_ema.pt":
expected_logits = torch.tensor([-0.2553, -0.6708, -0.1359, 0.2518, -0.2488])
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_large_1k_224_ema.pt":
expected_logits = torch.tensor([-0.0673, -0.5627, -0.3753, -0.2722, 0.0178])
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_huge_1k_224_ema.pt":
expected_logits = torch.tensor([-0.6377, -0.7458, -0.2150, 0.1184, -0.0597])
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_nano_22k_224_ema.pt":
expected_logits = torch.tensor([1.0799, 0.2322, -0.8860, 1.0219, 0.6231])
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_nano_22k_384_ema.pt":
expected_logits = torch.tensor([0.3766, 0.4917, -1.1426, 0.9942, 0.6024])
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_tiny_22k_224_ema.pt":
expected_logits = torch.tensor([0.4220, -0.6919, -0.4317, -0.2881, -0.6609])
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_tiny_22k_384_ema.pt":
expected_logits = torch.tensor([0.1082, -0.8286, -0.5095, 0.4681, -0.8085])
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_base_22k_224_ema.pt":
expected_logits = torch.tensor([-0.2419, -0.6221, 0.2176, -0.0980, -0.7527])
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_base_22k_384_ema.pt":
expected_logits = torch.tensor([0.0391, -0.4371, 0.3786, 0.1251, -0.2784])
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_large_22k_224_ema.pt":
expected_logits = torch.tensor([-0.0504, 0.5636, -0.1729, -0.6507, -0.3949])
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_large_22k_384_ema.pt":
expected_logits = torch.tensor([0.3560, 0.9486, 0.3149, -0.2667, -0.5138])
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_huge_22k_384_ema.pt":
expected_logits = torch.tensor([-0.2469, -0.4550, -0.5853, -0.0810, 0.0309])
elif checkpoint_url == "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_huge_22k_512_ema.pt":
expected_logits = torch.tensor([-0.3090, 0.0802, -0.0682, -0.1979, -0.2826])
else:
raise ValueError(f"Unknown URL: {checkpoint_url}")
assert torch.allclose(logits[0, :5], expected_logits, atol=1e-3)
assert logits.shape == expected_shape
print("Model outputs match the original results!")
if save_model:
print("Saving model to local...")
# Create folder to save model
if not os.path.isdir(pytorch_dump_folder_path):
os.mkdir(pytorch_dump_folder_path)
model.save_pretrained(pytorch_dump_folder_path)
preprocessor.save_pretrained(pytorch_dump_folder_path)
model_name = "convnextv2"
if "atto" in checkpoint_url:
model_name += "-atto"
if "femto" in checkpoint_url:
model_name += "-femto"
if "pico" in checkpoint_url:
model_name += "-pico"
if "nano" in checkpoint_url:
model_name += "-nano"
elif "tiny" in checkpoint_url:
model_name += "-tiny"
elif "base" in checkpoint_url:
model_name += "-base"
elif "large" in checkpoint_url:
model_name += "-large"
elif "huge" in checkpoint_url:
model_name += "-huge"
if "22k" in checkpoint_url and "1k" not in checkpoint_url:
model_name += "-22k"
elif "22k" in checkpoint_url and "1k" in checkpoint_url:
model_name += "-22k-1k"
elif "1k" in checkpoint_url:
model_name += "-1k"
if "224" in checkpoint_url:
model_name += "-224"
elif "384" in checkpoint_url:
model_name += "-384"
elif "512" in checkpoint_url:
model_name += "-512"
if push_to_hub:
print(f"Pushing {model_name} to the hub...")
model.push_to_hub(model_name)
preprocessor.push_to_hub(model_name)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--checkpoint_url",
default="https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_atto_1k_224_ema.pt",
type=str,
help="URL of the original ConvNeXTV2 checkpoint you'd like to convert.",
)
parser.add_argument(
"--pytorch_dump_folder_path",
default="model",
type=str,
help="Path to the output PyTorch model directory.",
)
parser.add_argument("--save_model", action="store_true", help="Save model to local")
parser.add_argument("--push_to_hub", action="store_true", help="Push model and image preprocessor to the hub")
args = parser.parse_args()
convert_convnextv2_checkpoint(
args.checkpoint_url, args.pytorch_dump_folder_path, args.save_model, args.push_to_hub
)
This diff is collapsed.
...@@ -1803,6 +1803,37 @@ class ConvNextPreTrainedModel(metaclass=DummyObject): ...@@ -1803,6 +1803,37 @@ class ConvNextPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
CONVNEXTV2_PRETRAINED_MODEL_ARCHIVE_LIST = None
class ConvNextV2Backbone(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class ConvNextV2ForImageClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class ConvNextV2Model(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class ConvNextV2PreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
CTRL_PRETRAINED_MODEL_ARCHIVE_LIST = None CTRL_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the PyTorch ConvNextV2 model. """
import inspect
import unittest
from transformers import ConvNextV2Config
from transformers.models.auto import get_values
from transformers.models.auto.modeling_auto import MODEL_FOR_BACKBONE_MAPPING_NAMES, MODEL_MAPPING_NAMES
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
from transformers.utils import cached_property, is_torch_available, is_vision_available
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
if is_torch_available():
import torch
from transformers import ConvNextV2Backbone, ConvNextV2ForImageClassification, ConvNextV2Model
from transformers.models.convnextv2.modeling_convnextv2 import CONVNEXTV2_PRETRAINED_MODEL_ARCHIVE_LIST
if is_vision_available():
from PIL import Image
from transformers import AutoImageProcessor
class ConvNextV2ModelTester:
def __init__(
self,
parent,
batch_size=13,
image_size=32,
num_channels=3,
num_stages=4,
hidden_sizes=[10, 20, 30, 40],
depths=[2, 2, 3, 2],
is_training=True,
use_labels=True,
intermediate_size=37,
hidden_act="gelu",
num_labels=10,
initializer_range=0.02,
out_features=["stage2", "stage3", "stage4"],
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.image_size = image_size
self.num_channels = num_channels
self.num_stages = num_stages
self.hidden_sizes = hidden_sizes
self.depths = depths
self.is_training = is_training
self.use_labels = use_labels
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.num_labels = num_labels
self.initializer_range = initializer_range
self.out_features = out_features
self.scope = scope
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
labels = None
if self.use_labels:
labels = ids_tensor([self.batch_size], self.num_labels)
config = self.get_config()
return config, pixel_values, labels
def get_config(self):
return ConvNextV2Config(
num_channels=self.num_channels,
hidden_sizes=self.hidden_sizes,
depths=self.depths,
num_stages=self.num_stages,
hidden_act=self.hidden_act,
is_decoder=False,
initializer_range=self.initializer_range,
out_features=self.out_features,
num_labels=self.num_labels,
)
def create_and_check_model(self, config, pixel_values, labels):
model = ConvNextV2Model(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
# expected last hidden states: B, C, H // 32, W // 32
self.parent.assertEqual(
result.last_hidden_state.shape,
(self.batch_size, self.hidden_sizes[-1], self.image_size // 32, self.image_size // 32),
)
def create_and_check_for_image_classification(self, config, pixel_values, labels):
model = ConvNextV2ForImageClassification(config)
model.to(torch_device)
model.eval()
result = model(pixel_values, labels=labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
def create_and_check_backbone(self, config, pixel_values, labels):
model = ConvNextV2Backbone(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
# verify hidden states
self.parent.assertEqual(len(result.feature_maps), len(config.out_features))
self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, self.hidden_sizes[1], 4, 4])
# verify channels
self.parent.assertEqual(len(model.channels), len(config.out_features))
self.parent.assertListEqual(model.channels, config.hidden_sizes[1:])
# verify backbone works with out_features=None
config.out_features = None
model = ConvNextV2Backbone(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
# verify feature maps
self.parent.assertEqual(len(result.feature_maps), 1)
self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, self.hidden_sizes[-1], 1, 1])
# verify channels
self.parent.assertEqual(len(model.channels), 1)
self.parent.assertListEqual(model.channels, [config.hidden_sizes[-1]])
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values, labels = config_and_inputs
inputs_dict = {"pixel_values": pixel_values}
return config, inputs_dict
def prepare_config_and_inputs_with_labels(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values, labels = config_and_inputs
inputs_dict = {"pixel_values": pixel_values, "labels": labels}
return config, inputs_dict
@require_torch
class ConvNextV2ModelTest(ModelTesterMixin, unittest.TestCase):
"""
Here we also overwrite some of the tests of test_modeling_common.py, as ConvNextV2 does not use input_ids, inputs_embeds,
attention_mask and seq_length.
"""
all_model_classes = (
(
ConvNextV2Model,
ConvNextV2ForImageClassification,
ConvNextV2Backbone,
)
if is_torch_available()
else ()
)
fx_compatible = False
test_pruning = False
test_resize_embeddings = False
test_head_masking = False
has_attentions = False
def setUp(self):
self.model_tester = ConvNextV2ModelTester(self)
self.config_tester = ConfigTester(self, config_class=ConvNextV2Config, has_text_modality=False, hidden_size=37)
def test_config(self):
self.create_and_test_config_common_properties()
self.config_tester.create_and_test_config_to_json_string()
self.config_tester.create_and_test_config_to_json_file()
self.config_tester.create_and_test_config_from_and_save_pretrained()
self.config_tester.create_and_test_config_with_num_labels()
self.config_tester.check_config_can_be_init_without_params()
self.config_tester.check_config_arguments_init()
def create_and_test_config_common_properties(self):
return
@unittest.skip(reason="ConvNextV2 does not use inputs_embeds")
def test_inputs_embeds(self):
pass
@unittest.skip(reason="ConvNextV2 does not support input and output embeddings")
def test_model_common_attributes(self):
pass
@unittest.skip(reason="ConvNextV2 does not use feedforward chunking")
def test_feed_forward_chunking(self):
pass
def test_training(self):
if not self.model_tester.is_training:
return
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_with_labels()
config.return_dict = True
if model_class.__name__ in [
*get_values(MODEL_MAPPING_NAMES),
*get_values(MODEL_FOR_BACKBONE_MAPPING_NAMES),
]:
continue
model = model_class(config)
model.to(torch_device)
model.train()
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
loss = model(**inputs).loss
loss.backward()
def test_training_gradient_checkpointing(self):
if not self.model_tester.is_training:
return
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_with_labels()
config.use_cache = False
config.return_dict = True
if (
model_class.__name__
in [*get_values(MODEL_MAPPING_NAMES), *get_values(MODEL_FOR_BACKBONE_MAPPING_NAMES)]
or not model_class.supports_gradient_checkpointing
):
continue
model = model_class(config)
model.to(torch_device)
model.gradient_checkpointing_enable()
model.train()
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
loss = model(**inputs).loss
loss.backward()
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.forward)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["pixel_values"]
self.assertListEqual(arg_names[:1], expected_arg_names)
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class):
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
expected_num_stages = self.model_tester.num_stages
self.assertEqual(len(hidden_states), expected_num_stages + 1)
# ConvNextV2's feature maps are of shape (batch_size, num_channels, height, width)
self.assertListEqual(
list(hidden_states[0].shape[-2:]),
[self.model_tester.image_size // 4, self.model_tester.image_size // 4],
)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
inputs_dict["output_hidden_states"] = True
check_hidden_states_output(inputs_dict, config, model_class)
# check that output_hidden_states also work using config
del inputs_dict["output_hidden_states"]
config.output_hidden_states = True
check_hidden_states_output(inputs_dict, config, model_class)
def test_for_image_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
@slow
def test_model_from_pretrained(self):
for model_name in CONVNEXTV2_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = ConvNextV2Model.from_pretrained(model_name)
self.assertIsNotNone(model)
# We will verify our results on an image of cute cats
def prepare_img():
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
return image
@require_torch
@require_vision
class ConvNextV2ModelIntegrationTest(unittest.TestCase):
@cached_property
def default_image_processor(self):
return AutoImageProcessor.from_pretrained("facebook/convnextv2-tiny-1k-224") if is_vision_available() else None
@slow
def test_inference_image_classification_head(self):
model = ConvNextV2ForImageClassification.from_pretrained("facebook/convnextv2-tiny-1k-224").to(torch_device)
preprocessor = self.default_image_processor
image = prepare_img()
inputs = preprocessor(images=image, return_tensors="pt").to(torch_device)
# forward pass
with torch.no_grad():
outputs = model(**inputs)
# verify the logits
expected_shape = torch.Size((1, 1000))
self.assertEqual(outputs.logits.shape, expected_shape)
expected_slice = torch.tensor([-0.3083, -0.3040, -0.4344]).to(torch_device)
self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
...@@ -820,6 +820,7 @@ SHOULD_HAVE_THEIR_OWN_PAGE = [ ...@@ -820,6 +820,7 @@ SHOULD_HAVE_THEIR_OWN_PAGE = [
"AutoBackbone", "AutoBackbone",
"BitBackbone", "BitBackbone",
"ConvNextBackbone", "ConvNextBackbone",
"ConvNextV2Backbone",
"DinatBackbone", "DinatBackbone",
"MaskFormerSwinBackbone", "MaskFormerSwinBackbone",
"MaskFormerSwinConfig", "MaskFormerSwinConfig",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment