intern_vit.py 16.9 KB
Newer Older
1
2
3
4
5
6
# adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_intern_vit.py
# --------------------------------------------------------
# InternVL
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
7
from functools import partial
8
from typing import Iterable, Optional, Set, Tuple
9
10
11
12
13
14

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig

15
from vllm.attention.layer import MultiHeadAttention
16
17
18
19
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size,
                              split_tensor_along_last_dim,
                              tensor_model_parallel_all_gather)
20
21
22
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
23
                                               QKVParallelLinear,
24
25
                                               RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
26
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55

NORM2FN = {
    'rms_norm': RMSNorm,
    'layer_norm': nn.LayerNorm,
}


class InternVisionEmbeddings(nn.Module):

    def __init__(self, config: PretrainedConfig):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.image_size = config.image_size
        self.patch_size = config.patch_size

        self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim))

        self.patch_embedding = nn.Conv2d(in_channels=3,
                                         out_channels=self.embed_dim,
                                         kernel_size=self.patch_size,
                                         stride=self.patch_size)

        self.num_patches = (self.image_size // self.patch_size)**2
        self.num_positions = self.num_patches + 1

        self.position_embedding = nn.Parameter(
            torch.randn(1, self.num_positions, self.embed_dim))

56
    def _get_pos_embed(self, pos_embed: torch.Tensor, H: int, W: int):
57
58
59
60
61
62
63
64
        target_dtype = pos_embed.dtype
        pos_embed = pos_embed.float().reshape(
            1, self.image_size // self.patch_size,
            self.image_size // self.patch_size, -1).permute(0, 3, 1, 2)
        pos_embed = F.interpolate(pos_embed,
                                  size=(H, W),
                                  mode='bicubic',
                                  align_corners=False)
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
        return pos_embed.reshape(1, -1, H * W).permute(0, 2,
                                                       1).to(target_dtype)

    def _get_position_embedding(self, H: int, W: int) -> torch.Tensor:
        position_embedding = self.position_embedding
        if self.num_patches == H * W:
            return position_embedding

        return torch.cat(
            [
                position_embedding[:, :1, :],
                self._get_pos_embed(position_embedding[:, 1:, :], H, W),
            ],
            dim=1,
        )
80
81
82
83
84
85
86
87
88
89

    def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
        target_dtype = self.patch_embedding.weight.dtype
        patch_embeds = self.patch_embedding(pixel_values.to(
            target_dtype))  # shape = [*, channel, width, height]
        batch_size, _, height, width = patch_embeds.shape
        patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
        class_embeds = self.class_embedding.expand(batch_size, 1,
                                                   -1).to(target_dtype)
        embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
90
        position_embedding = self._get_position_embedding(height, width)
91
92
93
94
        embeddings = embeddings + position_embedding.to(target_dtype)
        return embeddings


95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
class InternVisionPatchModel(nn.Module):

    def __init__(self, config: PretrainedConfig):
        super().__init__()
        self.config = config
        self.embeddings = InternVisionEmbeddings(config)

    def get_input_embeddings(self):
        return self.embeddings

    def forward(
        self,
        pixel_values: Optional[torch.Tensor] = None,
        pixel_embeds: Optional[torch.Tensor] = None,
    ) -> torch.FloatTensor:
        if pixel_values is None and pixel_embeds is None:
            raise ValueError(
                'You have to specify pixel_values or pixel_embeds')

        if pixel_embeds is not None:
            hidden_states = pixel_embeds
        elif pixel_values is not None:
            if pixel_values.ndim == 4:
                hidden_states = self.embeddings(pixel_values)
            else:
                raise ValueError(
                    f'wrong pixel_values size: {pixel_values.shape}')

        return hidden_states


126
class InternParallelAttention(nn.Module):
127
128
    """Multi-headed attention from 'Attention Is All You Need' paper"""

129
130
131
132
    def __init__(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig] = None,
133
134
        *,
        num_dummy_heads: int = 0,
135
        prefix: str = "",
136
    ) -> None:
137
        super().__init__()
138

139
140
141
142
143
144
145
146
147
148
        self.config = config
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads
        if self.head_dim * self.num_heads != self.embed_dim:
            raise ValueError(
                f'embed_dim must be divisible by num_heads '
                f'(got `embed_dim`: {self.embed_dim} and `num_heads`:'
                f' {self.num_heads}).')

149
150
151
152
153
154
155
156
        self.tp_size = get_tensor_model_parallel_world_size()
        self.tp_rank = get_tensor_model_parallel_rank()

        # Additional dummy heads are used to enable TP for common GPU counts.
        self.dummy_dim = (num_dummy_heads + self.num_heads) * self.head_dim
        self.num_heads_per_partition = divide(num_dummy_heads + self.num_heads,
                                              self.tp_size)

157
        self.scale = self.head_dim**-0.5
158
159
160
        self.qkv = QKVParallelLinear(
            self.embed_dim,
            self.head_dim,
161
            num_dummy_heads + self.num_heads,
162
163
            bias=config.qkv_bias,
            quant_config=quant_config,
164
            prefix=f"{prefix}.qkv",
165
        )
166
167
168
169

        self.qk_normalization = config.qk_normalization

        if self.qk_normalization:
170
171
172
173
174
175
            self.q_norm = RMSNorm(self.dummy_dim,
                                  eps=config.layer_norm_eps,
                                  var_hidden_size=self.embed_dim)
            self.k_norm = RMSNorm(self.dummy_dim,
                                  eps=config.layer_norm_eps,
                                  var_hidden_size=self.embed_dim)
176

177
        self.proj = RowParallelLinear(
178
            self.dummy_dim,
179
180
            self.embed_dim,
            quant_config=quant_config,
181
            prefix=f"{prefix}.proj",
182
183
        )

184
185
        self.attn = MultiHeadAttention(self.num_heads_per_partition,
                                       self.head_dim, self.scale)
186

187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
    def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor):
        if self.tp_size > 1:
            q = tensor_model_parallel_all_gather(q.contiguous())
            k = tensor_model_parallel_all_gather(k.contiguous())
        q = self.q_norm.forward_native(q)
        k = self.k_norm.forward_native(k)
        if self.tp_size > 1:
            splitter = partial(split_tensor_along_last_dim,
                               num_partitions=self.tp_size)
            q = splitter(q)[self.tp_rank]
            k = splitter(k)[self.tp_rank]
        return q, k

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, N, _ = x.shape
202
203
        qkv, _ = self.qkv(x)
        q, k, v = qkv.chunk(3, dim=-1)
204

205
206
207
        if self.qk_normalization:
            q, k = self._apply_qk_norm(q, k)

208
        out = self.attn(q, k, v)
209
210
        out, _ = self.proj(out)
        return out
211
212


213
214
215
class InternSdpaAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

216
217
218
219
220
221
    def __init__(
        self,
        config: PretrainedConfig,
        *,
        num_dummy_heads: int = 0,
    ) -> None:
222
        super().__init__()
223

224
225
226
227
228
229
230
231
232
233
        self.config = config
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads
        if self.head_dim * self.num_heads != self.embed_dim:
            raise ValueError(
                f'embed_dim must be divisible by num_heads '
                f'(got `embed_dim`: {self.embed_dim} and `num_heads`:'
                f' {self.num_heads}).')

234
235
236
        # Additional dummy heads are used to enable TP for common GPU counts.
        self.dummy_dim = (num_dummy_heads + self.num_heads) * self.head_dim

237
238
        self.scale = self.head_dim**-0.5
        self.qkv = nn.Linear(self.embed_dim,
239
                             3 * self.dummy_dim,
240
241
242
243
244
                             bias=config.qkv_bias)

        self.qk_normalization = config.qk_normalization

        if self.qk_normalization:
245
246
247
248
249
250
            self.q_norm = RMSNorm(self.dummy_dim,
                                  eps=config.layer_norm_eps,
                                  var_hidden_size=self.embed_dim)
            self.k_norm = RMSNorm(self.dummy_dim,
                                  eps=config.layer_norm_eps,
                                  var_hidden_size=self.embed_dim)
251

252
        self.proj = nn.Linear(self.dummy_dim, self.embed_dim)
253

254
    def forward(self, x: torch.Tensor) -> torch.Tensor:
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
        B, N, C = x.shape
        qkv = self.qkv(x)
        q, k, v = qkv.chunk(3, dim=-1)

        q = q.view(B, N, self.num_heads, self.head_dim)
        k = k.view(B, N, self.num_heads, self.head_dim)
        v = v.view(B, N, self.num_heads, self.head_dim)

        if self.qk_normalization:
            B_, N_, H_, D_ = q.shape
            q = self.q_norm.forward_native(q.flatten(-2,
                                                     -1)).view(B_, N_, H_, D_)
            k = self.k_norm.forward_native(k.flatten(-2,
                                                     -1)).view(B_, N_, H_, D_)
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        x = F.scaled_dot_product_attention(q, k, v, scale=self.scale)
274
        x = x.transpose(1, 2).reshape(B, N, -1)
275
276
277
278
279

        x = self.proj(x)
        return x


280
281
class InternMLP(nn.Module):

282
283
284
285
286
287
    def __init__(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
288
        super().__init__()
289

290
291
292
293
294
        self.config = config
        self.activation_fn = get_act_fn(config.hidden_act)
        self.fc1 = ColumnParallelLinear(config.hidden_size,
                                        config.intermediate_size,
                                        bias=True,
295
296
                                        quant_config=quant_config,
                                        prefix=f"{prefix}.fc1")
297
298
299
        self.fc2 = RowParallelLinear(config.intermediate_size,
                                     config.hidden_size,
                                     bias=True,
300
301
                                     quant_config=quant_config,
                                     prefix=f"{prefix}.fc2")
302
303
304
305
306
307
308
309
310
311
312

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states, _ = self.fc2(hidden_states)

        return hidden_states


class InternVisionEncoderLayer(nn.Module):

313
314
315
316
317
318
    def __init__(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig] = None,
        *,
        num_dummy_heads: int = 0,
319
        prefix: str = "",
320
    ) -> None:
321
        super().__init__()
322

323
324
325
326
        self.embed_dim = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.norm_type = config.norm_type

327
328
        self.attn = self._init_attn(config,
                                    quant_config,
329
330
                                    num_dummy_heads=num_dummy_heads,
                                    prefix=f"{prefix}.attn")
331

332
333
334
        self.mlp = InternMLP(config,
                             quant_config=quant_config,
                             prefix=f"{prefix}.mlp")
335
336
337
338
339
340
341
342
343
344
        self.norm1 = NORM2FN[self.norm_type](self.embed_dim,
                                             eps=config.layer_norm_eps)
        self.norm2 = NORM2FN[self.norm_type](self.embed_dim,
                                             eps=config.layer_norm_eps)

        self.ls1 = nn.Parameter(config.initializer_factor *
                                torch.ones(self.embed_dim))
        self.ls2 = nn.Parameter(config.initializer_factor *
                                torch.ones(self.embed_dim))

345
346
347
348
349
350
    def _init_attn(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig],
        *,
        num_dummy_heads: int,
351
        prefix: str = "",
352
353
354
355
356
    ):
        # fallback to sdpa attention if tp unavailable
        tp_size = get_tensor_model_parallel_world_size()
        num_heads = config.num_attention_heads

357
        if (num_heads + num_dummy_heads) % tp_size == 0:
358
359
            return InternParallelAttention(config,
                                           quant_config=quant_config,
360
361
                                           num_dummy_heads=num_dummy_heads,
                                           prefix=prefix)
362
363
364

        return InternSdpaAttention(config, num_dummy_heads=num_dummy_heads)

365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
    def forward(
        self,
        hidden_states: torch.Tensor,
    ):
        hidden_states = hidden_states + self.attn(
            self.norm1(hidden_states)) * self.ls1

        hidden_states = hidden_states + self.mlp(
            self.norm2(hidden_states)) * self.ls2

        return hidden_states


class InternVisionEncoder(nn.Module):

380
381
382
383
384
385
386
    def __init__(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig] = None,
        *,
        num_hidden_layers_override: Optional[int] = None,
        num_dummy_heads: int = 0,
387
        prefix: str = "",
388
    ):
389
        super().__init__()
390

391
392
393
394
395
396
        self.config = config

        if num_hidden_layers_override is None:
            num_hidden_layers = config.num_hidden_layers
        else:
            num_hidden_layers = num_hidden_layers_override
397

398
        self.layers = nn.ModuleList([
399
400
            InternVisionEncoderLayer(config,
                                     quant_config,
401
402
403
                                     num_dummy_heads=num_dummy_heads,
                                     prefix=f"{prefix}.layers.{layer_idx}")
            for layer_idx in range(num_hidden_layers)
404
405
406
407
408
409
410
411
412
413
414
415
416
        ])

    def forward(self, inputs_embeds: torch.Tensor):

        hidden_states = inputs_embeds
        for encoder_layer in self.layers:
            hidden_states = encoder_layer(hidden_states)

        return hidden_states


class InternVisionModel(nn.Module):

417
418
419
420
421
422
423
    def __init__(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig] = None,
        *,
        num_hidden_layers_override: Optional[int] = None,
        num_dummy_heads: int = 0,
424
425
        prefix: str = "",
    ) -> None:
426
        super().__init__()
427

428
429
430
431
432
433
        self.config = config

        self.embeddings = InternVisionEmbeddings(config)
        self.encoder = InternVisionEncoder(
            config=config,
            quant_config=quant_config,
434
435
            num_hidden_layers_override=num_hidden_layers_override,
            num_dummy_heads=num_dummy_heads,
436
            prefix=f"{prefix}.encoder",
437
        )
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462

    def get_input_embeddings(self):
        return self.embeddings

    def forward(
        self,
        pixel_values: Optional[torch.Tensor] = None,
        pixel_embeds: Optional[torch.Tensor] = None,
    ) -> torch.FloatTensor:
        if pixel_values is None and pixel_embeds is None:
            raise ValueError(
                'You have to specify pixel_values or pixel_embeds')

        if pixel_embeds is not None:
            hidden_states = pixel_embeds
        elif pixel_values is not None:
            if pixel_values.ndim == 4:
                hidden_states = self.embeddings(pixel_values)
            else:
                raise ValueError(
                    f'wrong pixel_values size: {pixel_values.shape}')

        encoder_outputs = self.encoder(inputs_embeds=hidden_states)

        return encoder_outputs
463

464
465
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
466
        params_dict = dict(self.named_parameters())
467
        loaded_params: Set[str] = set()
468
469
470
471
472
        for name, loaded_weight in weights:
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)
473
474
            loaded_params.add(name)
        return loaded_params