llava.py 31.5 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from abc import abstractmethod
4
from collections.abc import Iterable, Mapping, Sequence
5
6
from typing import (Final, Literal, Optional, Protocol, Set, Tuple, TypedDict,
                    TypeVar, Union, cast)
7
8

import torch
9
import torch.nn as nn
10
from packaging.version import Version
11
12
from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig,
                          PixtralVisionConfig, PretrainedConfig,
13
                          SiglipVisionConfig)
14
from transformers import __version__ as TRANSFORMERS_VERSION
15
from transformers.models.llava import LlavaProcessor
16
from transformers.models.pixtral import PixtralProcessor
17

18
from vllm.config import VllmConfig
19
from vllm.inputs import InputProcessingContext
20
from vllm.jsontree import json_map_leaves
21
from vllm.model_executor.layers.activation import get_act_fn
22
23
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               RowParallelLinear)
24
from vllm.model_executor.layers.quantization import QuantizationConfig
25
from vllm.model_executor.sampling_metadata import SamplingMetadata
26
from vllm.multimodal import MULTIMODAL_REGISTRY
27
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
28
                                    MultiModalInputs, MultiModalKwargs)
29
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
30
                                   ImageSize, MultiModalDataItems)
31
from vllm.multimodal.processing import (BaseMultiModalProcessor,
32
                                        BaseProcessingInfo, ProcessingCache,
33
34
                                        PromptReplacement, PromptUpdate,
                                        PromptUpdateDetails)
35
from vllm.multimodal.profiling import BaseDummyInputsBuilder
36
from vllm.sequence import IntermediateTensors
37

38
from .clip import CLIPVisionModel
39
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
40
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
41
from .siglip import SiglipVisionModel
42
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
43
                    maybe_prefix, merge_multimodal_embeddings)
44
from .vision import get_vision_encoder_info
45
46


47
48
class LlavaImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
49
    pixel_values: torch.Tensor
50
51
52
53
54
55
    """
    Shape: `(batch_size * num_images, num_channels, height, width)`

    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """
56

57
58
59
60
61
62
63
64
65
66
67

class PixtralHFImagePixelInputs(TypedDict):
    type: Literal["pixel_values_pixtral"]
    pixel_values: Union[torch.Tensor, list[torch.Tensor]]
    """
    Shape: `(batch_size * num_images, num_channels, height, width)`

    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """

68
69
70
71

class LlavaImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
    data: torch.Tensor
72
    """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
73
74
75
76
77

    `hidden_size` must match the hidden size of language model backbone.
    """


78
79
LlavaImageInputs = Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs,
                         LlavaImageEmbeddingInputs]
80
81


82
83
class LlavaMultiModalProjector(nn.Module):

84
85
86
87
    def __init__(self,
                 vision_hidden_size: int,
                 text_hidden_size: int,
                 projector_hidden_act: str,
88
                 multimodal_projector_bias: bool,
89
90
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
91
92
        super().__init__()

93
94
        self.linear_1 = ColumnParallelLinear(vision_hidden_size,
                                             text_hidden_size,
95
                                             bias=multimodal_projector_bias,
96
97
                                             quant_config=quant_config,
                                             prefix=f"{prefix}.linear_1")
98
        self.act = get_act_fn(projector_hidden_act)
99
100
        self.linear_2 = RowParallelLinear(text_hidden_size,
                                          text_hidden_size,
101
                                          bias=multimodal_projector_bias,
102
103
                                          quant_config=quant_config,
                                          prefix=f"{prefix}.linear_2")
104

105
    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
106
        hidden_states, _ = self.linear_1(image_features)
107
        hidden_states = self.act(hidden_states)
108
        hidden_states, _ = self.linear_2(hidden_states)
109
110
111
        return hidden_states


112
113
class LlavaLikeConfig(Protocol):
    vision_config: Final[PretrainedConfig]
114
    image_token_index: Final[int]
115
    vision_feature_select_strategy: Final[str]
116
    vision_feature_layer: Final[Union[int, list[int]]]
117

118

119
120
121
122
class LlavaLikeProcessor(Protocol):
    image_token: Final[str]


123
class BaseLlavaProcessingInfo(BaseProcessingInfo):
124

125
    def get_hf_config(self) -> LlavaLikeConfig:
126
        return self.ctx.get_hf_config(LlavaConfig)
127

128
129
    def get_vision_encoder_info(self):
        return get_vision_encoder_info(self.get_hf_config())
130

131
    @abstractmethod
132
    def get_hf_processor(self, **kwargs: object) -> LlavaLikeProcessor:
133
        raise NotImplementedError
134

135
136
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None}
137

138
139
140
141
142
143
144
145
146
147
148
149
150
    def _apply_feature_select_strategy(
        self,
        strategy: str,
        encoder_num_image_tokens: int,
    ) -> int:
        if strategy == "default":
            return encoder_num_image_tokens - 1
        if strategy == "full":
            return encoder_num_image_tokens

        msg = f"Unexpected feature select strategy: {strategy!r}"
        raise NotImplementedError(msg)

151
152
153
154
155
156
157
158
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        hf_config = self.get_hf_config()
        vision_encoder_info = self.get_vision_encoder_info()
159

160
161
162
163
164
165
166
        return self._apply_feature_select_strategy(
            hf_config.vision_feature_select_strategy,
            vision_encoder_info.get_num_image_tokens(
                image_width=image_width,
                image_height=image_height,
            ),
        )
167

168
169
    def get_image_size_with_most_features(self) -> ImageSize:
        vision_encoder_info = self.get_vision_encoder_info()
170
171
        width = height = vision_encoder_info.get_image_size()
        return ImageSize(width=width, height=height)
172

173
174
    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()
175

176
        return self.get_num_image_tokens(
177
178
179
180
            image_width=target_width,
            image_height=target_height,
        )

181
182
183
184
185
186

_I = TypeVar("_I", bound=BaseLlavaProcessingInfo)


class LlavaDummyInputsBuilder(BaseDummyInputsBuilder[_I]):

187
188
189
190
191
192
193
194
195
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        processor = self.info.get_hf_processor()
        image_token = processor.image_token

        return image_token * num_images

    def get_dummy_mm_data(
196
        self,
197
        seq_len: int,
198
        mm_counts: Mapping[str, int],
199
    ) -> MultiModalDataDict:
200
201
        num_images = mm_counts.get("image", 0)

202
203
        target_width, target_height = \
            self.info.get_image_size_with_most_features()
204

205
        return {
206
207
208
209
210
211
212
            "image":
            self._get_dummy_images(width=target_width,
                                   height=target_height,
                                   num_images=num_images)
        }


213
class LlavaProcessingInfo(BaseLlavaProcessingInfo):
214

215
    def get_hf_processor(self, **kwargs: object):
216
217
218
219
220
221
222
        hf_processor = self.ctx.get_hf_processor(LlavaProcessor, **kwargs)
        # In case patch_size is omitted from `processor_config.json`
        # e.g. for E5-V: https://huggingface.co/royokong/e5-v
        if hf_processor.patch_size is None:
            patch_size = self.get_vision_encoder_info().get_patch_size()
            hf_processor.patch_size = patch_size
        return hf_processor
223
224


225
class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor[_I]):
226
227
228
229
230
231
232
233
234

    # Copied from BaseMultiModalProcessor
    @abstractmethod
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        raise NotImplementedError
235

236
    def _get_prompt_updates(
237
238
239
240
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
241
    ) -> Sequence[PromptUpdate]:
242
        hf_config = self.info.get_hf_config()
243
244
245
246
247
248
249
250
251
252
        image_token_id = hf_config.image_token_index

        def get_replacement(item_idx: int):
            images = mm_items.get_items(
                "image", (ImageEmbeddingItems, ImageProcessorItems))

            if isinstance(images, ImageEmbeddingItems):
                num_image_tokens = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
253
                num_image_tokens = self.info.get_num_image_tokens(
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
                    image_width=image_size.width,
                    image_height=image_size.height,
                )

            return [image_token_id] * num_image_tokens

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
                replacement=get_replacement,
            ),
        ]


269
270
class LlavaMultiModalProcessor(
        BaseLlavaMultiModalProcessor[LlavaProcessingInfo]):
271

272
273
274
275
276
277
278
279
280
281
282
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )


283
class PixtralHFProcessingInfo(BaseLlavaProcessingInfo):
284

285
286
    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(PixtralProcessor, **kwargs)
287

288

289
290
class PixtralHFMultiModalProcessor(
        BaseMultiModalProcessor[PixtralHFProcessingInfo]):
291

292
293
294
295
296
297
298
299
300
301
302
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
        )
303

304
305
        pixel_values = processed_outputs.get("pixel_values")
        if pixel_values is not None:
306
            # Before/after https://github.com/huggingface/transformers/pull/35122
307
            if Version(TRANSFORMERS_VERSION) <= Version("4.48.3"):
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
                images = mm_data["images"]
                assert isinstance(images, list)

                # Original output: (1, num_images, C, H, W)
                # New output: (num_images, C, H, W)
                assert (isinstance(pixel_values, list)
                        and len(pixel_values) == 1)
                assert (isinstance(pixel_values[0], list)
                        and len(pixel_values[0]) == len(images))

                processed_outputs["pixel_values"] = pixel_values[0]
            else:
                # Avoid padding since we need the output for each image to be
                # independent of other images for the cache to work correctly
                image_sizes = processed_outputs["image_sizes"]
                assert len(pixel_values) == len(image_sizes)

                processed_outputs["pixel_values"] = [
                    p[:, :h, :w]
                    for p, (h, w) in zip(pixel_values, image_sizes)
                ]
329

330
        return processed_outputs
331

332
333
334
335
336
337
338
339
340
341
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )

342
    def _get_prompt_updates(
343
344
        self,
        mm_items: MultiModalDataItems,
345
346
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
347
    ) -> Sequence[PromptUpdate]:
348
        processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
349
        hf_config = self.info.get_hf_config()
350
351
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()
352

353
354
355
        image_break_id = vocab[processor.image_break_token]
        image_token_id = hf_config.image_token_index
        image_end_id = vocab[processor.image_end_token]
356

357
358
        assert isinstance(hf_config.vision_config, PixtralVisionConfig)
        encoder_info = PixtralHFEncoderInfo(hf_config)
359

360
361
362
        def get_replacement(item_idx: int):
            images = mm_items.get_items("image", ImageProcessorItems)
            image_size = images.get_image_size(item_idx)
363

364
            ncols, nrows = encoder_info.get_patch_grid_size(
365
366
367
                image_width=image_size.width,
                image_height=image_size.height,
            )
368

369
370
            tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
            tokens[-1] = image_end_id
371

372
            return PromptUpdateDetails.select_token_id(tokens, image_token_id)
373
374
375
376
377

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
378
379
                replacement=get_replacement,
            ),
380
381
        ]

382

383
384
385
386
387
388
389
390
391
392
def _build_llava_or_pixtral_hf_info(
    ctx: InputProcessingContext, ) -> BaseLlavaProcessingInfo:
    hf_config = ctx.get_hf_config(LlavaConfig)

    if isinstance(hf_config.vision_config, PixtralVisionConfig):
        return PixtralHFProcessingInfo(ctx)

    return LlavaProcessingInfo(ctx)


393
def _build_llava_or_pixtral_hf_processor(
394
395
    info: _I,
    dummy_inputs: BaseDummyInputsBuilder[_I],
396
397
    *,
    cache: Optional[ProcessingCache] = None,
398
) -> BaseMultiModalProcessor:
399
    if isinstance(info, PixtralHFProcessingInfo):
400
        return PixtralHFMultiModalProcessor(
401
402
403
404
405
406
407
408
409
            info,
            dummy_inputs,  # type: ignore
            cache=cache,
        )

    if isinstance(info, LlavaProcessingInfo):
        return LlavaMultiModalProcessor(
            info,
            dummy_inputs,  # type: ignore
410
            cache=cache,
411
        )
412

413
    raise NotImplementedError(type(info))
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436


def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
    """Determine the number of hidden layers to initialize up to in the
    visual encoder.
    
    Args:
        hf_config: Model config with vision feature layer(s).
    """
    feature_layers = hf_config.vision_feature_layer
    num_hidden_layers = hf_config.vision_config.num_hidden_layers
    # If we have one feature layer, initialize up to that layer
    if isinstance(feature_layers, int):
        return _get_layer_index(feature_layers, num_hidden_layers)
    # If we have multiple feature layers, initialize up to the deepest one
    elif isinstance(feature_layers, (list, tuple)):
        return max(
            _get_layer_index(idx, num_hidden_layers) for idx in feature_layers)
    raise TypeError(f"vision_layer_feature type: {type(feature_layers)}"
                    " is not supported")


def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
437
    """Given a signed vision feature layer, get the number of hidden layers
438
439
440
441
442
443
444
445
446
    needed to leverage it.

    Args:
        feature_layer_index: Index of a required layer in the visual encoder.
        num_hidden_layers: The total number of hidden layers in the visual
            encoder.
    """
    if feature_layer_index < 0:
        return num_hidden_layers + feature_layer_index + 1
447
    return feature_layer_index
448
449
450
451
452
453
454


def init_vision_tower_for_llava(
    hf_config: LlavaLikeConfig,
    quant_config: Optional[QuantizationConfig],
    *,
    require_post_norm: Optional[bool] = None,
455
    prefix: str = "",
456
) -> Union[CLIPVisionModel, SiglipVisionModel, PixtralHFVisionModel]:
457
458
    vision_config = hf_config.vision_config

459
460
    # Initialize the vision tower only up to the deepest required feature layer
    num_hidden_layers = _get_num_hidden_layers(hf_config)
461
462
463
464

    if isinstance(vision_config, CLIPVisionConfig):
        return CLIPVisionModel(
            vision_config,
465
            quant_config=quant_config,
466
            num_hidden_layers_override=num_hidden_layers,
467
            require_post_norm=require_post_norm,
468
            prefix=prefix,
469
470
471
472
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return SiglipVisionModel(
            vision_config,
473
            quant_config=quant_config,
474
            num_hidden_layers_override=num_hidden_layers,
475
            require_post_norm=require_post_norm,
476
            prefix=prefix,
477
        )
478
    elif isinstance(vision_config, PixtralVisionConfig):
479
480
        return PixtralHFVisionModel(
            vision_config,
481
            quant_config=quant_config,
482
483
            num_hidden_layers_override=num_hidden_layers,
            require_post_norm=require_post_norm,
484
            prefix=prefix,
485
        )
486
487
488
489
490

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


491
492
493
@MULTIMODAL_REGISTRY.register_processor(_build_llava_or_pixtral_hf_processor,
                                        info=_build_llava_or_pixtral_hf_info,
                                        dummy_inputs=LlavaDummyInputsBuilder)
494
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
495
496
497
498

    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"]
499
    }
500

501
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
502
        super().__init__()
503

504
505
506
507
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

508
        self.config = config
509
        self.multimodal_config = multimodal_config
510

511
512
513
514
515
516
517
518
519
        # NOTE: These are special cases for Pixtral-12B in the HF-format
        # https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json  # noqa
        if (config.text_config.architectures is None
                and config.text_config.model_type == "mistral"):
            config.text_config.architectures = ["MistralForCausalLM"]
        if (config.projector_hidden_act is None
                and config.vision_config.hidden_act == "gelu"):
            config.projector_hidden_act = "gelu"

520
        # TODO: Optionally initializes this for supporting embeddings.
521
        self.vision_tower = init_vision_tower_for_llava(
522
523
524
            config,
            quant_config,
            require_post_norm=False,
525
            prefix=maybe_prefix(prefix, "vision_tower"))
526
527
528
        self.multi_modal_projector = LlavaMultiModalProjector(
            vision_hidden_size=config.vision_config.hidden_size,
            text_hidden_size=config.text_config.hidden_size,
529
            projector_hidden_act=config.projector_hidden_act,
530
            multimodal_projector_bias=config.multimodal_projector_bias,
531
532
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "multi_modal_projector"))
533

534
        self.language_model = init_vllm_registered_model(
535
            vllm_config=vllm_config,
536
537
538
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )
539

540
541
542
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

543
544
545
546
547
548
549
    def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)
        actual_dims = tuple(data.shape[1:])

        if actual_dims != expected_dims:
            expected_expr = ("batch_size", *map(str, expected_dims))
550
            raise ValueError(
551
552
                f"The expected shape of pixel values is {expected_expr}. "
                f"You supplied {tuple(data.shape)}.")
553
554
555
556

        return data

    def _parse_and_validate_image_input(
557
558
            self, **kwargs: object) -> Optional[LlavaImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
559
        image_embeds = kwargs.pop("image_embeds", None)
560

561
        if pixel_values is None and image_embeds is None:
562
            return None
563

564
        if pixel_values is not None:
565
            if not isinstance(pixel_values, (torch.Tensor, list)):
566
567
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")
568

569
            if self.config.vision_config.model_type == "pixtral":
570
571
572
                return PixtralHFImagePixelInputs(
                    type="pixel_values_pixtral",
                    pixel_values=flatten_bn(pixel_values),
573
574
                )

575
576
            return LlavaImagePixelInputs(
                type="pixel_values",
577
                pixel_values=self._validate_pixel_values(
578
                    flatten_bn(pixel_values, concat=True)),
579
580
581
            )

        if image_embeds is not None:
582
            if not isinstance(image_embeds, (torch.Tensor, list)):
583
584
                raise ValueError("Incorrect type of image embeddings. "
                                 f"Got type: {type(image_embeds)}")
585

586
587
588
            if self.config.vision_config.model_type == "pixtral":
                raise ValueError("Pixtral-HF does not support image_embeds.")

589
590
            return LlavaImageEmbeddingInputs(
                type="image_embeds",
591
                data=flatten_bn(image_embeds, concat=True),
592
593
594
            )

        raise AssertionError("This line should be unreachable.")
595
596
597
598
599
600
601
602
603
604
605

    def _select_image_features(self, image_features: torch.Tensor, *,
                               strategy: str) -> torch.Tensor:
        # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421  # noqa
        if strategy == "default":
            return image_features[:, 1:]
        elif strategy == "full":
            return image_features

        raise ValueError(f"Unexpected select feature strategy: {strategy}")

606
607
    def _image_pixels_to_features(
        self,
608
609
        vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
                            PixtralHFVisionModel],
610
        pixel_values: Union[torch.Tensor, list[torch.Tensor]],
611
    ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
612
613
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
614
        image_features = vision_tower(pixel_values)
615

616
617
618
619
620
621
622
623
624
        def select_features(leaf: torch.Tensor):
            return self._select_image_features(
                leaf,
                strategy=self.config.vision_feature_select_strategy,
            )

        return cast(
            Union[torch.Tensor, tuple[torch.Tensor, ...]],
            json_map_leaves(select_features, image_features),
625
626
        )

627
628
629
    def _process_image_pixels(
        self,
        inputs: Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs],
630
    ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
631
632
        assert self.vision_tower is not None

633
        pixel_values = inputs["pixel_values"]
634
635
636

        return self._image_pixels_to_features(self.vision_tower, pixel_values)

637
638
639
640
    def _process_image_input(
        self,
        image_input: LlavaImageInputs,
    ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
641
642
643
        if image_input["type"] == "image_embeds":
            return image_input["data"]

644
645
        assert self.vision_tower is not None
        image_features = self._process_image_pixels(image_input)
646

647
648
649
650
651
652
653
654
655
656
657
        if isinstance(image_features, torch.Tensor):
            return self.multi_modal_projector(image_features)

        feature_sizes = [
            image_feature.shape[0] for image_feature in image_features
        ]

        image_embeds = self.multi_modal_projector(torch.cat(image_features))
        image_embeds = torch.split(image_embeds, feature_sizes)
        return image_embeds

658
659
660
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

661
662
    def get_multimodal_embeddings(
            self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
663
664
665
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return None
666

667
        return self._process_image_input(image_input)
668
669
670
671

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
672
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
673
674
    ) -> torch.Tensor:
        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
675
        if multimodal_embeddings is not None:
676
            inputs_embeds = merge_multimodal_embeddings(
677
678
                input_ids,
                inputs_embeds,
679
                multimodal_embeddings,
680
681
                self.config.image_token_index,
            )
682
683
        return inputs_embeds

684
685
686
687
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
688
        intermediate_tensors: Optional[IntermediateTensors] = None,
689
        inputs_embeds: Optional[torch.Tensor] = None,
690
        **kwargs: object,
691
    ) -> Union[torch.Tensor, IntermediateTensors]:
Cyrus Leung's avatar
Cyrus Leung committed
692
        """Run forward pass for LLaVA-1.5.
693
694
695

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
696

697
        Concretely, consider a text prompt:
698
699
        `"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.

700
        Tokenizer outputs:
701
702
703
704
        `[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
        278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
705
        before they are inputted to the model, so the input processor prepends
706
707
708
709
710
711
712
713
714
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
        29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
        29901]`.

        We insert 575 tokens so that including the original image token in the
        input, there are a total of 576 (24 * 24) image tokens, which
        corresponds to the number of image tokens inputted to the language
        model, i.e. the number of image tokens outputted by the visual encoder.
715
716
717
718
719
720
721

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
Cyrus Leung's avatar
Cyrus Leung committed
722
            pixel_values: The pixels in each input image.
723

724
725
        See also:
            :class:`LlavaImageInputs`
726
        """
727
728
        if intermediate_tensors is not None:
            inputs_embeds = None
729
730
731

        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
732
        elif inputs_embeds is None:
733
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
734
735
736
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
            input_ids = None
737

738
739
        hidden_states = self.language_model.model(input_ids,
                                                  positions,
740
                                                  intermediate_tensors,
741
                                                  inputs_embeds=inputs_embeds)
742
743
744

        return hidden_states

745
746
747
748
749
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
750
751
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)
752

753
754
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
755
        loader = AutoWeightsLoader(self)
756
        return loader.load_weights(weights)
757
758


759
760
class MantisProcessingInfo(LlavaProcessingInfo):

761
    def get_hf_processor(self, **kwargs: object):
762
763
764
        hf_config = self.get_hf_config()
        vision_info = self.get_vision_encoder_info()

765
766
        kwargs.setdefault("patch_size", vision_info.get_patch_size())

767
768
769
        if Version(TRANSFORMERS_VERSION) < Version("4.48"):
            # BUG: num_additional_image_tokens = 0 but treated as 1,
            # so we set vision_feature_select_strategy to None to offset this
770
            kwargs.setdefault("vision_feature_select_strategy", None)
771
772
        else:
            # FIXED: https://github.com/huggingface/transformers/pull/33424/files#diff-6a37acc21efcadaae622b079b2712a131131448ff64262bd219aa346aeec38faL150
773
774
775
776
            kwargs.setdefault(
                "vision_feature_select_strategy",
                hf_config.vision_feature_select_strategy,
            )
777

778
        return self.ctx.get_hf_processor(LlavaProcessor, **kwargs)
779
780


781
class MantisMultiModalProcessor(LlavaMultiModalProcessor):
782

783
784
    def apply(
        self,
785
        prompt: Union[str, list[int]],
786
787
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
788
        return_mm_hashes: bool = False,
789
    ) -> MultiModalInputs:
790
        hf_config = self.info.get_hf_config()
791
        image_token_id = hf_config.image_token_index
792
793

        # Assume that it doesn't depend on the image size
794
        num_image_tokens = self.info.get_num_image_tokens(
795
796
797
            image_width=-1,
            image_height=-1,
        )
798

799
800
        result = super().apply(prompt, mm_data, hf_processor_mm_kwargs,
                               return_mm_hashes)
801

802
803
        mm_items = self._to_mm_items(mm_data)
        mm_item_counts = mm_items.get_all_counts()
804
        mm_kwargs = result["mm_kwargs"]
805
        mm_hashes = result["mm_hashes"]
806
807
808
809
810
811

        # We reimplement the functionality of MLlavaProcessor from
        # https://github.com/TIGER-AI-Lab/Mantis.git
        def get_replacement_mantis(item_idx: int):
            return "".join([
                f"(image {item_idx+1}: <Image>",  # 7 tokens
812
                "<image>" * num_image_tokens,
813
814
815
                "</Image>)",  # 3 tokens
            ])

816
        mantis_mm_repls = self._bind_and_group_updates([
817
818
            PromptReplacement(
                modality="image",
819
                target=[image_token_id] * num_image_tokens,
820
821
822
823
                replacement=get_replacement_mantis,
            )
        ])

824
        prompt_ids, prompt, _ = self._apply_prompt_updates(
825
            result["prompt_token_ids"],
826
            mantis_mm_repls,
827
828
829
            mm_item_counts,
        )

830
        unbound_orig_repls = self._get_prompt_updates(
831
832
833
834
            mm_items,
            hf_processor_mm_kwargs,
            mm_kwargs,
        )
835
        orig_repls = self._bind_and_group_updates(unbound_orig_repls)
836
837
838
839
840
841
842

        mm_placeholders = self._find_mm_placeholders(
            orig_repls,
            prompt_ids,
            mm_item_counts,
        )
        self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
843

844
845
846
        mm_placeholder_ranges = {
            modality: [item.to_range() for item in placeholders]
            for modality, placeholders in mm_placeholders.items()
847
848
        }

849
        return MultiModalInputs(
850
            type="multimodal",
851
            prompt=prompt,
852
853
            prompt_token_ids=prompt_ids,
            mm_kwargs=mm_kwargs,
854
            mm_hashes=mm_hashes,
855
            mm_placeholders=mm_placeholder_ranges,
856
        )
857
858
859
860


# To use this model, please use
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
861
@MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor,
862
                                        info=MantisProcessingInfo,
863
                                        dummy_inputs=LlavaDummyInputsBuilder)
864
865
class MantisForConditionalGeneration(LlavaForConditionalGeneration):
    pass