parse.py 3.39 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from typing import List, Literal, Sequence, TypedDict, Union, cast, overload
4
5
6
7
8

from typing_extensions import TypeIs

from vllm.utils import is_list_of

9
10
11
from .data import (EncoderDecoderInputs, ExplicitEncoderDecoderPrompt,
                   ProcessorInputs, PromptType, SingletonPrompt, TextPrompt,
                   TokensPrompt)
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48


class ParsedText(TypedDict):
    content: str
    is_tokens: Literal[False]


class ParsedTokens(TypedDict):
    content: List[int]
    is_tokens: Literal[True]


@overload
def parse_and_batch_prompt(
        prompt: Union[str, List[str]]) -> Sequence[ParsedText]:
    ...


@overload
def parse_and_batch_prompt(
        prompt: Union[List[int], List[List[int]]]) -> Sequence[ParsedTokens]:
    ...


def parse_and_batch_prompt(
    prompt: Union[str, List[str], List[int], List[List[int]]],
) -> Union[Sequence[ParsedText], Sequence[ParsedTokens]]:
    if isinstance(prompt, str):
        # case 1: a string
        return [ParsedText(content=prompt, is_tokens=False)]

    if isinstance(prompt, list):
        if len(prompt) == 0:
            raise ValueError("please provide at least one prompt")

        if is_list_of(prompt, str):
            # case 2: array of strings
49
            prompt = cast(List[str], prompt)
50
51
52
53
54
            return [
                ParsedText(content=elem, is_tokens=False) for elem in prompt
            ]
        if is_list_of(prompt, int):
            # case 3: array of tokens
55
            prompt = cast(List[int], prompt)
56
57
            return [ParsedTokens(content=prompt, is_tokens=True)]
        if is_list_of(prompt, list):
58
            prompt = cast(List[List[int]], prompt)
59
60
61
62
63
64
65
66
67
68
            if len(prompt[0]) == 0:
                raise ValueError("please provide at least one prompt")

            if is_list_of(prompt[0], int):
                # case 4: array of token arrays
                return [
                    ParsedTokens(content=elem, is_tokens=True)
                    for elem in prompt
                ]

69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
    raise TypeError("prompt must be a string, array of strings, "
                    "array of tokens, or array of token arrays")


class ParsedStrPrompt(TypedDict):
    type: Literal["str"]
    content: str


class ParsedTextPrompt(TypedDict):
    type: Literal["text"]
    content: TextPrompt


class ParsedTokensPrompt(TypedDict):
    type: Literal["tokens"]
    content: TokensPrompt


def parse_singleton_prompt(
89
    prompt: SingletonPrompt,
90
) -> Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt]:
91
92
93
94
    if isinstance(prompt, str):
        return ParsedStrPrompt(type="str", content=prompt)
    elif isinstance(prompt, dict):
        if "prompt_token_ids" in prompt:
95
            return ParsedTokensPrompt(type="tokens",
96
97
98
                                      content=prompt)  # type: ignore
        elif "prompt" in prompt:
            return ParsedTextPrompt(type="text", content=prompt)
99
100

    raise TypeError("inputs must be a string, TextPrompt, or TokensPrompt")
101
102


103
104
105
106
def is_token_prompt(prompt: PromptType) -> TypeIs[TokensPrompt]:
    return isinstance(prompt, dict) and "prompt_token_ids" in prompt


107
def is_explicit_encoder_decoder_prompt(
108
109
        prompt: PromptType) -> TypeIs[ExplicitEncoderDecoderPrompt]:
    return isinstance(prompt, dict) and "encoder_prompt" in prompt
110
111


112
def is_encoder_decoder_inputs(
113
114
        inputs: ProcessorInputs) -> TypeIs[EncoderDecoderInputs]:
    return "encoder" in inputs and "decoder" in inputs