qwen2_vl.py 20.9 KB
Newer Older
Yineng Zhang's avatar
Yineng Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/19e6e80e10118f855137b90740936c0b11ac397f/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
25
import logging
Yineng Zhang's avatar
Yineng Zhang committed
26
from functools import lru_cache, partial
27
from typing import Iterable, List, Optional, Tuple, Type, TypedDict
Yineng Zhang's avatar
Yineng Zhang committed
28
29
30
31

import torch
import torch.nn as nn
import torch.nn.functional as F
32
from einops import rearrange
Mick's avatar
Mick committed
33
34
from transformers import Qwen2VLConfig
from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLVisionConfig
Yineng Zhang's avatar
Yineng Zhang committed
35
36

from sglang.srt.hf_transformers_utils import get_processor
Yineng Zhang's avatar
Yineng Zhang committed
37
from sglang.srt.layers.activation import QuickGELU
Mick's avatar
Mick committed
38
from sglang.srt.layers.attention.vision import VisionAttention
Yineng Zhang's avatar
Yineng Zhang committed
39
40
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
from sglang.srt.layers.logits_processor import LogitsProcessor
41
from sglang.srt.layers.pooler import Pooler, PoolingType
Yineng Zhang's avatar
Yineng Zhang committed
42
from sglang.srt.layers.quantization.base_config import QuantizationConfig
43
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
44
from sglang.srt.managers.mm_utils import (
45
    MultiModalityDataPaddingPatternMultimodalTokens,
46
    general_mm_embed_routine,
47
)
Mick's avatar
Mick committed
48
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
Yineng Zhang's avatar
Yineng Zhang committed
49
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
50
from sglang.srt.model_loader.weight_utils import default_weight_loader
Yineng Zhang's avatar
Yineng Zhang committed
51
from sglang.srt.models.qwen2 import Qwen2Model
52
from sglang.srt.utils import add_prefix
Yineng Zhang's avatar
Yineng Zhang committed
53

54
logger = logging.getLogger(__name__)
Yineng Zhang's avatar
Yineng Zhang committed
55

Mick's avatar
Mick committed
56

Yineng Zhang's avatar
Yineng Zhang committed
57
58
59
60
61
# === Vision Inputs === #


class Qwen2VLImageInputs(TypedDict):
    pixel_values: torch.Tensor
62
    """Shape:
Yineng Zhang's avatar
Yineng Zhang committed
63
64
65
66
67
    `(num_patches, num_channels * patch_size * patch_size)`
    """

    image_grid_thw: torch.Tensor
    """Shape: `(num_images, 3)`
68

Yineng Zhang's avatar
Yineng Zhang committed
69
70
71
72
73
74
    This should be in `(grid_t, grid_h, grid_w)` format.
    """


class Qwen2VLVideoInputs(TypedDict):
    pixel_values_videos: torch.Tensor
75
76
    """Shape:
    `(num_patches,
Yineng Zhang's avatar
Yineng Zhang committed
77
78
79
80
81
      num_channels * temporal_patch_size * patch_size * patch_size)`
    """

    video_grid_thw: torch.Tensor
    """Shape: `(num_videos, 3)`
82

Yineng Zhang's avatar
Yineng Zhang committed
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
    This should be in `(grid_t, grid_h, grid_w)` format.
    """


# === Vision Encoder === #


class Qwen2VisionMLP(nn.Module):

    def __init__(
        self,
        in_features: int,
        hidden_features: int = None,
        act_layer: Type[nn.Module] = QuickGELU,
        quant_config: Optional[QuantizationConfig] = None,
98
        prefix: str = "",
Yineng Zhang's avatar
Yineng Zhang committed
99
100
101
    ):
        super().__init__()
        self.fc1 = ColumnParallelLinear(
102
103
104
105
            in_features,
            hidden_features,
            quant_config=quant_config,
            prefix=add_prefix("fc1", prefix),
Yineng Zhang's avatar
Yineng Zhang committed
106
107
108
        )
        self.act = act_layer()
        self.fc2 = RowParallelLinear(
109
110
111
112
            hidden_features,
            in_features,
            quant_config=quant_config,
            prefix=add_prefix("fc2", prefix),
Yineng Zhang's avatar
Yineng Zhang committed
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_parallel, _ = self.fc1(x)
        x_parallel = self.act(x_parallel)
        x, _ = self.fc2(x_parallel)
        return x


class Qwen2VisionBlock(nn.Module):

    def __init__(
        self,
        dim: int,
        num_heads: int,
        mlp_ratio: float,
        act_layer: Type[nn.Module] = QuickGELU,
        norm_layer: Type[nn.Module] = None,
131
        attn_implementation: Optional[str] = "sdpa",
Yineng Zhang's avatar
Yineng Zhang committed
132
        quant_config: Optional[QuantizationConfig] = None,
133
        prefix: str = "",
Yineng Zhang's avatar
Yineng Zhang committed
134
135
136
137
138
139
140
    ) -> None:
        super().__init__()
        if norm_layer is None:
            norm_layer = partial(nn.LayerNorm, eps=1e-6)
        self.norm1 = norm_layer(dim)
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
141
        if attn_implementation == "sdpa":
142
            qkv_backend = "sdpa"
143
            softmax_in_single_precision = False
144
        elif attn_implementation == "flash_attention_2":
145
            qkv_backend = "triton_attn"
146
            softmax_in_single_precision = False
147
        elif attn_implementation == "eager":
148
            qkv_backend = "sdpa"
149
            softmax_in_single_precision = True
Yineng Zhang's avatar
Yineng Zhang committed
150

Mick's avatar
Mick committed
151
        self.attn = VisionAttention(
Yineng Zhang's avatar
Yineng Zhang committed
152
153
154
            embed_dim=dim,
            num_heads=num_heads,
            projection_size=dim,
155
            use_qkv_parallel=True,
156
            qkv_backend=qkv_backend,
157
            softmax_in_single_precision=softmax_in_single_precision,
158
            flatten_batch=True,
Yineng Zhang's avatar
Yineng Zhang committed
159
            quant_config=quant_config,
160
            prefix=add_prefix("attn", prefix),
Yineng Zhang's avatar
Yineng Zhang committed
161
162
        )
        self.mlp = Qwen2VisionMLP(
163
164
165
166
167
            dim,
            mlp_hidden_dim,
            act_layer=act_layer,
            quant_config=quant_config,
            prefix=add_prefix("mlp", prefix),
Yineng Zhang's avatar
Yineng Zhang committed
168
169
170
        )

    def forward(
171
172
173
174
        self,
        x: torch.Tensor,
        cu_seqlens: torch.Tensor,
        position_embeddings: torch.Tensor,
Yineng Zhang's avatar
Yineng Zhang committed
175
    ) -> torch.Tensor:
Mick's avatar
Mick committed
176
177
178
        hidden_states = self.norm1(x)
        hidden_states = rearrange(hidden_states, "s b ... -> b s ...")
        attn = self.attn(
179
180
181
            hidden_states,
            cu_seqlens=cu_seqlens,
            position_embeddings=position_embeddings,
Yineng Zhang's avatar
Yineng Zhang committed
182
        )
Mick's avatar
Mick committed
183
184
        attn = rearrange(attn, "b s ... -> s b ...")
        x = x + attn
Yineng Zhang's avatar
Yineng Zhang committed
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
        x = x + self.mlp(self.norm2(x))
        return x


class Qwen2VisionPatchEmbed(nn.Module):

    def __init__(
        self,
        patch_size: int = 14,
        temporal_patch_size: int = 2,
        in_chans: int = 3,
        embed_dim: int = 1152,
    ) -> None:
        super().__init__()
        self.patch_size = patch_size
        self.temporal_patch_size = temporal_patch_size
        self.embed_dim = embed_dim

        kernel_size = [temporal_patch_size, patch_size, patch_size]
        self.proj = nn.Conv3d(
            in_chans, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        L, C = x.shape
        x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
        x = self.proj(x).view(L, self.embed_dim)
        return x


class Qwen2VisionPatchMerger(nn.Module):

    def __init__(
        self,
        d_model: int,
        context_dim: int,
        norm_layer: Type[nn.Module] = None,
        spatial_merge_size: int = 2,
        quant_config: Optional[QuantizationConfig] = None,
224
        prefix: str = "",
Yineng Zhang's avatar
Yineng Zhang committed
225
226
227
228
229
230
231
232
233
234
235
236
237
    ) -> None:
        super().__init__()
        self.hidden_size = context_dim * (spatial_merge_size**2)
        if norm_layer is None:
            norm_layer = partial(nn.LayerNorm, eps=1e-6)
        self.ln_q = norm_layer(context_dim)
        self.mlp = nn.ModuleList(
            [
                ColumnParallelLinear(
                    self.hidden_size,
                    self.hidden_size,
                    bias=True,
                    quant_config=quant_config,
238
                    prefix=add_prefix("mlp.0", prefix),
Yineng Zhang's avatar
Yineng Zhang committed
239
240
241
                ),
                nn.GELU(),
                RowParallelLinear(
242
243
244
245
246
                    self.hidden_size,
                    d_model,
                    bias=True,
                    quant_config=quant_config,
                    prefix=add_prefix("mlp.2", prefix),
Yineng Zhang's avatar
Yineng Zhang committed
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
                ),
            ]
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.ln_q(x)
        x = x.view(-1, self.hidden_size)

        mlp_fc1, mlp_act, mlp_fc2 = self.mlp
        x_parallel, _ = mlp_fc1(x)
        x_parallel = mlp_act(x_parallel)
        out, _ = mlp_fc2(x_parallel)
        return out


class Qwen2VisionRotaryEmbedding(nn.Module):

    def __init__(self, dim: int, theta: float = 10000.0) -> None:
        super().__init__()
        self.dim = dim
        self.theta = theta
        inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        self._seq_len_cached = 0
        self._freqs_cached = None

    def update_freqs_cache(self, seqlen: int) -> None:
        if seqlen > self._seq_len_cached:
            seqlen *= 2
            self._seq_len_cached = seqlen
            self.inv_freq = 1.0 / (
                self.theta
                ** (
                    torch.arange(
                        0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device
                    )
                    / self.dim
                )
            )
            seq = torch.arange(
                seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
            )
            freqs = torch.outer(seq, self.inv_freq)
            self._freqs_cached = freqs

    def forward(self, seqlen: int) -> torch.Tensor:
        self.update_freqs_cache(seqlen)
        return self._freqs_cached[:seqlen]


class Qwen2VisionTransformer(nn.Module):

    def __init__(
        self,
        vision_config: Qwen2VLVisionConfig,
        norm_eps: float = 1e-6,
        quant_config: Optional[QuantizationConfig] = None,
304
        prefix: str = "",
Yineng Zhang's avatar
Yineng Zhang committed
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
    ) -> None:
        super().__init__()

        patch_size: int = vision_config.patch_size
        temporal_patch_size: int = vision_config.temporal_patch_size
        spatial_merge_size: int = vision_config.spatial_merge_size
        in_chans: int = vision_config.in_chans
        hidden_size: int = vision_config.hidden_size
        embed_dim: int = vision_config.embed_dim
        depth: int = vision_config.depth
        num_heads: int = vision_config.num_heads
        mlp_ratio: float = vision_config.mlp_ratio

        self.spatial_merge_size = spatial_merge_size

        self.patch_embed = Qwen2VisionPatchEmbed(
            patch_size=patch_size,
            temporal_patch_size=temporal_patch_size,
            in_chans=in_chans,
            embed_dim=embed_dim,
        )

        norm_layer = partial(nn.LayerNorm, eps=norm_eps)
        head_dim = embed_dim // num_heads
        self.rotary_pos_emb = Qwen2VisionRotaryEmbedding(head_dim // 2)
        self.blocks = nn.ModuleList(
            [
                Qwen2VisionBlock(
                    dim=embed_dim,
                    num_heads=num_heads,
                    mlp_ratio=mlp_ratio,
                    norm_layer=norm_layer,
337
                    attn_implementation="sdpa",
Yineng Zhang's avatar
Yineng Zhang committed
338
                    quant_config=quant_config,
339
                    prefix=add_prefix(f"blocks.{i}", prefix),
Yineng Zhang's avatar
Yineng Zhang committed
340
                )
341
                for i in range(depth)
Yineng Zhang's avatar
Yineng Zhang committed
342
343
344
345
346
347
348
            ]
        )
        self.merger = Qwen2VisionPatchMerger(
            d_model=hidden_size,
            context_dim=embed_dim,
            norm_layer=norm_layer,
            quant_config=quant_config,
349
            prefix=add_prefix("merger", prefix),
Yineng Zhang's avatar
Yineng Zhang committed
350
351
352
353
        )

    @property
    def dtype(self) -> torch.dtype:
354
        return self.patch_embed.proj.weight.dtype
Yineng Zhang's avatar
Yineng Zhang committed
355
356
357
358
359
360
361

    @property
    def device(self) -> torch.device:
        return self.blocks[0].mlp.fc2.weight.device

    def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
        pos_ids = []
362
363
        for i in range(grid_thw.size(0)):
            t, h, w = grid_thw[i].tolist()
Yineng Zhang's avatar
Yineng Zhang committed
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
            hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
            wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
            hpos_ids = (
                hpos_ids.reshape(
                    h // self.spatial_merge_size,
                    self.spatial_merge_size,
                    w // self.spatial_merge_size,
                    self.spatial_merge_size,
                )
                .permute(0, 2, 1, 3)
                .flatten()
            )
            wpos_ids = (
                wpos_ids.reshape(
                    h // self.spatial_merge_size,
                    self.spatial_merge_size,
                    w // self.spatial_merge_size,
                    self.spatial_merge_size,
                )
                .permute(0, 2, 1, 3)
                .flatten()
            )
            pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
        pos_ids = torch.cat(pos_ids, dim=0)
        max_grid_size = grid_thw[:, 1:].max()
        rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
        rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
        return rotary_pos_emb

    def forward(
        self,
        x: torch.Tensor,
        grid_thw: torch.Tensor,
    ) -> torch.Tensor:
        # patchify
        x = x.to(device=self.device, dtype=self.dtype)
        x = self.patch_embed(x)

        # compute position embedding
        rotary_pos_emb = self.rot_pos_emb(grid_thw)
404
405
        emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
        position_embeddings = (emb.cos(), emb.sin())
Yineng Zhang's avatar
Yineng Zhang committed
406
407
408
409
410
411
412
413
414
        # compute cu_seqlens
        cu_seqlens = torch.repeat_interleave(
            grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
        ).cumsum(dim=0, dtype=torch.int32)
        cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)

        # transformers
        x = x.unsqueeze(1)
        for blk in self.blocks:
415
            x = blk(x, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings)
Yineng Zhang's avatar
Yineng Zhang committed
416
417
418
419
420
421
422
423
424

        # adapter
        x = self.merger(x)
        return x


cached_get_processor = lru_cache(get_processor)


425
class Qwen2VLForConditionalGeneration(nn.Module):
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
    # BitandBytes specific attributes
    default_bitsandbytes_target_modules = [
        ".gate_proj.",
        ".down_proj.",
        ".up_proj.",
        ".q_proj.",
        ".k_proj.",
        ".v_proj.",
        ".o_proj.",
    ]
    bitsandbytes_stacked_params_mapping = {
        # shard_name, weight_name, index
        "q_proj": ("qkv_proj", 0),
        "k_proj": ("qkv_proj", 1),
        "v_proj": ("qkv_proj", 2),
        "gate_proj": ("gate_up_proj", 0),
        "up_proj": ("gate_up_proj", 1),
    }

Yineng Zhang's avatar
Yineng Zhang committed
445
446
447
448
    def __init__(
        self,
        config: Qwen2VLConfig,
        quant_config: Optional[QuantizationConfig] = None,
449
        prefix: str = "",
Yineng Zhang's avatar
Yineng Zhang committed
450
451
452
453
454
455
456
    ) -> None:
        super().__init__()

        self.config = config
        self.visual = Qwen2VisionTransformer(
            config.vision_config,
            norm_eps=getattr(config, "rms_norm_eps", 1e-6),
457
458
459
            # NOTE: Qwen2-VL vision encoder currently supports BitsAndBytes 4-bit quantization.
            # Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported.
            quant_config=quant_config,
460
            prefix=add_prefix("visual", prefix),
Yineng Zhang's avatar
Yineng Zhang committed
461
462
        )

463
464
465
        self.model = Qwen2Model(
            config, quant_config, prefix=add_prefix("model", prefix)
        )
Yineng Zhang's avatar
Yineng Zhang committed
466
467
468
469
470

        if config.tie_word_embeddings:
            self.lm_head = self.model.embed_tokens
        else:
            self.lm_head = ParallelLMHead(
471
472
473
474
                config.vocab_size,
                config.hidden_size,
                quant_config=quant_config,
                prefix=add_prefix("lm_head", prefix),
Yineng Zhang's avatar
Yineng Zhang committed
475
476
            )

477
        self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
Yineng Zhang's avatar
Yineng Zhang committed
478
        self.logits_processor = LogitsProcessor(config)
479
        self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
Yineng Zhang's avatar
Yineng Zhang committed
480

Mick's avatar
Mick committed
481
    def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
482
        pattern = MultiModalityDataPaddingPatternMultimodalTokens()
Mick's avatar
Mick committed
483
        return pattern.pad_input_tokens(input_ids, mm_inputs)
484

Mick's avatar
Mick committed
485
486
487
488
489
    def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
        # in qwen-vl, last dim is the same
        pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type(
            self.visual.dtype
        )
490
        image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0)
Mick's avatar
Mick committed
491
        assert pixel_values.dim() == 2, pixel_values.dim()
492
493
        assert image_grid_thw.dim() == 2, image_grid_thw.dim()
        image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
Yineng Zhang's avatar
Yineng Zhang committed
494
495
        return image_embeds

496
497
498
499
500
501
502
503
504
505
506
    def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
        # in qwen-vl, last dim is the same
        pixel_values = torch.cat(
            [item.pixel_values_videos for item in items], dim=0
        ).type(self.visual.dtype)
        video_grid_thw = torch.concat([item.video_grid_thw for item in items], dim=0)
        assert pixel_values.dim() == 2, pixel_values.dim()
        assert video_grid_thw.dim() == 2, video_grid_thw.dim()
        video_embeds = self.visual(pixel_values, grid_thw=video_grid_thw)
        return video_embeds

Yineng Zhang's avatar
Yineng Zhang committed
507
508
509
510
511
512
513
    def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor:
        pixel_values_videos = video_input["pixel_values_videos"].type(self.visual.dtype)
        video_embeds = self.visual(
            pixel_values_videos, grid_thw=video_input["video_grid_thw"]
        )
        return video_embeds

514
515
516
    def get_input_embeddings(self):
        return self.model.embed_tokens

Yineng Zhang's avatar
Yineng Zhang committed
517
518
519
520
521
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        forward_batch: ForwardBatch,
522
        get_embedding: bool = False,
Yineng Zhang's avatar
Yineng Zhang committed
523
524
525
526
527
528
529
530
531
532
533
534
535
    ):
        """Run forward pass for Qwen2-VL.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
            positions: Flattened (concatenated) position ids corresponding to a
                batch.
                **NOTE**: If mrope is enabled (default setting for Qwen2-VL
                opensource models), the shape will be `(3, seq_len)`,
                otherwise it will be `(seq_len,).
                (Use input_metadata.mrope_positions to replace it)
        """
536
        if self.is_mrope_enabled:
537
538
            positions = forward_batch.mrope_positions

539
        if not (
540
            forward_batch.forward_mode.is_decode()
541
            or not forward_batch.contains_image_inputs()
542
        ):
543
            if self.is_mrope_enabled:
Yineng Zhang's avatar
Yineng Zhang committed
544
545
546
547
                assert positions.ndim == 2 and positions.size(0) == 3, (
                    "multimodal section rotary embedding requires "
                    f"(3, seq_len) positions, but got {positions.size()}"
                )
Mick's avatar
Mick committed
548
        hidden_states = general_mm_embed_routine(
549
550
            input_ids=input_ids,
            forward_batch=forward_batch,
Mick's avatar
Mick committed
551
            language_model=self.model,
552
            multimodal_model=self,
Yineng Zhang's avatar
Yineng Zhang committed
553
554
            positions=positions,
        )
555

Mick's avatar
Mick committed
556
557
558
        if get_embedding:
            return self.pooler(hidden_states, forward_batch)
        else:
559
            return self.logits_processor(
560
                input_ids, hidden_states, self.lm_head, forward_batch
561
            )
Yineng Zhang's avatar
Yineng Zhang committed
562
563
564
565
566
567
568
569
570
571
572
573
574
575

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "up_proj", 1),
            ("gate_up_proj", "gate_proj", 0),
        ]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue
576
577
            if self.config.tie_word_embeddings and "lm_head.weight" in name:
                continue
Mick's avatar
Mick committed
578

Yineng Zhang's avatar
Yineng Zhang committed
579
580
581
582
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
Mick's avatar
Mick committed
583

Yineng Zhang's avatar
Yineng Zhang committed
584
585
586
587
588
589
590
591
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
Mick's avatar
Mick committed
592
593
594
595
                if "visual" in name:
                    # adapt to VisionAttention
                    name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")

Yineng Zhang's avatar
Yineng Zhang committed
596
597
598
599
600
601
602
603
604
605
606
607
608
609
                try:
                    # Skip loading extra bias for GPTQ models.
                    if name.endswith(".bias") and name not in params_dict:
                        continue
                    param = params_dict[name]
                except KeyError:
                    print(params_dict.keys())
                    raise

                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(param, loaded_weight)


EntryClass = Qwen2VLForConditionalGeneration