vit_backbone.py 8.93 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

16
"""Vision Transformer(VIT) model."""
17
18
19
20

import math
import einops
import torch
21
import apex
22
23
24
25
26
27
28
29
import torch.nn.functional as F
from megatron import get_args
from megatron.model.transformer import ParallelTransformer
from megatron.model.utils import (
    get_linear_layer,
    init_method_normal,
    scaled_init_method_normal,
)
30
from megatron.model.module import MegatronModule
31

32
CLASS_TOKEN_LENGTH = 8
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48

class VitMlpHead(MegatronModule):
    """Pooler layer.

    Pool hidden states of a specific token (for example start of the
    sequence) and add a linear transformation followed by a tanh.

    Arguments:
        hidden_size: hidden size
        init_method: weight initialization method for the linear layer.
            bias is set to zero.
    """

    def __init__(self, hidden_size, num_classes):
        super(VitMlpHead, self).__init__()
        self.dense_in = torch.nn.Linear(hidden_size, hidden_size)
49
        self.relu = torch.nn.ReLU()
50
51
52
        self.dense_out = torch.nn.Linear(hidden_size, num_classes)
        torch.nn.init.constant_(self.dense_out.bias, -10)

53
54
    def forward(self, hidden_states):
        # hidden_states: [b, 1, h]
55
        # sequence_index: index of the token to pool.
56
        dense_in_result = self.dense_in(hidden_states)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
57
58
59
        tanh_result = torch.tanh(dense_in_result)
        dense_out_result = self.dense_out(tanh_result)
        return dense_out_result
60
61


62
63
64
65
66
67
68
def isPerfectSquare(x):
    if(x >= 0):
        sr = math.sqrt(x)
        return (int(sr) * int(sr) == x)
    return False


69
70
71
72
73
74
75
76
77
78
79
def twod_interpolate_position_embeddings_hook(
    state_dict,
    prefix,
    local_metadata,
    strict,
    missing_keys,
    unexpected_keys,
    error_msgs,
):

    args = get_args()
80
81
82
    num_patches_per_dim_h = args.img_h // args.patch_dim
    num_patches_per_dim_w = args.img_w // args.patch_dim
    num_patches = num_patches_per_dim_h * num_patches_per_dim_w
83
84
85
    hidden_size = args.hidden_size

    key = prefix + "weight"
86

87
88
89
90
    assert key in state_dict
    if key in state_dict:
        input_param = state_dict[key]

91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
        input_seq_len = input_param.shape[0]
        assert(isPerfectSquare(input_seq_len) or isPerfectSquare(input_seq_len - CLASS_TOKEN_LENGTH))
        input_has_class_token = not isPerfectSquare(input_seq_len)
        num_tok_input = input_seq_len - CLASS_TOKEN_LENGTH if input_has_class_token else input_seq_len
        num_tok_output = num_patches
        output_has_class_token = args.class_token_present

        # update input_param and load it to state_dict[key]
        if input_has_class_token:
            input_param_tok = input_param[:CLASS_TOKEN_LENGTH, :]
            input_param_grid = input_param[CLASS_TOKEN_LENGTH:, :]
        else:
            input_param_tok = torch.zeros(CLASS_TOKEN_LENGTH, hidden_size)
            input_param_grid = input_param

106
        assert input_param.shape[1] == hidden_size
107
108

        if num_tok_input != num_tok_output:
109
110

            gs_input = int(math.sqrt(num_tok_input))
111
            gs_new = (num_patches_per_dim_h, num_patches_per_dim_w)
112
113
114
115
116
117

            input_param_grid = input_param_grid.transpose(0, 1).contiguous()
            input_param_grid = input_param_grid.reshape(
                (1, -1, gs_input, gs_input)
            )
            input_param_grid = input_param_grid.float()
118
            scale_factor = (gs_new[0] / gs_input, gs_new[1] / gs_input)
119
120
121
122
123
124

            input_param_grid = F.interpolate(
                input_param_grid, scale_factor=scale_factor, mode="bilinear"
            )

            input_param_grid = input_param_grid.half()
125
            input_param_grid = input_param_grid.reshape((-1, num_tok_output))
126
127
128
129
            input_param_grid = input_param_grid.transpose(0, 1).contiguous()

            assert input_param_grid.shape[1] == hidden_size

130
131
132
133
134
135
136
137
138
139
        input_param = input_param_grid
        assert (
            input_param.shape[0] == num_tok_output
            and input_param.shape[1] == hidden_size
        )

        if output_has_class_token:
            input_param = torch.cat((input_param_tok, input_param), dim=0)

        state_dict[key] = input_param
140
141


142
class VitBackbone(MegatronModule):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
143
    """Vision Transformer Model."""
144

145
    def __init__(self,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
146
                 pre_process=True,
147
148
                 post_process=True,
                 class_token=True,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
149
150
                 single_token_output=False,
                 drop_path_rate=0.0):
151
        super(VitBackbone, self).__init__(share_word_embeddings=False)
152
153
154
155
156
157
158
159
160
161
162
163
        args = get_args()

        self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
        if args.init_method_xavier_uniform:
            self.init_method = torch.nn.init.xavier_uniform_
            self.scaled_init_method = torch.nn.init.xavier_uniform_
        else:
            self.init_method = init_method_normal(args.init_method_std)
            self.scaled_init_method = scaled_init_method_normal(
                args.init_method_std, args.num_layers
            )

Vijay Korthikanti's avatar
Vijay Korthikanti committed
164
165
        self.pre_process = pre_process
        self.post_process = post_process
166
        self.class_token = class_token
167
168
        self.hidden_size = args.hidden_size
        self.patch_dim = args.patch_dim
169
170
171
172
        self.img_h = args.img_h
        self.img_w = args.img_w
        self.micro_batch_size = args.micro_batch_size
        self.single_token_output = single_token_output
Vijay Korthikanti's avatar
Vijay Korthikanti committed
173
        self.drop_path_rate = drop_path_rate
174
175
176
177
178
179
180

        assert self.img_h % self.patch_dim == 0
        assert self.img_w % self.patch_dim == 0
        self.num_patches_per_dim_h = self.img_h // self.patch_dim
        self.num_patches_per_dim_w = self.img_w // self.patch_dim
        self.num_patches = self.num_patches_per_dim_h * self.num_patches_per_dim_w
        self.seq_length = self.num_patches + (CLASS_TOKEN_LENGTH if self.class_token else 0)
181
        self.flatten_dim = self.patch_dim * self.patch_dim * args.num_channels
182
183
        self.input_tensor = None
        self.position_ids = None
184

Vijay Korthikanti's avatar
Vijay Korthikanti committed
185
186
        if self.pre_process:
            # cls_token
187
188
189
190
191
192
193
            if self.class_token:
                self.cls_token = torch.nn.Parameter(
                    torch.randn(1, CLASS_TOKEN_LENGTH, self.hidden_size)
                )
                torch.nn.init.zeros_(self.cls_token)
            self.position_ids = torch.arange(self.seq_length).expand(1, -1).cuda()
            
Vijay Korthikanti's avatar
Vijay Korthikanti committed
194
195
196
197
            # Linear encoder
            self.linear_encoder = torch.nn.Linear(
                self.flatten_dim, self.hidden_size
            )
198

Vijay Korthikanti's avatar
Vijay Korthikanti committed
199
200
201
202
203
204
205
            # embedding
            self.position_embeddings = torch.nn.Embedding(
                self.seq_length, self.hidden_size
            )
            init_method_normal(args.init_method_std)(
                self.position_embeddings.weight
            )
206

207
            args.class_token_present = self.class_token
Vijay Korthikanti's avatar
Vijay Korthikanti committed
208
209
210
            self.position_embeddings._register_load_state_dict_pre_hook(
                twod_interpolate_position_embeddings_hook
            )
211

Vijay Korthikanti's avatar
Vijay Korthikanti committed
212
            self.embedding_dropout = torch.nn.Dropout(args.hidden_dropout)
213
214
215

        # Transformer
        self.transformer = ParallelTransformer(
216
            self.init_method,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
217
218
            self.scaled_init_method,
            pre_process=self.pre_process,
219
            post_process=self.post_process,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
220
            drop_path_rate=self.drop_path_rate
221
222
        )

Vijay Korthikanti's avatar
Vijay Korthikanti committed
223
224
225
226
227
228
229
230
231
232
233
234
    def set_input_tensor(self, input_tensor):
        """See megatron.model.transformer.set_input_tensor()"""
        self.transformer.set_input_tensor(input_tensor)

    def forward(self, input):

        if self.pre_process:
            rearranged_input = einops.rearrange(
                input,
                "b c (h p1) (w p2) -> b (h w) (p1 p2 c)",
                p1=self.patch_dim,
                p2=self.patch_dim,
235
236
            )

Vijay Korthikanti's avatar
Vijay Korthikanti committed
237
238
            assert rearranged_input.dtype == torch.half
            encoder_output = self.linear_encoder(rearranged_input)
239
240
241
242
243

            concatenated_tokens = encoder_output
            if self.class_token:
                cls_tokens = self.cls_token.expand(encoder_output.shape[0], -1, -1)
                concatenated_tokens = torch.cat((cls_tokens, encoder_output), dim=1)
244

Vijay Korthikanti's avatar
Vijay Korthikanti committed
245
            token_embeddings = concatenated_tokens + \
246
                    self.position_embeddings(self.position_ids[:, :concatenated_tokens.shape[1]])
Vijay Korthikanti's avatar
Vijay Korthikanti committed
247
248
249
            hidden_states = self.embedding_dropout(token_embeddings)
        else:
            hidden_states = input
250

Vijay Korthikanti's avatar
Vijay Korthikanti committed
251
        hidden_states = self.transformer(hidden_states, None)
252

253
254
        if self.single_token_output:
            hidden_states = hidden_states[:,0,:]
255

Vijay Korthikanti's avatar
Vijay Korthikanti committed
256
        return hidden_states
257