llava.py 13.8 KB
Newer Older
1
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
2
3

import torch
4
import torch.nn as nn
5
from transformers import CLIPVisionConfig, LlavaConfig
6
7

from vllm.attention import AttentionMetadata
8
from vllm.config import CacheConfig, MultiModalConfig
9
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
10
11
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.logits_processor import LogitsProcessor
12
13
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
14
15
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
16
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
17
from vllm.model_executor.models.clip import CLIPVisionModel
18
19
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
20
from vllm.multimodal import MULTIMODAL_REGISTRY
21
from vllm.sequence import IntermediateTensors, SamplerOutput
22

23
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
24
                   get_max_clip_image_tokens, input_processor_for_clip)
25
from .interfaces import SupportsVision
26
from .utils import merge_vision_embeddings
27

28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
_KEYS_TO_MODIFY_MAPPING = {
    "language_model.lm_head": "lm_head",
    "language_model.model": "language_model",
}


# TODO(xwjiang): Run benchmark and decide if TP.
class LlavaMultiModalProjector(nn.Module):

    def __init__(self, vision_hidden_size: int, text_hidden_size: int,
                 projector_hidden_act: str):
        super().__init__()

        self.linear_1 = nn.Linear(vision_hidden_size,
                                  text_hidden_size,
                                  bias=True)
        self.act = get_act_fn(projector_hidden_act)
        self.linear_2 = nn.Linear(text_hidden_size,
                                  text_hidden_size,
                                  bias=True)

49
    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
50
51
52
53
54
55
        hidden_states = self.linear_1(image_features)
        hidden_states = self.act(hidden_states)
        hidden_states = self.linear_2(hidden_states)
        return hidden_states


56
57
58
class LlavaImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
    data: torch.Tensor
59
    """Shape: `(batch_size, num_channels, height, width)`"""
60
61


62
LlavaImageInputs = LlavaImagePixelInputs
63
64


65
66
67
68
69
70
71
72
73
74
75
def get_max_llava_image_tokens(ctx: InputContext):
    hf_config = ctx.get_hf_config(LlavaConfig)
    vision_config = hf_config.vision_config

    if isinstance(vision_config, CLIPVisionConfig):
        return get_max_clip_image_tokens(vision_config)

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


76
77
78
79
80
81
82
83
84
85
86
def dummy_data_for_llava(ctx: InputContext, seq_len: int):
    hf_config = ctx.get_hf_config(LlavaConfig)
    vision_config = hf_config.vision_config

    if isinstance(vision_config, CLIPVisionConfig):
        seq_data = dummy_seq_data_for_clip(
            vision_config,
            seq_len,
            image_token_id=hf_config.image_token_index,
        )

87
        mm_data = dummy_image_for_clip(vision_config)
88
89
90
91
92
93
        return seq_data, mm_data

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs):
    multi_modal_data = llm_inputs.get("multi_modal_data")
    if multi_modal_data is None or "image" not in multi_modal_data:
        return llm_inputs

    model_config = ctx.model_config
    hf_config = ctx.get_hf_config(LlavaConfig)
    vision_config = hf_config.vision_config

    if isinstance(vision_config, CLIPVisionConfig):
        return input_processor_for_clip(
            model_config,
            vision_config,
            llm_inputs,
            image_token_id=hf_config.image_token_index,
        )

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


115
@MULTIMODAL_REGISTRY.register_image_input_mapper()
116
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
117
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava)
118
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava)
119
120
class LlavaForConditionalGeneration(nn.Module, SupportsVision):

121
    def __init__(self,
122
                 config: LlavaConfig,
123
                 multimodal_config: MultiModalConfig,
124
                 cache_config: Optional[CacheConfig] = None,
125
                 quant_config: Optional[QuantizationConfig] = None) -> None:
126
        super().__init__()
127

128
        self.config = config
129
        self.multimodal_config = multimodal_config
130

131
132
133
134
135
136
137
138
        # Initialize the vision tower only up to the required feature layer
        vision_feature_layer = config.vision_feature_layer
        if vision_feature_layer < 0:
            num_hidden_layers = config.vision_config.num_hidden_layers \
                + vision_feature_layer + 1
        else:
            num_hidden_layers = vision_feature_layer + 1

139
        # TODO: Optionally initializes this for supporting embeddings.
140
141
        self.vision_tower = CLIPVisionModel(
            config.vision_config, num_hidden_layers_override=num_hidden_layers)
142
143
144
145
146
        self.multi_modal_projector = LlavaMultiModalProjector(
            vision_hidden_size=config.vision_config.hidden_size,
            text_hidden_size=config.text_config.hidden_size,
            projector_hidden_act=config.projector_hidden_act)

147
        self.quant_config = quant_config
148
149
        self.language_model = LlamaModel(config.text_config, cache_config,
                                         quant_config)
150
151
152
153
        self.unpadded_vocab_size = config.text_config.vocab_size
        self.lm_head = ParallelLMHead(
            self.unpadded_vocab_size,
            config.text_config.hidden_size,
154
155
            org_num_embeddings=self.language_model.org_vocab_size,
            quant_config=quant_config)
156
157
        logit_scale = getattr(config, "logit_scale", 1.0)
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
158
159
                                                config.text_config.vocab_size,
                                                logit_scale)
160
161
        self.sampler = Sampler()

162
163
164
165
166
167
168
    def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)
        actual_dims = tuple(data.shape[1:])

        if actual_dims != expected_dims:
            expected_expr = ("batch_size", *map(str, expected_dims))
169
            raise ValueError(
170
171
                f"The expected shape of pixel values is {expected_expr}. "
                f"You supplied {tuple(data.shape)}.")
172
173
174
175

        return data

    def _parse_and_validate_image_input(
176
177
            self, **kwargs: object) -> Optional[LlavaImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
178

179
180
        if pixel_values is None:
            return None
181

182
183
184
        if not isinstance(pixel_values, torch.Tensor):
            raise ValueError("Incorrect type of pixel values. "
                             f"Got type: {type(pixel_values)}")
185

186
187
        return LlavaImagePixelInputs(
            type="pixel_values",
188
            data=self._validate_pixel_values(pixel_values),
189
        )
190
191
192
193
194
195
196
197
198
199
200
201
202
203

    def _select_image_features(self, image_features: torch.Tensor, *,
                               strategy: str) -> torch.Tensor:
        # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421  # noqa
        if strategy == "default":
            return image_features[:, 1:]
        elif strategy == "full":
            return image_features

        raise ValueError(f"Unexpected select feature strategy: {strategy}")

    def _image_pixels_to_features(self, vision_tower: CLIPVisionModel,
                                  pixel_values: torch.Tensor) -> torch.Tensor:

204
205
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
206
        image_features = vision_tower(pixel_values)
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222

        return self._select_image_features(
            image_features,
            strategy=self.config.vision_feature_select_strategy,
        )

    def _process_image_pixels(self,
                              inputs: LlavaImagePixelInputs) -> torch.Tensor:
        assert self.vision_tower is not None

        pixel_values = inputs["data"]

        return self._image_pixels_to_features(self.vision_tower, pixel_values)

    def _process_image_input(self,
                             image_input: LlavaImageInputs) -> torch.Tensor:
223
224
        assert self.vision_tower is not None
        image_features = self._process_image_pixels(image_input)
225
226
        return self.multi_modal_projector(image_features)

227
228
229
230
231
232
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
233
        intermediate_tensors: Optional[IntermediateTensors] = None,
234
235
        **kwargs: object,
    ) -> SamplerOutput:
Cyrus Leung's avatar
Cyrus Leung committed
236
        """Run forward pass for LLaVA-1.5.
237
238
239

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
240

241
        Concretely, consider a text prompt:
242
243
        `"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.

244
        Tokenizer outputs:
245
246
247
248
249
250
251
252
253
254
255
256
257
258
        `[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
        278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
        before they are inputted to the model, so the input processor prepends 
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
        29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
        29901]`.

        We insert 575 tokens so that including the original image token in the
        input, there are a total of 576 (24 * 24) image tokens, which
        corresponds to the number of image tokens inputted to the language
        model, i.e. the number of image tokens outputted by the visual encoder.
259
260
261
262
263
264
265

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
Cyrus Leung's avatar
Cyrus Leung committed
266
            pixel_values: The pixels in each input image.
267
268
269
        
        See also:
            :class:`LlavaImageInputs`
270
        """
271
        image_input = self._parse_and_validate_image_input(**kwargs)
272

273
274
        if image_input is not None:
            vision_embeddings = self._process_image_input(image_input)
275
            inputs_embeds = self.language_model.get_input_embeddings(input_ids)
276

277
            inputs_embeds = merge_vision_embeddings(
278
                input_ids, inputs_embeds, vision_embeddings,
279
                self.config.image_token_index)
280

281
282
283
            input_ids = None
        else:
            inputs_embeds = None
284

285
286
287
288
        hidden_states = self.language_model(input_ids,
                                            positions,
                                            kv_caches,
                                            attn_metadata,
289
                                            None,
290
291
292
293
294
295
                                            inputs_embeds=inputs_embeds)

        return hidden_states

    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
296
        logits = self.logits_processor(self.lm_head, hidden_states,
297
298
299
300
301
302
303
304
305
306
307
                                       sampling_metadata)
        return logits

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        next_tokens = self.sampler(logits, sampling_metadata)
        return next_tokens

308
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
309
310
311
312
313
314
315
316
317
318
        # only doing this for language model part for now.
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
319
        for name, loaded_weight in weights:
320
321
            if "rotary_emb.inv_freq" in name:
                continue
322
323
324
            # post_layernorm is not needed in CLIPVisionModel
            if "vision_model.post_layernorm" in name:
                continue
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
            for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
                if key_to_modify in name:
                    name = name.replace(key_to_modify, new_key)
            use_default_weight_loading = False
            if "vision" in name:
                if self.vision_tower is not None:
                    # We only do sharding for language model and
                    # not vision model for now.
                    use_default_weight_loading = True
            else:
                for (param_name, weight_name,
                     shard_id) in stacked_params_mapping:
                    if weight_name not in name:
                        continue
                    param = params_dict[name.replace(weight_name, param_name)]
                    weight_loader = param.weight_loader
                    weight_loader(param, loaded_weight, shard_id)
                    break
                else:
                    use_default_weight_loading = True
345
            if use_default_weight_loading and name in params_dict:
346
347
348
349
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)