qwen2_5_vl.py 21.9 KB
Newer Older
Mick's avatar
Mick committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/19e6e80e10118f855137b90740936c0b11ac397f/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
import logging
from functools import lru_cache, partial
from typing import Iterable, List, Optional, Tuple, Type

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from transformers.activations import ACT2FN
from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm
35
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
36
    Qwen2_5_VLConfig,
37
38
    Qwen2_5_VLVisionConfig,
)
39
40
41
42
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
    Qwen2_5_VisionPatchEmbed,
    Qwen2_5_VisionRotaryEmbedding,
)
Mick's avatar
Mick committed
43
44
45
46
47
48
49
50

from sglang.srt.hf_transformers_utils import get_processor
from sglang.srt.layers.attention.vision import VisionAttention
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
51
from sglang.srt.managers.mm_utils import (
52
    MultiModalityDataPaddingPatternMultimodalTokens,
53
    general_mm_embed_routine,
54
)
Mick's avatar
Mick committed
55
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
Mick's avatar
Mick committed
56
57
58
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2 import Qwen2Model
59
from sglang.srt.models.qwen2_vl import Qwen2VLVideoInputs
60
from sglang.srt.utils import add_prefix
Mick's avatar
Mick committed
61
62
63
64
65
66
67
68
69
70
71
72
73

logger = logging.getLogger(__name__)


class Qwen2_5_VLMLP(nn.Module):

    def __init__(
        self,
        in_features: int,
        hidden_features: int = None,
        bias: bool = True,
        hidden_act="silu",
        quant_config: Optional[QuantizationConfig] = None,
74
        prefix: str = "",
Mick's avatar
Mick committed
75
76
77
    ):
        super().__init__()
        self.gate_proj = ColumnParallelLinear(
78
79
80
81
82
            in_features,
            hidden_features,
            bias=bias,
            quant_config=quant_config,
            prefix=add_prefix("gate_proj", prefix),
Mick's avatar
Mick committed
83
84
        )
        self.up_proj = ColumnParallelLinear(
85
86
87
88
89
            in_features,
            hidden_features,
            bias=bias,
            quant_config=quant_config,
            prefix=add_prefix("up_proj", prefix),
Mick's avatar
Mick committed
90
91
        )
        self.down_proj = RowParallelLinear(
92
93
94
95
96
            hidden_features,
            in_features,
            bias=bias,
            quant_config=quant_config,
            prefix=add_prefix("down_proj", prefix),
Mick's avatar
Mick committed
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
        )
        self.act = ACT2FN[hidden_act]

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_parallel_gate, _ = self.gate_proj(x)
        x_parallel_gate = self.act(x_parallel_gate)
        x_parallel_up, _ = self.up_proj(x)
        x_parallel = x_parallel_gate * x_parallel_up
        x, _ = self.down_proj(x_parallel)
        return x


class Qwen2_5_VisionBlock(nn.Module):

    def __init__(
        self,
        dim: int,
        intermediate_dim: int,
        num_heads: int,
        hidden_act="silu",
        norm_layer: Type[nn.Module] = None,
        attn_implementation: Optional[str] = "sdpa",
        quant_config: Optional[QuantizationConfig] = None,
120
        prefix: str = "",
Mick's avatar
Mick committed
121
122
123
124
125
126
127
128
    ) -> None:
        super().__init__()
        if norm_layer is None:
            norm_layer = partial(nn.LayerNorm, eps=1e-6)
        self.norm1 = Qwen2RMSNorm(dim, eps=1e-6)
        self.norm2 = Qwen2RMSNorm(dim, eps=1e-6)
        if attn_implementation == "sdpa":
            use_context_forward = False
129
            softmax_in_single_precision = False
130
            flatten_batch = True
Mick's avatar
Mick committed
131
        elif attn_implementation == "flash_attention_2":
132
            softmax_in_single_precision = False
Mick's avatar
Mick committed
133
            use_context_forward = True
134
            flatten_batch = True
Mick's avatar
Mick committed
135
        elif attn_implementation == "eager":
136
            softmax_in_single_precision = True
Mick's avatar
Mick committed
137
            use_context_forward = False
138
            flatten_batch = True
Mick's avatar
Mick committed
139
140
141
142
143

        self.attn = VisionAttention(
            embed_dim=dim,
            num_heads=num_heads,
            projection_size=dim,
144
            use_qkv_parallel=True,
Mick's avatar
Mick committed
145
            use_context_forward=use_context_forward,
146
            softmax_in_single_precision=softmax_in_single_precision,
147
            flatten_batch=flatten_batch,
Mick's avatar
Mick committed
148
            quant_config=quant_config,
149
            prefix=add_prefix("attn", prefix),
Mick's avatar
Mick committed
150
151
        )
        self.mlp = Qwen2_5_VLMLP(
152
153
154
155
156
            dim,
            intermediate_dim,
            hidden_act=hidden_act,
            quant_config=quant_config,
            prefix=add_prefix("mlp", prefix),
Mick's avatar
Mick committed
157
158
159
        )

    def forward(
160
161
162
163
        self,
        x: torch.Tensor,
        cu_seqlens: torch.Tensor,
        position_embeddings: torch.Tensor,
Mick's avatar
Mick committed
164
165
166
167
    ) -> torch.Tensor:
        hidden_states = self.norm1(x)
        hidden_states = rearrange(hidden_states, "s b ... -> b s ...")
        attn = self.attn(
168
169
170
            hidden_states,
            cu_seqlens=cu_seqlens,
            position_embeddings=position_embeddings,
Mick's avatar
Mick committed
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
        )
        attn = rearrange(attn, "b s ... -> s b ...")
        x = x + attn
        norm2 = self.norm2(x)
        mlp = self.mlp(norm2)
        x = x + mlp
        return x


class Qwen2_5_VisionPatchMerger(nn.Module):

    def __init__(
        self,
        dim: int,
        context_dim: int,
        spatial_merge_size: int = 2,
        quant_config: Optional[QuantizationConfig] = None,
188
        prefix: str = "",
Mick's avatar
Mick committed
189
190
191
192
193
194
195
196
197
198
199
    ) -> None:
        super().__init__()
        self.hidden_size = context_dim * (spatial_merge_size**2)
        self.ln_q = Qwen2RMSNorm(context_dim, eps=1e-6)
        self.mlp = nn.ModuleList(
            [
                ColumnParallelLinear(
                    self.hidden_size,
                    self.hidden_size,
                    bias=True,
                    quant_config=quant_config,
200
                    prefix=add_prefix("mlp.0", prefix),
Mick's avatar
Mick committed
201
202
203
                ),
                nn.GELU(),
                RowParallelLinear(
204
205
206
207
208
                    self.hidden_size,
                    dim,
                    bias=True,
                    quant_config=quant_config,
                    prefix=add_prefix("mlp.2", prefix),
Mick's avatar
Mick committed
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
                ),
            ]
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.ln_q(x)
        x = x.view(-1, self.hidden_size)

        mlp_fc1, mlp_act, mlp_fc2 = self.mlp
        x_parallel, _ = mlp_fc1(x)
        x_parallel = mlp_act(x_parallel)
        out, _ = mlp_fc2(x_parallel)
        return out


class Qwen2_5_VisionTransformer(nn.Module):

    def __init__(
        self,
        vision_config: Qwen2_5_VLVisionConfig,
        norm_eps: float = 1e-6,
        quant_config: Optional[QuantizationConfig] = None,
231
        prefix: str = "",
Mick's avatar
Mick committed
232
233
234
235
236
237
238
239
    ) -> None:
        super().__init__()

        patch_size: int = vision_config.patch_size
        temporal_patch_size: int = vision_config.temporal_patch_size
        spatial_merge_size: int = vision_config.spatial_merge_size
        self.spatial_merge_size = spatial_merge_size
        self.spatial_merge_unit: int = spatial_merge_size * spatial_merge_size
240
        in_channels: int = vision_config.in_channels
Mick's avatar
Mick committed
241
242
243
244
245
246
247
248
249
250
        hidden_size: int = vision_config.hidden_size
        depth: int = vision_config.depth
        num_heads: int = vision_config.num_heads
        self.fullatt_block_indexes = vision_config.fullatt_block_indexes
        self.window_size = vision_config.window_size
        self.patch_size = vision_config.patch_size
        mlp_hidden_size: int = vision_config.intermediate_size
        self.patch_embed = Qwen2_5_VisionPatchEmbed(
            patch_size=patch_size,
            temporal_patch_size=temporal_patch_size,
251
            in_channels=in_channels,
Mick's avatar
Mick committed
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
            embed_dim=hidden_size,
        )

        norm_layer = partial(nn.LayerNorm, eps=norm_eps)
        head_dim = hidden_size // num_heads
        self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
        self.blocks = nn.ModuleList(
            [
                Qwen2_5_VisionBlock(
                    dim=hidden_size,
                    intermediate_dim=mlp_hidden_size,
                    num_heads=num_heads,
                    hidden_act=vision_config.hidden_act,
                    norm_layer=norm_layer,
                    attn_implementation="sdpa",
                    quant_config=quant_config,
268
                    prefix=add_prefix(f"blocks.{i}", prefix),
Mick's avatar
Mick committed
269
                )
270
                for i in range(depth)
Mick's avatar
Mick committed
271
272
273
274
275
276
277
            ]
        )
        self.merger = Qwen2_5_VisionPatchMerger(
            dim=vision_config.out_hidden_size,
            context_dim=hidden_size,
            spatial_merge_size=spatial_merge_size,
            quant_config=quant_config,
278
            prefix=add_prefix("merger", prefix),
Mick's avatar
Mick committed
279
280
281
282
283
284
285
286
        )

    def get_window_index(self, grid_thw):
        cu_window_seqlens: list = [0]
        window_index_id = 0
        vit_merger_window_size = (
            self.window_size // self.spatial_merge_size // self.patch_size
        )
287
        window_index: list = []
Mick's avatar
Mick committed
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
        for grid_t, grid_h, grid_w in grid_thw:
            llm_grid_h, llm_grid_w = (
                grid_h // self.spatial_merge_size,
                grid_w // self.spatial_merge_size,
            )
            index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
                grid_t, llm_grid_h, llm_grid_w
            )
            pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
            pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
            num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
            num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
            index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
            index_padded = index_padded.reshape(
                grid_t,
                num_windows_h,
                vit_merger_window_size,
                num_windows_w,
                vit_merger_window_size,
            )
            index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
                grid_t,
                num_windows_h * num_windows_w,
                vit_merger_window_size,
                vit_merger_window_size,
            )
            seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
            index_padded = index_padded.reshape(-1)
            index_new = index_padded[index_padded != -100]
            window_index.append(index_new + window_index_id)
            cu_seqlens_tmp = (
                seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1]
            )
            cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
            window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
        window_index = torch.cat(window_index, dim=0)
        return window_index, cu_window_seqlens

    @property
    def dtype(self) -> torch.dtype:
328
        return self.patch_embed.proj.weight.dtype
Mick's avatar
Mick committed
329
330
331
332
333
334
335

    @property
    def device(self) -> torch.device:
        return self.blocks[0].mlp.gate_proj.weight.device

    def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
        pos_ids = []
336
337
        for i in range(grid_thw.size(0)):
            t, h, w = grid_thw[i].tolist()
Mick's avatar
Mick committed
338
            hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
339

340
341
342
343
344
            hpos_ids = hpos_ids.reshape(
                h // self.spatial_merge_size,
                self.spatial_merge_size,
                w // self.spatial_merge_size,
                self.spatial_merge_size,
Mick's avatar
Mick committed
345
            )
346
347
348
349
350
351
352
353
354
            hpos_ids = hpos_ids.permute(0, 2, 1, 3)
            hpos_ids = hpos_ids.flatten()

            wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
            wpos_ids = wpos_ids.reshape(
                h // self.spatial_merge_size,
                self.spatial_merge_size,
                w // self.spatial_merge_size,
                self.spatial_merge_size,
Mick's avatar
Mick committed
355
            )
356
357
            wpos_ids = wpos_ids.permute(0, 2, 1, 3)
            wpos_ids = wpos_ids.flatten()
358

Mick's avatar
Mick committed
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
            pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
        pos_ids = torch.cat(pos_ids, dim=0)
        max_grid_size = grid_thw[:, 1:].max()
        rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
        rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
        return rotary_pos_emb

    def forward(
        self,
        x: torch.Tensor,
        grid_thw: torch.Tensor,
    ) -> torch.Tensor:
        # patchify
        x = x.to(device=self.device, dtype=self.dtype)
        x = self.patch_embed(x)

        # compute position embedding
        rotary_pos_emb = self.rot_pos_emb(grid_thw)

        window_index, cu_window_seqlens = self.get_window_index(grid_thw)
        cu_window_seqlens = torch.tensor(
            cu_window_seqlens,
            device=x.device,
382
            dtype=torch.int32,
Mick's avatar
Mick committed
383
384
385
386
387
388
389
390
391
392
393
394
395
        )
        cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)

        seq_len, _ = x.size()

        x = x.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
        x = x[window_index, :, :]
        x = x.reshape(seq_len, -1)
        rotary_pos_emb = rotary_pos_emb.reshape(
            seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1
        )
        rotary_pos_emb = rotary_pos_emb[window_index, :, :]
        rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
396
397
        emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
        position_embeddings = (emb.cos(), emb.sin())
Mick's avatar
Mick committed
398
399

        # compute cu_seqlens
400
401
402
403
404
405
        cu_seqlens = torch.cat(
            [
                torch.tensor([0], device=grid_thw.device),
                (grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]).cumsum(dim=0),
            ]
        )
Mick's avatar
Mick committed
406
407
408
409
410
411
412
413
414
        cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)

        # transformers
        x = x.unsqueeze(1)
        for layer_num, blk in enumerate(self.blocks):
            if layer_num in self.fullatt_block_indexes:
                cu_seqlens_now = cu_seqlens
            else:
                cu_seqlens_now = cu_window_seqlens
415
416
417
            x = blk(
                x, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings
            )
Mick's avatar
Mick committed
418
419
420
421
422
423
424
425
426
427
428
429
430
431

        # adapter
        x = self.merger(x)

        reverse_indices = torch.argsort(window_index)
        x = x[reverse_indices, :]

        return x


cached_get_processor = lru_cache(get_processor)


class Qwen2_5_VLForConditionalGeneration(nn.Module):
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
    # BitandBytes specific attributes
    default_bitsandbytes_target_modules = [
        ".gate_proj.",
        ".down_proj.",
        ".up_proj.",
        ".q_proj.",
        ".k_proj.",
        ".v_proj.",
        ".o_proj.",
    ]
    bitsandbytes_stacked_params_mapping = {
        # shard_name, weight_name, index
        "q_proj": ("qkv_proj", 0),
        "k_proj": ("qkv_proj", 1),
        "v_proj": ("qkv_proj", 2),
        "gate_proj": ("gate_up_proj", 0),
        "up_proj": ("gate_up_proj", 1),
    }

Mick's avatar
Mick committed
451
452
    def __init__(
        self,
453
        config: Qwen2_5_VLConfig,
Mick's avatar
Mick committed
454
        quant_config: Optional[QuantizationConfig] = None,
455
        prefix: str = "",
Mick's avatar
Mick committed
456
457
458
459
460
461
462
    ) -> None:
        super().__init__()

        self.config = config
        self.visual = Qwen2_5_VisionTransformer(
            config.vision_config,
            norm_eps=getattr(config, "rms_norm_eps", 1e-6),
463
464
465
            # NOTE: Qwen2_5-VL vision encoder currently supports BitsAndBytes 4-bit quantization.
            # Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported.
            quant_config=quant_config,
466
            prefix=add_prefix("visual", prefix),
Mick's avatar
Mick committed
467
468
        )

469
470
471
472
473
        self.model = Qwen2Model(
            config,
            quant_config,
            prefix=add_prefix("model", prefix),
        )
Mick's avatar
Mick committed
474
475
476
477
478

        if config.tie_word_embeddings:
            self.lm_head = self.model.embed_tokens
        else:
            self.lm_head = ParallelLMHead(
479
480
481
482
                config.vocab_size,
                config.hidden_size,
                quant_config=quant_config,
                prefix=add_prefix("lm_head", prefix),
Mick's avatar
Mick committed
483
            )
484
        self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
Mick's avatar
Mick committed
485
486
487
488

        self.logits_processor = LogitsProcessor(config)
        self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)

Mick's avatar
Mick committed
489
    def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
Mick's avatar
Mick committed
490
        # Get all special token IDs
491
492
        im_token_id: int = mm_inputs.im_token_id
        pattern = MultiModalityDataPaddingPatternMultimodalTokens([im_token_id])
Mick's avatar
Mick committed
493
        return pattern.pad_input_tokens(input_ids, mm_inputs)
Mick's avatar
Mick committed
494

Mick's avatar
Mick committed
495
496
497
498
499
500
501
502
503
    def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
        # in qwen-vl, last dim is the same
        pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type(
            self.visual.dtype
        )
        image_grid_thws = torch.concat([item.image_grid_thws for item in items], dim=0)
        assert pixel_values.dim() == 2, pixel_values.dim()
        assert image_grid_thws.dim() == 2, image_grid_thws.dim()
        image_embeds = self.visual(pixel_values, grid_thw=image_grid_thws)
Mick's avatar
Mick committed
504
505
506
507
508
509
510
511
512
        return image_embeds

    def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor:
        pixel_values_videos = video_input["pixel_values_videos"].type(self.visual.dtype)
        video_embeds = self.visual(
            pixel_values_videos, grid_thw=video_input["video_grid_thw"]
        )
        return video_embeds

513
514
515
    def get_input_embeddings(self):
        return self.model.embed_tokens

Mick's avatar
Mick committed
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        forward_batch: ForwardBatch,
        get_embedding: bool = False,
    ):
        """Run forward pass for Qwen2_5-VL.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
            positions: Flattened (concatenated) position ids corresponding to a
                batch.
                **NOTE**: If mrope is enabled (default setting for Qwen2-VL
                opensource models), the shape will be `(3, seq_len)`,
                otherwise it will be `(seq_len,).
                (Use input_metadata.mrope_positions to replace it)
        """
535
        if self.is_mrope_enabled:
Mick's avatar
Mick committed
536
537
            positions = forward_batch.mrope_positions

538
        if not (
Mick's avatar
Mick committed
539
            forward_batch.forward_mode.is_decode()
540
            or not forward_batch.contains_image_inputs()
Mick's avatar
Mick committed
541
        ):
542
            if self.is_mrope_enabled:
Mick's avatar
Mick committed
543
544
545
546
547
                assert positions.ndim == 2 and positions.size(0) == 3, (
                    "multimodal section rotary embedding requires "
                    f"(3, seq_len) positions, but got {positions.size()}"
                )

Mick's avatar
Mick committed
548
        hidden_states = general_mm_embed_routine(
549
550
            input_ids=input_ids,
            forward_batch=forward_batch,
Mick's avatar
Mick committed
551
552
            language_model=self.model,
            image_data_embedding_func=self.get_image_feature,
Mick's avatar
Mick committed
553
554
555
556
557
558
559
560
561
562
563
564
565
            positions=positions,
        )

        if not get_embedding:
            return self.logits_processor(
                input_ids, hidden_states, self.lm_head, forward_batch
            )
        else:
            return self.pooler(hidden_states, forward_batch)

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
Mick's avatar
Mick committed
566
567
568
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
Mick's avatar
Mick committed
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
            ("gate_up_proj", "up_proj", 1),
            ("gate_up_proj", "gate_proj", 0),
        ]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue

            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                if "visual" in name:
                    continue
                name = name.replace(weight_name, param_name)

                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                if "visual" in name:
                    # adapt to VisionAttention
                    name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")

                try:
                    # Skip loading extra bias for GPTQ models.
                    if name.endswith(".bias") and name not in params_dict:
                        continue
                    param = params_dict[name]
                except KeyError:
                    print(params_dict.keys())
                    raise

                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(param, loaded_weight)


EntryClass = [Qwen2_5_VLForConditionalGeneration]