"vllm/vscode:/vscode.git/clone" did not exist on "d84d8f4429a5246a9d9f179b47fac7e13801710d"
llava.py 31.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import abstractmethod
5
from collections.abc import Iterable, Mapping, Sequence
6
from typing import (Annotated, Final, Literal, Optional, Protocol, TypeVar,
7
                    Union, cast)
8
9

import torch
10
import torch.nn as nn
11
12
from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig,
                          PixtralVisionConfig, PretrainedConfig,
13
                          SiglipVisionConfig)
14
from transformers.models.llava import LlavaProcessor
15
from transformers.models.pixtral import PixtralProcessor
16

17
from vllm.config import VllmConfig
18
from vllm.inputs import InputProcessingContext
19
from vllm.jsontree import json_map_leaves
20
from vllm.model_executor.layers.activation import get_act_fn
21
22
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               RowParallelLinear)
23
from vllm.model_executor.layers.quantization import QuantizationConfig
24
from vllm.model_executor.sampling_metadata import SamplingMetadata
25
from vllm.multimodal import MULTIMODAL_REGISTRY
26
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
27
                                    MultiModalInputs, MultiModalKwargs)
28
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
29
                                   ImageSize, MultiModalDataItems)
30
from vllm.multimodal.processing import (BaseMultiModalProcessor,
31
                                        BaseProcessingInfo, ProcessingCache,
32
33
                                        PromptReplacement, PromptUpdate,
                                        PromptUpdateDetails)
34
from vllm.multimodal.profiling import BaseDummyInputsBuilder
35
from vllm.sequence import IntermediateTensors
36
from vllm.utils.tensor_schema import TensorSchema, TensorShape
37

38
from .clip import CLIPVisionModel
39
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
40
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
41
from .siglip import SiglipVisionModel
42
43
44
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
                    init_vllm_registered_model, maybe_prefix,
                    merge_multimodal_embeddings)
45
from .vision import get_vision_encoder_info
46
47


48
class LlavaImagePixelInputs(TensorSchema):
49
    """
50
51
52
53
54
55
    Dimensions:
        - bn: Batch size * number of images
        - c: Number of channels (3)
        - h: Height
        - w: Width
    
56
57
58
    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """
59
60
    type: Literal["pixel_values"] = "pixel_values"
    pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
61

62

63
class PixtralHFImagePixelInputs(TensorSchema):
64
    """
65
66
67
68
69
70
    Dimensions:
        - bn: Batch size * number of images
        - c: Number of channels
        - h: Height
        - w: Width
    
71
72
73
    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """
74
75
76
    type: Literal["pixel_values_pixtral"] = "pixel_values_pixtral"
    pixel_values: Annotated[Union[torch.Tensor, list[torch.Tensor]],
                            TensorShape("bn", "c", "h", "w")]
77

78

79
class LlavaImageEmbeddingInputs(TensorSchema):
80
    """
81
82
83
84
85
86
87
    Dimensions:
        - bn: Batch size * number of images
        - ifs: Image feature size
        - hs: Hidden size (must match language model backbone)
    """
    type: Literal["image_embeds"] = "image_embeds"
    data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")]
88
89


90
91
LlavaImageInputs = Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs,
                         LlavaImageEmbeddingInputs]
92
93


94
95
class LlavaMultiModalProjector(nn.Module):

96
97
98
99
    def __init__(self,
                 vision_hidden_size: int,
                 text_hidden_size: int,
                 projector_hidden_act: str,
100
                 multimodal_projector_bias: bool,
101
102
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
103
104
        super().__init__()

105
106
        self.linear_1 = ColumnParallelLinear(vision_hidden_size,
                                             text_hidden_size,
107
                                             bias=multimodal_projector_bias,
108
109
                                             quant_config=quant_config,
                                             prefix=f"{prefix}.linear_1")
110
        self.act = get_act_fn(projector_hidden_act)
111
112
        self.linear_2 = RowParallelLinear(text_hidden_size,
                                          text_hidden_size,
113
                                          bias=multimodal_projector_bias,
114
115
                                          quant_config=quant_config,
                                          prefix=f"{prefix}.linear_2")
116

117
    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
118
        hidden_states, _ = self.linear_1(image_features)
119
        hidden_states = self.act(hidden_states)
120
        hidden_states, _ = self.linear_2(hidden_states)
121
122
123
        return hidden_states


124
125
class LlavaLikeConfig(Protocol):
    vision_config: Final[PretrainedConfig]
126
    image_token_index: Final[int]
127
    vision_feature_select_strategy: Final[str]
128
    vision_feature_layer: Final[Union[int, list[int]]]
129

130

131
132
133
134
class LlavaLikeProcessor(Protocol):
    image_token: Final[str]


135
class BaseLlavaProcessingInfo(BaseProcessingInfo):
136

137
    def get_hf_config(self) -> LlavaLikeConfig:
138
        return self.ctx.get_hf_config(LlavaConfig)
139

140
141
    def get_vision_encoder_info(self):
        return get_vision_encoder_info(self.get_hf_config())
142

143
    @abstractmethod
144
    def get_hf_processor(self, **kwargs: object) -> LlavaLikeProcessor:
145
        raise NotImplementedError
146

147
148
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None}
149

150
151
152
153
154
155
156
157
158
159
160
161
162
    def _apply_feature_select_strategy(
        self,
        strategy: str,
        encoder_num_image_tokens: int,
    ) -> int:
        if strategy == "default":
            return encoder_num_image_tokens - 1
        if strategy == "full":
            return encoder_num_image_tokens

        msg = f"Unexpected feature select strategy: {strategy!r}"
        raise NotImplementedError(msg)

163
164
165
166
167
168
169
170
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        hf_config = self.get_hf_config()
        vision_encoder_info = self.get_vision_encoder_info()
171

172
173
174
175
176
177
178
        return self._apply_feature_select_strategy(
            hf_config.vision_feature_select_strategy,
            vision_encoder_info.get_num_image_tokens(
                image_width=image_width,
                image_height=image_height,
            ),
        )
179

180
181
    def get_image_size_with_most_features(self) -> ImageSize:
        vision_encoder_info = self.get_vision_encoder_info()
182
183
        width = height = vision_encoder_info.get_image_size()
        return ImageSize(width=width, height=height)
184

185
186
    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()
187

188
        return self.get_num_image_tokens(
189
190
191
192
            image_width=target_width,
            image_height=target_height,
        )

193
194
195
196
197
198

_I = TypeVar("_I", bound=BaseLlavaProcessingInfo)


class LlavaDummyInputsBuilder(BaseDummyInputsBuilder[_I]):

199
200
201
202
203
204
205
206
207
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        processor = self.info.get_hf_processor()
        image_token = processor.image_token

        return image_token * num_images

    def get_dummy_mm_data(
208
        self,
209
        seq_len: int,
210
        mm_counts: Mapping[str, int],
211
    ) -> MultiModalDataDict:
212
213
        num_images = mm_counts.get("image", 0)

214
215
        target_width, target_height = \
            self.info.get_image_size_with_most_features()
216

217
        return {
218
219
220
221
222
223
224
            "image":
            self._get_dummy_images(width=target_width,
                                   height=target_height,
                                   num_images=num_images)
        }


225
class LlavaProcessingInfo(BaseLlavaProcessingInfo):
226

227
    def get_hf_processor(self, **kwargs: object):
228
229
230
231
232
233
234
        hf_processor = self.ctx.get_hf_processor(LlavaProcessor, **kwargs)
        # In case patch_size is omitted from `processor_config.json`
        # e.g. for E5-V: https://huggingface.co/royokong/e5-v
        if hf_processor.patch_size is None:
            patch_size = self.get_vision_encoder_info().get_patch_size()
            hf_processor.patch_size = patch_size
        return hf_processor
235
236


237
class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor[_I]):
238
239
240
241
242
243
244
245
246

    # Copied from BaseMultiModalProcessor
    @abstractmethod
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        raise NotImplementedError
247

248
    def _get_prompt_updates(
249
250
251
252
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
253
    ) -> Sequence[PromptUpdate]:
254
        hf_config = self.info.get_hf_config()
255
256
257
258
259
260
261
262
263
264
        image_token_id = hf_config.image_token_index

        def get_replacement(item_idx: int):
            images = mm_items.get_items(
                "image", (ImageEmbeddingItems, ImageProcessorItems))

            if isinstance(images, ImageEmbeddingItems):
                num_image_tokens = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
265
                num_image_tokens = self.info.get_num_image_tokens(
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
                    image_width=image_size.width,
                    image_height=image_size.height,
                )

            return [image_token_id] * num_image_tokens

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
                replacement=get_replacement,
            ),
        ]


281
282
class LlavaMultiModalProcessor(
        BaseLlavaMultiModalProcessor[LlavaProcessingInfo]):
283

284
285
286
287
288
289
290
291
292
293
294
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )


295
class PixtralHFProcessingInfo(BaseLlavaProcessingInfo):
296

297
298
    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(PixtralProcessor, **kwargs)
299

300

301
302
class PixtralHFMultiModalProcessor(
        BaseMultiModalProcessor[PixtralHFProcessingInfo]):
303

304
305
306
307
308
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
309
        tok_kwargs: Mapping[str, object],
310
311
312
313
314
    ) -> BatchFeature:
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
315
            tok_kwargs=tok_kwargs,
316
        )
317

318
319
        pixel_values = processed_outputs.get("pixel_values")
        if pixel_values is not None:
320
321
322
323
            # Avoid padding since we need the output for each image to be
            # independent of other images for the cache to work correctly
            image_sizes = processed_outputs["image_sizes"]
            assert len(pixel_values) == len(image_sizes)
324

325
326
327
            processed_outputs["pixel_values"] = [
                p[:, :h, :w] for p, (h, w) in zip(pixel_values, image_sizes)
            ]
328

329
        return processed_outputs
330

331
332
333
334
335
336
337
338
339
340
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )

341
    def _get_prompt_updates(
342
343
        self,
        mm_items: MultiModalDataItems,
344
345
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
346
    ) -> Sequence[PromptUpdate]:
347
        processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
348
        hf_config = self.info.get_hf_config()
349
350
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()
351

352
353
354
        image_break_id = vocab[processor.image_break_token]
        image_token_id = hf_config.image_token_index
        image_end_id = vocab[processor.image_end_token]
355

356
357
        assert isinstance(hf_config.vision_config, PixtralVisionConfig)
        encoder_info = PixtralHFEncoderInfo(hf_config)
358

359
360
361
        def get_replacement(item_idx: int):
            images = mm_items.get_items("image", ImageProcessorItems)
            image_size = images.get_image_size(item_idx)
362

363
            ncols, nrows = encoder_info.get_patch_grid_size(
364
365
366
                image_width=image_size.width,
                image_height=image_size.height,
            )
367

368
369
            tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
            tokens[-1] = image_end_id
370

371
            return PromptUpdateDetails.select_token_id(tokens, image_token_id)
372
373
374
375
376

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
377
378
                replacement=get_replacement,
            ),
379
380
        ]

381

382
383
384
385
386
387
388
389
390
391
def _build_llava_or_pixtral_hf_info(
    ctx: InputProcessingContext, ) -> BaseLlavaProcessingInfo:
    hf_config = ctx.get_hf_config(LlavaConfig)

    if isinstance(hf_config.vision_config, PixtralVisionConfig):
        return PixtralHFProcessingInfo(ctx)

    return LlavaProcessingInfo(ctx)


392
def _build_llava_or_pixtral_hf_processor(
393
394
    info: _I,
    dummy_inputs: BaseDummyInputsBuilder[_I],
395
396
    *,
    cache: Optional[ProcessingCache] = None,
397
) -> BaseMultiModalProcessor:
398
    if isinstance(info, PixtralHFProcessingInfo):
399
        return PixtralHFMultiModalProcessor(
400
401
402
403
404
405
406
407
408
            info,
            dummy_inputs,  # type: ignore
            cache=cache,
        )

    if isinstance(info, LlavaProcessingInfo):
        return LlavaMultiModalProcessor(
            info,
            dummy_inputs,  # type: ignore
409
            cache=cache,
410
        )
411

412
    raise NotImplementedError(type(info))
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435


def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
    """Determine the number of hidden layers to initialize up to in the
    visual encoder.
    
    Args:
        hf_config: Model config with vision feature layer(s).
    """
    feature_layers = hf_config.vision_feature_layer
    num_hidden_layers = hf_config.vision_config.num_hidden_layers
    # If we have one feature layer, initialize up to that layer
    if isinstance(feature_layers, int):
        return _get_layer_index(feature_layers, num_hidden_layers)
    # If we have multiple feature layers, initialize up to the deepest one
    elif isinstance(feature_layers, (list, tuple)):
        return max(
            _get_layer_index(idx, num_hidden_layers) for idx in feature_layers)
    raise TypeError(f"vision_layer_feature type: {type(feature_layers)}"
                    " is not supported")


def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
436
    """Given a signed vision feature layer, get the number of hidden layers
437
438
439
440
441
442
443
444
445
    needed to leverage it.

    Args:
        feature_layer_index: Index of a required layer in the visual encoder.
        num_hidden_layers: The total number of hidden layers in the visual
            encoder.
    """
    if feature_layer_index < 0:
        return num_hidden_layers + feature_layer_index + 1
446
    return feature_layer_index
447
448
449
450
451
452
453


def init_vision_tower_for_llava(
    hf_config: LlavaLikeConfig,
    quant_config: Optional[QuantizationConfig],
    *,
    require_post_norm: Optional[bool] = None,
454
    prefix: str = "",
455
) -> Union[CLIPVisionModel, SiglipVisionModel, PixtralHFVisionModel]:
456
457
    vision_config = hf_config.vision_config

458
459
    # Initialize the vision tower only up to the deepest required feature layer
    num_hidden_layers = _get_num_hidden_layers(hf_config)
460
461
462
463

    if isinstance(vision_config, CLIPVisionConfig):
        return CLIPVisionModel(
            vision_config,
464
            quant_config=quant_config,
465
            num_hidden_layers_override=num_hidden_layers,
466
            require_post_norm=require_post_norm,
467
            prefix=prefix,
468
469
470
471
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return SiglipVisionModel(
            vision_config,
472
            quant_config=quant_config,
473
            num_hidden_layers_override=num_hidden_layers,
474
            require_post_norm=require_post_norm,
475
            prefix=prefix,
476
        )
477
    elif isinstance(vision_config, PixtralVisionConfig):
478
479
        return PixtralHFVisionModel(
            vision_config,
480
            quant_config=quant_config,
481
482
            num_hidden_layers_override=num_hidden_layers,
            require_post_norm=require_post_norm,
483
            prefix=prefix,
484
        )
485
486
487
488
489

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


490
491
492
@MULTIMODAL_REGISTRY.register_processor(_build_llava_or_pixtral_hf_processor,
                                        info=_build_llava_or_pixtral_hf_info,
                                        dummy_inputs=LlavaDummyInputsBuilder)
493
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
494
495
496
497

    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"]
498
    }
499

500
501
502
503
504
505
506
507
508
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.language_model.": "language_model.model.",
            "model.vision_tower.": "vision_tower.",
            "model.multi_modal_projector.": "multi_modal_projector.",
            "lm_head.": "language_model.lm_head.",
        })

509
510
511
512
513
514
515
    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
        if modality.startswith("image"):
            return "<image>"

        raise ValueError("Only image modality is supported")

516
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
517
        super().__init__()
518

519
520
521
522
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

523
        self.config = config
524
        self.multimodal_config = multimodal_config
525

526
527
528
529
530
531
532
533
534
        # NOTE: These are special cases for Pixtral-12B in the HF-format
        # https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json  # noqa
        if (config.text_config.architectures is None
                and config.text_config.model_type == "mistral"):
            config.text_config.architectures = ["MistralForCausalLM"]
        if (config.projector_hidden_act is None
                and config.vision_config.hidden_act == "gelu"):
            config.projector_hidden_act = "gelu"

535
        # TODO: Optionally initializes this for supporting embeddings.
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
        if multimodal_config.get_limit_per_prompt("image"):
            self.vision_tower = init_vision_tower_for_llava(
                config,
                quant_config,
                require_post_norm=False,
                prefix=maybe_prefix(prefix, "vision_tower"))
            self.multi_modal_projector = LlavaMultiModalProjector(
                vision_hidden_size=config.vision_config.hidden_size,
                text_hidden_size=config.text_config.hidden_size,
                projector_hidden_act=config.projector_hidden_act,
                multimodal_projector_bias=config.multimodal_projector_bias,
                quant_config=quant_config,
                prefix=maybe_prefix(prefix, "multi_modal_projector"))
        else:
            self.vision_tower = None
            self.multi_modal_projector = None
552

553
        self.language_model = init_vllm_registered_model(
554
            vllm_config=vllm_config,
555
556
557
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )
558

559
560
561
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

562
    def _parse_and_validate_image_input(
563
564
            self, **kwargs: object) -> Optional[LlavaImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
565
        image_embeds = kwargs.pop("image_embeds", None)
566

567
        if pixel_values is None and image_embeds is None:
568
            return None
569

570
        if pixel_values is not None:
571
            if not isinstance(pixel_values, (torch.Tensor, list)):
572
573
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")
574

575
            if self.config.vision_config.model_type == "pixtral":
576
577
578
                return PixtralHFImagePixelInputs(
                    type="pixel_values_pixtral",
                    pixel_values=flatten_bn(pixel_values),
579
580
                )

581
            expected_h = expected_w = self.config.vision_config.image_size
582
583
            return LlavaImagePixelInputs(
                type="pixel_values",
584
585
586
587
588
                pixel_values=flatten_bn(pixel_values, concat=True),
                resolve_bindings={
                    "h": expected_h,
                    "w": expected_w
                },
589
590
591
            )

        if image_embeds is not None:
592
            if not isinstance(image_embeds, (torch.Tensor, list)):
593
594
                raise ValueError("Incorrect type of image embeddings. "
                                 f"Got type: {type(image_embeds)}")
595

596
597
598
            if self.config.vision_config.model_type == "pixtral":
                raise ValueError("Pixtral-HF does not support image_embeds.")

599
600
            return LlavaImageEmbeddingInputs(
                type="image_embeds",
601
                data=flatten_bn(image_embeds, concat=True),
602
603
604
            )

        raise AssertionError("This line should be unreachable.")
605
606
607
608
609
610
611
612
613
614
615

    def _select_image_features(self, image_features: torch.Tensor, *,
                               strategy: str) -> torch.Tensor:
        # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421  # noqa
        if strategy == "default":
            return image_features[:, 1:]
        elif strategy == "full":
            return image_features

        raise ValueError(f"Unexpected select feature strategy: {strategy}")

616
617
    def _image_pixels_to_features(
        self,
618
619
        vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
                            PixtralHFVisionModel],
620
        pixel_values: Union[torch.Tensor, list[torch.Tensor]],
621
    ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
622
623
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
624
        image_features = vision_tower(pixel_values)
625

626
627
628
629
630
631
632
633
634
        def select_features(leaf: torch.Tensor):
            return self._select_image_features(
                leaf,
                strategy=self.config.vision_feature_select_strategy,
            )

        return cast(
            Union[torch.Tensor, tuple[torch.Tensor, ...]],
            json_map_leaves(select_features, image_features),
635
636
        )

637
638
639
    def _process_image_pixels(
        self,
        inputs: Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs],
640
    ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
641
642
        assert self.vision_tower is not None

643
        pixel_values = inputs["pixel_values"]
644
645
646

        return self._image_pixels_to_features(self.vision_tower, pixel_values)

647
648
649
650
    def _process_image_input(
        self,
        image_input: LlavaImageInputs,
    ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
651
652
653
        if image_input["type"] == "image_embeds":
            return image_input["data"]

654
655
        assert self.vision_tower is not None
        image_features = self._process_image_pixels(image_input)
656

657
658
659
660
661
662
663
664
665
666
667
        if isinstance(image_features, torch.Tensor):
            return self.multi_modal_projector(image_features)

        feature_sizes = [
            image_feature.shape[0] for image_feature in image_features
        ]

        image_embeds = self.multi_modal_projector(torch.cat(image_features))
        image_embeds = torch.split(image_embeds, feature_sizes)
        return image_embeds

668
669
670
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

671
672
    def get_multimodal_embeddings(self,
                                  **kwargs: object) -> MultiModalEmbeddings:
673
674
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
675
            return []
676

677
        return self._process_image_input(image_input)
678
679
680
681

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
682
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
683
684
    ) -> torch.Tensor:
        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
685
686
        if multimodal_embeddings is not None \
            and len(multimodal_embeddings) != 0:
687
            inputs_embeds = merge_multimodal_embeddings(
688
689
                input_ids,
                inputs_embeds,
690
                multimodal_embeddings,
691
692
                self.config.image_token_index,
            )
693
694
        return inputs_embeds

695
696
697
698
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
699
        intermediate_tensors: Optional[IntermediateTensors] = None,
700
        inputs_embeds: Optional[torch.Tensor] = None,
701
        **kwargs: object,
702
    ) -> Union[torch.Tensor, IntermediateTensors]:
Cyrus Leung's avatar
Cyrus Leung committed
703
        """Run forward pass for LLaVA-1.5.
704
705
706

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
707

708
        Concretely, consider a text prompt:
709
710
        `"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.

711
        Tokenizer outputs:
712
713
714
715
        `[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
        278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
716
        before they are inputted to the model, so the input processor prepends
717
718
719
720
721
722
723
724
725
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
        29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
        29901]`.

        We insert 575 tokens so that including the original image token in the
        input, there are a total of 576 (24 * 24) image tokens, which
        corresponds to the number of image tokens inputted to the language
        model, i.e. the number of image tokens outputted by the visual encoder.
726
727
728
729
730
731
732

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
Cyrus Leung's avatar
Cyrus Leung committed
733
            pixel_values: The pixels in each input image.
734

735
736
        Info:
            [LlavaImageInputs][]
737
        """
738
739
        if intermediate_tensors is not None:
            inputs_embeds = None
740
741
742

        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
743
        elif inputs_embeds is None:
744
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
745
746
747
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
            input_ids = None
748

749
750
        hidden_states = self.language_model.model(input_ids,
                                                  positions,
751
                                                  intermediate_tensors,
752
                                                  inputs_embeds=inputs_embeds)
753
754
755

        return hidden_states

756
757
758
759
760
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
761
762
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)
763

764
765
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
766
767
768
769
770
        skip_prefixes = []
        if self.vision_tower is None and self.multi_modal_projector is None:
            skip_prefixes.extend(["vision_tower.", "multi_modal_projector."])

        loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
771
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
772
773


774
775
class MantisProcessingInfo(LlavaProcessingInfo):

776
    def get_hf_processor(self, **kwargs: object):
777
778
779
        hf_config = self.get_hf_config()
        vision_info = self.get_vision_encoder_info()

780
        kwargs.setdefault("patch_size", vision_info.get_patch_size())
781
782
783
784
        kwargs.setdefault(
            "vision_feature_select_strategy",
            hf_config.vision_feature_select_strategy,
        )
785

786
        return self.ctx.get_hf_processor(LlavaProcessor, **kwargs)
787
788


789
class MantisMultiModalProcessor(LlavaMultiModalProcessor):
790

791
792
    def apply(
        self,
793
        prompt: Union[str, list[int]],
794
795
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
796
        tokenization_kwargs: Optional[Mapping[str, object]] = None,
797
        return_mm_hashes: bool = False,
798
    ) -> MultiModalInputs:
799
        hf_config = self.info.get_hf_config()
800
        image_token_id = hf_config.image_token_index
801
802

        # Assume that it doesn't depend on the image size
803
        num_image_tokens = self.info.get_num_image_tokens(
804
805
806
            image_width=-1,
            image_height=-1,
        )
807

808
        result = super().apply(prompt, mm_data, hf_processor_mm_kwargs,
809
                               tokenization_kwargs, return_mm_hashes)
810

811
812
        mm_items = self._to_mm_items(mm_data)
        mm_item_counts = mm_items.get_all_counts()
813
        mm_kwargs = result["mm_kwargs"]
814
        mm_hashes = result["mm_hashes"]
815
816
817
818
819
820

        # We reimplement the functionality of MLlavaProcessor from
        # https://github.com/TIGER-AI-Lab/Mantis.git
        def get_replacement_mantis(item_idx: int):
            return "".join([
                f"(image {item_idx+1}: <Image>",  # 7 tokens
821
                "<image>" * num_image_tokens,
822
823
824
                "</Image>)",  # 3 tokens
            ])

825
        mantis_mm_repls = self._bind_and_group_updates([
826
827
            PromptReplacement(
                modality="image",
828
                target=[image_token_id] * num_image_tokens,
829
830
831
832
                replacement=get_replacement_mantis,
            )
        ])

833
        prompt_ids, prompt, _ = self._apply_prompt_updates(
834
            result["prompt_token_ids"],
835
            mantis_mm_repls,
836
837
838
            mm_item_counts,
        )

839
        unbound_orig_repls = self._get_prompt_updates(
840
841
842
843
            mm_items,
            hf_processor_mm_kwargs,
            mm_kwargs,
        )
844
        orig_repls = self._bind_and_group_updates(unbound_orig_repls)
845
846
847
848
849
850
851

        mm_placeholders = self._find_mm_placeholders(
            orig_repls,
            prompt_ids,
            mm_item_counts,
        )
        self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
852

853
854
855
        mm_placeholder_ranges = {
            modality: [item.to_range() for item in placeholders]
            for modality, placeholders in mm_placeholders.items()
856
857
        }

858
        return MultiModalInputs(
859
            type="multimodal",
860
            prompt=prompt,
861
862
            prompt_token_ids=prompt_ids,
            mm_kwargs=mm_kwargs,
863
            mm_hashes=mm_hashes,
864
            mm_placeholders=mm_placeholder_ranges,
865
        )
866
867
868
869


# To use this model, please use
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
870
@MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor,
871
                                        info=MantisProcessingInfo,
872
                                        dummy_inputs=LlavaDummyInputsBuilder)
873
874
class MantisForConditionalGeneration(LlavaForConditionalGeneration):
    pass