processing.py 38.5 KB
Newer Older
1
2
import re
from abc import ABC, abstractmethod
3
from collections import defaultdict
4
from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence
5
from dataclasses import dataclass, field
6
from functools import lru_cache
7
8
from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol,
                    TypeVar, Union)
9

10
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
11

12
13
import vllm.envs as envs
from vllm.inputs import InputProcessingContext
14
from vllm.logger import init_logger
15
16
from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens,
                                               encode_tokens)
17
from vllm.utils import LRUCache, flatten_2d_lists, full_groupby
18

19
from .hasher import MultiModalHasher
20
from .inputs import (MultiModalDataDict, MultiModalFieldConfig,
21
22
                     MultiModalInputsV2, MultiModalKwargs,
                     MultiModalKwargsItem, PlaceholderRange)
23
from .parse import MultiModalDataItems, MultiModalDataParser
24
25
26

if TYPE_CHECKING:
    from .profiling import BaseDummyInputsBuilder
27

28
logger = init_logger(__name__)
29
30

_S = TypeVar("_S", str, list[int])
31
_PromptSeq = Union[str, list[int]]
32

33
34

@dataclass
35
class PromptReplacement:
36
37
38
39
    """
    Defines how to replace portions of an input prompt with placeholder tokens.
    """

40
    modality: str
41
    """The modality for which the replacement is made."""
42

43
    target: _PromptSeq
44
    """The token sequence (or text) to find and replace."""
45

46
47
    replacement: Union[Callable[[int], _PromptSeq],
                       _PromptSeq] = field(repr=False)
48
    """
49
50
    Given the index of the processed item within :attr:`modality`,
    output the replacement token sequence (or text).
51

52
53
    For convenience, you can directly pass in the replacement token sequence
    (or text) instead of a function if it does not depend on the input.
54
55
    """

56
57
    def bind(self, tokenizer: AnyTokenizer) -> "BoundPromptReplacement":
        return BoundPromptReplacement(
58
59
60
61
            tokenizer=tokenizer,
            modality=self.modality,
            _target=self.target,
            _replacement=self.replacement,
62
        )
63
64


65
66
67
68
69
70
71
@lru_cache(maxsize=2048)
def _cached_encode(
    tokenizer: AnyTokenizer,
    text: str,
    *,
    add_special_tokens: bool = False,
) -> list[int]:
72
73
74
    return encode_tokens(tokenizer,
                         text,
                         add_special_tokens=add_special_tokens)
75
76


77
78
79
80
81
82
83
@lru_cache(maxsize=2048)
def _cached_decode(
    tokenizer: AnyTokenizer,
    token_ids: tuple[int, ...],
    *,
    skip_special_tokens: bool = False,
) -> str:
84
85
86
    return decode_tokens(tokenizer,
                         list(token_ids),
                         skip_special_tokens=skip_special_tokens)
87
88
89
90
91


class _HasModalityAttr(Protocol):
    modality: str

92

93
class _HasModalityProp(Protocol):
94

95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
    @property
    def modality(self) -> str:
        ...


_M = TypeVar("_M", bound=Union[_HasModalityAttr, _HasModalityProp])


def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]:
    """Convenience function to apply :func:`full_groupby` based on modality."""
    return full_groupby(values, key=lambda x: x.modality)


@dataclass
class _BoundPromptSequence:
110
111
    tokenizer: AnyTokenizer = field(repr=False)

112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
    _text: Optional[str]
    _token_ids: Optional[list[int]]

    def __post_init__(self) -> None:
        if self._text is None and self._token_ids is None:
            raise ValueError("At least one of 'text' and 'token_ids' must be "
                             "specified")

    @property
    def text(self) -> str:
        if self._text is None:
            assert self._token_ids is not None
            self._text = _cached_decode(self.tokenizer, tuple(self._token_ids))

        return self._text

    @property
    def token_ids(self) -> list[int]:
        if self._token_ids is None:
            assert self._text is not None
            self._token_ids = _cached_encode(self.tokenizer, self._text)

        return self._token_ids


@dataclass
138
class BoundPromptReplacement:
139
140
141
142
143
    """
    A :class:`PromptReplacement` bound to a tokenizer to automatically
    convert :attr:`target` and the result of :meth:`get_replacement` between
    token sequence and text representations.
    """
144
    tokenizer: AnyTokenizer = field(repr=False)
145
146
    modality: str

147
148
149
    _target: _PromptSeq
    _replacement: Union[Callable[[int], _PromptSeq],
                        _PromptSeq] = field(repr=False)
150

151
152
153
154
155
    def __post_init__(self) -> None:
        self._replacement_cache = dict[int, _BoundPromptSequence]()

    @property
    def target(self) -> _BoundPromptSequence:
156
        """The token sequence (or text) to find and replace."""
157
        target = self._target
158

159
160
161
162
163
        return _BoundPromptSequence(
            tokenizer=self.tokenizer,
            _text=target if isinstance(target, str) else None,
            _token_ids=target if isinstance(target, list) else None,
        )
164

165
    def get_replacement(self, item_idx: int) -> _BoundPromptSequence:
166
167
168
169
        """
        Given the index of the processed item within :attr:`modality`,
        output the replacement token sequence (or text).
        """
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
        replacement = self._replacement
        if callable(replacement):
            cache_key = item_idx
            if cache_key in self._replacement_cache:
                return self._replacement_cache[cache_key]

            replacement = replacement(item_idx)
        else:
            cache_key = None

        bound_replacement = _BoundPromptSequence(
            tokenizer=self.tokenizer,
            _text=replacement if isinstance(replacement, str) else None,
            _token_ids=replacement if isinstance(replacement, list) else None,
        )

        if cache_key is not None:
            self._replacement_cache[cache_key] = bound_replacement

        return bound_replacement


192
193
194
class _TokenMatch(NamedTuple):
    start_idx: int
    end_idx: int
195
196


197
198
199
200
def iter_token_matches(
    token_ids: list[int],
    match_ids: list[int],
) -> Iterable[_TokenMatch]:
201
202
203
204
205
206
    """
    Yield each occurrence of :code:`match_ids` in :code:`token_ids`.

    Note that empty matches are ignored.
    """
    prompt_len = len(token_ids)
207
    match_len = len(match_ids)
208

209
210
    if match_len == 0:
        return
211

212
213
    start_idx = 0
    while start_idx < prompt_len - match_len + 1:
214
        end_idx = start_idx + match_len
215

216
217
        if token_ids[start_idx:end_idx] == match_ids:
            yield _TokenMatch(start_idx=start_idx, end_idx=end_idx)
218
219
220
221
222

            # Exclude overlapping matches
            start_idx = end_idx
        else:
            start_idx += 1
223
224


225
226
@dataclass(repr=False)
class _PromptReplacementMatch(ABC):
227
    prompt_repl: BoundPromptReplacement
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248

    @property
    def modality(self) -> str:
        return self.prompt_repl.modality

    @property
    @abstractmethod
    def start_idx(self) -> int:
        raise NotImplementedError

    @property
    @abstractmethod
    def end_idx(self) -> int:
        raise NotImplementedError

    def __repr__(self) -> str:
        return (f"{type(self).__name__}(modality={self.modality!r}, "
                f"start_idx={self.start_idx!r}, end_idx={self.end_idx!r})")


@dataclass(repr=False)
249
class _PromptReplacementTokenMatch(_PromptReplacementMatch):
250
251
252
253
254
255
256
257
258
259
260
261
    match: _TokenMatch

    @property
    def start_idx(self) -> int:
        return self.match.start_idx

    @property
    def end_idx(self) -> int:
        return self.match.end_idx


@dataclass(repr=False)
262
class _PromptReplacementTextMatch(_PromptReplacementMatch):
263
264
265
266
267
268
269
270
271
272
    match: re.Match[str]

    @property
    def start_idx(self) -> int:
        return self.match.start()

    @property
    def end_idx(self) -> int:
        return self.match.end()

273

274
@dataclass
275
class PlaceholderInfo:
276
    modality: str
277
    item_idx: int
278
    start_idx: int
279
    replacement: list[int]
280
281
282

    @property
    def length(self) -> int:
283
        return len(self.replacement)
284
285
286
287
288
289

    def to_range(self) -> PlaceholderRange:
        return PlaceholderRange(
            offset=self.start_idx,
            length=self.length,
        )
290
291
292
293


def find_token_matches(
    prompt: list[int],
294
    prompt_repls: Sequence[BoundPromptReplacement],
295
) -> list[_PromptReplacementTokenMatch]:
296
297
298
299
300
301
302
303
304
305
    """Return each target of :code:`prompt_repls` found in :code:`prompt`."""
    return [
        _PromptReplacementTokenMatch(prompt_repl, match)
        for prompt_repl in prompt_repls
        for match in iter_token_matches(prompt, prompt_repl.target.token_ids)
    ]


def find_text_matches(
    prompt: str,
306
    prompt_repls: Sequence[BoundPromptReplacement],
307
) -> list[_PromptReplacementTextMatch]:
308
309
310
311
312
313
314
315
316
    """Return each target of :code:`prompt_repls` found in :code:`prompt`."""
    return [
        _PromptReplacementTextMatch(prompt_repl, match)
        for prompt_repl in prompt_repls
        for match in re.finditer(re.escape(prompt_repl.target.text), prompt)
    ]


def _resolve_matches(
317
    prompt: _PromptSeq,
318
    mm_matches: Mapping[str, Sequence[_PromptReplacementMatch]],
319
) -> list[_PromptReplacementMatch]:
320
    """
321
    Resolve :code:`mm_matches` to ensure that there are no overlapping matches,
322
    and sort them such that earlier matches take priority over later ones.
323
    """
324
325
    matches = [m for matches in mm_matches.values() for m in matches]

326
327
    seen_matches: list[Optional[_PromptReplacementMatch]] = [None
                                                             ] * len(prompt)
328

329
    for match in matches:
330
331
332
333
334
        for idx in range(match.start_idx, match.end_idx):
            if seen_matches[idx] is not None:
                raise ValueError("Found overlapping matches "
                                 f"({seen_matches[idx]} and {match}) "
                                 f"at index={idx} of prompt={prompt}")
335

336
            seen_matches[idx] = match
337
338
339
340
341
342

    return sorted(matches, key=lambda x: x.start_idx)


def _replace_matches(
    prompt: _S,
343
    mm_matches: Mapping[str, Sequence[_PromptReplacementMatch]],
344
    mm_item_counts: Mapping[str, int],
345
) -> list[_S]:
346
    """Apply the replacements in :code:`mm_matches` to :code:`prompt`."""
347
348
    out_seqs = list[_S]()
    prev_end_idx = 0
349
    next_idx_by_modality = defaultdict[str, int](lambda: 0)
350

351
    for match in _resolve_matches(prompt, mm_matches):
352
353
354
        modality = match.modality

        item_idx = next_idx_by_modality[modality]
355
        if item_idx >= mm_item_counts.get(modality, 0):
356
357
358
359
            continue

        start_idx = match.start_idx
        end_idx = match.end_idx
360

361
        repl_info = match.prompt_repl
362
363
364
365
366
367
368
369
        replacement = repl_info.get_replacement(item_idx)

        if isinstance(prompt, str):
            repl_seq = replacement.text
            out_seqs.append(prompt[prev_end_idx:start_idx] + repl_seq)
        else:
            repl_seq = replacement.token_ids
            out_seqs.append(prompt[prev_end_idx:start_idx] + repl_seq)
370
371
372
373
374
375
376
377
378
379
380

        prev_end_idx = end_idx
        next_idx_by_modality[modality] += 1

    out_seqs.append(prompt[prev_end_idx:])

    return out_seqs


def replace_token_matches(
    prompt: list[int],
381
    mm_matches: Mapping[str, Sequence[_PromptReplacementTokenMatch]],
382
    mm_item_counts: Mapping[str, int],
383
) -> list[int]:
384
385
    """Apply the replacements in :code:`mm_matches` to :code:`prompt`."""
    if not mm_matches:
386
387
        return prompt

388
    token_id_seqs = _replace_matches(prompt, mm_matches, mm_item_counts)
389
390

    return flatten_2d_lists(token_id_seqs)
391
392


393
394
def replace_text_matches(
    prompt: str,
395
    mm_matches: Mapping[str, Sequence[_PromptReplacementTextMatch]],
396
    mm_item_counts: Mapping[str, int],
397
) -> str:
398
399
    """Apply the replacements in :code:`mm_matches` to :code:`prompt`."""
    if not mm_matches:
400
        return prompt
401

402
    texts = _replace_matches(prompt, mm_matches, mm_item_counts)
403
404

    return "".join(texts)
405
406


407
408
409
def _iter_modality_placeholders(
    prompt: list[int],
    modality: str,
410
    modality_repls: Sequence[BoundPromptReplacement],
411
    modal_item_count: int,
412
) -> Iterable[PlaceholderInfo]:
413
    if modal_item_count == 0:
414
        return
415

416
    prompt_len = len(prompt)
417
    item_idx = 0
418
419
420
421
422
423

    start_idx = 0
    while start_idx < prompt_len:
        found = False

        for repl_info in modality_repls:
424
            replacement = repl_info.get_replacement(item_idx)
425
426
427
428
429
430
            repl_tokens = replacement.token_ids
            repl_len = len(repl_tokens)
            end_idx = start_idx + repl_len

            if repl_len == 0 or end_idx > prompt_len:
                continue
431

432
            if prompt[start_idx:end_idx] == repl_tokens:
433
                yield PlaceholderInfo(
434
                    modality=modality,
435
                    item_idx=item_idx,
436
437
438
439
                    start_idx=start_idx,
                    replacement=repl_tokens,
                )

440
441
                item_idx += 1
                if item_idx >= modal_item_count:
442
443
444
445
446
447
448
449
450
                    return

                # Exclude overlapping matches
                start_idx = end_idx
                found = True
                break

        if not found:
            start_idx += 1
451
452


453
def _iter_placeholders(
454
    mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
455
    prompt: list[int],
456
    mm_item_counts: Mapping[str, int],
457
) -> Iterable[PlaceholderInfo]:
458
    """
459
460
    For each modality, yield each set of placeholder tokens found in
    :code:`prompt`.
461
462
463

    Note that empty matches are ignored.
    """
464
    for modality, modal_item_count in mm_item_counts.items():
465
        if modality in mm_prompt_repls:
466
467
468
            yield from _iter_modality_placeholders(
                prompt,
                modality,
469
                mm_prompt_repls[modality],
470
                modal_item_count,
471
472
            )

473

474
def find_mm_placeholders(
475
    mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
476
477
    prompt: list[int],
    mm_item_counts: Mapping[str, int],
478
) -> Mapping[str, list[PlaceholderInfo]]:
479
480
481
482
    it = _iter_placeholders(mm_prompt_repls, prompt, mm_item_counts)
    return dict(full_groupby_modality(it))


483
484
485
486
487
488
489
490
class ProcessingCache:

    def __init__(self, capacity: int) -> None:
        super().__init__()

        # DEBUG: Set to None to disable
        self.debug_cache_hit_ratio_steps: Optional[int] = None

491
        self._cache = LRUCache[str, MultiModalKwargsItem](capacity)
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508

    def _maybe_log_cache_stats(self) -> None:
        steps = self.debug_cache_hit_ratio_steps
        if not steps:
            return

        cache_stats = self._cache.stat()
        if cache_stats.total % steps == 0:
            logger.debug("ProcessingCache: hit_ratio = %.2f",
                         cache_stats.hit_ratio)

    def get(
        self,
        model_id: str,
        modality: str,
        input_item: object,
        input_kwargs: Mapping[str, object],
509
    ) -> Optional[MultiModalKwargsItem]:
510
511
512
513
514
515
516
517
518
519
520
        """
        Get a processed multi-modal item from the cache
        according to its dependencies, including:

        - The model ID
        - The modality of the item
        - The original data item passed to the HF processor
        - The configuration options of the HF processor
        """
        self._maybe_log_cache_stats()

521
522
523
        cache_key = MultiModalHasher.hash_kwargs(model_id=model_id,
                                                 **{modality: input_item},
                                                 **input_kwargs)
524
525
526
527
528
529
530
531
        return self._cache.get(cache_key)

    def put(
        self,
        model_id: str,
        modality: str,
        input_item: object,
        input_kwargs: Mapping[str, object],
532
        output_kwargs: MultiModalKwargsItem,
533
534
535
536
537
    ) -> None:
        """
        Put a processed multi-modal item into the cache
        according to its dependencies (see :meth:`get`).
        """
538
539
540
        cache_key = MultiModalHasher.hash_kwargs(model_id=model_id,
                                                 **{modality: input_item},
                                                 **input_kwargs)
541
        self._cache.put(cache_key, output_kwargs)
542
543


544
class BaseProcessingInfo:
545
    """Base class to provide the information necessary for data processing."""
546

547
548
    def __init__(self, ctx: InputProcessingContext) -> None:
        super().__init__()
549

550
551
552
553
554
555
556
        self.ctx = ctx

    @property
    def model_id(self) -> str:
        return self.ctx.model_config.model

    def get_tokenizer(self) -> AnyTokenizer:
557
558
        return self.ctx.tokenizer

559
    def get_hf_config(self) -> PretrainedConfig:
560
561
        return self.ctx.get_hf_config()

562
    def get_hf_processor(self, **kwargs: object) -> ProcessorMixin:
563
564
565
566
567
568
        """
        Subclasses can override this method to handle
        specific kwargs from model config or user inputs.
        """
        return self.ctx.get_hf_processor(**kwargs)

569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
    @abstractmethod
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        """
        Return the maximum supported number of items for each modality.

        A value of `None` means unlimited number of items.

        Omitting a modality from the returned dictionary means that
        it is not supported at all.
        """
        raise NotImplementedError

    @abstractmethod
    def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
        """
        Get the maximum possible number of tokens per data item
        for each modality.

        The dictionary returned by this method should have the same
        keys as that returned by :meth:`get_supported_mm_limits`.
        """
        raise NotImplementedError


_I = TypeVar("_I", bound=BaseProcessingInfo)
594

595
596

class BaseMultiModalProcessor(ABC, Generic[_I]):
597
    """
598
    Abstract base class to process multi-modal inputs to be used in vLLM.
599
600

    Not to be confused with :class:`transformers.ProcessorMixin`.
601
602
    """

603
    def __init__(self,
604
605
                 info: _I,
                 dummy_inputs: "BaseDummyInputsBuilder[_I]",
606
607
608
                 *,
                 cache: Optional[ProcessingCache] = None,
                 enable_sanity_checks: bool = True) -> None:
609
610
        super().__init__()

611
612
        self.info = info
        self.dummy_inputs = dummy_inputs
613
614
        self.cache = cache
        self.enable_sanity_checks = enable_sanity_checks
615

616
617
        self.data_parser = self._get_data_parser()

618
    def __call__(
619
        self,
620
621
        prompt: str,
        mm_data: MultiModalDataDict,
622
        hf_processor_mm_kwargs: Mapping[str, object],
623
    ) -> MultiModalInputsV2:
624
        return self.apply(prompt, mm_data, hf_processor_mm_kwargs)
625

626
627
    def _get_data_parser(self) -> MultiModalDataParser:
        """
628
        Construct a parser to preprocess multi-modal data items
629
630
631
632
633
634
635
636
        before passing them to :meth:`_get_hf_mm_data`.

        You can support additional modalities by creating a subclass
        of :class:`MultiModalDataParser` that has additional subparsers.
        """
        return MultiModalDataParser()

    def _to_mm_items(
637
638
639
        self,
        mm_data: MultiModalDataDict,
    ) -> MultiModalDataItems:
640
641
642
643
        """
        Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems`
        before passing them to :meth:`_get_hf_mm_data`.
        """
644
        mm_items = self.data_parser.parse_mm_data(mm_data)
645

646
        mm_limits = self.info.ctx.get_mm_config().limit_per_prompt
647
648
649
650
651
652
653
654
655
        for modality, items in mm_items.items():
            limit = mm_limits.get(modality, 1)
            if len(items) > limit:
                raise ValueError(
                    f"You set {modality}={limit} (or defaulted to 1) in "
                    f"`--limit-mm-per-prompt`, but passed {len(items)} "
                    f"{modality} items in the same prompt.")

        return mm_items
656

657
658
659
660
661
662
663
664
665
    @abstractmethod
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        """Given the HF-processed data, output the metadata of each field."""
        raise NotImplementedError

666
667
    @abstractmethod
    def _get_prompt_replacements(
668
        self,
669
        mm_items: MultiModalDataItems,
670
671
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
672
673
674
675
676
    ) -> list[PromptReplacement]:
        """
        Given the original multi-modal items for this modality
        and HF-processed data, output the replacements to perform.

677
678
679
680
681
682
683
684
        Notes:
            - You should not assume that HF processor always performs prompt
              replacement: in :meth:`_apply_hf_processor_missing`, this method
              is called on text-only and multimodal-only inputs separately,
              instead of passing them in the same call.
            - The replacement information returned by this method is also used
              to determine the placeholder token positions for each multi-modal
              item.
685
686
        """
        raise NotImplementedError
687

688
    def _find_mm_placeholders(
689
        self,
690
        mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
691
        new_token_ids: list[int],
692
        mm_item_counts: Mapping[str, int],
693
    ) -> Mapping[str, list[PlaceholderInfo]]:
694
695
        return find_mm_placeholders(mm_prompt_repls, new_token_ids,
                                    mm_item_counts)
696

697
    def _get_hf_mm_data(
698
        self,
699
        mm_items: MultiModalDataItems,
700
701
702
    ) -> tuple[Mapping[str, object], Mapping[str, object]]:
        processor_data = dict[str, object]()
        passthrough_data = dict[str, object]()
703

704
705
706
        for items in mm_items.values():
            processor_data.update(items.get_processor_data())
            passthrough_data.update(items.get_passthrough_data())
707

708
709
        return processor_data, passthrough_data

710
711
712
    def _call_hf_processor(
        self,
        prompt: str,
713
714
715
716
        # Not to be confused with `mm_data` in `self.apply`.
        # This refers to the data to be passed to HF processor.
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
717
    ) -> BatchFeature:
718
719
720
721
        """
        Call the HF processor on the prompt text and
        associated multi-modal data.
        """
722
723
        return self.info.ctx.call_hf_processor(
            self.info.get_hf_processor(**mm_kwargs),
724
725
            dict(text=prompt, **mm_data),
            mm_kwargs,
726
727
        )

728
    def _apply_hf_processor_text_mm(
729
        self,
730
        prompt_text: str,
731
        mm_items: MultiModalDataItems,
732
733
734
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> tuple[list[int], MultiModalKwargs]:
        """
735
736
        Apply the HF processor on the prompt text and multi-modal data
        together.
737
738
739
740
741
742
743
744
745
        """
        processor_data, passthrough_data = self._get_hf_mm_data(mm_items)

        processed_data = self._call_hf_processor(
            prompt=prompt_text,
            mm_data=processor_data,
            mm_kwargs=hf_processor_mm_kwargs,
        )
        processed_data.update(passthrough_data)
746

747
        prompt_ids, = processed_data.pop("input_ids").tolist()
748

749
750
751
        mm_kwargs = MultiModalKwargs.from_hf_inputs(
            processed_data,
            self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs),
752
        )
753

754
755
        return prompt_ids, mm_kwargs

756
    def _apply_hf_processor_text_only(self, prompt_text: str) -> list[int]:
757
        """
758
        Apply the HF processor on the prompt text only.
759

760
761
762
        Since HF processor requires that text and multi-modal items
        correspond to each other, we create dummy multi-modal items
        to go along with the text.
763
        """
764
        prompt_ids, _ = self._apply_hf_processor_text_mm(
765
766
767
768
769
            prompt_text=prompt_text,
            mm_items=MultiModalDataItems({}),
            hf_processor_mm_kwargs={},
        )

770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
        return prompt_ids

    def _apply_hf_processor_tokens_only(
        self,
        prompt_tokens: list[int],
    ) -> list[int]:
        """
        Apply the HF processor on the prompt tokens only.

        Most HF processors accept prompt text but not prompt tokens.
        If the HF processor adds or removes tokens that are not related to
        multi-modal data, you should override this method so it is consistent
        with the output of :meth:`_apply_hf_processor_text_only` on the
        corresponding text.
        """
        return prompt_tokens

    def _apply_hf_processor_mm_only(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> MultiModalKwargs:
        """
        Apply the HF processor on the multi-modal data only.

        Since HF processor requires that text and multi-modal items
        correspond to each other, we generate dummy text using
        :class:`DummyInputsBuilder` to go along with the multi-modal data.
        """
        mm_counts = mm_items.get_all_counts()

801
802
        dummy_inputs = self.dummy_inputs.get_dummy_processor_inputs(
            self.info.ctx.model_config.max_model_len,
803
            mm_counts,
804
        )
805

806
        _, mm_kwargs = self._apply_hf_processor_text_mm(
807
            prompt_text=dummy_inputs.prompt_text,
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
            mm_items=mm_items,
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
        )

        return mm_kwargs

    def _apply_hf_processor_main(
        self,
        prompt: Union[str, list[int]],
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        *,
        enable_hf_prompt_replacement: bool,
    ) -> tuple[list[int], MultiModalKwargs]:
        """
        Apply the HF processor on the prompt text and multi-modal data.

        Note:
            If :code:`enable_hf_prompt_replacement=False`, the prompt should
            correspond to the multi-modal items.
        """
        if isinstance(prompt, str):
            if enable_hf_prompt_replacement:
                return self._apply_hf_processor_text_mm(
                    prompt_text=prompt,
                    mm_items=mm_items,
                    hf_processor_mm_kwargs=hf_processor_mm_kwargs,
                )

            prompt_ids = self._apply_hf_processor_text_only(prompt)
        else:
            prompt_ids = self._apply_hf_processor_tokens_only(prompt)

        mm_missing_kwargs = self._apply_hf_processor_mm_only(
            mm_items=mm_items,
843
844
845
846
847
848
849
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
        )

        return prompt_ids, mm_missing_kwargs

    def _cached_apply_hf_processor(
        self,
850
        prompt: Union[str, list[int]],
851
852
853
854
855
856
857
858
        mm_data_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> tuple[list[int], MultiModalKwargs]:
        """
        Apply the HF processor on the full prompt text,
        caching the results and reusing cached results.
        """
        cache = self.cache
859
        model_id = self.info.model_id
860

861
862
        _, passthrough_data = self._get_hf_mm_data(mm_data_items)
        if cache is None or passthrough_data:
863
864
            return self._apply_hf_processor_main(
                prompt=prompt,
865
866
                mm_items=mm_data_items,
                hf_processor_mm_kwargs=hf_processor_mm_kwargs,
867
                enable_hf_prompt_replacement=True,
868
869
            )

870
        mm_maybe_cached_kw_items = {
871
872
873
874
875
876
877
878
            modality: [
                cache.get(model_id, modality, item, hf_processor_mm_kwargs)
                for item in items
            ]
            for modality, items in mm_data_items.items()
        }

        mm_missing_idxs = {
879
880
881
            modality:
            [idx for idx, item in enumerate(kw_items) if item is None]
            for modality, kw_items in mm_maybe_cached_kw_items.items()
882
883
884
885
886
        }
        mm_missing_data = {
            modality: [mm_data_items[modality][idx] for idx in idxs]
            for modality, idxs in mm_missing_idxs.items()
        }
887
        mm_missing_data_items = self._to_mm_items(mm_missing_data)
888

889
890
891
892
893
        # NOTE: `prompt` does not correspond to `mm_missing_data_items`,
        # so we need to pass `enable_hf_prompt_replacement=False`
        prompt_ids, mm_missing_kwargs = self._apply_hf_processor_main(
            prompt=prompt,
            mm_items=mm_missing_data_items,
894
            hf_processor_mm_kwargs=hf_processor_mm_kwargs,
895
            enable_hf_prompt_replacement=False,
896
897
898
899
900
901
902
        )

        mm_missing_next_idx = {
            modality: 0
            for modality in mm_missing_data_items
        }

903
904
905
906
907
        merged_kw_items = list[MultiModalKwargsItem]()
        for modality, kw_items in mm_maybe_cached_kw_items.items():
            for idx, kw_item in enumerate(kw_items):
                if kw_item is None:
                    kw_item = mm_missing_kwargs.get_item(
908
909
910
911
912
913
914
915
916
                        modality,
                        mm_missing_next_idx[modality],
                    )

                    cache.put(
                        model_id,
                        modality,
                        mm_data_items[modality][idx],
                        hf_processor_mm_kwargs,
917
                        kw_item,
918
919
920
921
                    )

                    mm_missing_next_idx[modality] += 1

922
                merged_kw_items.append(kw_item)
923
924

        if self.enable_sanity_checks:
925
            mm_missing_counts = mm_missing_data_items.get_all_counts()
926
927
928
929
930
931
            assert all(
                item_count == mm_missing_counts[modality]
                for modality, item_count in mm_missing_next_idx.items()), dict(
                    mm_missing_next_idx=mm_missing_next_idx,
                    mm_missing_counts=mm_missing_counts)

932
        mm_kwargs = MultiModalKwargs.from_items(merged_kw_items)
933
934

        return prompt_ids, mm_kwargs
935

936
    def _bind_and_group_repls(
937
        self,
938
        prompt_repls: list[PromptReplacement],
939
940
    ) -> dict[str, list[BoundPromptReplacement]]:
        tokenizer = self.info.get_tokenizer()
941

942
943
        it = (prompt_repl.bind(tokenizer) for prompt_repl in prompt_repls)
        return dict(full_groupby_modality(it))
944

945
946
947
948
    def _always_apply_prompt_replacements(self) -> bool:
        """
        A flag which can be overridden so that
        :meth:`_apply_prompt_replacements` is always called even if we
949
950
        detect that HF has performed processing via
        :meth:`_find_placeholders_by_modality`.
951

952
953
        This is useful in cases where :meth:`_find_placeholders_by_modality`
        cannot be reliably used to detect whether HF has performed processing.
954
955
956
        """
        return False

957
958
959
    def _apply_prompt_replacements(
        self,
        token_ids: list[int],
960
        mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
961
        mm_item_counts: Mapping[str, int],
962
963
    ) -> tuple[list[int], str, Mapping[str, list[PlaceholderInfo]]]:
        tokenizer = self.info.get_tokenizer()
964

965
966
967
968
        mm_token_matches = {
            modality: find_token_matches(token_ids, prompt_repls)
            for modality, prompt_repls in mm_prompt_repls.items()
        }
969
970
        mm_match_counts = {
            modality: len(matches)
971
            for modality, matches in mm_token_matches.items()
972
        }
973
974
975
976
977
978
979
980
981
982
983
984

        # If the search text does not represent a special token,
        # it may have different token IDs in the prompt, because
        # the tokens may go across the boundaries of the search text.
        # ----
        # e.g. when searching for "foo" in "food", if "food" itself makes
        # up a token, then the token ID of "foo" will not appear at all
        # ----
        # Since it is inefficient to search for all possible tokenizations
        # of the search text in the prompt, we instead perform string
        # replacement on the decoded token IDs, then encode them back.
        if all(
985
986
            mm_match_counts.get(modality, 0) >= item_count
            for modality, item_count in mm_item_counts.items()
987
988
989
        ):  # yapf: disable
            token_ids = replace_token_matches(
                token_ids,
990
                mm_token_matches,
991
                mm_item_counts,
992
993
            )

994
995
996
997
998
            text = decode_tokens(tokenizer, token_ids)
            matched_repls = {
                modality: [match.prompt_repl for match in token_matches]
                for modality, token_matches in mm_token_matches.items()
            }
999
        else:
1000
            text = decode_tokens(tokenizer, token_ids)
1001

1002
1003
1004
1005
            mm_text_matches = {
                modality: find_text_matches(text, prompt_repls)
                for modality, prompt_repls in mm_prompt_repls.items()
            }
1006
1007
            text = replace_text_matches(
                text,
1008
                mm_text_matches,
1009
                mm_item_counts,
1010
1011
            )

1012
1013
1014
            token_ids = encode_tokens(tokenizer,
                                      text,
                                      add_special_tokens=False)
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
            matched_repls = {
                modality: [match.prompt_repl for match in token_matches]
                for modality, token_matches in mm_text_matches.items()
            }

        placeholders = self._find_mm_placeholders(
            matched_repls,
            token_ids,
            mm_item_counts,
        )
1025
1026

        return token_ids, text, placeholders
1027

1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
    def _validate_mm_kwargs(
        self,
        mm_kwargs: MultiModalKwargs,
        mm_item_counts: Mapping[str, int],
    ) -> None:
        for modality, item_count in mm_item_counts.items():
            if modality in mm_kwargs.modalities:
                items = mm_kwargs.get_items(modality)
            else:
                items = []

            if len(items) != item_count:
                raise RuntimeError(
                    f"Expected there to be {item_count} {modality} items in "
                    f"keyword arguments corresponding to {item_count} "
                    f"{modality} data items, but only found {len(items)}! "
                    "There is likely a problem with your "
                    "implementation of merged multi-modal processor for this "
                    "model (usually arising from an inconsistency between "
                    "`_call_hf_processor` and `_get_mm_fields_config`).")

    def _validate_mm_placeholders(
        self,
1051
        mm_placeholders: Mapping[str, list[PlaceholderInfo]],
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
        mm_item_counts: Mapping[str, int],
        *,
        allow_missing: bool = False,
    ) -> Mapping[str, int]:
        missing_repl_counts = dict[str, int]()

        for modality, item_count in mm_item_counts.items():
            placeholders = mm_placeholders.get(modality, [])

            if len(placeholders) != item_count and not allow_missing:
                raise RuntimeError(
                    f"Expected there to be {item_count} prompt replacements "
                    f"corresponding to {item_count} {modality} items, but only "
                    f"found {len(placeholders)} prompt replacements! Either "
                    "the prompt text has missing/incorrect tokens for "
                    "multi-modal inputs, or there is a problem with your "
                    "implementation of merged multi-modal processor for this "
                    "model (usually arising from an inconsistency between "
                    "`_call_hf_processor` and `_get_prompt_replacements`).")

            missing_repl_counts[modality] = item_count - len(placeholders)

        return missing_repl_counts

1076
1077
    def apply(
        self,
1078
        prompt: Union[str, list[int]],
1079
        mm_data: MultiModalDataDict,
1080
        hf_processor_mm_kwargs: Mapping[str, object],
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
    ) -> MultiModalInputsV2:
        """
        Process multi-modal inputs to be used in vLLM.

        The main steps are:

        1. Apply HF Processor on prompt text and multi-modal data together,
           outputting token IDs and processed tensors.
        2. Find and replace sequences in the token IDs with placeholder tokens.
           The number of placeholder tokens equals the feature size of the
           multi-modal data outputted by the multi-modal encoder.
        3. Extract information about the placeholder tokens from the
           processed token IDs.
        """
1095
        mm_items = self._to_mm_items(mm_data)
1096

1097
1098
1099
1100
1101
        # Create MM hashes (only used in V1)
        # TODO: Use these hash keys for caching operations in apply_hf_processor
        # instead of rehashing.

        if envs.VLLM_USE_V1:
1102
            model_id = self.info.model_id
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
            mm_hashes = {
                modality: [
                    MultiModalHasher.hash_kwargs(model_id=model_id,
                                                 **{modality: item},
                                                 **hf_processor_mm_kwargs)
                    for item in items
                ]
                for modality, items in mm_items.items()
            }
        else:
            mm_hashes = None

1115
        prompt_ids, mm_kwargs = self._cached_apply_hf_processor(
1116
            prompt,
1117
1118
1119
            mm_items,
            hf_processor_mm_kwargs,
        )
1120

1121
1122
1123
1124
1125
        unbound_prompt_repls = self._get_prompt_replacements(
            mm_items,
            hf_processor_mm_kwargs,
            mm_kwargs,
        )
1126
        mm_prompt_repls = self._bind_and_group_repls(unbound_prompt_repls)
1127

1128
        mm_item_counts = mm_items.get_all_counts()
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
        self._validate_mm_kwargs(mm_kwargs, mm_item_counts)

        hf_mm_placeholders = self._find_mm_placeholders(
            mm_prompt_repls,
            prompt_ids,
            mm_item_counts,
        )

        if self._always_apply_prompt_replacements():
            mm_missing_repl_counts = mm_item_counts
            mm_missing_repls = dict(mm_prompt_repls)
        else:
            mm_missing_repl_counts = self._validate_mm_placeholders(
                hf_mm_placeholders,
                mm_item_counts,
                allow_missing=True,
            )

1147
            mm_missing_repls = dict[str, list[BoundPromptReplacement]]()
1148
1149
1150
1151
1152
1153
1154
1155
            for modality, missing_repl_count in mm_missing_repl_counts.items():
                if missing_repl_count == 0:
                    mm_missing_repls[modality] = []
                elif missing_repl_count == mm_item_counts.get(modality, 0):
                    mm_missing_repls[modality] = mm_prompt_repls[modality]
                else:
                    raise ValueError("Partial prompt replacement within "
                                     f"{modality=} is not supported")
1156

1157
1158
1159
        # If HF processor already inserts placeholder tokens,
        # there is no need for us to insert them
        if all(len(repls) == 0 for repls in mm_missing_repls.items()):
1160
            tokenizer = self.info.get_tokenizer()
1161
            prompt = decode_tokens(tokenizer, prompt_ids)
1162
            mm_placeholders = hf_mm_placeholders
1163
1164
1165
        else:
            (
                prompt_ids,
1166
                prompt,
1167
                missing_mm_placeholders,
1168
1169
            ) = self._apply_prompt_replacements(
                prompt_ids,
1170
1171
                mm_missing_repls,
                mm_missing_repl_counts,
1172
1173
            )

1174
1175
1176
1177
1178
1179
1180
1181
            mm_placeholders = {**hf_mm_placeholders, **missing_mm_placeholders}

        self._validate_mm_placeholders(mm_placeholders, mm_item_counts)

        mm_placeholder_ranges = {
            modality: [item.to_range() for item in placeholders]
            for modality, placeholders in mm_placeholders.items()
        }
1182
1183
1184

        return MultiModalInputsV2(
            type="multimodal",
1185
            prompt=prompt,
1186
            prompt_token_ids=prompt_ids,
1187
            mm_kwargs=mm_kwargs,
1188
            mm_hashes=mm_hashes,
1189
            mm_placeholders=mm_placeholder_ranges,
1190
        )