parse.py 3.12 KB
Newer Older
1
2
3
4
5
6
7
from typing import List, Literal, Sequence, TypedDict, Union, overload

from typing_extensions import TypeIs

from vllm.utils import is_list_of

from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt,
8
9
                   LLMInputs, PromptInputs, SingletonPromptInputs, TextPrompt,
                   TokensPrompt)
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63


class ParsedText(TypedDict):
    content: str
    is_tokens: Literal[False]


class ParsedTokens(TypedDict):
    content: List[int]
    is_tokens: Literal[True]


@overload
def parse_and_batch_prompt(
        prompt: Union[str, List[str]]) -> Sequence[ParsedText]:
    ...


@overload
def parse_and_batch_prompt(
        prompt: Union[List[int], List[List[int]]]) -> Sequence[ParsedTokens]:
    ...


def parse_and_batch_prompt(
    prompt: Union[str, List[str], List[int], List[List[int]]],
) -> Union[Sequence[ParsedText], Sequence[ParsedTokens]]:
    if isinstance(prompt, str):
        # case 1: a string
        return [ParsedText(content=prompt, is_tokens=False)]

    if isinstance(prompt, list):
        if len(prompt) == 0:
            raise ValueError("please provide at least one prompt")

        if is_list_of(prompt, str):
            # case 2: array of strings
            return [
                ParsedText(content=elem, is_tokens=False) for elem in prompt
            ]
        if is_list_of(prompt, int):
            # case 3: array of tokens
            return [ParsedTokens(content=prompt, is_tokens=True)]
        if is_list_of(prompt, list):
            if len(prompt[0]) == 0:
                raise ValueError("please provide at least one prompt")

            if is_list_of(prompt[0], int):
                # case 4: array of token arrays
                return [
                    ParsedTokens(content=elem, is_tokens=True)
                    for elem in prompt
                ]

64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
    raise TypeError("prompt must be a string, array of strings, "
                    "array of tokens, or array of token arrays")


class ParsedStrPrompt(TypedDict):
    type: Literal["str"]
    content: str


class ParsedTextPrompt(TypedDict):
    type: Literal["text"]
    content: TextPrompt


class ParsedTokensPrompt(TypedDict):
    type: Literal["tokens"]
    content: TokensPrompt


def parse_singleton_prompt(
    inputs: SingletonPromptInputs,
) -> Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt]:
    if isinstance(inputs, str):
        return ParsedStrPrompt(type="str", content=inputs)
    elif isinstance(inputs, dict):
        if "prompt_token_ids" in inputs:
            return ParsedTokensPrompt(type="tokens",
                                      content=inputs)  # type: ignore
        elif "prompt" in inputs:
            return ParsedTextPrompt(type="text", content=inputs)

    raise TypeError("inputs must be a string, TextPrompt, or TokensPrompt")
96
97
98
99
100
101
102
103
104
105
106


def is_explicit_encoder_decoder_prompt(
        inputs: PromptInputs) -> TypeIs[ExplicitEncoderDecoderPrompt]:
    return isinstance(inputs, dict) and "encoder_prompt" in inputs


def is_valid_encoder_decoder_llm_inputs(
    inputs: Union[LLMInputs, EncoderDecoderLLMInputs],
) -> TypeIs[EncoderDecoderLLMInputs]:
    return "encoder_prompt_token_ids" in inputs