llava.py 34 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from abc import abstractmethod
4
from collections.abc import Iterable, Mapping, Sequence
5
from functools import cached_property
6
7
from typing import (Final, Literal, Optional, Protocol, Set, Tuple, TypedDict,
                    TypeVar, Union, cast)
8
9

import torch
10
import torch.nn as nn
11
from packaging.version import Version
12
13
from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig,
                          PixtralVisionConfig, PretrainedConfig,
14
                          SiglipVisionConfig)
15
from transformers import __version__ as TRANSFORMERS_VERSION
16
from transformers.models.llava import LlavaProcessor
17
from transformers.models.pixtral import PixtralProcessor
18

19
from vllm.config import VllmConfig
20
from vllm.inputs import InputProcessingContext
21
from vllm.jsontree import json_map_leaves
22
from vllm.model_executor.layers.activation import get_act_fn
23
24
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               RowParallelLinear)
25
from vllm.model_executor.layers.quantization import QuantizationConfig
Joe Runde's avatar
Joe Runde committed
26
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
27
from vllm.model_executor.sampling_metadata import SamplingMetadata
28
from vllm.multimodal import MULTIMODAL_REGISTRY
29
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
30
                                    MultiModalInputs, MultiModalKwargs)
31
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
32
                                   ImageSize, MultiModalDataItems)
33
from vllm.multimodal.processing import (BaseMultiModalProcessor,
34
                                        BaseProcessingInfo, ProcessingCache,
35
                                        PromptReplacement, PromptUpdate)
36
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
37
from vllm.sequence import IntermediateTensors
38
from vllm.utils import flatten_2d_lists
39

40
from .clip import CLIPVisionModel
41
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
42
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
43
from .siglip import SiglipVisionModel
44
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
45
                    maybe_prefix, merge_multimodal_embeddings)
46
47
from .vision import (get_vision_encoder_info, scatter_patch_features,
                     select_patch_features)
48
49


50
51
class LlavaImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
52
    pixel_values: torch.Tensor
53
54
55
56
57
58
    """
    Shape: `(batch_size * num_images, num_channels, height, width)`

    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """
59

60
61
62
63
64
65
66
67
68
69
70
71

class PixtralHFImagePixelInputs(TypedDict):
    type: Literal["pixel_values_pixtral"]
    pixel_values: Union[torch.Tensor, list[torch.Tensor]]
    """
    Shape: `(batch_size * num_images, num_channels, height, width)`

    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """

    embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
72
73
74
75
    """
    A boolean mask indicating which image embeddings correspond
    to patch tokens.
    
76
    Shape: `(batch_size, num_images, num_embeds)`
77
78
    """

79
80
81
82

class LlavaImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
    data: torch.Tensor
83
    """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
84
85
86
87
88

    `hidden_size` must match the hidden size of language model backbone.
    """


89
90
LlavaImageInputs = Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs,
                         LlavaImageEmbeddingInputs]
91
92


93
94
class LlavaMultiModalProjector(nn.Module):

95
96
97
98
    def __init__(self,
                 vision_hidden_size: int,
                 text_hidden_size: int,
                 projector_hidden_act: str,
99
                 multimodal_projector_bias: bool,
100
101
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
102
103
        super().__init__()

104
105
        self.linear_1 = ColumnParallelLinear(vision_hidden_size,
                                             text_hidden_size,
106
                                             bias=multimodal_projector_bias,
107
108
                                             quant_config=quant_config,
                                             prefix=f"{prefix}.linear_1")
109
        self.act = get_act_fn(projector_hidden_act)
110
111
        self.linear_2 = RowParallelLinear(text_hidden_size,
                                          text_hidden_size,
112
                                          bias=multimodal_projector_bias,
113
114
                                          quant_config=quant_config,
                                          prefix=f"{prefix}.linear_2")
115

116
    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
117
        hidden_states, _ = self.linear_1(image_features)
118
        hidden_states = self.act(hidden_states)
119
        hidden_states, _ = self.linear_2(hidden_states)
120
121
122
        return hidden_states


123
124
class LlavaLikeConfig(Protocol):
    vision_config: Final[PretrainedConfig]
125
    image_token_index: Final[int]
126
    vision_feature_select_strategy: Final[str]
127
    vision_feature_layer: Final[Union[int, list[int]]]
128

129

130
131
132
133
class LlavaLikeProcessor(Protocol):
    image_token: Final[str]


134
class BaseLlavaProcessingInfo(BaseProcessingInfo):
135

136
    def get_hf_config(self) -> LlavaLikeConfig:
137
        return self.ctx.get_hf_config(LlavaConfig)
138

139
140
    def get_vision_encoder_info(self):
        return get_vision_encoder_info(self.get_hf_config())
141

142
    @abstractmethod
143
    def get_hf_processor(self, **kwargs: object) -> LlavaLikeProcessor:
144
        raise NotImplementedError
145

146
147
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None}
148

149
150
151
152
153
    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> Mapping[str, int]:
154
        return {"image": self.get_max_image_tokens()}
155

156
157
158
159
160
161
162
163
164
165
166
167
168
    def _apply_feature_select_strategy(
        self,
        strategy: str,
        encoder_num_image_tokens: int,
    ) -> int:
        if strategy == "default":
            return encoder_num_image_tokens - 1
        if strategy == "full":
            return encoder_num_image_tokens

        msg = f"Unexpected feature select strategy: {strategy!r}"
        raise NotImplementedError(msg)

169
170
171
172
173
174
175
176
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        hf_config = self.get_hf_config()
        vision_encoder_info = self.get_vision_encoder_info()
177

178
179
180
181
182
183
184
        return self._apply_feature_select_strategy(
            hf_config.vision_feature_select_strategy,
            vision_encoder_info.get_num_image_tokens(
                image_width=image_width,
                image_height=image_height,
            ),
        )
185

186
187
    def get_image_size_with_most_features(self) -> ImageSize:
        vision_encoder_info = self.get_vision_encoder_info()
188
189
        width = height = vision_encoder_info.get_image_size()
        return ImageSize(width=width, height=height)
190

191
192
    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()
193

194
        return self.get_num_image_tokens(
195
196
197
198
            image_width=target_width,
            image_height=target_height,
        )

199
200
201
202
203
204

_I = TypeVar("_I", bound=BaseLlavaProcessingInfo)


class LlavaDummyInputsBuilder(BaseDummyInputsBuilder[_I]):

205
    def get_dummy_processor_inputs(
206
        self,
207
        seq_len: int,
208
209
210
211
        mm_counts: Mapping[str, int],
    ) -> ProcessorInputs:
        num_images = mm_counts.get("image", 0)

212
        processor = self.info.get_hf_processor()
213
        image_token = processor.image_token
214
215
        target_width, target_height = \
            self.info.get_image_size_with_most_features()
216
217
218
219
220
221
222
223
224
225
226
227
228
229

        mm_data = {
            "image":
            self._get_dummy_images(width=target_width,
                                   height=target_height,
                                   num_images=num_images)
        }

        return ProcessorInputs(
            prompt_text=image_token * num_images,
            mm_data=mm_data,
        )


230
class LlavaProcessingInfo(BaseLlavaProcessingInfo):
231

232
    def get_hf_processor(self, **kwargs: object):
233
234
235
236
237
238
239
        hf_processor = self.ctx.get_hf_processor(LlavaProcessor, **kwargs)
        # In case patch_size is omitted from `processor_config.json`
        # e.g. for E5-V: https://huggingface.co/royokong/e5-v
        if hf_processor.patch_size is None:
            patch_size = self.get_vision_encoder_info().get_patch_size()
            hf_processor.patch_size = patch_size
        return hf_processor
240
241


242
class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor[_I]):
243
244
245
246
247
248
249
250
251

    # Copied from BaseMultiModalProcessor
    @abstractmethod
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        raise NotImplementedError
252

253
    def _get_prompt_updates(
254
255
256
257
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
258
    ) -> Sequence[PromptUpdate]:
259
        hf_config = self.info.get_hf_config()
260
261
262
263
264
265
266
267
268
269
        image_token_id = hf_config.image_token_index

        def get_replacement(item_idx: int):
            images = mm_items.get_items(
                "image", (ImageEmbeddingItems, ImageProcessorItems))

            if isinstance(images, ImageEmbeddingItems):
                num_image_tokens = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
270
                num_image_tokens = self.info.get_num_image_tokens(
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
                    image_width=image_size.width,
                    image_height=image_size.height,
                )

            return [image_token_id] * num_image_tokens

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
                replacement=get_replacement,
            ),
        ]


286
287
class LlavaMultiModalProcessor(
        BaseLlavaMultiModalProcessor[LlavaProcessingInfo]):
288

289
290
291
292
293
294
295
296
297
298
299
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )


300
class PixtralHFProcessingInfo(BaseLlavaProcessingInfo):
301

302
303
    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(PixtralProcessor, **kwargs)
304

305

306
307
class PixtralHFMultiModalProcessor(
        BaseMultiModalProcessor[PixtralHFProcessingInfo]):
308

309
310
311
312
313
314
315
316
317
318
319
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
        )
320

321
322
        pixel_values = processed_outputs.get("pixel_values")
        if pixel_values is not None:
323
            # Before/after https://github.com/huggingface/transformers/pull/35122
324
            if Version(TRANSFORMERS_VERSION) <= Version("4.48.3"):
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
                images = mm_data["images"]
                assert isinstance(images, list)

                # Original output: (1, num_images, C, H, W)
                # New output: (num_images, C, H, W)
                assert (isinstance(pixel_values, list)
                        and len(pixel_values) == 1)
                assert (isinstance(pixel_values[0], list)
                        and len(pixel_values[0]) == len(images))

                processed_outputs["pixel_values"] = pixel_values[0]
            else:
                # Avoid padding since we need the output for each image to be
                # independent of other images for the cache to work correctly
                image_sizes = processed_outputs["image_sizes"]
                assert len(pixel_values) == len(image_sizes)

                processed_outputs["pixel_values"] = [
                    p[:, :h, :w]
                    for p, (h, w) in zip(pixel_values, image_sizes)
                ]
346

347
            hf_config = self.info.get_hf_config()
348
349
350
            vision_config = hf_config.vision_config
            assert isinstance(vision_config, PixtralVisionConfig)
            encoder_info = PixtralHFEncoderInfo(vision_config)
351
352

            tile_sizes = [
353
                encoder_info.get_patch_grid_size(
354
                    image_width=pixel_value.shape[-1],
355
356
                    image_height=pixel_value.shape[-2],
                ) for pixel_value in processed_outputs["pixel_values"]
357
            ]
358
359
360
361
            embed_is_patch = [
                torch.tensor(([True] * ncols + [False]) * nrows)
                for ncols, nrows in tile_sizes
            ]
362
363
            processed_outputs["embed_is_patch"] = embed_is_patch

364
        return processed_outputs
365

366
367
368
369
370
371
372
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
373
            embed_is_patch=MultiModalFieldConfig.batched("image"),
374
375
376
            image_embeds=MultiModalFieldConfig.batched("image"),
        )

377
    def _get_prompt_updates(
378
379
        self,
        mm_items: MultiModalDataItems,
380
381
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
382
    ) -> Sequence[PromptUpdate]:
383
        processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
384
        hf_config = self.info.get_hf_config()
385
386
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()
387

388
389
390
        image_break_id = vocab[processor.image_break_token]
        image_token_id = hf_config.image_token_index
        image_end_id = vocab[processor.image_end_token]
391

392
393
        vision_config = hf_config.vision_config
        assert isinstance(vision_config, PixtralVisionConfig)
394
        encoder_info = PixtralHFEncoderInfo(vision_config)
395

396
397
398
        def get_replacement(item_idx: int):
            images = mm_items.get_items("image", ImageProcessorItems)
            image_size = images.get_image_size(item_idx)
399

400
            ncols, nrows = encoder_info.get_patch_grid_size(
401
402
403
                image_width=image_size.width,
                image_height=image_size.height,
            )
404

405
406
            tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
            tokens[-1] = image_end_id
407

408
            return tokens
409
410
411
412
413

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
414
415
                replacement=get_replacement,
            ),
416
417
        ]

418

419
420
421
422
423
424
425
426
427
428
def _build_llava_or_pixtral_hf_info(
    ctx: InputProcessingContext, ) -> BaseLlavaProcessingInfo:
    hf_config = ctx.get_hf_config(LlavaConfig)

    if isinstance(hf_config.vision_config, PixtralVisionConfig):
        return PixtralHFProcessingInfo(ctx)

    return LlavaProcessingInfo(ctx)


429
def _build_llava_or_pixtral_hf_processor(
430
431
    info: _I,
    dummy_inputs: BaseDummyInputsBuilder[_I],
432
433
434
    *,
    cache: Optional[ProcessingCache] = None,
    enable_sanity_checks: bool = True,
435
) -> BaseMultiModalProcessor:
436
    if isinstance(info, PixtralHFProcessingInfo):
437
        return PixtralHFMultiModalProcessor(
438
439
440
441
442
443
444
445
446
447
            info,
            dummy_inputs,  # type: ignore
            cache=cache,
            enable_sanity_checks=enable_sanity_checks,
        )

    if isinstance(info, LlavaProcessingInfo):
        return LlavaMultiModalProcessor(
            info,
            dummy_inputs,  # type: ignore
448
449
            cache=cache,
            enable_sanity_checks=enable_sanity_checks,
450
        )
451

452
    raise NotImplementedError(type(info))
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475


def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
    """Determine the number of hidden layers to initialize up to in the
    visual encoder.
    
    Args:
        hf_config: Model config with vision feature layer(s).
    """
    feature_layers = hf_config.vision_feature_layer
    num_hidden_layers = hf_config.vision_config.num_hidden_layers
    # If we have one feature layer, initialize up to that layer
    if isinstance(feature_layers, int):
        return _get_layer_index(feature_layers, num_hidden_layers)
    # If we have multiple feature layers, initialize up to the deepest one
    elif isinstance(feature_layers, (list, tuple)):
        return max(
            _get_layer_index(idx, num_hidden_layers) for idx in feature_layers)
    raise TypeError(f"vision_layer_feature type: {type(feature_layers)}"
                    " is not supported")


def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
476
    """Given a signed vision feature layer, get the number of hidden layers
477
478
479
480
481
482
483
484
485
    needed to leverage it.

    Args:
        feature_layer_index: Index of a required layer in the visual encoder.
        num_hidden_layers: The total number of hidden layers in the visual
            encoder.
    """
    if feature_layer_index < 0:
        return num_hidden_layers + feature_layer_index + 1
486
    return feature_layer_index
487
488
489
490
491
492
493


def init_vision_tower_for_llava(
    hf_config: LlavaLikeConfig,
    quant_config: Optional[QuantizationConfig],
    *,
    require_post_norm: Optional[bool] = None,
494
    prefix: str = "",
495
) -> Union[CLIPVisionModel, SiglipVisionModel, PixtralHFVisionModel]:
496
497
    vision_config = hf_config.vision_config

498
499
    # Initialize the vision tower only up to the deepest required feature layer
    num_hidden_layers = _get_num_hidden_layers(hf_config)
500
501
502
503

    if isinstance(vision_config, CLIPVisionConfig):
        return CLIPVisionModel(
            vision_config,
504
            quant_config=quant_config,
505
            num_hidden_layers_override=num_hidden_layers,
506
            require_post_norm=require_post_norm,
507
            prefix=prefix,
508
509
510
511
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return SiglipVisionModel(
            vision_config,
512
            quant_config=quant_config,
513
            num_hidden_layers_override=num_hidden_layers,
514
            require_post_norm=require_post_norm,
515
            prefix=prefix,
516
        )
517
    elif isinstance(vision_config, PixtralVisionConfig):
518
519
        return PixtralHFVisionModel(
            vision_config,
520
            quant_config=quant_config,
521
522
            num_hidden_layers_override=num_hidden_layers,
            require_post_norm=require_post_norm,
523
            prefix=prefix,
524
        )
525
526
527
528
529

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


530
531
532
@MULTIMODAL_REGISTRY.register_processor(_build_llava_or_pixtral_hf_processor,
                                        info=_build_llava_or_pixtral_hf_info,
                                        dummy_inputs=LlavaDummyInputsBuilder)
533
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
534
535
536
537

    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"]
538
    }
539

540
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
541
        super().__init__()
542

543
544
545
546
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

547
        self.config = config
548
        self.multimodal_config = multimodal_config
549

550
551
552
553
554
555
556
557
558
        # NOTE: These are special cases for Pixtral-12B in the HF-format
        # https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json  # noqa
        if (config.text_config.architectures is None
                and config.text_config.model_type == "mistral"):
            config.text_config.architectures = ["MistralForCausalLM"]
        if (config.projector_hidden_act is None
                and config.vision_config.hidden_act == "gelu"):
            config.projector_hidden_act = "gelu"

559
        # TODO: Optionally initializes this for supporting embeddings.
560
        self.vision_tower = init_vision_tower_for_llava(
561
562
563
            config,
            quant_config,
            require_post_norm=False,
564
            prefix=maybe_prefix(prefix, "vision_tower"))
565
566
567
        self.multi_modal_projector = LlavaMultiModalProjector(
            vision_hidden_size=config.vision_config.hidden_size,
            text_hidden_size=config.text_config.hidden_size,
568
            projector_hidden_act=config.projector_hidden_act,
569
            multimodal_projector_bias=config.multimodal_projector_bias,
570
571
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "multi_modal_projector"))
572

573
        self.language_model = init_vllm_registered_model(
574
            vllm_config=vllm_config,
575
576
577
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )
578

579
580
581
582
583
584
585
586
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

    @cached_property
    def sampler(self):
        if hasattr(self.language_model, "sampler"):
            return self.language_model.sampler

Joe Runde's avatar
Joe Runde committed
587
        return get_sampler()
588

589
590
591
592
593
594
595
    def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)
        actual_dims = tuple(data.shape[1:])

        if actual_dims != expected_dims:
            expected_expr = ("batch_size", *map(str, expected_dims))
596
            raise ValueError(
597
598
                f"The expected shape of pixel values is {expected_expr}. "
                f"You supplied {tuple(data.shape)}.")
599
600
601
602

        return data

    def _parse_and_validate_image_input(
603
604
            self, **kwargs: object) -> Optional[LlavaImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
605
        image_embeds = kwargs.pop("image_embeds", None)
606

607
        if pixel_values is None and image_embeds is None:
608
            return None
609

610
        if pixel_values is not None:
611
            if not isinstance(pixel_values, (torch.Tensor, list)):
612
613
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")
614

615
            if self.config.vision_config.model_type == "pixtral":
616
617
618
619
620
621
622
623
                embed_is_patch = kwargs.pop("embed_is_patch")
                if not isinstance(embed_is_patch, (torch.Tensor, list)):
                    raise ValueError("Incorrect type of embed_is_patch. "
                                     f"Got type: {type(embed_is_patch)}")

                return PixtralHFImagePixelInputs(
                    type="pixel_values_pixtral",
                    pixel_values=flatten_bn(pixel_values),
624
                    embed_is_patch=embed_is_patch,
625
626
                )

627
628
            return LlavaImagePixelInputs(
                type="pixel_values",
629
                pixel_values=self._validate_pixel_values(
630
                    flatten_bn(pixel_values, concat=True)),
631
632
633
            )

        if image_embeds is not None:
634
            if not isinstance(image_embeds, (torch.Tensor, list)):
635
636
                raise ValueError("Incorrect type of image embeddings. "
                                 f"Got type: {type(image_embeds)}")
637

638
639
640
            if self.config.vision_config.model_type == "pixtral":
                raise ValueError("Pixtral-HF does not support image_embeds.")

641
642
            return LlavaImageEmbeddingInputs(
                type="image_embeds",
643
                data=flatten_bn(image_embeds, concat=True),
644
645
646
            )

        raise AssertionError("This line should be unreachable.")
647
648
649
650
651
652
653
654
655
656
657

    def _select_image_features(self, image_features: torch.Tensor, *,
                               strategy: str) -> torch.Tensor:
        # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421  # noqa
        if strategy == "default":
            return image_features[:, 1:]
        elif strategy == "full":
            return image_features

        raise ValueError(f"Unexpected select feature strategy: {strategy}")

658
659
    def _image_pixels_to_features(
        self,
660
661
        vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
                            PixtralHFVisionModel],
662
        pixel_values: Union[torch.Tensor, list[torch.Tensor]],
663
    ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
664
665
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
666
        image_features = vision_tower(pixel_values)
667

668
669
670
671
672
673
674
675
676
        def select_features(leaf: torch.Tensor):
            return self._select_image_features(
                leaf,
                strategy=self.config.vision_feature_select_strategy,
            )

        return cast(
            Union[torch.Tensor, tuple[torch.Tensor, ...]],
            json_map_leaves(select_features, image_features),
677
678
        )

679
680
681
    def _process_image_pixels(
        self,
        inputs: Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs],
682
    ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
683
684
        assert self.vision_tower is not None

685
        pixel_values = inputs["pixel_values"]
686
687
688

        return self._image_pixels_to_features(self.vision_tower, pixel_values)

689
690
691
692
    def _process_image_input(
        self,
        image_input: LlavaImageInputs,
    ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
693
694
695
        if image_input["type"] == "image_embeds":
            return image_input["data"]

696
697
        assert self.vision_tower is not None
        image_features = self._process_image_pixels(image_input)
698

699
700
701
702
703
704
705
706
707
708
709
        if isinstance(image_features, torch.Tensor):
            return self.multi_modal_projector(image_features)

        feature_sizes = [
            image_feature.shape[0] for image_feature in image_features
        ]

        image_embeds = self.multi_modal_projector(torch.cat(image_features))
        image_embeds = torch.split(image_embeds, feature_sizes)
        return image_embeds

710
711
    def get_multimodal_embeddings(
            self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
712
713
714
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return None
715

716
        vision_embeddings = self._process_image_input(image_input)
717

718
719
        if (kwargs.get("v0_path", False)
                or image_input["type"] != "pixel_values_pixtral"):
720
            # The path is used for pixtral (V0 only) and llava (V0/V1)
721
            return vision_embeddings
722

723
        return flatten_2d_lists(
724
            scatter_patch_features(*args) for args in zip(
725
726
727
                vision_embeddings,
                image_input["embed_is_patch"],
            ))
728
729
730
731

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
732
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
733
734
    ) -> torch.Tensor:
        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
735
        if multimodal_embeddings is not None:
736
            inputs_embeds = merge_multimodal_embeddings(
737
738
                input_ids,
                inputs_embeds,
739
                select_patch_features(multimodal_embeddings),
740
741
                self.config.image_token_index,
            )
742
743
        return inputs_embeds

744
745
746
747
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
748
        intermediate_tensors: Optional[IntermediateTensors] = None,
749
        inputs_embeds: Optional[torch.Tensor] = None,
750
        **kwargs: object,
751
    ) -> Union[torch.Tensor, IntermediateTensors]:
Cyrus Leung's avatar
Cyrus Leung committed
752
        """Run forward pass for LLaVA-1.5.
753
754
755

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
756

757
        Concretely, consider a text prompt:
758
759
        `"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.

760
        Tokenizer outputs:
761
762
763
764
        `[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
        278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
765
        before they are inputted to the model, so the input processor prepends
766
767
768
769
770
771
772
773
774
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
        29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
        29901]`.

        We insert 575 tokens so that including the original image token in the
        input, there are a total of 576 (24 * 24) image tokens, which
        corresponds to the number of image tokens inputted to the language
        model, i.e. the number of image tokens outputted by the visual encoder.
775
776
777
778
779
780
781

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
Cyrus Leung's avatar
Cyrus Leung committed
782
            pixel_values: The pixels in each input image.
783

784
785
        See also:
            :class:`LlavaImageInputs`
786
        """
787
788
        if intermediate_tensors is not None:
            inputs_embeds = None
789
790
791

        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
792
        elif inputs_embeds is None:
793
            kwargs.update({"v0_path": True})
794
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
795
796
797
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
            input_ids = None
798

799
800
        hidden_states = self.language_model.model(input_ids,
                                                  positions,
801
                                                  intermediate_tensors,
802
                                                  inputs_embeds=inputs_embeds)
803
804
805

        return hidden_states

806
807
808
809
810
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
811
812
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)
813
814
815
816
817
818

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
819
        return self.language_model.sample(logits, sampling_metadata)
820

821
822
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
823
        loader = AutoWeightsLoader(self)
824
        return loader.load_weights(weights)
825
826


827
828
class MantisProcessingInfo(LlavaProcessingInfo):

829
    def get_hf_processor(self, **kwargs: object):
830
831
832
        hf_config = self.get_hf_config()
        vision_info = self.get_vision_encoder_info()

833
834
        kwargs.setdefault("patch_size", vision_info.get_patch_size())

835
836
837
        if Version(TRANSFORMERS_VERSION) < Version("4.48"):
            # BUG: num_additional_image_tokens = 0 but treated as 1,
            # so we set vision_feature_select_strategy to None to offset this
838
            kwargs.setdefault("vision_feature_select_strategy", None)
839
840
        else:
            # FIXED: https://github.com/huggingface/transformers/pull/33424/files#diff-6a37acc21efcadaae622b079b2712a131131448ff64262bd219aa346aeec38faL150
841
842
843
844
            kwargs.setdefault(
                "vision_feature_select_strategy",
                hf_config.vision_feature_select_strategy,
            )
845

846
        return self.ctx.get_hf_processor(LlavaProcessor, **kwargs)
847
848


849
class MantisMultiModalProcessor(LlavaMultiModalProcessor):
850

851
852
    def apply(
        self,
853
        prompt: Union[str, list[int]],
854
855
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
856
        return_mm_hashes: bool = False,
857
    ) -> MultiModalInputs:
858
        hf_config = self.info.get_hf_config()
859
        image_token_id = hf_config.image_token_index
860
861

        # Assume that it doesn't depend on the image size
862
        num_image_tokens = self.info.get_num_image_tokens(
863
864
865
            image_width=-1,
            image_height=-1,
        )
866

867
868
        result = super().apply(prompt, mm_data, hf_processor_mm_kwargs,
                               return_mm_hashes)
869

870
871
        mm_items = self._to_mm_items(mm_data)
        mm_item_counts = mm_items.get_all_counts()
872
873
874
875
876
877
878
        mm_kwargs = result["mm_kwargs"]

        # We reimplement the functionality of MLlavaProcessor from
        # https://github.com/TIGER-AI-Lab/Mantis.git
        def get_replacement_mantis(item_idx: int):
            return "".join([
                f"(image {item_idx+1}: <Image>",  # 7 tokens
879
                "<image>" * num_image_tokens,
880
881
882
                "</Image>)",  # 3 tokens
            ])

883
        mantis_mm_repls = self._bind_and_group_updates([
884
885
            PromptReplacement(
                modality="image",
886
                target=[image_token_id] * num_image_tokens,
887
888
889
890
                replacement=get_replacement_mantis,
            )
        ])

891
        prompt_ids, prompt, _ = self._apply_prompt_updates(
892
            result["prompt_token_ids"],
893
            mantis_mm_repls,
894
895
896
            mm_item_counts,
        )

897
        unbound_orig_repls = self._get_prompt_updates(
898
899
900
901
            mm_items,
            hf_processor_mm_kwargs,
            mm_kwargs,
        )
902
        orig_repls = self._bind_and_group_updates(unbound_orig_repls)
903
904
905
906
907
908
909

        mm_placeholders = self._find_mm_placeholders(
            orig_repls,
            prompt_ids,
            mm_item_counts,
        )
        self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
910

911
912
913
        mm_placeholder_ranges = {
            modality: [item.to_range() for item in placeholders]
            for modality, placeholders in mm_placeholders.items()
914
915
        }

916
        return MultiModalInputs(
917
            type="multimodal",
918
            prompt=prompt,
919
920
            prompt_token_ids=prompt_ids,
            mm_kwargs=mm_kwargs,
921
            mm_placeholders=mm_placeholder_ranges,
922
        )
923
924
925
926


# To use this model, please use
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
927
@MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor,
928
                                        info=MantisProcessingInfo,
929
                                        dummy_inputs=LlavaDummyInputsBuilder)
930
931
class MantisForConditionalGeneration(LlavaForConditionalGeneration):
    pass