intern_vit.py 16.8 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
7
8
# adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_intern_vit.py
# --------------------------------------------------------
# InternVL
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
9
from collections.abc import Iterable
10
from functools import partial
11
from typing import Optional
12
13
14
15
16
17

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig

18
from vllm.attention.layer import MultiHeadAttention
19
20
21
22
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size,
                              split_tensor_along_last_dim,
                              tensor_model_parallel_all_gather)
23
24
25
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
26
                                               QKVParallelLinear,
27
28
                                               RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
29
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58

NORM2FN = {
    'rms_norm': RMSNorm,
    'layer_norm': nn.LayerNorm,
}


class InternVisionEmbeddings(nn.Module):

    def __init__(self, config: PretrainedConfig):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.image_size = config.image_size
        self.patch_size = config.patch_size

        self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim))

        self.patch_embedding = nn.Conv2d(in_channels=3,
                                         out_channels=self.embed_dim,
                                         kernel_size=self.patch_size,
                                         stride=self.patch_size)

        self.num_patches = (self.image_size // self.patch_size)**2
        self.num_positions = self.num_patches + 1

        self.position_embedding = nn.Parameter(
            torch.randn(1, self.num_positions, self.embed_dim))

59
    def _get_pos_embed(self, pos_embed: torch.Tensor, H: int, W: int):
60
61
62
63
64
65
66
67
        target_dtype = pos_embed.dtype
        pos_embed = pos_embed.float().reshape(
            1, self.image_size // self.patch_size,
            self.image_size // self.patch_size, -1).permute(0, 3, 1, 2)
        pos_embed = F.interpolate(pos_embed,
                                  size=(H, W),
                                  mode='bicubic',
                                  align_corners=False)
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
        return pos_embed.reshape(1, -1, H * W).permute(0, 2,
                                                       1).to(target_dtype)

    def _get_position_embedding(self, H: int, W: int) -> torch.Tensor:
        position_embedding = self.position_embedding
        if self.num_patches == H * W:
            return position_embedding

        return torch.cat(
            [
                position_embedding[:, :1, :],
                self._get_pos_embed(position_embedding[:, 1:, :], H, W),
            ],
            dim=1,
        )
83
84
85
86
87
88
89
90
91
92

    def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
        target_dtype = self.patch_embedding.weight.dtype
        patch_embeds = self.patch_embedding(pixel_values.to(
            target_dtype))  # shape = [*, channel, width, height]
        batch_size, _, height, width = patch_embeds.shape
        patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
        class_embeds = self.class_embedding.expand(batch_size, 1,
                                                   -1).to(target_dtype)
        embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
93
        position_embedding = self._get_position_embedding(height, width)
94
95
96
97
        embeddings = embeddings + position_embedding.to(target_dtype)
        return embeddings


98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
class InternVisionPatchModel(nn.Module):

    def __init__(self, config: PretrainedConfig):
        super().__init__()
        self.config = config
        self.embeddings = InternVisionEmbeddings(config)

    def get_input_embeddings(self):
        return self.embeddings

    def forward(
        self,
        pixel_values: Optional[torch.Tensor] = None,
        pixel_embeds: Optional[torch.Tensor] = None,
    ) -> torch.FloatTensor:
        if pixel_values is None and pixel_embeds is None:
            raise ValueError(
                'You have to specify pixel_values or pixel_embeds')

        if pixel_embeds is not None:
            hidden_states = pixel_embeds
        elif pixel_values is not None:
            if pixel_values.ndim == 4:
                hidden_states = self.embeddings(pixel_values)
            else:
                raise ValueError(
                    f'wrong pixel_values size: {pixel_values.shape}')

        return hidden_states


129
class InternParallelAttention(nn.Module):
130
131
    """Multi-headed attention from 'Attention Is All You Need' paper"""

132
133
134
135
    def __init__(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig] = None,
136
137
        *,
        num_dummy_heads: int = 0,
138
        prefix: str = "",
139
    ) -> None:
140
        super().__init__()
141

142
143
144
145
146
147
148
149
150
151
        self.config = config
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads
        if self.head_dim * self.num_heads != self.embed_dim:
            raise ValueError(
                f'embed_dim must be divisible by num_heads '
                f'(got `embed_dim`: {self.embed_dim} and `num_heads`:'
                f' {self.num_heads}).')

152
153
154
155
156
157
158
159
        self.tp_size = get_tensor_model_parallel_world_size()
        self.tp_rank = get_tensor_model_parallel_rank()

        # Additional dummy heads are used to enable TP for common GPU counts.
        self.dummy_dim = (num_dummy_heads + self.num_heads) * self.head_dim
        self.num_heads_per_partition = divide(num_dummy_heads + self.num_heads,
                                              self.tp_size)

160
        self.scale = self.head_dim**-0.5
161
162
163
        self.qkv = QKVParallelLinear(
            self.embed_dim,
            self.head_dim,
164
            num_dummy_heads + self.num_heads,
165
166
            bias=config.qkv_bias,
            quant_config=quant_config,
167
            prefix=f"{prefix}.qkv",
168
        )
169
170
171
172

        self.qk_normalization = config.qk_normalization

        if self.qk_normalization:
173
174
175
176
177
178
            self.q_norm = RMSNorm(self.dummy_dim,
                                  eps=config.layer_norm_eps,
                                  var_hidden_size=self.embed_dim)
            self.k_norm = RMSNorm(self.dummy_dim,
                                  eps=config.layer_norm_eps,
                                  var_hidden_size=self.embed_dim)
179

180
        self.proj = RowParallelLinear(
181
            self.dummy_dim,
182
183
            self.embed_dim,
            quant_config=quant_config,
184
            prefix=f"{prefix}.proj",
185
186
        )

187
188
        self.attn = MultiHeadAttention(self.num_heads_per_partition,
                                       self.head_dim, self.scale)
189

190
191
192
193
    def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor):
        if self.tp_size > 1:
            q = tensor_model_parallel_all_gather(q.contiguous())
            k = tensor_model_parallel_all_gather(k.contiguous())
194
195
        q = self.q_norm(q)
        k = self.k_norm(k)
196
197
198
199
200
201
202
203
204
        if self.tp_size > 1:
            splitter = partial(split_tensor_along_last_dim,
                               num_partitions=self.tp_size)
            q = splitter(q)[self.tp_rank]
            k = splitter(k)[self.tp_rank]
        return q, k

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, N, _ = x.shape
205
206
        qkv, _ = self.qkv(x)
        q, k, v = qkv.chunk(3, dim=-1)
207

208
209
210
        if self.qk_normalization:
            q, k = self._apply_qk_norm(q, k)

211
        out = self.attn(q, k, v)
212
213
        out, _ = self.proj(out)
        return out
214
215


216
217
218
class InternSdpaAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

219
220
221
222
223
224
    def __init__(
        self,
        config: PretrainedConfig,
        *,
        num_dummy_heads: int = 0,
    ) -> None:
225
        super().__init__()
226

227
228
229
230
231
232
233
234
235
236
        self.config = config
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads
        if self.head_dim * self.num_heads != self.embed_dim:
            raise ValueError(
                f'embed_dim must be divisible by num_heads '
                f'(got `embed_dim`: {self.embed_dim} and `num_heads`:'
                f' {self.num_heads}).')

237
238
239
        # Additional dummy heads are used to enable TP for common GPU counts.
        self.dummy_dim = (num_dummy_heads + self.num_heads) * self.head_dim

240
241
        self.scale = self.head_dim**-0.5
        self.qkv = nn.Linear(self.embed_dim,
242
                             3 * self.dummy_dim,
243
244
245
246
247
                             bias=config.qkv_bias)

        self.qk_normalization = config.qk_normalization

        if self.qk_normalization:
248
249
250
251
252
253
            self.q_norm = RMSNorm(self.dummy_dim,
                                  eps=config.layer_norm_eps,
                                  var_hidden_size=self.embed_dim)
            self.k_norm = RMSNorm(self.dummy_dim,
                                  eps=config.layer_norm_eps,
                                  var_hidden_size=self.embed_dim)
254

255
        self.proj = nn.Linear(self.dummy_dim, self.embed_dim)
256

257
    def forward(self, x: torch.Tensor) -> torch.Tensor:
258
259
260
261
262
263
264
265
266
267
        B, N, C = x.shape
        qkv = self.qkv(x)
        q, k, v = qkv.chunk(3, dim=-1)

        q = q.view(B, N, self.num_heads, self.head_dim)
        k = k.view(B, N, self.num_heads, self.head_dim)
        v = v.view(B, N, self.num_heads, self.head_dim)

        if self.qk_normalization:
            B_, N_, H_, D_ = q.shape
268
269
            q = self.q_norm(q.flatten(-2, -1)).view(B_, N_, H_, D_)
            k = self.k_norm(k.flatten(-2, -1)).view(B_, N_, H_, D_)
270
271
272
273
274
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        x = F.scaled_dot_product_attention(q, k, v, scale=self.scale)
275
        x = x.transpose(1, 2).reshape(B, N, -1)
276
277
278
279
280

        x = self.proj(x)
        return x


281
282
class InternMLP(nn.Module):

283
284
285
286
287
288
    def __init__(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
289
        super().__init__()
290

291
292
293
294
295
        self.config = config
        self.activation_fn = get_act_fn(config.hidden_act)
        self.fc1 = ColumnParallelLinear(config.hidden_size,
                                        config.intermediate_size,
                                        bias=True,
296
297
                                        quant_config=quant_config,
                                        prefix=f"{prefix}.fc1")
298
299
300
        self.fc2 = RowParallelLinear(config.intermediate_size,
                                     config.hidden_size,
                                     bias=True,
301
302
                                     quant_config=quant_config,
                                     prefix=f"{prefix}.fc2")
303
304
305
306
307
308
309
310
311
312
313

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states, _ = self.fc2(hidden_states)

        return hidden_states


class InternVisionEncoderLayer(nn.Module):

314
315
316
317
318
319
    def __init__(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig] = None,
        *,
        num_dummy_heads: int = 0,
320
        prefix: str = "",
321
    ) -> None:
322
        super().__init__()
323

324
325
326
327
        self.embed_dim = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.norm_type = config.norm_type

328
329
        self.attn = self._init_attn(config,
                                    quant_config,
330
331
                                    num_dummy_heads=num_dummy_heads,
                                    prefix=f"{prefix}.attn")
332

333
334
335
        self.mlp = InternMLP(config,
                             quant_config=quant_config,
                             prefix=f"{prefix}.mlp")
336
337
338
339
340
341
342
343
344
345
        self.norm1 = NORM2FN[self.norm_type](self.embed_dim,
                                             eps=config.layer_norm_eps)
        self.norm2 = NORM2FN[self.norm_type](self.embed_dim,
                                             eps=config.layer_norm_eps)

        self.ls1 = nn.Parameter(config.initializer_factor *
                                torch.ones(self.embed_dim))
        self.ls2 = nn.Parameter(config.initializer_factor *
                                torch.ones(self.embed_dim))

346
347
348
349
350
351
    def _init_attn(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig],
        *,
        num_dummy_heads: int,
352
        prefix: str = "",
353
354
355
356
357
    ):
        # fallback to sdpa attention if tp unavailable
        tp_size = get_tensor_model_parallel_world_size()
        num_heads = config.num_attention_heads

358
        if (num_heads + num_dummy_heads) % tp_size == 0:
359
360
            return InternParallelAttention(config,
                                           quant_config=quant_config,
361
362
                                           num_dummy_heads=num_dummy_heads,
                                           prefix=prefix)
363
364
365

        return InternSdpaAttention(config, num_dummy_heads=num_dummy_heads)

366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
    def forward(
        self,
        hidden_states: torch.Tensor,
    ):
        hidden_states = hidden_states + self.attn(
            self.norm1(hidden_states)) * self.ls1

        hidden_states = hidden_states + self.mlp(
            self.norm2(hidden_states)) * self.ls2

        return hidden_states


class InternVisionEncoder(nn.Module):

381
382
383
384
385
386
387
    def __init__(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig] = None,
        *,
        num_hidden_layers_override: Optional[int] = None,
        num_dummy_heads: int = 0,
388
        prefix: str = "",
389
    ):
390
        super().__init__()
391

392
393
394
395
396
397
        self.config = config

        if num_hidden_layers_override is None:
            num_hidden_layers = config.num_hidden_layers
        else:
            num_hidden_layers = num_hidden_layers_override
398

399
        self.layers = nn.ModuleList([
400
401
            InternVisionEncoderLayer(config,
                                     quant_config,
402
403
404
                                     num_dummy_heads=num_dummy_heads,
                                     prefix=f"{prefix}.layers.{layer_idx}")
            for layer_idx in range(num_hidden_layers)
405
406
407
408
409
410
411
412
413
414
415
416
417
        ])

    def forward(self, inputs_embeds: torch.Tensor):

        hidden_states = inputs_embeds
        for encoder_layer in self.layers:
            hidden_states = encoder_layer(hidden_states)

        return hidden_states


class InternVisionModel(nn.Module):

418
419
420
421
    packed_modules_mapping = {
        "qkv": ["qkv"],
    }

422
423
424
425
426
427
428
    def __init__(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig] = None,
        *,
        num_hidden_layers_override: Optional[int] = None,
        num_dummy_heads: int = 0,
429
430
        prefix: str = "",
    ) -> None:
431
        super().__init__()
432

433
434
435
436
437
438
        self.config = config

        self.embeddings = InternVisionEmbeddings(config)
        self.encoder = InternVisionEncoder(
            config=config,
            quant_config=quant_config,
439
440
            num_hidden_layers_override=num_hidden_layers_override,
            num_dummy_heads=num_dummy_heads,
441
            prefix=f"{prefix}.encoder",
442
        )
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467

    def get_input_embeddings(self):
        return self.embeddings

    def forward(
        self,
        pixel_values: Optional[torch.Tensor] = None,
        pixel_embeds: Optional[torch.Tensor] = None,
    ) -> torch.FloatTensor:
        if pixel_values is None and pixel_embeds is None:
            raise ValueError(
                'You have to specify pixel_values or pixel_embeds')

        if pixel_embeds is not None:
            hidden_states = pixel_embeds
        elif pixel_values is not None:
            if pixel_values.ndim == 4:
                hidden_states = self.embeddings(pixel_values)
            else:
                raise ValueError(
                    f'wrong pixel_values size: {pixel_values.shape}')

        encoder_outputs = self.encoder(inputs_embeds=hidden_states)

        return encoder_outputs
468

469
470
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
471
        params_dict = dict(self.named_parameters())
472
        loaded_params: set[str] = set()
473
474
475
476
477
        for name, loaded_weight in weights:
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)
478
479
            loaded_params.add(name)
        return loaded_params