dummy_inputs.py 6.97 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from abc import ABC, abstractmethod
4
5
from collections.abc import Mapping
from dataclasses import dataclass, field
6
from typing import Generic, TypeVar
7
8
9
10
11

import numpy as np
import numpy.typing as npt
from PIL import Image

12
13
14
15
16
17
from vllm.config.multimodal import (
    AudioDummyOptions,
    BaseDummyOptions,
    ImageDummyOptions,
    VideoDummyOptions,
)
18
19
from vllm.logger import init_logger

20
from ..inputs import MultiModalDataDict
21
from ..parse import MultiModalDataItems
22
from .context import BaseProcessingInfo
23

24
_I = TypeVar("_I", bound=BaseProcessingInfo)
25
26
27
28
29
30

logger = init_logger(__name__)


@dataclass
class ProcessorInputs:
31
32
    """
    Represents the keyword arguments to
33
    [`vllm.multimodal.processing.BaseMultiModalProcessor.apply`][].
34
    """
35

36
    prompt: str | list[int]
37
    mm_items: MultiModalDataItems
38
    hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict)
39
    tokenization_kwargs: Mapping[str, object] = field(default_factory=dict)
40
41


42
class BaseDummyInputsBuilder(ABC, Generic[_I]):
43
    """
44
    Abstract base class that constructs the dummy data to profile
45
46
47
    multi-modal models.
    """

48
    def __init__(self, info: _I) -> None:
49
50
        super().__init__()

51
        self.info = info
52

53
    @abstractmethod
54
55
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        """
56
        Build the text input corresponding to `mm_counts`.
57
        """
58
        raise NotImplementedError
59

60
    @abstractmethod
61
62
63
64
    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
65
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
66
        mm_processor_kwargs: Mapping[str, object] | None = None,
67
68
69
70
    ) -> MultiModalDataDict:
        """
        Build the multimodal input which, after processing, results in
        the maximum possible number of placeholder tokens.
71
72
73
74
75
76

        Args:
            seq_len: Sequence length
            mm_counts: Count of items per modality
            mm_options: Configurable options per modality (optional).
                       If None, use model defaults for backward compatibility.
77
                       If provided, models can use these to customize dummy
78
                       data generation.
79
80
81
        """
        raise NotImplementedError

82
83
84
85
    def get_dummy_processor_inputs(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
86
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
87
        mm_processor_kwargs: Mapping[str, object] | None = None,
88
89
    ) -> ProcessorInputs:
        """
90
        Build the input which, after processing, results in
91
        the maximum possible number of placeholder tokens.
92
93
94
95
96

        Args:
            seq_len: Sequence length
            mm_counts: Count of items per modality
            mm_options: Configurable options per modality (optional)
97
98
            mm_processor_kwargs: Additional keyword arguments
                                for hf_processor (optional)
99
        """
100
        dummy_text = self.get_dummy_text(mm_counts)
101
102
103
104
105
106
        dummy_mm_data = self.get_dummy_mm_data(
            seq_len,
            mm_counts,
            mm_options,
            mm_processor_kwargs=mm_processor_kwargs,
        )
107
        dummy_mm_items = self.info.parse_mm_data(dummy_mm_data, validate=False)
108

109
        tokenization_kwargs = {"truncation": False}
110

111
112
        return ProcessorInputs(
            prompt=dummy_text,
113
            mm_items=dummy_mm_items,
114
            hf_processor_mm_kwargs=mm_processor_kwargs or {},
115
116
            tokenization_kwargs=tokenization_kwargs,
        )
117
118
119
120
121
122

    def _get_dummy_audios(
        self,
        *,
        length: int,
        num_audios: int,
123
        overrides: AudioDummyOptions | None = None,
124
    ) -> list[npt.NDArray]:
125
126
        if num_audios == 0:
            return []
127
128
129
130
        if overrides and overrides.length:
            if overrides.length > length:
                logger.warning(
                    "audio.length override (%d) exceeds model's "
131
132
133
134
                    "maximum length (%d), will be ignored",
                    overrides.length,
                    length,
                )
135
            length = min(length, overrides.length)
136
        audio = np.zeros((length,))
137
138
139
140
141
142
143
144
        return [audio] * num_audios

    def _get_dummy_images(
        self,
        *,
        width: int,
        height: int,
        num_images: int,
145
        overrides: ImageDummyOptions | None = None,
146
    ) -> list[Image.Image]:
147
148
        if num_images == 0:
            return []
149
150
151
152
153
        if overrides:
            if overrides.width:
                if overrides.width > width:
                    logger.warning(
                        "image.width override (%d) exceeds model's "
154
155
156
157
                        "maximum width (%d), will be ignored",
                        overrides.width,
                        width,
                    )
158
159
160
161
162
163
                width = min(width, overrides.width)
            if overrides.height:
                if overrides.height > height:
                    logger.warning(
                        "image.height override (%d) exceeds model's "
                        "maximum height (%d), will be ignored",
164
165
166
                        overrides.height,
                        height,
                    )
167
                height = min(height, overrides.height)
168
        image = Image.new("RGB", (width, height), color=255)
169
170
171
172
173
174
175
176
177
        return [image] * num_images

    def _get_dummy_videos(
        self,
        *,
        width: int,
        height: int,
        num_frames: int,
        num_videos: int,
178
        overrides: VideoDummyOptions | None = None,
179
    ) -> list[npt.NDArray]:
180
181
        if num_videos == 0:
            return []
182
183
184
185
186
187
        if overrides:
            if overrides.num_frames:
                if overrides.num_frames > num_frames:
                    logger.warning(
                        "video.num_frames override (%d) exceeds model's "
                        "maximum number of frames (%d), will be ignored",
188
189
190
                        overrides.num_frames,
                        num_frames,
                    )
191
192
193
194
195
                num_frames = min(num_frames, overrides.num_frames)
            if overrides.width:
                if overrides.width > width:
                    logger.warning(
                        "video.width override (%d) exceeds model's "
196
197
198
199
                        "maximum width (%d), will be ignored",
                        overrides.width,
                        width,
                    )
200
201
202
203
204
205
                width = min(width, overrides.width)
            if overrides.height:
                if overrides.height > height:
                    logger.warning(
                        "video.height override (%d) exceeds model's "
                        "maximum height (%d), will be ignored",
206
207
208
                        overrides.height,
                        height,
                    )
209
                height = min(height, overrides.height)
210
        video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8)
211
        return [video] * num_videos