data.py 11.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from collections.abc import Iterable
4
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, cast
5

6
import torch
7
from typing_extensions import NotRequired, TypedDict, TypeIs, TypeVar
8
9

if TYPE_CHECKING:
10
11
12
13
14
    from vllm.multimodal.inputs import (
        MultiModalDataDict,
        MultiModalInputs,
        MultiModalUUIDDict,
    )
15
16
17
18
else:
    MultiModalDataDict = object
    MultiModalInputs = object
    MultiModalUUIDDict = object
19
20
21
22
23
24
25
26


class TextPrompt(TypedDict):
    """Schema for a text prompt."""

    prompt: str
    """The input text to be tokenized before passing to the model."""

27
    multi_modal_data: NotRequired[MultiModalDataDict | None]
28
29
30
31
32
    """
    Optional multi-modal data to pass to the model,
    if the model supports it.
    """

33
    mm_processor_kwargs: NotRequired[dict[str, Any] | None]
34
35
36
37
38
39
40
    """
    Optional multi-modal processor kwargs to be forwarded to the
    multimodal input mapper & processor. Note that if multiple modalities
    have registered mappers etc for the model being considered, we attempt
    to pass the mm_processor_kwargs to each of them.
    """

41
    multi_modal_uuids: NotRequired[MultiModalUUIDDict]
42
43
44
45
46
47
48
49
    """
    Optional user-specified UUIDs for multimodal items, mapped by modality.
    Lists must match the number of items per modality and may contain `None`.
    For `None` entries, the hasher will compute IDs automatically; non-None
    entries override the default hashes for caching, and MUST be unique per
    multimodal item.
    """

50
51
52
53
54
    cache_salt: NotRequired[str]
    """
    Optional cache salt to be used for prefix caching.
    """

55
56
57
58

class TokensPrompt(TypedDict):
    """Schema for a tokenized prompt."""

59
    prompt_token_ids: list[int]
60
61
    """A list of token IDs to pass to the model."""

62
63
64
    prompt: NotRequired[str]
    """The prompt text corresponding to the token IDs, if available."""

65
    token_type_ids: NotRequired[list[int]]
66
67
    """A list of token type IDs to pass to the cross encoder model."""

68
    multi_modal_data: NotRequired[MultiModalDataDict | None]
69
    """
70
    Optional multi-modal data to pass to the model,
71
72
73
    if the model supports it.
    """

74
    mm_processor_kwargs: NotRequired[dict[str, Any] | None]
75
    """
76
    Optional multi-modal processor kwargs to be forwarded to the
77
78
79
80
81
    multimodal input mapper & processor. Note that if multiple modalities
    have registered mappers etc for the model being considered, we attempt
    to pass the mm_processor_kwargs to each of them.
    """

82
    multi_modal_uuids: NotRequired[MultiModalUUIDDict]
83
84
85
86
87
88
89
    """
    Optional user-specified UUIDs for multimodal items, mapped by modality.
    Lists must match the number of items per modality and may contain `None`.
    For `None` entries, the hasher will compute IDs automatically; non-None
    entries override the default hashes for caching.
    """

90
91
92
93
94
    cache_salt: NotRequired[str]
    """
    Optional cache salt to be used for prefix caching.
    """

95

96
97
98
99
100
101
class EmbedsPrompt(TypedDict):
    """Schema for a prompt provided via token embeddings."""

    prompt_embeds: torch.Tensor
    """The embeddings of the prompt."""

102
103
104
105
106
    cache_salt: NotRequired[str]
    """
    Optional cache salt to be used for prefix caching.
    """

107

108
109
110
111
112
113
114
115
116
117
class DataPrompt(TypedDict):
    """Represents generic inputs handled by IO processor plugins."""

    data: Any
    """The input data"""

    data_format: str
    """The input data format"""


118
SingletonPrompt: TypeAlias = str | TextPrompt | TokensPrompt | EmbedsPrompt
119
"""
120
Set of possible schemas for a single prompt:
121

122
123
124
- A text prompt ([`str`][] or [`TextPrompt`][vllm.inputs.data.TextPrompt])
- A tokenized prompt ([`TokensPrompt`][vllm.inputs.data.TokensPrompt])
- An embeddings prompt ([`EmbedsPrompt`][vllm.inputs.data.EmbedsPrompt])
125
126
127
128
129

Note that "singleton" is as opposed to a data structure
which encapsulates multiple prompts, i.e. of the sort
which may be utilized for encoder/decoder models when
the user desires to express both the encoder & decoder
130
131
prompts explicitly, i.e. 
[`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt]
132

133
134
A prompt of type [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] may be 
employed as (1) input to a decoder-only model, (2) input to
135
136
137
the encoder of an encoder/decoder model, in the scenario
where the decoder-prompt is not specified explicitly, or
(3) as a member of a larger data structure encapsulating
138
139
more than one prompt, i.e. 
[`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt]
140
141
"""

142
143

def is_tokens_prompt(prompt: SingletonPrompt) -> TypeIs[TokensPrompt]:
144
145
146
147
148
    return (
        isinstance(prompt, dict)
        and "prompt_token_ids" in prompt
        and "prompt_embeds" not in prompt
    )
149
150
151


def is_embeds_prompt(prompt: SingletonPrompt) -> TypeIs[EmbedsPrompt]:
152
153
154
155
156
    return (
        isinstance(prompt, dict)
        and "prompt_token_ids" not in prompt
        and "prompt_embeds" in prompt
    )
157
158


159
160
161
162
163
164
_T1_co = TypeVar(
    "_T1_co", bound=SingletonPrompt, default=SingletonPrompt, covariant=True
)
_T2_co = TypeVar(
    "_T2_co", bound=SingletonPrompt, default=SingletonPrompt, covariant=True
)
165

166
167
168

# TODO: Make fields ReadOnly once mypy supports it
class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
169
170
171
    """
    Represents an encoder/decoder model input prompt,
    comprising an explicit encoder prompt and a decoder prompt.
172

173
    The encoder and decoder prompts, respectively, may be formatted
174
175
    according to any of the
    [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] schemas,
176
    and are not required to have the same schema.
177

178
179
180
    Only the encoder prompt may have multi-modal data. mm_processor_kwargs
    should be at the top-level, and should not be set in the encoder/decoder
    prompts, since they are agnostic to the encoder/decoder.
181

182
183
184
    Note that an
    [`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt]
    may not be used as an input to a decoder-only model,
185
    and that the `encoder_prompt` and `decoder_prompt`
186
    fields of this data structure themselves must be
187
    [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] instances.
188
189
    """

190
    encoder_prompt: _T1_co
191

192
    decoder_prompt: _T2_co | None
193

194
    mm_processor_kwargs: NotRequired[dict[str, Any]]
195

196

197
PromptType: TypeAlias = SingletonPrompt | ExplicitEncoderDecoderPrompt
198
199
200
201
"""
Set of possible schemas for an LLM input, including
both decoder-only and encoder/decoder input types:

202
203
204
- A text prompt ([`str`][] or [`TextPrompt`][vllm.inputs.data.TextPrompt])
- A tokenized prompt ([`TokensPrompt`][vllm.inputs.data.TokensPrompt])
- An embeddings prompt ([`EmbedsPrompt`][vllm.inputs.data.EmbedsPrompt])
205
- A single data structure containing both an encoder and a decoder prompt
206
  ([`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt])
207
208
209
"""


210
211
class TokenInputs(TypedDict):
    """Represents token-based inputs."""
212
213
214
215

    type: Literal["token"]
    """The type of inputs."""

216
    prompt_token_ids: list[int]
217
218
    """The token IDs of the prompt."""

219
220
221
222
223
    cache_salt: NotRequired[str]
    """
    Optional cache salt to be used for prefix caching.
    """

224

225
def token_inputs(
226
    prompt_token_ids: list[int],
227
    cache_salt: str | None = None,
228
) -> TokenInputs:
229
230
    """Construct [`TokenInputs`][vllm.inputs.data.TokenInputs] from optional
    values."""
231
    inputs = TokenInputs(type="token", prompt_token_ids=prompt_token_ids)
232

233
234
    if cache_salt is not None:
        inputs["cache_salt"] = cache_salt
235
236
237
238

    return inputs


239
240
241
242
243
244
245
246
247
class EmbedsInputs(TypedDict):
    """Represents embeddings-based inputs."""

    type: Literal["embeds"]
    """The type of inputs."""

    prompt_embeds: torch.Tensor
    """The embeddings of the prompt."""

248
249
250
251
252
    cache_salt: NotRequired[str]
    """
    Optional cache salt to be used for prefix caching.
    """

253

254
255
def embeds_inputs(
    prompt_embeds: torch.Tensor,
256
    cache_salt: str | None = None,
257
) -> EmbedsInputs:
258
259
    """Construct [`EmbedsInputs`][vllm.inputs.data.EmbedsInputs] from optional
    values."""
260
261
262
263
    inputs = EmbedsInputs(type="embeds", prompt_embeds=prompt_embeds)

    if cache_salt is not None:
        inputs["cache_salt"] = cache_salt
264
265
266
267

    return inputs


268
DecoderOnlyInputs: TypeAlias = TokenInputs | EmbedsInputs | MultiModalInputs
269
"""
270
The inputs in [`LLMEngine`][vllm.engine.llm_engine.LLMEngine] before they are
271
272
273
274
275
passed to the model executor.
This specifies the data required for decoder-only models.
"""


276
class EncoderDecoderInputs(TypedDict):
277
    """
278
279
    The inputs in [`LLMEngine`][vllm.engine.llm_engine.LLMEngine] before they
    are passed to the model executor.
280
281
282

    This specifies the required data for encoder-decoder models.
    """
283

284
    encoder: TokenInputs | MultiModalInputs
285
    """The inputs for the encoder portion."""
286

287
    decoder: TokenInputs | MultiModalInputs
288
    """The inputs for the decoder portion."""
289

290

291
SingletonInputs: TypeAlias = TokenInputs | EmbedsInputs | MultiModalInputs
292
"""
293
294
A processed [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] which can be
passed to [`Sequence`][collections.abc.Sequence].
295
296
"""

297
ProcessorInputs: TypeAlias = DecoderOnlyInputs | EncoderDecoderInputs
298
"""
299
The outputs from [`vllm.inputs.preprocess.InputPreprocessor`][].
300
"""
301

302
303
_T1 = TypeVar("_T1", bound=SingletonPrompt, default=SingletonPrompt)
_T2 = TypeVar("_T2", bound=SingletonPrompt, default=SingletonPrompt)
304
305


306
307
def build_explicit_enc_dec_prompt(
    encoder_prompt: _T1,
308
309
    decoder_prompt: _T2 | None,
    mm_processor_kwargs: dict[str, Any] | None = None,
310
) -> ExplicitEncoderDecoderPrompt[_T1, _T2]:
311
312
313
314
315
    if mm_processor_kwargs is None:
        mm_processor_kwargs = {}
    return ExplicitEncoderDecoderPrompt(
        encoder_prompt=encoder_prompt,
        decoder_prompt=decoder_prompt,
316
317
        mm_processor_kwargs=mm_processor_kwargs,
    )
318
319
320
321


def zip_enc_dec_prompts(
    enc_prompts: Iterable[_T1],
322
323
    dec_prompts: Iterable[_T2 | None],
    mm_processor_kwargs: Iterable[dict[str, Any]] | dict[str, Any] | None = None,
324
) -> list[ExplicitEncoderDecoderPrompt[_T1, _T2]]:
325
    """
326
    Zip encoder and decoder prompts together into a list of
327
328
    [`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt]
    instances.
329

330
    `mm_processor_kwargs` may also be provided; if a dict is passed, the same
331
332
    dictionary will be used for every encoder/decoder prompt. If an iterable is
    provided, it will be zipped with the encoder/decoder prompts.
333
334
    """
    if mm_processor_kwargs is None:
335
        mm_processor_kwargs = cast(dict[str, Any], {})
336
    if isinstance(mm_processor_kwargs, dict):
337
        return [
338
            build_explicit_enc_dec_prompt(
339
340
341
                encoder_prompt,
                decoder_prompt,
                cast(dict[str, Any], mm_processor_kwargs),
342
343
            )
            for (encoder_prompt, decoder_prompt) in zip(enc_prompts, dec_prompts)
344
        ]
345
    return [
346
347
348
349
        build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt, mm_proc_kwargs)
        for (encoder_prompt, decoder_prompt, mm_proc_kwargs) in zip(
            enc_prompts, dec_prompts, mm_processor_kwargs
        )
350
351
    ]

352

353
354
def to_enc_dec_tuple_list(
    enc_dec_prompts: Iterable[ExplicitEncoderDecoderPrompt[_T1, _T2]],
355
) -> list[tuple[_T1, _T2 | None]]:
356
357
358
359
    return [
        (enc_dec_prompt["encoder_prompt"], enc_dec_prompt["decoder_prompt"])
        for enc_dec_prompt in enc_dec_prompts
    ]