preprocess.py 24 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Mapping
5
from typing import Any, Optional, Union, cast
6
7
8
9
10

from typing_extensions import assert_never

from vllm.config import ModelConfig
from vllm.logger import init_logger
11
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
12
from vllm.multimodal.cache import BaseMultiModalProcessorCache
13
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
14
                                    MultiModalInputs, MultiModalUUIDDict)
15
from vllm.multimodal.processing import BaseMultiModalProcessor
16
from vllm.transformers_utils.tokenizer import AnyTokenizer
17

18
from .data import (DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt,
19
20
21
22
                   EncoderDecoderInputs, ExplicitEncoderDecoderPrompt,
                   ProcessorInputs, PromptType, SingletonInputs,
                   SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt,
                   embeds_inputs, token_inputs)
23
from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt
24
25
26
27
28
29
30
31
32

logger = init_logger(__name__)


class InputPreprocessor:

    def __init__(
        self,
        model_config: ModelConfig,
33
        tokenizer: Optional[AnyTokenizer],
34
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
35
        mm_processor_cache: Optional[BaseMultiModalProcessorCache] = None,
36
37
38
39
40
    ) -> None:
        super().__init__()

        self.model_config = model_config
        self.tokenizer = tokenizer
41
        self.mm_registry = mm_registry
42
        self.mm_processor_cache = mm_processor_cache
43

44
    def get_tokenizer(self) -> AnyTokenizer:
45
46
47
48
49
50
        if self.tokenizer is None:
            raise ValueError("You cannot pass text prompts when "
                             "`skip_tokenizer_init` is True")

        return self.tokenizer

51
    def get_bos_token_id(self) -> Optional[int]:
52
53
54
55
56
        if self.tokenizer is None:
            logger.warning("Using None for BOS token id because tokenizer "
                           "is not initialized")
            return None

57
        return self.tokenizer.bos_token_id
58

59
    def get_eos_token_id(self) -> Optional[int]:
60
61
62
63
64
        if self.tokenizer is None:
            logger.warning("Using None for EOS token id because tokenizer "
                           "is not initialized")
            return None

65
        return self.tokenizer.eos_token_id
66
67

    def get_decoder_start_token_id(self) -> Optional[int]:
68
        """
69
70
71
        Obtain the decoder start token id employed by an encoder/decoder
        model. Returns None for non-encoder/decoder models or if the
        model config is unavailable.
72
        """
73

74
        if not self.model_config.is_encoder_decoder:
75
76
77
            logger.warning_once(
                "Using None for decoder start token id because "
                "this is not an encoder/decoder model.")
78
79
            return None

80
        if self.model_config is None or self.model_config.hf_config is None:
81
82
83
            logger.warning_once(
                "Using None for decoder start token id because "
                "model config is not available.")
84
85
86
            return None

        dec_start_token_id = getattr(self.model_config.hf_config,
87
                                     "decoder_start_token_id", None)
88
        if dec_start_token_id is None:
89
90
91
92
            logger.warning_once(
                "Falling back on <BOS> for decoder start token "
                "id because decoder start token id is not "
                "available.")
93
94
95
96
            dec_start_token_id = self.get_bos_token_id()

        return dec_start_token_id

97
    def _get_default_enc_dec_decoder_prompt(self) -> list[int]:
98
        """
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
        Specifically for encoder/decoder models:
        generate a default decoder prompt for when
        the user specifies only the encoder prompt.

        Encoder/decoder models utilize the decoder
        prompt in different ways; as new models are
        added, it is intended that this function
        will be extended to produce differing
        default decoder prompts, depending on the
        model variety.

        Absent a special case, the default behavior
        of this method is to mirror the behavior of
        the HuggingFace (HF) GenerationMixin for a None
        decoder prompt, which is to employ a logit processor
        setting to force the first decoded token to be <BOS>.
        Here, this behavior is approximated by having the
        "default" decoder prompt be <BOS>.

        However, it is possible that in the future
119
        other models may have different or more
120
121
122
123
124
125
126
        complex logic for the default decoder prompt.
        This motivates having a special helper method
        for default decoder prompts.

        Returns:

        * prompt_token_ids
127
        """
128
129
130
131
132
133
134

        bos_token_id = self.get_bos_token_id()
        assert bos_token_id is not None
        return [bos_token_id]

    def _prepare_decoder_input_ids_for_generation(
        self,
135
136
        decoder_input_ids: Optional[list[int]],
    ) -> list[int]:
137
138
139
        """
        Prepares `decoder_input_ids` for generation with encoder-decoder models.

140
141
142
143
        Based on:
        https://github.com/huggingface/transformers/blob/4037a2b5b1278736e566aec12e169100275545ea/src/transformers/generation/utils.py
        specifically,
        `GenerationMixin._prepare_decoder_input_ids_for_generation()`.
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161

        Arguments:

        * decoder_input_ids: input token ids to preprocess

        Returns:

        * Processed token list
        """

        decoder_start_token_id = self.get_decoder_start_token_id()
        assert decoder_start_token_id is not None

        if decoder_input_ids is None:
            # no decoder prompt input ->
            # use decoder_start_token_id as decoder_input_ids
            decoder_input_ids = self._get_default_enc_dec_decoder_prompt()

162
163
        if (len(decoder_input_ids) == 0
                or decoder_input_ids[0] != decoder_start_token_id):
164
165
166
167
            decoder_input_ids = [decoder_start_token_id] + decoder_input_ids

        return decoder_input_ids

168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
    def _get_tokenization_kw(
        self,
        overrides: Optional[dict[str, Any]] = None,
    ) -> dict[str, Any]:
        kwargs = dict[str, Any]()

        if self.model_config.hf_config.model_type == "whisper":
            # For Whisper, special tokens should be provided by the user based
            # on the task and language of their request. Also needed to avoid
            # appending an EOS token to the prompt which disrupts generation.
            kwargs["add_special_tokens"] = False

        if overrides:
            kwargs.update(overrides)

        return kwargs

185
186
187
    def _tokenize_prompt(
        self,
        prompt: str,
188
        tokenization_kwargs: Optional[dict[str, Any]] = None,
189
    ) -> list[int]:
190
191
192
193
        """
        Apply the model's tokenizer to a text prompt, returning the
        corresponding token IDs.
        """
194
        tokenizer = self.get_tokenizer()
195
        tokenization_kwargs = self._get_tokenization_kw(tokenization_kwargs)
196

197
        encoder_config = self.model_config.encoder_config
198

199
        if encoder_config and encoder_config.get("do_lower_case", False):
200
201
            prompt = prompt.lower()

202
        return tokenizer.encode(prompt, **tokenization_kwargs)
203

204
    def _get_mm_tokenizer(self) -> AnyTokenizer:
205
206
207
208
209
        # PrithviGeoSpatialMAE needs to be initialized without a tokenizer
        # while using also multi-modal input
        if not self.tokenizer:
            return cast(AnyTokenizer, object())  # Dummy

210
211
        tokenizer = self.get_tokenizer()
        return tokenizer
212

213
214
215
    def _get_mm_processor(self) -> BaseMultiModalProcessor:
        if not hasattr(self, "_mm_processor"):
            tokenizer = self._get_mm_tokenizer()
216

217
218
219
220
221
222
223
            self._mm_processor = self.mm_registry.create_processor(
                self.model_config,
                tokenizer=tokenizer,
                cache=self.mm_processor_cache,
            )

        return self._mm_processor
224

225
226
    def _process_multimodal(
        self,
227
        prompt: Union[str, list[int]],
228
229
        mm_data: MultiModalDataDict,
        mm_processor_kwargs: Optional[Mapping[str, object]],
230
        tokenization_kwargs: Optional[dict[str, Any]] = None,
231
        *,
232
        mm_uuids: Optional[MultiModalUUIDDict] = None,
233
    ) -> MultiModalInputs:
234
235
236
237
        """
        Apply the model's multi-modal processor to a multi-modal prompt,
        returning the corresponding token IDs and metadata.
        """
238
        mm_processor = self._get_mm_processor()
239

240
241
242
        if mm_processor_kwargs is None:
            mm_processor_kwargs = {}

243
        mm_input = mm_processor.apply(
244
245
246
247
            prompt,
            mm_data,
            hf_processor_mm_kwargs=mm_processor_kwargs,
            tokenization_kwargs=tokenization_kwargs,
248
            mm_uuids=mm_uuids,
249
        )
250
251
252
253
254
255
256
257
258
259
        mm_hashes = mm_input["mm_hashes"]

        # Validate that all mm items have a string as their hash
        if not contains_only_strings(mm_hashes):
            raise ValueError(
                f"mm_hashes must contain only strings, got: {mm_hashes}. "
                "This is likely due to an incorrect custom implementation of "
                "MultiModalProcessor.apply method.")

        return mm_input
260

261
262
263
264
    def _process_embeds(
        self,
        parsed_content: EmbedsPrompt,
    ) -> EmbedsInputs:
265
266
267
        if not self.model_config.enable_prompt_embeds:
            raise ValueError("You must set `--enable-prompt-embeds` to input "
                             "`prompt_embeds`.")
268
269

        prompt_embeds = parsed_content["prompt_embeds"]
270

271
272
273
274
275
276
277
278
279
280
281
        # prompt_embeds must be (seq_len, hidden_size), but if the user
        # passes in a batch of size 1, i.e. (1, seq_len, hidden_size),
        # we can unambiguously process the intent by squeezing the batch
        # dimension.
        if prompt_embeds.ndim == 3:
            prompt_embeds = prompt_embeds.squeeze(dim=0)

        if prompt_embeds.ndim != 2:
            raise ValueError(
                "prompt_embeds must be of shape (seq_len, hidden_size).")

282
283
284
285
286
        # Tensors must be on CPU for serialization between processes
        # in the MsgpackEncoder. Casting to CPU here ensures that there is no
        # hidden device transfer in the critical path of generation.
        prompt_embeds = prompt_embeds.cpu()

287
288
289
        return embeds_inputs(prompt_embeds=prompt_embeds,
                             cache_salt=parsed_content.get("cache_salt"))

290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
    def _truncate_inputs(
            self,
            inputs: list[int],
            tokenization_kwargs: Optional[dict[str, Any]] = None) -> list[int]:

        if not tokenization_kwargs or "truncation" not in \
                tokenization_kwargs or self.tokenizer is None:
            return inputs

        max_length = tokenization_kwargs["max_length"]

        if self.tokenizer.truncation_side == "left":
            return inputs[-max_length:]
        else:
            return inputs[:max_length]

306
307
308
    def _process_tokens(
        self,
        parsed_content: TokensPrompt,
309
        tokenization_kwargs: Optional[dict[str, Any]] = None,
310
        *,
311
        mm_uuids: Optional[MultiModalUUIDDict] = None,
312
    ) -> Union[TokenInputs, MultiModalInputs]:
313
314
        prompt_token_ids = self._truncate_inputs(
            parsed_content["prompt_token_ids"], tokenization_kwargs)
315
316

        inputs: Union[TokenInputs, MultiModalInputs]
317
        if self.model_config.is_multimodal_model:
318
319
            inputs = self._process_multimodal(
                prompt_token_ids,
320
                parsed_content.get("multi_modal_data", {}),
321
                parsed_content.get("mm_processor_kwargs"),
322
                tokenization_kwargs=tokenization_kwargs,
323
                mm_uuids=mm_uuids,
324
            )
325
        else:
326
327
328
329
            if parsed_content.get("multi_modal_data"):
                raise ValueError(
                    "This model does not support multimodal inputs")

330
            inputs = token_inputs(prompt_token_ids)
331
332
333
334
335
336
337
338
339
340

        if cache_salt := parsed_content.get("cache_salt"):
            inputs["cache_salt"] = cache_salt

        return inputs

    def _process_text(
        self,
        parsed_content: TextPrompt,
        tokenization_kwargs: Optional[dict[str, Any]] = None,
341
        *,
342
        mm_uuids: Optional[MultiModalUUIDDict] = None,
343
344
345
346
    ) -> Union[TokenInputs, MultiModalInputs]:
        prompt_text = parsed_content["prompt"]

        inputs: Union[TokenInputs, MultiModalInputs]
347
        if self.model_config.is_multimodal_model:
348
349
            inputs = self._process_multimodal(
                prompt_text,
350
                parsed_content.get("multi_modal_data", {}),
351
                parsed_content.get("mm_processor_kwargs"),
352
                tokenization_kwargs=tokenization_kwargs,
353
                mm_uuids=mm_uuids,
354
355
            )
        else:
356
357
358
359
            if parsed_content.get("multi_modal_data"):
                raise ValueError(
                    "This model does not support multimodal inputs")

360
361
362
363
            prompt_token_ids = self._tokenize_prompt(
                prompt_text,
                tokenization_kwargs=tokenization_kwargs,
            )
364
            inputs = token_inputs(prompt_token_ids)
365
366
367
368
369

        if cache_salt := parsed_content.get("cache_salt"):
            inputs["cache_salt"] = cache_salt

        return inputs
370

371
    def _prompt_to_llm_inputs(
372
        self,
373
        prompt: SingletonPrompt,
374
        tokenization_kwargs: Optional[dict[str, Any]] = None,
375
        *,
376
        mm_uuids: Optional[MultiModalUUIDDict] = None,
377
    ) -> SingletonInputs:
378
379
        """
        Extract the singleton inputs from a prompt.
380
381
382

        Arguments:

383
        * prompt: single encoder or decoder input prompt
384
385
386

        Returns:

387
        * [`SingletonInputs`][vllm.inputs.data.SingletonInputs] instance
388
        """
389
        parsed = parse_singleton_prompt(prompt)
390
391

        if parsed["type"] == "embeds":
392
393
394
395
            return self._process_embeds(parsed["content"])
        if parsed["type"] == "tokens":
            return self._process_tokens(
                parsed["content"],
396
                mm_uuids=mm_uuids,
397
            )
398
399
400
401
        if parsed["type"] == "text":
            return self._process_text(
                parsed["content"],
                tokenization_kwargs=tokenization_kwargs,
402
                mm_uuids=mm_uuids,
403
404
405
406
            )
        if parsed["type"] == "str":
            return self._process_text(
                TextPrompt(prompt=parsed["content"]),
407
                tokenization_kwargs=tokenization_kwargs,
408
                mm_uuids=mm_uuids,
409
            )
410

411
412
        assert_never(parsed)

413
414
    def _build_enc_dec_llm_inputs(
        self,
415
416
        encoder_inputs: SingletonInputs,
        decoder_inputs: Optional[SingletonInputs],
417
    ) -> EncoderDecoderInputs:
418
419
420
421
        if (encoder_inputs["type"] == "embeds"
                or decoder_inputs and decoder_inputs["type"] == "embeds"):
            raise ValueError("Embedding inputs are not supported for encoder-"
                             "decoder models")
422

423
424
425
426
427
        # Needed for mypy
        encoder_inputs = cast(Union[TokenInputs, MultiModalInputs],
                              encoder_inputs)
        decoder_inputs = cast(Optional[Union[TokenInputs, MultiModalInputs]],
                              decoder_inputs)
428

429
        if decoder_inputs is None:
430
431
432
433
434
435
436
437
438
            if self.model_config.hf_config.model_type == "whisper":
                # For Whisper models, the text prompt should go to the decoder.
                # If no explicit encoder/decoder inputs, then copy the prompt
                # from the encoder to the decoder. The encoder tokens are later
                # overridden by the audio features.
                dec_token_ids = encoder_inputs["prompt_token_ids"].copy()
            else:
                dec_token_ids = self._prepare_decoder_input_ids_for_generation(
                    None)
439
            decoder_inputs = token_inputs(dec_token_ids)
440
        else:
441
442
443
            if "multi_modal_data" in decoder_inputs:
                raise ValueError("Multi-modal decoder inputs of encoder-"
                                 "decoder models are not supported yet")
444
445
446
447

            dec_token_ids = self._prepare_decoder_input_ids_for_generation(
                decoder_inputs["prompt_token_ids"])
            decoder_inputs["prompt_token_ids"] = dec_token_ids
448

449
        return EncoderDecoderInputs(
450
451
            encoder=encoder_inputs,
            decoder=decoder_inputs,
452
453
        )

454
    def _split_enc_dec_mm_inputs(
455
        self,
456
457
        inputs: Union[SingletonInputs, MultiModalEncDecInputs],
        decoder_inputs_to_override: Optional[SingletonInputs] = None,
458
    ) -> tuple[SingletonInputs, SingletonInputs]:
459
460
461
462
        """
        For encoder/decoder models only:
        Separate Encoder/Decoder inputs from a MultiModalEncDecInputs
        """
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
        if (inputs["type"] == "embeds" or decoder_inputs_to_override
                and decoder_inputs_to_override["type"] == "embeds"):
            raise ValueError("Embedding inputs are not supported for encoder-"
                             "decoder models")

        # Needed for mypy
        inputs = cast(
            Union[TokenInputs, MultiModalInputs, MultiModalEncDecInputs],
            inputs,
        )
        decoder_inputs_to_override = cast(
            Optional[Union[TokenInputs, MultiModalInputs]],
            decoder_inputs_to_override,
        )

478
479
        encoder_inputs: SingletonInputs
        decoder_inputs: SingletonInputs
480
481

        if inputs["type"] == "multimodal":  # Multimodal data inputs
482
            if "encoder_prompt_token_ids" not in inputs:
483
484
485
                raise RuntimeError("You should register an encoder-decoder "
                                   "multi-modal processor for encoder-decoder "
                                   "models.")
486
            inputs = cast(MultiModalEncDecInputs, inputs)
487

488
            encoder_inputs = token_inputs(inputs["encoder_prompt_token_ids"])
489

490
491
492
493
494
495
496
497
498
            decoder_prompt_inputs = decoder_inputs_to_override or inputs
            decoder_inputs = MultiModalInputs(
                type="multimodal",
                prompt_token_ids=decoder_prompt_inputs["prompt_token_ids"],
                mm_kwargs=inputs["mm_kwargs"],
                mm_hashes=inputs["mm_hashes"],
                mm_placeholders=inputs["mm_placeholders"],
            )
            if cache_salt := inputs.get("cache_salt"):
499
500
                decoder_inputs["cache_salt"] = cache_salt

501
        elif inputs["type"] == "token":  # Text-only inputs
502
            encoder_inputs = token_inputs(prompt_token_ids=[])
503
504
505
            decoder_inputs = decoder_inputs_to_override or inputs
        else:
            assert_never(inputs)  # type: ignore[arg-type]
506

507
508
        return encoder_inputs, decoder_inputs

509
510
    def _process_encoder_decoder_prompt(
        self,
511
        prompt: PromptType,
512
        tokenization_kwargs: Optional[dict[str, Any]] = None,
513
        *,
514
        mm_uuids: Optional[MultiModalUUIDDict] = None,
515
    ) -> EncoderDecoderInputs:
516
        """
517
        For encoder/decoder models only:
518
519
520
        Process an input prompt into an
        [`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs]
        instance.
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538

        There are two types of input prompts:
        singleton prompts which carry only the
        encoder prompt, and explicit encoder/decoder
        prompts which carry both the encoder and the
        decoder prompts as member variables.

        This function handles the following scenarios:
        * Singleton encoder prompt: extract encoder prompt
          token ids & infer default decoder prompt token ids
        * Explicit encoder/decoder prompt: extract encoder
          and decoder prompt token ids

        Note that for Explicit encoder/decoder prompts,
        each sub-prompt (encoder or decoder prompt) can
        have any possible singleton type; thus this
        method relies on helper functions to obtain
        token ids for the sub-prompts.
539

540
541
        Arguments:

542
        * prompt: an input prompt
543
544
545

        Returns:

546
547
        * [`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs]
          instance
548
        """
549
550
        encoder_inputs: SingletonInputs
        decoder_inputs: Optional[SingletonInputs]
551

552
        if is_explicit_encoder_decoder_prompt(prompt):
553
554
            # `cast` is needed for mypy, but not pyright
            prompt_ = cast(ExplicitEncoderDecoderPrompt, prompt)
555
            encoder_inputs = self._prompt_to_llm_inputs(
556
                prompt_["encoder_prompt"],
557
                tokenization_kwargs=tokenization_kwargs,
558
                mm_uuids=mm_uuids,
559
            )
560
            if (decoder_input := prompt_["decoder_prompt"]) is None:
561
                decoder_inputs = None
562
            else:
563
                decoder_inputs = self._prompt_to_llm_inputs(decoder_input)
564
565
            # For multimodal model, override decoder prompt from processor
            # with explicit decoder prompt.
566
            if self.model_config.is_multimodal_model:
567
                encoder_inputs, decoder_inputs = (
568
569
                    self._split_enc_dec_mm_inputs(encoder_inputs,
                                                  decoder_inputs))
570
        else:
571
            # `cast` is needed for mypy, but not pyright
572
            inputs = self._prompt_to_llm_inputs(
573
                cast(SingletonPrompt, prompt),
574
                tokenization_kwargs=tokenization_kwargs,
575
                mm_uuids=mm_uuids,
576
            )
577
            if self.model_config.is_multimodal_model:
578
579
                # Encoder-Decoder Multimodal model
                encoder_inputs, decoder_inputs = (
580
                    self._split_enc_dec_mm_inputs(inputs))
581
582
583
            else:
                encoder_inputs = inputs
                decoder_inputs = None
584
585

        return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
586
587
588

    def _build_decoder_only_llm_inputs(
        self,
589
        prompt_inputs: DecoderOnlyInputs,
590
    ) -> DecoderOnlyInputs:
591
592
593
        if "prompt_token_ids" in prompt_inputs:
            prompt_inputs = cast(Union[TokenInputs, MultiModalInputs],
                                 prompt_inputs)  # Needed for mypy
594

595
        return prompt_inputs
596
597
598

    def _process_decoder_only_prompt(
        self,
599
        prompt: SingletonPrompt,
600
        tokenization_kwargs: Optional[dict[str, Any]] = None,
601
        *,
602
        mm_uuids: Optional[MultiModalUUIDDict] = None,
603
    ) -> DecoderOnlyInputs:
604
        """
605
        For decoder-only models:
606
607
        Process an input prompt into a
        [`DecoderOnlyInputs`][vllm.inputs.data.DecoderOnlyInputs] instance.
608
609
610

        Arguments:

611
        * prompt: input prompt
612
613
614

        Returns:

615
        * [`DecoderOnlyInputs`][vllm.inputs.data.DecoderOnlyInputs] instance
616
        """
617

618
        prompt_comps = self._prompt_to_llm_inputs(
619
            prompt,
620
            tokenization_kwargs=tokenization_kwargs,
621
            mm_uuids=mm_uuids,
622
623
        )

624
        return self._build_decoder_only_llm_inputs(prompt_comps)
625
626
627

    def preprocess(
        self,
628
        prompt: PromptType,
629
        tokenization_kwargs: Optional[dict[str, Any]] = None,
630
        *,
631
        mm_uuids: Optional[MultiModalUUIDDict] = None,
632
    ) -> ProcessorInputs:
633
        """Preprocess the input prompt."""
634
        if self.model_config.is_encoder_decoder:
635
            # Encoder-decoder model requires special mapping of
636
            # input prompts to encoder & decoder.
637
            return self._process_encoder_decoder_prompt(
638
639
                prompt,
                tokenization_kwargs,
640
                mm_uuids=mm_uuids,
641
            )
642

643
        if is_explicit_encoder_decoder_prompt(prompt):
644
645
646
647
            raise ValueError("Cannot pass encoder-decoder prompt "
                             "to decoder-only models")

        # Decoder-only operation
648
        # `cast` is needed for mypy, but not pyright
649
        return self._process_decoder_only_prompt(
650
            cast(SingletonPrompt, prompt),
651
            tokenization_kwargs=tokenization_kwargs,
652
            mm_uuids=mm_uuids,
653
654
        )

655
656
657
    def clear_cache(self) -> None:
        if self.mm_processor_cache is not None:
            self.mm_processor_cache.clear_cache()
658
659
660
661
662
663
664
665
666
667
668
669


# Helper function to validate that a nested dictionary contains
# only strings or list of strings as the leaf values.
def contains_only_strings(obj: object):
    if isinstance(obj, str):
        return True
    if isinstance(obj, list):
        return all(isinstance(x, str) for x in obj)
    if isinstance(obj, dict):
        return all(contains_only_strings(v) for v in obj.values())
    return False