llava.py 15.2 KB
Newer Older
1
import itertools
2
3
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
                    TypedDict, Union)
4
5

import torch
6
import torch.nn as nn
7
from transformers import CLIPVisionConfig, LlavaConfig, SiglipVisionConfig
8
9

from vllm.attention import AttentionMetadata
10
from vllm.config import CacheConfig, MultiModalConfig
11
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
12
from vllm.model_executor.layers.activation import get_act_fn
13
from vllm.model_executor.layers.quantization import QuantizationConfig
14
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
15
from vllm.model_executor.sampling_metadata import SamplingMetadata
16
from vllm.multimodal import MULTIMODAL_REGISTRY
17
from vllm.sequence import IntermediateTensors, SamplerOutput
18

19
20
21
from .clip import (CLIPVisionModel, dummy_image_for_clip,
                   dummy_seq_data_for_clip, get_max_clip_image_tokens,
                   input_processor_for_clip)
22
from .interfaces import SupportsMultiModal
23
24
25
26
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
                     dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
                     input_processor_for_siglip)
from .utils import (filter_weights, init_vllm_registered_model,
27
                    merge_multimodal_embeddings)
28
29


30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
class LlavaImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
    data: torch.Tensor
    """Shape: `(batch_size, num_channels, height, width)`"""


class LlavaImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
    data: torch.Tensor
    """Shape: `(batch_size, image_feature_size, hidden_size)`

    `hidden_size` must match the hidden size of language model backbone.
    """


LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageEmbeddingInputs]


48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# TODO(xwjiang): Run benchmark and decide if TP.
class LlavaMultiModalProjector(nn.Module):

    def __init__(self, vision_hidden_size: int, text_hidden_size: int,
                 projector_hidden_act: str):
        super().__init__()

        self.linear_1 = nn.Linear(vision_hidden_size,
                                  text_hidden_size,
                                  bias=True)
        self.act = get_act_fn(projector_hidden_act)
        self.linear_2 = nn.Linear(text_hidden_size,
                                  text_hidden_size,
                                  bias=True)

63
    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
64
65
66
67
68
69
        hidden_states = self.linear_1(image_features)
        hidden_states = self.act(hidden_states)
        hidden_states = self.linear_2(hidden_states)
        return hidden_states


70
71
72
73
74
def get_max_llava_image_tokens(ctx: InputContext):
    hf_config = ctx.get_hf_config(LlavaConfig)
    vision_config = hf_config.vision_config

    if isinstance(vision_config, CLIPVisionConfig):
75
76
77
78
79
80
81
82
83
84
85
86
87
88
        num_image_tokens = get_max_clip_image_tokens(vision_config)
    elif isinstance(vision_config, SiglipVisionConfig):
        num_image_tokens = get_max_siglip_image_tokens(vision_config)
    else:
        msg = f"Unsupported vision config: {type(vision_config)}"
        raise NotImplementedError(msg)

    strategy = hf_config.vision_feature_select_strategy
    if strategy == "default":
        return num_image_tokens - 1
    elif strategy == "full":
        return num_image_tokens
    else:
        raise ValueError(f"Unexpected select feature strategy: {strategy}")
89
90


91
92
def dummy_data_for_llava(ctx: InputContext, seq_len: int,
                         mm_counts: Mapping[str, int]):
93
94
    hf_config = ctx.get_hf_config(LlavaConfig)
    vision_config = hf_config.vision_config
95
    num_images = mm_counts["image"]
96

97
98
    image_feature_size = get_max_llava_image_tokens(ctx)

99
100
101
102
    if isinstance(vision_config, CLIPVisionConfig):
        seq_data = dummy_seq_data_for_clip(
            vision_config,
            seq_len,
103
            num_images,
104
            image_token_id=hf_config.image_token_index,
105
            image_feature_size_override=image_feature_size,
106
107
        )

108
        mm_data = dummy_image_for_clip(vision_config, num_images)
109
        return seq_data, mm_data
110
111
112
113
    elif isinstance(vision_config, SiglipVisionConfig):
        seq_data = dummy_seq_data_for_siglip(
            vision_config,
            seq_len,
114
            num_images,
115
116
117
118
            image_token_id=hf_config.image_token_index,
            image_feature_size_override=image_feature_size,
        )

119
        mm_data = dummy_image_for_siglip(vision_config, num_images)
120
        return seq_data, mm_data
121
122
123
124
125

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


126
127
128
129
130
131
132
133
134
def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs):
    multi_modal_data = llm_inputs.get("multi_modal_data")
    if multi_modal_data is None or "image" not in multi_modal_data:
        return llm_inputs

    model_config = ctx.model_config
    hf_config = ctx.get_hf_config(LlavaConfig)
    vision_config = hf_config.vision_config

135
136
    image_feature_size = get_max_llava_image_tokens(ctx)

137
138
139
140
141
142
    if isinstance(vision_config, CLIPVisionConfig):
        return input_processor_for_clip(
            model_config,
            vision_config,
            llm_inputs,
            image_token_id=hf_config.image_token_index,
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
            image_feature_size_override=image_feature_size,
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return input_processor_for_siglip(
            model_config,
            vision_config,
            llm_inputs,
            image_token_id=hf_config.image_token_index,
            image_feature_size_override=image_feature_size,
        )

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


def _init_vision_tower(hf_config: LlavaConfig):
    vision_config = hf_config.vision_config

    # Initialize the vision tower only up to the required feature layer
    vision_feature_layer = hf_config.vision_feature_layer
    if vision_feature_layer < 0:
        num_hidden_layers = hf_config.vision_config.num_hidden_layers \
            + vision_feature_layer + 1
    else:
        num_hidden_layers = vision_feature_layer + 1

    if isinstance(vision_config, CLIPVisionConfig):
        return CLIPVisionModel(
            vision_config,
            num_hidden_layers_override=num_hidden_layers,
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return SiglipVisionModel(
            vision_config,
            num_hidden_layers_override=num_hidden_layers,
178
179
180
181
182
183
        )

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


184
@MULTIMODAL_REGISTRY.register_image_input_mapper()
185
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
186
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava)
187
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava)
188
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal):
189

190
    def __init__(self,
191
                 config: LlavaConfig,
192
                 multimodal_config: MultiModalConfig,
193
                 cache_config: Optional[CacheConfig] = None,
194
                 quant_config: Optional[QuantizationConfig] = None) -> None:
195
        super().__init__()
196

197
        self.config = config
198
        self.multimodal_config = multimodal_config
199

200
        # TODO: Optionally initializes this for supporting embeddings.
201
        self.vision_tower = _init_vision_tower(config)
202
203
204
205
206
        self.multi_modal_projector = LlavaMultiModalProjector(
            vision_hidden_size=config.vision_config.hidden_size,
            text_hidden_size=config.text_config.hidden_size,
            projector_hidden_act=config.projector_hidden_act)

207
208
        self.language_model = init_vllm_registered_model(
            config.text_config, cache_config, quant_config)
209

210
211
212
213
214
215
216
    def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)
        actual_dims = tuple(data.shape[1:])

        if actual_dims != expected_dims:
            expected_expr = ("batch_size", *map(str, expected_dims))
217
            raise ValueError(
218
219
                f"The expected shape of pixel values is {expected_expr}. "
                f"You supplied {tuple(data.shape)}.")
220
221
222
223

        return data

    def _parse_and_validate_image_input(
224
225
            self, **kwargs: object) -> Optional[LlavaImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
226
        image_embeds = kwargs.pop("image_embeds", None)
227

228
        if pixel_values is None and image_embeds is None:
229
            return None
230

231
232
233
234
        if pixel_values is not None:
            if not isinstance(pixel_values, torch.Tensor):
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")
235
236
237
238

            # Remove the N dimension until multiple images are supported.
            pixel_values = pixel_values.squeeze(1)

239
240
241
242
243
244
245
246
247
            return LlavaImagePixelInputs(
                type="pixel_values",
                data=self._validate_pixel_values(pixel_values),
            )

        if image_embeds is not None:
            if not isinstance(image_embeds, torch.Tensor):
                raise ValueError("Incorrect type of image embeddings. "
                                 f"Got type: {type(image_embeds)}")
248
249
250
251

            # Remove the N dimension until multiple images are supported.
            image_embeds = image_embeds.squeeze(1)

252
253
254
255
256
257
            return LlavaImageEmbeddingInputs(
                type="image_embeds",
                data=image_embeds,
            )

        raise AssertionError("This line should be unreachable.")
258
259
260
261
262
263
264
265
266
267
268

    def _select_image_features(self, image_features: torch.Tensor, *,
                               strategy: str) -> torch.Tensor:
        # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421  # noqa
        if strategy == "default":
            return image_features[:, 1:]
        elif strategy == "full":
            return image_features

        raise ValueError(f"Unexpected select feature strategy: {strategy}")

269
270
271
272
273
    def _image_pixels_to_features(
        self,
        vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:
274

275
276
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
277
        image_features = vision_tower(pixel_values)
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293

        return self._select_image_features(
            image_features,
            strategy=self.config.vision_feature_select_strategy,
        )

    def _process_image_pixels(self,
                              inputs: LlavaImagePixelInputs) -> torch.Tensor:
        assert self.vision_tower is not None

        pixel_values = inputs["data"]

        return self._image_pixels_to_features(self.vision_tower, pixel_values)

    def _process_image_input(self,
                             image_input: LlavaImageInputs) -> torch.Tensor:
294
295
296
297

        if image_input["type"] == "image_embeds":
            return image_input["data"]

298
299
        assert self.vision_tower is not None
        image_features = self._process_image_pixels(image_input)
300
301
        return self.multi_modal_projector(image_features)

302
303
304
305
306
307
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
308
        intermediate_tensors: Optional[IntermediateTensors] = None,
309
310
        **kwargs: object,
    ) -> SamplerOutput:
Cyrus Leung's avatar
Cyrus Leung committed
311
        """Run forward pass for LLaVA-1.5.
312
313
314

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
315

316
        Concretely, consider a text prompt:
317
318
        `"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.

319
        Tokenizer outputs:
320
321
322
323
        `[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
        278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
324
        before they are inputted to the model, so the input processor prepends
325
326
327
328
329
330
331
332
333
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
        29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
        29901]`.

        We insert 575 tokens so that including the original image token in the
        input, there are a total of 576 (24 * 24) image tokens, which
        corresponds to the number of image tokens inputted to the language
        model, i.e. the number of image tokens outputted by the visual encoder.
334
335
336
337
338
339
340

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
Cyrus Leung's avatar
Cyrus Leung committed
341
            pixel_values: The pixels in each input image.
342

343
344
        See also:
            :class:`LlavaImageInputs`
345
        """
346
        image_input = self._parse_and_validate_image_input(**kwargs)
347

348
349
        if image_input is not None:
            vision_embeddings = self._process_image_input(image_input)
350
351
            inputs_embeds = self.language_model.model.get_input_embeddings(
                input_ids)
352

353
            inputs_embeds = merge_multimodal_embeddings(
354
                input_ids, inputs_embeds, vision_embeddings,
355
                self.config.image_token_index)
356

357
358
359
            input_ids = None
        else:
            inputs_embeds = None
360

361
362
363
364
365
366
        hidden_states = self.language_model.model(input_ids,
                                                  positions,
                                                  kv_caches,
                                                  attn_metadata,
                                                  None,
                                                  inputs_embeds=inputs_embeds)
367
368
369

        return hidden_states

370
371
372
373
374
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
375
376
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)
377
378
379
380
381
382

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
383
        return self.language_model.sample(logits, sampling_metadata)
384

385
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
        # prepare weight iterators for components
        vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3)

        # load vision encoder
        vit_weights = filter_weights(vit_weights, "vision_tower")
        self.vision_tower.load_weights(vit_weights)

        # load mlp projector
        mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
        mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
        for name, loaded_weight in mlp_weights:
            param = mlp_params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)

        # load llm backbone
        llm_weights = filter_weights(llm_weights, "language_model")
        self.language_model.load_weights(llm_weights)