llava.py 31.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import abstractmethod
5
from collections.abc import Iterable, Mapping, Sequence
6
from typing import (Annotated, Final, Literal, Optional, Protocol, TypeVar,
7
                    Union, cast)
8
9

import torch
10
import torch.nn as nn
11
12
from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig,
                          PixtralVisionConfig, PretrainedConfig,
13
                          SiglipVisionConfig)
14
from transformers.models.llava import LlavaProcessor
15
from transformers.models.pixtral import PixtralProcessor
16

17
from vllm.config import VllmConfig
18
from vllm.inputs import InputProcessingContext
19
from vllm.model_executor.layers.activation import get_act_fn
20
21
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               RowParallelLinear)
22
from vllm.model_executor.layers.quantization import QuantizationConfig
23
from vllm.model_executor.sampling_metadata import SamplingMetadata
24
from vllm.multimodal import MULTIMODAL_REGISTRY
25
from vllm.multimodal.cache import BaseMultiModalProcessorCache
26
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
27
28
                                    MultiModalInputs, MultiModalKwargsItems,
                                    MultiModalUUIDDict)
29
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
30
                                   ImageSize, MultiModalDataItems)
31
from vllm.multimodal.processing import (BaseMultiModalProcessor,
32
33
                                        BaseProcessingInfo, PromptReplacement,
                                        PromptUpdate, PromptUpdateDetails)
34
from vllm.multimodal.profiling import BaseDummyInputsBuilder
35
from vllm.sequence import IntermediateTensors
36
from vllm.utils.jsontree import json_map_leaves
37
from vllm.utils.tensor_schema import TensorSchema, TensorShape
38

39
from .clip import CLIPVisionModel
40
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
41
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
42
from .siglip import SiglipVisionModel
43
44
45
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
                    init_vllm_registered_model, maybe_prefix,
                    merge_multimodal_embeddings)
46
from .vision import get_vision_encoder_info
47
48


49
class LlavaImagePixelInputs(TensorSchema):
50
    """
51
52
53
54
55
56
    Dimensions:
        - bn: Batch size * number of images
        - c: Number of channels (3)
        - h: Height
        - w: Width
    
57
58
59
    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """
60
61
    type: Literal["pixel_values"] = "pixel_values"
    pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
62

63

64
class PixtralHFImagePixelInputs(TensorSchema):
65
    """
66
67
68
69
70
71
    Dimensions:
        - bn: Batch size * number of images
        - c: Number of channels
        - h: Height
        - w: Width
    
72
73
74
    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """
75
    type: Literal["pixel_values_pixtral"] = "pixel_values_pixtral"
76
77
78
    pixel_values: Annotated[
        Union[torch.Tensor, list[torch.Tensor]],
        TensorShape("bn", "c", "h", "w", dynamic_dims={"h", "w"})]
79

80

81
class LlavaImageEmbeddingInputs(TensorSchema):
82
    """
83
84
85
86
87
88
89
    Dimensions:
        - bn: Batch size * number of images
        - ifs: Image feature size
        - hs: Hidden size (must match language model backbone)
    """
    type: Literal["image_embeds"] = "image_embeds"
    data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")]
90
91


92
93
LlavaImageInputs = Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs,
                         LlavaImageEmbeddingInputs]
94
95


96
97
class LlavaMultiModalProjector(nn.Module):

98
99
100
101
    def __init__(self,
                 vision_hidden_size: int,
                 text_hidden_size: int,
                 projector_hidden_act: str,
102
                 multimodal_projector_bias: bool,
103
104
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
105
106
        super().__init__()

107
108
        self.linear_1 = ColumnParallelLinear(vision_hidden_size,
                                             text_hidden_size,
109
                                             bias=multimodal_projector_bias,
110
111
                                             quant_config=quant_config,
                                             prefix=f"{prefix}.linear_1")
112
        self.act = get_act_fn(projector_hidden_act)
113
114
        self.linear_2 = RowParallelLinear(text_hidden_size,
                                          text_hidden_size,
115
                                          bias=multimodal_projector_bias,
116
117
                                          quant_config=quant_config,
                                          prefix=f"{prefix}.linear_2")
118

119
    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
120
        hidden_states, _ = self.linear_1(image_features)
121
        hidden_states = self.act(hidden_states)
122
        hidden_states, _ = self.linear_2(hidden_states)
123
124
125
        return hidden_states


126
127
class LlavaLikeConfig(Protocol):
    vision_config: Final[PretrainedConfig]
128
    image_token_index: Final[int]
129
    vision_feature_select_strategy: Final[str]
130
    vision_feature_layer: Final[Union[int, list[int]]]
131

132

133
134
135
136
class LlavaLikeProcessor(Protocol):
    image_token: Final[str]


137
class BaseLlavaProcessingInfo(BaseProcessingInfo):
138

139
    def get_hf_config(self) -> LlavaLikeConfig:
140
        return self.ctx.get_hf_config(LlavaConfig)
141

142
143
    def get_vision_encoder_info(self):
        return get_vision_encoder_info(self.get_hf_config())
144

145
    @abstractmethod
146
    def get_hf_processor(self, **kwargs: object) -> LlavaLikeProcessor:
147
        raise NotImplementedError
148

149
150
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None}
151

152
153
154
155
156
157
158
159
160
161
162
163
164
    def _apply_feature_select_strategy(
        self,
        strategy: str,
        encoder_num_image_tokens: int,
    ) -> int:
        if strategy == "default":
            return encoder_num_image_tokens - 1
        if strategy == "full":
            return encoder_num_image_tokens

        msg = f"Unexpected feature select strategy: {strategy!r}"
        raise NotImplementedError(msg)

165
166
167
168
169
170
171
172
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        hf_config = self.get_hf_config()
        vision_encoder_info = self.get_vision_encoder_info()
173

174
175
176
177
178
179
180
        return self._apply_feature_select_strategy(
            hf_config.vision_feature_select_strategy,
            vision_encoder_info.get_num_image_tokens(
                image_width=image_width,
                image_height=image_height,
            ),
        )
181

182
183
    def get_image_size_with_most_features(self) -> ImageSize:
        vision_encoder_info = self.get_vision_encoder_info()
184
185
        width = height = vision_encoder_info.get_image_size()
        return ImageSize(width=width, height=height)
186

187
188
    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()
189

190
        return self.get_num_image_tokens(
191
192
193
194
            image_width=target_width,
            image_height=target_height,
        )

195
196
197
198
199
200

_I = TypeVar("_I", bound=BaseLlavaProcessingInfo)


class LlavaDummyInputsBuilder(BaseDummyInputsBuilder[_I]):

201
202
203
204
205
206
207
208
209
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        processor = self.info.get_hf_processor()
        image_token = processor.image_token

        return image_token * num_images

    def get_dummy_mm_data(
210
        self,
211
        seq_len: int,
212
        mm_counts: Mapping[str, int],
213
    ) -> MultiModalDataDict:
214
215
        num_images = mm_counts.get("image", 0)

216
217
        target_width, target_height = \
            self.info.get_image_size_with_most_features()
218

219
        return {
220
221
222
223
224
225
226
            "image":
            self._get_dummy_images(width=target_width,
                                   height=target_height,
                                   num_images=num_images)
        }


227
class LlavaProcessingInfo(BaseLlavaProcessingInfo):
228

229
    def get_hf_processor(self, **kwargs: object):
230
231
232
233
234
235
236
        hf_processor = self.ctx.get_hf_processor(LlavaProcessor, **kwargs)
        # In case patch_size is omitted from `processor_config.json`
        # e.g. for E5-V: https://huggingface.co/royokong/e5-v
        if hf_processor.patch_size is None:
            patch_size = self.get_vision_encoder_info().get_patch_size()
            hf_processor.patch_size = patch_size
        return hf_processor
237
238


239
class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor[_I]):
240
241
242
243
244
245
246
247
248

    # Copied from BaseMultiModalProcessor
    @abstractmethod
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        raise NotImplementedError
249

250
    def _get_prompt_updates(
251
252
253
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
254
        out_mm_kwargs: MultiModalKwargsItems,
255
    ) -> Sequence[PromptUpdate]:
256
        hf_config = self.info.get_hf_config()
257
258
259
260
261
262
263
264
265
266
        image_token_id = hf_config.image_token_index

        def get_replacement(item_idx: int):
            images = mm_items.get_items(
                "image", (ImageEmbeddingItems, ImageProcessorItems))

            if isinstance(images, ImageEmbeddingItems):
                num_image_tokens = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
267
                num_image_tokens = self.info.get_num_image_tokens(
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
                    image_width=image_size.width,
                    image_height=image_size.height,
                )

            return [image_token_id] * num_image_tokens

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
                replacement=get_replacement,
            ),
        ]


283
284
class LlavaMultiModalProcessor(
        BaseLlavaMultiModalProcessor[LlavaProcessingInfo]):
285

286
287
288
289
290
291
292
293
294
295
296
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )


297
class PixtralHFProcessingInfo(BaseLlavaProcessingInfo):
298

299
300
    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(PixtralProcessor, **kwargs)
301

302

303
304
class PixtralHFMultiModalProcessor(
        BaseMultiModalProcessor[PixtralHFProcessingInfo]):
305

306
307
308
309
310
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
311
        tok_kwargs: Mapping[str, object],
312
313
314
315
316
    ) -> BatchFeature:
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
317
            tok_kwargs=tok_kwargs,
318
        )
319

320
321
        pixel_values = processed_outputs.get("pixel_values")
        if pixel_values is not None:
322
323
324
325
            # Avoid padding since we need the output for each image to be
            # independent of other images for the cache to work correctly
            image_sizes = processed_outputs["image_sizes"]
            assert len(pixel_values) == len(image_sizes)
326

327
328
329
            processed_outputs["pixel_values"] = [
                p[:, :h, :w] for p, (h, w) in zip(pixel_values, image_sizes)
            ]
330

331
        return processed_outputs
332

333
334
335
336
337
338
339
340
341
342
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )

343
    def _get_prompt_updates(
344
345
        self,
        mm_items: MultiModalDataItems,
346
        hf_processor_mm_kwargs: Mapping[str, object],
347
        out_mm_kwargs: MultiModalKwargsItems,
348
    ) -> Sequence[PromptUpdate]:
349
        processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
350
        hf_config = self.info.get_hf_config()
351
352
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()
353

354
355
356
        image_break_id = vocab[processor.image_break_token]
        image_token_id = hf_config.image_token_index
        image_end_id = vocab[processor.image_end_token]
357

358
359
        assert isinstance(hf_config.vision_config, PixtralVisionConfig)
        encoder_info = PixtralHFEncoderInfo(hf_config)
360

361
362
363
        def get_replacement(item_idx: int):
            images = mm_items.get_items("image", ImageProcessorItems)
            image_size = images.get_image_size(item_idx)
364

365
            ncols, nrows = encoder_info.get_patch_grid_size(
366
367
368
                image_width=image_size.width,
                image_height=image_size.height,
            )
369

370
371
            tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
            tokens[-1] = image_end_id
372

373
            return PromptUpdateDetails.select_token_id(tokens, image_token_id)
374
375
376
377
378

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
379
380
                replacement=get_replacement,
            ),
381
382
        ]

383

384
385
386
387
388
389
390
391
392
393
def _build_llava_or_pixtral_hf_info(
    ctx: InputProcessingContext, ) -> BaseLlavaProcessingInfo:
    hf_config = ctx.get_hf_config(LlavaConfig)

    if isinstance(hf_config.vision_config, PixtralVisionConfig):
        return PixtralHFProcessingInfo(ctx)

    return LlavaProcessingInfo(ctx)


394
def _build_llava_or_pixtral_hf_processor(
395
396
    info: _I,
    dummy_inputs: BaseDummyInputsBuilder[_I],
397
    *,
398
    cache: Optional[BaseMultiModalProcessorCache] = None,
399
) -> BaseMultiModalProcessor:
400
    if isinstance(info, PixtralHFProcessingInfo):
401
        return PixtralHFMultiModalProcessor(
402
403
404
405
406
407
408
409
410
            info,
            dummy_inputs,  # type: ignore
            cache=cache,
        )

    if isinstance(info, LlavaProcessingInfo):
        return LlavaMultiModalProcessor(
            info,
            dummy_inputs,  # type: ignore
411
            cache=cache,
412
        )
413

414
    raise NotImplementedError(type(info))
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437


def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
    """Determine the number of hidden layers to initialize up to in the
    visual encoder.
    
    Args:
        hf_config: Model config with vision feature layer(s).
    """
    feature_layers = hf_config.vision_feature_layer
    num_hidden_layers = hf_config.vision_config.num_hidden_layers
    # If we have one feature layer, initialize up to that layer
    if isinstance(feature_layers, int):
        return _get_layer_index(feature_layers, num_hidden_layers)
    # If we have multiple feature layers, initialize up to the deepest one
    elif isinstance(feature_layers, (list, tuple)):
        return max(
            _get_layer_index(idx, num_hidden_layers) for idx in feature_layers)
    raise TypeError(f"vision_layer_feature type: {type(feature_layers)}"
                    " is not supported")


def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
438
    """Given a signed vision feature layer, get the number of hidden layers
439
440
441
442
443
444
445
446
447
    needed to leverage it.

    Args:
        feature_layer_index: Index of a required layer in the visual encoder.
        num_hidden_layers: The total number of hidden layers in the visual
            encoder.
    """
    if feature_layer_index < 0:
        return num_hidden_layers + feature_layer_index + 1
448
    return feature_layer_index
449
450
451
452
453
454
455


def init_vision_tower_for_llava(
    hf_config: LlavaLikeConfig,
    quant_config: Optional[QuantizationConfig],
    *,
    require_post_norm: Optional[bool] = None,
456
    prefix: str = "",
457
) -> Union[CLIPVisionModel, SiglipVisionModel, PixtralHFVisionModel]:
458
459
    vision_config = hf_config.vision_config

460
461
    # Initialize the vision tower only up to the deepest required feature layer
    num_hidden_layers = _get_num_hidden_layers(hf_config)
462
463
464
465

    if isinstance(vision_config, CLIPVisionConfig):
        return CLIPVisionModel(
            vision_config,
466
            quant_config=quant_config,
467
            num_hidden_layers_override=num_hidden_layers,
468
            require_post_norm=require_post_norm,
469
            prefix=prefix,
470
471
472
473
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return SiglipVisionModel(
            vision_config,
474
            quant_config=quant_config,
475
            num_hidden_layers_override=num_hidden_layers,
476
            require_post_norm=require_post_norm,
477
            prefix=prefix,
478
        )
479
    elif isinstance(vision_config, PixtralVisionConfig):
480
481
        return PixtralHFVisionModel(
            vision_config,
482
            quant_config=quant_config,
483
484
            num_hidden_layers_override=num_hidden_layers,
            require_post_norm=require_post_norm,
485
            prefix=prefix,
486
        )
487
488
489
490
491

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


492
493
494
@MULTIMODAL_REGISTRY.register_processor(_build_llava_or_pixtral_hf_processor,
                                        info=_build_llava_or_pixtral_hf_info,
                                        dummy_inputs=LlavaDummyInputsBuilder)
495
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
496
497
498
499

    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"]
500
    }
501

502
503
504
505
506
507
508
509
510
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.language_model.": "language_model.model.",
            "model.vision_tower.": "vision_tower.",
            "model.multi_modal_projector.": "multi_modal_projector.",
            "lm_head.": "language_model.lm_head.",
        })

511
512
513
514
515
516
517
    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
        if modality.startswith("image"):
            return "<image>"

        raise ValueError("Only image modality is supported")

518
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
519
        super().__init__()
520

521
522
523
524
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

525
        self.config = config
526
        self.multimodal_config = multimodal_config
527

528
529
530
531
532
533
534
535
536
        # NOTE: These are special cases for Pixtral-12B in the HF-format
        # https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json  # noqa
        if (config.text_config.architectures is None
                and config.text_config.model_type == "mistral"):
            config.text_config.architectures = ["MistralForCausalLM"]
        if (config.projector_hidden_act is None
                and config.vision_config.hidden_act == "gelu"):
            config.projector_hidden_act = "gelu"

537
        # TODO: Optionally initializes this for supporting embeddings.
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
        if multimodal_config.get_limit_per_prompt("image"):
            self.vision_tower = init_vision_tower_for_llava(
                config,
                quant_config,
                require_post_norm=False,
                prefix=maybe_prefix(prefix, "vision_tower"))
            self.multi_modal_projector = LlavaMultiModalProjector(
                vision_hidden_size=config.vision_config.hidden_size,
                text_hidden_size=config.text_config.hidden_size,
                projector_hidden_act=config.projector_hidden_act,
                multimodal_projector_bias=config.multimodal_projector_bias,
                quant_config=quant_config,
                prefix=maybe_prefix(prefix, "multi_modal_projector"))
        else:
            self.vision_tower = None
            self.multi_modal_projector = None
554

555
        self.language_model = init_vllm_registered_model(
556
            vllm_config=vllm_config,
557
558
559
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )
560

561
562
563
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

564
    def _parse_and_validate_image_input(
565
566
            self, **kwargs: object) -> Optional[LlavaImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
567
        image_embeds = kwargs.pop("image_embeds", None)
568

569
        if pixel_values is None and image_embeds is None:
570
            return None
571

572
        if pixel_values is not None:
573
            if not isinstance(pixel_values, (torch.Tensor, list)):
574
575
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")
576

577
            if self.config.vision_config.model_type == "pixtral":
578
579
580
                return PixtralHFImagePixelInputs(
                    type="pixel_values_pixtral",
                    pixel_values=flatten_bn(pixel_values),
581
582
                )

583
            expected_h = expected_w = self.config.vision_config.image_size
584
585
            return LlavaImagePixelInputs(
                type="pixel_values",
586
587
588
589
590
                pixel_values=flatten_bn(pixel_values, concat=True),
                resolve_bindings={
                    "h": expected_h,
                    "w": expected_w
                },
591
592
593
            )

        if image_embeds is not None:
594
            if not isinstance(image_embeds, (torch.Tensor, list)):
595
596
                raise ValueError("Incorrect type of image embeddings. "
                                 f"Got type: {type(image_embeds)}")
597

598
599
600
            if self.config.vision_config.model_type == "pixtral":
                raise ValueError("Pixtral-HF does not support image_embeds.")

601
602
            return LlavaImageEmbeddingInputs(
                type="image_embeds",
603
                data=flatten_bn(image_embeds, concat=True),
604
605
606
            )

        raise AssertionError("This line should be unreachable.")
607
608
609
610
611
612
613
614
615
616
617

    def _select_image_features(self, image_features: torch.Tensor, *,
                               strategy: str) -> torch.Tensor:
        # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421  # noqa
        if strategy == "default":
            return image_features[:, 1:]
        elif strategy == "full":
            return image_features

        raise ValueError(f"Unexpected select feature strategy: {strategy}")

618
619
    def _image_pixels_to_features(
        self,
620
621
        vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
                            PixtralHFVisionModel],
622
        pixel_values: Union[torch.Tensor, list[torch.Tensor]],
623
    ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
624
625
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
626
        image_features = vision_tower(pixel_values)
627

628
629
630
631
632
633
634
635
636
        def select_features(leaf: torch.Tensor):
            return self._select_image_features(
                leaf,
                strategy=self.config.vision_feature_select_strategy,
            )

        return cast(
            Union[torch.Tensor, tuple[torch.Tensor, ...]],
            json_map_leaves(select_features, image_features),
637
638
        )

639
640
641
    def _process_image_pixels(
        self,
        inputs: Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs],
642
    ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
643
644
        assert self.vision_tower is not None

645
        pixel_values = inputs["pixel_values"]
646
647
648

        return self._image_pixels_to_features(self.vision_tower, pixel_values)

649
650
651
652
    def _process_image_input(
        self,
        image_input: LlavaImageInputs,
    ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
653
654
655
        if image_input["type"] == "image_embeds":
            return image_input["data"]

656
657
        assert self.vision_tower is not None
        image_features = self._process_image_pixels(image_input)
658

659
660
661
662
663
664
665
666
667
668
669
        if isinstance(image_features, torch.Tensor):
            return self.multi_modal_projector(image_features)

        feature_sizes = [
            image_feature.shape[0] for image_feature in image_features
        ]

        image_embeds = self.multi_modal_projector(torch.cat(image_features))
        image_embeds = torch.split(image_embeds, feature_sizes)
        return image_embeds

670
671
672
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

673
674
    def get_multimodal_embeddings(self,
                                  **kwargs: object) -> MultiModalEmbeddings:
675
676
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
677
            return []
678

679
        return self._process_image_input(image_input)
680
681
682
683

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
684
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
685
686
    ) -> torch.Tensor:
        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
687
688
        if multimodal_embeddings is not None \
            and len(multimodal_embeddings) != 0:
689
            inputs_embeds = merge_multimodal_embeddings(
690
691
                input_ids,
                inputs_embeds,
692
                multimodal_embeddings,
693
694
                self.config.image_token_index,
            )
695
696
        return inputs_embeds

697
698
699
700
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
701
        intermediate_tensors: Optional[IntermediateTensors] = None,
702
        inputs_embeds: Optional[torch.Tensor] = None,
703
        **kwargs: object,
704
    ) -> Union[torch.Tensor, IntermediateTensors]:
Cyrus Leung's avatar
Cyrus Leung committed
705
        """Run forward pass for LLaVA-1.5.
706
707
708

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
709

710
        Concretely, consider a text prompt:
711
712
        `"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.

713
        Tokenizer outputs:
714
715
716
717
        `[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
        278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
718
        before they are inputted to the model, so the input processor prepends
719
720
721
722
723
724
725
726
727
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
        29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
        29901]`.

        We insert 575 tokens so that including the original image token in the
        input, there are a total of 576 (24 * 24) image tokens, which
        corresponds to the number of image tokens inputted to the language
        model, i.e. the number of image tokens outputted by the visual encoder.
728
729
730
731
732
733
734

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
Cyrus Leung's avatar
Cyrus Leung committed
735
            pixel_values: The pixels in each input image.
736

737
738
        Info:
            [LlavaImageInputs][]
739
        """
740
741
        if intermediate_tensors is not None:
            inputs_embeds = None
742
743
744

        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
745
        elif inputs_embeds is None:
746
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
747
748
749
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
            input_ids = None
750

751
752
        hidden_states = self.language_model.model(input_ids,
                                                  positions,
753
                                                  intermediate_tensors,
754
                                                  inputs_embeds=inputs_embeds)
755
756
757

        return hidden_states

758
759
760
761
762
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
763
764
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)
765

766
767
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
768
769
770
771
772
        skip_prefixes = []
        if self.vision_tower is None and self.multi_modal_projector is None:
            skip_prefixes.extend(["vision_tower.", "multi_modal_projector."])

        loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
773
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
774
775


776
777
class MantisProcessingInfo(LlavaProcessingInfo):

778
    def get_hf_processor(self, **kwargs: object):
779
780
781
        hf_config = self.get_hf_config()
        vision_info = self.get_vision_encoder_info()

782
        kwargs.setdefault("patch_size", vision_info.get_patch_size())
783
784
785
786
        kwargs.setdefault(
            "vision_feature_select_strategy",
            hf_config.vision_feature_select_strategy,
        )
787

788
        return self.ctx.get_hf_processor(LlavaProcessor, **kwargs)
789
790


791
class MantisMultiModalProcessor(LlavaMultiModalProcessor):
792

793
794
    def apply(
        self,
795
        prompt: Union[str, list[int]],
796
797
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
798
        tokenization_kwargs: Optional[Mapping[str, object]] = None,
799
        mm_uuids: Optional[MultiModalUUIDDict] = None,
800
    ) -> MultiModalInputs:
801
        hf_config = self.info.get_hf_config()
802
        image_token_id = hf_config.image_token_index
803
804

        # Assume that it doesn't depend on the image size
805
        num_image_tokens = self.info.get_num_image_tokens(
806
807
808
            image_width=-1,
            image_height=-1,
        )
809

810
811
812
813
        result = super().apply(prompt,
                               mm_data,
                               hf_processor_mm_kwargs,
                               tokenization_kwargs,
814
                               mm_uuids=mm_uuids)
815

816
817
        mm_items = self._to_mm_items(mm_data)
        mm_item_counts = mm_items.get_all_counts()
818
        mm_kwargs = result["mm_kwargs"]
819
        mm_hashes = result["mm_hashes"]
820
821
822
823
824
825

        # We reimplement the functionality of MLlavaProcessor from
        # https://github.com/TIGER-AI-Lab/Mantis.git
        def get_replacement_mantis(item_idx: int):
            return "".join([
                f"(image {item_idx+1}: <Image>",  # 7 tokens
826
                "<image>" * num_image_tokens,
827
828
829
                "</Image>)",  # 3 tokens
            ])

830
        mantis_mm_repls = self._bind_and_group_updates([
831
832
            PromptReplacement(
                modality="image",
833
                target=[image_token_id] * num_image_tokens,
834
835
                replacement=get_replacement_mantis,
            )
836
        ], mm_item_counts)
837

838
        prompt_ids, prompt, _ = self._apply_prompt_updates(
839
            result["prompt_token_ids"],
840
            mantis_mm_repls,
841
842
        )

843
        orig_repls = self._get_mm_prompt_updates(
844
845
846
847
            mm_items,
            hf_processor_mm_kwargs,
            mm_kwargs,
        )
848
        mm_placeholders = self._find_mm_placeholders(prompt_ids, orig_repls)
849
        self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
850

851
852
853
        mm_placeholder_ranges = {
            modality: [item.to_range() for item in placeholders]
            for modality, placeholders in mm_placeholders.items()
854
855
        }

856
        return MultiModalInputs(
857
            type="multimodal",
858
            prompt=prompt,
859
860
            prompt_token_ids=prompt_ids,
            mm_kwargs=mm_kwargs,
861
            mm_hashes=mm_hashes,
862
            mm_placeholders=mm_placeholder_ranges,
863
        )
864
865
866
867


# To use this model, please use
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
868
@MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor,
869
                                        info=MantisProcessingInfo,
870
                                        dummy_inputs=LlavaDummyInputsBuilder)
871
872
class MantisForConditionalGeneration(LlavaForConditionalGeneration):
    pass