profiling.py 6.55 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from abc import ABC, abstractmethod
4
5
from collections.abc import Mapping
from dataclasses import dataclass, field
6
from typing import TYPE_CHECKING, Generic
7
8
9
10
11

import numpy as np
import numpy.typing as npt
from PIL import Image

12
13
14
15
16
17
from vllm.config.multimodal import (
    AudioDummyOptions,
    BaseDummyOptions,
    ImageDummyOptions,
    VideoDummyOptions,
)
18
19
from vllm.logger import init_logger

20
21
22
23
24
25
26
27
from .inputs import MultiModalDataDict

if TYPE_CHECKING:
    from .processing import _I
else:
    from typing import TypeVar

    _I = TypeVar("_I")
28
29
30
31
32
33

logger = init_logger(__name__)


@dataclass
class ProcessorInputs:
34
35
    """
    Represents the keyword arguments to
36
    [`vllm.multimodal.processing.BaseMultiModalProcessor.apply`][].
37
    """
38

39
    prompt: str | list[int]
40
41
    mm_data: MultiModalDataDict
    hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict)
42
    tokenization_kwargs: Mapping[str, object] = field(default_factory=dict)
43
44


45
class BaseDummyInputsBuilder(ABC, Generic[_I]):
46
    """
47
    Abstract base class that constructs the dummy data to profile
48
49
50
    multi-modal models.
    """

51
    def __init__(self, info: _I) -> None:
52
53
        super().__init__()

54
        self.info = info
55

56
    @abstractmethod
57
58
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        """
59
        Build the text input corresponding to `mm_counts`.
60
        """
61
        raise NotImplementedError
62

63
    @abstractmethod
64
65
66
67
    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
68
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
69
70
71
72
    ) -> MultiModalDataDict:
        """
        Build the multimodal input which, after processing, results in
        the maximum possible number of placeholder tokens.
73
74
75
76
77
78

        Args:
            seq_len: Sequence length
            mm_counts: Count of items per modality
            mm_options: Configurable options per modality (optional).
                       If None, use model defaults for backward compatibility.
79
                       If provided, models can use these to customize dummy
80
                       data generation.
81
82
83
        """
        raise NotImplementedError

84
85
86
87
    def get_dummy_processor_inputs(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
88
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
89
90
    ) -> ProcessorInputs:
        """
91
        Build the input which, after processing, results in
92
        the maximum possible number of placeholder tokens.
93
94
95
96
97

        Args:
            seq_len: Sequence length
            mm_counts: Count of items per modality
            mm_options: Configurable options per modality (optional)
98
        """
99
        dummy_text = self.get_dummy_text(mm_counts)
100
101
102
103

        # Use the unified function for both legacy and configurable cases
        dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts, mm_options)

104
        tokenization_kwargs = {"truncation": False}
105

106
107
108
109
110
        return ProcessorInputs(
            prompt=dummy_text,
            mm_data=dummy_mm_data,
            tokenization_kwargs=tokenization_kwargs,
        )
111
112
113
114
115
116

    def _get_dummy_audios(
        self,
        *,
        length: int,
        num_audios: int,
117
        overrides: AudioDummyOptions | None = None,
118
    ) -> list[npt.NDArray]:
119
120
        if num_audios == 0:
            return []
121
122
123
124
        if overrides and overrides.length:
            if overrides.length > length:
                logger.warning(
                    "audio.length override (%d) exceeds model's "
125
126
127
128
                    "maximum length (%d), will be ignored",
                    overrides.length,
                    length,
                )
129
            length = min(length, overrides.length)
130
        audio = np.zeros((length,))
131
132
133
134
135
136
137
138
        return [audio] * num_audios

    def _get_dummy_images(
        self,
        *,
        width: int,
        height: int,
        num_images: int,
139
        overrides: ImageDummyOptions | None = None,
140
    ) -> list[Image.Image]:
141
142
        if num_images == 0:
            return []
143
144
145
146
147
        if overrides:
            if overrides.width:
                if overrides.width > width:
                    logger.warning(
                        "image.width override (%d) exceeds model's "
148
149
150
151
                        "maximum width (%d), will be ignored",
                        overrides.width,
                        width,
                    )
152
153
154
155
156
157
                width = min(width, overrides.width)
            if overrides.height:
                if overrides.height > height:
                    logger.warning(
                        "image.height override (%d) exceeds model's "
                        "maximum height (%d), will be ignored",
158
159
160
                        overrides.height,
                        height,
                    )
161
                height = min(height, overrides.height)
162
        image = Image.new("RGB", (width, height), color=255)
163
164
165
166
167
168
169
170
171
        return [image] * num_images

    def _get_dummy_videos(
        self,
        *,
        width: int,
        height: int,
        num_frames: int,
        num_videos: int,
172
        overrides: VideoDummyOptions | None = None,
173
    ) -> list[npt.NDArray]:
174
175
        if num_videos == 0:
            return []
176
177
178
179
180
181
        if overrides:
            if overrides.num_frames:
                if overrides.num_frames > num_frames:
                    logger.warning(
                        "video.num_frames override (%d) exceeds model's "
                        "maximum number of frames (%d), will be ignored",
182
183
184
                        overrides.num_frames,
                        num_frames,
                    )
185
186
187
188
189
                num_frames = min(num_frames, overrides.num_frames)
            if overrides.width:
                if overrides.width > width:
                    logger.warning(
                        "video.width override (%d) exceeds model's "
190
191
192
193
                        "maximum width (%d), will be ignored",
                        overrides.width,
                        width,
                    )
194
195
196
197
198
199
                width = min(width, overrides.width)
            if overrides.height:
                if overrides.height > height:
                    logger.warning(
                        "video.height override (%d) exceeds model's "
                        "maximum height (%d), will be ignored",
200
201
202
                        overrides.height,
                        height,
                    )
203
                height = min(height, overrides.height)
204
        video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8)
205
        return [video] * num_videos