data.py 12.3 KB
Newer Older
1
2
from dataclasses import dataclass
from functools import cached_property
3
from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List, Literal,
4
                    Optional, Tuple, Union, cast)
5

6
7
import torch
from typing_extensions import NotRequired, TypedDict, TypeVar, assert_never
8
9

if TYPE_CHECKING:
10
    from vllm.multimodal import MultiModalDataDict, MultiModalPlaceholderDict
11
    from vllm.multimodal.inputs import MultiModalInputsV2
12
13
14
15
16
17
18
19


class TextPrompt(TypedDict):
    """Schema for a text prompt."""

    prompt: str
    """The input text to be tokenized before passing to the model."""

20
    multi_modal_data: NotRequired["MultiModalDataDict"]
21
22
23
24
25
    """
    Optional multi-modal data to pass to the model,
    if the model supports it.
    """

26
27
28
29
30
31
32
33
    mm_processor_kwargs: NotRequired[Dict[str, Any]]
    """
    Optional multi-modal processor kwargs to be forwarded to the
    multimodal input mapper & processor. Note that if multiple modalities
    have registered mappers etc for the model being considered, we attempt
    to pass the mm_processor_kwargs to each of them.
    """

34
35
36
37
38
39
40

class TokensPrompt(TypedDict):
    """Schema for a tokenized prompt."""

    prompt_token_ids: List[int]
    """A list of token IDs to pass to the model."""

41
42
43
    token_type_ids: NotRequired[List[int]]
    """A list of token type IDs to pass to the cross encoder model."""

44
    multi_modal_data: NotRequired["MultiModalDataDict"]
45
    """
46
    DEPRECATED: Optional multi-modal data to pass to the model,
47
48
49
    if the model supports it.
    """

50
51
    mm_processor_kwargs: NotRequired[Dict[str, Any]]
    """
52
    DEPRECATED: Optional multi-modal processor kwargs to be forwarded to the
53
54
55
56
57
    multimodal input mapper & processor. Note that if multiple modalities
    have registered mappers etc for the model being considered, we attempt
    to pass the mm_processor_kwargs to each of them.
    """

58

59
SingletonPrompt = Union[str, TextPrompt, TokensPrompt]
60
"""
61
Set of possible schemas for a single prompt:
62
63
64

- A text prompt (:class:`str` or :class:`TextPrompt`)
- A tokenized prompt (:class:`TokensPrompt`)
65
66
67
68
69

Note that "singleton" is as opposed to a data structure
which encapsulates multiple prompts, i.e. of the sort
which may be utilized for encoder/decoder models when
the user desires to express both the encoder & decoder
70
prompts explicitly, i.e. :class:`ExplicitEncoderDecoderPrompt`
71

72
A prompt of type :class:`SingletonPrompt` may be employed
73
74
75
76
as (1) input to a decoder-only model, (2) input to
the encoder of an encoder/decoder model, in the scenario
where the decoder-prompt is not specified explicitly, or
(3) as a member of a larger data structure encapsulating
77
more than one prompt, i.e. :class:`ExplicitEncoderDecoderPrompt`
78
79
"""

80
_T1_co = TypeVar("_T1_co",
81
82
                 bound=SingletonPrompt,
                 default=SingletonPrompt,
83
84
                 covariant=True)
_T2_co = TypeVar("_T2_co",
85
86
                 bound=SingletonPrompt,
                 default=SingletonPrompt,
87
                 covariant=True)
88

89
90
91

# TODO: Make fields ReadOnly once mypy supports it
class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
92
93
94
    """
    Represents an encoder/decoder model input prompt,
    comprising an explicit encoder prompt and a decoder prompt.
95

96
97
98
    The encoder and decoder prompts, respectively, may be formatted
    according to any of the :class:`SingletonPrompt` schemas,
    and are not required to have the same schema.
99

100
101
102
    Only the encoder prompt may have multi-modal data. mm_processor_kwargs
    should be at the top-level, and should not be set in the encoder/decoder
    prompts, since they are agnostic to the encoder/decoder.
103

104
    Note that an :class:`ExplicitEncoderDecoderPrompt` may not
105
    be used as an input to a decoder-only model,
106
    and that the :code:`encoder_prompt` and :code:`decoder_prompt`
107
    fields of this data structure themselves must be
108
    :class:`SingletonPrompt` instances.
109
110
    """

111
    encoder_prompt: _T1_co
112

113
    decoder_prompt: Optional[_T2_co]
114

115
116
    mm_processor_kwargs: NotRequired[Dict[str, Any]]

117

118
PromptType = Union[SingletonPrompt, ExplicitEncoderDecoderPrompt]
119
120
121
122
123
124
125
126
127
128
129
"""
Set of possible schemas for an LLM input, including
both decoder-only and encoder/decoder input types:

- A text prompt (:class:`str` or :class:`TextPrompt`)
- A tokenized prompt (:class:`TokensPrompt`)
- A single data structure containing both an encoder and a decoder prompt
  (:class:`ExplicitEncoderDecoderPrompt`)
"""


130
131
class TokenInputs(TypedDict):
    """Represents token-based inputs."""
132
133
134
135

    type: Literal["token"]
    """The type of inputs."""

136
    prompt_token_ids: List[int]
137
138
    """The token IDs of the prompt."""

139
140
141
    token_type_ids: NotRequired[List[int]]
    """The token type IDs of the prompt."""

142
    prompt: NotRequired[str]
143
144
145
146
    """
    The original prompt text corresponding to the token IDs, if available.
    """

147
    multi_modal_data: NotRequired["MultiModalDataDict"]
148
149
150
151
152
    """
    Optional multi-modal data to pass to the model,
    if the model supports it.
    """

153
    multi_modal_placeholders: NotRequired["MultiModalPlaceholderDict"]
154
155
156
157
    """
    Placeholder ranges for the multi-modal data.
    """

158
    mm_processor_kwargs: NotRequired[Dict[str, Any]]
159
160
161
162
163
164
165
    """
    Optional multi-modal processor kwargs to be forwarded to the
    multimodal input mapper & processor. Note that if multiple modalities
    have registered mappers etc for the model being considered, we attempt
    to pass the mm_processor_kwargs to each of them.
    """

166

167
168
def token_inputs(
    prompt_token_ids: List[int],
169
    token_type_ids: Optional[List[int]] = None,
170
171
    prompt: Optional[str] = None,
    multi_modal_data: Optional["MultiModalDataDict"] = None,
172
    multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None,
173
174
175
    mm_processor_kwargs: Optional[Dict[str, Any]] = None,
) -> TokenInputs:
    """Construct :class:`TokenInputs` from optional values."""
176
    inputs = TokenInputs(type="token", prompt_token_ids=prompt_token_ids)
177
178
179

    if prompt is not None:
        inputs["prompt"] = prompt
180
181
    if token_type_ids is not None:
        inputs["token_type_ids"] = token_type_ids
182
183
    if multi_modal_data is not None:
        inputs["multi_modal_data"] = multi_modal_data
184
185
    if multi_modal_placeholders is not None:
        inputs["multi_modal_placeholders"] = multi_modal_placeholders
186
187
188
189
190
191
    if mm_processor_kwargs is not None:
        inputs["mm_processor_kwargs"] = mm_processor_kwargs

    return inputs


192
DecoderOnlyInputs = Union[TokenInputs, "MultiModalInputsV2"]
193
194
195
196
197
198
199
"""
The inputs in :class:`~vllm.LLMEngine` before they are
passed to the model executor.
This specifies the data required for decoder-only models.
"""


200
class EncoderDecoderInputs(TypedDict):
201
202
203
204
205
206
    """
    The inputs in :class:`~vllm.LLMEngine` before they are
    passed to the model executor.

    This specifies the required data for encoder-decoder models.
    """
207
    encoder: Union[TokenInputs, "MultiModalInputsV2"]
208
    """The inputs for the encoder portion."""
209

210
    decoder: Union[TokenInputs, "MultiModalInputsV2"]
211
    """The inputs for the decoder portion."""
212

213

214
SingletonInputs = Union[TokenInputs, "MultiModalInputsV2"]
215
216
217
218
219
"""
A processed :class:`SingletonPrompt` which can be passed to
:class:`vllm.sequence.Sequence`.
"""

220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245

@dataclass
class SingletonInputsAdapter:
    """
    Unified interface to access the components of :class:`SingletonInputs`.
    """
    inputs: SingletonInputs

    @cached_property
    def prompt(self) -> Optional[str]:
        inputs = self.inputs

        if inputs["type"] == "token" or inputs["type"] == "multimodal":
            return inputs.get("prompt")

        assert_never(inputs)

    @cached_property
    def prompt_token_ids(self) -> List[int]:
        inputs = self.inputs

        if inputs["type"] == "token" or inputs["type"] == "multimodal":
            return inputs.get("prompt_token_ids", [])

        assert_never(inputs)

246
247
248
249
250
251
252
253
254
    @cached_property
    def token_type_ids(self) -> List[int]:
        inputs = self.inputs

        if inputs["type"] == "token" or inputs["type"] == "multimodal":
            return inputs.get("token_type_ids", [])

        assert_never(inputs)

255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
    @cached_property
    def prompt_embeds(self) -> Optional[torch.Tensor]:
        inputs = self.inputs

        if inputs["type"] == "token" or inputs["type"] == "multimodal":
            return None

        assert_never(inputs)

    @cached_property
    def multi_modal_data(self) -> "MultiModalDataDict":
        inputs = self.inputs

        if inputs["type"] == "token":
            return inputs.get("multi_modal_data", {})

        if inputs["type"] == "multimodal":
            return inputs.get("mm_kwargs", {})

        assert_never(inputs)

    @cached_property
    def multi_modal_placeholders(self) -> "MultiModalPlaceholderDict":
        inputs = self.inputs

        if inputs["type"] == "token":
            return inputs.get("multi_modal_placeholders", {})

        if inputs["type"] == "multimodal":
            return inputs.get("mm_placeholders", {})

        assert_never(inputs)

    @cached_property
    def mm_processor_kwargs(self) -> Dict[str, Any]:
        inputs = self.inputs

        if inputs["type"] == "token":
            return inputs.get("mm_processor_kwargs", {})

        if inputs["type"] == "multimodal":
            return {}

        assert_never(inputs)


301
302
303
304
ProcessorInputs = Union[DecoderOnlyInputs, EncoderDecoderInputs]
"""
The inputs to :data:`vllm.inputs.InputProcessor`.
"""
305

306
307
_T1 = TypeVar("_T1", bound=SingletonPrompt, default=SingletonPrompt)
_T2 = TypeVar("_T2", bound=SingletonPrompt, default=SingletonPrompt)
308
309


310
311
312
def build_explicit_enc_dec_prompt(
    encoder_prompt: _T1,
    decoder_prompt: Optional[_T2],
313
    mm_processor_kwargs: Optional[Dict[str, Any]] = None,
314
) -> ExplicitEncoderDecoderPrompt[_T1, _T2]:
315
316
317
318
319
320
    if mm_processor_kwargs is None:
        mm_processor_kwargs = {}
    return ExplicitEncoderDecoderPrompt(
        encoder_prompt=encoder_prompt,
        decoder_prompt=decoder_prompt,
        mm_processor_kwargs=mm_processor_kwargs)
321
322
323
324
325


def zip_enc_dec_prompts(
    enc_prompts: Iterable[_T1],
    dec_prompts: Iterable[Optional[_T2]],
326
327
    mm_processor_kwargs: Optional[Union[Iterable[Dict[str, Any]],
                                        Dict[str, Any]]] = None,
328
) -> List[ExplicitEncoderDecoderPrompt[_T1, _T2]]:
329
    """
330
    Zip encoder and decoder prompts together into a list of
331
332
333
334
335
    :class:`ExplicitEncoderDecoderPrompt` instances.
    
    ``mm_processor_kwargs`` may also be provided; if a dict is passed, the same
    dictionary will be used for every encoder/decoder prompt. If an iterable is
    provided, it will be zipped with the encoder/decoder prompts.
336
337
    """
    if mm_processor_kwargs is None:
338
339
        mm_processor_kwargs = cast(Dict[str, Any], {})
    if isinstance(mm_processor_kwargs, dict):
340
        return [
341
342
343
            build_explicit_enc_dec_prompt(
                encoder_prompt, decoder_prompt,
                cast(Dict[str, Any], mm_processor_kwargs))
344
345
346
            for (encoder_prompt,
                 decoder_prompt) in zip(enc_prompts, dec_prompts)
        ]
347
    return [
348
349
350
351
        build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt,
                                      mm_proc_kwargs)
        for (encoder_prompt, decoder_prompt, mm_proc_kwargs
             ) in zip(enc_prompts, dec_prompts, mm_processor_kwargs)
352
353
    ]

354

355
356
357
358
359
360
def to_enc_dec_tuple_list(
    enc_dec_prompts: Iterable[ExplicitEncoderDecoderPrompt[_T1, _T2]],
) -> List[Tuple[_T1, Optional[_T2]]]:
    return [(enc_dec_prompt["encoder_prompt"],
             enc_dec_prompt["decoder_prompt"])
            for enc_dec_prompt in enc_dec_prompts]
361
362
363


def __getattr__(name: str):
364
    import warnings
365

366
    if name == "PromptInput":
367
368
369
370
371
372
373
        msg = ("PromptInput has been renamed to PromptType. "
               "The original name will be removed in an upcoming version.")

        warnings.warn(DeprecationWarning(msg), stacklevel=2)

        return PromptType

374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
    if name == "LLMInputs":
        msg = ("LLMInputs has been renamed to DecoderOnlyInputs. "
               "The original name will be removed in an upcoming version.")

        warnings.warn(DeprecationWarning(msg), stacklevel=2)

        return DecoderOnlyInputs

    if name == "EncoderDecoderLLMInputs":
        msg = (
            "EncoderDecoderLLMInputs has been renamed to EncoderDecoderInputs. "
            "The original name will be removed in an upcoming version.")

        warnings.warn(DeprecationWarning(msg), stacklevel=2)

        return EncoderDecoderInputs

391
    raise AttributeError(f"module {__name__!r} has no attribute {name!r}")