llava.py 28.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import abstractmethod
5
from collections.abc import Iterable, Mapping, Sequence
6
from typing import Annotated, Final, Literal, Protocol, TypeAlias, TypeVar
7
8

import torch
9
import torch.nn as nn
10
11
12
13
14
15
16
17
from transformers import (
    BatchFeature,
    CLIPVisionConfig,
    LlavaConfig,
    PixtralVisionConfig,
    PretrainedConfig,
    SiglipVisionConfig,
)
18
from transformers.models.llava import LlavaProcessor
19
from transformers.models.pixtral import PixtralProcessor
20

21
from vllm.config import VllmConfig
22
from vllm.config.multimodal import BaseDummyOptions
23
from vllm.inputs import MultiModalDataDict, MultiModalInput, mm_input
24
from vllm.model_executor.layers.activation import get_act_fn
25
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
26
from vllm.model_executor.layers.quantization import QuantizationConfig
27
from vllm.multimodal import MULTIMODAL_REGISTRY
28
from vllm.multimodal.cache import BaseMultiModalProcessorCache
29
30
31
32
33
34
35
36
37
38
39
from vllm.multimodal.inputs import (
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
from vllm.multimodal.parse import (
    ImageEmbeddingItems,
    ImageProcessorItems,
    ImageSize,
    MultiModalDataItems,
)
from vllm.multimodal.processing import (
40
    BaseDummyInputsBuilder,
41
42
43
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    InputProcessingContext,
44
    ProcessorInputs,
45
46
47
    PromptReplacement,
    PromptUpdate,
    PromptUpdateDetails,
48
    TimingContext,
49
)
50
from vllm.sequence import IntermediateTensors
51
from vllm.utils.tensor_schema import TensorSchema, TensorShape
52

53
from .clip import CLIPVisionModel
54
55
from .interfaces import (
    MultiModalEmbeddings,
56
    SupportsEagle,
57
    SupportsEagle3,
58
59
60
61
62
    SupportsLoRA,
    SupportsMultiModal,
    SupportsPP,
)
from .module_mapping import MultiModelKeys
63
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
64
from .siglip import SiglipVisionModel
65
66
67
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
68
    get_layer_index,
69
70
71
    init_vllm_registered_model,
    maybe_prefix,
)
72
from .vision import get_num_selected_vision_tokens, get_vision_encoder_info
73
74


75
class LlavaImagePixelInputs(TensorSchema):
76
    """
77
78
79
80
81
    Dimensions:
        - bn: Batch size * number of images
        - c: Number of channels (3)
        - h: Height
        - w: Width
82

83
84
85
    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """
86

87
88
    type: Literal["pixel_values"] = "pixel_values"
    pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
89

90

91
class PixtralHFImagePixelInputs(TensorSchema):
92
    """
93
94
95
96
97
    Dimensions:
        - bn: Batch size * number of images
        - c: Number of channels
        - h: Height
        - w: Width
98

99
100
101
    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """
102

103
    type: Literal["pixel_values_pixtral"] = "pixel_values_pixtral"
104
    pixel_values: Annotated[
105
        torch.Tensor | list[torch.Tensor],
106
107
        TensorShape("bn", "c", "h", "w", dynamic_dims={"h", "w"}),
    ]
108

109

110
class LlavaImageEmbeddingInputs(TensorSchema):
111
    """
112
113
114
115
116
    Dimensions:
        - bn: Batch size * number of images
        - ifs: Image feature size
        - hs: Hidden size (must match language model backbone)
    """
117

118
119
    type: Literal["image_embeds"] = "image_embeds"
    data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")]
120
121


122
123
124
LlavaImageInputs: TypeAlias = (
    LlavaImagePixelInputs | PixtralHFImagePixelInputs | LlavaImageEmbeddingInputs
)
125
"""Alias for supported LLaVA image input types."""
126
127


128
class LlavaMultiModalProjector(nn.Module):
129
130
131
132
133
134
    def __init__(
        self,
        vision_hidden_size: int,
        text_hidden_size: int,
        projector_hidden_act: str,
        multimodal_projector_bias: bool,
135
        quant_config: QuantizationConfig | None = None,
136
137
        prefix: str = "",
    ):
138
139
        super().__init__()

140
141
142
143
144
145
146
        self.linear_1 = ColumnParallelLinear(
            vision_hidden_size,
            text_hidden_size,
            bias=multimodal_projector_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.linear_1",
        )
147
        self.act = get_act_fn(projector_hidden_act)
148
149
150
151
152
153
154
        self.linear_2 = RowParallelLinear(
            text_hidden_size,
            text_hidden_size,
            bias=multimodal_projector_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.linear_2",
        )
155

156
    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
157
        hidden_states, _ = self.linear_1(image_features)
158
        hidden_states = self.act(hidden_states)
159
        hidden_states, _ = self.linear_2(hidden_states)
160
161
162
        return hidden_states


163
164
class LlavaLikeConfig(Protocol):
    vision_config: Final[PretrainedConfig]
165
    image_token_index: Final[int]
166
    vision_feature_select_strategy: Final[str]
167
    vision_feature_layer: Final[int | list[int]]
168

169

170
171
172
173
class LlavaLikeProcessor(Protocol):
    image_token: Final[str]


174
175
class BaseLlavaProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self) -> LlavaLikeConfig:
176
        return self.ctx.get_hf_config(LlavaConfig)
177

178
179
    def get_vision_encoder_info(self):
        return get_vision_encoder_info(self.get_hf_config())
180

181
    @abstractmethod
182
    def get_hf_processor(self, **kwargs: object) -> LlavaLikeProcessor:
183
        raise NotImplementedError
184

185
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
186
        return {"image": None}
187

188
189
190
191
192
193
194
195
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        hf_config = self.get_hf_config()
        vision_encoder_info = self.get_vision_encoder_info()
196

197
        return get_num_selected_vision_tokens(
198
199
200
201
            vision_encoder_info.get_num_image_tokens(
                image_width=image_width,
                image_height=image_height,
            ),
202
            hf_config.vision_feature_select_strategy,
203
        )
204

205
206
    def get_image_size_with_most_features(self) -> ImageSize:
        vision_encoder_info = self.get_vision_encoder_info()
207
208
        width = height = vision_encoder_info.get_image_size()
        return ImageSize(width=width, height=height)
209

210
211
    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()
212

213
        return self.get_num_image_tokens(
214
215
216
217
            image_width=target_width,
            image_height=target_height,
        )

218
219
220
221
222

_I = TypeVar("_I", bound=BaseLlavaProcessingInfo)


class LlavaDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
223
224
225
226
227
228
229
230
231
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        processor = self.info.get_hf_processor()
        image_token = processor.image_token

        return image_token * num_images

    def get_dummy_mm_data(
232
        self,
233
        seq_len: int,
234
        mm_counts: Mapping[str, int],
235
        mm_options: Mapping[str, BaseDummyOptions],
236
    ) -> MultiModalDataDict:
237
238
        num_images = mm_counts.get("image", 0)

239
        target_width, target_height = self.info.get_image_size_with_most_features()
240

241
        image_overrides = mm_options.get("image")
242

243
        return {
244
245
246
247
248
249
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            )
250
251
252
        }


253
class LlavaProcessingInfo(BaseLlavaProcessingInfo):
254
    def get_hf_processor(self, **kwargs: object):
255
256
257
258
259
260
261
        hf_processor = self.ctx.get_hf_processor(LlavaProcessor, **kwargs)
        # In case patch_size is omitted from `processor_config.json`
        # e.g. for E5-V: https://huggingface.co/royokong/e5-v
        if hf_processor.patch_size is None:
            patch_size = self.get_vision_encoder_info().get_patch_size()
            hf_processor.patch_size = patch_size
        return hf_processor
262
263


264
class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor[_I]):
265
266
267
268
269
270
271
272
    # Copied from BaseMultiModalProcessor
    @abstractmethod
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        raise NotImplementedError
273

274
    def _get_prompt_updates(
275
276
277
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
278
        out_mm_kwargs: MultiModalKwargsItems,
279
    ) -> Sequence[PromptUpdate]:
280
        hf_config = self.info.get_hf_config()
281
282
283
284
        image_token_id = hf_config.image_token_index

        def get_replacement(item_idx: int):
            images = mm_items.get_items(
285
286
                "image", (ImageEmbeddingItems, ImageProcessorItems)
            )
287
288
289
290
291

            if isinstance(images, ImageEmbeddingItems):
                num_image_tokens = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
292
                num_image_tokens = self.info.get_num_image_tokens(
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
                    image_width=image_size.width,
                    image_height=image_size.height,
                )

            return [image_token_id] * num_image_tokens

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
                replacement=get_replacement,
            ),
        ]


308
class LlavaMultiModalProcessor(BaseLlavaMultiModalProcessor[LlavaProcessingInfo]):
309
310
311
312
313
314
315
316
317
318
319
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )


320
class PixtralHFProcessingInfo(BaseLlavaProcessingInfo):
321
322
    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(PixtralProcessor, **kwargs)
323

324

325
class PixtralHFMultiModalProcessor(BaseMultiModalProcessor[PixtralHFProcessingInfo]):
326
327
328
329
330
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
331
        tok_kwargs: Mapping[str, object],
332
333
334
335
336
    ) -> BatchFeature:
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
337
            tok_kwargs=tok_kwargs,
338
        )
339

340
341
        pixel_values = processed_outputs.get("pixel_values")
        if pixel_values is not None:
342
343
344
345
            # Avoid padding since we need the output for each image to be
            # independent of other images for the cache to work correctly
            image_sizes = processed_outputs["image_sizes"]
            assert len(pixel_values) == len(image_sizes)
346

347
348
349
            processed_outputs["pixel_values"] = [
                p[:, :h, :w] for p, (h, w) in zip(pixel_values, image_sizes)
            ]
350

351
        return processed_outputs
352

353
354
355
356
357
358
359
360
361
362
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )

363
    def _get_prompt_updates(
364
365
        self,
        mm_items: MultiModalDataItems,
366
        hf_processor_mm_kwargs: Mapping[str, object],
367
        out_mm_kwargs: MultiModalKwargsItems,
368
    ) -> Sequence[PromptUpdate]:
369
        processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
370
        hf_config = self.info.get_hf_config()
371
372
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()
373

374
375
376
        image_break_id = vocab[processor.image_break_token]
        image_token_id = hf_config.image_token_index
        image_end_id = vocab[processor.image_end_token]
377

378
379
        assert isinstance(hf_config.vision_config, PixtralVisionConfig)
        encoder_info = PixtralHFEncoderInfo(hf_config)
380

381
382
383
        def get_replacement(item_idx: int):
            images = mm_items.get_items("image", ImageProcessorItems)
            image_size = images.get_image_size(item_idx)
384

385
            ncols, nrows = encoder_info.get_patch_grid_size(
386
387
388
                image_width=image_size.width,
                image_height=image_size.height,
            )
389

390
391
            tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
            tokens[-1] = image_end_id
392

393
            return PromptUpdateDetails.select_token_id(tokens, image_token_id)
394
395
396
397
398

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
399
400
                replacement=get_replacement,
            ),
401
402
        ]

403

404
def _build_llava_or_pixtral_hf_info(
405
406
    ctx: InputProcessingContext,
) -> BaseLlavaProcessingInfo:
407
408
409
410
411
412
413
414
    hf_config = ctx.get_hf_config(LlavaConfig)

    if isinstance(hf_config.vision_config, PixtralVisionConfig):
        return PixtralHFProcessingInfo(ctx)

    return LlavaProcessingInfo(ctx)


415
def _build_llava_or_pixtral_hf_processor(
416
417
    info: _I,
    dummy_inputs: BaseDummyInputsBuilder[_I],
418
    *,
419
    cache: BaseMultiModalProcessorCache | None = None,
420
) -> BaseMultiModalProcessor:
421
    if isinstance(info, PixtralHFProcessingInfo):
422
        return PixtralHFMultiModalProcessor(
423
424
425
426
427
428
429
430
431
            info,
            dummy_inputs,  # type: ignore
            cache=cache,
        )

    if isinstance(info, LlavaProcessingInfo):
        return LlavaMultiModalProcessor(
            info,
            dummy_inputs,  # type: ignore
432
            cache=cache,
433
        )
434

435
    raise NotImplementedError(type(info))
436
437
438
439
440


def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
    """Determine the number of hidden layers to initialize up to in the
    visual encoder.
441

442
443
444
445
446
447
448
    Args:
        hf_config: Model config with vision feature layer(s).
    """
    feature_layers = hf_config.vision_feature_layer
    num_hidden_layers = hf_config.vision_config.num_hidden_layers
    # If we have one feature layer, initialize up to that layer
    if isinstance(feature_layers, int):
449
        return get_layer_index(feature_layers, num_hidden_layers)
450
451
    # If we have multiple feature layers, initialize up to the deepest one
    elif isinstance(feature_layers, (list, tuple)):
452
        return max(get_layer_index(idx, num_hidden_layers) for idx in feature_layers)
453
454
455
    raise TypeError(
        f"vision_layer_feature type: {type(feature_layers)} is not supported"
    )
456
457


458
459
def init_vision_tower_for_llava(
    hf_config: LlavaLikeConfig,
460
    quant_config: QuantizationConfig | None,
461
    *,
462
    require_post_norm: bool | None = None,
463
    prefix: str = "",
464
) -> CLIPVisionModel | SiglipVisionModel | PixtralHFVisionModel:
465
466
    vision_config = hf_config.vision_config

467
468
    # Initialize the vision tower only up to the deepest required feature layer
    num_hidden_layers = _get_num_hidden_layers(hf_config)
469
470
471
472

    if isinstance(vision_config, CLIPVisionConfig):
        return CLIPVisionModel(
            vision_config,
473
            quant_config=quant_config,
474
            num_hidden_layers_override=num_hidden_layers,
475
            require_post_norm=require_post_norm,
476
            prefix=prefix,
477
478
479
480
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return SiglipVisionModel(
            vision_config,
481
            quant_config=quant_config,
482
            num_hidden_layers_override=num_hidden_layers,
483
            require_post_norm=require_post_norm,
484
            prefix=prefix,
485
        )
486
    elif isinstance(vision_config, PixtralVisionConfig):
487
488
        return PixtralHFVisionModel(
            vision_config,
489
            quant_config=quant_config,
490
491
            num_hidden_layers_override=num_hidden_layers,
            require_post_norm=require_post_norm,
492
            prefix=prefix,
493
        )
494
495
496
497
498

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


499
500
501
502
503
@MULTIMODAL_REGISTRY.register_processor(
    _build_llava_or_pixtral_hf_processor,
    info=_build_llava_or_pixtral_hf_info,
    dummy_inputs=LlavaDummyInputsBuilder,
)
504
class LlavaForConditionalGeneration(
505
506
507
508
509
510
    nn.Module,
    SupportsLoRA,
    SupportsMultiModal,
    SupportsPP,
    SupportsEagle,
    SupportsEagle3,
511
):
512
513
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
514
        "gate_up_proj": ["gate_proj", "up_proj"],
515
    }
516

517
518
519
520
521
522
523
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.language_model.": "language_model.model.",
            "model.vision_tower.": "vision_tower.",
            "model.multi_modal_projector.": "multi_modal_projector.",
            "lm_head.": "language_model.lm_head.",
524
525
        }
    )
526

527
    @classmethod
528
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
529
530
531
532
533
        if modality.startswith("image"):
            return "<image>"

        raise ValueError("Only image modality is supported")

534
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
535
        super().__init__()
536

537
538
539
540
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

541
        self.config = config
542
        self.multimodal_config = multimodal_config
543

544
545
546
547
548
        self.configure_mm_token_handling(
            vocab_size=config.text_config.vocab_size,
            mm_token_ids=[config.image_token_index],
        )

549
550
        # NOTE: These are special cases for Pixtral-12B in the HF-format
        # https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json  # noqa
551
552
553
554
        if (
            config.text_config.architectures is None
            and config.text_config.model_type == "mistral"
        ):
555
            config.text_config.architectures = ["MistralForCausalLM"]
556
557
558
559
        if (
            config.projector_hidden_act is None
            and config.vision_config.hidden_act == "gelu"
        ):
560
561
            config.projector_hidden_act = "gelu"

562
        with self._mark_tower_model(vllm_config, "image"):
563
564
            self.vision_tower = init_vision_tower_for_llava(
                config,
565
                quant_config=quant_config,
566
                require_post_norm=False,
567
568
                prefix=maybe_prefix(prefix, "vision_tower"),
            )
569
570
571
572
573
574
            self.multi_modal_projector = LlavaMultiModalProjector(
                vision_hidden_size=config.vision_config.hidden_size,
                text_hidden_size=config.text_config.hidden_size,
                projector_hidden_act=config.projector_hidden_act,
                multimodal_projector_bias=config.multimodal_projector_bias,
                quant_config=quant_config,
575
576
                prefix=maybe_prefix(prefix, "multi_modal_projector"),
            )
577
578
579
580
581
582
583

        with self._mark_language_model(vllm_config):
            self.language_model = init_vllm_registered_model(
                vllm_config=vllm_config,
                hf_config=config.text_config,
                prefix=maybe_prefix(prefix, "language_model"),
            )
584

585
        self.make_empty_intermediate_tensors = (
586
587
            self.language_model.make_empty_intermediate_tensors
        )
588

589
    def _parse_and_validate_image_input(
590
        self, **kwargs: object
591
    ) -> LlavaImageInputs | None:
592
        pixel_values = kwargs.pop("pixel_values", None)
593
        image_embeds = kwargs.pop("image_embeds", None)
594

595
        if pixel_values is None and image_embeds is None:
596
            return None
597

598
        if pixel_values is not None:
599
            if self.config.vision_config.model_type == "pixtral":
600
601
                return PixtralHFImagePixelInputs(
                    type="pixel_values_pixtral",
602
                    pixel_values=pixel_values,
603
604
                )

605
            expected_h = expected_w = self.config.vision_config.image_size
606
607
            return LlavaImagePixelInputs(
                type="pixel_values",
608
                pixel_values=pixel_values,
609
                resolve_bindings={"h": expected_h, "w": expected_w},
610
611
612
            )

        if image_embeds is not None:
613
614
615
            if self.config.vision_config.model_type == "pixtral":
                raise ValueError("Pixtral-HF does not support image_embeds.")

616
617
            return LlavaImageEmbeddingInputs(
                type="image_embeds",
618
                data=image_embeds,
619
620
621
            )

        raise AssertionError("This line should be unreachable.")
622

623
624
    def _image_pixels_to_features(
        self,
625
626
627
        vision_tower: CLIPVisionModel | SiglipVisionModel | PixtralHFVisionModel,
        pixel_values: torch.Tensor | list[torch.Tensor],
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
628
629
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
630
631
632
633
        return vision_tower(
            pixel_values,
            feature_select_strategy=self.config.vision_feature_select_strategy,
        )
634

635
636
    def _process_image_pixels(
        self,
637
638
        inputs: LlavaImagePixelInputs | PixtralHFImagePixelInputs,
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
639
        pixel_values = inputs["pixel_values"]
640
641
642

        return self._image_pixels_to_features(self.vision_tower, pixel_values)

643
644
645
    def _process_image_input(
        self,
        image_input: LlavaImageInputs,
646
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
647
648
649
        if image_input["type"] == "image_embeds":
            return image_input["data"]

650
        image_features = self._process_image_pixels(image_input)
651

652
653
654
        if isinstance(image_features, torch.Tensor):
            return self.multi_modal_projector(image_features)

655
        feature_sizes = [image_feature.shape[0] for image_feature in image_features]
656
657
658
659
660

        image_embeds = self.multi_modal_projector(torch.cat(image_features))
        image_embeds = torch.split(image_embeds, feature_sizes)
        return image_embeds

661
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
662
663
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
664
            return []
665

666
        return self._process_image_input(image_input)
667

668
669
    def forward(
        self,
670
        input_ids: torch.Tensor | None,
671
        positions: torch.Tensor,
672
673
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
674
        **kwargs: object,
675
    ) -> torch.Tensor | IntermediateTensors:
Cyrus Leung's avatar
Cyrus Leung committed
676
        """Run forward pass for LLaVA-1.5.
677
678
679

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
680

681
        Concretely, consider a text prompt:
682
683
        `"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.

684
        Tokenizer outputs:
685
686
687
688
        `[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
        278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
689
        before they are inputted to the model, so the input processor prepends
690
691
692
693
694
695
696
697
698
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
        29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
        29901]`.

        We insert 575 tokens so that including the original image token in the
        input, there are a total of 576 (24 * 24) image tokens, which
        corresponds to the number of image tokens inputted to the language
        model, i.e. the number of image tokens outputted by the visual encoder.
699
700
701
702
703
704
705

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
706
707
708
            positions: Position indices for the input tokens.
            intermediate_tensors: Intermediate tensors from prior forward pass.
            inputs_embeds: Optional tensor of input embeddings.
709

710
        Info:
samzong's avatar
samzong committed
711
            [`LlavaImageInputs`][vllm.model_executor.models.llava.LlavaImageInputs]
712
        """
713
714
        if intermediate_tensors is not None:
            inputs_embeds = None
715

716
717
718
        hidden_states = self.language_model.model(
            input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
        )
719
720
721

        return hidden_states

722
723
724
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
725
    ) -> torch.Tensor | None:
726
        return self.language_model.compute_logits(hidden_states)
727

728
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
729
        loader = AutoWeightsLoader(self)
730
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
731

732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model",
            connector="multi_modal_projector",
            tower_model="vision_tower",
        )

    def get_num_mm_encoder_tokens(
        self,
        num_image_tokens: int,
    ) -> int:
        # LLaVA's vision encoder outputs one token per patch without
        # spatial merging or pixel shuffle
        return num_image_tokens

    def get_num_mm_connector_tokens(
        self,
        num_vision_tokens: int,
    ) -> int:
        # LLaVA's MLP projector outputs the same number of tokens
        # as it receives from the vision encoder (1:1 mapping)
        return num_vision_tokens

758

759
class MantisProcessingInfo(LlavaProcessingInfo):
760
    def get_hf_processor(self, **kwargs: object):
761
762
763
        hf_config = self.get_hf_config()
        vision_info = self.get_vision_encoder_info()

764
        kwargs.setdefault("patch_size", vision_info.get_patch_size())
765
766
767
768
        kwargs.setdefault(
            "vision_feature_select_strategy",
            hf_config.vision_feature_select_strategy,
        )
769

770
        return self.ctx.get_hf_processor(LlavaProcessor, **kwargs)
771
772


773
class MantisMultiModalProcessor(LlavaMultiModalProcessor):
774
775
    def apply(
        self,
776
777
        inputs: ProcessorInputs,
        timing_ctx: TimingContext,
778
    ) -> MultiModalInput:
779
        hf_config = self.info.get_hf_config()
780
        image_token_id = hf_config.image_token_index
781
782

        # Assume that it doesn't depend on the image size
783
        num_image_tokens = self.info.get_num_image_tokens(
784
785
786
            image_width=-1,
            image_height=-1,
        )
787

788
        result = super().apply(inputs, timing_ctx)
789

790
        mm_item_counts = inputs.mm_data_items.get_all_counts()
791
        mm_kwargs = result["mm_kwargs"]
792
        mm_hashes = result["mm_hashes"]
793
794
795
796

        # We reimplement the functionality of MLlavaProcessor from
        # https://github.com/TIGER-AI-Lab/Mantis.git
        def get_replacement_mantis(item_idx: int):
797
798
799
800
801
802
            return "".join(
                [
                    f"(image {item_idx + 1}: <Image>",  # 7 tokens
                    "<image>" * num_image_tokens,
                    "</Image>)",  # 3 tokens
                ]
803
            )
804
805
806
807
808
809
810
811
812
813
814

        mantis_mm_repls = self._bind_and_group_updates(
            [
                PromptReplacement(
                    modality="image",
                    target=[image_token_id] * num_image_tokens,
                    replacement=get_replacement_mantis,
                )
            ],
            mm_item_counts,
        )
815

816
        prompt_ids, _ = self._apply_prompt_updates(
817
            result["prompt_token_ids"],
818
            mantis_mm_repls,
819
820
        )

821
        orig_repls = self._get_mm_prompt_updates(
822
823
            inputs.mm_data_items,
            inputs.hf_processor_mm_kwargs,
824
825
            mm_kwargs,
        )
826
        mm_placeholders = self._find_mm_placeholders(prompt_ids, orig_repls)
827
        self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
828

829
830
831
        mm_placeholder_ranges = {
            modality: [item.to_range() for item in placeholders]
            for modality, placeholders in mm_placeholders.items()
832
833
        }

834
        return mm_input(
835
836
            prompt_token_ids=prompt_ids,
            mm_kwargs=mm_kwargs,
837
            mm_hashes=mm_hashes,
838
            mm_placeholders=mm_placeholder_ranges,
839
        )
840
841
842
843


# To use this model, please use
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
844
845
846
847
848
@MULTIMODAL_REGISTRY.register_processor(
    MantisMultiModalProcessor,
    info=MantisProcessingInfo,
    dummy_inputs=LlavaDummyInputsBuilder,
)
849
850
class MantisForConditionalGeneration(LlavaForConditionalGeneration):
    pass