profiling.py 11.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from abc import ABC, abstractmethod
4
5
from collections.abc import Mapping
from dataclasses import dataclass, field
6
from typing import Generic, NamedTuple, TypeVar, cast
7
8
9
10
11

import numpy as np
import numpy.typing as npt
from PIL import Image

12
13
14
15
16
17
from vllm.config.multimodal import (
    AudioDummyOptions,
    BaseDummyOptions,
    ImageDummyOptions,
    VideoDummyOptions,
)
18
19
from vllm.logger import init_logger

20
21
22
23
24
25
26
27
28
29
30
31
from .inputs import (
    MultiModalDataDict,
    MultiModalEncDecInputs,
    MultiModalInputs,
    MultiModalKwargsItems,
    MultiModalPlaceholderDict,
)
from .processing import (
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    EncDecMultiModalProcessor,
)
32
33
34
35
36
37

logger = init_logger(__name__)


@dataclass
class ProcessorInputs:
38
39
    """
    Represents the keyword arguments to
40
    [`vllm.multimodal.processing.BaseMultiModalProcessor.apply`][].
41
    """
42

43
    prompt: str | list[int]
44
45
    mm_data: MultiModalDataDict
    hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict)
46
    tokenization_kwargs: Mapping[str, object] = field(default_factory=dict)
47
48


49
50
51
52
53
54
55
56
57
58
class DummyEncoderData(NamedTuple):
    """Dummy data used for profiling."""

    prompt_token_ids: list[int]


class DummyDecoderData(NamedTuple):
    """Dummy data used for profiling."""

    prompt_token_ids: list[int]
59
    multi_modal_data: MultiModalKwargsItems
60
61
62
    multi_modal_placeholders: MultiModalPlaceholderDict


63
64
65
66
_I = TypeVar("_I", bound=BaseProcessingInfo)


class BaseDummyInputsBuilder(ABC, Generic[_I]):
67
    """
68
    Abstract base class that constructs the dummy data to profile
69
70
71
    multi-modal models.
    """

72
    def __init__(self, info: _I) -> None:
73
74
        super().__init__()

75
        self.info = info
76

77
    @abstractmethod
78
79
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        """
80
        Build the text input corresponding to `mm_counts`.
81
        """
82
        raise NotImplementedError
83

84
    @abstractmethod
85
86
87
88
    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
89
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
90
91
92
93
    ) -> MultiModalDataDict:
        """
        Build the multimodal input which, after processing, results in
        the maximum possible number of placeholder tokens.
94
95
96
97
98
99

        Args:
            seq_len: Sequence length
            mm_counts: Count of items per modality
            mm_options: Configurable options per modality (optional).
                       If None, use model defaults for backward compatibility.
100
                       If provided, models can use these to customize dummy
101
                       data generation.
102
103
104
        """
        raise NotImplementedError

105
106
107
108
    def get_dummy_processor_inputs(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
109
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
110
111
    ) -> ProcessorInputs:
        """
112
        Build the input which, after processing, results in
113
        the maximum possible number of placeholder tokens.
114
115
116
117
118

        Args:
            seq_len: Sequence length
            mm_counts: Count of items per modality
            mm_options: Configurable options per modality (optional)
119
        """
120
        dummy_text = self.get_dummy_text(mm_counts)
121
122
123
124

        # Use the unified function for both legacy and configurable cases
        dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts, mm_options)

125
        tokenization_kwargs = {"truncation": False}
126

127
128
129
130
131
        return ProcessorInputs(
            prompt=dummy_text,
            mm_data=dummy_mm_data,
            tokenization_kwargs=tokenization_kwargs,
        )
132
133
134
135
136
137

    def _get_dummy_audios(
        self,
        *,
        length: int,
        num_audios: int,
138
        overrides: AudioDummyOptions | None = None,
139
    ) -> list[npt.NDArray]:
140
141
        if num_audios == 0:
            return []
142
143
144
145
        if overrides and overrides.length:
            if overrides.length > length:
                logger.warning(
                    "audio.length override (%d) exceeds model's "
146
147
148
149
                    "maximum length (%d), will be ignored",
                    overrides.length,
                    length,
                )
150
            length = min(length, overrides.length)
151
        audio = np.zeros((length,))
152
153
154
155
156
157
158
159
        return [audio] * num_audios

    def _get_dummy_images(
        self,
        *,
        width: int,
        height: int,
        num_images: int,
160
        overrides: ImageDummyOptions | None = None,
161
    ) -> list[Image.Image]:
162
163
        if num_images == 0:
            return []
164
165
166
167
168
        if overrides:
            if overrides.width:
                if overrides.width > width:
                    logger.warning(
                        "image.width override (%d) exceeds model's "
169
170
171
172
                        "maximum width (%d), will be ignored",
                        overrides.width,
                        width,
                    )
173
174
175
176
177
178
                width = min(width, overrides.width)
            if overrides.height:
                if overrides.height > height:
                    logger.warning(
                        "image.height override (%d) exceeds model's "
                        "maximum height (%d), will be ignored",
179
180
181
                        overrides.height,
                        height,
                    )
182
                height = min(height, overrides.height)
183
        image = Image.new("RGB", (width, height), color=255)
184
185
186
187
188
189
190
191
192
        return [image] * num_images

    def _get_dummy_videos(
        self,
        *,
        width: int,
        height: int,
        num_frames: int,
        num_videos: int,
193
        overrides: VideoDummyOptions | None = None,
194
    ) -> list[npt.NDArray]:
195
196
        if num_videos == 0:
            return []
197
198
199
200
201
202
        if overrides:
            if overrides.num_frames:
                if overrides.num_frames > num_frames:
                    logger.warning(
                        "video.num_frames override (%d) exceeds model's "
                        "maximum number of frames (%d), will be ignored",
203
204
205
                        overrides.num_frames,
                        num_frames,
                    )
206
207
208
209
210
                num_frames = min(num_frames, overrides.num_frames)
            if overrides.width:
                if overrides.width > width:
                    logger.warning(
                        "video.width override (%d) exceeds model's "
211
212
213
214
                        "maximum width (%d), will be ignored",
                        overrides.width,
                        width,
                    )
215
216
217
218
219
220
                width = min(width, overrides.width)
            if overrides.height:
                if overrides.height > height:
                    logger.warning(
                        "video.height override (%d) exceeds model's "
                        "maximum height (%d), will be ignored",
221
222
223
                        overrides.height,
                        height,
                    )
224
                height = min(height, overrides.height)
225
        video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8)
226
227
        return [video] * num_videos

228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249

class MultiModalProfiler(Generic[_I]):
    """
    Contains code for running memory profiling for multi-modal models.
    """

    def __init__(
        self,
        processor: BaseMultiModalProcessor[_I],
    ) -> None:
        super().__init__()

        self.processor = processor

    @property
    def processing_info(self) -> BaseProcessingInfo:
        return self.processor.info

    @property
    def dummy_inputs(self) -> BaseDummyInputsBuilder[_I]:
        return self.processor.dummy_inputs

250
    def get_mm_limits(self) -> Mapping[str, int]:
251
        return self.processor.allowed_mm_limits
252
253
254
255

    def _get_dummy_mm_inputs(
        self,
        seq_len: int,
256
257
        mm_counts: Mapping[str, int] | None = None,
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
258
    ) -> MultiModalInputs:
259
260
261
        if mm_counts is None:
            mm_counts = self.get_mm_limits()

262
263
        factory = self.dummy_inputs
        processor_inputs = factory.get_dummy_processor_inputs(
264
265
            seq_len, mm_counts, mm_options
        )
266
267

        return self.processor.apply(
268
            prompt=processor_inputs.prompt,
269
270
            mm_data=processor_inputs.mm_data,
            hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
271
            tokenization_kwargs=processor_inputs.tokenization_kwargs,
272
273
        )

274
    def _get_mm_num_tokens(
275
        self,
276
277
        mm_inputs: MultiModalInputs,
    ) -> Mapping[str, int]:
278
279
        placeholders_by_modality = mm_inputs["mm_placeholders"]

280
        return {
281
            modality: sum(item.get_num_embeds for item in placeholders)
282
283
            for modality, placeholders in placeholders_by_modality.items()
        }
284

285
286
287
    def get_encoder_dummy_data(
        self,
        seq_len: int,
288
289
        mm_counts: Mapping[str, int] | None = None,
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
290
    ) -> DummyEncoderData:
291
        mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts, mm_options)
292
293
294
295
296
297
298
        mm_inputs = cast(MultiModalEncDecInputs, mm_inputs)

        # For encoder-decoder models, use encoder prompt token ids instead of
        # decoder prompt to construct dummy seq_data for encoder profiling.
        encoder_prompt_token_ids = mm_inputs["encoder_prompt_token_ids"]

        total_len = len(encoder_prompt_token_ids)
299

300
301
302
303
        processor = cast(EncDecMultiModalProcessor, self.processor)
        if processor.pad_dummy_encoder_prompt:
            num_tokens_to_pad = max(total_len, seq_len) - total_len
            encoder_prompt_token_ids.extend([0] * num_tokens_to_pad)
304

305
        return DummyEncoderData(encoder_prompt_token_ids)
306

307
308
309
    def get_decoder_dummy_data(
        self,
        seq_len: int,
310
311
        mm_counts: Mapping[str, int] | None = None,
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
312
    ) -> DummyDecoderData:
313
        mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts, mm_options)
314

315
        prompt_token_ids = mm_inputs["prompt_token_ids"]
316
317
        total_len = len(prompt_token_ids)

318
319
        if total_len < seq_len:
            prompt_token_ids.extend([0] * (seq_len - total_len))
320

321
322
        return DummyDecoderData(
            prompt_token_ids=prompt_token_ids,
323
            multi_modal_data=mm_inputs["mm_kwargs"].require_data(),
324
            multi_modal_placeholders=mm_inputs["mm_placeholders"],
325
        )
326

327
    def get_mm_max_tokens(
328
329
        self,
        seq_len: int,
330
        mm_counts: Mapping[str, int] | None = None,
331
    ) -> Mapping[str, int]:
332
333
334
335
        """
        Returns the maximum number of embeddings per item of each modality, excluding
        any break/text tokens in-between multimodal embeddings/encoder outputs.
        """
336
337
338
339
340
341
342
        if mm_counts is None:
            mm_counts = self.get_mm_limits()

        max_tokens_per_item = self.processing_info.get_mm_max_tokens_per_item(
            seq_len=seq_len,
            mm_counts=mm_counts,
        )
343
        if max_tokens_per_item is not None:
344
345
346
347
348
            return {
                modality: max_tokens
                for modality, max_tokens in max_tokens_per_item.items()
                if mm_counts.get(modality, 0) > 0
            }
349

350
        mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
351
        return self._get_mm_num_tokens(mm_inputs)