data.py 12.9 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from collections.abc import Iterable
4
5
from dataclasses import dataclass
from functools import cached_property
6
from typing import TYPE_CHECKING, Any, Generic, Literal, Optional, Union, cast
7

8
9
import torch
from typing_extensions import NotRequired, TypedDict, TypeVar, assert_never
10
11

if TYPE_CHECKING:
12
13
    from vllm.multimodal import (MultiModalDataDict, MultiModalKwargs,
                                 MultiModalPlaceholderDict)
14
    from vllm.multimodal.inputs import MultiModalInputs
15
16
17
18
19
20
21
22


class TextPrompt(TypedDict):
    """Schema for a text prompt."""

    prompt: str
    """The input text to be tokenized before passing to the model."""

23
    multi_modal_data: NotRequired["MultiModalDataDict"]
24
25
26
27
28
    """
    Optional multi-modal data to pass to the model,
    if the model supports it.
    """

29
    mm_processor_kwargs: NotRequired[dict[str, Any]]
30
31
32
33
34
35
36
    """
    Optional multi-modal processor kwargs to be forwarded to the
    multimodal input mapper & processor. Note that if multiple modalities
    have registered mappers etc for the model being considered, we attempt
    to pass the mm_processor_kwargs to each of them.
    """

37
38
39
40

class TokensPrompt(TypedDict):
    """Schema for a tokenized prompt."""

41
    prompt_token_ids: list[int]
42
43
    """A list of token IDs to pass to the model."""

44
    token_type_ids: NotRequired[list[int]]
45
46
    """A list of token type IDs to pass to the cross encoder model."""

47
    multi_modal_data: NotRequired["MultiModalDataDict"]
48
    """
49
    Optional multi-modal data to pass to the model,
50
51
52
    if the model supports it.
    """

53
    mm_processor_kwargs: NotRequired[dict[str, Any]]
54
    """
55
    Optional multi-modal processor kwargs to be forwarded to the
56
57
58
59
60
    multimodal input mapper & processor. Note that if multiple modalities
    have registered mappers etc for the model being considered, we attempt
    to pass the mm_processor_kwargs to each of them.
    """

61

62
SingletonPrompt = Union[str, TextPrompt, TokensPrompt]
63
"""
64
Set of possible schemas for a single prompt:
65
66
67

- A text prompt (:class:`str` or :class:`TextPrompt`)
- A tokenized prompt (:class:`TokensPrompt`)
68
69
70
71
72

Note that "singleton" is as opposed to a data structure
which encapsulates multiple prompts, i.e. of the sort
which may be utilized for encoder/decoder models when
the user desires to express both the encoder & decoder
73
prompts explicitly, i.e. :class:`ExplicitEncoderDecoderPrompt`
74

75
A prompt of type :class:`SingletonPrompt` may be employed
76
77
78
79
as (1) input to a decoder-only model, (2) input to
the encoder of an encoder/decoder model, in the scenario
where the decoder-prompt is not specified explicitly, or
(3) as a member of a larger data structure encapsulating
80
more than one prompt, i.e. :class:`ExplicitEncoderDecoderPrompt`
81
82
"""

83
_T1_co = TypeVar("_T1_co",
84
85
                 bound=SingletonPrompt,
                 default=SingletonPrompt,
86
87
                 covariant=True)
_T2_co = TypeVar("_T2_co",
88
89
                 bound=SingletonPrompt,
                 default=SingletonPrompt,
90
                 covariant=True)
91

92
93
94

# TODO: Make fields ReadOnly once mypy supports it
class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
95
96
97
    """
    Represents an encoder/decoder model input prompt,
    comprising an explicit encoder prompt and a decoder prompt.
98

99
100
101
    The encoder and decoder prompts, respectively, may be formatted
    according to any of the :class:`SingletonPrompt` schemas,
    and are not required to have the same schema.
102

103
104
105
    Only the encoder prompt may have multi-modal data. mm_processor_kwargs
    should be at the top-level, and should not be set in the encoder/decoder
    prompts, since they are agnostic to the encoder/decoder.
106

107
    Note that an :class:`ExplicitEncoderDecoderPrompt` may not
108
    be used as an input to a decoder-only model,
109
    and that the :code:`encoder_prompt` and :code:`decoder_prompt`
110
    fields of this data structure themselves must be
111
    :class:`SingletonPrompt` instances.
112
113
    """

114
    encoder_prompt: _T1_co
115

116
    decoder_prompt: Optional[_T2_co]
117

118
    mm_processor_kwargs: NotRequired[dict[str, Any]]
119

120

121
PromptType = Union[SingletonPrompt, ExplicitEncoderDecoderPrompt]
122
123
124
125
126
127
128
129
130
131
132
"""
Set of possible schemas for an LLM input, including
both decoder-only and encoder/decoder input types:

- A text prompt (:class:`str` or :class:`TextPrompt`)
- A tokenized prompt (:class:`TokensPrompt`)
- A single data structure containing both an encoder and a decoder prompt
  (:class:`ExplicitEncoderDecoderPrompt`)
"""


133
134
class TokenInputs(TypedDict):
    """Represents token-based inputs."""
135
136
137
138

    type: Literal["token"]
    """The type of inputs."""

139
    prompt_token_ids: list[int]
140
141
    """The token IDs of the prompt."""

142
    token_type_ids: NotRequired[list[int]]
143
144
    """The token type IDs of the prompt."""

145
    prompt: NotRequired[str]
146
147
148
149
    """
    The original prompt text corresponding to the token IDs, if available.
    """

150
    multi_modal_data: NotRequired["MultiModalDataDict"]
151
152
153
154
155
    """
    Optional multi-modal data to pass to the model,
    if the model supports it.
    """

156
157
158
159
160
161
    multi_modal_inputs: NotRequired["MultiModalKwargs"]
    """
    Optional multi-modal inputs to pass to the model,
    if the model supports it.
    """

162
    multi_modal_placeholders: NotRequired["MultiModalPlaceholderDict"]
163
164
165
166
    """
    Placeholder ranges for the multi-modal data.
    """

167
    multi_modal_hashes: NotRequired[list[str]]
168
169
170
171
    """
    The hashes of the multi-modal data.
    """

172
    mm_processor_kwargs: NotRequired[dict[str, Any]]
173
174
175
176
177
178
179
    """
    Optional multi-modal processor kwargs to be forwarded to the
    multimodal input mapper & processor. Note that if multiple modalities
    have registered mappers etc for the model being considered, we attempt
    to pass the mm_processor_kwargs to each of them.
    """

180

181
def token_inputs(
182
183
    prompt_token_ids: list[int],
    token_type_ids: Optional[list[int]] = None,
184
185
    prompt: Optional[str] = None,
    multi_modal_data: Optional["MultiModalDataDict"] = None,
186
    multi_modal_inputs: Optional["MultiModalKwargs"] = None,
187
    multi_modal_hashes: Optional[list[str]] = None,
188
    multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None,
189
    mm_processor_kwargs: Optional[dict[str, Any]] = None,
190
191
) -> TokenInputs:
    """Construct :class:`TokenInputs` from optional values."""
192
    inputs = TokenInputs(type="token", prompt_token_ids=prompt_token_ids)
193
194
195

    if prompt is not None:
        inputs["prompt"] = prompt
196
197
    if token_type_ids is not None:
        inputs["token_type_ids"] = token_type_ids
198
199
    if multi_modal_data is not None:
        inputs["multi_modal_data"] = multi_modal_data
200
201
    if multi_modal_inputs is not None:
        inputs["multi_modal_inputs"] = multi_modal_inputs
202
203
    if multi_modal_hashes is not None:
        inputs["multi_modal_hashes"] = multi_modal_hashes
204
205
    if multi_modal_placeholders is not None:
        inputs["multi_modal_placeholders"] = multi_modal_placeholders
206
207
208
209
210
211
    if mm_processor_kwargs is not None:
        inputs["mm_processor_kwargs"] = mm_processor_kwargs

    return inputs


212
DecoderOnlyInputs = Union[TokenInputs, "MultiModalInputs"]
213
214
215
216
217
218
219
"""
The inputs in :class:`~vllm.LLMEngine` before they are
passed to the model executor.
This specifies the data required for decoder-only models.
"""


220
class EncoderDecoderInputs(TypedDict):
221
222
223
224
225
226
    """
    The inputs in :class:`~vllm.LLMEngine` before they are
    passed to the model executor.

    This specifies the required data for encoder-decoder models.
    """
227
    encoder: Union[TokenInputs, "MultiModalInputs"]
228
    """The inputs for the encoder portion."""
229

230
    decoder: Union[TokenInputs, "MultiModalInputs"]
231
    """The inputs for the decoder portion."""
232

233

234
SingletonInputs = Union[TokenInputs, "MultiModalInputs"]
235
236
237
238
239
"""
A processed :class:`SingletonPrompt` which can be passed to
:class:`vllm.sequence.Sequence`.
"""

240
241
242
243
244
245
246
247
248
249
250
251
252
253
254

@dataclass
class SingletonInputsAdapter:
    """
    Unified interface to access the components of :class:`SingletonInputs`.
    """
    inputs: SingletonInputs

    @cached_property
    def prompt(self) -> Optional[str]:
        inputs = self.inputs

        if inputs["type"] == "token" or inputs["type"] == "multimodal":
            return inputs.get("prompt")

255
        assert_never(inputs)  # type: ignore[arg-type]
256
257

    @cached_property
258
    def prompt_token_ids(self) -> list[int]:
259
260
261
262
263
        inputs = self.inputs

        if inputs["type"] == "token" or inputs["type"] == "multimodal":
            return inputs.get("prompt_token_ids", [])

264
        assert_never(inputs)  # type: ignore[arg-type]
265

266
    @cached_property
267
    def token_type_ids(self) -> list[int]:
268
269
270
271
272
        inputs = self.inputs

        if inputs["type"] == "token" or inputs["type"] == "multimodal":
            return inputs.get("token_type_ids", [])

273
        assert_never(inputs)  # type: ignore[arg-type]
274

275
276
277
278
279
280
281
    @cached_property
    def prompt_embeds(self) -> Optional[torch.Tensor]:
        inputs = self.inputs

        if inputs["type"] == "token" or inputs["type"] == "multimodal":
            return None

282
        assert_never(inputs)  # type: ignore[arg-type]
283
284
285
286
287
288
289
290
291

    @cached_property
    def multi_modal_data(self) -> "MultiModalDataDict":
        inputs = self.inputs

        if inputs["type"] == "token":
            return inputs.get("multi_modal_data", {})

        if inputs["type"] == "multimodal":
292
293
            return inputs.get("mm_kwargs", {})

294
        assert_never(inputs)  # type: ignore[arg-type]
295
296

    @cached_property
297
    def multi_modal_inputs(self) -> Union[dict, "MultiModalKwargs"]:
298
299
300
301
302
303
        inputs = self.inputs

        if inputs["type"] == "token":
            return inputs.get("multi_modal_inputs", {})

        if inputs["type"] == "multimodal":
304
305
            return inputs.get("mm_kwargs", {})

306
        assert_never(inputs)  # type: ignore[arg-type]
307

308
    @cached_property
309
    def multi_modal_hashes(self) -> list[str]:
310
311
312
313
314
315
        inputs = self.inputs

        if inputs["type"] == "token":
            return inputs.get("multi_modal_hashes", [])

        if inputs["type"] == "multimodal":
316
            # only the case when we use MultiModalInputs
317
            return inputs.get("mm_hashes", [])  # type: ignore[return-value]
318

319
        assert_never(inputs)  # type: ignore[arg-type]
320

321
322
323
324
325
326
327
328
329
330
    @cached_property
    def multi_modal_placeholders(self) -> "MultiModalPlaceholderDict":
        inputs = self.inputs

        if inputs["type"] == "token":
            return inputs.get("multi_modal_placeholders", {})

        if inputs["type"] == "multimodal":
            return inputs.get("mm_placeholders", {})

331
        assert_never(inputs)  # type: ignore[arg-type]
332
333

    @cached_property
334
    def mm_processor_kwargs(self) -> dict[str, Any]:
335
336
337
338
339
340
341
342
        inputs = self.inputs

        if inputs["type"] == "token":
            return inputs.get("mm_processor_kwargs", {})

        if inputs["type"] == "multimodal":
            return {}

343
        assert_never(inputs)  # type: ignore[arg-type]
344
345


346
347
348
349
ProcessorInputs = Union[DecoderOnlyInputs, EncoderDecoderInputs]
"""
The inputs to :data:`vllm.inputs.InputProcessor`.
"""
350

351
352
_T1 = TypeVar("_T1", bound=SingletonPrompt, default=SingletonPrompt)
_T2 = TypeVar("_T2", bound=SingletonPrompt, default=SingletonPrompt)
353
354


355
356
357
def build_explicit_enc_dec_prompt(
    encoder_prompt: _T1,
    decoder_prompt: Optional[_T2],
358
    mm_processor_kwargs: Optional[dict[str, Any]] = None,
359
) -> ExplicitEncoderDecoderPrompt[_T1, _T2]:
360
361
362
363
364
365
    if mm_processor_kwargs is None:
        mm_processor_kwargs = {}
    return ExplicitEncoderDecoderPrompt(
        encoder_prompt=encoder_prompt,
        decoder_prompt=decoder_prompt,
        mm_processor_kwargs=mm_processor_kwargs)
366
367
368
369
370


def zip_enc_dec_prompts(
    enc_prompts: Iterable[_T1],
    dec_prompts: Iterable[Optional[_T2]],
371
372
373
    mm_processor_kwargs: Optional[Union[Iterable[dict[str, Any]],
                                        dict[str, Any]]] = None,
) -> list[ExplicitEncoderDecoderPrompt[_T1, _T2]]:
374
    """
375
    Zip encoder and decoder prompts together into a list of
376
377
378
379
380
    :class:`ExplicitEncoderDecoderPrompt` instances.
    
    ``mm_processor_kwargs`` may also be provided; if a dict is passed, the same
    dictionary will be used for every encoder/decoder prompt. If an iterable is
    provided, it will be zipped with the encoder/decoder prompts.
381
382
    """
    if mm_processor_kwargs is None:
383
        mm_processor_kwargs = cast(dict[str, Any], {})
384
    if isinstance(mm_processor_kwargs, dict):
385
        return [
386
387
            build_explicit_enc_dec_prompt(
                encoder_prompt, decoder_prompt,
388
                cast(dict[str, Any], mm_processor_kwargs))
389
390
391
            for (encoder_prompt,
                 decoder_prompt) in zip(enc_prompts, dec_prompts)
        ]
392
    return [
393
394
395
396
        build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt,
                                      mm_proc_kwargs)
        for (encoder_prompt, decoder_prompt, mm_proc_kwargs
             ) in zip(enc_prompts, dec_prompts, mm_processor_kwargs)
397
398
    ]

399

400
401
def to_enc_dec_tuple_list(
    enc_dec_prompts: Iterable[ExplicitEncoderDecoderPrompt[_T1, _T2]],
402
) -> list[tuple[_T1, Optional[_T2]]]:
403
404
405
    return [(enc_dec_prompt["encoder_prompt"],
             enc_dec_prompt["decoder_prompt"])
            for enc_dec_prompt in enc_dec_prompts]