llava.py 28.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import abstractmethod
5
from collections.abc import Iterable, Mapping, Sequence
6
from typing import Annotated, Final, Literal, Protocol, TypeAlias, TypeVar
7
8

import torch
9
import torch.nn as nn
10
11
12
13
14
15
16
17
from transformers import (
    BatchFeature,
    CLIPVisionConfig,
    LlavaConfig,
    PixtralVisionConfig,
    PretrainedConfig,
    SiglipVisionConfig,
)
18
from transformers.models.llava import LlavaProcessor
19
from transformers.models.pixtral import PixtralProcessor
20

21
from vllm.config import VllmConfig
22
from vllm.config.multimodal import BaseDummyOptions
23
from vllm.model_executor.layers.activation import get_act_fn
24
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
25
from vllm.model_executor.layers.quantization import QuantizationConfig
26
from vllm.multimodal import MULTIMODAL_REGISTRY
27
from vllm.multimodal.cache import BaseMultiModalProcessorCache
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalInputs,
    MultiModalKwargsItems,
    MultiModalUUIDDict,
)
from vllm.multimodal.parse import (
    ImageEmbeddingItems,
    ImageProcessorItems,
    ImageSize,
    MultiModalDataItems,
)
from vllm.multimodal.processing import (
42
    BaseDummyInputsBuilder,
43
44
45
46
47
48
49
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    InputProcessingContext,
    PromptReplacement,
    PromptUpdate,
    PromptUpdateDetails,
)
50
from vllm.sequence import IntermediateTensors
51
from vllm.utils.tensor_schema import TensorSchema, TensorShape
52

53
from .clip import CLIPVisionModel
54
55
from .interfaces import (
    MultiModalEmbeddings,
56
    SupportsEagle3,
57
58
59
60
61
    SupportsLoRA,
    SupportsMultiModal,
    SupportsPP,
)
from .module_mapping import MultiModelKeys
62
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
63
from .siglip import SiglipVisionModel
64
65
66
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
67
    get_layer_index,
68
69
70
    init_vllm_registered_model,
    maybe_prefix,
)
71
from .vision import get_num_selected_vision_tokens, get_vision_encoder_info
72
73


74
class LlavaImagePixelInputs(TensorSchema):
75
    """
76
77
78
79
80
    Dimensions:
        - bn: Batch size * number of images
        - c: Number of channels (3)
        - h: Height
        - w: Width
81

82
83
84
    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """
85

86
87
    type: Literal["pixel_values"] = "pixel_values"
    pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
88

89

90
class PixtralHFImagePixelInputs(TensorSchema):
91
    """
92
93
94
95
96
    Dimensions:
        - bn: Batch size * number of images
        - c: Number of channels
        - h: Height
        - w: Width
97

98
99
100
    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """
101

102
    type: Literal["pixel_values_pixtral"] = "pixel_values_pixtral"
103
    pixel_values: Annotated[
104
        torch.Tensor | list[torch.Tensor],
105
106
        TensorShape("bn", "c", "h", "w", dynamic_dims={"h", "w"}),
    ]
107

108

109
class LlavaImageEmbeddingInputs(TensorSchema):
110
    """
111
112
113
114
115
    Dimensions:
        - bn: Batch size * number of images
        - ifs: Image feature size
        - hs: Hidden size (must match language model backbone)
    """
116

117
118
    type: Literal["image_embeds"] = "image_embeds"
    data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")]
119
120


121
122
123
LlavaImageInputs: TypeAlias = (
    LlavaImagePixelInputs | PixtralHFImagePixelInputs | LlavaImageEmbeddingInputs
)
124
125


126
class LlavaMultiModalProjector(nn.Module):
127
128
129
130
131
132
    def __init__(
        self,
        vision_hidden_size: int,
        text_hidden_size: int,
        projector_hidden_act: str,
        multimodal_projector_bias: bool,
133
        quant_config: QuantizationConfig | None = None,
134
135
        prefix: str = "",
    ):
136
137
        super().__init__()

138
139
140
141
142
143
144
        self.linear_1 = ColumnParallelLinear(
            vision_hidden_size,
            text_hidden_size,
            bias=multimodal_projector_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.linear_1",
        )
145
        self.act = get_act_fn(projector_hidden_act)
146
147
148
149
150
151
152
        self.linear_2 = RowParallelLinear(
            text_hidden_size,
            text_hidden_size,
            bias=multimodal_projector_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.linear_2",
        )
153

154
    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
155
        hidden_states, _ = self.linear_1(image_features)
156
        hidden_states = self.act(hidden_states)
157
        hidden_states, _ = self.linear_2(hidden_states)
158
159
160
        return hidden_states


161
162
class LlavaLikeConfig(Protocol):
    vision_config: Final[PretrainedConfig]
163
    image_token_index: Final[int]
164
    vision_feature_select_strategy: Final[str]
165
    vision_feature_layer: Final[int | list[int]]
166

167

168
169
170
171
class LlavaLikeProcessor(Protocol):
    image_token: Final[str]


172
173
class BaseLlavaProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self) -> LlavaLikeConfig:
174
        return self.ctx.get_hf_config(LlavaConfig)
175

176
177
    def get_vision_encoder_info(self):
        return get_vision_encoder_info(self.get_hf_config())
178

179
    @abstractmethod
180
    def get_hf_processor(self, **kwargs: object) -> LlavaLikeProcessor:
181
        raise NotImplementedError
182

183
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
184
        return {"image": None}
185

186
187
188
189
190
191
192
193
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        hf_config = self.get_hf_config()
        vision_encoder_info = self.get_vision_encoder_info()
194

195
        return get_num_selected_vision_tokens(
196
197
198
199
            vision_encoder_info.get_num_image_tokens(
                image_width=image_width,
                image_height=image_height,
            ),
200
            hf_config.vision_feature_select_strategy,
201
        )
202

203
204
    def get_image_size_with_most_features(self) -> ImageSize:
        vision_encoder_info = self.get_vision_encoder_info()
205
206
        width = height = vision_encoder_info.get_image_size()
        return ImageSize(width=width, height=height)
207

208
209
    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()
210

211
        return self.get_num_image_tokens(
212
213
214
215
            image_width=target_width,
            image_height=target_height,
        )

216
217
218
219
220

_I = TypeVar("_I", bound=BaseLlavaProcessingInfo)


class LlavaDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
221
222
223
224
225
226
227
228
229
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        processor = self.info.get_hf_processor()
        image_token = processor.image_token

        return image_token * num_images

    def get_dummy_mm_data(
230
        self,
231
        seq_len: int,
232
        mm_counts: Mapping[str, int],
233
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
234
    ) -> MultiModalDataDict:
235
236
        num_images = mm_counts.get("image", 0)

237
        target_width, target_height = self.info.get_image_size_with_most_features()
238

239
240
        image_overrides = mm_options.get("image") if mm_options else None

241
        return {
242
243
244
245
246
247
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            )
248
249
250
        }


251
class LlavaProcessingInfo(BaseLlavaProcessingInfo):
252
    def get_hf_processor(self, **kwargs: object):
253
254
255
256
257
258
259
        hf_processor = self.ctx.get_hf_processor(LlavaProcessor, **kwargs)
        # In case patch_size is omitted from `processor_config.json`
        # e.g. for E5-V: https://huggingface.co/royokong/e5-v
        if hf_processor.patch_size is None:
            patch_size = self.get_vision_encoder_info().get_patch_size()
            hf_processor.patch_size = patch_size
        return hf_processor
260
261


262
class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor[_I]):
263
264
265
266
267
268
269
270
    # Copied from BaseMultiModalProcessor
    @abstractmethod
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        raise NotImplementedError
271

272
    def _get_prompt_updates(
273
274
275
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
276
        out_mm_kwargs: MultiModalKwargsItems,
277
    ) -> Sequence[PromptUpdate]:
278
        hf_config = self.info.get_hf_config()
279
280
281
282
        image_token_id = hf_config.image_token_index

        def get_replacement(item_idx: int):
            images = mm_items.get_items(
283
284
                "image", (ImageEmbeddingItems, ImageProcessorItems)
            )
285
286
287
288
289

            if isinstance(images, ImageEmbeddingItems):
                num_image_tokens = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
290
                num_image_tokens = self.info.get_num_image_tokens(
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
                    image_width=image_size.width,
                    image_height=image_size.height,
                )

            return [image_token_id] * num_image_tokens

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
                replacement=get_replacement,
            ),
        ]


306
class LlavaMultiModalProcessor(BaseLlavaMultiModalProcessor[LlavaProcessingInfo]):
307
308
309
310
311
312
313
314
315
316
317
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )


318
class PixtralHFProcessingInfo(BaseLlavaProcessingInfo):
319
320
    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(PixtralProcessor, **kwargs)
321

322

323
class PixtralHFMultiModalProcessor(BaseMultiModalProcessor[PixtralHFProcessingInfo]):
324
325
326
327
328
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
329
        tok_kwargs: Mapping[str, object],
330
331
332
333
334
    ) -> BatchFeature:
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
335
            tok_kwargs=tok_kwargs,
336
        )
337

338
339
        pixel_values = processed_outputs.get("pixel_values")
        if pixel_values is not None:
340
341
342
343
            # Avoid padding since we need the output for each image to be
            # independent of other images for the cache to work correctly
            image_sizes = processed_outputs["image_sizes"]
            assert len(pixel_values) == len(image_sizes)
344

345
346
347
            processed_outputs["pixel_values"] = [
                p[:, :h, :w] for p, (h, w) in zip(pixel_values, image_sizes)
            ]
348

349
        return processed_outputs
350

351
352
353
354
355
356
357
358
359
360
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )

361
    def _get_prompt_updates(
362
363
        self,
        mm_items: MultiModalDataItems,
364
        hf_processor_mm_kwargs: Mapping[str, object],
365
        out_mm_kwargs: MultiModalKwargsItems,
366
    ) -> Sequence[PromptUpdate]:
367
        processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
368
        hf_config = self.info.get_hf_config()
369
370
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()
371

372
373
374
        image_break_id = vocab[processor.image_break_token]
        image_token_id = hf_config.image_token_index
        image_end_id = vocab[processor.image_end_token]
375

376
377
        assert isinstance(hf_config.vision_config, PixtralVisionConfig)
        encoder_info = PixtralHFEncoderInfo(hf_config)
378

379
380
381
        def get_replacement(item_idx: int):
            images = mm_items.get_items("image", ImageProcessorItems)
            image_size = images.get_image_size(item_idx)
382

383
            ncols, nrows = encoder_info.get_patch_grid_size(
384
385
386
                image_width=image_size.width,
                image_height=image_size.height,
            )
387

388
389
            tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
            tokens[-1] = image_end_id
390

391
            return PromptUpdateDetails.select_token_id(tokens, image_token_id)
392
393
394
395
396

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
397
398
                replacement=get_replacement,
            ),
399
400
        ]

401

402
def _build_llava_or_pixtral_hf_info(
403
404
    ctx: InputProcessingContext,
) -> BaseLlavaProcessingInfo:
405
406
407
408
409
410
411
412
    hf_config = ctx.get_hf_config(LlavaConfig)

    if isinstance(hf_config.vision_config, PixtralVisionConfig):
        return PixtralHFProcessingInfo(ctx)

    return LlavaProcessingInfo(ctx)


413
def _build_llava_or_pixtral_hf_processor(
414
415
    info: _I,
    dummy_inputs: BaseDummyInputsBuilder[_I],
416
    *,
417
    cache: BaseMultiModalProcessorCache | None = None,
418
) -> BaseMultiModalProcessor:
419
    if isinstance(info, PixtralHFProcessingInfo):
420
        return PixtralHFMultiModalProcessor(
421
422
423
424
425
426
427
428
429
            info,
            dummy_inputs,  # type: ignore
            cache=cache,
        )

    if isinstance(info, LlavaProcessingInfo):
        return LlavaMultiModalProcessor(
            info,
            dummy_inputs,  # type: ignore
430
            cache=cache,
431
        )
432

433
    raise NotImplementedError(type(info))
434
435
436
437
438


def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
    """Determine the number of hidden layers to initialize up to in the
    visual encoder.
439

440
441
442
443
444
445
446
    Args:
        hf_config: Model config with vision feature layer(s).
    """
    feature_layers = hf_config.vision_feature_layer
    num_hidden_layers = hf_config.vision_config.num_hidden_layers
    # If we have one feature layer, initialize up to that layer
    if isinstance(feature_layers, int):
447
        return get_layer_index(feature_layers, num_hidden_layers)
448
449
    # If we have multiple feature layers, initialize up to the deepest one
    elif isinstance(feature_layers, (list, tuple)):
450
        return max(get_layer_index(idx, num_hidden_layers) for idx in feature_layers)
451
452
453
    raise TypeError(
        f"vision_layer_feature type: {type(feature_layers)} is not supported"
    )
454
455


456
457
def init_vision_tower_for_llava(
    hf_config: LlavaLikeConfig,
458
    quant_config: QuantizationConfig | None,
459
    *,
460
    require_post_norm: bool | None = None,
461
    prefix: str = "",
462
) -> CLIPVisionModel | SiglipVisionModel | PixtralHFVisionModel:
463
464
    vision_config = hf_config.vision_config

465
466
    # Initialize the vision tower only up to the deepest required feature layer
    num_hidden_layers = _get_num_hidden_layers(hf_config)
467
468
469
470

    if isinstance(vision_config, CLIPVisionConfig):
        return CLIPVisionModel(
            vision_config,
471
            quant_config=quant_config,
472
            num_hidden_layers_override=num_hidden_layers,
473
            require_post_norm=require_post_norm,
474
            prefix=prefix,
475
476
477
478
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return SiglipVisionModel(
            vision_config,
479
            quant_config=quant_config,
480
            num_hidden_layers_override=num_hidden_layers,
481
            require_post_norm=require_post_norm,
482
            prefix=prefix,
483
        )
484
    elif isinstance(vision_config, PixtralVisionConfig):
485
486
        return PixtralHFVisionModel(
            vision_config,
487
            quant_config=quant_config,
488
489
            num_hidden_layers_override=num_hidden_layers,
            require_post_norm=require_post_norm,
490
            prefix=prefix,
491
        )
492
493
494
495
496

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


497
498
499
500
501
@MULTIMODAL_REGISTRY.register_processor(
    _build_llava_or_pixtral_hf_processor,
    info=_build_llava_or_pixtral_hf_info,
    dummy_inputs=LlavaDummyInputsBuilder,
)
502
class LlavaForConditionalGeneration(
503
    nn.Module, SupportsLoRA, SupportsMultiModal, SupportsPP, SupportsEagle3
504
):
505
506
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
507
        "gate_up_proj": ["gate_proj", "up_proj"],
508
    }
509

510
511
512
513
514
515
516
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.language_model.": "language_model.model.",
            "model.vision_tower.": "vision_tower.",
            "model.multi_modal_projector.": "multi_modal_projector.",
            "lm_head.": "language_model.lm_head.",
517
518
        }
    )
519

520
    @classmethod
521
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
522
523
524
525
526
        if modality.startswith("image"):
            return "<image>"

        raise ValueError("Only image modality is supported")

527
528
529
530
531
532
533
    def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
        self.get_language_model().model.aux_hidden_state_layers = layers

    def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
        num_layers = len(self.get_language_model().model.layers)
        return (2, num_layers // 2, num_layers - 3)

534
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
535
        super().__init__()
536

537
538
539
540
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

541
        self.config = config
542
        self.multimodal_config = multimodal_config
543

544
545
        # NOTE: These are special cases for Pixtral-12B in the HF-format
        # https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json  # noqa
546
547
548
549
        if (
            config.text_config.architectures is None
            and config.text_config.model_type == "mistral"
        ):
550
            config.text_config.architectures = ["MistralForCausalLM"]
551
552
553
554
        if (
            config.projector_hidden_act is None
            and config.vision_config.hidden_act == "gelu"
        ):
555
556
            config.projector_hidden_act = "gelu"

557
        with self._mark_tower_model(vllm_config, "image"):
558
559
            self.vision_tower = init_vision_tower_for_llava(
                config,
560
                quant_config=quant_config,
561
                require_post_norm=False,
562
563
                prefix=maybe_prefix(prefix, "vision_tower"),
            )
564
565
566
567
568
569
            self.multi_modal_projector = LlavaMultiModalProjector(
                vision_hidden_size=config.vision_config.hidden_size,
                text_hidden_size=config.text_config.hidden_size,
                projector_hidden_act=config.projector_hidden_act,
                multimodal_projector_bias=config.multimodal_projector_bias,
                quant_config=quant_config,
570
571
                prefix=maybe_prefix(prefix, "multi_modal_projector"),
            )
572
573
574
575
576
577
578

        with self._mark_language_model(vllm_config):
            self.language_model = init_vllm_registered_model(
                vllm_config=vllm_config,
                hf_config=config.text_config,
                prefix=maybe_prefix(prefix, "language_model"),
            )
579

580
        self.make_empty_intermediate_tensors = (
581
582
            self.language_model.make_empty_intermediate_tensors
        )
583

584
    def _parse_and_validate_image_input(
585
        self, **kwargs: object
586
    ) -> LlavaImageInputs | None:
587
        pixel_values = kwargs.pop("pixel_values", None)
588
        image_embeds = kwargs.pop("image_embeds", None)
589

590
        if pixel_values is None and image_embeds is None:
591
            return None
592

593
        if pixel_values is not None:
594
            if self.config.vision_config.model_type == "pixtral":
595
596
                return PixtralHFImagePixelInputs(
                    type="pixel_values_pixtral",
597
                    pixel_values=pixel_values,
598
599
                )

600
            expected_h = expected_w = self.config.vision_config.image_size
601
602
            return LlavaImagePixelInputs(
                type="pixel_values",
603
                pixel_values=pixel_values,
604
                resolve_bindings={"h": expected_h, "w": expected_w},
605
606
607
            )

        if image_embeds is not None:
608
609
610
            if self.config.vision_config.model_type == "pixtral":
                raise ValueError("Pixtral-HF does not support image_embeds.")

611
612
            return LlavaImageEmbeddingInputs(
                type="image_embeds",
613
                data=image_embeds,
614
615
616
            )

        raise AssertionError("This line should be unreachable.")
617

618
619
    def _image_pixels_to_features(
        self,
620
621
622
        vision_tower: CLIPVisionModel | SiglipVisionModel | PixtralHFVisionModel,
        pixel_values: torch.Tensor | list[torch.Tensor],
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
623
624
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
625
626
627
628
        return vision_tower(
            pixel_values,
            feature_select_strategy=self.config.vision_feature_select_strategy,
        )
629

630
631
    def _process_image_pixels(
        self,
632
633
        inputs: LlavaImagePixelInputs | PixtralHFImagePixelInputs,
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
634
        pixel_values = inputs["pixel_values"]
635
636
637

        return self._image_pixels_to_features(self.vision_tower, pixel_values)

638
639
640
    def _process_image_input(
        self,
        image_input: LlavaImageInputs,
641
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
642
643
644
        if image_input["type"] == "image_embeds":
            return image_input["data"]

645
        image_features = self._process_image_pixels(image_input)
646

647
648
649
        if isinstance(image_features, torch.Tensor):
            return self.multi_modal_projector(image_features)

650
        feature_sizes = [image_feature.shape[0] for image_feature in image_features]
651
652
653
654
655

        image_embeds = self.multi_modal_projector(torch.cat(image_features))
        image_embeds = torch.split(image_embeds, feature_sizes)
        return image_embeds

656
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
657
658
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
659
            return []
660

661
        return self._process_image_input(image_input)
662

663
664
    def forward(
        self,
zhuwenwen's avatar
zhuwenwen committed
665
        input_ids: torch.Tensor,
666
        positions: torch.Tensor,
667
668
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
669
        **kwargs: object,
670
    ) -> torch.Tensor | IntermediateTensors:
Cyrus Leung's avatar
Cyrus Leung committed
671
        """Run forward pass for LLaVA-1.5.
672
673
674

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
675

676
        Concretely, consider a text prompt:
677
678
        `"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.

679
        Tokenizer outputs:
680
681
682
683
        `[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
        278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
684
        before they are inputted to the model, so the input processor prepends
685
686
687
688
689
690
691
692
693
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
        29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
        29901]`.

        We insert 575 tokens so that including the original image token in the
        input, there are a total of 576 (24 * 24) image tokens, which
        corresponds to the number of image tokens inputted to the language
        model, i.e. the number of image tokens outputted by the visual encoder.
694
695
696
697
698
699
700

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
701
702
703
            positions: Position indices for the input tokens.
            intermediate_tensors: Intermediate tensors from prior forward pass.
            inputs_embeds: Optional tensor of input embeddings.
704

705
        Info:
samzong's avatar
samzong committed
706
            [`LlavaImageInputs`][vllm.model_executor.models.llava.LlavaImageInputs]
707
        """
708
709
        if intermediate_tensors is not None:
            inputs_embeds = None
710

711
712
713
        hidden_states = self.language_model.model(
            input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
        )
714
715
716

        return hidden_states

717
718
719
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
720
    ) -> torch.Tensor | None:
721
        return self.language_model.compute_logits(hidden_states)
722

723
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
724
        loader = AutoWeightsLoader(self)
725
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
726

727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model",
            connector="multi_modal_projector",
            tower_model="vision_tower",
        )

    def get_num_mm_encoder_tokens(
        self,
        num_image_tokens: int,
    ) -> int:
        # LLaVA's vision encoder outputs one token per patch without
        # spatial merging or pixel shuffle
        return num_image_tokens

    def get_num_mm_connector_tokens(
        self,
        num_vision_tokens: int,
    ) -> int:
        # LLaVA's MLP projector outputs the same number of tokens
        # as it receives from the vision encoder (1:1 mapping)
        return num_vision_tokens

753

754
class MantisProcessingInfo(LlavaProcessingInfo):
755
    def get_hf_processor(self, **kwargs: object):
756
757
758
        hf_config = self.get_hf_config()
        vision_info = self.get_vision_encoder_info()

759
        kwargs.setdefault("patch_size", vision_info.get_patch_size())
760
761
762
763
        kwargs.setdefault(
            "vision_feature_select_strategy",
            hf_config.vision_feature_select_strategy,
        )
764

765
        return self.ctx.get_hf_processor(LlavaProcessor, **kwargs)
766
767


768
class MantisMultiModalProcessor(LlavaMultiModalProcessor):
769
770
    def apply(
        self,
771
        prompt: str | list[int],
772
773
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
774
775
        tokenization_kwargs: Mapping[str, object] | None = None,
        mm_uuids: MultiModalUUIDDict | None = None,
776
    ) -> MultiModalInputs:
777
        hf_config = self.info.get_hf_config()
778
        image_token_id = hf_config.image_token_index
779
780

        # Assume that it doesn't depend on the image size
781
        num_image_tokens = self.info.get_num_image_tokens(
782
783
784
            image_width=-1,
            image_height=-1,
        )
785

786
787
788
789
790
791
792
        result = super().apply(
            prompt,
            mm_data,
            hf_processor_mm_kwargs,
            tokenization_kwargs,
            mm_uuids=mm_uuids,
        )
793

794
795
        mm_items = self._to_mm_items(mm_data)
        mm_item_counts = mm_items.get_all_counts()
796
        mm_kwargs = result["mm_kwargs"]
797
        mm_hashes = result["mm_hashes"]
798
799
800
801

        # We reimplement the functionality of MLlavaProcessor from
        # https://github.com/TIGER-AI-Lab/Mantis.git
        def get_replacement_mantis(item_idx: int):
802
803
804
805
806
807
            return "".join(
                [
                    f"(image {item_idx + 1}: <Image>",  # 7 tokens
                    "<image>" * num_image_tokens,
                    "</Image>)",  # 3 tokens
                ]
808
            )
809
810
811
812
813
814
815
816
817
818
819

        mantis_mm_repls = self._bind_and_group_updates(
            [
                PromptReplacement(
                    modality="image",
                    target=[image_token_id] * num_image_tokens,
                    replacement=get_replacement_mantis,
                )
            ],
            mm_item_counts,
        )
820

821
        prompt_ids, _ = self._apply_prompt_updates(
822
            result["prompt_token_ids"],
823
            mantis_mm_repls,
824
825
        )

826
        orig_repls = self._get_mm_prompt_updates(
827
828
829
830
            mm_items,
            hf_processor_mm_kwargs,
            mm_kwargs,
        )
831
        mm_placeholders = self._find_mm_placeholders(prompt_ids, orig_repls)
832
        self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
833

834
835
836
        mm_placeholder_ranges = {
            modality: [item.to_range() for item in placeholders]
            for modality, placeholders in mm_placeholders.items()
837
838
        }

839
        return MultiModalInputs(
840
841
842
            type="multimodal",
            prompt_token_ids=prompt_ids,
            mm_kwargs=mm_kwargs,
843
            mm_hashes=mm_hashes,
844
            mm_placeholders=mm_placeholder_ranges,
845
        )
846
847
848
849


# To use this model, please use
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
850
851
852
853
854
@MULTIMODAL_REGISTRY.register_processor(
    MantisMultiModalProcessor,
    info=MantisProcessingInfo,
    dummy_inputs=LlavaDummyInputsBuilder,
)
855
class MantisForConditionalGeneration(LlavaForConditionalGeneration):
zhuwenwen's avatar
zhuwenwen committed
856
    pass