llava.py 14.8 KB
Newer Older
1
2
import itertools
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
3
4

import torch
5
import torch.nn as nn
6
from transformers import CLIPVisionConfig, LlavaConfig, SiglipVisionConfig
7
8

from vllm.attention import AttentionMetadata
9
from vllm.config import CacheConfig, MultiModalConfig
10
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
11
from vllm.model_executor.layers.activation import get_act_fn
12
13
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
14
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
15
from vllm.model_executor.sampling_metadata import SamplingMetadata
16
from vllm.multimodal import MULTIMODAL_REGISTRY
17
from vllm.sequence import IntermediateTensors, SamplerOutput
18

19
20
21
from .clip import (CLIPVisionModel, dummy_image_for_clip,
                   dummy_seq_data_for_clip, get_max_clip_image_tokens,
                   input_processor_for_clip)
22
from .interfaces import SupportsVision
23
24
25
26
27
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
                     dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
                     input_processor_for_siglip)
from .utils import (filter_weights, init_vllm_registered_model,
                    merge_vision_embeddings)
28
29


30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
class LlavaImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
    data: torch.Tensor
    """Shape: `(batch_size, num_channels, height, width)`"""


class LlavaImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
    data: torch.Tensor
    """Shape: `(batch_size, image_feature_size, hidden_size)`

    `hidden_size` must match the hidden size of language model backbone.
    """


LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageEmbeddingInputs]


48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# TODO(xwjiang): Run benchmark and decide if TP.
class LlavaMultiModalProjector(nn.Module):

    def __init__(self, vision_hidden_size: int, text_hidden_size: int,
                 projector_hidden_act: str):
        super().__init__()

        self.linear_1 = nn.Linear(vision_hidden_size,
                                  text_hidden_size,
                                  bias=True)
        self.act = get_act_fn(projector_hidden_act)
        self.linear_2 = nn.Linear(text_hidden_size,
                                  text_hidden_size,
                                  bias=True)

63
    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
64
65
66
67
68
69
        hidden_states = self.linear_1(image_features)
        hidden_states = self.act(hidden_states)
        hidden_states = self.linear_2(hidden_states)
        return hidden_states


70
71
72
73
74
def get_max_llava_image_tokens(ctx: InputContext):
    hf_config = ctx.get_hf_config(LlavaConfig)
    vision_config = hf_config.vision_config

    if isinstance(vision_config, CLIPVisionConfig):
75
76
77
78
79
80
81
82
83
84
85
86
87
88
        num_image_tokens = get_max_clip_image_tokens(vision_config)
    elif isinstance(vision_config, SiglipVisionConfig):
        num_image_tokens = get_max_siglip_image_tokens(vision_config)
    else:
        msg = f"Unsupported vision config: {type(vision_config)}"
        raise NotImplementedError(msg)

    strategy = hf_config.vision_feature_select_strategy
    if strategy == "default":
        return num_image_tokens - 1
    elif strategy == "full":
        return num_image_tokens
    else:
        raise ValueError(f"Unexpected select feature strategy: {strategy}")
89
90


91
92
93
94
def dummy_data_for_llava(ctx: InputContext, seq_len: int):
    hf_config = ctx.get_hf_config(LlavaConfig)
    vision_config = hf_config.vision_config

95
96
    image_feature_size = get_max_llava_image_tokens(ctx)

97
98
99
100
101
    if isinstance(vision_config, CLIPVisionConfig):
        seq_data = dummy_seq_data_for_clip(
            vision_config,
            seq_len,
            image_token_id=hf_config.image_token_index,
102
            image_feature_size_override=image_feature_size,
103
104
        )

105
        mm_data = dummy_image_for_clip(vision_config)
106
        return seq_data, mm_data
107
108
109
110
111
112
113
114
115
116
    elif isinstance(vision_config, SiglipVisionConfig):
        seq_data = dummy_seq_data_for_siglip(
            vision_config,
            seq_len,
            image_token_id=hf_config.image_token_index,
            image_feature_size_override=image_feature_size,
        )

        mm_data = dummy_image_for_siglip(vision_config)
        return seq_data, mm_data
117
118
119
120
121

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


122
123
124
125
126
127
128
129
130
def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs):
    multi_modal_data = llm_inputs.get("multi_modal_data")
    if multi_modal_data is None or "image" not in multi_modal_data:
        return llm_inputs

    model_config = ctx.model_config
    hf_config = ctx.get_hf_config(LlavaConfig)
    vision_config = hf_config.vision_config

131
132
    image_feature_size = get_max_llava_image_tokens(ctx)

133
134
135
136
137
138
    if isinstance(vision_config, CLIPVisionConfig):
        return input_processor_for_clip(
            model_config,
            vision_config,
            llm_inputs,
            image_token_id=hf_config.image_token_index,
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
            image_feature_size_override=image_feature_size,
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return input_processor_for_siglip(
            model_config,
            vision_config,
            llm_inputs,
            image_token_id=hf_config.image_token_index,
            image_feature_size_override=image_feature_size,
        )

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


def _init_vision_tower(hf_config: LlavaConfig):
    vision_config = hf_config.vision_config

    # Initialize the vision tower only up to the required feature layer
    vision_feature_layer = hf_config.vision_feature_layer
    if vision_feature_layer < 0:
        num_hidden_layers = hf_config.vision_config.num_hidden_layers \
            + vision_feature_layer + 1
    else:
        num_hidden_layers = vision_feature_layer + 1

    if isinstance(vision_config, CLIPVisionConfig):
        return CLIPVisionModel(
            vision_config,
            num_hidden_layers_override=num_hidden_layers,
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return SiglipVisionModel(
            vision_config,
            num_hidden_layers_override=num_hidden_layers,
174
175
176
177
178
179
        )

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


180
@MULTIMODAL_REGISTRY.register_image_input_mapper()
181
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
182
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava)
183
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava)
184
185
class LlavaForConditionalGeneration(nn.Module, SupportsVision):

186
    def __init__(self,
187
                 config: LlavaConfig,
188
                 multimodal_config: MultiModalConfig,
189
                 cache_config: Optional[CacheConfig] = None,
190
                 quant_config: Optional[QuantizationConfig] = None) -> None:
191
        super().__init__()
192

193
        self.config = config
194
        self.multimodal_config = multimodal_config
195

196
        # TODO: Optionally initializes this for supporting embeddings.
197
        self.vision_tower = _init_vision_tower(config)
198
199
200
201
202
        self.multi_modal_projector = LlavaMultiModalProjector(
            vision_hidden_size=config.vision_config.hidden_size,
            text_hidden_size=config.text_config.hidden_size,
            projector_hidden_act=config.projector_hidden_act)

203
204
        self.language_model = init_vllm_registered_model(
            config.text_config, cache_config, quant_config)
205

206
207
208
209
210
211
212
    def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)
        actual_dims = tuple(data.shape[1:])

        if actual_dims != expected_dims:
            expected_expr = ("batch_size", *map(str, expected_dims))
213
            raise ValueError(
214
215
                f"The expected shape of pixel values is {expected_expr}. "
                f"You supplied {tuple(data.shape)}.")
216
217
218
219

        return data

    def _parse_and_validate_image_input(
220
221
            self, **kwargs: object) -> Optional[LlavaImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
222
        image_embeds = kwargs.pop("image_embeds", None)
223

224
        if pixel_values is None and image_embeds is None:
225
            return None
226

227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
        if pixel_values is not None:
            if not isinstance(pixel_values, torch.Tensor):
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")
            return LlavaImagePixelInputs(
                type="pixel_values",
                data=self._validate_pixel_values(pixel_values),
            )

        if image_embeds is not None:
            if not isinstance(image_embeds, torch.Tensor):
                raise ValueError("Incorrect type of image embeddings. "
                                 f"Got type: {type(image_embeds)}")
            return LlavaImageEmbeddingInputs(
                type="image_embeds",
                data=image_embeds,
            )

        raise AssertionError("This line should be unreachable.")
246
247
248
249
250
251
252
253
254
255
256

    def _select_image_features(self, image_features: torch.Tensor, *,
                               strategy: str) -> torch.Tensor:
        # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421  # noqa
        if strategy == "default":
            return image_features[:, 1:]
        elif strategy == "full":
            return image_features

        raise ValueError(f"Unexpected select feature strategy: {strategy}")

257
258
259
260
261
    def _image_pixels_to_features(
        self,
        vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:
262

263
264
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
265
        image_features = vision_tower(pixel_values)
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281

        return self._select_image_features(
            image_features,
            strategy=self.config.vision_feature_select_strategy,
        )

    def _process_image_pixels(self,
                              inputs: LlavaImagePixelInputs) -> torch.Tensor:
        assert self.vision_tower is not None

        pixel_values = inputs["data"]

        return self._image_pixels_to_features(self.vision_tower, pixel_values)

    def _process_image_input(self,
                             image_input: LlavaImageInputs) -> torch.Tensor:
282
283
284
285

        if image_input["type"] == "image_embeds":
            return image_input["data"]

286
287
        assert self.vision_tower is not None
        image_features = self._process_image_pixels(image_input)
288
289
        return self.multi_modal_projector(image_features)

290
291
292
293
294
295
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
296
        intermediate_tensors: Optional[IntermediateTensors] = None,
297
298
        **kwargs: object,
    ) -> SamplerOutput:
Cyrus Leung's avatar
Cyrus Leung committed
299
        """Run forward pass for LLaVA-1.5.
300
301
302

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
303

304
        Concretely, consider a text prompt:
305
306
        `"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.

307
        Tokenizer outputs:
308
309
310
311
312
313
314
315
316
317
318
319
320
321
        `[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
        278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
        before they are inputted to the model, so the input processor prepends 
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
        29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
        29901]`.

        We insert 575 tokens so that including the original image token in the
        input, there are a total of 576 (24 * 24) image tokens, which
        corresponds to the number of image tokens inputted to the language
        model, i.e. the number of image tokens outputted by the visual encoder.
322
323
324
325
326
327
328

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
Cyrus Leung's avatar
Cyrus Leung committed
329
            pixel_values: The pixels in each input image.
330
331
332
        
        See also:
            :class:`LlavaImageInputs`
333
        """
334
        image_input = self._parse_and_validate_image_input(**kwargs)
335

336
337
        if image_input is not None:
            vision_embeddings = self._process_image_input(image_input)
338
339
            inputs_embeds = self.language_model.model.get_input_embeddings(
                input_ids)
340

341
            inputs_embeds = merge_vision_embeddings(
342
                input_ids, inputs_embeds, vision_embeddings,
343
                self.config.image_token_index)
344

345
346
347
            input_ids = None
        else:
            inputs_embeds = None
348

349
350
351
352
353
354
        hidden_states = self.language_model.model(input_ids,
                                                  positions,
                                                  kv_caches,
                                                  attn_metadata,
                                                  None,
                                                  inputs_embeds=inputs_embeds)
355
356
357

        return hidden_states

358
359
360
361
362
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
363
364
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)
365
366
367
368
369
370

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
371
        return self.language_model.sample(logits, sampling_metadata)
372

373
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
        # prepare weight iterators for components
        vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3)

        # load vision encoder
        vit_weights = filter_weights(vit_weights, "vision_tower")
        self.vision_tower.load_weights(vit_weights)

        # load mlp projector
        mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
        mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
        for name, loaded_weight in mlp_weights:
            param = mlp_params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)

        # load llm backbone
        llm_weights = filter_weights(llm_weights, "language_model")
        self.language_model.load_weights(llm_weights)