processing.py 8.78 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
from dataclasses import dataclass
from functools import lru_cache, partial
from typing import (Any, Callable, Collection, Generic, List, Mapping,
                    Optional, TypedDict, TypeVar, final)

from transformers import BatchFeature
from typing_extensions import TypeAlias

from vllm.inputs import InputProcessingContext
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import is_list_of

from .inputs import (AudioItem, ImageItem, MultiModalDataDict,
                     MultiModalInputsV2, MultiModalKwargs, PlaceholderRange,
                     VideoItem)

_T = TypeVar("_T")

ReplacementFunc: TypeAlias = Callable[[_T, BatchFeature, int], List[int]]
"""
Given the original data item, HF-processed data, and index of the processed
item, output the replacement token IDs to be allocated in vLLM.
"""


@dataclass
class ModalityProcessingMetadata(Generic[_T]):
    placeholder_replacements: Mapping[str, ReplacementFunc]
    """
    A dictionary where each item represents the original placeholder in the
    prompt text and the corresponding replacement.
    """


class MultiModalProcessingMetadataBuiltins(TypedDict, total=False):
    """Type annotations for modality types predefined by vLLM."""

    image: ModalityProcessingMetadata[ImageItem]
    video: ModalityProcessingMetadata[VideoItem]
    audio: ModalityProcessingMetadata[AudioItem]


MultiModalProcessingMetadata: TypeAlias = \
    Mapping[str, ModalityProcessingMetadata[Any]]
"""
A dictionary containing an entry for each modality type to process.

Note:
    This dictionary also accepts modality keys defined outside
    :class:`MultiModalProcessingMetadataBuiltins` as long as a customized plugin
    is registered through the :class:`~vllm.multimodal.MULTIMODAL_REGISTRY`.
    Read more on that :ref:`here <adding_multimodal_plugin>`.
"""

MultiModalMultiData: TypeAlias = List[_T]
"""
A list of data items, where the number of data items allowed
per modality is restricted by :code:`--limit-mm-per-prompt`.
"""


@final
class MultiModalMultiDataBuiltins(TypedDict, total=False):
    """Type annotations for modality types predefined by vLLM."""

    image: MultiModalMultiData[ImageItem]
    """The input images."""

    video: MultiModalMultiData[VideoItem]
    """The input videos."""

    audio: MultiModalMultiData[AudioItem]
    """The input audios."""


MultiModalMultiDataDict: TypeAlias = Mapping[str, MultiModalMultiData[Any]]
"""
A dictionary containing an entry for each modality type to input.

Note:
    This dictionary also accepts modality keys defined outside
    :class:`MultiModalMultiDataBuiltins` as long as a customized plugin
    is registered through the :class:`~vllm.multimodal.MULTIMODAL_REGISTRY`.
    Read more on that :ref:`here <adding_multimodal_plugin>`.
"""


def to_multi_format(data: MultiModalDataDict) -> MultiModalMultiDataDict:
    """
    Convert a :class:`MultiModalDataDict` containing single data items
    to a :class:`MultiModalMultiDataDict` containing multiple data items
    per entry.
    """
    multi_data: Mapping[str, MultiModalMultiData[Any]] = {}

    for k, v in data.items():
        # yapf: disable
        if k == "video":
            # Special case since even a single item can be a list
            multi_data[k] = v if is_list_of(v, list) else [v]  # type: ignore[index]
        elif k in ("image", "audio"):
            multi_data[k] = v if isinstance(v, list) else [v]  # type: ignore[index]
        else:
            multi_data[k] = v if isinstance(v, list) else [v]  # type: ignore[index]
        # yapf: enable

    return multi_data


def encode_no_special_tokens(
    tokenizer: AnyTokenizer,
    text: str,
) -> List[int]:
    """
    Backend-agnostic equivalent of HF's
    :code:`tokenizer.encode(text, add_special_tokens=False)`.
    """
    if isinstance(tokenizer, MistralTokenizer):
        return tokenizer.tokenizer.encode(text, bos=False, eos=False)

    return tokenizer.encode(text, add_special_tokens=False)


@lru_cache
def candidate_placeholders(
    tokenizer: AnyTokenizer,
    placeholder_text: str,
) -> Collection[List[int]]:
    """Generate token ID sequences that may represent a placeholder text."""
    # When the placeholder text is not mapped to a special token ID,
    # it may be tokenized differently based on whether it is at the start/end
    # of the string. So, we go through each combination of whether the text
    # is at the start and end boundaries of the string

    # Matches the placeholder when it is in the middle of the string
    start_id, = encode_no_special_tokens(tokenizer, "a")
    end_id, = encode_no_special_tokens(tokenizer, "b")

    candidate_basic = encode_no_special_tokens(tokenizer, placeholder_text)

    start_id_, *candidate_a = encode_no_special_tokens(
        tokenizer,
        f"a{placeholder_text}",
    )
    assert start_id == start_id_

    start_id_, *candidate_ab, end_id_ = encode_no_special_tokens(
        tokenizer,
        f"a{placeholder_text}b",
    )
    assert start_id == start_id_ and end_id == end_id_

    *candidate_b, end_id_ = encode_no_special_tokens(
        tokenizer,
        f"{placeholder_text}b",
    )
    assert end_id == end_id_

    # Remove duplicates (need to convert to tuple to be hashable)
    unique_candidates = {
        tuple(c)
        for c in [candidate_basic, candidate_a, candidate_ab, candidate_b]
    }

    # Convert back to list
    return [list(c) for c in unique_candidates]


def apply_placeholders(
    token_ids: List[int],
    placeholder_ids: List[int],
    get_replacement_ids: Callable[[], List[int]],
) -> Optional[PlaceholderRange]:
    """
    Find the first occurrence of :code:`placeholder_ids`,
    and replace it with the output of :code:`get_replacement_ids`.

    This function updates :code:`token_ids` in place.
    """
    placeholder_length = len(placeholder_ids)

    for start_idx in range(len(token_ids) - placeholder_length + 1):
        if token_ids[start_idx:placeholder_length] == placeholder_ids:
            token_ids[start_idx:placeholder_length] = get_replacement_ids()

            return PlaceholderRange(offset=start_idx,
                                    length=placeholder_length)

    return None


class MultiModalProcessor:
    """
    Helper class to process multi-modal inputs to be used in vLLM.
    """

    def __init__(
        self,
        ctx: InputProcessingContext,
        metadata: MultiModalProcessingMetadata,
    ) -> None:
        super().__init__()

        self.ctx = ctx
        self.metadata = metadata

    def __call__(
        self,
        prompt: str,
        mm_data: MultiModalDataDict,
        mm_processor_kwargs: Mapping[str, object],
    ) -> MultiModalInputsV2:
        return self.apply(prompt, mm_data, mm_processor_kwargs)

    def apply(
        self,
        prompt: str,
        mm_data: MultiModalDataDict,
        mm_processor_kwargs: Mapping[str, object],
    ) -> MultiModalInputsV2:
        tokenizer = self.ctx.tokenizer
        hf_processor = self.ctx.get_hf_processor()

        processed_inputs = hf_processor(
            text=prompt,  # type: ignore
            **mm_data,
            **mm_processor_kwargs,
        )
        new_token_ids, = processed_inputs.pop("input_ids").tolist()
        mm_kwargs = MultiModalKwargs(processed_inputs)

        mm_placeholders: Mapping[str, List[PlaceholderRange]] = {}

        for modality, orig_inputs in to_multi_format(mm_data).items():
            assert isinstance(orig_inputs, list)

            metadata = self.metadata[modality]
            placeholder_replacements = metadata.placeholder_replacements

            modality_placeholders: List[PlaceholderRange] = []

            for item_idx, orig_item in enumerate(orig_inputs):
                for match_text, replace_fn in placeholder_replacements.items():
                    candidates = candidate_placeholders(tokenizer, match_text)
                    get_replacement_ids = partial(
                        replace_fn,
                        orig_item,
                        processed_inputs,
                        item_idx,
                    )

                    for match_ids in candidates:
                        # TODO(youkaichao): Don't update new_token_ids
                        placeholders = apply_placeholders(
                            new_token_ids,
                            match_ids,
                            get_replacement_ids,
                        )

                        if placeholders is not None:
                            modality_placeholders.append(placeholders)

            # yapf: disable
            mm_placeholders[modality] = modality_placeholders  # type: ignore[index]
            # yapf: enable

        return MultiModalInputsV2(
            type="multimodal",
            prompt=prompt,
            prompt_token_ids=new_token_ids,
            mm_kwargs=mm_kwargs,
            mm_placeholders=mm_placeholders,
        )