"docs/vscode:/vscode.git/clone" did not exist on "7af553ea30031446b4c1c74ad83187f9fd3de4e7"
llava.py 12.9 KB
Newer Older
1
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
2
3

import torch
4
import torch.nn as nn
5
from transformers import CLIPVisionConfig, LlavaConfig
6
7

from vllm.attention import AttentionMetadata
8
from vllm.config import CacheConfig, MultiModalConfig
9
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
10
11
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.logits_processor import LogitsProcessor
12
13
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
14
15
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
16
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
17
from vllm.model_executor.models.clip import CLIPVisionModel
18
19
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
20
from vllm.multimodal import MULTIMODAL_REGISTRY
21
from vllm.sequence import IntermediateTensors, SamplerOutput
22

23
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
24
                   get_max_clip_image_tokens, input_processor_for_clip)
25
from .interfaces import SupportsVision
26
from .utils import merge_vision_embeddings
27

28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
_KEYS_TO_MODIFY_MAPPING = {
    "language_model.lm_head": "lm_head",
    "language_model.model": "language_model",
}


# TODO(xwjiang): Run benchmark and decide if TP.
class LlavaMultiModalProjector(nn.Module):

    def __init__(self, vision_hidden_size: int, text_hidden_size: int,
                 projector_hidden_act: str):
        super().__init__()

        self.linear_1 = nn.Linear(vision_hidden_size,
                                  text_hidden_size,
                                  bias=True)
        self.act = get_act_fn(projector_hidden_act)
        self.linear_2 = nn.Linear(text_hidden_size,
                                  text_hidden_size,
                                  bias=True)

49
    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
50
51
52
53
54
55
        hidden_states = self.linear_1(image_features)
        hidden_states = self.act(hidden_states)
        hidden_states = self.linear_2(hidden_states)
        return hidden_states


56
57
58
class LlavaImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
    data: torch.Tensor
59
    """Shape: `(batch_size, num_channels, height, width)`"""
60
61


62
LlavaImageInputs = LlavaImagePixelInputs
63
64


65
66
67
68
69
70
71
72
73
74
75
def get_max_llava_image_tokens(ctx: InputContext):
    hf_config = ctx.get_hf_config(LlavaConfig)
    vision_config = hf_config.vision_config

    if isinstance(vision_config, CLIPVisionConfig):
        return get_max_clip_image_tokens(vision_config)

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


76
77
78
79
80
81
82
83
84
85
86
def dummy_data_for_llava(ctx: InputContext, seq_len: int):
    hf_config = ctx.get_hf_config(LlavaConfig)
    vision_config = hf_config.vision_config

    if isinstance(vision_config, CLIPVisionConfig):
        seq_data = dummy_seq_data_for_clip(
            vision_config,
            seq_len,
            image_token_id=hf_config.image_token_index,
        )

87
        mm_data = dummy_image_for_clip(vision_config)
88
89
90
91
92
93
        return seq_data, mm_data

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs):
    multi_modal_data = llm_inputs.get("multi_modal_data")
    if multi_modal_data is None or "image" not in multi_modal_data:
        return llm_inputs

    model_config = ctx.model_config
    hf_config = ctx.get_hf_config(LlavaConfig)
    vision_config = hf_config.vision_config

    if isinstance(vision_config, CLIPVisionConfig):
        return input_processor_for_clip(
            model_config,
            vision_config,
            llm_inputs,
            image_token_id=hf_config.image_token_index,
        )

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


115
@MULTIMODAL_REGISTRY.register_image_input_mapper()
116
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
117
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava)
118
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava)
119
120
class LlavaForConditionalGeneration(nn.Module, SupportsVision):

121
    def __init__(self,
122
                 config: LlavaConfig,
123
                 multimodal_config: MultiModalConfig,
124
                 cache_config: Optional[CacheConfig] = None,
125
                 quant_config: Optional[QuantizationConfig] = None) -> None:
126
        super().__init__()
127

128
        self.config = config
129
        self.multimodal_config = multimodal_config
130

131
132
        # TODO: Optionally initializes this for supporting embeddings.
        self.vision_tower = CLIPVisionModel(config.vision_config)
133
134
135
136
137
        self.multi_modal_projector = LlavaMultiModalProjector(
            vision_hidden_size=config.vision_config.hidden_size,
            text_hidden_size=config.text_config.hidden_size,
            projector_hidden_act=config.projector_hidden_act)

138
        self.quant_config = quant_config
139
140
        self.language_model = LlamaModel(config.text_config, cache_config,
                                         quant_config)
141
142
143
144
        self.unpadded_vocab_size = config.text_config.vocab_size
        self.lm_head = ParallelLMHead(
            self.unpadded_vocab_size,
            config.text_config.hidden_size,
145
146
            org_num_embeddings=self.language_model.org_vocab_size,
            quant_config=quant_config)
147
148
149
150
151
        logit_scale = getattr(config, "logit_scale", 1.0)
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size, logit_scale)
        self.sampler = Sampler()

152
    def _validate_image_data(self, data: torch.Tensor) -> torch.Tensor:
153
154
155
156
        if list(data.shape)[1:] != [
                3, self.config.vision_config.image_size,
                self.config.vision_config.image_size
        ]:
157
            raise ValueError(
158
159
                "The expected image tensor shape is batch dimension plus "
                "channel, height and width.")
160
161
162
163

        return data

    def _parse_and_validate_image_input(
164
165
            self, **kwargs: object) -> Optional[LlavaImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
166

167
168
        if pixel_values is None:
            return None
169

170
171
172
        if not isinstance(pixel_values, torch.Tensor):
            raise ValueError("Incorrect type of pixel values. "
                             f"Got type: {type(pixel_values)}")
173

174
175
176
177
        return LlavaImagePixelInputs(
            type="pixel_values",
            data=self._validate_image_data(pixel_values),
        )
178
179
180
181
182
183
184
185
186
187
188
189
190
191

    def _select_image_features(self, image_features: torch.Tensor, *,
                               strategy: str) -> torch.Tensor:
        # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421  # noqa
        if strategy == "default":
            return image_features[:, 1:]
        elif strategy == "full":
            return image_features

        raise ValueError(f"Unexpected select feature strategy: {strategy}")

    def _image_pixels_to_features(self, vision_tower: CLIPVisionModel,
                                  pixel_values: torch.Tensor) -> torch.Tensor:

192
193
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
194
        image_features = vision_tower(pixel_values,
195
                                      self.config.vision_feature_layer)
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211

        return self._select_image_features(
            image_features,
            strategy=self.config.vision_feature_select_strategy,
        )

    def _process_image_pixels(self,
                              inputs: LlavaImagePixelInputs) -> torch.Tensor:
        assert self.vision_tower is not None

        pixel_values = inputs["data"]

        return self._image_pixels_to_features(self.vision_tower, pixel_values)

    def _process_image_input(self,
                             image_input: LlavaImageInputs) -> torch.Tensor:
212
213
        assert self.vision_tower is not None
        image_features = self._process_image_pixels(image_input)
214
215
        return self.multi_modal_projector(image_features)

216
217
218
219
220
221
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
222
        intermediate_tensors: Optional[IntermediateTensors] = None,
223
224
        **kwargs: object,
    ) -> SamplerOutput:
Cyrus Leung's avatar
Cyrus Leung committed
225
        """Run forward pass for LLaVA-1.5.
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
        Concretely, consider a text prompt:
        "<image>\nUSER: What's the content of the image?\nASSISTANT:".
        Tokenizer outputs:
        [1, 32000, 29871, 13, 11889, 29901, 1724, 29915, 29879, 278,
        2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901].
        The to-be-inserted image has a size of 576 (24 * 24) along the context
        length dimension.
        `input_ids` is thus [1, 32000, ..., 32000, 29871, 13, 11889, 29901,
        1724, 29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933,
        9047, 13566, 29901].
        There will be 576 `32000` in the `input_ids`.
        (32000 is the token id for `<image>`.)

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
Cyrus Leung's avatar
Cyrus Leung committed
248
            pixel_values: The pixels in each input image.
249
        """
250
        image_input = self._parse_and_validate_image_input(**kwargs)
251

252
253
        if image_input is not None:
            vision_embeddings = self._process_image_input(image_input)
254
            inputs_embeds = self.language_model.get_input_embeddings(input_ids)
255

256
            inputs_embeds = merge_vision_embeddings(
257
                input_ids, inputs_embeds, vision_embeddings,
258
                self.config.image_token_index)
259

260
261
262
            input_ids = None
        else:
            inputs_embeds = None
263

264
265
266
267
        hidden_states = self.language_model(input_ids,
                                            positions,
                                            kv_caches,
                                            attn_metadata,
268
                                            None,
269
270
271
272
273
274
                                            inputs_embeds=inputs_embeds)

        return hidden_states

    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
275
        logits = self.logits_processor(self.lm_head, hidden_states,
276
277
278
279
280
281
282
283
284
285
286
                                       sampling_metadata)
        return logits

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        next_tokens = self.sampler(logits, sampling_metadata)
        return next_tokens

287
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
288
289
290
291
292
293
294
295
296
297
        # only doing this for language model part for now.
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
298
        for name, loaded_weight in weights:
299
300
            if "rotary_emb.inv_freq" in name:
                continue
301
302
303
            # post_layernorm is not needed in CLIPVisionModel
            if "vision_model.post_layernorm" in name:
                continue
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
            for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
                if key_to_modify in name:
                    name = name.replace(key_to_modify, new_key)
            use_default_weight_loading = False
            if "vision" in name:
                if self.vision_tower is not None:
                    # We only do sharding for language model and
                    # not vision model for now.
                    use_default_weight_loading = True
            else:
                for (param_name, weight_name,
                     shard_id) in stacked_params_mapping:
                    if weight_name not in name:
                        continue
                    param = params_dict[name.replace(weight_name, param_name)]
                    weight_loader = param.weight_loader
                    weight_loader(param, loaded_weight, shard_id)
                    break
                else:
                    use_default_weight_loading = True
            if use_default_weight_loading:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)