processing.py 21.3 KB
Newer Older
1
2
3
import re
from abc import ABC, abstractmethod
from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence
4
from dataclasses import dataclass
5
6
7
from functools import lru_cache
from itertools import groupby
from typing import Any, Generic, NamedTuple, Optional, Protocol, TypeVar, Union
8

9
import numpy as np
10
from transformers import BatchFeature
11
from typing_extensions import TypeAlias, TypedDict
12
13
14

from vllm.inputs import InputProcessingContext
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
15
from vllm.utils import flatten_2d_lists, full_groupby, is_list_of
16
17
18
19
20

from .inputs import (AudioItem, ImageItem, MultiModalDataDict,
                     MultiModalInputsV2, MultiModalKwargs, PlaceholderRange,
                     VideoItem)

21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36

def bind_prompt_sequence(
    seq: Union[str, list[int]],
    tokenizer: AnyTokenizer,
) -> "_BoundPromptSequence":
    """
    Bind a text or token sequence to a tokenizer so that it can be
    lazily converted into the other format on demand.
    """
    return _BoundPromptSequence(
        tokenizer=tokenizer,
        _text=seq if isinstance(seq, str) else None,
        _token_ids=seq if isinstance(seq, list) else None,
    )


37
_T = TypeVar("_T")
38
_S = TypeVar("_S", str, list[int])
39

40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77

@dataclass
class PromptReplacement(Generic[_S, _T]):
    target: _S
    """The text or token sequence to find and replace."""

    repl_unit: _S
    """
    The unit making up the replacement text or token sequence.
    
    See :code:`repl_count` for more details.
    """

    repl_count: Union[Callable[[list[_T], BatchFeature, int], int], int]
    """
    Given the original multi-modal items for this modality, HF-processed data,
    and index of the processed item, output the number of repetitions of
    :code:`repl_unit` to build up the replacement text or token sequence.

    For convenience, you can pass in an integer if the number of repetitions is
    a constant.
    """

    def __repr__(self) -> str:
        return (f"{type(self).__name__}(target={self.target!r}, "
                f"repl_unit={self.repl_unit!r})")

    def bind(
        self,
        modality: str,
        tokenizer: AnyTokenizer,
    ) -> "_BoundPromptReplacement[_T]":
        return _BoundPromptReplacement(
            modality=modality,
            target=bind_prompt_sequence(self.target, tokenizer),
            repl_unit=bind_prompt_sequence(self.repl_unit, tokenizer),
            repl_count=self.repl_count,
        )
78
79
80
81


@dataclass
class ModalityProcessingMetadata(Generic[_T]):
82
83
    prompt_repls: Sequence[Union[PromptReplacement[str, _T],
                                 PromptReplacement[list[int], _T]]]
84
    """
85
86
87
88
    Defines each text or token sequence to replace in the HF-processed prompt.

    This is skipped if the HF-processed prompt is found to already contain
    the replacement prompts.
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
    """


class MultiModalProcessingMetadataBuiltins(TypedDict, total=False):
    """Type annotations for modality types predefined by vLLM."""

    image: ModalityProcessingMetadata[ImageItem]
    video: ModalityProcessingMetadata[VideoItem]
    audio: ModalityProcessingMetadata[AudioItem]


MultiModalProcessingMetadata: TypeAlias = \
    Mapping[str, ModalityProcessingMetadata[Any]]
"""
A dictionary containing an entry for each modality type to process.

Note:
    This dictionary also accepts modality keys defined outside
    :class:`MultiModalProcessingMetadataBuiltins` as long as a customized plugin
    is registered through the :class:`~vllm.multimodal.MULTIMODAL_REGISTRY`.
    Read more on that :ref:`here <adding_multimodal_plugin>`.
"""


113
114
115
116
117
118
119
120
121
122
123
124
125
126
def _encode(
    tokenizer: AnyTokenizer,
    text: str,
    *,
    add_special_tokens: bool = False,
) -> list[int]:
    """
    Backend-agnostic equivalent of HF's
    :code:`tokenizer.encode(text, add_special_tokens=...)`.
    """
    if isinstance(tokenizer, MistralTokenizer):
        return tokenizer.tokenizer.encode(text,
                                          bos=add_special_tokens,
                                          eos=add_special_tokens)
127

128
    return tokenizer.encode(text, add_special_tokens=add_special_tokens)
129
130


131
132
133
134
135
136
137
138
@lru_cache(maxsize=2048)
def _cached_encode(
    tokenizer: AnyTokenizer,
    text: str,
    *,
    add_special_tokens: bool = False,
) -> list[int]:
    return _encode(tokenizer, text, add_special_tokens=add_special_tokens)
139
140


141
142
143
144
145
146
147
148
149
150
151
def _decode(
    tokenizer: AnyTokenizer,
    token_ids: list[int],
    *,
    skip_special_tokens: bool = False,
) -> str:
    """
    Backend-agnostic equivalent of HF's
    :code:`tokenizer.decode(token_ids, skip_special_tokens=...)`.
    """
    return tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
152
153


154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
@lru_cache(maxsize=2048)
def _cached_decode(
    tokenizer: AnyTokenizer,
    token_ids: tuple[int, ...],
    *,
    skip_special_tokens: bool = False,
) -> str:
    return _decode(tokenizer,
                   list(token_ids),
                   skip_special_tokens=skip_special_tokens)


class _HasModalityAttr(Protocol):
    modality: str

169

170
class _HasModalityProp(Protocol):
171

172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
    @property
    def modality(self) -> str:
        ...


_M = TypeVar("_M", bound=Union[_HasModalityAttr, _HasModalityProp])


def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]:
    """Convenience function to apply :func:`full_groupby` based on modality."""
    return full_groupby(values, key=lambda x: x.modality)


@dataclass
class _BoundPromptSequence:
    tokenizer: AnyTokenizer
    _text: Optional[str]
    _token_ids: Optional[list[int]]

    def __post_init__(self) -> None:
        if self._text is None and self._token_ids is None:
            raise ValueError("At least one of 'text' and 'token_ids' must be "
                             "specified")

    @property
    def text(self) -> str:
        if self._text is None:
            assert self._token_ids is not None
            self._text = _cached_decode(self.tokenizer, tuple(self._token_ids))

        return self._text

    @property
    def token_ids(self) -> list[int]:
        if self._token_ids is None:
            assert self._text is not None
            self._token_ids = _cached_encode(self.tokenizer, self._text)

        return self._token_ids

    def __repr__(self) -> str:
        return (f"{type(self).__name__}(_text={self._text!r}, "
                f"_token_ids={self._token_ids!r})")


@dataclass
class _BoundPromptReplacement(Generic[_T]):
    modality: str
    target: _BoundPromptSequence
    repl_unit: _BoundPromptSequence
    repl_count: Union[Callable[[list[_T], BatchFeature, int], int], int]

    def get_count(
        self,
        mm_items: list[_T],
        hf_inputs: BatchFeature,
        item_idx: int,
    ) -> int:
        repl_count = self.repl_count
        if isinstance(repl_count, int):
            return repl_count

        return repl_count(mm_items, hf_inputs, item_idx)


def to_multi_format(data: MultiModalDataDict) -> dict[str, list[Any]]:
238
239
240
241
242
    """
    Convert a :class:`MultiModalDataDict` containing single data items
    to a :class:`MultiModalMultiDataDict` containing multiple data items
    per entry.
    """
243
    multi_data = dict[str, list[Any]]()
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258

    for k, v in data.items():
        # yapf: disable
        if k == "video":
            # Special case since even a single item can be a list
            multi_data[k] = v if is_list_of(v, list) else [v]  # type: ignore[index]
        elif k in ("image", "audio"):
            multi_data[k] = v if isinstance(v, list) else [v]  # type: ignore[index]
        else:
            multi_data[k] = v if isinstance(v, list) else [v]  # type: ignore[index]
        # yapf: enable

    return multi_data


259
260
261
262
263
264
265
266
class _TokenRun(NamedTuple):
    token_id: int

    start_idx: int
    length: int


def iter_token_runs(token_ids: list[int]) -> Iterable[_TokenRun]:
267
    """
268
    Yield the starting index and length of each run of tokens that are the same.
269
    """
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
    start_idx = 0

    for token_id, it in groupby(token_ids):
        length = sum(1 for _ in it)
        yield _TokenRun(token_id=token_id, start_idx=start_idx, length=length)

        start_idx += length


class _PlaceholderInfo(NamedTuple):
    modality: str
    offset: int
    length: int

    def to_range(self) -> PlaceholderRange:
        return PlaceholderRange(offset=self.offset, length=self.length)


def iter_placeholders(
    prompt_repls: Sequence[_BoundPromptReplacement[Any]],
    token_ids: list[int],
    *,
    min_placeholder_count: int,
) -> Iterable[_PlaceholderInfo]:
    """Yield each set of placeholder tokens found in :code:`token_ids`."""
    placeholder_ids_by_modality = {
        modality: {
            token_id
            for prompt_repl in repls
            for token_id in prompt_repl.repl_unit.token_ids
        }
        for modality, repls in full_groupby_modality(prompt_repls)
    }
303

304
305
306
307
308
309
310
311
312
313
    for run_info in iter_token_runs(token_ids):
        if run_info.length > min_placeholder_count:
            for (modality,
                 placeholder_ids) in placeholder_ids_by_modality.items():
                if run_info.token_id in placeholder_ids:
                    yield _PlaceholderInfo(
                        modality=modality,
                        offset=run_info.start_idx,
                        length=run_info.length,
                    )
314
315


316
317
318
class _TokenMatch(NamedTuple):
    start_idx: int
    end_idx: int
319
320


321
322
323
324
325
326
def iter_token_matches(
    token_ids: list[int],
    match_ids: list[int],
) -> Iterable[_TokenMatch]:
    """Yield each occurrence of :code:`match_ids` in :code:`token_ids`."""
    match_len = len(match_ids)
327

328
329
330
331
    last_end_idx = 0
    for start_idx in range(len(token_ids) - match_len + 1):
        if start_idx < last_end_idx:
            continue  # Exclude overlapping matches
332

333
334
335
336
        end_idx = start_idx + match_len
        if token_ids[start_idx:end_idx] == match_ids:
            yield _TokenMatch(start_idx=start_idx, end_idx=end_idx)
            last_end_idx = end_idx
337
338


339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
class _PromptReplacementMatch(ABC, Generic[_T, _S]):
    prompt_repl: _BoundPromptReplacement[_T]

    @property
    def modality(self) -> str:
        return self.prompt_repl.modality

    @property
    @abstractmethod
    def start_idx(self) -> int:
        raise NotImplementedError

    @property
    @abstractmethod
    def end_idx(self) -> int:
        raise NotImplementedError

    @abstractmethod
    def get_repl(
        self,
        mm_items: list[_T],
        hf_inputs: BatchFeature,
        item_idx: int,
    ) -> _S:
        raise NotImplementedError

    def __repr__(self) -> str:
        return (f"{type(self).__name__}(modality={self.modality!r}, "
                f"start_idx={self.start_idx!r}, end_idx={self.end_idx!r})")


@dataclass(repr=False)
class _PromptReplacementTokenMatch(_PromptReplacementMatch[_T, list[int]]):
    prompt_repl: _BoundPromptReplacement[_T]
    match: _TokenMatch

    @property
    def start_idx(self) -> int:
        return self.match.start_idx

    @property
    def end_idx(self) -> int:
        return self.match.end_idx

    def get_repl(
        self,
        mm_items: list[_T],
        hf_inputs: BatchFeature,
        item_idx: int,
    ) -> list[int]:
        prompt_repl = self.prompt_repl
        count = prompt_repl.get_count(mm_items, hf_inputs, item_idx)
        return prompt_repl.repl_unit.token_ids * count
392

393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448

@dataclass(repr=False)
class _PromptReplacementTextMatch(_PromptReplacementMatch[_T, str]):
    prompt_repl: _BoundPromptReplacement[_T]
    match: re.Match[str]

    @property
    def start_idx(self) -> int:
        return self.match.start()

    @property
    def end_idx(self) -> int:
        return self.match.end()

    def get_repl(
        self,
        mm_items: list[_T],
        hf_inputs: BatchFeature,
        item_idx: int,
    ) -> str:
        prompt_repl = self.prompt_repl
        count = prompt_repl.get_count(mm_items, hf_inputs, item_idx)
        return prompt_repl.repl_unit.text * count


def find_token_matches(
    prompt: list[int],
    prompt_repls: Sequence[_BoundPromptReplacement[_T]],
) -> list[_PromptReplacementTokenMatch[_T]]:
    """Return each target of :code:`prompt_repls` found in :code:`prompt`."""
    return [
        _PromptReplacementTokenMatch(prompt_repl, match)
        for prompt_repl in prompt_repls
        for match in iter_token_matches(prompt, prompt_repl.target.token_ids)
    ]


def find_text_matches(
    prompt: str,
    prompt_repls: Sequence[_BoundPromptReplacement[_T]],
) -> list[_PromptReplacementTextMatch[_T]]:
    """Return each target of :code:`prompt_repls` found in :code:`prompt`."""
    return [
        _PromptReplacementTextMatch(prompt_repl, match)
        for prompt_repl in prompt_repls
        for match in re.finditer(re.escape(prompt_repl.target.text), prompt)
    ]


def _resolve_matches(
    prompt: _S,
    matches: Sequence[_PromptReplacementMatch[_T, _S]],
) -> list[_PromptReplacementMatch[_T, _S]]:
    """
    Resolve :code:`matches` to ensure that there are no overlapping matches,
    and sort them such that earlier matches take priority over later ones.
449
    """
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
    num_matches_by_idx = np.zeros(len(prompt), dtype=int)
    for match in matches:
        num_matches_by_idx[match.start_idx:match.end_idx] += 1

    duplicate_matches_idxs, = np.nonzero(num_matches_by_idx > 1)
    if len(duplicate_matches_idxs) > 0:
        raise ValueError("Unable to find a unique replacement "
                         f"at indices={duplicate_matches_idxs} "
                         f"of prompt={prompt}")

    return sorted(matches, key=lambda x: x.start_idx)


def _replace_matches(
    prompt: _S,
    matches: Sequence[_PromptReplacementMatch[_T, _S]],
    mm_items_by_modality: Mapping[str, list[_T]],
    hf_inputs: BatchFeature,
) -> list[_S]:
    out_seqs = list[_S]()
    prev_end_idx = 0
    next_idx_by_modality = {modality: 0 for modality in mm_items_by_modality}

    for match in _resolve_matches(prompt, matches):
        modality = match.modality
        mm_items = mm_items_by_modality[modality]

        item_idx = next_idx_by_modality[modality]
        if item_idx >= len(mm_items):
            continue

        start_idx = match.start_idx
        end_idx = match.end_idx
        repl_ids = match.get_repl(mm_items, hf_inputs, item_idx)

        out_seqs.append(prompt[prev_end_idx:start_idx] + repl_ids)
        prev_end_idx = end_idx
        next_idx_by_modality[modality] += 1

    out_seqs.append(prompt[prev_end_idx:])

    return out_seqs


def replace_token_matches(
    prompt: list[int],
    matches: Sequence[_PromptReplacementMatch[_T, list[int]]],
    mm_items_by_modality: Mapping[str, list[_T]],
    hf_inputs: BatchFeature,
) -> list[int]:
    """Apply :code:`prompt_repls` to :code:`prompt`."""
    if not matches:
        return prompt

    token_id_seqs = _replace_matches(
        prompt,
        matches,
        mm_items_by_modality,
        hf_inputs,
    )

    return flatten_2d_lists(token_id_seqs)
512
513


514
515
516
517
518
519
520
521
522
def replace_text_matches(
    prompt: str,
    matches: Sequence[_PromptReplacementMatch[_T, str]],
    mm_items_by_modality: Mapping[str, list[_T]],
    hf_inputs: BatchFeature,
) -> str:
    """Apply :code:`prompt_repls` to :code:`prompt`."""
    if not matches:
        return prompt
523

524
525
526
527
528
529
530
531
    texts = _replace_matches(
        prompt,
        matches,
        mm_items_by_modality,
        hf_inputs,
    )

    return "".join(texts)
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556


class MultiModalProcessor:
    """
    Helper class to process multi-modal inputs to be used in vLLM.
    """

    def __init__(
        self,
        ctx: InputProcessingContext,
        metadata: MultiModalProcessingMetadata,
    ) -> None:
        super().__init__()

        self.ctx = ctx
        self.metadata = metadata

    def __call__(
        self,
        prompt: str,
        mm_data: MultiModalDataDict,
        mm_processor_kwargs: Mapping[str, object],
    ) -> MultiModalInputsV2:
        return self.apply(prompt, mm_data, mm_processor_kwargs)

557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
    def _find_placeholders(
        self,
        all_prompt_repls: Sequence[_BoundPromptReplacement[Any]],
        new_token_ids: list[int],
        *,
        # To avoid false positives from multi-input when detecting
        # whether placeholder tokens have been inserted, in case
        # the target sequence is a subset of the replacement tokens
        min_placeholder_count: int = 16,
    ) -> list[_PlaceholderInfo]:
        return list(
            iter_placeholders(
                all_prompt_repls,
                new_token_ids,
                min_placeholder_count=min_placeholder_count,
            ))

    def _apply_hf_processor(
575
576
577
578
        self,
        prompt: str,
        mm_data: MultiModalDataDict,
        mm_processor_kwargs: Mapping[str, object],
579
    ) -> BatchFeature:
580
581
        hf_processor = self.ctx.get_hf_processor()

582
        return hf_processor(
583
584
585
586
587
            text=prompt,  # type: ignore
            **mm_data,
            **mm_processor_kwargs,
        )

588
589
590
591
592
    def _bind_prompt_replacements(
        self,
        mm_data: MultiModalDataDict,
    ) -> list[_BoundPromptReplacement[Any]]:
        tokenizer = self.ctx.tokenizer
593

594
595
596
597
598
        return [
            prompt_repl.bind(modality, tokenizer)
            for modality, metadata in self.metadata.items()
            if modality in mm_data for prompt_repl in metadata.prompt_repls
        ]
599

600
601
602
603
604
605
606
607
    def _apply_prompt_replacements(
        self,
        mm_data: MultiModalDataDict,
        hf_inputs: BatchFeature,
        token_ids: list[int],
        prompt_repls: Sequence[_BoundPromptReplacement[Any]],
    ) -> tuple[list[int], str, list[_PlaceholderInfo]]:
        tokenizer = self.ctx.tokenizer
608

609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
        mm_items = to_multi_format(mm_data)
        token_matches = find_token_matches(token_ids, prompt_repls)

        # If the search text does not represent a special token,
        # it may have different token IDs in the prompt, because
        # the tokens may go across the boundaries of the search text.
        # ----
        # e.g. when searching for "foo" in "food", if "food" itself makes
        # up a token, then the token ID of "foo" will not appear at all
        # ----
        # Since it is inefficient to search for all possible tokenizations
        # of the search text in the prompt, we instead perform string
        # replacement on the decoded token IDs, then encode them back.
        if all(
            len(matches) >= len(mm_data[modality])
            for modality, matches in full_groupby_modality(token_matches)
        ):  # yapf: disable
            token_ids = replace_token_matches(
                token_ids,
                token_matches,
                mm_items,
                hf_inputs,
            )

            text = _decode(tokenizer, token_ids)
            matched_repls = [match.prompt_repl for match in token_matches]
        else:
            text = _decode(tokenizer, token_ids)

            text_matches = find_text_matches(text, prompt_repls)
            text = replace_text_matches(
                text,
                text_matches,
                mm_items,
                hf_inputs,
            )

            token_ids = _encode(tokenizer, text)
            matched_repls = [match.prompt_repl for match in text_matches]

        placeholders = self._find_placeholders(matched_repls, token_ids)

        # Sanity check
        assert len(placeholders) == len(matched_repls), dict(
            # Log this information for easier debugging
            text=text,
            token_ids=token_ids,
            placeholders=placeholders,
            matched_repls=matched_repls,
        )
659

660
        return token_ids, text, placeholders
661

662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
    def apply(
        self,
        prompt_text: str,
        mm_data: MultiModalDataDict,
        mm_processor_kwargs: Mapping[str, object],
    ) -> MultiModalInputsV2:
        """
        Process multi-modal inputs to be used in vLLM.

        The main steps are:

        1. Apply HF Processor on prompt text and multi-modal data together,
           outputting token IDs and processed tensors.
        2. Find and replace sequences in the token IDs with placeholder tokens.
           The number of placeholder tokens equals the feature size of the
           multi-modal data outputted by the multi-modal encoder.
        3. Extract information about the placeholder tokens from the
           processed token IDs.
        """
        tokenizer = self.ctx.tokenizer

        hf_inputs = self._apply_hf_processor(prompt_text, mm_data,
                                             mm_processor_kwargs)
        prompt_ids, = hf_inputs.pop("input_ids").tolist()
        mm_kwargs = MultiModalKwargs(hf_inputs)
687

688
        all_prompt_repls = self._bind_prompt_replacements(mm_data)
689

690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
        # If HF processor already inserts placeholder tokens,
        # there is no need for us to insert them
        all_placeholders = self._find_placeholders(all_prompt_repls,
                                                   prompt_ids)
        if all_placeholders:
            prompt_text = _decode(tokenizer, prompt_ids)
        else:
            (
                prompt_ids,
                prompt_text,
                all_placeholders,
            ) = self._apply_prompt_replacements(
                mm_data,
                hf_inputs,
                prompt_ids,
                all_prompt_repls,
            )

        mm_placeholders = {
            modality: [item.to_range() for item in items]
            for modality, items in full_groupby_modality(all_placeholders)
        }
712
713
714

        return MultiModalInputsV2(
            type="multimodal",
715
716
            prompt=prompt_text,
            prompt_token_ids=prompt_ids,
717
718
719
            mm_kwargs=mm_kwargs,
            mm_placeholders=mm_placeholders,
        )