fuyu.py 14.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# adapted from https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/fuyu/modeling_fuyu.py
# Copyright 2023 The vLLM team.
# Copyright 2023 HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Fuyu model."""
import math
21
from collections.abc import Iterable, Mapping, Sequence
22
from typing import Annotated, Literal, Optional
23
24
25

import torch
import torch.nn as nn
26
27
from transformers import (BatchFeature, FuyuConfig, FuyuImageProcessor,
                          FuyuProcessor)
28

29
from vllm.config import VllmConfig
30
31
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.models.persimmon import PersimmonForCausalLM
32
from vllm.multimodal import MULTIMODAL_REGISTRY
33
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
34
                                    MultiModalKwargsItems)
35
36
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
                                   MultiModalDataItems)
37
from vllm.multimodal.processing import (BaseMultiModalProcessor,
38
                                        BaseProcessingInfo, PromptReplacement,
39
                                        PromptUpdate, PromptUpdateDetails)
40
from vllm.multimodal.profiling import BaseDummyInputsBuilder
41
from vllm.sequence import IntermediateTensors
42
from vllm.utils.tensor_schema import TensorSchema, TensorShape
43

44
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
45
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, maybe_prefix,
46
                    merge_multimodal_embeddings)
47
48
49
50
51
52

# Cannot find the following 2 numbers from hf config.
_IMAGE_TOKEN_ID = 71011
_NEWLINE_TOKEN_ID = 71019


53
class FuyuImagePatchInputs(TensorSchema):
54
    """
55
56
    Dimensions:
        - bn: Batch size * number of images
57
58
        - bnp: Batch size * number of images * number of patches
        - fn: patch_size_x * patch_size_y * num_channels
59
60
    """

61
62
63
64
    type: Literal["image_patches"] = "image_patches"

    flat_data: Annotated[
        torch.Tensor,
65
        TensorShape("bnp", "fn"),
66
67
68
    ]

    patches_per_image: Annotated[list[int], TensorShape("bn")]
69
    """
70
    The number of total patches for each image in the batch.
71
    
72
73
    This is used to split the embeddings which has the first two dimensions
    flattened just like `flat_data`.
74
    """
75

76

77
class FuyuProcessingInfo(BaseProcessingInfo):
78

79
    def get_hf_config(self):
80
        return self.ctx.get_hf_config(FuyuConfig)
81

82
83
    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(FuyuProcessor, **kwargs)
84

85
86
    def get_image_processor(self, **kwargs: object) -> FuyuImageProcessor:
        return self.get_hf_processor(**kwargs).image_processor
87

88
89
90
91
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": 1}

    def get_image_feature_grid_size(
92
93
94
95
96
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> tuple[int, int]:
97
        image_processor = self.get_image_processor()
98
99
        target_width = image_processor.size["width"]
        target_height = image_processor.size["height"]
100
101
        patch_width = image_processor.patch_size["width"]
        patch_height = image_processor.patch_size["height"]
102
103
104
105
106
107
108
109
110

        if not (image_width <= target_width and image_height <= target_height):
            height_scale_factor = target_height / image_height
            width_scale_factor = target_width / image_width
            optimal_scale_factor = min(height_scale_factor, width_scale_factor)

            image_height = int(image_height * optimal_scale_factor)
            image_width = int(image_width * optimal_scale_factor)

111
112
        ncols = math.ceil(image_width / patch_width)
        nrows = math.ceil(image_height / patch_height)
113
114
        return ncols, nrows

115
116
117
118
119
120
121
122
123
124
125
126
127
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        ncols, nrows = self.get_image_feature_grid_size(
            image_width=image_width,
            image_height=image_height,
        )

        return ncols * nrows

128
129
    def get_image_size_with_most_features(self) -> ImageSize:
        image_processor = self.get_image_processor()
130
131
132
        return ImageSize(width=image_processor.size["width"],
                         height=image_processor.size["height"])

133
134
135

class FuyuDummyInputsBuilder(BaseDummyInputsBuilder[FuyuProcessingInfo]):

136
137
138
139
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        return ""

    def get_dummy_mm_data(
140
141
142
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
143
    ) -> MultiModalDataDict:
144
145
        target_width, target_height = \
            self.info.get_image_size_with_most_features()
146
147
        num_images = mm_counts.get("image", 0)

148
        return {
149
150
151
152
153
154
155
            "image":
            self._get_dummy_images(width=target_width,
                                   height=target_height,
                                   num_images=num_images)
        }


156
class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
157
158
159
160
161
162

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
163
        tok_kwargs: Mapping[str, object],
164
165
166
    ) -> BatchFeature:
        if not mm_data:
            # Avoid warning from HF logger for text-only input
167
168
            prompt_ids = self.info.get_tokenizer().encode(prompt)
            prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
169
170
171
172
173
174
            return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
175
            tok_kwargs=tok_kwargs,
176
177
178
179
180
181
182
183
184
        )

        image_patches = processed_outputs.get("image_patches")
        if image_patches is not None:
            images = mm_data["images"]
            assert isinstance(images, list)

            # Original output: (1, num_images, Pn, Px * Py * C)
            # New output: (num_images, Pn, Px * Py * C)
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
            # image_patches is a list with shape:
            # (1, num_images, Pn, Px * Py * C)
            # before Transformers 4.53
            if isinstance(image_patches, list):
                assert len(image_patches) == 1
                assert (isinstance(image_patches[0], torch.Tensor)
                        and len(image_patches[0]) == len(images))
                processed_outputs["image_patches"] = image_patches[0]
            # image_patches is a tensor with shape:
            # (num_images, Pn, Px * Py * C)
            # after Transformers 4.53
            elif isinstance(image_patches, torch.Tensor):
                assert len(image_patches) == len(images)
            else:
                raise AssertionError("This line should be unreachable.")
200
201
202

        return processed_outputs

203
204
205
206
207
208
    def _apply_hf_processor_tokens_only(
        self,
        prompt_tokens: list[int],
    ) -> list[int]:
        # HF processor adds boa_token_id
        tokenizer = self.info.get_tokenizer()
209
210
211
        vocab = tokenizer.get_vocab()

        boa_token_id = vocab["<0x04>"]
212
213
        if prompt_tokens[-1] != boa_token_id:
            prompt_tokens.append(boa_token_id)
214

215
        return prompt_tokens
216

217
218
219
220
221
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
222
        return dict(image_patches=MultiModalFieldConfig.batched("image"))
223

224
    def _get_prompt_updates(
225
226
227
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
228
        out_mm_kwargs: MultiModalKwargsItems,
229
    ) -> Sequence[PromptUpdate]:
230
        hf_config = self.info.get_hf_config()
231
        bos_token_id = hf_config.bos_token_id
232
        assert isinstance(bos_token_id, int)
233

234
        tokenizer = self.info.get_tokenizer()
235
236
237
238
239
240
        eot_token_id = tokenizer.bos_token_id
        assert isinstance(eot_token_id, int)

        def get_replacement_fuyu(item_idx: int):
            images = mm_items.get_items("image", ImageProcessorItems)
            image_size = images.get_image_size(item_idx)
241

242
            ncols, nrows = self.info.get_image_feature_grid_size(
243
244
                image_width=image_size.width,
                image_height=image_size.height,
245
            )
246
247
            image_tokens = ([_IMAGE_TOKEN_ID] * ncols +
                            [_NEWLINE_TOKEN_ID]) * nrows
248

249
250
251
            return PromptUpdateDetails.select_token_id(
                image_tokens + [bos_token_id],
                embed_token_id=_IMAGE_TOKEN_ID,
252
            )
253
254
255
256
257
258
259
260
261
262

        return [
            PromptReplacement(
                modality="image",
                target=[eot_token_id],
                replacement=get_replacement_fuyu,
            )
        ]


263
264
265
@MULTIMODAL_REGISTRY.register_processor(FuyuMultiModalProcessor,
                                        info=FuyuProcessingInfo,
                                        dummy_inputs=FuyuDummyInputsBuilder)
266
class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
267

268
269
270
271
272
273
274
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "model.vision_embed_tokens.": "vision_embed_tokens.",
            "model.language_model.": "language_model.model.",
            "lm_head.": "language_model.lm_head.",
        })

275
276
277
278
279
280
281
    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
        if modality.startswith("image"):
            return None

        raise ValueError("Only image modality is supported")

282
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
283
        super().__init__()
284
285
286
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
287
288
289
        self.config = config
        self.multimodal_config = multimodal_config

290
        self.vocab_size = config.text_config.vocab_size
291
292
293
294
295
296
297
        self.image_token_id = _IMAGE_TOKEN_ID
        self.image_feature_size = config.patch_size**2 * config.num_channels

        self.vision_embed_tokens = ColumnParallelLinear(
            self.image_feature_size,
            config.hidden_size,
            quant_config=quant_config,
298
            gather_output=True,
299
        )
300
        self.language_model = PersimmonForCausalLM(
301
302
303
            vllm_config=vllm_config.with_hf_config(config.text_config),
            prefix=maybe_prefix(prefix, "language_model"),
        )
304
305
306
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

307
    def _parse_and_validate_image_input(
308
309
310
311
            self, **kwargs: object) -> Optional[FuyuImagePatchInputs]:
        image_patches = kwargs.pop("image_patches", None)
        if image_patches is not None:
            image_patches_flat = flatten_bn(image_patches)
312
313
            flat_data = flatten_bn(image_patches_flat, concat=True)

314
315
            return FuyuImagePatchInputs(
                type="image_patches",
316
                flat_data=flat_data,
317
                patches_per_image=[x.size(0) for x in image_patches_flat],
318
                resolve_bindings={"fn": self.image_feature_size},
319
            )
320

321
322
        return None

323
    def _process_image_input(
324
            self, image_input: FuyuImagePatchInputs) -> MultiModalEmbeddings:
325
        image_patches_flat = image_input["flat_data"]
326
        patches_per_image = image_input["patches_per_image"]
327
328

        assert self.vision_embed_tokens is not None
329
330
        vision_embeddings_flat, _ = self.vision_embed_tokens(
            image_patches_flat)
331

332
        return vision_embeddings_flat.split(patches_per_image, dim=0)
333

334
335
336
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

337
338
    def get_multimodal_embeddings(self,
                                  **kwargs: object) -> MultiModalEmbeddings:
339
340
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
341
            return []
342

343
        return self._process_image_input(image_input)
344
345
346
347

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
348
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
349
350
    ) -> torch.Tensor:
        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
351
352
        if multimodal_embeddings is not None \
            and len(multimodal_embeddings) != 0:
353
            inputs_embeds = merge_multimodal_embeddings(
354
355
356
357
358
                input_ids,
                inputs_embeds,
                multimodal_embeddings,
                _IMAGE_TOKEN_ID,
            )
359
360
        return inputs_embeds

361
362
363
364
365
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
366
        inputs_embeds: Optional[torch.Tensor] = None,
367
368
        **kwargs: object,
    ):
369
370
        if intermediate_tensors is not None:
            inputs_embeds = None
371
372
373
374
375
376
377
378

        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
        elif inputs_embeds is None:
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
            input_ids = None
379
380
381
382

        hidden_states = self.language_model(
            input_ids=input_ids,
            positions=positions,
383
            intermediate_tensors=intermediate_tensors,
384
385
386
387
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

388
389
390
391
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
392
        logits = self.language_model.logits_processor(
393
            self.language_model.lm_head, hidden_states)
394
395
        return logits

396
397
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
398
        loader = AutoWeightsLoader(self)
399
        return loader.load_weights(weights)