parse.py 4.55 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from collections.abc import Sequence
4
from typing import TYPE_CHECKING, Literal, NamedTuple, TypeAlias, TypedDict, cast
5
6
7

from typing_extensions import TypeIs

8
from vllm.utils.collection_utils import is_list_of
9

10
11
12
13
14
15
16
17
18
19
from .data import (
    EmbedsPrompt,
    ExplicitEncoderDecoderPrompt,
    ProcessorInputs,
    PromptType,
    SingletonInputs,
    SingletonPrompt,
    TextPrompt,
    TokensPrompt,
)
20

21
22
23
if TYPE_CHECKING:
    import torch

24

25
def parse_raw_prompts(
26
27
    prompt: str | list[str] | list[int] | list[list[int]],
) -> Sequence[TextPrompt] | Sequence[TokensPrompt]:
28
29
    if isinstance(prompt, str):
        # case 1: a string
30
        return [TextPrompt(prompt=prompt)]
31
32
33
34
35

    if isinstance(prompt, list):
        if len(prompt) == 0:
            raise ValueError("please provide at least one prompt")

36
        # case 2: array of strings
37
        if is_list_of(prompt, str):
38
            prompt = cast(list[str], prompt)
39
            return [TextPrompt(prompt=elem) for elem in prompt]
40
41

        # case 3: array of tokens
42
        if is_list_of(prompt, int):
43
            prompt = cast(list[int], prompt)
44
            return [TokensPrompt(prompt_token_ids=prompt)]
45
46

        # case 4: array of token arrays
47
        if is_list_of(prompt, list):
48
49
50
51
52
53
54
55
56
57
58
            if len(prompt) == 1 and isinstance(prompt[0], list) and len(prompt[0]) == 0:
                raise ValueError("please provide at least one prompt")
            for elem in prompt:
                if not isinstance(elem, list):
                    raise TypeError(
                        "prompt must be a list of lists, but found a non-list element."
                    )
                if not is_list_of(elem, int):
                    raise TypeError(
                        "Nested lists of tokens must contain only integers."
                    )
59
60
61

            prompt = cast(list[list[int]], prompt)
            return [TokensPrompt(prompt_token_ids=elem) for elem in prompt]
62

63
64
65
66
    raise TypeError(
        "prompt must be a string, array of strings, "
        "array of tokens, or array of token arrays"
    )
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83


class ParsedStrPrompt(TypedDict):
    type: Literal["str"]
    content: str


class ParsedTextPrompt(TypedDict):
    type: Literal["text"]
    content: TextPrompt


class ParsedTokensPrompt(TypedDict):
    type: Literal["tokens"]
    content: TokensPrompt


84
class ParsedEmbedsPrompt(TypedDict):
85
    type: Literal["embeds"]
86
87
88
    content: EmbedsPrompt


89
90
91
ParsedSingletonPrompt: TypeAlias = (
    ParsedStrPrompt | ParsedTextPrompt | ParsedTokensPrompt | ParsedEmbedsPrompt
)
92
93
94


def parse_singleton_prompt(prompt: SingletonPrompt) -> ParsedSingletonPrompt:
95
96
97
    if isinstance(prompt, str):
        return ParsedStrPrompt(type="str", content=prompt)
    elif isinstance(prompt, dict):
98
99
100
        # Type ignores are because mypy does not correctly infer the TypedDicts
        # Pyright does succeed.
        if "prompt_embeds" in prompt:
101
            return ParsedEmbedsPrompt(type="embeds", content=prompt)  # type: ignore[typeddict-item]
102
        elif "prompt_token_ids" in prompt:
103
            return ParsedTokensPrompt(type="tokens", content=prompt)  # type: ignore[typeddict-item]
104
105
        elif "prompt" in prompt:
            return ParsedTextPrompt(type="text", content=prompt)
106
    raise TypeError(
107
108
        "inputs must be a string, TextPrompt, TokensPrompt, or EmbedsPrompt"
    )
109
110
111


def is_explicit_encoder_decoder_prompt(
112
113
    prompt: PromptType,
) -> TypeIs[ExplicitEncoderDecoderPrompt]:
114
    return isinstance(prompt, dict) and "encoder_prompt" in prompt
115
116


117
118
def split_enc_dec_inputs(
    inputs: ProcessorInputs,
119
) -> tuple[SingletonInputs | None, SingletonInputs]:
120
121
122
123
124
125
126
127
    if "encoder" in inputs and "decoder" in inputs:
        # NOTE: This passes pyright but not mypy
        return (
            inputs["encoder"],  # type: ignore[typeddict-item]
            inputs["decoder"],  # type: ignore[typeddict-item]
        )

    return None, inputs
128
129
130


class PromptComponents(NamedTuple):
131
132
133
    text: str | None = None
    token_ids: list[int] | None = None
    embeds: "torch.Tensor | None" = None
134
135
136
137
138
139


def get_prompt_components(prompt: PromptType) -> PromptComponents:
    if isinstance(prompt, str):
        return PromptComponents(text=prompt)

140
    if encoder_prompt := prompt.get("encoder_prompt"):
141
142
143
144
145
146
147
        return get_prompt_components(encoder_prompt)  # type: ignore[arg-type]

    return PromptComponents(
        text=prompt.get("prompt"),  # type: ignore[arg-type]
        token_ids=prompt.get("prompt_token_ids"),  # type: ignore[arg-type]
        embeds=prompt.get("prompt_embeds"),
    )