paligemma.py 13.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from collections.abc import Iterable, Mapping, Sequence
4
from typing import Annotated, Literal, Optional, Union
Roger Wang's avatar
Roger Wang committed
5
6
7

import torch
from torch import nn
8
from transformers import BatchFeature, PaliGemmaConfig
Roger Wang's avatar
Roger Wang committed
9

10
from vllm.config import VllmConfig
11
from vllm.config.multimodal import BaseDummyOptions
Roger Wang's avatar
Roger Wang committed
12
13
from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalInputs,
    MultiModalKwargsItems,
    MultiModalUUIDDict,
)
from vllm.multimodal.parse import (
    ImageEmbeddingItems,
    ImageProcessorItems,
    MultiModalDataItems,
)
from vllm.multimodal.processing import (
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptIndexTargets,
    PromptInsertion,
    PromptUpdate,
    PromptUpdateDetails,
)
34
from vllm.multimodal.profiling import BaseDummyInputsBuilder
35
from vllm.sequence import IntermediateTensors
36
from vllm.utils.tensor_schema import TensorSchema, TensorShape
Roger Wang's avatar
Roger Wang committed
37

38
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
39
from .siglip import SiglipVisionModel
40
41
42
43
44
45
46
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    flatten_bn,
    init_vllm_registered_model,
    maybe_prefix,
)
47
from .vision import get_vision_encoder_info
Roger Wang's avatar
Roger Wang committed
48
49
50
51

logger = init_logger(__name__)


52
53
54
55
56
57
58
59
class PaliGemmaImagePixelInputs(TensorSchema):
    """
    Dimensions:
        - bn: Batch size * number of images
        - c: Number of channels (3)
        - h: Height
        - w: Width
    """
60

61
62
    type: Literal["pixel_values"] = "pixel_values"
    data: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
63
64


65
class PaliGemmaImageEmbeddingInputs(TensorSchema):
66
    """
67
68
69
70
71
    Dimensions:
        - bn: Batch size * number of images
        - ifs: Image feature size
        - hs: Hidden size (must match language model backbone)
    """
72

73
74
    type: Literal["image_embeds"] = "image_embeds"
    data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")]
75
76


77
PaliGemmaImageInputs = Union[PaliGemmaImagePixelInputs, PaliGemmaImageEmbeddingInputs]
78
79


Roger Wang's avatar
Roger Wang committed
80
81
82
83
class PaliGemmaMultiModalProjector(nn.Module):
    def __init__(self, vision_hidden_size: int, projection_dim: int):
        super().__init__()

84
        self.linear = nn.Linear(vision_hidden_size, projection_dim, bias=True)
Roger Wang's avatar
Roger Wang committed
85
86

    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
87
        hidden_states = self.linear(image_features)
Roger Wang's avatar
Roger Wang committed
88
89
90
        return hidden_states


91
92
93
94
class PaliGemmaProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self):
        return self.ctx.get_hf_config(PaliGemmaConfig)

95
96
97
    def get_vision_encoder_info(self):
        return get_vision_encoder_info(self.get_hf_config())

98
99
100
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": 1}

101
    def get_num_image_tokens(
102
        self,
103
104
105
106
        *,
        image_width: int,
        image_height: int,
    ) -> int:
107
        vision_encoder_info = self.get_vision_encoder_info()
108
109
110
111
112

        return vision_encoder_info.get_num_image_tokens(
            image_width=image_width,
            image_height=image_height,
        )
113
114


115
class PaliGemmaDummyInputsBuilder(BaseDummyInputsBuilder[PaliGemmaProcessingInfo]):
116
117
118
119
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        return ""

    def get_dummy_mm_data(
120
121
122
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
123
        mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
124
    ) -> MultiModalDataDict:
125
126
127
128
129
130
        hf_config = self.info.get_hf_config()
        vision_config = hf_config.vision_config
        max_image_size = vision_config.image_size

        num_images = mm_counts.get("image", 0)

131
132
        image_overrides = mm_options.get("image") if mm_options else None

133
        return {
134
135
136
137
138
139
            "image": self._get_dummy_images(
                width=max_image_size,
                height=max_image_size,
                num_images=num_images,
                overrides=image_overrides,
            )
140
141
142
        }


143
class PaliGemmaMultiModalProcessor(BaseMultiModalProcessor[PaliGemmaProcessingInfo]):
144
145
146
147
148
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
149
        tok_kwargs: Mapping[str, object],
150
151
152
    ) -> BatchFeature:
        tokenizer = self.info.get_tokenizer()
        if not mm_data:
153
            prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)
154
155
156
157
158
159
            return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

        return super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
160
            tok_kwargs=tok_kwargs,
161
162
163
164
165
166
167
168
169
170
171
172
173
        )

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(pixel_values=MultiModalFieldConfig.batched("image"))

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
174
        out_mm_kwargs: MultiModalKwargsItems,
175
    ) -> Sequence[PromptUpdate]:
176
177
178
179
180
181
182
183
        hf_config = self.info.get_hf_config()
        image_token_id = hf_config.image_token_index

        tokenizer = self.info.get_tokenizer()

        bos_token_id = tokenizer.bos_token_id
        assert isinstance(bos_token_id, int)

184
185
        def get_insertion(item_idx: int):
            images = mm_items.get_items(
186
187
                "image", (ImageEmbeddingItems, ImageProcessorItems)
            )
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204

            if isinstance(images, ImageEmbeddingItems):
                num_image_tokens = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
                num_image_tokens = self.info.get_num_image_tokens(
                    image_width=image_size.width,
                    image_height=image_size.height,
                )

            image_tokens = [image_token_id] * num_image_tokens

            return PromptUpdateDetails.select_token_id(
                image_tokens + [bos_token_id],
                embed_token_id=image_token_id,
            )

205
206
207
208
209
210
211
        # Paligemma 1 and 2 have different tokenizer.add_bos_token
        # Insert <image>*n + <bos> after <bos> for Paligemma 1
        # Insert <image>*n + <bos> for Paligemma 2
        return [
            PromptInsertion(
                modality="image",
                target=PromptIndexTargets.prefix(
212
213
                    [bos_token_id] if tokenizer.add_bos_token else []
                ),
214
                insertion=get_insertion,
215
216
217
218
219
220
221
222
            )
        ]

    def apply(
        self,
        prompt: Union[str, list[int]],
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
223
        tokenization_kwargs: Optional[Mapping[str, object]] = None,
224
        mm_uuids: Optional[MultiModalUUIDDict] = None,
225
    ) -> MultiModalInputs:
226
227
228
229
230
231
232
        mm_inputs = super().apply(
            prompt,
            mm_data,
            hf_processor_mm_kwargs,
            tokenization_kwargs,
            mm_uuids=mm_uuids,
        )
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
        prompt_token_ids = mm_inputs["prompt_token_ids"]

        tokenizer = self.info.get_tokenizer()
        newline_prompt = "\n"
        newline_token_id = tokenizer.encode(newline_prompt)[-1]  # 108
        # Force to add newline at the end of prompt for paligemma's format
        # This step can NOT be replacemented by current PromptUpdate methods
        if len(prompt_token_ids) and prompt_token_ids[-1] != newline_token_id:
            prompt_token_ids.append(newline_token_id)
            mm_inputs["prompt_token_ids"] = prompt_token_ids

        return mm_inputs


@MULTIMODAL_REGISTRY.register_processor(
    PaliGemmaMultiModalProcessor,
    info=PaliGemmaProcessingInfo,
250
251
252
    dummy_inputs=PaliGemmaDummyInputsBuilder,
)
class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
253
254
255
256
257
258
259
260
261
262
263
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }
264

265
266
267
268
269
270
271
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.language_model.": "language_model.model.",
            "model.vision_tower.": "vision_tower.",
            "model.multi_modal_projector.": "multi_modal_projector.",
            "lm_head.": "language_model.lm_head.",
272
273
        }
    )
274

275
276
277
278
279
280
281
    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
        if modality.startswith("image"):
            return None

        raise ValueError("Only image modality is supported")

282
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Roger Wang's avatar
Roger Wang committed
283
        super().__init__()
284
285
286
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
Roger Wang's avatar
Roger Wang committed
287
288
289
        self.config = config
        self.multimodal_config = multimodal_config

290
291
292
293
294
        self.vision_tower = SiglipVisionModel(
            config.vision_config,
            quant_config,
            prefix=maybe_prefix(prefix, "vision_tower"),
        )
Roger Wang's avatar
Roger Wang committed
295
296
        self.multi_modal_projector = PaliGemmaMultiModalProjector(
            vision_hidden_size=config.vision_config.hidden_size,
297
298
            projection_dim=config.vision_config.projection_dim,
        )
Roger Wang's avatar
Roger Wang committed
299
300

        self.quant_config = quant_config
Jani Monoses's avatar
Jani Monoses committed
301
302
303
304
305

        if config.text_config.model_type == "gemma":
            config.text_config.architectures = ["GemmaForCausalLM"]
        else:
            config.text_config.architectures = ["Gemma2ForCausalLM"]
306
307
        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
308
309
310
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )
Roger Wang's avatar
Roger Wang committed
311
        logit_scale = getattr(config, "logit_scale", 1.0)
312
313
314
        self.language_model.logits_processor.scale *= logit_scale

        self.make_empty_intermediate_tensors = (
315
316
            self.language_model.make_empty_intermediate_tensors
        )
317

Roger Wang's avatar
Roger Wang committed
318
    def _parse_and_validate_image_input(
319
320
        self, **kwargs: object
    ) -> Optional[PaliGemmaImageInputs]:
Roger Wang's avatar
Roger Wang committed
321
        pixel_values = kwargs.pop("pixel_values", None)
322
        image_embeds = kwargs.pop("image_embeds", None)
Roger Wang's avatar
Roger Wang committed
323

324
        if pixel_values is None and image_embeds is None:
Roger Wang's avatar
Roger Wang committed
325
326
            return None

327
        if pixel_values is not None:
328
            pixel_values = flatten_bn(pixel_values, concat=True)
329

330
            h = w = self.config.vision_config.image_size
331
332
333
334
335
            return PaliGemmaImagePixelInputs(
                type="pixel_values",
                data=pixel_values,
                resolve_bindings={"h": h, "w": w},
            )
336
337

        if image_embeds is not None:
338
            image_embeds = flatten_bn(image_embeds, concat=True)
339

340
341
342
343
344
345
            return PaliGemmaImageEmbeddingInputs(
                type="image_embeds",
                data=image_embeds,
            )

        raise AssertionError("This line should be unreachable.")
Roger Wang's avatar
Roger Wang committed
346

347
348
349
350
351
    def _image_pixels_to_features(
        self,
        vision_tower: SiglipVisionModel,
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:
352
        target_dtype = vision_tower.get_input_embeddings().weight.dtype
353
        image_features = vision_tower(pixel_values.to(dtype=target_dtype))
Roger Wang's avatar
Roger Wang committed
354

355
        return image_features
Roger Wang's avatar
Roger Wang committed
356

357
    def _process_image_input(
358
        self,
359
        image_input: PaliGemmaImageInputs,
360
    ) -> torch.Tensor:
361
362
        if image_input["type"] == "image_embeds":
            return image_input["data"]
Roger Wang's avatar
Roger Wang committed
363

364
365
366
        assert self.vision_tower is not None
        pixel_values = image_input["data"]
        image_features = self._image_pixels_to_features(
367
368
369
            self.vision_tower,
            pixel_values,
        )
Roger Wang's avatar
Roger Wang committed
370
371
372

        return self.multi_modal_projector(image_features)

373
374
375
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

376
    def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
377
378
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
379
            return []
380
381
382
383
384
        vision_embeddings = self._process_image_input(image_input)
        # https://github.com/huggingface/transformers/blob/main/src/transformers/models/paligemma/modeling_paligemma.py#L294 # noqa
        vision_embeddings = vision_embeddings * (self.config.hidden_size**-0.5)
        return vision_embeddings

385
386
387
388
389
390
391
392
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        **kwargs: object,
    ) -> IntermediateTensors:
393
394
        if intermediate_tensors is not None:
            inputs_embeds = None
395

396
397
398
        hidden_states = self.language_model.model(
            input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
        )
Roger Wang's avatar
Roger Wang committed
399
400
401

        return hidden_states

402
403
404
405
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
406
        return self.language_model.compute_logits(hidden_states)
Roger Wang's avatar
Roger Wang committed
407

408
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
409
        loader = AutoWeightsLoader(self)
410
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)