parse.py 3.16 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from typing import TYPE_CHECKING, Literal, NamedTuple, TypeAlias, TypedDict
4
5
6

from typing_extensions import TypeIs

7
from vllm.utils import length_from_prompt_token_ids_or_embeds
8

9
10
11
12
13
14
15
16
17
18
from .data import (
    EmbedsPrompt,
    ExplicitEncoderDecoderPrompt,
    ProcessorInputs,
    PromptType,
    SingletonInputs,
    SingletonPrompt,
    TextPrompt,
    TokensPrompt,
)
19

20
21
22
if TYPE_CHECKING:
    import torch

23

24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
class ParsedStrPrompt(TypedDict):
    type: Literal["str"]
    content: str


class ParsedTextPrompt(TypedDict):
    type: Literal["text"]
    content: TextPrompt


class ParsedTokensPrompt(TypedDict):
    type: Literal["tokens"]
    content: TokensPrompt


39
class ParsedEmbedsPrompt(TypedDict):
40
    type: Literal["embeds"]
41
42
43
    content: EmbedsPrompt


44
45
46
ParsedSingletonPrompt: TypeAlias = (
    ParsedStrPrompt | ParsedTextPrompt | ParsedTokensPrompt | ParsedEmbedsPrompt
)
47
48
49


def parse_singleton_prompt(prompt: SingletonPrompt) -> ParsedSingletonPrompt:
50
51
52
    if isinstance(prompt, str):
        return ParsedStrPrompt(type="str", content=prompt)
    elif isinstance(prompt, dict):
53
54
55
        # Type ignores are because mypy does not correctly infer the TypedDicts
        # Pyright does succeed.
        if "prompt_embeds" in prompt:
56
            return ParsedEmbedsPrompt(type="embeds", content=prompt)  # type: ignore[typeddict-item]
57
        elif "prompt_token_ids" in prompt:
58
            return ParsedTokensPrompt(type="tokens", content=prompt)  # type: ignore[typeddict-item]
59
60
        elif "prompt" in prompt:
            return ParsedTextPrompt(type="text", content=prompt)
61
    raise TypeError(
62
63
        "inputs must be a string, TextPrompt, TokensPrompt, or EmbedsPrompt"
    )
64
65
66


def is_explicit_encoder_decoder_prompt(
67
68
    prompt: PromptType,
) -> TypeIs[ExplicitEncoderDecoderPrompt]:
69
    return isinstance(prompt, dict) and "encoder_prompt" in prompt
70
71


72
73
def split_enc_dec_inputs(
    inputs: ProcessorInputs,
74
) -> tuple[SingletonInputs | None, SingletonInputs]:
75
76
77
78
79
80
81
82
    if "encoder" in inputs and "decoder" in inputs:
        # NOTE: This passes pyright but not mypy
        return (
            inputs["encoder"],  # type: ignore[typeddict-item]
            inputs["decoder"],  # type: ignore[typeddict-item]
        )

    return None, inputs
83
84
85


class PromptComponents(NamedTuple):
86
87
88
    text: str | None = None
    token_ids: list[int] | None = None
    embeds: "torch.Tensor | None" = None
89
90
91
92
93
94


def get_prompt_components(prompt: PromptType) -> PromptComponents:
    if isinstance(prompt, str):
        return PromptComponents(text=prompt)

95
    if encoder_prompt := prompt.get("encoder_prompt"):
96
97
98
99
100
101
102
        return get_prompt_components(encoder_prompt)  # type: ignore[arg-type]

    return PromptComponents(
        text=prompt.get("prompt"),  # type: ignore[arg-type]
        token_ids=prompt.get("prompt_token_ids"),  # type: ignore[arg-type]
        embeds=prompt.get("prompt_embeds"),
    )
103
104
105
106
107
108
109


def get_prompt_len(prompt: TokensPrompt | EmbedsPrompt):
    return length_from_prompt_token_ids_or_embeds(
        prompt.get("prompt_token_ids"),  # type: ignore[arg-type]
        prompt.get("prompt_embeds"),  # type: ignore[arg-type]
    )