paligemma.py 15 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from collections.abc import Iterable, Mapping, Sequence
4
from typing import Annotated, Literal, Optional, Union
Roger Wang's avatar
Roger Wang committed
5
6
7

import torch
from torch import nn
8
from transformers import BatchFeature, PaliGemmaConfig
Roger Wang's avatar
Roger Wang committed
9

10
from vllm.config import VllmConfig
Roger Wang's avatar
Roger Wang committed
11
12
13
from vllm.logger import init_logger
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
14
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
15
16
                                    MultiModalInputs, MultiModalKwargsItems,
                                    MultiModalUUIDDict)
17
18
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
                                   MultiModalDataItems)
19
20
from vllm.multimodal.processing import (BaseMultiModalProcessor,
                                        BaseProcessingInfo, PromptIndexTargets,
21
                                        PromptInsertion, PromptUpdate,
22
                                        PromptUpdateDetails)
23
from vllm.multimodal.profiling import BaseDummyInputsBuilder
24
from vllm.sequence import IntermediateTensors
25
from vllm.utils.tensor_schema import TensorSchema, TensorShape
Roger Wang's avatar
Roger Wang committed
26

27
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
28
from .siglip import SiglipVisionModel
29
30
31
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
                    init_vllm_registered_model, maybe_prefix,
                    merge_multimodal_embeddings)
32
from .vision import get_vision_encoder_info
Roger Wang's avatar
Roger Wang committed
33
34
35
36

logger = init_logger(__name__)


37
38
39
40
41
42
43
44
45
46
class PaliGemmaImagePixelInputs(TensorSchema):
    """
    Dimensions:
        - bn: Batch size * number of images
        - c: Number of channels (3)
        - h: Height
        - w: Width
    """
    type: Literal["pixel_values"] = "pixel_values"
    data: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
47
48


49
class PaliGemmaImageEmbeddingInputs(TensorSchema):
50
    """
51
52
53
54
55
56
57
    Dimensions:
        - bn: Batch size * number of images
        - ifs: Image feature size
        - hs: Hidden size (must match language model backbone)
    """
    type: Literal["image_embeds"] = "image_embeds"
    data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")]
58
59
60
61
62
63


PaliGemmaImageInputs = Union[PaliGemmaImagePixelInputs,
                             PaliGemmaImageEmbeddingInputs]


Roger Wang's avatar
Roger Wang committed
64
65
66
67
68
class PaliGemmaMultiModalProjector(nn.Module):

    def __init__(self, vision_hidden_size: int, projection_dim: int):
        super().__init__()

69
        self.linear = nn.Linear(vision_hidden_size, projection_dim, bias=True)
Roger Wang's avatar
Roger Wang committed
70
71

    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
72
        hidden_states = self.linear(image_features)
Roger Wang's avatar
Roger Wang committed
73
74
75
        return hidden_states


76
77
78
79
80
class PaliGemmaProcessingInfo(BaseProcessingInfo):

    def get_hf_config(self):
        return self.ctx.get_hf_config(PaliGemmaConfig)

81
82
83
    def get_vision_encoder_info(self):
        return get_vision_encoder_info(self.get_hf_config())

84
85
86
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": 1}

87
    def get_num_image_tokens(
88
        self,
89
90
91
92
        *,
        image_width: int,
        image_height: int,
    ) -> int:
93
        vision_encoder_info = self.get_vision_encoder_info()
94
95
96
97
98

        return vision_encoder_info.get_num_image_tokens(
            image_width=image_width,
            image_height=image_height,
        )
99
100
101
102
103


class PaliGemmaDummyInputsBuilder(
        BaseDummyInputsBuilder[PaliGemmaProcessingInfo]):

104
105
106
107
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        return ""

    def get_dummy_mm_data(
108
109
110
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
111
    ) -> MultiModalDataDict:
112
113
114
115
116
117
        hf_config = self.info.get_hf_config()
        vision_config = hf_config.vision_config
        max_image_size = vision_config.image_size

        num_images = mm_counts.get("image", 0)

118
        return {
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
            "image":
            self._get_dummy_images(width=max_image_size,
                                   height=max_image_size,
                                   num_images=num_images)
        }


class PaliGemmaMultiModalProcessor(
        BaseMultiModalProcessor[PaliGemmaProcessingInfo]):

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
134
        tok_kwargs: Mapping[str, object],
135
136
137
    ) -> BatchFeature:
        tokenizer = self.info.get_tokenizer()
        if not mm_data:
138
            prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)
139
140
141
142
143
144
            return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

        return super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
145
            tok_kwargs=tok_kwargs,
146
147
148
149
150
151
152
153
154
155
156
157
158
        )

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(pixel_values=MultiModalFieldConfig.batched("image"))

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
159
        out_mm_kwargs: MultiModalKwargsItems,
160
    ) -> Sequence[PromptUpdate]:
161
162
163
164
165
166
167
168
        hf_config = self.info.get_hf_config()
        image_token_id = hf_config.image_token_index

        tokenizer = self.info.get_tokenizer()

        bos_token_id = tokenizer.bos_token_id
        assert isinstance(bos_token_id, int)

169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
        def get_insertion(item_idx: int):
            images = mm_items.get_items(
                "image", (ImageEmbeddingItems, ImageProcessorItems))

            if isinstance(images, ImageEmbeddingItems):
                num_image_tokens = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
                num_image_tokens = self.info.get_num_image_tokens(
                    image_width=image_size.width,
                    image_height=image_size.height,
                )

            image_tokens = [image_token_id] * num_image_tokens

            return PromptUpdateDetails.select_token_id(
                image_tokens + [bos_token_id],
                embed_token_id=image_token_id,
            )

189
190
191
192
193
194
195
196
        # Paligemma 1 and 2 have different tokenizer.add_bos_token
        # Insert <image>*n + <bos> after <bos> for Paligemma 1
        # Insert <image>*n + <bos> for Paligemma 2
        return [
            PromptInsertion(
                modality="image",
                target=PromptIndexTargets.prefix(
                    [bos_token_id] if tokenizer.add_bos_token else []),
197
                insertion=get_insertion,
198
199
200
201
202
203
204
205
            )
        ]

    def apply(
        self,
        prompt: Union[str, list[int]],
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
206
        tokenization_kwargs: Optional[Mapping[str, object]] = None,
207
        mm_uuids: Optional[MultiModalUUIDDict] = None,
208
    ) -> MultiModalInputs:
209
210
211
212
        mm_inputs = super().apply(prompt,
                                  mm_data,
                                  hf_processor_mm_kwargs,
                                  tokenization_kwargs,
213
                                  mm_uuids=mm_uuids)
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
        prompt_token_ids = mm_inputs["prompt_token_ids"]

        tokenizer = self.info.get_tokenizer()
        newline_prompt = "\n"
        newline_token_id = tokenizer.encode(newline_prompt)[-1]  # 108
        # Force to add newline at the end of prompt for paligemma's format
        # This step can NOT be replacemented by current PromptUpdate methods
        if len(prompt_token_ids) and prompt_token_ids[-1] != newline_token_id:
            prompt_token_ids.append(newline_token_id)
            mm_inputs["prompt_token_ids"] = prompt_token_ids
            mm_inputs["prompt"] += newline_prompt

        return mm_inputs


@MULTIMODAL_REGISTRY.register_processor(
    PaliGemmaMultiModalProcessor,
    info=PaliGemmaProcessingInfo,
    dummy_inputs=PaliGemmaDummyInputsBuilder)
233
class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
234
                                        SupportsPP):
235
236
237
238
239
240
241
242
243
244
245
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }
246

247
248
249
250
251
252
253
254
255
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.language_model.": "language_model.model.",
            "model.vision_tower.": "vision_tower.",
            "model.multi_modal_projector.": "multi_modal_projector.",
            "lm_head.": "language_model.lm_head.",
        })

256
257
258
259
260
261
262
    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
        if modality.startswith("image"):
            return None

        raise ValueError("Only image modality is supported")

263
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Roger Wang's avatar
Roger Wang committed
264
        super().__init__()
265
266
267
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
Roger Wang's avatar
Roger Wang committed
268
269
270
        self.config = config
        self.multimodal_config = multimodal_config

271
        self.vision_tower = SiglipVisionModel(config.vision_config,
272
                                              quant_config,
273
274
                                              prefix=maybe_prefix(
                                                  prefix, "vision_tower"))
Roger Wang's avatar
Roger Wang committed
275
276
277
278
279
        self.multi_modal_projector = PaliGemmaMultiModalProjector(
            vision_hidden_size=config.vision_config.hidden_size,
            projection_dim=config.vision_config.projection_dim)

        self.quant_config = quant_config
Jani Monoses's avatar
Jani Monoses committed
280
281
282
283
284

        if config.text_config.model_type == "gemma":
            config.text_config.architectures = ["GemmaForCausalLM"]
        else:
            config.text_config.architectures = ["Gemma2ForCausalLM"]
285
286
        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
287
288
289
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )
Roger Wang's avatar
Roger Wang committed
290
        logit_scale = getattr(config, "logit_scale", 1.0)
291
292
293
294
295
        self.language_model.logits_processor.scale *= logit_scale

        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

Roger Wang's avatar
Roger Wang committed
296
297
298
    def _parse_and_validate_image_input(
            self, **kwargs: object) -> Optional[PaliGemmaImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
299
        image_embeds = kwargs.pop("image_embeds", None)
Roger Wang's avatar
Roger Wang committed
300

301
        if pixel_values is None and image_embeds is None:
Roger Wang's avatar
Roger Wang committed
302
303
            return None

304
        if pixel_values is not None:
305
            pixel_values = flatten_bn(pixel_values, concat=True)
306

307
308
309
310
311
312
313
            h = w = self.config.vision_config.image_size
            return PaliGemmaImagePixelInputs(type="pixel_values",
                                             data=pixel_values,
                                             resolve_bindings={
                                                 "h": h,
                                                 "w": w
                                             })
314
315

        if image_embeds is not None:
316
            image_embeds = flatten_bn(image_embeds, concat=True)
317

318
319
320
321
322
323
            return PaliGemmaImageEmbeddingInputs(
                type="image_embeds",
                data=image_embeds,
            )

        raise AssertionError("This line should be unreachable.")
Roger Wang's avatar
Roger Wang committed
324

325
326
327
328
329
    def _image_pixels_to_features(
        self,
        vision_tower: SiglipVisionModel,
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:
Roger Wang's avatar
Roger Wang committed
330

331
        target_dtype = vision_tower.get_input_embeddings().weight.dtype
332
        image_features = vision_tower(pixel_values.to(dtype=target_dtype))
Roger Wang's avatar
Roger Wang committed
333

334
        return image_features
Roger Wang's avatar
Roger Wang committed
335

336
    def _process_image_input(
337
        self,
338
        image_input: PaliGemmaImageInputs,
339
    ) -> torch.Tensor:
Roger Wang's avatar
Roger Wang committed
340

341
342
        if image_input["type"] == "image_embeds":
            return image_input["data"]
Roger Wang's avatar
Roger Wang committed
343

344
345
346
        assert self.vision_tower is not None
        pixel_values = image_input["data"]
        image_features = self._image_pixels_to_features(
347
348
349
            self.vision_tower,
            pixel_values,
        )
Roger Wang's avatar
Roger Wang committed
350
351
352

        return self.multi_modal_projector(image_features)

353
354
355
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

356
357
    def get_multimodal_embeddings(self,
                                  **kwargs: object) -> MultiModalEmbeddings:
358
359
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
360
            return []
361
362
363
364
365
366
367
368
        vision_embeddings = self._process_image_input(image_input)
        # https://github.com/huggingface/transformers/blob/main/src/transformers/models/paligemma/modeling_paligemma.py#L294 # noqa
        vision_embeddings = vision_embeddings * (self.config.hidden_size**-0.5)
        return vision_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
369
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
370
371
    ) -> torch.Tensor:
        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
372
373
        if multimodal_embeddings is not None \
            and len(multimodal_embeddings) != 0:
374
375
376
377
378
            inputs_embeds = merge_multimodal_embeddings(
                input_ids, inputs_embeds, multimodal_embeddings,
                self.config.image_token_index)
        return inputs_embeds

379
380
381
382
    def forward(self,
                input_ids: torch.Tensor,
                positions: torch.Tensor,
                intermediate_tensors: Optional[IntermediateTensors] = None,
383
                inputs_embeds: Optional[torch.Tensor] = None,
384
                **kwargs: object) -> IntermediateTensors:
385
386
        if intermediate_tensors is not None:
            inputs_embeds = None
387
388
389
390
391
392
393
394

        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
        elif inputs_embeds is None:
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
            input_ids = None
Roger Wang's avatar
Roger Wang committed
395

396
397
        hidden_states = self.language_model.model(input_ids,
                                                  positions,
398
                                                  intermediate_tensors,
399
                                                  inputs_embeds=inputs_embeds)
Roger Wang's avatar
Roger Wang committed
400
401
402

        return hidden_states

403
404
405
406
407
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
408
409
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)
Roger Wang's avatar
Roger Wang committed
410

411
412
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
413
        loader = AutoWeightsLoader(self)
414
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)