vit_backbone.py 9.06 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

16
"""Vision Transformer(VIT) model."""
17
18
19
20

import math
import einops
import torch
21
import apex
22
23
24
25
26
27
28
29
import torch.nn.functional as F
from megatron import get_args
from megatron.model.transformer import ParallelTransformer
from megatron.model.utils import (
    get_linear_layer,
    init_method_normal,
    scaled_init_method_normal,
)
30
from megatron.model.module import MegatronModule
31

32
CLASS_TOKEN_LENGTH = 8
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48

class VitMlpHead(MegatronModule):
    """Pooler layer.

    Pool hidden states of a specific token (for example start of the
    sequence) and add a linear transformation followed by a tanh.

    Arguments:
        hidden_size: hidden size
        init_method: weight initialization method for the linear layer.
            bias is set to zero.
    """

    def __init__(self, hidden_size, num_classes):
        super(VitMlpHead, self).__init__()
        self.dense_in = torch.nn.Linear(hidden_size, hidden_size)
49
        self.relu = torch.nn.ReLU()
50
51
52
        self.dense_out = torch.nn.Linear(hidden_size, num_classes)
        torch.nn.init.constant_(self.dense_out.bias, -10)

53
54
    def forward(self, hidden_states):
        # hidden_states: [b, 1, h]
55
        # sequence_index: index of the token to pool.
56
        dense_in_result = self.dense_in(hidden_states)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
57
58
59
        tanh_result = torch.tanh(dense_in_result)
        dense_out_result = self.dense_out(tanh_result)
        return dense_out_result
60
61


62
63
64
65
66
67
68
def isPerfectSquare(x):
    if(x >= 0):
        sr = math.sqrt(x)
        return (int(sr) * int(sr) == x)
    return False


69
70
71
72
73
74
75
76
77
78
79
def twod_interpolate_position_embeddings_hook(
    state_dict,
    prefix,
    local_metadata,
    strict,
    missing_keys,
    unexpected_keys,
    error_msgs,
):

    args = get_args()
80
81
82
    num_patches_per_dim_h = args.img_h // args.patch_dim
    num_patches_per_dim_w = args.img_w // args.patch_dim
    num_patches = num_patches_per_dim_h * num_patches_per_dim_w
83
84
85
    hidden_size = args.hidden_size

    key = prefix + "weight"
86

87
88
89
90
    assert key in state_dict
    if key in state_dict:
        input_param = state_dict[key]

91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
        input_seq_len = input_param.shape[0]
        assert(isPerfectSquare(input_seq_len) or isPerfectSquare(input_seq_len - CLASS_TOKEN_LENGTH))
        input_has_class_token = not isPerfectSquare(input_seq_len)
        num_tok_input = input_seq_len - CLASS_TOKEN_LENGTH if input_has_class_token else input_seq_len
        num_tok_output = num_patches
        output_has_class_token = args.class_token_present

        # update input_param and load it to state_dict[key]
        if input_has_class_token:
            input_param_tok = input_param[:CLASS_TOKEN_LENGTH, :]
            input_param_grid = input_param[CLASS_TOKEN_LENGTH:, :]
        else:
            input_param_tok = torch.zeros(CLASS_TOKEN_LENGTH, hidden_size)
            input_param_grid = input_param

106
        assert input_param.shape[1] == hidden_size
107
108

        if num_tok_input != num_tok_output:
109
110

            gs_input = int(math.sqrt(num_tok_input))
111
            gs_new = (num_patches_per_dim_h, num_patches_per_dim_w)
112
113
114
115
116
117

            input_param_grid = input_param_grid.transpose(0, 1).contiguous()
            input_param_grid = input_param_grid.reshape(
                (1, -1, gs_input, gs_input)
            )
            input_param_grid = input_param_grid.float()
118
            scale_factor = (gs_new[0] / gs_input, gs_new[1] / gs_input)
119
120
121
122
123
124

            input_param_grid = F.interpolate(
                input_param_grid, scale_factor=scale_factor, mode="bilinear"
            )

            input_param_grid = input_param_grid.half()
125
            input_param_grid = input_param_grid.reshape((-1, num_tok_output))
126
127
128
129
            input_param_grid = input_param_grid.transpose(0, 1).contiguous()

            assert input_param_grid.shape[1] == hidden_size

130
131
132
133
134
135
136
137
138
139
        input_param = input_param_grid
        assert (
            input_param.shape[0] == num_tok_output
            and input_param.shape[1] == hidden_size
        )

        if output_has_class_token:
            input_param = torch.cat((input_param_tok, input_param), dim=0)

        state_dict[key] = input_param
140
141


142
class VitBackbone(MegatronModule):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
143
    """Vision Transformer Model."""
144

145
    def __init__(self,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
146
                 pre_process=True,
147
148
                 post_process=True,
                 class_token=True,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
149
                 single_token_output=False,
150
                 post_layer_norm=True,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
151
                 drop_path_rate=0.0):
152
        super(VitBackbone, self).__init__(share_word_embeddings=False)
153
154
155
156
157
158
159
160
161
162
163
164
        args = get_args()

        self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
        if args.init_method_xavier_uniform:
            self.init_method = torch.nn.init.xavier_uniform_
            self.scaled_init_method = torch.nn.init.xavier_uniform_
        else:
            self.init_method = init_method_normal(args.init_method_std)
            self.scaled_init_method = scaled_init_method_normal(
                args.init_method_std, args.num_layers
            )

Vijay Korthikanti's avatar
Vijay Korthikanti committed
165
166
        self.pre_process = pre_process
        self.post_process = post_process
167
        self.class_token = class_token
168
        self.post_layer_norm = post_layer_norm
169
170
        self.hidden_size = args.hidden_size
        self.patch_dim = args.patch_dim
171
172
173
174
        self.img_h = args.img_h
        self.img_w = args.img_w
        self.micro_batch_size = args.micro_batch_size
        self.single_token_output = single_token_output
Vijay Korthikanti's avatar
Vijay Korthikanti committed
175
        self.drop_path_rate = drop_path_rate
176
177
178
179
180
181
182

        assert self.img_h % self.patch_dim == 0
        assert self.img_w % self.patch_dim == 0
        self.num_patches_per_dim_h = self.img_h // self.patch_dim
        self.num_patches_per_dim_w = self.img_w // self.patch_dim
        self.num_patches = self.num_patches_per_dim_h * self.num_patches_per_dim_w
        self.seq_length = self.num_patches + (CLASS_TOKEN_LENGTH if self.class_token else 0)
183
        self.flatten_dim = self.patch_dim * self.patch_dim * args.num_channels
184
185
        self.input_tensor = None
        self.position_ids = None
186

Vijay Korthikanti's avatar
Vijay Korthikanti committed
187
188
        if self.pre_process:
            # cls_token
189
190
191
192
193
194
195
            if self.class_token:
                self.cls_token = torch.nn.Parameter(
                    torch.randn(1, CLASS_TOKEN_LENGTH, self.hidden_size)
                )
                torch.nn.init.zeros_(self.cls_token)
            self.position_ids = torch.arange(self.seq_length).expand(1, -1).cuda()
            
Vijay Korthikanti's avatar
Vijay Korthikanti committed
196
197
198
199
            # Linear encoder
            self.linear_encoder = torch.nn.Linear(
                self.flatten_dim, self.hidden_size
            )
200

Vijay Korthikanti's avatar
Vijay Korthikanti committed
201
202
203
204
205
206
207
            # embedding
            self.position_embeddings = torch.nn.Embedding(
                self.seq_length, self.hidden_size
            )
            init_method_normal(args.init_method_std)(
                self.position_embeddings.weight
            )
208

209
            args.class_token_present = self.class_token
Vijay Korthikanti's avatar
Vijay Korthikanti committed
210
211
212
            self.position_embeddings._register_load_state_dict_pre_hook(
                twod_interpolate_position_embeddings_hook
            )
213

Vijay Korthikanti's avatar
Vijay Korthikanti committed
214
            self.embedding_dropout = torch.nn.Dropout(args.hidden_dropout)
215
216
217

        # Transformer
        self.transformer = ParallelTransformer(
218
            self.init_method,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
219
220
            self.scaled_init_method,
            pre_process=self.pre_process,
221
            post_process=self.post_process,
222
            post_layer_norm=self.post_layer_norm,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
223
            drop_path_rate=self.drop_path_rate
224
225
        )

Vijay Korthikanti's avatar
Vijay Korthikanti committed
226
227
228
229
230
231
232
233
234
235
236
237
    def set_input_tensor(self, input_tensor):
        """See megatron.model.transformer.set_input_tensor()"""
        self.transformer.set_input_tensor(input_tensor)

    def forward(self, input):

        if self.pre_process:
            rearranged_input = einops.rearrange(
                input,
                "b c (h p1) (w p2) -> b (h w) (p1 p2 c)",
                p1=self.patch_dim,
                p2=self.patch_dim,
238
239
            )

Vijay Korthikanti's avatar
Vijay Korthikanti committed
240
241
            assert rearranged_input.dtype == torch.half
            encoder_output = self.linear_encoder(rearranged_input)
242
243
244
245
246

            concatenated_tokens = encoder_output
            if self.class_token:
                cls_tokens = self.cls_token.expand(encoder_output.shape[0], -1, -1)
                concatenated_tokens = torch.cat((cls_tokens, encoder_output), dim=1)
247

Vijay Korthikanti's avatar
Vijay Korthikanti committed
248
            token_embeddings = concatenated_tokens + \
249
                    self.position_embeddings(self.position_ids[:, :concatenated_tokens.shape[1]])
Vijay Korthikanti's avatar
Vijay Korthikanti committed
250
251
252
            hidden_states = self.embedding_dropout(token_embeddings)
        else:
            hidden_states = input
253

Vijay Korthikanti's avatar
Vijay Korthikanti committed
254
        hidden_states = self.transformer(hidden_states, None)
255

256
257
        if self.single_token_output:
            hidden_states = hidden_states[:,0,:]
258

Vijay Korthikanti's avatar
Vijay Korthikanti committed
259
        return hidden_states
260