paligemma.py 11 KB
Newer Older
1
2
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
                    TypedDict, Union)
Roger Wang's avatar
Roger Wang committed
3
4
5

import torch
from torch import nn
6
from transformers import PaliGemmaConfig
Roger Wang's avatar
Roger Wang committed
7
8
9
10
11

from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger
12
from vllm.model_executor.layers.quantization import QuantizationConfig
13
from vllm.model_executor.layers.sampler import SamplerOutput
14
from vllm.model_executor.models.gemma import GemmaForCausalLM
Roger Wang's avatar
Roger Wang committed
15
16
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
17
from vllm.multimodal.utils import cached_get_tokenizer
18
from vllm.sequence import IntermediateTensors
Roger Wang's avatar
Roger Wang committed
19

20
from .interfaces import SupportsMultiModal, SupportsPP
21
22
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
                     dummy_seq_data_for_siglip, get_max_siglip_image_tokens)
23
from .utils import AutoWeightsLoader, merge_multimodal_embeddings
Roger Wang's avatar
Roger Wang committed
24
25
26
27

logger = init_logger(__name__)


28
29
30
class PaliGemmaImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
    data: torch.Tensor
31
    """Shape: `(batch_size * num_images, num_channels, height, width)`"""
32
33
34
35
36


class PaliGemmaImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
    data: torch.Tensor
37
    """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
38
39
40
41
42
43
44
45
46

    `hidden_size` must match the hidden size of language model backbone.
    """


PaliGemmaImageInputs = Union[PaliGemmaImagePixelInputs,
                             PaliGemmaImageEmbeddingInputs]


Roger Wang's avatar
Roger Wang committed
47
48
def get_max_paligemma_image_tokens(ctx: InputContext):
    hf_config = ctx.get_hf_config(PaliGemmaConfig)
49
    vision_config = hf_config.vision_config
Roger Wang's avatar
Roger Wang committed
50

51
    return get_max_siglip_image_tokens(vision_config)
Roger Wang's avatar
Roger Wang committed
52
53


54
55
def dummy_data_for_paligemma(ctx: InputContext, seq_len: int,
                             mm_counts: Mapping[str, int]):
Roger Wang's avatar
Roger Wang committed
56
57
    hf_config = ctx.get_hf_config(PaliGemmaConfig)
    vision_config = hf_config.vision_config
58
    num_images = mm_counts["image"]
Roger Wang's avatar
Roger Wang committed
59

60
61
    seq_data = dummy_seq_data_for_siglip(
        vision_config,
Roger Wang's avatar
Roger Wang committed
62
        seq_len,
63
        num_images,
Roger Wang's avatar
Roger Wang committed
64
65
66
        image_token_id=hf_config.image_token_index,
    )

67
    mm_data = dummy_image_for_siglip(vision_config, num_images)
Roger Wang's avatar
Roger Wang committed
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
    return seq_data, mm_data


def input_processor_for_paligemma(ctx: InputContext, llm_inputs: LLMInputs):

    """
    The correct prompt format needs to be:
    '<image>' * image_feature_size + '<bos>' + prompt + '\n'

    See https://github.com/huggingface/transformers/blob/25245ec26dc29bcf6102e1b4ddd0dfd02e720cf5/src/transformers/models/paligemma/processing_paligemma.py#L55
    """ # noqa

    multi_modal_data = llm_inputs.get("multi_modal_data")
    if multi_modal_data is None or "image" not in multi_modal_data:
        return llm_inputs

    model_config = ctx.model_config
    hf_config = ctx.get_hf_config(PaliGemmaConfig)

    tokenizer = cached_get_tokenizer(model_config.tokenizer)
    image_feature_size = hf_config.text_config.num_image_tokens
    image_token_str = tokenizer.decode(hf_config.image_token_index)
    bos_token = tokenizer.decode(hf_config.bos_token_id)
    image_token_str_pad = image_token_str * image_feature_size
    image_token_ids_pad = [hf_config.image_token_index] * image_feature_size

    orig_prompt = llm_inputs.get("prompt")
    orig_prompt_ids = llm_inputs.get("prompt_token_ids")

97
    if orig_prompt is not None and image_token_str in orig_prompt:
Roger Wang's avatar
Roger Wang committed
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
        logger.warning(
            "The image token '%s' was detected in the prompt and "
            "will be removed. Please follow the proper prompt format"
            " documented on HuggingFace.", image_token_str)
        orig_prompt = orig_prompt.replace(image_token_str, "")
        orig_prompt_ids.remove(hf_config.image_token_index)

    new_prompt = f"{image_token_str_pad}{bos_token}{orig_prompt}\n"
    new_token_ids = image_token_ids_pad + orig_prompt_ids + [108]  #newline

    # NOTE: Create a defensive copy of the original inputs
    return LLMInputs(prompt_token_ids=new_token_ids,
                     prompt=new_prompt,
                     multi_modal_data=multi_modal_data)


class PaliGemmaMultiModalProjector(nn.Module):

    def __init__(self, vision_hidden_size: int, projection_dim: int):
        super().__init__()

119
        self.linear = nn.Linear(vision_hidden_size, projection_dim, bias=True)
Roger Wang's avatar
Roger Wang committed
120
121

    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
122
        hidden_states = self.linear(image_features)
Roger Wang's avatar
Roger Wang committed
123
124
125
126
127
128
129
        return hidden_states


@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_paligemma_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_paligemma)
@INPUT_REGISTRY.register_input_processor(input_processor_for_paligemma)
130
131
class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
                                        SupportsPP):
Roger Wang's avatar
Roger Wang committed
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148

    def __init__(self,
                 config: PaliGemmaConfig,
                 multimodal_config: MultiModalConfig,
                 cache_config: Optional[CacheConfig] = None,
                 quant_config: Optional[QuantizationConfig] = None) -> None:
        super().__init__()

        self.config = config
        self.multimodal_config = multimodal_config

        self.vision_tower = SiglipVisionModel(config.vision_config)
        self.multi_modal_projector = PaliGemmaMultiModalProjector(
            vision_hidden_size=config.vision_config.hidden_size,
            projection_dim=config.vision_config.projection_dim)

        self.quant_config = quant_config
149
150
        self.language_model = GemmaForCausalLM(config.text_config,
                                               cache_config, quant_config)
Roger Wang's avatar
Roger Wang committed
151
        logit_scale = getattr(config, "logit_scale", 1.0)
152
153
154
155
156
157
158
159
        self.language_model.logits_processor.scale *= logit_scale

        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

    @property
    def sampler(self):
        return self.language_model.sampler
Roger Wang's avatar
Roger Wang committed
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176

    def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)
        actual_dims = tuple(data.shape[1:])

        if actual_dims != expected_dims:
            expected_expr = ("batch_size", *map(str, expected_dims))
            raise ValueError(
                f"The expected shape of pixel values is {expected_expr}. "
                f"You supplied {tuple(data.shape)}.")

        return data

    def _parse_and_validate_image_input(
            self, **kwargs: object) -> Optional[PaliGemmaImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
177
        image_embeds = kwargs.pop("image_embeds", None)
Roger Wang's avatar
Roger Wang committed
178

179
        if pixel_values is None and image_embeds is None:
Roger Wang's avatar
Roger Wang committed
180
181
            return None

182
183
184
185
        if pixel_values is not None:
            if not isinstance(pixel_values, torch.Tensor):
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")
186
187
188
189

            # Remove the N dimension until multiple images are supported.
            pixel_values = pixel_values.squeeze(1)

190
191
192
193
194
195
196
197
198
            return PaliGemmaImagePixelInputs(
                type="pixel_values",
                data=self._validate_pixel_values(pixel_values),
            )

        if image_embeds is not None:
            if not isinstance(image_embeds, torch.Tensor):
                raise ValueError("Incorrect type of image embeddings. "
                                 f"Got type: {type(image_embeds)}")
199
200
201
202

            # Remove the N dimension until multiple images are supported.
            image_embeds = image_embeds.squeeze(1)

203
204
205
206
207
208
            return PaliGemmaImageEmbeddingInputs(
                type="image_embeds",
                data=image_embeds,
            )

        raise AssertionError("This line should be unreachable.")
Roger Wang's avatar
Roger Wang committed
209

210
211
212
213
214
    def _image_pixels_to_features(
        self,
        vision_tower: SiglipVisionModel,
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:
Roger Wang's avatar
Roger Wang committed
215

216
        target_dtype = vision_tower.get_input_embeddings().weight.dtype
217
        image_features = vision_tower(pixel_values.to(dtype=target_dtype))
Roger Wang's avatar
Roger Wang committed
218

219
        return image_features
Roger Wang's avatar
Roger Wang committed
220

221
    def _process_image_input(
222
        self,
223
        image_input: PaliGemmaImageInputs,
224
    ) -> torch.Tensor:
Roger Wang's avatar
Roger Wang committed
225

226
227
        if image_input["type"] == "image_embeds":
            return image_input["data"]
Roger Wang's avatar
Roger Wang committed
228

229
230
231
        assert self.vision_tower is not None
        pixel_values = image_input["data"]
        image_features = self._image_pixels_to_features(
232
233
234
            self.vision_tower,
            pixel_values,
        )
Roger Wang's avatar
Roger Wang committed
235
236
237

        return self.multi_modal_projector(image_features)

238
239
240
    def forward(self,
                input_ids: torch.Tensor,
                positions: torch.Tensor,
Roger Wang's avatar
Roger Wang committed
241
242
                kv_caches: List[torch.Tensor],
                attn_metadata: AttentionMetadata,
243
                intermediate_tensors: Optional[IntermediateTensors] = None,
244
245
246
247
248
249
                **kwargs: object) -> Union[SamplerOutput, IntermediateTensors]:
        if intermediate_tensors is not None:
            input_ids = None
            inputs_embeds = None
        else:
            parsed_image_input = self._parse_and_validate_image_input(**kwargs)
Roger Wang's avatar
Roger Wang committed
250

251
252
253
254
255
256
            if parsed_image_input is not None:
                vision_embeddings = self._process_image_input(
                    parsed_image_input)
                # https://github.com/huggingface/transformers/blob/main/src/transformers/models/paligemma/modeling_paligemma.py#L294 # noqa
                vision_embeddings = vision_embeddings * (
                    self.config.hidden_size**-0.5)
Roger Wang's avatar
Roger Wang committed
257

258
259
                inputs_embeds = self.language_model.model.get_input_embeddings(
                    input_ids)
Roger Wang's avatar
Roger Wang committed
260

261
262
263
                inputs_embeds = merge_multimodal_embeddings(
                    input_ids, inputs_embeds, vision_embeddings,
                    self.config.image_token_index)
Roger Wang's avatar
Roger Wang committed
264

265
266
267
                input_ids = None
            else:
                inputs_embeds = None
Roger Wang's avatar
Roger Wang committed
268

269
270
271
272
        hidden_states = self.language_model.model(input_ids,
                                                  positions,
                                                  kv_caches,
                                                  attn_metadata,
273
                                                  intermediate_tensors,
274
                                                  inputs_embeds=inputs_embeds)
Roger Wang's avatar
Roger Wang committed
275
276
277

        return hidden_states

278
279
280
281
282
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
283
284
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)
Roger Wang's avatar
Roger Wang committed
285
286
287
288
289
290

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
291
        return self.language_model.sample(logits, sampling_metadata)
Roger Wang's avatar
Roger Wang committed
292
293

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
294
295
        loader = AutoWeightsLoader(self)
        loader.load_weights(weights)