llava.py 28.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import abstractmethod
5
from collections.abc import Iterable, Mapping, Sequence
6
from typing import Annotated, Final, Literal, Protocol, TypeAlias, TypeVar
7
8

import torch
9
import torch.nn as nn
10
11
12
13
14
15
16
17
from transformers import (
    BatchFeature,
    CLIPVisionConfig,
    LlavaConfig,
    PixtralVisionConfig,
    PretrainedConfig,
    SiglipVisionConfig,
)
18
from transformers.models.llava import LlavaProcessor
19
from transformers.models.pixtral import PixtralProcessor
20

21
from vllm.config import VllmConfig
22
from vllm.config.multimodal import BaseDummyOptions
23
from vllm.model_executor.layers.activation import get_act_fn
24
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
25
from vllm.model_executor.layers.quantization import QuantizationConfig
26
from vllm.multimodal import MULTIMODAL_REGISTRY
27
from vllm.multimodal.cache import BaseMultiModalProcessorCache
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalInputs,
    MultiModalKwargsItems,
    MultiModalUUIDDict,
)
from vllm.multimodal.parse import (
    ImageEmbeddingItems,
    ImageProcessorItems,
    ImageSize,
    MultiModalDataItems,
)
from vllm.multimodal.processing import (
42
    BaseDummyInputsBuilder,
43
44
45
46
47
48
49
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    InputProcessingContext,
    PromptReplacement,
    PromptUpdate,
    PromptUpdateDetails,
)
50
from vllm.sequence import IntermediateTensors
51
from vllm.utils.tensor_schema import TensorSchema, TensorShape
52

53
from .clip import CLIPVisionModel
54
55
from .interfaces import (
    MultiModalEmbeddings,
56
    SupportsEagle3,
57
58
59
60
61
    SupportsLoRA,
    SupportsMultiModal,
    SupportsPP,
)
from .module_mapping import MultiModelKeys
62
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
63
from .siglip import SiglipVisionModel
64
65
66
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
67
    get_layer_index,
68
69
70
    init_vllm_registered_model,
    maybe_prefix,
)
71
from .vision import get_num_selected_vision_tokens, get_vision_encoder_info
72
73


74
class LlavaImagePixelInputs(TensorSchema):
75
    """
76
77
78
79
80
    Dimensions:
        - bn: Batch size * number of images
        - c: Number of channels (3)
        - h: Height
        - w: Width
81

82
83
84
    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """
85

86
87
    type: Literal["pixel_values"] = "pixel_values"
    pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
88

89

90
class PixtralHFImagePixelInputs(TensorSchema):
91
    """
92
93
94
95
96
    Dimensions:
        - bn: Batch size * number of images
        - c: Number of channels
        - h: Height
        - w: Width
97

98
99
100
    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """
101

102
    type: Literal["pixel_values_pixtral"] = "pixel_values_pixtral"
103
    pixel_values: Annotated[
104
        torch.Tensor | list[torch.Tensor],
105
106
        TensorShape("bn", "c", "h", "w", dynamic_dims={"h", "w"}),
    ]
107

108

109
class LlavaImageEmbeddingInputs(TensorSchema):
110
    """
111
112
113
114
115
    Dimensions:
        - bn: Batch size * number of images
        - ifs: Image feature size
        - hs: Hidden size (must match language model backbone)
    """
116

117
118
    type: Literal["image_embeds"] = "image_embeds"
    data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")]
119
120


121
122
123
LlavaImageInputs: TypeAlias = (
    LlavaImagePixelInputs | PixtralHFImagePixelInputs | LlavaImageEmbeddingInputs
)
124
"""Alias for supported LLaVA image input types."""
125
126


127
class LlavaMultiModalProjector(nn.Module):
128
129
130
131
132
133
    def __init__(
        self,
        vision_hidden_size: int,
        text_hidden_size: int,
        projector_hidden_act: str,
        multimodal_projector_bias: bool,
134
        quant_config: QuantizationConfig | None = None,
135
136
        prefix: str = "",
    ):
137
138
        super().__init__()

139
140
141
142
143
144
145
        self.linear_1 = ColumnParallelLinear(
            vision_hidden_size,
            text_hidden_size,
            bias=multimodal_projector_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.linear_1",
        )
146
        self.act = get_act_fn(projector_hidden_act)
147
148
149
150
151
152
153
        self.linear_2 = RowParallelLinear(
            text_hidden_size,
            text_hidden_size,
            bias=multimodal_projector_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.linear_2",
        )
154

155
    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
156
        hidden_states, _ = self.linear_1(image_features)
157
        hidden_states = self.act(hidden_states)
158
        hidden_states, _ = self.linear_2(hidden_states)
159
160
161
        return hidden_states


162
163
class LlavaLikeConfig(Protocol):
    vision_config: Final[PretrainedConfig]
164
    image_token_index: Final[int]
165
    vision_feature_select_strategy: Final[str]
166
    vision_feature_layer: Final[int | list[int]]
167

168

169
170
171
172
class LlavaLikeProcessor(Protocol):
    image_token: Final[str]


173
174
class BaseLlavaProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self) -> LlavaLikeConfig:
175
        return self.ctx.get_hf_config(LlavaConfig)
176

177
178
    def get_vision_encoder_info(self):
        return get_vision_encoder_info(self.get_hf_config())
179

180
    @abstractmethod
181
    def get_hf_processor(self, **kwargs: object) -> LlavaLikeProcessor:
182
        raise NotImplementedError
183

184
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
185
        return {"image": None}
186

187
188
189
190
191
192
193
194
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        hf_config = self.get_hf_config()
        vision_encoder_info = self.get_vision_encoder_info()
195

196
        return get_num_selected_vision_tokens(
197
198
199
200
            vision_encoder_info.get_num_image_tokens(
                image_width=image_width,
                image_height=image_height,
            ),
201
            hf_config.vision_feature_select_strategy,
202
        )
203

204
205
    def get_image_size_with_most_features(self) -> ImageSize:
        vision_encoder_info = self.get_vision_encoder_info()
206
207
        width = height = vision_encoder_info.get_image_size()
        return ImageSize(width=width, height=height)
208

209
210
    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()
211

212
        return self.get_num_image_tokens(
213
214
215
216
            image_width=target_width,
            image_height=target_height,
        )

217
218
219
220
221

_I = TypeVar("_I", bound=BaseLlavaProcessingInfo)


class LlavaDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
222
223
224
225
226
227
228
229
230
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        processor = self.info.get_hf_processor()
        image_token = processor.image_token

        return image_token * num_images

    def get_dummy_mm_data(
231
        self,
232
        seq_len: int,
233
        mm_counts: Mapping[str, int],
234
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
235
    ) -> MultiModalDataDict:
236
237
        num_images = mm_counts.get("image", 0)

238
        target_width, target_height = self.info.get_image_size_with_most_features()
239

240
241
        image_overrides = mm_options.get("image") if mm_options else None

242
        return {
243
244
245
246
247
248
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            )
249
250
251
        }


252
class LlavaProcessingInfo(BaseLlavaProcessingInfo):
253
    def get_hf_processor(self, **kwargs: object):
254
255
256
257
258
259
260
        hf_processor = self.ctx.get_hf_processor(LlavaProcessor, **kwargs)
        # In case patch_size is omitted from `processor_config.json`
        # e.g. for E5-V: https://huggingface.co/royokong/e5-v
        if hf_processor.patch_size is None:
            patch_size = self.get_vision_encoder_info().get_patch_size()
            hf_processor.patch_size = patch_size
        return hf_processor
261
262


263
class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor[_I]):
264
265
266
267
268
269
270
271
    # Copied from BaseMultiModalProcessor
    @abstractmethod
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        raise NotImplementedError
272

273
    def _get_prompt_updates(
274
275
276
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
277
        out_mm_kwargs: MultiModalKwargsItems,
278
    ) -> Sequence[PromptUpdate]:
279
        hf_config = self.info.get_hf_config()
280
281
282
283
        image_token_id = hf_config.image_token_index

        def get_replacement(item_idx: int):
            images = mm_items.get_items(
284
285
                "image", (ImageEmbeddingItems, ImageProcessorItems)
            )
286
287
288
289
290

            if isinstance(images, ImageEmbeddingItems):
                num_image_tokens = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
291
                num_image_tokens = self.info.get_num_image_tokens(
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
                    image_width=image_size.width,
                    image_height=image_size.height,
                )

            return [image_token_id] * num_image_tokens

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
                replacement=get_replacement,
            ),
        ]


307
class LlavaMultiModalProcessor(BaseLlavaMultiModalProcessor[LlavaProcessingInfo]):
308
309
310
311
312
313
314
315
316
317
318
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )


319
class PixtralHFProcessingInfo(BaseLlavaProcessingInfo):
320
321
    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(PixtralProcessor, **kwargs)
322

323

324
class PixtralHFMultiModalProcessor(BaseMultiModalProcessor[PixtralHFProcessingInfo]):
325
326
327
328
329
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
330
        tok_kwargs: Mapping[str, object],
331
332
333
334
335
    ) -> BatchFeature:
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
336
            tok_kwargs=tok_kwargs,
337
        )
338

339
340
        pixel_values = processed_outputs.get("pixel_values")
        if pixel_values is not None:
341
342
343
344
            # Avoid padding since we need the output for each image to be
            # independent of other images for the cache to work correctly
            image_sizes = processed_outputs["image_sizes"]
            assert len(pixel_values) == len(image_sizes)
345

346
347
348
            processed_outputs["pixel_values"] = [
                p[:, :h, :w] for p, (h, w) in zip(pixel_values, image_sizes)
            ]
349

350
        return processed_outputs
351

352
353
354
355
356
357
358
359
360
361
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )

362
    def _get_prompt_updates(
363
364
        self,
        mm_items: MultiModalDataItems,
365
        hf_processor_mm_kwargs: Mapping[str, object],
366
        out_mm_kwargs: MultiModalKwargsItems,
367
    ) -> Sequence[PromptUpdate]:
368
        processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
369
        hf_config = self.info.get_hf_config()
370
371
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()
372

373
374
375
        image_break_id = vocab[processor.image_break_token]
        image_token_id = hf_config.image_token_index
        image_end_id = vocab[processor.image_end_token]
376

377
378
        assert isinstance(hf_config.vision_config, PixtralVisionConfig)
        encoder_info = PixtralHFEncoderInfo(hf_config)
379

380
381
382
        def get_replacement(item_idx: int):
            images = mm_items.get_items("image", ImageProcessorItems)
            image_size = images.get_image_size(item_idx)
383

384
            ncols, nrows = encoder_info.get_patch_grid_size(
385
386
387
                image_width=image_size.width,
                image_height=image_size.height,
            )
388

389
390
            tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
            tokens[-1] = image_end_id
391

392
            return PromptUpdateDetails.select_token_id(tokens, image_token_id)
393
394
395
396
397

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
398
399
                replacement=get_replacement,
            ),
400
401
        ]

402

403
def _build_llava_or_pixtral_hf_info(
404
405
    ctx: InputProcessingContext,
) -> BaseLlavaProcessingInfo:
406
407
408
409
410
411
412
413
    hf_config = ctx.get_hf_config(LlavaConfig)

    if isinstance(hf_config.vision_config, PixtralVisionConfig):
        return PixtralHFProcessingInfo(ctx)

    return LlavaProcessingInfo(ctx)


414
def _build_llava_or_pixtral_hf_processor(
415
416
    info: _I,
    dummy_inputs: BaseDummyInputsBuilder[_I],
417
    *,
418
    cache: BaseMultiModalProcessorCache | None = None,
419
) -> BaseMultiModalProcessor:
420
    if isinstance(info, PixtralHFProcessingInfo):
421
        return PixtralHFMultiModalProcessor(
422
423
424
425
426
427
428
429
430
            info,
            dummy_inputs,  # type: ignore
            cache=cache,
        )

    if isinstance(info, LlavaProcessingInfo):
        return LlavaMultiModalProcessor(
            info,
            dummy_inputs,  # type: ignore
431
            cache=cache,
432
        )
433

434
    raise NotImplementedError(type(info))
435
436
437
438
439


def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
    """Determine the number of hidden layers to initialize up to in the
    visual encoder.
440

441
442
443
444
445
446
447
    Args:
        hf_config: Model config with vision feature layer(s).
    """
    feature_layers = hf_config.vision_feature_layer
    num_hidden_layers = hf_config.vision_config.num_hidden_layers
    # If we have one feature layer, initialize up to that layer
    if isinstance(feature_layers, int):
448
        return get_layer_index(feature_layers, num_hidden_layers)
449
450
    # If we have multiple feature layers, initialize up to the deepest one
    elif isinstance(feature_layers, (list, tuple)):
451
        return max(get_layer_index(idx, num_hidden_layers) for idx in feature_layers)
452
453
454
    raise TypeError(
        f"vision_layer_feature type: {type(feature_layers)} is not supported"
    )
455
456


457
458
def init_vision_tower_for_llava(
    hf_config: LlavaLikeConfig,
459
    quant_config: QuantizationConfig | None,
460
    *,
461
    require_post_norm: bool | None = None,
462
    prefix: str = "",
463
) -> CLIPVisionModel | SiglipVisionModel | PixtralHFVisionModel:
464
465
    vision_config = hf_config.vision_config

466
467
    # Initialize the vision tower only up to the deepest required feature layer
    num_hidden_layers = _get_num_hidden_layers(hf_config)
468
469
470
471

    if isinstance(vision_config, CLIPVisionConfig):
        return CLIPVisionModel(
            vision_config,
472
            quant_config=quant_config,
473
            num_hidden_layers_override=num_hidden_layers,
474
            require_post_norm=require_post_norm,
475
            prefix=prefix,
476
477
478
479
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return SiglipVisionModel(
            vision_config,
480
            quant_config=quant_config,
481
            num_hidden_layers_override=num_hidden_layers,
482
            require_post_norm=require_post_norm,
483
            prefix=prefix,
484
        )
485
    elif isinstance(vision_config, PixtralVisionConfig):
486
487
        return PixtralHFVisionModel(
            vision_config,
488
            quant_config=quant_config,
489
490
            num_hidden_layers_override=num_hidden_layers,
            require_post_norm=require_post_norm,
491
            prefix=prefix,
492
        )
493
494
495
496
497

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


498
499
500
501
502
@MULTIMODAL_REGISTRY.register_processor(
    _build_llava_or_pixtral_hf_processor,
    info=_build_llava_or_pixtral_hf_info,
    dummy_inputs=LlavaDummyInputsBuilder,
)
503
class LlavaForConditionalGeneration(
504
    nn.Module, SupportsLoRA, SupportsMultiModal, SupportsPP, SupportsEagle3
505
):
506
507
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
508
        "gate_up_proj": ["gate_proj", "up_proj"],
509
    }
510

511
512
513
514
515
516
517
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.language_model.": "language_model.model.",
            "model.vision_tower.": "vision_tower.",
            "model.multi_modal_projector.": "multi_modal_projector.",
            "lm_head.": "language_model.lm_head.",
518
519
        }
    )
520

521
    @classmethod
522
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
523
524
525
526
527
        if modality.startswith("image"):
            return "<image>"

        raise ValueError("Only image modality is supported")

528
529
530
531
532
533
534
    def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
        self.get_language_model().model.aux_hidden_state_layers = layers

    def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
        num_layers = len(self.get_language_model().model.layers)
        return (2, num_layers // 2, num_layers - 3)

535
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
536
        super().__init__()
537

538
539
540
541
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

542
        self.config = config
543
        self.multimodal_config = multimodal_config
544

545
546
        # NOTE: These are special cases for Pixtral-12B in the HF-format
        # https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json  # noqa
547
548
549
550
        if (
            config.text_config.architectures is None
            and config.text_config.model_type == "mistral"
        ):
551
            config.text_config.architectures = ["MistralForCausalLM"]
552
553
554
555
        if (
            config.projector_hidden_act is None
            and config.vision_config.hidden_act == "gelu"
        ):
556
557
            config.projector_hidden_act = "gelu"

558
        with self._mark_tower_model(vllm_config, "image"):
559
560
            self.vision_tower = init_vision_tower_for_llava(
                config,
561
                quant_config=quant_config,
562
                require_post_norm=False,
563
564
                prefix=maybe_prefix(prefix, "vision_tower"),
            )
565
566
567
568
569
570
            self.multi_modal_projector = LlavaMultiModalProjector(
                vision_hidden_size=config.vision_config.hidden_size,
                text_hidden_size=config.text_config.hidden_size,
                projector_hidden_act=config.projector_hidden_act,
                multimodal_projector_bias=config.multimodal_projector_bias,
                quant_config=quant_config,
571
572
                prefix=maybe_prefix(prefix, "multi_modal_projector"),
            )
573
574
575
576
577
578
579

        with self._mark_language_model(vllm_config):
            self.language_model = init_vllm_registered_model(
                vllm_config=vllm_config,
                hf_config=config.text_config,
                prefix=maybe_prefix(prefix, "language_model"),
            )
580

581
        self.make_empty_intermediate_tensors = (
582
583
            self.language_model.make_empty_intermediate_tensors
        )
584

585
    def _parse_and_validate_image_input(
586
        self, **kwargs: object
587
    ) -> LlavaImageInputs | None:
588
        pixel_values = kwargs.pop("pixel_values", None)
589
        image_embeds = kwargs.pop("image_embeds", None)
590

591
        if pixel_values is None and image_embeds is None:
592
            return None
593

594
        if pixel_values is not None:
595
            if self.config.vision_config.model_type == "pixtral":
596
597
                return PixtralHFImagePixelInputs(
                    type="pixel_values_pixtral",
598
                    pixel_values=pixel_values,
599
600
                )

601
            expected_h = expected_w = self.config.vision_config.image_size
602
603
            return LlavaImagePixelInputs(
                type="pixel_values",
604
                pixel_values=pixel_values,
605
                resolve_bindings={"h": expected_h, "w": expected_w},
606
607
608
            )

        if image_embeds is not None:
609
610
611
            if self.config.vision_config.model_type == "pixtral":
                raise ValueError("Pixtral-HF does not support image_embeds.")

612
613
            return LlavaImageEmbeddingInputs(
                type="image_embeds",
614
                data=image_embeds,
615
616
617
            )

        raise AssertionError("This line should be unreachable.")
618

619
620
    def _image_pixels_to_features(
        self,
621
622
623
        vision_tower: CLIPVisionModel | SiglipVisionModel | PixtralHFVisionModel,
        pixel_values: torch.Tensor | list[torch.Tensor],
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
624
625
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
626
627
628
629
        return vision_tower(
            pixel_values,
            feature_select_strategy=self.config.vision_feature_select_strategy,
        )
630

631
632
    def _process_image_pixels(
        self,
633
634
        inputs: LlavaImagePixelInputs | PixtralHFImagePixelInputs,
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
635
        pixel_values = inputs["pixel_values"]
636
637
638

        return self._image_pixels_to_features(self.vision_tower, pixel_values)

639
640
641
    def _process_image_input(
        self,
        image_input: LlavaImageInputs,
642
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
643
644
645
        if image_input["type"] == "image_embeds":
            return image_input["data"]

646
        image_features = self._process_image_pixels(image_input)
647

648
649
650
        if isinstance(image_features, torch.Tensor):
            return self.multi_modal_projector(image_features)

651
        feature_sizes = [image_feature.shape[0] for image_feature in image_features]
652
653
654
655
656

        image_embeds = self.multi_modal_projector(torch.cat(image_features))
        image_embeds = torch.split(image_embeds, feature_sizes)
        return image_embeds

657
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
658
659
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
660
            return []
661

662
        return self._process_image_input(image_input)
663

664
665
    def forward(
        self,
666
        input_ids: torch.Tensor | None,
667
        positions: torch.Tensor,
668
669
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
670
        **kwargs: object,
671
    ) -> torch.Tensor | IntermediateTensors:
Cyrus Leung's avatar
Cyrus Leung committed
672
        """Run forward pass for LLaVA-1.5.
673
674
675

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
676

677
        Concretely, consider a text prompt:
678
679
        `"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.

680
        Tokenizer outputs:
681
682
683
684
        `[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
        278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
685
        before they are inputted to the model, so the input processor prepends
686
687
688
689
690
691
692
693
694
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
        29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
        29901]`.

        We insert 575 tokens so that including the original image token in the
        input, there are a total of 576 (24 * 24) image tokens, which
        corresponds to the number of image tokens inputted to the language
        model, i.e. the number of image tokens outputted by the visual encoder.
695
696
697
698
699
700
701

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
702
703
704
            positions: Position indices for the input tokens.
            intermediate_tensors: Intermediate tensors from prior forward pass.
            inputs_embeds: Optional tensor of input embeddings.
705

706
        Info:
samzong's avatar
samzong committed
707
            [`LlavaImageInputs`][vllm.model_executor.models.llava.LlavaImageInputs]
708
        """
709
710
        if intermediate_tensors is not None:
            inputs_embeds = None
711

712
713
714
        hidden_states = self.language_model.model(
            input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
        )
715
716
717

        return hidden_states

718
719
720
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
721
    ) -> torch.Tensor | None:
722
        return self.language_model.compute_logits(hidden_states)
723

724
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
725
        loader = AutoWeightsLoader(self)
726
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
727

728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model",
            connector="multi_modal_projector",
            tower_model="vision_tower",
        )

    def get_num_mm_encoder_tokens(
        self,
        num_image_tokens: int,
    ) -> int:
        # LLaVA's vision encoder outputs one token per patch without
        # spatial merging or pixel shuffle
        return num_image_tokens

    def get_num_mm_connector_tokens(
        self,
        num_vision_tokens: int,
    ) -> int:
        # LLaVA's MLP projector outputs the same number of tokens
        # as it receives from the vision encoder (1:1 mapping)
        return num_vision_tokens

754

755
class MantisProcessingInfo(LlavaProcessingInfo):
756
    def get_hf_processor(self, **kwargs: object):
757
758
759
        hf_config = self.get_hf_config()
        vision_info = self.get_vision_encoder_info()

760
        kwargs.setdefault("patch_size", vision_info.get_patch_size())
761
762
763
764
        kwargs.setdefault(
            "vision_feature_select_strategy",
            hf_config.vision_feature_select_strategy,
        )
765

766
        return self.ctx.get_hf_processor(LlavaProcessor, **kwargs)
767
768


769
class MantisMultiModalProcessor(LlavaMultiModalProcessor):
770
771
    def apply(
        self,
772
        prompt: str | list[int],
773
        mm_items: MultiModalDataItems,
774
        hf_processor_mm_kwargs: Mapping[str, object],
775
776
        tokenization_kwargs: Mapping[str, object] | None = None,
        mm_uuids: MultiModalUUIDDict | None = None,
777
    ) -> MultiModalInputs:
778
        hf_config = self.info.get_hf_config()
779
        image_token_id = hf_config.image_token_index
780
781

        # Assume that it doesn't depend on the image size
782
        num_image_tokens = self.info.get_num_image_tokens(
783
784
785
            image_width=-1,
            image_height=-1,
        )
786

787
788
        result = super().apply(
            prompt,
789
            mm_items,
790
791
792
793
            hf_processor_mm_kwargs,
            tokenization_kwargs,
            mm_uuids=mm_uuids,
        )
794

795
        mm_item_counts = mm_items.get_all_counts()
796
        mm_kwargs = result["mm_kwargs"]
797
        mm_hashes = result["mm_hashes"]
798
799
800
801

        # We reimplement the functionality of MLlavaProcessor from
        # https://github.com/TIGER-AI-Lab/Mantis.git
        def get_replacement_mantis(item_idx: int):
802
803
804
805
806
807
            return "".join(
                [
                    f"(image {item_idx + 1}: <Image>",  # 7 tokens
                    "<image>" * num_image_tokens,
                    "</Image>)",  # 3 tokens
                ]
808
            )
809
810
811
812
813
814
815
816
817
818
819

        mantis_mm_repls = self._bind_and_group_updates(
            [
                PromptReplacement(
                    modality="image",
                    target=[image_token_id] * num_image_tokens,
                    replacement=get_replacement_mantis,
                )
            ],
            mm_item_counts,
        )
820

821
        prompt_ids, _ = self._apply_prompt_updates(
822
            result["prompt_token_ids"],
823
            mantis_mm_repls,
824
825
        )

826
        orig_repls = self._get_mm_prompt_updates(
827
828
829
830
            mm_items,
            hf_processor_mm_kwargs,
            mm_kwargs,
        )
831
        mm_placeholders = self._find_mm_placeholders(prompt_ids, orig_repls)
832
        self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
833

834
835
836
        mm_placeholder_ranges = {
            modality: [item.to_range() for item in placeholders]
            for modality, placeholders in mm_placeholders.items()
837
838
        }

839
        return MultiModalInputs(
840
841
842
            type="multimodal",
            prompt_token_ids=prompt_ids,
            mm_kwargs=mm_kwargs,
843
            mm_hashes=mm_hashes,
844
            mm_placeholders=mm_placeholder_ranges,
845
        )
846
847
848
849


# To use this model, please use
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
850
851
852
853
854
@MULTIMODAL_REGISTRY.register_processor(
    MantisMultiModalProcessor,
    info=MantisProcessingInfo,
    dummy_inputs=LlavaDummyInputsBuilder,
)
855
856
class MantisForConditionalGeneration(LlavaForConditionalGeneration):
    pass