configuration_resnet.py 5.56 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# coding=utf-8
# Copyright 2022 Microsoft Research, Inc. and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" ResNet model configuration"""

regisss's avatar
regisss committed
17
18
19
20
21
from collections import OrderedDict
from typing import Mapping

from packaging import version

22
from ...configuration_utils import PretrainedConfig
regisss's avatar
regisss committed
23
from ...onnx import OnnxConfig
24
from ...utils import logging
25
from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
26
27
28
29
30


logger = logging.get_logger(__name__)

RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP = {
31
    "microsoft/resnet-50": "https://huggingface.co/microsoft/resnet-50/blob/main/config.json",
32
33
34
}


35
class ResNetConfig(BackboneConfigMixin, PretrainedConfig):
36
37
38
    r"""
    This is the configuration class to store the configuration of a [`ResNetModel`]. It is used to instantiate an
    ResNet model according to the specified arguments, defining the model architecture. Instantiating a configuration
39
40
    with the defaults will yield a similar configuration to that of the ResNet
    [microsoft/resnet-50](https://huggingface.co/microsoft/resnet-50) architecture.
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.

    Args:
        num_channels (`int`, *optional*, defaults to 3):
            The number of input channels.
        embedding_size (`int`, *optional*, defaults to 64):
            Dimensionality (hidden size) for the embedding layer.
        hidden_sizes (`List[int]`, *optional*, defaults to `[256, 512, 1024, 2048]`):
            Dimensionality (hidden size) at each stage.
        depths (`List[int]`, *optional*, defaults to `[3, 4, 6, 3]`):
            Depth (number of layers) for each stage.
        layer_type (`str`, *optional*, defaults to `"bottleneck"`):
            The layer to use, it can be either `"basic"` (used for smaller models, like resnet-18 or resnet-34) or
            `"bottleneck"` (used for larger models like resnet-50 and above).
        hidden_act (`str`, *optional*, defaults to `"relu"`):
            The non-linear activation function in each block. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"`
            are supported.
        downsample_in_first_stage (`bool`, *optional*, defaults to `False`):
            If `True`, the first stage will downsample the inputs using a `stride` of 2.
62
        out_features (`List[str]`, *optional*):
63
            If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
64
65
66
67
68
69
            (depending on how many stages the model has). If unset and `out_indices` is set, will default to the
            corresponding stages. If unset and `out_indices` is unset, will default to the last stage.
        out_indices (`List[int]`, *optional*):
            If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
            many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
            If unset and `out_features` is unset, will default to the last stage.
70
71
72
73
74
75
76

    Example:
    ```python
    >>> from transformers import ResNetConfig, ResNetModel

    >>> # Initializing a ResNet resnet-50 style configuration
    >>> configuration = ResNetConfig()
77
78

    >>> # Initializing a model (with random weights) from the resnet-50 style configuration
79
    >>> model = ResNetModel(configuration)
80

81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```
    """
    model_type = "resnet"
    layer_types = ["basic", "bottleneck"]

    def __init__(
        self,
        num_channels=3,
        embedding_size=64,
        hidden_sizes=[256, 512, 1024, 2048],
        depths=[3, 4, 6, 3],
        layer_type="bottleneck",
        hidden_act="relu",
        downsample_in_first_stage=False,
97
        out_features=None,
98
        out_indices=None,
99
        **kwargs,
100
101
102
103
104
105
106
107
108
109
110
    ):
        super().__init__(**kwargs)
        if layer_type not in self.layer_types:
            raise ValueError(f"layer_type={layer_type} is not one of {','.join(self.layer_types)}")
        self.num_channels = num_channels
        self.embedding_size = embedding_size
        self.hidden_sizes = hidden_sizes
        self.depths = depths
        self.layer_type = layer_type
        self.hidden_act = hidden_act
        self.downsample_in_first_stage = downsample_in_first_stage
111
        self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)]
112
113
114
        self._out_features, self._out_indices = get_aligned_output_features_output_indices(
            out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
        )
regisss's avatar
regisss committed
115
116
117
118
119
120
121
122
123


class ResNetOnnxConfig(OnnxConfig):
    torch_onnx_minimum_version = version.parse("1.11")

    @property
    def inputs(self) -> Mapping[str, Mapping[int, str]]:
        return OrderedDict(
            [
124
                ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
regisss's avatar
regisss committed
125
126
127
128
129
130
            ]
        )

    @property
    def atol_for_validation(self) -> float:
        return 1e-3