fuyu.py 12.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# coding=utf-8
# adapted from https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/fuyu/modeling_fuyu.py
# Copyright 2023 The vLLM team.
# Copyright 2023 HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Fuyu model."""
import math
19
from array import array
20
from typing import Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict
21
22
23
24
25
26
27
28
29

import torch
import torch.nn as nn
import torch.utils.checkpoint
from PIL import Image
from transformers import FuyuConfig, FuyuImageProcessor

from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
30
31
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
                         token_inputs)
32
from vllm.model_executor.layers.linear import ColumnParallelLinear
33
from vllm.model_executor.layers.quantization import QuantizationConfig
34
from vllm.model_executor.layers.sampler import SamplerOutput
35
36
37
38
from vllm.model_executor.models.persimmon import PersimmonForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs
39
40
from vllm.multimodal.image import cached_get_image_processor
from vllm.multimodal.utils import cached_get_tokenizer
41
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
42
                           SequenceData)
43

44
from .interfaces import SupportsMultiModal, SupportsPP
45
from .utils import AutoWeightsLoader, flatten_bn, merge_multimodal_embeddings
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95

# Cannot find the following 2 numbers from hf config.
_IMAGE_TOKEN_ID = 71011
_NEWLINE_TOKEN_ID = 71019

MAX_IMAGE_FEATURE_SIZE_HEIGHT = 1080
MAX_IMAGE_FEATURE_SIZE_WIDTH = 1920


class FuyuImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
    data: torch.Tensor
    """
    Shape: 
    (batch_size, num_patches, patch_size_x * patch_size_y * num_channels)
    """


def _calculate_num_image_tokens(
    height: int,
    width: int,
) -> Tuple[int, int]:
    """
    calculate number of image tokens needed for a given image size
    The expected Fuyu image prompts is in format:
        (image_token * ncols + newline_token) * nrows
    args:
        image_size: Tuple[int, int] - (width, height) of the image
    returns:
        ncols: int - number of image tokens in x direction
        nrows: int - number of image tokens in y direction
    """
    ncol = math.ceil(width / 30)
    nrow = math.ceil(height / 30)
    return ncol, nrow


def get_max_fuyu_image_feature_size():

    return _calculate_num_image_tokens(
        height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
        width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
    )


def get_max_fuyu_image_tokens(ctx: InputContext):
    ncol, nrow = get_max_fuyu_image_feature_size()
    return (ncol + 1) * nrow


96
def dummy_seq_data_for_fuyu(ctx: InputContext, seq_len: int, num_images: int):
97
98
99
    ncol, nrow = get_max_fuyu_image_feature_size()
    image_feature_size = get_max_fuyu_image_tokens(ctx)

100
101
102
103
104
105
    image_token_ids = (
        array(VLLM_TOKEN_ID_ARRAY_TYPE, [_IMAGE_TOKEN_ID]) * ncol +
        array(VLLM_TOKEN_ID_ARRAY_TYPE, [_NEWLINE_TOKEN_ID])) * nrow
    token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, image_token_ids) * num_images
    token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
                       [0]) * (seq_len - image_feature_size * num_images)
106
107
108
109
    return SequenceData(token_ids)


def dummy_image_for_fuyu(
110
111
    num_images: int,
    *,
112
113
114
115
    image_width: int,
    image_height: int,
):
    image = Image.new("RGB", (image_width, image_height), color=0)
116
    return {"image": image if num_images == 1 else [image] * num_images}
117
118


119
120
121
122
123
124
125
def dummy_data_for_fuyu(ctx: InputContext, seq_len: int,
                        mm_counts: Mapping[str, int]):
    num_images = mm_counts["image"]
    seq_data = dummy_seq_data_for_fuyu(ctx, seq_len, num_images)
    mm_data = dummy_image_for_fuyu(num_images,
                                   image_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
                                   image_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT)
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
    return seq_data, mm_data


def _fuyu_image_preprocess(image_processor: FuyuImageProcessor,
                           data: Image.Image):
    image_encoding = image_processor.preprocess(data, return_tensors="pt")
    batch_images = torch.stack([img[0] for img in image_encoding["images"]
                                ]).unsqueeze(1)
    image_unpadded_heights = torch.tensor(
        image_encoding["image_unpadded_heights"])
    image_unpadded_widths = torch.tensor(
        image_encoding["image_unpadded_widths"])

    batch_size = len(image_encoding["images"])
    image_present = torch.ones(batch_size, 1, 1)
    model_image_input = image_processor.preprocess_with_tokenizer_info(
        image_input=batch_images,
        image_present=image_present,
        image_unpadded_h=image_unpadded_heights,
        image_unpadded_w=image_unpadded_widths,
        image_placeholder_id=_IMAGE_TOKEN_ID,
        image_newline_id=_NEWLINE_TOKEN_ID,
        variable_sized=True,
    )
    return model_image_input


153
154
def input_processor_for_fuyu(ctx: InputContext, inputs: DecoderOnlyInputs):
    multi_modal_data = inputs.get("multi_modal_data")
155
    if multi_modal_data is None or "image" not in multi_modal_data:
156
        return inputs
157
158
159
160
161
162
163
164
165
166
167

    model_config = ctx.model_config
    image_data = multi_modal_data["image"]
    new_multi_modal_data = {}
    # process image data
    if isinstance(image_data, Image.Image):
        # Fuyu's image_processor can also finish token padding
        image_processor: FuyuImageProcessor = cached_get_image_processor(
            model_config.model)

        model_image_input = _fuyu_image_preprocess(image_processor, image_data)
168
        image_patches = torch.cat([
169
170
171
172
173
174
175
176
177
178
179
            image_patch[0]
            for image_patch in model_image_input["image_patches"]
        ])
        new_multi_modal_data["image"] = image_patches

    elif isinstance(image_data, torch.Tensor):
        raise NotImplementedError("Embeddings input is not supported yet")
    else:
        raise TypeError(f"Invalid image type: {type(image_data)}")

    # process prompts
180
181
    prompt = inputs.get("prompt")
    prompt_token_ids = inputs["prompt_token_ids"]
182
183
184
185
186
187
188
189
190
191
192
193
    tokenizer = cached_get_tokenizer(model_config.model)
    # dim0 is batch_size, dim1 is subseq_size which will always be 1
    image_input_ids: List[List[
        torch.Tensor]] = model_image_input["image_input_ids"]
    image_input_ids = image_input_ids[0][0].tolist()
    bos_token = tokenizer.encode("<s>", add_special_tokens=False)[1:]
    boa_token = tokenizer.encode("\x04", add_special_tokens=False)[1:]

    new_prompt = prompt + "\x04"
    new_prompt_token_ids = image_input_ids + bos_token + prompt_token_ids[
        1:] + boa_token

194
195
196
    return token_inputs(prompt=new_prompt,
                        prompt_token_ids=new_prompt_token_ids,
                        multi_modal_data=new_multi_modal_data)
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212


def input_mapper_for_fuyu(ctx: InputContext, data: object):
    model_config = ctx.model_config
    if isinstance(data, Image.Image):
        # Fuyu's image_processor can also finish token padding
        image_processor: FuyuImageProcessor = cached_get_image_processor(
            model_config.model)

        model_image_input = _fuyu_image_preprocess(image_processor, data)
        data = torch.stack([
            image_patch[0]
            for image_patch in model_image_input["image_patches"]
        ])

    # image has been processed with prompt in input processor
213
    return MultiModalInputs({"pixel_values": data})
214
215
216
217
218
219


@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_fuyu)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_fuyu_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_fuyu)
@INPUT_REGISTRY.register_input_processor(input_processor_for_fuyu)
220
class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
221
222
223
224
225
226
227
228
229
230
231

    def __init__(self,
                 config: FuyuConfig,
                 multimodal_config: MultiModalConfig,
                 cache_config: Optional[CacheConfig] = None,
                 quant_config: Optional[QuantizationConfig] = None) -> None:
        super().__init__()
        self.config = config
        self.multimodal_config = multimodal_config

        self.padding_idx = config.pad_token_id
232
        self.vocab_size = config.text_config.vocab_size
233
234
235
236
237
238
239
        self.image_token_id = _IMAGE_TOKEN_ID
        self.image_feature_size = config.patch_size**2 * config.num_channels

        self.vision_embed_tokens = ColumnParallelLinear(
            self.image_feature_size,
            config.hidden_size,
            quant_config=quant_config,
240
            gather_output=True,
241
        )
242
        self.language_model = PersimmonForCausalLM(config.text_config,
243
244
                                                   cache_config=cache_config,
                                                   quant_config=quant_config)
245
246
247
248
249
250
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

    @property
    def sampler(self):
        return self.language_model.sampler
251

252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
    def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:

        h = w = self.config.patch_size
        num_channels = self.config.num_channels
        expected_dims = num_channels * h * w

        def _validate_shape(d: torch.Tensor):
            actual_dims = d.size(-1)

            if actual_dims != expected_dims:
                expected_expr = str(expected_dims)
                raise ValueError(
                    "The expected shape of pixel values per image per batch "
                    f" per patch is {expected_expr}. "
                    f"You supplied {tuple(d.shape)}.")

        for d in data:
            _validate_shape(d)

        return data.to(self.vision_embed_tokens.weight.dtype)

273
274
    def _parse_and_validate_image_input(
            self, **kwargs: object) -> Optional[FuyuImagePixelInputs]:
275
        pixel_values = kwargs.pop("pixel_values", None)
276

277
278
279
280
281
282
283
284
285
286
        if pixel_values is not None:
            if not isinstance(pixel_values, (torch.Tensor, list)):
                raise ValueError("Incorrect type of image patches. "
                                 f"Got type: {type(pixel_values)}")

            return FuyuImagePixelInputs(
                type="pixel_values",
                data=self._validate_pixel_values(
                    flatten_bn(pixel_values, concat=True)),
            )
287

288
289
        return None

290
291
292
293
294
295
296
    def _process_image_input(
            self, image_input: FuyuImagePixelInputs) -> torch.Tensor:

        assert self.vision_embed_tokens is not None
        vision_embeddings, _ = self.vision_embed_tokens(image_input["data"])
        return vision_embeddings

297
298
299
300
301
302
303
304
305
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        **kwargs: object,
    ):
306
307
308
309
310
        if intermediate_tensors is not None:
            input_ids = None
            inputs_embeds = None
        else:
            image_input = self._parse_and_validate_image_input(**kwargs)
311

312
313
314
315
316
317
318
            if image_input is not None:
                vision_embeddings = self._process_image_input(image_input)
                inputs_embeds = self.language_model.model.embed_tokens(
                    input_ids)
                inputs_embeds = merge_multimodal_embeddings(
                    input_ids, inputs_embeds, vision_embeddings,
                    self.image_token_id)
319

320
321
            else:
                inputs_embeds = None
322
323
324
325
326
327

        hidden_states = self.language_model(
            input_ids=input_ids,
            positions=positions,
            kv_caches=kv_caches,
            attn_metadata=attn_metadata,
328
            intermediate_tensors=intermediate_tensors,
329
330
331
332
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

333
334
335
336
337
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
338
339
340
341
342
343
344
345
346
347
348
349
350
        logits = self.language_model.logits_processor(
            self.language_model.lm_head, hidden_states, sampling_metadata)
        return logits

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        next_tokens = self.language_model.sampler(logits, sampling_metadata)
        return next_tokens

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
351
352
        loader = AutoWeightsLoader(self)
        loader.load_weights(weights)