llava.py 15.3 KB
Newer Older
1
import itertools
2
3
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
                    TypedDict, Union)
4
5

import torch
6
import torch.nn as nn
7
from transformers import CLIPVisionConfig, LlavaConfig, SiglipVisionConfig
8
9

from vllm.attention import AttentionMetadata
10
from vllm.config import CacheConfig, MultiModalConfig
11
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
12
from vllm.model_executor.layers.activation import get_act_fn
13
from vllm.model_executor.layers.quantization import QuantizationConfig
14
from vllm.model_executor.layers.sampler import SamplerOutput
15
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
16
from vllm.model_executor.sampling_metadata import SamplingMetadata
17
from vllm.multimodal import MULTIMODAL_REGISTRY
18
from vllm.sequence import IntermediateTensors
19

20
21
22
from .clip import (CLIPVisionModel, dummy_image_for_clip,
                   dummy_seq_data_for_clip, get_max_clip_image_tokens,
                   input_processor_for_clip)
23
from .interfaces import SupportsMultiModal
24
25
26
27
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
                     dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
                     input_processor_for_siglip)
from .utils import (filter_weights, init_vllm_registered_model,
28
                    merge_multimodal_embeddings)
29
30


31
32
33
class LlavaImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
    data: torch.Tensor
34
    """Shape: `(batch_size * num_images, num_channels, height, width)`"""
35
36
37
38
39


class LlavaImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
    data: torch.Tensor
40
    """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
41
42
43
44
45
46
47
48

    `hidden_size` must match the hidden size of language model backbone.
    """


LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageEmbeddingInputs]


49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# TODO(xwjiang): Run benchmark and decide if TP.
class LlavaMultiModalProjector(nn.Module):

    def __init__(self, vision_hidden_size: int, text_hidden_size: int,
                 projector_hidden_act: str):
        super().__init__()

        self.linear_1 = nn.Linear(vision_hidden_size,
                                  text_hidden_size,
                                  bias=True)
        self.act = get_act_fn(projector_hidden_act)
        self.linear_2 = nn.Linear(text_hidden_size,
                                  text_hidden_size,
                                  bias=True)

64
    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
65
66
67
68
69
70
        hidden_states = self.linear_1(image_features)
        hidden_states = self.act(hidden_states)
        hidden_states = self.linear_2(hidden_states)
        return hidden_states


71
72
73
74
75
def get_max_llava_image_tokens(ctx: InputContext):
    hf_config = ctx.get_hf_config(LlavaConfig)
    vision_config = hf_config.vision_config

    if isinstance(vision_config, CLIPVisionConfig):
76
77
78
79
80
81
82
83
84
85
86
87
88
89
        num_image_tokens = get_max_clip_image_tokens(vision_config)
    elif isinstance(vision_config, SiglipVisionConfig):
        num_image_tokens = get_max_siglip_image_tokens(vision_config)
    else:
        msg = f"Unsupported vision config: {type(vision_config)}"
        raise NotImplementedError(msg)

    strategy = hf_config.vision_feature_select_strategy
    if strategy == "default":
        return num_image_tokens - 1
    elif strategy == "full":
        return num_image_tokens
    else:
        raise ValueError(f"Unexpected select feature strategy: {strategy}")
90
91


92
93
def dummy_data_for_llava(ctx: InputContext, seq_len: int,
                         mm_counts: Mapping[str, int]):
94
95
    hf_config = ctx.get_hf_config(LlavaConfig)
    vision_config = hf_config.vision_config
96
    num_images = mm_counts["image"]
97

98
99
    image_feature_size = get_max_llava_image_tokens(ctx)

100
101
102
103
    if isinstance(vision_config, CLIPVisionConfig):
        seq_data = dummy_seq_data_for_clip(
            vision_config,
            seq_len,
104
            num_images,
105
            image_token_id=hf_config.image_token_index,
106
            image_feature_size_override=image_feature_size,
107
108
        )

109
        mm_data = dummy_image_for_clip(vision_config, num_images)
110
        return seq_data, mm_data
111
112
113
114
    elif isinstance(vision_config, SiglipVisionConfig):
        seq_data = dummy_seq_data_for_siglip(
            vision_config,
            seq_len,
115
            num_images,
116
117
118
119
            image_token_id=hf_config.image_token_index,
            image_feature_size_override=image_feature_size,
        )

120
        mm_data = dummy_image_for_siglip(vision_config, num_images)
121
        return seq_data, mm_data
122
123
124
125
126

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


127
128
129
130
131
132
133
134
135
def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs):
    multi_modal_data = llm_inputs.get("multi_modal_data")
    if multi_modal_data is None or "image" not in multi_modal_data:
        return llm_inputs

    model_config = ctx.model_config
    hf_config = ctx.get_hf_config(LlavaConfig)
    vision_config = hf_config.vision_config

136
137
    image_feature_size = get_max_llava_image_tokens(ctx)

138
139
140
141
142
143
    if isinstance(vision_config, CLIPVisionConfig):
        return input_processor_for_clip(
            model_config,
            vision_config,
            llm_inputs,
            image_token_id=hf_config.image_token_index,
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
            image_feature_size_override=image_feature_size,
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return input_processor_for_siglip(
            model_config,
            vision_config,
            llm_inputs,
            image_token_id=hf_config.image_token_index,
            image_feature_size_override=image_feature_size,
        )

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


def _init_vision_tower(hf_config: LlavaConfig):
    vision_config = hf_config.vision_config

    # Initialize the vision tower only up to the required feature layer
    vision_feature_layer = hf_config.vision_feature_layer
    if vision_feature_layer < 0:
        num_hidden_layers = hf_config.vision_config.num_hidden_layers \
            + vision_feature_layer + 1
    else:
        num_hidden_layers = vision_feature_layer + 1

    if isinstance(vision_config, CLIPVisionConfig):
        return CLIPVisionModel(
            vision_config,
            num_hidden_layers_override=num_hidden_layers,
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return SiglipVisionModel(
            vision_config,
            num_hidden_layers_override=num_hidden_layers,
179
180
181
182
183
184
        )

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


185
@MULTIMODAL_REGISTRY.register_image_input_mapper()
186
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
187
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava)
188
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava)
189
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal):
190

191
    def __init__(self,
192
                 config: LlavaConfig,
193
                 multimodal_config: MultiModalConfig,
194
                 cache_config: Optional[CacheConfig] = None,
195
                 quant_config: Optional[QuantizationConfig] = None) -> None:
196
        super().__init__()
197

198
        self.config = config
199
        self.multimodal_config = multimodal_config
200

201
        # TODO: Optionally initializes this for supporting embeddings.
202
        self.vision_tower = _init_vision_tower(config)
203
204
205
206
207
        self.multi_modal_projector = LlavaMultiModalProjector(
            vision_hidden_size=config.vision_config.hidden_size,
            text_hidden_size=config.text_config.hidden_size,
            projector_hidden_act=config.projector_hidden_act)

208
209
        self.language_model = init_vllm_registered_model(
            config.text_config, cache_config, quant_config)
210

211
212
213
214
215
216
217
    def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)
        actual_dims = tuple(data.shape[1:])

        if actual_dims != expected_dims:
            expected_expr = ("batch_size", *map(str, expected_dims))
218
            raise ValueError(
219
220
                f"The expected shape of pixel values is {expected_expr}. "
                f"You supplied {tuple(data.shape)}.")
221
222
223
224

        return data

    def _parse_and_validate_image_input(
225
226
            self, **kwargs: object) -> Optional[LlavaImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
227
        image_embeds = kwargs.pop("image_embeds", None)
228

229
        if pixel_values is None and image_embeds is None:
230
            return None
231

232
233
234
235
        if pixel_values is not None:
            if not isinstance(pixel_values, torch.Tensor):
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")
236
237
238
239

            # Remove the N dimension until multiple images are supported.
            pixel_values = pixel_values.squeeze(1)

240
241
242
243
244
245
246
247
248
            return LlavaImagePixelInputs(
                type="pixel_values",
                data=self._validate_pixel_values(pixel_values),
            )

        if image_embeds is not None:
            if not isinstance(image_embeds, torch.Tensor):
                raise ValueError("Incorrect type of image embeddings. "
                                 f"Got type: {type(image_embeds)}")
249
250
251
252

            # Remove the N dimension until multiple images are supported.
            image_embeds = image_embeds.squeeze(1)

253
254
255
256
257
258
            return LlavaImageEmbeddingInputs(
                type="image_embeds",
                data=image_embeds,
            )

        raise AssertionError("This line should be unreachable.")
259
260
261
262
263
264
265
266
267
268
269

    def _select_image_features(self, image_features: torch.Tensor, *,
                               strategy: str) -> torch.Tensor:
        # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421  # noqa
        if strategy == "default":
            return image_features[:, 1:]
        elif strategy == "full":
            return image_features

        raise ValueError(f"Unexpected select feature strategy: {strategy}")

270
271
272
273
274
    def _image_pixels_to_features(
        self,
        vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:
275

276
277
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
278
        image_features = vision_tower(pixel_values)
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294

        return self._select_image_features(
            image_features,
            strategy=self.config.vision_feature_select_strategy,
        )

    def _process_image_pixels(self,
                              inputs: LlavaImagePixelInputs) -> torch.Tensor:
        assert self.vision_tower is not None

        pixel_values = inputs["data"]

        return self._image_pixels_to_features(self.vision_tower, pixel_values)

    def _process_image_input(self,
                             image_input: LlavaImageInputs) -> torch.Tensor:
295
296
297
298

        if image_input["type"] == "image_embeds":
            return image_input["data"]

299
300
        assert self.vision_tower is not None
        image_features = self._process_image_pixels(image_input)
301
302
        return self.multi_modal_projector(image_features)

303
304
305
306
307
308
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
309
        intermediate_tensors: Optional[IntermediateTensors] = None,
310
311
        **kwargs: object,
    ) -> SamplerOutput:
Cyrus Leung's avatar
Cyrus Leung committed
312
        """Run forward pass for LLaVA-1.5.
313
314
315

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
316

317
        Concretely, consider a text prompt:
318
319
        `"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.

320
        Tokenizer outputs:
321
322
323
324
        `[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
        278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
325
        before they are inputted to the model, so the input processor prepends
326
327
328
329
330
331
332
333
334
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
        29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
        29901]`.

        We insert 575 tokens so that including the original image token in the
        input, there are a total of 576 (24 * 24) image tokens, which
        corresponds to the number of image tokens inputted to the language
        model, i.e. the number of image tokens outputted by the visual encoder.
335
336
337
338
339
340
341

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
Cyrus Leung's avatar
Cyrus Leung committed
342
            pixel_values: The pixels in each input image.
343

344
345
        See also:
            :class:`LlavaImageInputs`
346
        """
347
        image_input = self._parse_and_validate_image_input(**kwargs)
348

349
350
        if image_input is not None:
            vision_embeddings = self._process_image_input(image_input)
351
352
            inputs_embeds = self.language_model.model.get_input_embeddings(
                input_ids)
353

354
            inputs_embeds = merge_multimodal_embeddings(
355
                input_ids, inputs_embeds, vision_embeddings,
356
                self.config.image_token_index)
357

358
359
360
            input_ids = None
        else:
            inputs_embeds = None
361

362
363
364
365
366
367
        hidden_states = self.language_model.model(input_ids,
                                                  positions,
                                                  kv_caches,
                                                  attn_metadata,
                                                  None,
                                                  inputs_embeds=inputs_embeds)
368
369
370

        return hidden_states

371
372
373
374
375
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
376
377
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)
378
379
380
381
382
383

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
384
        return self.language_model.sample(logits, sampling_metadata)
385

386
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
        # prepare weight iterators for components
        vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3)

        # load vision encoder
        vit_weights = filter_weights(vit_weights, "vision_tower")
        self.vision_tower.load_weights(vit_weights)

        # load mlp projector
        mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
        mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
        for name, loaded_weight in mlp_weights:
            param = mlp_params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)

        # load llm backbone
        llm_weights = filter_weights(llm_weights, "language_model")
        self.language_model.load_weights(llm_weights)