"include/ck/utility/get_id.hpp" did not exist on "05e046654c9a226444091806a418a77fe0e4a4c2"
vit_model.py 13.9 KB
Newer Older
huaerkl's avatar
v1.0  
huaerkl committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Vision Transformer(VIT) model."""

import math
import einops
import torch
import torch.nn.functional as F
from megatron import get_args
from megatron import mpu
from megatron.model.transformer import ParallelTransformer
from megatron.enums import AttnMaskType
from .module import MegatronModule, fp32_to_float16
from megatron.model.utils import (
    get_linear_layer,
    init_method_normal,
    scaled_init_method_normal,
)
from .module import MegatronModule

from deepspeed.pipe import PipelineModule, LayerSpec, TiedLayerSpec
from megatron.model.fused_layer_norm import MixedFusedLayerNorm as LayerNorm
from megatron.model.module import float16_to_fp32
from .language_model import EmbeddingPipe
from .transformer import ParallelTransformerLayerPipe
from .language_model import parallel_lm_logits


class VitMlpHead(MegatronModule):
    """Pooler layer.

    Pool hidden states of a specific token (for example start of the
    sequence) and add a linear transformation followed by a tanh.

    Arguments:
        hidden_size: hidden size
        init_method: weight initialization method for the linear layer.
            bias is set to zero.
    """

    def __init__(self, hidden_size, num_classes):
        super(VitMlpHead, self).__init__()
        self.dense_in = torch.nn.Linear(hidden_size, hidden_size)
        self.dense_out = torch.nn.Linear(hidden_size, num_classes)
        torch.nn.init.constant_(self.dense_out.bias, -10)

    def forward(self, hidden_states, sequence_index=0):
        # hidden_states: [b, s, h]
        # sequence_index: index of the token to pool.
        x = hidden_states[:, sequence_index, :]
        x = self.dense_in(x)
        x = torch.tanh(x)
        x = self.dense_out(x)
        return x


def twod_interpolate_position_embeddings_hook(
    state_dict,
    prefix,
    local_metadata,
    strict,
    missing_keys,
    unexpected_keys,
    error_msgs,
):

    args = get_args()
    num_patches_per_dim = args.img_dim // args.patch_dim
    num_patches = num_patches_per_dim ** 2
    seq_length = num_patches + 1
    hidden_size = args.hidden_size

    key = prefix + "weight"
    # import pdb
    # pdb.set_trace()
    assert key in state_dict
    if key in state_dict:
        input_param = state_dict[key]

        assert input_param.shape[1] == hidden_size
        if input_param.shape[0] != seq_length:
            # update input_param and load it to state_dict[key]

            num_tok_input = input_param.shape[0] - 1
            num_tok_new = seq_length - 1
            input_param_tok, input_param_grid = (
                input_param[:1, :],
                input_param[1:, :],
            )

            gs_input = int(math.sqrt(num_tok_input))
            gs_new = int(math.sqrt(num_tok_new))

            input_param_grid = input_param_grid.transpose(0, 1).contiguous()
            input_param_grid = input_param_grid.reshape(
                (1, -1, gs_input, gs_input)
            )
            input_param_grid = input_param_grid.float()
            scale_factor = gs_new / gs_input

            input_param_grid = F.interpolate(
                input_param_grid, scale_factor=scale_factor, mode="bilinear"
            )

            input_param_grid = input_param_grid.half()
            input_param_grid = input_param_grid.reshape((-1, gs_new * gs_new))
            input_param_grid = input_param_grid.transpose(0, 1).contiguous()

            assert input_param_grid.shape[1] == hidden_size
            input_param = torch.cat((input_param_tok, input_param_grid), dim=0)
            assert (
                input_param.shape[0] == seq_length
                and input_param.shape[1] == hidden_size
            )

            state_dict[key] = input_param


class VitModel(MegatronModule):
    """Vision Transformer Model."""

    def __init__(self, num_classes, finetune=False):
        super(VitModel, self).__init__()
        args = get_args()

        self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
        if args.init_method_xavier_uniform:
            self.init_method = torch.nn.init.xavier_uniform_
            self.scaled_init_method = torch.nn.init.xavier_uniform_
        else:
            self.init_method = init_method_normal(args.init_method_std)
            self.scaled_init_method = scaled_init_method_normal(
                args.init_method_std, args.num_layers
            )

        self.hidden_size = args.hidden_size
        self.num_classes = num_classes
        self.patch_dim = args.patch_dim
        self.img_dim = args.img_dim
        self.finetune = finetune

        assert self.img_dim % self.patch_dim == 0
        self.num_patches_per_dim = self.img_dim // self.patch_dim
        self.num_patches = self.num_patches_per_dim ** 2
        self.seq_length = self.num_patches + 1
        self.flatten_dim = self.patch_dim * self.patch_dim * args.num_channels

        # cls_token
        self.cls_token = torch.nn.Parameter(torch.randn(1, 1, self.hidden_size))
        torch.nn.init.zeros_(self.cls_token)

        # Linear encoder
        self.linear_encoder = torch.nn.Linear(
            self.flatten_dim, self.hidden_size
        )

        # embedding
        self.position_embeddings = torch.nn.Embedding(
            self.seq_length, self.hidden_size
        )
        init_method_normal(args.init_method_std)(
            self.position_embeddings.weight
        )
        self.position_ids = torch.arange(self.seq_length).expand(1, -1).cuda()

        self.position_embeddings._register_load_state_dict_pre_hook(
            twod_interpolate_position_embeddings_hook
        )

        self.embedding_dropout = torch.nn.Dropout(args.hidden_dropout)

        # Transformer
        self.transformer = ParallelTransformer(
            self.init_method, self.scaled_init_method
        )

        # MLP head
        if not self.finetune:
            self.mlp_head = VitMlpHead(self.hidden_size, self.num_classes)
        else:
            self.class_head = get_linear_layer(
                self.hidden_size, num_classes, torch.nn.init.zeros_
            )
    def set_input_tensor(self, input_tensor):
        """See megatron.model.transformer.set_input_tensor()"""
        self.transformer.set_input_tensor(input_tensor)
    
    def forward(self, x):
        x = einops.rearrange(
            x,
            "b c (h p1) (w p2) -> b (h w) (p1 p2 c)",
            p1=self.patch_dim,
            p2=self.patch_dim,
        )

        assert x.dtype == torch.half
        x = self.linear_encoder(x)
        cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)

        x = x + self.position_embeddings(self.position_ids)
        x = self.embedding_dropout(x)
        x = self.transformer(x, None)

        if not self.finetune:
            x = self.mlp_head(x)
        else:
            x = self.class_head(x[:, 0, :])

        return x


class PregrocessPipe(MegatronModule):
    def __init__(self):
        super(PregrocessPipe, self).__init__()
        args = get_args()
        self.hidden_size = args.hidden_size
        self.patch_dim = args.patch_dim
        self.img_dim = args.img_dim

        assert self.img_dim % self.patch_dim == 0
        self.num_patches_per_dim = self.img_dim // self.patch_dim
        self.num_patches = self.num_patches_per_dim ** 2
        self.seq_length = self.num_patches + 1
        self.flatten_dim = self.patch_dim * self.patch_dim * args.num_channels

        # cls_token
        self.cls_token = torch.nn.Parameter(torch.randn(1, 1, self.hidden_size))
        torch.nn.init.zeros_(self.cls_token)

        # Linear encoder
        self.linear_encoder = torch.nn.Linear(
            self.flatten_dim, self.hidden_size
        )

        # embedding
        self.position_embeddings = torch.nn.Embedding(
            self.seq_length, self.hidden_size
        )
        # init_method_normal(args.init_method_std)(
        #     self.position_embeddings.weight
        # )
        self.position_ids = torch.arange(self.seq_length).expand(1, -1).cuda()

        # self.position_embeddings._register_load_state_dict_pre_hook(
        #     twod_interpolate_position_embeddings_hook
        # )

        self.embedding_dropout = torch.nn.Dropout(args.hidden_dropout)
    def forward(self, x):
        x = einops.rearrange(
            x,
            "b c (h p1) (w p2) -> b (h w) (p1 p2 c)",
            p1=self.patch_dim,
            p2=self.patch_dim,
        )

        assert x.dtype == torch.half
        x = self.linear_encoder(x)
        cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)

        x = x + self.position_embeddings(self.position_ids)
        x = self.embedding_dropout(x)
        return x

class VitModelPipe(PipelineModule,MegatronModule):
    """Vision Transformer Model."""
    def __init__(self, num_classes, finetune=False, attn_mask_type: AttnMaskType = AttnMaskType.causal):

        args = get_args()
        
        if args.init_method_xavier_uniform:
            init_method = torch.nn.init.xavier_uniform_
            scaled_init_method = torch.nn.init.xavier_uniform_
        else:
            init_method = init_method_normal(args.init_method_std)
            scaled_init_method = scaled_init_method_normal(
                args.init_method_std, args.num_layers
            )

        self.specs = []
        
        def _to_float16(inputs):
            if args.fp16:
                return fp32_to_float16(inputs, lambda v: v.half())
            elif args.bf16:
                return fp32_to_float16(inputs, lambda v: v.bfloat16())
            else:
                return inputs
        
        self.specs.append(_to_float16)
        
        # Embedding layer
        self.specs.append(TiedLayerSpec('embed', PregrocessPipe))
        
        if args.fp32_residual_connection:
            if getattr(args, 'pretrain_causal_attention', False):
                self.specs.append(lambda x: x.transpose(0, 1).contiguous().float())
            else:
                # EmbeddingPipe returns attention mask as well
                self.specs.append(lambda x: (x[0].transpose(0, 1).contiguous().float(), *x[1:]))
        else:
            if getattr(args, 'pretrain_causal_attention', False):
                self.specs.append(lambda x: x.transpose(0, 1).contiguous())
            else:
                # EmbeddingPipe returns attention mask as well
                self.specs.append(lambda x: (x[0].transpose(0, 1).contiguous(), *x[1:]))
        
        for layer_idx in range(args.num_layers):
            self.specs.append(
                LayerSpec(ParallelTransformerLayerPipe,
                    init_method=init_method,
                    output_layer_init_method=scaled_init_method,
                    layer_number=layer_idx,
                    # TODO: Change naming of class from GPT to something that encapsulate prefix lm.
                    self_attn_mask_type=attn_mask_type))
         
        # Undo data format change
        def undo(x):
            if not getattr(args, 'pretrain_causal_attention', False):
                x = x[0]
            return x.transpose(0, 1).contiguous()
        self.specs.append(undo)

        # Final layernorm after transformer layers
        self.specs.append(
            LayerSpec(LayerNorm,
                      args.hidden_size,
                      eps=args.layernorm_epsilon))
        
        # MLP head
        if not finetune:
            self.specs.append(VitMlpHead(args.hidden_size, num_classes))
        else:
            self.specs.append(lambda x: x[:, 0, :])
            self.specs.append(get_linear_layer(args.hidden_size, num_classes, torch.nn.init.zeros_))
        
        # Convert to fp32 if needed
        # if args.fp16 or args.bf16:
        #     self.specs.append(float16_to_fp32)
        
        if args.checkpoint_activations:
            interval = args.checkpoint_num_layers
        else:
            interval = 0
        
        from deepspeed.runtime.pipe.topology import PipeModelDataParallelTopology
        topo = PipeModelDataParallelTopology(num_pp=mpu.get_pipeline_model_parallel_world_size(),
                                             num_mp=mpu.get_tensor_model_parallel_world_size(),
                                             num_dp=mpu.get_data_parallel_world_size())

        # here one can extend the regex to include more layers to be counted towards partitioning,
        # e.g. 'type:transformer|embedding' will add up all the transformer blocks and also the first
        # and last embedding layers and then partition that transformers+2 layers - so to get a good
        # balance you may want to use less transformer layers
        #
        # caveat emptor: the current implementation of PP fails unless each stage has at least one
        # transformer layer
        if args.pp_partition_method is not None:
            partition_method = args.pp_partition_method
        else:
            partition_method = 'type:transformer'

        super().__init__(layers=self.specs,
                         loss_fn=get_cross_entropy(is_prefix=attn_mask_type is AttnMaskType.prefix),
                         topology=topo,
                         activation_checkpoint_interval=interval,
                         partition_method=partition_method)


def get_cross_entropy(is_prefix: bool):
    def CrossEntropy(output, labels):
        args = get_args()

        losses = mpu.vocab_parallel_cross_entropy(output.contiguous().float(), labels)

        loss = torch.sum(losses) / labels.shape[0]
        return loss
    return CrossEntropy