intern_vit.py 18.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
# adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_intern_vit.py
# --------------------------------------------------------
# InternVL
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
10
from collections.abc import Iterable
11
from functools import partial
12
from typing import Optional
13
14
15
16
17
18

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig

19
from vllm.attention.layer import MultiHeadAttention
20
21
22
23
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size,
                              split_tensor_along_last_dim,
                              tensor_model_parallel_all_gather)
24
25
26
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
27
                                               QKVParallelLinear,
28
                                               ReplicatedLinear,
29
30
                                               RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
31
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
32
from vllm.multimodal.utils import run_dp_sharded_vision_model
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61

NORM2FN = {
    'rms_norm': RMSNorm,
    'layer_norm': nn.LayerNorm,
}


class InternVisionEmbeddings(nn.Module):

    def __init__(self, config: PretrainedConfig):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.image_size = config.image_size
        self.patch_size = config.patch_size

        self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim))

        self.patch_embedding = nn.Conv2d(in_channels=3,
                                         out_channels=self.embed_dim,
                                         kernel_size=self.patch_size,
                                         stride=self.patch_size)

        self.num_patches = (self.image_size // self.patch_size)**2
        self.num_positions = self.num_patches + 1

        self.position_embedding = nn.Parameter(
            torch.randn(1, self.num_positions, self.embed_dim))

62
    def _get_pos_embed(self, pos_embed: torch.Tensor, H: int, W: int):
63
64
65
66
67
68
69
70
        target_dtype = pos_embed.dtype
        pos_embed = pos_embed.float().reshape(
            1, self.image_size // self.patch_size,
            self.image_size // self.patch_size, -1).permute(0, 3, 1, 2)
        pos_embed = F.interpolate(pos_embed,
                                  size=(H, W),
                                  mode='bicubic',
                                  align_corners=False)
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
        return pos_embed.reshape(1, -1, H * W).permute(0, 2,
                                                       1).to(target_dtype)

    def _get_position_embedding(self, H: int, W: int) -> torch.Tensor:
        position_embedding = self.position_embedding
        if self.num_patches == H * W:
            return position_embedding

        return torch.cat(
            [
                position_embedding[:, :1, :],
                self._get_pos_embed(position_embedding[:, 1:, :], H, W),
            ],
            dim=1,
        )
86
87
88
89
90
91
92
93
94
95

    def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
        target_dtype = self.patch_embedding.weight.dtype
        patch_embeds = self.patch_embedding(pixel_values.to(
            target_dtype))  # shape = [*, channel, width, height]
        batch_size, _, height, width = patch_embeds.shape
        patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
        class_embeds = self.class_embedding.expand(batch_size, 1,
                                                   -1).to(target_dtype)
        embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
96
        position_embedding = self._get_position_embedding(height, width)
97
98
99
100
        embeddings = embeddings + position_embedding.to(target_dtype)
        return embeddings


101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
class InternVisionPatchModel(nn.Module):

    def __init__(self, config: PretrainedConfig):
        super().__init__()
        self.config = config
        self.embeddings = InternVisionEmbeddings(config)

    def get_input_embeddings(self):
        return self.embeddings

    def forward(
        self,
        pixel_values: Optional[torch.Tensor] = None,
        pixel_embeds: Optional[torch.Tensor] = None,
    ) -> torch.FloatTensor:
        if pixel_values is None and pixel_embeds is None:
            raise ValueError(
                'You have to specify pixel_values or pixel_embeds')

        if pixel_embeds is not None:
            hidden_states = pixel_embeds
        elif pixel_values is not None:
            if pixel_values.ndim == 4:
                hidden_states = self.embeddings(pixel_values)
            else:
                raise ValueError(
                    f'wrong pixel_values size: {pixel_values.shape}')

        return hidden_states


132
class InternParallelAttention(nn.Module):
133
134
    """Multi-headed attention from 'Attention Is All You Need' paper"""

135
136
137
138
    def __init__(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig] = None,
139
140
        *,
        num_dummy_heads: int = 0,
141
        prefix: str = "",
142
        use_data_parallel: bool = False,
143
    ) -> None:
144
        super().__init__()
145

146
147
148
149
150
151
152
153
154
155
        self.config = config
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads
        if self.head_dim * self.num_heads != self.embed_dim:
            raise ValueError(
                f'embed_dim must be divisible by num_heads '
                f'(got `embed_dim`: {self.embed_dim} and `num_heads`:'
                f' {self.num_heads}).')

156
157
158
159
        self.tp_size = (1 if use_data_parallel else
                        get_tensor_model_parallel_world_size())
        self.tp_rank = (0 if use_data_parallel else
                        get_tensor_model_parallel_rank())
160
161
162
163
164
165

        # Additional dummy heads are used to enable TP for common GPU counts.
        self.dummy_dim = (num_dummy_heads + self.num_heads) * self.head_dim
        self.num_heads_per_partition = divide(num_dummy_heads + self.num_heads,
                                              self.tp_size)

166
        self.scale = self.head_dim**-0.5
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
        if use_data_parallel:
            self.qkv = ReplicatedLinear(
                self.embed_dim,
                3 * self.head_dim * self.num_heads,
                bias=config.qkv_bias,
                quant_config=quant_config,
                prefix=f"{prefix}.qkv",
            )
        else:
            self.qkv = QKVParallelLinear(
                self.embed_dim,
                self.head_dim,
                num_dummy_heads + self.num_heads,
                bias=config.qkv_bias,
                quant_config=quant_config,
                prefix=f"{prefix}.qkv",
            )
184
185
186
187

        self.qk_normalization = config.qk_normalization

        if self.qk_normalization:
188
189
190
191
192
193
            self.q_norm = RMSNorm(self.dummy_dim,
                                  eps=config.layer_norm_eps,
                                  var_hidden_size=self.embed_dim)
            self.k_norm = RMSNorm(self.dummy_dim,
                                  eps=config.layer_norm_eps,
                                  var_hidden_size=self.embed_dim)
194

195
196
197
198
199
200
201
202
203
204
205
206
207
208
        if use_data_parallel:
            self.proj = ReplicatedLinear(
                self.dummy_dim,
                self.embed_dim,
                quant_config=quant_config,
                prefix=f"{prefix}.proj",
            )
        else:
            self.proj = RowParallelLinear(
                self.dummy_dim,
                self.embed_dim,
                quant_config=quant_config,
                prefix=f"{prefix}.proj",
            )
209

210
211
        self.attn = MultiHeadAttention(self.num_heads_per_partition,
                                       self.head_dim, self.scale)
212

213
214
215
216
    def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor):
        if self.tp_size > 1:
            q = tensor_model_parallel_all_gather(q.contiguous())
            k = tensor_model_parallel_all_gather(k.contiguous())
217
218
        q = self.q_norm(q)
        k = self.k_norm(k)
219
220
221
222
223
224
225
226
227
        if self.tp_size > 1:
            splitter = partial(split_tensor_along_last_dim,
                               num_partitions=self.tp_size)
            q = splitter(q)[self.tp_rank]
            k = splitter(k)[self.tp_rank]
        return q, k

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, N, _ = x.shape
228
229
        qkv, _ = self.qkv(x)
        q, k, v = qkv.chunk(3, dim=-1)
230

231
232
233
        if self.qk_normalization:
            q, k = self._apply_qk_norm(q, k)

234
        out = self.attn(q, k, v)
235
236
        out, _ = self.proj(out)
        return out
237
238


239
240
241
class InternSdpaAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

242
243
244
245
246
247
    def __init__(
        self,
        config: PretrainedConfig,
        *,
        num_dummy_heads: int = 0,
    ) -> None:
248
        super().__init__()
249

250
251
252
253
254
255
256
257
258
259
        self.config = config
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads
        if self.head_dim * self.num_heads != self.embed_dim:
            raise ValueError(
                f'embed_dim must be divisible by num_heads '
                f'(got `embed_dim`: {self.embed_dim} and `num_heads`:'
                f' {self.num_heads}).')

260
261
262
        # Additional dummy heads are used to enable TP for common GPU counts.
        self.dummy_dim = (num_dummy_heads + self.num_heads) * self.head_dim

263
264
        self.scale = self.head_dim**-0.5
        self.qkv = nn.Linear(self.embed_dim,
265
                             3 * self.dummy_dim,
266
267
268
269
270
                             bias=config.qkv_bias)

        self.qk_normalization = config.qk_normalization

        if self.qk_normalization:
271
272
273
274
275
276
            self.q_norm = RMSNorm(self.dummy_dim,
                                  eps=config.layer_norm_eps,
                                  var_hidden_size=self.embed_dim)
            self.k_norm = RMSNorm(self.dummy_dim,
                                  eps=config.layer_norm_eps,
                                  var_hidden_size=self.embed_dim)
277

278
        self.proj = nn.Linear(self.dummy_dim, self.embed_dim)
279

280
281
282
283
        # Use unified MultiHeadAttention with automatic backend selection
        self.attn = MultiHeadAttention(self.num_heads, self.head_dim,
                                       self.scale)

284
    def forward(self, x: torch.Tensor) -> torch.Tensor:
285
286
287
288
289
290
291
292
293
294
        B, N, C = x.shape
        qkv = self.qkv(x)
        q, k, v = qkv.chunk(3, dim=-1)

        q = q.view(B, N, self.num_heads, self.head_dim)
        k = k.view(B, N, self.num_heads, self.head_dim)
        v = v.view(B, N, self.num_heads, self.head_dim)

        if self.qk_normalization:
            B_, N_, H_, D_ = q.shape
295
296
            q = self.q_norm(q.flatten(-2, -1)).view(B_, N_, H_, D_)
            k = self.k_norm(k.flatten(-2, -1)).view(B_, N_, H_, D_)
297

298
299
        # Use unified MultiHeadAttention with automatic backend selection
        x = self.attn(q, k, v)
300
301
302
303
304

        x = self.proj(x)
        return x


305
306
class InternMLP(nn.Module):

307
308
309
310
311
    def __init__(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
312
        use_data_parallel: bool = False,
313
    ) -> None:
314
        super().__init__()
315

316
317
        self.config = config
        self.activation_fn = get_act_fn(config.hidden_act)
318
319
320
321
322
323
324
325
326
327
328
329
330
331
        cls_fc1 = (ReplicatedLinear
                   if use_data_parallel else ColumnParallelLinear)
        self.fc1 = cls_fc1(config.hidden_size,
                           config.intermediate_size,
                           bias=True,
                           quant_config=quant_config,
                           prefix=f"{prefix}.fc1")
        cls_fc2 = (ReplicatedLinear
                   if use_data_parallel else RowParallelLinear)
        self.fc2 = cls_fc2(config.intermediate_size,
                           config.hidden_size,
                           bias=True,
                           quant_config=quant_config,
                           prefix=f"{prefix}.fc2")
332
333
334
335
336
337
338
339
340
341
342

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states, _ = self.fc2(hidden_states)

        return hidden_states


class InternVisionEncoderLayer(nn.Module):

343
344
345
346
347
348
    def __init__(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig] = None,
        *,
        num_dummy_heads: int = 0,
349
        prefix: str = "",
350
        use_data_parallel: bool = False,
351
    ) -> None:
352
        super().__init__()
353

354
355
356
357
        self.embed_dim = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.norm_type = config.norm_type

358
359
        self.attn = self._init_attn(config,
                                    quant_config,
360
                                    num_dummy_heads=num_dummy_heads,
361
362
                                    prefix=f"{prefix}.attn",
                                    use_data_parallel=use_data_parallel)
363

364
365
        self.mlp = InternMLP(config,
                             quant_config=quant_config,
366
367
                             prefix=f"{prefix}.mlp",
                             use_data_parallel=use_data_parallel)
368
369
370
371
372
373
374
375
376
377
        self.norm1 = NORM2FN[self.norm_type](self.embed_dim,
                                             eps=config.layer_norm_eps)
        self.norm2 = NORM2FN[self.norm_type](self.embed_dim,
                                             eps=config.layer_norm_eps)

        self.ls1 = nn.Parameter(config.initializer_factor *
                                torch.ones(self.embed_dim))
        self.ls2 = nn.Parameter(config.initializer_factor *
                                torch.ones(self.embed_dim))

378
379
380
381
382
383
    def _init_attn(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig],
        *,
        num_dummy_heads: int,
384
        prefix: str = "",
385
        use_data_parallel: bool = False,
386
387
    ):
        # fallback to sdpa attention if tp unavailable
388
389
390
        # tp_size = get_tensor_model_parallel_world_size()
        tp_size = (1 if use_data_parallel else
                   get_tensor_model_parallel_world_size())
391
392
        num_heads = config.num_attention_heads

393
        if (num_heads + num_dummy_heads) % tp_size == 0:
394
395
            return InternParallelAttention(config,
                                           quant_config=quant_config,
396
                                           num_dummy_heads=num_dummy_heads,
397
398
                                           prefix=prefix,
                                           use_data_parallel=use_data_parallel)
399
400
401

        return InternSdpaAttention(config, num_dummy_heads=num_dummy_heads)

402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
    def forward(
        self,
        hidden_states: torch.Tensor,
    ):
        hidden_states = hidden_states + self.attn(
            self.norm1(hidden_states)) * self.ls1

        hidden_states = hidden_states + self.mlp(
            self.norm2(hidden_states)) * self.ls2

        return hidden_states


class InternVisionEncoder(nn.Module):

417
418
419
420
421
422
423
    def __init__(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig] = None,
        *,
        num_hidden_layers_override: Optional[int] = None,
        num_dummy_heads: int = 0,
424
        prefix: str = "",
425
        use_data_parallel: bool = False,
426
    ):
427
        super().__init__()
428

429
430
431
432
433
434
        self.config = config

        if num_hidden_layers_override is None:
            num_hidden_layers = config.num_hidden_layers
        else:
            num_hidden_layers = num_hidden_layers_override
435

436
        self.layers = nn.ModuleList([
437
438
            InternVisionEncoderLayer(config,
                                     quant_config,
439
                                     num_dummy_heads=num_dummy_heads,
440
441
                                     prefix=f"{prefix}.layers.{layer_idx}",
                                     use_data_parallel=use_data_parallel)
442
            for layer_idx in range(num_hidden_layers)
443
444
445
446
447
448
449
450
451
452
453
454
455
        ])

    def forward(self, inputs_embeds: torch.Tensor):

        hidden_states = inputs_embeds
        for encoder_layer in self.layers:
            hidden_states = encoder_layer(hidden_states)

        return hidden_states


class InternVisionModel(nn.Module):

456
457
458
459
    packed_modules_mapping = {
        "qkv": ["qkv"],
    }

460
461
462
463
464
465
466
    def __init__(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig] = None,
        *,
        num_hidden_layers_override: Optional[int] = None,
        num_dummy_heads: int = 0,
467
        prefix: str = "",
468
        use_data_parallel: bool = False,
469
    ) -> None:
470
        super().__init__()
471

472
        self.config = config
473
        self.use_data_parallel = use_data_parallel
474
475
476
477
478

        self.embeddings = InternVisionEmbeddings(config)
        self.encoder = InternVisionEncoder(
            config=config,
            quant_config=quant_config,
479
480
            num_hidden_layers_override=num_hidden_layers_override,
            num_dummy_heads=num_dummy_heads,
481
            prefix=f"{prefix}.encoder",
482
            use_data_parallel=use_data_parallel,
483
        )
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505

    def get_input_embeddings(self):
        return self.embeddings

    def forward(
        self,
        pixel_values: Optional[torch.Tensor] = None,
        pixel_embeds: Optional[torch.Tensor] = None,
    ) -> torch.FloatTensor:
        if pixel_values is None and pixel_embeds is None:
            raise ValueError(
                'You have to specify pixel_values or pixel_embeds')

        if pixel_embeds is not None:
            hidden_states = pixel_embeds
        elif pixel_values is not None:
            if pixel_values.ndim == 4:
                hidden_states = self.embeddings(pixel_values)
            else:
                raise ValueError(
                    f'wrong pixel_values size: {pixel_values.shape}')

506
507
508
509
510
        if self.use_data_parallel:
            encoder_outputs = run_dp_sharded_vision_model(
                hidden_states, self.encoder)
        else:
            encoder_outputs = self.encoder(inputs_embeds=hidden_states)
511
512

        return encoder_outputs
513

514
515
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
516
        params_dict = dict(self.named_parameters())
517
        loaded_params: set[str] = set()
518
519
520
521
522
        for name, loaded_weight in weights:
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)
523
524
            loaded_params.add(name)
        return loaded_params