data.py 11.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, Generic, Literal, Optional, Union, cast
5

6
import torch
7
from typing_extensions import NotRequired, TypedDict, TypeIs, TypeVar
8
9

if TYPE_CHECKING:
10
11
12
13
14
    from vllm.multimodal.inputs import (
        MultiModalDataDict,
        MultiModalInputs,
        MultiModalUUIDDict,
    )
15
16
17
18
19
20
21
22


class TextPrompt(TypedDict):
    """Schema for a text prompt."""

    prompt: str
    """The input text to be tokenized before passing to the model."""

23
    multi_modal_data: NotRequired["MultiModalDataDict"]
24
25
26
27
28
    """
    Optional multi-modal data to pass to the model,
    if the model supports it.
    """

29
    mm_processor_kwargs: NotRequired[dict[str, Any]]
30
31
32
33
34
35
36
    """
    Optional multi-modal processor kwargs to be forwarded to the
    multimodal input mapper & processor. Note that if multiple modalities
    have registered mappers etc for the model being considered, we attempt
    to pass the mm_processor_kwargs to each of them.
    """

37
38
39
40
41
42
43
44
45
    multi_modal_uuids: NotRequired["MultiModalUUIDDict"]
    """
    Optional user-specified UUIDs for multimodal items, mapped by modality.
    Lists must match the number of items per modality and may contain `None`.
    For `None` entries, the hasher will compute IDs automatically; non-None
    entries override the default hashes for caching, and MUST be unique per
    multimodal item.
    """

46
47
48
49
50
    cache_salt: NotRequired[str]
    """
    Optional cache salt to be used for prefix caching.
    """

51
52
53
54

class TokensPrompt(TypedDict):
    """Schema for a tokenized prompt."""

55
    prompt_token_ids: list[int]
56
57
    """A list of token IDs to pass to the model."""

58
59
60
    prompt: NotRequired[str]
    """The prompt text corresponding to the token IDs, if available."""

61
    token_type_ids: NotRequired[list[int]]
62
63
    """A list of token type IDs to pass to the cross encoder model."""

64
    multi_modal_data: NotRequired["MultiModalDataDict"]
65
    """
66
    Optional multi-modal data to pass to the model,
67
68
69
    if the model supports it.
    """

70
    mm_processor_kwargs: NotRequired[dict[str, Any]]
71
    """
72
    Optional multi-modal processor kwargs to be forwarded to the
73
74
75
76
77
    multimodal input mapper & processor. Note that if multiple modalities
    have registered mappers etc for the model being considered, we attempt
    to pass the mm_processor_kwargs to each of them.
    """

78
79
80
81
82
83
84
85
    multi_modal_uuids: NotRequired["MultiModalUUIDDict"]
    """
    Optional user-specified UUIDs for multimodal items, mapped by modality.
    Lists must match the number of items per modality and may contain `None`.
    For `None` entries, the hasher will compute IDs automatically; non-None
    entries override the default hashes for caching.
    """

86
87
88
89
90
    cache_salt: NotRequired[str]
    """
    Optional cache salt to be used for prefix caching.
    """

91

92
93
94
95
96
97
class EmbedsPrompt(TypedDict):
    """Schema for a prompt provided via token embeddings."""

    prompt_embeds: torch.Tensor
    """The embeddings of the prompt."""

98
99
100
101
102
    cache_salt: NotRequired[str]
    """
    Optional cache salt to be used for prefix caching.
    """

103

104
105
106
107
108
109
110
111
112
113
class DataPrompt(TypedDict):
    """Represents generic inputs handled by IO processor plugins."""

    data: Any
    """The input data"""

    data_format: str
    """The input data format"""


114
SingletonPrompt = Union[str, TextPrompt, TokensPrompt, EmbedsPrompt]
115
"""
116
Set of possible schemas for a single prompt:
117

118
119
120
- A text prompt ([`str`][] or [`TextPrompt`][vllm.inputs.data.TextPrompt])
- A tokenized prompt ([`TokensPrompt`][vllm.inputs.data.TokensPrompt])
- An embeddings prompt ([`EmbedsPrompt`][vllm.inputs.data.EmbedsPrompt])
121
122
123
124
125

Note that "singleton" is as opposed to a data structure
which encapsulates multiple prompts, i.e. of the sort
which may be utilized for encoder/decoder models when
the user desires to express both the encoder & decoder
126
127
prompts explicitly, i.e. 
[`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt]
128

129
130
A prompt of type [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] may be 
employed as (1) input to a decoder-only model, (2) input to
131
132
133
the encoder of an encoder/decoder model, in the scenario
where the decoder-prompt is not specified explicitly, or
(3) as a member of a larger data structure encapsulating
134
135
more than one prompt, i.e. 
[`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt]
136
137
"""

138
139

def is_tokens_prompt(prompt: SingletonPrompt) -> TypeIs[TokensPrompt]:
140
141
142
143
144
    return (
        isinstance(prompt, dict)
        and "prompt_token_ids" in prompt
        and "prompt_embeds" not in prompt
    )
145
146
147


def is_embeds_prompt(prompt: SingletonPrompt) -> TypeIs[EmbedsPrompt]:
148
149
150
151
152
    return (
        isinstance(prompt, dict)
        and "prompt_token_ids" not in prompt
        and "prompt_embeds" in prompt
    )
153
154


155
156
157
158
159
160
_T1_co = TypeVar(
    "_T1_co", bound=SingletonPrompt, default=SingletonPrompt, covariant=True
)
_T2_co = TypeVar(
    "_T2_co", bound=SingletonPrompt, default=SingletonPrompt, covariant=True
)
161

162
163
164

# TODO: Make fields ReadOnly once mypy supports it
class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
165
166
167
    """
    Represents an encoder/decoder model input prompt,
    comprising an explicit encoder prompt and a decoder prompt.
168

169
    The encoder and decoder prompts, respectively, may be formatted
170
171
    according to any of the
    [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] schemas,
172
    and are not required to have the same schema.
173

174
175
176
    Only the encoder prompt may have multi-modal data. mm_processor_kwargs
    should be at the top-level, and should not be set in the encoder/decoder
    prompts, since they are agnostic to the encoder/decoder.
177

178
179
180
    Note that an
    [`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt]
    may not be used as an input to a decoder-only model,
181
    and that the `encoder_prompt` and `decoder_prompt`
182
    fields of this data structure themselves must be
183
    [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] instances.
184
185
    """

186
    encoder_prompt: _T1_co
187

188
    decoder_prompt: Optional[_T2_co]
189

190
    mm_processor_kwargs: NotRequired[dict[str, Any]]
191

192

193
PromptType = Union[SingletonPrompt, ExplicitEncoderDecoderPrompt]
194
195
196
197
"""
Set of possible schemas for an LLM input, including
both decoder-only and encoder/decoder input types:

198
199
200
- A text prompt ([`str`][] or [`TextPrompt`][vllm.inputs.data.TextPrompt])
- A tokenized prompt ([`TokensPrompt`][vllm.inputs.data.TokensPrompt])
- An embeddings prompt ([`EmbedsPrompt`][vllm.inputs.data.EmbedsPrompt])
201
- A single data structure containing both an encoder and a decoder prompt
202
  ([`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt])
203
204
205
"""


206
207
class TokenInputs(TypedDict):
    """Represents token-based inputs."""
208
209
210
211

    type: Literal["token"]
    """The type of inputs."""

212
    prompt_token_ids: list[int]
213
214
    """The token IDs of the prompt."""

215
216
217
218
219
    cache_salt: NotRequired[str]
    """
    Optional cache salt to be used for prefix caching.
    """

220

221
def token_inputs(
222
    prompt_token_ids: list[int],
223
    cache_salt: Optional[str] = None,
224
) -> TokenInputs:
225
226
    """Construct [`TokenInputs`][vllm.inputs.data.TokenInputs] from optional
    values."""
227
    inputs = TokenInputs(type="token", prompt_token_ids=prompt_token_ids)
228

229
230
    if cache_salt is not None:
        inputs["cache_salt"] = cache_salt
231
232
233
234

    return inputs


235
236
237
238
239
240
241
242
243
class EmbedsInputs(TypedDict):
    """Represents embeddings-based inputs."""

    type: Literal["embeds"]
    """The type of inputs."""

    prompt_embeds: torch.Tensor
    """The embeddings of the prompt."""

244
245
246
247
248
    cache_salt: NotRequired[str]
    """
    Optional cache salt to be used for prefix caching.
    """

249

250
251
252
253
def embeds_inputs(
    prompt_embeds: torch.Tensor,
    cache_salt: Optional[str] = None,
) -> EmbedsInputs:
254
255
    """Construct [`EmbedsInputs`][vllm.inputs.data.EmbedsInputs] from optional
    values."""
256
257
258
259
    inputs = EmbedsInputs(type="embeds", prompt_embeds=prompt_embeds)

    if cache_salt is not None:
        inputs["cache_salt"] = cache_salt
260
261
262
263
264

    return inputs


DecoderOnlyInputs = Union[TokenInputs, EmbedsInputs, "MultiModalInputs"]
265
"""
266
The inputs in [`LLMEngine`][vllm.engine.llm_engine.LLMEngine] before they are
267
268
269
270
271
passed to the model executor.
This specifies the data required for decoder-only models.
"""


272
class EncoderDecoderInputs(TypedDict):
273
    """
274
275
    The inputs in [`LLMEngine`][vllm.engine.llm_engine.LLMEngine] before they
    are passed to the model executor.
276
277
278

    This specifies the required data for encoder-decoder models.
    """
279

280
    encoder: Union[TokenInputs, "MultiModalInputs"]
281
    """The inputs for the encoder portion."""
282

283
    decoder: Union[TokenInputs, "MultiModalInputs"]
284
    """The inputs for the decoder portion."""
285

286

287
SingletonInputs = Union[TokenInputs, EmbedsInputs, "MultiModalInputs"]
288
"""
289
290
A processed [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] which can be
passed to [`Sequence`][collections.abc.Sequence].
291
292
293
294
"""

ProcessorInputs = Union[DecoderOnlyInputs, EncoderDecoderInputs]
"""
295
The outputs from [`vllm.inputs.preprocess.InputPreprocessor`][].
296
"""
297

298
299
_T1 = TypeVar("_T1", bound=SingletonPrompt, default=SingletonPrompt)
_T2 = TypeVar("_T2", bound=SingletonPrompt, default=SingletonPrompt)
300
301


302
303
304
def build_explicit_enc_dec_prompt(
    encoder_prompt: _T1,
    decoder_prompt: Optional[_T2],
305
    mm_processor_kwargs: Optional[dict[str, Any]] = None,
306
) -> ExplicitEncoderDecoderPrompt[_T1, _T2]:
307
308
309
310
311
    if mm_processor_kwargs is None:
        mm_processor_kwargs = {}
    return ExplicitEncoderDecoderPrompt(
        encoder_prompt=encoder_prompt,
        decoder_prompt=decoder_prompt,
312
313
        mm_processor_kwargs=mm_processor_kwargs,
    )
314
315
316
317
318


def zip_enc_dec_prompts(
    enc_prompts: Iterable[_T1],
    dec_prompts: Iterable[Optional[_T2]],
319
320
321
    mm_processor_kwargs: Optional[
        Union[Iterable[dict[str, Any]], dict[str, Any]]
    ] = None,
322
) -> list[ExplicitEncoderDecoderPrompt[_T1, _T2]]:
323
    """
324
    Zip encoder and decoder prompts together into a list of
325
326
    [`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt]
    instances.
327

328
329
330
    ``mm_processor_kwargs`` may also be provided; if a dict is passed, the same
    dictionary will be used for every encoder/decoder prompt. If an iterable is
    provided, it will be zipped with the encoder/decoder prompts.
331
332
    """
    if mm_processor_kwargs is None:
333
        mm_processor_kwargs = cast(dict[str, Any], {})
334
    if isinstance(mm_processor_kwargs, dict):
335
        return [
336
            build_explicit_enc_dec_prompt(
337
338
339
                encoder_prompt,
                decoder_prompt,
                cast(dict[str, Any], mm_processor_kwargs),
340
341
            )
            for (encoder_prompt, decoder_prompt) in zip(enc_prompts, dec_prompts)
342
        ]
343
    return [
344
345
346
347
        build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt, mm_proc_kwargs)
        for (encoder_prompt, decoder_prompt, mm_proc_kwargs) in zip(
            enc_prompts, dec_prompts, mm_processor_kwargs
        )
348
349
    ]

350

351
352
def to_enc_dec_tuple_list(
    enc_dec_prompts: Iterable[ExplicitEncoderDecoderPrompt[_T1, _T2]],
353
) -> list[tuple[_T1, Optional[_T2]]]:
354
355
356
357
    return [
        (enc_dec_prompt["encoder_prompt"], enc_dec_prompt["decoder_prompt"])
        for enc_dec_prompt in enc_dec_prompts
    ]