paligemma.py 13.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from collections.abc import Iterable, Mapping, Sequence
4
from typing import Annotated, Literal, Optional, Union
Roger Wang's avatar
Roger Wang committed
5
6
7

import torch
from torch import nn
8
from transformers import BatchFeature, PaliGemmaConfig
Roger Wang's avatar
Roger Wang committed
9

10
from vllm.config import VllmConfig
Roger Wang's avatar
Roger Wang committed
11
12
from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY
13
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
14
15
                                    MultiModalInputs, MultiModalKwargsItems,
                                    MultiModalUUIDDict)
16
17
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
                                   MultiModalDataItems)
18
19
from vllm.multimodal.processing import (BaseMultiModalProcessor,
                                        BaseProcessingInfo, PromptIndexTargets,
20
                                        PromptInsertion, PromptUpdate,
21
                                        PromptUpdateDetails)
22
from vllm.multimodal.profiling import BaseDummyInputsBuilder
23
from vllm.sequence import IntermediateTensors
24
from vllm.utils.tensor_schema import TensorSchema, TensorShape
Roger Wang's avatar
Roger Wang committed
25

26
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
27
from .siglip import SiglipVisionModel
28
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
29
                    init_vllm_registered_model, maybe_prefix)
30
from .vision import get_vision_encoder_info
Roger Wang's avatar
Roger Wang committed
31
32
33
34

logger = init_logger(__name__)


35
36
37
38
39
40
41
42
43
44
class PaliGemmaImagePixelInputs(TensorSchema):
    """
    Dimensions:
        - bn: Batch size * number of images
        - c: Number of channels (3)
        - h: Height
        - w: Width
    """
    type: Literal["pixel_values"] = "pixel_values"
    data: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
45
46


47
class PaliGemmaImageEmbeddingInputs(TensorSchema):
48
    """
49
50
51
52
53
54
55
    Dimensions:
        - bn: Batch size * number of images
        - ifs: Image feature size
        - hs: Hidden size (must match language model backbone)
    """
    type: Literal["image_embeds"] = "image_embeds"
    data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")]
56
57
58
59
60
61


PaliGemmaImageInputs = Union[PaliGemmaImagePixelInputs,
                             PaliGemmaImageEmbeddingInputs]


Roger Wang's avatar
Roger Wang committed
62
63
64
65
66
class PaliGemmaMultiModalProjector(nn.Module):

    def __init__(self, vision_hidden_size: int, projection_dim: int):
        super().__init__()

67
        self.linear = nn.Linear(vision_hidden_size, projection_dim, bias=True)
Roger Wang's avatar
Roger Wang committed
68
69

    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
70
        hidden_states = self.linear(image_features)
Roger Wang's avatar
Roger Wang committed
71
72
73
        return hidden_states


74
75
76
77
78
class PaliGemmaProcessingInfo(BaseProcessingInfo):

    def get_hf_config(self):
        return self.ctx.get_hf_config(PaliGemmaConfig)

79
80
81
    def get_vision_encoder_info(self):
        return get_vision_encoder_info(self.get_hf_config())

82
83
84
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": 1}

85
    def get_num_image_tokens(
86
        self,
87
88
89
90
        *,
        image_width: int,
        image_height: int,
    ) -> int:
91
        vision_encoder_info = self.get_vision_encoder_info()
92
93
94
95
96

        return vision_encoder_info.get_num_image_tokens(
            image_width=image_width,
            image_height=image_height,
        )
97
98
99
100
101


class PaliGemmaDummyInputsBuilder(
        BaseDummyInputsBuilder[PaliGemmaProcessingInfo]):

102
103
104
105
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        return ""

    def get_dummy_mm_data(
106
107
108
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
109
    ) -> MultiModalDataDict:
110
111
112
113
114
115
        hf_config = self.info.get_hf_config()
        vision_config = hf_config.vision_config
        max_image_size = vision_config.image_size

        num_images = mm_counts.get("image", 0)

116
        return {
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
            "image":
            self._get_dummy_images(width=max_image_size,
                                   height=max_image_size,
                                   num_images=num_images)
        }


class PaliGemmaMultiModalProcessor(
        BaseMultiModalProcessor[PaliGemmaProcessingInfo]):

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
132
        tok_kwargs: Mapping[str, object],
133
134
135
    ) -> BatchFeature:
        tokenizer = self.info.get_tokenizer()
        if not mm_data:
136
            prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)
137
138
139
140
141
142
            return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

        return super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
143
            tok_kwargs=tok_kwargs,
144
145
146
147
148
149
150
151
152
153
154
155
156
        )

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(pixel_values=MultiModalFieldConfig.batched("image"))

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
157
        out_mm_kwargs: MultiModalKwargsItems,
158
    ) -> Sequence[PromptUpdate]:
159
160
161
162
163
164
165
166
        hf_config = self.info.get_hf_config()
        image_token_id = hf_config.image_token_index

        tokenizer = self.info.get_tokenizer()

        bos_token_id = tokenizer.bos_token_id
        assert isinstance(bos_token_id, int)

167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
        def get_insertion(item_idx: int):
            images = mm_items.get_items(
                "image", (ImageEmbeddingItems, ImageProcessorItems))

            if isinstance(images, ImageEmbeddingItems):
                num_image_tokens = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
                num_image_tokens = self.info.get_num_image_tokens(
                    image_width=image_size.width,
                    image_height=image_size.height,
                )

            image_tokens = [image_token_id] * num_image_tokens

            return PromptUpdateDetails.select_token_id(
                image_tokens + [bos_token_id],
                embed_token_id=image_token_id,
            )

187
188
189
190
191
192
193
194
        # Paligemma 1 and 2 have different tokenizer.add_bos_token
        # Insert <image>*n + <bos> after <bos> for Paligemma 1
        # Insert <image>*n + <bos> for Paligemma 2
        return [
            PromptInsertion(
                modality="image",
                target=PromptIndexTargets.prefix(
                    [bos_token_id] if tokenizer.add_bos_token else []),
195
                insertion=get_insertion,
196
197
198
199
200
201
202
203
            )
        ]

    def apply(
        self,
        prompt: Union[str, list[int]],
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
204
        tokenization_kwargs: Optional[Mapping[str, object]] = None,
205
        mm_uuids: Optional[MultiModalUUIDDict] = None,
206
    ) -> MultiModalInputs:
207
208
209
210
        mm_inputs = super().apply(prompt,
                                  mm_data,
                                  hf_processor_mm_kwargs,
                                  tokenization_kwargs,
211
                                  mm_uuids=mm_uuids)
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
        prompt_token_ids = mm_inputs["prompt_token_ids"]

        tokenizer = self.info.get_tokenizer()
        newline_prompt = "\n"
        newline_token_id = tokenizer.encode(newline_prompt)[-1]  # 108
        # Force to add newline at the end of prompt for paligemma's format
        # This step can NOT be replacemented by current PromptUpdate methods
        if len(prompt_token_ids) and prompt_token_ids[-1] != newline_token_id:
            prompt_token_ids.append(newline_token_id)
            mm_inputs["prompt_token_ids"] = prompt_token_ids
            mm_inputs["prompt"] += newline_prompt

        return mm_inputs


@MULTIMODAL_REGISTRY.register_processor(
    PaliGemmaMultiModalProcessor,
    info=PaliGemmaProcessingInfo,
    dummy_inputs=PaliGemmaDummyInputsBuilder)
231
class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
232
                                        SupportsPP):
233
234
235
236
237
238
239
240
241
242
243
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }
244

245
246
247
248
249
250
251
252
253
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.language_model.": "language_model.model.",
            "model.vision_tower.": "vision_tower.",
            "model.multi_modal_projector.": "multi_modal_projector.",
            "lm_head.": "language_model.lm_head.",
        })

254
255
256
257
258
259
260
    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
        if modality.startswith("image"):
            return None

        raise ValueError("Only image modality is supported")

261
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Roger Wang's avatar
Roger Wang committed
262
        super().__init__()
263
264
265
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
Roger Wang's avatar
Roger Wang committed
266
267
268
        self.config = config
        self.multimodal_config = multimodal_config

269
        self.vision_tower = SiglipVisionModel(config.vision_config,
270
                                              quant_config,
271
272
                                              prefix=maybe_prefix(
                                                  prefix, "vision_tower"))
Roger Wang's avatar
Roger Wang committed
273
274
275
276
277
        self.multi_modal_projector = PaliGemmaMultiModalProjector(
            vision_hidden_size=config.vision_config.hidden_size,
            projection_dim=config.vision_config.projection_dim)

        self.quant_config = quant_config
Jani Monoses's avatar
Jani Monoses committed
278
279
280
281
282

        if config.text_config.model_type == "gemma":
            config.text_config.architectures = ["GemmaForCausalLM"]
        else:
            config.text_config.architectures = ["Gemma2ForCausalLM"]
283
284
        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
285
286
287
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )
Roger Wang's avatar
Roger Wang committed
288
        logit_scale = getattr(config, "logit_scale", 1.0)
289
290
291
292
293
        self.language_model.logits_processor.scale *= logit_scale

        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

Roger Wang's avatar
Roger Wang committed
294
295
296
    def _parse_and_validate_image_input(
            self, **kwargs: object) -> Optional[PaliGemmaImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
297
        image_embeds = kwargs.pop("image_embeds", None)
Roger Wang's avatar
Roger Wang committed
298

299
        if pixel_values is None and image_embeds is None:
Roger Wang's avatar
Roger Wang committed
300
301
            return None

302
        if pixel_values is not None:
303
            pixel_values = flatten_bn(pixel_values, concat=True)
304

305
306
307
308
309
310
311
            h = w = self.config.vision_config.image_size
            return PaliGemmaImagePixelInputs(type="pixel_values",
                                             data=pixel_values,
                                             resolve_bindings={
                                                 "h": h,
                                                 "w": w
                                             })
312
313

        if image_embeds is not None:
314
            image_embeds = flatten_bn(image_embeds, concat=True)
315

316
317
318
319
320
321
            return PaliGemmaImageEmbeddingInputs(
                type="image_embeds",
                data=image_embeds,
            )

        raise AssertionError("This line should be unreachable.")
Roger Wang's avatar
Roger Wang committed
322

323
324
325
326
327
    def _image_pixels_to_features(
        self,
        vision_tower: SiglipVisionModel,
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:
Roger Wang's avatar
Roger Wang committed
328

329
        target_dtype = vision_tower.get_input_embeddings().weight.dtype
330
        image_features = vision_tower(pixel_values.to(dtype=target_dtype))
Roger Wang's avatar
Roger Wang committed
331

332
        return image_features
Roger Wang's avatar
Roger Wang committed
333

334
    def _process_image_input(
335
        self,
336
        image_input: PaliGemmaImageInputs,
337
    ) -> torch.Tensor:
Roger Wang's avatar
Roger Wang committed
338

339
340
        if image_input["type"] == "image_embeds":
            return image_input["data"]
Roger Wang's avatar
Roger Wang committed
341

342
343
344
        assert self.vision_tower is not None
        pixel_values = image_input["data"]
        image_features = self._image_pixels_to_features(
345
346
347
            self.vision_tower,
            pixel_values,
        )
Roger Wang's avatar
Roger Wang committed
348
349
350

        return self.multi_modal_projector(image_features)

351
352
353
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

354
355
    def get_multimodal_embeddings(self,
                                  **kwargs: object) -> MultiModalEmbeddings:
356
357
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
358
            return []
359
360
361
362
363
        vision_embeddings = self._process_image_input(image_input)
        # https://github.com/huggingface/transformers/blob/main/src/transformers/models/paligemma/modeling_paligemma.py#L294 # noqa
        vision_embeddings = vision_embeddings * (self.config.hidden_size**-0.5)
        return vision_embeddings

364
365
366
367
    def forward(self,
                input_ids: torch.Tensor,
                positions: torch.Tensor,
                intermediate_tensors: Optional[IntermediateTensors] = None,
368
                inputs_embeds: Optional[torch.Tensor] = None,
369
                **kwargs: object) -> IntermediateTensors:
370
371
        if intermediate_tensors is not None:
            inputs_embeds = None
372

373
374
        hidden_states = self.language_model.model(input_ids,
                                                  positions,
375
                                                  intermediate_tensors,
376
                                                  inputs_embeds=inputs_embeds)
Roger Wang's avatar
Roger Wang committed
377
378
379

        return hidden_states

380
381
382
383
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
384
        return self.language_model.compute_logits(hidden_states)
Roger Wang's avatar
Roger Wang committed
385

386
387
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
388
        loader = AutoWeightsLoader(self)
389
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)