data.py 12.9 KB
Newer Older
1
2
from dataclasses import dataclass
from functools import cached_property
3
from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List, Literal,
4
                    Optional, Tuple, Union, cast)
5

6
7
import torch
from typing_extensions import NotRequired, TypedDict, TypeVar, assert_never
8
9

if TYPE_CHECKING:
10
11
    from vllm.multimodal import (MultiModalDataDict, MultiModalKwargs,
                                 MultiModalPlaceholderDict)
12
    from vllm.multimodal.inputs import MultiModalInputsV2
13
14
15
16
17
18
19
20


class TextPrompt(TypedDict):
    """Schema for a text prompt."""

    prompt: str
    """The input text to be tokenized before passing to the model."""

21
    multi_modal_data: NotRequired["MultiModalDataDict"]
22
23
24
25
26
    """
    Optional multi-modal data to pass to the model,
    if the model supports it.
    """

27
28
29
30
31
32
33
34
    mm_processor_kwargs: NotRequired[Dict[str, Any]]
    """
    Optional multi-modal processor kwargs to be forwarded to the
    multimodal input mapper & processor. Note that if multiple modalities
    have registered mappers etc for the model being considered, we attempt
    to pass the mm_processor_kwargs to each of them.
    """

35
36
37
38
39
40
41

class TokensPrompt(TypedDict):
    """Schema for a tokenized prompt."""

    prompt_token_ids: List[int]
    """A list of token IDs to pass to the model."""

42
43
44
    token_type_ids: NotRequired[List[int]]
    """A list of token type IDs to pass to the cross encoder model."""

45
    multi_modal_data: NotRequired["MultiModalDataDict"]
46
    """
47
    Optional multi-modal data to pass to the model,
48
49
50
    if the model supports it.
    """

51
52
    mm_processor_kwargs: NotRequired[Dict[str, Any]]
    """
53
    Optional multi-modal processor kwargs to be forwarded to the
54
55
56
57
58
    multimodal input mapper & processor. Note that if multiple modalities
    have registered mappers etc for the model being considered, we attempt
    to pass the mm_processor_kwargs to each of them.
    """

59

60
SingletonPrompt = Union[str, TextPrompt, TokensPrompt]
61
"""
62
Set of possible schemas for a single prompt:
63
64
65

- A text prompt (:class:`str` or :class:`TextPrompt`)
- A tokenized prompt (:class:`TokensPrompt`)
66
67
68
69
70

Note that "singleton" is as opposed to a data structure
which encapsulates multiple prompts, i.e. of the sort
which may be utilized for encoder/decoder models when
the user desires to express both the encoder & decoder
71
prompts explicitly, i.e. :class:`ExplicitEncoderDecoderPrompt`
72

73
A prompt of type :class:`SingletonPrompt` may be employed
74
75
76
77
as (1) input to a decoder-only model, (2) input to
the encoder of an encoder/decoder model, in the scenario
where the decoder-prompt is not specified explicitly, or
(3) as a member of a larger data structure encapsulating
78
more than one prompt, i.e. :class:`ExplicitEncoderDecoderPrompt`
79
80
"""

81
_T1_co = TypeVar("_T1_co",
82
83
                 bound=SingletonPrompt,
                 default=SingletonPrompt,
84
85
                 covariant=True)
_T2_co = TypeVar("_T2_co",
86
87
                 bound=SingletonPrompt,
                 default=SingletonPrompt,
88
                 covariant=True)
89

90
91
92

# TODO: Make fields ReadOnly once mypy supports it
class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
93
94
95
    """
    Represents an encoder/decoder model input prompt,
    comprising an explicit encoder prompt and a decoder prompt.
96

97
98
99
    The encoder and decoder prompts, respectively, may be formatted
    according to any of the :class:`SingletonPrompt` schemas,
    and are not required to have the same schema.
100

101
102
103
    Only the encoder prompt may have multi-modal data. mm_processor_kwargs
    should be at the top-level, and should not be set in the encoder/decoder
    prompts, since they are agnostic to the encoder/decoder.
104

105
    Note that an :class:`ExplicitEncoderDecoderPrompt` may not
106
    be used as an input to a decoder-only model,
107
    and that the :code:`encoder_prompt` and :code:`decoder_prompt`
108
    fields of this data structure themselves must be
109
    :class:`SingletonPrompt` instances.
110
111
    """

112
    encoder_prompt: _T1_co
113

114
    decoder_prompt: Optional[_T2_co]
115

116
117
    mm_processor_kwargs: NotRequired[Dict[str, Any]]

118

119
PromptType = Union[SingletonPrompt, ExplicitEncoderDecoderPrompt]
120
121
122
123
124
125
126
127
128
129
130
"""
Set of possible schemas for an LLM input, including
both decoder-only and encoder/decoder input types:

- A text prompt (:class:`str` or :class:`TextPrompt`)
- A tokenized prompt (:class:`TokensPrompt`)
- A single data structure containing both an encoder and a decoder prompt
  (:class:`ExplicitEncoderDecoderPrompt`)
"""


131
132
class TokenInputs(TypedDict):
    """Represents token-based inputs."""
133
134
135
136

    type: Literal["token"]
    """The type of inputs."""

137
    prompt_token_ids: List[int]
138
139
    """The token IDs of the prompt."""

140
141
142
    token_type_ids: NotRequired[List[int]]
    """The token type IDs of the prompt."""

143
    prompt: NotRequired[str]
144
145
146
147
    """
    The original prompt text corresponding to the token IDs, if available.
    """

148
    multi_modal_data: NotRequired["MultiModalDataDict"]
149
150
151
152
153
    """
    Optional multi-modal data to pass to the model,
    if the model supports it.
    """

154
155
156
157
158
159
    multi_modal_inputs: NotRequired["MultiModalKwargs"]
    """
    Optional multi-modal inputs to pass to the model,
    if the model supports it.
    """

160
    multi_modal_placeholders: NotRequired["MultiModalPlaceholderDict"]
161
162
163
164
    """
    Placeholder ranges for the multi-modal data.
    """

165
166
167
168
169
    multi_modal_hashes: NotRequired[List[str]]
    """
    The hashes of the multi-modal data.
    """

170
    mm_processor_kwargs: NotRequired[Dict[str, Any]]
171
172
173
174
175
176
177
    """
    Optional multi-modal processor kwargs to be forwarded to the
    multimodal input mapper & processor. Note that if multiple modalities
    have registered mappers etc for the model being considered, we attempt
    to pass the mm_processor_kwargs to each of them.
    """

178

179
180
def token_inputs(
    prompt_token_ids: List[int],
181
    token_type_ids: Optional[List[int]] = None,
182
183
    prompt: Optional[str] = None,
    multi_modal_data: Optional["MultiModalDataDict"] = None,
184
    multi_modal_inputs: Optional["MultiModalKwargs"] = None,
185
    multi_modal_hashes: Optional[List[str]] = None,
186
    multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None,
187
188
189
    mm_processor_kwargs: Optional[Dict[str, Any]] = None,
) -> TokenInputs:
    """Construct :class:`TokenInputs` from optional values."""
190
    inputs = TokenInputs(type="token", prompt_token_ids=prompt_token_ids)
191
192
193

    if prompt is not None:
        inputs["prompt"] = prompt
194
195
    if token_type_ids is not None:
        inputs["token_type_ids"] = token_type_ids
196
197
    if multi_modal_data is not None:
        inputs["multi_modal_data"] = multi_modal_data
198
199
    if multi_modal_inputs is not None:
        inputs["multi_modal_inputs"] = multi_modal_inputs
200
201
    if multi_modal_hashes is not None:
        inputs["multi_modal_hashes"] = multi_modal_hashes
202
203
    if multi_modal_placeholders is not None:
        inputs["multi_modal_placeholders"] = multi_modal_placeholders
204
205
206
207
208
209
    if mm_processor_kwargs is not None:
        inputs["mm_processor_kwargs"] = mm_processor_kwargs

    return inputs


210
DecoderOnlyInputs = Union[TokenInputs, "MultiModalInputsV2"]
211
212
213
214
215
216
217
"""
The inputs in :class:`~vllm.LLMEngine` before they are
passed to the model executor.
This specifies the data required for decoder-only models.
"""


218
class EncoderDecoderInputs(TypedDict):
219
220
221
222
223
224
    """
    The inputs in :class:`~vllm.LLMEngine` before they are
    passed to the model executor.

    This specifies the required data for encoder-decoder models.
    """
225
    encoder: Union[TokenInputs, "MultiModalInputsV2"]
226
    """The inputs for the encoder portion."""
227

228
    decoder: Union[TokenInputs, "MultiModalInputsV2"]
229
    """The inputs for the decoder portion."""
230

231

232
SingletonInputs = Union[TokenInputs, "MultiModalInputsV2"]
233
234
235
236
237
"""
A processed :class:`SingletonPrompt` which can be passed to
:class:`vllm.sequence.Sequence`.
"""

238
239
240
241
242
243
244
245
246
247
248
249
250
251
252

@dataclass
class SingletonInputsAdapter:
    """
    Unified interface to access the components of :class:`SingletonInputs`.
    """
    inputs: SingletonInputs

    @cached_property
    def prompt(self) -> Optional[str]:
        inputs = self.inputs

        if inputs["type"] == "token" or inputs["type"] == "multimodal":
            return inputs.get("prompt")

253
        assert_never(inputs)  # type: ignore[arg-type]
254
255
256
257
258
259
260
261

    @cached_property
    def prompt_token_ids(self) -> List[int]:
        inputs = self.inputs

        if inputs["type"] == "token" or inputs["type"] == "multimodal":
            return inputs.get("prompt_token_ids", [])

262
        assert_never(inputs)  # type: ignore[arg-type]
263

264
265
266
267
268
269
270
    @cached_property
    def token_type_ids(self) -> List[int]:
        inputs = self.inputs

        if inputs["type"] == "token" or inputs["type"] == "multimodal":
            return inputs.get("token_type_ids", [])

271
        assert_never(inputs)  # type: ignore[arg-type]
272

273
274
275
276
277
278
279
    @cached_property
    def prompt_embeds(self) -> Optional[torch.Tensor]:
        inputs = self.inputs

        if inputs["type"] == "token" or inputs["type"] == "multimodal":
            return None

280
        assert_never(inputs)  # type: ignore[arg-type]
281
282
283
284
285
286
287
288
289

    @cached_property
    def multi_modal_data(self) -> "MultiModalDataDict":
        inputs = self.inputs

        if inputs["type"] == "token":
            return inputs.get("multi_modal_data", {})

        if inputs["type"] == "multimodal":
290
291
            return inputs.get("mm_kwargs", {})

292
        assert_never(inputs)  # type: ignore[arg-type]
293
294
295
296
297
298
299
300
301

    @cached_property
    def multi_modal_inputs(self) -> Union[Dict, "MultiModalKwargs"]:
        inputs = self.inputs

        if inputs["type"] == "token":
            return inputs.get("multi_modal_inputs", {})

        if inputs["type"] == "multimodal":
302
303
            return inputs.get("mm_kwargs", {})

304
        assert_never(inputs)  # type: ignore[arg-type]
305

306
307
308
309
310
311
312
313
    @cached_property
    def multi_modal_hashes(self) -> List[str]:
        inputs = self.inputs

        if inputs["type"] == "token":
            return inputs.get("multi_modal_hashes", [])

        if inputs["type"] == "multimodal":
314
315
            # only the case when we use MultiModalInputsV2
            return inputs.get("mm_hashes", [])  # type: ignore[return-value]
316

317
        assert_never(inputs)  # type: ignore[arg-type]
318

319
320
321
322
323
324
325
326
327
328
    @cached_property
    def multi_modal_placeholders(self) -> "MultiModalPlaceholderDict":
        inputs = self.inputs

        if inputs["type"] == "token":
            return inputs.get("multi_modal_placeholders", {})

        if inputs["type"] == "multimodal":
            return inputs.get("mm_placeholders", {})

329
        assert_never(inputs)  # type: ignore[arg-type]
330
331
332
333
334
335
336
337
338
339
340

    @cached_property
    def mm_processor_kwargs(self) -> Dict[str, Any]:
        inputs = self.inputs

        if inputs["type"] == "token":
            return inputs.get("mm_processor_kwargs", {})

        if inputs["type"] == "multimodal":
            return {}

341
        assert_never(inputs)  # type: ignore[arg-type]
342
343


344
345
346
347
ProcessorInputs = Union[DecoderOnlyInputs, EncoderDecoderInputs]
"""
The inputs to :data:`vllm.inputs.InputProcessor`.
"""
348

349
350
_T1 = TypeVar("_T1", bound=SingletonPrompt, default=SingletonPrompt)
_T2 = TypeVar("_T2", bound=SingletonPrompt, default=SingletonPrompt)
351
352


353
354
355
def build_explicit_enc_dec_prompt(
    encoder_prompt: _T1,
    decoder_prompt: Optional[_T2],
356
    mm_processor_kwargs: Optional[Dict[str, Any]] = None,
357
) -> ExplicitEncoderDecoderPrompt[_T1, _T2]:
358
359
360
361
362
363
    if mm_processor_kwargs is None:
        mm_processor_kwargs = {}
    return ExplicitEncoderDecoderPrompt(
        encoder_prompt=encoder_prompt,
        decoder_prompt=decoder_prompt,
        mm_processor_kwargs=mm_processor_kwargs)
364
365
366
367
368


def zip_enc_dec_prompts(
    enc_prompts: Iterable[_T1],
    dec_prompts: Iterable[Optional[_T2]],
369
370
    mm_processor_kwargs: Optional[Union[Iterable[Dict[str, Any]],
                                        Dict[str, Any]]] = None,
371
) -> List[ExplicitEncoderDecoderPrompt[_T1, _T2]]:
372
    """
373
    Zip encoder and decoder prompts together into a list of
374
375
376
377
378
    :class:`ExplicitEncoderDecoderPrompt` instances.
    
    ``mm_processor_kwargs`` may also be provided; if a dict is passed, the same
    dictionary will be used for every encoder/decoder prompt. If an iterable is
    provided, it will be zipped with the encoder/decoder prompts.
379
380
    """
    if mm_processor_kwargs is None:
381
382
        mm_processor_kwargs = cast(Dict[str, Any], {})
    if isinstance(mm_processor_kwargs, dict):
383
        return [
384
385
386
            build_explicit_enc_dec_prompt(
                encoder_prompt, decoder_prompt,
                cast(Dict[str, Any], mm_processor_kwargs))
387
388
389
            for (encoder_prompt,
                 decoder_prompt) in zip(enc_prompts, dec_prompts)
        ]
390
    return [
391
392
393
394
        build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt,
                                      mm_proc_kwargs)
        for (encoder_prompt, decoder_prompt, mm_proc_kwargs
             ) in zip(enc_prompts, dec_prompts, mm_processor_kwargs)
395
396
    ]

397

398
399
400
401
402
403
def to_enc_dec_tuple_list(
    enc_dec_prompts: Iterable[ExplicitEncoderDecoderPrompt[_T1, _T2]],
) -> List[Tuple[_T1, Optional[_T2]]]:
    return [(enc_dec_prompt["encoder_prompt"],
             enc_dec_prompt["decoder_prompt"])
            for enc_dec_prompt in enc_dec_prompts]