vit_backbone.py 8.51 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2

3
"""Vision Transformer(VIT) model."""
4
5
6
7

import math
import einops
import torch
8
import apex
9
import torch.nn.functional as F
xingjinliang's avatar
xingjinliang committed
10
11
12
from megatron.training import get_args
from megatron.legacy.model.transformer import ParallelTransformer
from megatron.legacy.model.utils import (
13
14
15
16
    get_linear_layer,
    init_method_normal,
    scaled_init_method_normal,
)
xingjinliang's avatar
xingjinliang committed
17
from megatron.legacy.model.module import MegatronModule
18

19
CLASS_TOKEN_LENGTH = 8
20
21
22
23
24
25
26

class VitMlpHead(MegatronModule):
    """Pooler layer.

    Pool hidden states of a specific token (for example start of the
    sequence) and add a linear transformation followed by a tanh.

xingjinliang's avatar
xingjinliang committed
27
    Args:
28
29
30
31
32
        hidden_size: hidden size
        init_method: weight initialization method for the linear layer.
            bias is set to zero.
    """

xingjinliang's avatar
xingjinliang committed
33
    def __init__(self, config, hidden_size, num_classes):
34
        super(VitMlpHead, self).__init__()
xingjinliang's avatar
xingjinliang committed
35
        self.config = config
36
        self.dense_in = torch.nn.Linear(hidden_size, hidden_size)
37
        self.relu = torch.nn.ReLU()
38
39
40
        self.dense_out = torch.nn.Linear(hidden_size, num_classes)
        torch.nn.init.constant_(self.dense_out.bias, -10)

41
42
    def forward(self, hidden_states):
        # hidden_states: [b, 1, h]
43
        # sequence_index: index of the token to pool.
44
        dense_in_result = self.dense_in(hidden_states)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
45
46
47
        tanh_result = torch.tanh(dense_in_result)
        dense_out_result = self.dense_out(tanh_result)
        return dense_out_result
48
49


50
51
52
53
54
55
56
def isPerfectSquare(x):
    if(x >= 0):
        sr = math.sqrt(x)
        return (int(sr) * int(sr) == x)
    return False


57
58
59
60
61
62
63
64
65
66
67
def twod_interpolate_position_embeddings_hook(
    state_dict,
    prefix,
    local_metadata,
    strict,
    missing_keys,
    unexpected_keys,
    error_msgs,
):

    args = get_args()
68
69
70
    num_patches_per_dim_h = args.img_h // args.patch_dim
    num_patches_per_dim_w = args.img_w // args.patch_dim
    num_patches = num_patches_per_dim_h * num_patches_per_dim_w
71
72
73
    hidden_size = args.hidden_size

    key = prefix + "weight"
74

75
76
77
78
    assert key in state_dict
    if key in state_dict:
        input_param = state_dict[key]

79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
        input_seq_len = input_param.shape[0]
        assert(isPerfectSquare(input_seq_len) or isPerfectSquare(input_seq_len - CLASS_TOKEN_LENGTH))
        input_has_class_token = not isPerfectSquare(input_seq_len)
        num_tok_input = input_seq_len - CLASS_TOKEN_LENGTH if input_has_class_token else input_seq_len
        num_tok_output = num_patches
        output_has_class_token = args.class_token_present

        # update input_param and load it to state_dict[key]
        if input_has_class_token:
            input_param_tok = input_param[:CLASS_TOKEN_LENGTH, :]
            input_param_grid = input_param[CLASS_TOKEN_LENGTH:, :]
        else:
            input_param_tok = torch.zeros(CLASS_TOKEN_LENGTH, hidden_size)
            input_param_grid = input_param

94
        assert input_param.shape[1] == hidden_size
95
96

        if num_tok_input != num_tok_output:
97
98

            gs_input = int(math.sqrt(num_tok_input))
99
            gs_new = (num_patches_per_dim_h, num_patches_per_dim_w)
100
101
102
103
104
105

            input_param_grid = input_param_grid.transpose(0, 1).contiguous()
            input_param_grid = input_param_grid.reshape(
                (1, -1, gs_input, gs_input)
            )
            input_param_grid = input_param_grid.float()
106
            scale_factor = (gs_new[0] / gs_input, gs_new[1] / gs_input)
107
108
109
110
111
112

            input_param_grid = F.interpolate(
                input_param_grid, scale_factor=scale_factor, mode="bilinear"
            )

            input_param_grid = input_param_grid.half()
113
            input_param_grid = input_param_grid.reshape((-1, num_tok_output))
114
115
116
117
            input_param_grid = input_param_grid.transpose(0, 1).contiguous()

            assert input_param_grid.shape[1] == hidden_size

118
119
120
121
122
123
124
125
126
127
        input_param = input_param_grid
        assert (
            input_param.shape[0] == num_tok_output
            and input_param.shape[1] == hidden_size
        )

        if output_has_class_token:
            input_param = torch.cat((input_param_tok, input_param), dim=0)

        state_dict[key] = input_param
128
129


130
class VitBackbone(MegatronModule):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
131
    """Vision Transformer Model."""
132

133
    def __init__(self,
liangjing's avatar
v1  
liangjing committed
134
                 config,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
135
                 pre_process=True,
136
137
                 post_process=True,
                 class_token=True,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
138
                 single_token_output=False,
139
                 post_layer_norm=True,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
140
                 drop_path_rate=0.0):
liangjing's avatar
v1  
liangjing committed
141
        super(VitBackbone, self).__init__(share_embeddings_and_output_weights=False)
142
        args = get_args()
xingjinliang's avatar
xingjinliang committed
143
        self.config = config
144
145
146

        self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy

Vijay Korthikanti's avatar
Vijay Korthikanti committed
147
148
        self.pre_process = pre_process
        self.post_process = post_process
149
        self.class_token = class_token
150
        self.post_layer_norm = post_layer_norm
151
152
        self.hidden_size = args.hidden_size
        self.patch_dim = args.patch_dim
153
154
155
156
        self.img_h = args.img_h
        self.img_w = args.img_w
        self.micro_batch_size = args.micro_batch_size
        self.single_token_output = single_token_output
Vijay Korthikanti's avatar
Vijay Korthikanti committed
157
        self.drop_path_rate = drop_path_rate
158
159
160
161
162
163
164

        assert self.img_h % self.patch_dim == 0
        assert self.img_w % self.patch_dim == 0
        self.num_patches_per_dim_h = self.img_h // self.patch_dim
        self.num_patches_per_dim_w = self.img_w // self.patch_dim
        self.num_patches = self.num_patches_per_dim_h * self.num_patches_per_dim_w
        self.seq_length = self.num_patches + (CLASS_TOKEN_LENGTH if self.class_token else 0)
165
        self.flatten_dim = self.patch_dim * self.patch_dim * args.num_channels
166
167
        self.input_tensor = None
        self.position_ids = None
168

Vijay Korthikanti's avatar
Vijay Korthikanti committed
169
170
        if self.pre_process:
            # cls_token
171
172
173
174
175
176
            if self.class_token:
                self.cls_token = torch.nn.Parameter(
                    torch.randn(1, CLASS_TOKEN_LENGTH, self.hidden_size)
                )
                torch.nn.init.zeros_(self.cls_token)
            self.position_ids = torch.arange(self.seq_length).expand(1, -1).cuda()
xingjinliang's avatar
xingjinliang committed
177

Vijay Korthikanti's avatar
Vijay Korthikanti committed
178
179
180
181
            # Linear encoder
            self.linear_encoder = torch.nn.Linear(
                self.flatten_dim, self.hidden_size
            )
182

Vijay Korthikanti's avatar
Vijay Korthikanti committed
183
184
185
186
187
188
189
            # embedding
            self.position_embeddings = torch.nn.Embedding(
                self.seq_length, self.hidden_size
            )
            init_method_normal(args.init_method_std)(
                self.position_embeddings.weight
            )
190

191
            args.class_token_present = self.class_token
Vijay Korthikanti's avatar
Vijay Korthikanti committed
192
193
194
            self.position_embeddings._register_load_state_dict_pre_hook(
                twod_interpolate_position_embeddings_hook
            )
195

Vijay Korthikanti's avatar
Vijay Korthikanti committed
196
            self.embedding_dropout = torch.nn.Dropout(args.hidden_dropout)
197
198
199

        # Transformer
        self.transformer = ParallelTransformer(
liangjing's avatar
v1  
liangjing committed
200
            config,
xingjinliang's avatar
xingjinliang committed
201
            model_type=args.model_type,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
202
            pre_process=self.pre_process,
203
            post_process=self.post_process,
204
            post_layer_norm=self.post_layer_norm,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
205
            drop_path_rate=self.drop_path_rate
206
207
        )

Vijay Korthikanti's avatar
Vijay Korthikanti committed
208
    def set_input_tensor(self, input_tensor):
xingjinliang's avatar
xingjinliang committed
209
        """See megatron.legacy.model.transformer.set_input_tensor()"""
Vijay Korthikanti's avatar
Vijay Korthikanti committed
210
211
212
213
214
215
216
217
218
219
        self.transformer.set_input_tensor(input_tensor)

    def forward(self, input):

        if self.pre_process:
            rearranged_input = einops.rearrange(
                input,
                "b c (h p1) (w p2) -> b (h w) (p1 p2 c)",
                p1=self.patch_dim,
                p2=self.patch_dim,
220
221
            )

Vijay Korthikanti's avatar
Vijay Korthikanti committed
222
223
            assert rearranged_input.dtype == torch.half
            encoder_output = self.linear_encoder(rearranged_input)
224
225
226
227
228

            concatenated_tokens = encoder_output
            if self.class_token:
                cls_tokens = self.cls_token.expand(encoder_output.shape[0], -1, -1)
                concatenated_tokens = torch.cat((cls_tokens, encoder_output), dim=1)
229

Vijay Korthikanti's avatar
Vijay Korthikanti committed
230
            token_embeddings = concatenated_tokens + \
231
                    self.position_embeddings(self.position_ids[:, :concatenated_tokens.shape[1]])
Yu Yao's avatar
Yu Yao committed
232
233
            # [b, s, h] => [s, b, h]
            token_embeddings = token_embeddings.transpose(0, 1).contiguous()
Vijay Korthikanti's avatar
Vijay Korthikanti committed
234
235
236
            hidden_states = self.embedding_dropout(token_embeddings)
        else:
            hidden_states = input
237

Vijay Korthikanti's avatar
Vijay Korthikanti committed
238
        hidden_states = self.transformer(hidden_states, None)
239

Yu Yao's avatar
Yu Yao committed
240
241
242
243
244
245
        if self.post_process:
            # [s b h] => [b s h]
            if self.single_token_output:
                hidden_states = hidden_states[0]
            else:
                hidden_states = hidden_states.transpose(0, 1).contiguous()
246

Vijay Korthikanti's avatar
Vijay Korthikanti committed
247
        return hidden_states
248