llava.py 23.5 KB
Newer Older
1
from functools import cached_property
2
from types import MethodType
3
from typing import (Iterable, List, Literal, Mapping, Optional, Protocol, Set,
4
                    Tuple, TypedDict, Union)
5
6

import torch
7
import torch.nn as nn
8
9
10
from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig,
                          PixtralVisionConfig, PretrainedConfig,
                          ProcessorMixin, SiglipVisionConfig)
11
from transformers.models.llava import LlavaProcessor
12
from transformers.models.pixtral import PixtralProcessor
13
14

from vllm.attention import AttentionMetadata
15
from vllm.config import VllmConfig
16
from vllm.inputs import InputContext
17
from vllm.model_executor.layers.activation import get_act_fn
18
19
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               RowParallelLinear)
20
from vllm.model_executor.layers.quantization import QuantizationConfig
Joe Runde's avatar
Joe Runde committed
21
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
22
from vllm.model_executor.sampling_metadata import SamplingMetadata
23
from vllm.multimodal import MULTIMODAL_REGISTRY
24
from vllm.multimodal.inputs import NestedTensors
25
from vllm.multimodal.processing import (BaseMultiModalProcessor,
26
                                        MultiModalDataItems, ProcessorInputs,
27
                                        PromptReplacement)
28
from vllm.sequence import IntermediateTensors
29

30
from .clip import (CLIPVisionModel, dummy_image_for_clip,
31
                   get_max_clip_image_tokens)
32
from .interfaces import SupportsMultiModal, SupportsPP
33
from .pixtral import (PixtralHFVisionModel, dummy_image_for_pixtral_hf,
34
35
                      get_max_pixtral_hf_image_tokens,
                      get_pixtral_hf_image_feature_size)
36
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
37
                     get_max_siglip_image_tokens)
38
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
39
                    maybe_prefix, merge_multimodal_embeddings)
40
41


42
43
class LlavaImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
44
45
46
47
48
49
50
    data: Union[torch.Tensor, List[torch.Tensor]]
    """
    Shape: `(batch_size * num_images, num_channels, height, width)`

    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """
51
52
53
54
55


class LlavaImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
    data: torch.Tensor
56
    """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
57
58
59
60
61
62
63
64

    `hidden_size` must match the hidden size of language model backbone.
    """


LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageEmbeddingInputs]


65
66
class LlavaMultiModalProjector(nn.Module):

67
68
69
70
71
72
    def __init__(self,
                 vision_hidden_size: int,
                 text_hidden_size: int,
                 projector_hidden_act: str,
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
73
74
        super().__init__()

75
76
77
78
79
        self.linear_1 = ColumnParallelLinear(vision_hidden_size,
                                             text_hidden_size,
                                             bias=True,
                                             quant_config=quant_config,
                                             prefix=f"{prefix}.linear_1")
80
        self.act = get_act_fn(projector_hidden_act)
81
82
83
84
85
        self.linear_2 = RowParallelLinear(text_hidden_size,
                                          text_hidden_size,
                                          bias=True,
                                          quant_config=quant_config,
                                          prefix=f"{prefix}.linear_2")
86

87
    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
88
        hidden_states, _ = self.linear_1(image_features)
89
        hidden_states = self.act(hidden_states)
90
        hidden_states, _ = self.linear_2(hidden_states)
91
92
93
        return hidden_states


94
95
96
97
98
def get_max_llava_image_tokens(ctx: InputContext):
    hf_config = ctx.get_hf_config(LlavaConfig)
    vision_config = hf_config.vision_config

    if isinstance(vision_config, CLIPVisionConfig):
99
100
101
        num_image_tokens = get_max_clip_image_tokens(vision_config)
    elif isinstance(vision_config, SiglipVisionConfig):
        num_image_tokens = get_max_siglip_image_tokens(vision_config)
102
103
    elif isinstance(vision_config, PixtralVisionConfig):
        num_image_tokens = get_max_pixtral_hf_image_tokens(vision_config)
104
105
106
107
108
109
110
111
112
113
114
    else:
        msg = f"Unsupported vision config: {type(vision_config)}"
        raise NotImplementedError(msg)

    strategy = hf_config.vision_feature_select_strategy
    if strategy == "default":
        return num_image_tokens - 1
    elif strategy == "full":
        return num_image_tokens
    else:
        raise ValueError(f"Unexpected select feature strategy: {strategy}")
115
116


117
class LlavaMultiModalProcessor(BaseMultiModalProcessor):
118

119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
    def _patch_pixtral_processor(self, hf_processor: PixtralProcessor):
        if getattr(hf_processor, "__is_patched__", False):
            return  # Already patched

        image_processor = hf_processor.image_processor  # type: ignore
        orig_preprocess = image_processor.preprocess

        def preprocess(__self, *args, **kwargs):
            hf_inputs = orig_preprocess(*args, **kwargs)
            hf_inputs["is_pixtral"] = torch.tensor(True)
            return hf_inputs

        image_processor.preprocess = MethodType(preprocess, image_processor)

        hf_processor.__is_patched__ = True  # type: ignore

135
    def _get_hf_processor(self) -> Union[LlavaProcessor, PixtralProcessor]:
136
137
        hf_processor = self.ctx.get_hf_processor(
            (LlavaProcessor, PixtralProcessor))
138
139
140
141
142
143

        if isinstance(hf_processor, PixtralProcessor):
            self._patch_pixtral_processor(hf_processor)

        return hf_processor

144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
    def _get_prompt_replacements(
        self,
        mm_items: MultiModalDataItems,
        hf_inputs: BatchFeature,
        mm_processor_kwargs: Mapping[str, object],
    ) -> list[PromptReplacement]:
        hf_config = self.ctx.get_hf_config(LlavaConfig)
        image_token_id = hf_config.image_token_index

        processor = self._get_hf_processor()
        if isinstance(processor, PixtralProcessor):
            image_token = processor.image_token
            image_break_token = processor.image_break_token
            image_end_token = processor.image_end_token

            vision_config = hf_config.vision_config
            assert isinstance(vision_config, PixtralVisionConfig)

            def get_replacement_pixtral(item_idx: int):
                image_size = mm_items.get_image_size(item_idx)
                (
                    num_width_tokens,
                    num_height_tokens,
                ) = get_pixtral_hf_image_feature_size(
                    vision_config,
                    image_width=image_size.width,
                    image_height=image_size.height,
                )

                tokens = ([image_token] * num_width_tokens +
                          [image_break_token]) * num_height_tokens
                tokens[-1] = image_end_token

                return "".join(tokens)

            return [
                PromptReplacement(
                    modality="image",
                    target=[image_token_id],
                    replacement=get_replacement_pixtral,
                ),
            ]

        max_image_tokens = get_max_llava_image_tokens(self.ctx)

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
                replacement=[image_token_id] * max_image_tokens,
            )
        ]

    def _get_dummy_mm_inputs(
198
199
        self,
        mm_counts: Mapping[str, int],
200
    ) -> ProcessorInputs:
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
        hf_config = self.ctx.get_hf_config(LlavaConfig)
        vision_config = hf_config.vision_config
        num_images = mm_counts["image"]

        if isinstance(vision_config, CLIPVisionConfig):
            data = dummy_image_for_clip(vision_config, num_images)
        elif isinstance(vision_config, SiglipVisionConfig):
            data = dummy_image_for_siglip(vision_config, num_images)
        elif isinstance(vision_config, PixtralVisionConfig):
            data = dummy_image_for_pixtral_hf(vision_config, num_images)
        else:
            msg = f"Unsupported vision config: {type(vision_config)}"
            raise NotImplementedError(msg)

        hf_processor = self._get_hf_processor()
216
        image_token = hf_processor.image_token
217

218
219
220
221
222
        return ProcessorInputs(
            prompt_text=image_token * num_images,
            mm_data=data,
            mm_processor_kwargs={},
        )
223
224


225
226
class LlavaLikeConfig(Protocol):
    vision_config: PretrainedConfig
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
    vision_feature_layer: Union[int, List[int]]


def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
    """Determine the number of hidden layers to initialize up to in the
    visual encoder.
    
    Args:
        hf_config: Model config with vision feature layer(s).
    """
    feature_layers = hf_config.vision_feature_layer
    num_hidden_layers = hf_config.vision_config.num_hidden_layers
    # If we have one feature layer, initialize up to that layer
    if isinstance(feature_layers, int):
        return _get_layer_index(feature_layers, num_hidden_layers)
    # If we have multiple feature layers, initialize up to the deepest one
    elif isinstance(feature_layers, (list, tuple)):
        return max(
            _get_layer_index(idx, num_hidden_layers) for idx in feature_layers)
    raise TypeError(f"vision_layer_feature type: {type(feature_layers)}"
                    " is not supported")


def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
    """Given an signed vision feature layer, get the number of hidden layers
    needed to leverage it.

    Args:
        feature_layer_index: Index of a required layer in the visual encoder.
        num_hidden_layers: The total number of hidden layers in the visual
            encoder.
    """
    if feature_layer_index < 0:
        return num_hidden_layers + feature_layer_index + 1
    return feature_layer_index + 1
262
263
264
265
266
267
268


def init_vision_tower_for_llava(
    hf_config: LlavaLikeConfig,
    quant_config: Optional[QuantizationConfig],
    *,
    require_post_norm: Optional[bool] = None,
269
    prefix: str = "",
270
):
271
272
    vision_config = hf_config.vision_config

273
274
    # Initialize the vision tower only up to the deepest required feature layer
    num_hidden_layers = _get_num_hidden_layers(hf_config)
275
276
277
278

    if isinstance(vision_config, CLIPVisionConfig):
        return CLIPVisionModel(
            vision_config,
279
            quant_config=quant_config,
280
            num_hidden_layers_override=num_hidden_layers,
281
            require_post_norm=require_post_norm,
282
            prefix=prefix,
283
284
285
286
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return SiglipVisionModel(
            vision_config,
287
            quant_config=quant_config,
288
            num_hidden_layers_override=num_hidden_layers,
289
            require_post_norm=require_post_norm,
290
            prefix=prefix,
291
        )
292
    elif isinstance(vision_config, PixtralVisionConfig):
293
294
        return PixtralHFVisionModel(
            vision_config,
295
            quant_config=quant_config,
296
297
            num_hidden_layers_override=num_hidden_layers,
            require_post_norm=require_post_norm,
298
            prefix=prefix,
299
        )
300
301
302
303
304

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


305
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
306
@MULTIMODAL_REGISTRY.register_processor(LlavaMultiModalProcessor)
307
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
308
309
310
311
312
313
314
315
316
    # BitandBytes specific attributes
    bitsandbytes_stacked_params_mapping = {
        # shard_name, weight_name, index
        "q_proj": ("qkv_proj", 0),
        "k_proj": ("qkv_proj", 1),
        "v_proj": ("qkv_proj", 2),
        "gate_proj": ("gate_up_proj", 0),
        "up_proj": ("gate_up_proj", 1),
    }
317

318
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
319
        super().__init__()
320

321
322
323
324
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

325
        self.config = config
326
        self.multimodal_config = multimodal_config
327

328
329
330
331
332
333
334
335
336
        # NOTE: These are special cases for Pixtral-12B in the HF-format
        # https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json  # noqa
        if (config.text_config.architectures is None
                and config.text_config.model_type == "mistral"):
            config.text_config.architectures = ["MistralForCausalLM"]
        if (config.projector_hidden_act is None
                and config.vision_config.hidden_act == "gelu"):
            config.projector_hidden_act = "gelu"

337
        # TODO: Optionally initializes this for supporting embeddings.
338
        self.vision_tower = init_vision_tower_for_llava(
339
340
341
            config,
            quant_config,
            require_post_norm=False,
342
            prefix=maybe_prefix(prefix, "vision_tower"))
343
344
345
        self.multi_modal_projector = LlavaMultiModalProjector(
            vision_hidden_size=config.vision_config.hidden_size,
            text_hidden_size=config.text_config.hidden_size,
346
347
348
            projector_hidden_act=config.projector_hidden_act,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "multi_modal_projector"))
349

350
        self.language_model = init_vllm_registered_model(
351
            vllm_config=vllm_config,
352
353
354
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )
355

356
357
358
359
360
361
362
363
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

    @cached_property
    def sampler(self):
        if hasattr(self.language_model, "sampler"):
            return self.language_model.sampler

Joe Runde's avatar
Joe Runde committed
364
        return get_sampler()
365

366
367
368
369
370
371
372
    def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)
        actual_dims = tuple(data.shape[1:])

        if actual_dims != expected_dims:
            expected_expr = ("batch_size", *map(str, expected_dims))
373
            raise ValueError(
374
375
                f"The expected shape of pixel values is {expected_expr}. "
                f"You supplied {tuple(data.shape)}.")
376
377
378
379

        return data

    def _parse_and_validate_image_input(
380
381
            self, **kwargs: object) -> Optional[LlavaImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
382
        is_pixtral = kwargs.pop("is_pixtral", torch.tensor([False]))
383
        image_embeds = kwargs.pop("image_embeds", None)
384

385
        if pixel_values is None and image_embeds is None:
386
            return None
387

388
        if pixel_values is not None:
389
            if not isinstance(pixel_values, (torch.Tensor, list)):
390
391
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")
392

393
394
            assert isinstance(is_pixtral, torch.Tensor)
            if is_pixtral.any():
395
                images = pixel_values
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414

                def flatten_to_3d_tensors(item):
                    if isinstance(item, torch.Tensor):
                        if item.dim() >= 3:
                            return [t for t in item.view(-1, *item.shape[-3:])]
                        else:
                            raise ValueError(
                                f"Unexpected tensor dimension: {item.dim()}")
                    elif isinstance(item, list):
                        return [
                            t for subitem in item
                            for t in flatten_to_3d_tensors(subitem)
                        ]
                    else:
                        raise ValueError(f"Unexpected type: {type(item)}")

                # Restructure the batched images into a list of lists of images
                images = flatten_to_3d_tensors(pixel_values)

415
416
                return LlavaImagePixelInputs(
                    type="pixel_values",
417
                    data=images,
418
419
                )

420
421
            return LlavaImagePixelInputs(
                type="pixel_values",
422
423
                data=self._validate_pixel_values(
                    flatten_bn(pixel_values, concat=True)),
424
425
426
            )

        if image_embeds is not None:
427
            if not isinstance(image_embeds, (torch.Tensor, list)):
428
429
                raise ValueError("Incorrect type of image embeddings. "
                                 f"Got type: {type(image_embeds)}")
430

431
432
            return LlavaImageEmbeddingInputs(
                type="image_embeds",
433
                data=flatten_bn(image_embeds, concat=True),
434
435
436
            )

        raise AssertionError("This line should be unreachable.")
437
438
439
440
441
442
443
444
445
446
447

    def _select_image_features(self, image_features: torch.Tensor, *,
                               strategy: str) -> torch.Tensor:
        # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421  # noqa
        if strategy == "default":
            return image_features[:, 1:]
        elif strategy == "full":
            return image_features

        raise ValueError(f"Unexpected select feature strategy: {strategy}")

448
449
    def _image_pixels_to_features(
        self,
450
451
        vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
                            PixtralHFVisionModel],
452
453
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:
454

455
456
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
457
        image_features = vision_tower(pixel_values)
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473

        return self._select_image_features(
            image_features,
            strategy=self.config.vision_feature_select_strategy,
        )

    def _process_image_pixels(self,
                              inputs: LlavaImagePixelInputs) -> torch.Tensor:
        assert self.vision_tower is not None

        pixel_values = inputs["data"]

        return self._image_pixels_to_features(self.vision_tower, pixel_values)

    def _process_image_input(self,
                             image_input: LlavaImageInputs) -> torch.Tensor:
474
475
476
477

        if image_input["type"] == "image_embeds":
            return image_input["data"]

478
479
        assert self.vision_tower is not None
        image_features = self._process_image_pixels(image_input)
480
481
        return self.multi_modal_projector(image_features)

482
    def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
483
484
485
486
487
488
489
490
491
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return None
        vision_embeddings = self._process_image_input(image_input)
        return vision_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
492
        multimodal_embeddings: Optional[NestedTensors] = None,
493
494
    ) -> torch.Tensor:
        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
495
        if multimodal_embeddings is not None:
496
            inputs_embeds = merge_multimodal_embeddings(
497
                input_ids, inputs_embeds, multimodal_embeddings,
498
499
500
                self.config.image_token_index)
        return inputs_embeds

501
502
503
504
505
506
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
507
        intermediate_tensors: Optional[IntermediateTensors] = None,
508
        inputs_embeds: Optional[torch.Tensor] = None,
509
        **kwargs: object,
510
    ) -> Union[torch.Tensor, IntermediateTensors]:
Cyrus Leung's avatar
Cyrus Leung committed
511
        """Run forward pass for LLaVA-1.5.
512
513
514

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
515

516
        Concretely, consider a text prompt:
517
518
        `"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.

519
        Tokenizer outputs:
520
521
522
523
        `[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
        278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
524
        before they are inputted to the model, so the input processor prepends
525
526
527
528
529
530
531
532
533
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
        29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
        29901]`.

        We insert 575 tokens so that including the original image token in the
        input, there are a total of 576 (24 * 24) image tokens, which
        corresponds to the number of image tokens inputted to the language
        model, i.e. the number of image tokens outputted by the visual encoder.
534
535
536
537
538
539
540

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
Cyrus Leung's avatar
Cyrus Leung committed
541
            pixel_values: The pixels in each input image.
542

543
544
        See also:
            :class:`LlavaImageInputs`
545
        """
546
547
        if intermediate_tensors is not None:
            inputs_embeds = None
548
549
550

        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
551
        elif inputs_embeds is None:
552
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
553
554
555
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
            input_ids = None
556

557
558
559
560
        hidden_states = self.language_model.model(input_ids,
                                                  positions,
                                                  kv_caches,
                                                  attn_metadata,
561
                                                  intermediate_tensors,
562
                                                  inputs_embeds=inputs_embeds)
563
564
565

        return hidden_states

566
567
568
569
570
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
571
572
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)
573
574
575
576
577
578

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
579
        return self.language_model.sample(logits, sampling_metadata)
580

581
582
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
583
        loader = AutoWeightsLoader(self)
584
        return loader.load_weights(weights)
585
586


587
class MantisMultiModalProcessor(LlavaMultiModalProcessor):
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606

    def _get_hf_processor(self) -> ProcessorMixin:
        try:
            from mantis.models.mllava import MLlavaProcessor
        except ModuleNotFoundError as exc:
            raise ModuleNotFoundError(
                "You need to `pip install "
                "git+https://github.com/TIGER-AI-Lab/Mantis.git` "
                "to use this model") from exc

        processor = MLlavaProcessor.from_pretrained(
            self.ctx.model_config.tokenizer)
        assert isinstance(processor, ProcessorMixin)
        return processor


# To use this model, please use
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
607
@MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor)
608
609
class MantisForConditionalGeneration(LlavaForConditionalGeneration):
    pass