fuyu.py 13.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# adapted from https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/fuyu/modeling_fuyu.py
# Copyright 2023 The vLLM team.
# Copyright 2023 HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Fuyu model."""
import math
21
from collections.abc import Iterable, Mapping, Sequence
22
from typing import Annotated, Literal, Optional
23
24
25

import torch
import torch.nn as nn
26
27
from transformers import (BatchFeature, FuyuConfig, FuyuImageProcessor,
                          FuyuProcessor)
28

29
from vllm.config import VllmConfig
30
31
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.models.persimmon import PersimmonForCausalLM
32
from vllm.multimodal import MULTIMODAL_REGISTRY
33
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
34
                                    MultiModalKwargsItems)
35
36
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
                                   MultiModalDataItems)
37
from vllm.multimodal.processing import (BaseMultiModalProcessor,
38
                                        BaseProcessingInfo, PromptReplacement,
39
                                        PromptUpdate, PromptUpdateDetails)
40
from vllm.multimodal.profiling import BaseDummyInputsBuilder
41
from vllm.sequence import IntermediateTensors
42
from vllm.utils.tensor_schema import TensorSchema, TensorShape
43

44
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
45
from .utils import AutoWeightsLoader, WeightsMapper, flatten_bn, maybe_prefix
46
47
48
49
50
51

# Cannot find the following 2 numbers from hf config.
_IMAGE_TOKEN_ID = 71011
_NEWLINE_TOKEN_ID = 71019


52
class FuyuImagePatchInputs(TensorSchema):
53
    """
54
55
    Dimensions:
        - bn: Batch size * number of images
56
57
        - bnp: Batch size * number of images * number of patches
        - fn: patch_size_x * patch_size_y * num_channels
58
59
    """

60
61
62
63
    type: Literal["image_patches"] = "image_patches"

    flat_data: Annotated[
        torch.Tensor,
64
        TensorShape("bnp", "fn"),
65
66
67
    ]

    patches_per_image: Annotated[list[int], TensorShape("bn")]
68
    """
69
    The number of total patches for each image in the batch.
70
    
71
72
    This is used to split the embeddings which has the first two dimensions
    flattened just like `flat_data`.
73
    """
74

75

76
class FuyuProcessingInfo(BaseProcessingInfo):
77

78
    def get_hf_config(self):
79
        return self.ctx.get_hf_config(FuyuConfig)
80

81
82
    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(FuyuProcessor, **kwargs)
83

84
85
    def get_image_processor(self, **kwargs: object) -> FuyuImageProcessor:
        return self.get_hf_processor(**kwargs).image_processor
86

87
88
89
90
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": 1}

    def get_image_feature_grid_size(
91
92
93
94
95
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> tuple[int, int]:
96
        image_processor = self.get_image_processor()
97
98
        target_width = image_processor.size["width"]
        target_height = image_processor.size["height"]
99
100
        patch_width = image_processor.patch_size["width"]
        patch_height = image_processor.patch_size["height"]
101
102
103
104
105
106
107
108
109

        if not (image_width <= target_width and image_height <= target_height):
            height_scale_factor = target_height / image_height
            width_scale_factor = target_width / image_width
            optimal_scale_factor = min(height_scale_factor, width_scale_factor)

            image_height = int(image_height * optimal_scale_factor)
            image_width = int(image_width * optimal_scale_factor)

110
111
        ncols = math.ceil(image_width / patch_width)
        nrows = math.ceil(image_height / patch_height)
112
113
        return ncols, nrows

114
115
116
117
118
119
120
121
122
123
124
125
126
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        ncols, nrows = self.get_image_feature_grid_size(
            image_width=image_width,
            image_height=image_height,
        )

        return ncols * nrows

127
128
    def get_image_size_with_most_features(self) -> ImageSize:
        image_processor = self.get_image_processor()
129
130
131
        return ImageSize(width=image_processor.size["width"],
                         height=image_processor.size["height"])

132
133
134

class FuyuDummyInputsBuilder(BaseDummyInputsBuilder[FuyuProcessingInfo]):

135
136
137
138
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        return ""

    def get_dummy_mm_data(
139
140
141
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
142
    ) -> MultiModalDataDict:
143
144
        target_width, target_height = \
            self.info.get_image_size_with_most_features()
145
146
        num_images = mm_counts.get("image", 0)

147
        return {
148
149
150
151
152
153
154
            "image":
            self._get_dummy_images(width=target_width,
                                   height=target_height,
                                   num_images=num_images)
        }


155
class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
156
157
158
159
160
161

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
162
        tok_kwargs: Mapping[str, object],
163
164
165
    ) -> BatchFeature:
        if not mm_data:
            # Avoid warning from HF logger for text-only input
166
167
            prompt_ids = self.info.get_tokenizer().encode(prompt)
            prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
168
169
170
171
172
173
            return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
174
            tok_kwargs=tok_kwargs,
175
176
177
178
179
180
181
182
183
        )

        image_patches = processed_outputs.get("image_patches")
        if image_patches is not None:
            images = mm_data["images"]
            assert isinstance(images, list)

            # Original output: (1, num_images, Pn, Px * Py * C)
            # New output: (num_images, Pn, Px * Py * C)
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
            # image_patches is a list with shape:
            # (1, num_images, Pn, Px * Py * C)
            # before Transformers 4.53
            if isinstance(image_patches, list):
                assert len(image_patches) == 1
                assert (isinstance(image_patches[0], torch.Tensor)
                        and len(image_patches[0]) == len(images))
                processed_outputs["image_patches"] = image_patches[0]
            # image_patches is a tensor with shape:
            # (num_images, Pn, Px * Py * C)
            # after Transformers 4.53
            elif isinstance(image_patches, torch.Tensor):
                assert len(image_patches) == len(images)
            else:
                raise AssertionError("This line should be unreachable.")
199
200
201

        return processed_outputs

202
203
204
205
206
207
    def _apply_hf_processor_tokens_only(
        self,
        prompt_tokens: list[int],
    ) -> list[int]:
        # HF processor adds boa_token_id
        tokenizer = self.info.get_tokenizer()
208
209
210
        vocab = tokenizer.get_vocab()

        boa_token_id = vocab["<0x04>"]
211
212
        if prompt_tokens[-1] != boa_token_id:
            prompt_tokens.append(boa_token_id)
213

214
        return prompt_tokens
215

216
217
218
219
220
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
221
        return dict(image_patches=MultiModalFieldConfig.batched("image"))
222

223
    def _get_prompt_updates(
224
225
226
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
227
        out_mm_kwargs: MultiModalKwargsItems,
228
    ) -> Sequence[PromptUpdate]:
229
        hf_config = self.info.get_hf_config()
230
        bos_token_id = hf_config.bos_token_id
231
        assert isinstance(bos_token_id, int)
232

233
        tokenizer = self.info.get_tokenizer()
234
235
236
237
238
239
        eot_token_id = tokenizer.bos_token_id
        assert isinstance(eot_token_id, int)

        def get_replacement_fuyu(item_idx: int):
            images = mm_items.get_items("image", ImageProcessorItems)
            image_size = images.get_image_size(item_idx)
240

241
            ncols, nrows = self.info.get_image_feature_grid_size(
242
243
                image_width=image_size.width,
                image_height=image_size.height,
244
            )
245
246
            image_tokens = ([_IMAGE_TOKEN_ID] * ncols +
                            [_NEWLINE_TOKEN_ID]) * nrows
247

248
249
250
            return PromptUpdateDetails.select_token_id(
                image_tokens + [bos_token_id],
                embed_token_id=_IMAGE_TOKEN_ID,
251
            )
252
253
254
255
256
257
258
259
260
261

        return [
            PromptReplacement(
                modality="image",
                target=[eot_token_id],
                replacement=get_replacement_fuyu,
            )
        ]


262
263
264
@MULTIMODAL_REGISTRY.register_processor(FuyuMultiModalProcessor,
                                        info=FuyuProcessingInfo,
                                        dummy_inputs=FuyuDummyInputsBuilder)
265
class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
266

267
268
269
270
271
272
273
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "model.vision_embed_tokens.": "vision_embed_tokens.",
            "model.language_model.": "language_model.model.",
            "lm_head.": "language_model.lm_head.",
        })

274
275
276
277
278
279
280
    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
        if modality.startswith("image"):
            return None

        raise ValueError("Only image modality is supported")

281
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
282
        super().__init__()
283
284
285
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
286
287
288
        self.config = config
        self.multimodal_config = multimodal_config

289
        self.vocab_size = config.text_config.vocab_size
290
291
292
293
294
295
296
        self.image_token_id = _IMAGE_TOKEN_ID
        self.image_feature_size = config.patch_size**2 * config.num_channels

        self.vision_embed_tokens = ColumnParallelLinear(
            self.image_feature_size,
            config.hidden_size,
            quant_config=quant_config,
297
            gather_output=True,
298
        )
299
        self.language_model = PersimmonForCausalLM(
300
301
302
            vllm_config=vllm_config.with_hf_config(config.text_config),
            prefix=maybe_prefix(prefix, "language_model"),
        )
303
304
305
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

306
    def _parse_and_validate_image_input(
307
308
309
310
            self, **kwargs: object) -> Optional[FuyuImagePatchInputs]:
        image_patches = kwargs.pop("image_patches", None)
        if image_patches is not None:
            image_patches_flat = flatten_bn(image_patches)
311
312
            flat_data = flatten_bn(image_patches_flat, concat=True)

313
314
            return FuyuImagePatchInputs(
                type="image_patches",
315
                flat_data=flat_data,
316
                patches_per_image=[x.size(0) for x in image_patches_flat],
317
                resolve_bindings={"fn": self.image_feature_size},
318
            )
319

320
321
        return None

322
    def _process_image_input(
323
            self, image_input: FuyuImagePatchInputs) -> MultiModalEmbeddings:
324
        image_patches_flat = image_input["flat_data"]
325
        patches_per_image = image_input["patches_per_image"]
326
327

        assert self.vision_embed_tokens is not None
328
329
        vision_embeddings_flat, _ = self.vision_embed_tokens(
            image_patches_flat)
330

331
        return vision_embeddings_flat.split(patches_per_image, dim=0)
332

333
334
335
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

336
337
    def get_multimodal_embeddings(self,
                                  **kwargs: object) -> MultiModalEmbeddings:
338
339
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
340
            return []
341

342
        return self._process_image_input(image_input)
343

344
345
346
347
348
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
349
        inputs_embeds: Optional[torch.Tensor] = None,
350
351
        **kwargs: object,
    ):
352
353
        if intermediate_tensors is not None:
            inputs_embeds = None
354

355
356
357
        hidden_states = self.language_model(
            input_ids=input_ids,
            positions=positions,
358
            intermediate_tensors=intermediate_tensors,
359
360
361
362
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

363
364
365
366
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
367
        logits = self.language_model.logits_processor(
368
            self.language_model.lm_head, hidden_states)
369
370
        return logits

371
372
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
373
        loader = AutoWeightsLoader(self)
374
        return loader.load_weights(weights)