fuyu.py 12.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# adapted from https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/fuyu/modeling_fuyu.py
# Copyright 2023 The vLLM team.
# Copyright 2023 HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Fuyu model."""
import math
21
from collections.abc import Iterable, Mapping, Sequence
22
from typing import Annotated, Literal, Optional
23
24
25

import torch
import torch.nn as nn
26
27
from transformers import (BatchFeature, FuyuConfig, FuyuImageProcessor,
                          FuyuProcessor)
28

29
from vllm.config import VllmConfig
30
31
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.models.persimmon import PersimmonForCausalLM
32
from vllm.multimodal import MULTIMODAL_REGISTRY
33
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
34
                                    MultiModalKwargsItems)
35
36
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
                                   MultiModalDataItems)
37
from vllm.multimodal.processing import (BaseMultiModalProcessor,
38
                                        BaseProcessingInfo, PromptReplacement,
39
                                        PromptUpdate, PromptUpdateDetails)
40
from vllm.multimodal.profiling import BaseDummyInputsBuilder
41
from vllm.sequence import IntermediateTensors
42
from vllm.utils.tensor_schema import TensorSchema, TensorShape
43

44
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
45
from .utils import AutoWeightsLoader, WeightsMapper, flatten_bn, maybe_prefix
46
47
48
49
50
51

# Cannot find the following 2 numbers from hf config.
_IMAGE_TOKEN_ID = 71011
_NEWLINE_TOKEN_ID = 71019


52
class FuyuImagePatchInputs(TensorSchema):
53
    """
54
55
    Dimensions:
        - bn: Batch size * number of images
56
57
        - bnp: Batch size * number of images * number of patches
        - fn: patch_size_x * patch_size_y * num_channels
58
59
    """

60
61
    type: Literal["image_patches"] = "image_patches"

62
    image_patches_flat: Annotated[torch.Tensor, TensorShape("bnp", "fn")]
63
64

    patches_per_image: Annotated[list[int], TensorShape("bn")]
65
    """
66
    The number of total patches for each image in the batch.
67
    
68
    This is used to split the embeddings which has the first two dimensions
69
    flattened just like `image_patches_flat`.
70
    """
71

72

73
class FuyuProcessingInfo(BaseProcessingInfo):
74

75
    def get_hf_config(self):
76
        return self.ctx.get_hf_config(FuyuConfig)
77

78
79
    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(FuyuProcessor, **kwargs)
80

81
82
    def get_image_processor(self, **kwargs: object) -> FuyuImageProcessor:
        return self.get_hf_processor(**kwargs).image_processor
83

84
85
86
87
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": 1}

    def get_image_feature_grid_size(
88
89
90
91
92
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> tuple[int, int]:
93
        image_processor = self.get_image_processor()
94
95
        target_width = image_processor.size["width"]
        target_height = image_processor.size["height"]
96
97
        patch_width = image_processor.patch_size["width"]
        patch_height = image_processor.patch_size["height"]
98
99
100
101
102
103
104
105
106

        if not (image_width <= target_width and image_height <= target_height):
            height_scale_factor = target_height / image_height
            width_scale_factor = target_width / image_width
            optimal_scale_factor = min(height_scale_factor, width_scale_factor)

            image_height = int(image_height * optimal_scale_factor)
            image_width = int(image_width * optimal_scale_factor)

107
108
        ncols = math.ceil(image_width / patch_width)
        nrows = math.ceil(image_height / patch_height)
109
110
        return ncols, nrows

111
112
113
114
115
116
117
118
119
120
121
122
123
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        ncols, nrows = self.get_image_feature_grid_size(
            image_width=image_width,
            image_height=image_height,
        )

        return ncols * nrows

124
125
    def get_image_size_with_most_features(self) -> ImageSize:
        image_processor = self.get_image_processor()
126
127
128
        return ImageSize(width=image_processor.size["width"],
                         height=image_processor.size["height"])

129
130
131

class FuyuDummyInputsBuilder(BaseDummyInputsBuilder[FuyuProcessingInfo]):

132
133
134
135
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        return ""

    def get_dummy_mm_data(
136
137
138
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
139
    ) -> MultiModalDataDict:
140
141
        target_width, target_height = \
            self.info.get_image_size_with_most_features()
142
143
        num_images = mm_counts.get("image", 0)

144
        return {
145
146
147
148
149
150
151
            "image":
            self._get_dummy_images(width=target_width,
                                   height=target_height,
                                   num_images=num_images)
        }


152
class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
153
154
155
156
157
158

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
159
        tok_kwargs: Mapping[str, object],
160
161
162
    ) -> BatchFeature:
        if not mm_data:
            # Avoid warning from HF logger for text-only input
163
164
            prompt_ids = self.info.get_tokenizer().encode(prompt)
            prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
165
166
167
168
169
170
            return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
171
            tok_kwargs=tok_kwargs,
172
173
        )

174
175
176
177
        image_patches = processed_outputs["image_patches"]
        processed_outputs["image_patches"] = flatten_bn(image_patches)
        processed_outputs["patches_per_image"] = torch.tensor(
            [len(p) for p in image_patches])
178
179
180

        return processed_outputs

181
182
183
184
185
186
    def _apply_hf_processor_tokens_only(
        self,
        prompt_tokens: list[int],
    ) -> list[int]:
        # HF processor adds boa_token_id
        tokenizer = self.info.get_tokenizer()
187
188
189
        vocab = tokenizer.get_vocab()

        boa_token_id = vocab["<0x04>"]
190
191
        if prompt_tokens[-1] != boa_token_id:
            prompt_tokens.append(boa_token_id)
192

193
        return prompt_tokens
194

195
196
197
198
199
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
200
201
202
203
204
205
206
        patches_per_image = hf_inputs.get("patches_per_image", torch.empty(0))

        return dict(
            image_patches=MultiModalFieldConfig.flat_from_sizes(
                "image", patches_per_image),
            patches_per_image=MultiModalFieldConfig.batched("image"),
        )
207

208
    def _get_prompt_updates(
209
210
211
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
212
        out_mm_kwargs: MultiModalKwargsItems,
213
    ) -> Sequence[PromptUpdate]:
214
        hf_config = self.info.get_hf_config()
215
        bos_token_id = hf_config.bos_token_id
216
        assert isinstance(bos_token_id, int)
217

218
        tokenizer = self.info.get_tokenizer()
219
220
221
222
223
224
        eot_token_id = tokenizer.bos_token_id
        assert isinstance(eot_token_id, int)

        def get_replacement_fuyu(item_idx: int):
            images = mm_items.get_items("image", ImageProcessorItems)
            image_size = images.get_image_size(item_idx)
225

226
            ncols, nrows = self.info.get_image_feature_grid_size(
227
228
                image_width=image_size.width,
                image_height=image_size.height,
229
            )
230
231
            image_tokens = ([_IMAGE_TOKEN_ID] * ncols +
                            [_NEWLINE_TOKEN_ID]) * nrows
232

233
234
235
            return PromptUpdateDetails.select_token_id(
                image_tokens + [bos_token_id],
                embed_token_id=_IMAGE_TOKEN_ID,
236
            )
237
238
239
240
241
242
243
244
245
246

        return [
            PromptReplacement(
                modality="image",
                target=[eot_token_id],
                replacement=get_replacement_fuyu,
            )
        ]


247
248
249
@MULTIMODAL_REGISTRY.register_processor(FuyuMultiModalProcessor,
                                        info=FuyuProcessingInfo,
                                        dummy_inputs=FuyuDummyInputsBuilder)
250
class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
251
    merge_by_field_config = True
252

253
254
255
256
257
258
259
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "model.vision_embed_tokens.": "vision_embed_tokens.",
            "model.language_model.": "language_model.model.",
            "lm_head.": "language_model.lm_head.",
        })

260
261
262
263
264
265
266
    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
        if modality.startswith("image"):
            return None

        raise ValueError("Only image modality is supported")

267
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
268
        super().__init__()
269
270
271
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
272
273
274
        self.config = config
        self.multimodal_config = multimodal_config

275
        self.vocab_size = config.text_config.vocab_size
276
277
278
279
280
281
282
        self.image_token_id = _IMAGE_TOKEN_ID
        self.image_feature_size = config.patch_size**2 * config.num_channels

        self.vision_embed_tokens = ColumnParallelLinear(
            self.image_feature_size,
            config.hidden_size,
            quant_config=quant_config,
283
            gather_output=True,
284
        )
285
        self.language_model = PersimmonForCausalLM(
286
287
288
            vllm_config=vllm_config.with_hf_config(config.text_config),
            prefix=maybe_prefix(prefix, "language_model"),
        )
289
290
291
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

292
    def _parse_and_validate_image_input(
293
294
            self, **kwargs: object) -> Optional[FuyuImagePatchInputs]:
        image_patches = kwargs.pop("image_patches", None)
295
        patches_per_image = kwargs.pop("patches_per_image", None)
296

297
298
299
300
301
302
303
304
305
        if image_patches is None:
            return None

        return FuyuImagePatchInputs(
            type="image_patches",
            image_patches_flat=image_patches,
            patches_per_image=patches_per_image,
            resolve_bindings={"fn": self.image_feature_size},
        )
306

307
    def _process_image_input(
308
            self, image_input: FuyuImagePatchInputs) -> MultiModalEmbeddings:
309
        image_patches_flat = image_input["image_patches_flat"]
310
        patches_per_image = image_input["patches_per_image"]
311
312

        assert self.vision_embed_tokens is not None
313
314
        vision_embeddings_flat, _ = self.vision_embed_tokens(
            image_patches_flat)
315

316
        return vision_embeddings_flat.split(patches_per_image.tolist(), dim=0)
317

318
319
320
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

321
322
    def get_multimodal_embeddings(self,
                                  **kwargs: object) -> MultiModalEmbeddings:
323
324
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
325
            return []
326

327
        return self._process_image_input(image_input)
328

329
330
331
332
333
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
334
        inputs_embeds: Optional[torch.Tensor] = None,
335
336
        **kwargs: object,
    ):
337
338
        if intermediate_tensors is not None:
            inputs_embeds = None
339

340
341
342
        hidden_states = self.language_model(
            input_ids=input_ids,
            positions=positions,
343
            intermediate_tensors=intermediate_tensors,
344
345
346
347
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

348
349
350
351
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
352
        logits = self.language_model.logits_processor(
353
            self.language_model.lm_head, hidden_states)
354
355
        return logits

356
357
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
358
        loader = AutoWeightsLoader(self)
359
        return loader.load_weights(weights)