parse.py 3.24 KB
Newer Older
1
from typing import List, Literal, Sequence, TypedDict, Union, cast, overload
2
3
4
5
6

from typing_extensions import TypeIs

from vllm.utils import is_list_of

7
8
9
from .data import (DecoderOnlyInputs, EncoderDecoderInputs,
                   ExplicitEncoderDecoderPrompt, PromptType, SingletonPrompt,
                   TextPrompt, TokensPrompt)
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46


class ParsedText(TypedDict):
    content: str
    is_tokens: Literal[False]


class ParsedTokens(TypedDict):
    content: List[int]
    is_tokens: Literal[True]


@overload
def parse_and_batch_prompt(
        prompt: Union[str, List[str]]) -> Sequence[ParsedText]:
    ...


@overload
def parse_and_batch_prompt(
        prompt: Union[List[int], List[List[int]]]) -> Sequence[ParsedTokens]:
    ...


def parse_and_batch_prompt(
    prompt: Union[str, List[str], List[int], List[List[int]]],
) -> Union[Sequence[ParsedText], Sequence[ParsedTokens]]:
    if isinstance(prompt, str):
        # case 1: a string
        return [ParsedText(content=prompt, is_tokens=False)]

    if isinstance(prompt, list):
        if len(prompt) == 0:
            raise ValueError("please provide at least one prompt")

        if is_list_of(prompt, str):
            # case 2: array of strings
47
            prompt = cast(List[str], prompt)
48
49
50
51
52
            return [
                ParsedText(content=elem, is_tokens=False) for elem in prompt
            ]
        if is_list_of(prompt, int):
            # case 3: array of tokens
53
            prompt = cast(List[int], prompt)
54
55
            return [ParsedTokens(content=prompt, is_tokens=True)]
        if is_list_of(prompt, list):
56
            prompt = cast(List[List[int]], prompt)
57
58
59
60
61
62
63
64
65
66
            if len(prompt[0]) == 0:
                raise ValueError("please provide at least one prompt")

            if is_list_of(prompt[0], int):
                # case 4: array of token arrays
                return [
                    ParsedTokens(content=elem, is_tokens=True)
                    for elem in prompt
                ]

67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
    raise TypeError("prompt must be a string, array of strings, "
                    "array of tokens, or array of token arrays")


class ParsedStrPrompt(TypedDict):
    type: Literal["str"]
    content: str


class ParsedTextPrompt(TypedDict):
    type: Literal["text"]
    content: TextPrompt


class ParsedTokensPrompt(TypedDict):
    type: Literal["tokens"]
    content: TokensPrompt


def parse_singleton_prompt(
87
    prompt: SingletonPrompt,
88
) -> Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt]:
89
90
91
92
    if isinstance(prompt, str):
        return ParsedStrPrompt(type="str", content=prompt)
    elif isinstance(prompt, dict):
        if "prompt_token_ids" in prompt:
93
            return ParsedTokensPrompt(type="tokens",
94
95
96
                                      content=prompt)  # type: ignore
        elif "prompt" in prompt:
            return ParsedTextPrompt(type="text", content=prompt)
97
98

    raise TypeError("inputs must be a string, TextPrompt, or TokensPrompt")
99
100
101


def is_explicit_encoder_decoder_prompt(
102
103
        prompt: PromptType) -> TypeIs[ExplicitEncoderDecoderPrompt]:
    return isinstance(prompt, dict) and "encoder_prompt" in prompt
104
105


106
107
108
def is_encoder_decoder_inputs(
    inputs: Union[DecoderOnlyInputs, EncoderDecoderInputs],
) -> TypeIs[EncoderDecoderInputs]:
109
    return "encoder_prompt_token_ids" in inputs