llava.py 29.5 KB
Newer Older
1
from abc import abstractmethod
2
from functools import cached_property
3
from typing import (Final, Iterable, List, Literal, Mapping, Optional,
4
                    Protocol, Set, Tuple, TypedDict, TypeVar, Union)
5
6

import torch
7
import torch.nn as nn
8
from packaging.version import Version
9
10
from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig,
                          PixtralVisionConfig, PretrainedConfig,
11
                          SiglipVisionConfig)
12
from transformers import __version__ as TRANSFORMERS_VERSION
13
from transformers.models.llava import LlavaProcessor
14
from transformers.models.pixtral import PixtralProcessor
15
16

from vllm.attention import AttentionMetadata
17
from vllm.config import VllmConfig
18
from vllm.inputs import InputProcessingContext
19
from vllm.model_executor.layers.activation import get_act_fn
20
21
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               RowParallelLinear)
22
from vllm.model_executor.layers.quantization import QuantizationConfig
Joe Runde's avatar
Joe Runde committed
23
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
24
from vllm.model_executor.sampling_metadata import SamplingMetadata
25
from vllm.multimodal import MULTIMODAL_REGISTRY
26
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
27
                                    MultiModalInputs, MultiModalKwargs,
28
                                    NestedTensors)
29
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
30
                                   ImageSize, MultiModalDataItems)
31
from vllm.multimodal.processing import (BaseMultiModalProcessor,
32
33
34
                                        BaseProcessingInfo, ProcessingCache,
                                        PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
35
from vllm.sequence import IntermediateTensors
36

37
from .clip import CLIPVisionModel
38
from .interfaces import SupportsMultiModal, SupportsPP
39
40
41
from .pixtral import (PixtralHFVisionModel,
                      get_pixtral_hf_image_feature_grid_size)
from .siglip import SiglipVisionModel
42
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
43
                    maybe_prefix, merge_multimodal_embeddings)
44
from .vision import get_vision_encoder_info
45
46


47
48
class LlavaImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
49
50
51
52
53
54
55
    data: Union[torch.Tensor, List[torch.Tensor]]
    """
    Shape: `(batch_size * num_images, num_channels, height, width)`

    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """
56
57
58
59
60


class LlavaImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
    data: torch.Tensor
61
    """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
62
63
64
65
66
67
68
69

    `hidden_size` must match the hidden size of language model backbone.
    """


LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageEmbeddingInputs]


70
71
class LlavaMultiModalProjector(nn.Module):

72
73
74
75
76
77
    def __init__(self,
                 vision_hidden_size: int,
                 text_hidden_size: int,
                 projector_hidden_act: str,
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
78
79
        super().__init__()

80
81
82
83
84
        self.linear_1 = ColumnParallelLinear(vision_hidden_size,
                                             text_hidden_size,
                                             bias=True,
                                             quant_config=quant_config,
                                             prefix=f"{prefix}.linear_1")
85
        self.act = get_act_fn(projector_hidden_act)
86
87
88
89
90
        self.linear_2 = RowParallelLinear(text_hidden_size,
                                          text_hidden_size,
                                          bias=True,
                                          quant_config=quant_config,
                                          prefix=f"{prefix}.linear_2")
91

92
    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
93
        hidden_states, _ = self.linear_1(image_features)
94
        hidden_states = self.act(hidden_states)
95
        hidden_states, _ = self.linear_2(hidden_states)
96
97
98
        return hidden_states


99
100
class LlavaLikeConfig(Protocol):
    vision_config: Final[PretrainedConfig]
101
    image_token_index: Final[int]
102
    vision_feature_select_strategy: Final[str]
103
    vision_feature_layer: Final[Union[int, list[int]]]
104

105

106
107
108
109
class LlavaLikeProcessor(Protocol):
    image_token: Final[str]


110
class BaseLlavaProcessingInfo(BaseProcessingInfo):
111

112
    def get_hf_config(self) -> LlavaLikeConfig:
113
        return self.ctx.get_hf_config(LlavaConfig)
114

115
116
    def get_vision_encoder_info(self):
        return get_vision_encoder_info(self.get_hf_config())
117

118
    @abstractmethod
119
    def get_hf_processor(self) -> LlavaLikeProcessor:
120
        raise NotImplementedError
121

122
123
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None}
124

125
126
    def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
        return {"image": self.get_max_image_tokens()}
127

128
129
130
131
132
133
134
135
136
137
138
139
140
    def _apply_feature_select_strategy(
        self,
        strategy: str,
        encoder_num_image_tokens: int,
    ) -> int:
        if strategy == "default":
            return encoder_num_image_tokens - 1
        if strategy == "full":
            return encoder_num_image_tokens

        msg = f"Unexpected feature select strategy: {strategy!r}"
        raise NotImplementedError(msg)

141
142
143
144
145
146
147
148
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        hf_config = self.get_hf_config()
        vision_encoder_info = self.get_vision_encoder_info()
149

150
151
152
153
154
155
156
        return self._apply_feature_select_strategy(
            hf_config.vision_feature_select_strategy,
            vision_encoder_info.get_num_image_tokens(
                image_width=image_width,
                image_height=image_height,
            ),
        )
157

158
159
    def get_image_size_with_most_features(self) -> ImageSize:
        vision_encoder_info = self.get_vision_encoder_info()
160
161
        width = height = vision_encoder_info.get_image_size()
        return ImageSize(width=width, height=height)
162

163
164
    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()
165

166
        return self.get_num_image_tokens(
167
168
169
170
            image_width=target_width,
            image_height=target_height,
        )

171
172
173
174
175
176

_I = TypeVar("_I", bound=BaseLlavaProcessingInfo)


class LlavaDummyInputsBuilder(BaseDummyInputsBuilder[_I]):

177
    def get_dummy_processor_inputs(
178
        self,
179
        seq_len: int,
180
181
182
183
        mm_counts: Mapping[str, int],
    ) -> ProcessorInputs:
        num_images = mm_counts.get("image", 0)

184
        processor = self.info.get_hf_processor()
185
        image_token = processor.image_token
186
187
        target_width, target_height = \
            self.info.get_image_size_with_most_features()
188
189
190
191
192
193
194
195
196
197
198
199
200
201

        mm_data = {
            "image":
            self._get_dummy_images(width=target_width,
                                   height=target_height,
                                   num_images=num_images)
        }

        return ProcessorInputs(
            prompt_text=image_token * num_images,
            mm_data=mm_data,
        )


202
class LlavaProcessingInfo(BaseLlavaProcessingInfo):
203

204
    def get_hf_processor(self):
205
206
207
        return self.ctx.get_hf_processor(LlavaProcessor)


208
class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor[_I]):
209
210
211
212
213
214
215
216
217

    # Copied from BaseMultiModalProcessor
    @abstractmethod
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        raise NotImplementedError
218
219
220
221
222
223
224

    def _get_prompt_replacements(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
    ) -> list[PromptReplacement]:
225
        hf_config = self.info.get_hf_config()
226
227
228
229
230
231
232
233
234
235
        image_token_id = hf_config.image_token_index

        def get_replacement(item_idx: int):
            images = mm_items.get_items(
                "image", (ImageEmbeddingItems, ImageProcessorItems))

            if isinstance(images, ImageEmbeddingItems):
                num_image_tokens = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
236
                num_image_tokens = self.info.get_num_image_tokens(
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
                    image_width=image_size.width,
                    image_height=image_size.height,
                )

            return [image_token_id] * num_image_tokens

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
                replacement=get_replacement,
            ),
        ]


252
253
class LlavaMultiModalProcessor(
        BaseLlavaMultiModalProcessor[LlavaProcessingInfo]):
254

255
256
257
258
259
260
261
262
263
264
265
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )


266
class PixtralHFProcessingInfo(BaseLlavaProcessingInfo):
267

268
    def get_hf_processor(self):
269
270
        return self.ctx.get_hf_processor(PixtralProcessor)

271

272
273
class PixtralHFMultiModalProcessor(
        BaseMultiModalProcessor[PixtralHFProcessingInfo]):
274

275
276
277
278
279
280
281
282
283
284
285
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
        )
286

287
288
289
290
        pixel_values = processed_outputs.get("pixel_values")
        if pixel_values is not None:
            images = mm_data["images"]
            assert isinstance(images, list)
291

292
293
294
295
296
            # Original output: (1, num_images, C, H, W)
            # New output: (num_images, C, H, W)
            assert (isinstance(pixel_values, list) and len(pixel_values) == 1)
            assert (isinstance(pixel_values[0], list)
                    and len(pixel_values[0]) == len(images))
297

298
            processed_outputs["pixel_values"] = pixel_values[0]
299

300
        return processed_outputs
301

302
303
304
305
306
307
308
309
310
311
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )

312
313
314
    def _get_prompt_replacements(
        self,
        mm_items: MultiModalDataItems,
315
316
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
317
    ) -> list[PromptReplacement]:
318
        processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
319
        hf_config = self.info.get_hf_config()
320
321
        tokenizer = self.info.get_tokenizer()
        vocab = tokenizer.get_vocab()
322

323
324
325
        image_break_id = vocab[processor.image_break_token]
        image_token_id = hf_config.image_token_index
        image_end_id = vocab[processor.image_end_token]
326

327
328
        vision_config = hf_config.vision_config
        assert isinstance(vision_config, PixtralVisionConfig)
329

330
331
332
        def get_replacement(item_idx: int):
            images = mm_items.get_items("image", ImageProcessorItems)
            image_size = images.get_image_size(item_idx)
333

334
335
336
337
338
            ncols, nrows = get_pixtral_hf_image_feature_grid_size(
                vision_config,
                image_width=image_size.width,
                image_height=image_size.height,
            )
339

340
341
            tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
            tokens[-1] = image_end_id
342

343
            return tokens
344
345
346
347
348

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
349
350
                replacement=get_replacement,
            ),
351
352
        ]

353

354
355
356
357
358
359
360
361
362
363
def _build_llava_or_pixtral_hf_info(
    ctx: InputProcessingContext, ) -> BaseLlavaProcessingInfo:
    hf_config = ctx.get_hf_config(LlavaConfig)

    if isinstance(hf_config.vision_config, PixtralVisionConfig):
        return PixtralHFProcessingInfo(ctx)

    return LlavaProcessingInfo(ctx)


364
def _build_llava_or_pixtral_hf_processor(
365
366
    info: _I,
    dummy_inputs: BaseDummyInputsBuilder[_I],
367
368
369
    *,
    cache: Optional[ProcessingCache] = None,
    enable_sanity_checks: bool = True,
370
) -> BaseMultiModalProcessor:
371
    if isinstance(info, PixtralHFProcessingInfo):
372
        return PixtralHFMultiModalProcessor(
373
374
375
376
377
378
379
380
381
382
            info,
            dummy_inputs,  # type: ignore
            cache=cache,
            enable_sanity_checks=enable_sanity_checks,
        )

    if isinstance(info, LlavaProcessingInfo):
        return LlavaMultiModalProcessor(
            info,
            dummy_inputs,  # type: ignore
383
384
            cache=cache,
            enable_sanity_checks=enable_sanity_checks,
385
        )
386

387
    raise NotImplementedError(type(info))
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421


def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
    """Determine the number of hidden layers to initialize up to in the
    visual encoder.
    
    Args:
        hf_config: Model config with vision feature layer(s).
    """
    feature_layers = hf_config.vision_feature_layer
    num_hidden_layers = hf_config.vision_config.num_hidden_layers
    # If we have one feature layer, initialize up to that layer
    if isinstance(feature_layers, int):
        return _get_layer_index(feature_layers, num_hidden_layers)
    # If we have multiple feature layers, initialize up to the deepest one
    elif isinstance(feature_layers, (list, tuple)):
        return max(
            _get_layer_index(idx, num_hidden_layers) for idx in feature_layers)
    raise TypeError(f"vision_layer_feature type: {type(feature_layers)}"
                    " is not supported")


def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
    """Given an signed vision feature layer, get the number of hidden layers
    needed to leverage it.

    Args:
        feature_layer_index: Index of a required layer in the visual encoder.
        num_hidden_layers: The total number of hidden layers in the visual
            encoder.
    """
    if feature_layer_index < 0:
        return num_hidden_layers + feature_layer_index + 1
    return feature_layer_index + 1
422
423
424
425
426
427
428


def init_vision_tower_for_llava(
    hf_config: LlavaLikeConfig,
    quant_config: Optional[QuantizationConfig],
    *,
    require_post_norm: Optional[bool] = None,
429
    prefix: str = "",
430
):
431
432
    vision_config = hf_config.vision_config

433
434
    # Initialize the vision tower only up to the deepest required feature layer
    num_hidden_layers = _get_num_hidden_layers(hf_config)
435
436
437
438

    if isinstance(vision_config, CLIPVisionConfig):
        return CLIPVisionModel(
            vision_config,
439
            quant_config=quant_config,
440
            num_hidden_layers_override=num_hidden_layers,
441
            require_post_norm=require_post_norm,
442
            prefix=prefix,
443
444
445
446
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return SiglipVisionModel(
            vision_config,
447
            quant_config=quant_config,
448
            num_hidden_layers_override=num_hidden_layers,
449
            require_post_norm=require_post_norm,
450
            prefix=prefix,
451
        )
452
    elif isinstance(vision_config, PixtralVisionConfig):
453
454
        return PixtralHFVisionModel(
            vision_config,
455
            quant_config=quant_config,
456
457
            num_hidden_layers_override=num_hidden_layers,
            require_post_norm=require_post_norm,
458
            prefix=prefix,
459
        )
460
461
462
463
464

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


465
466
467
@MULTIMODAL_REGISTRY.register_processor(_build_llava_or_pixtral_hf_processor,
                                        info=_build_llava_or_pixtral_hf_info,
                                        dummy_inputs=LlavaDummyInputsBuilder)
468
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
469
470
471
472

    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"]
473
    }
474

475
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
476
        super().__init__()
477

478
479
480
481
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

482
        self.config = config
483
        self.multimodal_config = multimodal_config
484

485
486
487
488
489
490
491
492
493
        # NOTE: These are special cases for Pixtral-12B in the HF-format
        # https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json  # noqa
        if (config.text_config.architectures is None
                and config.text_config.model_type == "mistral"):
            config.text_config.architectures = ["MistralForCausalLM"]
        if (config.projector_hidden_act is None
                and config.vision_config.hidden_act == "gelu"):
            config.projector_hidden_act = "gelu"

494
        # TODO: Optionally initializes this for supporting embeddings.
495
        self.vision_tower = init_vision_tower_for_llava(
496
497
498
            config,
            quant_config,
            require_post_norm=False,
499
            prefix=maybe_prefix(prefix, "vision_tower"))
500
501
502
        self.multi_modal_projector = LlavaMultiModalProjector(
            vision_hidden_size=config.vision_config.hidden_size,
            text_hidden_size=config.text_config.hidden_size,
503
504
505
            projector_hidden_act=config.projector_hidden_act,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "multi_modal_projector"))
506

507
        self.language_model = init_vllm_registered_model(
508
            vllm_config=vllm_config,
509
510
511
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )
512

513
514
515
516
517
518
519
520
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

    @cached_property
    def sampler(self):
        if hasattr(self.language_model, "sampler"):
            return self.language_model.sampler

Joe Runde's avatar
Joe Runde committed
521
        return get_sampler()
522

523
524
525
526
527
528
529
    def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)
        actual_dims = tuple(data.shape[1:])

        if actual_dims != expected_dims:
            expected_expr = ("batch_size", *map(str, expected_dims))
530
            raise ValueError(
531
532
                f"The expected shape of pixel values is {expected_expr}. "
                f"You supplied {tuple(data.shape)}.")
533
534
535
536

        return data

    def _parse_and_validate_image_input(
537
538
            self, **kwargs: object) -> Optional[LlavaImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
539
        image_embeds = kwargs.pop("image_embeds", None)
540

541
        if pixel_values is None and image_embeds is None:
542
            return None
543

544
        if pixel_values is not None:
545
            if not isinstance(pixel_values, (torch.Tensor, list)):
546
547
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")
548

549
550
551
552
553
554
            if self.config.vision_config.model_type == "pixtral":
                return LlavaImagePixelInputs(
                    type="pixel_values",
                    data=flatten_bn(pixel_values),
                )

555
556
            return LlavaImagePixelInputs(
                type="pixel_values",
557
558
                data=self._validate_pixel_values(
                    flatten_bn(pixel_values, concat=True)),
559
560
561
            )

        if image_embeds is not None:
562
            if not isinstance(image_embeds, (torch.Tensor, list)):
563
564
                raise ValueError("Incorrect type of image embeddings. "
                                 f"Got type: {type(image_embeds)}")
565

566
567
            return LlavaImageEmbeddingInputs(
                type="image_embeds",
568
                data=flatten_bn(image_embeds, concat=True),
569
570
571
            )

        raise AssertionError("This line should be unreachable.")
572
573
574
575
576
577
578
579
580
581
582

    def _select_image_features(self, image_features: torch.Tensor, *,
                               strategy: str) -> torch.Tensor:
        # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421  # noqa
        if strategy == "default":
            return image_features[:, 1:]
        elif strategy == "full":
            return image_features

        raise ValueError(f"Unexpected select feature strategy: {strategy}")

583
584
    def _image_pixels_to_features(
        self,
585
586
        vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
                            PixtralHFVisionModel],
587
588
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:
589

590
591
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
592
        image_features = vision_tower(pixel_values)
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608

        return self._select_image_features(
            image_features,
            strategy=self.config.vision_feature_select_strategy,
        )

    def _process_image_pixels(self,
                              inputs: LlavaImagePixelInputs) -> torch.Tensor:
        assert self.vision_tower is not None

        pixel_values = inputs["data"]

        return self._image_pixels_to_features(self.vision_tower, pixel_values)

    def _process_image_input(self,
                             image_input: LlavaImageInputs) -> torch.Tensor:
609
610
611
612

        if image_input["type"] == "image_embeds":
            return image_input["data"]

613
614
        assert self.vision_tower is not None
        image_features = self._process_image_pixels(image_input)
615
616
        return self.multi_modal_projector(image_features)

617
    def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
618
619
620
621
622
623
624
625
626
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return None
        vision_embeddings = self._process_image_input(image_input)
        return vision_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
627
        multimodal_embeddings: Optional[NestedTensors] = None,
628
629
    ) -> torch.Tensor:
        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
630
        if multimodal_embeddings is not None:
631
            inputs_embeds = merge_multimodal_embeddings(
632
                input_ids, inputs_embeds, multimodal_embeddings,
633
634
635
                self.config.image_token_index)
        return inputs_embeds

636
637
638
639
640
641
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
642
        intermediate_tensors: Optional[IntermediateTensors] = None,
643
        inputs_embeds: Optional[torch.Tensor] = None,
644
        **kwargs: object,
645
    ) -> Union[torch.Tensor, IntermediateTensors]:
Cyrus Leung's avatar
Cyrus Leung committed
646
        """Run forward pass for LLaVA-1.5.
647
648
649

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
650

651
        Concretely, consider a text prompt:
652
653
        `"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.

654
        Tokenizer outputs:
655
656
657
658
        `[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
        278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
659
        before they are inputted to the model, so the input processor prepends
660
661
662
663
664
665
666
667
668
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
        29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
        29901]`.

        We insert 575 tokens so that including the original image token in the
        input, there are a total of 576 (24 * 24) image tokens, which
        corresponds to the number of image tokens inputted to the language
        model, i.e. the number of image tokens outputted by the visual encoder.
669
670
671
672
673
674
675

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
Cyrus Leung's avatar
Cyrus Leung committed
676
            pixel_values: The pixels in each input image.
677

678
679
        See also:
            :class:`LlavaImageInputs`
680
        """
681
682
        if intermediate_tensors is not None:
            inputs_embeds = None
683
684
685

        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
686
        elif inputs_embeds is None:
687
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
688
689
690
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
            input_ids = None
691

692
693
694
695
        hidden_states = self.language_model.model(input_ids,
                                                  positions,
                                                  kv_caches,
                                                  attn_metadata,
696
                                                  intermediate_tensors,
697
                                                  inputs_embeds=inputs_embeds)
698
699
700

        return hidden_states

701
702
703
704
705
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
706
707
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)
708
709
710
711
712
713

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
714
        return self.language_model.sample(logits, sampling_metadata)
715

716
717
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
718
        loader = AutoWeightsLoader(self)
719
        return loader.load_weights(weights)
720
721


722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
class MantisProcessingInfo(LlavaProcessingInfo):

    def get_hf_processor(self):
        hf_config = self.get_hf_config()
        vision_info = self.get_vision_encoder_info()

        if Version(TRANSFORMERS_VERSION) < Version("4.48"):
            # BUG: num_additional_image_tokens = 0 but treated as 1,
            # so we set vision_feature_select_strategy to None to offset this
            vision_feature_select_strategy = None
        else:
            # FIXED: https://github.com/huggingface/transformers/pull/33424/files#diff-6a37acc21efcadaae622b079b2712a131131448ff64262bd219aa346aeec38faL150
            vision_feature_select_strategy = hf_config.vision_feature_select_strategy  # noqa: E501

        return self.ctx.get_hf_processor(
            LlavaProcessor,
            patch_size=vision_info.get_patch_size(),
            vision_feature_select_strategy=vision_feature_select_strategy,
        )


743
class MantisMultiModalProcessor(LlavaMultiModalProcessor):
744

745
746
    def apply(
        self,
747
        prompt: Union[str, list[int]],
748
749
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
750
    ) -> MultiModalInputs:
751
        hf_config = self.info.get_hf_config()
752
        image_token_id = hf_config.image_token_index
753
754

        # Assume that it doesn't depend on the image size
755
        num_image_tokens = self.info.get_num_image_tokens(
756
757
758
            image_width=-1,
            image_height=-1,
        )
759

760
        result = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
761

762
763
        mm_items = self._to_mm_items(mm_data)
        mm_item_counts = mm_items.get_all_counts()
764
765
766
767
768
769
770
        mm_kwargs = result["mm_kwargs"]

        # We reimplement the functionality of MLlavaProcessor from
        # https://github.com/TIGER-AI-Lab/Mantis.git
        def get_replacement_mantis(item_idx: int):
            return "".join([
                f"(image {item_idx+1}: <Image>",  # 7 tokens
771
                "<image>" * num_image_tokens,
772
773
774
                "</Image>)",  # 3 tokens
            ])

775
        mantis_mm_repls = self._bind_and_group_repls([
776
777
            PromptReplacement(
                modality="image",
778
                target=[image_token_id] * num_image_tokens,
779
780
781
782
                replacement=get_replacement_mantis,
            )
        ])

783
        prompt_ids, prompt, _ = self._apply_prompt_replacements(
784
            result["prompt_token_ids"],
785
            mantis_mm_repls,
786
787
788
789
790
791
792
793
            mm_item_counts,
        )

        unbound_orig_repls = self._get_prompt_replacements(
            mm_items,
            hf_processor_mm_kwargs,
            mm_kwargs,
        )
794
795
796
797
798
799
800
        orig_repls = self._bind_and_group_repls(unbound_orig_repls)

        mm_placeholders = self._find_mm_placeholders(
            orig_repls,
            prompt_ids,
            mm_item_counts,
        )
801

802
        self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
803

804
805
806
        mm_placeholder_ranges = {
            modality: [item.to_range() for item in placeholders]
            for modality, placeholders in mm_placeholders.items()
807
808
        }

809
        return MultiModalInputs(
810
            type="multimodal",
811
            prompt=prompt,
812
813
            prompt_token_ids=prompt_ids,
            mm_kwargs=mm_kwargs,
814
            mm_placeholders=mm_placeholder_ranges,
815
        )
816
817
818
819


# To use this model, please use
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
820
@MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor,
821
                                        info=MantisProcessingInfo,
822
                                        dummy_inputs=LlavaDummyInputsBuilder)
823
824
class MantisForConditionalGeneration(LlavaForConditionalGeneration):
    pass