profiling.py 11.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from abc import ABC, abstractmethod
4
5
from collections.abc import Mapping
from dataclasses import dataclass, field
6
from typing import Generic, NamedTuple, TypeVar, cast
7
8
9
10
11

import numpy as np
import numpy.typing as npt
from PIL import Image

12
13
14
15
16
17
from vllm.config.multimodal import (
    AudioDummyOptions,
    BaseDummyOptions,
    ImageDummyOptions,
    VideoDummyOptions,
)
18
19
from vllm.logger import init_logger

20
21
22
23
24
25
26
27
28
29
30
31
from .inputs import (
    MultiModalDataDict,
    MultiModalEncDecInputs,
    MultiModalInputs,
    MultiModalKwargsItems,
    MultiModalPlaceholderDict,
)
from .processing import (
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    EncDecMultiModalProcessor,
)
32
33
34
35
36
37

logger = init_logger(__name__)


@dataclass
class ProcessorInputs:
38
39
    """
    Represents the keyword arguments to
40
    [`vllm.multimodal.processing.BaseMultiModalProcessor.apply`][].
41
    """
42

43
    prompt: str | list[int]
44
45
    mm_data: MultiModalDataDict
    hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict)
46
    tokenization_kwargs: Mapping[str, object] = field(default_factory=dict)
47
48


49
50
51
52
53
54
55
56
57
58
class DummyEncoderData(NamedTuple):
    """Dummy data used for profiling."""

    prompt_token_ids: list[int]


class DummyDecoderData(NamedTuple):
    """Dummy data used for profiling."""

    prompt_token_ids: list[int]
59
    multi_modal_data: MultiModalKwargsItems
60
61
62
    multi_modal_placeholders: MultiModalPlaceholderDict


63
64
65
66
_I = TypeVar("_I", bound=BaseProcessingInfo)


class BaseDummyInputsBuilder(ABC, Generic[_I]):
67
    """
68
    Abstract base class that constructs the dummy data to profile
69
70
71
    multi-modal models.
    """

72
    def __init__(self, info: _I) -> None:
73
74
        super().__init__()

75
        self.info = info
76

77
    @abstractmethod
78
79
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        """
80
        Build the text input corresponding to `mm_counts`.
81
        """
82
        raise NotImplementedError
83

84
    @abstractmethod
85
86
87
88
    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
89
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
90
91
92
93
    ) -> MultiModalDataDict:
        """
        Build the multimodal input which, after processing, results in
        the maximum possible number of placeholder tokens.
94
95
96
97
98
99

        Args:
            seq_len: Sequence length
            mm_counts: Count of items per modality
            mm_options: Configurable options per modality (optional).
                       If None, use model defaults for backward compatibility.
100
                       If provided, models can use these to customize dummy
101
                       data generation.
102
103
104
        """
        raise NotImplementedError

105
106
107
108
    def get_dummy_processor_inputs(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
109
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
110
111
    ) -> ProcessorInputs:
        """
112
        Build the input which, after processing, results in
113
        the maximum possible number of placeholder tokens.
114
115
116
117
118

        Args:
            seq_len: Sequence length
            mm_counts: Count of items per modality
            mm_options: Configurable options per modality (optional)
119
        """
120
        dummy_text = self.get_dummy_text(mm_counts)
121
122
123
124

        # Use the unified function for both legacy and configurable cases
        dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts, mm_options)

125
        tokenization_kwargs = {"truncation": False}
126

127
128
129
130
131
        return ProcessorInputs(
            prompt=dummy_text,
            mm_data=dummy_mm_data,
            tokenization_kwargs=tokenization_kwargs,
        )
132
133
134
135
136
137

    def _get_dummy_audios(
        self,
        *,
        length: int,
        num_audios: int,
138
        overrides: AudioDummyOptions | None = None,
139
    ) -> list[npt.NDArray]:
140
141
        if num_audios == 0:
            return []
142
143
144
145
        if overrides and overrides.length:
            if overrides.length > length:
                logger.warning(
                    "audio.length override (%d) exceeds model's "
146
147
148
149
                    "maximum length (%d), will be ignored",
                    overrides.length,
                    length,
                )
150
            length = min(length, overrides.length)
151
        audio = np.zeros((length,))
152
153
154
155
156
157
158
159
        return [audio] * num_audios

    def _get_dummy_images(
        self,
        *,
        width: int,
        height: int,
        num_images: int,
160
        overrides: ImageDummyOptions | None = None,
161
    ) -> list[Image.Image]:
162
163
        if num_images == 0:
            return []
164
165
166
167
168
        if overrides:
            if overrides.width:
                if overrides.width > width:
                    logger.warning(
                        "image.width override (%d) exceeds model's "
169
170
171
172
                        "maximum width (%d), will be ignored",
                        overrides.width,
                        width,
                    )
173
174
175
176
177
178
                width = min(width, overrides.width)
            if overrides.height:
                if overrides.height > height:
                    logger.warning(
                        "image.height override (%d) exceeds model's "
                        "maximum height (%d), will be ignored",
179
180
181
                        overrides.height,
                        height,
                    )
182
                height = min(height, overrides.height)
183
        image = Image.new("RGB", (width, height), color=255)
184
185
186
187
188
189
190
191
192
        return [image] * num_images

    def _get_dummy_videos(
        self,
        *,
        width: int,
        height: int,
        num_frames: int,
        num_videos: int,
193
        overrides: VideoDummyOptions | None = None,
194
    ) -> list[npt.NDArray]:
195
196
        if num_videos == 0:
            return []
197
198
199
200
201
202
        if overrides:
            if overrides.num_frames:
                if overrides.num_frames > num_frames:
                    logger.warning(
                        "video.num_frames override (%d) exceeds model's "
                        "maximum number of frames (%d), will be ignored",
203
204
205
                        overrides.num_frames,
                        num_frames,
                    )
206
207
208
209
210
                num_frames = min(num_frames, overrides.num_frames)
            if overrides.width:
                if overrides.width > width:
                    logger.warning(
                        "video.width override (%d) exceeds model's "
211
212
213
214
                        "maximum width (%d), will be ignored",
                        overrides.width,
                        width,
                    )
215
216
217
218
219
220
                width = min(width, overrides.width)
            if overrides.height:
                if overrides.height > height:
                    logger.warning(
                        "video.height override (%d) exceeds model's "
                        "maximum height (%d), will be ignored",
221
222
223
                        overrides.height,
                        height,
                    )
224
                height = min(height, overrides.height)
225
        video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8)
226
227
        return [video] * num_videos

228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249

class MultiModalProfiler(Generic[_I]):
    """
    Contains code for running memory profiling for multi-modal models.
    """

    def __init__(
        self,
        processor: BaseMultiModalProcessor[_I],
    ) -> None:
        super().__init__()

        self.processor = processor

    @property
    def processing_info(self) -> BaseProcessingInfo:
        return self.processor.info

    @property
    def dummy_inputs(self) -> BaseDummyInputsBuilder[_I]:
        return self.processor.dummy_inputs

250
    def get_mm_limits(self) -> Mapping[str, int]:
251
        return self.processor.allowed_mm_limits
252
253
254
255

    def _get_dummy_mm_inputs(
        self,
        seq_len: int,
256
257
        mm_counts: Mapping[str, int] | None = None,
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
258
    ) -> MultiModalInputs:
259
260
261
        if mm_counts is None:
            mm_counts = self.get_mm_limits()

262
263
        factory = self.dummy_inputs
        processor_inputs = factory.get_dummy_processor_inputs(
264
265
            seq_len, mm_counts, mm_options
        )
266
267

        return self.processor.apply(
268
            prompt=processor_inputs.prompt,
269
270
            mm_data=processor_inputs.mm_data,
            hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
271
            tokenization_kwargs=processor_inputs.tokenization_kwargs,
272
273
        )

274
    def _get_mm_num_tokens(
275
        self,
276
        mm_inputs: MultiModalInputs,
277
        mm_embeddings_only: bool = True,
278
    ) -> Mapping[str, int]:
279
280
        placeholders_by_modality = mm_inputs["mm_placeholders"]

281
        return {
282
283
284
285
            modality: sum(
                item.get_num_embeds() if mm_embeddings_only else item.length
                for item in placeholders
            )
286
287
            for modality, placeholders in placeholders_by_modality.items()
        }
288

289
290
291
    def get_encoder_dummy_data(
        self,
        seq_len: int,
292
293
        mm_counts: Mapping[str, int] | None = None,
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
294
    ) -> DummyEncoderData:
295
        mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts, mm_options)
296
297
298
299
300
301
302
        mm_inputs = cast(MultiModalEncDecInputs, mm_inputs)

        # For encoder-decoder models, use encoder prompt token ids instead of
        # decoder prompt to construct dummy seq_data for encoder profiling.
        encoder_prompt_token_ids = mm_inputs["encoder_prompt_token_ids"]

        total_len = len(encoder_prompt_token_ids)
303

304
305
306
307
        processor = cast(EncDecMultiModalProcessor, self.processor)
        if processor.pad_dummy_encoder_prompt:
            num_tokens_to_pad = max(total_len, seq_len) - total_len
            encoder_prompt_token_ids.extend([0] * num_tokens_to_pad)
308

309
        return DummyEncoderData(encoder_prompt_token_ids)
310

311
312
313
    def get_decoder_dummy_data(
        self,
        seq_len: int,
314
315
        mm_counts: Mapping[str, int] | None = None,
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
316
    ) -> DummyDecoderData:
317
        mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts, mm_options)
318

319
        prompt_token_ids = mm_inputs["prompt_token_ids"]
320
321
        total_len = len(prompt_token_ids)

322
323
        if total_len < seq_len:
            prompt_token_ids.extend([0] * (seq_len - total_len))
324

325
326
        return DummyDecoderData(
            prompt_token_ids=prompt_token_ids,
327
            multi_modal_data=mm_inputs["mm_kwargs"].require_data(),
328
            multi_modal_placeholders=mm_inputs["mm_placeholders"],
329
        )
330

331
    def _get_mm_max_tokens(
332
333
        self,
        seq_len: int,
334
        mm_counts: Mapping[str, int] | None = None,
335
        mm_embeddings_only: bool = True,
336
    ) -> Mapping[str, int]:
337
338
339
340
341
342
343
        if mm_counts is None:
            mm_counts = self.get_mm_limits()

        max_tokens_per_item = self.processing_info.get_mm_max_tokens_per_item(
            seq_len=seq_len,
            mm_counts=mm_counts,
        )
344
        if max_tokens_per_item is not None:
345
346
347
348
349
            return {
                modality: max_tokens
                for modality, max_tokens in max_tokens_per_item.items()
                if mm_counts.get(modality, 0) > 0
            }
350

351
        mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
352
        return self._get_mm_num_tokens(mm_inputs, mm_embeddings_only=mm_embeddings_only)
353
354
355
356

    def get_mm_max_contiguous_tokens(
        self,
        seq_len: int,
357
        mm_counts: Mapping[str, int] | None = None,
358
    ) -> Mapping[str, int]:
359
360
361
362
        """
        Returns the maximum length of the multimodal (image placeholders+text)
        tokens, including any break/text tokens in-between image embeddings.

363
        `<im_start> [IMG] [IMG] [IMG] <row_break> [IMG] [IMG] [IMG] <im_end>`
364
        Returns 9, even when the number of image embeddings is 6.
365

366
367
368
        This is important to take into account when profiling and
        initializing the encoder cache size.
        """
369
        return self._get_mm_max_tokens(seq_len, mm_counts, mm_embeddings_only=False)