"vscode:/vscode.git/clone" did not exist on "782505ed8eb4f1b27cccd009a8dc9b69f6ad6ebc"
data.py 11.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from typing import TYPE_CHECKING, Any, Literal, TypeAlias
4

5
import torch
6
from typing_extensions import NotRequired, TypedDict, assert_never
7
8

if TYPE_CHECKING:
9
10
    from vllm.multimodal.inputs import (
        MultiModalDataDict,
11
        MultiModalEncDecInputs,
12
13
14
        MultiModalInputs,
        MultiModalUUIDDict,
    )
15
16
else:
    MultiModalDataDict = object
17
    MultiModalEncDecInputs = object
18
19
    MultiModalInputs = object
    MultiModalUUIDDict = object
20
21


22
23
24
25
26
27
28
# Inputs to LLM API
class _PromptOptions(TypedDict):
    """
    Additional options available to all
    [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt].
    """

29
    multi_modal_data: NotRequired[MultiModalDataDict | None]
30
31
32
33
34
    """
    Optional multi-modal data to pass to the model,
    if the model supports it.
    """

35
    mm_processor_kwargs: NotRequired[dict[str, Any] | None]
36
37
38
39
40
41
42
    """
    Optional multi-modal processor kwargs to be forwarded to the
    multimodal input mapper & processor. Note that if multiple modalities
    have registered mappers etc for the model being considered, we attempt
    to pass the mm_processor_kwargs to each of them.
    """

43
    multi_modal_uuids: NotRequired[MultiModalUUIDDict]
44
45
46
47
48
49
50
51
    """
    Optional user-specified UUIDs for multimodal items, mapped by modality.
    Lists must match the number of items per modality and may contain `None`.
    For `None` entries, the hasher will compute IDs automatically; non-None
    entries override the default hashes for caching, and MUST be unique per
    multimodal item.
    """

52
53
54
55
56
    cache_salt: NotRequired[str]
    """
    Optional cache salt to be used for prefix caching.
    """

57

58
class TextPrompt(_PromptOptions):
59
60
61
62
63
64
    """Schema for a text prompt."""

    prompt: str
    """The input text to be tokenized before passing to the model."""


65
class TokensPrompt(_PromptOptions):
66
67
    """Schema for a tokenized prompt."""

68
    prompt_token_ids: list[int]
69
70
    """A list of token IDs to pass to the model."""

71
72
73
    prompt: NotRequired[str]
    """The prompt text corresponding to the token IDs, if available."""

74
    token_type_ids: NotRequired[list[int]]
75
76
    """A list of token type IDs to pass to the cross encoder model."""

77

78
class EmbedsPrompt(_PromptOptions):
79
80
81
82
83
    """Schema for a prompt provided via token embeddings."""

    prompt_embeds: torch.Tensor
    """The embeddings of the prompt."""

84
85
86
    prompt: NotRequired[str]
    """The prompt text corresponding to the token embeddings, if available."""

87

88
89
90
91
92
DecoderOnlyPrompt: TypeAlias = (
    str | TextPrompt | list[int] | TokensPrompt | EmbedsPrompt
)
"""
Schema of a prompt for a decoder-only model:
93

94
95
96
97
- A text prompt (string or [`TextPrompt`][vllm.inputs.data.TextPrompt])
- A tokenized prompt (list of token IDs, or
  [`TokensPrompt`][vllm.inputs.data.TokensPrompt])
- An embeddings prompt ([`EmbedsPrompt`][vllm.inputs.data.EmbedsPrompt])
98

99
100
101
For encoder-decoder models, passing a singleton prompt is shorthand for passing
`ExplicitEncoderDecoderPrompt(encoder_prompt=prompt, decoder_prompt=None)`.
"""
102
103


104
105
106
107
108
109
110
EncoderPrompt: TypeAlias = str | TextPrompt | list[int] | TokensPrompt
"""
Schema of a prompt for the encoder part of a encoder-decoder model:

- A text prompt (string or [`TextPrompt`][vllm.inputs.data.TextPrompt])
- A tokenized prompt (list of token IDs, or
  [`TokensPrompt`][vllm.inputs.data.TokensPrompt])
111
112
"""

113

114
DecoderPrompt: TypeAlias = str | TextPrompt | list[int] | TokensPrompt
115
"""
116
Schema of a prompt for the decoder part of an encoder-decoder model:
117

118
119
120
- A text prompt (string or [`TextPrompt`][vllm.inputs.data.TextPrompt])
- A tokenized prompt (list of token IDs, or
  [`TokensPrompt`][vllm.inputs.data.TokensPrompt])
121

122
123
124
Note:
    Multi-modal inputs are not supported for decoder prompts.
"""
125

126

127
class ExplicitEncoderDecoderPrompt(TypedDict):
128
    """
129
130
131
132
    Schema for a pair of encoder and decoder singleton prompts.

    Note:
        This schema is not valid for decoder-only models.
133
134
    """

135
136
    encoder_prompt: EncoderPrompt
    """The prompt for the encoder part of the model."""
137

138
139
140
    decoder_prompt: DecoderPrompt | None
    """
    The prompt for the decoder part of the model.
141

142
143
    Passing `None` will cause the prompt to be inferred automatically.
    """
144

145

146
EncoderDecoderPrompt: TypeAlias = EncoderPrompt | ExplicitEncoderDecoderPrompt
147
"""
148
Schema for a prompt for an encoder-decoder model.
149

150
151
152
153
154
155
156
157
158
159
You can pass a singleton encoder prompt, in which case the decoder prompt is
considered to be `None` (i.e., infer automatically).
"""


SingletonPrompt: TypeAlias = DecoderOnlyPrompt | EncoderPrompt | DecoderPrompt
"""
Schema for a single prompt. This is as opposed to a data structure
which encapsulates multiple prompts, such as
[`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt].
160
161
162
"""


163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
PromptType: TypeAlias = DecoderOnlyPrompt | EncoderDecoderPrompt
"""
Schema for any prompt, regardless of model type.

This is the input format accepted by most [`LLM`][vllm.entrypoints.llm.LLM] APIs.
"""


class DataPrompt(_PromptOptions):
    """
    Represents generic inputs that are converted to
    [`PromptType`][vllm.inputs.data.PromptType] by IO processor plugins.
    """

    data: Any
    """The input data."""

    data_format: str
    """The input data format."""


# Outputs of processor
class _InputOptions(TypedDict):
    """
    Additional options available to all input types.
    """

    cache_salt: NotRequired[str]
    """Optional cache salt to be used for prefix caching."""


class TokenInputs(_InputOptions):
195
    """Represents token-based inputs."""
196
197
198
199

    type: Literal["token"]
    """The type of inputs."""

200
    prompt_token_ids: list[int]
201
202
    """The token IDs of the prompt."""

203
204
205
    prompt: NotRequired[str]
    """The prompt text corresponding to the token IDs, if available."""

206

207
def token_inputs(
208
    prompt_token_ids: list[int],
209
210
    *,
    prompt: str | None = None,
211
    cache_salt: str | None = None,
212
) -> TokenInputs:
213
214
    """Construct [`TokenInputs`][vllm.inputs.data.TokenInputs] from optional
    values."""
215
    inputs = TokenInputs(type="token", prompt_token_ids=prompt_token_ids)
216

217
218
    if prompt is not None:
        inputs["prompt"] = prompt
219
220
    if cache_salt is not None:
        inputs["cache_salt"] = cache_salt
221
222
223
224

    return inputs


225
class EmbedsInputs(_InputOptions):
226
227
228
229
230
231
232
233
    """Represents embeddings-based inputs."""

    type: Literal["embeds"]
    """The type of inputs."""

    prompt_embeds: torch.Tensor
    """The embeddings of the prompt."""

234
235
236
    prompt: NotRequired[str]
    """The prompt text corresponding to the token IDs, if available."""

237

238
239
def embeds_inputs(
    prompt_embeds: torch.Tensor,
240
241
    *,
    prompt: str | None = None,
242
    cache_salt: str | None = None,
243
) -> EmbedsInputs:
244
245
    """Construct [`EmbedsInputs`][vllm.inputs.data.EmbedsInputs] from optional
    values."""
246
247
    inputs = EmbedsInputs(type="embeds", prompt_embeds=prompt_embeds)

248
249
    if prompt is not None:
        inputs["prompt"] = prompt
250
251
    if cache_salt is not None:
        inputs["cache_salt"] = cache_salt
252
253
254
255

    return inputs


256
DecoderOnlyInputs: TypeAlias = TokenInputs | EmbedsInputs | MultiModalInputs
257
"""
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
A processed prompt from
[`InputPreprocessor`][vllm.inputs.preprocess.InputPreprocessor]
which can be passed to
[`InputProcessor`][vllm.v1.engine.input_processor.InputProcessor]
for decoder-only models.
"""


EncoderInputs: TypeAlias = TokenInputs | MultiModalEncDecInputs
"""
A processed encoder prompt from
[`InputPreprocessor`][vllm.inputs.preprocess.InputPreprocessor]
which can be passed to
[`InputProcessor`][vllm.v1.engine.input_processor.InputProcessor]
for encoder-decoder models.
"""


DecoderInputs: TypeAlias = TokenInputs | MultiModalInputs
"""
A processed decoder prompt from
[`InputPreprocessor`][vllm.inputs.preprocess.InputPreprocessor]
which can be passed to
[`InputProcessor`][vllm.v1.engine.input_processor.InputProcessor]
for encoder-decoder models.
283
284
285
"""


286
class EncoderDecoderInputs(TypedDict):
287
    """
288
289
290
291
292
    A processed pair of encoder and decoder singleton prompts.
    [`InputPreprocessor`][vllm.inputs.preprocess.InputPreprocessor]
    which can be passed to
    [`InputProcessor`][vllm.v1.engine.input_processor.InputProcessor]
    for encoder-decoder models.
293
    """
294

295
296
297
    type: Literal["enc_dec"]

    encoder_prompt: EncoderInputs
298
    """The inputs for the encoder portion."""
299

300
    decoder_prompt: DecoderInputs
301
    """The inputs for the decoder portion."""
302

303

304
ProcessorInputs: TypeAlias = DecoderOnlyInputs | EncoderDecoderInputs
305
"""
306
307
308
309
A processed prompt from
[`InputPreprocessor`][vllm.inputs.preprocess.InputPreprocessor]
which can be passed to
[`InputProcessor`][vllm.v1.engine.input_processor.InputProcessor].
310
"""
311

312

313
SingletonInputs: TypeAlias = DecoderOnlyInputs | MultiModalEncDecInputs
314
"""The inputs for a single encoder/decoder prompt."""
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405


def _validate_enc_inputs(inputs: SingletonInputs) -> EncoderInputs:
    if inputs["type"] == "embeds":
        raise ValueError(
            "Embedding inputs are not supported for encoder-decoder models"
        )

    if inputs["type"] == "multimodal" and "encoder_prompt_token_ids" not in inputs:
        raise RuntimeError(
            "You should register an encoder-decoder multi-modal processor "
            "for encoder-decoder models."
        )

    return inputs  # type: ignore[return-value]


def _validate_dec_inputs(inputs: SingletonInputs) -> DecoderInputs:
    if inputs["type"] == "embeds":
        raise ValueError(
            "Embedding inputs are not supported for encoder-decoder models"
        )

    return inputs


def _prepare_decoder_input_ids_for_generation(
    decoder_input_ids: list[int],
    decoder_start_token_id: int,
) -> list[int]:
    """
    Prepare `decoder_input_ids` for generation with encoder-decoder models,
    according to `GenerationMixin._prepare_decoder_input_ids_for_generation()`.

    Source:
    https://github.com/huggingface/transformers/blob/v5.1.0/src/transformers/generation/utils.py
    """
    if len(decoder_input_ids) == 0 or decoder_input_ids[0] != decoder_start_token_id:
        decoder_input_ids = [decoder_start_token_id] + decoder_input_ids

    return decoder_input_ids


def build_enc_dec_inputs(
    encoder_inputs: SingletonInputs,
    decoder_inputs: SingletonInputs | None,
    decoder_start_token_id: int,
) -> EncoderDecoderInputs:
    enc_inputs = _validate_enc_inputs(encoder_inputs)

    if decoder_inputs is None:
        dec_inputs: DecoderInputs = enc_inputs
    else:
        dec_inputs = _validate_dec_inputs(decoder_inputs)

    enc_inputs_new: EncoderInputs
    dec_inputs_new: DecoderInputs

    if enc_inputs["type"] == "multimodal":
        from vllm.multimodal.inputs import mm_inputs

        enc_inputs_new = token_inputs(
            enc_inputs["encoder_prompt_token_ids"],
            prompt=enc_inputs.get("encoder_prompt"),
        )
        dec_inputs_new = mm_inputs(
            prompt_token_ids=dec_inputs["prompt_token_ids"],
            prompt=dec_inputs.get("prompt"),
            mm_kwargs=enc_inputs["mm_kwargs"],
            mm_hashes=enc_inputs["mm_hashes"],
            mm_placeholders=enc_inputs["mm_placeholders"],
        )
    elif enc_inputs["type"] == "token":
        enc_inputs_new = token_inputs(prompt_token_ids=[])
        dec_inputs_new = dec_inputs
    else:
        assert_never(enc_inputs)

    dec_inputs_new["prompt_token_ids"] = _prepare_decoder_input_ids_for_generation(
        dec_inputs_new["prompt_token_ids"],
        decoder_start_token_id,
    )

    if cache_salt := enc_inputs.get("cache_salt"):
        dec_inputs_new["cache_salt"] = cache_salt

    return EncoderDecoderInputs(
        type="enc_dec",
        encoder_prompt=enc_inputs_new,
        decoder_prompt=dec_inputs_new,
    )