llava.py 21.6 KB
Newer Older
1
from functools import cached_property
2
3
from typing import (Iterable, List, Literal, Mapping, Optional, Protocol,
                    Tuple, TypedDict, Union)
4
5

import torch
6
import torch.nn as nn
7
from PIL import Image
8
from transformers import (CLIPVisionConfig, LlavaConfig, PixtralVisionConfig,
9
                          PretrainedConfig, SiglipVisionConfig)
10
11

from vllm.attention import AttentionMetadata
12
from vllm.config import VllmConfig
13
14
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
                         InputContext)
15
from vllm.model_executor.layers.activation import get_act_fn
16
from vllm.model_executor.layers.quantization import QuantizationConfig
Joe Runde's avatar
Joe Runde committed
17
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
18
from vllm.model_executor.sampling_metadata import SamplingMetadata
19
from vllm.multimodal import MULTIMODAL_REGISTRY
20
from vllm.multimodal.base import NestedTensors
21
from vllm.sequence import IntermediateTensors
22
from vllm.utils import is_list_of
23

24
25
26
from .clip import (CLIPVisionModel, dummy_image_for_clip,
                   dummy_seq_data_for_clip, get_max_clip_image_tokens,
                   input_processor_for_clip)
27
from .interfaces import SupportsMultiModal, SupportsPP
28
29
30
31
from .pixtral import (PixtralHFVisionModel, dummy_image_for_pixtral_hf,
                      dummy_seq_data_for_pixtral_hf,
                      get_max_pixtral_hf_image_tokens,
                      input_processor_for_pixtral_hf)
32
33
34
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
                     dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
                     input_processor_for_siglip)
35
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
36
                    maybe_prefix, merge_multimodal_embeddings)
37
38


39
40
class LlavaImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
41
42
43
44
45
46
47
    data: Union[torch.Tensor, List[torch.Tensor]]
    """
    Shape: `(batch_size * num_images, num_channels, height, width)`

    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """
48
49
50
51
52


class LlavaImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
    data: torch.Tensor
53
    """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
54
55
56
57
58
59
60
61

    `hidden_size` must match the hidden size of language model backbone.
    """


LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageEmbeddingInputs]


62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
# TODO(xwjiang): Run benchmark and decide if TP.
class LlavaMultiModalProjector(nn.Module):

    def __init__(self, vision_hidden_size: int, text_hidden_size: int,
                 projector_hidden_act: str):
        super().__init__()

        self.linear_1 = nn.Linear(vision_hidden_size,
                                  text_hidden_size,
                                  bias=True)
        self.act = get_act_fn(projector_hidden_act)
        self.linear_2 = nn.Linear(text_hidden_size,
                                  text_hidden_size,
                                  bias=True)

77
    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
78
79
80
81
82
83
        hidden_states = self.linear_1(image_features)
        hidden_states = self.act(hidden_states)
        hidden_states = self.linear_2(hidden_states)
        return hidden_states


84
85
86
87
88
def get_max_llava_image_tokens(ctx: InputContext):
    hf_config = ctx.get_hf_config(LlavaConfig)
    vision_config = hf_config.vision_config

    if isinstance(vision_config, CLIPVisionConfig):
89
90
91
        num_image_tokens = get_max_clip_image_tokens(vision_config)
    elif isinstance(vision_config, SiglipVisionConfig):
        num_image_tokens = get_max_siglip_image_tokens(vision_config)
92
93
    elif isinstance(vision_config, PixtralVisionConfig):
        num_image_tokens = get_max_pixtral_hf_image_tokens(vision_config)
94
95
96
97
98
99
100
101
102
103
104
    else:
        msg = f"Unsupported vision config: {type(vision_config)}"
        raise NotImplementedError(msg)

    strategy = hf_config.vision_feature_select_strategy
    if strategy == "default":
        return num_image_tokens - 1
    elif strategy == "full":
        return num_image_tokens
    else:
        raise ValueError(f"Unexpected select feature strategy: {strategy}")
105
106


107
108
def dummy_data_for_llava(ctx: InputContext, seq_len: int,
                         mm_counts: Mapping[str, int]):
109
110
    hf_config = ctx.get_hf_config(LlavaConfig)
    vision_config = hf_config.vision_config
111
    num_images = mm_counts["image"]
112

113
114
    image_feature_size = get_max_llava_image_tokens(ctx)

115
    if isinstance(vision_config, CLIPVisionConfig):
116
        seq_data, ranges = dummy_seq_data_for_clip(
117
118
            vision_config,
            seq_len,
119
            num_images,
120
            image_token_id=hf_config.image_token_index,
121
            image_feature_size_override=image_feature_size,
122
123
        )

124
        mm_data = dummy_image_for_clip(vision_config, num_images)
125
        return DummyData(seq_data, mm_data, ranges)
126
    elif isinstance(vision_config, SiglipVisionConfig):
127
        seq_data, ranges = dummy_seq_data_for_siglip(
128
129
            vision_config,
            seq_len,
130
            num_images,
131
132
133
134
            image_token_id=hf_config.image_token_index,
            image_feature_size_override=image_feature_size,
        )

135
        mm_data = dummy_image_for_siglip(vision_config, num_images)
136
        return DummyData(seq_data, mm_data, ranges)
137
    elif isinstance(vision_config, PixtralVisionConfig):
138
        seq_data, ranges = dummy_seq_data_for_pixtral_hf(
139
140
141
142
143
144
145
146
            vision_config,
            seq_len,
            num_images,
            image_token_id=hf_config.image_token_index,
            image_feature_size_override=image_feature_size,
        )

        mm_data = dummy_image_for_pixtral_hf(vision_config, num_images)
147
        return DummyData(seq_data, mm_data, ranges)
148
149
150
151
152

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


153
154
def input_processor_for_llava(ctx: InputContext, inputs: DecoderOnlyInputs):
    multi_modal_data = inputs.get("multi_modal_data")
155
    if multi_modal_data is None or "image" not in multi_modal_data:
156
        return inputs
157
158
159
160
161

    model_config = ctx.model_config
    hf_config = ctx.get_hf_config(LlavaConfig)
    vision_config = hf_config.vision_config

162
163
164
165
166
167
168
169
170
171
172
173
    image_data = multi_modal_data["image"]
    if isinstance(image_data, Image.Image):
        image_feature_size = get_max_llava_image_tokens(ctx)
    elif is_list_of(image_data, Image.Image):
        image_feature_size = [get_max_llava_image_tokens(ctx)
                              ] * len(image_data)
    elif isinstance(image_data, torch.Tensor):
        num_images, image_feature_size, hidden_size = image_data.shape
    elif is_list_of(image_data, torch.Tensor):
        image_feature_size = [item.shape[1] for item in image_data]
    else:
        raise TypeError(f"Invalid image type: {type(image_data)}")
174

175
176
177
178
    if isinstance(vision_config, CLIPVisionConfig):
        return input_processor_for_clip(
            model_config,
            vision_config,
179
            inputs,
180
            image_token_id=hf_config.image_token_index,
181
182
183
184
185
186
            image_feature_size_override=image_feature_size,
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return input_processor_for_siglip(
            model_config,
            vision_config,
187
            inputs,
188
189
190
            image_token_id=hf_config.image_token_index,
            image_feature_size_override=image_feature_size,
        )
191
192
193
194
195
196
197
198
199
    elif isinstance(vision_config, PixtralVisionConfig):
        # We ignore image_feature_size_override since we have non-uniform
        # image sizes for Pixtral
        return input_processor_for_pixtral_hf(
            model_config,
            vision_config,
            inputs,
            image_token_id=hf_config.image_token_index,
        )
200
201
202
203
204

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


205
206
207
208
209
210
211
212
213
214
class LlavaLikeConfig(Protocol):
    vision_config: PretrainedConfig
    vision_feature_layer: int


def init_vision_tower_for_llava(
    hf_config: LlavaLikeConfig,
    quant_config: Optional[QuantizationConfig],
    *,
    require_post_norm: Optional[bool] = None,
215
    prefix: str = "",
216
):
217
218
219
220
221
222
223
224
225
226
227
228
229
    vision_config = hf_config.vision_config

    # Initialize the vision tower only up to the required feature layer
    vision_feature_layer = hf_config.vision_feature_layer
    if vision_feature_layer < 0:
        num_hidden_layers = hf_config.vision_config.num_hidden_layers \
            + vision_feature_layer + 1
    else:
        num_hidden_layers = vision_feature_layer + 1

    if isinstance(vision_config, CLIPVisionConfig):
        return CLIPVisionModel(
            vision_config,
230
            quant_config=quant_config,
231
            num_hidden_layers_override=num_hidden_layers,
232
            require_post_norm=require_post_norm,
233
            prefix=prefix,
234
235
236
237
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return SiglipVisionModel(
            vision_config,
238
            quant_config=quant_config,
239
            num_hidden_layers_override=num_hidden_layers,
240
            require_post_norm=require_post_norm,
241
            prefix=prefix,
242
        )
243
    elif isinstance(vision_config, PixtralVisionConfig):
244
245
        return PixtralHFVisionModel(
            vision_config,
246
            quant_config=quant_config,
247
248
            num_hidden_layers_override=num_hidden_layers,
            require_post_norm=require_post_norm,
249
            prefix=prefix,
250
        )
251
252
253
254
255

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


256
@MULTIMODAL_REGISTRY.register_image_input_mapper()
257
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
258
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava)
259
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava)
260
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
261

262
    def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
263
        super().__init__()
264

265
266
267
268
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

269
        self.config = config
270
        self.multimodal_config = multimodal_config
271

272
273
274
275
276
277
278
279
280
        # NOTE: These are special cases for Pixtral-12B in the HF-format
        # https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json  # noqa
        if (config.text_config.architectures is None
                and config.text_config.model_type == "mistral"):
            config.text_config.architectures = ["MistralForCausalLM"]
        if (config.projector_hidden_act is None
                and config.vision_config.hidden_act == "gelu"):
            config.projector_hidden_act = "gelu"

281
        # TODO: Optionally initializes this for supporting embeddings.
282
        self.vision_tower = init_vision_tower_for_llava(
283
284
285
            config,
            quant_config,
            require_post_norm=False,
286
            prefix=maybe_prefix(prefix, "vision_tower"))
287
288
289
290
291
        self.multi_modal_projector = LlavaMultiModalProjector(
            vision_hidden_size=config.vision_config.hidden_size,
            text_hidden_size=config.text_config.hidden_size,
            projector_hidden_act=config.projector_hidden_act)

292
        self.language_model = init_vllm_registered_model(
293
            config.text_config,
294
            vllm_config=vllm_config,
295
            prefix=maybe_prefix(prefix, "language_model"))
296

297
298
299
300
301
302
303
304
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

    @cached_property
    def sampler(self):
        if hasattr(self.language_model, "sampler"):
            return self.language_model.sampler

Joe Runde's avatar
Joe Runde committed
305
        return get_sampler()
306

307
308
309
310
311
312
313
    def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)
        actual_dims = tuple(data.shape[1:])

        if actual_dims != expected_dims:
            expected_expr = ("batch_size", *map(str, expected_dims))
314
            raise ValueError(
315
316
                f"The expected shape of pixel values is {expected_expr}. "
                f"You supplied {tuple(data.shape)}.")
317
318
319

        return data

320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
    def _validate_image_sizes(self, images: List[torch.Tensor],
                              sizes: List[torch.Tensor]) -> List[torch.Tensor]:
        if not isinstance(sizes, list):
            sizes = [sizes]

        total_images = sum(size.numel() // 2 for size in sizes)
        if total_images != len(images):
            raise ValueError("Mismatch in number of images. "
                             f"Expected {total_images}, got {len(images)}")
        img_idx = 0
        for size in sizes:
            # Flatten the size tensor to a list of (height, width) pairs
            size = size.view(-1, 2).tolist()
            for expected_h, expected_w in size:
                if img_idx >= len(images):
                    raise ValueError("Ran out of images before sizes. "
                                     f"{img_idx} >= {len(images)}")
                img = images[img_idx]
                if img.shape[-2:] != (expected_h, expected_w):
                    raise ValueError(
                        "Image size mismatch. Expected "
                        f"{(expected_h, expected_w)}, got {img.shape[-2:]}")
                if img.shape[-3] != 3:
                    raise ValueError("Image channel mismatch. Expected 3, "
                                     f"got {img.shape[-3]}")
                img_idx += 1
        return images

348
    def _parse_and_validate_image_input(
349
350
            self, **kwargs: object) -> Optional[LlavaImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
351
        image_sizes = kwargs.pop("image_sizes", None)
352
        image_embeds = kwargs.pop("image_embeds", None)
353

354
        if pixel_values is None and image_embeds is None:
355
            return None
356

357
        if pixel_values is not None:
358
            if not isinstance(pixel_values, (torch.Tensor, list)):
359
360
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")
361

362
363
364
365
            # Case for models like PixtralHF that have dynamic image sizes
            # so we need to produce a list of tensors
            if image_sizes is not None:
                images = pixel_values
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384

                def flatten_to_3d_tensors(item):
                    if isinstance(item, torch.Tensor):
                        if item.dim() >= 3:
                            return [t for t in item.view(-1, *item.shape[-3:])]
                        else:
                            raise ValueError(
                                f"Unexpected tensor dimension: {item.dim()}")
                    elif isinstance(item, list):
                        return [
                            t for subitem in item
                            for t in flatten_to_3d_tensors(subitem)
                        ]
                    else:
                        raise ValueError(f"Unexpected type: {type(item)}")

                # Restructure the batched images into a list of lists of images
                images = flatten_to_3d_tensors(pixel_values)

385
386
                return LlavaImagePixelInputs(
                    type="pixel_values",
387
                    data=self._validate_image_sizes(images, image_sizes),
388
389
                )

390
391
            return LlavaImagePixelInputs(
                type="pixel_values",
392
393
                data=self._validate_pixel_values(
                    flatten_bn(pixel_values, concat=True)),
394
395
396
            )

        if image_embeds is not None:
397
            if not isinstance(image_embeds, (torch.Tensor, list)):
398
399
                raise ValueError("Incorrect type of image embeddings. "
                                 f"Got type: {type(image_embeds)}")
400

401
402
            return LlavaImageEmbeddingInputs(
                type="image_embeds",
403
                data=flatten_bn(image_embeds, concat=True),
404
405
406
            )

        raise AssertionError("This line should be unreachable.")
407
408
409
410
411
412
413
414
415
416
417

    def _select_image_features(self, image_features: torch.Tensor, *,
                               strategy: str) -> torch.Tensor:
        # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421  # noqa
        if strategy == "default":
            return image_features[:, 1:]
        elif strategy == "full":
            return image_features

        raise ValueError(f"Unexpected select feature strategy: {strategy}")

418
419
    def _image_pixels_to_features(
        self,
420
421
        vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
                            PixtralHFVisionModel],
422
423
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:
424

425
426
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
427
        image_features = vision_tower(pixel_values)
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443

        return self._select_image_features(
            image_features,
            strategy=self.config.vision_feature_select_strategy,
        )

    def _process_image_pixels(self,
                              inputs: LlavaImagePixelInputs) -> torch.Tensor:
        assert self.vision_tower is not None

        pixel_values = inputs["data"]

        return self._image_pixels_to_features(self.vision_tower, pixel_values)

    def _process_image_input(self,
                             image_input: LlavaImageInputs) -> torch.Tensor:
444
445
446
447

        if image_input["type"] == "image_embeds":
            return image_input["data"]

448
449
        assert self.vision_tower is not None
        image_features = self._process_image_pixels(image_input)
450
451
        return self.multi_modal_projector(image_features)

452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
    def process_mm_inputs(self, **kwargs):
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return None
        vision_embeddings = self._process_image_input(image_input)
        return vision_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
        vision_embeddings: Optional[NestedTensors] = None,
    ) -> torch.Tensor:
        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
        if vision_embeddings is not None:
            inputs_embeds = merge_multimodal_embeddings(
                input_ids, inputs_embeds, vision_embeddings,
                self.config.image_token_index)
        return inputs_embeds

471
472
473
474
475
476
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
477
        intermediate_tensors: Optional[IntermediateTensors] = None,
478
        inputs_embeds: Optional[torch.Tensor] = None,
479
        **kwargs: object,
480
    ) -> Union[torch.Tensor, IntermediateTensors]:
Cyrus Leung's avatar
Cyrus Leung committed
481
        """Run forward pass for LLaVA-1.5.
482
483
484

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
485

486
        Concretely, consider a text prompt:
487
488
        `"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.

489
        Tokenizer outputs:
490
491
492
493
        `[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
        278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
494
        before they are inputted to the model, so the input processor prepends
495
496
497
498
499
500
501
502
503
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
        29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
        29901]`.

        We insert 575 tokens so that including the original image token in the
        input, there are a total of 576 (24 * 24) image tokens, which
        corresponds to the number of image tokens inputted to the language
        model, i.e. the number of image tokens outputted by the visual encoder.
504
505
506
507
508
509
510

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
Cyrus Leung's avatar
Cyrus Leung committed
511
            pixel_values: The pixels in each input image.
512

513
514
        See also:
            :class:`LlavaImageInputs`
515
        """
516
517
        if intermediate_tensors is not None:
            inputs_embeds = None
518
519
520
521
522
523
524
        elif inputs_embeds is None:
            vision_embeddings = self.process_mm_inputs(**kwargs)
            # always pass the input via `inputs_embeds`
            # to make sure the computation graph is consistent
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
            input_ids = None
525

526
527
528
529
        hidden_states = self.language_model.model(input_ids,
                                                  positions,
                                                  kv_caches,
                                                  attn_metadata,
530
                                                  intermediate_tensors,
531
                                                  inputs_embeds=inputs_embeds)
532
533
534

        return hidden_states

535
536
537
538
539
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
540
541
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)
542
543
544
545
546
547

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
548
        return self.language_model.sample(logits, sampling_metadata)
549

550
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
551
552
        loader = AutoWeightsLoader(self)
        loader.load_weights(weights)