parse.py 4.33 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
from collections.abc import Sequence
3
from typing import Literal, Optional, TypedDict, Union, cast, overload
4
5
6
7
8

from typing_extensions import TypeIs

from vllm.utils import is_list_of

9
10
11
from .data import (EmbedsPrompt, ExplicitEncoderDecoderPrompt, ProcessorInputs,
                   PromptType, SingletonInputs, SingletonPrompt, TextPrompt,
                   TokensPrompt)
12
13
14
15
16
17
18
19


class ParsedText(TypedDict):
    content: str
    is_tokens: Literal[False]


class ParsedTokens(TypedDict):
20
    content: list[int]
21
22
23
24
25
    is_tokens: Literal[True]


@overload
def parse_and_batch_prompt(
26
        prompt: Union[str, list[str]]) -> Sequence[ParsedText]:
27
28
29
30
31
    ...


@overload
def parse_and_batch_prompt(
32
        prompt: Union[list[int], list[list[int]]]) -> Sequence[ParsedTokens]:
33
34
35
36
    ...


def parse_and_batch_prompt(
37
    prompt: Union[str, list[str], list[int], list[list[int]]],
38
39
40
41
42
43
44
45
46
47
48
) -> Union[Sequence[ParsedText], Sequence[ParsedTokens]]:
    if isinstance(prompt, str):
        # case 1: a string
        return [ParsedText(content=prompt, is_tokens=False)]

    if isinstance(prompt, list):
        if len(prompt) == 0:
            raise ValueError("please provide at least one prompt")

        if is_list_of(prompt, str):
            # case 2: array of strings
49
            prompt = cast(list[str], prompt)
50
51
52
53
54
            return [
                ParsedText(content=elem, is_tokens=False) for elem in prompt
            ]
        if is_list_of(prompt, int):
            # case 3: array of tokens
55
            prompt = cast(list[int], prompt)
56
57
            return [ParsedTokens(content=prompt, is_tokens=True)]
        if is_list_of(prompt, list):
58
            prompt = cast(list[list[int]], prompt)
59
60
61
62
63
64
65
66
67
68
            if len(prompt[0]) == 0:
                raise ValueError("please provide at least one prompt")

            if is_list_of(prompt[0], int):
                # case 4: array of token arrays
                return [
                    ParsedTokens(content=elem, is_tokens=True)
                    for elem in prompt
                ]

69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
    raise TypeError("prompt must be a string, array of strings, "
                    "array of tokens, or array of token arrays")


class ParsedStrPrompt(TypedDict):
    type: Literal["str"]
    content: str


class ParsedTextPrompt(TypedDict):
    type: Literal["text"]
    content: TextPrompt


class ParsedTokensPrompt(TypedDict):
    type: Literal["tokens"]
    content: TokensPrompt


88
89
90
91
92
class ParsedEmbedsPrompt(TypedDict):
    type: Literal['embeds']
    content: EmbedsPrompt


93
94
95
96
ParsedSingletonPrompt = Union[ParsedStrPrompt, ParsedTextPrompt,
                              ParsedTokensPrompt, ParsedEmbedsPrompt]


97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
@overload
def parse_singleton_prompt(prompt: str) -> ParsedStrPrompt:
    ...


@overload
def parse_singleton_prompt(prompt: TextPrompt) -> ParsedTextPrompt:
    ...


@overload
def parse_singleton_prompt(prompt: TokensPrompt) -> ParsedTokensPrompt:
    ...


@overload
def parse_singleton_prompt(prompt: EmbedsPrompt) -> ParsedEmbedsPrompt:
    ...


117
def parse_singleton_prompt(prompt: SingletonPrompt) -> ParsedSingletonPrompt:
118
119
120
    if isinstance(prompt, str):
        return ParsedStrPrompt(type="str", content=prompt)
    elif isinstance(prompt, dict):
121
122
123
124
125
126
127
128
        # Type ignores are because mypy does not correctly infer the TypedDicts
        # Pyright does succeed.
        if "prompt_embeds" in prompt:
            return ParsedEmbedsPrompt(
                type="embeds", content=prompt)  # type: ignore[typeddict-item]
        elif "prompt_token_ids" in prompt:
            return ParsedTokensPrompt(
                type="tokens", content=prompt)  # type: ignore[typeddict-item]
129
130
        elif "prompt" in prompt:
            return ParsedTextPrompt(type="text", content=prompt)
131
132
    raise TypeError(
        "inputs must be a string, TextPrompt, TokensPrompt, or EmbedsPrompt")
133
134
135


def is_explicit_encoder_decoder_prompt(
136
137
        prompt: PromptType) -> TypeIs[ExplicitEncoderDecoderPrompt]:
    return isinstance(prompt, dict) and "encoder_prompt" in prompt
138
139


140
141
142
143
144
145
146
147
148
149
150
def split_enc_dec_inputs(
    inputs: ProcessorInputs,
) -> tuple[Optional[SingletonInputs], SingletonInputs]:
    if "encoder" in inputs and "decoder" in inputs:
        # NOTE: This passes pyright but not mypy
        return (
            inputs["encoder"],  # type: ignore[typeddict-item]
            inputs["decoder"],  # type: ignore[typeddict-item]
        )

    return None, inputs