fuyu.py 12.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# adapted from https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/fuyu/modeling_fuyu.py
# Copyright 2023 The vLLM team.
# Copyright 2023 HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
19
20
"""PyTorch Fuyu model."""

21
import math
22
from collections.abc import Iterable, Mapping, Sequence
23
from typing import Annotated, Literal
24
25
26

import torch
import torch.nn as nn
27
from transformers import BatchFeature, FuyuConfig, FuyuImageProcessor, FuyuProcessor
28

29
from vllm.config import VllmConfig
30
from vllm.config.multimodal import BaseDummyOptions
31
32
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.models.persimmon import PersimmonForCausalLM
33
from vllm.multimodal import MULTIMODAL_REGISTRY
34
35
36
37
38
39
40
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems
from vllm.multimodal.processing import (
41
    BaseDummyInputsBuilder,
42
43
44
45
46
47
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
    PromptUpdateDetails,
)
48
from vllm.sequence import IntermediateTensors
49
from vllm.utils.tensor_schema import TensorSchema, TensorShape
50

51
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
52
from .utils import AutoWeightsLoader, WeightsMapper, flatten_bn, maybe_prefix
53
54
55
56
57
58

# Cannot find the following 2 numbers from hf config.
_IMAGE_TOKEN_ID = 71011
_NEWLINE_TOKEN_ID = 71019


59
class FuyuImagePatchInputs(TensorSchema):
60
    """
61
62
    Dimensions:
        - bn: Batch size * number of images
63
64
        - bnp: Batch size * number of images * number of patches
        - fn: patch_size_x * patch_size_y * num_channels
65
66
    """

67
68
    type: Literal["image_patches"] = "image_patches"

69
    image_patches_flat: Annotated[torch.Tensor, TensorShape("bnp", "fn")]
70
71

    patches_per_image: Annotated[list[int], TensorShape("bn")]
72
    """
73
    The number of total patches for each image in the batch.
74
    
75
    This is used to split the embeddings which has the first two dimensions
76
    flattened just like `image_patches_flat`.
77
    """
78

79

80
81
class FuyuProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self):
82
        return self.ctx.get_hf_config(FuyuConfig)
83

84
85
    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(FuyuProcessor, **kwargs)
86

87
88
    def get_image_processor(self, **kwargs: object) -> FuyuImageProcessor:
        return self.get_hf_processor(**kwargs).image_processor
89

90
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
91
92
93
        return {"image": 1}

    def get_image_feature_grid_size(
94
95
96
97
98
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> tuple[int, int]:
99
        image_processor = self.get_image_processor()
100
101
        target_width = image_processor.size["width"]
        target_height = image_processor.size["height"]
102
103
        patch_width = image_processor.patch_size["width"]
        patch_height = image_processor.patch_size["height"]
104
105
106
107
108
109
110
111
112

        if not (image_width <= target_width and image_height <= target_height):
            height_scale_factor = target_height / image_height
            width_scale_factor = target_width / image_width
            optimal_scale_factor = min(height_scale_factor, width_scale_factor)

            image_height = int(image_height * optimal_scale_factor)
            image_width = int(image_width * optimal_scale_factor)

113
114
        ncols = math.ceil(image_width / patch_width)
        nrows = math.ceil(image_height / patch_height)
115
116
        return ncols, nrows

117
118
119
120
121
122
123
124
125
126
127
128
129
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        ncols, nrows = self.get_image_feature_grid_size(
            image_width=image_width,
            image_height=image_height,
        )

        return ncols * nrows

130
131
    def get_image_size_with_most_features(self) -> ImageSize:
        image_processor = self.get_image_processor()
132
133
134
        return ImageSize(
            width=image_processor.size["width"], height=image_processor.size["height"]
        )
135

136
137

class FuyuDummyInputsBuilder(BaseDummyInputsBuilder[FuyuProcessingInfo]):
138
139
140
141
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        return ""

    def get_dummy_mm_data(
142
143
144
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
145
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
146
    ) -> MultiModalDataDict:
147
        target_width, target_height = self.info.get_image_size_with_most_features()
148
149
        num_images = mm_counts.get("image", 0)

150
151
        image_overrides = mm_options.get("image") if mm_options else None

152
        return {
153
154
155
156
157
158
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            )
159
160
161
        }


162
class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
163
164
165
166
167
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
168
        tok_kwargs: Mapping[str, object],
169
170
171
    ) -> BatchFeature:
        if not mm_data:
            # Avoid warning from HF logger for text-only input
172
173
            prompt_ids = self.info.get_tokenizer().encode(prompt)
            prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
174
175
176
177
178
179
            return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
180
            tok_kwargs=tok_kwargs,
181
182
        )

183
184
185
        image_patches = processed_outputs["image_patches"]
        processed_outputs["image_patches"] = flatten_bn(image_patches)
        processed_outputs["patches_per_image"] = torch.tensor(
186
187
            [len(p) for p in image_patches]
        )
188
189
190

        return processed_outputs

191
192
193
194
195
196
    def _apply_hf_processor_tokens_only(
        self,
        prompt_tokens: list[int],
    ) -> list[int]:
        # HF processor adds boa_token_id
        tokenizer = self.info.get_tokenizer()
197
198
199
        vocab = tokenizer.get_vocab()

        boa_token_id = vocab["<0x04>"]
200
201
        if prompt_tokens[-1] != boa_token_id:
            prompt_tokens.append(boa_token_id)
202

203
        return prompt_tokens
204

205
206
207
208
209
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
210
211
212
213
        patches_per_image = hf_inputs.get("patches_per_image", torch.empty(0))

        return dict(
            image_patches=MultiModalFieldConfig.flat_from_sizes(
214
215
                "image", patches_per_image
            ),
216
217
            patches_per_image=MultiModalFieldConfig.batched("image"),
        )
218

219
    def _get_prompt_updates(
220
221
222
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
223
        out_mm_kwargs: MultiModalKwargsItems,
224
    ) -> Sequence[PromptUpdate]:
225
        hf_config = self.info.get_hf_config()
226
        bos_token_id = hf_config.bos_token_id
227
        assert isinstance(bos_token_id, int)
228

229
        tokenizer = self.info.get_tokenizer()
230
231
232
233
234
235
        eot_token_id = tokenizer.bos_token_id
        assert isinstance(eot_token_id, int)

        def get_replacement_fuyu(item_idx: int):
            images = mm_items.get_items("image", ImageProcessorItems)
            image_size = images.get_image_size(item_idx)
236

237
            ncols, nrows = self.info.get_image_feature_grid_size(
238
239
                image_width=image_size.width,
                image_height=image_size.height,
240
            )
241
            image_tokens = ([_IMAGE_TOKEN_ID] * ncols + [_NEWLINE_TOKEN_ID]) * nrows
242

243
244
245
            return PromptUpdateDetails.select_token_id(
                image_tokens + [bos_token_id],
                embed_token_id=_IMAGE_TOKEN_ID,
246
            )
247
248
249
250
251
252
253
254
255
256

        return [
            PromptReplacement(
                modality="image",
                target=[eot_token_id],
                replacement=get_replacement_fuyu,
            )
        ]


257
258
259
260
261
@MULTIMODAL_REGISTRY.register_processor(
    FuyuMultiModalProcessor,
    info=FuyuProcessingInfo,
    dummy_inputs=FuyuDummyInputsBuilder,
)
262
class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
263
264
265
266
267
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "model.vision_embed_tokens.": "vision_embed_tokens.",
            "model.language_model.": "language_model.model.",
            "lm_head.": "language_model.lm_head.",
268
269
        }
    )
270

271
    @classmethod
272
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
273
274
275
276
277
        if modality.startswith("image"):
            return None

        raise ValueError("Only image modality is supported")

278
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
279
        super().__init__()
280
281
282
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
283
284
285
        self.config = config
        self.multimodal_config = multimodal_config

286
        self.vocab_size = config.text_config.vocab_size
287
288
289
290
291
292
293
        self.image_token_id = _IMAGE_TOKEN_ID
        self.image_feature_size = config.patch_size**2 * config.num_channels

        self.vision_embed_tokens = ColumnParallelLinear(
            self.image_feature_size,
            config.hidden_size,
            quant_config=quant_config,
294
            gather_output=True,
295
        )
296
        self.language_model = PersimmonForCausalLM(
297
298
299
            vllm_config=vllm_config.with_hf_config(config.text_config),
            prefix=maybe_prefix(prefix, "language_model"),
        )
300
        self.make_empty_intermediate_tensors = (
301
302
            self.language_model.make_empty_intermediate_tensors
        )
303

304
    def _parse_and_validate_image_input(
305
        self, **kwargs: object
306
    ) -> FuyuImagePatchInputs | None:
307
        image_patches = kwargs.pop("image_patches", None)
308
        patches_per_image = kwargs.pop("patches_per_image", None)
309

310
311
312
313
314
315
316
317
318
        if image_patches is None:
            return None

        return FuyuImagePatchInputs(
            type="image_patches",
            image_patches_flat=image_patches,
            patches_per_image=patches_per_image,
            resolve_bindings={"fn": self.image_feature_size},
        )
319

320
    def _process_image_input(
321
322
        self, image_input: FuyuImagePatchInputs
    ) -> MultiModalEmbeddings:
323
        image_patches_flat = image_input["image_patches_flat"]
324
        patches_per_image = image_input["patches_per_image"]
325
326

        assert self.vision_embed_tokens is not None
327
        vision_embeddings_flat, _ = self.vision_embed_tokens(image_patches_flat)
328

329
        return vision_embeddings_flat.split(patches_per_image.tolist(), dim=0)
330

331
332
333
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

334
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
335
336
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
337
            return []
338

339
        return self._process_image_input(image_input)
340

341
342
343
344
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
345
346
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
347
348
        **kwargs: object,
    ):
349
350
        if intermediate_tensors is not None:
            inputs_embeds = None
351

352
353
354
        hidden_states = self.language_model(
            input_ids=input_ids,
            positions=positions,
355
            intermediate_tensors=intermediate_tensors,
356
357
358
359
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

360
361
362
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
363
    ) -> torch.Tensor | None:
364
        logits = self.language_model.logits_processor(
365
366
            self.language_model.lm_head, hidden_states
        )
367
368
        return logits

369
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
370
        loader = AutoWeightsLoader(self)
371
        return loader.load_weights(weights)