vit.py 14.5 KB
Newer Older
アマデウス's avatar
アマデウス committed
1
2
3
4
5
import math
from typing import Callable

import torch
from colossalai import nn as col_nn
6
from colossalai.nn.layer.utils import CheckpointModule
アマデウス's avatar
アマデウス committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from colossalai.registry import LAYERS, MODELS
from torch import dtype, nn

__all__ = [
    'VisionTransformer',
    'vit_lite_depth7_patch4_32',
    'vit_tiny_patch4_32',
    'vit_tiny_patch16_224',
    'vit_tiny_patch16_384',
    'vit_small_patch16_224',
    'vit_small_patch16_384',
    'vit_small_patch32_224',
    'vit_small_patch32_384',
    'vit_base_patch16_224',
    'vit_base_patch16_384',
    'vit_base_patch32_224',
    'vit_base_patch32_384',
    'vit_large_patch16_224',
    'vit_large_patch16_384',
    'vit_large_patch32_224',
    'vit_large_patch32_384',
]

_init_rules = dict(
    torch=dict(
        embed=dict(
            weight_initializer=col_nn.init.kaiming_uniform_(a=math.sqrt(5)),
            bias_initializer=col_nn.init.xavier_uniform_(a=1, scale=1),
            position_embed_initializer=col_nn.init.zeros_(),
        ),
        transformer=dict(
            weight_initializer=col_nn.init.kaiming_uniform_(a=math.sqrt(5)),
            bias_initializer=col_nn.init.xavier_uniform_(a=1, scale=1),
        ),
        head=dict(
            weight_initializer=col_nn.init.kaiming_uniform_(a=math.sqrt(5)),
            bias_initializer=col_nn.init.xavier_uniform_(a=1, scale=1),
        ),
    ),
    jax=dict(
        embed=dict(
            weight_initializer=col_nn.init.lecun_normal_(),
            bias_initializer=col_nn.init.zeros_(),
            position_embed_initializer=col_nn.init.trunc_normal_(std=.02),
        ),
        transformer=dict(
            weight_initializer=col_nn.init.xavier_uniform_(),
            bias_initializer=col_nn.init.normal_(std=1e-6),
        ),
        head=dict(
            weight_initializer=col_nn.init.zeros_(),
            bias_initializer=col_nn.init.zeros_(),
        ),
    ),
)


@LAYERS.register_module
class ViTEmbedding(nn.Module):
    def __init__(self,
                 img_size: int,
                 patch_size: int,
                 in_chans: int,
                 embedding_dim: int,
                 dropout: float,
                 dtype: dtype = None,
                 flatten: bool = True,
74
                 init_method: str = 'torch'):
アマデウス's avatar
アマデウス committed
75
76
77
78
79
80
81
82
        super().__init__()
        self.patch_embed = col_nn.PatchEmbedding(img_size,
                                                 patch_size,
                                                 in_chans,
                                                 embedding_dim,
                                                 dtype=dtype,
                                                 flatten=flatten,
                                                 **_init_rules[init_method]['embed'])
83
        self.dropout = col_nn.Dropout(dropout)
アマデウス's avatar
アマデウス committed
84
85
86

    def forward(self, x):
        x = self.patch_embed(x)
87
        x = self.dropout(x)
アマデウス's avatar
アマデウス committed
88
89
90
91
        return x


@LAYERS.register_module
92
class ViTSelfAttention(nn.Module):
アマデウス's avatar
アマデウス committed
93
94
95
96
97
98
99
    def __init__(self,
                 dim: int,
                 num_heads: int,
                 attention_dropout: float,
                 dropout: float,
                 bias: bool = True,
                 dtype: dtype = None,
100
                 init_method: str = 'torch'):
101
        super().__init__()
アマデウス's avatar
アマデウス committed
102
103
104
105
106
107
        self.attention_head_size = dim // num_heads
        self.query_key_value = col_nn.Linear(dim,
                                             3 * dim,
                                             dtype=dtype,
                                             bias=bias,
                                             **_init_rules[init_method]['transformer'])
108
109
110
        self.attention_dropout = col_nn.Dropout(attention_dropout)
        self.dense = col_nn.Linear(dim, dim, dtype=dtype, bias=True, **_init_rules[init_method]['transformer'])
        self.dropout = col_nn.Dropout(dropout)
アマデウス's avatar
アマデウス committed
111
112
        self.softmax = nn.Softmax(dim=-1)

113
    def forward(self, x):
アマデウス's avatar
アマデウス committed
114
115
116
117
118
119
120
121
122
123
124
125
        qkv = self.query_key_value(x)
        all_head_size = qkv.shape[-1] // 3
        num_attention_heads = all_head_size // self.attention_head_size
        new_qkv_shape = qkv.shape[:-1] + \
            (num_attention_heads, 3 * self.attention_head_size)
        qkv = qkv.view(new_qkv_shape)
        qkv = qkv.permute((0, 2, 1, 3))
        q, k, v = torch.chunk(qkv, 3, dim=-1)

        x = torch.matmul(q, k.transpose(-1, -2))
        x = x / math.sqrt(self.attention_head_size)
        x = self.softmax(x)
126
        x = self.attention_dropout(x)
アマデウス's avatar
アマデウス committed
127
128
129
130
131
132
133

        x = torch.matmul(x, v)
        x = x.transpose(1, 2)
        new_context_layer_shape = x.size()[:-2] + (all_head_size, )
        x = x.reshape(new_context_layer_shape)

        x = self.dense(x)
134
        x = self.dropout(x)
アマデウス's avatar
アマデウス committed
135
136
137
138
139

        return x


@LAYERS.register_module
140
class ViTMLP(nn.Module):
アマデウス's avatar
アマデウス committed
141
142
143
144
145
146
147
    def __init__(self,
                 dim: int,
                 mlp_ratio: int,
                 activation: Callable,
                 dropout: float,
                 dtype: dtype = None,
                 bias: bool = True,
148
                 init_method: str = 'torch'):
149
        super().__init__()
アマデウス's avatar
アマデウス committed
150
151
152
153
154
155
        self.dense_1 = col_nn.Linear(dim,
                                     mlp_ratio * dim,
                                     dtype=dtype,
                                     bias=bias,
                                     **_init_rules[init_method]['transformer'])
        self.activation = activation
156
        self.dropout_1 = col_nn.Dropout(dropout)
アマデウス's avatar
アマデウス committed
157
158
159
160
161
        self.dense_2 = col_nn.Linear(mlp_ratio * dim,
                                     dim,
                                     dtype=dtype,
                                     bias=bias,
                                     **_init_rules[init_method]['transformer'])
162
        self.dropout_2 = col_nn.Dropout(dropout)
アマデウス's avatar
アマデウス committed
163

164
    def forward(self, x):
アマデウス's avatar
アマデウス committed
165
166
        x = self.dense_1(x)
        x = self.activation(x)
167
        x = self.dropout_1(x)
アマデウス's avatar
アマデウス committed
168
        x = self.dense_2(x)
169
        x = self.dropout_2(x)
アマデウス's avatar
アマデウス committed
170
171
172
173
174
175
176
177
178
179
180
        return x


@LAYERS.register_module
class ViTHead(nn.Module):
    def __init__(self,
                 dim: int,
                 num_classes: int,
                 representation_size: int = None,
                 dtype: dtype = None,
                 bias: bool = True,
181
                 init_method: str = 'torch'):
アマデウス's avatar
アマデウス committed
182
183
184
185
186
187
        super().__init__()
        if representation_size:
            self.representation = col_nn.Linear(dim,
                                                representation_size,
                                                bias=bias,
                                                dtype=dtype,
188
                                                **_init_rules[init_method]['head'])
アマデウス's avatar
アマデウス committed
189
190
191
192
        else:
            self.representation = None
            representation_size = dim

193
194
195
196
197
        self.dense = col_nn.Classifier(representation_size,
                                       num_classes,
                                       dtype=dtype,
                                       bias=bias,
                                       **_init_rules[init_method]['head'])
アマデウス's avatar
アマデウス committed
198
199
200
201
202

    def forward(self, x):
        x = x[:, 0]
        if self.representation is not None:
            x = self.representation(x)
203
        x = self.dense(x)
アマデウス's avatar
アマデウス committed
204
205
206
207
        return x


@LAYERS.register_module
208
class ViTBlock(CheckpointModule):
アマデウス's avatar
アマデウス committed
209
210
211
212
213
214
215
216
    def __init__(self,
                 dim: int,
                 num_heads: int,
                 mlp_ratio: int,
                 activation: Callable,
                 attention_dropout: float = 0.,
                 dropout: float = 0.,
                 drop_path: float = 0.,
217
                 layernorm_epsilon: float = 1e-6,
アマデウス's avatar
アマデウス committed
218
219
220
                 dtype: dtype = None,
                 bias: bool = True,
                 checkpoint: bool = False,
221
                 init_method: str = 'torch'):
222
223
        super().__init__(checkpoint)
        self.norm1 = col_nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype)
アマデウス's avatar
アマデウス committed
224
225
226
227
228
229
        self.attn = ViTSelfAttention(dim=dim,
                                     num_heads=num_heads,
                                     attention_dropout=attention_dropout,
                                     dropout=dropout,
                                     bias=bias,
                                     dtype=dtype,
230
                                     init_method=init_method)
アマデウス's avatar
アマデウス committed
231
        self.drop_path = col_nn.DropPath(drop_path) if drop_path > 0. else nn.Identity()
232
        self.norm2 = col_nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype)
アマデウス's avatar
アマデウス committed
233
234
235
236
237
238
        self.mlp = ViTMLP(dim=dim,
                          mlp_ratio=mlp_ratio,
                          activation=activation,
                          dropout=dropout,
                          dtype=dtype,
                          bias=bias,
239
                          init_method=init_method)
アマデウス's avatar
アマデウス committed
240

241
    def _forward(self, x):
アマデウス's avatar
アマデウス committed
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x


@MODELS.register_module
class VisionTransformer(nn.Module):
    def __init__(self,
                 img_size: int = 224,
                 patch_size: int = 16,
                 in_chans: int = 3,
                 num_classes: int = 1000,
                 depth: int = 12,
                 num_heads: int = 12,
                 dim: int = 768,
                 mlp_ratio: int = 4,
                 attention_dropout: float = 0.,
                 dropout: float = 0.1,
                 drop_path: float = 0.,
261
                 layernorm_epsilon: float = 1e-6,
アマデウス's avatar
アマデウス committed
262
263
264
265
266
                 activation: Callable = nn.functional.gelu,
                 representation_size: int = None,
                 dtype: dtype = None,
                 bias: bool = True,
                 checkpoint: bool = False,
267
                 init_method: str = 'torch'):
アマデウス's avatar
アマデウス committed
268
269
        super().__init__()

270
271
272
273
274
275
276
        embed = ViTEmbedding(img_size=img_size,
                             patch_size=patch_size,
                             in_chans=in_chans,
                             embedding_dim=dim,
                             dropout=dropout,
                             dtype=dtype,
                             init_method=init_method)
アマデウス's avatar
アマデウス committed
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295

        # stochastic depth decay rule
        dpr = [x.item() for x in torch.linspace(0, drop_path, depth)]
        blocks = [
            ViTBlock(
                dim=dim,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                attention_dropout=attention_dropout,
                dropout=dropout,
                drop_path=dpr[i],
                activation=activation,
                dtype=dtype,
                bias=bias,
                checkpoint=checkpoint,
                init_method=init_method,
            ) for i in range(depth)
        ]

296
        norm = col_nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype)
アマデウス's avatar
アマデウス committed
297

298
299
300
301
302
303
        head = ViTHead(dim=dim,
                       num_classes=num_classes,
                       representation_size=representation_size,
                       dtype=dtype,
                       bias=bias,
                       init_method=init_method)
アマデウス's avatar
アマデウス committed
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415

        self.layers = nn.Sequential(
            embed,
            *blocks,
            norm,
            head,
        )

    def forward(self, x):
        x = self.layers(x)
        return x


def _create_vit_model(**model_kwargs):
    model = VisionTransformer(**model_kwargs)
    return model


@MODELS.register_module
def vit_lite_depth7_patch4_32(**kwargs):
    model_kwargs = dict(img_size=32, patch_size=4, dim=256, depth=7, num_heads=4, mlp_ratio=2, num_classes=10, **kwargs)
    return _create_vit_model(**model_kwargs)


@MODELS.register_module
def vit_tiny_patch4_32(**kwargs):
    model_kwargs = dict(img_size=32, patch_size=4, dim=512, depth=6, num_heads=8, mlp_ratio=1, num_classes=10, **kwargs)
    return _create_vit_model(**model_kwargs)


@MODELS.register_module
def vit_tiny_patch16_224(**kwargs):
    model_kwargs = dict(img_size=224, patch_size=16, dim=192, depth=12, num_heads=3, mlp_ratio=4, **kwargs)
    return _create_vit_model(**model_kwargs)


@MODELS.register_module
def vit_tiny_patch16_384(**kwargs):
    model_kwargs = dict(img_size=384, patch_size=16, dim=192, depth=12, num_heads=3, mlp_ratio=4, **kwargs)
    return _create_vit_model(**model_kwargs)


@MODELS.register_module
def vit_small_patch16_224(**kwargs):
    model_kwargs = dict(img_size=224, patch_size=16, dim=384, depth=12, num_heads=6, mlp_ratio=4, **kwargs)
    return _create_vit_model(**model_kwargs)


@MODELS.register_module
def vit_small_patch16_384(**kwargs):
    model_kwargs = dict(img_size=384, patch_size=16, dim=384, depth=12, num_heads=6, mlp_ratio=4, **kwargs)
    return _create_vit_model(**model_kwargs)


@MODELS.register_module
def vit_small_patch32_224(**kwargs):
    model_kwargs = dict(img_size=224, patch_size=32, dim=384, depth=12, num_heads=6, mlp_ratio=4, **kwargs)
    return _create_vit_model(**model_kwargs)


@MODELS.register_module
def vit_small_patch32_384(**kwargs):
    model_kwargs = dict(img_size=384, patch_size=32, dim=384, depth=12, num_heads=6, mlp_ratio=4, **kwargs)
    return _create_vit_model(**model_kwargs)


@MODELS.register_module
def vit_base_patch16_224(**kwargs):
    model_kwargs = dict(img_size=224, patch_size=16, dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs)
    return _create_vit_model(**model_kwargs)


@MODELS.register_module
def vit_base_patch16_384(**kwargs):
    model_kwargs = dict(img_size=384, patch_size=16, dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs)
    return _create_vit_model(**model_kwargs)


@MODELS.register_module
def vit_base_patch32_224(**kwargs):
    model_kwargs = dict(img_size=224, patch_size=32, dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs)
    return _create_vit_model(**model_kwargs)


@MODELS.register_module
def vit_base_patch32_384(**kwargs):
    model_kwargs = dict(img_size=384, patch_size=32, dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs)
    return _create_vit_model(**model_kwargs)


@MODELS.register_module
def vit_large_patch16_224(**kwargs):
    model_kwargs = dict(img_size=224, patch_size=16, dim=1024, depth=24, num_heads=16, mlp_ratio=4, **kwargs)
    return _create_vit_model(**model_kwargs)


@MODELS.register_module
def vit_large_patch16_384(**kwargs):
    model_kwargs = dict(img_size=384, patch_size=16, dim=1024, depth=24, num_heads=16, mlp_ratio=4, **kwargs)
    return _create_vit_model(**model_kwargs)


@MODELS.register_module
def vit_large_patch32_224(**kwargs):
    model_kwargs = dict(img_size=224, patch_size=32, dim=1024, depth=24, num_heads=16, mlp_ratio=4, **kwargs)
    return _create_vit_model(**model_kwargs)


@MODELS.register_module
def vit_large_patch32_384(**kwargs):
    model_kwargs = dict(img_size=384, patch_size=32, dim=1024, depth=24, num_heads=16, mlp_ratio=4, **kwargs)
    return _create_vit_model(**model_kwargs)