intern_vit.py 15 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
# adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_intern_vit.py
# --------------------------------------------------------
# InternVL
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
10
from collections.abc import Iterable
11
from functools import partial
12
13
14
15
16
17

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig

18
19
20
21
22
23
24
from vllm.distributed import (
    divide,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
    split_tensor_along_last_dim,
    tensor_model_parallel_all_gather,
)
25
from vllm.model_executor.layers.activation import get_act_fn
26
from vllm.model_executor.layers.attention.mm_encoder_attention import MMEncoderAttention
27
from vllm.model_executor.layers.conv import Conv2dLayer
28
from vllm.model_executor.layers.layernorm import RMSNorm
29
30
31
32
33
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
34
from vllm.model_executor.layers.quantization import QuantizationConfig
35
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
zhuwenwen's avatar
zhuwenwen committed
36
import vllm.envs as envs
37

38

39
from .vision import run_dp_sharded_vision_model
40
41

NORM2FN = {
42
43
    "rms_norm": RMSNorm,
    "layer_norm": nn.LayerNorm,
44
45
46
47
48
49
50
51
52
53
54
55
56
}


class InternVisionEmbeddings(nn.Module):
    def __init__(self, config: PretrainedConfig):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.image_size = config.image_size
        self.patch_size = config.patch_size

        self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim))

57
        self.patch_embedding = Conv2dLayer(
58
59
60
61
62
            in_channels=3,
            out_channels=self.embed_dim,
            kernel_size=self.patch_size,
            stride=self.patch_size,
        )
63

64
        self.num_patches = (self.image_size // self.patch_size) ** 2
65
66
67
        self.num_positions = self.num_patches + 1

        self.position_embedding = nn.Parameter(
68
69
            torch.randn(1, self.num_positions, self.embed_dim)
        )
70

71
    def _get_pos_embed(self, pos_embed: torch.Tensor, H: int, W: int):
72
        target_dtype = pos_embed.dtype
73
74
75
76
77
78
79
80
81
82
83
84
85
86
        pos_embed = (
            pos_embed.float()
            .reshape(
                1,
                self.image_size // self.patch_size,
                self.image_size // self.patch_size,
                -1,
            )
            .permute(0, 3, 1, 2)
        )
        pos_embed = F.interpolate(
            pos_embed, size=(H, W), mode="bicubic", align_corners=False
        )
        return pos_embed.reshape(1, -1, H * W).permute(0, 2, 1).to(target_dtype)
87
88
89
90
91
92
93
94
95
96
97
98
99

    def _get_position_embedding(self, H: int, W: int) -> torch.Tensor:
        position_embedding = self.position_embedding
        if self.num_patches == H * W:
            return position_embedding

        return torch.cat(
            [
                position_embedding[:, :1, :],
                self._get_pos_embed(position_embedding[:, 1:, :], H, W),
            ],
            dim=1,
        )
100
101
102

    def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
        target_dtype = self.patch_embedding.weight.dtype
103
104
105
        patch_embeds = self.patch_embedding(
            pixel_values.to(target_dtype)
        )  # shape = [*, channel, width, height]
106
107
        batch_size, _, height, width = patch_embeds.shape
        patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
108
        class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
109
        embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
110
        position_embedding = self._get_position_embedding(height, width)
111
112
113
114
        embeddings = embeddings + position_embedding.to(target_dtype)
        return embeddings


115
116
117
118
119
120
121
122
123
124
125
class InternVisionPatchModel(nn.Module):
    def __init__(self, config: PretrainedConfig):
        super().__init__()
        self.config = config
        self.embeddings = InternVisionEmbeddings(config)

    def get_input_embeddings(self):
        return self.embeddings

    def forward(
        self,
126
127
        pixel_values: torch.Tensor | None = None,
        pixel_embeds: torch.Tensor | None = None,
128
129
    ) -> torch.FloatTensor:
        if pixel_values is None and pixel_embeds is None:
130
            raise ValueError("You have to specify pixel_values or pixel_embeds")
131
132
133
134
135
136
137

        if pixel_embeds is not None:
            hidden_states = pixel_embeds
        elif pixel_values is not None:
            if pixel_values.ndim == 4:
                hidden_states = self.embeddings(pixel_values)
            else:
138
                raise ValueError(f"wrong pixel_values size: {pixel_values.shape}")
139
140
141
142

        return hidden_states


143
class InternParallelAttention(nn.Module):
144
145
    """Multi-headed attention from 'Attention Is All You Need' paper"""

146
147
148
    def __init__(
        self,
        config: PretrainedConfig,
149
        quant_config: QuantizationConfig | None = None,
150
151
        *,
        num_dummy_heads: int = 0,
152
        prefix: str = "",
153
        use_data_parallel: bool = False,
154
    ) -> None:
155
        super().__init__()
156

157
158
159
160
161
162
        self.config = config
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads
        if self.head_dim * self.num_heads != self.embed_dim:
            raise ValueError(
163
164
165
166
                f"embed_dim must be divisible by num_heads "
                f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
                f" {self.num_heads})."
            )
167

168
169
170
171
        self.tp_size = (
            1 if use_data_parallel else get_tensor_model_parallel_world_size()
        )
        self.tp_rank = 0 if use_data_parallel else get_tensor_model_parallel_rank()
172
173
174

        # Additional dummy heads are used to enable TP for common GPU counts.
        self.dummy_dim = (num_dummy_heads + self.num_heads) * self.head_dim
175
176
177
        self.num_heads_per_partition = divide(
            num_dummy_heads + self.num_heads, self.tp_size
        )
178

179
        self.scale = self.head_dim**-0.5
180
181
182
        self.qkv = QKVParallelLinear(
            self.embed_dim,
            self.head_dim,
183
            num_dummy_heads + self.num_heads,
184
185
            bias=config.qkv_bias,
            quant_config=quant_config,
186
            prefix=f"{prefix}.qkv",
187
            disable_tp=use_data_parallel,
188
        )
189
190
191
192

        self.qk_normalization = config.qk_normalization

        if self.qk_normalization:
193
194
195
196
197
198
199
200
201
202
            self.q_norm = RMSNorm(
                self.dummy_dim,
                eps=config.layer_norm_eps,
                var_hidden_size=self.embed_dim,
            )
            self.k_norm = RMSNorm(
                self.dummy_dim,
                eps=config.layer_norm_eps,
                var_hidden_size=self.embed_dim,
            )
203

204
        self.proj = RowParallelLinear(
205
            self.dummy_dim,
206
207
            self.embed_dim,
            quant_config=quant_config,
208
            prefix=f"{prefix}.proj",
209
            disable_tp=use_data_parallel,
210
211
        )

212
        self.attn = MMEncoderAttention(
213
214
            self.num_heads_per_partition, self.head_dim, self.scale
        )
215

216
217
218
219
    def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor):
        if self.tp_size > 1:
            q = tensor_model_parallel_all_gather(q.contiguous())
            k = tensor_model_parallel_all_gather(k.contiguous())
220
221
        q = self.q_norm(q)
        k = self.k_norm(k)
222
        if self.tp_size > 1:
223
            splitter = partial(split_tensor_along_last_dim, num_partitions=self.tp_size)
224
225
226
227
228
229
            q = splitter(q)[self.tp_rank]
            k = splitter(k)[self.tp_rank]
        return q, k

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, N, _ = x.shape
230
231
        qkv, _ = self.qkv(x)
        q, k, v = qkv.chunk(3, dim=-1)
232

233
        if self.qk_normalization:
234
            q, k = self._apply_qk_norm(q, k)
235

236
        out = self.attn(q, k, v)
237
238
        out, _ = self.proj(out)
        return out
239
240
241


class InternMLP(nn.Module):
242
243
244
    def __init__(
        self,
        config: PretrainedConfig,
245
        quant_config: QuantizationConfig | None = None,
246
        prefix: str = "",
247
        use_data_parallel: bool = False,
248
    ) -> None:
249
        super().__init__()
250

251
252
        self.config = config
        self.activation_fn = get_act_fn(config.hidden_act)
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
        self.fc1 = ColumnParallelLinear(
            config.hidden_size,
            config.intermediate_size,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.fc1",
            disable_tp=use_data_parallel,
        )
        self.fc2 = RowParallelLinear(
            config.intermediate_size,
            config.hidden_size,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.fc2",
            disable_tp=use_data_parallel,
        )
269
270
271
272
273
274
275
276
277
278

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states, _ = self.fc2(hidden_states)

        return hidden_states


class InternVisionEncoderLayer(nn.Module):
279
280
281
    def __init__(
        self,
        config: PretrainedConfig,
282
        quant_config: QuantizationConfig | None = None,
283
284
        *,
        num_dummy_heads: int = 0,
285
        prefix: str = "",
286
        use_data_parallel: bool = False,
287
        attn_cls: type[InternParallelAttention] = InternParallelAttention,
288
    ) -> None:
289
        super().__init__()
290

291
292
293
        self.embed_dim = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.norm_type = config.norm_type
294
        self.attn_cls = attn_cls
295

296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
        self.attn = self._init_attn(
            config,
            quant_config,
            num_dummy_heads=num_dummy_heads,
            prefix=f"{prefix}.attn",
            use_data_parallel=use_data_parallel,
        )

        self.mlp = InternMLP(
            config,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp",
            use_data_parallel=use_data_parallel,
        )
        self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
        self.norm2 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)

        self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
        self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
315

316
317
318
    def _init_attn(
        self,
        config: PretrainedConfig,
319
        quant_config: QuantizationConfig | None,
320
321
        *,
        num_dummy_heads: int,
322
        prefix: str = "",
323
        use_data_parallel: bool = False,
324
325
    ):
        # fallback to sdpa attention if tp unavailable
326
        tp_size = 1 if use_data_parallel else get_tensor_model_parallel_world_size()
327
328
        num_heads = config.num_attention_heads

329
330
        # if the number of heads is not divisible by tp_size,
        # we also disable Attention's TP
331
332
333
        use_data_parallel = (
            use_data_parallel or (num_heads + num_dummy_heads) % tp_size != 0
        )
334
        return self.attn_cls(
335
336
337
338
339
340
            config,
            quant_config=quant_config,
            num_dummy_heads=num_dummy_heads,
            prefix=prefix,
            use_data_parallel=use_data_parallel,
        )
341

342
343
344
345
    def forward(
        self,
        hidden_states: torch.Tensor,
    ):
346
        hidden_states = hidden_states + self.attn(self.norm1(hidden_states)) * self.ls1
347

348
        hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) * self.ls2
349
350
351
352
353

        return hidden_states


class InternVisionEncoder(nn.Module):
354
355
356
    def __init__(
        self,
        config: PretrainedConfig,
357
        quant_config: QuantizationConfig | None = None,
358
        *,
359
        num_hidden_layers_override: int | None = None,
360
        num_dummy_heads: int = 0,
361
        prefix: str = "",
362
        use_data_parallel: bool = False,
363
        layer_cls: type[InternVisionEncoderLayer] = InternVisionEncoderLayer,
364
    ):
365
        super().__init__()
366

367
        self.config = config
368
        self.layer_cls = layer_cls
369
370
371
372
373

        if num_hidden_layers_override is None:
            num_hidden_layers = config.num_hidden_layers
        else:
            num_hidden_layers = num_hidden_layers_override
374

375
376
        self.layers = nn.ModuleList(
            [
377
                self.layer_cls(
378
379
380
381
382
383
384
385
386
                    config,
                    quant_config,
                    num_dummy_heads=num_dummy_heads,
                    prefix=f"{prefix}.layers.{layer_idx}",
                    use_data_parallel=use_data_parallel,
                )
                for layer_idx in range(num_hidden_layers)
            ]
        )
387
388
389
390
391
392
393
394
395
396

    def forward(self, inputs_embeds: torch.Tensor):
        hidden_states = inputs_embeds
        for encoder_layer in self.layers:
            hidden_states = encoder_layer(hidden_states)

        return hidden_states


class InternVisionModel(nn.Module):
397
398
399
400
    packed_modules_mapping = {
        "qkv": ["qkv"],
    }

401
402
403
    def __init__(
        self,
        config: PretrainedConfig,
404
        quant_config: QuantizationConfig | None = None,
405
        *,
406
        num_hidden_layers_override: int | None = None,
407
        num_dummy_heads: int = 0,
408
        prefix: str = "",
409
        use_data_parallel: bool = False,
410
    ) -> None:
411
        super().__init__()
412

413
        self.config = config
414
        self.use_data_parallel = use_data_parallel
415
416
417
418
419

        self.embeddings = InternVisionEmbeddings(config)
        self.encoder = InternVisionEncoder(
            config=config,
            quant_config=quant_config,
420
421
            num_hidden_layers_override=num_hidden_layers_override,
            num_dummy_heads=num_dummy_heads,
422
            prefix=f"{prefix}.encoder",
423
            use_data_parallel=use_data_parallel,
424
        )
425
426
427
428
429
430

    def get_input_embeddings(self):
        return self.embeddings

    def forward(
        self,
431
432
        pixel_values: torch.Tensor | None = None,
        pixel_embeds: torch.Tensor | None = None,
433
434
    ) -> torch.FloatTensor:
        if pixel_values is None and pixel_embeds is None:
435
            raise ValueError("You have to specify pixel_values or pixel_embeds")
436
437
438
439
440
441
442

        if pixel_embeds is not None:
            hidden_states = pixel_embeds
        elif pixel_values is not None:
            if pixel_values.ndim == 4:
                hidden_states = self.embeddings(pixel_values)
            else:
443
                raise ValueError(f"wrong pixel_values size: {pixel_values.shape}")
444

445
        if self.use_data_parallel:
446
            encoder_outputs = run_dp_sharded_vision_model(hidden_states, self.encoder)
447
448
        else:
            encoder_outputs = self.encoder(inputs_embeds=hidden_states)
449
450

        return encoder_outputs
451

452
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
453
        params_dict = dict(self.named_parameters())
454
        loaded_params: set[str] = set()
455
456
        for name, loaded_weight in weights:
            param = params_dict[name]
457
            weight_loader = getattr(param, "weight_loader", default_weight_loader)
458
            weight_loader(param, loaded_weight)
459
460
            loaded_params.add(name)
        return loaded_params