llava.py 14.2 KB
Newer Older
1
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
2
3

import torch
4
import torch.nn as nn
5
from transformers import LlavaConfig
6
7

from vllm.attention import AttentionMetadata
8
from vllm.config import CacheConfig, VisionLanguageConfig
9
10
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.logits_processor import LogitsProcessor
11
12
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
13
14
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
15
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
16
from vllm.model_executor.models.clip import CLIPVisionModel
17
18
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
19
20
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import get_dummy_image_data
21
22
from vllm.sequence import SamplerOutput

23
from .interfaces import SupportsVision
24

25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
_KEYS_TO_MODIFY_MAPPING = {
    "language_model.lm_head": "lm_head",
    "language_model.model": "language_model",
}


# TODO(xwjiang): Run benchmark and decide if TP.
class LlavaMultiModalProjector(nn.Module):

    def __init__(self, vision_hidden_size: int, text_hidden_size: int,
                 projector_hidden_act: str):
        super().__init__()

        self.linear_1 = nn.Linear(vision_hidden_size,
                                  text_hidden_size,
                                  bias=True)
        self.act = get_act_fn(projector_hidden_act)
        self.linear_2 = nn.Linear(text_hidden_size,
                                  text_hidden_size,
                                  bias=True)

46
    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
47
48
49
50
51
52
        hidden_states = self.linear_1(image_features)
        hidden_states = self.act(hidden_states)
        hidden_states = self.linear_2(hidden_states)
        return hidden_states


53
54
55
56
def merge_vision_embeddings(input_ids: torch.Tensor,
                            inputs_embeds: torch.Tensor,
                            vision_embeddings: torch.Tensor,
                            image_token_id: int) -> torch.Tensor:
57
58
    """In place merges in vision_embeddings with inputs_embeds."""
    mask = (input_ids == image_token_id)
59
60
61
62
63
64
65

    image_feature_size = vision_embeddings.shape[0] * vision_embeddings.shape[1]
    if mask.sum() != image_feature_size:
        raise ValueError(f"image_feature_size should be {image_feature_size}, "
                         f"but found: {mask.sum()}")

    inputs_embeds[mask] = vision_embeddings.view(image_feature_size,
66
67
                                                 vision_embeddings.shape[-1])

68
    return inputs_embeds
69

70

71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
class LlavaImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
    data: torch.Tensor
    """Shape: (batch_size, num_channels, height, width)"""


class LlavaImageFeatureInputs(TypedDict):
    type: Literal["image_features"]
    data: torch.Tensor
    """Shape: (batch_size, image_feature_size, hidden_size)"""


LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageFeatureInputs]


86
87
88
@MULTIMODAL_REGISTRY.register_image_feature_input()
@MULTIMODAL_REGISTRY.register_image_pixel_input()
@MULTIMODAL_REGISTRY.register_dummy_data(get_dummy_image_data)
89
90
class LlavaForConditionalGeneration(nn.Module, SupportsVision):

91
    def __init__(self,
92
                 config: LlavaConfig,
93
                 vlm_config: VisionLanguageConfig,
94
                 cache_config: Optional[CacheConfig] = None,
95
                 quant_config: Optional[QuantizationConfig] = None) -> None:
96
        super().__init__()
97

98
        self.config = config
99
        self.vlm_config = vlm_config
100

101
        if self.vlm_config.image_input_type == (
102
103
104
105
106
107
108
109
110
111
                VisionLanguageConfig.ImageInputType.PIXEL_VALUES):
            self.vision_tower = CLIPVisionModel(config.vision_config)
        else:
            self.vision_tower = None

        self.multi_modal_projector = LlavaMultiModalProjector(
            vision_hidden_size=config.vision_config.hidden_size,
            text_hidden_size=config.text_config.hidden_size,
            projector_hidden_act=config.projector_hidden_act)

112
        self.quant_config = quant_config
113
114
        self.language_model = LlamaModel(config.text_config, cache_config,
                                         quant_config)
115
116
117
118
119
120
121
122
123
124
        self.unpadded_vocab_size = config.text_config.vocab_size
        self.lm_head = ParallelLMHead(
            self.unpadded_vocab_size,
            config.text_config.hidden_size,
            org_num_embeddings=self.language_model.org_vocab_size)
        logit_scale = getattr(config, "logit_scale", 1.0)
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size, logit_scale)
        self.sampler = Sampler()

125
    def _validate_image_data(self, data: torch.Tensor) -> torch.Tensor:
126
        if list(data.shape[1:]) != list(self.vlm_config.image_input_shape[1:]):
127
128
            raise ValueError(
                f"The expected image tensor shape is batch dimension plus "
129
                f"{self.vlm_config.image_input_shape[1:]}. "
130
131
132
133
134
135
136
137
                f"You supplied {data.shape}. "
                f"If you are using vLLM's entrypoint, make sure your "
                f"supplied image input is consistent with "
                f"image_input_shape in engine args.")

        return data

    def _parse_and_validate_image_input(
138
139
140
141
            self, **kwargs: object) -> Optional[LlavaImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
        image_features = kwargs.pop("image_features", None)

142
        expected_input_type = self.vlm_config.image_input_type
143
144
145
        ImageInputType = VisionLanguageConfig.ImageInputType

        if expected_input_type == ImageInputType.PIXEL_VALUES:
146
147
148
149
150
151
152
            if image_features is not None:
                raise ValueError(
                    "Expected pixel values but got image features")
            if pixel_values is None:
                return None

            if not isinstance(pixel_values, torch.Tensor):
153
154
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")
155
156
157

            return LlavaImagePixelInputs(
                type="pixel_values",
158
                data=self._validate_image_data(pixel_values),
159
            )
160
161
162
163
164
165
166
167
168

        if expected_input_type == ImageInputType.IMAGE_FEATURES:
            if pixel_values is not None:
                raise ValueError(
                    "Expected image features but got pixel values")
            if image_features is None:
                return None

            if not isinstance(image_features, torch.Tensor):
169
170
                raise ValueError("Incorrect type of image features. "
                                 f"Got type: {type(image_features)}")
171
172
173

            return LlavaImageFeatureInputs(
                type="image_features",
174
                data=self._validate_image_data(image_features),
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
            )

        return None

    def _select_image_features(self, image_features: torch.Tensor, *,
                               strategy: str) -> torch.Tensor:
        # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421  # noqa
        if strategy == "default":
            return image_features[:, 1:]
        elif strategy == "full":
            return image_features

        raise ValueError(f"Unexpected select feature strategy: {strategy}")

    def _image_pixels_to_features(self, vision_tower: CLIPVisionModel,
                                  pixel_values: torch.Tensor) -> torch.Tensor:

192
193
194
195
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
        image_features = vision_tower(pixel_values.to(vision_tower.device),
                                      self.config.vision_feature_layer)
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219

        return self._select_image_features(
            image_features,
            strategy=self.config.vision_feature_select_strategy,
        )

    def _process_image_pixels(self,
                              inputs: LlavaImagePixelInputs) -> torch.Tensor:
        assert self.vision_tower is not None

        pixel_values = inputs["data"]

        return self._image_pixels_to_features(self.vision_tower, pixel_values)

    def _process_image_input(self,
                             image_input: LlavaImageInputs) -> torch.Tensor:
        if image_input["type"] == "pixel_values":
            assert self.vision_tower is not None
            image_features = self._process_image_pixels(image_input)
        else:
            image_features = image_input["data"]

        return self.multi_modal_projector(image_features)

220
221
222
223
224
225
226
227
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
        **kwargs: object,
    ) -> SamplerOutput:
Cyrus Leung's avatar
Cyrus Leung committed
228
        """Run forward pass for LLaVA-1.5.
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
        Concretely, consider a text prompt:
        "<image>\nUSER: What's the content of the image?\nASSISTANT:".
        Tokenizer outputs:
        [1, 32000, 29871, 13, 11889, 29901, 1724, 29915, 29879, 278,
        2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901].
        The to-be-inserted image has a size of 576 (24 * 24) along the context
        length dimension.
        `input_ids` is thus [1, 32000, ..., 32000, 29871, 13, 11889, 29901,
        1724, 29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933,
        9047, 13566, 29901].
        There will be 576 `32000` in the `input_ids`.
        (32000 is the token id for `<image>`.)

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

Cyrus Leung's avatar
Cyrus Leung committed
248
249
        This model has two modes of image inputs:
        `PIXEL_VALUES` and `IMAGE_FEATURES`.
250
251
252
253

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
Cyrus Leung's avatar
Cyrus Leung committed
254
255
256
257
258
259
260
261
262
263
264
265
266
            pixel_values: The pixels in each input image.
                Expects a batch with shape `[1, 3, 336, 336]`.
                (Only applicable to `PIXEL_VALUES` mode)
            image_features: The image features for each input image outputted by
                the vision tower before passing to the multi-modal projector.
                Expects a batch with shape `[1, 576, 1024]`.
                (Only applicable to `IMAGE_FEATURES` mode)

        See also:
            Each input maps to huggingface implementation, as follows:

            - `pixel_values`: https://github.com/huggingface/transformers/blob/v4.41.1/src/transformers/models/llava/modeling_llava.py#L360
            - `image_features`: https://github.com/huggingface/transformers/blob/v4.41.1/src/transformers/models/llava/modeling_llava.py#L437
267
        """
268
        image_input = self._parse_and_validate_image_input(**kwargs)
269

270
271
        if image_input is not None:
            vision_embeddings = self._process_image_input(image_input)
272
            inputs_embeds = self.language_model.get_input_embeddings(input_ids)
273

274
            inputs_embeds = merge_vision_embeddings(
275
                input_ids, inputs_embeds, vision_embeddings,
276
                self.vlm_config.image_token_id)
277

278
279
280
            input_ids = None
        else:
            inputs_embeds = None
281

282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
        hidden_states = self.language_model(input_ids,
                                            positions,
                                            kv_caches,
                                            attn_metadata,
                                            inputs_embeds=inputs_embeds)

        return hidden_states

    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head.weight, hidden_states,
                                       sampling_metadata)
        return logits

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        next_tokens = self.sampler(logits, sampling_metadata)
        return next_tokens

304
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
305
306
307
308
309
310
311
312
313
314
        # only doing this for language model part for now.
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
315
        for name, loaded_weight in weights:
316
317
            if "rotary_emb.inv_freq" in name:
                continue
318
319
320
            # post_layernorm is not needed in CLIPVisionModel
            if "vision_model.post_layernorm" in name:
                continue
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
            for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
                if key_to_modify in name:
                    name = name.replace(key_to_modify, new_key)
            use_default_weight_loading = False
            if "vision" in name:
                if self.vision_tower is not None:
                    # We only do sharding for language model and
                    # not vision model for now.
                    use_default_weight_loading = True
            else:
                for (param_name, weight_name,
                     shard_id) in stacked_params_mapping:
                    if weight_name not in name:
                        continue
                    param = params_dict[name.replace(weight_name, param_name)]
                    weight_loader = param.weight_loader
                    weight_loader(param, loaded_weight, shard_id)
                    break
                else:
                    use_default_weight_loading = True
            if use_default_weight_loading:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)