data.py 9.51 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, Generic, Literal, Optional, Union, cast
4

5
import torch
6
from typing_extensions import NotRequired, TypedDict, TypeVar
7
8

if TYPE_CHECKING:
9
    from vllm.multimodal.inputs import MultiModalDataDict, MultiModalInputs
10
11
12
13
14
15
16
17


class TextPrompt(TypedDict):
    """Schema for a text prompt."""

    prompt: str
    """The input text to be tokenized before passing to the model."""

18
    multi_modal_data: NotRequired["MultiModalDataDict"]
19
20
21
22
23
    """
    Optional multi-modal data to pass to the model,
    if the model supports it.
    """

24
    mm_processor_kwargs: NotRequired[dict[str, Any]]
25
26
27
28
29
30
31
    """
    Optional multi-modal processor kwargs to be forwarded to the
    multimodal input mapper & processor. Note that if multiple modalities
    have registered mappers etc for the model being considered, we attempt
    to pass the mm_processor_kwargs to each of them.
    """

32
33
34
35
36
    cache_salt: NotRequired[str]
    """
    Optional cache salt to be used for prefix caching.
    """

37
38
39
40

class TokensPrompt(TypedDict):
    """Schema for a tokenized prompt."""

41
    prompt_token_ids: list[int]
42
43
    """A list of token IDs to pass to the model."""

44
    token_type_ids: NotRequired[list[int]]
45
46
    """A list of token type IDs to pass to the cross encoder model."""

47
    multi_modal_data: NotRequired["MultiModalDataDict"]
48
    """
49
    Optional multi-modal data to pass to the model,
50
51
52
    if the model supports it.
    """

53
    mm_processor_kwargs: NotRequired[dict[str, Any]]
54
    """
55
    Optional multi-modal processor kwargs to be forwarded to the
56
57
58
59
60
    multimodal input mapper & processor. Note that if multiple modalities
    have registered mappers etc for the model being considered, we attempt
    to pass the mm_processor_kwargs to each of them.
    """

61
62
63
64
65
    cache_salt: NotRequired[str]
    """
    Optional cache salt to be used for prefix caching.
    """

66

67
68
69
70
71
72
class EmbedsPrompt(TypedDict):
    """Schema for a prompt provided via token embeddings."""

    prompt_embeds: torch.Tensor
    """The embeddings of the prompt."""

73
74
75
76
77
    cache_salt: NotRequired[str]
    """
    Optional cache salt to be used for prefix caching.
    """

78
79

SingletonPrompt = Union[str, TextPrompt, TokensPrompt, EmbedsPrompt]
80
"""
81
Set of possible schemas for a single prompt:
82

83
84
85
- A text prompt ({class}`str` or {class}`TextPrompt`)
- A tokenized prompt ({class}`TokensPrompt`)
- An embeddings prompt ({class}`EmbedsPrompt`)
86
87
88
89
90

Note that "singleton" is as opposed to a data structure
which encapsulates multiple prompts, i.e. of the sort
which may be utilized for encoder/decoder models when
the user desires to express both the encoder & decoder
91
prompts explicitly, i.e. {class}`ExplicitEncoderDecoderPrompt`
92

93
A prompt of type {class}`SingletonPrompt` may be employed
94
95
96
97
as (1) input to a decoder-only model, (2) input to
the encoder of an encoder/decoder model, in the scenario
where the decoder-prompt is not specified explicitly, or
(3) as a member of a larger data structure encapsulating
98
more than one prompt, i.e. {class}`ExplicitEncoderDecoderPrompt`
99
100
"""

101
_T1_co = TypeVar("_T1_co",
102
103
                 bound=SingletonPrompt,
                 default=SingletonPrompt,
104
105
                 covariant=True)
_T2_co = TypeVar("_T2_co",
106
107
                 bound=SingletonPrompt,
                 default=SingletonPrompt,
108
                 covariant=True)
109

110
111
112

# TODO: Make fields ReadOnly once mypy supports it
class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
113
114
115
    """
    Represents an encoder/decoder model input prompt,
    comprising an explicit encoder prompt and a decoder prompt.
116

117
    The encoder and decoder prompts, respectively, may be formatted
118
    according to any of the {class}`SingletonPrompt` schemas,
119
    and are not required to have the same schema.
120

121
122
123
    Only the encoder prompt may have multi-modal data. mm_processor_kwargs
    should be at the top-level, and should not be set in the encoder/decoder
    prompts, since they are agnostic to the encoder/decoder.
124

125
    Note that an {class}`ExplicitEncoderDecoderPrompt` may not
126
    be used as an input to a decoder-only model,
127
    and that the `encoder_prompt` and `decoder_prompt`
128
    fields of this data structure themselves must be
129
    {class}`SingletonPrompt` instances.
130
131
    """

132
    encoder_prompt: _T1_co
133

134
    decoder_prompt: Optional[_T2_co]
135

136
    mm_processor_kwargs: NotRequired[dict[str, Any]]
137

138

139
PromptType = Union[SingletonPrompt, ExplicitEncoderDecoderPrompt]
140
141
142
143
"""
Set of possible schemas for an LLM input, including
both decoder-only and encoder/decoder input types:

144
145
146
- A text prompt ({class}`str` or {class}`TextPrompt`)
- A tokenized prompt ({class}`TokensPrompt`)
- An embeddings prompt ({class}`EmbedsPrompt`)
147
- A single data structure containing both an encoder and a decoder prompt
148
  ({class}`ExplicitEncoderDecoderPrompt`)
149
150
151
"""


152
153
class TokenInputs(TypedDict):
    """Represents token-based inputs."""
154
155
156
157

    type: Literal["token"]
    """The type of inputs."""

158
    prompt_token_ids: list[int]
159
160
    """The token IDs of the prompt."""

161
    token_type_ids: NotRequired[list[int]]
162
163
    """The token type IDs of the prompt."""

164
    prompt: NotRequired[str]
165
166
167
168
    """
    The original prompt text corresponding to the token IDs, if available.
    """

169
170
171
172
173
    cache_salt: NotRequired[str]
    """
    Optional cache salt to be used for prefix caching.
    """

174

175
def token_inputs(
176
177
    prompt_token_ids: list[int],
    token_type_ids: Optional[list[int]] = None,
178
    prompt: Optional[str] = None,
179
    cache_salt: Optional[str] = None,
180
) -> TokenInputs:
181
    """Construct {class}`TokenInputs` from optional values."""
182
    inputs = TokenInputs(type="token", prompt_token_ids=prompt_token_ids)
183
184
185

    if prompt is not None:
        inputs["prompt"] = prompt
186
187
    if token_type_ids is not None:
        inputs["token_type_ids"] = token_type_ids
188
189
    if cache_salt is not None:
        inputs["cache_salt"] = cache_salt
190
191
192
193

    return inputs


194
195
196
197
198
199
200
201
202
class EmbedsInputs(TypedDict):
    """Represents embeddings-based inputs."""

    type: Literal["embeds"]
    """The type of inputs."""

    prompt_embeds: torch.Tensor
    """The embeddings of the prompt."""

203
204
205
206
207
    cache_salt: NotRequired[str]
    """
    Optional cache salt to be used for prefix caching.
    """

208

209
210
211
212
def embeds_inputs(
    prompt_embeds: torch.Tensor,
    cache_salt: Optional[str] = None,
) -> EmbedsInputs:
213
    """Construct :class:`EmbedsInputs` from optional values."""
214
215
216
217
    inputs = EmbedsInputs(type="embeds", prompt_embeds=prompt_embeds)

    if cache_salt is not None:
        inputs["cache_salt"] = cache_salt
218
219
220
221
222

    return inputs


DecoderOnlyInputs = Union[TokenInputs, EmbedsInputs, "MultiModalInputs"]
223
"""
224
The inputs in {class}`~vllm.LLMEngine` before they are
225
226
227
228
229
passed to the model executor.
This specifies the data required for decoder-only models.
"""


230
class EncoderDecoderInputs(TypedDict):
231
    """
232
    The inputs in {class}`~vllm.LLMEngine` before they are
233
234
235
236
    passed to the model executor.

    This specifies the required data for encoder-decoder models.
    """
237
    encoder: Union[TokenInputs, "MultiModalInputs"]
238
    """The inputs for the encoder portion."""
239

240
    decoder: Union[TokenInputs, "MultiModalInputs"]
241
    """The inputs for the decoder portion."""
242

243

244
SingletonInputs = Union[TokenInputs, EmbedsInputs, "MultiModalInputs"]
245
"""
246
247
A processed {class}`SingletonPrompt` which can be passed to
{class}`vllm.sequence.Sequence`.
248
249
250
251
"""

ProcessorInputs = Union[DecoderOnlyInputs, EncoderDecoderInputs]
"""
252
The inputs to {data}`vllm.inputs.InputProcessor`.
253
"""
254

255
256
_T1 = TypeVar("_T1", bound=SingletonPrompt, default=SingletonPrompt)
_T2 = TypeVar("_T2", bound=SingletonPrompt, default=SingletonPrompt)
257
258


259
260
261
def build_explicit_enc_dec_prompt(
    encoder_prompt: _T1,
    decoder_prompt: Optional[_T2],
262
    mm_processor_kwargs: Optional[dict[str, Any]] = None,
263
) -> ExplicitEncoderDecoderPrompt[_T1, _T2]:
264
265
266
267
268
269
    if mm_processor_kwargs is None:
        mm_processor_kwargs = {}
    return ExplicitEncoderDecoderPrompt(
        encoder_prompt=encoder_prompt,
        decoder_prompt=decoder_prompt,
        mm_processor_kwargs=mm_processor_kwargs)
270
271
272
273
274


def zip_enc_dec_prompts(
    enc_prompts: Iterable[_T1],
    dec_prompts: Iterable[Optional[_T2]],
275
276
277
    mm_processor_kwargs: Optional[Union[Iterable[dict[str, Any]],
                                        dict[str, Any]]] = None,
) -> list[ExplicitEncoderDecoderPrompt[_T1, _T2]]:
278
    """
279
    Zip encoder and decoder prompts together into a list of
280
    {class}`ExplicitEncoderDecoderPrompt` instances.
281

282
283
284
    ``mm_processor_kwargs`` may also be provided; if a dict is passed, the same
    dictionary will be used for every encoder/decoder prompt. If an iterable is
    provided, it will be zipped with the encoder/decoder prompts.
285
286
    """
    if mm_processor_kwargs is None:
287
        mm_processor_kwargs = cast(dict[str, Any], {})
288
    if isinstance(mm_processor_kwargs, dict):
289
        return [
290
291
            build_explicit_enc_dec_prompt(
                encoder_prompt, decoder_prompt,
292
                cast(dict[str, Any], mm_processor_kwargs))
293
294
295
            for (encoder_prompt,
                 decoder_prompt) in zip(enc_prompts, dec_prompts)
        ]
296
    return [
297
298
299
300
        build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt,
                                      mm_proc_kwargs)
        for (encoder_prompt, decoder_prompt, mm_proc_kwargs
             ) in zip(enc_prompts, dec_prompts, mm_processor_kwargs)
301
302
    ]

303

304
305
def to_enc_dec_tuple_list(
    enc_dec_prompts: Iterable[ExplicitEncoderDecoderPrompt[_T1, _T2]],
306
) -> list[tuple[_T1, Optional[_T2]]]:
307
308
309
    return [(enc_dec_prompt["encoder_prompt"],
             enc_dec_prompt["decoder_prompt"])
            for enc_dec_prompt in enc_dec_prompts]