llava.py 21.6 KB
Newer Older
1
from functools import cached_property
2
from types import MethodType
3
from typing import (Iterable, List, Literal, Mapping, Optional, Protocol, Set,
4
                    Tuple, TypedDict, Union)
5
6

import torch
7
import torch.nn as nn
8
9
10
11
12
from PIL.Image import Image
from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig,
                          PixtralVisionConfig, PretrainedConfig,
                          ProcessorMixin, SiglipVisionConfig)
from transformers.models.pixtral import PixtralProcessor
13
14

from vllm.attention import AttentionMetadata
15
from vllm.config import VllmConfig
16
from vllm.inputs import InputContext
17
from vllm.model_executor.layers.activation import get_act_fn
18
19
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               RowParallelLinear)
20
from vllm.model_executor.layers.quantization import QuantizationConfig
Joe Runde's avatar
Joe Runde committed
21
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
22
from vllm.model_executor.sampling_metadata import SamplingMetadata
23
from vllm.multimodal import MULTIMODAL_REGISTRY
24
25
26
27
28
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
from vllm.multimodal.processing import (InputProcessingContext,
                                        ModalityProcessingMetadata,
                                        MultiModalProcessingMetadata,
                                        MultiModalProcessor, PromptReplacement)
29
from vllm.sequence import IntermediateTensors
30

31
from .clip import (CLIPVisionModel, dummy_image_for_clip,
32
                   get_max_clip_image_tokens)
33
from .interfaces import SupportsMultiModal, SupportsPP
34
from .pixtral import (PixtralHFVisionModel, dummy_image_for_pixtral_hf,
35
                      get_max_pixtral_hf_image_tokens)
36
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
37
                     get_max_siglip_image_tokens)
38
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
39
                    maybe_prefix, merge_multimodal_embeddings)
40
41


42
43
class LlavaImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
44
45
46
47
48
49
50
    data: Union[torch.Tensor, List[torch.Tensor]]
    """
    Shape: `(batch_size * num_images, num_channels, height, width)`

    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """
51
52
53
54
55


class LlavaImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
    data: torch.Tensor
56
    """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
57
58
59
60
61
62
63
64

    `hidden_size` must match the hidden size of language model backbone.
    """


LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageEmbeddingInputs]


65
66
class LlavaMultiModalProjector(nn.Module):

67
68
69
70
71
72
    def __init__(self,
                 vision_hidden_size: int,
                 text_hidden_size: int,
                 projector_hidden_act: str,
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
73
74
        super().__init__()

75
76
77
78
79
        self.linear_1 = ColumnParallelLinear(vision_hidden_size,
                                             text_hidden_size,
                                             bias=True,
                                             quant_config=quant_config,
                                             prefix=f"{prefix}.linear_1")
80
        self.act = get_act_fn(projector_hidden_act)
81
82
83
84
85
        self.linear_2 = RowParallelLinear(text_hidden_size,
                                          text_hidden_size,
                                          bias=True,
                                          quant_config=quant_config,
                                          prefix=f"{prefix}.linear_2")
86

87
    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
88
        hidden_states, _ = self.linear_1(image_features)
89
        hidden_states = self.act(hidden_states)
90
        hidden_states, _ = self.linear_2(hidden_states)
91
92
93
        return hidden_states


94
95
96
97
98
def get_max_llava_image_tokens(ctx: InputContext):
    hf_config = ctx.get_hf_config(LlavaConfig)
    vision_config = hf_config.vision_config

    if isinstance(vision_config, CLIPVisionConfig):
99
100
101
        num_image_tokens = get_max_clip_image_tokens(vision_config)
    elif isinstance(vision_config, SiglipVisionConfig):
        num_image_tokens = get_max_siglip_image_tokens(vision_config)
102
103
    elif isinstance(vision_config, PixtralVisionConfig):
        num_image_tokens = get_max_pixtral_hf_image_tokens(vision_config)
104
105
106
107
108
109
110
111
112
113
114
    else:
        msg = f"Unsupported vision config: {type(vision_config)}"
        raise NotImplementedError(msg)

    strategy = hf_config.vision_feature_select_strategy
    if strategy == "default":
        return num_image_tokens - 1
    elif strategy == "full":
        return num_image_tokens
    else:
        raise ValueError(f"Unexpected select feature strategy: {strategy}")
115
116


117
118
def dummy_mm_kwargs_for_llava(ctx: InputProcessingContext,
                              mm_counts: Mapping[str, int]):
119
120
    hf_config = ctx.get_hf_config(LlavaConfig)
    vision_config = hf_config.vision_config
121
    num_images = mm_counts["image"]
122
123

    if isinstance(vision_config, CLIPVisionConfig):
124
        data = dummy_image_for_clip(vision_config, num_images)
125
    elif isinstance(vision_config, SiglipVisionConfig):
126
        data = dummy_image_for_siglip(vision_config, num_images)
127
    elif isinstance(vision_config, PixtralVisionConfig):
128
129
130
131
        data = dummy_image_for_pixtral_hf(vision_config, num_images)
    else:
        msg = f"Unsupported vision config: {type(vision_config)}"
        raise NotImplementedError(msg)
132

133
134
135
136
    hf_processor = ctx.get_hf_processor()
    image_processor = hf_processor.image_processor  # type: ignore
    hf_inputs = image_processor.preprocess(data['image'], return_tensors="pt")
    is_pixtral = isinstance(hf_processor, PixtralProcessor)
137

138
139
140
141
    return MultiModalKwargs(
        **hf_inputs,
        is_pixtral=torch.tensor(is_pixtral),
    )
142

143

144
145
def create_metadata_for_llava(
        ctx: InputProcessingContext) -> MultiModalProcessingMetadata:
146
    hf_config = ctx.get_hf_config(LlavaConfig)
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
    image_token_id = hf_config.image_token_index

    def get_repl_count(
        mm_items: list[Image],
        hf_inputs: BatchFeature,
        item_idx: int,
    ) -> int:
        return get_max_llava_image_tokens(ctx)

    return {
        "image":
        ModalityProcessingMetadata(prompt_repls=[
            PromptReplacement(target=[image_token_id],
                              repl_unit=[image_token_id],
                              repl_count=get_repl_count),
        ]),
    }
164

165

166
class LlavaProcessor(MultiModalProcessor):
167

168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
    def _patch_pixtral_processor(self, hf_processor: PixtralProcessor):
        if getattr(hf_processor, "__is_patched__", False):
            return  # Already patched

        image_processor = hf_processor.image_processor  # type: ignore
        orig_preprocess = image_processor.preprocess

        def preprocess(__self, *args, **kwargs):
            hf_inputs = orig_preprocess(*args, **kwargs)
            hf_inputs["is_pixtral"] = torch.tensor(True)
            return hf_inputs

        image_processor.preprocess = MethodType(preprocess, image_processor)

        hf_processor.__is_patched__ = True  # type: ignore

    def _get_hf_processor(self) -> ProcessorMixin:
        hf_processor = self.ctx.get_hf_processor()

        if isinstance(hf_processor, PixtralProcessor):
            self._patch_pixtral_processor(hf_processor)

        return hf_processor

    def _get_dummy_mm_kwargs(
        self,
        mm_counts: Mapping[str, int],
    ) -> MultiModalKwargs:
        return dummy_mm_kwargs_for_llava(self.ctx, mm_counts)
197
198


199
200
class LlavaLikeConfig(Protocol):
    vision_config: PretrainedConfig
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
    vision_feature_layer: Union[int, List[int]]


def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
    """Determine the number of hidden layers to initialize up to in the
    visual encoder.
    
    Args:
        hf_config: Model config with vision feature layer(s).
    """
    feature_layers = hf_config.vision_feature_layer
    num_hidden_layers = hf_config.vision_config.num_hidden_layers
    # If we have one feature layer, initialize up to that layer
    if isinstance(feature_layers, int):
        return _get_layer_index(feature_layers, num_hidden_layers)
    # If we have multiple feature layers, initialize up to the deepest one
    elif isinstance(feature_layers, (list, tuple)):
        return max(
            _get_layer_index(idx, num_hidden_layers) for idx in feature_layers)
    raise TypeError(f"vision_layer_feature type: {type(feature_layers)}"
                    " is not supported")


def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
    """Given an signed vision feature layer, get the number of hidden layers
    needed to leverage it.

    Args:
        feature_layer_index: Index of a required layer in the visual encoder.
        num_hidden_layers: The total number of hidden layers in the visual
            encoder.
    """
    if feature_layer_index < 0:
        return num_hidden_layers + feature_layer_index + 1
    return feature_layer_index + 1
236
237
238
239
240
241
242


def init_vision_tower_for_llava(
    hf_config: LlavaLikeConfig,
    quant_config: Optional[QuantizationConfig],
    *,
    require_post_norm: Optional[bool] = None,
243
    prefix: str = "",
244
):
245
246
    vision_config = hf_config.vision_config

247
248
    # Initialize the vision tower only up to the deepest required feature layer
    num_hidden_layers = _get_num_hidden_layers(hf_config)
249
250
251
252

    if isinstance(vision_config, CLIPVisionConfig):
        return CLIPVisionModel(
            vision_config,
253
            quant_config=quant_config,
254
            num_hidden_layers_override=num_hidden_layers,
255
            require_post_norm=require_post_norm,
256
            prefix=prefix,
257
258
259
260
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return SiglipVisionModel(
            vision_config,
261
            quant_config=quant_config,
262
            num_hidden_layers_override=num_hidden_layers,
263
            require_post_norm=require_post_norm,
264
            prefix=prefix,
265
        )
266
    elif isinstance(vision_config, PixtralVisionConfig):
267
268
        return PixtralHFVisionModel(
            vision_config,
269
            quant_config=quant_config,
270
271
            num_hidden_layers_override=num_hidden_layers,
            require_post_norm=require_post_norm,
272
            prefix=prefix,
273
        )
274
275
276
277
278

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


279
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
280
281
282
283
@MULTIMODAL_REGISTRY.register_processor(lambda ctx: LlavaProcessor(
    ctx=ctx,
    metadata=create_metadata_for_llava(ctx),
))
284
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
285
286
287
288
289
290
291
292
293
    # BitandBytes specific attributes
    bitsandbytes_stacked_params_mapping = {
        # shard_name, weight_name, index
        "q_proj": ("qkv_proj", 0),
        "k_proj": ("qkv_proj", 1),
        "v_proj": ("qkv_proj", 2),
        "gate_proj": ("gate_up_proj", 0),
        "up_proj": ("gate_up_proj", 1),
    }
294

295
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
296
        super().__init__()
297

298
299
300
301
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

302
        self.config = config
303
        self.multimodal_config = multimodal_config
304

305
306
307
308
309
310
311
312
313
        # NOTE: These are special cases for Pixtral-12B in the HF-format
        # https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json  # noqa
        if (config.text_config.architectures is None
                and config.text_config.model_type == "mistral"):
            config.text_config.architectures = ["MistralForCausalLM"]
        if (config.projector_hidden_act is None
                and config.vision_config.hidden_act == "gelu"):
            config.projector_hidden_act = "gelu"

314
        # TODO: Optionally initializes this for supporting embeddings.
315
        self.vision_tower = init_vision_tower_for_llava(
316
317
318
            config,
            quant_config,
            require_post_norm=False,
319
            prefix=maybe_prefix(prefix, "vision_tower"))
320
321
322
        self.multi_modal_projector = LlavaMultiModalProjector(
            vision_hidden_size=config.vision_config.hidden_size,
            text_hidden_size=config.text_config.hidden_size,
323
324
325
            projector_hidden_act=config.projector_hidden_act,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "multi_modal_projector"))
326

327
        self.language_model = init_vllm_registered_model(
328
            vllm_config=vllm_config,
329
330
331
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )
332

333
334
335
336
337
338
339
340
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

    @cached_property
    def sampler(self):
        if hasattr(self.language_model, "sampler"):
            return self.language_model.sampler

Joe Runde's avatar
Joe Runde committed
341
        return get_sampler()
342

343
344
345
346
347
348
349
    def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)
        actual_dims = tuple(data.shape[1:])

        if actual_dims != expected_dims:
            expected_expr = ("batch_size", *map(str, expected_dims))
350
            raise ValueError(
351
352
                f"The expected shape of pixel values is {expected_expr}. "
                f"You supplied {tuple(data.shape)}.")
353
354
355
356

        return data

    def _parse_and_validate_image_input(
357
358
            self, **kwargs: object) -> Optional[LlavaImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
359
        is_pixtral = kwargs.pop("is_pixtral", torch.tensor([False]))
360
        image_embeds = kwargs.pop("image_embeds", None)
361

362
        if pixel_values is None and image_embeds is None:
363
            return None
364

365
        if pixel_values is not None:
366
            if not isinstance(pixel_values, (torch.Tensor, list)):
367
368
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")
369

370
371
            assert isinstance(is_pixtral, torch.Tensor)
            if is_pixtral.any():
372
                images = pixel_values
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391

                def flatten_to_3d_tensors(item):
                    if isinstance(item, torch.Tensor):
                        if item.dim() >= 3:
                            return [t for t in item.view(-1, *item.shape[-3:])]
                        else:
                            raise ValueError(
                                f"Unexpected tensor dimension: {item.dim()}")
                    elif isinstance(item, list):
                        return [
                            t for subitem in item
                            for t in flatten_to_3d_tensors(subitem)
                        ]
                    else:
                        raise ValueError(f"Unexpected type: {type(item)}")

                # Restructure the batched images into a list of lists of images
                images = flatten_to_3d_tensors(pixel_values)

392
393
                return LlavaImagePixelInputs(
                    type="pixel_values",
394
                    data=images,
395
396
                )

397
398
            return LlavaImagePixelInputs(
                type="pixel_values",
399
400
                data=self._validate_pixel_values(
                    flatten_bn(pixel_values, concat=True)),
401
402
403
            )

        if image_embeds is not None:
404
            if not isinstance(image_embeds, (torch.Tensor, list)):
405
406
                raise ValueError("Incorrect type of image embeddings. "
                                 f"Got type: {type(image_embeds)}")
407

408
409
            return LlavaImageEmbeddingInputs(
                type="image_embeds",
410
                data=flatten_bn(image_embeds, concat=True),
411
412
413
            )

        raise AssertionError("This line should be unreachable.")
414
415
416
417
418
419
420
421
422
423
424

    def _select_image_features(self, image_features: torch.Tensor, *,
                               strategy: str) -> torch.Tensor:
        # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421  # noqa
        if strategy == "default":
            return image_features[:, 1:]
        elif strategy == "full":
            return image_features

        raise ValueError(f"Unexpected select feature strategy: {strategy}")

425
426
    def _image_pixels_to_features(
        self,
427
428
        vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
                            PixtralHFVisionModel],
429
430
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:
431

432
433
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
434
        image_features = vision_tower(pixel_values)
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450

        return self._select_image_features(
            image_features,
            strategy=self.config.vision_feature_select_strategy,
        )

    def _process_image_pixels(self,
                              inputs: LlavaImagePixelInputs) -> torch.Tensor:
        assert self.vision_tower is not None

        pixel_values = inputs["data"]

        return self._image_pixels_to_features(self.vision_tower, pixel_values)

    def _process_image_input(self,
                             image_input: LlavaImageInputs) -> torch.Tensor:
451
452
453
454

        if image_input["type"] == "image_embeds":
            return image_input["data"]

455
456
        assert self.vision_tower is not None
        image_features = self._process_image_pixels(image_input)
457
458
        return self.multi_modal_projector(image_features)

459
    def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
460
461
462
463
464
465
466
467
468
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return None
        vision_embeddings = self._process_image_input(image_input)
        return vision_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
469
        multimodal_embeddings: Optional[NestedTensors] = None,
470
471
    ) -> torch.Tensor:
        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
472
        if multimodal_embeddings is not None:
473
            inputs_embeds = merge_multimodal_embeddings(
474
                input_ids, inputs_embeds, multimodal_embeddings,
475
476
477
                self.config.image_token_index)
        return inputs_embeds

478
479
480
481
482
483
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
484
        intermediate_tensors: Optional[IntermediateTensors] = None,
485
        inputs_embeds: Optional[torch.Tensor] = None,
486
        **kwargs: object,
487
    ) -> Union[torch.Tensor, IntermediateTensors]:
Cyrus Leung's avatar
Cyrus Leung committed
488
        """Run forward pass for LLaVA-1.5.
489
490
491

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
492

493
        Concretely, consider a text prompt:
494
495
        `"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.

496
        Tokenizer outputs:
497
498
499
500
        `[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
        278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
501
        before they are inputted to the model, so the input processor prepends
502
503
504
505
506
507
508
509
510
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
        29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
        29901]`.

        We insert 575 tokens so that including the original image token in the
        input, there are a total of 576 (24 * 24) image tokens, which
        corresponds to the number of image tokens inputted to the language
        model, i.e. the number of image tokens outputted by the visual encoder.
511
512
513
514
515
516
517

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
Cyrus Leung's avatar
Cyrus Leung committed
518
            pixel_values: The pixels in each input image.
519

520
521
        See also:
            :class:`LlavaImageInputs`
522
        """
523
524
        if intermediate_tensors is not None:
            inputs_embeds = None
525
526
527

        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
528
        elif inputs_embeds is None:
529
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
530
531
532
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
            input_ids = None
533

534
535
536
537
        hidden_states = self.language_model.model(input_ids,
                                                  positions,
                                                  kv_caches,
                                                  attn_metadata,
538
                                                  intermediate_tensors,
539
                                                  inputs_embeds=inputs_embeds)
540
541
542

        return hidden_states

543
544
545
546
547
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
548
549
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)
550
551
552
553
554
555

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
556
        return self.language_model.sample(logits, sampling_metadata)
557

558
559
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
560
        loader = AutoWeightsLoader(self)
561
        return loader.load_weights(weights)