data.py 11.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from collections.abc import Iterable
4
from dataclasses import dataclass
5
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, cast
6

7
import torch
8
from typing_extensions import NotRequired, TypedDict, TypeIs, TypeVar
9

10
11
from vllm.sampling_params import SamplingParams

12
if TYPE_CHECKING:
13
14
15
16
17
    from vllm.multimodal.inputs import (
        MultiModalDataDict,
        MultiModalInputs,
        MultiModalUUIDDict,
    )
18
19
20
21
else:
    MultiModalDataDict = object
    MultiModalInputs = object
    MultiModalUUIDDict = object
22
23
24
25
26
27
28
29


class TextPrompt(TypedDict):
    """Schema for a text prompt."""

    prompt: str
    """The input text to be tokenized before passing to the model."""

30
    multi_modal_data: NotRequired[MultiModalDataDict | None]
31
32
33
34
35
    """
    Optional multi-modal data to pass to the model,
    if the model supports it.
    """

36
    mm_processor_kwargs: NotRequired[dict[str, Any] | None]
37
38
39
40
41
42
43
    """
    Optional multi-modal processor kwargs to be forwarded to the
    multimodal input mapper & processor. Note that if multiple modalities
    have registered mappers etc for the model being considered, we attempt
    to pass the mm_processor_kwargs to each of them.
    """

44
    multi_modal_uuids: NotRequired[MultiModalUUIDDict]
45
46
47
48
49
50
51
52
    """
    Optional user-specified UUIDs for multimodal items, mapped by modality.
    Lists must match the number of items per modality and may contain `None`.
    For `None` entries, the hasher will compute IDs automatically; non-None
    entries override the default hashes for caching, and MUST be unique per
    multimodal item.
    """

53
54
55
56
57
    cache_salt: NotRequired[str]
    """
    Optional cache salt to be used for prefix caching.
    """

58
59
60
61

class TokensPrompt(TypedDict):
    """Schema for a tokenized prompt."""

62
    prompt_token_ids: list[int]
63
64
    """A list of token IDs to pass to the model."""

65
66
67
    prompt: NotRequired[str]
    """The prompt text corresponding to the token IDs, if available."""

68
    token_type_ids: NotRequired[list[int]]
69
70
    """A list of token type IDs to pass to the cross encoder model."""

71
    multi_modal_data: NotRequired[MultiModalDataDict | None]
72
    """
73
    Optional multi-modal data to pass to the model,
74
75
76
    if the model supports it.
    """

77
    mm_processor_kwargs: NotRequired[dict[str, Any] | None]
78
    """
79
    Optional multi-modal processor kwargs to be forwarded to the
80
81
82
83
84
    multimodal input mapper & processor. Note that if multiple modalities
    have registered mappers etc for the model being considered, we attempt
    to pass the mm_processor_kwargs to each of them.
    """

85
    multi_modal_uuids: NotRequired[MultiModalUUIDDict]
86
87
88
89
90
91
92
    """
    Optional user-specified UUIDs for multimodal items, mapped by modality.
    Lists must match the number of items per modality and may contain `None`.
    For `None` entries, the hasher will compute IDs automatically; non-None
    entries override the default hashes for caching.
    """

93
94
95
96
97
    cache_salt: NotRequired[str]
    """
    Optional cache salt to be used for prefix caching.
    """

98

99
100
101
102
103
104
class EmbedsPrompt(TypedDict):
    """Schema for a prompt provided via token embeddings."""

    prompt_embeds: torch.Tensor
    """The embeddings of the prompt."""

105
106
107
108
109
    cache_salt: NotRequired[str]
    """
    Optional cache salt to be used for prefix caching.
    """

110

111
112
113
114
115
116
117
118
119
120
class DataPrompt(TypedDict):
    """Represents generic inputs handled by IO processor plugins."""

    data: Any
    """The input data"""

    data_format: str
    """The input data format"""


121
SingletonPrompt: TypeAlias = str | TextPrompt | TokensPrompt | EmbedsPrompt
122
"""
123
Set of possible schemas for a single prompt:
124

125
126
127
- A text prompt ([`str`][] or [`TextPrompt`][vllm.inputs.data.TextPrompt])
- A tokenized prompt ([`TokensPrompt`][vllm.inputs.data.TokensPrompt])
- An embeddings prompt ([`EmbedsPrompt`][vllm.inputs.data.EmbedsPrompt])
128
129
130
131
132

Note that "singleton" is as opposed to a data structure
which encapsulates multiple prompts, i.e. of the sort
which may be utilized for encoder/decoder models when
the user desires to express both the encoder & decoder
133
134
prompts explicitly, i.e. 
[`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt]
135

136
137
A prompt of type [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] may be 
employed as (1) input to a decoder-only model, (2) input to
138
139
140
the encoder of an encoder/decoder model, in the scenario
where the decoder-prompt is not specified explicitly, or
(3) as a member of a larger data structure encapsulating
141
142
more than one prompt, i.e. 
[`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt]
143
144
"""

145
146

def is_tokens_prompt(prompt: SingletonPrompt) -> TypeIs[TokensPrompt]:
147
148
149
150
151
    return (
        isinstance(prompt, dict)
        and "prompt_token_ids" in prompt
        and "prompt_embeds" not in prompt
    )
152
153
154


def is_embeds_prompt(prompt: SingletonPrompt) -> TypeIs[EmbedsPrompt]:
155
156
157
158
159
    return (
        isinstance(prompt, dict)
        and "prompt_token_ids" not in prompt
        and "prompt_embeds" in prompt
    )
160
161


162
163
164
165
166
167
_T1_co = TypeVar(
    "_T1_co", bound=SingletonPrompt, default=SingletonPrompt, covariant=True
)
_T2_co = TypeVar(
    "_T2_co", bound=SingletonPrompt, default=SingletonPrompt, covariant=True
)
168

169
170
171

# TODO: Make fields ReadOnly once mypy supports it
class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
172
173
174
    """
    Represents an encoder/decoder model input prompt,
    comprising an explicit encoder prompt and a decoder prompt.
175

176
    The encoder and decoder prompts, respectively, may be formatted
177
178
    according to any of the
    [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] schemas,
179
    and are not required to have the same schema.
180

181
182
183
    Only the encoder prompt may have multi-modal data. mm_processor_kwargs
    should be at the top-level, and should not be set in the encoder/decoder
    prompts, since they are agnostic to the encoder/decoder.
184

185
186
187
    Note that an
    [`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt]
    may not be used as an input to a decoder-only model,
188
    and that the `encoder_prompt` and `decoder_prompt`
189
    fields of this data structure themselves must be
190
    [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] instances.
191
192
    """

193
    encoder_prompt: _T1_co
194

195
    decoder_prompt: _T2_co | None
196

197
    mm_processor_kwargs: NotRequired[dict[str, Any]]
198

199

200
PromptType: TypeAlias = SingletonPrompt | ExplicitEncoderDecoderPrompt
201
202
203
204
"""
Set of possible schemas for an LLM input, including
both decoder-only and encoder/decoder input types:

205
206
207
- A text prompt ([`str`][] or [`TextPrompt`][vllm.inputs.data.TextPrompt])
- A tokenized prompt ([`TokensPrompt`][vllm.inputs.data.TokensPrompt])
- An embeddings prompt ([`EmbedsPrompt`][vllm.inputs.data.EmbedsPrompt])
208
- A single data structure containing both an encoder and a decoder prompt
209
  ([`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt])
210
211
212
"""


213
214
class TokenInputs(TypedDict):
    """Represents token-based inputs."""
215
216
217
218

    type: Literal["token"]
    """The type of inputs."""

219
    prompt_token_ids: list[int]
220
221
    """The token IDs of the prompt."""

222
223
224
225
226
    cache_salt: NotRequired[str]
    """
    Optional cache salt to be used for prefix caching.
    """

227

228
def token_inputs(
229
    prompt_token_ids: list[int],
230
    cache_salt: str | None = None,
231
) -> TokenInputs:
232
233
    """Construct [`TokenInputs`][vllm.inputs.data.TokenInputs] from optional
    values."""
234
    inputs = TokenInputs(type="token", prompt_token_ids=prompt_token_ids)
235

236
237
    if cache_salt is not None:
        inputs["cache_salt"] = cache_salt
238
239
240
241

    return inputs


242
243
244
245
246
247
248
249
250
class EmbedsInputs(TypedDict):
    """Represents embeddings-based inputs."""

    type: Literal["embeds"]
    """The type of inputs."""

    prompt_embeds: torch.Tensor
    """The embeddings of the prompt."""

251
252
253
254
255
    cache_salt: NotRequired[str]
    """
    Optional cache salt to be used for prefix caching.
    """

256

257
258
def embeds_inputs(
    prompt_embeds: torch.Tensor,
259
    cache_salt: str | None = None,
260
) -> EmbedsInputs:
261
262
    """Construct [`EmbedsInputs`][vllm.inputs.data.EmbedsInputs] from optional
    values."""
263
264
265
266
    inputs = EmbedsInputs(type="embeds", prompt_embeds=prompt_embeds)

    if cache_salt is not None:
        inputs["cache_salt"] = cache_salt
267
268
269
270

    return inputs


271
DecoderOnlyInputs: TypeAlias = TokenInputs | EmbedsInputs | MultiModalInputs
272
"""
273
The inputs in [`LLMEngine`][vllm.engine.llm_engine.LLMEngine] before they are
274
275
276
277
278
passed to the model executor.
This specifies the data required for decoder-only models.
"""


279
class EncoderDecoderInputs(TypedDict):
280
    """
281
282
    The inputs in [`LLMEngine`][vllm.engine.llm_engine.LLMEngine] before they
    are passed to the model executor.
283
284
285

    This specifies the required data for encoder-decoder models.
    """
286

287
    encoder: TokenInputs | MultiModalInputs
288
    """The inputs for the encoder portion."""
289

290
    decoder: TokenInputs | MultiModalInputs
291
    """The inputs for the decoder portion."""
292

293

294
SingletonInputs: TypeAlias = TokenInputs | EmbedsInputs | MultiModalInputs
295
"""
296
297
A processed [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] which can be
passed to [`Sequence`][collections.abc.Sequence].
298
299
"""

300
ProcessorInputs: TypeAlias = DecoderOnlyInputs | EncoderDecoderInputs
301
"""
302
The outputs from [`vllm.inputs.preprocess.InputPreprocessor`][].
303
"""
304

305
306
_T1 = TypeVar("_T1", bound=SingletonPrompt, default=SingletonPrompt)
_T2 = TypeVar("_T2", bound=SingletonPrompt, default=SingletonPrompt)
307
308


309
310
def build_explicit_enc_dec_prompt(
    encoder_prompt: _T1,
311
312
    decoder_prompt: _T2 | None,
    mm_processor_kwargs: dict[str, Any] | None = None,
313
) -> ExplicitEncoderDecoderPrompt[_T1, _T2]:
314
315
316
317
318
    if mm_processor_kwargs is None:
        mm_processor_kwargs = {}
    return ExplicitEncoderDecoderPrompt(
        encoder_prompt=encoder_prompt,
        decoder_prompt=decoder_prompt,
319
320
        mm_processor_kwargs=mm_processor_kwargs,
    )
321
322
323
324


def zip_enc_dec_prompts(
    enc_prompts: Iterable[_T1],
325
326
    dec_prompts: Iterable[_T2 | None],
    mm_processor_kwargs: Iterable[dict[str, Any]] | dict[str, Any] | None = None,
327
) -> list[ExplicitEncoderDecoderPrompt[_T1, _T2]]:
328
    """
329
    Zip encoder and decoder prompts together into a list of
330
331
    [`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt]
    instances.
332

333
    `mm_processor_kwargs` may also be provided; if a dict is passed, the same
334
335
    dictionary will be used for every encoder/decoder prompt. If an iterable is
    provided, it will be zipped with the encoder/decoder prompts.
336
337
    """
    if mm_processor_kwargs is None:
338
        mm_processor_kwargs = cast(dict[str, Any], {})
339
    if isinstance(mm_processor_kwargs, dict):
340
        return [
341
            build_explicit_enc_dec_prompt(
342
343
344
                encoder_prompt,
                decoder_prompt,
                cast(dict[str, Any], mm_processor_kwargs),
345
346
            )
            for (encoder_prompt, decoder_prompt) in zip(enc_prompts, dec_prompts)
347
        ]
348
    return [
349
350
351
352
        build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt, mm_proc_kwargs)
        for (encoder_prompt, decoder_prompt, mm_proc_kwargs) in zip(
            enc_prompts, dec_prompts, mm_processor_kwargs
        )
353
354
    ]

355

356
357
def to_enc_dec_tuple_list(
    enc_dec_prompts: Iterable[ExplicitEncoderDecoderPrompt[_T1, _T2]],
358
) -> list[tuple[_T1, _T2 | None]]:
359
360
361
362
    return [
        (enc_dec_prompt["encoder_prompt"], enc_dec_prompt["decoder_prompt"])
        for enc_dec_prompt in enc_dec_prompts
    ]
363
364
365
366
367
368
369
370
371
372
373
374


@dataclass
class StreamingInput:
    """Input data for a streaming generation request.

    This is used with generate() to support multi-turn streaming sessions
    where inputs are provided via an async generator.
    """

    prompt: PromptType
    sampling_params: SamplingParams | None = None