llava.py 21.1 KB
Newer Older
1
from functools import cached_property
2
3
from typing import (Iterable, List, Literal, Mapping, Optional, Protocol,
                    Tuple, TypedDict, Union)
4
5

import torch
6
import torch.nn as nn
7
from PIL import Image
8
from transformers import (CLIPVisionConfig, LlavaConfig, PixtralVisionConfig,
9
                          PretrainedConfig, SiglipVisionConfig)
10
11

from vllm.attention import AttentionMetadata
12
from vllm.config import VllmConfig
13
14
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
                         InputContext)
15
from vllm.model_executor.layers.activation import get_act_fn
16
from vllm.model_executor.layers.quantization import QuantizationConfig
Joe Runde's avatar
Joe Runde committed
17
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
18
from vllm.model_executor.sampling_metadata import SamplingMetadata
19
from vllm.multimodal import MULTIMODAL_REGISTRY
20
from vllm.sequence import IntermediateTensors
21
from vllm.utils import is_list_of
22

23
24
25
from .clip import (CLIPVisionModel, dummy_image_for_clip,
                   dummy_seq_data_for_clip, get_max_clip_image_tokens,
                   input_processor_for_clip)
26
from .interfaces import SupportsMultiModal, SupportsPP
27
28
29
30
from .pixtral import (PixtralHFVisionModel, dummy_image_for_pixtral_hf,
                      dummy_seq_data_for_pixtral_hf,
                      get_max_pixtral_hf_image_tokens,
                      input_processor_for_pixtral_hf)
31
32
33
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
                     dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
                     input_processor_for_siglip)
34
35
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
                    merge_multimodal_embeddings)
36
37


38
39
class LlavaImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
40
41
42
43
44
45
46
    data: Union[torch.Tensor, List[torch.Tensor]]
    """
    Shape: `(batch_size * num_images, num_channels, height, width)`

    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """
47
48
49
50
51


class LlavaImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
    data: torch.Tensor
52
    """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
53
54
55
56
57
58
59
60

    `hidden_size` must match the hidden size of language model backbone.
    """


LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageEmbeddingInputs]


61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
# TODO(xwjiang): Run benchmark and decide if TP.
class LlavaMultiModalProjector(nn.Module):

    def __init__(self, vision_hidden_size: int, text_hidden_size: int,
                 projector_hidden_act: str):
        super().__init__()

        self.linear_1 = nn.Linear(vision_hidden_size,
                                  text_hidden_size,
                                  bias=True)
        self.act = get_act_fn(projector_hidden_act)
        self.linear_2 = nn.Linear(text_hidden_size,
                                  text_hidden_size,
                                  bias=True)

76
    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
77
78
79
80
81
82
        hidden_states = self.linear_1(image_features)
        hidden_states = self.act(hidden_states)
        hidden_states = self.linear_2(hidden_states)
        return hidden_states


83
84
85
86
87
def get_max_llava_image_tokens(ctx: InputContext):
    hf_config = ctx.get_hf_config(LlavaConfig)
    vision_config = hf_config.vision_config

    if isinstance(vision_config, CLIPVisionConfig):
88
89
90
        num_image_tokens = get_max_clip_image_tokens(vision_config)
    elif isinstance(vision_config, SiglipVisionConfig):
        num_image_tokens = get_max_siglip_image_tokens(vision_config)
91
92
    elif isinstance(vision_config, PixtralVisionConfig):
        num_image_tokens = get_max_pixtral_hf_image_tokens(vision_config)
93
94
95
96
97
98
99
100
101
102
103
    else:
        msg = f"Unsupported vision config: {type(vision_config)}"
        raise NotImplementedError(msg)

    strategy = hf_config.vision_feature_select_strategy
    if strategy == "default":
        return num_image_tokens - 1
    elif strategy == "full":
        return num_image_tokens
    else:
        raise ValueError(f"Unexpected select feature strategy: {strategy}")
104
105


106
107
def dummy_data_for_llava(ctx: InputContext, seq_len: int,
                         mm_counts: Mapping[str, int]):
108
109
    hf_config = ctx.get_hf_config(LlavaConfig)
    vision_config = hf_config.vision_config
110
    num_images = mm_counts["image"]
111

112
113
    image_feature_size = get_max_llava_image_tokens(ctx)

114
    if isinstance(vision_config, CLIPVisionConfig):
115
        seq_data, ranges = dummy_seq_data_for_clip(
116
117
            vision_config,
            seq_len,
118
            num_images,
119
            image_token_id=hf_config.image_token_index,
120
            image_feature_size_override=image_feature_size,
121
122
        )

123
        mm_data = dummy_image_for_clip(vision_config, num_images)
124
        return DummyData(seq_data, mm_data, ranges)
125
    elif isinstance(vision_config, SiglipVisionConfig):
126
        seq_data, ranges = dummy_seq_data_for_siglip(
127
128
            vision_config,
            seq_len,
129
            num_images,
130
131
132
133
            image_token_id=hf_config.image_token_index,
            image_feature_size_override=image_feature_size,
        )

134
        mm_data = dummy_image_for_siglip(vision_config, num_images)
135
        return DummyData(seq_data, mm_data, ranges)
136
    elif isinstance(vision_config, PixtralVisionConfig):
137
        seq_data, ranges = dummy_seq_data_for_pixtral_hf(
138
139
140
141
142
143
144
145
            vision_config,
            seq_len,
            num_images,
            image_token_id=hf_config.image_token_index,
            image_feature_size_override=image_feature_size,
        )

        mm_data = dummy_image_for_pixtral_hf(vision_config, num_images)
146
        return DummyData(seq_data, mm_data, ranges)
147
148
149
150
151

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


152
153
def input_processor_for_llava(ctx: InputContext, inputs: DecoderOnlyInputs):
    multi_modal_data = inputs.get("multi_modal_data")
154
    if multi_modal_data is None or "image" not in multi_modal_data:
155
        return inputs
156
157
158
159
160

    model_config = ctx.model_config
    hf_config = ctx.get_hf_config(LlavaConfig)
    vision_config = hf_config.vision_config

161
162
163
164
165
166
167
168
169
170
171
172
    image_data = multi_modal_data["image"]
    if isinstance(image_data, Image.Image):
        image_feature_size = get_max_llava_image_tokens(ctx)
    elif is_list_of(image_data, Image.Image):
        image_feature_size = [get_max_llava_image_tokens(ctx)
                              ] * len(image_data)
    elif isinstance(image_data, torch.Tensor):
        num_images, image_feature_size, hidden_size = image_data.shape
    elif is_list_of(image_data, torch.Tensor):
        image_feature_size = [item.shape[1] for item in image_data]
    else:
        raise TypeError(f"Invalid image type: {type(image_data)}")
173

174
175
176
177
    if isinstance(vision_config, CLIPVisionConfig):
        return input_processor_for_clip(
            model_config,
            vision_config,
178
            inputs,
179
            image_token_id=hf_config.image_token_index,
180
181
182
183
184
185
            image_feature_size_override=image_feature_size,
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return input_processor_for_siglip(
            model_config,
            vision_config,
186
            inputs,
187
188
189
            image_token_id=hf_config.image_token_index,
            image_feature_size_override=image_feature_size,
        )
190
191
192
193
194
195
196
197
198
    elif isinstance(vision_config, PixtralVisionConfig):
        # We ignore image_feature_size_override since we have non-uniform
        # image sizes for Pixtral
        return input_processor_for_pixtral_hf(
            model_config,
            vision_config,
            inputs,
            image_token_id=hf_config.image_token_index,
        )
199
200
201
202
203

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


204
205
206
207
208
209
210
211
212
213
class LlavaLikeConfig(Protocol):
    vision_config: PretrainedConfig
    vision_feature_layer: int


def init_vision_tower_for_llava(
    hf_config: LlavaLikeConfig,
    quant_config: Optional[QuantizationConfig],
    *,
    require_post_norm: Optional[bool] = None,
214
    prefix: str = "",
215
):
216
217
218
219
220
221
222
223
224
225
226
227
228
    vision_config = hf_config.vision_config

    # Initialize the vision tower only up to the required feature layer
    vision_feature_layer = hf_config.vision_feature_layer
    if vision_feature_layer < 0:
        num_hidden_layers = hf_config.vision_config.num_hidden_layers \
            + vision_feature_layer + 1
    else:
        num_hidden_layers = vision_feature_layer + 1

    if isinstance(vision_config, CLIPVisionConfig):
        return CLIPVisionModel(
            vision_config,
229
            quant_config=quant_config,
230
            num_hidden_layers_override=num_hidden_layers,
231
            require_post_norm=require_post_norm,
232
            prefix=prefix,
233
234
235
236
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return SiglipVisionModel(
            vision_config,
237
            quant_config=quant_config,
238
            num_hidden_layers_override=num_hidden_layers,
239
            require_post_norm=require_post_norm,
240
            prefix=prefix,
241
        )
242
    elif isinstance(vision_config, PixtralVisionConfig):
243
244
        return PixtralHFVisionModel(
            vision_config,
245
            quant_config=quant_config,
246
247
            num_hidden_layers_override=num_hidden_layers,
            require_post_norm=require_post_norm,
248
            prefix=prefix,
249
        )
250
251
252
253
254

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


255
@MULTIMODAL_REGISTRY.register_image_input_mapper()
256
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
257
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava)
258
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava)
259
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
260

261
    def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
262
        super().__init__()
263

264
265
266
267
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

268
        self.config = config
269
        self.multimodal_config = multimodal_config
270

271
272
273
274
275
276
277
278
279
        # NOTE: These are special cases for Pixtral-12B in the HF-format
        # https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json  # noqa
        if (config.text_config.architectures is None
                and config.text_config.model_type == "mistral"):
            config.text_config.architectures = ["MistralForCausalLM"]
        if (config.projector_hidden_act is None
                and config.vision_config.hidden_act == "gelu"):
            config.projector_hidden_act = "gelu"

280
        # TODO: Optionally initializes this for supporting embeddings.
281
        self.vision_tower = init_vision_tower_for_llava(
282
283
284
285
            config,
            quant_config,
            require_post_norm=False,
            prefix="vision_tower")
286
287
288
289
290
        self.multi_modal_projector = LlavaMultiModalProjector(
            vision_hidden_size=config.vision_config.hidden_size,
            text_hidden_size=config.text_config.hidden_size,
            projector_hidden_act=config.projector_hidden_act)

291
        self.language_model = init_vllm_registered_model(
292
            config.text_config,
293
            vllm_config=vllm_config,
294
            prefix="language_model")
295

296
297
298
299
300
301
302
303
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

    @cached_property
    def sampler(self):
        if hasattr(self.language_model, "sampler"):
            return self.language_model.sampler

Joe Runde's avatar
Joe Runde committed
304
        return get_sampler()
305

306
307
308
309
310
311
312
    def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)
        actual_dims = tuple(data.shape[1:])

        if actual_dims != expected_dims:
            expected_expr = ("batch_size", *map(str, expected_dims))
313
            raise ValueError(
314
315
                f"The expected shape of pixel values is {expected_expr}. "
                f"You supplied {tuple(data.shape)}.")
316
317
318

        return data

319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
    def _validate_image_sizes(self, images: List[torch.Tensor],
                              sizes: List[torch.Tensor]) -> List[torch.Tensor]:
        if not isinstance(sizes, list):
            sizes = [sizes]

        total_images = sum(size.numel() // 2 for size in sizes)
        if total_images != len(images):
            raise ValueError("Mismatch in number of images. "
                             f"Expected {total_images}, got {len(images)}")
        img_idx = 0
        for size in sizes:
            # Flatten the size tensor to a list of (height, width) pairs
            size = size.view(-1, 2).tolist()
            for expected_h, expected_w in size:
                if img_idx >= len(images):
                    raise ValueError("Ran out of images before sizes. "
                                     f"{img_idx} >= {len(images)}")
                img = images[img_idx]
                if img.shape[-2:] != (expected_h, expected_w):
                    raise ValueError(
                        "Image size mismatch. Expected "
                        f"{(expected_h, expected_w)}, got {img.shape[-2:]}")
                if img.shape[-3] != 3:
                    raise ValueError("Image channel mismatch. Expected 3, "
                                     f"got {img.shape[-3]}")
                img_idx += 1
        return images

347
    def _parse_and_validate_image_input(
348
349
            self, **kwargs: object) -> Optional[LlavaImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
350
        image_sizes = kwargs.pop("image_sizes", None)
351
        image_embeds = kwargs.pop("image_embeds", None)
352

353
        if pixel_values is None and image_embeds is None:
354
            return None
355

356
        if pixel_values is not None:
357
            if not isinstance(pixel_values, (torch.Tensor, list)):
358
359
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")
360

361
362
363
364
            # Case for models like PixtralHF that have dynamic image sizes
            # so we need to produce a list of tensors
            if image_sizes is not None:
                images = pixel_values
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383

                def flatten_to_3d_tensors(item):
                    if isinstance(item, torch.Tensor):
                        if item.dim() >= 3:
                            return [t for t in item.view(-1, *item.shape[-3:])]
                        else:
                            raise ValueError(
                                f"Unexpected tensor dimension: {item.dim()}")
                    elif isinstance(item, list):
                        return [
                            t for subitem in item
                            for t in flatten_to_3d_tensors(subitem)
                        ]
                    else:
                        raise ValueError(f"Unexpected type: {type(item)}")

                # Restructure the batched images into a list of lists of images
                images = flatten_to_3d_tensors(pixel_values)

384
385
                return LlavaImagePixelInputs(
                    type="pixel_values",
386
                    data=self._validate_image_sizes(images, image_sizes),
387
388
                )

389
390
            return LlavaImagePixelInputs(
                type="pixel_values",
391
392
                data=self._validate_pixel_values(
                    flatten_bn(pixel_values, concat=True)),
393
394
395
            )

        if image_embeds is not None:
396
            if not isinstance(image_embeds, (torch.Tensor, list)):
397
398
                raise ValueError("Incorrect type of image embeddings. "
                                 f"Got type: {type(image_embeds)}")
399

400
401
            return LlavaImageEmbeddingInputs(
                type="image_embeds",
402
                data=flatten_bn(image_embeds, concat=True),
403
404
405
            )

        raise AssertionError("This line should be unreachable.")
406
407
408
409
410
411
412
413
414
415
416

    def _select_image_features(self, image_features: torch.Tensor, *,
                               strategy: str) -> torch.Tensor:
        # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421  # noqa
        if strategy == "default":
            return image_features[:, 1:]
        elif strategy == "full":
            return image_features

        raise ValueError(f"Unexpected select feature strategy: {strategy}")

417
418
    def _image_pixels_to_features(
        self,
419
420
        vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
                            PixtralHFVisionModel],
421
422
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:
423

424
425
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
426
        image_features = vision_tower(pixel_values)
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442

        return self._select_image_features(
            image_features,
            strategy=self.config.vision_feature_select_strategy,
        )

    def _process_image_pixels(self,
                              inputs: LlavaImagePixelInputs) -> torch.Tensor:
        assert self.vision_tower is not None

        pixel_values = inputs["data"]

        return self._image_pixels_to_features(self.vision_tower, pixel_values)

    def _process_image_input(self,
                             image_input: LlavaImageInputs) -> torch.Tensor:
443
444
445
446

        if image_input["type"] == "image_embeds":
            return image_input["data"]

447
448
        assert self.vision_tower is not None
        image_features = self._process_image_pixels(image_input)
449
450
        return self.multi_modal_projector(image_features)

451
452
453
454
455
456
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
457
        intermediate_tensors: Optional[IntermediateTensors] = None,
458
        **kwargs: object,
459
    ) -> Union[torch.Tensor, IntermediateTensors]:
Cyrus Leung's avatar
Cyrus Leung committed
460
        """Run forward pass for LLaVA-1.5.
461
462
463

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
464

465
        Concretely, consider a text prompt:
466
467
        `"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.

468
        Tokenizer outputs:
469
470
471
472
        `[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
        278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
473
        before they are inputted to the model, so the input processor prepends
474
475
476
477
478
479
480
481
482
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
        29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
        29901]`.

        We insert 575 tokens so that including the original image token in the
        input, there are a total of 576 (24 * 24) image tokens, which
        corresponds to the number of image tokens inputted to the language
        model, i.e. the number of image tokens outputted by the visual encoder.
483
484
485
486
487
488
489

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
Cyrus Leung's avatar
Cyrus Leung committed
490
            pixel_values: The pixels in each input image.
491

492
493
        See also:
            :class:`LlavaImageInputs`
494
        """
495
496
497
498
499
500
501
502
        if intermediate_tensors is not None:
            inputs_embeds = None
        else:
            image_input = self._parse_and_validate_image_input(**kwargs)
            if image_input is not None:
                vision_embeddings = self._process_image_input(image_input)
                inputs_embeds = self.language_model.model.get_input_embeddings(
                    input_ids)
503

504
505
506
507
                inputs_embeds = merge_multimodal_embeddings(
                    input_ids, inputs_embeds, vision_embeddings,
                    self.config.image_token_index)
            else:
508
509
                inputs_embeds = self.language_model.model.get_input_embeddings(
                    input_ids)
510
511
512
513
514

        # always pass the input via `inputs_embeds`
        # to make sure the computation graph is consistent
        # for `torch.compile` integration
        input_ids = None
515

516
517
518
519
        hidden_states = self.language_model.model(input_ids,
                                                  positions,
                                                  kv_caches,
                                                  attn_metadata,
520
                                                  intermediate_tensors,
521
                                                  inputs_embeds=inputs_embeds)
522
523
524

        return hidden_states

525
526
527
528
529
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
530
531
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)
532
533
534
535
536
537

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
538
        return self.language_model.sample(logits, sampling_metadata)
539

540
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
541
542
        loader = AutoWeightsLoader(self)
        loader.load_weights(weights)