fuyu.py 14.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# adapted from https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/fuyu/modeling_fuyu.py
# Copyright 2023 The vLLM team.
# Copyright 2023 HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Fuyu model."""
import math
21
from collections.abc import Iterable, Mapping, Sequence
22
from typing import Annotated, Literal, Optional
23
24
25

import torch
import torch.nn as nn
26
27
from transformers import (BatchFeature, FuyuConfig, FuyuImageProcessor,
                          FuyuProcessor)
28

29
from vllm.config import VllmConfig
30
31
32
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.models.persimmon import PersimmonForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata
33
from vllm.multimodal import MULTIMODAL_REGISTRY
34
35
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
                                    MultiModalKwargs)
36
37
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
                                   MultiModalDataItems)
38
from vllm.multimodal.processing import (BaseMultiModalProcessor,
39
                                        BaseProcessingInfo, PromptReplacement,
40
                                        PromptUpdate, PromptUpdateDetails)
41
from vllm.multimodal.profiling import BaseDummyInputsBuilder
42
from vllm.sequence import IntermediateTensors
43
from vllm.utils.tensor_schema import TensorSchema, TensorShape
44

45
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
46
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, maybe_prefix,
47
                    merge_multimodal_embeddings)
48
49
50
51
52
53

# Cannot find the following 2 numbers from hf config.
_IMAGE_TOKEN_ID = 71011
_NEWLINE_TOKEN_ID = 71019


54
class FuyuImagePatchInputs(TensorSchema):
55
    """
56
57
    Dimensions:
        - bn: Batch size * number of images
58
59
        - bnp: Batch size * number of images * number of patches
        - fn: patch_size_x * patch_size_y * num_channels
60
61
    """

62
63
64
65
    type: Literal["image_patches"] = "image_patches"

    flat_data: Annotated[
        torch.Tensor,
66
        TensorShape("bnp", "fn"),
67
68
69
    ]

    patches_per_image: Annotated[list[int], TensorShape("bn")]
70
    """
71
    The number of total patches for each image in the batch.
72
    
73
74
    This is used to split the embeddings which has the first two dimensions
    flattened just like `flat_data`.
75
    """
76

77

78
class FuyuProcessingInfo(BaseProcessingInfo):
79

80
    def get_hf_config(self):
81
        return self.ctx.get_hf_config(FuyuConfig)
82

83
84
    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(FuyuProcessor, **kwargs)
85

86
87
    def get_image_processor(self, **kwargs: object) -> FuyuImageProcessor:
        return self.get_hf_processor(**kwargs).image_processor
88

89
90
91
92
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": 1}

    def get_image_feature_grid_size(
93
94
95
96
97
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> tuple[int, int]:
98
        image_processor = self.get_image_processor()
99
100
        target_width = image_processor.size["width"]
        target_height = image_processor.size["height"]
101
102
        patch_width = image_processor.patch_size["width"]
        patch_height = image_processor.patch_size["height"]
103
104
105
106
107
108
109
110
111

        if not (image_width <= target_width and image_height <= target_height):
            height_scale_factor = target_height / image_height
            width_scale_factor = target_width / image_width
            optimal_scale_factor = min(height_scale_factor, width_scale_factor)

            image_height = int(image_height * optimal_scale_factor)
            image_width = int(image_width * optimal_scale_factor)

112
113
        ncols = math.ceil(image_width / patch_width)
        nrows = math.ceil(image_height / patch_height)
114
115
        return ncols, nrows

116
117
118
119
120
121
122
123
124
125
126
127
128
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        ncols, nrows = self.get_image_feature_grid_size(
            image_width=image_width,
            image_height=image_height,
        )

        return ncols * nrows

129
130
    def get_image_size_with_most_features(self) -> ImageSize:
        image_processor = self.get_image_processor()
131
132
133
        return ImageSize(width=image_processor.size["width"],
                         height=image_processor.size["height"])

134
135
136

class FuyuDummyInputsBuilder(BaseDummyInputsBuilder[FuyuProcessingInfo]):

137
138
139
140
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        return ""

    def get_dummy_mm_data(
141
142
143
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
144
    ) -> MultiModalDataDict:
145
146
        target_width, target_height = \
            self.info.get_image_size_with_most_features()
147
148
        num_images = mm_counts.get("image", 0)

149
        return {
150
151
152
153
154
155
156
            "image":
            self._get_dummy_images(width=target_width,
                                   height=target_height,
                                   num_images=num_images)
        }


157
class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
158
159
160
161
162
163

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
164
        tok_kwargs: Mapping[str, object],
165
166
167
    ) -> BatchFeature:
        if not mm_data:
            # Avoid warning from HF logger for text-only input
168
169
            prompt_ids = self.info.get_tokenizer().encode(prompt)
            prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
170
171
172
173
174
175
            return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
176
            tok_kwargs=tok_kwargs,
177
178
179
180
181
182
183
184
185
        )

        image_patches = processed_outputs.get("image_patches")
        if image_patches is not None:
            images = mm_data["images"]
            assert isinstance(images, list)

            # Original output: (1, num_images, Pn, Px * Py * C)
            # New output: (num_images, Pn, Px * Py * C)
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
            # image_patches is a list with shape:
            # (1, num_images, Pn, Px * Py * C)
            # before Transformers 4.53
            if isinstance(image_patches, list):
                assert len(image_patches) == 1
                assert (isinstance(image_patches[0], torch.Tensor)
                        and len(image_patches[0]) == len(images))
                processed_outputs["image_patches"] = image_patches[0]
            # image_patches is a tensor with shape:
            # (num_images, Pn, Px * Py * C)
            # after Transformers 4.53
            elif isinstance(image_patches, torch.Tensor):
                assert len(image_patches) == len(images)
            else:
                raise AssertionError("This line should be unreachable.")
201
202
203

        return processed_outputs

204
205
206
207
208
209
    def _apply_hf_processor_tokens_only(
        self,
        prompt_tokens: list[int],
    ) -> list[int]:
        # HF processor adds boa_token_id
        tokenizer = self.info.get_tokenizer()
210
211
212
        vocab = tokenizer.get_vocab()

        boa_token_id = vocab["<0x04>"]
213
214
        if prompt_tokens[-1] != boa_token_id:
            prompt_tokens.append(boa_token_id)
215

216
        return prompt_tokens
217

218
219
220
221
222
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
223
        return dict(image_patches=MultiModalFieldConfig.batched("image"))
224

225
    def _get_prompt_updates(
226
227
228
229
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
230
    ) -> Sequence[PromptUpdate]:
231
        hf_config = self.info.get_hf_config()
232
        bos_token_id = hf_config.bos_token_id
233
        assert isinstance(bos_token_id, int)
234

235
        tokenizer = self.info.get_tokenizer()
236
237
238
239
240
241
        eot_token_id = tokenizer.bos_token_id
        assert isinstance(eot_token_id, int)

        def get_replacement_fuyu(item_idx: int):
            images = mm_items.get_items("image", ImageProcessorItems)
            image_size = images.get_image_size(item_idx)
242

243
            ncols, nrows = self.info.get_image_feature_grid_size(
244
245
                image_width=image_size.width,
                image_height=image_size.height,
246
            )
247
248
            image_tokens = ([_IMAGE_TOKEN_ID] * ncols +
                            [_NEWLINE_TOKEN_ID]) * nrows
249

250
251
252
            return PromptUpdateDetails.select_token_id(
                image_tokens + [bos_token_id],
                embed_token_id=_IMAGE_TOKEN_ID,
253
            )
254
255
256
257
258
259
260
261
262
263

        return [
            PromptReplacement(
                modality="image",
                target=[eot_token_id],
                replacement=get_replacement_fuyu,
            )
        ]


264
265
266
@MULTIMODAL_REGISTRY.register_processor(FuyuMultiModalProcessor,
                                        info=FuyuProcessingInfo,
                                        dummy_inputs=FuyuDummyInputsBuilder)
267
class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
268

269
270
271
272
273
274
275
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "model.vision_embed_tokens.": "vision_embed_tokens.",
            "model.language_model.": "language_model.model.",
            "lm_head.": "language_model.lm_head.",
        })

276
277
278
279
280
281
282
    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
        if modality.startswith("image"):
            return None

        raise ValueError("Only image modality is supported")

283
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
284
        super().__init__()
285
286
287
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
288
289
290
        self.config = config
        self.multimodal_config = multimodal_config

291
        self.vocab_size = config.text_config.vocab_size
292
293
294
295
296
297
298
        self.image_token_id = _IMAGE_TOKEN_ID
        self.image_feature_size = config.patch_size**2 * config.num_channels

        self.vision_embed_tokens = ColumnParallelLinear(
            self.image_feature_size,
            config.hidden_size,
            quant_config=quant_config,
299
            gather_output=True,
300
        )
301
        self.language_model = PersimmonForCausalLM(
302
303
304
            vllm_config=vllm_config.with_hf_config(config.text_config),
            prefix=maybe_prefix(prefix, "language_model"),
        )
305
306
307
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

308
    def _parse_and_validate_image_input(
309
310
311
312
            self, **kwargs: object) -> Optional[FuyuImagePatchInputs]:
        image_patches = kwargs.pop("image_patches", None)
        if image_patches is not None:
            image_patches_flat = flatten_bn(image_patches)
313
314
            flat_data = flatten_bn(image_patches_flat, concat=True)

315
316
            return FuyuImagePatchInputs(
                type="image_patches",
317
                flat_data=flat_data,
318
                patches_per_image=[x.size(0) for x in image_patches_flat],
319
                resolve_bindings={"fn": self.image_feature_size},
320
            )
321

322
323
        return None

324
    def _process_image_input(
325
            self, image_input: FuyuImagePatchInputs) -> MultiModalEmbeddings:
326
        image_patches_flat = image_input["flat_data"]
327
        patches_per_image = image_input["patches_per_image"]
328
329

        assert self.vision_embed_tokens is not None
330
331
        vision_embeddings_flat, _ = self.vision_embed_tokens(
            image_patches_flat)
332

333
        return vision_embeddings_flat.split(patches_per_image, dim=0)
334

335
336
337
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

338
339
    def get_multimodal_embeddings(self,
                                  **kwargs: object) -> MultiModalEmbeddings:
340
341
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
342
            return []
343

344
        return self._process_image_input(image_input)
345
346
347
348

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
349
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
350
351
    ) -> torch.Tensor:
        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
352
353
        if multimodal_embeddings is not None \
            and len(multimodal_embeddings) != 0:
354
            inputs_embeds = merge_multimodal_embeddings(
355
356
357
358
359
                input_ids,
                inputs_embeds,
                multimodal_embeddings,
                _IMAGE_TOKEN_ID,
            )
360
361
        return inputs_embeds

362
363
364
365
366
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
367
        inputs_embeds: Optional[torch.Tensor] = None,
368
369
        **kwargs: object,
    ):
370
371
        if intermediate_tensors is not None:
            inputs_embeds = None
372
373
374
375
376
377
378
379

        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
        elif inputs_embeds is None:
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
            input_ids = None
380
381
382
383

        hidden_states = self.language_model(
            input_ids=input_ids,
            positions=positions,
384
            intermediate_tensors=intermediate_tensors,
385
386
387
388
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

389
390
391
392
393
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
394
395
396
397
        logits = self.language_model.logits_processor(
            self.language_model.lm_head, hidden_states, sampling_metadata)
        return logits

398
399
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
400
        loader = AutoWeightsLoader(self)
401
        return loader.load_weights(weights)