llava.py 33.8 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from abc import abstractmethod
4
from collections.abc import Iterable, Mapping, Sequence
5
from functools import cached_property
6
7
from typing import (Final, Literal, Optional, Protocol, Set, Tuple, TypedDict,
                    TypeVar, Union, cast)
8
9

import torch
10
import torch.nn as nn
11
from packaging.version import Version
12
13
from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig,
                          PixtralVisionConfig, PretrainedConfig,
14
                          SiglipVisionConfig)
15
from transformers import __version__ as TRANSFORMERS_VERSION
16
from transformers.models.llava import LlavaProcessor
17
from transformers.models.pixtral import PixtralProcessor
18

19
from vllm.config import VllmConfig
20
from vllm.inputs import InputProcessingContext
21
from vllm.jsontree import json_map_leaves
22
from vllm.model_executor.layers.activation import get_act_fn
23
24
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               RowParallelLinear)
25
from vllm.model_executor.layers.quantization import QuantizationConfig
Joe Runde's avatar
Joe Runde committed
26
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
27
from vllm.model_executor.sampling_metadata import SamplingMetadata
28
from vllm.multimodal import MULTIMODAL_REGISTRY
29
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
30
                                    MultiModalInputs, MultiModalKwargs)
31
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
32
                                   ImageSize, MultiModalDataItems)
33
from vllm.multimodal.processing import (BaseMultiModalProcessor,
34
                                        BaseProcessingInfo, ProcessingCache,
35
                                        PromptReplacement, PromptUpdate)
36
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
37
from vllm.sequence import IntermediateTensors
38

39
from .clip import CLIPVisionModel
40
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
41
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
42
from .siglip import SiglipVisionModel
43
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
44
                    maybe_prefix, merge_multimodal_embeddings)
45
46
from .vision import (get_vision_encoder_info, scatter_patch_features,
                     select_patch_features)
47
48


49
50
class LlavaImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
51
    pixel_values: torch.Tensor
52
53
54
55
56
57
    """
    Shape: `(batch_size * num_images, num_channels, height, width)`

    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """
58

59
60
61
62
63
64
65
66
67
68
69
70

class PixtralHFImagePixelInputs(TypedDict):
    type: Literal["pixel_values_pixtral"]
    pixel_values: Union[torch.Tensor, list[torch.Tensor]]
    """
    Shape: `(batch_size * num_images, num_channels, height, width)`

    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """

    embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
71
72
73
74
    """
    A boolean mask indicating which image embeddings correspond
    to patch tokens.
    
75
    Shape: `(batch_size * num_images, num_embeds)`
76
77
    """

78
79
80
81

class LlavaImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
    data: torch.Tensor
82
    """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
83
84
85
86
87

    `hidden_size` must match the hidden size of language model backbone.
    """


88
89
LlavaImageInputs = Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs,
                         LlavaImageEmbeddingInputs]
90
91


92
93
class LlavaMultiModalProjector(nn.Module):

94
95
96
97
    def __init__(self,
                 vision_hidden_size: int,
                 text_hidden_size: int,
                 projector_hidden_act: str,
98
                 multimodal_projector_bias: bool,
99
100
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
101
102
        super().__init__()

103
104
        self.linear_1 = ColumnParallelLinear(vision_hidden_size,
                                             text_hidden_size,
105
                                             bias=multimodal_projector_bias,
106
107
                                             quant_config=quant_config,
                                             prefix=f"{prefix}.linear_1")
108
        self.act = get_act_fn(projector_hidden_act)
109
110
        self.linear_2 = RowParallelLinear(text_hidden_size,
                                          text_hidden_size,
111
                                          bias=multimodal_projector_bias,
112
113
                                          quant_config=quant_config,
                                          prefix=f"{prefix}.linear_2")
114

115
    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
116
        hidden_states, _ = self.linear_1(image_features)
117
        hidden_states = self.act(hidden_states)
118
        hidden_states, _ = self.linear_2(hidden_states)
119
120
121
        return hidden_states


122
123
class LlavaLikeConfig(Protocol):
    vision_config: Final[PretrainedConfig]
124
    image_token_index: Final[int]
125
    vision_feature_select_strategy: Final[str]
126
    vision_feature_layer: Final[Union[int, list[int]]]
127

128

129
130
131
132
class LlavaLikeProcessor(Protocol):
    image_token: Final[str]


133
class BaseLlavaProcessingInfo(BaseProcessingInfo):
134

135
    def get_hf_config(self) -> LlavaLikeConfig:
136
        return self.ctx.get_hf_config(LlavaConfig)
137

138
139
    def get_vision_encoder_info(self):
        return get_vision_encoder_info(self.get_hf_config())
140

141
    @abstractmethod
142
    def get_hf_processor(self, **kwargs: object) -> LlavaLikeProcessor:
143
        raise NotImplementedError
144

145
146
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None}
147

148
149
150
151
152
    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> Mapping[str, int]:
153
        return {"image": self.get_max_image_tokens()}
154

155
156
157
158
159
160
161
162
163
164
165
166
167
    def _apply_feature_select_strategy(
        self,
        strategy: str,
        encoder_num_image_tokens: int,
    ) -> int:
        if strategy == "default":
            return encoder_num_image_tokens - 1
        if strategy == "full":
            return encoder_num_image_tokens

        msg = f"Unexpected feature select strategy: {strategy!r}"
        raise NotImplementedError(msg)

168
169
170
171
172
173
174
175
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        hf_config = self.get_hf_config()
        vision_encoder_info = self.get_vision_encoder_info()
176

177
178
179
180
181
182
183
        return self._apply_feature_select_strategy(
            hf_config.vision_feature_select_strategy,
            vision_encoder_info.get_num_image_tokens(
                image_width=image_width,
                image_height=image_height,
            ),
        )
184

185
186
    def get_image_size_with_most_features(self) -> ImageSize:
        vision_encoder_info = self.get_vision_encoder_info()
187
188
        width = height = vision_encoder_info.get_image_size()
        return ImageSize(width=width, height=height)
189

190
191
    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()
192

193
        return self.get_num_image_tokens(
194
195
196
197
            image_width=target_width,
            image_height=target_height,
        )

198
199
200
201
202
203

_I = TypeVar("_I", bound=BaseLlavaProcessingInfo)


class LlavaDummyInputsBuilder(BaseDummyInputsBuilder[_I]):

204
    def get_dummy_processor_inputs(
205
        self,
206
        seq_len: int,
207
208
209
210
        mm_counts: Mapping[str, int],
    ) -> ProcessorInputs:
        num_images = mm_counts.get("image", 0)

211
        processor = self.info.get_hf_processor()
212
        image_token = processor.image_token
213
214
        target_width, target_height = \
            self.info.get_image_size_with_most_features()
215
216
217
218
219
220
221
222
223
224
225
226
227
228

        mm_data = {
            "image":
            self._get_dummy_images(width=target_width,
                                   height=target_height,
                                   num_images=num_images)
        }

        return ProcessorInputs(
            prompt_text=image_token * num_images,
            mm_data=mm_data,
        )


229
class LlavaProcessingInfo(BaseLlavaProcessingInfo):
230

231
    def get_hf_processor(self, **kwargs: object):
232
233
234
235
236
237
238
        hf_processor = self.ctx.get_hf_processor(LlavaProcessor, **kwargs)
        # In case patch_size is omitted from `processor_config.json`
        # e.g. for E5-V: https://huggingface.co/royokong/e5-v
        if hf_processor.patch_size is None:
            patch_size = self.get_vision_encoder_info().get_patch_size()
            hf_processor.patch_size = patch_size
        return hf_processor
239
240


241
class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor[_I]):
242
243
244
245
246
247
248
249
250

    # Copied from BaseMultiModalProcessor
    @abstractmethod
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        raise NotImplementedError
251

252
    def _get_prompt_updates(
253
254
255
256
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
257
    ) -> Sequence[PromptUpdate]:
258
        hf_config = self.info.get_hf_config()
259
260
261
262
263
264
265
266
267
268
        image_token_id = hf_config.image_token_index

        def get_replacement(item_idx: int):
            images = mm_items.get_items(
                "image", (ImageEmbeddingItems, ImageProcessorItems))

            if isinstance(images, ImageEmbeddingItems):
                num_image_tokens = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
269
                num_image_tokens = self.info.get_num_image_tokens(
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
                    image_width=image_size.width,
                    image_height=image_size.height,
                )

            return [image_token_id] * num_image_tokens

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
                replacement=get_replacement,
            ),
        ]


285
286
class LlavaMultiModalProcessor(
        BaseLlavaMultiModalProcessor[LlavaProcessingInfo]):
287

288
289
290
291
292
293
294
295
296
297
298
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )


299
class PixtralHFProcessingInfo(BaseLlavaProcessingInfo):
300

301
302
    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(PixtralProcessor, **kwargs)
303

304

305
306
class PixtralHFMultiModalProcessor(
        BaseMultiModalProcessor[PixtralHFProcessingInfo]):
307

308
309
310
311
312
313
314
315
316
317
318
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
        )
319

320
321
        pixel_values = processed_outputs.get("pixel_values")
        if pixel_values is not None:
322
            # Before/after https://github.com/huggingface/transformers/pull/35122
323
            if Version(TRANSFORMERS_VERSION) <= Version("4.48.3"):
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
                images = mm_data["images"]
                assert isinstance(images, list)

                # Original output: (1, num_images, C, H, W)
                # New output: (num_images, C, H, W)
                assert (isinstance(pixel_values, list)
                        and len(pixel_values) == 1)
                assert (isinstance(pixel_values[0], list)
                        and len(pixel_values[0]) == len(images))

                processed_outputs["pixel_values"] = pixel_values[0]
            else:
                # Avoid padding since we need the output for each image to be
                # independent of other images for the cache to work correctly
                image_sizes = processed_outputs["image_sizes"]
                assert len(pixel_values) == len(image_sizes)

                processed_outputs["pixel_values"] = [
                    p[:, :h, :w]
                    for p, (h, w) in zip(pixel_values, image_sizes)
                ]
345

346
            hf_config = self.info.get_hf_config()
347
348
349
            vision_config = hf_config.vision_config
            assert isinstance(vision_config, PixtralVisionConfig)
            encoder_info = PixtralHFEncoderInfo(vision_config)
350
351

            tile_sizes = [
352
                encoder_info.get_patch_grid_size(
353
                    image_width=pixel_value.shape[-1],
354
355
                    image_height=pixel_value.shape[-2],
                ) for pixel_value in processed_outputs["pixel_values"]
356
            ]
357
358
359
360
            embed_is_patch = [
                torch.tensor(([True] * ncols + [False]) * nrows)
                for ncols, nrows in tile_sizes
            ]
361
362
            processed_outputs["embed_is_patch"] = embed_is_patch

363
        return processed_outputs
364

365
366
367
368
369
370
371
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
372
            embed_is_patch=MultiModalFieldConfig.batched("image"),
373
374
375
            image_embeds=MultiModalFieldConfig.batched("image"),
        )

376
    def _get_prompt_updates(
377
378
        self,
        mm_items: MultiModalDataItems,
379
380
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
381
    ) -> Sequence[PromptUpdate]:
382
        processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
383
        hf_config = self.info.get_hf_config()
384
385
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()
386

387
388
389
        image_break_id = vocab[processor.image_break_token]
        image_token_id = hf_config.image_token_index
        image_end_id = vocab[processor.image_end_token]
390

391
392
        vision_config = hf_config.vision_config
        assert isinstance(vision_config, PixtralVisionConfig)
393
        encoder_info = PixtralHFEncoderInfo(vision_config)
394

395
396
397
        def get_replacement(item_idx: int):
            images = mm_items.get_items("image", ImageProcessorItems)
            image_size = images.get_image_size(item_idx)
398

399
            ncols, nrows = encoder_info.get_patch_grid_size(
400
401
402
                image_width=image_size.width,
                image_height=image_size.height,
            )
403

404
405
            tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
            tokens[-1] = image_end_id
406

407
            return tokens
408
409
410
411
412

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
413
414
                replacement=get_replacement,
            ),
415
416
        ]

417

418
419
420
421
422
423
424
425
426
427
def _build_llava_or_pixtral_hf_info(
    ctx: InputProcessingContext, ) -> BaseLlavaProcessingInfo:
    hf_config = ctx.get_hf_config(LlavaConfig)

    if isinstance(hf_config.vision_config, PixtralVisionConfig):
        return PixtralHFProcessingInfo(ctx)

    return LlavaProcessingInfo(ctx)


428
def _build_llava_or_pixtral_hf_processor(
429
430
    info: _I,
    dummy_inputs: BaseDummyInputsBuilder[_I],
431
432
433
    *,
    cache: Optional[ProcessingCache] = None,
    enable_sanity_checks: bool = True,
434
) -> BaseMultiModalProcessor:
435
    if isinstance(info, PixtralHFProcessingInfo):
436
        return PixtralHFMultiModalProcessor(
437
438
439
440
441
442
443
444
445
446
            info,
            dummy_inputs,  # type: ignore
            cache=cache,
            enable_sanity_checks=enable_sanity_checks,
        )

    if isinstance(info, LlavaProcessingInfo):
        return LlavaMultiModalProcessor(
            info,
            dummy_inputs,  # type: ignore
447
448
            cache=cache,
            enable_sanity_checks=enable_sanity_checks,
449
        )
450

451
    raise NotImplementedError(type(info))
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474


def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
    """Determine the number of hidden layers to initialize up to in the
    visual encoder.
    
    Args:
        hf_config: Model config with vision feature layer(s).
    """
    feature_layers = hf_config.vision_feature_layer
    num_hidden_layers = hf_config.vision_config.num_hidden_layers
    # If we have one feature layer, initialize up to that layer
    if isinstance(feature_layers, int):
        return _get_layer_index(feature_layers, num_hidden_layers)
    # If we have multiple feature layers, initialize up to the deepest one
    elif isinstance(feature_layers, (list, tuple)):
        return max(
            _get_layer_index(idx, num_hidden_layers) for idx in feature_layers)
    raise TypeError(f"vision_layer_feature type: {type(feature_layers)}"
                    " is not supported")


def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
475
    """Given a signed vision feature layer, get the number of hidden layers
476
477
478
479
480
481
482
483
484
    needed to leverage it.

    Args:
        feature_layer_index: Index of a required layer in the visual encoder.
        num_hidden_layers: The total number of hidden layers in the visual
            encoder.
    """
    if feature_layer_index < 0:
        return num_hidden_layers + feature_layer_index + 1
485
    return feature_layer_index
486
487
488
489
490
491
492


def init_vision_tower_for_llava(
    hf_config: LlavaLikeConfig,
    quant_config: Optional[QuantizationConfig],
    *,
    require_post_norm: Optional[bool] = None,
493
    prefix: str = "",
494
) -> Union[CLIPVisionModel, SiglipVisionModel, PixtralHFVisionModel]:
495
496
    vision_config = hf_config.vision_config

497
498
    # Initialize the vision tower only up to the deepest required feature layer
    num_hidden_layers = _get_num_hidden_layers(hf_config)
499
500
501
502

    if isinstance(vision_config, CLIPVisionConfig):
        return CLIPVisionModel(
            vision_config,
503
            quant_config=quant_config,
504
            num_hidden_layers_override=num_hidden_layers,
505
            require_post_norm=require_post_norm,
506
            prefix=prefix,
507
508
509
510
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return SiglipVisionModel(
            vision_config,
511
            quant_config=quant_config,
512
            num_hidden_layers_override=num_hidden_layers,
513
            require_post_norm=require_post_norm,
514
            prefix=prefix,
515
        )
516
    elif isinstance(vision_config, PixtralVisionConfig):
517
518
        return PixtralHFVisionModel(
            vision_config,
519
            quant_config=quant_config,
520
521
            num_hidden_layers_override=num_hidden_layers,
            require_post_norm=require_post_norm,
522
            prefix=prefix,
523
        )
524
525
526
527
528

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


529
530
531
@MULTIMODAL_REGISTRY.register_processor(_build_llava_or_pixtral_hf_processor,
                                        info=_build_llava_or_pixtral_hf_info,
                                        dummy_inputs=LlavaDummyInputsBuilder)
532
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
533
534
535
536

    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"]
537
    }
538

539
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
540
        super().__init__()
541

542
543
544
545
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

546
        self.config = config
547
        self.multimodal_config = multimodal_config
548

549
550
551
552
553
554
555
556
557
        # NOTE: These are special cases for Pixtral-12B in the HF-format
        # https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json  # noqa
        if (config.text_config.architectures is None
                and config.text_config.model_type == "mistral"):
            config.text_config.architectures = ["MistralForCausalLM"]
        if (config.projector_hidden_act is None
                and config.vision_config.hidden_act == "gelu"):
            config.projector_hidden_act = "gelu"

558
        # TODO: Optionally initializes this for supporting embeddings.
559
        self.vision_tower = init_vision_tower_for_llava(
560
561
562
            config,
            quant_config,
            require_post_norm=False,
563
            prefix=maybe_prefix(prefix, "vision_tower"))
564
565
566
        self.multi_modal_projector = LlavaMultiModalProjector(
            vision_hidden_size=config.vision_config.hidden_size,
            text_hidden_size=config.text_config.hidden_size,
567
            projector_hidden_act=config.projector_hidden_act,
568
            multimodal_projector_bias=config.multimodal_projector_bias,
569
570
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "multi_modal_projector"))
571

572
        self.language_model = init_vllm_registered_model(
573
            vllm_config=vllm_config,
574
575
576
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )
577

578
579
580
581
582
583
584
585
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

    @cached_property
    def sampler(self):
        if hasattr(self.language_model, "sampler"):
            return self.language_model.sampler

Joe Runde's avatar
Joe Runde committed
586
        return get_sampler()
587

588
589
590
591
592
593
594
    def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)
        actual_dims = tuple(data.shape[1:])

        if actual_dims != expected_dims:
            expected_expr = ("batch_size", *map(str, expected_dims))
595
            raise ValueError(
596
597
                f"The expected shape of pixel values is {expected_expr}. "
                f"You supplied {tuple(data.shape)}.")
598
599
600
601

        return data

    def _parse_and_validate_image_input(
602
603
            self, **kwargs: object) -> Optional[LlavaImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
604
        image_embeds = kwargs.pop("image_embeds", None)
605

606
        if pixel_values is None and image_embeds is None:
607
            return None
608

609
        if pixel_values is not None:
610
            if not isinstance(pixel_values, (torch.Tensor, list)):
611
612
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")
613

614
            if self.config.vision_config.model_type == "pixtral":
615
616
617
618
619
                embed_is_patch = kwargs.pop("embed_is_patch")
                if not isinstance(embed_is_patch, (torch.Tensor, list)):
                    raise ValueError("Incorrect type of embed_is_patch. "
                                     f"Got type: {type(embed_is_patch)}")

620
621
                embed_is_patch = flatten_bn(embed_is_patch)

622
623
624
                return PixtralHFImagePixelInputs(
                    type="pixel_values_pixtral",
                    pixel_values=flatten_bn(pixel_values),
625
                    embed_is_patch=embed_is_patch,
626
627
                )

628
629
            return LlavaImagePixelInputs(
                type="pixel_values",
630
                pixel_values=self._validate_pixel_values(
631
                    flatten_bn(pixel_values, concat=True)),
632
633
634
            )

        if image_embeds is not None:
635
            if not isinstance(image_embeds, (torch.Tensor, list)):
636
637
                raise ValueError("Incorrect type of image embeddings. "
                                 f"Got type: {type(image_embeds)}")
638

639
640
641
            if self.config.vision_config.model_type == "pixtral":
                raise ValueError("Pixtral-HF does not support image_embeds.")

642
643
            return LlavaImageEmbeddingInputs(
                type="image_embeds",
644
                data=flatten_bn(image_embeds, concat=True),
645
646
647
            )

        raise AssertionError("This line should be unreachable.")
648
649
650
651
652
653
654
655
656
657
658

    def _select_image_features(self, image_features: torch.Tensor, *,
                               strategy: str) -> torch.Tensor:
        # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421  # noqa
        if strategy == "default":
            return image_features[:, 1:]
        elif strategy == "full":
            return image_features

        raise ValueError(f"Unexpected select feature strategy: {strategy}")

659
660
    def _image_pixels_to_features(
        self,
661
662
        vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
                            PixtralHFVisionModel],
663
        pixel_values: Union[torch.Tensor, list[torch.Tensor]],
664
    ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
665
666
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
667
        image_features = vision_tower(pixel_values)
668

669
670
671
672
673
674
675
676
677
        def select_features(leaf: torch.Tensor):
            return self._select_image_features(
                leaf,
                strategy=self.config.vision_feature_select_strategy,
            )

        return cast(
            Union[torch.Tensor, tuple[torch.Tensor, ...]],
            json_map_leaves(select_features, image_features),
678
679
        )

680
681
682
    def _process_image_pixels(
        self,
        inputs: Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs],
683
    ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
684
685
        assert self.vision_tower is not None

686
        pixel_values = inputs["pixel_values"]
687
688
689

        return self._image_pixels_to_features(self.vision_tower, pixel_values)

690
691
692
693
    def _process_image_input(
        self,
        image_input: LlavaImageInputs,
    ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
694
695
696
        if image_input["type"] == "image_embeds":
            return image_input["data"]

697
698
        assert self.vision_tower is not None
        image_features = self._process_image_pixels(image_input)
699

700
701
702
703
704
705
706
707
708
709
710
        if isinstance(image_features, torch.Tensor):
            return self.multi_modal_projector(image_features)

        feature_sizes = [
            image_feature.shape[0] for image_feature in image_features
        ]

        image_embeds = self.multi_modal_projector(torch.cat(image_features))
        image_embeds = torch.split(image_embeds, feature_sizes)
        return image_embeds

711
712
    def get_multimodal_embeddings(
            self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
713
714
715
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return None
716

717
        image_features = self._process_image_input(image_input)
718

719
        if image_input["type"] != "pixel_values_pixtral":
720
            # The path is used for pixtral (V0 only) and llava (V0/V1)
721
            return image_features
722

723
724
725
726
        return scatter_patch_features(
            image_features,
            image_input["embed_is_patch"],
        )
727
728
729
730

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
731
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
732
733
    ) -> torch.Tensor:
        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
734
        if multimodal_embeddings is not None:
735
            inputs_embeds = merge_multimodal_embeddings(
736
737
                input_ids,
                inputs_embeds,
738
                select_patch_features(multimodal_embeddings),
739
740
                self.config.image_token_index,
            )
741
742
        return inputs_embeds

743
744
745
746
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
747
        intermediate_tensors: Optional[IntermediateTensors] = None,
748
        inputs_embeds: Optional[torch.Tensor] = None,
749
        **kwargs: object,
750
    ) -> Union[torch.Tensor, IntermediateTensors]:
Cyrus Leung's avatar
Cyrus Leung committed
751
        """Run forward pass for LLaVA-1.5.
752
753
754

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
755

756
        Concretely, consider a text prompt:
757
758
        `"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.

759
        Tokenizer outputs:
760
761
762
763
        `[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
        278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
764
        before they are inputted to the model, so the input processor prepends
765
766
767
768
769
770
771
772
773
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
        29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
        29901]`.

        We insert 575 tokens so that including the original image token in the
        input, there are a total of 576 (24 * 24) image tokens, which
        corresponds to the number of image tokens inputted to the language
        model, i.e. the number of image tokens outputted by the visual encoder.
774
775
776
777
778
779
780

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
Cyrus Leung's avatar
Cyrus Leung committed
781
            pixel_values: The pixels in each input image.
782

783
784
        See also:
            :class:`LlavaImageInputs`
785
        """
786
787
        if intermediate_tensors is not None:
            inputs_embeds = None
788
789
790

        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
791
        elif inputs_embeds is None:
792
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
793
794
795
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
            input_ids = None
796

797
798
        hidden_states = self.language_model.model(input_ids,
                                                  positions,
799
                                                  intermediate_tensors,
800
                                                  inputs_embeds=inputs_embeds)
801
802
803

        return hidden_states

804
805
806
807
808
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
809
810
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)
811
812
813
814
815
816

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
817
        return self.language_model.sample(logits, sampling_metadata)
818

819
820
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
821
        loader = AutoWeightsLoader(self)
822
        return loader.load_weights(weights)
823
824


825
826
class MantisProcessingInfo(LlavaProcessingInfo):

827
    def get_hf_processor(self, **kwargs: object):
828
829
830
        hf_config = self.get_hf_config()
        vision_info = self.get_vision_encoder_info()

831
832
        kwargs.setdefault("patch_size", vision_info.get_patch_size())

833
834
835
        if Version(TRANSFORMERS_VERSION) < Version("4.48"):
            # BUG: num_additional_image_tokens = 0 but treated as 1,
            # so we set vision_feature_select_strategy to None to offset this
836
            kwargs.setdefault("vision_feature_select_strategy", None)
837
838
        else:
            # FIXED: https://github.com/huggingface/transformers/pull/33424/files#diff-6a37acc21efcadaae622b079b2712a131131448ff64262bd219aa346aeec38faL150
839
840
841
842
            kwargs.setdefault(
                "vision_feature_select_strategy",
                hf_config.vision_feature_select_strategy,
            )
843

844
        return self.ctx.get_hf_processor(LlavaProcessor, **kwargs)
845
846


847
class MantisMultiModalProcessor(LlavaMultiModalProcessor):
848

849
850
    def apply(
        self,
851
        prompt: Union[str, list[int]],
852
853
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
854
        return_mm_hashes: bool = False,
855
    ) -> MultiModalInputs:
856
        hf_config = self.info.get_hf_config()
857
        image_token_id = hf_config.image_token_index
858
859

        # Assume that it doesn't depend on the image size
860
        num_image_tokens = self.info.get_num_image_tokens(
861
862
863
            image_width=-1,
            image_height=-1,
        )
864

865
866
        result = super().apply(prompt, mm_data, hf_processor_mm_kwargs,
                               return_mm_hashes)
867

868
869
        mm_items = self._to_mm_items(mm_data)
        mm_item_counts = mm_items.get_all_counts()
870
871
872
873
874
875
876
        mm_kwargs = result["mm_kwargs"]

        # We reimplement the functionality of MLlavaProcessor from
        # https://github.com/TIGER-AI-Lab/Mantis.git
        def get_replacement_mantis(item_idx: int):
            return "".join([
                f"(image {item_idx+1}: <Image>",  # 7 tokens
877
                "<image>" * num_image_tokens,
878
879
880
                "</Image>)",  # 3 tokens
            ])

881
        mantis_mm_repls = self._bind_and_group_updates([
882
883
            PromptReplacement(
                modality="image",
884
                target=[image_token_id] * num_image_tokens,
885
886
887
888
                replacement=get_replacement_mantis,
            )
        ])

889
        prompt_ids, prompt, _ = self._apply_prompt_updates(
890
            result["prompt_token_ids"],
891
            mantis_mm_repls,
892
893
894
            mm_item_counts,
        )

895
        unbound_orig_repls = self._get_prompt_updates(
896
897
898
899
            mm_items,
            hf_processor_mm_kwargs,
            mm_kwargs,
        )
900
        orig_repls = self._bind_and_group_updates(unbound_orig_repls)
901
902
903
904
905
906
907

        mm_placeholders = self._find_mm_placeholders(
            orig_repls,
            prompt_ids,
            mm_item_counts,
        )
        self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
908

909
910
911
        mm_placeholder_ranges = {
            modality: [item.to_range() for item in placeholders]
            for modality, placeholders in mm_placeholders.items()
912
913
        }

914
        return MultiModalInputs(
915
            type="multimodal",
916
            prompt=prompt,
917
918
            prompt_token_ids=prompt_ids,
            mm_kwargs=mm_kwargs,
919
            mm_placeholders=mm_placeholder_ranges,
920
        )
921
922
923
924


# To use this model, please use
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
925
@MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor,
926
                                        info=MantisProcessingInfo,
927
                                        dummy_inputs=LlavaDummyInputsBuilder)
928
929
class MantisForConditionalGeneration(LlavaForConditionalGeneration):
    pass