fuyu.py 13 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# adapted from https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/fuyu/modeling_fuyu.py
# Copyright 2023 The vLLM team.
# Copyright 2023 HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Fuyu model."""
import math
21
from collections.abc import Iterable, Mapping, Sequence
22
from typing import Annotated, Literal, Optional
23
24
25

import torch
import torch.nn as nn
26
27
from transformers import (BatchFeature, FuyuConfig, FuyuImageProcessor,
                          FuyuProcessor)
28

29
from vllm.config import VllmConfig
30
from vllm.config.multimodal import BaseDummyOptions
31
32
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.models.persimmon import PersimmonForCausalLM
33
from vllm.multimodal import MULTIMODAL_REGISTRY
34
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
35
                                    MultiModalKwargsItems)
36
37
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
                                   MultiModalDataItems)
38
from vllm.multimodal.processing import (BaseMultiModalProcessor,
39
                                        BaseProcessingInfo, PromptReplacement,
40
                                        PromptUpdate, PromptUpdateDetails)
41
from vllm.multimodal.profiling import BaseDummyInputsBuilder
42
from vllm.sequence import IntermediateTensors
43
from vllm.utils.tensor_schema import TensorSchema, TensorShape
44

45
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
46
from .utils import AutoWeightsLoader, WeightsMapper, flatten_bn, maybe_prefix
47
48
49
50
51
52

# Cannot find the following 2 numbers from hf config.
_IMAGE_TOKEN_ID = 71011
_NEWLINE_TOKEN_ID = 71019


53
class FuyuImagePatchInputs(TensorSchema):
54
    """
55
56
    Dimensions:
        - bn: Batch size * number of images
57
58
        - bnp: Batch size * number of images * number of patches
        - fn: patch_size_x * patch_size_y * num_channels
59
60
    """

61
62
    type: Literal["image_patches"] = "image_patches"

63
    image_patches_flat: Annotated[torch.Tensor, TensorShape("bnp", "fn")]
64
65

    patches_per_image: Annotated[list[int], TensorShape("bn")]
66
    """
67
    The number of total patches for each image in the batch.
68
    
69
    This is used to split the embeddings which has the first two dimensions
70
    flattened just like `image_patches_flat`.
71
    """
72

73

74
class FuyuProcessingInfo(BaseProcessingInfo):
75

76
    def get_hf_config(self):
77
        return self.ctx.get_hf_config(FuyuConfig)
78

79
80
    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(FuyuProcessor, **kwargs)
81

82
83
    def get_image_processor(self, **kwargs: object) -> FuyuImageProcessor:
        return self.get_hf_processor(**kwargs).image_processor
84

85
86
87
88
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": 1}

    def get_image_feature_grid_size(
89
90
91
92
93
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> tuple[int, int]:
94
        image_processor = self.get_image_processor()
95
96
        target_width = image_processor.size["width"]
        target_height = image_processor.size["height"]
97
98
        patch_width = image_processor.patch_size["width"]
        patch_height = image_processor.patch_size["height"]
99
100
101
102
103
104
105
106
107

        if not (image_width <= target_width and image_height <= target_height):
            height_scale_factor = target_height / image_height
            width_scale_factor = target_width / image_width
            optimal_scale_factor = min(height_scale_factor, width_scale_factor)

            image_height = int(image_height * optimal_scale_factor)
            image_width = int(image_width * optimal_scale_factor)

108
109
        ncols = math.ceil(image_width / patch_width)
        nrows = math.ceil(image_height / patch_height)
110
111
        return ncols, nrows

112
113
114
115
116
117
118
119
120
121
122
123
124
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        ncols, nrows = self.get_image_feature_grid_size(
            image_width=image_width,
            image_height=image_height,
        )

        return ncols * nrows

125
126
    def get_image_size_with_most_features(self) -> ImageSize:
        image_processor = self.get_image_processor()
127
128
129
        return ImageSize(width=image_processor.size["width"],
                         height=image_processor.size["height"])

130
131
132

class FuyuDummyInputsBuilder(BaseDummyInputsBuilder[FuyuProcessingInfo]):

133
134
135
136
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        return ""

    def get_dummy_mm_data(
137
138
139
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
140
        mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
141
    ) -> MultiModalDataDict:
142
143
        target_width, target_height = \
            self.info.get_image_size_with_most_features()
144
145
        num_images = mm_counts.get("image", 0)

146
147
        image_overrides = mm_options.get("image") if mm_options else None

148
        return {
149
150
151
            "image":
            self._get_dummy_images(width=target_width,
                                   height=target_height,
152
153
                                   num_images=num_images,
                                   overrides=image_overrides)
154
155
156
        }


157
class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
158
159
160
161
162
163

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
164
        tok_kwargs: Mapping[str, object],
165
166
167
    ) -> BatchFeature:
        if not mm_data:
            # Avoid warning from HF logger for text-only input
168
169
            prompt_ids = self.info.get_tokenizer().encode(prompt)
            prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
170
171
172
173
174
175
            return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
176
            tok_kwargs=tok_kwargs,
177
178
        )

179
180
181
182
        image_patches = processed_outputs["image_patches"]
        processed_outputs["image_patches"] = flatten_bn(image_patches)
        processed_outputs["patches_per_image"] = torch.tensor(
            [len(p) for p in image_patches])
183
184
185

        return processed_outputs

186
187
188
189
190
191
    def _apply_hf_processor_tokens_only(
        self,
        prompt_tokens: list[int],
    ) -> list[int]:
        # HF processor adds boa_token_id
        tokenizer = self.info.get_tokenizer()
192
193
194
        vocab = tokenizer.get_vocab()

        boa_token_id = vocab["<0x04>"]
195
196
        if prompt_tokens[-1] != boa_token_id:
            prompt_tokens.append(boa_token_id)
197

198
        return prompt_tokens
199

200
201
202
203
204
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
205
206
207
208
209
210
211
        patches_per_image = hf_inputs.get("patches_per_image", torch.empty(0))

        return dict(
            image_patches=MultiModalFieldConfig.flat_from_sizes(
                "image", patches_per_image),
            patches_per_image=MultiModalFieldConfig.batched("image"),
        )
212

213
    def _get_prompt_updates(
214
215
216
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
217
        out_mm_kwargs: MultiModalKwargsItems,
218
    ) -> Sequence[PromptUpdate]:
219
        hf_config = self.info.get_hf_config()
220
        bos_token_id = hf_config.bos_token_id
221
        assert isinstance(bos_token_id, int)
222

223
        tokenizer = self.info.get_tokenizer()
224
225
226
227
228
229
        eot_token_id = tokenizer.bos_token_id
        assert isinstance(eot_token_id, int)

        def get_replacement_fuyu(item_idx: int):
            images = mm_items.get_items("image", ImageProcessorItems)
            image_size = images.get_image_size(item_idx)
230

231
            ncols, nrows = self.info.get_image_feature_grid_size(
232
233
                image_width=image_size.width,
                image_height=image_size.height,
234
            )
235
236
            image_tokens = ([_IMAGE_TOKEN_ID] * ncols +
                            [_NEWLINE_TOKEN_ID]) * nrows
237

238
239
240
            return PromptUpdateDetails.select_token_id(
                image_tokens + [bos_token_id],
                embed_token_id=_IMAGE_TOKEN_ID,
241
            )
242
243
244
245
246
247
248
249
250
251

        return [
            PromptReplacement(
                modality="image",
                target=[eot_token_id],
                replacement=get_replacement_fuyu,
            )
        ]


252
253
254
@MULTIMODAL_REGISTRY.register_processor(FuyuMultiModalProcessor,
                                        info=FuyuProcessingInfo,
                                        dummy_inputs=FuyuDummyInputsBuilder)
255
class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
256
    merge_by_field_config = True
257

258
259
260
261
262
263
264
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "model.vision_embed_tokens.": "vision_embed_tokens.",
            "model.language_model.": "language_model.model.",
            "lm_head.": "language_model.lm_head.",
        })

265
266
267
268
269
270
271
    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
        if modality.startswith("image"):
            return None

        raise ValueError("Only image modality is supported")

272
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
273
        super().__init__()
274
275
276
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
277
278
279
        self.config = config
        self.multimodal_config = multimodal_config

280
        self.vocab_size = config.text_config.vocab_size
281
282
283
284
285
286
287
        self.image_token_id = _IMAGE_TOKEN_ID
        self.image_feature_size = config.patch_size**2 * config.num_channels

        self.vision_embed_tokens = ColumnParallelLinear(
            self.image_feature_size,
            config.hidden_size,
            quant_config=quant_config,
288
            gather_output=True,
289
        )
290
        self.language_model = PersimmonForCausalLM(
291
292
293
            vllm_config=vllm_config.with_hf_config(config.text_config),
            prefix=maybe_prefix(prefix, "language_model"),
        )
294
295
296
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

297
    def _parse_and_validate_image_input(
298
299
            self, **kwargs: object) -> Optional[FuyuImagePatchInputs]:
        image_patches = kwargs.pop("image_patches", None)
300
        patches_per_image = kwargs.pop("patches_per_image", None)
301

302
303
304
305
306
307
308
309
310
        if image_patches is None:
            return None

        return FuyuImagePatchInputs(
            type="image_patches",
            image_patches_flat=image_patches,
            patches_per_image=patches_per_image,
            resolve_bindings={"fn": self.image_feature_size},
        )
311

312
    def _process_image_input(
313
            self, image_input: FuyuImagePatchInputs) -> MultiModalEmbeddings:
314
        image_patches_flat = image_input["image_patches_flat"]
315
        patches_per_image = image_input["patches_per_image"]
316
317

        assert self.vision_embed_tokens is not None
318
319
        vision_embeddings_flat, _ = self.vision_embed_tokens(
            image_patches_flat)
320

321
        return vision_embeddings_flat.split(patches_per_image.tolist(), dim=0)
322

323
324
325
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

326
327
    def get_multimodal_embeddings(self,
                                  **kwargs: object) -> MultiModalEmbeddings:
328
329
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
330
            return []
331

332
        return self._process_image_input(image_input)
333

334
335
336
337
338
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
339
        inputs_embeds: Optional[torch.Tensor] = None,
340
341
        **kwargs: object,
    ):
342
343
        if intermediate_tensors is not None:
            inputs_embeds = None
344

345
346
347
        hidden_states = self.language_model(
            input_ids=input_ids,
            positions=positions,
348
            intermediate_tensors=intermediate_tensors,
349
350
351
352
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

353
354
355
356
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
357
        logits = self.language_model.logits_processor(
358
            self.language_model.lm_head, hidden_states)
359
360
        return logits

361
362
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
363
        loader = AutoWeightsLoader(self)
364
        return loader.load_weights(weights)