"vscode:/vscode.git/clone" did not exist on "73a484caa1ad320d6e695f098c25c479a71e6774"
fuyu.py 13.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# adapted from https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/fuyu/modeling_fuyu.py
# Copyright 2023 The vLLM team.
# Copyright 2023 HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Fuyu model."""
import math
18
from array import array
19
from typing import Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict
20
21
22
23
24

import torch
import torch.nn as nn
import torch.utils.checkpoint
from PIL import Image
25
from transformers import FuyuImageProcessor
26
27

from vllm.attention import AttentionMetadata
28
from vllm.config import VllmConfig
29
30
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
                         InputContext, token_inputs)
31
from vllm.model_executor.layers.linear import ColumnParallelLinear
32
from vllm.model_executor.layers.sampler import SamplerOutput
33
34
35
from vllm.model_executor.models.persimmon import PersimmonForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
36
from vllm.multimodal.base import MultiModalKwargs
37
from vllm.multimodal.image import cached_get_image_processor
38
39
from vllm.multimodal.utils import (cached_get_tokenizer,
                                   consecutive_placeholder_ranges)
40
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
41
                           SequenceData)
42
from vllm.utils import is_list_of
43

44
from .interfaces import SupportsMultiModal, SupportsPP
45
from .utils import AutoWeightsLoader, flatten_bn, merge_multimodal_embeddings
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95

# Cannot find the following 2 numbers from hf config.
_IMAGE_TOKEN_ID = 71011
_NEWLINE_TOKEN_ID = 71019

MAX_IMAGE_FEATURE_SIZE_HEIGHT = 1080
MAX_IMAGE_FEATURE_SIZE_WIDTH = 1920


class FuyuImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
    data: torch.Tensor
    """
    Shape: 
    (batch_size, num_patches, patch_size_x * patch_size_y * num_channels)
    """


def _calculate_num_image_tokens(
    height: int,
    width: int,
) -> Tuple[int, int]:
    """
    calculate number of image tokens needed for a given image size
    The expected Fuyu image prompts is in format:
        (image_token * ncols + newline_token) * nrows
    args:
        image_size: Tuple[int, int] - (width, height) of the image
    returns:
        ncols: int - number of image tokens in x direction
        nrows: int - number of image tokens in y direction
    """
    ncol = math.ceil(width / 30)
    nrow = math.ceil(height / 30)
    return ncol, nrow


def get_max_fuyu_image_feature_size():

    return _calculate_num_image_tokens(
        height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
        width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
    )


def get_max_fuyu_image_tokens(ctx: InputContext):
    ncol, nrow = get_max_fuyu_image_feature_size()
    return (ncol + 1) * nrow


96
def dummy_seq_data_for_fuyu(ctx: InputContext, seq_len: int, num_images: int):
97
98
99
    ncol, nrow = get_max_fuyu_image_feature_size()
    image_feature_size = get_max_fuyu_image_tokens(ctx)

100
101
102
103
104
105
    image_token_ids = (
        array(VLLM_TOKEN_ID_ARRAY_TYPE, [_IMAGE_TOKEN_ID]) * ncol +
        array(VLLM_TOKEN_ID_ARRAY_TYPE, [_NEWLINE_TOKEN_ID])) * nrow
    token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, image_token_ids) * num_images
    token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
                       [0]) * (seq_len - image_feature_size * num_images)
106
107
108
109
110
    return SequenceData(token_ids), {
        "image":
        consecutive_placeholder_ranges(num_items=num_images,
                                       item_size=image_feature_size)
    }
111
112
113


def dummy_image_for_fuyu(
114
115
    num_images: int,
    *,
116
117
118
119
    image_width: int,
    image_height: int,
):
    image = Image.new("RGB", (image_width, image_height), color=0)
120
    return {"image": image if num_images == 1 else [image] * num_images}
121
122


123
124
125
def dummy_data_for_fuyu(ctx: InputContext, seq_len: int,
                        mm_counts: Mapping[str, int]):
    num_images = mm_counts["image"]
126
    seq_data, ranges = dummy_seq_data_for_fuyu(ctx, seq_len, num_images)
127
128
129
    mm_data = dummy_image_for_fuyu(num_images,
                                   image_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
                                   image_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT)
130
    return DummyData(seq_data, mm_data, ranges)
131
132
133


def _fuyu_image_preprocess(image_processor: FuyuImageProcessor,
134
                           data: List[Image.Image]):
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
    image_encoding = image_processor.preprocess(data, return_tensors="pt")
    batch_images = torch.stack([img[0] for img in image_encoding["images"]
                                ]).unsqueeze(1)
    image_unpadded_heights = torch.tensor(
        image_encoding["image_unpadded_heights"])
    image_unpadded_widths = torch.tensor(
        image_encoding["image_unpadded_widths"])

    batch_size = len(image_encoding["images"])
    image_present = torch.ones(batch_size, 1, 1)
    model_image_input = image_processor.preprocess_with_tokenizer_info(
        image_input=batch_images,
        image_present=image_present,
        image_unpadded_h=image_unpadded_heights,
        image_unpadded_w=image_unpadded_widths,
        image_placeholder_id=_IMAGE_TOKEN_ID,
        image_newline_id=_NEWLINE_TOKEN_ID,
        variable_sized=True,
    )
    return model_image_input


157
158
def input_processor_for_fuyu(ctx: InputContext, inputs: DecoderOnlyInputs):
    multi_modal_data = inputs.get("multi_modal_data")
159
    if multi_modal_data is None or "image" not in multi_modal_data:
160
        return inputs
161
162
163
164

    model_config = ctx.model_config
    image_data = multi_modal_data["image"]
    new_multi_modal_data = {}
165
166
    image_list = image_data if isinstance(image_data, list) else [image_data]

167
    # process image data
168
    if is_list_of(image_list, Image.Image):
169
170
171
172
173
        # Fuyu's image_processor can also finish token padding
        image_processor: FuyuImageProcessor = cached_get_image_processor(
            model_config.model)

        model_image_input = _fuyu_image_preprocess(image_processor, image_data)
174
        image_patches = torch.cat([
175
176
177
178
179
            image_patch[0]
            for image_patch in model_image_input["image_patches"]
        ])
        new_multi_modal_data["image"] = image_patches

180
    elif is_list_of(image_list, torch.Tensor):
181
182
183
184
185
        raise NotImplementedError("Embeddings input is not supported yet")
    else:
        raise TypeError(f"Invalid image type: {type(image_data)}")

    # process prompts
186
187
    prompt = inputs.get("prompt")
    prompt_token_ids = inputs["prompt_token_ids"]
188
189
190
191
192
193
194
195
196
197
198
199
    tokenizer = cached_get_tokenizer(model_config.model)
    # dim0 is batch_size, dim1 is subseq_size which will always be 1
    image_input_ids: List[List[
        torch.Tensor]] = model_image_input["image_input_ids"]
    image_input_ids = image_input_ids[0][0].tolist()
    bos_token = tokenizer.encode("<s>", add_special_tokens=False)[1:]
    boa_token = tokenizer.encode("\x04", add_special_tokens=False)[1:]

    new_prompt = prompt + "\x04"
    new_prompt_token_ids = image_input_ids + bos_token + prompt_token_ids[
        1:] + boa_token

200
201
202
    return token_inputs(prompt=new_prompt,
                        prompt_token_ids=new_prompt_token_ids,
                        multi_modal_data=new_multi_modal_data)
203
204
205
206


def input_mapper_for_fuyu(ctx: InputContext, data: object):
    model_config = ctx.model_config
207
208
    data_list = data if isinstance(data, list) else [data]
    if is_list_of(data_list, Image.Image):
209
210
211
212
        # Fuyu's image_processor can also finish token padding
        image_processor: FuyuImageProcessor = cached_get_image_processor(
            model_config.model)

213
        model_image_input = _fuyu_image_preprocess(image_processor, data_list)
214
215
216
217
218
219
        data = torch.stack([
            image_patch[0]
            for image_patch in model_image_input["image_patches"]
        ])

    # image has been processed with prompt in input processor
220
    return MultiModalKwargs({"pixel_values": data})
221
222
223
224
225
226


@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_fuyu)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_fuyu_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_fuyu)
@INPUT_REGISTRY.register_input_processor(input_processor_for_fuyu)
227
class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
228

229
    def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
230
        super().__init__()
231
232
233
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
234
235
236
237
        self.config = config
        self.multimodal_config = multimodal_config

        self.padding_idx = config.pad_token_id
238
        self.vocab_size = config.text_config.vocab_size
239
240
241
242
243
244
245
        self.image_token_id = _IMAGE_TOKEN_ID
        self.image_feature_size = config.patch_size**2 * config.num_channels

        self.vision_embed_tokens = ColumnParallelLinear(
            self.image_feature_size,
            config.hidden_size,
            quant_config=quant_config,
246
            gather_output=True,
247
        )
248
249
        self.language_model = PersimmonForCausalLM(
            vllm_config.with_hf_config(config.text_config))
250
251
252
253
254
255
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

    @property
    def sampler(self):
        return self.language_model.sampler
256

257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
    def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:

        h = w = self.config.patch_size
        num_channels = self.config.num_channels
        expected_dims = num_channels * h * w

        def _validate_shape(d: torch.Tensor):
            actual_dims = d.size(-1)

            if actual_dims != expected_dims:
                expected_expr = str(expected_dims)
                raise ValueError(
                    "The expected shape of pixel values per image per batch "
                    f" per patch is {expected_expr}. "
                    f"You supplied {tuple(d.shape)}.")

        for d in data:
            _validate_shape(d)

        return data.to(self.vision_embed_tokens.weight.dtype)

278
279
    def _parse_and_validate_image_input(
            self, **kwargs: object) -> Optional[FuyuImagePixelInputs]:
280
        pixel_values = kwargs.pop("pixel_values", None)
281

282
283
284
285
286
287
288
289
290
291
        if pixel_values is not None:
            if not isinstance(pixel_values, (torch.Tensor, list)):
                raise ValueError("Incorrect type of image patches. "
                                 f"Got type: {type(pixel_values)}")

            return FuyuImagePixelInputs(
                type="pixel_values",
                data=self._validate_pixel_values(
                    flatten_bn(pixel_values, concat=True)),
            )
292

293
294
        return None

295
296
297
298
299
300
301
    def _process_image_input(
            self, image_input: FuyuImagePixelInputs) -> torch.Tensor:

        assert self.vision_embed_tokens is not None
        vision_embeddings, _ = self.vision_embed_tokens(image_input["data"])
        return vision_embeddings

302
303
304
305
306
307
308
309
310
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        **kwargs: object,
    ):
311
312
313
314
315
        if intermediate_tensors is not None:
            input_ids = None
            inputs_embeds = None
        else:
            image_input = self._parse_and_validate_image_input(**kwargs)
316

317
318
319
320
321
322
323
            if image_input is not None:
                vision_embeddings = self._process_image_input(image_input)
                inputs_embeds = self.language_model.model.embed_tokens(
                    input_ids)
                inputs_embeds = merge_multimodal_embeddings(
                    input_ids, inputs_embeds, vision_embeddings,
                    self.image_token_id)
324

325
326
            else:
                inputs_embeds = None
327
328
329
330
331
332

        hidden_states = self.language_model(
            input_ids=input_ids,
            positions=positions,
            kv_caches=kv_caches,
            attn_metadata=attn_metadata,
333
            intermediate_tensors=intermediate_tensors,
334
335
336
337
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

338
339
340
341
342
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
343
344
345
346
347
348
349
350
351
352
353
354
355
        logits = self.language_model.logits_processor(
            self.language_model.lm_head, hidden_states, sampling_metadata)
        return logits

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        next_tokens = self.language_model.sampler(logits, sampling_metadata)
        return next_tokens

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
356
357
        loader = AutoWeightsLoader(self)
        loader.load_weights(weights)