"vscode:/vscode.git/clone" did not exist on "695e7adcd22c25b859a6d4b3af99617aaf425708"
llava.py 15.4 KB
Newer Older
1
from functools import cached_property
2
3
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
                    TypedDict, Union)
4
5

import torch
6
import torch.nn as nn
7
from PIL import Image
8
from transformers import CLIPVisionConfig, LlavaConfig, SiglipVisionConfig
9
10

from vllm.attention import AttentionMetadata
11
from vllm.config import CacheConfig, MultiModalConfig
12
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
13
from vllm.model_executor.layers.activation import get_act_fn
14
from vllm.model_executor.layers.quantization import QuantizationConfig
15
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
16
from vllm.model_executor.sampling_metadata import SamplingMetadata
17
from vllm.multimodal import MULTIMODAL_REGISTRY
18
from vllm.sequence import IntermediateTensors
19
from vllm.utils import is_list_of
20

21
22
23
from .clip import (CLIPVisionModel, dummy_image_for_clip,
                   dummy_seq_data_for_clip, get_max_clip_image_tokens,
                   input_processor_for_clip)
24
from .interfaces import SupportsMultiModal, SupportsPP
25
26
27
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
                     dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
                     input_processor_for_siglip)
28
29
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
                    merge_multimodal_embeddings)
30
31


32
33
34
class LlavaImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
    data: torch.Tensor
35
    """Shape: `(batch_size * num_images, num_channels, height, width)`"""
36
37
38
39
40


class LlavaImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
    data: torch.Tensor
41
    """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
42
43
44
45
46
47
48
49

    `hidden_size` must match the hidden size of language model backbone.
    """


LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageEmbeddingInputs]


50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# TODO(xwjiang): Run benchmark and decide if TP.
class LlavaMultiModalProjector(nn.Module):

    def __init__(self, vision_hidden_size: int, text_hidden_size: int,
                 projector_hidden_act: str):
        super().__init__()

        self.linear_1 = nn.Linear(vision_hidden_size,
                                  text_hidden_size,
                                  bias=True)
        self.act = get_act_fn(projector_hidden_act)
        self.linear_2 = nn.Linear(text_hidden_size,
                                  text_hidden_size,
                                  bias=True)

65
    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
66
67
68
69
70
71
        hidden_states = self.linear_1(image_features)
        hidden_states = self.act(hidden_states)
        hidden_states = self.linear_2(hidden_states)
        return hidden_states


72
73
74
75
76
def get_max_llava_image_tokens(ctx: InputContext):
    hf_config = ctx.get_hf_config(LlavaConfig)
    vision_config = hf_config.vision_config

    if isinstance(vision_config, CLIPVisionConfig):
77
78
79
80
81
82
83
84
85
86
87
88
89
90
        num_image_tokens = get_max_clip_image_tokens(vision_config)
    elif isinstance(vision_config, SiglipVisionConfig):
        num_image_tokens = get_max_siglip_image_tokens(vision_config)
    else:
        msg = f"Unsupported vision config: {type(vision_config)}"
        raise NotImplementedError(msg)

    strategy = hf_config.vision_feature_select_strategy
    if strategy == "default":
        return num_image_tokens - 1
    elif strategy == "full":
        return num_image_tokens
    else:
        raise ValueError(f"Unexpected select feature strategy: {strategy}")
91
92


93
94
def dummy_data_for_llava(ctx: InputContext, seq_len: int,
                         mm_counts: Mapping[str, int]):
95
96
    hf_config = ctx.get_hf_config(LlavaConfig)
    vision_config = hf_config.vision_config
97
    num_images = mm_counts["image"]
98

99
100
    image_feature_size = get_max_llava_image_tokens(ctx)

101
102
103
104
    if isinstance(vision_config, CLIPVisionConfig):
        seq_data = dummy_seq_data_for_clip(
            vision_config,
            seq_len,
105
            num_images,
106
            image_token_id=hf_config.image_token_index,
107
            image_feature_size_override=image_feature_size,
108
109
        )

110
        mm_data = dummy_image_for_clip(vision_config, num_images)
111
        return seq_data, mm_data
112
113
114
115
    elif isinstance(vision_config, SiglipVisionConfig):
        seq_data = dummy_seq_data_for_siglip(
            vision_config,
            seq_len,
116
            num_images,
117
118
119
120
            image_token_id=hf_config.image_token_index,
            image_feature_size_override=image_feature_size,
        )

121
        mm_data = dummy_image_for_siglip(vision_config, num_images)
122
        return seq_data, mm_data
123
124
125
126
127

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


128
129
130
131
132
133
134
135
136
def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs):
    multi_modal_data = llm_inputs.get("multi_modal_data")
    if multi_modal_data is None or "image" not in multi_modal_data:
        return llm_inputs

    model_config = ctx.model_config
    hf_config = ctx.get_hf_config(LlavaConfig)
    vision_config = hf_config.vision_config

137
138
139
140
141
142
143
144
145
146
147
148
    image_data = multi_modal_data["image"]
    if isinstance(image_data, Image.Image):
        image_feature_size = get_max_llava_image_tokens(ctx)
    elif is_list_of(image_data, Image.Image):
        image_feature_size = [get_max_llava_image_tokens(ctx)
                              ] * len(image_data)
    elif isinstance(image_data, torch.Tensor):
        num_images, image_feature_size, hidden_size = image_data.shape
    elif is_list_of(image_data, torch.Tensor):
        image_feature_size = [item.shape[1] for item in image_data]
    else:
        raise TypeError(f"Invalid image type: {type(image_data)}")
149

150
151
152
153
154
155
    if isinstance(vision_config, CLIPVisionConfig):
        return input_processor_for_clip(
            model_config,
            vision_config,
            llm_inputs,
            image_token_id=hf_config.image_token_index,
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
            image_feature_size_override=image_feature_size,
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return input_processor_for_siglip(
            model_config,
            vision_config,
            llm_inputs,
            image_token_id=hf_config.image_token_index,
            image_feature_size_override=image_feature_size,
        )

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


def _init_vision_tower(hf_config: LlavaConfig):
    vision_config = hf_config.vision_config

    # Initialize the vision tower only up to the required feature layer
    vision_feature_layer = hf_config.vision_feature_layer
    if vision_feature_layer < 0:
        num_hidden_layers = hf_config.vision_config.num_hidden_layers \
            + vision_feature_layer + 1
    else:
        num_hidden_layers = vision_feature_layer + 1

    if isinstance(vision_config, CLIPVisionConfig):
        return CLIPVisionModel(
            vision_config,
            num_hidden_layers_override=num_hidden_layers,
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return SiglipVisionModel(
            vision_config,
            num_hidden_layers_override=num_hidden_layers,
191
192
193
194
195
196
        )

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


197
@MULTIMODAL_REGISTRY.register_image_input_mapper()
198
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
199
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava)
200
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava)
201
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
202

203
    def __init__(self,
204
                 config: LlavaConfig,
205
                 multimodal_config: MultiModalConfig,
206
                 cache_config: Optional[CacheConfig] = None,
207
                 quant_config: Optional[QuantizationConfig] = None) -> None:
208
        super().__init__()
209

210
        self.config = config
211
        self.multimodal_config = multimodal_config
212

213
        # TODO: Optionally initializes this for supporting embeddings.
214
        self.vision_tower = _init_vision_tower(config)
215
216
217
218
219
        self.multi_modal_projector = LlavaMultiModalProjector(
            vision_hidden_size=config.vision_config.hidden_size,
            text_hidden_size=config.text_config.hidden_size,
            projector_hidden_act=config.projector_hidden_act)

220
221
        self.language_model = init_vllm_registered_model(
            config.text_config, cache_config, quant_config)
222

223
224
225
226
227
228
229
230
231
232
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

    @cached_property
    def sampler(self):
        if hasattr(self.language_model, "sampler"):
            return self.language_model.sampler

        return Sampler()

233
234
235
236
237
238
239
    def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)
        actual_dims = tuple(data.shape[1:])

        if actual_dims != expected_dims:
            expected_expr = ("batch_size", *map(str, expected_dims))
240
            raise ValueError(
241
242
                f"The expected shape of pixel values is {expected_expr}. "
                f"You supplied {tuple(data.shape)}.")
243
244
245
246

        return data

    def _parse_and_validate_image_input(
247
248
            self, **kwargs: object) -> Optional[LlavaImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
249
        image_embeds = kwargs.pop("image_embeds", None)
250

251
        if pixel_values is None and image_embeds is None:
252
            return None
253

254
        if pixel_values is not None:
255
            if not isinstance(pixel_values, (torch.Tensor, list)):
256
257
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")
258

259
260
            return LlavaImagePixelInputs(
                type="pixel_values",
261
262
                data=self._validate_pixel_values(
                    flatten_bn(pixel_values, concat=True)),
263
264
265
            )

        if image_embeds is not None:
266
            if not isinstance(image_embeds, (torch.Tensor, list)):
267
268
                raise ValueError("Incorrect type of image embeddings. "
                                 f"Got type: {type(image_embeds)}")
269

270
271
            return LlavaImageEmbeddingInputs(
                type="image_embeds",
272
                data=flatten_bn(image_embeds, concat=True),
273
274
275
            )

        raise AssertionError("This line should be unreachable.")
276
277
278
279
280
281
282
283
284
285
286

    def _select_image_features(self, image_features: torch.Tensor, *,
                               strategy: str) -> torch.Tensor:
        # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421  # noqa
        if strategy == "default":
            return image_features[:, 1:]
        elif strategy == "full":
            return image_features

        raise ValueError(f"Unexpected select feature strategy: {strategy}")

287
288
289
290
291
    def _image_pixels_to_features(
        self,
        vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:
292

293
294
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
295
        image_features = vision_tower(pixel_values)
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311

        return self._select_image_features(
            image_features,
            strategy=self.config.vision_feature_select_strategy,
        )

    def _process_image_pixels(self,
                              inputs: LlavaImagePixelInputs) -> torch.Tensor:
        assert self.vision_tower is not None

        pixel_values = inputs["data"]

        return self._image_pixels_to_features(self.vision_tower, pixel_values)

    def _process_image_input(self,
                             image_input: LlavaImageInputs) -> torch.Tensor:
312
313
314
315

        if image_input["type"] == "image_embeds":
            return image_input["data"]

316
317
        assert self.vision_tower is not None
        image_features = self._process_image_pixels(image_input)
318
319
        return self.multi_modal_projector(image_features)

320
321
322
323
324
325
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
326
        intermediate_tensors: Optional[IntermediateTensors] = None,
327
        **kwargs: object,
328
    ) -> Union[torch.Tensor, IntermediateTensors]:
Cyrus Leung's avatar
Cyrus Leung committed
329
        """Run forward pass for LLaVA-1.5.
330
331
332

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
333

334
        Concretely, consider a text prompt:
335
336
        `"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.

337
        Tokenizer outputs:
338
339
340
341
        `[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
        278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
342
        before they are inputted to the model, so the input processor prepends
343
344
345
346
347
348
349
350
351
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
        29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
        29901]`.

        We insert 575 tokens so that including the original image token in the
        input, there are a total of 576 (24 * 24) image tokens, which
        corresponds to the number of image tokens inputted to the language
        model, i.e. the number of image tokens outputted by the visual encoder.
352
353
354
355
356
357
358

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
Cyrus Leung's avatar
Cyrus Leung committed
359
            pixel_values: The pixels in each input image.
360

361
362
        See also:
            :class:`LlavaImageInputs`
363
        """
364
365
366
367
368
        if intermediate_tensors is not None:
            input_ids = None
            inputs_embeds = None
        else:
            image_input = self._parse_and_validate_image_input(**kwargs)
369

370
371
372
373
            if image_input is not None:
                vision_embeddings = self._process_image_input(image_input)
                inputs_embeds = self.language_model.model.get_input_embeddings(
                    input_ids)
374

375
376
377
                inputs_embeds = merge_multimodal_embeddings(
                    input_ids, inputs_embeds, vision_embeddings,
                    self.config.image_token_index)
378

379
380
381
                input_ids = None
            else:
                inputs_embeds = None
382

383
384
385
386
        hidden_states = self.language_model.model(input_ids,
                                                  positions,
                                                  kv_caches,
                                                  attn_metadata,
387
                                                  intermediate_tensors,
388
                                                  inputs_embeds=inputs_embeds)
389
390
391

        return hidden_states

392
393
394
395
396
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
397
398
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)
399
400
401
402
403
404

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
405
        return self.language_model.sample(logits, sampling_metadata)
406

407
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
408
409
        loader = AutoWeightsLoader(self)
        loader.load_weights(weights)