"tests/models/language/generation/test_common.py" did not exist on "b40cf6402e356a10415e969e648a32911fb9b8ec"
data.py 10.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from collections.abc import Iterable
4
from dataclasses import dataclass
5
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, cast
6

7
import torch
8
from typing_extensions import NotRequired, TypedDict, TypeVar
9

10
11
from vllm.sampling_params import SamplingParams

12
if TYPE_CHECKING:
13
14
    from vllm.multimodal.inputs import (
        MultiModalDataDict,
15
        MultiModalEncDecInputs,
16
17
18
        MultiModalInputs,
        MultiModalUUIDDict,
    )
19
20
else:
    MultiModalDataDict = object
21
    MultiModalEncDecInputs = object
22
23
    MultiModalInputs = object
    MultiModalUUIDDict = object
24
25


26
class _CommonKeys(TypedDict):
27
    multi_modal_data: NotRequired[MultiModalDataDict | None]
28
29
30
31
32
    """
    Optional multi-modal data to pass to the model,
    if the model supports it.
    """

33
    mm_processor_kwargs: NotRequired[dict[str, Any] | None]
34
35
36
37
38
39
40
    """
    Optional multi-modal processor kwargs to be forwarded to the
    multimodal input mapper & processor. Note that if multiple modalities
    have registered mappers etc for the model being considered, we attempt
    to pass the mm_processor_kwargs to each of them.
    """

41
    multi_modal_uuids: NotRequired[MultiModalUUIDDict]
42
43
44
45
46
47
48
49
    """
    Optional user-specified UUIDs for multimodal items, mapped by modality.
    Lists must match the number of items per modality and may contain `None`.
    For `None` entries, the hasher will compute IDs automatically; non-None
    entries override the default hashes for caching, and MUST be unique per
    multimodal item.
    """

50
51
52
53
54
    cache_salt: NotRequired[str]
    """
    Optional cache salt to be used for prefix caching.
    """

55

56
57
58
59
60
61
62
63
class TextPrompt(_CommonKeys):
    """Schema for a text prompt."""

    prompt: str
    """The input text to be tokenized before passing to the model."""


class TokensPrompt(_CommonKeys):
64
65
    """Schema for a tokenized prompt."""

66
    prompt_token_ids: list[int]
67
68
    """A list of token IDs to pass to the model."""

69
70
71
    prompt: NotRequired[str]
    """The prompt text corresponding to the token IDs, if available."""

72
    token_type_ids: NotRequired[list[int]]
73
74
    """A list of token type IDs to pass to the cross encoder model."""

75

76
class EmbedsPrompt(_CommonKeys):
77
78
79
80
81
    """Schema for a prompt provided via token embeddings."""

    prompt_embeds: torch.Tensor
    """The embeddings of the prompt."""

82
83
84
    prompt: NotRequired[str]
    """The prompt text corresponding to the token embeddings, if available."""

85

86
class DataPrompt(_CommonKeys):
87
88
89
90
91
92
93
94
95
    """Represents generic inputs handled by IO processor plugins."""

    data: Any
    """The input data"""

    data_format: str
    """The input data format"""


96
SingletonPrompt: TypeAlias = str | TextPrompt | TokensPrompt | EmbedsPrompt
97
"""
98
Set of possible schemas for a single prompt:
99

100
101
102
- A text prompt ([`str`][] or [`TextPrompt`][vllm.inputs.data.TextPrompt])
- A tokenized prompt ([`TokensPrompt`][vllm.inputs.data.TokensPrompt])
- An embeddings prompt ([`EmbedsPrompt`][vllm.inputs.data.EmbedsPrompt])
103
104
105
106
107

Note that "singleton" is as opposed to a data structure
which encapsulates multiple prompts, i.e. of the sort
which may be utilized for encoder/decoder models when
the user desires to express both the encoder & decoder
108
109
prompts explicitly, i.e. 
[`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt]
110

111
112
A prompt of type [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] may be 
employed as (1) input to a decoder-only model, (2) input to
113
114
115
the encoder of an encoder/decoder model, in the scenario
where the decoder-prompt is not specified explicitly, or
(3) as a member of a larger data structure encapsulating
116
117
more than one prompt, i.e. 
[`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt]
118
119
"""

120

121
122
123
124
125
126
_T1_co = TypeVar(
    "_T1_co", bound=SingletonPrompt, default=SingletonPrompt, covariant=True
)
_T2_co = TypeVar(
    "_T2_co", bound=SingletonPrompt, default=SingletonPrompt, covariant=True
)
127

128
129
130

# TODO: Make fields ReadOnly once mypy supports it
class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
131
132
133
    """
    Represents an encoder/decoder model input prompt,
    comprising an explicit encoder prompt and a decoder prompt.
134

135
    The encoder and decoder prompts, respectively, may be formatted
136
137
    according to any of the
    [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] schemas,
138
    and are not required to have the same schema.
139

140
141
142
    Only the encoder prompt may have multi-modal data. mm_processor_kwargs
    should be at the top-level, and should not be set in the encoder/decoder
    prompts, since they are agnostic to the encoder/decoder.
143

144
145
146
    Note that an
    [`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt]
    may not be used as an input to a decoder-only model,
147
    and that the `encoder_prompt` and `decoder_prompt`
148
    fields of this data structure themselves must be
149
    [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] instances.
150
151
    """

152
    encoder_prompt: _T1_co
153

154
    decoder_prompt: _T2_co | None
155

156
    mm_processor_kwargs: NotRequired[dict[str, Any]]
157

158

159
PromptType: TypeAlias = SingletonPrompt | ExplicitEncoderDecoderPrompt[Any, Any]
160
161
162
163
"""
Set of possible schemas for an LLM input, including
both decoder-only and encoder/decoder input types:

164
165
166
- A text prompt ([`str`][] or [`TextPrompt`][vllm.inputs.data.TextPrompt])
- A tokenized prompt ([`TokensPrompt`][vllm.inputs.data.TokensPrompt])
- An embeddings prompt ([`EmbedsPrompt`][vllm.inputs.data.EmbedsPrompt])
167
- A single data structure containing both an encoder and a decoder prompt
168
  ([`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt])
169
170
171
"""


172
173
class TokenInputs(TypedDict):
    """Represents token-based inputs."""
174
175
176
177

    type: Literal["token"]
    """The type of inputs."""

178
    prompt_token_ids: list[int]
179
180
    """The token IDs of the prompt."""

181
182
183
184
185
    cache_salt: NotRequired[str]
    """
    Optional cache salt to be used for prefix caching.
    """

186

187
def token_inputs(
188
    prompt_token_ids: list[int],
189
    cache_salt: str | None = None,
190
) -> TokenInputs:
191
192
    """Construct [`TokenInputs`][vllm.inputs.data.TokenInputs] from optional
    values."""
193
    inputs = TokenInputs(type="token", prompt_token_ids=prompt_token_ids)
194

195
196
    if cache_salt is not None:
        inputs["cache_salt"] = cache_salt
197
198
199
200

    return inputs


201
202
203
204
205
206
207
208
209
class EmbedsInputs(TypedDict):
    """Represents embeddings-based inputs."""

    type: Literal["embeds"]
    """The type of inputs."""

    prompt_embeds: torch.Tensor
    """The embeddings of the prompt."""

210
211
212
213
214
    cache_salt: NotRequired[str]
    """
    Optional cache salt to be used for prefix caching.
    """

215

216
217
def embeds_inputs(
    prompt_embeds: torch.Tensor,
218
    cache_salt: str | None = None,
219
) -> EmbedsInputs:
220
221
    """Construct [`EmbedsInputs`][vllm.inputs.data.EmbedsInputs] from optional
    values."""
222
223
224
225
    inputs = EmbedsInputs(type="embeds", prompt_embeds=prompt_embeds)

    if cache_salt is not None:
        inputs["cache_salt"] = cache_salt
226
227
228
229

    return inputs


230
DecoderOnlyInputs: TypeAlias = TokenInputs | EmbedsInputs | MultiModalInputs
231
"""
232
The inputs in [`LLMEngine`][vllm.engine.llm_engine.LLMEngine] before they are
233
234
235
236
237
passed to the model executor.
This specifies the data required for decoder-only models.
"""


238
class EncoderDecoderInputs(TypedDict):
239
    """
240
241
    The inputs in [`LLMEngine`][vllm.engine.llm_engine.LLMEngine] before they
    are passed to the model executor.
242
243
244

    This specifies the required data for encoder-decoder models.
    """
245

246
    encoder: TokenInputs | MultiModalEncDecInputs
247
    """The inputs for the encoder portion."""
248

249
    decoder: TokenInputs | MultiModalInputs
250
    """The inputs for the decoder portion."""
251

252

253
SingletonInputs: TypeAlias = TokenInputs | EmbedsInputs | MultiModalInputs
254
"""
255
256
A processed [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] which can be
passed to [`Sequence`][collections.abc.Sequence].
257
258
"""

259
ProcessorInputs: TypeAlias = DecoderOnlyInputs | EncoderDecoderInputs
260
"""
261
The outputs from [`vllm.inputs.preprocess.InputPreprocessor`][].
262
"""
263

264
265
_T1 = TypeVar("_T1", bound=SingletonPrompt, default=SingletonPrompt)
_T2 = TypeVar("_T2", bound=SingletonPrompt, default=SingletonPrompt)
266
267


268
269
def build_explicit_enc_dec_prompt(
    encoder_prompt: _T1,
270
271
    decoder_prompt: _T2 | None,
    mm_processor_kwargs: dict[str, Any] | None = None,
272
) -> ExplicitEncoderDecoderPrompt[_T1, _T2]:
273
274
275
276
277
    if mm_processor_kwargs is None:
        mm_processor_kwargs = {}
    return ExplicitEncoderDecoderPrompt(
        encoder_prompt=encoder_prompt,
        decoder_prompt=decoder_prompt,
278
279
        mm_processor_kwargs=mm_processor_kwargs,
    )
280
281
282
283


def zip_enc_dec_prompts(
    enc_prompts: Iterable[_T1],
284
285
    dec_prompts: Iterable[_T2 | None],
    mm_processor_kwargs: Iterable[dict[str, Any]] | dict[str, Any] | None = None,
286
) -> list[ExplicitEncoderDecoderPrompt[_T1, _T2]]:
287
    """
288
    Zip encoder and decoder prompts together into a list of
289
290
    [`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt]
    instances.
291

292
    `mm_processor_kwargs` may also be provided; if a dict is passed, the same
293
294
    dictionary will be used for every encoder/decoder prompt. If an iterable is
    provided, it will be zipped with the encoder/decoder prompts.
295
296
    """
    if mm_processor_kwargs is None:
297
        mm_processor_kwargs = cast(dict[str, Any], {})
298
    if isinstance(mm_processor_kwargs, dict):
299
        return [
300
            build_explicit_enc_dec_prompt(
301
302
303
                encoder_prompt,
                decoder_prompt,
                cast(dict[str, Any], mm_processor_kwargs),
304
305
            )
            for (encoder_prompt, decoder_prompt) in zip(enc_prompts, dec_prompts)
306
        ]
307
    return [
308
309
310
311
        build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt, mm_proc_kwargs)
        for (encoder_prompt, decoder_prompt, mm_proc_kwargs) in zip(
            enc_prompts, dec_prompts, mm_processor_kwargs
        )
312
313
    ]

314

315
316
def to_enc_dec_tuple_list(
    enc_dec_prompts: Iterable[ExplicitEncoderDecoderPrompt[_T1, _T2]],
317
) -> list[tuple[_T1, _T2 | None]]:
318
319
320
321
    return [
        (enc_dec_prompt["encoder_prompt"], enc_dec_prompt["decoder_prompt"])
        for enc_dec_prompt in enc_dec_prompts
    ]
322
323
324
325
326
327
328
329
330
331
332
333


@dataclass
class StreamingInput:
    """Input data for a streaming generation request.

    This is used with generate() to support multi-turn streaming sessions
    where inputs are provided via an async generator.
    """

    prompt: PromptType
    sampling_params: SamplingParams | None = None