dummy_inputs.py 6.52 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from abc import ABC, abstractmethod
4
5
from collections.abc import Mapping
from dataclasses import dataclass, field
6
from typing import Generic, TypeVar
7
8
9
10
11

import numpy as np
import numpy.typing as npt
from PIL import Image

12
13
14
15
16
17
from vllm.config.multimodal import (
    AudioDummyOptions,
    BaseDummyOptions,
    ImageDummyOptions,
    VideoDummyOptions,
)
18
19
from vllm.logger import init_logger

20
21
from ..inputs import MultiModalDataDict
from .context import BaseProcessingInfo
22

23
_I = TypeVar("_I", bound=BaseProcessingInfo)
24
25
26
27
28
29

logger = init_logger(__name__)


@dataclass
class ProcessorInputs:
30
31
    """
    Represents the keyword arguments to
32
    [`vllm.multimodal.processing.BaseMultiModalProcessor.apply`][].
33
    """
34

35
    prompt: str | list[int]
36
37
    mm_data: MultiModalDataDict
    hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict)
38
    tokenization_kwargs: Mapping[str, object] = field(default_factory=dict)
39
40


41
class BaseDummyInputsBuilder(ABC, Generic[_I]):
42
    """
43
    Abstract base class that constructs the dummy data to profile
44
45
46
    multi-modal models.
    """

47
    def __init__(self, info: _I) -> None:
48
49
        super().__init__()

50
        self.info = info
51

52
    @abstractmethod
53
54
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        """
55
        Build the text input corresponding to `mm_counts`.
56
        """
57
        raise NotImplementedError
58

59
    @abstractmethod
60
61
62
63
    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
64
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
65
66
67
68
    ) -> MultiModalDataDict:
        """
        Build the multimodal input which, after processing, results in
        the maximum possible number of placeholder tokens.
69
70
71
72
73
74

        Args:
            seq_len: Sequence length
            mm_counts: Count of items per modality
            mm_options: Configurable options per modality (optional).
                       If None, use model defaults for backward compatibility.
75
                       If provided, models can use these to customize dummy
76
                       data generation.
77
78
79
        """
        raise NotImplementedError

80
81
82
83
    def get_dummy_processor_inputs(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
84
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
85
86
    ) -> ProcessorInputs:
        """
87
        Build the input which, after processing, results in
88
        the maximum possible number of placeholder tokens.
89
90
91
92
93

        Args:
            seq_len: Sequence length
            mm_counts: Count of items per modality
            mm_options: Configurable options per modality (optional)
94
        """
95
        dummy_text = self.get_dummy_text(mm_counts)
96
97
98
99

        # Use the unified function for both legacy and configurable cases
        dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts, mm_options)

100
        tokenization_kwargs = {"truncation": False}
101

102
103
104
105
106
        return ProcessorInputs(
            prompt=dummy_text,
            mm_data=dummy_mm_data,
            tokenization_kwargs=tokenization_kwargs,
        )
107
108
109
110
111
112

    def _get_dummy_audios(
        self,
        *,
        length: int,
        num_audios: int,
113
        overrides: AudioDummyOptions | None = None,
114
    ) -> list[npt.NDArray]:
115
116
        if num_audios == 0:
            return []
117
118
119
120
        if overrides and overrides.length:
            if overrides.length > length:
                logger.warning(
                    "audio.length override (%d) exceeds model's "
121
122
123
124
                    "maximum length (%d), will be ignored",
                    overrides.length,
                    length,
                )
125
            length = min(length, overrides.length)
126
        audio = np.zeros((length,))
127
128
129
130
131
132
133
134
        return [audio] * num_audios

    def _get_dummy_images(
        self,
        *,
        width: int,
        height: int,
        num_images: int,
135
        overrides: ImageDummyOptions | None = None,
136
    ) -> list[Image.Image]:
137
138
        if num_images == 0:
            return []
139
140
141
142
143
        if overrides:
            if overrides.width:
                if overrides.width > width:
                    logger.warning(
                        "image.width override (%d) exceeds model's "
144
145
146
147
                        "maximum width (%d), will be ignored",
                        overrides.width,
                        width,
                    )
148
149
150
151
152
153
                width = min(width, overrides.width)
            if overrides.height:
                if overrides.height > height:
                    logger.warning(
                        "image.height override (%d) exceeds model's "
                        "maximum height (%d), will be ignored",
154
155
156
                        overrides.height,
                        height,
                    )
157
                height = min(height, overrides.height)
158
        image = Image.new("RGB", (width, height), color=255)
159
160
161
162
163
164
165
166
167
        return [image] * num_images

    def _get_dummy_videos(
        self,
        *,
        width: int,
        height: int,
        num_frames: int,
        num_videos: int,
168
        overrides: VideoDummyOptions | None = None,
169
    ) -> list[npt.NDArray]:
170
171
        if num_videos == 0:
            return []
172
173
174
175
176
177
        if overrides:
            if overrides.num_frames:
                if overrides.num_frames > num_frames:
                    logger.warning(
                        "video.num_frames override (%d) exceeds model's "
                        "maximum number of frames (%d), will be ignored",
178
179
180
                        overrides.num_frames,
                        num_frames,
                    )
181
182
183
184
185
                num_frames = min(num_frames, overrides.num_frames)
            if overrides.width:
                if overrides.width > width:
                    logger.warning(
                        "video.width override (%d) exceeds model's "
186
187
188
189
                        "maximum width (%d), will be ignored",
                        overrides.width,
                        width,
                    )
190
191
192
193
194
195
                width = min(width, overrides.width)
            if overrides.height:
                if overrides.height > height:
                    logger.warning(
                        "video.height override (%d) exceeds model's "
                        "maximum height (%d), will be ignored",
196
197
198
                        overrides.height,
                        height,
                    )
199
                height = min(height, overrides.height)
200
        video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8)
201
        return [video] * num_videos