"vscode:/vscode.git/clone" did not exist on "1d7c29f5fecab930fbb28bf59f1bc4510abe335b"
llava.py 28.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import abstractmethod
5
from collections.abc import Iterable, Mapping, Sequence
6
from typing import Annotated, Final, Literal, Protocol, TypeAlias, TypeVar
7
8

import torch
9
import torch.nn as nn
10
11
12
13
14
15
16
17
from transformers import (
    BatchFeature,
    CLIPVisionConfig,
    LlavaConfig,
    PixtralVisionConfig,
    PretrainedConfig,
    SiglipVisionConfig,
)
18
from transformers.models.llava import LlavaProcessor
19
from transformers.models.pixtral import PixtralProcessor
20

21
from vllm.config import VllmConfig
22
from vllm.config.multimodal import BaseDummyOptions
23
from vllm.model_executor.layers.activation import get_act_fn
24
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
25
from vllm.model_executor.layers.quantization import QuantizationConfig
26
from vllm.multimodal import MULTIMODAL_REGISTRY
27
from vllm.multimodal.cache import BaseMultiModalProcessorCache
28
29
30
31
32
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalInputs,
    MultiModalKwargsItems,
33
    mm_inputs,
34
35
36
37
38
39
)
from vllm.multimodal.parse import (
    ImageEmbeddingItems,
    ImageProcessorItems,
    ImageSize,
    MultiModalDataItems,
40
    MultiModalUUIDItems,
41
42
)
from vllm.multimodal.processing import (
43
    BaseDummyInputsBuilder,
44
45
46
47
48
49
50
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    InputProcessingContext,
    PromptReplacement,
    PromptUpdate,
    PromptUpdateDetails,
)
51
from vllm.sequence import IntermediateTensors
52
from vllm.utils.tensor_schema import TensorSchema, TensorShape
53

54
from .clip import CLIPVisionModel
55
56
from .interfaces import (
    MultiModalEmbeddings,
57
    SupportsEagle3,
58
59
60
61
62
    SupportsLoRA,
    SupportsMultiModal,
    SupportsPP,
)
from .module_mapping import MultiModelKeys
63
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
64
from .siglip import SiglipVisionModel
65
66
67
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
68
    get_layer_index,
69
70
71
    init_vllm_registered_model,
    maybe_prefix,
)
72
from .vision import get_num_selected_vision_tokens, get_vision_encoder_info
73
74


75
class LlavaImagePixelInputs(TensorSchema):
76
    """
77
78
79
80
81
    Dimensions:
        - bn: Batch size * number of images
        - c: Number of channels (3)
        - h: Height
        - w: Width
82

83
84
85
    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """
86

87
88
    type: Literal["pixel_values"] = "pixel_values"
    pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
89

90

91
class PixtralHFImagePixelInputs(TensorSchema):
92
    """
93
94
95
96
97
    Dimensions:
        - bn: Batch size * number of images
        - c: Number of channels
        - h: Height
        - w: Width
98

99
100
101
    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """
102

103
    type: Literal["pixel_values_pixtral"] = "pixel_values_pixtral"
104
    pixel_values: Annotated[
105
        torch.Tensor | list[torch.Tensor],
106
107
        TensorShape("bn", "c", "h", "w", dynamic_dims={"h", "w"}),
    ]
108

109

110
class LlavaImageEmbeddingInputs(TensorSchema):
111
    """
112
113
114
115
116
    Dimensions:
        - bn: Batch size * number of images
        - ifs: Image feature size
        - hs: Hidden size (must match language model backbone)
    """
117

118
119
    type: Literal["image_embeds"] = "image_embeds"
    data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")]
120
121


122
123
124
LlavaImageInputs: TypeAlias = (
    LlavaImagePixelInputs | PixtralHFImagePixelInputs | LlavaImageEmbeddingInputs
)
125
"""Alias for supported LLaVA image input types."""
126
127


128
class LlavaMultiModalProjector(nn.Module):
129
130
131
132
133
134
    def __init__(
        self,
        vision_hidden_size: int,
        text_hidden_size: int,
        projector_hidden_act: str,
        multimodal_projector_bias: bool,
135
        quant_config: QuantizationConfig | None = None,
136
137
        prefix: str = "",
    ):
138
139
        super().__init__()

140
141
142
143
144
145
146
        self.linear_1 = ColumnParallelLinear(
            vision_hidden_size,
            text_hidden_size,
            bias=multimodal_projector_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.linear_1",
        )
147
        self.act = get_act_fn(projector_hidden_act)
148
149
150
151
152
153
154
        self.linear_2 = RowParallelLinear(
            text_hidden_size,
            text_hidden_size,
            bias=multimodal_projector_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.linear_2",
        )
155

156
    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
157
        hidden_states, _ = self.linear_1(image_features)
158
        hidden_states = self.act(hidden_states)
159
        hidden_states, _ = self.linear_2(hidden_states)
160
161
162
        return hidden_states


163
164
class LlavaLikeConfig(Protocol):
    vision_config: Final[PretrainedConfig]
165
    image_token_index: Final[int]
166
    vision_feature_select_strategy: Final[str]
167
    vision_feature_layer: Final[int | list[int]]
168

169

170
171
172
173
class LlavaLikeProcessor(Protocol):
    image_token: Final[str]


174
175
class BaseLlavaProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self) -> LlavaLikeConfig:
176
        return self.ctx.get_hf_config(LlavaConfig)
177

178
179
    def get_vision_encoder_info(self):
        return get_vision_encoder_info(self.get_hf_config())
180

181
    @abstractmethod
182
    def get_hf_processor(self, **kwargs: object) -> LlavaLikeProcessor:
183
        raise NotImplementedError
184

185
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
186
        return {"image": None}
187

188
189
190
191
192
193
194
195
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        hf_config = self.get_hf_config()
        vision_encoder_info = self.get_vision_encoder_info()
196

197
        return get_num_selected_vision_tokens(
198
199
200
201
            vision_encoder_info.get_num_image_tokens(
                image_width=image_width,
                image_height=image_height,
            ),
202
            hf_config.vision_feature_select_strategy,
203
        )
204

205
206
    def get_image_size_with_most_features(self) -> ImageSize:
        vision_encoder_info = self.get_vision_encoder_info()
207
208
        width = height = vision_encoder_info.get_image_size()
        return ImageSize(width=width, height=height)
209

210
211
    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()
212

213
        return self.get_num_image_tokens(
214
215
216
217
            image_width=target_width,
            image_height=target_height,
        )

218
219
220
221
222

_I = TypeVar("_I", bound=BaseLlavaProcessingInfo)


class LlavaDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
223
224
225
226
227
228
229
230
231
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        processor = self.info.get_hf_processor()
        image_token = processor.image_token

        return image_token * num_images

    def get_dummy_mm_data(
232
        self,
233
        seq_len: int,
234
        mm_counts: Mapping[str, int],
235
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
236
        mm_processor_kwargs: Mapping[str, object] | None = None,
237
    ) -> MultiModalDataDict:
238
239
        num_images = mm_counts.get("image", 0)

240
        target_width, target_height = self.info.get_image_size_with_most_features()
241

242
243
        image_overrides = mm_options.get("image") if mm_options else None

244
        return {
245
246
247
248
249
250
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            )
251
252
253
        }


254
class LlavaProcessingInfo(BaseLlavaProcessingInfo):
255
    def get_hf_processor(self, **kwargs: object):
256
257
258
259
260
261
262
        hf_processor = self.ctx.get_hf_processor(LlavaProcessor, **kwargs)
        # In case patch_size is omitted from `processor_config.json`
        # e.g. for E5-V: https://huggingface.co/royokong/e5-v
        if hf_processor.patch_size is None:
            patch_size = self.get_vision_encoder_info().get_patch_size()
            hf_processor.patch_size = patch_size
        return hf_processor
263
264


265
class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor[_I]):
266
267
268
269
270
271
272
273
    # Copied from BaseMultiModalProcessor
    @abstractmethod
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        raise NotImplementedError
274

275
    def _get_prompt_updates(
276
277
278
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
279
        out_mm_kwargs: MultiModalKwargsItems,
280
    ) -> Sequence[PromptUpdate]:
281
        hf_config = self.info.get_hf_config()
282
283
284
285
        image_token_id = hf_config.image_token_index

        def get_replacement(item_idx: int):
            images = mm_items.get_items(
286
287
                "image", (ImageEmbeddingItems, ImageProcessorItems)
            )
288
289
290
291
292

            if isinstance(images, ImageEmbeddingItems):
                num_image_tokens = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
293
                num_image_tokens = self.info.get_num_image_tokens(
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
                    image_width=image_size.width,
                    image_height=image_size.height,
                )

            return [image_token_id] * num_image_tokens

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
                replacement=get_replacement,
            ),
        ]


309
class LlavaMultiModalProcessor(BaseLlavaMultiModalProcessor[LlavaProcessingInfo]):
310
311
312
313
314
315
316
317
318
319
320
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )


321
class PixtralHFProcessingInfo(BaseLlavaProcessingInfo):
322
323
    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(PixtralProcessor, **kwargs)
324

325

326
class PixtralHFMultiModalProcessor(BaseMultiModalProcessor[PixtralHFProcessingInfo]):
327
328
329
330
331
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
332
        tok_kwargs: Mapping[str, object],
333
334
335
336
337
    ) -> BatchFeature:
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
338
            tok_kwargs=tok_kwargs,
339
        )
340

341
342
        pixel_values = processed_outputs.get("pixel_values")
        if pixel_values is not None:
343
344
345
346
            # Avoid padding since we need the output for each image to be
            # independent of other images for the cache to work correctly
            image_sizes = processed_outputs["image_sizes"]
            assert len(pixel_values) == len(image_sizes)
347

348
349
350
            processed_outputs["pixel_values"] = [
                p[:, :h, :w] for p, (h, w) in zip(pixel_values, image_sizes)
            ]
351

352
        return processed_outputs
353

354
355
356
357
358
359
360
361
362
363
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )

364
    def _get_prompt_updates(
365
366
        self,
        mm_items: MultiModalDataItems,
367
        hf_processor_mm_kwargs: Mapping[str, object],
368
        out_mm_kwargs: MultiModalKwargsItems,
369
    ) -> Sequence[PromptUpdate]:
370
        processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
371
        hf_config = self.info.get_hf_config()
372
373
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()
374

375
376
377
        image_break_id = vocab[processor.image_break_token]
        image_token_id = hf_config.image_token_index
        image_end_id = vocab[processor.image_end_token]
378

379
380
        assert isinstance(hf_config.vision_config, PixtralVisionConfig)
        encoder_info = PixtralHFEncoderInfo(hf_config)
381

382
383
384
        def get_replacement(item_idx: int):
            images = mm_items.get_items("image", ImageProcessorItems)
            image_size = images.get_image_size(item_idx)
385

386
            ncols, nrows = encoder_info.get_patch_grid_size(
387
388
389
                image_width=image_size.width,
                image_height=image_size.height,
            )
390

391
392
            tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
            tokens[-1] = image_end_id
393

394
            return PromptUpdateDetails.select_token_id(tokens, image_token_id)
395
396
397
398
399

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
400
401
                replacement=get_replacement,
            ),
402
403
        ]

404

405
def _build_llava_or_pixtral_hf_info(
406
407
    ctx: InputProcessingContext,
) -> BaseLlavaProcessingInfo:
408
409
410
411
412
413
414
415
    hf_config = ctx.get_hf_config(LlavaConfig)

    if isinstance(hf_config.vision_config, PixtralVisionConfig):
        return PixtralHFProcessingInfo(ctx)

    return LlavaProcessingInfo(ctx)


416
def _build_llava_or_pixtral_hf_processor(
417
418
    info: _I,
    dummy_inputs: BaseDummyInputsBuilder[_I],
419
    *,
420
    cache: BaseMultiModalProcessorCache | None = None,
421
) -> BaseMultiModalProcessor:
422
    if isinstance(info, PixtralHFProcessingInfo):
423
        return PixtralHFMultiModalProcessor(
424
425
426
427
428
429
430
431
432
            info,
            dummy_inputs,  # type: ignore
            cache=cache,
        )

    if isinstance(info, LlavaProcessingInfo):
        return LlavaMultiModalProcessor(
            info,
            dummy_inputs,  # type: ignore
433
            cache=cache,
434
        )
435

436
    raise NotImplementedError(type(info))
437
438
439
440
441


def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
    """Determine the number of hidden layers to initialize up to in the
    visual encoder.
442

443
444
445
446
447
448
449
    Args:
        hf_config: Model config with vision feature layer(s).
    """
    feature_layers = hf_config.vision_feature_layer
    num_hidden_layers = hf_config.vision_config.num_hidden_layers
    # If we have one feature layer, initialize up to that layer
    if isinstance(feature_layers, int):
450
        return get_layer_index(feature_layers, num_hidden_layers)
451
452
    # If we have multiple feature layers, initialize up to the deepest one
    elif isinstance(feature_layers, (list, tuple)):
453
        return max(get_layer_index(idx, num_hidden_layers) for idx in feature_layers)
454
455
456
    raise TypeError(
        f"vision_layer_feature type: {type(feature_layers)} is not supported"
    )
457
458


459
460
def init_vision_tower_for_llava(
    hf_config: LlavaLikeConfig,
461
    quant_config: QuantizationConfig | None,
462
    *,
463
    require_post_norm: bool | None = None,
464
    prefix: str = "",
465
) -> CLIPVisionModel | SiglipVisionModel | PixtralHFVisionModel:
466
467
    vision_config = hf_config.vision_config

468
469
    # Initialize the vision tower only up to the deepest required feature layer
    num_hidden_layers = _get_num_hidden_layers(hf_config)
470
471
472
473

    if isinstance(vision_config, CLIPVisionConfig):
        return CLIPVisionModel(
            vision_config,
474
            quant_config=quant_config,
475
            num_hidden_layers_override=num_hidden_layers,
476
            require_post_norm=require_post_norm,
477
            prefix=prefix,
478
479
480
481
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return SiglipVisionModel(
            vision_config,
482
            quant_config=quant_config,
483
            num_hidden_layers_override=num_hidden_layers,
484
            require_post_norm=require_post_norm,
485
            prefix=prefix,
486
        )
487
    elif isinstance(vision_config, PixtralVisionConfig):
488
489
        return PixtralHFVisionModel(
            vision_config,
490
            quant_config=quant_config,
491
492
            num_hidden_layers_override=num_hidden_layers,
            require_post_norm=require_post_norm,
493
            prefix=prefix,
494
        )
495
496
497
498
499

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


500
501
502
503
504
@MULTIMODAL_REGISTRY.register_processor(
    _build_llava_or_pixtral_hf_processor,
    info=_build_llava_or_pixtral_hf_info,
    dummy_inputs=LlavaDummyInputsBuilder,
)
505
class LlavaForConditionalGeneration(
506
    nn.Module, SupportsLoRA, SupportsMultiModal, SupportsPP, SupportsEagle3
507
):
508
509
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
510
        "gate_up_proj": ["gate_proj", "up_proj"],
511
    }
512

513
514
515
516
517
518
519
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.language_model.": "language_model.model.",
            "model.vision_tower.": "vision_tower.",
            "model.multi_modal_projector.": "multi_modal_projector.",
            "lm_head.": "language_model.lm_head.",
520
521
        }
    )
522

523
    @classmethod
524
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
525
526
527
528
529
        if modality.startswith("image"):
            return "<image>"

        raise ValueError("Only image modality is supported")

530
531
532
533
534
535
536
    def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
        self.get_language_model().model.aux_hidden_state_layers = layers

    def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
        num_layers = len(self.get_language_model().model.layers)
        return (2, num_layers // 2, num_layers - 3)

537
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
538
        super().__init__()
539

540
541
542
543
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

544
        self.config = config
545
        self.multimodal_config = multimodal_config
546

547
548
        # NOTE: These are special cases for Pixtral-12B in the HF-format
        # https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json  # noqa
549
550
551
552
        if (
            config.text_config.architectures is None
            and config.text_config.model_type == "mistral"
        ):
553
            config.text_config.architectures = ["MistralForCausalLM"]
554
555
556
557
        if (
            config.projector_hidden_act is None
            and config.vision_config.hidden_act == "gelu"
        ):
558
559
            config.projector_hidden_act = "gelu"

560
        with self._mark_tower_model(vllm_config, "image"):
561
562
            self.vision_tower = init_vision_tower_for_llava(
                config,
563
                quant_config=quant_config,
564
                require_post_norm=False,
565
566
                prefix=maybe_prefix(prefix, "vision_tower"),
            )
567
568
569
570
571
572
            self.multi_modal_projector = LlavaMultiModalProjector(
                vision_hidden_size=config.vision_config.hidden_size,
                text_hidden_size=config.text_config.hidden_size,
                projector_hidden_act=config.projector_hidden_act,
                multimodal_projector_bias=config.multimodal_projector_bias,
                quant_config=quant_config,
573
574
                prefix=maybe_prefix(prefix, "multi_modal_projector"),
            )
575
576
577
578
579
580
581

        with self._mark_language_model(vllm_config):
            self.language_model = init_vllm_registered_model(
                vllm_config=vllm_config,
                hf_config=config.text_config,
                prefix=maybe_prefix(prefix, "language_model"),
            )
582

583
        self.make_empty_intermediate_tensors = (
584
585
            self.language_model.make_empty_intermediate_tensors
        )
586

587
    def _parse_and_validate_image_input(
588
        self, **kwargs: object
589
    ) -> LlavaImageInputs | None:
590
        pixel_values = kwargs.pop("pixel_values", None)
591
        image_embeds = kwargs.pop("image_embeds", None)
592

593
        if pixel_values is None and image_embeds is None:
594
            return None
595

596
        if pixel_values is not None:
597
            if self.config.vision_config.model_type == "pixtral":
598
599
                return PixtralHFImagePixelInputs(
                    type="pixel_values_pixtral",
600
                    pixel_values=pixel_values,
601
602
                )

603
            expected_h = expected_w = self.config.vision_config.image_size
604
605
            return LlavaImagePixelInputs(
                type="pixel_values",
606
                pixel_values=pixel_values,
607
                resolve_bindings={"h": expected_h, "w": expected_w},
608
609
610
            )

        if image_embeds is not None:
611
612
613
            if self.config.vision_config.model_type == "pixtral":
                raise ValueError("Pixtral-HF does not support image_embeds.")

614
615
            return LlavaImageEmbeddingInputs(
                type="image_embeds",
616
                data=image_embeds,
617
618
619
            )

        raise AssertionError("This line should be unreachable.")
620

621
622
    def _image_pixels_to_features(
        self,
623
624
625
        vision_tower: CLIPVisionModel | SiglipVisionModel | PixtralHFVisionModel,
        pixel_values: torch.Tensor | list[torch.Tensor],
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
626
627
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
628
629
630
631
        return vision_tower(
            pixel_values,
            feature_select_strategy=self.config.vision_feature_select_strategy,
        )
632

633
634
    def _process_image_pixels(
        self,
635
636
        inputs: LlavaImagePixelInputs | PixtralHFImagePixelInputs,
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
637
        pixel_values = inputs["pixel_values"]
638
639
640

        return self._image_pixels_to_features(self.vision_tower, pixel_values)

641
642
643
    def _process_image_input(
        self,
        image_input: LlavaImageInputs,
644
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
645
646
647
        if image_input["type"] == "image_embeds":
            return image_input["data"]

648
        image_features = self._process_image_pixels(image_input)
649

650
651
652
        if isinstance(image_features, torch.Tensor):
            return self.multi_modal_projector(image_features)

653
        feature_sizes = [image_feature.shape[0] for image_feature in image_features]
654
655
656
657
658

        image_embeds = self.multi_modal_projector(torch.cat(image_features))
        image_embeds = torch.split(image_embeds, feature_sizes)
        return image_embeds

659
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
660
661
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
662
            return []
663

664
        return self._process_image_input(image_input)
665

666
667
    def forward(
        self,
668
        input_ids: torch.Tensor | None,
669
        positions: torch.Tensor,
670
671
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
672
        **kwargs: object,
673
    ) -> torch.Tensor | IntermediateTensors:
Cyrus Leung's avatar
Cyrus Leung committed
674
        """Run forward pass for LLaVA-1.5.
675
676
677

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
678

679
        Concretely, consider a text prompt:
680
681
        `"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.

682
        Tokenizer outputs:
683
684
685
686
        `[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
        278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
687
        before they are inputted to the model, so the input processor prepends
688
689
690
691
692
693
694
695
696
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
        29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
        29901]`.

        We insert 575 tokens so that including the original image token in the
        input, there are a total of 576 (24 * 24) image tokens, which
        corresponds to the number of image tokens inputted to the language
        model, i.e. the number of image tokens outputted by the visual encoder.
697
698
699
700
701
702
703

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
704
705
706
            positions: Position indices for the input tokens.
            intermediate_tensors: Intermediate tensors from prior forward pass.
            inputs_embeds: Optional tensor of input embeddings.
707

708
        Info:
samzong's avatar
samzong committed
709
            [`LlavaImageInputs`][vllm.model_executor.models.llava.LlavaImageInputs]
710
        """
711
712
        if intermediate_tensors is not None:
            inputs_embeds = None
713

714
715
716
        hidden_states = self.language_model.model(
            input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
        )
717
718
719

        return hidden_states

720
721
722
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
723
    ) -> torch.Tensor | None:
724
        return self.language_model.compute_logits(hidden_states)
725

726
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
727
        loader = AutoWeightsLoader(self)
728
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
729

730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model",
            connector="multi_modal_projector",
            tower_model="vision_tower",
        )

    def get_num_mm_encoder_tokens(
        self,
        num_image_tokens: int,
    ) -> int:
        # LLaVA's vision encoder outputs one token per patch without
        # spatial merging or pixel shuffle
        return num_image_tokens

    def get_num_mm_connector_tokens(
        self,
        num_vision_tokens: int,
    ) -> int:
        # LLaVA's MLP projector outputs the same number of tokens
        # as it receives from the vision encoder (1:1 mapping)
        return num_vision_tokens

756

757
class MantisProcessingInfo(LlavaProcessingInfo):
758
    def get_hf_processor(self, **kwargs: object):
759
760
761
        hf_config = self.get_hf_config()
        vision_info = self.get_vision_encoder_info()

762
        kwargs.setdefault("patch_size", vision_info.get_patch_size())
763
764
765
766
        kwargs.setdefault(
            "vision_feature_select_strategy",
            hf_config.vision_feature_select_strategy,
        )
767

768
        return self.ctx.get_hf_processor(LlavaProcessor, **kwargs)
769
770


771
class MantisMultiModalProcessor(LlavaMultiModalProcessor):
772
773
    def apply(
        self,
774
        prompt: str | list[int],
775
        mm_items: MultiModalDataItems,
776
777
        mm_uuid_items: MultiModalUUIDItems | None = None,
        hf_processor_mm_kwargs: Mapping[str, object] | None = None,
778
        tokenization_kwargs: Mapping[str, object] | None = None,
779
    ) -> MultiModalInputs:
780
        hf_config = self.info.get_hf_config()
781
        image_token_id = hf_config.image_token_index
782
783

        # Assume that it doesn't depend on the image size
784
        num_image_tokens = self.info.get_num_image_tokens(
785
786
787
            image_width=-1,
            image_height=-1,
        )
788

789
790
        result = super().apply(
            prompt,
791
            mm_items,
792
793
794
            mm_uuid_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
            tokenization_kwargs=tokenization_kwargs,
795
        )
796

797
        mm_item_counts = mm_items.get_all_counts()
798
        mm_kwargs = result["mm_kwargs"]
799
        mm_hashes = result["mm_hashes"]
800
801
802
803

        # We reimplement the functionality of MLlavaProcessor from
        # https://github.com/TIGER-AI-Lab/Mantis.git
        def get_replacement_mantis(item_idx: int):
804
805
806
807
808
809
            return "".join(
                [
                    f"(image {item_idx + 1}: <Image>",  # 7 tokens
                    "<image>" * num_image_tokens,
                    "</Image>)",  # 3 tokens
                ]
810
            )
811
812
813
814
815
816
817
818
819
820
821

        mantis_mm_repls = self._bind_and_group_updates(
            [
                PromptReplacement(
                    modality="image",
                    target=[image_token_id] * num_image_tokens,
                    replacement=get_replacement_mantis,
                )
            ],
            mm_item_counts,
        )
822

823
        prompt_ids, _ = self._apply_prompt_updates(
824
            result["prompt_token_ids"],
825
            mantis_mm_repls,
826
827
        )

828
        orig_repls = self._get_mm_prompt_updates(
829
830
831
832
            mm_items,
            hf_processor_mm_kwargs,
            mm_kwargs,
        )
833
        mm_placeholders = self._find_mm_placeholders(prompt_ids, orig_repls)
834
        self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
835

836
837
838
        mm_placeholder_ranges = {
            modality: [item.to_range() for item in placeholders]
            for modality, placeholders in mm_placeholders.items()
839
840
        }

841
        return mm_inputs(
842
843
            prompt_token_ids=prompt_ids,
            mm_kwargs=mm_kwargs,
844
            mm_hashes=mm_hashes,
845
            mm_placeholders=mm_placeholder_ranges,
846
        )
847
848
849
850


# To use this model, please use
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
851
852
853
854
855
@MULTIMODAL_REGISTRY.register_processor(
    MantisMultiModalProcessor,
    info=MantisProcessingInfo,
    dummy_inputs=LlavaDummyInputsBuilder,
)
856
857
class MantisForConditionalGeneration(LlavaForConditionalGeneration):
    pass