llava_next.py 22.7 KB
Newer Older
1
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
2
3
4

import torch
import torch.nn as nn
5
from PIL import Image
6
from transformers import CLIPVisionConfig, LlavaNextConfig
7
8
9
10
11
from transformers.models.llava_next.modeling_llava_next import (
    get_anyres_image_grid_shape, unpad_image)
from typing_extensions import NotRequired

from vllm.attention import AttentionMetadata
12
from vllm.config import CacheConfig, MultiModalConfig
13
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
14
15
16
17
18
19
20
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
21
from vllm.model_executor.models.clip import CLIPVisionModel
22
23
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
24
from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensors
25
from vllm.sequence import IntermediateTensors, SamplerOutput
26

27
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
28
                   get_clip_patch_grid_length, input_processor_for_clip)
29
from .interfaces import SupportsVision
30
31
from .llava import LlavaMultiModalProjector
from .utils import merge_vision_embeddings
32
33
34
35
36
37
38
39

logger = init_logger(__name__)

_KEYS_TO_MODIFY_MAPPING = {
    "language_model.lm_head": "lm_head",
    "language_model.model": "language_model",
}

40
41
42
# Result in the max possible feature size (2x2 grid of 336x336px tiles)
MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448

43
44
45

class LlavaNextImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
46
47
48
49
    data: BatchedTensors
    """
    Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`

50
51
    Note that `num_patches` may be different for each batch, in which case
    the data is passed as a list instead of a batched tensor.
52
    """
53
54

    image_sizes: NotRequired[torch.Tensor]
55
56
57
58
59
    """
    Shape: `(batch_size, 2)`

    This should be in `(height, width)` format.
    """
60
61


62
LlavaNextImageInputs = LlavaNextImagePixelInputs
63
64


65
66
67
# Taken from: https://github.com/huggingface/text-generation-inference/blob/v2.0.4/server/text_generation_server/models/vlm_causal_lm.py#L91
# NOTE: new_height and new_width are further incremented to properly invert the
# floordiv operation: https://github.com/huggingface/transformers/blob/v4.42.2/src/transformers/models/llava_next/modeling_llava_next.py#L133
68
69
70
71
72
73
74
75
76
77
78
79
80
81
def _get_llava_next_num_unpadded_features(
    height: int,
    width: int,
    npatches: int,
    num_patch_height: int,
    num_patch_width: int,
) -> Tuple[int, int]:
    current_height = npatches * num_patch_height
    current_width = npatches * num_patch_width

    aspect_ratio: float = width / height
    current_aspect_ratio: float = current_width / current_height
    if aspect_ratio > current_aspect_ratio:
        new_height = (height * current_width) // width
82
83
        if new_height % 2 == 1:
            new_height += 1
84
85
86
        current_height = new_height
    else:
        new_width = (width * current_height) // height
87
88
        if new_width % 2 == 1:
            new_width += 1
89
90
91
92
93
94
95
        current_width = new_width

    unpadded_features = current_height * current_width
    newline_features = current_height
    return (unpadded_features, newline_features)


96
97
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.0.4/server/text_generation_server/models/vlm_causal_lm.py#L111
def get_llava_next_image_feature_size(
98
99
100
101
102
103
104
105
106
107
108
109
110
111
    hf_config: LlavaNextConfig,
    *,
    input_height: int,
    input_width: int,
) -> int:
    vision_config = hf_config.vision_config

    if isinstance(vision_config, CLIPVisionConfig):
        num_patches = get_clip_patch_grid_length(
            image_size=vision_config.image_size,
            patch_size=vision_config.patch_size,
        )
        base_feature_size = num_patches * num_patches

112
113
114
        # Note: We follow the "wrong" width/height order
        # [ref: PR huggingface/transformers#31588]
        num_patch_width, num_patch_height = get_anyres_image_grid_shape(
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
            image_size=(input_height, input_width),
            grid_pinpoints=hf_config.image_grid_pinpoints,
            patch_size=vision_config.image_size,
        )

        (
            unpadded_feature_size,
            newline_feature_size,
        ) = _get_llava_next_num_unpadded_features(input_height, input_width,
                                                  num_patches,
                                                  num_patch_height,
                                                  num_patch_width)

        return unpadded_feature_size + newline_feature_size + base_feature_size

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


134
135
136
137
def get_max_llava_next_image_tokens(ctx: InputContext):

    return get_llava_next_image_feature_size(
        ctx.get_hf_config(LlavaNextConfig),
138
139
        input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
        input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
140
141
142
    )


143
144
145
146
def dummy_data_for_llava_next(ctx: InputContext, seq_len: int):
    hf_config = ctx.get_hf_config(LlavaNextConfig)
    vision_config = hf_config.vision_config

147
    image_feature_size = get_max_llava_next_image_tokens(ctx)
148
149
150
151
152
153
154
155
156

    if isinstance(vision_config, CLIPVisionConfig):
        seq_data = dummy_seq_data_for_clip(
            vision_config,
            seq_len,
            image_token_id=hf_config.image_token_index,
            image_feature_size_override=image_feature_size,
        )

157
158
        mm_data = dummy_image_for_clip(
            vision_config,
159
160
            image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
            image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
161
        )
162
163
164
165
166
167
168

        return seq_data, mm_data

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


169
170
171
172
def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
    multi_modal_data = llm_inputs.get("multi_modal_data")
    if multi_modal_data is None or "image" not in multi_modal_data:
        return llm_inputs
173

174
175
176
    model_config = ctx.model_config
    hf_config = ctx.get_hf_config(LlavaNextConfig)
    vision_config = hf_config.vision_config
177

178
179
180
181
182
183
184
185
186
187
188
189
190
    image_data = multi_modal_data["image"]
    if isinstance(image_data, Image.Image):
        width, height = image_data.size

        image_feature_size = get_llava_next_image_feature_size(
            hf_config,
            input_height=height,
            input_width=width,
        )
    elif isinstance(image_data, torch.Tensor):
        raise NotImplementedError("Embeddings input is not supported yet")
    else:
        raise TypeError(f"Invalid image type: {type(image_data)}")
191

192
    vision_config = hf_config.vision_config
193

194
195
196
197
198
199
200
201
    if isinstance(vision_config, CLIPVisionConfig):
        return input_processor_for_clip(
            model_config,
            vision_config,
            llm_inputs,
            image_token_id=hf_config.image_token_index,
            image_feature_size_override=image_feature_size,
        )
202

203
204
    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)
205
206


207
@MULTIMODAL_REGISTRY.register_image_input_mapper()
208
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_next_image_tokens)
209
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next)
210
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next)
211
212
class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):

213
214
    def __init__(self,
                 config: LlavaNextConfig,
215
                 multimodal_config: MultiModalConfig,
216
217
                 cache_config: Optional[CacheConfig] = None,
                 quant_config: Optional[QuantizationConfig] = None) -> None:
218
        super().__init__()
219
220

        self.config = config
221
        self.multimodal_config = multimodal_config
222

223
        # TODO: Optionally initializes this for supporting embeddings.
224
        self.vision_tower = CLIPVisionModel(config=config.vision_config)
225
226
227
228
229
230
231
232
233
234
235
236
        self.multi_modal_projector = LlavaMultiModalProjector(
            vision_hidden_size=config.vision_config.hidden_size,
            text_hidden_size=config.text_config.hidden_size,
            projector_hidden_act=config.projector_hidden_act)

        self.quant_config = quant_config
        self.language_model = LlamaModel(config.text_config, cache_config,
                                         quant_config)
        self.unpadded_vocab_size = config.text_config.vocab_size
        self.lm_head = ParallelLMHead(
            self.unpadded_vocab_size,
            config.text_config.hidden_size,
237
238
            org_num_embeddings=self.language_model.org_vocab_size,
            quant_config=quant_config)
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
        logit_scale = getattr(config, "logit_scale", 1.0)
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size, logit_scale)
        self.sampler = Sampler()

        self.image_newline = nn.Parameter(
            torch.empty(config.text_config.hidden_size))

    def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
        if list(data.shape[1:]) != [2]:
            raise ValueError(
                f"The expected image sizes shape is batch dimension plus "
                f"{[2]}. You supplied {data.shape}.")

        return data

255
256
257
258
    def _validate_pixel_values(
        self, data: Union[torch.Tensor, List[torch.Tensor]]
    ) -> Union[torch.Tensor, List[torch.Tensor]]:

259
260
261
262
263
264
265
266
        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)

        def _validate_shape(d: torch.Tensor):
            actual_dims = tuple(d.shape[1:])

            if actual_dims != expected_dims:
                expected_expr = ("num_patches", *map(str, expected_dims))
267
                raise ValueError(
268
269
                    "The expected shape of pixel values in each batch element "
                    f"is {expected_expr}. You supplied {tuple(d.shape)}.")
270

271
272
        for d in data:
            _validate_shape(d)
273
274
275

        return data

276
    def _parse_and_validate_image_input(
277
            self, **kwargs: object) -> Optional[LlavaNextImagePixelInputs]:
278
279
280
        pixel_values = kwargs.pop("pixel_values", None)
        image_sizes = kwargs.pop("image_sizes", None)

281
        if pixel_values is None:
282
            return None
283

284
        if not isinstance(pixel_values, (torch.Tensor, list)):
285
286
            raise ValueError("Incorrect type of pixel values. "
                             f"Got type: {type(pixel_values)}")
287

288
289
290
        if not isinstance(image_sizes, torch.Tensor):
            raise ValueError("Incorrect type of image sizes. "
                             f"Got type: {type(image_sizes)}")
291

292
293
        return LlavaNextImagePixelInputs(
            type="pixel_values",
294
            data=self._validate_pixel_values(pixel_values),
295
296
            image_sizes=self._validate_image_sizes(image_sizes),
        )
297

Cyrus Leung's avatar
Cyrus Leung committed
298
299
300
301
302
303
304
305
306
307
308
309
310
    def _select_image_features(self, image_features: torch.Tensor, *,
                               strategy: str) -> torch.Tensor:
        # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421  # noqa
        if strategy == "default":
            return image_features[:, 1:]
        elif strategy == "full":
            return image_features

        raise ValueError(f"Unexpected select feature strategy: {strategy}")

    def _image_pixels_to_features(self, vision_tower: CLIPVisionModel,
                                  pixel_values: torch.Tensor) -> torch.Tensor:

311
312
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
313
        image_features = vision_tower(pixel_values,
314
                                      self.config.vision_feature_layer)
Cyrus Leung's avatar
Cyrus Leung committed
315
316
317
318
319
320

        return self._select_image_features(
            image_features,
            strategy=self.config.vision_feature_select_strategy,
        )

321
    # Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
    def _merge_image_patch_embeddings(self, image_size: torch.Tensor,
                                      patch_embeddings: torch.Tensor, *,
                                      strategy: str) -> torch.Tensor:
        if strategy == "flat":
            return patch_embeddings.flatten(0, 1)

        if strategy.startswith("spatial"):
            height = width = self.config.vision_config.image_size \
                // self.config.vision_config.patch_size

            base_patch_embeds = patch_embeddings[0]
            if height * width != base_patch_embeds.shape[0]:
                raise ValueError(
                    "The number of patches is not consistent with the "
                    "image size.")

            if patch_embeddings.shape[0] > 1:
                other_patch_embeds = patch_embeddings[1:]

                # image_aspect_ratio == "anyres"
342
343
                # Note: We follow the "wrong" width/height order
                # [ref: PR huggingface/transformers#31588]
344
                num_patch_width, num_patch_height = get_anyres_image_grid_shape(
345
                    image_size,
346
347
348
349
                    self.config.image_grid_pinpoints,
                    self.config.vision_config.image_size,
                )
                other_patch_embeds = other_patch_embeds \
350
                    .view(num_patch_height, num_patch_width, height, width, -1)
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387

                if "unpad" in strategy:
                    other_patch_embeds = other_patch_embeds \
                        .permute(4, 0, 2, 1, 3).contiguous() \
                        .flatten(1, 2).flatten(2, 3)
                    other_patch_embeds = unpad_image(other_patch_embeds,
                                                     image_size)
                    other_patch_embeds = torch.cat((
                        other_patch_embeds,
                        self.image_newline[:, None, None] \
                            .expand(*other_patch_embeds.shape[:-1], 1) \
                            .to(other_patch_embeds.device),
                    ), dim=-1)
                    other_patch_embeds = other_patch_embeds \
                        .flatten(1, 2).transpose(0, 1)
                else:
                    other_patch_embeds = other_patch_embeds \
                        .permute(0, 2, 1, 3, 4).contiguous() \
                        .flatten(0, 3)

                merged_patch_embeddings = torch.cat(
                    (base_patch_embeds, other_patch_embeds), dim=0)
            else:
                if "unpad" in strategy:
                    merged_patch_embeddings = torch.cat(
                        (base_patch_embeds,
                         self.image_newline[None] \
                            .to(base_patch_embeds.device)
                    ), dim=0)
                else:
                    merged_patch_embeddings = base_patch_embeds

            return merged_patch_embeddings

        raise ValueError(f"Unexpected patch merge strategy: {strategy}")

    def _process_image_pixels(
388
389
390
        self,
        inputs: LlavaNextImagePixelInputs,
    ) -> BatchedTensors:
391
392
393
394
        assert self.vision_tower is not None

        pixel_values = inputs["data"]

395
396
397
398
399
400
401
        if isinstance(pixel_values, torch.Tensor):
            b, num_patches, c, h, w = pixel_values.shape
            stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w)
            stacked_image_features = self._image_pixels_to_features(
                self.vision_tower, stacked_pixel_values)
            stacked_patch_embeddings = self.multi_modal_projector(
                stacked_image_features)
402

403
404
405
406
407
            return stacked_patch_embeddings.view(
                b, num_patches, *stacked_patch_embeddings.shape[1:])

        num_patches_per_batch = [v.shape[0] for v in pixel_values]
        stacked_pixel_values = torch.cat(pixel_values)
408
409
410
        stacked_image_features = self._image_pixels_to_features(
            self.vision_tower, stacked_pixel_values)

411
412
413
414
        return [
            self.multi_modal_projector(image_features) for image_features in
            torch.split(stacked_image_features, num_patches_per_batch)
        ]
415
416

    def _process_image_input(
417
418
            self, image_input: LlavaNextImageInputs) -> BatchedTensors:
        patch_embeddings = self._process_image_pixels(image_input)
419
420
421

        image_sizes = image_input.get("image_sizes")
        if image_sizes is None:
422
            batch_size = len(image_input["data"])
423
            vision_config = self.config.vision_config
424
425
            default_height = default_width = vision_config.image_size
            image_sizes = torch.as_tensor([[default_height, default_width]
426
427
                                           for _ in range(batch_size)])

428
        return [
429
            self._merge_image_patch_embeddings(image_sizes[i],
430
                                               patch_features_batch,
431
                                               strategy="spatial_unpad")
432
            for i, patch_features_batch in enumerate(patch_embeddings)
433
434
435
436
437
438
439
440
        ]

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
441
        intermediate_tensors: Optional[IntermediateTensors] = None,
442
443
        **kwargs: object,
    ) -> SamplerOutput:
Cyrus Leung's avatar
Cyrus Leung committed
444
        """Run forward pass for LlaVA-NeXT.
445
446
447

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
448

449
        Concretely, consider a text prompt:
450
451
452
453
454
        `"A chat between a curious human and an artificial intelligence
        assistant. The assistant gives helpful, detailed, and polite answers to
        the human's questions.
        USER: <image>\\nWhat is shown in this image? ASSISTANT:"`.

455
        Tokenizer outputs:
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
        `[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255,
        29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568,
        6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901,
        29871, 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973, 319, 1799,
        9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
        before they are inputted to the model, so the input processor prepends 
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255,
        29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568,
        6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901,
        29871, 32000, ..., 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973,
        319, 1799, 9047, 13566, 29901]`.

        Unlike in LLaVA-1.5, the number of image tokens inputted to the language
        model depends on the original size of the input image. Including the
        original image token in the input, the required number of image tokens
        is given by :func:`get_llava_next_image_feature_size`.
475
476
477
478
479
480
481

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
Cyrus Leung's avatar
Cyrus Leung committed
482
            pixel_values: The pixels in each grid patch for each input image.
483
            image_sizes: The original `(height, width)` for each input image.
484
        
Cyrus Leung's avatar
Cyrus Leung committed
485
        See also:
486
            :class:`LlavaNextImageInputs`
487
488
489
490
491
492
493
494
495
        """
        image_input = self._parse_and_validate_image_input(**kwargs)

        if image_input is not None:
            vision_embeddings = self._process_image_input(image_input)
            inputs_embeds = self.language_model.get_input_embeddings(input_ids)

            inputs_embeds = merge_vision_embeddings(
                input_ids, inputs_embeds, vision_embeddings,
496
                self.config.image_token_index)
497
498
499
500
501
502
503
504
505

            input_ids = None
        else:
            inputs_embeds = None

        hidden_states = self.language_model(input_ids,
                                            positions,
                                            kv_caches,
                                            attn_metadata,
506
                                            None,
507
508
509
510
511
512
                                            inputs_embeds=inputs_embeds)

        return hidden_states

    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
513
        logits = self.logits_processor(self.lm_head, hidden_states,
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
                                       sampling_metadata)
        return logits

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        next_tokens = self.sampler(logits, sampling_metadata)
        return next_tokens

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        # only doing this for language model part for now.
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue
539
540
541
            # post_layernorm is not needed in CLIPVisionModel
            if "vision_model.post_layernorm" in name:
                continue
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
            for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
                if key_to_modify in name:
                    name = name.replace(key_to_modify, new_key)
            use_default_weight_loading = False
            if "vision" in name:
                if self.vision_tower is not None:
                    # We only do sharding for language model and
                    # not vision model for now.
                    use_default_weight_loading = True
            else:
                for (param_name, weight_name,
                     shard_id) in stacked_params_mapping:
                    if weight_name not in name:
                        continue
                    param = params_dict[name.replace(weight_name, param_name)]
                    weight_loader = param.weight_loader
                    weight_loader(param, loaded_weight, shard_id)
                    break
                else:
                    use_default_weight_loading = True
            if use_default_weight_loading:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)