llava.py 13.8 KB
Newer Older
1
2
import itertools
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
3
4

import torch
5
import torch.nn as nn
6
from transformers import CLIPVisionConfig, LlavaConfig, SiglipVisionConfig
7
8

from vllm.attention import AttentionMetadata
9
from vllm.config import CacheConfig, MultiModalConfig
10
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
11
from vllm.model_executor.layers.activation import get_act_fn
12
13
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
14
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
15
from vllm.model_executor.sampling_metadata import SamplingMetadata
16
from vllm.multimodal import MULTIMODAL_REGISTRY
17
from vllm.sequence import IntermediateTensors, SamplerOutput
18

19
20
21
from .clip import (CLIPVisionModel, dummy_image_for_clip,
                   dummy_seq_data_for_clip, get_max_clip_image_tokens,
                   input_processor_for_clip)
22
from .interfaces import SupportsVision
23
24
25
26
27
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
                     dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
                     input_processor_for_siglip)
from .utils import (filter_weights, init_vllm_registered_model,
                    merge_vision_embeddings)
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44


# TODO(xwjiang): Run benchmark and decide if TP.
class LlavaMultiModalProjector(nn.Module):

    def __init__(self, vision_hidden_size: int, text_hidden_size: int,
                 projector_hidden_act: str):
        super().__init__()

        self.linear_1 = nn.Linear(vision_hidden_size,
                                  text_hidden_size,
                                  bias=True)
        self.act = get_act_fn(projector_hidden_act)
        self.linear_2 = nn.Linear(text_hidden_size,
                                  text_hidden_size,
                                  bias=True)

45
    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
46
47
48
49
50
51
        hidden_states = self.linear_1(image_features)
        hidden_states = self.act(hidden_states)
        hidden_states = self.linear_2(hidden_states)
        return hidden_states


52
53
54
class LlavaImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
    data: torch.Tensor
55
    """Shape: `(batch_size, num_channels, height, width)`"""
56
57


58
LlavaImageInputs = LlavaImagePixelInputs
59
60


61
62
63
64
65
def get_max_llava_image_tokens(ctx: InputContext):
    hf_config = ctx.get_hf_config(LlavaConfig)
    vision_config = hf_config.vision_config

    if isinstance(vision_config, CLIPVisionConfig):
66
67
68
69
70
71
72
73
74
75
76
77
78
79
        num_image_tokens = get_max_clip_image_tokens(vision_config)
    elif isinstance(vision_config, SiglipVisionConfig):
        num_image_tokens = get_max_siglip_image_tokens(vision_config)
    else:
        msg = f"Unsupported vision config: {type(vision_config)}"
        raise NotImplementedError(msg)

    strategy = hf_config.vision_feature_select_strategy
    if strategy == "default":
        return num_image_tokens - 1
    elif strategy == "full":
        return num_image_tokens
    else:
        raise ValueError(f"Unexpected select feature strategy: {strategy}")
80
81


82
83
84
85
def dummy_data_for_llava(ctx: InputContext, seq_len: int):
    hf_config = ctx.get_hf_config(LlavaConfig)
    vision_config = hf_config.vision_config

86
87
    image_feature_size = get_max_llava_image_tokens(ctx)

88
89
90
91
92
    if isinstance(vision_config, CLIPVisionConfig):
        seq_data = dummy_seq_data_for_clip(
            vision_config,
            seq_len,
            image_token_id=hf_config.image_token_index,
93
            image_feature_size_override=image_feature_size,
94
95
        )

96
        mm_data = dummy_image_for_clip(vision_config)
97
        return seq_data, mm_data
98
99
100
101
102
103
104
105
106
107
    elif isinstance(vision_config, SiglipVisionConfig):
        seq_data = dummy_seq_data_for_siglip(
            vision_config,
            seq_len,
            image_token_id=hf_config.image_token_index,
            image_feature_size_override=image_feature_size,
        )

        mm_data = dummy_image_for_siglip(vision_config)
        return seq_data, mm_data
108
109
110
111
112

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


113
114
115
116
117
118
119
120
121
def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs):
    multi_modal_data = llm_inputs.get("multi_modal_data")
    if multi_modal_data is None or "image" not in multi_modal_data:
        return llm_inputs

    model_config = ctx.model_config
    hf_config = ctx.get_hf_config(LlavaConfig)
    vision_config = hf_config.vision_config

122
123
    image_feature_size = get_max_llava_image_tokens(ctx)

124
125
126
127
128
129
    if isinstance(vision_config, CLIPVisionConfig):
        return input_processor_for_clip(
            model_config,
            vision_config,
            llm_inputs,
            image_token_id=hf_config.image_token_index,
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
            image_feature_size_override=image_feature_size,
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return input_processor_for_siglip(
            model_config,
            vision_config,
            llm_inputs,
            image_token_id=hf_config.image_token_index,
            image_feature_size_override=image_feature_size,
        )

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


def _init_vision_tower(hf_config: LlavaConfig):
    vision_config = hf_config.vision_config

    # Initialize the vision tower only up to the required feature layer
    vision_feature_layer = hf_config.vision_feature_layer
    if vision_feature_layer < 0:
        num_hidden_layers = hf_config.vision_config.num_hidden_layers \
            + vision_feature_layer + 1
    else:
        num_hidden_layers = vision_feature_layer + 1

    if isinstance(vision_config, CLIPVisionConfig):
        return CLIPVisionModel(
            vision_config,
            num_hidden_layers_override=num_hidden_layers,
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return SiglipVisionModel(
            vision_config,
            num_hidden_layers_override=num_hidden_layers,
165
166
167
168
169
170
        )

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


171
@MULTIMODAL_REGISTRY.register_image_input_mapper()
172
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
173
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava)
174
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava)
175
176
class LlavaForConditionalGeneration(nn.Module, SupportsVision):

177
    def __init__(self,
178
                 config: LlavaConfig,
179
                 multimodal_config: MultiModalConfig,
180
                 cache_config: Optional[CacheConfig] = None,
181
                 quant_config: Optional[QuantizationConfig] = None) -> None:
182
        super().__init__()
183

184
        self.config = config
185
        self.multimodal_config = multimodal_config
186

187
        # TODO: Optionally initializes this for supporting embeddings.
188
        self.vision_tower = _init_vision_tower(config)
189
190
191
192
193
        self.multi_modal_projector = LlavaMultiModalProjector(
            vision_hidden_size=config.vision_config.hidden_size,
            text_hidden_size=config.text_config.hidden_size,
            projector_hidden_act=config.projector_hidden_act)

194
195
        self.language_model = init_vllm_registered_model(
            config.text_config, cache_config, quant_config)
196

197
198
199
200
201
202
203
    def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)
        actual_dims = tuple(data.shape[1:])

        if actual_dims != expected_dims:
            expected_expr = ("batch_size", *map(str, expected_dims))
204
            raise ValueError(
205
206
                f"The expected shape of pixel values is {expected_expr}. "
                f"You supplied {tuple(data.shape)}.")
207
208
209
210

        return data

    def _parse_and_validate_image_input(
211
212
            self, **kwargs: object) -> Optional[LlavaImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
213

214
215
        if pixel_values is None:
            return None
216

217
218
219
        if not isinstance(pixel_values, torch.Tensor):
            raise ValueError("Incorrect type of pixel values. "
                             f"Got type: {type(pixel_values)}")
220

221
222
        return LlavaImagePixelInputs(
            type="pixel_values",
223
            data=self._validate_pixel_values(pixel_values),
224
        )
225
226
227
228
229
230
231
232
233
234
235

    def _select_image_features(self, image_features: torch.Tensor, *,
                               strategy: str) -> torch.Tensor:
        # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421  # noqa
        if strategy == "default":
            return image_features[:, 1:]
        elif strategy == "full":
            return image_features

        raise ValueError(f"Unexpected select feature strategy: {strategy}")

236
237
238
239
240
    def _image_pixels_to_features(
        self,
        vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:
241

242
243
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
244
        image_features = vision_tower(pixel_values)
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260

        return self._select_image_features(
            image_features,
            strategy=self.config.vision_feature_select_strategy,
        )

    def _process_image_pixels(self,
                              inputs: LlavaImagePixelInputs) -> torch.Tensor:
        assert self.vision_tower is not None

        pixel_values = inputs["data"]

        return self._image_pixels_to_features(self.vision_tower, pixel_values)

    def _process_image_input(self,
                             image_input: LlavaImageInputs) -> torch.Tensor:
261
262
        assert self.vision_tower is not None
        image_features = self._process_image_pixels(image_input)
263
264
        return self.multi_modal_projector(image_features)

265
266
267
268
269
270
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
271
        intermediate_tensors: Optional[IntermediateTensors] = None,
272
273
        **kwargs: object,
    ) -> SamplerOutput:
Cyrus Leung's avatar
Cyrus Leung committed
274
        """Run forward pass for LLaVA-1.5.
275
276
277

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
278

279
        Concretely, consider a text prompt:
280
281
        `"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.

282
        Tokenizer outputs:
283
284
285
286
287
288
289
290
291
292
293
294
295
296
        `[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
        278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
        before they are inputted to the model, so the input processor prepends 
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
        29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
        29901]`.

        We insert 575 tokens so that including the original image token in the
        input, there are a total of 576 (24 * 24) image tokens, which
        corresponds to the number of image tokens inputted to the language
        model, i.e. the number of image tokens outputted by the visual encoder.
297
298
299
300
301
302
303

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
Cyrus Leung's avatar
Cyrus Leung committed
304
            pixel_values: The pixels in each input image.
305
306
307
        
        See also:
            :class:`LlavaImageInputs`
308
        """
309
        image_input = self._parse_and_validate_image_input(**kwargs)
310

311
312
        if image_input is not None:
            vision_embeddings = self._process_image_input(image_input)
313
314
            inputs_embeds = self.language_model.model.get_input_embeddings(
                input_ids)
315

316
            inputs_embeds = merge_vision_embeddings(
317
                input_ids, inputs_embeds, vision_embeddings,
318
                self.config.image_token_index)
319

320
321
322
            input_ids = None
        else:
            inputs_embeds = None
323

324
325
326
327
328
329
        hidden_states = self.language_model.model(input_ids,
                                                  positions,
                                                  kv_caches,
                                                  attn_metadata,
                                                  None,
                                                  inputs_embeds=inputs_embeds)
330
331
332
333
334

        return hidden_states

    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
335
336
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)
337
338
339
340
341
342

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
343
        return self.language_model.sample(logits, sampling_metadata)
344

345
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
        # prepare weight iterators for components
        vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3)

        # load vision encoder
        vit_weights = filter_weights(vit_weights, "vision_tower")
        self.vision_tower.load_weights(vit_weights)

        # load mlp projector
        mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
        mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
        for name, loaded_weight in mlp_weights:
            param = mlp_params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)

        # load llm backbone
        llm_weights = filter_weights(llm_weights, "language_model")
        self.language_model.load_weights(llm_weights)