llava.py 32.3 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from abc import abstractmethod
4
from collections.abc import Iterable, Mapping, Sequence
5
from functools import cached_property
6
7
from typing import (Final, Literal, Optional, Protocol, Set, Tuple, TypedDict,
                    TypeVar, Union, cast)
8
9

import torch
10
import torch.nn as nn
11
from packaging.version import Version
12
13
from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig,
                          PixtralVisionConfig, PretrainedConfig,
14
                          SiglipVisionConfig)
15
from transformers import __version__ as TRANSFORMERS_VERSION
16
from transformers.models.llava import LlavaProcessor
17
from transformers.models.pixtral import PixtralProcessor
18

19
from vllm.config import VllmConfig
20
from vllm.inputs import InputProcessingContext
21
from vllm.jsontree import json_map_leaves
22
from vllm.model_executor.layers.activation import get_act_fn
23
24
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               RowParallelLinear)
25
from vllm.model_executor.layers.quantization import QuantizationConfig
Joe Runde's avatar
Joe Runde committed
26
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
27
from vllm.model_executor.sampling_metadata import SamplingMetadata
28
from vllm.multimodal import MULTIMODAL_REGISTRY
29
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
30
                                    MultiModalInputs, MultiModalKwargs)
31
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
32
                                   ImageSize, MultiModalDataItems)
33
from vllm.multimodal.processing import (BaseMultiModalProcessor,
34
                                        BaseProcessingInfo, ProcessingCache,
35
36
                                        PromptReplacement, PromptUpdate,
                                        PromptUpdateDetails)
37
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
38
from vllm.sequence import IntermediateTensors
39

40
from .clip import CLIPVisionModel
41
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
42
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
43
from .siglip import SiglipVisionModel
44
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
45
                    maybe_prefix, merge_multimodal_embeddings)
46
from .vision import get_vision_encoder_info
47
48


49
50
class LlavaImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
51
    pixel_values: torch.Tensor
52
53
54
55
56
57
    """
    Shape: `(batch_size * num_images, num_channels, height, width)`

    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """
58

59
60
61
62
63
64
65
66
67
68
69

class PixtralHFImagePixelInputs(TypedDict):
    type: Literal["pixel_values_pixtral"]
    pixel_values: Union[torch.Tensor, list[torch.Tensor]]
    """
    Shape: `(batch_size * num_images, num_channels, height, width)`

    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """

70
71
72
73

class LlavaImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
    data: torch.Tensor
74
    """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
75
76
77
78
79

    `hidden_size` must match the hidden size of language model backbone.
    """


80
81
LlavaImageInputs = Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs,
                         LlavaImageEmbeddingInputs]
82
83


84
85
class LlavaMultiModalProjector(nn.Module):

86
87
88
89
    def __init__(self,
                 vision_hidden_size: int,
                 text_hidden_size: int,
                 projector_hidden_act: str,
90
                 multimodal_projector_bias: bool,
91
92
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
93
94
        super().__init__()

95
96
        self.linear_1 = ColumnParallelLinear(vision_hidden_size,
                                             text_hidden_size,
97
                                             bias=multimodal_projector_bias,
98
99
                                             quant_config=quant_config,
                                             prefix=f"{prefix}.linear_1")
100
        self.act = get_act_fn(projector_hidden_act)
101
102
        self.linear_2 = RowParallelLinear(text_hidden_size,
                                          text_hidden_size,
103
                                          bias=multimodal_projector_bias,
104
105
                                          quant_config=quant_config,
                                          prefix=f"{prefix}.linear_2")
106

107
    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
108
        hidden_states, _ = self.linear_1(image_features)
109
        hidden_states = self.act(hidden_states)
110
        hidden_states, _ = self.linear_2(hidden_states)
111
112
113
        return hidden_states


114
115
class LlavaLikeConfig(Protocol):
    vision_config: Final[PretrainedConfig]
116
    image_token_index: Final[int]
117
    vision_feature_select_strategy: Final[str]
118
    vision_feature_layer: Final[Union[int, list[int]]]
119

120

121
122
123
124
class LlavaLikeProcessor(Protocol):
    image_token: Final[str]


125
class BaseLlavaProcessingInfo(BaseProcessingInfo):
126

127
    def get_hf_config(self) -> LlavaLikeConfig:
128
        return self.ctx.get_hf_config(LlavaConfig)
129

130
131
    def get_vision_encoder_info(self):
        return get_vision_encoder_info(self.get_hf_config())
132

133
    @abstractmethod
134
    def get_hf_processor(self, **kwargs: object) -> LlavaLikeProcessor:
135
        raise NotImplementedError
136

137
138
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None}
139

140
141
142
143
144
    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> Mapping[str, int]:
145
        return {"image": self.get_max_image_tokens()}
146

147
148
149
150
151
152
153
154
155
156
157
158
159
    def _apply_feature_select_strategy(
        self,
        strategy: str,
        encoder_num_image_tokens: int,
    ) -> int:
        if strategy == "default":
            return encoder_num_image_tokens - 1
        if strategy == "full":
            return encoder_num_image_tokens

        msg = f"Unexpected feature select strategy: {strategy!r}"
        raise NotImplementedError(msg)

160
161
162
163
164
165
166
167
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        hf_config = self.get_hf_config()
        vision_encoder_info = self.get_vision_encoder_info()
168

169
170
171
172
173
174
175
        return self._apply_feature_select_strategy(
            hf_config.vision_feature_select_strategy,
            vision_encoder_info.get_num_image_tokens(
                image_width=image_width,
                image_height=image_height,
            ),
        )
176

177
178
    def get_image_size_with_most_features(self) -> ImageSize:
        vision_encoder_info = self.get_vision_encoder_info()
179
180
        width = height = vision_encoder_info.get_image_size()
        return ImageSize(width=width, height=height)
181

182
183
    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()
184

185
        return self.get_num_image_tokens(
186
187
188
189
            image_width=target_width,
            image_height=target_height,
        )

190
191
192
193
194
195

_I = TypeVar("_I", bound=BaseLlavaProcessingInfo)


class LlavaDummyInputsBuilder(BaseDummyInputsBuilder[_I]):

196
    def get_dummy_processor_inputs(
197
        self,
198
        seq_len: int,
199
200
201
202
        mm_counts: Mapping[str, int],
    ) -> ProcessorInputs:
        num_images = mm_counts.get("image", 0)

203
        processor = self.info.get_hf_processor()
204
        image_token = processor.image_token
205
206
        target_width, target_height = \
            self.info.get_image_size_with_most_features()
207
208
209
210
211
212
213
214
215
216
217
218
219
220

        mm_data = {
            "image":
            self._get_dummy_images(width=target_width,
                                   height=target_height,
                                   num_images=num_images)
        }

        return ProcessorInputs(
            prompt_text=image_token * num_images,
            mm_data=mm_data,
        )


221
class LlavaProcessingInfo(BaseLlavaProcessingInfo):
222

223
    def get_hf_processor(self, **kwargs: object):
224
225
226
227
228
229
230
        hf_processor = self.ctx.get_hf_processor(LlavaProcessor, **kwargs)
        # In case patch_size is omitted from `processor_config.json`
        # e.g. for E5-V: https://huggingface.co/royokong/e5-v
        if hf_processor.patch_size is None:
            patch_size = self.get_vision_encoder_info().get_patch_size()
            hf_processor.patch_size = patch_size
        return hf_processor
231
232


233
class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor[_I]):
234
235
236
237
238
239
240
241
242

    # Copied from BaseMultiModalProcessor
    @abstractmethod
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        raise NotImplementedError
243

244
    def _get_prompt_updates(
245
246
247
248
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
249
    ) -> Sequence[PromptUpdate]:
250
        hf_config = self.info.get_hf_config()
251
252
253
254
255
256
257
258
259
260
        image_token_id = hf_config.image_token_index

        def get_replacement(item_idx: int):
            images = mm_items.get_items(
                "image", (ImageEmbeddingItems, ImageProcessorItems))

            if isinstance(images, ImageEmbeddingItems):
                num_image_tokens = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
261
                num_image_tokens = self.info.get_num_image_tokens(
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
                    image_width=image_size.width,
                    image_height=image_size.height,
                )

            return [image_token_id] * num_image_tokens

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
                replacement=get_replacement,
            ),
        ]


277
278
class LlavaMultiModalProcessor(
        BaseLlavaMultiModalProcessor[LlavaProcessingInfo]):
279

280
281
282
283
284
285
286
287
288
289
290
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )


291
class PixtralHFProcessingInfo(BaseLlavaProcessingInfo):
292

293
294
    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(PixtralProcessor, **kwargs)
295

296

297
298
class PixtralHFMultiModalProcessor(
        BaseMultiModalProcessor[PixtralHFProcessingInfo]):
299

300
301
302
303
304
305
306
307
308
309
310
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
        )
311

312
313
        pixel_values = processed_outputs.get("pixel_values")
        if pixel_values is not None:
314
            # Before/after https://github.com/huggingface/transformers/pull/35122
315
            if Version(TRANSFORMERS_VERSION) <= Version("4.48.3"):
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
                images = mm_data["images"]
                assert isinstance(images, list)

                # Original output: (1, num_images, C, H, W)
                # New output: (num_images, C, H, W)
                assert (isinstance(pixel_values, list)
                        and len(pixel_values) == 1)
                assert (isinstance(pixel_values[0], list)
                        and len(pixel_values[0]) == len(images))

                processed_outputs["pixel_values"] = pixel_values[0]
            else:
                # Avoid padding since we need the output for each image to be
                # independent of other images for the cache to work correctly
                image_sizes = processed_outputs["image_sizes"]
                assert len(pixel_values) == len(image_sizes)

                processed_outputs["pixel_values"] = [
                    p[:, :h, :w]
                    for p, (h, w) in zip(pixel_values, image_sizes)
                ]
337

338
        return processed_outputs
339

340
341
342
343
344
345
346
347
348
349
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )

350
    def _get_prompt_updates(
351
352
        self,
        mm_items: MultiModalDataItems,
353
354
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
355
    ) -> Sequence[PromptUpdate]:
356
        processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
357
        hf_config = self.info.get_hf_config()
358
359
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()
360

361
362
363
        image_break_id = vocab[processor.image_break_token]
        image_token_id = hf_config.image_token_index
        image_end_id = vocab[processor.image_end_token]
364

365
366
        vision_config = hf_config.vision_config
        assert isinstance(vision_config, PixtralVisionConfig)
367
        encoder_info = PixtralHFEncoderInfo(vision_config)
368

369
370
371
        def get_replacement(item_idx: int):
            images = mm_items.get_items("image", ImageProcessorItems)
            image_size = images.get_image_size(item_idx)
372

373
            ncols, nrows = encoder_info.get_patch_grid_size(
374
375
376
                image_width=image_size.width,
                image_height=image_size.height,
            )
377

378
379
            tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
            tokens[-1] = image_end_id
380

381
            return PromptUpdateDetails.select_token_id(tokens, image_token_id)
382
383
384
385
386

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
387
388
                replacement=get_replacement,
            ),
389
390
        ]

391

392
393
394
395
396
397
398
399
400
401
def _build_llava_or_pixtral_hf_info(
    ctx: InputProcessingContext, ) -> BaseLlavaProcessingInfo:
    hf_config = ctx.get_hf_config(LlavaConfig)

    if isinstance(hf_config.vision_config, PixtralVisionConfig):
        return PixtralHFProcessingInfo(ctx)

    return LlavaProcessingInfo(ctx)


402
def _build_llava_or_pixtral_hf_processor(
403
404
    info: _I,
    dummy_inputs: BaseDummyInputsBuilder[_I],
405
406
407
    *,
    cache: Optional[ProcessingCache] = None,
    enable_sanity_checks: bool = True,
408
) -> BaseMultiModalProcessor:
409
    if isinstance(info, PixtralHFProcessingInfo):
410
        return PixtralHFMultiModalProcessor(
411
412
413
414
415
416
417
418
419
420
            info,
            dummy_inputs,  # type: ignore
            cache=cache,
            enable_sanity_checks=enable_sanity_checks,
        )

    if isinstance(info, LlavaProcessingInfo):
        return LlavaMultiModalProcessor(
            info,
            dummy_inputs,  # type: ignore
421
422
            cache=cache,
            enable_sanity_checks=enable_sanity_checks,
423
        )
424

425
    raise NotImplementedError(type(info))
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448


def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
    """Determine the number of hidden layers to initialize up to in the
    visual encoder.
    
    Args:
        hf_config: Model config with vision feature layer(s).
    """
    feature_layers = hf_config.vision_feature_layer
    num_hidden_layers = hf_config.vision_config.num_hidden_layers
    # If we have one feature layer, initialize up to that layer
    if isinstance(feature_layers, int):
        return _get_layer_index(feature_layers, num_hidden_layers)
    # If we have multiple feature layers, initialize up to the deepest one
    elif isinstance(feature_layers, (list, tuple)):
        return max(
            _get_layer_index(idx, num_hidden_layers) for idx in feature_layers)
    raise TypeError(f"vision_layer_feature type: {type(feature_layers)}"
                    " is not supported")


def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
449
    """Given a signed vision feature layer, get the number of hidden layers
450
451
452
453
454
455
456
457
458
    needed to leverage it.

    Args:
        feature_layer_index: Index of a required layer in the visual encoder.
        num_hidden_layers: The total number of hidden layers in the visual
            encoder.
    """
    if feature_layer_index < 0:
        return num_hidden_layers + feature_layer_index + 1
459
    return feature_layer_index
460
461
462
463
464
465
466


def init_vision_tower_for_llava(
    hf_config: LlavaLikeConfig,
    quant_config: Optional[QuantizationConfig],
    *,
    require_post_norm: Optional[bool] = None,
467
    prefix: str = "",
468
) -> Union[CLIPVisionModel, SiglipVisionModel, PixtralHFVisionModel]:
469
470
    vision_config = hf_config.vision_config

471
472
    # Initialize the vision tower only up to the deepest required feature layer
    num_hidden_layers = _get_num_hidden_layers(hf_config)
473
474
475
476

    if isinstance(vision_config, CLIPVisionConfig):
        return CLIPVisionModel(
            vision_config,
477
            quant_config=quant_config,
478
            num_hidden_layers_override=num_hidden_layers,
479
            require_post_norm=require_post_norm,
480
            prefix=prefix,
481
482
483
484
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return SiglipVisionModel(
            vision_config,
485
            quant_config=quant_config,
486
            num_hidden_layers_override=num_hidden_layers,
487
            require_post_norm=require_post_norm,
488
            prefix=prefix,
489
        )
490
    elif isinstance(vision_config, PixtralVisionConfig):
491
492
        return PixtralHFVisionModel(
            vision_config,
493
            quant_config=quant_config,
494
495
            num_hidden_layers_override=num_hidden_layers,
            require_post_norm=require_post_norm,
496
            prefix=prefix,
497
        )
498
499
500
501
502

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


503
504
505
@MULTIMODAL_REGISTRY.register_processor(_build_llava_or_pixtral_hf_processor,
                                        info=_build_llava_or_pixtral_hf_info,
                                        dummy_inputs=LlavaDummyInputsBuilder)
506
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
507
508
509
510

    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"]
511
    }
512

513
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
514
        super().__init__()
515

516
517
518
519
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

520
        self.config = config
521
        self.multimodal_config = multimodal_config
522

523
524
525
526
527
528
529
530
531
        # NOTE: These are special cases for Pixtral-12B in the HF-format
        # https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json  # noqa
        if (config.text_config.architectures is None
                and config.text_config.model_type == "mistral"):
            config.text_config.architectures = ["MistralForCausalLM"]
        if (config.projector_hidden_act is None
                and config.vision_config.hidden_act == "gelu"):
            config.projector_hidden_act = "gelu"

532
        # TODO: Optionally initializes this for supporting embeddings.
533
        self.vision_tower = init_vision_tower_for_llava(
534
535
536
            config,
            quant_config,
            require_post_norm=False,
537
            prefix=maybe_prefix(prefix, "vision_tower"))
538
539
540
        self.multi_modal_projector = LlavaMultiModalProjector(
            vision_hidden_size=config.vision_config.hidden_size,
            text_hidden_size=config.text_config.hidden_size,
541
            projector_hidden_act=config.projector_hidden_act,
542
            multimodal_projector_bias=config.multimodal_projector_bias,
543
544
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "multi_modal_projector"))
545

546
        self.language_model = init_vllm_registered_model(
547
            vllm_config=vllm_config,
548
549
550
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )
551

552
553
554
555
556
557
558
559
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

    @cached_property
    def sampler(self):
        if hasattr(self.language_model, "sampler"):
            return self.language_model.sampler

Joe Runde's avatar
Joe Runde committed
560
        return get_sampler()
561

562
563
564
565
566
567
568
    def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)
        actual_dims = tuple(data.shape[1:])

        if actual_dims != expected_dims:
            expected_expr = ("batch_size", *map(str, expected_dims))
569
            raise ValueError(
570
571
                f"The expected shape of pixel values is {expected_expr}. "
                f"You supplied {tuple(data.shape)}.")
572
573
574
575

        return data

    def _parse_and_validate_image_input(
576
577
            self, **kwargs: object) -> Optional[LlavaImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
578
        image_embeds = kwargs.pop("image_embeds", None)
579

580
        if pixel_values is None and image_embeds is None:
581
            return None
582

583
        if pixel_values is not None:
584
            if not isinstance(pixel_values, (torch.Tensor, list)):
585
586
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")
587

588
            if self.config.vision_config.model_type == "pixtral":
589
590
591
                return PixtralHFImagePixelInputs(
                    type="pixel_values_pixtral",
                    pixel_values=flatten_bn(pixel_values),
592
593
                )

594
595
            return LlavaImagePixelInputs(
                type="pixel_values",
596
                pixel_values=self._validate_pixel_values(
597
                    flatten_bn(pixel_values, concat=True)),
598
599
600
            )

        if image_embeds is not None:
601
            if not isinstance(image_embeds, (torch.Tensor, list)):
602
603
                raise ValueError("Incorrect type of image embeddings. "
                                 f"Got type: {type(image_embeds)}")
604

605
606
607
            if self.config.vision_config.model_type == "pixtral":
                raise ValueError("Pixtral-HF does not support image_embeds.")

608
609
            return LlavaImageEmbeddingInputs(
                type="image_embeds",
610
                data=flatten_bn(image_embeds, concat=True),
611
612
613
            )

        raise AssertionError("This line should be unreachable.")
614
615
616
617
618
619
620
621
622
623
624

    def _select_image_features(self, image_features: torch.Tensor, *,
                               strategy: str) -> torch.Tensor:
        # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421  # noqa
        if strategy == "default":
            return image_features[:, 1:]
        elif strategy == "full":
            return image_features

        raise ValueError(f"Unexpected select feature strategy: {strategy}")

625
626
    def _image_pixels_to_features(
        self,
627
628
        vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
                            PixtralHFVisionModel],
629
        pixel_values: Union[torch.Tensor, list[torch.Tensor]],
630
    ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
631
632
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
633
        image_features = vision_tower(pixel_values)
634

635
636
637
638
639
640
641
642
643
        def select_features(leaf: torch.Tensor):
            return self._select_image_features(
                leaf,
                strategy=self.config.vision_feature_select_strategy,
            )

        return cast(
            Union[torch.Tensor, tuple[torch.Tensor, ...]],
            json_map_leaves(select_features, image_features),
644
645
        )

646
647
648
    def _process_image_pixels(
        self,
        inputs: Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs],
649
    ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
650
651
        assert self.vision_tower is not None

652
        pixel_values = inputs["pixel_values"]
653
654
655

        return self._image_pixels_to_features(self.vision_tower, pixel_values)

656
657
658
659
    def _process_image_input(
        self,
        image_input: LlavaImageInputs,
    ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
660
661
662
        if image_input["type"] == "image_embeds":
            return image_input["data"]

663
664
        assert self.vision_tower is not None
        image_features = self._process_image_pixels(image_input)
665

666
667
668
669
670
671
672
673
674
675
676
        if isinstance(image_features, torch.Tensor):
            return self.multi_modal_projector(image_features)

        feature_sizes = [
            image_feature.shape[0] for image_feature in image_features
        ]

        image_embeds = self.multi_modal_projector(torch.cat(image_features))
        image_embeds = torch.split(image_embeds, feature_sizes)
        return image_embeds

677
678
    def get_multimodal_embeddings(
            self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
679
680
681
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return None
682

683
        return self._process_image_input(image_input)
684
685
686
687

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
688
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
689
690
    ) -> torch.Tensor:
        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
691
        if multimodal_embeddings is not None:
692
            inputs_embeds = merge_multimodal_embeddings(
693
694
                input_ids,
                inputs_embeds,
695
                multimodal_embeddings,
696
697
                self.config.image_token_index,
            )
698
699
        return inputs_embeds

700
701
702
703
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
704
        intermediate_tensors: Optional[IntermediateTensors] = None,
705
        inputs_embeds: Optional[torch.Tensor] = None,
706
        **kwargs: object,
707
    ) -> Union[torch.Tensor, IntermediateTensors]:
Cyrus Leung's avatar
Cyrus Leung committed
708
        """Run forward pass for LLaVA-1.5.
709
710
711

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
712

713
        Concretely, consider a text prompt:
714
715
        `"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.

716
        Tokenizer outputs:
717
718
719
720
        `[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
        278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
721
        before they are inputted to the model, so the input processor prepends
722
723
724
725
726
727
728
729
730
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
        29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
        29901]`.

        We insert 575 tokens so that including the original image token in the
        input, there are a total of 576 (24 * 24) image tokens, which
        corresponds to the number of image tokens inputted to the language
        model, i.e. the number of image tokens outputted by the visual encoder.
731
732
733
734
735
736
737

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
Cyrus Leung's avatar
Cyrus Leung committed
738
            pixel_values: The pixels in each input image.
739

740
741
        See also:
            :class:`LlavaImageInputs`
742
        """
743
744
        if intermediate_tensors is not None:
            inputs_embeds = None
745
746
747

        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
748
        elif inputs_embeds is None:
749
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
750
751
752
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
            input_ids = None
753

754
755
        hidden_states = self.language_model.model(input_ids,
                                                  positions,
756
                                                  intermediate_tensors,
757
                                                  inputs_embeds=inputs_embeds)
758
759
760

        return hidden_states

761
762
763
764
765
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
766
767
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)
768
769
770
771
772
773

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
774
        return self.language_model.sample(logits, sampling_metadata)
775

776
777
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
778
        loader = AutoWeightsLoader(self)
779
        return loader.load_weights(weights)
780
781


782
783
class MantisProcessingInfo(LlavaProcessingInfo):

784
    def get_hf_processor(self, **kwargs: object):
785
786
787
        hf_config = self.get_hf_config()
        vision_info = self.get_vision_encoder_info()

788
789
        kwargs.setdefault("patch_size", vision_info.get_patch_size())

790
791
792
        if Version(TRANSFORMERS_VERSION) < Version("4.48"):
            # BUG: num_additional_image_tokens = 0 but treated as 1,
            # so we set vision_feature_select_strategy to None to offset this
793
            kwargs.setdefault("vision_feature_select_strategy", None)
794
795
        else:
            # FIXED: https://github.com/huggingface/transformers/pull/33424/files#diff-6a37acc21efcadaae622b079b2712a131131448ff64262bd219aa346aeec38faL150
796
797
798
799
            kwargs.setdefault(
                "vision_feature_select_strategy",
                hf_config.vision_feature_select_strategy,
            )
800

801
        return self.ctx.get_hf_processor(LlavaProcessor, **kwargs)
802
803


804
class MantisMultiModalProcessor(LlavaMultiModalProcessor):
805

806
807
    def apply(
        self,
808
        prompt: Union[str, list[int]],
809
810
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
811
        return_mm_hashes: bool = False,
812
    ) -> MultiModalInputs:
813
        hf_config = self.info.get_hf_config()
814
        image_token_id = hf_config.image_token_index
815
816

        # Assume that it doesn't depend on the image size
817
        num_image_tokens = self.info.get_num_image_tokens(
818
819
820
            image_width=-1,
            image_height=-1,
        )
821

822
823
        result = super().apply(prompt, mm_data, hf_processor_mm_kwargs,
                               return_mm_hashes)
824

825
826
        mm_items = self._to_mm_items(mm_data)
        mm_item_counts = mm_items.get_all_counts()
827
        mm_kwargs = result["mm_kwargs"]
828
        mm_hashes = result["mm_hashes"]
829
830
831
832
833
834

        # We reimplement the functionality of MLlavaProcessor from
        # https://github.com/TIGER-AI-Lab/Mantis.git
        def get_replacement_mantis(item_idx: int):
            return "".join([
                f"(image {item_idx+1}: <Image>",  # 7 tokens
835
                "<image>" * num_image_tokens,
836
837
838
                "</Image>)",  # 3 tokens
            ])

839
        mantis_mm_repls = self._bind_and_group_updates([
840
841
            PromptReplacement(
                modality="image",
842
                target=[image_token_id] * num_image_tokens,
843
844
845
846
                replacement=get_replacement_mantis,
            )
        ])

847
        prompt_ids, prompt, _ = self._apply_prompt_updates(
848
            result["prompt_token_ids"],
849
            mantis_mm_repls,
850
851
852
            mm_item_counts,
        )

853
        unbound_orig_repls = self._get_prompt_updates(
854
855
856
857
            mm_items,
            hf_processor_mm_kwargs,
            mm_kwargs,
        )
858
        orig_repls = self._bind_and_group_updates(unbound_orig_repls)
859
860
861
862
863
864
865

        mm_placeholders = self._find_mm_placeholders(
            orig_repls,
            prompt_ids,
            mm_item_counts,
        )
        self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
866

867
868
869
        mm_placeholder_ranges = {
            modality: [item.to_range() for item in placeholders]
            for modality, placeholders in mm_placeholders.items()
870
871
        }

872
        return MultiModalInputs(
873
            type="multimodal",
874
            prompt=prompt,
875
876
            prompt_token_ids=prompt_ids,
            mm_kwargs=mm_kwargs,
877
            mm_hashes=mm_hashes,
878
            mm_placeholders=mm_placeholder_ranges,
879
        )
880
881
882
883


# To use this model, please use
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
884
@MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor,
885
                                        info=MantisProcessingInfo,
886
                                        dummy_inputs=LlavaDummyInputsBuilder)
887
888
class MantisForConditionalGeneration(LlavaForConditionalGeneration):
    pass