"vllm/utils/__init__.py" did not exist on "fbefc8a78d22b20eac042c586805c7dcbfc66b1c"
data.py 11.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from typing import TYPE_CHECKING, Any, Literal, TypeAlias
4

5
import torch
6
from typing_extensions import NotRequired, TypedDict, assert_never
7
8

if TYPE_CHECKING:
9
10
    from vllm.multimodal.inputs import (
        MultiModalDataDict,
11
        MultiModalEncDecInputs,
12
13
14
        MultiModalInputs,
        MultiModalUUIDDict,
    )
15
16
else:
    MultiModalDataDict = object
17
    MultiModalEncDecInputs = object
18
19
    MultiModalInputs = object
    MultiModalUUIDDict = object
20
21


22
23
24
25
26
27
28
# Inputs to LLM API
class _PromptOptions(TypedDict):
    """
    Additional options available to all
    [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt].
    """

29
    multi_modal_data: NotRequired[MultiModalDataDict | None]
30
31
32
33
34
    """
    Optional multi-modal data to pass to the model,
    if the model supports it.
    """

35
    mm_processor_kwargs: NotRequired[dict[str, Any] | None]
36
37
38
39
40
41
42
    """
    Optional multi-modal processor kwargs to be forwarded to the
    multimodal input mapper & processor. Note that if multiple modalities
    have registered mappers etc for the model being considered, we attempt
    to pass the mm_processor_kwargs to each of them.
    """

43
    multi_modal_uuids: NotRequired[MultiModalUUIDDict]
44
45
46
47
48
49
50
51
    """
    Optional user-specified UUIDs for multimodal items, mapped by modality.
    Lists must match the number of items per modality and may contain `None`.
    For `None` entries, the hasher will compute IDs automatically; non-None
    entries override the default hashes for caching, and MUST be unique per
    multimodal item.
    """

52
53
54
55
56
    cache_salt: NotRequired[str]
    """
    Optional cache salt to be used for prefix caching.
    """

57

58
class TextPrompt(_PromptOptions):
59
60
61
62
63
64
    """Schema for a text prompt."""

    prompt: str
    """The input text to be tokenized before passing to the model."""


65
class TokensPrompt(_PromptOptions):
66
67
    """Schema for a tokenized prompt."""

68
    prompt_token_ids: list[int]
69
70
    """A list of token IDs to pass to the model."""

71
72
73
    prompt: NotRequired[str]
    """The prompt text corresponding to the token IDs, if available."""

74
    token_type_ids: NotRequired[list[int]]
75
76
    """A list of token type IDs to pass to the cross encoder model."""

77

78
class EmbedsPrompt(_PromptOptions):
79
80
81
82
83
    """Schema for a prompt provided via token embeddings."""

    prompt_embeds: torch.Tensor
    """The embeddings of the prompt."""

84
85
86
    prompt: NotRequired[str]
    """The prompt text corresponding to the token embeddings, if available."""

87

88
89
90
91
92
DecoderOnlyPrompt: TypeAlias = (
    str | TextPrompt | list[int] | TokensPrompt | EmbedsPrompt
)
"""
Schema of a prompt for a decoder-only model:
93

94
95
96
97
- A text prompt (string or [`TextPrompt`][vllm.inputs.data.TextPrompt])
- A tokenized prompt (list of token IDs, or
  [`TokensPrompt`][vllm.inputs.data.TokensPrompt])
- An embeddings prompt ([`EmbedsPrompt`][vllm.inputs.data.EmbedsPrompt])
98

99
100
101
For encoder-decoder models, passing a singleton prompt is shorthand for passing
`ExplicitEncoderDecoderPrompt(encoder_prompt=prompt, decoder_prompt=None)`.
"""
102
103


104
105
106
107
108
109
110
EncoderPrompt: TypeAlias = str | TextPrompt | list[int] | TokensPrompt
"""
Schema of a prompt for the encoder part of a encoder-decoder model:

- A text prompt (string or [`TextPrompt`][vllm.inputs.data.TextPrompt])
- A tokenized prompt (list of token IDs, or
  [`TokensPrompt`][vllm.inputs.data.TokensPrompt])
111
112
"""

113

114
DecoderPrompt: TypeAlias = str | TextPrompt | list[int] | TokensPrompt
115
"""
116
Schema of a prompt for the decoder part of an encoder-decoder model:
117

118
119
120
- A text prompt (string or [`TextPrompt`][vllm.inputs.data.TextPrompt])
- A tokenized prompt (list of token IDs, or
  [`TokensPrompt`][vllm.inputs.data.TokensPrompt])
121

122
123
124
Note:
    Multi-modal inputs are not supported for decoder prompts.
"""
125

126

127
class ExplicitEncoderDecoderPrompt(TypedDict):
128
    """
129
130
131
132
    Schema for a pair of encoder and decoder singleton prompts.

    Note:
        This schema is not valid for decoder-only models.
133
134
    """

135
136
    encoder_prompt: EncoderPrompt
    """The prompt for the encoder part of the model."""
137

138
139
140
    decoder_prompt: DecoderPrompt | None
    """
    The prompt for the decoder part of the model.
141

142
143
    Passing `None` will cause the prompt to be inferred automatically.
    """
144

145

146
EncoderDecoderPrompt: TypeAlias = EncoderPrompt | ExplicitEncoderDecoderPrompt
147
"""
148
Schema for a prompt for an encoder-decoder model.
149

150
151
152
153
154
155
156
157
158
159
You can pass a singleton encoder prompt, in which case the decoder prompt is
considered to be `None` (i.e., infer automatically).
"""


SingletonPrompt: TypeAlias = DecoderOnlyPrompt | EncoderPrompt | DecoderPrompt
"""
Schema for a single prompt. This is as opposed to a data structure
which encapsulates multiple prompts, such as
[`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt].
160
161
162
"""


163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
PromptType: TypeAlias = DecoderOnlyPrompt | EncoderDecoderPrompt
"""
Schema for any prompt, regardless of model type.

This is the input format accepted by most [`LLM`][vllm.entrypoints.llm.LLM] APIs.
"""


class DataPrompt(_PromptOptions):
    """
    Represents generic inputs that are converted to
    [`PromptType`][vllm.inputs.data.PromptType] by IO processor plugins.
    """

    data: Any
    """The input data."""

    data_format: str
    """The input data format."""


# Outputs of processor
class _InputOptions(TypedDict):
    """
    Additional options available to all input types.
    """

190
191
192
    arrival_time: NotRequired[float]
    """The time when the input was received (before rendering)."""

193
194
195
196
197
    cache_salt: NotRequired[str]
    """Optional cache salt to be used for prefix caching."""


class TokenInputs(_InputOptions):
198
    """Represents token-based inputs."""
199
200
201
202

    type: Literal["token"]
    """The type of inputs."""

203
    prompt_token_ids: list[int]
204
205
    """The token IDs of the prompt."""

206
207
208
    prompt: NotRequired[str]
    """The prompt text corresponding to the token IDs, if available."""

209

210
def token_inputs(
211
    prompt_token_ids: list[int],
212
213
    *,
    prompt: str | None = None,
214
    cache_salt: str | None = None,
215
) -> TokenInputs:
216
217
    """Construct [`TokenInputs`][vllm.inputs.data.TokenInputs] from optional
    values."""
218
    inputs = TokenInputs(type="token", prompt_token_ids=prompt_token_ids)
219

220
221
    if prompt is not None:
        inputs["prompt"] = prompt
222
223
    if cache_salt is not None:
        inputs["cache_salt"] = cache_salt
224
225
226
227

    return inputs


228
class EmbedsInputs(_InputOptions):
229
230
231
232
233
234
235
236
    """Represents embeddings-based inputs."""

    type: Literal["embeds"]
    """The type of inputs."""

    prompt_embeds: torch.Tensor
    """The embeddings of the prompt."""

237
238
239
    prompt: NotRequired[str]
    """The prompt text corresponding to the token IDs, if available."""

240

241
242
def embeds_inputs(
    prompt_embeds: torch.Tensor,
243
244
    *,
    prompt: str | None = None,
245
    cache_salt: str | None = None,
246
) -> EmbedsInputs:
247
248
    """Construct [`EmbedsInputs`][vllm.inputs.data.EmbedsInputs] from optional
    values."""
249
250
    inputs = EmbedsInputs(type="embeds", prompt_embeds=prompt_embeds)

251
252
    if prompt is not None:
        inputs["prompt"] = prompt
253
254
    if cache_salt is not None:
        inputs["cache_salt"] = cache_salt
255
256
257
258

    return inputs


259
DecoderOnlyInputs: TypeAlias = TokenInputs | EmbedsInputs | MultiModalInputs
260
"""
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
A processed prompt from
[`InputPreprocessor`][vllm.inputs.preprocess.InputPreprocessor]
which can be passed to
[`InputProcessor`][vllm.v1.engine.input_processor.InputProcessor]
for decoder-only models.
"""


EncoderInputs: TypeAlias = TokenInputs | MultiModalEncDecInputs
"""
A processed encoder prompt from
[`InputPreprocessor`][vllm.inputs.preprocess.InputPreprocessor]
which can be passed to
[`InputProcessor`][vllm.v1.engine.input_processor.InputProcessor]
for encoder-decoder models.
"""


DecoderInputs: TypeAlias = TokenInputs | MultiModalInputs
"""
A processed decoder prompt from
[`InputPreprocessor`][vllm.inputs.preprocess.InputPreprocessor]
which can be passed to
[`InputProcessor`][vllm.v1.engine.input_processor.InputProcessor]
for encoder-decoder models.
286
287
288
"""


289
class EncoderDecoderInputs(TypedDict):
290
    """
291
292
293
294
295
    A processed pair of encoder and decoder singleton prompts.
    [`InputPreprocessor`][vllm.inputs.preprocess.InputPreprocessor]
    which can be passed to
    [`InputProcessor`][vllm.v1.engine.input_processor.InputProcessor]
    for encoder-decoder models.
296
    """
297

298
299
300
    type: Literal["enc_dec"]

    encoder_prompt: EncoderInputs
301
    """The inputs for the encoder portion."""
302

303
    decoder_prompt: DecoderInputs
304
    """The inputs for the decoder portion."""
305

306
307
308
    arrival_time: NotRequired[float]
    """The time when the input was received (before rendering)."""

309

310
ProcessorInputs: TypeAlias = DecoderOnlyInputs | EncoderDecoderInputs
311
"""
312
313
314
315
A processed prompt from
[`InputPreprocessor`][vllm.inputs.preprocess.InputPreprocessor]
which can be passed to
[`InputProcessor`][vllm.v1.engine.input_processor.InputProcessor].
316
"""
317

318

319
SingletonInputs: TypeAlias = DecoderOnlyInputs | MultiModalEncDecInputs
320
"""The inputs for a single encoder/decoder prompt."""
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411


def _validate_enc_inputs(inputs: SingletonInputs) -> EncoderInputs:
    if inputs["type"] == "embeds":
        raise ValueError(
            "Embedding inputs are not supported for encoder-decoder models"
        )

    if inputs["type"] == "multimodal" and "encoder_prompt_token_ids" not in inputs:
        raise RuntimeError(
            "You should register an encoder-decoder multi-modal processor "
            "for encoder-decoder models."
        )

    return inputs  # type: ignore[return-value]


def _validate_dec_inputs(inputs: SingletonInputs) -> DecoderInputs:
    if inputs["type"] == "embeds":
        raise ValueError(
            "Embedding inputs are not supported for encoder-decoder models"
        )

    return inputs


def _prepare_decoder_input_ids_for_generation(
    decoder_input_ids: list[int],
    decoder_start_token_id: int,
) -> list[int]:
    """
    Prepare `decoder_input_ids` for generation with encoder-decoder models,
    according to `GenerationMixin._prepare_decoder_input_ids_for_generation()`.

    Source:
    https://github.com/huggingface/transformers/blob/v5.1.0/src/transformers/generation/utils.py
    """
    if len(decoder_input_ids) == 0 or decoder_input_ids[0] != decoder_start_token_id:
        decoder_input_ids = [decoder_start_token_id] + decoder_input_ids

    return decoder_input_ids


def build_enc_dec_inputs(
    encoder_inputs: SingletonInputs,
    decoder_inputs: SingletonInputs | None,
    decoder_start_token_id: int,
) -> EncoderDecoderInputs:
    enc_inputs = _validate_enc_inputs(encoder_inputs)

    if decoder_inputs is None:
        dec_inputs: DecoderInputs = enc_inputs
    else:
        dec_inputs = _validate_dec_inputs(decoder_inputs)

    enc_inputs_new: EncoderInputs
    dec_inputs_new: DecoderInputs

    if enc_inputs["type"] == "multimodal":
        from vllm.multimodal.inputs import mm_inputs

        enc_inputs_new = token_inputs(
            enc_inputs["encoder_prompt_token_ids"],
            prompt=enc_inputs.get("encoder_prompt"),
        )
        dec_inputs_new = mm_inputs(
            prompt_token_ids=dec_inputs["prompt_token_ids"],
            prompt=dec_inputs.get("prompt"),
            mm_kwargs=enc_inputs["mm_kwargs"],
            mm_hashes=enc_inputs["mm_hashes"],
            mm_placeholders=enc_inputs["mm_placeholders"],
        )
    elif enc_inputs["type"] == "token":
        enc_inputs_new = token_inputs(prompt_token_ids=[])
        dec_inputs_new = dec_inputs
    else:
        assert_never(enc_inputs)

    dec_inputs_new["prompt_token_ids"] = _prepare_decoder_input_ids_for_generation(
        dec_inputs_new["prompt_token_ids"],
        decoder_start_token_id,
    )

    if cache_salt := enc_inputs.get("cache_salt"):
        dec_inputs_new["cache_salt"] = cache_salt

    return EncoderDecoderInputs(
        type="enc_dec",
        encoder_prompt=enc_inputs_new,
        decoder_prompt=dec_inputs_new,
    )