fuyu.py 12.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# adapted from https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/fuyu/modeling_fuyu.py
# Copyright 2023 The vLLM team.
# Copyright 2023 HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
19
20
"""PyTorch Fuyu model."""

21
import math
22
from collections.abc import Iterable, Mapping, Sequence
23
from typing import Annotated, Literal
24
25
26

import torch
import torch.nn as nn
27
from transformers import BatchFeature, FuyuConfig, FuyuImageProcessor, FuyuProcessor
28

29
from vllm.config import VllmConfig
30
from vllm.config.multimodal import BaseDummyOptions
31
32
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.models.persimmon import PersimmonForCausalLM
33
from vllm.multimodal import MULTIMODAL_REGISTRY
34
35
36
37
38
39
40
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems
from vllm.multimodal.processing import (
41
    BaseDummyInputsBuilder,
42
43
44
45
46
47
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
    PromptUpdateDetails,
)
48
from vllm.sequence import IntermediateTensors
49
from vllm.utils.tensor_schema import TensorSchema, TensorShape
50

51
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
52
from .utils import AutoWeightsLoader, WeightsMapper, flatten_bn, maybe_prefix
53
54
55
56
57
58

# Cannot find the following 2 numbers from hf config.
_IMAGE_TOKEN_ID = 71011
_NEWLINE_TOKEN_ID = 71019


59
class FuyuImagePatchInputs(TensorSchema):
60
    """
61
62
    Dimensions:
        - bn: Batch size * number of images
63
64
        - bnp: Batch size * number of images * number of patches
        - fn: patch_size_x * patch_size_y * num_channels
65
66
    """

67
68
    type: Literal["image_patches"] = "image_patches"

69
    image_patches_flat: Annotated[torch.Tensor, TensorShape("bnp", "fn")]
70
71

    patches_per_image: Annotated[list[int], TensorShape("bn")]
72
    """
73
    The number of total patches for each image in the batch.
74
    
75
    This is used to split the embeddings which has the first two dimensions
76
    flattened just like `image_patches_flat`.
77
    """
78

79

80
81
class FuyuProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self):
82
        return self.ctx.get_hf_config(FuyuConfig)
83

84
85
    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(FuyuProcessor, **kwargs)
86

87
88
    def get_image_processor(self, **kwargs: object) -> FuyuImageProcessor:
        return self.get_hf_processor(**kwargs).image_processor
89

90
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
91
92
93
        return {"image": 1}

    def get_image_feature_grid_size(
94
95
96
97
98
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> tuple[int, int]:
99
        image_processor = self.get_image_processor()
100
101
        target_width = image_processor.size["width"]
        target_height = image_processor.size["height"]
102
103
        patch_width = image_processor.patch_size["width"]
        patch_height = image_processor.patch_size["height"]
104
105
106
107
108
109
110
111
112

        if not (image_width <= target_width and image_height <= target_height):
            height_scale_factor = target_height / image_height
            width_scale_factor = target_width / image_width
            optimal_scale_factor = min(height_scale_factor, width_scale_factor)

            image_height = int(image_height * optimal_scale_factor)
            image_width = int(image_width * optimal_scale_factor)

113
114
        ncols = math.ceil(image_width / patch_width)
        nrows = math.ceil(image_height / patch_height)
115
116
        return ncols, nrows

117
118
119
120
121
122
123
124
125
126
127
128
129
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        ncols, nrows = self.get_image_feature_grid_size(
            image_width=image_width,
            image_height=image_height,
        )

        return ncols * nrows

130
131
    def get_image_size_with_most_features(self) -> ImageSize:
        image_processor = self.get_image_processor()
132
133
134
        return ImageSize(
            width=image_processor.size["width"], height=image_processor.size["height"]
        )
135

136
137

class FuyuDummyInputsBuilder(BaseDummyInputsBuilder[FuyuProcessingInfo]):
138
139
140
141
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        return ""

    def get_dummy_mm_data(
142
143
144
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
145
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
146
        mm_processor_kwargs: Mapping[str, object] | None = None,
147
    ) -> MultiModalDataDict:
148
        target_width, target_height = self.info.get_image_size_with_most_features()
149
150
        num_images = mm_counts.get("image", 0)

151
152
        image_overrides = mm_options.get("image") if mm_options else None

153
        return {
154
155
156
157
158
159
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            )
160
161
162
        }


163
class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
164
165
166
167
168
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
169
        tok_kwargs: Mapping[str, object],
170
171
172
    ) -> BatchFeature:
        if not mm_data:
            # Avoid warning from HF logger for text-only input
173
174
            prompt_ids = self.info.get_tokenizer().encode(prompt)
            prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
175
176
177
178
179
180
            return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
181
            tok_kwargs=tok_kwargs,
182
183
        )

184
185
186
        image_patches = processed_outputs["image_patches"]
        processed_outputs["image_patches"] = flatten_bn(image_patches)
        processed_outputs["patches_per_image"] = torch.tensor(
187
188
            [len(p) for p in image_patches]
        )
189
190
191

        return processed_outputs

192
193
194
195
196
197
    def _apply_hf_processor_tokens_only(
        self,
        prompt_tokens: list[int],
    ) -> list[int]:
        # HF processor adds boa_token_id
        tokenizer = self.info.get_tokenizer()
198
199
200
        vocab = tokenizer.get_vocab()

        boa_token_id = vocab["<0x04>"]
201
202
        if prompt_tokens[-1] != boa_token_id:
            prompt_tokens.append(boa_token_id)
203

204
        return prompt_tokens
205

206
207
208
209
210
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
211
212
213
214
        patches_per_image = hf_inputs.get("patches_per_image", torch.empty(0))

        return dict(
            image_patches=MultiModalFieldConfig.flat_from_sizes(
215
216
                "image", patches_per_image
            ),
217
218
            patches_per_image=MultiModalFieldConfig.batched("image"),
        )
219

220
    def _get_prompt_updates(
221
222
223
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
224
        out_mm_kwargs: MultiModalKwargsItems,
225
    ) -> Sequence[PromptUpdate]:
226
        hf_config = self.info.get_hf_config()
227
        bos_token_id = hf_config.bos_token_id
228
        assert isinstance(bos_token_id, int)
229

230
        tokenizer = self.info.get_tokenizer()
231
232
233
234
235
236
        eot_token_id = tokenizer.bos_token_id
        assert isinstance(eot_token_id, int)

        def get_replacement_fuyu(item_idx: int):
            images = mm_items.get_items("image", ImageProcessorItems)
            image_size = images.get_image_size(item_idx)
237

238
            ncols, nrows = self.info.get_image_feature_grid_size(
239
240
                image_width=image_size.width,
                image_height=image_size.height,
241
            )
242
            image_tokens = ([_IMAGE_TOKEN_ID] * ncols + [_NEWLINE_TOKEN_ID]) * nrows
243

244
245
246
            return PromptUpdateDetails.select_token_id(
                image_tokens + [bos_token_id],
                embed_token_id=_IMAGE_TOKEN_ID,
247
            )
248
249
250
251
252
253
254
255
256
257

        return [
            PromptReplacement(
                modality="image",
                target=[eot_token_id],
                replacement=get_replacement_fuyu,
            )
        ]


258
259
260
261
262
@MULTIMODAL_REGISTRY.register_processor(
    FuyuMultiModalProcessor,
    info=FuyuProcessingInfo,
    dummy_inputs=FuyuDummyInputsBuilder,
)
263
class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
264
265
266
267
268
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "model.vision_embed_tokens.": "vision_embed_tokens.",
            "model.language_model.": "language_model.model.",
            "lm_head.": "language_model.lm_head.",
269
270
        }
    )
271

272
    @classmethod
273
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
274
275
276
277
278
        if modality.startswith("image"):
            return None

        raise ValueError("Only image modality is supported")

279
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
280
        super().__init__()
281
282
283
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
284
285
286
        self.config = config
        self.multimodal_config = multimodal_config

287
        self.vocab_size = config.text_config.vocab_size
288
289
290
        self.image_token_id = _IMAGE_TOKEN_ID
        self.image_feature_size = config.patch_size**2 * config.num_channels

291
292
293
294
295
296
297
298
299
300
301
302
303
304
        with self._mark_tower_model(vllm_config, "image"):
            self.vision_embed_tokens = ColumnParallelLinear(
                self.image_feature_size,
                config.hidden_size,
                quant_config=quant_config,
                gather_output=True,
            )

        with self._mark_language_model(vllm_config):
            self.language_model = PersimmonForCausalLM(
                vllm_config=vllm_config.with_hf_config(config.text_config),
                prefix=maybe_prefix(prefix, "language_model"),
            )

305
        self.make_empty_intermediate_tensors = (
306
307
            self.language_model.make_empty_intermediate_tensors
        )
308

309
    def _parse_and_validate_image_input(
310
        self, **kwargs: object
311
    ) -> FuyuImagePatchInputs | None:
312
        image_patches = kwargs.pop("image_patches", None)
313
        patches_per_image = kwargs.pop("patches_per_image", None)
314

315
316
317
318
319
320
321
322
323
        if image_patches is None:
            return None

        return FuyuImagePatchInputs(
            type="image_patches",
            image_patches_flat=image_patches,
            patches_per_image=patches_per_image,
            resolve_bindings={"fn": self.image_feature_size},
        )
324

325
    def _process_image_input(
326
327
        self, image_input: FuyuImagePatchInputs
    ) -> MultiModalEmbeddings:
328
        image_patches_flat = image_input["image_patches_flat"]
329
        patches_per_image = image_input["patches_per_image"]
330

331
        vision_embeddings_flat, _ = self.vision_embed_tokens(image_patches_flat)
332

333
        return vision_embeddings_flat.split(patches_per_image.tolist(), dim=0)
334

335
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
336
337
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
338
            return []
339

340
        return self._process_image_input(image_input)
341

342
343
    def forward(
        self,
344
        input_ids: torch.Tensor | None,
345
        positions: torch.Tensor,
346
347
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
348
349
        **kwargs: object,
    ):
350
351
        if intermediate_tensors is not None:
            inputs_embeds = None
352

353
354
355
        hidden_states = self.language_model(
            input_ids=input_ids,
            positions=positions,
356
            intermediate_tensors=intermediate_tensors,
357
358
359
360
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

361
362
363
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
364
    ) -> torch.Tensor | None:
365
        return self.language_model.compute_logits(hidden_states)
366

367
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
368
        loader = AutoWeightsLoader(self)
369
        return loader.load_weights(weights)