fuyu.py 14.3 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# adapted from https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/fuyu/modeling_fuyu.py
# Copyright 2023 The vLLM team.
# Copyright 2023 HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Fuyu model."""
import math
20
from collections.abc import Iterable, Mapping, Sequence
21
from typing import Literal, Optional, Set, Tuple, TypedDict
22
23
24

import torch
import torch.nn as nn
25
26
from transformers import (BatchFeature, FuyuConfig, FuyuImageProcessor,
                          FuyuProcessor)
27

28
from vllm.config import VllmConfig
29
from vllm.model_executor.layers.linear import ColumnParallelLinear
30
from vllm.model_executor.layers.sampler import SamplerOutput
31
32
from vllm.model_executor.models.persimmon import PersimmonForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata
33
from vllm.multimodal import MULTIMODAL_REGISTRY
34
35
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
                                    MultiModalKwargs)
36
37
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
                                   MultiModalDataItems)
38
from vllm.multimodal.processing import (BaseMultiModalProcessor,
39
                                        BaseProcessingInfo, PromptReplacement,
40
                                        PromptUpdate, PromptUpdateDetails)
41
from vllm.multimodal.profiling import BaseDummyInputsBuilder
42
from vllm.sequence import IntermediateTensors
43

44
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
45
46
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
                    merge_multimodal_embeddings)
47
48
49
50
51
52

# Cannot find the following 2 numbers from hf config.
_IMAGE_TOKEN_ID = 71011
_NEWLINE_TOKEN_ID = 71019


53
54
class FuyuImagePatchInputs(TypedDict):
    type: Literal["image_patches"]
55
    flat_data: torch.Tensor
56
57
    """
    Shape: 
58
59
60
    `(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)`
    """

61
    patches_per_image: list[int]
62
    """
63
64
65
66
    The number of total patches for each image in the batch.

    This is used to split the embeddings which has the first two dimensions
    flattened just like `flat_data`.
67
    """
68

69

70
class FuyuProcessingInfo(BaseProcessingInfo):
71

72
    def get_hf_config(self):
73
        return self.ctx.get_hf_config(FuyuConfig)
74

75
76
    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(FuyuProcessor, **kwargs)
77

78
79
    def get_image_processor(self) -> FuyuImageProcessor:
        return self.get_hf_processor().image_processor
80

81
82
83
84
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": 1}

    def get_image_feature_grid_size(
85
86
87
88
89
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> tuple[int, int]:
90
        image_processor = self.get_image_processor()
91
92
        target_width = image_processor.size["width"]
        target_height = image_processor.size["height"]
93
94
        patch_width = image_processor.patch_size["width"]
        patch_height = image_processor.patch_size["height"]
95
96
97
98
99
100
101
102
103

        if not (image_width <= target_width and image_height <= target_height):
            height_scale_factor = target_height / image_height
            width_scale_factor = target_width / image_width
            optimal_scale_factor = min(height_scale_factor, width_scale_factor)

            image_height = int(image_height * optimal_scale_factor)
            image_width = int(image_width * optimal_scale_factor)

104
105
        ncols = math.ceil(image_width / patch_width)
        nrows = math.ceil(image_height / patch_height)
106
107
        return ncols, nrows

108
109
110
111
112
113
114
115
116
117
118
119
120
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
        ncols, nrows = self.get_image_feature_grid_size(
            image_width=image_width,
            image_height=image_height,
        )

        return ncols * nrows

121
122
    def get_image_size_with_most_features(self) -> ImageSize:
        image_processor = self.get_image_processor()
123
124
125
        return ImageSize(width=image_processor.size["width"],
                         height=image_processor.size["height"])

126
127
128

class FuyuDummyInputsBuilder(BaseDummyInputsBuilder[FuyuProcessingInfo]):

129
130
131
132
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        return ""

    def get_dummy_mm_data(
133
134
135
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
136
    ) -> MultiModalDataDict:
137
138
        target_width, target_height = \
            self.info.get_image_size_with_most_features()
139
140
        num_images = mm_counts.get("image", 0)

141
        return {
142
143
144
145
146
147
148
            "image":
            self._get_dummy_images(width=target_width,
                                   height=target_height,
                                   num_images=num_images)
        }


149
class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
150
151
152
153
154
155
156
157
158

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        if not mm_data:
            # Avoid warning from HF logger for text-only input
159
160
            prompt_ids = self.info.get_tokenizer().encode(prompt)
            prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
            return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
        )

        image_patches = processed_outputs.get("image_patches")
        if image_patches is not None:
            images = mm_data["images"]
            assert isinstance(images, list)

            # Original output: (1, num_images, Pn, Px * Py * C)
            # New output: (num_images, Pn, Px * Py * C)
            assert (isinstance(image_patches, list)
                    and len(image_patches) == 1)
            assert (isinstance(image_patches[0], torch.Tensor)
                    and len(image_patches[0]) == len(images))

            processed_outputs["image_patches"] = image_patches[0]

        return processed_outputs

185
186
187
188
189
190
    def _apply_hf_processor_tokens_only(
        self,
        prompt_tokens: list[int],
    ) -> list[int]:
        # HF processor adds boa_token_id
        tokenizer = self.info.get_tokenizer()
191
192
193
        vocab = tokenizer.get_vocab()

        boa_token_id = vocab["<0x04>"]
194
195
196

        return prompt_tokens + [boa_token_id]

197
198
199
200
201
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
202
        return dict(image_patches=MultiModalFieldConfig.batched("image"))
203

204
    def _get_prompt_updates(
205
206
207
208
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
209
    ) -> Sequence[PromptUpdate]:
210
        hf_config = self.info.get_hf_config()
211
        bos_token_id = hf_config.bos_token_id
212
        assert isinstance(bos_token_id, int)
213

214
        tokenizer = self.info.get_tokenizer()
215
216
217
218
219
220
        eot_token_id = tokenizer.bos_token_id
        assert isinstance(eot_token_id, int)

        def get_replacement_fuyu(item_idx: int):
            images = mm_items.get_items("image", ImageProcessorItems)
            image_size = images.get_image_size(item_idx)
221

222
            ncols, nrows = self.info.get_image_feature_grid_size(
223
224
                image_width=image_size.width,
                image_height=image_size.height,
225
            )
226
227
            image_tokens = ([_IMAGE_TOKEN_ID] * ncols +
                            [_NEWLINE_TOKEN_ID]) * nrows
228

229
230
231
            return PromptUpdateDetails.select_token_id(
                image_tokens + [bos_token_id],
                embed_token_id=_IMAGE_TOKEN_ID,
232
            )
233
234
235
236
237
238
239
240
241
242

        return [
            PromptReplacement(
                modality="image",
                target=[eot_token_id],
                replacement=get_replacement_fuyu,
            )
        ]


243
244
245
@MULTIMODAL_REGISTRY.register_processor(FuyuMultiModalProcessor,
                                        info=FuyuProcessingInfo,
                                        dummy_inputs=FuyuDummyInputsBuilder)
246
class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
247

248
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
249
        super().__init__()
250
251
252
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
253
254
255
        self.config = config
        self.multimodal_config = multimodal_config

256
        self.vocab_size = config.text_config.vocab_size
257
258
259
260
261
262
263
        self.image_token_id = _IMAGE_TOKEN_ID
        self.image_feature_size = config.patch_size**2 * config.num_channels

        self.vision_embed_tokens = ColumnParallelLinear(
            self.image_feature_size,
            config.hidden_size,
            quant_config=quant_config,
264
            gather_output=True,
265
        )
266
        self.language_model = PersimmonForCausalLM(
267
268
269
            vllm_config=vllm_config.with_hf_config(config.text_config),
            prefix=maybe_prefix(prefix, "language_model"),
        )
270
271
272
273
274
275
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

    @property
    def sampler(self):
        return self.language_model.sampler
276

277
278
279
280
281
282
283
284
285
286
287
288
289
    def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:

        h = w = self.config.patch_size
        num_channels = self.config.num_channels
        expected_dims = num_channels * h * w

        def _validate_shape(d: torch.Tensor):
            actual_dims = d.size(-1)

            if actual_dims != expected_dims:
                expected_expr = str(expected_dims)
                raise ValueError(
                    "The expected shape of pixel values per image per batch "
290
                    f"per patch is {expected_expr}. "
291
292
293
294
295
296
297
                    f"You supplied {tuple(d.shape)}.")

        for d in data:
            _validate_shape(d)

        return data.to(self.vision_embed_tokens.weight.dtype)

298
    def _parse_and_validate_image_input(
299
300
301
302
            self, **kwargs: object) -> Optional[FuyuImagePatchInputs]:
        image_patches = kwargs.pop("image_patches", None)
        if image_patches is not None:
            if not isinstance(image_patches, (torch.Tensor, list)):
303
                raise ValueError("Incorrect type of image patches. "
304
                                 f"Got type: {type(image_patches)}")
305

306
307
308
309
            image_patches_flat = flatten_bn(image_patches)

            return FuyuImagePatchInputs(
                type="image_patches",
310
                flat_data=self._validate_pixel_values(
311
312
                    flatten_bn(image_patches_flat, concat=True)),
                patches_per_image=[x.size(0) for x in image_patches_flat],
313
            )
314

315
316
        return None

317
    def _process_image_input(
318
            self, image_input: FuyuImagePatchInputs) -> MultiModalEmbeddings:
319
        image_patches_flat = image_input["flat_data"]
320
        patches_per_image = image_input["patches_per_image"]
321
322

        assert self.vision_embed_tokens is not None
323
324
        vision_embeddings_flat, _ = self.vision_embed_tokens(
            image_patches_flat)
325

326
        return vision_embeddings_flat.split(patches_per_image, dim=0)
327

328
329
330
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

331
    def get_multimodal_embeddings(
332
            self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
333
334
335
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return None
336

337
        return self._process_image_input(image_input)
338
339
340
341

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
342
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
343
344
345
346
    ) -> torch.Tensor:
        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
        if multimodal_embeddings is not None:
            inputs_embeds = merge_multimodal_embeddings(
347
348
349
350
351
                input_ids,
                inputs_embeds,
                multimodal_embeddings,
                _IMAGE_TOKEN_ID,
            )
352
353
        return inputs_embeds

354
355
356
357
358
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
359
        inputs_embeds: Optional[torch.Tensor] = None,
360
361
        **kwargs: object,
    ):
362
363
        if intermediate_tensors is not None:
            inputs_embeds = None
364
365
366
367
368
369
370
371

        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
        elif inputs_embeds is None:
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
            input_ids = None
372
373
374
375

        hidden_states = self.language_model(
            input_ids=input_ids,
            positions=positions,
376
            intermediate_tensors=intermediate_tensors,
377
378
379
380
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

381
382
383
384
385
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
386
387
388
389
390
391
392
393
394
395
396
397
        logits = self.language_model.logits_processor(
            self.language_model.lm_head, hidden_states, sampling_metadata)
        return logits

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        next_tokens = self.language_model.sampler(logits, sampling_metadata)
        return next_tokens

398
399
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
400
        loader = AutoWeightsLoader(self)
401
        return loader.load_weights(weights)