paligemma.py 12.6 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
4
                    TypedDict, Union)
Roger Wang's avatar
Roger Wang committed
5
6
7

import torch
from torch import nn
8
from transformers import PaliGemmaConfig
Roger Wang's avatar
Roger Wang committed
9
10

from vllm.attention import AttentionMetadata
11
from vllm.config import VllmConfig
12
13
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
                         InputContext, token_inputs)
Roger Wang's avatar
Roger Wang committed
14
from vllm.logger import init_logger
15
from vllm.model_executor.layers.sampler import SamplerOutput
Roger Wang's avatar
Roger Wang committed
16
17
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
18
from vllm.multimodal.inputs import NestedTensors
19
from vllm.multimodal.utils import cached_get_tokenizer
20
from vllm.sequence import IntermediateTensors
Roger Wang's avatar
Roger Wang committed
21

22
from .interfaces import SupportsMultiModal, SupportsPP
23
24
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
                     dummy_seq_data_for_siglip, get_max_siglip_image_tokens)
25
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
26
                    maybe_prefix, merge_multimodal_embeddings)
Roger Wang's avatar
Roger Wang committed
27
28
29
30

logger = init_logger(__name__)


31
32
33
class PaliGemmaImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
    data: torch.Tensor
34
    """Shape: `(batch_size * num_images, num_channels, height, width)`"""
35
36
37
38
39


class PaliGemmaImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
    data: torch.Tensor
40
    """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
41
42
43
44
45
46
47
48
49

    `hidden_size` must match the hidden size of language model backbone.
    """


PaliGemmaImageInputs = Union[PaliGemmaImagePixelInputs,
                             PaliGemmaImageEmbeddingInputs]


Roger Wang's avatar
Roger Wang committed
50
51
def get_max_paligemma_image_tokens(ctx: InputContext):
    hf_config = ctx.get_hf_config(PaliGemmaConfig)
52
    vision_config = hf_config.vision_config
Roger Wang's avatar
Roger Wang committed
53

54
    return get_max_siglip_image_tokens(vision_config)
Roger Wang's avatar
Roger Wang committed
55
56


57
58
def dummy_data_for_paligemma(ctx: InputContext, seq_len: int,
                             mm_counts: Mapping[str, int]):
Roger Wang's avatar
Roger Wang committed
59
60
    hf_config = ctx.get_hf_config(PaliGemmaConfig)
    vision_config = hf_config.vision_config
61
    num_images = mm_counts["image"]
Roger Wang's avatar
Roger Wang committed
62

63
    seq_data, ranges = dummy_seq_data_for_siglip(
64
        vision_config,
Roger Wang's avatar
Roger Wang committed
65
        seq_len,
66
        num_images,
Roger Wang's avatar
Roger Wang committed
67
68
69
        image_token_id=hf_config.image_token_index,
    )

70
    mm_data = dummy_image_for_siglip(vision_config, num_images)
71
    return DummyData(seq_data, mm_data, ranges)
Roger Wang's avatar
Roger Wang committed
72
73


74
75
def input_processor_for_paligemma(ctx: InputContext,
                                  inputs: DecoderOnlyInputs):
Roger Wang's avatar
Roger Wang committed
76
77
78
79
80
81
82
83

    """
    The correct prompt format needs to be:
    '<image>' * image_feature_size + '<bos>' + prompt + '\n'

    See https://github.com/huggingface/transformers/blob/25245ec26dc29bcf6102e1b4ddd0dfd02e720cf5/src/transformers/models/paligemma/processing_paligemma.py#L55
    """ # noqa

84
    multi_modal_data = inputs.get("multi_modal_data")
Roger Wang's avatar
Roger Wang committed
85
    if multi_modal_data is None or "image" not in multi_modal_data:
86
        return inputs
Roger Wang's avatar
Roger Wang committed
87
88
89
90
91
92
93
94
95
96
97

    model_config = ctx.model_config
    hf_config = ctx.get_hf_config(PaliGemmaConfig)

    tokenizer = cached_get_tokenizer(model_config.tokenizer)
    image_feature_size = hf_config.text_config.num_image_tokens
    image_token_str = tokenizer.decode(hf_config.image_token_index)
    bos_token = tokenizer.decode(hf_config.bos_token_id)
    image_token_str_pad = image_token_str * image_feature_size
    image_token_ids_pad = [hf_config.image_token_index] * image_feature_size

98
99
    orig_prompt = inputs.get("prompt")
    orig_prompt_ids = inputs.get("prompt_token_ids")
Roger Wang's avatar
Roger Wang committed
100

101
    if orig_prompt is not None and image_token_str in orig_prompt:
Roger Wang's avatar
Roger Wang committed
102
103
104
105
106
107
108
109
        logger.warning(
            "The image token '%s' was detected in the prompt and "
            "will be removed. Please follow the proper prompt format"
            " documented on HuggingFace.", image_token_str)
        orig_prompt = orig_prompt.replace(image_token_str, "")
        orig_prompt_ids.remove(hf_config.image_token_index)

    new_prompt = f"{image_token_str_pad}{bos_token}{orig_prompt}\n"
Jani Monoses's avatar
Jani Monoses committed
110
111
112
113
114

    # The PaliGemma 2 tokenizer does not include a starting BOS token
    if orig_prompt_ids[0] != hf_config.bos_token_id:
        orig_prompt_ids = [hf_config.bos_token_id] + orig_prompt_ids

Roger Wang's avatar
Roger Wang committed
115
116
117
    new_token_ids = image_token_ids_pad + orig_prompt_ids + [108]  #newline

    # NOTE: Create a defensive copy of the original inputs
118
119
120
    return token_inputs(prompt_token_ids=new_token_ids,
                        prompt=new_prompt,
                        multi_modal_data=multi_modal_data)
Roger Wang's avatar
Roger Wang committed
121
122
123
124
125
126
127


class PaliGemmaMultiModalProjector(nn.Module):

    def __init__(self, vision_hidden_size: int, projection_dim: int):
        super().__init__()

128
        self.linear = nn.Linear(vision_hidden_size, projection_dim, bias=True)
Roger Wang's avatar
Roger Wang committed
129
130

    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
131
        hidden_states = self.linear(image_features)
Roger Wang's avatar
Roger Wang committed
132
133
134
135
136
137
138
        return hidden_states


@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_paligemma_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_paligemma)
@INPUT_REGISTRY.register_input_processor(input_processor_for_paligemma)
139
140
class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
                                        SupportsPP):
141
142
143
144
145
146
147
148
149
150
151
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }
152

153
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Roger Wang's avatar
Roger Wang committed
154
        super().__init__()
155
156
157
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
Roger Wang's avatar
Roger Wang committed
158
159
160
        self.config = config
        self.multimodal_config = multimodal_config

161
        self.vision_tower = SiglipVisionModel(config.vision_config,
162
                                              quant_config,
163
164
                                              prefix=maybe_prefix(
                                                  prefix, "vision_tower"))
Roger Wang's avatar
Roger Wang committed
165
166
167
168
169
        self.multi_modal_projector = PaliGemmaMultiModalProjector(
            vision_hidden_size=config.vision_config.hidden_size,
            projection_dim=config.vision_config.projection_dim)

        self.quant_config = quant_config
Jani Monoses's avatar
Jani Monoses committed
170
171
172
173
174

        if config.text_config.model_type == "gemma":
            config.text_config.architectures = ["GemmaForCausalLM"]
        else:
            config.text_config.architectures = ["Gemma2ForCausalLM"]
175
176
        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
177
178
179
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )
Roger Wang's avatar
Roger Wang committed
180
        logit_scale = getattr(config, "logit_scale", 1.0)
181
182
183
184
185
186
187
188
        self.language_model.logits_processor.scale *= logit_scale

        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

    @property
    def sampler(self):
        return self.language_model.sampler
Roger Wang's avatar
Roger Wang committed
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205

    def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)
        actual_dims = tuple(data.shape[1:])

        if actual_dims != expected_dims:
            expected_expr = ("batch_size", *map(str, expected_dims))
            raise ValueError(
                f"The expected shape of pixel values is {expected_expr}. "
                f"You supplied {tuple(data.shape)}.")

        return data

    def _parse_and_validate_image_input(
            self, **kwargs: object) -> Optional[PaliGemmaImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
206
        image_embeds = kwargs.pop("image_embeds", None)
Roger Wang's avatar
Roger Wang committed
207

208
        if pixel_values is None and image_embeds is None:
Roger Wang's avatar
Roger Wang committed
209
210
            return None

211
212
213
214
        if pixel_values is not None:
            if not isinstance(pixel_values, torch.Tensor):
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")
215
216
217
218

            # Remove the N dimension until multiple images are supported.
            pixel_values = pixel_values.squeeze(1)

219
220
221
222
223
224
225
226
227
            return PaliGemmaImagePixelInputs(
                type="pixel_values",
                data=self._validate_pixel_values(pixel_values),
            )

        if image_embeds is not None:
            if not isinstance(image_embeds, torch.Tensor):
                raise ValueError("Incorrect type of image embeddings. "
                                 f"Got type: {type(image_embeds)}")
228
229
230
231

            # Remove the N dimension until multiple images are supported.
            image_embeds = image_embeds.squeeze(1)

232
233
234
235
236
237
            return PaliGemmaImageEmbeddingInputs(
                type="image_embeds",
                data=image_embeds,
            )

        raise AssertionError("This line should be unreachable.")
Roger Wang's avatar
Roger Wang committed
238

239
240
241
242
243
    def _image_pixels_to_features(
        self,
        vision_tower: SiglipVisionModel,
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:
Roger Wang's avatar
Roger Wang committed
244

245
        target_dtype = vision_tower.get_input_embeddings().weight.dtype
246
        image_features = vision_tower(pixel_values.to(dtype=target_dtype))
Roger Wang's avatar
Roger Wang committed
247

248
        return image_features
Roger Wang's avatar
Roger Wang committed
249

250
    def _process_image_input(
251
        self,
252
        image_input: PaliGemmaImageInputs,
253
    ) -> torch.Tensor:
Roger Wang's avatar
Roger Wang committed
254

255
256
        if image_input["type"] == "image_embeds":
            return image_input["data"]
Roger Wang's avatar
Roger Wang committed
257

258
259
260
        assert self.vision_tower is not None
        pixel_values = image_input["data"]
        image_features = self._image_pixels_to_features(
261
262
263
            self.vision_tower,
            pixel_values,
        )
Roger Wang's avatar
Roger Wang committed
264
265
266

        return self.multi_modal_projector(image_features)

267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
    def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return None
        vision_embeddings = self._process_image_input(image_input)
        # https://github.com/huggingface/transformers/blob/main/src/transformers/models/paligemma/modeling_paligemma.py#L294 # noqa
        vision_embeddings = vision_embeddings * (self.config.hidden_size**-0.5)
        return vision_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings: Optional[NestedTensors] = None,
    ) -> torch.Tensor:
        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
        if multimodal_embeddings is not None:
            inputs_embeds = merge_multimodal_embeddings(
                input_ids, inputs_embeds, multimodal_embeddings,
                self.config.image_token_index)
        return inputs_embeds

288
289
290
    def forward(self,
                input_ids: torch.Tensor,
                positions: torch.Tensor,
Roger Wang's avatar
Roger Wang committed
291
292
                kv_caches: List[torch.Tensor],
                attn_metadata: AttentionMetadata,
293
                intermediate_tensors: Optional[IntermediateTensors] = None,
294
                inputs_embeds: Optional[torch.Tensor] = None,
295
296
297
                **kwargs: object) -> Union[SamplerOutput, IntermediateTensors]:
        if intermediate_tensors is not None:
            inputs_embeds = None
298
299
300
301
302
303
304
305

        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
        elif inputs_embeds is None:
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
            input_ids = None
Roger Wang's avatar
Roger Wang committed
306

307
308
309
310
        hidden_states = self.language_model.model(input_ids,
                                                  positions,
                                                  kv_caches,
                                                  attn_metadata,
311
                                                  intermediate_tensors,
312
                                                  inputs_embeds=inputs_embeds)
Roger Wang's avatar
Roger Wang committed
313
314
315

        return hidden_states

316
317
318
319
320
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
321
322
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)
Roger Wang's avatar
Roger Wang committed
323
324
325
326
327
328

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
329
        return self.language_model.sample(logits, sampling_metadata)
Roger Wang's avatar
Roger Wang committed
330

331
332
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
333
        loader = AutoWeightsLoader(self)
334
        return loader.load_weights(weights)