llava.py 15.2 KB
Newer Older
1
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
2
3

import torch
4
import torch.nn as nn
5
from transformers import CLIPVisionConfig, LlavaConfig
6
7

from vllm.attention import AttentionMetadata
8
from vllm.config import CacheConfig, VisionLanguageConfig
9
from vllm.inputs import INPUT_REGISTRY, InputContext
10
11
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.logits_processor import LogitsProcessor
12
13
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
14
15
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
16
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
17
from vllm.model_executor.models.clip import CLIPVisionModel
18
19
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
20
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalData
21
22
from vllm.sequence import SamplerOutput

23
24
from .clip import (dummy_feature_data_for_clip, dummy_pixel_data_for_clip,
                   dummy_seq_data_for_clip)
25
from .interfaces import SupportsVision
26

27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
_KEYS_TO_MODIFY_MAPPING = {
    "language_model.lm_head": "lm_head",
    "language_model.model": "language_model",
}


# TODO(xwjiang): Run benchmark and decide if TP.
class LlavaMultiModalProjector(nn.Module):

    def __init__(self, vision_hidden_size: int, text_hidden_size: int,
                 projector_hidden_act: str):
        super().__init__()

        self.linear_1 = nn.Linear(vision_hidden_size,
                                  text_hidden_size,
                                  bias=True)
        self.act = get_act_fn(projector_hidden_act)
        self.linear_2 = nn.Linear(text_hidden_size,
                                  text_hidden_size,
                                  bias=True)

48
    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
49
50
51
52
53
54
        hidden_states = self.linear_1(image_features)
        hidden_states = self.act(hidden_states)
        hidden_states = self.linear_2(hidden_states)
        return hidden_states


55
56
57
58
def merge_vision_embeddings(input_ids: torch.Tensor,
                            inputs_embeds: torch.Tensor,
                            vision_embeddings: torch.Tensor,
                            image_token_id: int) -> torch.Tensor:
59
60
    """In place merges in vision_embeddings with inputs_embeds."""
    mask = (input_ids == image_token_id)
61
62
63
64
65
66
67

    image_feature_size = vision_embeddings.shape[0] * vision_embeddings.shape[1]
    if mask.sum() != image_feature_size:
        raise ValueError(f"image_feature_size should be {image_feature_size}, "
                         f"but found: {mask.sum()}")

    inputs_embeds[mask] = vision_embeddings.view(image_feature_size,
68
69
                                                 vision_embeddings.shape[-1])

70
    return inputs_embeds
71

72

73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
class LlavaImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
    data: torch.Tensor
    """Shape: (batch_size, num_channels, height, width)"""


class LlavaImageFeatureInputs(TypedDict):
    type: Literal["image_features"]
    data: torch.Tensor
    """Shape: (batch_size, image_feature_size, hidden_size)"""


LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageFeatureInputs]


88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
def dummy_data_for_llava(ctx: InputContext, seq_len: int):
    multimodal_config = ctx.get_multimodal_config()
    hf_config = ctx.get_hf_config(LlavaConfig)
    vision_config = hf_config.vision_config

    if isinstance(vision_config, CLIPVisionConfig):
        seq_data = dummy_seq_data_for_clip(
            vision_config,
            seq_len,
            image_token_id=hf_config.image_token_index,
        )

        image_input_type = multimodal_config.image_input_type
        ImageInputType = VisionLanguageConfig.ImageInputType
        mm_data: MultiModalData
        if image_input_type == ImageInputType.PIXEL_VALUES:
            mm_data = dummy_pixel_data_for_clip(vision_config)
        elif image_input_type == ImageInputType.IMAGE_FEATURES:
            mm_data = dummy_feature_data_for_clip(vision_config)

        return seq_data, mm_data

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


@MULTIMODAL_REGISTRY.register_image_feature_input_mapper()
@MULTIMODAL_REGISTRY.register_image_pixel_input_mapper()
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava)
117
118
class LlavaForConditionalGeneration(nn.Module, SupportsVision):

119
    def __init__(self,
120
                 config: LlavaConfig,
121
                 vlm_config: VisionLanguageConfig,
122
                 cache_config: Optional[CacheConfig] = None,
123
                 quant_config: Optional[QuantizationConfig] = None) -> None:
124
        super().__init__()
125

126
        self.config = config
127
        self.vlm_config = vlm_config
128

129
        if self.vlm_config.image_input_type == (
130
131
132
133
134
135
136
137
138
139
                VisionLanguageConfig.ImageInputType.PIXEL_VALUES):
            self.vision_tower = CLIPVisionModel(config.vision_config)
        else:
            self.vision_tower = None

        self.multi_modal_projector = LlavaMultiModalProjector(
            vision_hidden_size=config.vision_config.hidden_size,
            text_hidden_size=config.text_config.hidden_size,
            projector_hidden_act=config.projector_hidden_act)

140
        self.quant_config = quant_config
141
142
        self.language_model = LlamaModel(config.text_config, cache_config,
                                         quant_config)
143
144
145
146
147
148
149
150
151
152
        self.unpadded_vocab_size = config.text_config.vocab_size
        self.lm_head = ParallelLMHead(
            self.unpadded_vocab_size,
            config.text_config.hidden_size,
            org_num_embeddings=self.language_model.org_vocab_size)
        logit_scale = getattr(config, "logit_scale", 1.0)
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size, logit_scale)
        self.sampler = Sampler()

153
    def _validate_image_data(self, data: torch.Tensor) -> torch.Tensor:
154
        if list(data.shape[1:]) != list(self.vlm_config.image_input_shape[1:]):
155
156
            raise ValueError(
                f"The expected image tensor shape is batch dimension plus "
157
                f"{self.vlm_config.image_input_shape[1:]}. "
158
159
160
161
162
163
164
165
                f"You supplied {data.shape}. "
                f"If you are using vLLM's entrypoint, make sure your "
                f"supplied image input is consistent with "
                f"image_input_shape in engine args.")

        return data

    def _parse_and_validate_image_input(
166
167
168
169
            self, **kwargs: object) -> Optional[LlavaImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
        image_features = kwargs.pop("image_features", None)

170
        expected_input_type = self.vlm_config.image_input_type
171
172
173
        ImageInputType = VisionLanguageConfig.ImageInputType

        if expected_input_type == ImageInputType.PIXEL_VALUES:
174
175
176
177
178
179
180
            if image_features is not None:
                raise ValueError(
                    "Expected pixel values but got image features")
            if pixel_values is None:
                return None

            if not isinstance(pixel_values, torch.Tensor):
181
182
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")
183
184
185

            return LlavaImagePixelInputs(
                type="pixel_values",
186
                data=self._validate_image_data(pixel_values),
187
            )
188
189
190
191
192
193
194
195
196

        if expected_input_type == ImageInputType.IMAGE_FEATURES:
            if pixel_values is not None:
                raise ValueError(
                    "Expected image features but got pixel values")
            if image_features is None:
                return None

            if not isinstance(image_features, torch.Tensor):
197
198
                raise ValueError("Incorrect type of image features. "
                                 f"Got type: {type(image_features)}")
199
200
201

            return LlavaImageFeatureInputs(
                type="image_features",
202
                data=self._validate_image_data(image_features),
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
            )

        return None

    def _select_image_features(self, image_features: torch.Tensor, *,
                               strategy: str) -> torch.Tensor:
        # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421  # noqa
        if strategy == "default":
            return image_features[:, 1:]
        elif strategy == "full":
            return image_features

        raise ValueError(f"Unexpected select feature strategy: {strategy}")

    def _image_pixels_to_features(self, vision_tower: CLIPVisionModel,
                                  pixel_values: torch.Tensor) -> torch.Tensor:

220
221
222
223
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
        image_features = vision_tower(pixel_values.to(vision_tower.device),
                                      self.config.vision_feature_layer)
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247

        return self._select_image_features(
            image_features,
            strategy=self.config.vision_feature_select_strategy,
        )

    def _process_image_pixels(self,
                              inputs: LlavaImagePixelInputs) -> torch.Tensor:
        assert self.vision_tower is not None

        pixel_values = inputs["data"]

        return self._image_pixels_to_features(self.vision_tower, pixel_values)

    def _process_image_input(self,
                             image_input: LlavaImageInputs) -> torch.Tensor:
        if image_input["type"] == "pixel_values":
            assert self.vision_tower is not None
            image_features = self._process_image_pixels(image_input)
        else:
            image_features = image_input["data"]

        return self.multi_modal_projector(image_features)

248
249
250
251
252
253
254
255
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
        **kwargs: object,
    ) -> SamplerOutput:
Cyrus Leung's avatar
Cyrus Leung committed
256
        """Run forward pass for LLaVA-1.5.
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
        Concretely, consider a text prompt:
        "<image>\nUSER: What's the content of the image?\nASSISTANT:".
        Tokenizer outputs:
        [1, 32000, 29871, 13, 11889, 29901, 1724, 29915, 29879, 278,
        2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901].
        The to-be-inserted image has a size of 576 (24 * 24) along the context
        length dimension.
        `input_ids` is thus [1, 32000, ..., 32000, 29871, 13, 11889, 29901,
        1724, 29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933,
        9047, 13566, 29901].
        There will be 576 `32000` in the `input_ids`.
        (32000 is the token id for `<image>`.)

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

Cyrus Leung's avatar
Cyrus Leung committed
276
277
        This model has two modes of image inputs:
        `PIXEL_VALUES` and `IMAGE_FEATURES`.
278
279
280
281

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
Cyrus Leung's avatar
Cyrus Leung committed
282
283
284
285
286
287
288
289
290
291
292
293
294
            pixel_values: The pixels in each input image.
                Expects a batch with shape `[1, 3, 336, 336]`.
                (Only applicable to `PIXEL_VALUES` mode)
            image_features: The image features for each input image outputted by
                the vision tower before passing to the multi-modal projector.
                Expects a batch with shape `[1, 576, 1024]`.
                (Only applicable to `IMAGE_FEATURES` mode)

        See also:
            Each input maps to huggingface implementation, as follows:

            - `pixel_values`: https://github.com/huggingface/transformers/blob/v4.41.1/src/transformers/models/llava/modeling_llava.py#L360
            - `image_features`: https://github.com/huggingface/transformers/blob/v4.41.1/src/transformers/models/llava/modeling_llava.py#L437
295
        """
296
        image_input = self._parse_and_validate_image_input(**kwargs)
297

298
299
        if image_input is not None:
            vision_embeddings = self._process_image_input(image_input)
300
            inputs_embeds = self.language_model.get_input_embeddings(input_ids)
301

302
            inputs_embeds = merge_vision_embeddings(
303
                input_ids, inputs_embeds, vision_embeddings,
304
                self.vlm_config.image_token_id)
305

306
307
308
            input_ids = None
        else:
            inputs_embeds = None
309

310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
        hidden_states = self.language_model(input_ids,
                                            positions,
                                            kv_caches,
                                            attn_metadata,
                                            inputs_embeds=inputs_embeds)

        return hidden_states

    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head.weight, hidden_states,
                                       sampling_metadata)
        return logits

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        next_tokens = self.sampler(logits, sampling_metadata)
        return next_tokens

332
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
333
334
335
336
337
338
339
340
341
342
        # only doing this for language model part for now.
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
343
        for name, loaded_weight in weights:
344
345
            if "rotary_emb.inv_freq" in name:
                continue
346
347
348
            # post_layernorm is not needed in CLIPVisionModel
            if "vision_model.post_layernorm" in name:
                continue
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
            for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
                if key_to_modify in name:
                    name = name.replace(key_to_modify, new_key)
            use_default_weight_loading = False
            if "vision" in name:
                if self.vision_tower is not None:
                    # We only do sharding for language model and
                    # not vision model for now.
                    use_default_weight_loading = True
            else:
                for (param_name, weight_name,
                     shard_id) in stacked_params_mapping:
                    if weight_name not in name:
                        continue
                    param = params_dict[name.replace(weight_name, param_name)]
                    weight_loader = param.weight_loader
                    weight_loader(param, loaded_weight, shard_id)
                    break
                else:
                    use_default_weight_loading = True
            if use_default_weight_loading:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)