llava.py 29 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import abstractmethod
5
from collections.abc import Iterable, Mapping, Sequence
6
from typing import Annotated, Final, Literal, Protocol, TypeAlias, TypeVar
7
8

import torch
9
import torch.nn as nn
10
11
12
13
14
15
16
17
from transformers import (
    BatchFeature,
    CLIPVisionConfig,
    LlavaConfig,
    PixtralVisionConfig,
    PretrainedConfig,
    SiglipVisionConfig,
)
18
from transformers.models.llava import LlavaProcessor
19
from transformers.models.pixtral import PixtralProcessor
20

21
from vllm.config import VllmConfig
22
from vllm.config.multimodal import BaseDummyOptions, MultiModalConfig
23
from vllm.model_executor.layers.activation import get_act_fn
24
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
25
from vllm.model_executor.layers.quantization import QuantizationConfig
26
from vllm.multimodal import MULTIMODAL_REGISTRY
27
from vllm.multimodal.cache import BaseMultiModalProcessorCache
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalInputs,
    MultiModalKwargsItems,
    MultiModalUUIDDict,
)
from vllm.multimodal.parse import (
    ImageEmbeddingItems,
    ImageProcessorItems,
    ImageSize,
    MultiModalDataItems,
)
from vllm.multimodal.processing import (
42
    BaseDummyInputsBuilder,
43
44
45
46
47
48
49
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    InputProcessingContext,
    PromptReplacement,
    PromptUpdate,
    PromptUpdateDetails,
)
50
from vllm.sequence import IntermediateTensors
51
from vllm.utils.tensor_schema import TensorSchema, TensorShape
52

53
from .clip import CLIPVisionModel
54
55
56
57
58
59
60
from .interfaces import (
    MultiModalEmbeddings,
    SupportsLoRA,
    SupportsMultiModal,
    SupportsPP,
)
from .module_mapping import MultiModelKeys
61
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
62
from .siglip import SiglipVisionModel
63
64
65
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
66
    get_layer_index,
67
68
69
    init_vllm_registered_model,
    maybe_prefix,
)
70
from .vision import get_num_selected_vision_tokens, get_vision_encoder_info
71
72


73
class LlavaImagePixelInputs(TensorSchema):
74
    """
75
76
77
78
79
    Dimensions:
        - bn: Batch size * number of images
        - c: Number of channels (3)
        - h: Height
        - w: Width
80

81
82
83
    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """
84

85
86
    type: Literal["pixel_values"] = "pixel_values"
    pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
87

88

89
class PixtralHFImagePixelInputs(TensorSchema):
90
    """
91
92
93
94
95
    Dimensions:
        - bn: Batch size * number of images
        - c: Number of channels
        - h: Height
        - w: Width
96

97
98
99
    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """
100

101
    type: Literal["pixel_values_pixtral"] = "pixel_values_pixtral"
102
    pixel_values: Annotated[
103
        torch.Tensor | list[torch.Tensor],
104
105
        TensorShape("bn", "c", "h", "w", dynamic_dims={"h", "w"}),
    ]
106

107

108
class LlavaImageEmbeddingInputs(TensorSchema):
109
    """
110
111
112
113
114
    Dimensions:
        - bn: Batch size * number of images
        - ifs: Image feature size
        - hs: Hidden size (must match language model backbone)
    """
115

116
117
    type: Literal["image_embeds"] = "image_embeds"
    data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")]
118
119


120
121
122
LlavaImageInputs: TypeAlias = (
    LlavaImagePixelInputs | PixtralHFImagePixelInputs | LlavaImageEmbeddingInputs
)
123
124


125
class LlavaMultiModalProjector(nn.Module):
126
127
128
129
130
131
    def __init__(
        self,
        vision_hidden_size: int,
        text_hidden_size: int,
        projector_hidden_act: str,
        multimodal_projector_bias: bool,
132
        quant_config: QuantizationConfig | None = None,
133
134
        prefix: str = "",
    ):
135
136
        super().__init__()

137
138
139
140
141
142
143
        self.linear_1 = ColumnParallelLinear(
            vision_hidden_size,
            text_hidden_size,
            bias=multimodal_projector_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.linear_1",
        )
144
        self.act = get_act_fn(projector_hidden_act)
145
146
147
148
149
150
151
        self.linear_2 = RowParallelLinear(
            text_hidden_size,
            text_hidden_size,
            bias=multimodal_projector_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.linear_2",
        )
152

153
    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
154
        hidden_states, _ = self.linear_1(image_features)
155
        hidden_states = self.act(hidden_states)
156
        hidden_states, _ = self.linear_2(hidden_states)
157
158
159
        return hidden_states


160
161
class LlavaLikeConfig(Protocol):
    vision_config: Final[PretrainedConfig]
162
    image_token_index: Final[int]
163
    vision_feature_select_strategy: Final[str]
164
    vision_feature_layer: Final[int | list[int]]
165

166

167
168
169
170
class LlavaLikeProcessor(Protocol):
    image_token: Final[str]


171
172
class BaseLlavaProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self) -> LlavaLikeConfig:
173
        return self.ctx.get_hf_config(LlavaConfig)
174

175
176
    def get_vision_encoder_info(self):
        return get_vision_encoder_info(self.get_hf_config())
177

178
    @abstractmethod
179
    def get_hf_processor(self, **kwargs: object) -> LlavaLikeProcessor:
180
        raise NotImplementedError
181

182
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
183
        return {"image": None}
184

185
186
187
188
189
190
191
192
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        hf_config = self.get_hf_config()
        vision_encoder_info = self.get_vision_encoder_info()
193

194
        return get_num_selected_vision_tokens(
195
196
197
198
            vision_encoder_info.get_num_image_tokens(
                image_width=image_width,
                image_height=image_height,
            ),
199
            hf_config.vision_feature_select_strategy,
200
        )
201

202
203
    def get_image_size_with_most_features(self) -> ImageSize:
        vision_encoder_info = self.get_vision_encoder_info()
204
205
        width = height = vision_encoder_info.get_image_size()
        return ImageSize(width=width, height=height)
206

207
208
    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()
209

210
        return self.get_num_image_tokens(
211
212
213
214
            image_width=target_width,
            image_height=target_height,
        )

215
216
217
218
219

_I = TypeVar("_I", bound=BaseLlavaProcessingInfo)


class LlavaDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
220
221
222
223
224
225
226
227
228
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        processor = self.info.get_hf_processor()
        image_token = processor.image_token

        return image_token * num_images

    def get_dummy_mm_data(
229
        self,
230
        seq_len: int,
231
        mm_counts: Mapping[str, int],
232
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
233
    ) -> MultiModalDataDict:
234
235
        num_images = mm_counts.get("image", 0)

236
        target_width, target_height = self.info.get_image_size_with_most_features()
237

238
239
        image_overrides = mm_options.get("image") if mm_options else None

240
        return {
241
242
243
244
245
246
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            )
247
248
249
        }


250
class LlavaProcessingInfo(BaseLlavaProcessingInfo):
251
    def get_hf_processor(self, **kwargs: object):
252
253
254
255
256
257
258
        hf_processor = self.ctx.get_hf_processor(LlavaProcessor, **kwargs)
        # In case patch_size is omitted from `processor_config.json`
        # e.g. for E5-V: https://huggingface.co/royokong/e5-v
        if hf_processor.patch_size is None:
            patch_size = self.get_vision_encoder_info().get_patch_size()
            hf_processor.patch_size = patch_size
        return hf_processor
259
260


261
class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor[_I]):
262
263
264
265
266
267
268
269
    # Copied from BaseMultiModalProcessor
    @abstractmethod
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        raise NotImplementedError
270

271
    def _get_prompt_updates(
272
273
274
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
275
        out_mm_kwargs: MultiModalKwargsItems,
276
    ) -> Sequence[PromptUpdate]:
277
        hf_config = self.info.get_hf_config()
278
279
280
281
        image_token_id = hf_config.image_token_index

        def get_replacement(item_idx: int):
            images = mm_items.get_items(
282
283
                "image", (ImageEmbeddingItems, ImageProcessorItems)
            )
284
285
286
287
288

            if isinstance(images, ImageEmbeddingItems):
                num_image_tokens = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
289
                num_image_tokens = self.info.get_num_image_tokens(
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
                    image_width=image_size.width,
                    image_height=image_size.height,
                )

            return [image_token_id] * num_image_tokens

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
                replacement=get_replacement,
            ),
        ]


305
class LlavaMultiModalProcessor(BaseLlavaMultiModalProcessor[LlavaProcessingInfo]):
306
307
308
309
310
311
312
313
314
315
316
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )


317
class PixtralHFProcessingInfo(BaseLlavaProcessingInfo):
318
319
    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(PixtralProcessor, **kwargs)
320

321

322
class PixtralHFMultiModalProcessor(BaseMultiModalProcessor[PixtralHFProcessingInfo]):
323
324
325
326
327
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
328
        tok_kwargs: Mapping[str, object],
329
330
331
332
333
    ) -> BatchFeature:
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
334
            tok_kwargs=tok_kwargs,
335
        )
336

337
338
        pixel_values = processed_outputs.get("pixel_values")
        if pixel_values is not None:
339
340
341
342
            # Avoid padding since we need the output for each image to be
            # independent of other images for the cache to work correctly
            image_sizes = processed_outputs["image_sizes"]
            assert len(pixel_values) == len(image_sizes)
343

344
345
346
            processed_outputs["pixel_values"] = [
                p[:, :h, :w] for p, (h, w) in zip(pixel_values, image_sizes)
            ]
347

348
        return processed_outputs
349

350
351
352
353
354
355
356
357
358
359
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )

360
    def _get_prompt_updates(
361
362
        self,
        mm_items: MultiModalDataItems,
363
        hf_processor_mm_kwargs: Mapping[str, object],
364
        out_mm_kwargs: MultiModalKwargsItems,
365
    ) -> Sequence[PromptUpdate]:
366
        processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
367
        hf_config = self.info.get_hf_config()
368
369
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()
370

371
372
373
        image_break_id = vocab[processor.image_break_token]
        image_token_id = hf_config.image_token_index
        image_end_id = vocab[processor.image_end_token]
374

375
376
        assert isinstance(hf_config.vision_config, PixtralVisionConfig)
        encoder_info = PixtralHFEncoderInfo(hf_config)
377

378
379
380
        def get_replacement(item_idx: int):
            images = mm_items.get_items("image", ImageProcessorItems)
            image_size = images.get_image_size(item_idx)
381

382
            ncols, nrows = encoder_info.get_patch_grid_size(
383
384
385
                image_width=image_size.width,
                image_height=image_size.height,
            )
386

387
388
            tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
            tokens[-1] = image_end_id
389

390
            return PromptUpdateDetails.select_token_id(tokens, image_token_id)
391
392
393
394
395

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
396
397
                replacement=get_replacement,
            ),
398
399
        ]

400

401
def _build_llava_or_pixtral_hf_info(
402
403
    ctx: InputProcessingContext,
) -> BaseLlavaProcessingInfo:
404
405
406
407
408
409
410
411
    hf_config = ctx.get_hf_config(LlavaConfig)

    if isinstance(hf_config.vision_config, PixtralVisionConfig):
        return PixtralHFProcessingInfo(ctx)

    return LlavaProcessingInfo(ctx)


412
def _build_llava_or_pixtral_hf_processor(
413
414
    info: _I,
    dummy_inputs: BaseDummyInputsBuilder[_I],
415
    *,
416
    cache: BaseMultiModalProcessorCache | None = None,
417
) -> BaseMultiModalProcessor:
418
    if isinstance(info, PixtralHFProcessingInfo):
419
        return PixtralHFMultiModalProcessor(
420
421
422
423
424
425
426
427
428
            info,
            dummy_inputs,  # type: ignore
            cache=cache,
        )

    if isinstance(info, LlavaProcessingInfo):
        return LlavaMultiModalProcessor(
            info,
            dummy_inputs,  # type: ignore
429
            cache=cache,
430
        )
431

432
    raise NotImplementedError(type(info))
433
434
435
436
437


def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
    """Determine the number of hidden layers to initialize up to in the
    visual encoder.
438

439
440
441
442
443
444
445
    Args:
        hf_config: Model config with vision feature layer(s).
    """
    feature_layers = hf_config.vision_feature_layer
    num_hidden_layers = hf_config.vision_config.num_hidden_layers
    # If we have one feature layer, initialize up to that layer
    if isinstance(feature_layers, int):
446
        return get_layer_index(feature_layers, num_hidden_layers)
447
448
    # If we have multiple feature layers, initialize up to the deepest one
    elif isinstance(feature_layers, (list, tuple)):
449
        return max(get_layer_index(idx, num_hidden_layers) for idx in feature_layers)
450
451
452
    raise TypeError(
        f"vision_layer_feature type: {type(feature_layers)} is not supported"
    )
453
454


455
456
def init_vision_tower_for_llava(
    hf_config: LlavaLikeConfig,
457
    quant_config: QuantizationConfig | None,
458
    multimodal_config: MultiModalConfig | None,
459
    *,
460
    require_post_norm: bool | None = None,
461
    prefix: str = "",
462
) -> CLIPVisionModel | SiglipVisionModel | PixtralHFVisionModel:
463
464
    vision_config = hf_config.vision_config

465
466
    # Initialize the vision tower only up to the deepest required feature layer
    num_hidden_layers = _get_num_hidden_layers(hf_config)
467
468
469
470

    if isinstance(vision_config, CLIPVisionConfig):
        return CLIPVisionModel(
            vision_config,
471
            quant_config=quant_config,
472
            multimodal_config=multimodal_config,
473
            num_hidden_layers_override=num_hidden_layers,
474
            require_post_norm=require_post_norm,
475
            prefix=prefix,
476
477
478
479
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return SiglipVisionModel(
            vision_config,
480
            quant_config=quant_config,
481
            multimodal_config=multimodal_config,
482
            num_hidden_layers_override=num_hidden_layers,
483
            require_post_norm=require_post_norm,
484
            prefix=prefix,
485
        )
486
    elif isinstance(vision_config, PixtralVisionConfig):
487
488
        return PixtralHFVisionModel(
            vision_config,
489
            quant_config=quant_config,
490
            multimodal_config=multimodal_config,
491
492
            num_hidden_layers_override=num_hidden_layers,
            require_post_norm=require_post_norm,
493
            prefix=prefix,
494
        )
495
496
497
498
499

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


500
501
502
503
504
@MULTIMODAL_REGISTRY.register_processor(
    _build_llava_or_pixtral_hf_processor,
    info=_build_llava_or_pixtral_hf_info,
    dummy_inputs=LlavaDummyInputsBuilder,
)
505
506
507
class LlavaForConditionalGeneration(
    nn.Module, SupportsLoRA, SupportsMultiModal, SupportsPP
):
508
509
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
510
        "gate_up_proj": ["gate_proj", "up_proj"],
511
    }
512

513
514
515
516
517
518
519
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.language_model.": "language_model.model.",
            "model.vision_tower.": "vision_tower.",
            "model.multi_modal_projector.": "multi_modal_projector.",
            "lm_head.": "language_model.lm_head.",
520
521
        }
    )
522

523
    @classmethod
524
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
525
526
527
528
529
        if modality.startswith("image"):
            return "<image>"

        raise ValueError("Only image modality is supported")

530
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
531
        super().__init__()
532

533
534
535
536
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

537
        self.config = config
538
        self.multimodal_config = multimodal_config
539

540
541
        # NOTE: These are special cases for Pixtral-12B in the HF-format
        # https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json  # noqa
542
543
544
545
        if (
            config.text_config.architectures is None
            and config.text_config.model_type == "mistral"
        ):
546
            config.text_config.architectures = ["MistralForCausalLM"]
547
548
549
550
        if (
            config.projector_hidden_act is None
            and config.vision_config.hidden_act == "gelu"
        ):
551
552
            config.projector_hidden_act = "gelu"

553
        # TODO: Optionally initializes this for supporting embeddings.
554
555
556
        if multimodal_config.get_limit_per_prompt("image"):
            self.vision_tower = init_vision_tower_for_llava(
                config,
557
558
                quant_config=quant_config,
                multimodal_config=multimodal_config,
559
                require_post_norm=False,
560
561
                prefix=maybe_prefix(prefix, "vision_tower"),
            )
562
563
564
565
566
567
            self.multi_modal_projector = LlavaMultiModalProjector(
                vision_hidden_size=config.vision_config.hidden_size,
                text_hidden_size=config.text_config.hidden_size,
                projector_hidden_act=config.projector_hidden_act,
                multimodal_projector_bias=config.multimodal_projector_bias,
                quant_config=quant_config,
568
569
                prefix=maybe_prefix(prefix, "multi_modal_projector"),
            )
570
571
572
        else:
            self.vision_tower = None
            self.multi_modal_projector = None
573

574
        self.language_model = init_vllm_registered_model(
575
            vllm_config=vllm_config,
576
577
578
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )
579

580
        self.make_empty_intermediate_tensors = (
581
582
            self.language_model.make_empty_intermediate_tensors
        )
583

584
    def _parse_and_validate_image_input(
585
        self, **kwargs: object
586
    ) -> LlavaImageInputs | None:
587
        pixel_values = kwargs.pop("pixel_values", None)
588
        image_embeds = kwargs.pop("image_embeds", None)
589

590
        if pixel_values is None and image_embeds is None:
591
            return None
592

593
        if pixel_values is not None:
594
            if self.config.vision_config.model_type == "pixtral":
595
596
                return PixtralHFImagePixelInputs(
                    type="pixel_values_pixtral",
597
                    pixel_values=pixel_values,
598
599
                )

600
            expected_h = expected_w = self.config.vision_config.image_size
601
602
            return LlavaImagePixelInputs(
                type="pixel_values",
603
                pixel_values=pixel_values,
604
                resolve_bindings={"h": expected_h, "w": expected_w},
605
606
607
            )

        if image_embeds is not None:
608
609
610
            if self.config.vision_config.model_type == "pixtral":
                raise ValueError("Pixtral-HF does not support image_embeds.")

611
612
            return LlavaImageEmbeddingInputs(
                type="image_embeds",
613
                data=image_embeds,
614
615
616
            )

        raise AssertionError("This line should be unreachable.")
617

618
619
    def _image_pixels_to_features(
        self,
620
621
622
        vision_tower: CLIPVisionModel | SiglipVisionModel | PixtralHFVisionModel,
        pixel_values: torch.Tensor | list[torch.Tensor],
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
623
624
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
625
626
627
628
        return vision_tower(
            pixel_values,
            feature_select_strategy=self.config.vision_feature_select_strategy,
        )
629

630
631
    def _process_image_pixels(
        self,
632
633
        inputs: LlavaImagePixelInputs | PixtralHFImagePixelInputs,
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
634
635
        assert self.vision_tower is not None

636
        pixel_values = inputs["pixel_values"]
637
638
639

        return self._image_pixels_to_features(self.vision_tower, pixel_values)

640
641
642
    def _process_image_input(
        self,
        image_input: LlavaImageInputs,
643
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
644
645
646
        if image_input["type"] == "image_embeds":
            return image_input["data"]

647
648
        assert self.vision_tower is not None
        image_features = self._process_image_pixels(image_input)
649

650
651
652
        if isinstance(image_features, torch.Tensor):
            return self.multi_modal_projector(image_features)

653
        feature_sizes = [image_feature.shape[0] for image_feature in image_features]
654
655
656
657
658

        image_embeds = self.multi_modal_projector(torch.cat(image_features))
        image_embeds = torch.split(image_embeds, feature_sizes)
        return image_embeds

659
660
661
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

662
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
663
664
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
665
            return []
666

667
        return self._process_image_input(image_input)
668

669
670
671
672
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
673
674
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
675
        **kwargs: object,
676
    ) -> torch.Tensor | IntermediateTensors:
Cyrus Leung's avatar
Cyrus Leung committed
677
        """Run forward pass for LLaVA-1.5.
678
679
680

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
681

682
        Concretely, consider a text prompt:
683
684
        `"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.

685
        Tokenizer outputs:
686
687
688
689
        `[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
        278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
690
        before they are inputted to the model, so the input processor prepends
691
692
693
694
695
696
697
698
699
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
        29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
        29901]`.

        We insert 575 tokens so that including the original image token in the
        input, there are a total of 576 (24 * 24) image tokens, which
        corresponds to the number of image tokens inputted to the language
        model, i.e. the number of image tokens outputted by the visual encoder.
700
701
702
703
704
705
706

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
707
708
709
            positions: Position indices for the input tokens.
            intermediate_tensors: Intermediate tensors from prior forward pass.
            inputs_embeds: Optional tensor of input embeddings.
710

711
        Info:
samzong's avatar
samzong committed
712
            [`LlavaImageInputs`][vllm.model_executor.models.llava.LlavaImageInputs]
713
        """
714
715
        if intermediate_tensors is not None:
            inputs_embeds = None
716

717
718
719
        hidden_states = self.language_model.model(
            input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
        )
720
721
722

        return hidden_states

723
724
725
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
726
    ) -> torch.Tensor | None:
727
        return self.language_model.compute_logits(hidden_states)
728

729
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
730
731
732
733
734
        skip_prefixes = []
        if self.vision_tower is None and self.multi_modal_projector is None:
            skip_prefixes.extend(["vision_tower.", "multi_modal_projector."])

        loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
735
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
736

737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model",
            connector="multi_modal_projector",
            tower_model="vision_tower",
        )

    def get_num_mm_encoder_tokens(
        self,
        num_image_tokens: int,
    ) -> int:
        # LLaVA's vision encoder outputs one token per patch without
        # spatial merging or pixel shuffle
        return num_image_tokens

    def get_num_mm_connector_tokens(
        self,
        num_vision_tokens: int,
    ) -> int:
        # LLaVA's MLP projector outputs the same number of tokens
        # as it receives from the vision encoder (1:1 mapping)
        return num_vision_tokens

763

764
class MantisProcessingInfo(LlavaProcessingInfo):
765
    def get_hf_processor(self, **kwargs: object):
766
767
768
        hf_config = self.get_hf_config()
        vision_info = self.get_vision_encoder_info()

769
        kwargs.setdefault("patch_size", vision_info.get_patch_size())
770
771
772
773
        kwargs.setdefault(
            "vision_feature_select_strategy",
            hf_config.vision_feature_select_strategy,
        )
774

775
        return self.ctx.get_hf_processor(LlavaProcessor, **kwargs)
776
777


778
class MantisMultiModalProcessor(LlavaMultiModalProcessor):
779
780
    def apply(
        self,
781
        prompt: str | list[int],
782
783
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
784
785
        tokenization_kwargs: Mapping[str, object] | None = None,
        mm_uuids: MultiModalUUIDDict | None = None,
786
    ) -> MultiModalInputs:
787
        hf_config = self.info.get_hf_config()
788
        image_token_id = hf_config.image_token_index
789
790

        # Assume that it doesn't depend on the image size
791
        num_image_tokens = self.info.get_num_image_tokens(
792
793
794
            image_width=-1,
            image_height=-1,
        )
795

796
797
798
799
800
801
802
        result = super().apply(
            prompt,
            mm_data,
            hf_processor_mm_kwargs,
            tokenization_kwargs,
            mm_uuids=mm_uuids,
        )
803

804
805
        mm_items = self._to_mm_items(mm_data)
        mm_item_counts = mm_items.get_all_counts()
806
        mm_kwargs = result["mm_kwargs"]
807
        mm_hashes = result["mm_hashes"]
808
809
810
811

        # We reimplement the functionality of MLlavaProcessor from
        # https://github.com/TIGER-AI-Lab/Mantis.git
        def get_replacement_mantis(item_idx: int):
812
813
814
815
816
817
            return "".join(
                [
                    f"(image {item_idx + 1}: <Image>",  # 7 tokens
                    "<image>" * num_image_tokens,
                    "</Image>)",  # 3 tokens
                ]
818
            )
819
820
821
822
823
824
825
826
827
828
829

        mantis_mm_repls = self._bind_and_group_updates(
            [
                PromptReplacement(
                    modality="image",
                    target=[image_token_id] * num_image_tokens,
                    replacement=get_replacement_mantis,
                )
            ],
            mm_item_counts,
        )
830

831
        prompt_ids, _ = self._apply_prompt_updates(
832
            result["prompt_token_ids"],
833
            mantis_mm_repls,
834
835
        )

836
        orig_repls = self._get_mm_prompt_updates(
837
838
839
840
            mm_items,
            hf_processor_mm_kwargs,
            mm_kwargs,
        )
841
        mm_placeholders = self._find_mm_placeholders(prompt_ids, orig_repls)
842
        self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
843

844
845
846
        mm_placeholder_ranges = {
            modality: [item.to_range() for item in placeholders]
            for modality, placeholders in mm_placeholders.items()
847
848
        }

849
        return MultiModalInputs(
850
851
852
            type="multimodal",
            prompt_token_ids=prompt_ids,
            mm_kwargs=mm_kwargs,
853
            mm_hashes=mm_hashes,
854
            mm_placeholders=mm_placeholder_ranges,
855
        )
856
857
858
859


# To use this model, please use
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
860
861
862
863
864
@MULTIMODAL_REGISTRY.register_processor(
    MantisMultiModalProcessor,
    info=MantisProcessingInfo,
    dummy_inputs=LlavaDummyInputsBuilder,
)
865
866
class MantisForConditionalGeneration(LlavaForConditionalGeneration):
    pass