fuyu.py 14.3 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# adapted from https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/fuyu/modeling_fuyu.py
# Copyright 2023 The vLLM team.
# Copyright 2023 HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Fuyu model."""
import math
20
from collections.abc import Iterable, Mapping, Sequence
21
from typing import Literal, Optional, Set, Tuple, TypedDict
22
23
24

import torch
import torch.nn as nn
25
26
from transformers import (BatchFeature, FuyuConfig, FuyuImageProcessor,
                          FuyuProcessor)
27

28
from vllm.config import VllmConfig
29
from vllm.model_executor.layers.linear import ColumnParallelLinear
30
from vllm.model_executor.layers.sampler import SamplerOutput
31
32
from vllm.model_executor.models.persimmon import PersimmonForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata
33
from vllm.multimodal import MULTIMODAL_REGISTRY
34
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
35
36
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
                                   MultiModalDataItems)
37
from vllm.multimodal.processing import (BaseMultiModalProcessor,
38
                                        BaseProcessingInfo, PromptReplacement,
39
                                        PromptUpdate, PromptUpdateDetails)
40
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
41
from vllm.sequence import IntermediateTensors
42

43
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
44
45
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
                    merge_multimodal_embeddings)
46
47
48
49
50
51

# Cannot find the following 2 numbers from hf config.
_IMAGE_TOKEN_ID = 71011
_NEWLINE_TOKEN_ID = 71019


52
53
class FuyuImagePatchInputs(TypedDict):
    type: Literal["image_patches"]
54
    flat_data: torch.Tensor
55
56
    """
    Shape: 
57
58
59
    `(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)`
    """

60
    patches_per_image: list[int]
61
    """
62
63
64
65
    The number of total patches for each image in the batch.

    This is used to split the embeddings which has the first two dimensions
    flattened just like `flat_data`.
66
67
68
    """


69
class FuyuProcessingInfo(BaseProcessingInfo):
70

71
    def get_hf_config(self):
72
        return self.ctx.get_hf_config(FuyuConfig)
73

74
75
    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(FuyuProcessor, **kwargs)
76

77
78
    def get_image_processor(self) -> FuyuImageProcessor:
        return self.get_hf_processor().image_processor
79

80
81
82
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": 1}

83
84
85
86
87
    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> Mapping[str, int]:
88
89
90
91
92
93
94
95
96
97
98
        target_width, target_height = self.get_image_size_with_most_features()

        max_ncols, max_nrows = self.get_image_feature_grid_size(
            image_width=target_width,
            image_height=target_height,
        )
        max_image_tokens = (max_ncols + 1) * max_nrows

        return {"image": max_image_tokens}

    def get_image_feature_grid_size(
99
100
101
102
103
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> tuple[int, int]:
104
        image_processor = self.get_image_processor()
105
106
        target_width = image_processor.size["width"]
        target_height = image_processor.size["height"]
107
108
        patch_width = image_processor.patch_size["width"]
        patch_height = image_processor.patch_size["height"]
109
110
111
112
113
114
115
116
117

        if not (image_width <= target_width and image_height <= target_height):
            height_scale_factor = target_height / image_height
            width_scale_factor = target_width / image_width
            optimal_scale_factor = min(height_scale_factor, width_scale_factor)

            image_height = int(image_height * optimal_scale_factor)
            image_width = int(image_width * optimal_scale_factor)

118
119
        ncols = math.ceil(image_width / patch_width)
        nrows = math.ceil(image_height / patch_height)
120
121
        return ncols, nrows

122
123
    def get_image_size_with_most_features(self) -> ImageSize:
        image_processor = self.get_image_processor()
124
125
126
        return ImageSize(width=image_processor.size["width"],
                         height=image_processor.size["height"])

127
128
129

class FuyuDummyInputsBuilder(BaseDummyInputsBuilder[FuyuProcessingInfo]):

130
131
132
133
134
    def get_dummy_processor_inputs(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> ProcessorInputs:
135
136
        target_width, target_height = \
            self.info.get_image_size_with_most_features()
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
        num_images = mm_counts.get("image", 0)

        mm_data = {
            "image":
            self._get_dummy_images(width=target_width,
                                   height=target_height,
                                   num_images=num_images)
        }

        return ProcessorInputs(
            prompt_text="",
            mm_data=mm_data,
        )


152
class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
153
154
155
156
157
158
159
160
161

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        if not mm_data:
            # Avoid warning from HF logger for text-only input
162
163
            prompt_ids = self.info.get_tokenizer().encode(prompt)
            prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
            return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
        )

        image_patches = processed_outputs.get("image_patches")
        if image_patches is not None:
            images = mm_data["images"]
            assert isinstance(images, list)

            # Original output: (1, num_images, Pn, Px * Py * C)
            # New output: (num_images, Pn, Px * Py * C)
            assert (isinstance(image_patches, list)
                    and len(image_patches) == 1)
            assert (isinstance(image_patches[0], torch.Tensor)
                    and len(image_patches[0]) == len(images))

            processed_outputs["image_patches"] = image_patches[0]

        return processed_outputs

188
189
190
191
192
193
    def _apply_hf_processor_tokens_only(
        self,
        prompt_tokens: list[int],
    ) -> list[int]:
        # HF processor adds boa_token_id
        tokenizer = self.info.get_tokenizer()
194
195
196
        vocab = tokenizer.get_vocab()

        boa_token_id = vocab["<0x04>"]
197
198
199

        return prompt_tokens + [boa_token_id]

200
201
202
203
204
205
206
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(image_patches=MultiModalFieldConfig.batched("image"))

207
    def _get_prompt_updates(
208
209
210
211
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
212
    ) -> Sequence[PromptUpdate]:
213
        hf_config = self.info.get_hf_config()
214
        bos_token_id = hf_config.bos_token_id
215
        assert isinstance(bos_token_id, int)
216

217
        tokenizer = self.info.get_tokenizer()
218
219
220
221
222
223
        eot_token_id = tokenizer.bos_token_id
        assert isinstance(eot_token_id, int)

        def get_replacement_fuyu(item_idx: int):
            images = mm_items.get_items("image", ImageProcessorItems)
            image_size = images.get_image_size(item_idx)
224

225
            ncols, nrows = self.info.get_image_feature_grid_size(
226
227
                image_width=image_size.width,
                image_height=image_size.height,
228
            )
229
230
            image_tokens = ([_IMAGE_TOKEN_ID] * ncols +
                            [_NEWLINE_TOKEN_ID]) * nrows
231

232
            return PromptUpdateDetails(
233
234
235
                full=image_tokens + [bos_token_id],
                features=image_tokens,
            )
236
237
238
239
240
241
242
243
244
245

        return [
            PromptReplacement(
                modality="image",
                target=[eot_token_id],
                replacement=get_replacement_fuyu,
            )
        ]


246
247
248
@MULTIMODAL_REGISTRY.register_processor(FuyuMultiModalProcessor,
                                        info=FuyuProcessingInfo,
                                        dummy_inputs=FuyuDummyInputsBuilder)
249
class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
250

251
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
252
        super().__init__()
253
254
255
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
256
257
258
        self.config = config
        self.multimodal_config = multimodal_config

259
        self.vocab_size = config.text_config.vocab_size
260
261
262
263
264
265
266
        self.image_token_id = _IMAGE_TOKEN_ID
        self.image_feature_size = config.patch_size**2 * config.num_channels

        self.vision_embed_tokens = ColumnParallelLinear(
            self.image_feature_size,
            config.hidden_size,
            quant_config=quant_config,
267
            gather_output=True,
268
        )
269
        self.language_model = PersimmonForCausalLM(
270
271
272
            vllm_config=vllm_config.with_hf_config(config.text_config),
            prefix=maybe_prefix(prefix, "language_model"),
        )
273
274
275
276
277
278
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

    @property
    def sampler(self):
        return self.language_model.sampler
279

280
281
282
283
284
285
286
287
288
289
290
291
292
    def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:

        h = w = self.config.patch_size
        num_channels = self.config.num_channels
        expected_dims = num_channels * h * w

        def _validate_shape(d: torch.Tensor):
            actual_dims = d.size(-1)

            if actual_dims != expected_dims:
                expected_expr = str(expected_dims)
                raise ValueError(
                    "The expected shape of pixel values per image per batch "
293
                    f"per patch is {expected_expr}. "
294
295
296
297
298
299
300
                    f"You supplied {tuple(d.shape)}.")

        for d in data:
            _validate_shape(d)

        return data.to(self.vision_embed_tokens.weight.dtype)

301
    def _parse_and_validate_image_input(
302
303
304
305
            self, **kwargs: object) -> Optional[FuyuImagePatchInputs]:
        image_patches = kwargs.pop("image_patches", None)
        if image_patches is not None:
            if not isinstance(image_patches, (torch.Tensor, list)):
306
                raise ValueError("Incorrect type of image patches. "
307
                                 f"Got type: {type(image_patches)}")
308

309
310
311
312
            image_patches_flat = flatten_bn(image_patches)

            return FuyuImagePatchInputs(
                type="image_patches",
313
                flat_data=self._validate_pixel_values(
314
315
                    flatten_bn(image_patches_flat, concat=True)),
                patches_per_image=[x.size(0) for x in image_patches_flat],
316
            )
317

318
319
        return None

320
    def _process_image_input(
321
            self, image_input: FuyuImagePatchInputs) -> MultiModalEmbeddings:
322
        image_patches_flat = image_input["flat_data"]
323
        patches_per_image = image_input["patches_per_image"]
324
325

        assert self.vision_embed_tokens is not None
326
327
328
        vision_embeddings_flat, _ = self.vision_embed_tokens(
            image_patches_flat)
        return vision_embeddings_flat.split(patches_per_image, dim=0)
329

330
    def get_multimodal_embeddings(
331
            self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
332
333
334
335
336
337
338
339
340
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return None
        vision_embeddings = self._process_image_input(image_input)
        return vision_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
341
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
342
343
344
345
346
347
348
349
    ) -> torch.Tensor:
        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
        if multimodal_embeddings is not None:
            inputs_embeds = merge_multimodal_embeddings(
                input_ids, inputs_embeds, multimodal_embeddings,
                _IMAGE_TOKEN_ID)
        return inputs_embeds

350
351
352
353
354
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
355
        inputs_embeds: Optional[torch.Tensor] = None,
356
357
        **kwargs: object,
    ):
358
359
        if intermediate_tensors is not None:
            inputs_embeds = None
360
361
362
363
364
365
366
367

        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
        elif inputs_embeds is None:
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
            input_ids = None
368
369
370
371

        hidden_states = self.language_model(
            input_ids=input_ids,
            positions=positions,
372
            intermediate_tensors=intermediate_tensors,
373
374
375
376
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

377
378
379
380
381
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
382
383
384
385
386
387
388
389
390
391
392
393
        logits = self.language_model.logits_processor(
            self.language_model.lm_head, hidden_states, sampling_metadata)
        return logits

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        next_tokens = self.language_model.sampler(logits, sampling_metadata)
        return next_tokens

394
395
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
396
        loader = AutoWeightsLoader(self)
397
        return loader.load_weights(weights)