llava.py 20.8 KB
Newer Older
1
from functools import cached_property
2
3
from typing import (Iterable, List, Literal, Mapping, Optional, Protocol,
                    Tuple, TypedDict, Union)
4
5

import torch
6
import torch.nn as nn
7
from PIL import Image
8
from transformers import (CLIPVisionConfig, LlavaConfig, PixtralVisionConfig,
9
                          PretrainedConfig, SiglipVisionConfig)
10
11

from vllm.attention import AttentionMetadata
12
from vllm.config import CacheConfig, MultiModalConfig
13
from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext
14
from vllm.model_executor.layers.activation import get_act_fn
15
from vllm.model_executor.layers.quantization import QuantizationConfig
16
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
17
from vllm.model_executor.sampling_metadata import SamplingMetadata
18
from vllm.multimodal import MULTIMODAL_REGISTRY
19
from vllm.sequence import IntermediateTensors
20
from vllm.utils import is_list_of
21

22
23
24
from .clip import (CLIPVisionModel, dummy_image_for_clip,
                   dummy_seq_data_for_clip, get_max_clip_image_tokens,
                   input_processor_for_clip)
25
from .interfaces import SupportsMultiModal, SupportsPP
26
27
28
29
from .pixtral import (PixtralHFVisionModel, dummy_image_for_pixtral_hf,
                      dummy_seq_data_for_pixtral_hf,
                      get_max_pixtral_hf_image_tokens,
                      input_processor_for_pixtral_hf)
30
31
32
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
                     dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
                     input_processor_for_siglip)
33
34
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
                    merge_multimodal_embeddings)
35
36


37
38
class LlavaImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
39
40
41
42
43
44
45
    data: Union[torch.Tensor, List[torch.Tensor]]
    """
    Shape: `(batch_size * num_images, num_channels, height, width)`

    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """
46
47
48
49
50


class LlavaImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
    data: torch.Tensor
51
    """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
52
53
54
55
56
57
58
59

    `hidden_size` must match the hidden size of language model backbone.
    """


LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageEmbeddingInputs]


60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# TODO(xwjiang): Run benchmark and decide if TP.
class LlavaMultiModalProjector(nn.Module):

    def __init__(self, vision_hidden_size: int, text_hidden_size: int,
                 projector_hidden_act: str):
        super().__init__()

        self.linear_1 = nn.Linear(vision_hidden_size,
                                  text_hidden_size,
                                  bias=True)
        self.act = get_act_fn(projector_hidden_act)
        self.linear_2 = nn.Linear(text_hidden_size,
                                  text_hidden_size,
                                  bias=True)

75
    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
76
77
78
79
80
81
        hidden_states = self.linear_1(image_features)
        hidden_states = self.act(hidden_states)
        hidden_states = self.linear_2(hidden_states)
        return hidden_states


82
83
84
85
86
def get_max_llava_image_tokens(ctx: InputContext):
    hf_config = ctx.get_hf_config(LlavaConfig)
    vision_config = hf_config.vision_config

    if isinstance(vision_config, CLIPVisionConfig):
87
88
89
        num_image_tokens = get_max_clip_image_tokens(vision_config)
    elif isinstance(vision_config, SiglipVisionConfig):
        num_image_tokens = get_max_siglip_image_tokens(vision_config)
90
91
    elif isinstance(vision_config, PixtralVisionConfig):
        num_image_tokens = get_max_pixtral_hf_image_tokens(vision_config)
92
93
94
95
96
97
98
99
100
101
102
    else:
        msg = f"Unsupported vision config: {type(vision_config)}"
        raise NotImplementedError(msg)

    strategy = hf_config.vision_feature_select_strategy
    if strategy == "default":
        return num_image_tokens - 1
    elif strategy == "full":
        return num_image_tokens
    else:
        raise ValueError(f"Unexpected select feature strategy: {strategy}")
103
104


105
106
def dummy_data_for_llava(ctx: InputContext, seq_len: int,
                         mm_counts: Mapping[str, int]):
107
108
    hf_config = ctx.get_hf_config(LlavaConfig)
    vision_config = hf_config.vision_config
109
    num_images = mm_counts["image"]
110

111
112
    image_feature_size = get_max_llava_image_tokens(ctx)

113
114
115
116
    if isinstance(vision_config, CLIPVisionConfig):
        seq_data = dummy_seq_data_for_clip(
            vision_config,
            seq_len,
117
            num_images,
118
            image_token_id=hf_config.image_token_index,
119
            image_feature_size_override=image_feature_size,
120
121
        )

122
        mm_data = dummy_image_for_clip(vision_config, num_images)
123
        return seq_data, mm_data
124
125
126
127
    elif isinstance(vision_config, SiglipVisionConfig):
        seq_data = dummy_seq_data_for_siglip(
            vision_config,
            seq_len,
128
            num_images,
129
130
131
132
            image_token_id=hf_config.image_token_index,
            image_feature_size_override=image_feature_size,
        )

133
        mm_data = dummy_image_for_siglip(vision_config, num_images)
134
        return seq_data, mm_data
135
136
137
138
139
140
141
142
143
144
145
    elif isinstance(vision_config, PixtralVisionConfig):
        seq_data = dummy_seq_data_for_pixtral_hf(
            vision_config,
            seq_len,
            num_images,
            image_token_id=hf_config.image_token_index,
            image_feature_size_override=image_feature_size,
        )

        mm_data = dummy_image_for_pixtral_hf(vision_config, num_images)
        return seq_data, mm_data
146
147
148
149
150

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


151
152
def input_processor_for_llava(ctx: InputContext, inputs: DecoderOnlyInputs):
    multi_modal_data = inputs.get("multi_modal_data")
153
    if multi_modal_data is None or "image" not in multi_modal_data:
154
        return inputs
155
156
157
158
159

    model_config = ctx.model_config
    hf_config = ctx.get_hf_config(LlavaConfig)
    vision_config = hf_config.vision_config

160
161
162
163
164
165
166
167
168
169
170
171
    image_data = multi_modal_data["image"]
    if isinstance(image_data, Image.Image):
        image_feature_size = get_max_llava_image_tokens(ctx)
    elif is_list_of(image_data, Image.Image):
        image_feature_size = [get_max_llava_image_tokens(ctx)
                              ] * len(image_data)
    elif isinstance(image_data, torch.Tensor):
        num_images, image_feature_size, hidden_size = image_data.shape
    elif is_list_of(image_data, torch.Tensor):
        image_feature_size = [item.shape[1] for item in image_data]
    else:
        raise TypeError(f"Invalid image type: {type(image_data)}")
172

173
174
175
176
    if isinstance(vision_config, CLIPVisionConfig):
        return input_processor_for_clip(
            model_config,
            vision_config,
177
            inputs,
178
            image_token_id=hf_config.image_token_index,
179
180
181
182
183
184
            image_feature_size_override=image_feature_size,
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return input_processor_for_siglip(
            model_config,
            vision_config,
185
            inputs,
186
187
188
            image_token_id=hf_config.image_token_index,
            image_feature_size_override=image_feature_size,
        )
189
190
191
192
193
194
195
196
197
    elif isinstance(vision_config, PixtralVisionConfig):
        # We ignore image_feature_size_override since we have non-uniform
        # image sizes for Pixtral
        return input_processor_for_pixtral_hf(
            model_config,
            vision_config,
            inputs,
            image_token_id=hf_config.image_token_index,
        )
198
199
200
201
202

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


203
204
205
206
207
208
209
210
211
212
213
class LlavaLikeConfig(Protocol):
    vision_config: PretrainedConfig
    vision_feature_layer: int


def init_vision_tower_for_llava(
    hf_config: LlavaLikeConfig,
    quant_config: Optional[QuantizationConfig],
    *,
    require_post_norm: Optional[bool] = None,
):
214
215
216
217
218
219
220
221
222
223
224
225
226
    vision_config = hf_config.vision_config

    # Initialize the vision tower only up to the required feature layer
    vision_feature_layer = hf_config.vision_feature_layer
    if vision_feature_layer < 0:
        num_hidden_layers = hf_config.vision_config.num_hidden_layers \
            + vision_feature_layer + 1
    else:
        num_hidden_layers = vision_feature_layer + 1

    if isinstance(vision_config, CLIPVisionConfig):
        return CLIPVisionModel(
            vision_config,
227
            quant_config,
228
            num_hidden_layers_override=num_hidden_layers,
229
            require_post_norm=require_post_norm,
230
231
232
233
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return SiglipVisionModel(
            vision_config,
234
            quant_config,
235
            num_hidden_layers_override=num_hidden_layers,
236
            require_post_norm=require_post_norm,
237
        )
238
    elif isinstance(vision_config, PixtralVisionConfig):
239
240
241
242
243
244
        return PixtralHFVisionModel(
            vision_config,
            quant_config,
            num_hidden_layers_override=num_hidden_layers,
            require_post_norm=require_post_norm,
        )
245
246
247
248
249

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


250
@MULTIMODAL_REGISTRY.register_image_input_mapper()
251
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
252
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava)
253
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava)
254
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
255

256
    def __init__(self,
257
                 config: LlavaConfig,
258
                 multimodal_config: MultiModalConfig,
259
                 cache_config: Optional[CacheConfig] = None,
260
                 quant_config: Optional[QuantizationConfig] = None) -> None:
261
        super().__init__()
262

263
        self.config = config
264
        self.multimodal_config = multimodal_config
265

266
267
268
269
270
271
272
273
274
        # NOTE: These are special cases for Pixtral-12B in the HF-format
        # https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json  # noqa
        if (config.text_config.architectures is None
                and config.text_config.model_type == "mistral"):
            config.text_config.architectures = ["MistralForCausalLM"]
        if (config.projector_hidden_act is None
                and config.vision_config.hidden_act == "gelu"):
            config.projector_hidden_act = "gelu"

275
        # TODO: Optionally initializes this for supporting embeddings.
276
277
        self.vision_tower = init_vision_tower_for_llava(
            config, quant_config, require_post_norm=False)
278
279
280
281
282
        self.multi_modal_projector = LlavaMultiModalProjector(
            vision_hidden_size=config.vision_config.hidden_size,
            text_hidden_size=config.text_config.hidden_size,
            projector_hidden_act=config.projector_hidden_act)

283
284
        self.language_model = init_vllm_registered_model(
            config.text_config, cache_config, quant_config)
285

286
287
288
289
290
291
292
293
294
295
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

    @cached_property
    def sampler(self):
        if hasattr(self.language_model, "sampler"):
            return self.language_model.sampler

        return Sampler()

296
297
298
299
300
301
302
    def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)
        actual_dims = tuple(data.shape[1:])

        if actual_dims != expected_dims:
            expected_expr = ("batch_size", *map(str, expected_dims))
303
            raise ValueError(
304
305
                f"The expected shape of pixel values is {expected_expr}. "
                f"You supplied {tuple(data.shape)}.")
306
307
308

        return data

309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
    def _validate_image_sizes(self, images: List[torch.Tensor],
                              sizes: List[torch.Tensor]) -> List[torch.Tensor]:
        if not isinstance(sizes, list):
            sizes = [sizes]

        total_images = sum(size.numel() // 2 for size in sizes)
        if total_images != len(images):
            raise ValueError("Mismatch in number of images. "
                             f"Expected {total_images}, got {len(images)}")
        img_idx = 0
        for size in sizes:
            # Flatten the size tensor to a list of (height, width) pairs
            size = size.view(-1, 2).tolist()
            for expected_h, expected_w in size:
                if img_idx >= len(images):
                    raise ValueError("Ran out of images before sizes. "
                                     f"{img_idx} >= {len(images)}")
                img = images[img_idx]
                if img.shape[-2:] != (expected_h, expected_w):
                    raise ValueError(
                        "Image size mismatch. Expected "
                        f"{(expected_h, expected_w)}, got {img.shape[-2:]}")
                if img.shape[-3] != 3:
                    raise ValueError("Image channel mismatch. Expected 3, "
                                     f"got {img.shape[-3]}")
                img_idx += 1
        return images

337
    def _parse_and_validate_image_input(
338
339
            self, **kwargs: object) -> Optional[LlavaImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
340
        image_sizes = kwargs.pop("image_sizes", None)
341
        image_embeds = kwargs.pop("image_embeds", None)
342

343
        if pixel_values is None and image_embeds is None:
344
            return None
345

346
        if pixel_values is not None:
347
            if not isinstance(pixel_values, (torch.Tensor, list)):
348
349
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")
350

351
352
353
354
            # Case for models like PixtralHF that have dynamic image sizes
            # so we need to produce a list of tensors
            if image_sizes is not None:
                images = pixel_values
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373

                def flatten_to_3d_tensors(item):
                    if isinstance(item, torch.Tensor):
                        if item.dim() >= 3:
                            return [t for t in item.view(-1, *item.shape[-3:])]
                        else:
                            raise ValueError(
                                f"Unexpected tensor dimension: {item.dim()}")
                    elif isinstance(item, list):
                        return [
                            t for subitem in item
                            for t in flatten_to_3d_tensors(subitem)
                        ]
                    else:
                        raise ValueError(f"Unexpected type: {type(item)}")

                # Restructure the batched images into a list of lists of images
                images = flatten_to_3d_tensors(pixel_values)

374
375
                return LlavaImagePixelInputs(
                    type="pixel_values",
376
                    data=self._validate_image_sizes(images, image_sizes),
377
378
                )

379
380
            return LlavaImagePixelInputs(
                type="pixel_values",
381
382
                data=self._validate_pixel_values(
                    flatten_bn(pixel_values, concat=True)),
383
384
385
            )

        if image_embeds is not None:
386
            if not isinstance(image_embeds, (torch.Tensor, list)):
387
388
                raise ValueError("Incorrect type of image embeddings. "
                                 f"Got type: {type(image_embeds)}")
389

390
391
            return LlavaImageEmbeddingInputs(
                type="image_embeds",
392
                data=flatten_bn(image_embeds, concat=True),
393
394
395
            )

        raise AssertionError("This line should be unreachable.")
396
397
398
399
400
401
402
403
404
405
406

    def _select_image_features(self, image_features: torch.Tensor, *,
                               strategy: str) -> torch.Tensor:
        # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421  # noqa
        if strategy == "default":
            return image_features[:, 1:]
        elif strategy == "full":
            return image_features

        raise ValueError(f"Unexpected select feature strategy: {strategy}")

407
408
    def _image_pixels_to_features(
        self,
409
410
        vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
                            PixtralHFVisionModel],
411
412
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:
413

414
415
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
416
        image_features = vision_tower(pixel_values)
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432

        return self._select_image_features(
            image_features,
            strategy=self.config.vision_feature_select_strategy,
        )

    def _process_image_pixels(self,
                              inputs: LlavaImagePixelInputs) -> torch.Tensor:
        assert self.vision_tower is not None

        pixel_values = inputs["data"]

        return self._image_pixels_to_features(self.vision_tower, pixel_values)

    def _process_image_input(self,
                             image_input: LlavaImageInputs) -> torch.Tensor:
433
434
435
436

        if image_input["type"] == "image_embeds":
            return image_input["data"]

437
438
        assert self.vision_tower is not None
        image_features = self._process_image_pixels(image_input)
439
440
        return self.multi_modal_projector(image_features)

441
442
443
444
445
446
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
447
        intermediate_tensors: Optional[IntermediateTensors] = None,
448
        **kwargs: object,
449
    ) -> Union[torch.Tensor, IntermediateTensors]:
Cyrus Leung's avatar
Cyrus Leung committed
450
        """Run forward pass for LLaVA-1.5.
451
452
453

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
454

455
        Concretely, consider a text prompt:
456
457
        `"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.

458
        Tokenizer outputs:
459
460
461
462
        `[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
        278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
463
        before they are inputted to the model, so the input processor prepends
464
465
466
467
468
469
470
471
472
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
        29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
        29901]`.

        We insert 575 tokens so that including the original image token in the
        input, there are a total of 576 (24 * 24) image tokens, which
        corresponds to the number of image tokens inputted to the language
        model, i.e. the number of image tokens outputted by the visual encoder.
473
474
475
476
477
478
479

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
Cyrus Leung's avatar
Cyrus Leung committed
480
            pixel_values: The pixels in each input image.
481

482
483
        See also:
            :class:`LlavaImageInputs`
484
        """
485
486
487
488
        if intermediate_tensors is not None:
            input_ids = None
            inputs_embeds = None
        else:
489
490
            # always pass the input via `inputs_embeds`
            # to make sure the computation graph is consistent
491
            image_input = self._parse_and_validate_image_input(**kwargs)
492

493
494
495
496
            if image_input is not None:
                vision_embeddings = self._process_image_input(image_input)
                inputs_embeds = self.language_model.model.get_input_embeddings(
                    input_ids)
497

498
499
500
501
                inputs_embeds = merge_multimodal_embeddings(
                    input_ids, inputs_embeds, vision_embeddings,
                    self.config.image_token_index)
            else:
502
503
504
                inputs_embeds = self.language_model.model.get_input_embeddings(
                    input_ids)
            input_ids = None
505

506
507
508
509
        hidden_states = self.language_model.model(input_ids,
                                                  positions,
                                                  kv_caches,
                                                  attn_metadata,
510
                                                  intermediate_tensors,
511
                                                  inputs_embeds=inputs_embeds)
512
513
514

        return hidden_states

515
516
517
518
519
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
520
521
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)
522
523
524
525
526
527

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
528
        return self.language_model.sample(logits, sampling_metadata)
529

530
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
531
532
        loader = AutoWeightsLoader(self)
        loader.load_weights(weights)