llava.py 13.7 KB
Newer Older
1
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
2
3

import torch
4
import torch.nn as nn
5
from transformers import CLIPVisionConfig, LlavaConfig
6
7

from vllm.attention import AttentionMetadata
8
from vllm.config import CacheConfig, MultiModalConfig
9
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
10
11
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.logits_processor import LogitsProcessor
12
13
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
14
15
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
16
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
17
from vllm.model_executor.models.clip import CLIPVisionModel
18
19
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
20
from vllm.multimodal import MULTIMODAL_REGISTRY
21
from vllm.sequence import IntermediateTensors, SamplerOutput
22

23
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
24
                   get_max_clip_image_tokens, input_processor_for_clip)
25
from .interfaces import SupportsVision
26
from .utils import merge_vision_embeddings
27

28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
_KEYS_TO_MODIFY_MAPPING = {
    "language_model.lm_head": "lm_head",
    "language_model.model": "language_model",
}


# TODO(xwjiang): Run benchmark and decide if TP.
class LlavaMultiModalProjector(nn.Module):

    def __init__(self, vision_hidden_size: int, text_hidden_size: int,
                 projector_hidden_act: str):
        super().__init__()

        self.linear_1 = nn.Linear(vision_hidden_size,
                                  text_hidden_size,
                                  bias=True)
        self.act = get_act_fn(projector_hidden_act)
        self.linear_2 = nn.Linear(text_hidden_size,
                                  text_hidden_size,
                                  bias=True)

49
    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
50
51
52
53
54
55
        hidden_states = self.linear_1(image_features)
        hidden_states = self.act(hidden_states)
        hidden_states = self.linear_2(hidden_states)
        return hidden_states


56
57
58
class LlavaImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
    data: torch.Tensor
59
    """Shape: `(batch_size, num_channels, height, width)`"""
60
61


62
LlavaImageInputs = LlavaImagePixelInputs
63
64


65
66
67
68
69
70
71
72
73
74
75
def get_max_llava_image_tokens(ctx: InputContext):
    hf_config = ctx.get_hf_config(LlavaConfig)
    vision_config = hf_config.vision_config

    if isinstance(vision_config, CLIPVisionConfig):
        return get_max_clip_image_tokens(vision_config)

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


76
77
78
79
80
81
82
83
84
85
86
def dummy_data_for_llava(ctx: InputContext, seq_len: int):
    hf_config = ctx.get_hf_config(LlavaConfig)
    vision_config = hf_config.vision_config

    if isinstance(vision_config, CLIPVisionConfig):
        seq_data = dummy_seq_data_for_clip(
            vision_config,
            seq_len,
            image_token_id=hf_config.image_token_index,
        )

87
        mm_data = dummy_image_for_clip(vision_config)
88
89
90
91
92
93
        return seq_data, mm_data

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs):
    multi_modal_data = llm_inputs.get("multi_modal_data")
    if multi_modal_data is None or "image" not in multi_modal_data:
        return llm_inputs

    model_config = ctx.model_config
    hf_config = ctx.get_hf_config(LlavaConfig)
    vision_config = hf_config.vision_config

    if isinstance(vision_config, CLIPVisionConfig):
        return input_processor_for_clip(
            model_config,
            vision_config,
            llm_inputs,
            image_token_id=hf_config.image_token_index,
        )

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


115
@MULTIMODAL_REGISTRY.register_image_input_mapper()
116
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
117
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava)
118
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava)
119
120
class LlavaForConditionalGeneration(nn.Module, SupportsVision):

121
    def __init__(self,
122
                 config: LlavaConfig,
123
                 multimodal_config: MultiModalConfig,
124
                 cache_config: Optional[CacheConfig] = None,
125
                 quant_config: Optional[QuantizationConfig] = None) -> None:
126
        super().__init__()
127

128
        self.config = config
129
        self.multimodal_config = multimodal_config
130

131
132
133
134
135
136
137
138
        # Initialize the vision tower only up to the required feature layer
        vision_feature_layer = config.vision_feature_layer
        if vision_feature_layer < 0:
            num_hidden_layers = config.vision_config.num_hidden_layers \
                + vision_feature_layer + 1
        else:
            num_hidden_layers = vision_feature_layer + 1

139
        # TODO: Optionally initializes this for supporting embeddings.
140
141
        self.vision_tower = CLIPVisionModel(
            config.vision_config, num_hidden_layers_override=num_hidden_layers)
142
143
144
145
146
        self.multi_modal_projector = LlavaMultiModalProjector(
            vision_hidden_size=config.vision_config.hidden_size,
            text_hidden_size=config.text_config.hidden_size,
            projector_hidden_act=config.projector_hidden_act)

147
        self.quant_config = quant_config
148
149
        self.language_model = LlamaModel(config.text_config, cache_config,
                                         quant_config)
150
151
152
153
        self.unpadded_vocab_size = config.text_config.vocab_size
        self.lm_head = ParallelLMHead(
            self.unpadded_vocab_size,
            config.text_config.hidden_size,
154
155
            org_num_embeddings=self.language_model.org_vocab_size,
            quant_config=quant_config)
156
157
158
159
160
        logit_scale = getattr(config, "logit_scale", 1.0)
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size, logit_scale)
        self.sampler = Sampler()

161
162
163
164
165
166
167
    def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)
        actual_dims = tuple(data.shape[1:])

        if actual_dims != expected_dims:
            expected_expr = ("batch_size", *map(str, expected_dims))
168
            raise ValueError(
169
170
                f"The expected shape of pixel values is {expected_expr}. "
                f"You supplied {tuple(data.shape)}.")
171
172
173
174

        return data

    def _parse_and_validate_image_input(
175
176
            self, **kwargs: object) -> Optional[LlavaImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
177

178
179
        if pixel_values is None:
            return None
180

181
182
183
        if not isinstance(pixel_values, torch.Tensor):
            raise ValueError("Incorrect type of pixel values. "
                             f"Got type: {type(pixel_values)}")
184

185
186
        return LlavaImagePixelInputs(
            type="pixel_values",
187
            data=self._validate_pixel_values(pixel_values),
188
        )
189
190
191
192
193
194
195
196
197
198
199
200
201
202

    def _select_image_features(self, image_features: torch.Tensor, *,
                               strategy: str) -> torch.Tensor:
        # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421  # noqa
        if strategy == "default":
            return image_features[:, 1:]
        elif strategy == "full":
            return image_features

        raise ValueError(f"Unexpected select feature strategy: {strategy}")

    def _image_pixels_to_features(self, vision_tower: CLIPVisionModel,
                                  pixel_values: torch.Tensor) -> torch.Tensor:

203
204
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
205
        image_features = vision_tower(pixel_values)
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221

        return self._select_image_features(
            image_features,
            strategy=self.config.vision_feature_select_strategy,
        )

    def _process_image_pixels(self,
                              inputs: LlavaImagePixelInputs) -> torch.Tensor:
        assert self.vision_tower is not None

        pixel_values = inputs["data"]

        return self._image_pixels_to_features(self.vision_tower, pixel_values)

    def _process_image_input(self,
                             image_input: LlavaImageInputs) -> torch.Tensor:
222
223
        assert self.vision_tower is not None
        image_features = self._process_image_pixels(image_input)
224
225
        return self.multi_modal_projector(image_features)

226
227
228
229
230
231
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
232
        intermediate_tensors: Optional[IntermediateTensors] = None,
233
234
        **kwargs: object,
    ) -> SamplerOutput:
Cyrus Leung's avatar
Cyrus Leung committed
235
        """Run forward pass for LLaVA-1.5.
236
237
238

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
239

240
        Concretely, consider a text prompt:
241
242
        `"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.

243
        Tokenizer outputs:
244
245
246
247
248
249
250
251
252
253
254
255
256
257
        `[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
        278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
        before they are inputted to the model, so the input processor prepends 
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
        29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
        29901]`.

        We insert 575 tokens so that including the original image token in the
        input, there are a total of 576 (24 * 24) image tokens, which
        corresponds to the number of image tokens inputted to the language
        model, i.e. the number of image tokens outputted by the visual encoder.
258
259
260
261
262
263
264

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
Cyrus Leung's avatar
Cyrus Leung committed
265
            pixel_values: The pixels in each input image.
266
267
268
        
        See also:
            :class:`LlavaImageInputs`
269
        """
270
        image_input = self._parse_and_validate_image_input(**kwargs)
271

272
273
        if image_input is not None:
            vision_embeddings = self._process_image_input(image_input)
274
            inputs_embeds = self.language_model.get_input_embeddings(input_ids)
275

276
            inputs_embeds = merge_vision_embeddings(
277
                input_ids, inputs_embeds, vision_embeddings,
278
                self.config.image_token_index)
279

280
281
282
            input_ids = None
        else:
            inputs_embeds = None
283

284
285
286
287
        hidden_states = self.language_model(input_ids,
                                            positions,
                                            kv_caches,
                                            attn_metadata,
288
                                            None,
289
290
291
292
293
294
                                            inputs_embeds=inputs_embeds)

        return hidden_states

    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
295
        logits = self.logits_processor(self.lm_head, hidden_states,
296
297
298
299
300
301
302
303
304
305
306
                                       sampling_metadata)
        return logits

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        next_tokens = self.sampler(logits, sampling_metadata)
        return next_tokens

307
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
308
309
310
311
312
313
314
315
316
317
        # only doing this for language model part for now.
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
318
        for name, loaded_weight in weights:
319
320
            if "rotary_emb.inv_freq" in name:
                continue
321
322
323
            # post_layernorm is not needed in CLIPVisionModel
            if "vision_model.post_layernorm" in name:
                continue
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
            for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
                if key_to_modify in name:
                    name = name.replace(key_to_modify, new_key)
            use_default_weight_loading = False
            if "vision" in name:
                if self.vision_tower is not None:
                    # We only do sharding for language model and
                    # not vision model for now.
                    use_default_weight_loading = True
            else:
                for (param_name, weight_name,
                     shard_id) in stacked_params_mapping:
                    if weight_name not in name:
                        continue
                    param = params_dict[name.replace(weight_name, param_name)]
                    weight_loader = param.weight_loader
                    weight_loader(param, loaded_weight, shard_id)
                    break
                else:
                    use_default_weight_loading = True
344
            if use_default_weight_loading and name in params_dict:
345
346
347
348
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)