intern_vit.py 17 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
# adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_intern_vit.py
# --------------------------------------------------------
# InternVL
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
10
from collections.abc import Iterable
11
from functools import partial
12
from typing import Optional
13
14
15
16
17
18

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig

19
from vllm.attention.layer import MultiHeadAttention
20
21
22
23
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size,
                              split_tensor_along_last_dim,
                              tensor_model_parallel_all_gather)
24
25
26
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
27
                                               QKVParallelLinear,
28
29
                                               RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
30
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59

NORM2FN = {
    'rms_norm': RMSNorm,
    'layer_norm': nn.LayerNorm,
}


class InternVisionEmbeddings(nn.Module):

    def __init__(self, config: PretrainedConfig):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.image_size = config.image_size
        self.patch_size = config.patch_size

        self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim))

        self.patch_embedding = nn.Conv2d(in_channels=3,
                                         out_channels=self.embed_dim,
                                         kernel_size=self.patch_size,
                                         stride=self.patch_size)

        self.num_patches = (self.image_size // self.patch_size)**2
        self.num_positions = self.num_patches + 1

        self.position_embedding = nn.Parameter(
            torch.randn(1, self.num_positions, self.embed_dim))

60
    def _get_pos_embed(self, pos_embed: torch.Tensor, H: int, W: int):
61
62
63
64
65
66
67
68
        target_dtype = pos_embed.dtype
        pos_embed = pos_embed.float().reshape(
            1, self.image_size // self.patch_size,
            self.image_size // self.patch_size, -1).permute(0, 3, 1, 2)
        pos_embed = F.interpolate(pos_embed,
                                  size=(H, W),
                                  mode='bicubic',
                                  align_corners=False)
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
        return pos_embed.reshape(1, -1, H * W).permute(0, 2,
                                                       1).to(target_dtype)

    def _get_position_embedding(self, H: int, W: int) -> torch.Tensor:
        position_embedding = self.position_embedding
        if self.num_patches == H * W:
            return position_embedding

        return torch.cat(
            [
                position_embedding[:, :1, :],
                self._get_pos_embed(position_embedding[:, 1:, :], H, W),
            ],
            dim=1,
        )
84
85
86
87
88
89
90
91
92
93

    def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
        target_dtype = self.patch_embedding.weight.dtype
        patch_embeds = self.patch_embedding(pixel_values.to(
            target_dtype))  # shape = [*, channel, width, height]
        batch_size, _, height, width = patch_embeds.shape
        patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
        class_embeds = self.class_embedding.expand(batch_size, 1,
                                                   -1).to(target_dtype)
        embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
94
        position_embedding = self._get_position_embedding(height, width)
95
96
97
98
        embeddings = embeddings + position_embedding.to(target_dtype)
        return embeddings


99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
class InternVisionPatchModel(nn.Module):

    def __init__(self, config: PretrainedConfig):
        super().__init__()
        self.config = config
        self.embeddings = InternVisionEmbeddings(config)

    def get_input_embeddings(self):
        return self.embeddings

    def forward(
        self,
        pixel_values: Optional[torch.Tensor] = None,
        pixel_embeds: Optional[torch.Tensor] = None,
    ) -> torch.FloatTensor:
        if pixel_values is None and pixel_embeds is None:
            raise ValueError(
                'You have to specify pixel_values or pixel_embeds')

        if pixel_embeds is not None:
            hidden_states = pixel_embeds
        elif pixel_values is not None:
            if pixel_values.ndim == 4:
                hidden_states = self.embeddings(pixel_values)
            else:
                raise ValueError(
                    f'wrong pixel_values size: {pixel_values.shape}')

        return hidden_states


130
class InternParallelAttention(nn.Module):
131
132
    """Multi-headed attention from 'Attention Is All You Need' paper"""

133
134
135
136
    def __init__(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig] = None,
137
138
        *,
        num_dummy_heads: int = 0,
139
        prefix: str = "",
140
    ) -> None:
141
        super().__init__()
142

143
144
145
146
147
148
149
150
151
152
        self.config = config
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads
        if self.head_dim * self.num_heads != self.embed_dim:
            raise ValueError(
                f'embed_dim must be divisible by num_heads '
                f'(got `embed_dim`: {self.embed_dim} and `num_heads`:'
                f' {self.num_heads}).')

153
154
155
156
157
158
159
160
        self.tp_size = get_tensor_model_parallel_world_size()
        self.tp_rank = get_tensor_model_parallel_rank()

        # Additional dummy heads are used to enable TP for common GPU counts.
        self.dummy_dim = (num_dummy_heads + self.num_heads) * self.head_dim
        self.num_heads_per_partition = divide(num_dummy_heads + self.num_heads,
                                              self.tp_size)

161
        self.scale = self.head_dim**-0.5
162
163
164
        self.qkv = QKVParallelLinear(
            self.embed_dim,
            self.head_dim,
165
            num_dummy_heads + self.num_heads,
166
167
            bias=config.qkv_bias,
            quant_config=quant_config,
168
            prefix=f"{prefix}.qkv",
169
        )
170
171
172
173

        self.qk_normalization = config.qk_normalization

        if self.qk_normalization:
174
175
176
177
178
179
            self.q_norm = RMSNorm(self.dummy_dim,
                                  eps=config.layer_norm_eps,
                                  var_hidden_size=self.embed_dim)
            self.k_norm = RMSNorm(self.dummy_dim,
                                  eps=config.layer_norm_eps,
                                  var_hidden_size=self.embed_dim)
180

181
        self.proj = RowParallelLinear(
182
            self.dummy_dim,
183
184
            self.embed_dim,
            quant_config=quant_config,
185
            prefix=f"{prefix}.proj",
186
187
        )

188
189
        self.attn = MultiHeadAttention(self.num_heads_per_partition,
                                       self.head_dim, self.scale)
190

191
192
193
194
    def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor):
        if self.tp_size > 1:
            q = tensor_model_parallel_all_gather(q.contiguous())
            k = tensor_model_parallel_all_gather(k.contiguous())
195
196
        q = self.q_norm(q)
        k = self.k_norm(k)
197
198
199
200
201
202
203
204
205
        if self.tp_size > 1:
            splitter = partial(split_tensor_along_last_dim,
                               num_partitions=self.tp_size)
            q = splitter(q)[self.tp_rank]
            k = splitter(k)[self.tp_rank]
        return q, k

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, N, _ = x.shape
206
207
        qkv, _ = self.qkv(x)
        q, k, v = qkv.chunk(3, dim=-1)
208

209
210
211
        if self.qk_normalization:
            q, k = self._apply_qk_norm(q, k)

212
        out = self.attn(q, k, v)
213
214
        out, _ = self.proj(out)
        return out
215
216


217
218
219
class InternSdpaAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

220
221
222
223
224
225
    def __init__(
        self,
        config: PretrainedConfig,
        *,
        num_dummy_heads: int = 0,
    ) -> None:
226
        super().__init__()
227

228
229
230
231
232
233
234
235
236
237
        self.config = config
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads
        if self.head_dim * self.num_heads != self.embed_dim:
            raise ValueError(
                f'embed_dim must be divisible by num_heads '
                f'(got `embed_dim`: {self.embed_dim} and `num_heads`:'
                f' {self.num_heads}).')

238
239
240
        # Additional dummy heads are used to enable TP for common GPU counts.
        self.dummy_dim = (num_dummy_heads + self.num_heads) * self.head_dim

241
242
        self.scale = self.head_dim**-0.5
        self.qkv = nn.Linear(self.embed_dim,
243
                             3 * self.dummy_dim,
244
245
246
247
248
                             bias=config.qkv_bias)

        self.qk_normalization = config.qk_normalization

        if self.qk_normalization:
249
250
251
252
253
254
            self.q_norm = RMSNorm(self.dummy_dim,
                                  eps=config.layer_norm_eps,
                                  var_hidden_size=self.embed_dim)
            self.k_norm = RMSNorm(self.dummy_dim,
                                  eps=config.layer_norm_eps,
                                  var_hidden_size=self.embed_dim)
255

256
        self.proj = nn.Linear(self.dummy_dim, self.embed_dim)
257

258
259
260
261
        # Use unified MultiHeadAttention with automatic backend selection
        self.attn = MultiHeadAttention(self.num_heads, self.head_dim,
                                       self.scale)

262
    def forward(self, x: torch.Tensor) -> torch.Tensor:
263
264
265
266
267
268
269
270
271
272
        B, N, C = x.shape
        qkv = self.qkv(x)
        q, k, v = qkv.chunk(3, dim=-1)

        q = q.view(B, N, self.num_heads, self.head_dim)
        k = k.view(B, N, self.num_heads, self.head_dim)
        v = v.view(B, N, self.num_heads, self.head_dim)

        if self.qk_normalization:
            B_, N_, H_, D_ = q.shape
273
274
            q = self.q_norm(q.flatten(-2, -1)).view(B_, N_, H_, D_)
            k = self.k_norm(k.flatten(-2, -1)).view(B_, N_, H_, D_)
275

276
277
        # Use unified MultiHeadAttention with automatic backend selection
        x = self.attn(q, k, v)
278
279
280
281
282

        x = self.proj(x)
        return x


283
284
class InternMLP(nn.Module):

285
286
287
288
289
290
    def __init__(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
291
        super().__init__()
292

293
294
295
296
297
        self.config = config
        self.activation_fn = get_act_fn(config.hidden_act)
        self.fc1 = ColumnParallelLinear(config.hidden_size,
                                        config.intermediate_size,
                                        bias=True,
298
299
                                        quant_config=quant_config,
                                        prefix=f"{prefix}.fc1")
300
301
302
        self.fc2 = RowParallelLinear(config.intermediate_size,
                                     config.hidden_size,
                                     bias=True,
303
304
                                     quant_config=quant_config,
                                     prefix=f"{prefix}.fc2")
305
306
307
308
309
310
311
312
313
314
315

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states, _ = self.fc2(hidden_states)

        return hidden_states


class InternVisionEncoderLayer(nn.Module):

316
317
318
319
320
321
    def __init__(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig] = None,
        *,
        num_dummy_heads: int = 0,
322
        prefix: str = "",
323
    ) -> None:
324
        super().__init__()
325

326
327
328
329
        self.embed_dim = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.norm_type = config.norm_type

330
331
        self.attn = self._init_attn(config,
                                    quant_config,
332
333
                                    num_dummy_heads=num_dummy_heads,
                                    prefix=f"{prefix}.attn")
334

335
336
337
        self.mlp = InternMLP(config,
                             quant_config=quant_config,
                             prefix=f"{prefix}.mlp")
338
339
340
341
342
343
344
345
346
347
        self.norm1 = NORM2FN[self.norm_type](self.embed_dim,
                                             eps=config.layer_norm_eps)
        self.norm2 = NORM2FN[self.norm_type](self.embed_dim,
                                             eps=config.layer_norm_eps)

        self.ls1 = nn.Parameter(config.initializer_factor *
                                torch.ones(self.embed_dim))
        self.ls2 = nn.Parameter(config.initializer_factor *
                                torch.ones(self.embed_dim))

348
349
350
351
352
353
    def _init_attn(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig],
        *,
        num_dummy_heads: int,
354
        prefix: str = "",
355
356
357
358
359
    ):
        # fallback to sdpa attention if tp unavailable
        tp_size = get_tensor_model_parallel_world_size()
        num_heads = config.num_attention_heads

360
        if (num_heads + num_dummy_heads) % tp_size == 0:
361
362
            return InternParallelAttention(config,
                                           quant_config=quant_config,
363
364
                                           num_dummy_heads=num_dummy_heads,
                                           prefix=prefix)
365
366
367

        return InternSdpaAttention(config, num_dummy_heads=num_dummy_heads)

368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
    def forward(
        self,
        hidden_states: torch.Tensor,
    ):
        hidden_states = hidden_states + self.attn(
            self.norm1(hidden_states)) * self.ls1

        hidden_states = hidden_states + self.mlp(
            self.norm2(hidden_states)) * self.ls2

        return hidden_states


class InternVisionEncoder(nn.Module):

383
384
385
386
387
388
389
    def __init__(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig] = None,
        *,
        num_hidden_layers_override: Optional[int] = None,
        num_dummy_heads: int = 0,
390
        prefix: str = "",
391
    ):
392
        super().__init__()
393

394
395
396
397
398
399
        self.config = config

        if num_hidden_layers_override is None:
            num_hidden_layers = config.num_hidden_layers
        else:
            num_hidden_layers = num_hidden_layers_override
400

401
        self.layers = nn.ModuleList([
402
403
            InternVisionEncoderLayer(config,
                                     quant_config,
404
405
406
                                     num_dummy_heads=num_dummy_heads,
                                     prefix=f"{prefix}.layers.{layer_idx}")
            for layer_idx in range(num_hidden_layers)
407
408
409
410
411
412
413
414
415
416
417
418
419
        ])

    def forward(self, inputs_embeds: torch.Tensor):

        hidden_states = inputs_embeds
        for encoder_layer in self.layers:
            hidden_states = encoder_layer(hidden_states)

        return hidden_states


class InternVisionModel(nn.Module):

420
421
422
423
    packed_modules_mapping = {
        "qkv": ["qkv"],
    }

424
425
426
427
428
429
430
    def __init__(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig] = None,
        *,
        num_hidden_layers_override: Optional[int] = None,
        num_dummy_heads: int = 0,
431
432
        prefix: str = "",
    ) -> None:
433
        super().__init__()
434

435
436
437
438
439
440
        self.config = config

        self.embeddings = InternVisionEmbeddings(config)
        self.encoder = InternVisionEncoder(
            config=config,
            quant_config=quant_config,
441
442
            num_hidden_layers_override=num_hidden_layers_override,
            num_dummy_heads=num_dummy_heads,
443
            prefix=f"{prefix}.encoder",
444
        )
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469

    def get_input_embeddings(self):
        return self.embeddings

    def forward(
        self,
        pixel_values: Optional[torch.Tensor] = None,
        pixel_embeds: Optional[torch.Tensor] = None,
    ) -> torch.FloatTensor:
        if pixel_values is None and pixel_embeds is None:
            raise ValueError(
                'You have to specify pixel_values or pixel_embeds')

        if pixel_embeds is not None:
            hidden_states = pixel_embeds
        elif pixel_values is not None:
            if pixel_values.ndim == 4:
                hidden_states = self.embeddings(pixel_values)
            else:
                raise ValueError(
                    f'wrong pixel_values size: {pixel_values.shape}')

        encoder_outputs = self.encoder(inputs_embeds=hidden_states)

        return encoder_outputs
470

471
472
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
473
        params_dict = dict(self.named_parameters())
474
        loaded_params: set[str] = set()
475
476
477
478
479
        for name, loaded_weight in weights:
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)
480
481
            loaded_params.add(name)
        return loaded_params