image.py 4.75 KB
Newer Older
1
from functools import lru_cache
2
from typing import List, Optional, Tuple, TypeVar
3
4
5

import torch
from PIL import Image
6
from transformers import PreTrainedTokenizerBase
7

8
9
from vllm.config import ModelConfig
from vllm.inputs.registry import InputContext
10
from vllm.logger import init_logger
11
from vllm.transformers_utils.image_processor import get_image_processor
12
from vllm.transformers_utils.tokenizer import get_tokenizer
13

14
from .base import MultiModalInputs, MultiModalPlugin
15
16
17

logger = init_logger(__name__)

18
cached_get_image_processor = lru_cache(get_image_processor)
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
cached_get_tokenizer = lru_cache(get_tokenizer)

# Utilities for image input processors
_T = TypeVar("_T", str, int)


def repeat_and_pad_token(
    token: _T,
    *,
    repeat_count: int = 1,
    pad_token_left: Optional[_T] = None,
    pad_token_right: Optional[_T] = None,
) -> List[_T]:
    replacement = [token] * repeat_count
    if pad_token_left is not None:
        replacement = [pad_token_left] + replacement
    if pad_token_right is not None:
        replacement = replacement + [pad_token_right]

    return replacement


def repeat_and_pad_image_tokens(
    tokenizer: PreTrainedTokenizerBase,
    prompt: Optional[str],
    prompt_token_ids: List[int],
    *,
    image_token_id: int,
    repeat_count: int = 1,
    pad_token_left: Optional[int] = None,
    pad_token_right: Optional[int] = None,
) -> Tuple[Optional[str], List[int]]:
    if prompt is None:
        new_prompt = None
    else:
        image_token_str = tokenizer.decode(image_token_id)
        pad_token_str_left = (None if pad_token_left is None else
                              tokenizer.decode(pad_token_left))
        pad_token_str_right = (None if pad_token_right is None else
                               tokenizer.decode(pad_token_right))
        replacement_str = "".join(
            repeat_and_pad_token(
                image_token_str,
                repeat_count=repeat_count,
                pad_token_left=pad_token_str_left,
                pad_token_right=pad_token_str_right,
            ))

        image_token_count = prompt.count(image_token_str)
        # This is an arbitrary number to distinguish between the two cases
        if image_token_count > 16:
            logger.warning(
                "Please follow the prompt format that is "
                "documented on HuggingFace which does not involve "
                "repeating %s tokens.", image_token_str)
        elif image_token_count > 1:
            logger.warning("Multiple image input is not supported yet, "
                           "so any extra image tokens will be treated "
                           "as plain text.")

        # The image tokens are removed to be consistent with HuggingFace
        new_prompt = prompt.replace(image_token_str, replacement_str, 1)

    new_token_ids: List[int] = []
    for i, token in enumerate(prompt_token_ids):
        if token == image_token_id:
            replacement_ids = repeat_and_pad_token(
                image_token_id,
                repeat_count=repeat_count,
                pad_token_left=pad_token_left,
                pad_token_right=pad_token_right,
            )
            new_token_ids.extend(replacement_ids)

            # No need to further scan the list since we only replace once
            new_token_ids.extend(prompt_token_ids[i + 1:])
            break
        else:
            new_token_ids.append(token)

    return new_prompt, new_token_ids
100
101


102
class ImagePlugin(MultiModalPlugin):
103
    """Plugin for image data."""
104

105
106
    def get_data_key(self) -> str:
        return "image"
107

108
    def _get_hf_image_processor(self, model_config: ModelConfig):
109
        return cached_get_image_processor(
110
111
            model_config.model,
            trust_remote_code=model_config.trust_remote_code)
112

113
    def _default_input_mapper(self, ctx: InputContext,
114
                              data: object) -> MultiModalInputs:
115
        model_config = ctx.model_config
116
        if isinstance(data, (Image.Image, list)):
117
            image_processor = self._get_hf_image_processor(model_config)
118
            if image_processor is None:
119
                raise RuntimeError("No HuggingFace processor is available "
120
121
                                   "to process the image object")
            try:
122
123
124
                batch_data = image_processor \
                    .preprocess(data, return_tensors="pt") \
                    .data
125
            except Exception:
126
                logger.error("Failed to process image (%s)", data)
127
128
                raise

129
130
131
132
133
            return MultiModalInputs(batch_data)
        elif isinstance(data, torch.Tensor):
            raise NotImplementedError("Embeddings input is not supported yet")

        raise TypeError(f"Invalid image type: {type(data)}")
134
135
136

    def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
        return 3000