preprocess.py 8.23 KB
Newer Older
1
"""
Jiayi Yan's avatar
Jiayi Yan committed
2
Schemas and utilities for preprocessing inputs.
3
4
5
6
"""

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
7
from collections.abc import Mapping, Sequence
8
9
10
11
from typing import TYPE_CHECKING, NamedTuple, TypeAlias, TypedDict, overload

from vllm.inputs import (
    EmbedsPrompt,
12
    EngineInput,
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
    ExplicitEncoderDecoderPrompt,
    PromptType,
    SingletonPrompt,
    TextPrompt,
    TokensPrompt,
)
from vllm.utils import length_from_prompt_token_ids_or_embeds
from vllm.utils.collection_utils import is_list_of

if TYPE_CHECKING:
    import torch

    from vllm.config import ModelConfig
    from vllm.entrypoints.chat_utils import ChatCompletionMessageParam


@overload
def prompt_to_seq(
    prompt_or_prompts: SingletonPrompt | bytes | Sequence[SingletonPrompt | bytes],
) -> Sequence[SingletonPrompt]: ...


@overload
def prompt_to_seq(  # type: ignore[misc]
    prompt_or_prompts: ExplicitEncoderDecoderPrompt
    | Sequence[ExplicitEncoderDecoderPrompt],
) -> Sequence[ExplicitEncoderDecoderPrompt]: ...


@overload
def prompt_to_seq(  # type: ignore[misc]
    prompt_or_prompts: PromptType | Sequence[PromptType],
) -> Sequence[PromptType]: ...


def prompt_to_seq(
    prompt_or_prompts: PromptType | bytes | Sequence[PromptType | bytes],
) -> Sequence[PromptType]:
    if isinstance(prompt_or_prompts, (dict, str, bytes)) or (
        len(prompt_or_prompts) > 0 and is_list_of(prompt_or_prompts, int)
    ):
        return [prompt_or_prompts]  # type: ignore[list-item]

    return prompt_or_prompts  # type: ignore[return-value]


def conversation_to_seq(
    conversation_or_conversations: list["ChatCompletionMessageParam"]
    | Sequence[list["ChatCompletionMessageParam"]],
) -> Sequence[list["ChatCompletionMessageParam"]]:
    if len(conversation_or_conversations) > 0 and is_list_of(
        conversation_or_conversations, dict
    ):
        return [conversation_or_conversations]  # type: ignore[list-item]

    return conversation_or_conversations  # type: ignore[return-value]


DecoderOnlyDictPrompt: TypeAlias = TextPrompt | TokensPrompt | EmbedsPrompt
"""
73
A [`DecoderOnlyPrompt`][vllm.inputs.llm.DecoderOnlyPrompt]
74
75
76
77
78
79
that has been standardized into a dictionary.
"""


EncoderDictPrompt: TypeAlias = TextPrompt | TokensPrompt
"""
80
A [`EncoderPrompt`][vllm.inputs.llm.EncoderPrompt]
81
82
83
84
85
86
that has been standardized into a dictionary.
"""


DecoderDictPrompt: TypeAlias = TextPrompt | TokensPrompt
"""
87
A [`DecoderPrompt`][vllm.inputs.llm.DecoderPrompt]
88
89
90
91
92
93
that has been standardized into a dictionary.
"""


class EncoderDecoderDictPrompt(TypedDict):
    """
94
    A [`EncoderDecoderPrompt`][vllm.inputs.llm.EncoderDecoderPrompt]
95
96
97
98
99
100
101
102
103
104
105
106
    that has been standardized into a dictionary.
    """

    encoder_prompt: EncoderDictPrompt

    decoder_prompt: DecoderDictPrompt | None


SingletonDictPrompt: TypeAlias = (
    DecoderOnlyDictPrompt | EncoderDictPrompt | DecoderDictPrompt
)
"""
107
A [`SingletonPrompt`][vllm.inputs.llm.SingletonPrompt]
108
109
110
111
112
113
that has been standardized into a dictionary.
"""


DictPrompt: TypeAlias = DecoderOnlyDictPrompt | EncoderDecoderDictPrompt
"""
114
A [`PromptType`][vllm.inputs.llm.PromptType]
115
116
117
118
that has been standardized into a dictionary.
"""


119
120
121
122
123
124
125
126
127
128
129
130
131
def _validate_prompt_dict(prompt: Mapping[str, object]) -> None:
    """Reject malformed dict prompts before renderer tokenization."""
    if (
        "prompt" not in prompt
        or "prompt_token_ids" in prompt
        or "prompt_embeds" in prompt
    ):
        return

    if not isinstance(prompt["prompt"], str):
        raise TypeError("Prompt text should be a string")


132
def parse_dec_only_prompt(prompt: PromptType | object) -> DecoderOnlyDictPrompt:
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
    """
    Parse a prompt for a decoder-only model and normalize it to a dictionary.
    """
    if isinstance(prompt, str):
        return TextPrompt(prompt=prompt)

    if isinstance(prompt, list):
        if not is_list_of(prompt, int):
            raise TypeError("Token prompt should be a list of integers")

        return TokensPrompt(prompt_token_ids=prompt)

    if isinstance(prompt, dict):
        if "encoder_prompt" in prompt:
            raise TypeError("Cannot pass encoder-decoder prompt to decoder-only models")

149
150
        _validate_prompt_dict(prompt)

151
152
153
154
155
156
157
158
159
160
161
162
        if (
            "prompt" in prompt
            or "prompt_token_ids" in prompt
            or "prompt_embeds" in prompt
        ):
            return prompt  # type: ignore[return-value]

        raise TypeError("Prompt dictionary must contain text, tokens, or embeddings")

    raise TypeError("Prompt should be a string, list of tokens, or dictionary")


163
def _parse_enc_prompt(prompt: PromptType | object) -> EncoderDictPrompt:
164
165
166
167
168
169
170
171
172
173
    if isinstance(prompt, str):
        return TextPrompt(prompt=prompt)

    if isinstance(prompt, list):
        if not is_list_of(prompt, int):
            raise TypeError("Token prompt should be a list of integers")

        return TokensPrompt(prompt_token_ids=prompt)

    if isinstance(prompt, dict):
174
175
        _validate_prompt_dict(prompt)

176
177
178
179
180
181
182
183
184
185
186
        if "prompt_embeds" in prompt:
            raise TypeError("Cannot pass embeddings prompt to encoder-decoder models")

        if "prompt" in prompt or "prompt_token_ids" in prompt:
            return prompt  # type: ignore[return-value]

        raise TypeError("Prompt dictionary must contain text or tokens")

    raise TypeError("Prompt should be a string, list of tokens, or dictionary")


187
def _parse_dec_prompt(prompt: PromptType | object) -> DecoderDictPrompt:
188
189
190
191
192
193
194
195
196
197
    if isinstance(prompt, str):
        return TextPrompt(prompt=prompt)

    if isinstance(prompt, list):
        if not is_list_of(prompt, int):
            raise TypeError("Token prompt should be a list of integers")

        return TokensPrompt(prompt_token_ids=prompt)

    if isinstance(prompt, dict):
198
199
        _validate_prompt_dict(prompt)

200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
        if "prompt_embeds" in prompt:
            raise TypeError("Cannot pass embeddings prompt to encoder-decoder models")

        if (
            "multi_modal_data" in prompt
            or "mm_processor_kwargs" in prompt
            or "multi_modal_uuids" in prompt
        ):
            raise TypeError("Cannot pass multi-modal inputs to decoder prompt")

        if "prompt" in prompt or "prompt_token_ids" in prompt:
            return prompt  # type: ignore[return-value]

        raise TypeError("Prompt dictionary must contain text or tokens")

    raise TypeError("Prompt should be a string, list of tokens, or dictionary")


218
def parse_enc_dec_prompt(prompt: PromptType | object) -> EncoderDecoderDictPrompt:
219
220
221
222
    """
    Parse a prompt for an encoder-decoder model and normalize it to a dictionary.
    """
    if isinstance(prompt, dict) and "encoder_prompt" in prompt:
223
224
        enc_prompt = prompt["encoder_prompt"]  # type: ignore[typeddict-item]
        dec_prompt = prompt["decoder_prompt"]  # type: ignore[typeddict-item]
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
    else:
        enc_prompt = prompt
        dec_prompt = None

    return EncoderDecoderDictPrompt(
        encoder_prompt=_parse_enc_prompt(enc_prompt),
        decoder_prompt=None if dec_prompt is None else _parse_dec_prompt(dec_prompt),
    )


def parse_model_prompt(model_config: "ModelConfig", prompt: object):
    if model_config.is_encoder_decoder:
        return parse_enc_dec_prompt(prompt)

    return parse_dec_only_prompt(prompt)


class PromptComponents(NamedTuple):
    text: str | None = None
    token_ids: list[int] | None = None
    embeds: "torch.Tensor | None" = None


248
249
def extract_target_prompt(model_config: "ModelConfig", prompt: object):
    return (
250
251
252
253
254
        parse_enc_dec_prompt(prompt)["encoder_prompt"]
        if model_config.is_encoder_decoder
        else parse_dec_only_prompt(prompt)
    )

255
256
257

def extract_prompt_components(
    model_config: "ModelConfig",
258
    prompt: PromptType | EngineInput,
259
260
261
) -> PromptComponents:
    target_prompt = extract_target_prompt(model_config, prompt)

262
263
    return PromptComponents(
        text=target_prompt.get("prompt"),
264
        token_ids=target_prompt.get("prompt_token_ids"),
265
266
267
268
        embeds=target_prompt.get("prompt_embeds"),
    )


269
def extract_prompt_len(
270
271
    model_config: "ModelConfig",
    prompt: PromptType | EngineInput,
272
):
273
    target_prompt = extract_target_prompt(model_config, prompt)
274
275

    return length_from_prompt_token_ids_or_embeds(
276
        target_prompt.get("prompt_token_ids"),
277
278
        target_prompt.get("prompt_embeds"),
    )