qwen2_5_vl.py 23.6 KB
Newer Older
Mick's avatar
Mick committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/19e6e80e10118f855137b90740936c0b11ac397f/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
import logging
from functools import lru_cache, partial
from typing import Iterable, List, Optional, Tuple, Type

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from transformers.activations import ACT2FN
34
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
35
    Qwen2_5_VLConfig,
36
37
    Qwen2_5_VLVisionConfig,
)
38
39
40
41
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
    Qwen2_5_VisionPatchEmbed,
    Qwen2_5_VisionRotaryEmbedding,
)
Mick's avatar
Mick committed
42
43
44

from sglang.srt.hf_transformers_utils import get_processor
from sglang.srt.layers.attention.vision import VisionAttention
45
from sglang.srt.layers.layernorm import RMSNorm
Mick's avatar
Mick committed
46
47
48
49
50
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
51
from sglang.srt.managers.mm_utils import (
52
    MultiModalityDataPaddingPatternMultimodalTokens,
53
    general_mm_embed_routine,
54
)
Mick's avatar
Mick committed
55
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
Mick's avatar
Mick committed
56
57
58
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2 import Qwen2Model
59
from sglang.srt.utils import add_prefix
Mick's avatar
Mick committed
60
61
62
63
64
65
66
67
68
69
70
71
72

logger = logging.getLogger(__name__)


class Qwen2_5_VLMLP(nn.Module):

    def __init__(
        self,
        in_features: int,
        hidden_features: int = None,
        bias: bool = True,
        hidden_act="silu",
        quant_config: Optional[QuantizationConfig] = None,
73
        prefix: str = "",
Mick's avatar
Mick committed
74
75
76
    ):
        super().__init__()
        self.gate_proj = ColumnParallelLinear(
77
78
79
80
81
            in_features,
            hidden_features,
            bias=bias,
            quant_config=quant_config,
            prefix=add_prefix("gate_proj", prefix),
Mick's avatar
Mick committed
82
83
        )
        self.up_proj = ColumnParallelLinear(
84
85
86
87
88
            in_features,
            hidden_features,
            bias=bias,
            quant_config=quant_config,
            prefix=add_prefix("up_proj", prefix),
Mick's avatar
Mick committed
89
90
        )
        self.down_proj = RowParallelLinear(
91
92
93
94
95
            hidden_features,
            in_features,
            bias=bias,
            quant_config=quant_config,
            prefix=add_prefix("down_proj", prefix),
Mick's avatar
Mick committed
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
        )
        self.act = ACT2FN[hidden_act]

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_parallel_gate, _ = self.gate_proj(x)
        x_parallel_gate = self.act(x_parallel_gate)
        x_parallel_up, _ = self.up_proj(x)
        x_parallel = x_parallel_gate * x_parallel_up
        x, _ = self.down_proj(x_parallel)
        return x


class Qwen2_5_VisionBlock(nn.Module):

    def __init__(
        self,
        dim: int,
        intermediate_dim: int,
        num_heads: int,
        hidden_act="silu",
        norm_layer: Type[nn.Module] = None,
117
        attn_implementation: Optional[str] = None,
Mick's avatar
Mick committed
118
        quant_config: Optional[QuantizationConfig] = None,
119
        prefix: str = "",
120
        num_dummy_heads: int = 0,
Mick's avatar
Mick committed
121
122
123
124
    ) -> None:
        super().__init__()
        if norm_layer is None:
            norm_layer = partial(nn.LayerNorm, eps=1e-6)
125
126
        self.norm1 = RMSNorm(dim, eps=1e-6)
        self.norm2 = RMSNorm(dim, eps=1e-6)
127
128
129
130
131
132

        if attn_implementation is None:
            softmax_in_single_precision = False
            qkv_backend = None
            flatten_batch = True
        elif attn_implementation == "sdpa":
133
            softmax_in_single_precision = False
134
            qkv_backend = "sdpa"
135
            flatten_batch = True
Mick's avatar
Mick committed
136
        elif attn_implementation == "flash_attention_2":
137
            softmax_in_single_precision = False
138
            qkv_backend = "triton_attn"
139
            flatten_batch = True
Mick's avatar
Mick committed
140
        elif attn_implementation == "eager":
141
            softmax_in_single_precision = True
142
143
144
145
146
            qkv_backend = "sdpa"
            flatten_batch = True
        elif attn_implementation == "flash_attention_3":
            softmax_in_single_precision = False
            qkv_backend = "fa3"
147
            flatten_batch = True
Mick's avatar
Mick committed
148
149
150
151
152

        self.attn = VisionAttention(
            embed_dim=dim,
            num_heads=num_heads,
            projection_size=dim,
153
            use_qkv_parallel=True,
154
155
            rotary_embed="normal",
            proj_bias=True,
156
            qkv_backend=qkv_backend,
157
            softmax_in_single_precision=softmax_in_single_precision,
158
            flatten_batch=flatten_batch,
Mick's avatar
Mick committed
159
            quant_config=quant_config,
160
            prefix=add_prefix("attn", prefix),
161
            num_dummy_heads=num_dummy_heads,
Mick's avatar
Mick committed
162
163
        )
        self.mlp = Qwen2_5_VLMLP(
164
165
166
167
168
            dim,
            intermediate_dim,
            hidden_act=hidden_act,
            quant_config=quant_config,
            prefix=add_prefix("mlp", prefix),
Mick's avatar
Mick committed
169
170
171
        )

    def forward(
172
173
174
175
        self,
        x: torch.Tensor,
        cu_seqlens: torch.Tensor,
        position_embeddings: torch.Tensor,
Mick's avatar
Mick committed
176
    ) -> torch.Tensor:
177
178
179
180
181
182
183
        S, B, H = x.shape
        # norm1: flatten to 2D -> [S*B, H], then reshape back
        x2d = x.reshape(-1, H)
        hidden_states = self.norm1(x2d).reshape(S, B, H)

        # Attention expects [B, S, H]
        hidden_states = rearrange(hidden_states, "s b h -> b s h")
Mick's avatar
Mick committed
184
        attn = self.attn(
185
186
187
            hidden_states,
            cu_seqlens=cu_seqlens,
            position_embeddings=position_embeddings,
Mick's avatar
Mick committed
188
        )
189
190
191
192
193
194
195
196
197
198
199
        attn = rearrange(attn, "b s h -> s b h")

        # norm2 with fused residual-add: also 2D
        attn2d = attn.reshape(-1, H)
        x_norm_2d, x_after_add_2d = self.norm2(x2d, residual=attn2d)
        x_norm = x_norm_2d.reshape(S, B, H)
        x_after_add = x_after_add_2d.reshape(S, B, H)

        # MLP and final residual
        mlp_out = self.mlp(x_norm)
        x = x_after_add + mlp_out
Mick's avatar
Mick committed
200
201
202
203
204
205
206
207
208
209
210
        return x


class Qwen2_5_VisionPatchMerger(nn.Module):

    def __init__(
        self,
        dim: int,
        context_dim: int,
        spatial_merge_size: int = 2,
        quant_config: Optional[QuantizationConfig] = None,
211
        prefix: str = "",
Mick's avatar
Mick committed
212
213
214
    ) -> None:
        super().__init__()
        self.hidden_size = context_dim * (spatial_merge_size**2)
215
        self.ln_q = RMSNorm(context_dim, eps=1e-6)
Mick's avatar
Mick committed
216
217
218
219
220
221
222
        self.mlp = nn.ModuleList(
            [
                ColumnParallelLinear(
                    self.hidden_size,
                    self.hidden_size,
                    bias=True,
                    quant_config=quant_config,
223
                    prefix=add_prefix("mlp.0", prefix),
Mick's avatar
Mick committed
224
225
226
                ),
                nn.GELU(),
                RowParallelLinear(
227
228
229
230
231
                    self.hidden_size,
                    dim,
                    bias=True,
                    quant_config=quant_config,
                    prefix=add_prefix("mlp.2", prefix),
Mick's avatar
Mick committed
232
233
234
235
236
                ),
            ]
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
237
238
239
240
241
        # x expected shape: [S, B, context_dim]
        S, B, D = x.shape
        x2d = x.reshape(-1, D)
        x2d = self.ln_q(x2d)  # RMSNorm expects 2D
        x2d = x2d.view(-1, self.hidden_size)  # group into spatial_merge_unit
Mick's avatar
Mick committed
242
        mlp_fc1, mlp_act, mlp_fc2 = self.mlp
243
        x_parallel, _ = mlp_fc1(x2d)
Mick's avatar
Mick committed
244
245
246
247
248
249
250
251
252
253
254
255
        x_parallel = mlp_act(x_parallel)
        out, _ = mlp_fc2(x_parallel)
        return out


class Qwen2_5_VisionTransformer(nn.Module):

    def __init__(
        self,
        vision_config: Qwen2_5_VLVisionConfig,
        norm_eps: float = 1e-6,
        quant_config: Optional[QuantizationConfig] = None,
256
        prefix: str = "",
Mick's avatar
Mick committed
257
258
259
260
261
262
263
264
    ) -> None:
        super().__init__()

        patch_size: int = vision_config.patch_size
        temporal_patch_size: int = vision_config.temporal_patch_size
        spatial_merge_size: int = vision_config.spatial_merge_size
        self.spatial_merge_size = spatial_merge_size
        self.spatial_merge_unit: int = spatial_merge_size * spatial_merge_size
265
        in_channels: int = vision_config.in_channels
Mick's avatar
Mick committed
266
267
268
269
270
271
272
273
274
275
        hidden_size: int = vision_config.hidden_size
        depth: int = vision_config.depth
        num_heads: int = vision_config.num_heads
        self.fullatt_block_indexes = vision_config.fullatt_block_indexes
        self.window_size = vision_config.window_size
        self.patch_size = vision_config.patch_size
        mlp_hidden_size: int = vision_config.intermediate_size
        self.patch_embed = Qwen2_5_VisionPatchEmbed(
            patch_size=patch_size,
            temporal_patch_size=temporal_patch_size,
276
            in_channels=in_channels,
Mick's avatar
Mick committed
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
            embed_dim=hidden_size,
        )

        norm_layer = partial(nn.LayerNorm, eps=norm_eps)
        head_dim = hidden_size // num_heads
        self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
        self.blocks = nn.ModuleList(
            [
                Qwen2_5_VisionBlock(
                    dim=hidden_size,
                    intermediate_dim=mlp_hidden_size,
                    num_heads=num_heads,
                    hidden_act=vision_config.hidden_act,
                    norm_layer=norm_layer,
                    quant_config=quant_config,
292
                    prefix=add_prefix(f"blocks.{i}", prefix),
Mick's avatar
Mick committed
293
                )
294
                for i in range(depth)
Mick's avatar
Mick committed
295
296
297
298
299
300
301
            ]
        )
        self.merger = Qwen2_5_VisionPatchMerger(
            dim=vision_config.out_hidden_size,
            context_dim=hidden_size,
            spatial_merge_size=spatial_merge_size,
            quant_config=quant_config,
302
            prefix=add_prefix("merger", prefix),
Mick's avatar
Mick committed
303
304
305
306
307
308
309
310
        )

    def get_window_index(self, grid_thw):
        cu_window_seqlens: list = [0]
        window_index_id = 0
        vit_merger_window_size = (
            self.window_size // self.spatial_merge_size // self.patch_size
        )
311
        window_index: list = []
Mick's avatar
Mick committed
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
        for grid_t, grid_h, grid_w in grid_thw:
            llm_grid_h, llm_grid_w = (
                grid_h // self.spatial_merge_size,
                grid_w // self.spatial_merge_size,
            )
            index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
                grid_t, llm_grid_h, llm_grid_w
            )
            pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
            pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
            num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
            num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
            index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
            index_padded = index_padded.reshape(
                grid_t,
                num_windows_h,
                vit_merger_window_size,
                num_windows_w,
                vit_merger_window_size,
            )
            index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
                grid_t,
                num_windows_h * num_windows_w,
                vit_merger_window_size,
                vit_merger_window_size,
            )
            seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
            index_padded = index_padded.reshape(-1)
            index_new = index_padded[index_padded != -100]
            window_index.append(index_new + window_index_id)
            cu_seqlens_tmp = (
                seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1]
            )
            cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
            window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
        window_index = torch.cat(window_index, dim=0)
        return window_index, cu_window_seqlens

    @property
    def dtype(self) -> torch.dtype:
352
        return self.patch_embed.proj.weight.dtype
Mick's avatar
Mick committed
353
354
355
356
357
358
359

    @property
    def device(self) -> torch.device:
        return self.blocks[0].mlp.gate_proj.weight.device

    def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
        pos_ids = []
360
361
        for i in range(grid_thw.size(0)):
            t, h, w = grid_thw[i].tolist()
Mick's avatar
Mick committed
362
            hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
363

364
365
366
367
368
            hpos_ids = hpos_ids.reshape(
                h // self.spatial_merge_size,
                self.spatial_merge_size,
                w // self.spatial_merge_size,
                self.spatial_merge_size,
Mick's avatar
Mick committed
369
            )
370
371
372
373
374
375
376
377
378
            hpos_ids = hpos_ids.permute(0, 2, 1, 3)
            hpos_ids = hpos_ids.flatten()

            wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
            wpos_ids = wpos_ids.reshape(
                h // self.spatial_merge_size,
                self.spatial_merge_size,
                w // self.spatial_merge_size,
                self.spatial_merge_size,
Mick's avatar
Mick committed
379
            )
380
381
            wpos_ids = wpos_ids.permute(0, 2, 1, 3)
            wpos_ids = wpos_ids.flatten()
382

Mick's avatar
Mick committed
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
            pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
        pos_ids = torch.cat(pos_ids, dim=0)
        max_grid_size = grid_thw[:, 1:].max()
        rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
        rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
        return rotary_pos_emb

    def forward(
        self,
        x: torch.Tensor,
        grid_thw: torch.Tensor,
    ) -> torch.Tensor:
        # patchify
        x = x.to(device=self.device, dtype=self.dtype)
        x = self.patch_embed(x)

        # compute position embedding
        rotary_pos_emb = self.rot_pos_emb(grid_thw)

        window_index, cu_window_seqlens = self.get_window_index(grid_thw)
        cu_window_seqlens = torch.tensor(
            cu_window_seqlens,
            device=x.device,
406
            dtype=torch.int32,
Mick's avatar
Mick committed
407
408
409
        )
        cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)

410
411
412
413
414
415
        # Move window_index to the same device as x before using it to index x
        window_index = window_index.to(device=x.device)

        # Ensure rotary_pos_emb is on the same device/dtype as x
        rotary_pos_emb = rotary_pos_emb.to(device=x.device, dtype=x.dtype)

Mick's avatar
Mick committed
416
417
418
419
420
421
422
423
424
425
        seq_len, _ = x.size()

        x = x.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
        x = x[window_index, :, :]
        x = x.reshape(seq_len, -1)
        rotary_pos_emb = rotary_pos_emb.reshape(
            seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1
        )
        rotary_pos_emb = rotary_pos_emb[window_index, :, :]
        rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
426
427
        emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
        position_embeddings = (emb.cos(), emb.sin())
428
429
430
431
432
        # After building position_embeddings, make sure both cos and sin are on the same device/dtype as the attention input
        position_embeddings = (
            position_embeddings[0].to(x.device, x.dtype),
            position_embeddings[1].to(x.device, x.dtype),
        )
Mick's avatar
Mick committed
433

434
        # compute cu_seqlens - move cu_seqlens to GPU and make it int32
435
436
        cu_seqlens = torch.cat(
            [
437
438
439
440
                torch.tensor([0], device=x.device, dtype=torch.int32),
                (grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2])
                .cumsum(dim=0)
                .to(device=x.device, dtype=torch.int32),
441
442
            ]
        )
Mick's avatar
Mick committed
443
444
445
446
447
448
449
450
451
        cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)

        # transformers
        x = x.unsqueeze(1)
        for layer_num, blk in enumerate(self.blocks):
            if layer_num in self.fullatt_block_indexes:
                cu_seqlens_now = cu_seqlens
            else:
                cu_seqlens_now = cu_window_seqlens
452
453
454
            x = blk(
                x, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings
            )
Mick's avatar
Mick committed
455
456
457
458
459
460
461
462
463
464
465
466
467
468

        # adapter
        x = self.merger(x)

        reverse_indices = torch.argsort(window_index)
        x = x[reverse_indices, :]

        return x


cached_get_processor = lru_cache(get_processor)


class Qwen2_5_VLForConditionalGeneration(nn.Module):
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
    # BitandBytes specific attributes
    default_bitsandbytes_target_modules = [
        ".gate_proj.",
        ".down_proj.",
        ".up_proj.",
        ".q_proj.",
        ".k_proj.",
        ".v_proj.",
        ".o_proj.",
    ]
    bitsandbytes_stacked_params_mapping = {
        # shard_name, weight_name, index
        "q_proj": ("qkv_proj", 0),
        "k_proj": ("qkv_proj", 1),
        "v_proj": ("qkv_proj", 2),
        "gate_proj": ("gate_up_proj", 0),
        "up_proj": ("gate_up_proj", 1),
    }

Mick's avatar
Mick committed
488
489
    def __init__(
        self,
490
        config: Qwen2_5_VLConfig,
Mick's avatar
Mick committed
491
        quant_config: Optional[QuantizationConfig] = None,
492
        prefix: str = "",
Mick's avatar
Mick committed
493
494
495
496
497
498
499
    ) -> None:
        super().__init__()

        self.config = config
        self.visual = Qwen2_5_VisionTransformer(
            config.vision_config,
            norm_eps=getattr(config, "rms_norm_eps", 1e-6),
500
501
502
            # NOTE: Qwen2_5-VL vision encoder currently supports BitsAndBytes 4-bit quantization.
            # Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported.
            quant_config=quant_config,
503
            prefix=add_prefix("visual", prefix),
Mick's avatar
Mick committed
504
505
        )

506
507
508
509
510
        self.model = Qwen2Model(
            config,
            quant_config,
            prefix=add_prefix("model", prefix),
        )
Mick's avatar
Mick committed
511
512
513
514
515

        if config.tie_word_embeddings:
            self.lm_head = self.model.embed_tokens
        else:
            self.lm_head = ParallelLMHead(
516
517
518
519
                config.vocab_size,
                config.hidden_size,
                quant_config=quant_config,
                prefix=add_prefix("lm_head", prefix),
Mick's avatar
Mick committed
520
            )
521
        self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
Mick's avatar
Mick committed
522
523
524
525

        self.logits_processor = LogitsProcessor(config)
        self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)

Mick's avatar
Mick committed
526
    def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
527
        pattern = MultiModalityDataPaddingPatternMultimodalTokens()
Mick's avatar
Mick committed
528
        return pattern.pad_input_tokens(input_ids, mm_inputs)
Mick's avatar
Mick committed
529

Mick's avatar
Mick committed
530
531
    def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
        # in qwen-vl, last dim is the same
532
        pixel_values = torch.cat([item.feature for item in items], dim=0).type(
Mick's avatar
Mick committed
533
534
            self.visual.dtype
        )
535
        image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0)
Mick's avatar
Mick committed
536
        assert pixel_values.dim() == 2, pixel_values.dim()
537
538
        assert image_grid_thw.dim() == 2, image_grid_thw.dim()
        image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
Mick's avatar
Mick committed
539
540
        return image_embeds

541
542
    def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
        # in qwen-vl, last dim is the same
543
544
545
        pixel_values = torch.cat([item.feature for item in items], dim=0).type(
            self.visual.dtype
        )
546
547
548
549
        video_grid_thw = torch.concat([item.video_grid_thw for item in items], dim=0)
        assert pixel_values.dim() == 2, pixel_values.dim()
        assert video_grid_thw.dim() == 2, video_grid_thw.dim()
        video_embeds = self.visual(pixel_values, grid_thw=video_grid_thw)
Mick's avatar
Mick committed
550
551
        return video_embeds

552
553
554
    def get_input_embeddings(self):
        return self.model.embed_tokens

555
    @torch.no_grad()
Mick's avatar
Mick committed
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        forward_batch: ForwardBatch,
        get_embedding: bool = False,
    ):
        """Run forward pass for Qwen2_5-VL.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
            positions: Flattened (concatenated) position ids corresponding to a
                batch.
                **NOTE**: If mrope is enabled (default setting for Qwen2-VL
                opensource models), the shape will be `(3, seq_len)`,
                otherwise it will be `(seq_len,).
                (Use input_metadata.mrope_positions to replace it)
        """
575
        if self.is_mrope_enabled:
Mick's avatar
Mick committed
576
577
            positions = forward_batch.mrope_positions

578
        if not (
Mick's avatar
Mick committed
579
            forward_batch.forward_mode.is_decode()
580
            or not forward_batch.contains_image_inputs()
Mick's avatar
Mick committed
581
        ):
582
            if self.is_mrope_enabled:
Mick's avatar
Mick committed
583
584
585
586
587
                assert positions.ndim == 2 and positions.size(0) == 3, (
                    "multimodal section rotary embedding requires "
                    f"(3, seq_len) positions, but got {positions.size()}"
                )

Mick's avatar
Mick committed
588
        hidden_states = general_mm_embed_routine(
589
590
            input_ids=input_ids,
            forward_batch=forward_batch,
Mick's avatar
Mick committed
591
            language_model=self.model,
592
            multimodal_model=self,
Mick's avatar
Mick committed
593
594
595
596
597
598
599
600
601
602
603
604
605
            positions=positions,
        )

        if not get_embedding:
            return self.logits_processor(
                input_ids, hidden_states, self.lm_head, forward_batch
            )
        else:
            return self.pooler(hidden_states, forward_batch)

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
Mick's avatar
Mick committed
606
607
608
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
Mick's avatar
Mick committed
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
            ("gate_up_proj", "up_proj", 1),
            ("gate_up_proj", "gate_proj", 0),
        ]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue

            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                if "visual" in name:
                    continue
                name = name.replace(weight_name, param_name)

                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                if "visual" in name:
                    # adapt to VisionAttention
                    name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")

                try:
                    # Skip loading extra bias for GPTQ models.
                    if name.endswith(".bias") and name not in params_dict:
                        continue
                    param = params_dict[name]
                except KeyError:
                    print(params_dict.keys())
                    raise

                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(param, loaded_weight)


EntryClass = [Qwen2_5_VLForConditionalGeneration]