lfm2_vl.py 29.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import itertools
import math
from collections.abc import Iterable, Mapping, Sequence
from typing import Annotated, Literal

import torch
import torch.nn as nn
from transformers import BatchFeature
from transformers.activations import ACT2FN
from transformers.models.lfm2_vl import Lfm2VlProcessor
from transformers.models.lfm2_vl.configuration_lfm2_vl import Lfm2VlConfig
from transformers.models.lfm2_vl.image_processing_lfm2_vl_fast import (
    Lfm2VlImageProcessorFast,
    find_closest_aspect_ratio,
    round_by_factor,
)

from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.mamba.mamba_utils import (
25
26
    MambaStateCopyFunc,
    MambaStateCopyFuncCalculator,
27
28
29
30
31
32
33
34
35
36
37
38
    MambaStateDtypeCalculator,
    MambaStateShapeCalculator,
)
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems
from vllm.multimodal.processing import (
39
    BaseDummyInputsBuilder,
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdateDetails,
)
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape

from .interfaces import (
    IsHybrid,
    MultiModalEmbeddings,
    SupportsLoRA,
    SupportsMultiModal,
    SupportsPP,
)
55
from .lfm2_siglip2 import Siglip2Model
56
57
58
59
60
61
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    init_vllm_registered_model,
    maybe_prefix,
)
62
from .vision import is_vit_use_data_parallel
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178


class Lfm2VLImagePixelInputs(TensorSchema):
    """
    Dimensions:
        - b: Number of images in the prompt
        - bn: Batch size * number of images
        - d: Number of dimensions
        - fd: Number of features per dimension
    """

    type: Literal["pixel_values"] = "pixel_values"
    pixel_values: Annotated[torch.Tensor, TensorShape("bn", "d", "fd")]
    spatial_shapes: Annotated[torch.Tensor, TensorShape("bn", 2)]
    num_patches: Annotated[torch.Tensor, TensorShape("b")]


LFM2VLImageInputs = Lfm2VLImagePixelInputs


class Lfm2VLProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self):
        return self.ctx.get_hf_config(Lfm2VlConfig)

    def get_hf_processor(self, **kwargs):
        return self.ctx.get_hf_processor(Lfm2VlProcessor, **kwargs)

    def get_image_processor(self, **kwargs: object) -> Lfm2VlImageProcessorFast:
        return self.get_hf_processor(**kwargs).image_processor

    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
        return {"image": None}

    def get_image_size_with_most_features(self) -> ImageSize:
        processor = self.get_image_processor()
        max_image_tokens = processor.max_image_tokens
        encoder_patch_size = processor.encoder_patch_size
        downsample_factor = processor.downsample_factor
        max_pixels = max_image_tokens * (encoder_patch_size**2) * (downsample_factor**2)
        side = int(math.sqrt(max_pixels))
        return ImageSize(width=side, height=side)

    def _is_image_too_large(
        self,
        height: int,
        width: int,
        max_image_tokens: int,
        encoder_patch_size: int,
        downsample_factor: int,
        max_pixels_tolerance: float,
    ) -> bool:
        """Check if the image is too large to be processed as one tile."""
        total_factor = encoder_patch_size * downsample_factor

        h_bar = max(encoder_patch_size, round_by_factor(height, total_factor))
        w_bar = max(encoder_patch_size, round_by_factor(width, total_factor))
        return (
            h_bar * w_bar
            > max_image_tokens
            * encoder_patch_size**2
            * downsample_factor**2
            * max_pixels_tolerance
        )

    def smart_resize(
        self,
        height: int,
        width: int,
        downsample_factor: int,
        min_image_tokens: int,
        max_image_tokens: int,
        encoder_patch_size: int,
    ) -> tuple[int, int]:
        total_factor = encoder_patch_size * downsample_factor
        smart_resize_min_pixels = (
            min_image_tokens * encoder_patch_size**2 * downsample_factor**2
        )
        smart_resize_max_pixels = (
            max_image_tokens * encoder_patch_size**2 * downsample_factor**2
        )

        h_bar = max(total_factor, round_by_factor(height, total_factor))
        w_bar = max(total_factor, round_by_factor(width, total_factor))

        if h_bar * w_bar > smart_resize_max_pixels:
            beta = math.sqrt((height * width) / smart_resize_max_pixels)
            h_bar = max(
                total_factor, math.floor(height / beta / total_factor) * total_factor
            )
            w_bar = max(
                total_factor, math.floor(width / beta / total_factor) * total_factor
            )
        elif h_bar * w_bar < smart_resize_min_pixels:
            beta = math.sqrt(smart_resize_min_pixels / (height * width))
            h_bar = math.ceil(height * beta / total_factor) * total_factor
            w_bar = math.ceil(width * beta / total_factor) * total_factor

        return w_bar, h_bar

    def _target_ratios(self, min_tiles: int, max_tiles: int) -> list[tuple[int, int]]:
        ratios = [
            (w, h)
            for n in range(min_tiles, max_tiles + 1)
            for w in range(1, n + 1)
            for h in range(1, n + 1)
            if min_tiles <= w * h <= max_tiles
        ]
        return sorted(set(ratios), key=lambda x: x[0] * x[1])

    def _get_grid_layout(
        self,
        height: int,
        width: int,
        min_tiles: int,
        max_tiles: int,
        tile_size: int,
179
    ) -> tuple[int, int, int]:
180
181
182
183
184
185
186
187
188
189
190
191
192
        aspect_ratio = width / height
        target_ratios = self._target_ratios(min_tiles, max_tiles)
        # find best matching grid configuration
        grid_width, grid_height = find_closest_aspect_ratio(
            aspect_ratio, target_ratios, width, height, tile_size
        )
        total_patches = grid_width * grid_height
        return grid_width, grid_height, total_patches

    def _get_image_feature_grid_size(
        self,
        image_width: int,
        image_height: int,
193
194
195
196
        processor: Lfm2VlProcessor,
        mm_kwargs: Mapping[str, object],
    ) -> tuple[int, int, int]:
        image_processor: Lfm2VlImageProcessorFast = processor.image_processor
197

198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
        mm_kwargs = self.ctx.get_merged_mm_kwargs(mm_kwargs)
        downsample_factor = mm_kwargs.get(
            "downsample_factor", image_processor.downsample_factor
        )
        encoder_patch_size = mm_kwargs.get(
            "encoder_patch_size", image_processor.encoder_patch_size
        )
        max_pixels_tolerance = mm_kwargs.get(
            "max_pixels_tolerance", image_processor.max_pixels_tolerance
        )
        min_tiles = mm_kwargs.get("min_tiles", image_processor.min_tiles)
        max_tiles = mm_kwargs.get("max_tiles", image_processor.max_tiles)
        max_image_tokens = mm_kwargs.get(
            "max_image_tokens", image_processor.max_image_tokens
        )
        tile_size = mm_kwargs.get("tile_size", image_processor.tile_size)
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246

        do_image_splitting = not min_tiles == max_tiles == 1
        is_image_large = self._is_image_too_large(
            height=image_height,
            width=image_width,
            max_image_tokens=max_image_tokens,
            encoder_patch_size=encoder_patch_size,
            downsample_factor=downsample_factor,
            max_pixels_tolerance=max_pixels_tolerance,
        )

        # Big image will be cropped into patches and small images are just resized
        if is_image_large and do_image_splitting:
            grid_width, grid_height, total_patches = self._get_grid_layout(
                image_height,
                image_width,
                min_tiles=min_tiles,
                max_tiles=max_tiles,
                tile_size=tile_size,
            )
        else:
            grid_width = grid_height = total_patches = 1

        if grid_width * grid_height != 1:  # Thumbnail
            total_patches += 1

        return grid_width, grid_height, total_patches

    def get_num_patches(
        self,
        *,
        image_width: int,
        image_height: int,
247
248
        processor: Lfm2VlProcessor,
        mm_kwargs: Mapping[str, object],
249
250
251
252
253
    ) -> int:
        _, _, total_patches = self._get_image_feature_grid_size(
            image_width=image_width,
            image_height=image_height,
            processor=processor,
254
            mm_kwargs=mm_kwargs,
255
256
257
258
259
260
261
262
        )
        return total_patches

    def get_image_repl(
        self,
        image_width: int,
        image_height: int,
        spatial_shapes: torch.Tensor,
263
264
        processor: Lfm2VlProcessor,
        mm_kwargs: Mapping[str, object],
265
266
267
268
269
270
271
272
273
274
    ) -> str:
        grid_placeholder = "<|img_row_{n_h}_col_{n_w}|>"
        image_token = processor.image_token
        image_start_token = processor.image_start_token
        image_end_token = processor.image_end_token
        image_thumbnail_token = processor.image_thumbnail_token

        num_thumbnail_tokens, num_tokens_per_tile = self.get_num_image_tokens(
            spatial_shapes=spatial_shapes,
            processor=processor,
275
            mm_kwargs=mm_kwargs,
276
277
278
279
280
281
282
        )
        tile_img_placeholder = grid_placeholder + (image_token * num_tokens_per_tile)

        grid_w, grid_h, _ = self._get_image_feature_grid_size(
            image_width=image_width,
            image_height=image_height,
            processor=processor,
283
            mm_kwargs=mm_kwargs,
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
        )

        if grid_w > 1 or grid_h > 1:
            tiles_placeholder: list[str] = [
                tile_img_placeholder.format(n_h=i + 1, n_w=j + 1)
                for i in range(grid_h)
                for j in range(grid_w)
            ]

            if num_thumbnail_tokens > 0:
                tiles_placeholder.append(
                    image_thumbnail_token + (image_token * num_thumbnail_tokens)
                )
        else:
            tiles_placeholder = [image_token * num_thumbnail_tokens]

        placeholder = "".join(
            itertools.chain([image_start_token], tiles_placeholder, [image_end_token])
        )
        return placeholder

    def get_num_image_tokens(
        self,
        *,
        spatial_shapes: torch.Tensor,
309
310
        processor: Lfm2VlProcessor,
        mm_kwargs: Mapping[str, object],
311
    ) -> tuple[int, int]:
312
313
314
315
316
317
318
319
320
321
322
        image_processor: Lfm2VlImageProcessorFast = processor.image_processor

        mm_kwargs = self.ctx.get_merged_mm_kwargs(mm_kwargs)
        downsample_factor = mm_kwargs.get(
            "downsample_factor", image_processor.downsample_factor
        )
        encoder_patch_size = mm_kwargs.get(
            "encoder_patch_size", image_processor.encoder_patch_size
        )
        tile_size = mm_kwargs.get("tile_size", image_processor.tile_size)

323
324
325
326
        num_thumbnail_tokens = spatial_shapes[-1].prod() // (downsample_factor**2)
        num_patches_tile = tile_size // encoder_patch_size
        dwn_num_patches_tile = math.ceil(num_patches_tile / downsample_factor)
        num_tiles_tokens = dwn_num_patches_tile * dwn_num_patches_tile
327

328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
        return num_thumbnail_tokens, num_tiles_tokens


class Lfm2VLDummyInputsBuilder(BaseDummyInputsBuilder[Lfm2VLProcessingInfo]):
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)
        processor = self.info.get_hf_processor()
        image_token = processor.image_token
        return image_token * num_images

    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
343
        mm_processor_kwargs: Mapping[str, object] | None = None,
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
    ) -> MultiModalDataDict:
        num_images = mm_counts.get("image", 0)

        target_width, target_height = self.info.get_image_size_with_most_features()

        image_overrides = mm_options.get("image") if mm_options else None

        return {
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            ),
        }


class Lfm2VLMultiModalProcessor(BaseMultiModalProcessor[Lfm2VLProcessingInfo]):
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
        tok_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        # Text-only input not supported in composite processor
        if not (images := mm_data.get("images", [])):
371
372
373
            prompt_ids = self.info.get_tokenizer().encode(
                prompt, add_special_tokens=False
            )
374
375
376
377
378
379
380
381
382
383
            prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
            return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

        processed_outputs = super()._call_hf_processor(
            prompt,
            mm_data,
            mm_kwargs,
            tok_kwargs,
        )

384
385
        mm_items = self.info.parse_mm_data({"image": images}, validate=False)
        parsed_images = mm_items.get_items("image", ImageProcessorItems)
386
387
388
389
390
391
392
393
394
395
        image_sizes = [
            parsed_images.get_image_size(i) for i in range(len(parsed_images))
        ]
        hf_processor = self.info.get_hf_processor(**mm_kwargs)

        num_patches = [
            self.info.get_num_patches(
                image_width=size.width,
                image_height=size.height,
                processor=hf_processor,
396
                mm_kwargs=mm_kwargs,
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
            )
            for size in image_sizes
        ]
        processed_outputs["num_patches"] = torch.tensor(num_patches)

        return processed_outputs

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        num_patches = hf_inputs.get("num_patches", torch.empty(0))

        return dict[str, MultiModalFieldConfig](
            pixel_values=MultiModalFieldConfig.flat_from_sizes("image", num_patches),
            spatial_shapes=MultiModalFieldConfig.flat_from_sizes(
                "image", num_patches, keep_on_cpu=True
            ),
            num_patches=MultiModalFieldConfig.batched("image", keep_on_cpu=True),
        )

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargsItems,
    ) -> Sequence[PromptReplacement]:
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
        image_token = hf_processor.image_token

        def get_image_replacement_lfm2vl(item_idx: int):
            images = mm_items.get_items("image", ImageProcessorItems)
            image_size = images.get_image_size(item_idx)
            out_item = out_mm_kwargs["image"][item_idx]
            spatial_shapes = out_item["spatial_shapes"].data
            assert isinstance(spatial_shapes, torch.Tensor)
            image_repl = self.info.get_image_repl(
                image_width=image_size.width,
                image_height=image_size.height,
                spatial_shapes=spatial_shapes,
                processor=hf_processor,
439
                mm_kwargs=hf_processor_mm_kwargs,
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
            )
            return PromptUpdateDetails.select_text(
                image_repl,
                embed_text=image_token,
            )

        return [
            PromptReplacement(
                modality="image",
                target=image_token,
                replacement=get_image_replacement_lfm2vl,
            )
        ]


class Lfm2VLMultiModalProjector(nn.Module):
    def __init__(
457
458
459
        self,
        config: Lfm2VlConfig,
        prefix: str = "",
460
461
    ):
        super().__init__()
462
        self.use_data_parallel = is_vit_use_data_parallel()
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480

        in_channels = config.vision_config.hidden_size * (config.downsample_factor**2)
        self.factor = config.downsample_factor
        self.projector_use_layernorm = config.projector_use_layernorm
        if self.projector_use_layernorm:
            self.layer_norm = nn.LayerNorm(in_channels)
        self.linear_1 = nn.Linear(
            in_channels,
            config.projector_hidden_size,
            bias=config.projector_bias,
        )
        self.act = ACT2FN[config.projector_hidden_act]
        self.linear_2 = nn.Linear(
            config.projector_hidden_size,
            config.text_config.hidden_size,
            bias=config.projector_bias,
        )

481
482
483
484
485
486
    def forward(
        self,
        vision_features_packed: torch.Tensor,
        spatial_shapes: torch.Tensor,
    ) -> torch.Tensor:
        """Project packed vision features without materializing padded tensors.
487

488
489
490
491
492
493
494
495
496
497
        Args:
            vision_features_packed: (total_tokens, hidden_size) packed in tile order.
            spatial_shapes: (num_tiles, 2) on CPU (height, width) per tile.

        Returns:
            projected_packed: (total_projected_tokens, text_hidden_size)
        """
        assert spatial_shapes.device.type == "cpu", (
            "Expected `spatial_shapes` on CPU to avoid device-to-host sync in "
            "variable-length packing."
498
        )
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
        factor = self.factor
        device = vision_features_packed.device
        hidden_size = vision_features_packed.shape[-1]

        spatial_shapes_list: list[list[int]] = spatial_shapes.tolist()
        lengths_list = [h * w for h, w in spatial_shapes_list]

        gather_idx_parts: list[torch.Tensor] = []
        offset = 0

        dh = torch.arange(factor, dtype=torch.int64)
        dw = torch.arange(factor, dtype=torch.int64)
        dh_grid, dw_grid = torch.meshgrid(dh, dw, indexing="ij")
        dh_flat = dh_grid.reshape(-1)
        dw_flat = dw_grid.reshape(-1)

        for (height, width), length in zip(spatial_shapes_list, lengths_list):
            if length <= 0:
                continue
            if height % factor != 0 or width % factor != 0:
                raise ValueError(
                    "spatial_shapes must be divisible by downsample_factor: "
                    f"got ({height}, {width}) with factor={factor}."
                )
            height_out = height // factor
            width_out = width // factor

            rows_out = torch.arange(height_out, dtype=torch.int64)
            cols_out = torch.arange(width_out, dtype=torch.int64)
            rr, cc = torch.meshgrid(rows_out, cols_out, indexing="ij")
            rr = rr.reshape(-1)
            cc = cc.reshape(-1)

            token_idx = (rr[:, None] * factor + dh_flat[None, :]) * width + (
                cc[:, None] * factor + dw_flat[None, :]
            )
            gather_idx_parts.append(token_idx.reshape(-1) + offset)
            offset += length

        if gather_idx_parts:
            gather_idx = torch.cat(gather_idx_parts).to(device=device)
            gathered = vision_features_packed.index_select(0, gather_idx)
            unshuffled = gathered.reshape(-1, factor * factor * hidden_size)
        else:
            unshuffled = vision_features_packed.new_empty(
                (0, factor * factor * hidden_size)
            )

        if self.projector_use_layernorm:
            unshuffled = self.layer_norm(unshuffled)
        hidden_states = self.linear_1(unshuffled)
        hidden_states = self.act(hidden_states)
        projected_packed = self.linear_2(hidden_states)
        return projected_packed
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613


@MULTIMODAL_REGISTRY.register_processor(
    Lfm2VLMultiModalProcessor,
    info=Lfm2VLProcessingInfo,
    dummy_inputs=Lfm2VLDummyInputsBuilder,
)
class Lfm2VLForConditionalGeneration(
    nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, IsHybrid
):
    merge_by_field_config = True

    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "lm_head.": "language_model.lm_head.",
            "model.language_model.": "language_model.model.",
            "model.vision_tower.": "vision_tower.",
            "model.multi_modal_projector.": "multi_modal_projector.",
        }
    )

    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
        if modality.startswith("image"):
            return "<image>"

        raise ValueError("Only image modality is supported")

    @classmethod
    def get_mamba_state_dtype_from_config(
        cls,
        vllm_config: "VllmConfig",
    ) -> tuple[torch.dtype, ...]:
        return MambaStateDtypeCalculator.short_conv_state_dtype(
            vllm_config.model_config.dtype,
            vllm_config.cache_config.mamba_cache_dtype,
        )

    @classmethod
    def get_mamba_state_shape_from_config(
        cls,
        vllm_config: "VllmConfig",
    ) -> tuple[tuple[int, int]]:
        """Calculate shapes for LFM2's convolutional cache.

        Args:
            vllm_config: vLLM config

        Returns:
            Tuple containing:
            - conv_state_shape: Shape for convolutional state cache
        """
        parallel_config = vllm_config.parallel_config
        hf_language_config = vllm_config.model_config.hf_config.text_config

        return MambaStateShapeCalculator.short_conv_state_shape(
            tp_world_size=parallel_config.tensor_parallel_size,
            intermediate_size=hf_language_config.hidden_size,
            conv_kernel=hf_language_config.conv_L_cache,
        )

614
615
616
617
    @classmethod
    def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc]:
        return MambaStateCopyFuncCalculator.short_conv_state_copy_func()

618
619
620
621
622
623
624
625
626
627
628
629
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"):
        super().__init__()
        config: Lfm2VlConfig = vllm_config.model_config.hf_config
        multimodal_config = vllm_config.model_config.multimodal_config
        vision_config = config.vision_config
        quant_config = vllm_config.quant_config

        self.config = config
        self.vllm_config = vllm_config
        self.multimodal_config = multimodal_config
        self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"

630
631
632
633
634
635
636
637
638
639
640
        with self._mark_tower_model(vllm_config, "image"):
            if vision_config.model_type == "siglip2_vision_model":
                self.vision_tower = Siglip2Model(
                    config=vision_config,
                    quant_config=quant_config,
                    prefix=maybe_prefix(prefix, "vision_tower"),
                )
            else:
                raise ValueError(
                    f"Unsupported visual tokenizer type: {vision_config.model_type}"
                )
641

642
643
644
645
            self.multi_modal_projector = Lfm2VLMultiModalProjector(
                config=config,
                prefix=maybe_prefix(prefix, "multi_modal_projector"),
            )
646

647
648
649
650
651
652
653
        with self._mark_language_model(vllm_config):
            self.language_model = init_vllm_registered_model(
                vllm_config=vllm_config,
                hf_config=config.text_config,
                prefix=maybe_prefix(prefix, "language"),
                architectures=config.text_config.architectures,
            )
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679

        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors
        )

    def _parse_and_validate_image_input(
        self, **kwargs: object
    ) -> LFM2VLImageInputs | None:
        pixel_values = kwargs.pop("pixel_values", None)
        spatial_shapes = kwargs.pop("spatial_shapes", None)
        num_patches = kwargs.pop("num_patches", None)
        if pixel_values is None:
            return None

        return LFM2VLImageInputs(
            type="pixel_values",
            pixel_values=pixel_values,
            spatial_shapes=spatial_shapes,
            num_patches=num_patches,
        )

    def image_pixels_to_features(
        self,
        pixel_values: torch.FloatTensor,
        spatial_shapes: torch.Tensor,
    ) -> torch.Tensor:
680
681
682
683
684
        assert spatial_shapes.device.type == "cpu", (
            "Expected `spatial_shapes` on CPU to avoid device-to-host sync in "
            "variable-length packing."
        )

685
686
687
688
689
        pixel_values = pixel_values.to(
            dtype=self.vision_tower.vision_model.embeddings.patch_embedding.weight.dtype
        )  # fp16 compatibility

        # LFM2-VL's HF processor pads patch sequences with trailing zeros.
690
691
692
693
        # Pack patch tokens upfront so the vision tower runs entirely unpadded.
        spatial_shapes_list: list[list[int]] = spatial_shapes.tolist()
        lengths_list = [h * w for h, w in spatial_shapes_list]
        total_tokens = int(sum(lengths_list))
694
695
696
697
        lengths_cpu = (spatial_shapes[:, 0] * spatial_shapes[:, 1]).to(
            dtype=torch.int32
        )
        max_seqlen = (
698
            lengths_cpu.max().reshape(1)
699
            if lengths_cpu.numel()
700
            else torch.tensor([0], dtype=torch.int32)
701
        )
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720

        if total_tokens == 0:
            return []

        packed_pixel_values = pixel_values.new_empty(
            (total_tokens, pixel_values.shape[-1])
        )
        offset = 0
        for i, length in enumerate(lengths_list):
            if length <= 0:
                continue
            packed_pixel_values[offset : offset + length].copy_(
                pixel_values[i, :length]
            )
            offset += length
        packed_pixel_values = packed_pixel_values.unsqueeze(0)

        lengths = torch.tensor(
            lengths_list, dtype=torch.int32, device=pixel_values.device
721
722
723
724
        )
        cu_seqlens = torch.zeros(
            lengths.shape[0] + 1,
            dtype=torch.int32,
725
            device=pixel_values.device,
726
727
728
729
730
        )
        cu_seqlens[1:] = torch.cumsum(lengths, dim=0)

        with set_forward_context(None, self.vllm_config):
            vision_outputs = self.vision_tower(
731
                pixel_values_packed=packed_pixel_values,
732
733
734
735
                spatial_shapes=spatial_shapes,
                cu_seqlens=cu_seqlens,
                max_seqlen=max_seqlen,
            )
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
        image_outputs_packed = getattr(
            vision_outputs, "last_hidden_state", vision_outputs
        )
        vision_features_packed = image_outputs_packed[0]

        factor = self.multi_modal_projector.factor
        projected_lengths_list: list[int] = []
        for (height, width), length in zip(spatial_shapes_list, lengths_list):
            if length <= 0:
                projected_lengths_list.append(0)
                continue
            if height % factor != 0 or width % factor != 0:
                raise ValueError(
                    "spatial_shapes must be divisible by downsample_factor: "
                    f"got ({height}, {width}) with factor={factor}."
                )
            projected_lengths_list.append((height // factor) * (width // factor))
753

754
755
756
757
        projected_packed = self.multi_modal_projector(
            vision_features_packed=vision_features_packed,
            spatial_shapes=spatial_shapes,
        )
758

759
760
761
762
763
        image_features: list[torch.Tensor] = []
        offset = 0
        for out_len in projected_lengths_list:
            image_features.append(projected_packed[offset : offset + out_len])
            offset += out_len
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802

        return image_features

    def _process_image_input(
        self,
        image_input: LFM2VLImageInputs,
    ) -> torch.Tensor | list[torch.Tensor]:
        pixel_values = image_input["pixel_values"]
        spatial_shapes = image_input["spatial_shapes"]
        num_patches = image_input["num_patches"]

        image_features = self.image_pixels_to_features(
            pixel_values,
            spatial_shapes=spatial_shapes,
        )

        # Group patches by image - num_patches is on CPU (keep_on_cpu=True)
        # so .tolist() is instant with no DtoH sync
        num_patches_list = num_patches.tolist()
        batched_features: list[torch.Tensor] = []
        patch_idx = 0
        for count in num_patches_list:
            # Slice the list of patch tensors for this image
            image_patches = image_features[patch_idx : patch_idx + count]
            # Concatenate patches for this image
            batched_features.append(torch.cat(image_patches, dim=0))
            patch_idx += count

        return batched_features

    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return []

        return self._process_image_input(image_input)

    def forward(
        self,
803
        input_ids: torch.Tensor | None,
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
        positions: torch.Tensor,
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
        **kwargs: object,
    ) -> torch.Tensor | IntermediateTensors:
        if intermediate_tensors is not None:
            inputs_embeds = None

        hidden_states = self.language_model(
            input_ids=input_ids,
            positions=positions,
            intermediate_tensors=intermediate_tensors,
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor | None:
824
        return self.language_model.compute_logits(hidden_states)
825
826
827
828
829
830
831
832
833
834
835
836
837
838

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model",
            connector="multi_modal_projector",
            tower_model="vision_tower",
        )