llava.py 13.4 KB
Newer Older
1
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
2
3

import torch
4
import torch.nn as nn
5
from transformers import CLIPVisionConfig, LlavaConfig
6
7

from vllm.attention import AttentionMetadata
8
from vllm.config import CacheConfig, MultiModalConfig
9
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
10
11
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.logits_processor import LogitsProcessor
12
13
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
14
15
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
16
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
17
from vllm.model_executor.models.clip import CLIPVisionModel
18
19
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
20
from vllm.multimodal import MULTIMODAL_REGISTRY
21
from vllm.sequence import IntermediateTensors, SamplerOutput
22

23
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
24
                   get_max_clip_image_tokens, input_processor_for_clip)
25
from .interfaces import SupportsVision
26
from .utils import merge_vision_embeddings
27

28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
_KEYS_TO_MODIFY_MAPPING = {
    "language_model.lm_head": "lm_head",
    "language_model.model": "language_model",
}


# TODO(xwjiang): Run benchmark and decide if TP.
class LlavaMultiModalProjector(nn.Module):

    def __init__(self, vision_hidden_size: int, text_hidden_size: int,
                 projector_hidden_act: str):
        super().__init__()

        self.linear_1 = nn.Linear(vision_hidden_size,
                                  text_hidden_size,
                                  bias=True)
        self.act = get_act_fn(projector_hidden_act)
        self.linear_2 = nn.Linear(text_hidden_size,
                                  text_hidden_size,
                                  bias=True)

49
    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
50
51
52
53
54
55
        hidden_states = self.linear_1(image_features)
        hidden_states = self.act(hidden_states)
        hidden_states = self.linear_2(hidden_states)
        return hidden_states


56
57
58
class LlavaImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
    data: torch.Tensor
59
    """Shape: `(batch_size, num_channels, height, width)`"""
60
61


62
LlavaImageInputs = LlavaImagePixelInputs
63
64


65
66
67
68
69
70
71
72
73
74
75
def get_max_llava_image_tokens(ctx: InputContext):
    hf_config = ctx.get_hf_config(LlavaConfig)
    vision_config = hf_config.vision_config

    if isinstance(vision_config, CLIPVisionConfig):
        return get_max_clip_image_tokens(vision_config)

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


76
77
78
79
80
81
82
83
84
85
86
def dummy_data_for_llava(ctx: InputContext, seq_len: int):
    hf_config = ctx.get_hf_config(LlavaConfig)
    vision_config = hf_config.vision_config

    if isinstance(vision_config, CLIPVisionConfig):
        seq_data = dummy_seq_data_for_clip(
            vision_config,
            seq_len,
            image_token_id=hf_config.image_token_index,
        )

87
        mm_data = dummy_image_for_clip(vision_config)
88
89
90
91
92
93
        return seq_data, mm_data

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs):
    multi_modal_data = llm_inputs.get("multi_modal_data")
    if multi_modal_data is None or "image" not in multi_modal_data:
        return llm_inputs

    model_config = ctx.model_config
    hf_config = ctx.get_hf_config(LlavaConfig)
    vision_config = hf_config.vision_config

    if isinstance(vision_config, CLIPVisionConfig):
        return input_processor_for_clip(
            model_config,
            vision_config,
            llm_inputs,
            image_token_id=hf_config.image_token_index,
        )

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


115
@MULTIMODAL_REGISTRY.register_image_input_mapper()
116
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
117
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava)
118
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava)
119
120
class LlavaForConditionalGeneration(nn.Module, SupportsVision):

121
    def __init__(self,
122
                 config: LlavaConfig,
123
                 multimodal_config: MultiModalConfig,
124
                 cache_config: Optional[CacheConfig] = None,
125
                 quant_config: Optional[QuantizationConfig] = None) -> None:
126
        super().__init__()
127

128
        self.config = config
129
        self.multimodal_config = multimodal_config
130

131
132
        # TODO: Optionally initializes this for supporting embeddings.
        self.vision_tower = CLIPVisionModel(config.vision_config)
133
134
135
136
137
        self.multi_modal_projector = LlavaMultiModalProjector(
            vision_hidden_size=config.vision_config.hidden_size,
            text_hidden_size=config.text_config.hidden_size,
            projector_hidden_act=config.projector_hidden_act)

138
        self.quant_config = quant_config
139
140
        self.language_model = LlamaModel(config.text_config, cache_config,
                                         quant_config)
141
142
143
144
        self.unpadded_vocab_size = config.text_config.vocab_size
        self.lm_head = ParallelLMHead(
            self.unpadded_vocab_size,
            config.text_config.hidden_size,
145
146
            org_num_embeddings=self.language_model.org_vocab_size,
            quant_config=quant_config)
147
148
149
150
151
        logit_scale = getattr(config, "logit_scale", 1.0)
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size, logit_scale)
        self.sampler = Sampler()

152
153
154
155
156
157
158
    def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)
        actual_dims = tuple(data.shape[1:])

        if actual_dims != expected_dims:
            expected_expr = ("batch_size", *map(str, expected_dims))
159
            raise ValueError(
160
161
                f"The expected shape of pixel values is {expected_expr}. "
                f"You supplied {tuple(data.shape)}.")
162
163
164
165

        return data

    def _parse_and_validate_image_input(
166
167
            self, **kwargs: object) -> Optional[LlavaImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
168

169
170
        if pixel_values is None:
            return None
171

172
173
174
        if not isinstance(pixel_values, torch.Tensor):
            raise ValueError("Incorrect type of pixel values. "
                             f"Got type: {type(pixel_values)}")
175

176
177
        return LlavaImagePixelInputs(
            type="pixel_values",
178
            data=self._validate_pixel_values(pixel_values),
179
        )
180
181
182
183
184
185
186
187
188
189
190
191
192
193

    def _select_image_features(self, image_features: torch.Tensor, *,
                               strategy: str) -> torch.Tensor:
        # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421  # noqa
        if strategy == "default":
            return image_features[:, 1:]
        elif strategy == "full":
            return image_features

        raise ValueError(f"Unexpected select feature strategy: {strategy}")

    def _image_pixels_to_features(self, vision_tower: CLIPVisionModel,
                                  pixel_values: torch.Tensor) -> torch.Tensor:

194
195
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
196
        image_features = vision_tower(pixel_values,
197
                                      self.config.vision_feature_layer)
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213

        return self._select_image_features(
            image_features,
            strategy=self.config.vision_feature_select_strategy,
        )

    def _process_image_pixels(self,
                              inputs: LlavaImagePixelInputs) -> torch.Tensor:
        assert self.vision_tower is not None

        pixel_values = inputs["data"]

        return self._image_pixels_to_features(self.vision_tower, pixel_values)

    def _process_image_input(self,
                             image_input: LlavaImageInputs) -> torch.Tensor:
214
215
        assert self.vision_tower is not None
        image_features = self._process_image_pixels(image_input)
216
217
        return self.multi_modal_projector(image_features)

218
219
220
221
222
223
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
224
        intermediate_tensors: Optional[IntermediateTensors] = None,
225
226
        **kwargs: object,
    ) -> SamplerOutput:
Cyrus Leung's avatar
Cyrus Leung committed
227
        """Run forward pass for LLaVA-1.5.
228
229
230

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
231

232
        Concretely, consider a text prompt:
233
234
        `"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.

235
        Tokenizer outputs:
236
237
238
239
240
241
242
243
244
245
246
247
248
249
        `[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
        278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
        before they are inputted to the model, so the input processor prepends 
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
        29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
        29901]`.

        We insert 575 tokens so that including the original image token in the
        input, there are a total of 576 (24 * 24) image tokens, which
        corresponds to the number of image tokens inputted to the language
        model, i.e. the number of image tokens outputted by the visual encoder.
250
251
252
253
254
255
256

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
Cyrus Leung's avatar
Cyrus Leung committed
257
            pixel_values: The pixels in each input image.
258
259
260
        
        See also:
            :class:`LlavaImageInputs`
261
        """
262
        image_input = self._parse_and_validate_image_input(**kwargs)
263

264
265
        if image_input is not None:
            vision_embeddings = self._process_image_input(image_input)
266
            inputs_embeds = self.language_model.get_input_embeddings(input_ids)
267

268
            inputs_embeds = merge_vision_embeddings(
269
                input_ids, inputs_embeds, vision_embeddings,
270
                self.config.image_token_index)
271

272
273
274
            input_ids = None
        else:
            inputs_embeds = None
275

276
277
278
279
        hidden_states = self.language_model(input_ids,
                                            positions,
                                            kv_caches,
                                            attn_metadata,
280
                                            None,
281
282
283
284
285
286
                                            inputs_embeds=inputs_embeds)

        return hidden_states

    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
287
        logits = self.logits_processor(self.lm_head, hidden_states,
288
289
290
291
292
293
294
295
296
297
298
                                       sampling_metadata)
        return logits

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        next_tokens = self.sampler(logits, sampling_metadata)
        return next_tokens

299
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
300
301
302
303
304
305
306
307
308
309
        # only doing this for language model part for now.
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
310
        for name, loaded_weight in weights:
311
312
            if "rotary_emb.inv_freq" in name:
                continue
313
314
315
            # post_layernorm is not needed in CLIPVisionModel
            if "vision_model.post_layernorm" in name:
                continue
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
            for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
                if key_to_modify in name:
                    name = name.replace(key_to_modify, new_key)
            use_default_weight_loading = False
            if "vision" in name:
                if self.vision_tower is not None:
                    # We only do sharding for language model and
                    # not vision model for now.
                    use_default_weight_loading = True
            else:
                for (param_name, weight_name,
                     shard_id) in stacked_params_mapping:
                    if weight_name not in name:
                        continue
                    param = params_dict[name.replace(weight_name, param_name)]
                    weight_loader = param.weight_loader
                    weight_loader(param, loaded_weight, shard_id)
                    break
                else:
                    use_default_weight_loading = True
            if use_default_weight_loading:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)