paligemma.py 13.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from collections.abc import Iterable, Mapping, Sequence
4
from typing import Annotated, Literal, Optional, Union
Roger Wang's avatar
Roger Wang committed
5
6
7

import torch
from torch import nn
8
from transformers import BatchFeature, PaliGemmaConfig
Roger Wang's avatar
Roger Wang committed
9

10
from vllm.config import VllmConfig
Roger Wang's avatar
Roger Wang committed
11
12
from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY
13
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
14
15
                                    MultiModalInputs, MultiModalKwargsItems,
                                    MultiModalUUIDDict)
16
17
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
                                   MultiModalDataItems)
18
19
from vllm.multimodal.processing import (BaseMultiModalProcessor,
                                        BaseProcessingInfo, PromptIndexTargets,
20
                                        PromptInsertion, PromptUpdate,
21
                                        PromptUpdateDetails)
22
from vllm.multimodal.profiling import BaseDummyInputsBuilder
23
from vllm.sequence import IntermediateTensors
24
from vllm.utils.tensor_schema import TensorSchema, TensorShape
Roger Wang's avatar
Roger Wang committed
25

26
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
27
from .siglip import SiglipVisionModel
28
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
29
                    init_vllm_registered_model, maybe_prefix)
30
from .vision import get_vision_encoder_info
Roger Wang's avatar
Roger Wang committed
31
32
33
34

logger = init_logger(__name__)


35
36
37
38
39
40
41
42
43
44
class PaliGemmaImagePixelInputs(TensorSchema):
    """
    Dimensions:
        - bn: Batch size * number of images
        - c: Number of channels (3)
        - h: Height
        - w: Width
    """
    type: Literal["pixel_values"] = "pixel_values"
    data: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
45
46


47
class PaliGemmaImageEmbeddingInputs(TensorSchema):
48
    """
49
50
51
52
53
54
55
    Dimensions:
        - bn: Batch size * number of images
        - ifs: Image feature size
        - hs: Hidden size (must match language model backbone)
    """
    type: Literal["image_embeds"] = "image_embeds"
    data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")]
56
57
58
59
60
61


PaliGemmaImageInputs = Union[PaliGemmaImagePixelInputs,
                             PaliGemmaImageEmbeddingInputs]


Roger Wang's avatar
Roger Wang committed
62
63
64
65
66
class PaliGemmaMultiModalProjector(nn.Module):

    def __init__(self, vision_hidden_size: int, projection_dim: int):
        super().__init__()

67
        self.linear = nn.Linear(vision_hidden_size, projection_dim, bias=True)
Roger Wang's avatar
Roger Wang committed
68
69

    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
70
        hidden_states = self.linear(image_features)
Roger Wang's avatar
Roger Wang committed
71
72
73
        return hidden_states


74
75
76
77
78
class PaliGemmaProcessingInfo(BaseProcessingInfo):

    def get_hf_config(self):
        return self.ctx.get_hf_config(PaliGemmaConfig)

79
80
81
    def get_vision_encoder_info(self):
        return get_vision_encoder_info(self.get_hf_config())

82
83
84
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": 1}

85
    def get_num_image_tokens(
86
        self,
87
88
89
90
        *,
        image_width: int,
        image_height: int,
    ) -> int:
91
        vision_encoder_info = self.get_vision_encoder_info()
92
93
94
95
96

        return vision_encoder_info.get_num_image_tokens(
            image_width=image_width,
            image_height=image_height,
        )
97
98
99
100
101


class PaliGemmaDummyInputsBuilder(
        BaseDummyInputsBuilder[PaliGemmaProcessingInfo]):

102
103
104
105
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        return ""

    def get_dummy_mm_data(
106
107
108
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
109
    ) -> MultiModalDataDict:
110
111
112
113
114
115
        hf_config = self.info.get_hf_config()
        vision_config = hf_config.vision_config
        max_image_size = vision_config.image_size

        num_images = mm_counts.get("image", 0)

116
        return {
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
            "image":
            self._get_dummy_images(width=max_image_size,
                                   height=max_image_size,
                                   num_images=num_images)
        }


class PaliGemmaMultiModalProcessor(
        BaseMultiModalProcessor[PaliGemmaProcessingInfo]):

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
132
        tok_kwargs: Mapping[str, object],
133
134
135
    ) -> BatchFeature:
        tokenizer = self.info.get_tokenizer()
        if not mm_data:
136
            prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)
137
138
139
140
141
142
            return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

        return super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
143
            tok_kwargs=tok_kwargs,
144
145
146
147
148
149
150
151
152
153
154
155
156
        )

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(pixel_values=MultiModalFieldConfig.batched("image"))

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
157
        out_mm_kwargs: MultiModalKwargsItems,
158
    ) -> Sequence[PromptUpdate]:
159
160
161
162
163
164
165
166
        hf_config = self.info.get_hf_config()
        image_token_id = hf_config.image_token_index

        tokenizer = self.info.get_tokenizer()

        bos_token_id = tokenizer.bos_token_id
        assert isinstance(bos_token_id, int)

167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
        def get_insertion(item_idx: int):
            images = mm_items.get_items(
                "image", (ImageEmbeddingItems, ImageProcessorItems))

            if isinstance(images, ImageEmbeddingItems):
                num_image_tokens = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
                num_image_tokens = self.info.get_num_image_tokens(
                    image_width=image_size.width,
                    image_height=image_size.height,
                )

            image_tokens = [image_token_id] * num_image_tokens

            return PromptUpdateDetails.select_token_id(
                image_tokens + [bos_token_id],
                embed_token_id=image_token_id,
            )

187
188
189
190
191
192
193
194
        # Paligemma 1 and 2 have different tokenizer.add_bos_token
        # Insert <image>*n + <bos> after <bos> for Paligemma 1
        # Insert <image>*n + <bos> for Paligemma 2
        return [
            PromptInsertion(
                modality="image",
                target=PromptIndexTargets.prefix(
                    [bos_token_id] if tokenizer.add_bos_token else []),
195
                insertion=get_insertion,
196
197
198
199
200
201
202
203
            )
        ]

    def apply(
        self,
        prompt: Union[str, list[int]],
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
204
        tokenization_kwargs: Optional[Mapping[str, object]] = None,
205
        mm_uuids: Optional[MultiModalUUIDDict] = None,
206
    ) -> MultiModalInputs:
207
208
209
210
        mm_inputs = super().apply(prompt,
                                  mm_data,
                                  hf_processor_mm_kwargs,
                                  tokenization_kwargs,
211
                                  mm_uuids=mm_uuids)
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
        prompt_token_ids = mm_inputs["prompt_token_ids"]

        tokenizer = self.info.get_tokenizer()
        newline_prompt = "\n"
        newline_token_id = tokenizer.encode(newline_prompt)[-1]  # 108
        # Force to add newline at the end of prompt for paligemma's format
        # This step can NOT be replacemented by current PromptUpdate methods
        if len(prompt_token_ids) and prompt_token_ids[-1] != newline_token_id:
            prompt_token_ids.append(newline_token_id)
            mm_inputs["prompt_token_ids"] = prompt_token_ids

        return mm_inputs


@MULTIMODAL_REGISTRY.register_processor(
    PaliGemmaMultiModalProcessor,
    info=PaliGemmaProcessingInfo,
    dummy_inputs=PaliGemmaDummyInputsBuilder)
230
class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
231
                                        SupportsPP):
232
233
234
235
236
237
238
239
240
241
242
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }
243

244
245
246
247
248
249
250
251
252
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.language_model.": "language_model.model.",
            "model.vision_tower.": "vision_tower.",
            "model.multi_modal_projector.": "multi_modal_projector.",
            "lm_head.": "language_model.lm_head.",
        })

253
254
255
256
257
258
259
    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
        if modality.startswith("image"):
            return None

        raise ValueError("Only image modality is supported")

260
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Roger Wang's avatar
Roger Wang committed
261
        super().__init__()
262
263
264
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
Roger Wang's avatar
Roger Wang committed
265
266
267
        self.config = config
        self.multimodal_config = multimodal_config

268
        self.vision_tower = SiglipVisionModel(config.vision_config,
269
                                              quant_config,
270
271
                                              prefix=maybe_prefix(
                                                  prefix, "vision_tower"))
Roger Wang's avatar
Roger Wang committed
272
273
274
275
276
        self.multi_modal_projector = PaliGemmaMultiModalProjector(
            vision_hidden_size=config.vision_config.hidden_size,
            projection_dim=config.vision_config.projection_dim)

        self.quant_config = quant_config
Jani Monoses's avatar
Jani Monoses committed
277
278
279
280
281

        if config.text_config.model_type == "gemma":
            config.text_config.architectures = ["GemmaForCausalLM"]
        else:
            config.text_config.architectures = ["Gemma2ForCausalLM"]
282
283
        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
284
285
286
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )
Roger Wang's avatar
Roger Wang committed
287
        logit_scale = getattr(config, "logit_scale", 1.0)
288
289
290
291
292
        self.language_model.logits_processor.scale *= logit_scale

        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

Roger Wang's avatar
Roger Wang committed
293
294
295
    def _parse_and_validate_image_input(
            self, **kwargs: object) -> Optional[PaliGemmaImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
296
        image_embeds = kwargs.pop("image_embeds", None)
Roger Wang's avatar
Roger Wang committed
297

298
        if pixel_values is None and image_embeds is None:
Roger Wang's avatar
Roger Wang committed
299
300
            return None

301
        if pixel_values is not None:
302
            pixel_values = flatten_bn(pixel_values, concat=True)
303

304
305
306
307
308
309
310
            h = w = self.config.vision_config.image_size
            return PaliGemmaImagePixelInputs(type="pixel_values",
                                             data=pixel_values,
                                             resolve_bindings={
                                                 "h": h,
                                                 "w": w
                                             })
311
312

        if image_embeds is not None:
313
            image_embeds = flatten_bn(image_embeds, concat=True)
314

315
316
317
318
319
320
            return PaliGemmaImageEmbeddingInputs(
                type="image_embeds",
                data=image_embeds,
            )

        raise AssertionError("This line should be unreachable.")
Roger Wang's avatar
Roger Wang committed
321

322
323
324
325
326
    def _image_pixels_to_features(
        self,
        vision_tower: SiglipVisionModel,
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:
Roger Wang's avatar
Roger Wang committed
327

328
        target_dtype = vision_tower.get_input_embeddings().weight.dtype
329
        image_features = vision_tower(pixel_values.to(dtype=target_dtype))
Roger Wang's avatar
Roger Wang committed
330

331
        return image_features
Roger Wang's avatar
Roger Wang committed
332

333
    def _process_image_input(
334
        self,
335
        image_input: PaliGemmaImageInputs,
336
    ) -> torch.Tensor:
Roger Wang's avatar
Roger Wang committed
337

338
339
        if image_input["type"] == "image_embeds":
            return image_input["data"]
Roger Wang's avatar
Roger Wang committed
340

341
342
343
        assert self.vision_tower is not None
        pixel_values = image_input["data"]
        image_features = self._image_pixels_to_features(
344
345
346
            self.vision_tower,
            pixel_values,
        )
Roger Wang's avatar
Roger Wang committed
347
348
349

        return self.multi_modal_projector(image_features)

350
351
352
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

353
354
    def get_multimodal_embeddings(self,
                                  **kwargs: object) -> MultiModalEmbeddings:
355
356
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
357
            return []
358
359
360
361
362
        vision_embeddings = self._process_image_input(image_input)
        # https://github.com/huggingface/transformers/blob/main/src/transformers/models/paligemma/modeling_paligemma.py#L294 # noqa
        vision_embeddings = vision_embeddings * (self.config.hidden_size**-0.5)
        return vision_embeddings

363
364
365
366
    def forward(self,
                input_ids: torch.Tensor,
                positions: torch.Tensor,
                intermediate_tensors: Optional[IntermediateTensors] = None,
367
                inputs_embeds: Optional[torch.Tensor] = None,
368
                **kwargs: object) -> IntermediateTensors:
369
370
        if intermediate_tensors is not None:
            inputs_embeds = None
371

372
373
        hidden_states = self.language_model.model(input_ids,
                                                  positions,
374
                                                  intermediate_tensors,
375
                                                  inputs_embeds=inputs_embeds)
Roger Wang's avatar
Roger Wang committed
376
377
378

        return hidden_states

379
380
381
382
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
383
        return self.language_model.compute_logits(hidden_states)
Roger Wang's avatar
Roger Wang committed
384

385
386
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
387
        loader = AutoWeightsLoader(self)
388
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)