"vscode:/vscode.git/clone" did not exist on "afe9eb408ee1191cd57a68d46b6ce2860b1b41e1"
llava.py 28.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import abstractmethod
5
from collections.abc import Iterable, Mapping, Sequence
6
from typing import Annotated, Final, Literal, Protocol, TypeAlias, TypeVar
7
8

import torch
9
import torch.nn as nn
10
11
12
13
14
15
16
17
from transformers import (
    BatchFeature,
    CLIPVisionConfig,
    LlavaConfig,
    PixtralVisionConfig,
    PretrainedConfig,
    SiglipVisionConfig,
)
18
from transformers.models.llava import LlavaProcessor
19
from transformers.models.pixtral import PixtralProcessor
20

21
from vllm.config import VllmConfig
22
from vllm.config.multimodal import BaseDummyOptions, MultiModalConfig
23
from vllm.model_executor.layers.activation import get_act_fn
24
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
25
from vllm.model_executor.layers.quantization import QuantizationConfig
26
from vllm.multimodal import MULTIMODAL_REGISTRY
27
from vllm.multimodal.cache import BaseMultiModalProcessorCache
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalInputs,
    MultiModalKwargsItems,
    MultiModalUUIDDict,
)
from vllm.multimodal.parse import (
    ImageEmbeddingItems,
    ImageProcessorItems,
    ImageSize,
    MultiModalDataItems,
)
from vllm.multimodal.processing import (
42
    BaseDummyInputsBuilder,
43
44
45
46
47
48
49
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    InputProcessingContext,
    PromptReplacement,
    PromptUpdate,
    PromptUpdateDetails,
)
50
from vllm.sequence import IntermediateTensors
51
from vllm.utils.tensor_schema import TensorSchema, TensorShape
52

53
from .clip import CLIPVisionModel
54
55
from .interfaces import (
    MultiModalEmbeddings,
56
    SupportsEagle3,
57
58
59
60
61
    SupportsLoRA,
    SupportsMultiModal,
    SupportsPP,
)
from .module_mapping import MultiModelKeys
62
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
63
from .siglip import SiglipVisionModel
64
65
66
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
67
    get_layer_index,
68
69
70
    init_vllm_registered_model,
    maybe_prefix,
)
71
from .vision import get_num_selected_vision_tokens, get_vision_encoder_info
72
73


74
class LlavaImagePixelInputs(TensorSchema):
75
    """
76
77
78
79
80
    Dimensions:
        - bn: Batch size * number of images
        - c: Number of channels (3)
        - h: Height
        - w: Width
81

82
83
84
    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """
85

86
87
    type: Literal["pixel_values"] = "pixel_values"
    pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
88

89

90
class PixtralHFImagePixelInputs(TensorSchema):
91
    """
92
93
94
95
96
    Dimensions:
        - bn: Batch size * number of images
        - c: Number of channels
        - h: Height
        - w: Width
97

98
99
100
    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """
101

102
    type: Literal["pixel_values_pixtral"] = "pixel_values_pixtral"
103
    pixel_values: Annotated[
104
        torch.Tensor | list[torch.Tensor],
105
106
        TensorShape("bn", "c", "h", "w", dynamic_dims={"h", "w"}),
    ]
107

108

109
class LlavaImageEmbeddingInputs(TensorSchema):
110
    """
111
112
113
114
115
    Dimensions:
        - bn: Batch size * number of images
        - ifs: Image feature size
        - hs: Hidden size (must match language model backbone)
    """
116

117
118
    type: Literal["image_embeds"] = "image_embeds"
    data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")]
119
120


121
122
123
LlavaImageInputs: TypeAlias = (
    LlavaImagePixelInputs | PixtralHFImagePixelInputs | LlavaImageEmbeddingInputs
)
124
125


126
class LlavaMultiModalProjector(nn.Module):
127
128
129
130
131
132
    def __init__(
        self,
        vision_hidden_size: int,
        text_hidden_size: int,
        projector_hidden_act: str,
        multimodal_projector_bias: bool,
133
        quant_config: QuantizationConfig | None = None,
134
135
        prefix: str = "",
    ):
136
137
        super().__init__()

138
139
140
141
142
143
144
        self.linear_1 = ColumnParallelLinear(
            vision_hidden_size,
            text_hidden_size,
            bias=multimodal_projector_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.linear_1",
        )
145
        self.act = get_act_fn(projector_hidden_act)
146
147
148
149
150
151
152
        self.linear_2 = RowParallelLinear(
            text_hidden_size,
            text_hidden_size,
            bias=multimodal_projector_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.linear_2",
        )
153

154
    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
155
        hidden_states, _ = self.linear_1(image_features)
156
        hidden_states = self.act(hidden_states)
157
        hidden_states, _ = self.linear_2(hidden_states)
158
159
160
        return hidden_states


161
162
class LlavaLikeConfig(Protocol):
    vision_config: Final[PretrainedConfig]
163
    image_token_index: Final[int]
164
    vision_feature_select_strategy: Final[str]
165
    vision_feature_layer: Final[int | list[int]]
166

167

168
169
170
171
class LlavaLikeProcessor(Protocol):
    image_token: Final[str]


172
173
class BaseLlavaProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self) -> LlavaLikeConfig:
174
        return self.ctx.get_hf_config(LlavaConfig)
175

176
177
    def get_vision_encoder_info(self):
        return get_vision_encoder_info(self.get_hf_config())
178

179
    @abstractmethod
180
    def get_hf_processor(self, **kwargs: object) -> LlavaLikeProcessor:
181
        raise NotImplementedError
182

183
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
184
        return {"image": None}
185

186
187
188
189
190
191
192
193
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        hf_config = self.get_hf_config()
        vision_encoder_info = self.get_vision_encoder_info()
194

195
        return get_num_selected_vision_tokens(
196
197
198
199
            vision_encoder_info.get_num_image_tokens(
                image_width=image_width,
                image_height=image_height,
            ),
200
            hf_config.vision_feature_select_strategy,
201
        )
202

203
204
    def get_image_size_with_most_features(self) -> ImageSize:
        vision_encoder_info = self.get_vision_encoder_info()
205
206
        width = height = vision_encoder_info.get_image_size()
        return ImageSize(width=width, height=height)
207

208
209
    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()
210

211
        return self.get_num_image_tokens(
212
213
214
215
            image_width=target_width,
            image_height=target_height,
        )

216
217
218
219
220

_I = TypeVar("_I", bound=BaseLlavaProcessingInfo)


class LlavaDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
221
222
223
224
225
226
227
228
229
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        processor = self.info.get_hf_processor()
        image_token = processor.image_token

        return image_token * num_images

    def get_dummy_mm_data(
230
        self,
231
        seq_len: int,
232
        mm_counts: Mapping[str, int],
233
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
234
    ) -> MultiModalDataDict:
235
236
        num_images = mm_counts.get("image", 0)

237
        target_width, target_height = self.info.get_image_size_with_most_features()
238

239
240
        image_overrides = mm_options.get("image") if mm_options else None

241
        return {
242
243
244
245
246
247
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            )
248
249
250
        }


251
class LlavaProcessingInfo(BaseLlavaProcessingInfo):
252
    def get_hf_processor(self, **kwargs: object):
253
254
255
256
257
258
259
        hf_processor = self.ctx.get_hf_processor(LlavaProcessor, **kwargs)
        # In case patch_size is omitted from `processor_config.json`
        # e.g. for E5-V: https://huggingface.co/royokong/e5-v
        if hf_processor.patch_size is None:
            patch_size = self.get_vision_encoder_info().get_patch_size()
            hf_processor.patch_size = patch_size
        return hf_processor
260
261


262
class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor[_I]):
263
264
265
266
267
268
269
270
    # Copied from BaseMultiModalProcessor
    @abstractmethod
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        raise NotImplementedError
271

272
    def _get_prompt_updates(
273
274
275
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
276
        out_mm_kwargs: MultiModalKwargsItems,
277
    ) -> Sequence[PromptUpdate]:
278
        hf_config = self.info.get_hf_config()
279
280
281
282
        image_token_id = hf_config.image_token_index

        def get_replacement(item_idx: int):
            images = mm_items.get_items(
283
284
                "image", (ImageEmbeddingItems, ImageProcessorItems)
            )
285
286
287
288
289

            if isinstance(images, ImageEmbeddingItems):
                num_image_tokens = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
290
                num_image_tokens = self.info.get_num_image_tokens(
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
                    image_width=image_size.width,
                    image_height=image_size.height,
                )

            return [image_token_id] * num_image_tokens

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
                replacement=get_replacement,
            ),
        ]


306
class LlavaMultiModalProcessor(BaseLlavaMultiModalProcessor[LlavaProcessingInfo]):
307
308
309
310
311
312
313
314
315
316
317
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )


318
class PixtralHFProcessingInfo(BaseLlavaProcessingInfo):
319
320
    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(PixtralProcessor, **kwargs)
321

322

323
class PixtralHFMultiModalProcessor(BaseMultiModalProcessor[PixtralHFProcessingInfo]):
324
325
326
327
328
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
329
        tok_kwargs: Mapping[str, object],
330
331
332
333
334
    ) -> BatchFeature:
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
335
            tok_kwargs=tok_kwargs,
336
        )
337

338
339
        pixel_values = processed_outputs.get("pixel_values")
        if pixel_values is not None:
340
341
342
343
            # Avoid padding since we need the output for each image to be
            # independent of other images for the cache to work correctly
            image_sizes = processed_outputs["image_sizes"]
            assert len(pixel_values) == len(image_sizes)
344

345
346
347
            processed_outputs["pixel_values"] = [
                p[:, :h, :w] for p, (h, w) in zip(pixel_values, image_sizes)
            ]
348

349
        return processed_outputs
350

351
352
353
354
355
356
357
358
359
360
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )

361
    def _get_prompt_updates(
362
363
        self,
        mm_items: MultiModalDataItems,
364
        hf_processor_mm_kwargs: Mapping[str, object],
365
        out_mm_kwargs: MultiModalKwargsItems,
366
    ) -> Sequence[PromptUpdate]:
367
        processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
368
        hf_config = self.info.get_hf_config()
369
370
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()
371

372
373
374
        image_break_id = vocab[processor.image_break_token]
        image_token_id = hf_config.image_token_index
        image_end_id = vocab[processor.image_end_token]
375

376
377
        assert isinstance(hf_config.vision_config, PixtralVisionConfig)
        encoder_info = PixtralHFEncoderInfo(hf_config)
378

379
380
381
        def get_replacement(item_idx: int):
            images = mm_items.get_items("image", ImageProcessorItems)
            image_size = images.get_image_size(item_idx)
382

383
            ncols, nrows = encoder_info.get_patch_grid_size(
384
385
386
                image_width=image_size.width,
                image_height=image_size.height,
            )
387

388
389
            tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
            tokens[-1] = image_end_id
390

391
            return PromptUpdateDetails.select_token_id(tokens, image_token_id)
392
393
394
395
396

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
397
398
                replacement=get_replacement,
            ),
399
400
        ]

401

402
def _build_llava_or_pixtral_hf_info(
403
404
    ctx: InputProcessingContext,
) -> BaseLlavaProcessingInfo:
405
406
407
408
409
410
411
412
    hf_config = ctx.get_hf_config(LlavaConfig)

    if isinstance(hf_config.vision_config, PixtralVisionConfig):
        return PixtralHFProcessingInfo(ctx)

    return LlavaProcessingInfo(ctx)


413
def _build_llava_or_pixtral_hf_processor(
414
415
    info: _I,
    dummy_inputs: BaseDummyInputsBuilder[_I],
416
    *,
417
    cache: BaseMultiModalProcessorCache | None = None,
418
) -> BaseMultiModalProcessor:
419
    if isinstance(info, PixtralHFProcessingInfo):
420
        return PixtralHFMultiModalProcessor(
421
422
423
424
425
426
427
428
429
            info,
            dummy_inputs,  # type: ignore
            cache=cache,
        )

    if isinstance(info, LlavaProcessingInfo):
        return LlavaMultiModalProcessor(
            info,
            dummy_inputs,  # type: ignore
430
            cache=cache,
431
        )
432

433
    raise NotImplementedError(type(info))
434
435
436
437
438


def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
    """Determine the number of hidden layers to initialize up to in the
    visual encoder.
439

440
441
442
443
444
445
446
    Args:
        hf_config: Model config with vision feature layer(s).
    """
    feature_layers = hf_config.vision_feature_layer
    num_hidden_layers = hf_config.vision_config.num_hidden_layers
    # If we have one feature layer, initialize up to that layer
    if isinstance(feature_layers, int):
447
        return get_layer_index(feature_layers, num_hidden_layers)
448
449
    # If we have multiple feature layers, initialize up to the deepest one
    elif isinstance(feature_layers, (list, tuple)):
450
        return max(get_layer_index(idx, num_hidden_layers) for idx in feature_layers)
451
452
453
    raise TypeError(
        f"vision_layer_feature type: {type(feature_layers)} is not supported"
    )
454
455


456
457
def init_vision_tower_for_llava(
    hf_config: LlavaLikeConfig,
458
    quant_config: QuantizationConfig | None,
459
    multimodal_config: MultiModalConfig | None,
460
    *,
461
    require_post_norm: bool | None = None,
462
    prefix: str = "",
463
) -> CLIPVisionModel | SiglipVisionModel | PixtralHFVisionModel:
464
465
    vision_config = hf_config.vision_config

466
467
    # Initialize the vision tower only up to the deepest required feature layer
    num_hidden_layers = _get_num_hidden_layers(hf_config)
468
469
470
471

    if isinstance(vision_config, CLIPVisionConfig):
        return CLIPVisionModel(
            vision_config,
472
            quant_config=quant_config,
473
            multimodal_config=multimodal_config,
474
            num_hidden_layers_override=num_hidden_layers,
475
            require_post_norm=require_post_norm,
476
            prefix=prefix,
477
478
479
480
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return SiglipVisionModel(
            vision_config,
481
            quant_config=quant_config,
482
            multimodal_config=multimodal_config,
483
            num_hidden_layers_override=num_hidden_layers,
484
            require_post_norm=require_post_norm,
485
            prefix=prefix,
486
        )
487
    elif isinstance(vision_config, PixtralVisionConfig):
488
489
        return PixtralHFVisionModel(
            vision_config,
490
            quant_config=quant_config,
491
            multimodal_config=multimodal_config,
492
493
            num_hidden_layers_override=num_hidden_layers,
            require_post_norm=require_post_norm,
494
            prefix=prefix,
495
        )
496
497
498
499
500

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


501
502
503
504
505
@MULTIMODAL_REGISTRY.register_processor(
    _build_llava_or_pixtral_hf_processor,
    info=_build_llava_or_pixtral_hf_info,
    dummy_inputs=LlavaDummyInputsBuilder,
)
506
class LlavaForConditionalGeneration(
507
    nn.Module, SupportsLoRA, SupportsMultiModal, SupportsPP, SupportsEagle3
508
):
509
510
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
511
        "gate_up_proj": ["gate_proj", "up_proj"],
512
    }
513

514
515
516
517
518
519
520
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.language_model.": "language_model.model.",
            "model.vision_tower.": "vision_tower.",
            "model.multi_modal_projector.": "multi_modal_projector.",
            "lm_head.": "language_model.lm_head.",
521
522
        }
    )
523

524
    @classmethod
525
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
526
527
528
529
530
        if modality.startswith("image"):
            return "<image>"

        raise ValueError("Only image modality is supported")

531
532
533
534
535
536
537
    def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
        self.get_language_model().model.aux_hidden_state_layers = layers

    def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
        num_layers = len(self.get_language_model().model.layers)
        return (2, num_layers // 2, num_layers - 3)

538
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
539
        super().__init__()
540

541
542
543
544
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

545
        self.config = config
546
        self.multimodal_config = multimodal_config
547

548
549
        # NOTE: These are special cases for Pixtral-12B in the HF-format
        # https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json  # noqa
550
551
552
553
        if (
            config.text_config.architectures is None
            and config.text_config.model_type == "mistral"
        ):
554
            config.text_config.architectures = ["MistralForCausalLM"]
555
556
557
558
        if (
            config.projector_hidden_act is None
            and config.vision_config.hidden_act == "gelu"
        ):
559
560
            config.projector_hidden_act = "gelu"

561
        with self._mark_tower_model(vllm_config, "image"):
562
563
            self.vision_tower = init_vision_tower_for_llava(
                config,
564
565
                quant_config=quant_config,
                multimodal_config=multimodal_config,
566
                require_post_norm=False,
567
568
                prefix=maybe_prefix(prefix, "vision_tower"),
            )
569
570
571
572
573
574
            self.multi_modal_projector = LlavaMultiModalProjector(
                vision_hidden_size=config.vision_config.hidden_size,
                text_hidden_size=config.text_config.hidden_size,
                projector_hidden_act=config.projector_hidden_act,
                multimodal_projector_bias=config.multimodal_projector_bias,
                quant_config=quant_config,
575
576
                prefix=maybe_prefix(prefix, "multi_modal_projector"),
            )
577
578
579
580
581
582
583

        with self._mark_language_model(vllm_config):
            self.language_model = init_vllm_registered_model(
                vllm_config=vllm_config,
                hf_config=config.text_config,
                prefix=maybe_prefix(prefix, "language_model"),
            )
584

585
        self.make_empty_intermediate_tensors = (
586
587
            self.language_model.make_empty_intermediate_tensors
        )
588

589
    def _parse_and_validate_image_input(
590
        self, **kwargs: object
591
    ) -> LlavaImageInputs | None:
592
        pixel_values = kwargs.pop("pixel_values", None)
593
        image_embeds = kwargs.pop("image_embeds", None)
594

595
        if pixel_values is None and image_embeds is None:
596
            return None
597

598
        if pixel_values is not None:
599
            if self.config.vision_config.model_type == "pixtral":
600
601
                return PixtralHFImagePixelInputs(
                    type="pixel_values_pixtral",
602
                    pixel_values=pixel_values,
603
604
                )

605
            expected_h = expected_w = self.config.vision_config.image_size
606
607
            return LlavaImagePixelInputs(
                type="pixel_values",
608
                pixel_values=pixel_values,
609
                resolve_bindings={"h": expected_h, "w": expected_w},
610
611
612
            )

        if image_embeds is not None:
613
614
615
            if self.config.vision_config.model_type == "pixtral":
                raise ValueError("Pixtral-HF does not support image_embeds.")

616
617
            return LlavaImageEmbeddingInputs(
                type="image_embeds",
618
                data=image_embeds,
619
620
621
            )

        raise AssertionError("This line should be unreachable.")
622

623
624
    def _image_pixels_to_features(
        self,
625
626
627
        vision_tower: CLIPVisionModel | SiglipVisionModel | PixtralHFVisionModel,
        pixel_values: torch.Tensor | list[torch.Tensor],
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
628
629
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
630
631
632
633
        return vision_tower(
            pixel_values,
            feature_select_strategy=self.config.vision_feature_select_strategy,
        )
634

635
636
    def _process_image_pixels(
        self,
637
638
        inputs: LlavaImagePixelInputs | PixtralHFImagePixelInputs,
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
639
        pixel_values = inputs["pixel_values"]
640
641
642

        return self._image_pixels_to_features(self.vision_tower, pixel_values)

643
644
645
    def _process_image_input(
        self,
        image_input: LlavaImageInputs,
646
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
647
648
649
        if image_input["type"] == "image_embeds":
            return image_input["data"]

650
        image_features = self._process_image_pixels(image_input)
651

652
653
654
        if isinstance(image_features, torch.Tensor):
            return self.multi_modal_projector(image_features)

655
        feature_sizes = [image_feature.shape[0] for image_feature in image_features]
656
657
658
659
660

        image_embeds = self.multi_modal_projector(torch.cat(image_features))
        image_embeds = torch.split(image_embeds, feature_sizes)
        return image_embeds

661
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
662
663
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
664
            return []
665

666
        return self._process_image_input(image_input)
667

668
669
670
671
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
672
673
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
674
        **kwargs: object,
675
    ) -> torch.Tensor | IntermediateTensors:
Cyrus Leung's avatar
Cyrus Leung committed
676
        """Run forward pass for LLaVA-1.5.
677
678
679

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
680

681
        Concretely, consider a text prompt:
682
683
        `"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.

684
        Tokenizer outputs:
685
686
687
688
        `[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
        278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
689
        before they are inputted to the model, so the input processor prepends
690
691
692
693
694
695
696
697
698
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
        29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
        29901]`.

        We insert 575 tokens so that including the original image token in the
        input, there are a total of 576 (24 * 24) image tokens, which
        corresponds to the number of image tokens inputted to the language
        model, i.e. the number of image tokens outputted by the visual encoder.
699
700
701
702
703
704
705

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
706
707
708
            positions: Position indices for the input tokens.
            intermediate_tensors: Intermediate tensors from prior forward pass.
            inputs_embeds: Optional tensor of input embeddings.
709

710
        Info:
samzong's avatar
samzong committed
711
            [`LlavaImageInputs`][vllm.model_executor.models.llava.LlavaImageInputs]
712
        """
713
714
        if intermediate_tensors is not None:
            inputs_embeds = None
715

716
717
718
        hidden_states = self.language_model.model(
            input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
        )
719
720
721

        return hidden_states

722
723
724
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
725
    ) -> torch.Tensor | None:
726
        return self.language_model.compute_logits(hidden_states)
727

728
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
729
        loader = AutoWeightsLoader(self)
730
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
731

732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model",
            connector="multi_modal_projector",
            tower_model="vision_tower",
        )

    def get_num_mm_encoder_tokens(
        self,
        num_image_tokens: int,
    ) -> int:
        # LLaVA's vision encoder outputs one token per patch without
        # spatial merging or pixel shuffle
        return num_image_tokens

    def get_num_mm_connector_tokens(
        self,
        num_vision_tokens: int,
    ) -> int:
        # LLaVA's MLP projector outputs the same number of tokens
        # as it receives from the vision encoder (1:1 mapping)
        return num_vision_tokens

758

759
class MantisProcessingInfo(LlavaProcessingInfo):
760
    def get_hf_processor(self, **kwargs: object):
761
762
763
        hf_config = self.get_hf_config()
        vision_info = self.get_vision_encoder_info()

764
        kwargs.setdefault("patch_size", vision_info.get_patch_size())
765
766
767
768
        kwargs.setdefault(
            "vision_feature_select_strategy",
            hf_config.vision_feature_select_strategy,
        )
769

770
        return self.ctx.get_hf_processor(LlavaProcessor, **kwargs)
771
772


773
class MantisMultiModalProcessor(LlavaMultiModalProcessor):
774
775
    def apply(
        self,
776
        prompt: str | list[int],
777
778
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
779
780
        tokenization_kwargs: Mapping[str, object] | None = None,
        mm_uuids: MultiModalUUIDDict | None = None,
781
    ) -> MultiModalInputs:
782
        hf_config = self.info.get_hf_config()
783
        image_token_id = hf_config.image_token_index
784
785

        # Assume that it doesn't depend on the image size
786
        num_image_tokens = self.info.get_num_image_tokens(
787
788
789
            image_width=-1,
            image_height=-1,
        )
790

791
792
793
794
795
796
797
        result = super().apply(
            prompt,
            mm_data,
            hf_processor_mm_kwargs,
            tokenization_kwargs,
            mm_uuids=mm_uuids,
        )
798

799
800
        mm_items = self._to_mm_items(mm_data)
        mm_item_counts = mm_items.get_all_counts()
801
        mm_kwargs = result["mm_kwargs"]
802
        mm_hashes = result["mm_hashes"]
803
804
805
806

        # We reimplement the functionality of MLlavaProcessor from
        # https://github.com/TIGER-AI-Lab/Mantis.git
        def get_replacement_mantis(item_idx: int):
807
808
809
810
811
812
            return "".join(
                [
                    f"(image {item_idx + 1}: <Image>",  # 7 tokens
                    "<image>" * num_image_tokens,
                    "</Image>)",  # 3 tokens
                ]
813
            )
814
815
816
817
818
819
820
821
822
823
824

        mantis_mm_repls = self._bind_and_group_updates(
            [
                PromptReplacement(
                    modality="image",
                    target=[image_token_id] * num_image_tokens,
                    replacement=get_replacement_mantis,
                )
            ],
            mm_item_counts,
        )
825

826
        prompt_ids, _ = self._apply_prompt_updates(
827
            result["prompt_token_ids"],
828
            mantis_mm_repls,
829
830
        )

831
        orig_repls = self._get_mm_prompt_updates(
832
833
834
835
            mm_items,
            hf_processor_mm_kwargs,
            mm_kwargs,
        )
836
        mm_placeholders = self._find_mm_placeholders(prompt_ids, orig_repls)
837
        self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
838

839
840
841
        mm_placeholder_ranges = {
            modality: [item.to_range() for item in placeholders]
            for modality, placeholders in mm_placeholders.items()
842
843
        }

844
        return MultiModalInputs(
845
846
847
            type="multimodal",
            prompt_token_ids=prompt_ids,
            mm_kwargs=mm_kwargs,
848
            mm_hashes=mm_hashes,
849
            mm_placeholders=mm_placeholder_ranges,
850
        )
851
852
853
854


# To use this model, please use
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
855
856
857
858
859
@MULTIMODAL_REGISTRY.register_processor(
    MantisMultiModalProcessor,
    info=MantisProcessingInfo,
    dummy_inputs=LlavaDummyInputsBuilder,
)
860
861
class MantisForConditionalGeneration(LlavaForConditionalGeneration):
    pass