"vllm/vscode:/vscode.git/clone" did not exist on "36e4acd02a955f71ebb7b220cbfae4a4379bc57b"
llava.py 28.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import abstractmethod
5
from collections.abc import Iterable, Mapping, Sequence
6
from typing import Annotated, Final, Literal, Protocol, TypeAlias, TypeVar
7
8

import torch
9
import torch.nn as nn
10
11
12
13
14
15
16
17
from transformers import (
    BatchFeature,
    CLIPVisionConfig,
    LlavaConfig,
    PixtralVisionConfig,
    PretrainedConfig,
    SiglipVisionConfig,
)
18
from transformers.models.llava import LlavaProcessor
19
from transformers.models.pixtral import PixtralProcessor
20

21
from vllm.config import VllmConfig
22
from vllm.config.multimodal import BaseDummyOptions
23
from vllm.model_executor.layers.activation import get_act_fn
24
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
25
from vllm.model_executor.layers.quantization import QuantizationConfig
26
from vllm.multimodal import MULTIMODAL_REGISTRY
27
from vllm.multimodal.cache import BaseMultiModalProcessorCache
28
29
30
31
32
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalInputs,
    MultiModalKwargsItems,
33
    mm_inputs,
34
35
36
37
38
39
40
41
)
from vllm.multimodal.parse import (
    ImageEmbeddingItems,
    ImageProcessorItems,
    ImageSize,
    MultiModalDataItems,
)
from vllm.multimodal.processing import (
42
    BaseDummyInputsBuilder,
43
44
45
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    InputProcessingContext,
46
    ProcessorInputs,
47
48
49
    PromptReplacement,
    PromptUpdate,
    PromptUpdateDetails,
50
    TimingContext,
51
)
52
from vllm.sequence import IntermediateTensors
53
from vllm.utils.tensor_schema import TensorSchema, TensorShape
54

55
from .clip import CLIPVisionModel
56
57
from .interfaces import (
    MultiModalEmbeddings,
58
    SupportsEagle,
59
    SupportsEagle3,
60
61
62
63
64
    SupportsLoRA,
    SupportsMultiModal,
    SupportsPP,
)
from .module_mapping import MultiModelKeys
65
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
66
from .siglip import SiglipVisionModel
67
68
69
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
70
    get_layer_index,
71
72
73
    init_vllm_registered_model,
    maybe_prefix,
)
74
from .vision import get_num_selected_vision_tokens, get_vision_encoder_info
75
76


77
class LlavaImagePixelInputs(TensorSchema):
78
    """
79
80
81
82
83
    Dimensions:
        - bn: Batch size * number of images
        - c: Number of channels (3)
        - h: Height
        - w: Width
84

85
86
87
    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """
88

89
90
    type: Literal["pixel_values"] = "pixel_values"
    pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
91

92

93
class PixtralHFImagePixelInputs(TensorSchema):
94
    """
95
96
97
98
99
    Dimensions:
        - bn: Batch size * number of images
        - c: Number of channels
        - h: Height
        - w: Width
100

101
102
103
    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """
104

105
    type: Literal["pixel_values_pixtral"] = "pixel_values_pixtral"
106
    pixel_values: Annotated[
107
        torch.Tensor | list[torch.Tensor],
108
109
        TensorShape("bn", "c", "h", "w", dynamic_dims={"h", "w"}),
    ]
110

111

112
class LlavaImageEmbeddingInputs(TensorSchema):
113
    """
114
115
116
117
118
    Dimensions:
        - bn: Batch size * number of images
        - ifs: Image feature size
        - hs: Hidden size (must match language model backbone)
    """
119

120
121
    type: Literal["image_embeds"] = "image_embeds"
    data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")]
122
123


124
125
126
LlavaImageInputs: TypeAlias = (
    LlavaImagePixelInputs | PixtralHFImagePixelInputs | LlavaImageEmbeddingInputs
)
127
"""Alias for supported LLaVA image input types."""
128
129


130
class LlavaMultiModalProjector(nn.Module):
131
132
133
134
135
136
    def __init__(
        self,
        vision_hidden_size: int,
        text_hidden_size: int,
        projector_hidden_act: str,
        multimodal_projector_bias: bool,
137
        quant_config: QuantizationConfig | None = None,
138
139
        prefix: str = "",
    ):
140
141
        super().__init__()

142
143
144
145
146
147
148
        self.linear_1 = ColumnParallelLinear(
            vision_hidden_size,
            text_hidden_size,
            bias=multimodal_projector_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.linear_1",
        )
149
        self.act = get_act_fn(projector_hidden_act)
150
151
152
153
154
155
156
        self.linear_2 = RowParallelLinear(
            text_hidden_size,
            text_hidden_size,
            bias=multimodal_projector_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.linear_2",
        )
157

158
    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
159
        hidden_states, _ = self.linear_1(image_features)
160
        hidden_states = self.act(hidden_states)
161
        hidden_states, _ = self.linear_2(hidden_states)
162
163
164
        return hidden_states


165
166
class LlavaLikeConfig(Protocol):
    vision_config: Final[PretrainedConfig]
167
    image_token_index: Final[int]
168
    vision_feature_select_strategy: Final[str]
169
    vision_feature_layer: Final[int | list[int]]
170

171

172
173
174
175
class LlavaLikeProcessor(Protocol):
    image_token: Final[str]


176
177
class BaseLlavaProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self) -> LlavaLikeConfig:
178
        return self.ctx.get_hf_config(LlavaConfig)
179

180
181
    def get_vision_encoder_info(self):
        return get_vision_encoder_info(self.get_hf_config())
182

183
    @abstractmethod
184
    def get_hf_processor(self, **kwargs: object) -> LlavaLikeProcessor:
185
        raise NotImplementedError
186

187
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
188
        return {"image": None}
189

190
191
192
193
194
195
196
197
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        hf_config = self.get_hf_config()
        vision_encoder_info = self.get_vision_encoder_info()
198

199
        return get_num_selected_vision_tokens(
200
201
202
203
            vision_encoder_info.get_num_image_tokens(
                image_width=image_width,
                image_height=image_height,
            ),
204
            hf_config.vision_feature_select_strategy,
205
        )
206

207
208
    def get_image_size_with_most_features(self) -> ImageSize:
        vision_encoder_info = self.get_vision_encoder_info()
209
210
        width = height = vision_encoder_info.get_image_size()
        return ImageSize(width=width, height=height)
211

212
213
    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()
214

215
        return self.get_num_image_tokens(
216
217
218
219
            image_width=target_width,
            image_height=target_height,
        )

220
221
222
223
224

_I = TypeVar("_I", bound=BaseLlavaProcessingInfo)


class LlavaDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
225
226
227
228
229
230
231
232
233
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        processor = self.info.get_hf_processor()
        image_token = processor.image_token

        return image_token * num_images

    def get_dummy_mm_data(
234
        self,
235
        seq_len: int,
236
        mm_counts: Mapping[str, int],
237
        mm_options: Mapping[str, BaseDummyOptions],
238
    ) -> MultiModalDataDict:
239
240
        num_images = mm_counts.get("image", 0)

241
        target_width, target_height = self.info.get_image_size_with_most_features()
242

243
        image_overrides = mm_options.get("image")
244

245
        return {
246
247
248
249
250
251
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            )
252
253
254
        }


255
class LlavaProcessingInfo(BaseLlavaProcessingInfo):
256
    def get_hf_processor(self, **kwargs: object):
257
258
259
260
261
262
263
        hf_processor = self.ctx.get_hf_processor(LlavaProcessor, **kwargs)
        # In case patch_size is omitted from `processor_config.json`
        # e.g. for E5-V: https://huggingface.co/royokong/e5-v
        if hf_processor.patch_size is None:
            patch_size = self.get_vision_encoder_info().get_patch_size()
            hf_processor.patch_size = patch_size
        return hf_processor
264
265


266
class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor[_I]):
267
268
269
270
271
272
273
274
    # Copied from BaseMultiModalProcessor
    @abstractmethod
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        raise NotImplementedError
275

276
    def _get_prompt_updates(
277
278
279
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
280
        out_mm_kwargs: MultiModalKwargsItems,
281
    ) -> Sequence[PromptUpdate]:
282
        hf_config = self.info.get_hf_config()
283
284
285
286
        image_token_id = hf_config.image_token_index

        def get_replacement(item_idx: int):
            images = mm_items.get_items(
287
288
                "image", (ImageEmbeddingItems, ImageProcessorItems)
            )
289
290
291
292
293

            if isinstance(images, ImageEmbeddingItems):
                num_image_tokens = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
294
                num_image_tokens = self.info.get_num_image_tokens(
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
                    image_width=image_size.width,
                    image_height=image_size.height,
                )

            return [image_token_id] * num_image_tokens

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
                replacement=get_replacement,
            ),
        ]


310
class LlavaMultiModalProcessor(BaseLlavaMultiModalProcessor[LlavaProcessingInfo]):
311
312
313
314
315
316
317
318
319
320
321
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )


322
class PixtralHFProcessingInfo(BaseLlavaProcessingInfo):
323
324
    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(PixtralProcessor, **kwargs)
325

326

327
class PixtralHFMultiModalProcessor(BaseMultiModalProcessor[PixtralHFProcessingInfo]):
328
329
330
331
332
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
333
        tok_kwargs: Mapping[str, object],
334
335
336
337
338
    ) -> BatchFeature:
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
339
            tok_kwargs=tok_kwargs,
340
        )
341

342
343
        pixel_values = processed_outputs.get("pixel_values")
        if pixel_values is not None:
344
345
346
347
            # Avoid padding since we need the output for each image to be
            # independent of other images for the cache to work correctly
            image_sizes = processed_outputs["image_sizes"]
            assert len(pixel_values) == len(image_sizes)
348

349
350
351
            processed_outputs["pixel_values"] = [
                p[:, :h, :w] for p, (h, w) in zip(pixel_values, image_sizes)
            ]
352

353
        return processed_outputs
354

355
356
357
358
359
360
361
362
363
364
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )

365
    def _get_prompt_updates(
366
367
        self,
        mm_items: MultiModalDataItems,
368
        hf_processor_mm_kwargs: Mapping[str, object],
369
        out_mm_kwargs: MultiModalKwargsItems,
370
    ) -> Sequence[PromptUpdate]:
371
        processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
372
        hf_config = self.info.get_hf_config()
373
374
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()
375

376
377
378
        image_break_id = vocab[processor.image_break_token]
        image_token_id = hf_config.image_token_index
        image_end_id = vocab[processor.image_end_token]
379

380
381
        assert isinstance(hf_config.vision_config, PixtralVisionConfig)
        encoder_info = PixtralHFEncoderInfo(hf_config)
382

383
384
385
        def get_replacement(item_idx: int):
            images = mm_items.get_items("image", ImageProcessorItems)
            image_size = images.get_image_size(item_idx)
386

387
            ncols, nrows = encoder_info.get_patch_grid_size(
388
389
390
                image_width=image_size.width,
                image_height=image_size.height,
            )
391

392
393
            tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
            tokens[-1] = image_end_id
394

395
            return PromptUpdateDetails.select_token_id(tokens, image_token_id)
396
397
398
399
400

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
401
402
                replacement=get_replacement,
            ),
403
404
        ]

405

406
def _build_llava_or_pixtral_hf_info(
407
408
    ctx: InputProcessingContext,
) -> BaseLlavaProcessingInfo:
409
410
411
412
413
414
415
416
    hf_config = ctx.get_hf_config(LlavaConfig)

    if isinstance(hf_config.vision_config, PixtralVisionConfig):
        return PixtralHFProcessingInfo(ctx)

    return LlavaProcessingInfo(ctx)


417
def _build_llava_or_pixtral_hf_processor(
418
419
    info: _I,
    dummy_inputs: BaseDummyInputsBuilder[_I],
420
    *,
421
    cache: BaseMultiModalProcessorCache | None = None,
422
) -> BaseMultiModalProcessor:
423
    if isinstance(info, PixtralHFProcessingInfo):
424
        return PixtralHFMultiModalProcessor(
425
426
427
428
429
430
431
432
433
            info,
            dummy_inputs,  # type: ignore
            cache=cache,
        )

    if isinstance(info, LlavaProcessingInfo):
        return LlavaMultiModalProcessor(
            info,
            dummy_inputs,  # type: ignore
434
            cache=cache,
435
        )
436

437
    raise NotImplementedError(type(info))
438
439
440
441
442


def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
    """Determine the number of hidden layers to initialize up to in the
    visual encoder.
443

444
445
446
447
448
449
450
    Args:
        hf_config: Model config with vision feature layer(s).
    """
    feature_layers = hf_config.vision_feature_layer
    num_hidden_layers = hf_config.vision_config.num_hidden_layers
    # If we have one feature layer, initialize up to that layer
    if isinstance(feature_layers, int):
451
        return get_layer_index(feature_layers, num_hidden_layers)
452
453
    # If we have multiple feature layers, initialize up to the deepest one
    elif isinstance(feature_layers, (list, tuple)):
454
        return max(get_layer_index(idx, num_hidden_layers) for idx in feature_layers)
455
456
457
    raise TypeError(
        f"vision_layer_feature type: {type(feature_layers)} is not supported"
    )
458
459


460
461
def init_vision_tower_for_llava(
    hf_config: LlavaLikeConfig,
462
    quant_config: QuantizationConfig | None,
463
    *,
464
    require_post_norm: bool | None = None,
465
    prefix: str = "",
466
) -> CLIPVisionModel | SiglipVisionModel | PixtralHFVisionModel:
467
468
    vision_config = hf_config.vision_config

469
470
    # Initialize the vision tower only up to the deepest required feature layer
    num_hidden_layers = _get_num_hidden_layers(hf_config)
471
472
473
474

    if isinstance(vision_config, CLIPVisionConfig):
        return CLIPVisionModel(
            vision_config,
475
            quant_config=quant_config,
476
            num_hidden_layers_override=num_hidden_layers,
477
            require_post_norm=require_post_norm,
478
            prefix=prefix,
479
480
481
482
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return SiglipVisionModel(
            vision_config,
483
            quant_config=quant_config,
484
            num_hidden_layers_override=num_hidden_layers,
485
            require_post_norm=require_post_norm,
486
            prefix=prefix,
487
        )
488
    elif isinstance(vision_config, PixtralVisionConfig):
489
490
        return PixtralHFVisionModel(
            vision_config,
491
            quant_config=quant_config,
492
493
            num_hidden_layers_override=num_hidden_layers,
            require_post_norm=require_post_norm,
494
            prefix=prefix,
495
        )
496
497
498
499
500

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


501
502
503
504
505
@MULTIMODAL_REGISTRY.register_processor(
    _build_llava_or_pixtral_hf_processor,
    info=_build_llava_or_pixtral_hf_info,
    dummy_inputs=LlavaDummyInputsBuilder,
)
506
class LlavaForConditionalGeneration(
507
508
509
510
511
512
    nn.Module,
    SupportsLoRA,
    SupportsMultiModal,
    SupportsPP,
    SupportsEagle,
    SupportsEagle3,
513
):
514
515
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
516
        "gate_up_proj": ["gate_proj", "up_proj"],
517
    }
518

519
520
521
522
523
524
525
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.language_model.": "language_model.model.",
            "model.vision_tower.": "vision_tower.",
            "model.multi_modal_projector.": "multi_modal_projector.",
            "lm_head.": "language_model.lm_head.",
526
527
        }
    )
528

529
    @classmethod
530
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
531
532
533
534
535
        if modality.startswith("image"):
            return "<image>"

        raise ValueError("Only image modality is supported")

536
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
537
        super().__init__()
538

539
540
541
542
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

543
        self.config = config
544
        self.multimodal_config = multimodal_config
545

546
547
548
549
550
        self.configure_mm_token_handling(
            vocab_size=config.text_config.vocab_size,
            mm_token_ids=[config.image_token_index],
        )

551
552
        # NOTE: These are special cases for Pixtral-12B in the HF-format
        # https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json  # noqa
553
554
555
556
        if (
            config.text_config.architectures is None
            and config.text_config.model_type == "mistral"
        ):
557
            config.text_config.architectures = ["MistralForCausalLM"]
558
559
560
561
        if (
            config.projector_hidden_act is None
            and config.vision_config.hidden_act == "gelu"
        ):
562
563
            config.projector_hidden_act = "gelu"

564
        with self._mark_tower_model(vllm_config, "image"):
565
566
            self.vision_tower = init_vision_tower_for_llava(
                config,
567
                quant_config=quant_config,
568
                require_post_norm=False,
569
570
                prefix=maybe_prefix(prefix, "vision_tower"),
            )
571
572
573
574
575
576
            self.multi_modal_projector = LlavaMultiModalProjector(
                vision_hidden_size=config.vision_config.hidden_size,
                text_hidden_size=config.text_config.hidden_size,
                projector_hidden_act=config.projector_hidden_act,
                multimodal_projector_bias=config.multimodal_projector_bias,
                quant_config=quant_config,
577
578
                prefix=maybe_prefix(prefix, "multi_modal_projector"),
            )
579
580
581
582
583
584
585

        with self._mark_language_model(vllm_config):
            self.language_model = init_vllm_registered_model(
                vllm_config=vllm_config,
                hf_config=config.text_config,
                prefix=maybe_prefix(prefix, "language_model"),
            )
586

587
        self.make_empty_intermediate_tensors = (
588
589
            self.language_model.make_empty_intermediate_tensors
        )
590

591
    def _parse_and_validate_image_input(
592
        self, **kwargs: object
593
    ) -> LlavaImageInputs | None:
594
        pixel_values = kwargs.pop("pixel_values", None)
595
        image_embeds = kwargs.pop("image_embeds", None)
596

597
        if pixel_values is None and image_embeds is None:
598
            return None
599

600
        if pixel_values is not None:
601
            if self.config.vision_config.model_type == "pixtral":
602
603
                return PixtralHFImagePixelInputs(
                    type="pixel_values_pixtral",
604
                    pixel_values=pixel_values,
605
606
                )

607
            expected_h = expected_w = self.config.vision_config.image_size
608
609
            return LlavaImagePixelInputs(
                type="pixel_values",
610
                pixel_values=pixel_values,
611
                resolve_bindings={"h": expected_h, "w": expected_w},
612
613
614
            )

        if image_embeds is not None:
615
616
617
            if self.config.vision_config.model_type == "pixtral":
                raise ValueError("Pixtral-HF does not support image_embeds.")

618
619
            return LlavaImageEmbeddingInputs(
                type="image_embeds",
620
                data=image_embeds,
621
622
623
            )

        raise AssertionError("This line should be unreachable.")
624

625
626
    def _image_pixels_to_features(
        self,
627
628
629
        vision_tower: CLIPVisionModel | SiglipVisionModel | PixtralHFVisionModel,
        pixel_values: torch.Tensor | list[torch.Tensor],
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
630
631
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
632
633
634
635
        return vision_tower(
            pixel_values,
            feature_select_strategy=self.config.vision_feature_select_strategy,
        )
636

637
638
    def _process_image_pixels(
        self,
639
640
        inputs: LlavaImagePixelInputs | PixtralHFImagePixelInputs,
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
641
        pixel_values = inputs["pixel_values"]
642
643
644

        return self._image_pixels_to_features(self.vision_tower, pixel_values)

645
646
647
    def _process_image_input(
        self,
        image_input: LlavaImageInputs,
648
    ) -> torch.Tensor | tuple[torch.Tensor, ...]:
649
650
651
        if image_input["type"] == "image_embeds":
            return image_input["data"]

652
        image_features = self._process_image_pixels(image_input)
653

654
655
656
        if isinstance(image_features, torch.Tensor):
            return self.multi_modal_projector(image_features)

657
        feature_sizes = [image_feature.shape[0] for image_feature in image_features]
658
659
660
661
662

        image_embeds = self.multi_modal_projector(torch.cat(image_features))
        image_embeds = torch.split(image_embeds, feature_sizes)
        return image_embeds

663
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
664
665
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
666
            return []
667

668
        return self._process_image_input(image_input)
669

670
671
    def forward(
        self,
zhuwenwen's avatar
zhuwenwen committed
672
        input_ids: torch.Tensor | None,
673
        positions: torch.Tensor,
674
675
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
676
        **kwargs: object,
677
    ) -> torch.Tensor | IntermediateTensors:
Cyrus Leung's avatar
Cyrus Leung committed
678
        """Run forward pass for LLaVA-1.5.
679
680
681

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
682

683
        Concretely, consider a text prompt:
684
685
        `"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.

686
        Tokenizer outputs:
687
688
689
690
        `[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
        278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
691
        before they are inputted to the model, so the input processor prepends
692
693
694
695
696
697
698
699
700
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
        29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
        29901]`.

        We insert 575 tokens so that including the original image token in the
        input, there are a total of 576 (24 * 24) image tokens, which
        corresponds to the number of image tokens inputted to the language
        model, i.e. the number of image tokens outputted by the visual encoder.
701
702
703
704
705
706
707

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
708
709
710
            positions: Position indices for the input tokens.
            intermediate_tensors: Intermediate tensors from prior forward pass.
            inputs_embeds: Optional tensor of input embeddings.
711

712
        Info:
samzong's avatar
samzong committed
713
            [`LlavaImageInputs`][vllm.model_executor.models.llava.LlavaImageInputs]
714
        """
715
716
        if intermediate_tensors is not None:
            inputs_embeds = None
717

718
719
720
        hidden_states = self.language_model.model(
            input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
        )
721
722
723

        return hidden_states

724
725
726
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
727
    ) -> torch.Tensor | None:
728
        return self.language_model.compute_logits(hidden_states)
729

730
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
731
        loader = AutoWeightsLoader(self)
732
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
733

734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model",
            connector="multi_modal_projector",
            tower_model="vision_tower",
        )

    def get_num_mm_encoder_tokens(
        self,
        num_image_tokens: int,
    ) -> int:
        # LLaVA's vision encoder outputs one token per patch without
        # spatial merging or pixel shuffle
        return num_image_tokens

    def get_num_mm_connector_tokens(
        self,
        num_vision_tokens: int,
    ) -> int:
        # LLaVA's MLP projector outputs the same number of tokens
        # as it receives from the vision encoder (1:1 mapping)
        return num_vision_tokens

760

761
class MantisProcessingInfo(LlavaProcessingInfo):
762
    def get_hf_processor(self, **kwargs: object):
763
764
765
        hf_config = self.get_hf_config()
        vision_info = self.get_vision_encoder_info()

766
        kwargs.setdefault("patch_size", vision_info.get_patch_size())
767
768
769
770
        kwargs.setdefault(
            "vision_feature_select_strategy",
            hf_config.vision_feature_select_strategy,
        )
771

772
        return self.ctx.get_hf_processor(LlavaProcessor, **kwargs)
773
774


775
class MantisMultiModalProcessor(LlavaMultiModalProcessor):
776
777
    def apply(
        self,
778
779
        inputs: ProcessorInputs,
        timing_ctx: TimingContext,
780
    ) -> MultiModalInputs:
781
        hf_config = self.info.get_hf_config()
782
        image_token_id = hf_config.image_token_index
783
784

        # Assume that it doesn't depend on the image size
785
        num_image_tokens = self.info.get_num_image_tokens(
786
787
788
            image_width=-1,
            image_height=-1,
        )
789

790
        result = super().apply(inputs, timing_ctx)
791

792
        mm_item_counts = inputs.mm_data_items.get_all_counts()
793
        mm_kwargs = result["mm_kwargs"]
794
        mm_hashes = result["mm_hashes"]
795
796
797
798

        # We reimplement the functionality of MLlavaProcessor from
        # https://github.com/TIGER-AI-Lab/Mantis.git
        def get_replacement_mantis(item_idx: int):
799
800
801
802
803
804
            return "".join(
                [
                    f"(image {item_idx + 1}: <Image>",  # 7 tokens
                    "<image>" * num_image_tokens,
                    "</Image>)",  # 3 tokens
                ]
805
            )
806
807
808
809
810
811
812
813
814
815
816

        mantis_mm_repls = self._bind_and_group_updates(
            [
                PromptReplacement(
                    modality="image",
                    target=[image_token_id] * num_image_tokens,
                    replacement=get_replacement_mantis,
                )
            ],
            mm_item_counts,
        )
817

818
        prompt_ids, _ = self._apply_prompt_updates(
819
            result["prompt_token_ids"],
820
            mantis_mm_repls,
821
822
        )

823
        orig_repls = self._get_mm_prompt_updates(
824
825
            inputs.mm_data_items,
            inputs.hf_processor_mm_kwargs,
826
827
            mm_kwargs,
        )
828
        mm_placeholders = self._find_mm_placeholders(prompt_ids, orig_repls)
829
        self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
830

831
832
833
        mm_placeholder_ranges = {
            modality: [item.to_range() for item in placeholders]
            for modality, placeholders in mm_placeholders.items()
834
835
        }

836
        return mm_inputs(
837
838
            prompt_token_ids=prompt_ids,
            mm_kwargs=mm_kwargs,
839
            mm_hashes=mm_hashes,
840
            mm_placeholders=mm_placeholder_ranges,
841
        )
842
843
844
845


# To use this model, please use
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
846
847
848
849
850
@MULTIMODAL_REGISTRY.register_processor(
    MantisMultiModalProcessor,
    info=MantisProcessingInfo,
    dummy_inputs=LlavaDummyInputsBuilder,
)
851
class MantisForConditionalGeneration(LlavaForConditionalGeneration):
zhuwenwen's avatar
zhuwenwen committed
852
    pass