paligemma.py 14.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from collections.abc import Iterable, Mapping, Sequence
4
from typing import Annotated, Literal, Optional, Union
Roger Wang's avatar
Roger Wang committed
5
6
7

import torch
from torch import nn
8
from transformers import BatchFeature, PaliGemmaConfig
Roger Wang's avatar
Roger Wang committed
9

10
from vllm.config import VllmConfig
11
from vllm.config.multimodal import BaseDummyOptions
Roger Wang's avatar
Roger Wang committed
12
13
from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY
14
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
15
16
                                    MultiModalInputs, MultiModalKwargsItems,
                                    MultiModalUUIDDict)
17
18
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
                                   MultiModalDataItems)
19
20
from vllm.multimodal.processing import (BaseMultiModalProcessor,
                                        BaseProcessingInfo, PromptIndexTargets,
21
                                        PromptInsertion, PromptUpdate,
22
                                        PromptUpdateDetails)
23
from vllm.multimodal.profiling import BaseDummyInputsBuilder
24
from vllm.sequence import IntermediateTensors
25
from vllm.utils.tensor_schema import TensorSchema, TensorShape
Roger Wang's avatar
Roger Wang committed
26

27
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
28
from .siglip import SiglipVisionModel
29
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
30
                    init_vllm_registered_model, maybe_prefix)
31
from .vision import get_vision_encoder_info
Roger Wang's avatar
Roger Wang committed
32
33
34
35

logger = init_logger(__name__)


36
37
38
39
40
41
42
43
44
45
class PaliGemmaImagePixelInputs(TensorSchema):
    """
    Dimensions:
        - bn: Batch size * number of images
        - c: Number of channels (3)
        - h: Height
        - w: Width
    """
    type: Literal["pixel_values"] = "pixel_values"
    data: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
46
47


48
class PaliGemmaImageEmbeddingInputs(TensorSchema):
49
    """
50
51
52
53
54
55
56
    Dimensions:
        - bn: Batch size * number of images
        - ifs: Image feature size
        - hs: Hidden size (must match language model backbone)
    """
    type: Literal["image_embeds"] = "image_embeds"
    data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")]
57
58
59
60
61
62


PaliGemmaImageInputs = Union[PaliGemmaImagePixelInputs,
                             PaliGemmaImageEmbeddingInputs]


Roger Wang's avatar
Roger Wang committed
63
64
65
66
67
class PaliGemmaMultiModalProjector(nn.Module):

    def __init__(self, vision_hidden_size: int, projection_dim: int):
        super().__init__()

68
        self.linear = nn.Linear(vision_hidden_size, projection_dim, bias=True)
Roger Wang's avatar
Roger Wang committed
69
70

    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
71
        hidden_states = self.linear(image_features)
Roger Wang's avatar
Roger Wang committed
72
73
74
        return hidden_states


75
76
77
78
79
class PaliGemmaProcessingInfo(BaseProcessingInfo):

    def get_hf_config(self):
        return self.ctx.get_hf_config(PaliGemmaConfig)

80
81
82
    def get_vision_encoder_info(self):
        return get_vision_encoder_info(self.get_hf_config())

83
84
85
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": 1}

86
    def get_num_image_tokens(
87
        self,
88
89
90
91
        *,
        image_width: int,
        image_height: int,
    ) -> int:
92
        vision_encoder_info = self.get_vision_encoder_info()
93
94
95
96
97

        return vision_encoder_info.get_num_image_tokens(
            image_width=image_width,
            image_height=image_height,
        )
98
99
100
101
102


class PaliGemmaDummyInputsBuilder(
        BaseDummyInputsBuilder[PaliGemmaProcessingInfo]):

103
104
105
106
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        return ""

    def get_dummy_mm_data(
107
108
109
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
110
        mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
111
    ) -> MultiModalDataDict:
112
113
114
115
116
117
        hf_config = self.info.get_hf_config()
        vision_config = hf_config.vision_config
        max_image_size = vision_config.image_size

        num_images = mm_counts.get("image", 0)

118
119
        image_overrides = mm_options.get("image") if mm_options else None

120
        return {
121
122
123
            "image":
            self._get_dummy_images(width=max_image_size,
                                   height=max_image_size,
124
125
                                   num_images=num_images,
                                   overrides=image_overrides)
126
127
128
129
130
131
132
133
134
135
136
        }


class PaliGemmaMultiModalProcessor(
        BaseMultiModalProcessor[PaliGemmaProcessingInfo]):

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
137
        tok_kwargs: Mapping[str, object],
138
139
140
    ) -> BatchFeature:
        tokenizer = self.info.get_tokenizer()
        if not mm_data:
141
            prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)
142
143
144
145
146
147
            return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

        return super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
148
            tok_kwargs=tok_kwargs,
149
150
151
152
153
154
155
156
157
158
159
160
161
        )

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(pixel_values=MultiModalFieldConfig.batched("image"))

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
162
        out_mm_kwargs: MultiModalKwargsItems,
163
    ) -> Sequence[PromptUpdate]:
164
165
166
167
168
169
170
171
        hf_config = self.info.get_hf_config()
        image_token_id = hf_config.image_token_index

        tokenizer = self.info.get_tokenizer()

        bos_token_id = tokenizer.bos_token_id
        assert isinstance(bos_token_id, int)

172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
        def get_insertion(item_idx: int):
            images = mm_items.get_items(
                "image", (ImageEmbeddingItems, ImageProcessorItems))

            if isinstance(images, ImageEmbeddingItems):
                num_image_tokens = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
                num_image_tokens = self.info.get_num_image_tokens(
                    image_width=image_size.width,
                    image_height=image_size.height,
                )

            image_tokens = [image_token_id] * num_image_tokens

            return PromptUpdateDetails.select_token_id(
                image_tokens + [bos_token_id],
                embed_token_id=image_token_id,
            )

192
193
194
195
196
197
198
199
        # Paligemma 1 and 2 have different tokenizer.add_bos_token
        # Insert <image>*n + <bos> after <bos> for Paligemma 1
        # Insert <image>*n + <bos> for Paligemma 2
        return [
            PromptInsertion(
                modality="image",
                target=PromptIndexTargets.prefix(
                    [bos_token_id] if tokenizer.add_bos_token else []),
200
                insertion=get_insertion,
201
202
203
204
205
206
207
208
            )
        ]

    def apply(
        self,
        prompt: Union[str, list[int]],
        mm_data: MultiModalDataDict,
        hf_processor_mm_kwargs: Mapping[str, object],
209
        tokenization_kwargs: Optional[Mapping[str, object]] = None,
210
        mm_uuids: Optional[MultiModalUUIDDict] = None,
211
    ) -> MultiModalInputs:
212
213
214
215
        mm_inputs = super().apply(prompt,
                                  mm_data,
                                  hf_processor_mm_kwargs,
                                  tokenization_kwargs,
216
                                  mm_uuids=mm_uuids)
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
        prompt_token_ids = mm_inputs["prompt_token_ids"]

        tokenizer = self.info.get_tokenizer()
        newline_prompt = "\n"
        newline_token_id = tokenizer.encode(newline_prompt)[-1]  # 108
        # Force to add newline at the end of prompt for paligemma's format
        # This step can NOT be replacemented by current PromptUpdate methods
        if len(prompt_token_ids) and prompt_token_ids[-1] != newline_token_id:
            prompt_token_ids.append(newline_token_id)
            mm_inputs["prompt_token_ids"] = prompt_token_ids

        return mm_inputs


@MULTIMODAL_REGISTRY.register_processor(
    PaliGemmaMultiModalProcessor,
    info=PaliGemmaProcessingInfo,
    dummy_inputs=PaliGemmaDummyInputsBuilder)
235
class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
236
                                        SupportsPP):
237
238
239
240
241
242
243
244
245
246
247
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }
248

249
250
251
252
253
254
255
256
257
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.language_model.": "language_model.model.",
            "model.vision_tower.": "vision_tower.",
            "model.multi_modal_projector.": "multi_modal_projector.",
            "lm_head.": "language_model.lm_head.",
        })

258
259
260
261
262
263
264
    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
        if modality.startswith("image"):
            return None

        raise ValueError("Only image modality is supported")

265
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Roger Wang's avatar
Roger Wang committed
266
        super().__init__()
267
268
269
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
Roger Wang's avatar
Roger Wang committed
270
271
272
        self.config = config
        self.multimodal_config = multimodal_config

273
        self.vision_tower = SiglipVisionModel(config.vision_config,
274
                                              quant_config,
275
276
                                              prefix=maybe_prefix(
                                                  prefix, "vision_tower"))
Roger Wang's avatar
Roger Wang committed
277
278
279
280
281
        self.multi_modal_projector = PaliGemmaMultiModalProjector(
            vision_hidden_size=config.vision_config.hidden_size,
            projection_dim=config.vision_config.projection_dim)

        self.quant_config = quant_config
Jani Monoses's avatar
Jani Monoses committed
282
283
284
285
286

        if config.text_config.model_type == "gemma":
            config.text_config.architectures = ["GemmaForCausalLM"]
        else:
            config.text_config.architectures = ["Gemma2ForCausalLM"]
287
288
        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
289
290
291
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )
Roger Wang's avatar
Roger Wang committed
292
        logit_scale = getattr(config, "logit_scale", 1.0)
293
294
295
296
297
        self.language_model.logits_processor.scale *= logit_scale

        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

Roger Wang's avatar
Roger Wang committed
298
299
300
    def _parse_and_validate_image_input(
            self, **kwargs: object) -> Optional[PaliGemmaImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
301
        image_embeds = kwargs.pop("image_embeds", None)
Roger Wang's avatar
Roger Wang committed
302

303
        if pixel_values is None and image_embeds is None:
Roger Wang's avatar
Roger Wang committed
304
305
            return None

306
        if pixel_values is not None:
307
            pixel_values = flatten_bn(pixel_values, concat=True)
308

309
310
311
312
313
314
315
            h = w = self.config.vision_config.image_size
            return PaliGemmaImagePixelInputs(type="pixel_values",
                                             data=pixel_values,
                                             resolve_bindings={
                                                 "h": h,
                                                 "w": w
                                             })
316
317

        if image_embeds is not None:
318
            image_embeds = flatten_bn(image_embeds, concat=True)
319

320
321
322
323
324
325
            return PaliGemmaImageEmbeddingInputs(
                type="image_embeds",
                data=image_embeds,
            )

        raise AssertionError("This line should be unreachable.")
Roger Wang's avatar
Roger Wang committed
326

327
328
329
330
331
    def _image_pixels_to_features(
        self,
        vision_tower: SiglipVisionModel,
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:
Roger Wang's avatar
Roger Wang committed
332

333
        target_dtype = vision_tower.get_input_embeddings().weight.dtype
334
        image_features = vision_tower(pixel_values.to(dtype=target_dtype))
Roger Wang's avatar
Roger Wang committed
335

336
        return image_features
Roger Wang's avatar
Roger Wang committed
337

338
    def _process_image_input(
339
        self,
340
        image_input: PaliGemmaImageInputs,
341
    ) -> torch.Tensor:
Roger Wang's avatar
Roger Wang committed
342

343
344
        if image_input["type"] == "image_embeds":
            return image_input["data"]
Roger Wang's avatar
Roger Wang committed
345

346
347
348
        assert self.vision_tower is not None
        pixel_values = image_input["data"]
        image_features = self._image_pixels_to_features(
349
350
351
            self.vision_tower,
            pixel_values,
        )
Roger Wang's avatar
Roger Wang committed
352
353
354

        return self.multi_modal_projector(image_features)

355
356
357
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

358
359
    def get_multimodal_embeddings(self,
                                  **kwargs: object) -> MultiModalEmbeddings:
360
361
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
362
            return []
363
364
365
366
367
        vision_embeddings = self._process_image_input(image_input)
        # https://github.com/huggingface/transformers/blob/main/src/transformers/models/paligemma/modeling_paligemma.py#L294 # noqa
        vision_embeddings = vision_embeddings * (self.config.hidden_size**-0.5)
        return vision_embeddings

368
369
370
371
    def forward(self,
                input_ids: torch.Tensor,
                positions: torch.Tensor,
                intermediate_tensors: Optional[IntermediateTensors] = None,
372
                inputs_embeds: Optional[torch.Tensor] = None,
373
                **kwargs: object) -> IntermediateTensors:
374
375
        if intermediate_tensors is not None:
            inputs_embeds = None
376

377
378
        hidden_states = self.language_model.model(input_ids,
                                                  positions,
379
                                                  intermediate_tensors,
380
                                                  inputs_embeds=inputs_embeds)
Roger Wang's avatar
Roger Wang committed
381
382
383

        return hidden_states

384
385
386
387
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
388
        return self.language_model.compute_logits(hidden_states)
Roger Wang's avatar
Roger Wang committed
389

390
391
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
392
        loader = AutoWeightsLoader(self)
393
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)