data.py 8.47 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, Generic, Literal, Optional, Union, cast
4

5
from typing_extensions import NotRequired, TypedDict, TypeVar
6
7

if TYPE_CHECKING:
8
    from vllm.multimodal.inputs import MultiModalDataDict, MultiModalInputs
9
10
11
12
13
14
15
16


class TextPrompt(TypedDict):
    """Schema for a text prompt."""

    prompt: str
    """The input text to be tokenized before passing to the model."""

17
    multi_modal_data: NotRequired["MultiModalDataDict"]
18
19
20
21
22
    """
    Optional multi-modal data to pass to the model,
    if the model supports it.
    """

23
    mm_processor_kwargs: NotRequired[dict[str, Any]]
24
25
26
27
28
29
30
    """
    Optional multi-modal processor kwargs to be forwarded to the
    multimodal input mapper & processor. Note that if multiple modalities
    have registered mappers etc for the model being considered, we attempt
    to pass the mm_processor_kwargs to each of them.
    """

31
32
33
34
35
    cache_salt: NotRequired[str]
    """
    Optional cache salt to be used for prefix caching.
    """

36
37
38
39

class TokensPrompt(TypedDict):
    """Schema for a tokenized prompt."""

40
    prompt_token_ids: list[int]
41
42
    """A list of token IDs to pass to the model."""

43
    token_type_ids: NotRequired[list[int]]
44
45
    """A list of token type IDs to pass to the cross encoder model."""

46
    multi_modal_data: NotRequired["MultiModalDataDict"]
47
    """
48
    Optional multi-modal data to pass to the model,
49
50
51
    if the model supports it.
    """

52
    mm_processor_kwargs: NotRequired[dict[str, Any]]
53
    """
54
    Optional multi-modal processor kwargs to be forwarded to the
55
56
57
58
59
    multimodal input mapper & processor. Note that if multiple modalities
    have registered mappers etc for the model being considered, we attempt
    to pass the mm_processor_kwargs to each of them.
    """

60
61
62
63
64
    cache_salt: NotRequired[str]
    """
    Optional cache salt to be used for prefix caching.
    """

65

66
SingletonPrompt = Union[str, TextPrompt, TokensPrompt]
67
"""
68
Set of possible schemas for a single prompt:
69
70
71

- A text prompt (:class:`str` or :class:`TextPrompt`)
- A tokenized prompt (:class:`TokensPrompt`)
72
73
74
75
76

Note that "singleton" is as opposed to a data structure
which encapsulates multiple prompts, i.e. of the sort
which may be utilized for encoder/decoder models when
the user desires to express both the encoder & decoder
77
prompts explicitly, i.e. :class:`ExplicitEncoderDecoderPrompt`
78

79
A prompt of type :class:`SingletonPrompt` may be employed
80
81
82
83
as (1) input to a decoder-only model, (2) input to
the encoder of an encoder/decoder model, in the scenario
where the decoder-prompt is not specified explicitly, or
(3) as a member of a larger data structure encapsulating
84
more than one prompt, i.e. :class:`ExplicitEncoderDecoderPrompt`
85
86
"""

87
_T1_co = TypeVar("_T1_co",
88
89
                 bound=SingletonPrompt,
                 default=SingletonPrompt,
90
91
                 covariant=True)
_T2_co = TypeVar("_T2_co",
92
93
                 bound=SingletonPrompt,
                 default=SingletonPrompt,
94
                 covariant=True)
95

96
97
98

# TODO: Make fields ReadOnly once mypy supports it
class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
99
100
101
    """
    Represents an encoder/decoder model input prompt,
    comprising an explicit encoder prompt and a decoder prompt.
102

103
104
105
    The encoder and decoder prompts, respectively, may be formatted
    according to any of the :class:`SingletonPrompt` schemas,
    and are not required to have the same schema.
106

107
108
109
    Only the encoder prompt may have multi-modal data. mm_processor_kwargs
    should be at the top-level, and should not be set in the encoder/decoder
    prompts, since they are agnostic to the encoder/decoder.
110

111
    Note that an :class:`ExplicitEncoderDecoderPrompt` may not
112
    be used as an input to a decoder-only model,
113
    and that the :code:`encoder_prompt` and :code:`decoder_prompt`
114
    fields of this data structure themselves must be
115
    :class:`SingletonPrompt` instances.
116
117
    """

118
    encoder_prompt: _T1_co
119

120
    decoder_prompt: Optional[_T2_co]
121

122
    mm_processor_kwargs: NotRequired[dict[str, Any]]
123

124

125
PromptType = Union[SingletonPrompt, ExplicitEncoderDecoderPrompt]
126
127
128
129
130
131
132
133
134
135
136
"""
Set of possible schemas for an LLM input, including
both decoder-only and encoder/decoder input types:

- A text prompt (:class:`str` or :class:`TextPrompt`)
- A tokenized prompt (:class:`TokensPrompt`)
- A single data structure containing both an encoder and a decoder prompt
  (:class:`ExplicitEncoderDecoderPrompt`)
"""


137
138
class TokenInputs(TypedDict):
    """Represents token-based inputs."""
139
140
141
142

    type: Literal["token"]
    """The type of inputs."""

143
    prompt_token_ids: list[int]
144
145
    """The token IDs of the prompt."""

146
    token_type_ids: NotRequired[list[int]]
147
148
    """The token type IDs of the prompt."""

149
    prompt: NotRequired[str]
150
151
152
153
    """
    The original prompt text corresponding to the token IDs, if available.
    """

154
155
156
157
158
    cache_salt: NotRequired[str]
    """
    Optional cache salt to be used for prefix caching.
    """

159

160
def token_inputs(
161
162
    prompt_token_ids: list[int],
    token_type_ids: Optional[list[int]] = None,
163
    prompt: Optional[str] = None,
164
    cache_salt: Optional[str] = None,
165
166
) -> TokenInputs:
    """Construct :class:`TokenInputs` from optional values."""
167
    inputs = TokenInputs(type="token", prompt_token_ids=prompt_token_ids)
168
169
170

    if prompt is not None:
        inputs["prompt"] = prompt
171
172
    if token_type_ids is not None:
        inputs["token_type_ids"] = token_type_ids
173
174
    if cache_salt is not None:
        inputs["cache_salt"] = cache_salt
175
176
177
178

    return inputs


179
DecoderOnlyInputs = Union[TokenInputs, "MultiModalInputs"]
180
181
182
183
184
185
186
"""
The inputs in :class:`~vllm.LLMEngine` before they are
passed to the model executor.
This specifies the data required for decoder-only models.
"""


187
class EncoderDecoderInputs(TypedDict):
188
189
190
191
192
193
    """
    The inputs in :class:`~vllm.LLMEngine` before they are
    passed to the model executor.

    This specifies the required data for encoder-decoder models.
    """
194
    encoder: Union[TokenInputs, "MultiModalInputs"]
195
    """The inputs for the encoder portion."""
196

197
    decoder: Union[TokenInputs, "MultiModalInputs"]
198
    """The inputs for the decoder portion."""
199

200

201
SingletonInputs = Union[TokenInputs, "MultiModalInputs"]
202
203
204
205
206
207
208
209
210
"""
A processed :class:`SingletonPrompt` which can be passed to
:class:`vllm.sequence.Sequence`.
"""

ProcessorInputs = Union[DecoderOnlyInputs, EncoderDecoderInputs]
"""
The inputs to :data:`vllm.inputs.InputProcessor`.
"""
211

212
213
_T1 = TypeVar("_T1", bound=SingletonPrompt, default=SingletonPrompt)
_T2 = TypeVar("_T2", bound=SingletonPrompt, default=SingletonPrompt)
214
215


216
217
218
def build_explicit_enc_dec_prompt(
    encoder_prompt: _T1,
    decoder_prompt: Optional[_T2],
219
    mm_processor_kwargs: Optional[dict[str, Any]] = None,
220
) -> ExplicitEncoderDecoderPrompt[_T1, _T2]:
221
222
223
224
225
226
    if mm_processor_kwargs is None:
        mm_processor_kwargs = {}
    return ExplicitEncoderDecoderPrompt(
        encoder_prompt=encoder_prompt,
        decoder_prompt=decoder_prompt,
        mm_processor_kwargs=mm_processor_kwargs)
227
228
229
230
231


def zip_enc_dec_prompts(
    enc_prompts: Iterable[_T1],
    dec_prompts: Iterable[Optional[_T2]],
232
233
234
    mm_processor_kwargs: Optional[Union[Iterable[dict[str, Any]],
                                        dict[str, Any]]] = None,
) -> list[ExplicitEncoderDecoderPrompt[_T1, _T2]]:
235
    """
236
    Zip encoder and decoder prompts together into a list of
237
    :class:`ExplicitEncoderDecoderPrompt` instances.
238

239
240
241
    ``mm_processor_kwargs`` may also be provided; if a dict is passed, the same
    dictionary will be used for every encoder/decoder prompt. If an iterable is
    provided, it will be zipped with the encoder/decoder prompts.
242
243
    """
    if mm_processor_kwargs is None:
244
        mm_processor_kwargs = cast(dict[str, Any], {})
245
    if isinstance(mm_processor_kwargs, dict):
246
        return [
247
248
            build_explicit_enc_dec_prompt(
                encoder_prompt, decoder_prompt,
249
                cast(dict[str, Any], mm_processor_kwargs))
250
251
252
            for (encoder_prompt,
                 decoder_prompt) in zip(enc_prompts, dec_prompts)
        ]
253
    return [
254
255
256
257
        build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt,
                                      mm_proc_kwargs)
        for (encoder_prompt, decoder_prompt, mm_proc_kwargs
             ) in zip(enc_prompts, dec_prompts, mm_processor_kwargs)
258
259
    ]

260

261
262
def to_enc_dec_tuple_list(
    enc_dec_prompts: Iterable[ExplicitEncoderDecoderPrompt[_T1, _T2]],
263
) -> list[tuple[_T1, Optional[_T2]]]:
264
265
266
    return [(enc_dec_prompt["encoder_prompt"],
             enc_dec_prompt["decoder_prompt"])
            for enc_dec_prompt in enc_dec_prompts]