fuyu.py 12.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# adapted from https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/fuyu/modeling_fuyu.py
# Copyright 2023 The vLLM team.
# Copyright 2023 HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
19
20
"""PyTorch Fuyu model."""

21
import math
22
from collections.abc import Iterable, Mapping, Sequence
23
from typing import Annotated, Literal, Optional
24
25
26

import torch
import torch.nn as nn
27
from transformers import BatchFeature, FuyuConfig, FuyuImageProcessor, FuyuProcessor
28

29
from vllm.config import VllmConfig
30
from vllm.config.multimodal import BaseDummyOptions
31
32
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.models.persimmon import PersimmonForCausalLM
33
from vllm.multimodal import MULTIMODAL_REGISTRY
34
35
36
37
38
39
40
41
42
43
44
45
46
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems
from vllm.multimodal.processing import (
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
    PromptUpdateDetails,
)
47
from vllm.multimodal.profiling import BaseDummyInputsBuilder
48
from vllm.sequence import IntermediateTensors
49
from vllm.utils.tensor_schema import TensorSchema, TensorShape
50

51
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
52
from .utils import AutoWeightsLoader, WeightsMapper, flatten_bn, maybe_prefix
53
54
55
56
57
58

# Cannot find the following 2 numbers from hf config.
_IMAGE_TOKEN_ID = 71011
_NEWLINE_TOKEN_ID = 71019


59
class FuyuImagePatchInputs(TensorSchema):
60
    """
61
62
    Dimensions:
        - bn: Batch size * number of images
63
64
        - bnp: Batch size * number of images * number of patches
        - fn: patch_size_x * patch_size_y * num_channels
65
66
    """

67
68
    type: Literal["image_patches"] = "image_patches"

69
    image_patches_flat: Annotated[torch.Tensor, TensorShape("bnp", "fn")]
70
71

    patches_per_image: Annotated[list[int], TensorShape("bn")]
72
    """
73
    The number of total patches for each image in the batch.
74
    
75
    This is used to split the embeddings which has the first two dimensions
76
    flattened just like `image_patches_flat`.
77
    """
78

79

80
81
class FuyuProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self):
82
        return self.ctx.get_hf_config(FuyuConfig)
83

84
85
    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(FuyuProcessor, **kwargs)
86

87
88
    def get_image_processor(self, **kwargs: object) -> FuyuImageProcessor:
        return self.get_hf_processor(**kwargs).image_processor
89

90
91
92
93
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": 1}

    def get_image_feature_grid_size(
94
95
96
97
98
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> tuple[int, int]:
99
        image_processor = self.get_image_processor()
100
101
        target_width = image_processor.size["width"]
        target_height = image_processor.size["height"]
102
103
        patch_width = image_processor.patch_size["width"]
        patch_height = image_processor.patch_size["height"]
104
105
106
107
108
109
110
111
112

        if not (image_width <= target_width and image_height <= target_height):
            height_scale_factor = target_height / image_height
            width_scale_factor = target_width / image_width
            optimal_scale_factor = min(height_scale_factor, width_scale_factor)

            image_height = int(image_height * optimal_scale_factor)
            image_width = int(image_width * optimal_scale_factor)

113
114
        ncols = math.ceil(image_width / patch_width)
        nrows = math.ceil(image_height / patch_height)
115
116
        return ncols, nrows

117
118
119
120
121
122
123
124
125
126
127
128
129
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        ncols, nrows = self.get_image_feature_grid_size(
            image_width=image_width,
            image_height=image_height,
        )

        return ncols * nrows

130
131
    def get_image_size_with_most_features(self) -> ImageSize:
        image_processor = self.get_image_processor()
132
133
134
        return ImageSize(
            width=image_processor.size["width"], height=image_processor.size["height"]
        )
135

136
137

class FuyuDummyInputsBuilder(BaseDummyInputsBuilder[FuyuProcessingInfo]):
138
139
140
141
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        return ""

    def get_dummy_mm_data(
142
143
144
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
145
        mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
146
    ) -> MultiModalDataDict:
147
        target_width, target_height = self.info.get_image_size_with_most_features()
148
149
        num_images = mm_counts.get("image", 0)

150
151
        image_overrides = mm_options.get("image") if mm_options else None

152
        return {
153
154
155
156
157
158
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            )
159
160
161
        }


162
class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
163
164
165
166
167
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
168
        tok_kwargs: Mapping[str, object],
169
170
171
    ) -> BatchFeature:
        if not mm_data:
            # Avoid warning from HF logger for text-only input
172
173
            prompt_ids = self.info.get_tokenizer().encode(prompt)
            prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
174
175
176
177
178
179
            return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
180
            tok_kwargs=tok_kwargs,
181
182
        )

183
184
185
        image_patches = processed_outputs["image_patches"]
        processed_outputs["image_patches"] = flatten_bn(image_patches)
        processed_outputs["patches_per_image"] = torch.tensor(
186
187
            [len(p) for p in image_patches]
        )
188
189
190

        return processed_outputs

191
192
193
194
195
196
    def _apply_hf_processor_tokens_only(
        self,
        prompt_tokens: list[int],
    ) -> list[int]:
        # HF processor adds boa_token_id
        tokenizer = self.info.get_tokenizer()
197
198
199
        vocab = tokenizer.get_vocab()

        boa_token_id = vocab["<0x04>"]
200
201
        if prompt_tokens[-1] != boa_token_id:
            prompt_tokens.append(boa_token_id)
202

203
        return prompt_tokens
204

205
206
207
208
209
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
210
211
212
213
        patches_per_image = hf_inputs.get("patches_per_image", torch.empty(0))

        return dict(
            image_patches=MultiModalFieldConfig.flat_from_sizes(
214
215
                "image", patches_per_image
            ),
216
217
            patches_per_image=MultiModalFieldConfig.batched("image"),
        )
218

219
    def _get_prompt_updates(
220
221
222
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
223
        out_mm_kwargs: MultiModalKwargsItems,
224
    ) -> Sequence[PromptUpdate]:
225
        hf_config = self.info.get_hf_config()
226
        bos_token_id = hf_config.bos_token_id
227
        assert isinstance(bos_token_id, int)
228

229
        tokenizer = self.info.get_tokenizer()
230
231
232
233
234
235
        eot_token_id = tokenizer.bos_token_id
        assert isinstance(eot_token_id, int)

        def get_replacement_fuyu(item_idx: int):
            images = mm_items.get_items("image", ImageProcessorItems)
            image_size = images.get_image_size(item_idx)
236

237
            ncols, nrows = self.info.get_image_feature_grid_size(
238
239
                image_width=image_size.width,
                image_height=image_size.height,
240
            )
241
            image_tokens = ([_IMAGE_TOKEN_ID] * ncols + [_NEWLINE_TOKEN_ID]) * nrows
242

243
244
245
            return PromptUpdateDetails.select_token_id(
                image_tokens + [bos_token_id],
                embed_token_id=_IMAGE_TOKEN_ID,
246
            )
247
248
249
250
251
252
253
254
255
256

        return [
            PromptReplacement(
                modality="image",
                target=[eot_token_id],
                replacement=get_replacement_fuyu,
            )
        ]


257
258
259
260
261
@MULTIMODAL_REGISTRY.register_processor(
    FuyuMultiModalProcessor,
    info=FuyuProcessingInfo,
    dummy_inputs=FuyuDummyInputsBuilder,
)
262
class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
263
    merge_by_field_config = True
264

265
266
267
268
269
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "model.vision_embed_tokens.": "vision_embed_tokens.",
            "model.language_model.": "language_model.model.",
            "lm_head.": "language_model.lm_head.",
270
271
        }
    )
272

273
274
275
276
277
278
279
    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
        if modality.startswith("image"):
            return None

        raise ValueError("Only image modality is supported")

280
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
281
        super().__init__()
282
283
284
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
285
286
287
        self.config = config
        self.multimodal_config = multimodal_config

288
        self.vocab_size = config.text_config.vocab_size
289
290
291
292
293
294
295
        self.image_token_id = _IMAGE_TOKEN_ID
        self.image_feature_size = config.patch_size**2 * config.num_channels

        self.vision_embed_tokens = ColumnParallelLinear(
            self.image_feature_size,
            config.hidden_size,
            quant_config=quant_config,
296
            gather_output=True,
297
        )
298
        self.language_model = PersimmonForCausalLM(
299
300
301
            vllm_config=vllm_config.with_hf_config(config.text_config),
            prefix=maybe_prefix(prefix, "language_model"),
        )
302
        self.make_empty_intermediate_tensors = (
303
304
            self.language_model.make_empty_intermediate_tensors
        )
305

306
    def _parse_and_validate_image_input(
307
308
        self, **kwargs: object
    ) -> Optional[FuyuImagePatchInputs]:
309
        image_patches = kwargs.pop("image_patches", None)
310
        patches_per_image = kwargs.pop("patches_per_image", None)
311

312
313
314
315
316
317
318
319
320
        if image_patches is None:
            return None

        return FuyuImagePatchInputs(
            type="image_patches",
            image_patches_flat=image_patches,
            patches_per_image=patches_per_image,
            resolve_bindings={"fn": self.image_feature_size},
        )
321

322
    def _process_image_input(
323
324
        self, image_input: FuyuImagePatchInputs
    ) -> MultiModalEmbeddings:
325
        image_patches_flat = image_input["image_patches_flat"]
326
        patches_per_image = image_input["patches_per_image"]
327
328

        assert self.vision_embed_tokens is not None
329
        vision_embeddings_flat, _ = self.vision_embed_tokens(image_patches_flat)
330

331
        return vision_embeddings_flat.split(patches_per_image.tolist(), dim=0)
332

333
334
335
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

336
    def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
337
338
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
339
            return []
340

341
        return self._process_image_input(image_input)
342

343
344
345
346
347
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
348
        inputs_embeds: Optional[torch.Tensor] = None,
349
350
        **kwargs: object,
    ):
351
352
        if intermediate_tensors is not None:
            inputs_embeds = None
353

354
355
356
        hidden_states = self.language_model(
            input_ids=input_ids,
            positions=positions,
357
            intermediate_tensors=intermediate_tensors,
358
359
360
361
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

362
363
364
365
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
366
        logits = self.language_model.logits_processor(
367
368
            self.language_model.lm_head, hidden_states
        )
369
370
        return logits

371
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
372
        loader = AutoWeightsLoader(self)
373
        return loader.load_weights(weights)