parse.py 4.24 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from collections.abc import Sequence
4
from typing import (TYPE_CHECKING, Literal, NamedTuple, Optional, TypedDict,
5
                    Union, cast)
6
7
8
9
10

from typing_extensions import TypeIs

from vllm.utils import is_list_of

11
12
13
from .data import (EmbedsPrompt, ExplicitEncoderDecoderPrompt, ProcessorInputs,
                   PromptType, SingletonInputs, SingletonPrompt, TextPrompt,
                   TokensPrompt)
14

15
16
17
if TYPE_CHECKING:
    import torch

18

19
def parse_raw_prompts(
20
    prompt: Union[str, list[str], list[int], list[list[int]]],
21
) -> Union[Sequence[TextPrompt], Sequence[TokensPrompt]]:
22
23
    if isinstance(prompt, str):
        # case 1: a string
24
        return [TextPrompt(prompt=prompt)]
25
26
27
28
29
30
31

    if isinstance(prompt, list):
        if len(prompt) == 0:
            raise ValueError("please provide at least one prompt")

        if is_list_of(prompt, str):
            # case 2: array of strings
32
            prompt = cast(list[str], prompt)
33
            return [TextPrompt(prompt=elem) for elem in prompt]
34
35
        if is_list_of(prompt, int):
            # case 3: array of tokens
36
            prompt = cast(list[int], prompt)
37
            return [TokensPrompt(prompt_token_ids=prompt)]
38
        if is_list_of(prompt, list):
39
            prompt = cast(list[list[int]], prompt)
40
41
42
43
44
            if len(prompt[0]) == 0:
                raise ValueError("please provide at least one prompt")

            if is_list_of(prompt[0], int):
                # case 4: array of token arrays
45
                return [TokensPrompt(prompt_token_ids=elem) for elem in prompt]
46

47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
    raise TypeError("prompt must be a string, array of strings, "
                    "array of tokens, or array of token arrays")


class ParsedStrPrompt(TypedDict):
    type: Literal["str"]
    content: str


class ParsedTextPrompt(TypedDict):
    type: Literal["text"]
    content: TextPrompt


class ParsedTokensPrompt(TypedDict):
    type: Literal["tokens"]
    content: TokensPrompt


66
class ParsedEmbedsPrompt(TypedDict):
67
    type: Literal["embeds"]
68
69
70
    content: EmbedsPrompt


71
72
73
74
75
ParsedSingletonPrompt = Union[ParsedStrPrompt, ParsedTextPrompt,
                              ParsedTokensPrompt, ParsedEmbedsPrompt]


def parse_singleton_prompt(prompt: SingletonPrompt) -> ParsedSingletonPrompt:
76
77
78
    if isinstance(prompt, str):
        return ParsedStrPrompt(type="str", content=prompt)
    elif isinstance(prompt, dict):
79
80
81
82
83
84
85
86
        # Type ignores are because mypy does not correctly infer the TypedDicts
        # Pyright does succeed.
        if "prompt_embeds" in prompt:
            return ParsedEmbedsPrompt(
                type="embeds", content=prompt)  # type: ignore[typeddict-item]
        elif "prompt_token_ids" in prompt:
            return ParsedTokensPrompt(
                type="tokens", content=prompt)  # type: ignore[typeddict-item]
87
88
        elif "prompt" in prompt:
            return ParsedTextPrompt(type="text", content=prompt)
89
90
    raise TypeError(
        "inputs must be a string, TextPrompt, TokensPrompt, or EmbedsPrompt")
91
92
93


def is_explicit_encoder_decoder_prompt(
94
    prompt: PromptType, ) -> TypeIs[ExplicitEncoderDecoderPrompt]:
95
    return isinstance(prompt, dict) and "encoder_prompt" in prompt
96
97


98
99
100
101
102
103
104
105
106
107
108
def split_enc_dec_inputs(
    inputs: ProcessorInputs,
) -> tuple[Optional[SingletonInputs], SingletonInputs]:
    if "encoder" in inputs and "decoder" in inputs:
        # NOTE: This passes pyright but not mypy
        return (
            inputs["encoder"],  # type: ignore[typeddict-item]
            inputs["decoder"],  # type: ignore[typeddict-item]
        )

    return None, inputs
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128


class PromptComponents(NamedTuple):
    text: Optional[str] = None
    token_ids: Optional[list[int]] = None
    embeds: Optional["torch.Tensor"] = None


def get_prompt_components(prompt: PromptType) -> PromptComponents:
    if isinstance(prompt, str):
        return PromptComponents(text=prompt)

    if (encoder_prompt := prompt.get("encoder_prompt")):
        return get_prompt_components(encoder_prompt)  # type: ignore[arg-type]

    return PromptComponents(
        text=prompt.get("prompt"),  # type: ignore[arg-type]
        token_ids=prompt.get("prompt_token_ids"),  # type: ignore[arg-type]
        embeds=prompt.get("prompt_embeds"),
    )