data.py 10.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, Generic, Literal, Optional, Union, cast
5

6
import torch
7
from typing_extensions import NotRequired, TypedDict, TypeIs, TypeVar
8
9

if TYPE_CHECKING:
10
    from vllm.multimodal.inputs import MultiModalDataDict, MultiModalInputs
11
12
13
14
15
16
17
18


class TextPrompt(TypedDict):
    """Schema for a text prompt."""

    prompt: str
    """The input text to be tokenized before passing to the model."""

19
    multi_modal_data: NotRequired["MultiModalDataDict"]
20
21
22
23
24
    """
    Optional multi-modal data to pass to the model,
    if the model supports it.
    """

25
    mm_processor_kwargs: NotRequired[dict[str, Any]]
26
27
28
29
30
31
32
    """
    Optional multi-modal processor kwargs to be forwarded to the
    multimodal input mapper & processor. Note that if multiple modalities
    have registered mappers etc for the model being considered, we attempt
    to pass the mm_processor_kwargs to each of them.
    """

33
34
35
36
37
    cache_salt: NotRequired[str]
    """
    Optional cache salt to be used for prefix caching.
    """

38
39
40
41

class TokensPrompt(TypedDict):
    """Schema for a tokenized prompt."""

42
    prompt_token_ids: list[int]
43
44
    """A list of token IDs to pass to the model."""

45
    token_type_ids: NotRequired[list[int]]
46
47
    """A list of token type IDs to pass to the cross encoder model."""

48
    multi_modal_data: NotRequired["MultiModalDataDict"]
49
    """
50
    Optional multi-modal data to pass to the model,
51
52
53
    if the model supports it.
    """

54
    mm_processor_kwargs: NotRequired[dict[str, Any]]
55
    """
56
    Optional multi-modal processor kwargs to be forwarded to the
57
58
59
60
61
    multimodal input mapper & processor. Note that if multiple modalities
    have registered mappers etc for the model being considered, we attempt
    to pass the mm_processor_kwargs to each of them.
    """

62
63
64
65
66
    cache_salt: NotRequired[str]
    """
    Optional cache salt to be used for prefix caching.
    """

67

68
69
70
71
72
73
class EmbedsPrompt(TypedDict):
    """Schema for a prompt provided via token embeddings."""

    prompt_embeds: torch.Tensor
    """The embeddings of the prompt."""

74
75
76
77
78
    cache_salt: NotRequired[str]
    """
    Optional cache salt to be used for prefix caching.
    """

79
80

SingletonPrompt = Union[str, TextPrompt, TokensPrompt, EmbedsPrompt]
81
"""
82
Set of possible schemas for a single prompt:
83

84
85
86
- A text prompt ([`str`][] or [`TextPrompt`][vllm.inputs.data.TextPrompt])
- A tokenized prompt ([`TokensPrompt`][vllm.inputs.data.TokensPrompt])
- An embeddings prompt ([`EmbedsPrompt`][vllm.inputs.data.EmbedsPrompt])
87
88
89
90
91

Note that "singleton" is as opposed to a data structure
which encapsulates multiple prompts, i.e. of the sort
which may be utilized for encoder/decoder models when
the user desires to express both the encoder & decoder
92
93
prompts explicitly, i.e. 
[`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt]
94

95
96
A prompt of type [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] may be 
employed as (1) input to a decoder-only model, (2) input to
97
98
99
the encoder of an encoder/decoder model, in the scenario
where the decoder-prompt is not specified explicitly, or
(3) as a member of a larger data structure encapsulating
100
101
more than one prompt, i.e. 
[`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt]
102
103
"""

104
105
106
107
108
109
110
111
112
113
114

def is_tokens_prompt(prompt: SingletonPrompt) -> TypeIs[TokensPrompt]:
    return (isinstance(prompt, dict) and "prompt_token_ids" in prompt
            and "prompt_embeds" not in prompt)


def is_embeds_prompt(prompt: SingletonPrompt) -> TypeIs[EmbedsPrompt]:
    return (isinstance(prompt, dict) and "prompt_token_ids" not in prompt
            and "prompt_embeds" in prompt)


115
_T1_co = TypeVar("_T1_co",
116
117
                 bound=SingletonPrompt,
                 default=SingletonPrompt,
118
119
                 covariant=True)
_T2_co = TypeVar("_T2_co",
120
121
                 bound=SingletonPrompt,
                 default=SingletonPrompt,
122
                 covariant=True)
123

124
125
126

# TODO: Make fields ReadOnly once mypy supports it
class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
127
128
129
    """
    Represents an encoder/decoder model input prompt,
    comprising an explicit encoder prompt and a decoder prompt.
130

131
    The encoder and decoder prompts, respectively, may be formatted
132
133
    according to any of the
    [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] schemas,
134
    and are not required to have the same schema.
135

136
137
138
    Only the encoder prompt may have multi-modal data. mm_processor_kwargs
    should be at the top-level, and should not be set in the encoder/decoder
    prompts, since they are agnostic to the encoder/decoder.
139

140
141
142
    Note that an
    [`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt]
    may not be used as an input to a decoder-only model,
143
    and that the `encoder_prompt` and `decoder_prompt`
144
    fields of this data structure themselves must be
145
    [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] instances.
146
147
    """

148
    encoder_prompt: _T1_co
149

150
    decoder_prompt: Optional[_T2_co]
151

152
    mm_processor_kwargs: NotRequired[dict[str, Any]]
153

154

155
PromptType = Union[SingletonPrompt, ExplicitEncoderDecoderPrompt]
156
157
158
159
"""
Set of possible schemas for an LLM input, including
both decoder-only and encoder/decoder input types:

160
161
162
- A text prompt ([`str`][] or [`TextPrompt`][vllm.inputs.data.TextPrompt])
- A tokenized prompt ([`TokensPrompt`][vllm.inputs.data.TokensPrompt])
- An embeddings prompt ([`EmbedsPrompt`][vllm.inputs.data.EmbedsPrompt])
163
- A single data structure containing both an encoder and a decoder prompt
164
  ([`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt])
165
166
167
"""


168
169
class TokenInputs(TypedDict):
    """Represents token-based inputs."""
170
171
172
173

    type: Literal["token"]
    """The type of inputs."""

174
    prompt_token_ids: list[int]
175
176
    """The token IDs of the prompt."""

177
    token_type_ids: NotRequired[list[int]]
178
179
    """The token type IDs of the prompt."""

180
    prompt: NotRequired[str]
181
182
183
184
    """
    The original prompt text corresponding to the token IDs, if available.
    """

185
186
187
188
189
    cache_salt: NotRequired[str]
    """
    Optional cache salt to be used for prefix caching.
    """

190

191
def token_inputs(
192
193
    prompt_token_ids: list[int],
    token_type_ids: Optional[list[int]] = None,
194
    prompt: Optional[str] = None,
195
    cache_salt: Optional[str] = None,
196
) -> TokenInputs:
197
198
    """Construct [`TokenInputs`][vllm.inputs.data.TokenInputs] from optional
    values."""
199
    inputs = TokenInputs(type="token", prompt_token_ids=prompt_token_ids)
200
201
202

    if prompt is not None:
        inputs["prompt"] = prompt
203
204
    if token_type_ids is not None:
        inputs["token_type_ids"] = token_type_ids
205
206
    if cache_salt is not None:
        inputs["cache_salt"] = cache_salt
207
208
209
210

    return inputs


211
212
213
214
215
216
217
218
219
class EmbedsInputs(TypedDict):
    """Represents embeddings-based inputs."""

    type: Literal["embeds"]
    """The type of inputs."""

    prompt_embeds: torch.Tensor
    """The embeddings of the prompt."""

220
221
222
223
224
    cache_salt: NotRequired[str]
    """
    Optional cache salt to be used for prefix caching.
    """

225

226
227
228
229
def embeds_inputs(
    prompt_embeds: torch.Tensor,
    cache_salt: Optional[str] = None,
) -> EmbedsInputs:
230
231
    """Construct [`EmbedsInputs`][vllm.inputs.data.EmbedsInputs] from optional
    values."""
232
233
234
235
    inputs = EmbedsInputs(type="embeds", prompt_embeds=prompt_embeds)

    if cache_salt is not None:
        inputs["cache_salt"] = cache_salt
236
237
238
239
240

    return inputs


DecoderOnlyInputs = Union[TokenInputs, EmbedsInputs, "MultiModalInputs"]
241
"""
242
The inputs in [`LLMEngine`][vllm.engine.llm_engine.LLMEngine] before they are
243
244
245
246
247
passed to the model executor.
This specifies the data required for decoder-only models.
"""


248
class EncoderDecoderInputs(TypedDict):
249
    """
250
251
    The inputs in [`LLMEngine`][vllm.engine.llm_engine.LLMEngine] before they
    are passed to the model executor.
252
253
254

    This specifies the required data for encoder-decoder models.
    """
255

256
    encoder: Union[TokenInputs, "MultiModalInputs"]
257
    """The inputs for the encoder portion."""
258

259
    decoder: Union[TokenInputs, "MultiModalInputs"]
260
    """The inputs for the decoder portion."""
261

262

263
SingletonInputs = Union[TokenInputs, EmbedsInputs, "MultiModalInputs"]
264
"""
265
266
A processed [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] which can be 
passed to [`vllm.sequence.Sequence`][].
267
268
269
270
"""

ProcessorInputs = Union[DecoderOnlyInputs, EncoderDecoderInputs]
"""
271
The outputs from [`vllm.inputs.preprocess.InputPreprocessor`][].
272
"""
273

274
275
_T1 = TypeVar("_T1", bound=SingletonPrompt, default=SingletonPrompt)
_T2 = TypeVar("_T2", bound=SingletonPrompt, default=SingletonPrompt)
276
277


278
279
280
def build_explicit_enc_dec_prompt(
    encoder_prompt: _T1,
    decoder_prompt: Optional[_T2],
281
    mm_processor_kwargs: Optional[dict[str, Any]] = None,
282
) -> ExplicitEncoderDecoderPrompt[_T1, _T2]:
283
284
285
286
287
    if mm_processor_kwargs is None:
        mm_processor_kwargs = {}
    return ExplicitEncoderDecoderPrompt(
        encoder_prompt=encoder_prompt,
        decoder_prompt=decoder_prompt,
288
289
        mm_processor_kwargs=mm_processor_kwargs,
    )
290
291
292
293
294


def zip_enc_dec_prompts(
    enc_prompts: Iterable[_T1],
    dec_prompts: Iterable[Optional[_T2]],
295
296
297
    mm_processor_kwargs: Optional[Union[Iterable[dict[str, Any]],
                                        dict[str, Any]]] = None,
) -> list[ExplicitEncoderDecoderPrompt[_T1, _T2]]:
298
    """
299
    Zip encoder and decoder prompts together into a list of
300
301
    [`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt]
    instances.
302

303
304
305
    ``mm_processor_kwargs`` may also be provided; if a dict is passed, the same
    dictionary will be used for every encoder/decoder prompt. If an iterable is
    provided, it will be zipped with the encoder/decoder prompts.
306
307
    """
    if mm_processor_kwargs is None:
308
        mm_processor_kwargs = cast(dict[str, Any], {})
309
    if isinstance(mm_processor_kwargs, dict):
310
        return [
311
            build_explicit_enc_dec_prompt(
312
313
314
315
316
                encoder_prompt,
                decoder_prompt,
                cast(dict[str, Any], mm_processor_kwargs),
            ) for (encoder_prompt,
                   decoder_prompt) in zip(enc_prompts, dec_prompts)
317
        ]
318
    return [
319
320
321
322
        build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt,
                                      mm_proc_kwargs)
        for (encoder_prompt, decoder_prompt, mm_proc_kwargs
             ) in zip(enc_prompts, dec_prompts, mm_processor_kwargs)
323
324
    ]

325

326
327
def to_enc_dec_tuple_list(
    enc_dec_prompts: Iterable[ExplicitEncoderDecoderPrompt[_T1, _T2]],
328
) -> list[tuple[_T1, Optional[_T2]]]:
329
330
331
    return [(enc_dec_prompt["encoder_prompt"],
             enc_dec_prompt["decoder_prompt"])
            for enc_dec_prompt in enc_dec_prompts]