fuyu.py 14.3 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# adapted from https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/fuyu/modeling_fuyu.py
# Copyright 2023 The vLLM team.
# Copyright 2023 HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Fuyu model."""
import math
20
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
21
                    TypedDict)
22
23
24

import torch
import torch.nn as nn
25
26
from transformers import (BatchFeature, FuyuConfig, FuyuImageProcessor,
                          FuyuProcessor)
27
28

from vllm.attention import AttentionMetadata
29
from vllm.config import VllmConfig
30
from vllm.model_executor.layers.linear import ColumnParallelLinear
31
from vllm.model_executor.layers.sampler import SamplerOutput
32
33
from vllm.model_executor.models.persimmon import PersimmonForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata
34
from vllm.multimodal import MULTIMODAL_REGISTRY
35
36
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
                                    NestedTensors)
37
38
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
                                   MultiModalDataItems)
39
from vllm.multimodal.processing import (BaseMultiModalProcessor,
40
41
                                        BaseProcessingInfo, PromptReplacement,
                                        PromptReplacementDetails)
42
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
43
from vllm.sequence import IntermediateTensors
44

45
from .interfaces import SupportsMultiModal, SupportsPP
46
47
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
                    merge_multimodal_embeddings)
48
49
50
51
52
53

# Cannot find the following 2 numbers from hf config.
_IMAGE_TOKEN_ID = 71011
_NEWLINE_TOKEN_ID = 71019


54
55
class FuyuImagePatchInputs(TypedDict):
    type: Literal["image_patches"]
56
    flat_data: torch.Tensor
57
58
    """
    Shape: 
59
60
61
62
63
64
    `(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)`
    """

    patches_per_image: List[int]
    """
    List of number of total patches for each image in the batch.
65
    This is used to restore the first two dimensions of `flat_data`.
66
67
68
    """


69
class FuyuProcessingInfo(BaseProcessingInfo):
70

71
    def get_hf_config(self):
72
        return self.ctx.get_hf_config(FuyuConfig)
73

74
    def get_hf_processor(self):
75
        return self.ctx.get_hf_processor(FuyuProcessor)
76

77
78
    def get_image_processor(self) -> FuyuImageProcessor:
        return self.get_hf_processor().image_processor
79

80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": 1}

    def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
        target_width, target_height = self.get_image_size_with_most_features()

        max_ncols, max_nrows = self.get_image_feature_grid_size(
            image_width=target_width,
            image_height=target_height,
        )
        max_image_tokens = (max_ncols + 1) * max_nrows

        return {"image": max_image_tokens}

    def get_image_feature_grid_size(
95
96
97
98
99
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> tuple[int, int]:
100
        image_processor = self.get_image_processor()
101
102
        target_width = image_processor.size["width"]
        target_height = image_processor.size["height"]
103
104
105
106
107
108
109
110
111
112
113
114
115

        if not (image_width <= target_width and image_height <= target_height):
            height_scale_factor = target_height / image_height
            width_scale_factor = target_width / image_width
            optimal_scale_factor = min(height_scale_factor, width_scale_factor)

            image_height = int(image_height * optimal_scale_factor)
            image_width = int(image_width * optimal_scale_factor)

        ncols = math.ceil(image_width / 30)
        nrows = math.ceil(image_height / 30)
        return ncols, nrows

116
117
    def get_image_size_with_most_features(self) -> ImageSize:
        image_processor = self.get_image_processor()
118
119
120
        return ImageSize(width=image_processor.size["width"],
                         height=image_processor.size["height"])

121
122
123

class FuyuDummyInputsBuilder(BaseDummyInputsBuilder[FuyuProcessingInfo]):

124
125
126
127
128
    def get_dummy_processor_inputs(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> ProcessorInputs:
129
130
        target_width, target_height = \
            self.info.get_image_size_with_most_features()
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
        num_images = mm_counts.get("image", 0)

        mm_data = {
            "image":
            self._get_dummy_images(width=target_width,
                                   height=target_height,
                                   num_images=num_images)
        }

        return ProcessorInputs(
            prompt_text="",
            mm_data=mm_data,
        )


146
class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
147
148
149
150
151
152
153
154
155

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        if not mm_data:
            # Avoid warning from HF logger for text-only input
156
157
            prompt_ids = self.info.get_tokenizer().encode(prompt)
            prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
            return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
        )

        image_patches = processed_outputs.get("image_patches")
        if image_patches is not None:
            images = mm_data["images"]
            assert isinstance(images, list)

            # Original output: (1, num_images, Pn, Px * Py * C)
            # New output: (num_images, Pn, Px * Py * C)
            assert (isinstance(image_patches, list)
                    and len(image_patches) == 1)
            assert (isinstance(image_patches[0], torch.Tensor)
                    and len(image_patches[0]) == len(images))

            processed_outputs["image_patches"] = image_patches[0]

        return processed_outputs

182
183
184
185
186
187
    def _apply_hf_processor_tokens_only(
        self,
        prompt_tokens: list[int],
    ) -> list[int]:
        # HF processor adds boa_token_id
        tokenizer = self.info.get_tokenizer()
188
189
190
        vocab = tokenizer.get_vocab()

        boa_token_id = vocab["<0x04>"]
191
192
193

        return prompt_tokens + [boa_token_id]

194
195
196
197
198
199
200
201
202
203
204
205
206
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(image_patches=MultiModalFieldConfig.batched("image"))

    def _get_prompt_replacements(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
    ) -> list[PromptReplacement]:
207
        hf_config = self.info.get_hf_config()
208
        bos_token_id = hf_config.bos_token_id
209
        assert isinstance(bos_token_id, int)
210

211
        tokenizer = self.info.get_tokenizer()
212
213
214
215
216
217
        eot_token_id = tokenizer.bos_token_id
        assert isinstance(eot_token_id, int)

        def get_replacement_fuyu(item_idx: int):
            images = mm_items.get_items("image", ImageProcessorItems)
            image_size = images.get_image_size(item_idx)
218

219
            ncols, nrows = self.info.get_image_feature_grid_size(
220
221
                image_width=image_size.width,
                image_height=image_size.height,
222
            )
223
224
            image_tokens = ([_IMAGE_TOKEN_ID] * ncols +
                            [_NEWLINE_TOKEN_ID]) * nrows
225

226
227
228
229
            return PromptReplacementDetails(
                full=image_tokens + [bos_token_id],
                features=image_tokens,
            )
230
231
232
233
234
235
236
237
238
239

        return [
            PromptReplacement(
                modality="image",
                target=[eot_token_id],
                replacement=get_replacement_fuyu,
            )
        ]


240
241
242
@MULTIMODAL_REGISTRY.register_processor(FuyuMultiModalProcessor,
                                        info=FuyuProcessingInfo,
                                        dummy_inputs=FuyuDummyInputsBuilder)
243
class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
244

245
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
246
        super().__init__()
247
248
249
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
250
251
252
253
        self.config = config
        self.multimodal_config = multimodal_config

        self.padding_idx = config.pad_token_id
254
        self.vocab_size = config.text_config.vocab_size
255
256
257
258
259
260
261
        self.image_token_id = _IMAGE_TOKEN_ID
        self.image_feature_size = config.patch_size**2 * config.num_channels

        self.vision_embed_tokens = ColumnParallelLinear(
            self.image_feature_size,
            config.hidden_size,
            quant_config=quant_config,
262
            gather_output=True,
263
        )
264
        self.language_model = PersimmonForCausalLM(
265
266
267
            vllm_config=vllm_config.with_hf_config(config.text_config),
            prefix=maybe_prefix(prefix, "language_model"),
        )
268
269
270
271
272
273
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

    @property
    def sampler(self):
        return self.language_model.sampler
274

275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
    def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:

        h = w = self.config.patch_size
        num_channels = self.config.num_channels
        expected_dims = num_channels * h * w

        def _validate_shape(d: torch.Tensor):
            actual_dims = d.size(-1)

            if actual_dims != expected_dims:
                expected_expr = str(expected_dims)
                raise ValueError(
                    "The expected shape of pixel values per image per batch "
                    f" per patch is {expected_expr}. "
                    f"You supplied {tuple(d.shape)}.")

        for d in data:
            _validate_shape(d)

        return data.to(self.vision_embed_tokens.weight.dtype)

296
    def _parse_and_validate_image_input(
297
298
299
300
            self, **kwargs: object) -> Optional[FuyuImagePatchInputs]:
        image_patches = kwargs.pop("image_patches", None)
        if image_patches is not None:
            if not isinstance(image_patches, (torch.Tensor, list)):
301
                raise ValueError("Incorrect type of image patches. "
302
                                 f"Got type: {type(image_patches)}")
303

304
305
306
307
            image_patches_flat = flatten_bn(image_patches)

            return FuyuImagePatchInputs(
                type="image_patches",
308
                flat_data=self._validate_pixel_values(
309
310
                    flatten_bn(image_patches_flat, concat=True)),
                patches_per_image=[x.size(0) for x in image_patches_flat],
311
            )
312

313
314
        return None

315
    def _process_image_input(
316
            self, image_input: FuyuImagePatchInputs) -> NestedTensors:
317
        image_patches_flat = image_input["flat_data"]
318
        patches_per_image = image_input["patches_per_image"]
319
320

        assert self.vision_embed_tokens is not None
321
322
323
        vision_embeddings_flat, _ = self.vision_embed_tokens(
            image_patches_flat)
        return vision_embeddings_flat.split(patches_per_image, dim=0)
324

325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
    def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return None
        vision_embeddings = self._process_image_input(image_input)
        return vision_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings: Optional[NestedTensors] = None,
    ) -> torch.Tensor:
        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
        if multimodal_embeddings is not None:
            inputs_embeds = merge_multimodal_embeddings(
                input_ids, inputs_embeds, multimodal_embeddings,
                _IMAGE_TOKEN_ID)
        return inputs_embeds

344
345
346
347
348
349
350
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
        intermediate_tensors: Optional[IntermediateTensors] = None,
351
        inputs_embeds: Optional[torch.Tensor] = None,
352
353
        **kwargs: object,
    ):
354
355
        if intermediate_tensors is not None:
            inputs_embeds = None
356
357
358
359
360
361
362
363

        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
        elif inputs_embeds is None:
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
            input_ids = None
364
365
366
367
368
369

        hidden_states = self.language_model(
            input_ids=input_ids,
            positions=positions,
            kv_caches=kv_caches,
            attn_metadata=attn_metadata,
370
            intermediate_tensors=intermediate_tensors,
371
372
373
374
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

375
376
377
378
379
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
380
381
382
383
384
385
386
387
388
389
390
391
        logits = self.language_model.logits_processor(
            self.language_model.lm_head, hidden_states, sampling_metadata)
        return logits

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        next_tokens = self.language_model.sampler(logits, sampling_metadata)
        return next_tokens

392
393
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
394
        loader = AutoWeightsLoader(self)
395
        return loader.load_weights(weights)