dummy_inputs.py 6.57 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from abc import ABC, abstractmethod
4
5
from collections.abc import Mapping
from dataclasses import dataclass, field
6
from typing import Generic, TypeVar
7
8
9
10
11

import numpy as np
import numpy.typing as npt
from PIL import Image

12
13
14
15
16
17
from vllm.config.multimodal import (
    AudioDummyOptions,
    BaseDummyOptions,
    ImageDummyOptions,
    VideoDummyOptions,
)
18
19
from vllm.logger import init_logger

20
from ..inputs import MultiModalDataDict
21
from ..parse import MultiModalDataItems
22
from .context import BaseProcessingInfo
23

24
_I = TypeVar("_I", bound=BaseProcessingInfo)
25
26
27
28
29
30

logger = init_logger(__name__)


@dataclass
class ProcessorInputs:
31
32
    """
    Represents the keyword arguments to
33
    [`vllm.multimodal.processing.BaseMultiModalProcessor.apply`][].
34
    """
35

36
    prompt: str | list[int]
37
    mm_items: MultiModalDataItems
38
    hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict)
39
    tokenization_kwargs: Mapping[str, object] = field(default_factory=dict)
40
41


42
class BaseDummyInputsBuilder(ABC, Generic[_I]):
43
    """
44
    Abstract base class that constructs the dummy data to profile
45
46
47
    multi-modal models.
    """

48
    def __init__(self, info: _I) -> None:
49
50
        super().__init__()

51
        self.info = info
52

53
    @abstractmethod
54
55
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        """
56
        Build the text input corresponding to `mm_counts`.
57
        """
58
        raise NotImplementedError
59

60
    @abstractmethod
61
62
63
64
    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
65
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
66
67
68
69
    ) -> MultiModalDataDict:
        """
        Build the multimodal input which, after processing, results in
        the maximum possible number of placeholder tokens.
70
71
72
73
74
75

        Args:
            seq_len: Sequence length
            mm_counts: Count of items per modality
            mm_options: Configurable options per modality (optional).
                       If None, use model defaults for backward compatibility.
76
                       If provided, models can use these to customize dummy
77
                       data generation.
78
79
80
        """
        raise NotImplementedError

81
82
83
84
    def get_dummy_processor_inputs(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
85
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
86
87
    ) -> ProcessorInputs:
        """
88
        Build the input which, after processing, results in
89
        the maximum possible number of placeholder tokens.
90
91
92
93
94

        Args:
            seq_len: Sequence length
            mm_counts: Count of items per modality
            mm_options: Configurable options per modality (optional)
95
        """
96
        dummy_text = self.get_dummy_text(mm_counts)
97
        dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts, mm_options)
98
        dummy_mm_items = self.info.parse_mm_data(dummy_mm_data, validate=False)
99

100
        tokenization_kwargs = {"truncation": False}
101

102
103
        return ProcessorInputs(
            prompt=dummy_text,
104
            mm_items=dummy_mm_items,
105
106
            tokenization_kwargs=tokenization_kwargs,
        )
107
108
109
110
111
112

    def _get_dummy_audios(
        self,
        *,
        length: int,
        num_audios: int,
113
        overrides: AudioDummyOptions | None = None,
114
    ) -> list[npt.NDArray]:
115
116
        if num_audios == 0:
            return []
117
118
119
120
        if overrides and overrides.length:
            if overrides.length > length:
                logger.warning(
                    "audio.length override (%d) exceeds model's "
121
122
123
124
                    "maximum length (%d), will be ignored",
                    overrides.length,
                    length,
                )
125
            length = min(length, overrides.length)
126
        audio = np.zeros((length,))
127
128
129
130
131
132
133
134
        return [audio] * num_audios

    def _get_dummy_images(
        self,
        *,
        width: int,
        height: int,
        num_images: int,
135
        overrides: ImageDummyOptions | None = None,
136
    ) -> list[Image.Image]:
137
138
        if num_images == 0:
            return []
139
140
141
142
143
        if overrides:
            if overrides.width:
                if overrides.width > width:
                    logger.warning(
                        "image.width override (%d) exceeds model's "
144
145
146
147
                        "maximum width (%d), will be ignored",
                        overrides.width,
                        width,
                    )
148
149
150
151
152
153
                width = min(width, overrides.width)
            if overrides.height:
                if overrides.height > height:
                    logger.warning(
                        "image.height override (%d) exceeds model's "
                        "maximum height (%d), will be ignored",
154
155
156
                        overrides.height,
                        height,
                    )
157
                height = min(height, overrides.height)
158
        image = Image.new("RGB", (width, height), color=255)
159
160
161
162
163
164
165
166
167
        return [image] * num_images

    def _get_dummy_videos(
        self,
        *,
        width: int,
        height: int,
        num_frames: int,
        num_videos: int,
168
        overrides: VideoDummyOptions | None = None,
169
    ) -> list[npt.NDArray]:
170
171
        if num_videos == 0:
            return []
172
173
174
175
176
177
        if overrides:
            if overrides.num_frames:
                if overrides.num_frames > num_frames:
                    logger.warning(
                        "video.num_frames override (%d) exceeds model's "
                        "maximum number of frames (%d), will be ignored",
178
179
180
                        overrides.num_frames,
                        num_frames,
                    )
181
182
183
184
185
                num_frames = min(num_frames, overrides.num_frames)
            if overrides.width:
                if overrides.width > width:
                    logger.warning(
                        "video.width override (%d) exceeds model's "
186
187
188
189
                        "maximum width (%d), will be ignored",
                        overrides.width,
                        width,
                    )
190
191
192
193
194
195
                width = min(width, overrides.width)
            if overrides.height:
                if overrides.height > height:
                    logger.warning(
                        "video.height override (%d) exceeds model's "
                        "maximum height (%d), will be ignored",
196
197
198
                        overrides.height,
                        height,
                    )
199
                height = min(height, overrides.height)
200
        video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8)
201
        return [video] * num_videos