vit_model.py 6.73 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

16
"""Vision Transformer(VIT) model."""
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122

import math
import einops
import torch
import torch.nn.functional as F
from megatron import get_args
from megatron.model.transformer import ParallelTransformer
from megatron.model.utils import (
    get_linear_layer,
    init_method_normal,
    scaled_init_method_normal,
)
from .module import MegatronModule


class VitMlpHead(MegatronModule):
    """Pooler layer.

    Pool hidden states of a specific token (for example start of the
    sequence) and add a linear transformation followed by a tanh.

    Arguments:
        hidden_size: hidden size
        init_method: weight initialization method for the linear layer.
            bias is set to zero.
    """

    def __init__(self, hidden_size, num_classes):
        super(VitMlpHead, self).__init__()
        self.dense_in = torch.nn.Linear(hidden_size, hidden_size)
        self.dense_out = torch.nn.Linear(hidden_size, num_classes)
        torch.nn.init.constant_(self.dense_out.bias, -10)

    def forward(self, hidden_states, sequence_index=0):
        # hidden_states: [b, s, h]
        # sequence_index: index of the token to pool.
        x = hidden_states[:, sequence_index, :]
        x = self.dense_in(x)
        x = torch.tanh(x)
        x = self.dense_out(x)
        return x


def twod_interpolate_position_embeddings_hook(
    state_dict,
    prefix,
    local_metadata,
    strict,
    missing_keys,
    unexpected_keys,
    error_msgs,
):

    args = get_args()
    num_patches_per_dim = args.img_dim // args.patch_dim
    num_patches = num_patches_per_dim ** 2
    seq_length = num_patches + 1
    hidden_size = args.hidden_size

    key = prefix + "weight"
    # import pdb
    # pdb.set_trace()
    assert key in state_dict
    if key in state_dict:
        input_param = state_dict[key]

        assert input_param.shape[1] == hidden_size
        if input_param.shape[0] != seq_length:
            # update input_param and load it to state_dict[key]

            num_tok_input = input_param.shape[0] - 1
            num_tok_new = seq_length - 1
            input_param_tok, input_param_grid = (
                input_param[:1, :],
                input_param[1:, :],
            )

            gs_input = int(math.sqrt(num_tok_input))
            gs_new = int(math.sqrt(num_tok_new))

            input_param_grid = input_param_grid.transpose(0, 1).contiguous()
            input_param_grid = input_param_grid.reshape(
                (1, -1, gs_input, gs_input)
            )
            input_param_grid = input_param_grid.float()
            scale_factor = gs_new / gs_input

            input_param_grid = F.interpolate(
                input_param_grid, scale_factor=scale_factor, mode="bilinear"
            )

            input_param_grid = input_param_grid.half()
            input_param_grid = input_param_grid.reshape((-1, gs_new * gs_new))
            input_param_grid = input_param_grid.transpose(0, 1).contiguous()

            assert input_param_grid.shape[1] == hidden_size
            input_param = torch.cat((input_param_tok, input_param_grid), dim=0)
            assert (
                input_param.shape[0] == seq_length
                and input_param.shape[1] == hidden_size
            )

            state_dict[key] = input_param


class VitModel(MegatronModule):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
123
    """Vision Transformer Model."""
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210

    def __init__(self, num_classes, finetune=False):
        super(VitModel, self).__init__()
        args = get_args()

        self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
        if args.init_method_xavier_uniform:
            self.init_method = torch.nn.init.xavier_uniform_
            self.scaled_init_method = torch.nn.init.xavier_uniform_
        else:
            self.init_method = init_method_normal(args.init_method_std)
            self.scaled_init_method = scaled_init_method_normal(
                args.init_method_std, args.num_layers
            )

        self.hidden_size = args.hidden_size
        self.num_classes = num_classes
        self.patch_dim = args.patch_dim
        self.img_dim = args.img_dim
        self.finetune = finetune

        assert self.img_dim % self.patch_dim == 0
        self.num_patches_per_dim = self.img_dim // self.patch_dim
        self.num_patches = self.num_patches_per_dim ** 2
        self.seq_length = self.num_patches + 1
        self.flatten_dim = self.patch_dim * self.patch_dim * args.num_channels

        # cls_token
        self.cls_token = torch.nn.Parameter(torch.randn(1, 1, self.hidden_size))
        torch.nn.init.zeros_(self.cls_token)

        # Linear encoder
        self.linear_encoder = torch.nn.Linear(
            self.flatten_dim, self.hidden_size
        )

        # embedding
        self.position_embeddings = torch.nn.Embedding(
            self.seq_length, self.hidden_size
        )
        init_method_normal(args.init_method_std)(
            self.position_embeddings.weight
        )
        self.position_ids = torch.arange(self.seq_length).expand(1, -1).cuda()

        self.position_embeddings._register_load_state_dict_pre_hook(
            twod_interpolate_position_embeddings_hook
        )

        self.embedding_dropout = torch.nn.Dropout(args.hidden_dropout)

        # Transformer
        self.transformer = ParallelTransformer(
            self.init_method, self.scaled_init_method
        )

        # MLP head
        if not self.finetune:
            self.mlp_head = VitMlpHead(self.hidden_size, self.num_classes)
        else:
            self.class_head = get_linear_layer(
                self.hidden_size, num_classes, torch.nn.init.zeros_
            )

    def forward(self, x):
        x = einops.rearrange(
            x,
            "b c (h p1) (w p2) -> b (h w) (p1 p2 c)",
            p1=self.patch_dim,
            p2=self.patch_dim,
        )

        assert x.dtype == torch.half
        x = self.linear_encoder(x)
        cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)

        x = x + self.position_embeddings(self.position_ids)
        x = self.embedding_dropout(x)
        x = self.transformer(x, None)

        if not self.finetune:
            x = self.mlp_head(x)
        else:
            x = self.class_head(x[:, 0, :])

        return x