mlp_mixer.py 9.46 KB
Newer Older
zhanggzh's avatar
zhanggzh committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""MLP Mixer models for KerasCV.

Reference:
  - [MLP-Mixer: An all-MLP Architecture for Vision](https://arxiv.org/abs/2105.01601)
"""

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import backend
from tensorflow.keras import layers

from keras_cv.models import utils


def MLPBlock(mlp_dim, name=None):
    """An MLP block consisting of two linear layers with GELU activation in
    between.

    Args:
      mlp_dim: integer, the number of units to be present in the first layer.
      name: string, block label.

    Returns:
      a function that takes an input Tensor representing an MLP block.
    """
    if name is None:
        name = f"mlp_block_{backend.get_uid('mlp_block')}"

    def apply(x):
        y = layers.Dense(mlp_dim, name=f"{name}_dense_1")(x)
        y = layers.Activation("gelu", name=f"{name}_gelu")(y)
        return layers.Dense(x.shape[-1], name=f"{name}_dense_2")(y)

    return apply


def MixerBlock(tokens_mlp_dim, channels_mlp_dim, name=None):
    """A mixer block.

    Args:
      tokens_mlp_dim: integer, number of units to be present in the MLP block
        dealing with tokens.
      channels_mlp_dim: integer, number of units to be present in the MLP block
        dealing with channels.
      name: string, block label.

    Returns:
      a function that takes an input Tensor representing an MLP block.
    """
    if name is None:
        name = f"mixer_block_{backend.get_uid('mlp_block')}"

    def apply(x):
        y = layers.LayerNormalization()(x)
        y = layers.Permute((2, 1))(y)

        y = MLPBlock(tokens_mlp_dim, name=f"{name}_token_mixing")(y)
        y = layers.Permute((2, 1))(y)
        x = layers.Add()([x, y])

        y = layers.LayerNormalization()(x)
        y = MLPBlock(channels_mlp_dim, name=f"{name}_channel_mixing")(y)
        return layers.Add()([x, y])

    return apply


def MLPMixer(
    input_shape,
    patch_size,
    num_blocks,
    hidden_dim,
    tokens_mlp_dim,
    channels_mlp_dim,
    include_rescaling,
    include_top,
    classes=None,
    input_tensor=None,
    weights=None,
    pooling=None,
    classifier_activation="softmax",
    name=None,
    **kwargs,
):
    """Instantiates the MLP Mixer architecture.

    Reference:
    - [MLP-Mixer: An all-MLP Architecture for Vision (NeurIPS 2021)](https://arxiv.org/abs/2105.01601)

    This function returns a Keras MLP Mixer model.

    For transfer learning use cases, make sure to read the
    [guide to transfer learning & fine-tuning](https://keras.io/guides/transfer_learning/).

    Note that the `input_shape` should be fully divisible by the `patch_size`.

    Args:
      input_shape: tuple denoting the input shape, (224, 224, 3) for example.
      patch_size: tuple denoting the size of the patches to be extracted
        from the inputs ((16, 16) for example).
      num_blocks: number of mixer blocks.
      hidden_dim: dimension to which the patches will be linearly projected.
      tokens_mlp_dim: dimension of the MLP block responsible for tokens.
      channels_mlp_dim: dimension of the MLP block responsible for channels.
      include_rescaling: whether or not to Rescale the inputs.
        If set to True, inputs will be passed through a
        `Rescaling(1/255.0)` layer.
      include_top: whether to include the fully-connected
        layer at the top of the network.  If provided, classes must be provided.
      classes: optional number of classes to classify images
        into, only to be specified if `include_top` is True.
      weights: one of `None` (random initialization), or a pretrained
        weight file path.
      input_tensor: optional Keras tensor (i.e. output of `layers.Input()`)
        to use as image input for the model.
      pooling: optional pooling mode for feature extraction
        when `include_top` is `False`.
        - `None` means that the output of the model will be
            the 4D tensor output of the
            last convolutional block.
        - `avg` means that global average pooling
            will be applied to the output of the
            last convolutional block, and thus
            the output of the model will be a 2D tensor.
        - `max` means that global max pooling will
            be applied.
      classifier_activation: A `str` or callable. The activation function to use
        on the "top" layer. Ignored unless `include_top=True`. Set
        `classifier_activation=None` to return the logits of the "top" layer.
        When loading pretrained weights, `classifier_activation` can only
        be `None` or `"softmax"`.
      name: (Optional) name to pass to the model.  Defaults to "DenseNet".

    Returns:
      A `keras.Model` instance.
    """
    if weights and not tf.io.gfile.exists(weights):
        raise ValueError(
            "The `weights` argument should be either "
            "`None` or the path to the weights file to be loaded. "
            f"Weights file not found at location: {weights}"
        )

    if include_top and not classes:
        raise ValueError(
            "If `include_top` is True, "
            "you should specify `classes`. "
            f"Received: classes={classes}"
        )

    if (not isinstance(input_shape, tuple)) and (not isinstance(patch_size, tuple)):
        raise ValueError("`input_shape` and `patch_size` both need to be tuple.")

    if len(input_shape) != 3:
        raise ValueError(
            "`input_shape` needs to contain dimensions for three"
            " axes: height, width, and channel ((224, 224, 3) for example)."
        )

    if len(patch_size) != 2:
        raise ValueError(
            "`patch_size` needs to contain dimensions for two"
            " spatial axes: height, and width ((16, 16) for example)."
        )

    if input_shape[0] != input_shape[1]:
        raise ValueError("Non-uniform resolutions are not supported.")

    if patch_size[0] != patch_size[1]:
        raise ValueError("Non-uniform patch sizes are not supported.")

    if input_shape[0] % patch_size[0] != 0:
        raise ValueError("Input resolution should be divisible by the patch size.")

    inputs = utils.parse_model_inputs(input_shape, input_tensor)

    x = inputs
    if include_rescaling:
        x = layers.Rescaling(1 / 255.0)(x)

    x = layers.Conv2D(
        filters=hidden_dim,
        kernel_size=patch_size,
        strides=patch_size,
        padding="valid",
        name="patchify_and_projection",
    )(x)
    x = layers.Reshape((x.shape[1] * x.shape[2], x.shape[3]))(x)

    for i in range(num_blocks):
        x = MixerBlock(tokens_mlp_dim, channels_mlp_dim, name=f"mixer_block_{i}")(x)

    x = layers.LayerNormalization()(x)

    if include_top:
        x = layers.GlobalAveragePooling1D(name="avg_pool")(x)
        x = layers.Dense(classes, activation=classifier_activation, name="predictions")(
            x
        )

    elif pooling == "avg":
        x = layers.GlobalAveragePooling1D(name="avg_pool")(x)
    elif pooling == "max":
        x = layers.GlobalMaxPooling1D(name="max_pool")(x)

    model = keras.Model(inputs, x, name=name, **kwargs)

    if weights is not None:
        model.load_weights(weights)
    return model


def MLPMixerB16(
    input_shape,
    patch_size,
    include_rescaling,
    include_top,
    classes=None,
    input_tensor=None,
    weights=None,
    pooling=None,
    name="mlp_mixer_b16",
    **kwargs,
):
    return MLPMixer(
        input_shape=input_shape,
        patch_size=patch_size,
        num_blocks=12,
        hidden_dim=768,
        tokens_mlp_dim=384,
        channels_mlp_dim=3072,
        include_rescaling=include_rescaling,
        include_top=include_top,
        classes=classes,
        input_tensor=input_tensor,
        weights=weights,
        pooling=pooling,
        name=name,
        **kwargs,
    )


def MLPMixerB32(
    input_shape,
    patch_size,
    include_rescaling,
    include_top,
    classes=None,
    input_tensor=None,
    weights=None,
    pooling=None,
    name="mlp_mixer_b32",
    **kwargs,
):
    return MLPMixer(
        input_shape=input_shape,
        patch_size=patch_size,
        num_blocks=12,
        hidden_dim=768,
        tokens_mlp_dim=384,
        channels_mlp_dim=3072,
        include_rescaling=include_rescaling,
        include_top=include_top,
        classes=classes,
        input_tensor=input_tensor,
        weights=weights,
        pooling=pooling,
        name=name,
        **kwargs,
    )


def MLPMixerL16(
    input_shape,
    patch_size,
    include_rescaling,
    include_top,
    classes=None,
    input_tensor=None,
    weights=None,
    pooling=None,
    name="mlp_mixer_l16",
    **kwargs,
):
    return MLPMixer(
        input_shape=input_shape,
        patch_size=patch_size,
        num_blocks=24,
        hidden_dim=1024,
        tokens_mlp_dim=512,
        channels_mlp_dim=4096,
        include_rescaling=include_rescaling,
        include_top=include_top,
        classes=classes,
        input_tensor=input_tensor,
        weights=weights,
        pooling=pooling,
        name=name,
        **kwargs,
    )