data.py 12.1 KB
Newer Older
1
2
from dataclasses import dataclass
from functools import cached_property
3
from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List, Literal,
4
                    Optional, Tuple, Union, cast)
5

6
7
import torch
from typing_extensions import NotRequired, TypedDict, TypeVar, assert_never
8
9

if TYPE_CHECKING:
10
11
    from vllm.multimodal import (MultiModalDataDict, MultiModalKwargs,
                                 MultiModalPlaceholderDict)
12
    from vllm.multimodal.inputs import MultiModalInputsV2
13
14
15
16
17
18
19
20


class TextPrompt(TypedDict):
    """Schema for a text prompt."""

    prompt: str
    """The input text to be tokenized before passing to the model."""

21
    multi_modal_data: NotRequired["MultiModalDataDict"]
22
23
24
25
26
    """
    Optional multi-modal data to pass to the model,
    if the model supports it.
    """

27
28
29
30
31
32
33
34
    mm_processor_kwargs: NotRequired[Dict[str, Any]]
    """
    Optional multi-modal processor kwargs to be forwarded to the
    multimodal input mapper & processor. Note that if multiple modalities
    have registered mappers etc for the model being considered, we attempt
    to pass the mm_processor_kwargs to each of them.
    """

35
36
37
38
39
40
41

class TokensPrompt(TypedDict):
    """Schema for a tokenized prompt."""

    prompt_token_ids: List[int]
    """A list of token IDs to pass to the model."""

42
43
44
    token_type_ids: NotRequired[List[int]]
    """A list of token type IDs to pass to the cross encoder model."""

45
    multi_modal_data: NotRequired["MultiModalDataDict"]
46
    """
47
    DEPRECATED: Optional multi-modal data to pass to the model,
48
49
50
    if the model supports it.
    """

51
52
    mm_processor_kwargs: NotRequired[Dict[str, Any]]
    """
53
    DEPRECATED: Optional multi-modal processor kwargs to be forwarded to the
54
55
56
57
58
    multimodal input mapper & processor. Note that if multiple modalities
    have registered mappers etc for the model being considered, we attempt
    to pass the mm_processor_kwargs to each of them.
    """

59

60
SingletonPrompt = Union[str, TextPrompt, TokensPrompt]
61
"""
62
Set of possible schemas for a single prompt:
63
64
65

- A text prompt (:class:`str` or :class:`TextPrompt`)
- A tokenized prompt (:class:`TokensPrompt`)
66
67
68
69
70

Note that "singleton" is as opposed to a data structure
which encapsulates multiple prompts, i.e. of the sort
which may be utilized for encoder/decoder models when
the user desires to express both the encoder & decoder
71
prompts explicitly, i.e. :class:`ExplicitEncoderDecoderPrompt`
72

73
A prompt of type :class:`SingletonPrompt` may be employed
74
75
76
77
as (1) input to a decoder-only model, (2) input to
the encoder of an encoder/decoder model, in the scenario
where the decoder-prompt is not specified explicitly, or
(3) as a member of a larger data structure encapsulating
78
more than one prompt, i.e. :class:`ExplicitEncoderDecoderPrompt`
79
80
"""

81
_T1_co = TypeVar("_T1_co",
82
83
                 bound=SingletonPrompt,
                 default=SingletonPrompt,
84
85
                 covariant=True)
_T2_co = TypeVar("_T2_co",
86
87
                 bound=SingletonPrompt,
                 default=SingletonPrompt,
88
                 covariant=True)
89

90
91
92

# TODO: Make fields ReadOnly once mypy supports it
class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
93
94
95
    """
    Represents an encoder/decoder model input prompt,
    comprising an explicit encoder prompt and a decoder prompt.
96

97
98
99
    The encoder and decoder prompts, respectively, may be formatted
    according to any of the :class:`SingletonPrompt` schemas,
    and are not required to have the same schema.
100

101
102
103
    Only the encoder prompt may have multi-modal data. mm_processor_kwargs
    should be at the top-level, and should not be set in the encoder/decoder
    prompts, since they are agnostic to the encoder/decoder.
104

105
    Note that an :class:`ExplicitEncoderDecoderPrompt` may not
106
    be used as an input to a decoder-only model,
107
    and that the :code:`encoder_prompt` and :code:`decoder_prompt`
108
    fields of this data structure themselves must be
109
    :class:`SingletonPrompt` instances.
110
111
    """

112
    encoder_prompt: _T1_co
113

114
    decoder_prompt: Optional[_T2_co]
115

116
117
    mm_processor_kwargs: NotRequired[Dict[str, Any]]

118

119
PromptType = Union[SingletonPrompt, ExplicitEncoderDecoderPrompt]
120
121
122
123
124
125
126
127
128
129
130
"""
Set of possible schemas for an LLM input, including
both decoder-only and encoder/decoder input types:

- A text prompt (:class:`str` or :class:`TextPrompt`)
- A tokenized prompt (:class:`TokensPrompt`)
- A single data structure containing both an encoder and a decoder prompt
  (:class:`ExplicitEncoderDecoderPrompt`)
"""


131
132
class TokenInputs(TypedDict):
    """Represents token-based inputs."""
133
134
135
136

    type: Literal["token"]
    """The type of inputs."""

137
    prompt_token_ids: List[int]
138
139
    """The token IDs of the prompt."""

140
141
142
    token_type_ids: NotRequired[List[int]]
    """The token type IDs of the prompt."""

143
    prompt: NotRequired[str]
144
145
146
147
    """
    The original prompt text corresponding to the token IDs, if available.
    """

148
    multi_modal_data: NotRequired["MultiModalDataDict"]
149
150
151
152
153
    """
    Optional multi-modal data to pass to the model,
    if the model supports it.
    """

154
155
156
157
158
159
    multi_modal_inputs: NotRequired["MultiModalKwargs"]
    """
    Optional multi-modal inputs to pass to the model,
    if the model supports it.
    """

160
    multi_modal_placeholders: NotRequired["MultiModalPlaceholderDict"]
161
162
163
164
    """
    Placeholder ranges for the multi-modal data.
    """

165
    mm_processor_kwargs: NotRequired[Dict[str, Any]]
166
167
168
169
170
171
172
    """
    Optional multi-modal processor kwargs to be forwarded to the
    multimodal input mapper & processor. Note that if multiple modalities
    have registered mappers etc for the model being considered, we attempt
    to pass the mm_processor_kwargs to each of them.
    """

173

174
175
def token_inputs(
    prompt_token_ids: List[int],
176
    token_type_ids: Optional[List[int]] = None,
177
178
    prompt: Optional[str] = None,
    multi_modal_data: Optional["MultiModalDataDict"] = None,
179
    multi_modal_inputs: Optional["MultiModalKwargs"] = None,
180
    multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None,
181
182
183
    mm_processor_kwargs: Optional[Dict[str, Any]] = None,
) -> TokenInputs:
    """Construct :class:`TokenInputs` from optional values."""
184
    inputs = TokenInputs(type="token", prompt_token_ids=prompt_token_ids)
185
186
187

    if prompt is not None:
        inputs["prompt"] = prompt
188
189
    if token_type_ids is not None:
        inputs["token_type_ids"] = token_type_ids
190
191
    if multi_modal_data is not None:
        inputs["multi_modal_data"] = multi_modal_data
192
193
    if multi_modal_inputs is not None:
        inputs["multi_modal_inputs"] = multi_modal_inputs
194
195
    if multi_modal_placeholders is not None:
        inputs["multi_modal_placeholders"] = multi_modal_placeholders
196
197
198
199
200
201
    if mm_processor_kwargs is not None:
        inputs["mm_processor_kwargs"] = mm_processor_kwargs

    return inputs


202
DecoderOnlyInputs = Union[TokenInputs, "MultiModalInputsV2"]
203
204
205
206
207
208
209
"""
The inputs in :class:`~vllm.LLMEngine` before they are
passed to the model executor.
This specifies the data required for decoder-only models.
"""


210
class EncoderDecoderInputs(TypedDict):
211
212
213
214
215
216
    """
    The inputs in :class:`~vllm.LLMEngine` before they are
    passed to the model executor.

    This specifies the required data for encoder-decoder models.
    """
217
    encoder: Union[TokenInputs, "MultiModalInputsV2"]
218
    """The inputs for the encoder portion."""
219

220
    decoder: Union[TokenInputs, "MultiModalInputsV2"]
221
    """The inputs for the decoder portion."""
222

223

224
SingletonInputs = Union[TokenInputs, "MultiModalInputsV2"]
225
226
227
228
229
"""
A processed :class:`SingletonPrompt` which can be passed to
:class:`vllm.sequence.Sequence`.
"""

230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255

@dataclass
class SingletonInputsAdapter:
    """
    Unified interface to access the components of :class:`SingletonInputs`.
    """
    inputs: SingletonInputs

    @cached_property
    def prompt(self) -> Optional[str]:
        inputs = self.inputs

        if inputs["type"] == "token" or inputs["type"] == "multimodal":
            return inputs.get("prompt")

        assert_never(inputs)

    @cached_property
    def prompt_token_ids(self) -> List[int]:
        inputs = self.inputs

        if inputs["type"] == "token" or inputs["type"] == "multimodal":
            return inputs.get("prompt_token_ids", [])

        assert_never(inputs)

256
257
258
259
260
261
262
263
264
    @cached_property
    def token_type_ids(self) -> List[int]:
        inputs = self.inputs

        if inputs["type"] == "token" or inputs["type"] == "multimodal":
            return inputs.get("token_type_ids", [])

        assert_never(inputs)

265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
    @cached_property
    def prompt_embeds(self) -> Optional[torch.Tensor]:
        inputs = self.inputs

        if inputs["type"] == "token" or inputs["type"] == "multimodal":
            return None

        assert_never(inputs)

    @cached_property
    def multi_modal_data(self) -> "MultiModalDataDict":
        inputs = self.inputs

        if inputs["type"] == "token":
            return inputs.get("multi_modal_data", {})

        if inputs["type"] == "multimodal":
282
283
284
285
286
287
288
289
290
291
292
293
            return inputs.get("mm_kwargs", {})

        assert_never(inputs)

    @cached_property
    def multi_modal_inputs(self) -> Union[Dict, "MultiModalKwargs"]:
        inputs = self.inputs

        if inputs["type"] == "token":
            return inputs.get("multi_modal_inputs", {})

        if inputs["type"] == "multimodal":
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
            return inputs.get("mm_kwargs", {})

        assert_never(inputs)

    @cached_property
    def multi_modal_placeholders(self) -> "MultiModalPlaceholderDict":
        inputs = self.inputs

        if inputs["type"] == "token":
            return inputs.get("multi_modal_placeholders", {})

        if inputs["type"] == "multimodal":
            return inputs.get("mm_placeholders", {})

        assert_never(inputs)

    @cached_property
    def mm_processor_kwargs(self) -> Dict[str, Any]:
        inputs = self.inputs

        if inputs["type"] == "token":
            return inputs.get("mm_processor_kwargs", {})

        if inputs["type"] == "multimodal":
            return {}

        assert_never(inputs)


323
324
325
326
ProcessorInputs = Union[DecoderOnlyInputs, EncoderDecoderInputs]
"""
The inputs to :data:`vllm.inputs.InputProcessor`.
"""
327

328
329
_T1 = TypeVar("_T1", bound=SingletonPrompt, default=SingletonPrompt)
_T2 = TypeVar("_T2", bound=SingletonPrompt, default=SingletonPrompt)
330
331


332
333
334
def build_explicit_enc_dec_prompt(
    encoder_prompt: _T1,
    decoder_prompt: Optional[_T2],
335
    mm_processor_kwargs: Optional[Dict[str, Any]] = None,
336
) -> ExplicitEncoderDecoderPrompt[_T1, _T2]:
337
338
339
340
341
342
    if mm_processor_kwargs is None:
        mm_processor_kwargs = {}
    return ExplicitEncoderDecoderPrompt(
        encoder_prompt=encoder_prompt,
        decoder_prompt=decoder_prompt,
        mm_processor_kwargs=mm_processor_kwargs)
343
344
345
346
347


def zip_enc_dec_prompts(
    enc_prompts: Iterable[_T1],
    dec_prompts: Iterable[Optional[_T2]],
348
349
    mm_processor_kwargs: Optional[Union[Iterable[Dict[str, Any]],
                                        Dict[str, Any]]] = None,
350
) -> List[ExplicitEncoderDecoderPrompt[_T1, _T2]]:
351
    """
352
    Zip encoder and decoder prompts together into a list of
353
354
355
356
357
    :class:`ExplicitEncoderDecoderPrompt` instances.
    
    ``mm_processor_kwargs`` may also be provided; if a dict is passed, the same
    dictionary will be used for every encoder/decoder prompt. If an iterable is
    provided, it will be zipped with the encoder/decoder prompts.
358
359
    """
    if mm_processor_kwargs is None:
360
361
        mm_processor_kwargs = cast(Dict[str, Any], {})
    if isinstance(mm_processor_kwargs, dict):
362
        return [
363
364
365
            build_explicit_enc_dec_prompt(
                encoder_prompt, decoder_prompt,
                cast(Dict[str, Any], mm_processor_kwargs))
366
367
368
            for (encoder_prompt,
                 decoder_prompt) in zip(enc_prompts, dec_prompts)
        ]
369
    return [
370
371
372
373
        build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt,
                                      mm_proc_kwargs)
        for (encoder_prompt, decoder_prompt, mm_proc_kwargs
             ) in zip(enc_prompts, dec_prompts, mm_processor_kwargs)
374
375
    ]

376

377
378
379
380
381
382
def to_enc_dec_tuple_list(
    enc_dec_prompts: Iterable[ExplicitEncoderDecoderPrompt[_T1, _T2]],
) -> List[Tuple[_T1, Optional[_T2]]]:
    return [(enc_dec_prompt["encoder_prompt"],
             enc_dec_prompt["decoder_prompt"])
            for enc_dec_prompt in enc_dec_prompts]