paligemma.py 11.7 KB
Newer Older
1
2
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
                    TypedDict, Union)
Roger Wang's avatar
Roger Wang committed
3
4
5

import torch
from torch import nn
6
from transformers import PaliGemmaConfig
Roger Wang's avatar
Roger Wang committed
7
8
9
10
11

from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger
12
from vllm.model_executor.layers.quantization import QuantizationConfig
13
from vllm.model_executor.layers.sampler import SamplerOutput
Roger Wang's avatar
Roger Wang committed
14
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
15
from vllm.model_executor.models.gemma import GemmaForCausalLM
Roger Wang's avatar
Roger Wang committed
16
17
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
18
from vllm.multimodal.utils import cached_get_tokenizer
19
from vllm.sequence import IntermediateTensors
Roger Wang's avatar
Roger Wang committed
20

21
from .interfaces import SupportsMultiModal, SupportsPP
22
23
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
                     dummy_seq_data_for_siglip, get_max_siglip_image_tokens)
24
from .utils import group_weights_with_prefix, merge_multimodal_embeddings
Roger Wang's avatar
Roger Wang committed
25
26
27
28

logger = init_logger(__name__)


29
30
31
class PaliGemmaImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
    data: torch.Tensor
32
    """Shape: `(batch_size * num_images, num_channels, height, width)`"""
33
34
35
36
37


class PaliGemmaImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
    data: torch.Tensor
38
    """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
39
40
41
42
43
44
45
46
47

    `hidden_size` must match the hidden size of language model backbone.
    """


PaliGemmaImageInputs = Union[PaliGemmaImagePixelInputs,
                             PaliGemmaImageEmbeddingInputs]


Roger Wang's avatar
Roger Wang committed
48
49
def get_max_paligemma_image_tokens(ctx: InputContext):
    hf_config = ctx.get_hf_config(PaliGemmaConfig)
50
    vision_config = hf_config.vision_config
Roger Wang's avatar
Roger Wang committed
51

52
    return get_max_siglip_image_tokens(vision_config)
Roger Wang's avatar
Roger Wang committed
53
54


55
56
def dummy_data_for_paligemma(ctx: InputContext, seq_len: int,
                             mm_counts: Mapping[str, int]):
Roger Wang's avatar
Roger Wang committed
57
58
    hf_config = ctx.get_hf_config(PaliGemmaConfig)
    vision_config = hf_config.vision_config
59
    num_images = mm_counts["image"]
Roger Wang's avatar
Roger Wang committed
60

61
62
    seq_data = dummy_seq_data_for_siglip(
        vision_config,
Roger Wang's avatar
Roger Wang committed
63
        seq_len,
64
        num_images,
Roger Wang's avatar
Roger Wang committed
65
66
67
        image_token_id=hf_config.image_token_index,
    )

68
    mm_data = dummy_image_for_siglip(vision_config, num_images)
Roger Wang's avatar
Roger Wang committed
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
    return seq_data, mm_data


def input_processor_for_paligemma(ctx: InputContext, llm_inputs: LLMInputs):

    """
    The correct prompt format needs to be:
    '<image>' * image_feature_size + '<bos>' + prompt + '\n'

    See https://github.com/huggingface/transformers/blob/25245ec26dc29bcf6102e1b4ddd0dfd02e720cf5/src/transformers/models/paligemma/processing_paligemma.py#L55
    """ # noqa

    multi_modal_data = llm_inputs.get("multi_modal_data")
    if multi_modal_data is None or "image" not in multi_modal_data:
        return llm_inputs

    model_config = ctx.model_config
    hf_config = ctx.get_hf_config(PaliGemmaConfig)

    tokenizer = cached_get_tokenizer(model_config.tokenizer)
    image_feature_size = hf_config.text_config.num_image_tokens
    image_token_str = tokenizer.decode(hf_config.image_token_index)
    bos_token = tokenizer.decode(hf_config.bos_token_id)
    image_token_str_pad = image_token_str * image_feature_size
    image_token_ids_pad = [hf_config.image_token_index] * image_feature_size

    orig_prompt = llm_inputs.get("prompt")
    orig_prompt_ids = llm_inputs.get("prompt_token_ids")

98
    if orig_prompt is not None and image_token_str in orig_prompt:
Roger Wang's avatar
Roger Wang committed
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
        logger.warning(
            "The image token '%s' was detected in the prompt and "
            "will be removed. Please follow the proper prompt format"
            " documented on HuggingFace.", image_token_str)
        orig_prompt = orig_prompt.replace(image_token_str, "")
        orig_prompt_ids.remove(hf_config.image_token_index)

    new_prompt = f"{image_token_str_pad}{bos_token}{orig_prompt}\n"
    new_token_ids = image_token_ids_pad + orig_prompt_ids + [108]  #newline

    # NOTE: Create a defensive copy of the original inputs
    return LLMInputs(prompt_token_ids=new_token_ids,
                     prompt=new_prompt,
                     multi_modal_data=multi_modal_data)


class PaliGemmaMultiModalProjector(nn.Module):

    def __init__(self, vision_hidden_size: int, projection_dim: int):
        super().__init__()

120
        self.linear = nn.Linear(vision_hidden_size, projection_dim, bias=True)
Roger Wang's avatar
Roger Wang committed
121
122

    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
123
        hidden_states = self.linear(image_features)
Roger Wang's avatar
Roger Wang committed
124
125
126
127
128
129
130
        return hidden_states


@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_paligemma_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_paligemma)
@INPUT_REGISTRY.register_input_processor(input_processor_for_paligemma)
131
132
class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
                                        SupportsPP):
Roger Wang's avatar
Roger Wang committed
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149

    def __init__(self,
                 config: PaliGemmaConfig,
                 multimodal_config: MultiModalConfig,
                 cache_config: Optional[CacheConfig] = None,
                 quant_config: Optional[QuantizationConfig] = None) -> None:
        super().__init__()

        self.config = config
        self.multimodal_config = multimodal_config

        self.vision_tower = SiglipVisionModel(config.vision_config)
        self.multi_modal_projector = PaliGemmaMultiModalProjector(
            vision_hidden_size=config.vision_config.hidden_size,
            projection_dim=config.vision_config.projection_dim)

        self.quant_config = quant_config
150
151
        self.language_model = GemmaForCausalLM(config.text_config,
                                               cache_config, quant_config)
Roger Wang's avatar
Roger Wang committed
152
        logit_scale = getattr(config, "logit_scale", 1.0)
153
154
155
156
157
158
159
160
        self.language_model.logits_processor.scale *= logit_scale

        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

    @property
    def sampler(self):
        return self.language_model.sampler
Roger Wang's avatar
Roger Wang committed
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177

    def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)
        actual_dims = tuple(data.shape[1:])

        if actual_dims != expected_dims:
            expected_expr = ("batch_size", *map(str, expected_dims))
            raise ValueError(
                f"The expected shape of pixel values is {expected_expr}. "
                f"You supplied {tuple(data.shape)}.")

        return data

    def _parse_and_validate_image_input(
            self, **kwargs: object) -> Optional[PaliGemmaImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
178
        image_embeds = kwargs.pop("image_embeds", None)
Roger Wang's avatar
Roger Wang committed
179

180
        if pixel_values is None and image_embeds is None:
Roger Wang's avatar
Roger Wang committed
181
182
            return None

183
184
185
186
        if pixel_values is not None:
            if not isinstance(pixel_values, torch.Tensor):
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")
187
188
189
190

            # Remove the N dimension until multiple images are supported.
            pixel_values = pixel_values.squeeze(1)

191
192
193
194
195
196
197
198
199
            return PaliGemmaImagePixelInputs(
                type="pixel_values",
                data=self._validate_pixel_values(pixel_values),
            )

        if image_embeds is not None:
            if not isinstance(image_embeds, torch.Tensor):
                raise ValueError("Incorrect type of image embeddings. "
                                 f"Got type: {type(image_embeds)}")
200
201
202
203

            # Remove the N dimension until multiple images are supported.
            image_embeds = image_embeds.squeeze(1)

204
205
206
207
208
209
            return PaliGemmaImageEmbeddingInputs(
                type="image_embeds",
                data=image_embeds,
            )

        raise AssertionError("This line should be unreachable.")
Roger Wang's avatar
Roger Wang committed
210

211
212
213
214
215
    def _image_pixels_to_features(
        self,
        vision_tower: SiglipVisionModel,
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:
Roger Wang's avatar
Roger Wang committed
216

217
        target_dtype = vision_tower.get_input_embeddings().weight.dtype
218
        image_features = vision_tower(pixel_values.to(dtype=target_dtype))
Roger Wang's avatar
Roger Wang committed
219

220
        return image_features
Roger Wang's avatar
Roger Wang committed
221

222
    def _process_image_input(
223
        self,
224
        image_input: PaliGemmaImageInputs,
225
    ) -> torch.Tensor:
Roger Wang's avatar
Roger Wang committed
226

227
228
        if image_input["type"] == "image_embeds":
            return image_input["data"]
Roger Wang's avatar
Roger Wang committed
229

230
231
232
        assert self.vision_tower is not None
        pixel_values = image_input["data"]
        image_features = self._image_pixels_to_features(
233
234
235
            self.vision_tower,
            pixel_values,
        )
Roger Wang's avatar
Roger Wang committed
236
237
238

        return self.multi_modal_projector(image_features)

239
240
241
    def forward(self,
                input_ids: torch.Tensor,
                positions: torch.Tensor,
Roger Wang's avatar
Roger Wang committed
242
243
                kv_caches: List[torch.Tensor],
                attn_metadata: AttentionMetadata,
244
                intermediate_tensors: Optional[IntermediateTensors] = None,
245
246
247
248
249
250
                **kwargs: object) -> Union[SamplerOutput, IntermediateTensors]:
        if intermediate_tensors is not None:
            input_ids = None
            inputs_embeds = None
        else:
            parsed_image_input = self._parse_and_validate_image_input(**kwargs)
Roger Wang's avatar
Roger Wang committed
251

252
253
254
255
256
257
            if parsed_image_input is not None:
                vision_embeddings = self._process_image_input(
                    parsed_image_input)
                # https://github.com/huggingface/transformers/blob/main/src/transformers/models/paligemma/modeling_paligemma.py#L294 # noqa
                vision_embeddings = vision_embeddings * (
                    self.config.hidden_size**-0.5)
Roger Wang's avatar
Roger Wang committed
258

259
260
                inputs_embeds = self.language_model.model.get_input_embeddings(
                    input_ids)
Roger Wang's avatar
Roger Wang committed
261

262
263
264
                inputs_embeds = merge_multimodal_embeddings(
                    input_ids, inputs_embeds, vision_embeddings,
                    self.config.image_token_index)
Roger Wang's avatar
Roger Wang committed
265

266
267
268
                input_ids = None
            else:
                inputs_embeds = None
Roger Wang's avatar
Roger Wang committed
269

270
271
272
273
        hidden_states = self.language_model.model(input_ids,
                                                  positions,
                                                  kv_caches,
                                                  attn_metadata,
274
                                                  intermediate_tensors,
275
                                                  inputs_embeds=inputs_embeds)
Roger Wang's avatar
Roger Wang committed
276
277
278

        return hidden_states

279
280
281
282
283
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
284
285
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)
Roger Wang's avatar
Roger Wang committed
286
287
288
289
290
291

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
292
        return self.language_model.sample(logits, sampling_metadata)
Roger Wang's avatar
Roger Wang committed
293
294

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
295
        # prepare weight iterators for components
296
        weights_group = group_weights_with_prefix(weights)
297
298

        # load vision tower
299
        self.vision_tower.load_weights(weights_group["vision_tower"])
300
301
302

        # load mlp projector
        mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
303
        for name, loaded_weight in weights_group["multi_modal_projector"]:
304
305
306
307
308
309
            param = mlp_params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)

        # load llm backbone
310
        self.language_model.load_weights(weights_group["language_model"])