llava.py 21 KB
Newer Older
1
from functools import cached_property
2
3
from typing import (Iterable, List, Literal, Mapping, Optional, Protocol,
                    Tuple, TypedDict, Union)
4
5

import torch
6
import torch.nn as nn
7
from PIL import Image
8
from transformers import (CLIPVisionConfig, LlavaConfig, PixtralVisionConfig,
9
                          PretrainedConfig, SiglipVisionConfig)
10
11

from vllm.attention import AttentionMetadata
12
from vllm.config import CacheConfig, MultiModalConfig
13
from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext
14
from vllm.model_executor.layers.activation import get_act_fn
15
from vllm.model_executor.layers.quantization import QuantizationConfig
16
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
17
from vllm.model_executor.sampling_metadata import SamplingMetadata
18
from vllm.multimodal import MULTIMODAL_REGISTRY
19
from vllm.sequence import IntermediateTensors
20
from vllm.utils import is_list_of
21

22
23
24
from .clip import (CLIPVisionModel, dummy_image_for_clip,
                   dummy_seq_data_for_clip, get_max_clip_image_tokens,
                   input_processor_for_clip)
25
from .interfaces import SupportsMultiModal, SupportsPP
26
27
28
29
from .pixtral import (PixtralHFVisionModel, dummy_image_for_pixtral_hf,
                      dummy_seq_data_for_pixtral_hf,
                      get_max_pixtral_hf_image_tokens,
                      input_processor_for_pixtral_hf)
30
31
32
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
                     dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
                     input_processor_for_siglip)
33
34
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
                    merge_multimodal_embeddings)
35
36


37
38
class LlavaImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
39
40
41
42
43
44
45
    data: Union[torch.Tensor, List[torch.Tensor]]
    """
    Shape: `(batch_size * num_images, num_channels, height, width)`

    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """
46
47
48
49
50


class LlavaImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
    data: torch.Tensor
51
    """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
52
53
54
55
56
57
58
59

    `hidden_size` must match the hidden size of language model backbone.
    """


LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageEmbeddingInputs]


60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# TODO(xwjiang): Run benchmark and decide if TP.
class LlavaMultiModalProjector(nn.Module):

    def __init__(self, vision_hidden_size: int, text_hidden_size: int,
                 projector_hidden_act: str):
        super().__init__()

        self.linear_1 = nn.Linear(vision_hidden_size,
                                  text_hidden_size,
                                  bias=True)
        self.act = get_act_fn(projector_hidden_act)
        self.linear_2 = nn.Linear(text_hidden_size,
                                  text_hidden_size,
                                  bias=True)

75
    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
76
77
78
79
80
81
        hidden_states = self.linear_1(image_features)
        hidden_states = self.act(hidden_states)
        hidden_states = self.linear_2(hidden_states)
        return hidden_states


82
83
84
85
86
def get_max_llava_image_tokens(ctx: InputContext):
    hf_config = ctx.get_hf_config(LlavaConfig)
    vision_config = hf_config.vision_config

    if isinstance(vision_config, CLIPVisionConfig):
87
88
89
        num_image_tokens = get_max_clip_image_tokens(vision_config)
    elif isinstance(vision_config, SiglipVisionConfig):
        num_image_tokens = get_max_siglip_image_tokens(vision_config)
90
91
    elif isinstance(vision_config, PixtralVisionConfig):
        num_image_tokens = get_max_pixtral_hf_image_tokens(vision_config)
92
93
94
95
96
97
98
99
100
101
102
    else:
        msg = f"Unsupported vision config: {type(vision_config)}"
        raise NotImplementedError(msg)

    strategy = hf_config.vision_feature_select_strategy
    if strategy == "default":
        return num_image_tokens - 1
    elif strategy == "full":
        return num_image_tokens
    else:
        raise ValueError(f"Unexpected select feature strategy: {strategy}")
103
104


105
106
def dummy_data_for_llava(ctx: InputContext, seq_len: int,
                         mm_counts: Mapping[str, int]):
107
108
    hf_config = ctx.get_hf_config(LlavaConfig)
    vision_config = hf_config.vision_config
109
    num_images = mm_counts["image"]
110

111
112
    image_feature_size = get_max_llava_image_tokens(ctx)

113
114
115
116
    if isinstance(vision_config, CLIPVisionConfig):
        seq_data = dummy_seq_data_for_clip(
            vision_config,
            seq_len,
117
            num_images,
118
            image_token_id=hf_config.image_token_index,
119
            image_feature_size_override=image_feature_size,
120
121
        )

122
        mm_data = dummy_image_for_clip(vision_config, num_images)
123
        return seq_data, mm_data
124
125
126
127
    elif isinstance(vision_config, SiglipVisionConfig):
        seq_data = dummy_seq_data_for_siglip(
            vision_config,
            seq_len,
128
            num_images,
129
130
131
132
            image_token_id=hf_config.image_token_index,
            image_feature_size_override=image_feature_size,
        )

133
        mm_data = dummy_image_for_siglip(vision_config, num_images)
134
        return seq_data, mm_data
135
136
137
138
139
140
141
142
143
144
145
    elif isinstance(vision_config, PixtralVisionConfig):
        seq_data = dummy_seq_data_for_pixtral_hf(
            vision_config,
            seq_len,
            num_images,
            image_token_id=hf_config.image_token_index,
            image_feature_size_override=image_feature_size,
        )

        mm_data = dummy_image_for_pixtral_hf(vision_config, num_images)
        return seq_data, mm_data
146
147
148
149
150

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


151
152
def input_processor_for_llava(ctx: InputContext, inputs: DecoderOnlyInputs):
    multi_modal_data = inputs.get("multi_modal_data")
153
    if multi_modal_data is None or "image" not in multi_modal_data:
154
        return inputs
155
156
157
158
159

    model_config = ctx.model_config
    hf_config = ctx.get_hf_config(LlavaConfig)
    vision_config = hf_config.vision_config

160
161
162
163
164
165
166
167
168
169
170
171
    image_data = multi_modal_data["image"]
    if isinstance(image_data, Image.Image):
        image_feature_size = get_max_llava_image_tokens(ctx)
    elif is_list_of(image_data, Image.Image):
        image_feature_size = [get_max_llava_image_tokens(ctx)
                              ] * len(image_data)
    elif isinstance(image_data, torch.Tensor):
        num_images, image_feature_size, hidden_size = image_data.shape
    elif is_list_of(image_data, torch.Tensor):
        image_feature_size = [item.shape[1] for item in image_data]
    else:
        raise TypeError(f"Invalid image type: {type(image_data)}")
172

173
174
175
176
    if isinstance(vision_config, CLIPVisionConfig):
        return input_processor_for_clip(
            model_config,
            vision_config,
177
            inputs,
178
            image_token_id=hf_config.image_token_index,
179
180
181
182
183
184
            image_feature_size_override=image_feature_size,
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return input_processor_for_siglip(
            model_config,
            vision_config,
185
            inputs,
186
187
188
            image_token_id=hf_config.image_token_index,
            image_feature_size_override=image_feature_size,
        )
189
190
191
192
193
194
195
196
197
    elif isinstance(vision_config, PixtralVisionConfig):
        # We ignore image_feature_size_override since we have non-uniform
        # image sizes for Pixtral
        return input_processor_for_pixtral_hf(
            model_config,
            vision_config,
            inputs,
            image_token_id=hf_config.image_token_index,
        )
198
199
200
201
202

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


203
204
205
206
207
208
209
210
211
212
class LlavaLikeConfig(Protocol):
    vision_config: PretrainedConfig
    vision_feature_layer: int


def init_vision_tower_for_llava(
    hf_config: LlavaLikeConfig,
    quant_config: Optional[QuantizationConfig],
    *,
    require_post_norm: Optional[bool] = None,
213
    prefix: str = "",
214
):
215
216
217
218
219
220
221
222
223
224
225
226
227
    vision_config = hf_config.vision_config

    # Initialize the vision tower only up to the required feature layer
    vision_feature_layer = hf_config.vision_feature_layer
    if vision_feature_layer < 0:
        num_hidden_layers = hf_config.vision_config.num_hidden_layers \
            + vision_feature_layer + 1
    else:
        num_hidden_layers = vision_feature_layer + 1

    if isinstance(vision_config, CLIPVisionConfig):
        return CLIPVisionModel(
            vision_config,
228
            quant_config=quant_config,
229
            num_hidden_layers_override=num_hidden_layers,
230
            require_post_norm=require_post_norm,
231
            prefix=prefix,
232
233
234
235
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return SiglipVisionModel(
            vision_config,
236
            quant_config=quant_config,
237
            num_hidden_layers_override=num_hidden_layers,
238
            require_post_norm=require_post_norm,
239
            prefix=prefix,
240
        )
241
    elif isinstance(vision_config, PixtralVisionConfig):
242
243
        return PixtralHFVisionModel(
            vision_config,
244
            quant_config=quant_config,
245
246
            num_hidden_layers_override=num_hidden_layers,
            require_post_norm=require_post_norm,
247
            prefix=prefix,
248
        )
249
250
251
252
253

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


254
@MULTIMODAL_REGISTRY.register_image_input_mapper()
255
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
256
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava)
257
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava)
258
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
259

260
    def __init__(self,
261
                 config: LlavaConfig,
262
                 multimodal_config: MultiModalConfig,
263
                 cache_config: Optional[CacheConfig] = None,
264
                 quant_config: Optional[QuantizationConfig] = None) -> None:
265
        super().__init__()
266

267
        self.config = config
268
        self.multimodal_config = multimodal_config
269

270
271
272
273
274
275
276
277
278
        # NOTE: These are special cases for Pixtral-12B in the HF-format
        # https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json  # noqa
        if (config.text_config.architectures is None
                and config.text_config.model_type == "mistral"):
            config.text_config.architectures = ["MistralForCausalLM"]
        if (config.projector_hidden_act is None
                and config.vision_config.hidden_act == "gelu"):
            config.projector_hidden_act = "gelu"

279
        # TODO: Optionally initializes this for supporting embeddings.
280
        self.vision_tower = init_vision_tower_for_llava(
281
282
283
284
            config,
            quant_config,
            require_post_norm=False,
            prefix="vision_tower")
285
286
287
288
289
        self.multi_modal_projector = LlavaMultiModalProjector(
            vision_hidden_size=config.vision_config.hidden_size,
            text_hidden_size=config.text_config.hidden_size,
            projector_hidden_act=config.projector_hidden_act)

290
        self.language_model = init_vllm_registered_model(
291
292
293
294
            config.text_config,
            cache_config,
            quant_config,
            prefix="language_model")
295

296
297
298
299
300
301
302
303
304
305
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

    @cached_property
    def sampler(self):
        if hasattr(self.language_model, "sampler"):
            return self.language_model.sampler

        return Sampler()

306
307
308
309
310
311
312
    def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)
        actual_dims = tuple(data.shape[1:])

        if actual_dims != expected_dims:
            expected_expr = ("batch_size", *map(str, expected_dims))
313
            raise ValueError(
314
315
                f"The expected shape of pixel values is {expected_expr}. "
                f"You supplied {tuple(data.shape)}.")
316
317
318

        return data

319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
    def _validate_image_sizes(self, images: List[torch.Tensor],
                              sizes: List[torch.Tensor]) -> List[torch.Tensor]:
        if not isinstance(sizes, list):
            sizes = [sizes]

        total_images = sum(size.numel() // 2 for size in sizes)
        if total_images != len(images):
            raise ValueError("Mismatch in number of images. "
                             f"Expected {total_images}, got {len(images)}")
        img_idx = 0
        for size in sizes:
            # Flatten the size tensor to a list of (height, width) pairs
            size = size.view(-1, 2).tolist()
            for expected_h, expected_w in size:
                if img_idx >= len(images):
                    raise ValueError("Ran out of images before sizes. "
                                     f"{img_idx} >= {len(images)}")
                img = images[img_idx]
                if img.shape[-2:] != (expected_h, expected_w):
                    raise ValueError(
                        "Image size mismatch. Expected "
                        f"{(expected_h, expected_w)}, got {img.shape[-2:]}")
                if img.shape[-3] != 3:
                    raise ValueError("Image channel mismatch. Expected 3, "
                                     f"got {img.shape[-3]}")
                img_idx += 1
        return images

347
    def _parse_and_validate_image_input(
348
349
            self, **kwargs: object) -> Optional[LlavaImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
350
        image_sizes = kwargs.pop("image_sizes", None)
351
        image_embeds = kwargs.pop("image_embeds", None)
352

353
        if pixel_values is None and image_embeds is None:
354
            return None
355

356
        if pixel_values is not None:
357
            if not isinstance(pixel_values, (torch.Tensor, list)):
358
359
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")
360

361
362
363
364
            # Case for models like PixtralHF that have dynamic image sizes
            # so we need to produce a list of tensors
            if image_sizes is not None:
                images = pixel_values
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383

                def flatten_to_3d_tensors(item):
                    if isinstance(item, torch.Tensor):
                        if item.dim() >= 3:
                            return [t for t in item.view(-1, *item.shape[-3:])]
                        else:
                            raise ValueError(
                                f"Unexpected tensor dimension: {item.dim()}")
                    elif isinstance(item, list):
                        return [
                            t for subitem in item
                            for t in flatten_to_3d_tensors(subitem)
                        ]
                    else:
                        raise ValueError(f"Unexpected type: {type(item)}")

                # Restructure the batched images into a list of lists of images
                images = flatten_to_3d_tensors(pixel_values)

384
385
                return LlavaImagePixelInputs(
                    type="pixel_values",
386
                    data=self._validate_image_sizes(images, image_sizes),
387
388
                )

389
390
            return LlavaImagePixelInputs(
                type="pixel_values",
391
392
                data=self._validate_pixel_values(
                    flatten_bn(pixel_values, concat=True)),
393
394
395
            )

        if image_embeds is not None:
396
            if not isinstance(image_embeds, (torch.Tensor, list)):
397
398
                raise ValueError("Incorrect type of image embeddings. "
                                 f"Got type: {type(image_embeds)}")
399

400
401
            return LlavaImageEmbeddingInputs(
                type="image_embeds",
402
                data=flatten_bn(image_embeds, concat=True),
403
404
405
            )

        raise AssertionError("This line should be unreachable.")
406
407
408
409
410
411
412
413
414
415
416

    def _select_image_features(self, image_features: torch.Tensor, *,
                               strategy: str) -> torch.Tensor:
        # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421  # noqa
        if strategy == "default":
            return image_features[:, 1:]
        elif strategy == "full":
            return image_features

        raise ValueError(f"Unexpected select feature strategy: {strategy}")

417
418
    def _image_pixels_to_features(
        self,
419
420
        vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
                            PixtralHFVisionModel],
421
422
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:
423

424
425
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
426
        image_features = vision_tower(pixel_values)
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442

        return self._select_image_features(
            image_features,
            strategy=self.config.vision_feature_select_strategy,
        )

    def _process_image_pixels(self,
                              inputs: LlavaImagePixelInputs) -> torch.Tensor:
        assert self.vision_tower is not None

        pixel_values = inputs["data"]

        return self._image_pixels_to_features(self.vision_tower, pixel_values)

    def _process_image_input(self,
                             image_input: LlavaImageInputs) -> torch.Tensor:
443
444
445
446

        if image_input["type"] == "image_embeds":
            return image_input["data"]

447
448
        assert self.vision_tower is not None
        image_features = self._process_image_pixels(image_input)
449
450
        return self.multi_modal_projector(image_features)

451
452
453
454
455
456
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
457
        intermediate_tensors: Optional[IntermediateTensors] = None,
458
        **kwargs: object,
459
    ) -> Union[torch.Tensor, IntermediateTensors]:
Cyrus Leung's avatar
Cyrus Leung committed
460
        """Run forward pass for LLaVA-1.5.
461
462
463

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
464

465
        Concretely, consider a text prompt:
466
467
        `"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.

468
        Tokenizer outputs:
469
470
471
472
        `[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
        278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
473
        before they are inputted to the model, so the input processor prepends
474
475
476
477
478
479
480
481
482
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
        29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
        29901]`.

        We insert 575 tokens so that including the original image token in the
        input, there are a total of 576 (24 * 24) image tokens, which
        corresponds to the number of image tokens inputted to the language
        model, i.e. the number of image tokens outputted by the visual encoder.
483
484
485
486
487
488
489

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
Cyrus Leung's avatar
Cyrus Leung committed
490
            pixel_values: The pixels in each input image.
491

492
493
        See also:
            :class:`LlavaImageInputs`
494
        """
495
496
497
498
499
500
501
502
        if intermediate_tensors is not None:
            inputs_embeds = None
        else:
            image_input = self._parse_and_validate_image_input(**kwargs)
            if image_input is not None:
                vision_embeddings = self._process_image_input(image_input)
                inputs_embeds = self.language_model.model.get_input_embeddings(
                    input_ids)
503

504
505
506
507
                inputs_embeds = merge_multimodal_embeddings(
                    input_ids, inputs_embeds, vision_embeddings,
                    self.config.image_token_index)
            else:
508
509
                inputs_embeds = self.language_model.model.get_input_embeddings(
                    input_ids)
510
511
512
513
514

        # always pass the input via `inputs_embeds`
        # to make sure the computation graph is consistent
        # for `torch.compile` integration
        input_ids = None
515

516
517
518
519
        hidden_states = self.language_model.model(input_ids,
                                                  positions,
                                                  kv_caches,
                                                  attn_metadata,
520
                                                  intermediate_tensors,
521
                                                  inputs_embeds=inputs_embeds)
522
523
524

        return hidden_states

525
526
527
528
529
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
530
531
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)
532
533
534
535
536
537

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
538
        return self.language_model.sample(logits, sampling_metadata)
539

540
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
541
542
        loader = AutoWeightsLoader(self)
        loader.load_weights(weights)