parse.py 3.1 KB
Newer Older
1
2
3
4
5
6
from typing import List, Literal, Sequence, TypedDict, Union, overload

from typing_extensions import TypeIs

from vllm.utils import is_list_of

7
8
9
from .data import (DecoderOnlyInputs, EncoderDecoderInputs,
                   ExplicitEncoderDecoderPrompt, PromptType, SingletonPrompt,
                   TextPrompt, TokensPrompt)
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63


class ParsedText(TypedDict):
    content: str
    is_tokens: Literal[False]


class ParsedTokens(TypedDict):
    content: List[int]
    is_tokens: Literal[True]


@overload
def parse_and_batch_prompt(
        prompt: Union[str, List[str]]) -> Sequence[ParsedText]:
    ...


@overload
def parse_and_batch_prompt(
        prompt: Union[List[int], List[List[int]]]) -> Sequence[ParsedTokens]:
    ...


def parse_and_batch_prompt(
    prompt: Union[str, List[str], List[int], List[List[int]]],
) -> Union[Sequence[ParsedText], Sequence[ParsedTokens]]:
    if isinstance(prompt, str):
        # case 1: a string
        return [ParsedText(content=prompt, is_tokens=False)]

    if isinstance(prompt, list):
        if len(prompt) == 0:
            raise ValueError("please provide at least one prompt")

        if is_list_of(prompt, str):
            # case 2: array of strings
            return [
                ParsedText(content=elem, is_tokens=False) for elem in prompt
            ]
        if is_list_of(prompt, int):
            # case 3: array of tokens
            return [ParsedTokens(content=prompt, is_tokens=True)]
        if is_list_of(prompt, list):
            if len(prompt[0]) == 0:
                raise ValueError("please provide at least one prompt")

            if is_list_of(prompt[0], int):
                # case 4: array of token arrays
                return [
                    ParsedTokens(content=elem, is_tokens=True)
                    for elem in prompt
                ]

64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
    raise TypeError("prompt must be a string, array of strings, "
                    "array of tokens, or array of token arrays")


class ParsedStrPrompt(TypedDict):
    type: Literal["str"]
    content: str


class ParsedTextPrompt(TypedDict):
    type: Literal["text"]
    content: TextPrompt


class ParsedTokensPrompt(TypedDict):
    type: Literal["tokens"]
    content: TokensPrompt


def parse_singleton_prompt(
84
    prompt: SingletonPrompt,
85
) -> Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt]:
86
87
88
89
    if isinstance(prompt, str):
        return ParsedStrPrompt(type="str", content=prompt)
    elif isinstance(prompt, dict):
        if "prompt_token_ids" in prompt:
90
            return ParsedTokensPrompt(type="tokens",
91
92
93
                                      content=prompt)  # type: ignore
        elif "prompt" in prompt:
            return ParsedTextPrompt(type="text", content=prompt)
94
95

    raise TypeError("inputs must be a string, TextPrompt, or TokensPrompt")
96
97
98


def is_explicit_encoder_decoder_prompt(
99
100
        prompt: PromptType) -> TypeIs[ExplicitEncoderDecoderPrompt]:
    return isinstance(prompt, dict) and "encoder_prompt" in prompt
101
102


103
104
105
def is_encoder_decoder_inputs(
    inputs: Union[DecoderOnlyInputs, EncoderDecoderInputs],
) -> TypeIs[EncoderDecoderInputs]:
106
    return "encoder_prompt_token_ids" in inputs