dummy_inputs.py 6.13 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from abc import ABC, abstractmethod
4
from collections.abc import Mapping
5
from typing import Generic, TypeVar
6
7
8
9
10

import numpy as np
import numpy.typing as npt
from PIL import Image

11
12
13
14
15
16
from vllm.config.multimodal import (
    AudioDummyOptions,
    BaseDummyOptions,
    ImageDummyOptions,
    VideoDummyOptions,
)
17
18
from vllm.logger import init_logger

19
20
from ..inputs import MultiModalDataDict
from .context import BaseProcessingInfo
21
from .inputs import ProcessorInputs
22

23
_I = TypeVar("_I", bound=BaseProcessingInfo)
24
25
26
27

logger = init_logger(__name__)


28
class BaseDummyInputsBuilder(ABC, Generic[_I]):
29
    """
30
    Abstract base class that constructs the dummy data to profile
31
32
33
    multi-modal models.
    """

34
    def __init__(self, info: _I) -> None:
35
36
        super().__init__()

37
        self.info = info
38

39
    @abstractmethod
40
41
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        """
42
        Build the text input corresponding to `mm_counts`.
43
        """
44
        raise NotImplementedError
45

46
    @abstractmethod
47
48
49
50
    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
51
        mm_options: Mapping[str, BaseDummyOptions],
52
53
54
55
    ) -> MultiModalDataDict:
        """
        Build the multimodal input which, after processing, results in
        the maximum possible number of placeholder tokens.
56
57
58
59
60
61

        Args:
            seq_len: Sequence length
            mm_counts: Count of items per modality
            mm_options: Configurable options per modality (optional).
                       If None, use model defaults for backward compatibility.
62
                       If provided, models can use these to customize dummy
63
                       data generation.
64
65
66
        """
        raise NotImplementedError

67
68
69
70
    def get_dummy_processor_inputs(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
71
        mm_options: Mapping[str, BaseDummyOptions],
72
73
    ) -> ProcessorInputs:
        """
74
        Build the input which, after processing, results in
75
        the maximum possible number of placeholder tokens.
76
77
78
79
80

        Args:
            seq_len: Sequence length
            mm_counts: Count of items per modality
            mm_options: Configurable options per modality (optional)
81
        """
82
        dummy_text = self.get_dummy_text(mm_counts)
83
        dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts, mm_options)
84
        dummy_mm_items = self.info.parse_mm_data(dummy_mm_data, validate=False)
85

86
        tokenization_kwargs = {"truncation": False}
87

88
89
        return ProcessorInputs(
            prompt=dummy_text,
90
            mm_data_items=dummy_mm_items,
91
92
            tokenization_kwargs=tokenization_kwargs,
        )
93
94
95
96
97
98

    def _get_dummy_audios(
        self,
        *,
        length: int,
        num_audios: int,
99
        overrides: AudioDummyOptions | None = None,
100
    ) -> list[npt.NDArray]:
101
102
        if num_audios == 0:
            return []
103
104
105
106
        if overrides and overrides.length:
            if overrides.length > length:
                logger.warning(
                    "audio.length override (%d) exceeds model's "
107
108
109
110
                    "maximum length (%d), will be ignored",
                    overrides.length,
                    length,
                )
111
            length = min(length, overrides.length)
112
        audio = np.zeros((length,))
113
114
115
116
117
118
119
120
        return [audio] * num_audios

    def _get_dummy_images(
        self,
        *,
        width: int,
        height: int,
        num_images: int,
121
        overrides: ImageDummyOptions | None = None,
122
    ) -> list[Image.Image]:
123
124
        if num_images == 0:
            return []
125
126
127
128
129
        if overrides:
            if overrides.width:
                if overrides.width > width:
                    logger.warning(
                        "image.width override (%d) exceeds model's "
130
131
132
133
                        "maximum width (%d), will be ignored",
                        overrides.width,
                        width,
                    )
134
135
136
137
138
139
                width = min(width, overrides.width)
            if overrides.height:
                if overrides.height > height:
                    logger.warning(
                        "image.height override (%d) exceeds model's "
                        "maximum height (%d), will be ignored",
140
141
142
                        overrides.height,
                        height,
                    )
143
                height = min(height, overrides.height)
144
        image = Image.new("RGB", (width, height), color=255)
145
146
147
148
149
150
151
152
153
        return [image] * num_images

    def _get_dummy_videos(
        self,
        *,
        width: int,
        height: int,
        num_frames: int,
        num_videos: int,
154
        overrides: VideoDummyOptions | None = None,
155
    ) -> list[npt.NDArray]:
156
157
        if num_videos == 0:
            return []
158
159
160
161
162
163
        if overrides:
            if overrides.num_frames:
                if overrides.num_frames > num_frames:
                    logger.warning(
                        "video.num_frames override (%d) exceeds model's "
                        "maximum number of frames (%d), will be ignored",
164
165
166
                        overrides.num_frames,
                        num_frames,
                    )
167
168
169
170
171
                num_frames = min(num_frames, overrides.num_frames)
            if overrides.width:
                if overrides.width > width:
                    logger.warning(
                        "video.width override (%d) exceeds model's "
172
173
174
175
                        "maximum width (%d), will be ignored",
                        overrides.width,
                        width,
                    )
176
177
178
179
180
181
                width = min(width, overrides.width)
            if overrides.height:
                if overrides.height > height:
                    logger.warning(
                        "video.height override (%d) exceeds model's "
                        "maximum height (%d), will be ignored",
182
183
184
                        overrides.height,
                        height,
                    )
185
                height = min(height, overrides.height)
186
        video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8)
187
        return [video] * num_videos