llava.py 28.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import abstractmethod
5
from collections.abc import Iterable, Mapping, Sequence
6
from typing import Annotated, Final, Literal, Protocol, TypeAlias, TypeVar
7
8

import torch
9
import torch.nn as nn
10
11
12
13
14
15
16
17
from transformers import (
    BatchFeature,
    CLIPVisionConfig,
    LlavaConfig,
    PixtralVisionConfig,
    PretrainedConfig,
    SiglipVisionConfig,
)
18
from transformers.models.llava import LlavaProcessor
19
from transformers.models.pixtral import PixtralProcessor
20

21
from vllm.config import VllmConfig
22
from vllm.config.multimodal import BaseDummyOptions
23
from vllm.model_executor.layers.activation import get_act_fn
24
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
25
from vllm.model_executor.layers.quantization import QuantizationConfig
26
from vllm.multimodal import MULTIMODAL_REGISTRY
27
from vllm.multimodal.cache import BaseMultiModalProcessorCache
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalInputs,
    MultiModalKwargsItems,
    MultiModalUUIDDict,
)
from vllm.multimodal.parse import (
    ImageEmbeddingItems,
    ImageProcessorItems,
    ImageSize,
    MultiModalDataItems,
)
from vllm.multimodal.processing import (
42
    BaseDummyInputsBuilder,
43
44
45
46
47
48
49
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    InputProcessingContext,
    PromptReplacement,
    PromptUpdate,
    PromptUpdateDetails,
)
50
from vllm.sequence import IntermediateTensors
51
from vllm.utils.tensor_schema import TensorSchema, TensorShape
52

53
from .clip import CLIPVisionModel
54
55
from .interfaces import (
    MultiModalEmbeddings,
56
    SupportsEagle3,
57
58
59
60
61
    SupportsLoRA,
    SupportsMultiModal,
    SupportsPP,
)
from .module_mapping import MultiModelKeys
62
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
63
from .siglip import SiglipVisionModel
64
65
66
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
67
    get_layer_index,
68
69
70
    init_vllm_registered_model,
    maybe_prefix,
)
71
from .vision import get_num_selected_vision_tokens, get_vision_encoder_info
72
73


74
class LlavaImagePixelInputs(TensorSchema):
75
    """
76
77
78
79
80
    Dimensions:
        - bn: Batch size * number of images
        - c: Number of channels (3)
        - h: Height
        - w: Width
81

82
83
84
    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """
85

86
87
    type: Literal["pixel_values"] = "pixel_values"
    pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
88

89

90
class PixtralHFImagePixelInputs(TensorSchema):
91
    """
92
93
94
95
96
    Dimensions:
        - bn: Batch size * number of images
        - c: Number of channels
        - h: Height
        - w: Width
97

98
99
100
    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """
101

102
    type: Literal["pixel_values_pixtral"] = "pixel_values_pixtral"
103
    pixel_values: Annotated[
104
        torch.Tensor | list[torch.Tensor],
105
106
        TensorShape("bn", "c", "h", "w", dynamic_dims={"h", "w"}),
    ]
107

108

109
class LlavaImageEmbeddingInputs(TensorSchema):
110
    """
111
112
113
114
115
    Dimensions:
        - bn: Batch size * number of images
        - ifs: Image feature size
        - hs: Hidden size (must match language model backbone)
    """
116

117
118
    type: Literal["image_embeds"] = "image_embeds"
    data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")]
119
120


121
122
123
LlavaImageInputs: TypeAlias = (
    LlavaImagePixelInputs | PixtralHFImagePixelInputs | LlavaImageEmbeddingInputs
)
124
"""Alias for supported LLaVA image input types."""
125
126


127
class LlavaMultiModalProjector(nn.Module):
128
129
130
131
132
133
    def __init__(
        self,
        vision_hidden_size: int,
        text_hidden_size: int,
        projector_hidden_act: str,
        multimodal_projector_bias: bool,
134
        quant_config: QuantizationConfig | None = None,
135
136
        prefix: str = "",
    ):
137
138
        super().__init__()

139
140
141
142
143
144
145
        self.linear_1 = ColumnParallelLinear(
            vision_hidden_size,
            text_hidden_size,
            bias=multimodal_projector_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.linear_1",
        )
146
        self.act = get_act_fn(projector_hidden_act)
147
148
149
150
151
152
153
        self.linear_2 = RowParallelLinear(
            text_hidden_size,
            text_hidden_size,
            bias=multimodal_projector_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.linear_2",
        )
154

155
    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
156
        hidden_states, _ = self.linear_1(image_features)
157
        hidden_states = self.act(hidden_states)
158
        hidden_states, _ = self.linear_2(hidden_states)
159
160
161
        return hidden_states


162
163
class LlavaLikeConfig(Protocol):
    vision_config: Final[PretrainedConfig]
164
    image_token_index: Final[int]
165
    vision_feature_select_strategy: Final[str]
166
    vision_feature_layer: Final[int | list[int]]
167

168

169
170
171
172
class LlavaLikeProcessor(Protocol):
    image_token: Final[str]


173
174
class BaseLlavaProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self) -> LlavaLikeConfig:
175
        return self.ctx.get_hf_config(LlavaConfig)
176

177
178
    def get_vision_encoder_info(self):
        return get_vision_encoder_info(self.get_hf_config())
179

180
    @abstractmethod
181
    def get_hf_processor(self, **kwargs: object) -> LlavaLikeProcessor:
182
        raise NotImplementedError
183

184
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
185
        return {"image": None}
186

187
188
189
190
191
192
193
194
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        hf_config = self.get_hf_config()
        vision_encoder_info = self.get_vision_encoder_info()
195

196
        return get_num_selected_vision_tokens(
197
198
199
200
            vision_encoder_info.get_num_image_tokens(
                image_width=image_width,
                image_height=image_height,
            ),
201
            hf_config.vision_feature_select_strategy,
202
        )
203

204
205
    def get_image_size_with_most_features(self) -> ImageSize:
        vision_encoder_info = self.get_vision_encoder_info()
206
207
        width = height = vision_encoder_info.get_image_size()
        return ImageSize(width=width, height=height)
208

209
210
    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()
211

212
        return self.get_num_image_tokens(
213
214
215
216
            image_width=target_width,
            image_height=target_height,
        )

217
218
219
220
221

_I = TypeVar("_I", bound=BaseLlavaProcessingInfo)


class LlavaDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
222
223
224
225
226
227
228
229
230
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        processor = self.info.get_hf_processor()
        image_token = processor.image_token

        return image_token * num_images

    def get_dummy_mm_data(
231
        self,
232
        seq_len: int,
233
        mm_counts: Mapping[str, int],
234
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
235
        mm_processor_kwargs: Mapping[str, object] | None = None,
236
    ) -> MultiModalDataDict:
237
238
        num_images = mm_counts.get("image", 0)

239
        target_width, target_height = self.info.get_image_size_with_most_features()
240

241
242
        image_overrides = mm_options.get("image") if mm_options else None

243
        return {
244
245
246
247
248
249
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            )
250
251
252
        }


253
class LlavaProcessingInfo(BaseLlavaProcessingInfo):
254
    def get_hf_processor(self, **kwargs: object):
255
256
257
258
259
260
261
        hf_processor = self.ctx.get_hf_processor(LlavaProcessor, **kwargs)
        # In case patch_size is omitted from `processor_config.json`
        # e.g. for E5-V: https://huggingface.co/royokong/e5-v
        if hf_processor.patch_size is None:
            patch_size = self.get_vision_encoder_info().get_patch_size()
            hf_processor.patch_size = patch_size
        return hf_processor
262
263


264
class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor[_I]):
265
266
267
268
269
270
271
272
    # Copied from BaseMultiModalProcessor
    @abstractmethod
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        raise NotImplementedError
273

274
    def _get_prompt_updates(
275
276
277
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
278
        out_mm_kwargs: MultiModalKwargsItems,
279
    ) -> Sequence[PromptUpdate]:
280
        hf_config = self.info.get_hf_config()
281
282
283
284
        image_token_id = hf_config.image_token_index

        def get_replacement(item_idx: int):
            images = mm_items.get_items(
285
286
                "image", (ImageEmbeddingItems, ImageProcessorItems)
            )
287
288
289
290
291

            if isinstance(images, ImageEmbeddingItems):
                num_image_tokens = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
292
                num_image_tokens = self.info.get_num_image_tokens(
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
                    image_width=image_size.width,
                    image_height=image_size.height,
                )

            return [image_token_id] * num_image_tokens

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
                replacement=get_replacement,
            ),
        ]


308
class LlavaMultiModalProcessor(BaseLlavaMultiModalProcessor[LlavaProcessingInfo]):
309
310
311
312
313
314
315
316
317
318
319
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )


320
class PixtralHFProcessingInfo(BaseLlavaProcessingInfo):
321
322
    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(PixtralProcessor, **kwargs)
323

324

325
class PixtralHFMultiModalProcessor(BaseMultiModalProcessor[PixtralHFProcessingInfo]):
326
327
328
329
330
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
331
        tok_kwargs: Mapping[str, object],
332
333
334
335
336
    ) -> BatchFeature:
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
337
            tok_kwargs=tok_kwargs,
338
        )
339

340
341
        pixel_values = processed_outputs.get("pixel_values")
        if pixel_values is not None:
342
343
344
345
            # Avoid padding since we need the output for each image to be
            # independent of other images for the cache to work correctly
            image_sizes = processed_outputs["image_sizes"]
            assert len(pixel_values) == len(image_sizes)
346

347
348
349
            processed_outputs["pixel_values"] = [
                p[:, :h, :w] for p, (h, w) in zip(pixel_values, image_sizes)
            ]
350

351
        return processed_outputs
352

353
354
355
356
357
358
359
360
361
362
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )

363
    def _get_prompt_updates(
364
365
        self,
        mm_items: MultiModalDataItems,
366
        hf_processor_mm_kwargs: Mapping[str, object],
367
        out_mm_kwargs: MultiModalKwargsItems,
368
    ) -> Sequence[PromptUpdate]:
369
        processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
370
        hf_config = self.info.get_hf_config()
371
372
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()
373

374
375
376
        image_break_id = vocab[processor.image_break_token]
        image_token_id = hf_config.image_token_index
        image_end_id = vocab[processor.image_end_token]
377

378
379
        assert isinstance(hf_config.vision_config, PixtralVisionConfig)
        encoder_info = PixtralHFEncoderInfo(hf_config)
380

381
382
383
        def get_replacement(item_idx: int):
            images = mm_items.get_items("image", ImageProcessorItems)
            image_size = images.get_image_size(item_idx)
384

385
            ncols, nrows = encoder_info.get_patch_grid_size(
386
387
388
                image_width=image_size.width,
                image_height=image_size.height,
            )
389

390
391
            tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
            tokens[-1] = image_end_id
392

393
            return PromptUpdateDetails.select_token_id(tokens, image_token_id)
394
395
396
397
398

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
399
400
                replacement=get_replacement,
            ),
401
402
        ]

403

404
def _build_llava_or_pixtral_hf_info(
405
406
    ctx: InputProcessingContext,
) -> BaseLlavaProcessingInfo:
407
408
409
410
411
412
413
414
    hf_config = ctx.get_hf_config(LlavaConfig)

    if isinstance(hf_config.vision_config, PixtralVisionConfig):
        return PixtralHFProcessingInfo(ctx)

    return LlavaProcessingInfo(ctx)


415
def _build_llava_or_pixtral_hf_processor(
416
417
    info: _I,
    dummy_inputs: BaseDummyInputsBuilder[_I],
418
    *,
419
    cache: BaseMultiModalProcessorCache | None = None,
420
) -> BaseMultiModalProcessor:
421
    if isinstance(info, PixtralHFProcessingInfo):
422
        return PixtralHFMultiModalProcessor(
423
424
425
426
427
428
429
430
431
            info,
            dummy_inputs,  # type: ignore
            cache=cache,
        )

    if isinstance(info, LlavaProcessingInfo):
        return LlavaMultiModalProcessor(
            info,
            dummy_inputs,  # type: ignore
432
            cache=cache,
433
        )
434

435
    raise NotImplementedError(type(info))
436
437
438
439
440


def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
    """Determine the number of hidden layers to initialize up to in the
    visual encoder.
441

442
443
444
445
446
447
448
    Args:
        hf_config: Model config with vision feature layer(s).
    """
    feature_layers = hf_config.vision_feature_layer
    num_hidden_layers = hf_config.vision_config.num_hidden_layers
    # If we have one feature layer, initialize up to that layer
    if isinstance(feature_layers, int):
449
        return get_layer_index(feature_layers, num_hidden_layers)
450
451
    # If we have multiple feature layers, initialize up to the deepest one
    elif isinstance(feature_layers, (list, tuple)):
452
        return max(get_layer_index(idx, num_hidden_layers) for idx in feature_layers)
453
454
455
    raise TypeError(
        f"vision_layer_feature type: {type(feature_layers)} is not supported"
    )
456
457


458
459
def init_vision_tower_for_llava(
    hf_config: LlavaLikeConfig,
460
    quant_config: QuantizationConfig | None,
461
    *,
462
    require_post_norm: bool | None = None,
463
    prefix: str = "",
464
) -> CLIPVisionModel | SiglipVisionModel | PixtralHFVisionModel:
465
466
    vision_config = hf_config.vision_config

467
468
    # Initialize the vision tower only up to the deepest required feature layer
    num_hidden_layers = _get_num_hidden_layers(hf_config)
469
470
471
472

    if isinstance(vision_config, CLIPVisionConfig):
        return CLIPVisionModel(
            vision_config,
473
            quant_config=quant_config,
474
            num_hidden_layers_override=num_hidden_layers,
475
            require_post_norm=require_post_norm,
476
            prefix=prefix,
477
478
479
480
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return SiglipVisionModel(
            vision_config,
481
            quant_config=quant_config,
482
            num_hidden_layers_override=num_hidden_layers,
483
            require_post_norm=require_post_norm,
484
            prefix=prefix,
485
        )
486
    elif isinstance(vision_config, PixtralVisionConfig):
487
488
        return PixtralHFVisionModel(
            vision_config,
489
            quant_config=quant_config,
490
491
            num_hidden_layers_override=num_hidden_layers,
            require_post_norm=require_post_norm,
492
            prefix=prefix,
493
        )
494
495
496
497
498

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


499
500
501
502
503
@MULTIMODAL_REGISTRY.register_processor(
    _build_llava_or_pixtral_hf_processor,
    info=_build_llava_or_pixtral_hf_info,
    dummy_inputs=LlavaDummyInputsBuilder,
)
504
class LlavaForConditionalGeneration(
505
    nn.Module, SupportsLoRA, SupportsMultiModal, SupportsPP, SupportsEagle3
506
):
507
508
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
509
        "gate_up_proj": ["gate_proj", "up_proj"],
510
    }
511

512
513
514
515
516
517
518
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.language_model.": "language_model.model.",
            "model.vision_tower.": "vision_tower.",
            "model.multi_modal_projector.": "multi_modal_projector.",
            "lm_head.": "language_model.lm_head.",
519
520
        }
    )
521

522
    @classmethod
523
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
524
525
526
527
528
        if modality.startswith("image"):
            return "<image>"

        raise ValueError("Only image modality is supported")

529
530
531
532
533
534
535
    def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
        self.get_language_model().model.aux_hidden_state_layers = layers

    def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
        num_layers = len(self.get_language_model().model.layers)
        return (2, num_layers // 2, num_layers - 3)

536
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
537
        super().__init__()
538

539
540
541
542
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

543
        self.config = config
544
        self.multimodal_config = multimodal_config
545

546
547
        # NOTE: These are special cases for Pixtral-12B in the HF-format
        # https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json  # noqa
548
549
550
551
        if (
            config.text_config.architectures is None
            and config.text_config.model_type == "mistral"
        ):
552
            config.text_config.architectures = ["MistralForCausalLM"]
553
554
555
556
        if (
            config.projector_hidden_act is None
            and config.vision_config.hidden_act == "gelu"
        ):
557
558
            config.projector_hidden_act = "gelu"

559
        with self._mark_tower_model(vllm_config, "image"):
560
561
            self.vision_tower = init_vision_tower_for_llava(
                config,
562
                quant_config=quant_config,
563
                require_post_norm=False,
564
565
                prefix=maybe_prefix(prefix, "vision_tower"),
            )
566
567
568
569
570
571
            self.multi_modal_projector = LlavaMultiModalProjector(
                vision_hidden_size=config.vision_config.hidden_size,
                text_hidden_size=config.text_config.hidden_size,
                projector_hidden_act=config.projector_hidden_act,
                multimodal_projector_bias=config.multimodal_projector_bias,
                quant_config=quant_config,
572
573
                prefix=maybe_prefix(prefix, "multi_modal_projector"),
            )
574
575
576
577
578
579
580

        with self._mark_language_model(vllm_config):
            self.language_model = init_vllm_registered_model(
                vllm_config=vllm_config,
                hf_config=config.text_config,
                prefix=maybe_prefix(prefix, "language_model"),
            )
581

582
        self.make_empty_intermediate_tensors = (
583
584
            self.language_model.make_empty_intermediate_tensors
        )
585

586
    def _parse_and_validate_image_input(
587
        self, **kwargs: object
588
    ) -> LlavaImageInputs | None:
589
        pixel_values = kwargs.pop("pixel_values", None)
590
        image_embeds = kwargs.pop("image_embeds", None)
591

592
        if pixel_values is None and image_embeds is None:
593
            return None
594

595
        if pixel_values is not None:
596
            if self.config.vision_config.model_type == "pixtral":
597
598
                return PixtralHFImagePixelInputs(
                    type="pixel_values_pixtral",
599
                    pixel_values=pixel_values,
600
601
                )

602
            expected_h = expected_w = self.config.vision_config.image_size
603
604
            return LlavaImagePixelInputs(
                type="pixel_values",
605
                pixel_values=pixel_values,
606
                resolve_bindings={"h": expected_h, "w": expected_w},
607
608
609
            )

        if image_embeds is not None:
610
611
612
            if self.config.vision_config.model_type == "pixtral":
                raise ValueError("Pixtral-HF does not support image_embeds.")

613
614
            return LlavaImageEmbeddingInputs(
                type="image_embeds",
615
                data=image_embeds,
616
617
618
            )

        raise AssertionError("This line should be unreachable.")
619

620
621
    def _image_pixels_to_features(
        self,
622
623
624
        vision_tower: CLIPVisionModel | SiglipVisionModel | PixtralHFVisionModel,
        pixel_values: torch.Tensor | list[torch.Tensor],
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
625
626
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
627
628
629
630
        return vision_tower(
            pixel_values,
            feature_select_strategy=self.config.vision_feature_select_strategy,
        )
631

632
633
    def _process_image_pixels(
        self,
634
635
        inputs: LlavaImagePixelInputs | PixtralHFImagePixelInputs,
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
636
        pixel_values = inputs["pixel_values"]
637
638
639

        return self._image_pixels_to_features(self.vision_tower, pixel_values)

640
641
642
    def _process_image_input(
        self,
        image_input: LlavaImageInputs,
643
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
644
645
646
        if image_input["type"] == "image_embeds":
            return image_input["data"]

647
        image_features = self._process_image_pixels(image_input)
648

649
650
651
        if isinstance(image_features, torch.Tensor):
            return self.multi_modal_projector(image_features)

652
        feature_sizes = [image_feature.shape[0] for image_feature in image_features]
653
654
655
656
657

        image_embeds = self.multi_modal_projector(torch.cat(image_features))
        image_embeds = torch.split(image_embeds, feature_sizes)
        return image_embeds

658
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
659
660
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
661
            return []
662

663
        return self._process_image_input(image_input)
664

665
666
    def forward(
        self,
667
        input_ids: torch.Tensor | None,
668
        positions: torch.Tensor,
669
670
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
671
        **kwargs: object,
672
    ) -> torch.Tensor | IntermediateTensors:
Cyrus Leung's avatar
Cyrus Leung committed
673
        """Run forward pass for LLaVA-1.5.
674
675
676

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
677

678
        Concretely, consider a text prompt:
679
680
        `"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.

681
        Tokenizer outputs:
682
683
684
685
        `[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
        278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
686
        before they are inputted to the model, so the input processor prepends
687
688
689
690
691
692
693
694
695
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
        29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
        29901]`.

        We insert 575 tokens so that including the original image token in the
        input, there are a total of 576 (24 * 24) image tokens, which
        corresponds to the number of image tokens inputted to the language
        model, i.e. the number of image tokens outputted by the visual encoder.
696
697
698
699
700
701
702

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
703
704
705
            positions: Position indices for the input tokens.
            intermediate_tensors: Intermediate tensors from prior forward pass.
            inputs_embeds: Optional tensor of input embeddings.
706

707
        Info:
samzong's avatar
samzong committed
708
            [`LlavaImageInputs`][vllm.model_executor.models.llava.LlavaImageInputs]
709
        """
710
711
        if intermediate_tensors is not None:
            inputs_embeds = None
712

713
714
715
        hidden_states = self.language_model.model(
            input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
        )
716
717
718

        return hidden_states

719
720
721
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
722
    ) -> torch.Tensor | None:
723
        return self.language_model.compute_logits(hidden_states)
724

725
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
726
        loader = AutoWeightsLoader(self)
727
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
728

729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model",
            connector="multi_modal_projector",
            tower_model="vision_tower",
        )

    def get_num_mm_encoder_tokens(
        self,
        num_image_tokens: int,
    ) -> int:
        # LLaVA's vision encoder outputs one token per patch without
        # spatial merging or pixel shuffle
        return num_image_tokens

    def get_num_mm_connector_tokens(
        self,
        num_vision_tokens: int,
    ) -> int:
        # LLaVA's MLP projector outputs the same number of tokens
        # as it receives from the vision encoder (1:1 mapping)
        return num_vision_tokens

755

756
class MantisProcessingInfo(LlavaProcessingInfo):
757
    def get_hf_processor(self, **kwargs: object):
758
759
760
        hf_config = self.get_hf_config()
        vision_info = self.get_vision_encoder_info()

761
        kwargs.setdefault("patch_size", vision_info.get_patch_size())
762
763
764
765
        kwargs.setdefault(
            "vision_feature_select_strategy",
            hf_config.vision_feature_select_strategy,
        )
766

767
        return self.ctx.get_hf_processor(LlavaProcessor, **kwargs)
768
769


770
class MantisMultiModalProcessor(LlavaMultiModalProcessor):
771
772
    def apply(
        self,
773
        prompt: str | list[int],
774
        mm_items: MultiModalDataItems,
775
        hf_processor_mm_kwargs: Mapping[str, object],
776
777
        tokenization_kwargs: Mapping[str, object] | None = None,
        mm_uuids: MultiModalUUIDDict | None = None,
778
    ) -> MultiModalInputs:
779
        hf_config = self.info.get_hf_config()
780
        image_token_id = hf_config.image_token_index
781
782

        # Assume that it doesn't depend on the image size
783
        num_image_tokens = self.info.get_num_image_tokens(
784
785
786
            image_width=-1,
            image_height=-1,
        )
787

788
789
        result = super().apply(
            prompt,
790
            mm_items,
791
792
793
794
            hf_processor_mm_kwargs,
            tokenization_kwargs,
            mm_uuids=mm_uuids,
        )
795

796
        mm_item_counts = mm_items.get_all_counts()
797
        mm_kwargs = result["mm_kwargs"]
798
        mm_hashes = result["mm_hashes"]
799
800
801
802

        # We reimplement the functionality of MLlavaProcessor from
        # https://github.com/TIGER-AI-Lab/Mantis.git
        def get_replacement_mantis(item_idx: int):
803
804
805
806
807
808
            return "".join(
                [
                    f"(image {item_idx + 1}: <Image>",  # 7 tokens
                    "<image>" * num_image_tokens,
                    "</Image>)",  # 3 tokens
                ]
809
            )
810
811
812
813
814
815
816
817
818
819
820

        mantis_mm_repls = self._bind_and_group_updates(
            [
                PromptReplacement(
                    modality="image",
                    target=[image_token_id] * num_image_tokens,
                    replacement=get_replacement_mantis,
                )
            ],
            mm_item_counts,
        )
821

822
        prompt_ids, _ = self._apply_prompt_updates(
823
            result["prompt_token_ids"],
824
            mantis_mm_repls,
825
826
        )

827
        orig_repls = self._get_mm_prompt_updates(
828
829
830
831
            mm_items,
            hf_processor_mm_kwargs,
            mm_kwargs,
        )
832
        mm_placeholders = self._find_mm_placeholders(prompt_ids, orig_repls)
833
        self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
834

835
836
837
        mm_placeholder_ranges = {
            modality: [item.to_range() for item in placeholders]
            for modality, placeholders in mm_placeholders.items()
838
839
        }

840
        return MultiModalInputs(
841
842
843
            type="multimodal",
            prompt_token_ids=prompt_ids,
            mm_kwargs=mm_kwargs,
844
            mm_hashes=mm_hashes,
845
            mm_placeholders=mm_placeholder_ranges,
846
        )
847
848
849
850


# To use this model, please use
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
851
852
853
854
855
@MULTIMODAL_REGISTRY.register_processor(
    MantisMultiModalProcessor,
    info=MantisProcessingInfo,
    dummy_inputs=LlavaDummyInputsBuilder,
)
856
857
class MantisForConditionalGeneration(LlavaForConditionalGeneration):
    pass