llava.py 31.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import abstractmethod
5
from collections.abc import Iterable, Mapping, Sequence
6
from typing import (Annotated, Final, Literal, Optional, Protocol, TypeVar,
7
                    Union, cast)
8
9

import torch
10
import torch.nn as nn
11
12
from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig,
                          PixtralVisionConfig, PretrainedConfig,
13
                          SiglipVisionConfig)
14
from transformers.models.llava import LlavaProcessor
15
from transformers.models.pixtral import PixtralProcessor
16

17
from vllm.config import VllmConfig
18
from vllm.inputs import InputProcessingContext
19
from vllm.model_executor.layers.activation import get_act_fn
20
21
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               RowParallelLinear)
22
from vllm.model_executor.layers.quantization import QuantizationConfig
23
from vllm.multimodal import MULTIMODAL_REGISTRY
24
from vllm.multimodal.cache import BaseMultiModalProcessorCache
25
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
26
27
                                    MultiModalInputs, MultiModalKwargsItems,
                                    MultiModalUUIDDict)
28
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
29
                                   ImageSize, MultiModalDataItems)
30
from vllm.multimodal.processing import (BaseMultiModalProcessor,
31
32
                                        BaseProcessingInfo, PromptReplacement,
                                        PromptUpdate, PromptUpdateDetails)
33
from vllm.multimodal.profiling import BaseDummyInputsBuilder
34
from vllm.sequence import IntermediateTensors
35
from vllm.utils.jsontree import json_map_leaves
36
from vllm.utils.tensor_schema import TensorSchema, TensorShape
37

38
from .clip import CLIPVisionModel
39
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
40
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
41
from .siglip import SiglipVisionModel
42
43
44
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
                    init_vllm_registered_model, maybe_prefix,
                    merge_multimodal_embeddings)
45
from .vision import get_vision_encoder_info
46
47


48
class LlavaImagePixelInputs(TensorSchema):
49
    """
50
51
52
53
54
55
    Dimensions:
        - bn: Batch size * number of images
        - c: Number of channels (3)
        - h: Height
        - w: Width
    
56
57
58
    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """
59
60
    type: Literal["pixel_values"] = "pixel_values"
    pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
61

62

63
class PixtralHFImagePixelInputs(TensorSchema):
64
    """
65
66
67
68
69
70
    Dimensions:
        - bn: Batch size * number of images
        - c: Number of channels
        - h: Height
        - w: Width
    
71
72
73
    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """
74
    type: Literal["pixel_values_pixtral"] = "pixel_values_pixtral"
75
76
77
    pixel_values: Annotated[
        Union[torch.Tensor, list[torch.Tensor]],
        TensorShape("bn", "c", "h", "w", dynamic_dims={"h", "w"})]
78

79

80
class LlavaImageEmbeddingInputs(TensorSchema):
81
    """
82
83
84
85
86
87
88
    Dimensions:
        - bn: Batch size * number of images
        - ifs: Image feature size
        - hs: Hidden size (must match language model backbone)
    """
    type: Literal["image_embeds"] = "image_embeds"
    data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")]
89
90


91
92
LlavaImageInputs = Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs,
                         LlavaImageEmbeddingInputs]
93
94


95
96
class LlavaMultiModalProjector(nn.Module):

97
98
99
100
    def __init__(self,
                 vision_hidden_size: int,
                 text_hidden_size: int,
                 projector_hidden_act: str,
101
                 multimodal_projector_bias: bool,
102
103
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
104
105
        super().__init__()

106
107
        self.linear_1 = ColumnParallelLinear(vision_hidden_size,
                                             text_hidden_size,
108
                                             bias=multimodal_projector_bias,
109
110
                                             quant_config=quant_config,
                                             prefix=f"{prefix}.linear_1")
111
        self.act = get_act_fn(projector_hidden_act)
112
113
        self.linear_2 = RowParallelLinear(text_hidden_size,
                                          text_hidden_size,
114
                                          bias=multimodal_projector_bias,
115
116
                                          quant_config=quant_config,
                                          prefix=f"{prefix}.linear_2")
117

118
    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
119
        hidden_states, _ = self.linear_1(image_features)
120
        hidden_states = self.act(hidden_states)
121
        hidden_states, _ = self.linear_2(hidden_states)
122
123
124
        return hidden_states


125
126
class LlavaLikeConfig(Protocol):
    vision_config: Final[PretrainedConfig]
127
    image_token_index: Final[int]
128
    vision_feature_select_strategy: Final[str]
129
    vision_feature_layer: Final[Union[int, list[int]]]
130

131

132
133
134
135
class LlavaLikeProcessor(Protocol):
    image_token: Final[str]


136
class BaseLlavaProcessingInfo(BaseProcessingInfo):
137

138
    def get_hf_config(self) -> LlavaLikeConfig:
139
        return self.ctx.get_hf_config(LlavaConfig)
140

141
142
    def get_vision_encoder_info(self):
        return get_vision_encoder_info(self.get_hf_config())
143

144
    @abstractmethod
145
    def get_hf_processor(self, **kwargs: object) -> LlavaLikeProcessor:
146
        raise NotImplementedError
147

148
149
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None}
150

151
152
153
154
155
156
157
158
159
160
161
162
163
    def _apply_feature_select_strategy(
        self,
        strategy: str,
        encoder_num_image_tokens: int,
    ) -> int:
        if strategy == "default":
            return encoder_num_image_tokens - 1
        if strategy == "full":
            return encoder_num_image_tokens

        msg = f"Unexpected feature select strategy: {strategy!r}"
        raise NotImplementedError(msg)

164
165
166
167
168
169
170
171
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        hf_config = self.get_hf_config()
        vision_encoder_info = self.get_vision_encoder_info()
172

173
174
175
176
177
178
179
        return self._apply_feature_select_strategy(
            hf_config.vision_feature_select_strategy,
            vision_encoder_info.get_num_image_tokens(
                image_width=image_width,
                image_height=image_height,
            ),
        )
180

181
182
    def get_image_size_with_most_features(self) -> ImageSize:
        vision_encoder_info = self.get_vision_encoder_info()
183
184
        width = height = vision_encoder_info.get_image_size()
        return ImageSize(width=width, height=height)
185

186
187
    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()
188

189
        return self.get_num_image_tokens(
190
191
192
193
            image_width=target_width,
            image_height=target_height,
        )

194
195
196
197
198
199

_I = TypeVar("_I", bound=BaseLlavaProcessingInfo)


class LlavaDummyInputsBuilder(BaseDummyInputsBuilder[_I]):

200
201
202
203
204
205
206
207
208
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        processor = self.info.get_hf_processor()
        image_token = processor.image_token

        return image_token * num_images

    def get_dummy_mm_data(
209
        self,
210
        seq_len: int,
211
        mm_counts: Mapping[str, int],
212
    ) -> MultiModalDataDict:
213
214
        num_images = mm_counts.get("image", 0)

215
216
        target_width, target_height = \
            self.info.get_image_size_with_most_features()
217

218
        return {
219
220
221
222
223
224
225
            "image":
            self._get_dummy_images(width=target_width,
                                   height=target_height,
                                   num_images=num_images)
        }


226
class LlavaProcessingInfo(BaseLlavaProcessingInfo):
227

228
    def get_hf_processor(self, **kwargs: object):
229
230
231
232
233
234
235
        hf_processor = self.ctx.get_hf_processor(LlavaProcessor, **kwargs)
        # In case patch_size is omitted from `processor_config.json`
        # e.g. for E5-V: https://huggingface.co/royokong/e5-v
        if hf_processor.patch_size is None:
            patch_size = self.get_vision_encoder_info().get_patch_size()
            hf_processor.patch_size = patch_size
        return hf_processor
236
237


238
class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor[_I]):
239
240
241
242
243
244
245
246
247

    # Copied from BaseMultiModalProcessor
    @abstractmethod
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        raise NotImplementedError
248

249
    def _get_prompt_updates(
250
251
252
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
253
        out_mm_kwargs: MultiModalKwargsItems,
254
    ) -> Sequence[PromptUpdate]:
255
        hf_config = self.info.get_hf_config()
256
257
258
259
260
261
262
263
264
265
        image_token_id = hf_config.image_token_index

        def get_replacement(item_idx: int):
            images = mm_items.get_items(
                "image", (ImageEmbeddingItems, ImageProcessorItems))

            if isinstance(images, ImageEmbeddingItems):
                num_image_tokens = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
266
                num_image_tokens = self.info.get_num_image_tokens(
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
                    image_width=image_size.width,
                    image_height=image_size.height,
                )

            return [image_token_id] * num_image_tokens

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
                replacement=get_replacement,
            ),
        ]


282
283
class LlavaMultiModalProcessor(
        BaseLlavaMultiModalProcessor[LlavaProcessingInfo]):
284

285
286
287
288
289
290
291
292
293
294
295
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )


296
class PixtralHFProcessingInfo(BaseLlavaProcessingInfo):
297

298
299
    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(PixtralProcessor, **kwargs)
300

301

302
303
class PixtralHFMultiModalProcessor(
        BaseMultiModalProcessor[PixtralHFProcessingInfo]):
304

305
306
307
308
309
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
310
        tok_kwargs: Mapping[str, object],
311
312
313
314
315
    ) -> BatchFeature:
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
316
            tok_kwargs=tok_kwargs,
317
        )
318

319
320
        pixel_values = processed_outputs.get("pixel_values")
        if pixel_values is not None:
321
322
323
324
            # Avoid padding since we need the output for each image to be
            # independent of other images for the cache to work correctly
            image_sizes = processed_outputs["image_sizes"]
            assert len(pixel_values) == len(image_sizes)
325

326
327
328
            processed_outputs["pixel_values"] = [
                p[:, :h, :w] for p, (h, w) in zip(pixel_values, image_sizes)
            ]
329

330
        return processed_outputs
331

332
333
334
335
336
337
338
339
340
341
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )

342
    def _get_prompt_updates(
343
344
        self,
        mm_items: MultiModalDataItems,
345
        hf_processor_mm_kwargs: Mapping[str, object],
346
        out_mm_kwargs: MultiModalKwargsItems,
347
    ) -> Sequence[PromptUpdate]:
348
        processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
349
        hf_config = self.info.get_hf_config()
350
351
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()
352

353
354
355
        image_break_id = vocab[processor.image_break_token]
        image_token_id = hf_config.image_token_index
        image_end_id = vocab[processor.image_end_token]
356

357
358
        assert isinstance(hf_config.vision_config, PixtralVisionConfig)
        encoder_info = PixtralHFEncoderInfo(hf_config)
359

360
361
362
        def get_replacement(item_idx: int):
            images = mm_items.get_items("image", ImageProcessorItems)
            image_size = images.get_image_size(item_idx)
363

364
            ncols, nrows = encoder_info.get_patch_grid_size(
365
366
367
                image_width=image_size.width,
                image_height=image_size.height,
            )
368

369
370
            tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
            tokens[-1] = image_end_id
371

372
            return PromptUpdateDetails.select_token_id(tokens, image_token_id)
373
374
375
376
377

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
378
379
                replacement=get_replacement,
            ),
380
381
        ]

382

383
384
385
386
387
388
389
390
391
392
def _build_llava_or_pixtral_hf_info(
    ctx: InputProcessingContext, ) -> BaseLlavaProcessingInfo:
    hf_config = ctx.get_hf_config(LlavaConfig)

    if isinstance(hf_config.vision_config, PixtralVisionConfig):
        return PixtralHFProcessingInfo(ctx)

    return LlavaProcessingInfo(ctx)


393
def _build_llava_or_pixtral_hf_processor(
394
395
    info: _I,
    dummy_inputs: BaseDummyInputsBuilder[_I],
396
    *,
397
    cache: Optional[BaseMultiModalProcessorCache] = None,
398
) -> BaseMultiModalProcessor:
399
    if isinstance(info, PixtralHFProcessingInfo):
400
        return PixtralHFMultiModalProcessor(
401
402
403
404
405
406
407
408
409
            info,
            dummy_inputs,  # type: ignore
            cache=cache,
        )

    if isinstance(info, LlavaProcessingInfo):
        return LlavaMultiModalProcessor(
            info,
            dummy_inputs,  # type: ignore
410
            cache=cache,
411
        )
412

413
    raise NotImplementedError(type(info))
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436


def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
    """Determine the number of hidden layers to initialize up to in the
    visual encoder.
    
    Args:
        hf_config: Model config with vision feature layer(s).
    """
    feature_layers = hf_config.vision_feature_layer
    num_hidden_layers = hf_config.vision_config.num_hidden_layers
    # If we have one feature layer, initialize up to that layer
    if isinstance(feature_layers, int):
        return _get_layer_index(feature_layers, num_hidden_layers)
    # If we have multiple feature layers, initialize up to the deepest one
    elif isinstance(feature_layers, (list, tuple)):
        return max(
            _get_layer_index(idx, num_hidden_layers) for idx in feature_layers)
    raise TypeError(f"vision_layer_feature type: {type(feature_layers)}"
                    " is not supported")


def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
437
    """Given a signed vision feature layer, get the number of hidden layers
438
439
440
441
442
443
444
445
446
    needed to leverage it.

    Args:
        feature_layer_index: Index of a required layer in the visual encoder.
        num_hidden_layers: The total number of hidden layers in the visual
            encoder.
    """
    if feature_layer_index < 0:
        return num_hidden_layers + feature_layer_index + 1
447
    return feature_layer_index
448
449
450
451
452
453
454


def init_vision_tower_for_llava(
    hf_config: LlavaLikeConfig,
    quant_config: Optional[QuantizationConfig],
    *,
    require_post_norm: Optional[bool] = None,
455
    prefix: str = "",
456
) -> Union[CLIPVisionModel, SiglipVisionModel, PixtralHFVisionModel]:
457
458
    vision_config = hf_config.vision_config

459
460
    # Initialize the vision tower only up to the deepest required feature layer
    num_hidden_layers = _get_num_hidden_layers(hf_config)
461
462
463
464

    if isinstance(vision_config, CLIPVisionConfig):
        return CLIPVisionModel(
            vision_config,
465
            quant_config=quant_config,
466
            num_hidden_layers_override=num_hidden_layers,
467
            require_post_norm=require_post_norm,
468
            prefix=prefix,
469
470
471
472
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return SiglipVisionModel(
            vision_config,
473
            quant_config=quant_config,
474
            num_hidden_layers_override=num_hidden_layers,
475
            require_post_norm=require_post_norm,
476
            prefix=prefix,
477
        )
478
    elif isinstance(vision_config, PixtralVisionConfig):
479
480
        return PixtralHFVisionModel(
            vision_config,
481
            quant_config=quant_config,
482
483
            num_hidden_layers_override=num_hidden_layers,
            require_post_norm=require_post_norm,
484
            prefix=prefix,
485
        )
486
487
488
489
490

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


491
492
493
@MULTIMODAL_REGISTRY.register_processor(_build_llava_or_pixtral_hf_processor,
                                        info=_build_llava_or_pixtral_hf_info,
                                        dummy_inputs=LlavaDummyInputsBuilder)
494
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
495
496
497
498

    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"]
499
    }
500

501
502
503
504
505
506
507
508
509
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.language_model.": "language_model.model.",
            "model.vision_tower.": "vision_tower.",
            "model.multi_modal_projector.": "multi_modal_projector.",
            "lm_head.": "language_model.lm_head.",
        })

510
511
512
513
514
515
516
    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
        if modality.startswith("image"):
            return "<image>"

        raise ValueError("Only image modality is supported")

517
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
518
        super().__init__()
519

520
521
522
523
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

524
        self.config = config
525
        self.multimodal_config = multimodal_config
526

527
528
529
530
531
532
533
534
535
        # NOTE: These are special cases for Pixtral-12B in the HF-format
        # https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json  # noqa
        if (config.text_config.architectures is None
                and config.text_config.model_type == "mistral"):
            config.text_config.architectures = ["MistralForCausalLM"]
        if (config.projector_hidden_act is None
                and config.vision_config.hidden_act == "gelu"):
            config.projector_hidden_act = "gelu"

536
        # TODO: Optionally initializes this for supporting embeddings.
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
        if multimodal_config.get_limit_per_prompt("image"):
            self.vision_tower = init_vision_tower_for_llava(
                config,
                quant_config,
                require_post_norm=False,
                prefix=maybe_prefix(prefix, "vision_tower"))
            self.multi_modal_projector = LlavaMultiModalProjector(
                vision_hidden_size=config.vision_config.hidden_size,
                text_hidden_size=config.text_config.hidden_size,
                projector_hidden_act=config.projector_hidden_act,
                multimodal_projector_bias=config.multimodal_projector_bias,
                quant_config=quant_config,
                prefix=maybe_prefix(prefix, "multi_modal_projector"))
        else:
            self.vision_tower = None
            self.multi_modal_projector = None
553

554
        self.language_model = init_vllm_registered_model(
555
            vllm_config=vllm_config,
556
557
558
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )
559

560
561
562
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

563
    def _parse_and_validate_image_input(
564
565
            self, **kwargs: object) -> Optional[LlavaImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
566
        image_embeds = kwargs.pop("image_embeds", None)
567

568
        if pixel_values is None and image_embeds is None:
569
            return None
570

571
        if pixel_values is not None:
572
            if not isinstance(pixel_values, (torch.Tensor, list)):
573
574
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")
575

576
            if self.config.vision_config.model_type == "pixtral":
577
578
579
                return PixtralHFImagePixelInputs(
                    type="pixel_values_pixtral",
                    pixel_values=flatten_bn(pixel_values),
580
581
                )

582
            expected_h = expected_w = self.config.vision_config.image_size
583
584
            return LlavaImagePixelInputs(
                type="pixel_values",
585
586
587
588
589
                pixel_values=flatten_bn(pixel_values, concat=True),
                resolve_bindings={
                    "h": expected_h,
                    "w": expected_w
                },
590
591
592
            )

        if image_embeds is not None:
593
            if not isinstance(image_embeds, (torch.Tensor, list)):
594
595
                raise ValueError("Incorrect type of image embeddings. "
                                 f"Got type: {type(image_embeds)}")
596

597
598
599
            if self.config.vision_config.model_type == "pixtral":
                raise ValueError("Pixtral-HF does not support image_embeds.")

600
601
            return LlavaImageEmbeddingInputs(
                type="image_embeds",
602
                data=flatten_bn(image_embeds, concat=True),
603
604
605
            )

        raise AssertionError("This line should be unreachable.")
606
607
608
609
610
611
612
613
614
615
616

    def _select_image_features(self, image_features: torch.Tensor, *,
                               strategy: str) -> torch.Tensor:
        # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421  # noqa
        if strategy == "default":
            return image_features[:, 1:]
        elif strategy == "full":
            return image_features

        raise ValueError(f"Unexpected select feature strategy: {strategy}")

617
618
    def _image_pixels_to_features(
        self,
619
620
        vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
                            PixtralHFVisionModel],
621
        pixel_values: Union[torch.Tensor, list[torch.Tensor]],
622
    ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
623
624
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
625
        image_features = vision_tower(pixel_values)
626

627
628
629
630
631
632
633
634
635
        def select_features(leaf: torch.Tensor):
            return self._select_image_features(
                leaf,
                strategy=self.config.vision_feature_select_strategy,
            )

        return cast(
            Union[torch.Tensor, tuple[torch.Tensor, ...]],
            json_map_leaves(select_features, image_features),
636
637
        )

638
639
640
    def _process_image_pixels(
        self,
        inputs: Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs],
641
    ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
642
643
        assert self.vision_tower is not None

644
        pixel_values = inputs["pixel_values"]
645
646
647

        return self._image_pixels_to_features(self.vision_tower, pixel_values)

648
649
650
651
    def _process_image_input(
        self,
        image_input: LlavaImageInputs,
    ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
652
653
654
        if image_input["type"] == "image_embeds":
            return image_input["data"]

655
656
        assert self.vision_tower is not None
        image_features = self._process_image_pixels(image_input)
657

658
659
660
661
662
663
664
665
666
667
668
        if isinstance(image_features, torch.Tensor):
            return self.multi_modal_projector(image_features)

        feature_sizes = [
            image_feature.shape[0] for image_feature in image_features
        ]

        image_embeds = self.multi_modal_projector(torch.cat(image_features))
        image_embeds = torch.split(image_embeds, feature_sizes)
        return image_embeds

669
670
671
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

672
673
    def get_multimodal_embeddings(self,
                                  **kwargs: object) -> MultiModalEmbeddings:
674
675
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
676
            return []
677

678
        return self._process_image_input(image_input)
679
680
681
682

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
683
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
684
685
    ) -> torch.Tensor:
        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
686
687
        if multimodal_embeddings is not None \
            and len(multimodal_embeddings) != 0:
688
            inputs_embeds = merge_multimodal_embeddings(
689
690
                input_ids,
                inputs_embeds,
691
                multimodal_embeddings,
692
693
                self.config.image_token_index,
            )
694
695
        return inputs_embeds

696
697
698
699
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
700
        intermediate_tensors: Optional[IntermediateTensors] = None,
701
        inputs_embeds: Optional[torch.Tensor] = None,
702
        **kwargs: object,
703
    ) -> Union[torch.Tensor, IntermediateTensors]:
Cyrus Leung's avatar
Cyrus Leung committed
704
        """Run forward pass for LLaVA-1.5.
705
706
707

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
708

709
        Concretely, consider a text prompt:
710
711
        `"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.

712
        Tokenizer outputs:
713
714
715
716
        `[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
        278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
717
        before they are inputted to the model, so the input processor prepends
718
719
720
721
722
723
724
725
726
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
        29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
        29901]`.

        We insert 575 tokens so that including the original image token in the
        input, there are a total of 576 (24 * 24) image tokens, which
        corresponds to the number of image tokens inputted to the language
        model, i.e. the number of image tokens outputted by the visual encoder.
727
728
729
730
731
732
733

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
734
735
736
            positions: Position indices for the input tokens.
            intermediate_tensors: Intermediate tensors from prior forward pass.
            inputs_embeds: Optional tensor of input embeddings.
737

738
        Info:
samzong's avatar
samzong committed
739
            [`LlavaImageInputs`][vllm.model_executor.models.llava.LlavaImageInputs]
740
        """
741
742
        if intermediate_tensors is not None:
            inputs_embeds = None
743
744
745

        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
746
        elif inputs_embeds is None:
747
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
748
749
750
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
            input_ids = None
751

752
753
        hidden_states = self.language_model.model(input_ids,
                                                  positions,
754
                                                  intermediate_tensors,
755
                                                  inputs_embeds=inputs_embeds)
756
757
758

        return hidden_states

759
760
761
762
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
763
        return self.language_model.compute_logits(hidden_states)
764

765
766
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
767
768
769
770
771
        skip_prefixes = []
        if self.vision_tower is None and self.multi_modal_projector is None:
            skip_prefixes.extend(["vision_tower.", "multi_modal_projector."])

        loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
772
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
773
774


775
776
class MantisProcessingInfo(LlavaProcessingInfo):

777
    def get_hf_processor(self, **kwargs: object):
778
779
780
        hf_config = self.get_hf_config()
        vision_info = self.get_vision_encoder_info()

781
        kwargs.setdefault("patch_size", vision_info.get_patch_size())
782
783
784
785
        kwargs.setdefault(
            "vision_feature_select_strategy",
            hf_config.vision_feature_select_strategy,
        )
786

787
        return self.ctx.get_hf_processor(LlavaProcessor, **kwargs)
788
789


790
class MantisMultiModalProcessor(LlavaMultiModalProcessor):
791

792
793
    def apply(
        self,
794
        prompt: Union[str, list[int]],
795
796
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
797
        tokenization_kwargs: Optional[Mapping[str, object]] = None,
798
        mm_uuids: Optional[MultiModalUUIDDict] = None,
799
    ) -> MultiModalInputs:
800
        hf_config = self.info.get_hf_config()
801
        image_token_id = hf_config.image_token_index
802
803

        # Assume that it doesn't depend on the image size
804
        num_image_tokens = self.info.get_num_image_tokens(
805
806
807
            image_width=-1,
            image_height=-1,
        )
808

809
810
811
812
        result = super().apply(prompt,
                               mm_data,
                               hf_processor_mm_kwargs,
                               tokenization_kwargs,
813
                               mm_uuids=mm_uuids)
814

815
816
        mm_items = self._to_mm_items(mm_data)
        mm_item_counts = mm_items.get_all_counts()
817
        mm_kwargs = result["mm_kwargs"]
818
        mm_hashes = result["mm_hashes"]
819
820
821
822
823
824

        # We reimplement the functionality of MLlavaProcessor from
        # https://github.com/TIGER-AI-Lab/Mantis.git
        def get_replacement_mantis(item_idx: int):
            return "".join([
                f"(image {item_idx+1}: <Image>",  # 7 tokens
825
                "<image>" * num_image_tokens,
826
827
828
                "</Image>)",  # 3 tokens
            ])

829
        mantis_mm_repls = self._bind_and_group_updates([
830
831
            PromptReplacement(
                modality="image",
832
                target=[image_token_id] * num_image_tokens,
833
834
                replacement=get_replacement_mantis,
            )
835
        ], mm_item_counts)
836

837
        prompt_ids, prompt, _ = self._apply_prompt_updates(
838
            result["prompt_token_ids"],
839
            mantis_mm_repls,
840
841
        )

842
        orig_repls = self._get_mm_prompt_updates(
843
844
845
846
            mm_items,
            hf_processor_mm_kwargs,
            mm_kwargs,
        )
847
        mm_placeholders = self._find_mm_placeholders(prompt_ids, orig_repls)
848
        self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
849

850
851
852
        mm_placeholder_ranges = {
            modality: [item.to_range() for item in placeholders]
            for modality, placeholders in mm_placeholders.items()
853
854
        }

855
        return MultiModalInputs(
856
            type="multimodal",
857
            prompt=prompt,
858
859
            prompt_token_ids=prompt_ids,
            mm_kwargs=mm_kwargs,
860
            mm_hashes=mm_hashes,
861
            mm_placeholders=mm_placeholder_ranges,
862
        )
863
864
865
866


# To use this model, please use
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
867
@MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor,
868
                                        info=MantisProcessingInfo,
869
                                        dummy_inputs=LlavaDummyInputsBuilder)
870
871
class MantisForConditionalGeneration(LlavaForConditionalGeneration):
    pass