profiling.py 12.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from abc import ABC, abstractmethod
4
5
from collections.abc import Mapping
from dataclasses import dataclass, field
6
from typing import Generic, NamedTuple, Optional, TypeVar, Union, cast
7
8
9
10
11

import numpy as np
import numpy.typing as npt
from PIL import Image

12
import vllm.envs as envs
13
14
15
16
17
18
from vllm.config.multimodal import (
    AudioDummyOptions,
    BaseDummyOptions,
    ImageDummyOptions,
    VideoDummyOptions,
)
19
20
from vllm.logger import init_logger

21
22
23
24
25
26
27
28
29
30
31
32
from .inputs import (
    MultiModalDataDict,
    MultiModalEncDecInputs,
    MultiModalInputs,
    MultiModalKwargsItems,
    MultiModalPlaceholderDict,
)
from .processing import (
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    EncDecMultiModalProcessor,
)
33
34
35
36
37
38

logger = init_logger(__name__)


@dataclass
class ProcessorInputs:
39
40
    """
    Represents the keyword arguments to
41
    [`vllm.multimodal.processing.BaseMultiModalProcessor.apply`][].
42
    """
43

44
    prompt: Union[str, list[int]]
45
46
    mm_data: MultiModalDataDict
    hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict)
47
    tokenization_kwargs: Mapping[str, object] = field(default_factory=dict)
48
49


50
51
52
53
54
55
56
57
58
59
class DummyEncoderData(NamedTuple):
    """Dummy data used for profiling."""

    prompt_token_ids: list[int]


class DummyDecoderData(NamedTuple):
    """Dummy data used for profiling."""

    prompt_token_ids: list[int]
60
    multi_modal_data: MultiModalKwargsItems
61
62
63
    multi_modal_placeholders: MultiModalPlaceholderDict


64
65
66
67
_I = TypeVar("_I", bound=BaseProcessingInfo)


class BaseDummyInputsBuilder(ABC, Generic[_I]):
68
    """
69
    Abstract base class that constructs the dummy data to profile
70
71
72
    multi-modal models.
    """

73
    def __init__(self, info: _I) -> None:
74
75
        super().__init__()

76
        self.info = info
77

78
    @abstractmethod
79
80
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        """
81
        Build the text input corresponding to `mm_counts`.
82
        """
83
        raise NotImplementedError
84

85
    @abstractmethod
86
87
88
89
    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
90
        mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
91
92
93
94
    ) -> MultiModalDataDict:
        """
        Build the multimodal input which, after processing, results in
        the maximum possible number of placeholder tokens.
95
96
97
98
99
100

        Args:
            seq_len: Sequence length
            mm_counts: Count of items per modality
            mm_options: Configurable options per modality (optional).
                       If None, use model defaults for backward compatibility.
101
                       If provided, models can use these to customize dummy
102
                       data generation.
103
104
105
        """
        raise NotImplementedError

106
107
108
109
    def get_dummy_processor_inputs(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
110
        mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
111
112
    ) -> ProcessorInputs:
        """
113
        Build the input which, after processing, results in
114
        the maximum possible number of placeholder tokens.
115
116
117
118
119

        Args:
            seq_len: Sequence length
            mm_counts: Count of items per modality
            mm_options: Configurable options per modality (optional)
120
        """
121
        dummy_text = self.get_dummy_text(mm_counts)
122
123
124
125

        # Use the unified function for both legacy and configurable cases
        dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts, mm_options)

126
        tokenization_kwargs = {"truncation": False}
127

128
129
130
131
132
        return ProcessorInputs(
            prompt=dummy_text,
            mm_data=dummy_mm_data,
            tokenization_kwargs=tokenization_kwargs,
        )
133
134
135
136
137
138

    def _get_dummy_audios(
        self,
        *,
        length: int,
        num_audios: int,
139
        overrides: Optional[AudioDummyOptions] = None,
140
    ) -> list[npt.NDArray]:
141
142
        if num_audios == 0:
            return []
143
144
145
146
        if overrides and overrides.length:
            if overrides.length > length:
                logger.warning(
                    "audio.length override (%d) exceeds model's "
147
148
149
150
                    "maximum length (%d), will be ignored",
                    overrides.length,
                    length,
                )
151
            length = min(length, overrides.length)
152
        audio = np.zeros((length,))
153
154
155
156
157
158
159
160
        return [audio] * num_audios

    def _get_dummy_images(
        self,
        *,
        width: int,
        height: int,
        num_images: int,
161
        overrides: Optional[ImageDummyOptions] = None,
162
    ) -> list[Image.Image]:
163
164
        if num_images == 0:
            return []
165
166
167
168
169
        if overrides:
            if overrides.width:
                if overrides.width > width:
                    logger.warning(
                        "image.width override (%d) exceeds model's "
170
171
172
173
                        "maximum width (%d), will be ignored",
                        overrides.width,
                        width,
                    )
174
175
176
177
178
179
                width = min(width, overrides.width)
            if overrides.height:
                if overrides.height > height:
                    logger.warning(
                        "image.height override (%d) exceeds model's "
                        "maximum height (%d), will be ignored",
180
181
182
                        overrides.height,
                        height,
                    )
183
                height = min(height, overrides.height)
184
        image = Image.new("RGB", (width, height), color=255)
185
186
187
188
189
190
191
192
193
        return [image] * num_images

    def _get_dummy_videos(
        self,
        *,
        width: int,
        height: int,
        num_frames: int,
        num_videos: int,
194
        overrides: Optional[VideoDummyOptions] = None,
195
    ) -> list[npt.NDArray]:
196
197
        if num_videos == 0:
            return []
198
199
200
201
202
203
        if overrides:
            if overrides.num_frames:
                if overrides.num_frames > num_frames:
                    logger.warning(
                        "video.num_frames override (%d) exceeds model's "
                        "maximum number of frames (%d), will be ignored",
204
205
206
                        overrides.num_frames,
                        num_frames,
                    )
207
208
209
210
211
                num_frames = min(num_frames, overrides.num_frames)
            if overrides.width:
                if overrides.width > width:
                    logger.warning(
                        "video.width override (%d) exceeds model's "
212
213
214
215
                        "maximum width (%d), will be ignored",
                        overrides.width,
                        width,
                    )
216
217
218
219
220
221
                width = min(width, overrides.width)
            if overrides.height:
                if overrides.height > height:
                    logger.warning(
                        "video.height override (%d) exceeds model's "
                        "maximum height (%d), will be ignored",
222
223
224
                        overrides.height,
                        height,
                    )
225
                height = min(height, overrides.height)
226
        video = np.full((num_frames, width, height, 3), 255)
227
228
        return [video] * num_videos

229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250

class MultiModalProfiler(Generic[_I]):
    """
    Contains code for running memory profiling for multi-modal models.
    """

    def __init__(
        self,
        processor: BaseMultiModalProcessor[_I],
    ) -> None:
        super().__init__()

        self.processor = processor

    @property
    def processing_info(self) -> BaseProcessingInfo:
        return self.processor.info

    @property
    def dummy_inputs(self) -> BaseDummyInputsBuilder[_I]:
        return self.processor.dummy_inputs

251
    def get_mm_limits(self) -> Mapping[str, int]:
252
        return self.processor.allowed_mm_limits
253
254
255
256

    def _get_dummy_mm_inputs(
        self,
        seq_len: int,
257
        mm_counts: Optional[Mapping[str, int]] = None,
258
        mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
259
    ) -> MultiModalInputs:
260
261
262
        if mm_counts is None:
            mm_counts = self.get_mm_limits()

263
264
        factory = self.dummy_inputs
        processor_inputs = factory.get_dummy_processor_inputs(
265
266
            seq_len, mm_counts, mm_options
        )
267
268

        return self.processor.apply(
269
            prompt=processor_inputs.prompt,
270
271
            mm_data=processor_inputs.mm_data,
            hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
272
            tokenization_kwargs=processor_inputs.tokenization_kwargs,
273
274
        )

275
    def _get_mm_num_tokens(
276
        self,
277
        mm_inputs: MultiModalInputs,
278
        mm_embeddings_only: bool = True,
279
    ) -> Mapping[str, int]:
280
281
        placeholders_by_modality = mm_inputs["mm_placeholders"]

282
        return {
283
284
285
286
            modality: sum(
                item.get_num_embeds() if mm_embeddings_only else item.length
                for item in placeholders
            )
287
288
            for modality, placeholders in placeholders_by_modality.items()
        }
289

290
291
292
293
    def get_encoder_dummy_data(
        self,
        seq_len: int,
        mm_counts: Optional[Mapping[str, int]] = None,
294
        mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
295
    ) -> DummyEncoderData:
296
        mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts, mm_options)
297
298
299
300
301
302
303
        mm_inputs = cast(MultiModalEncDecInputs, mm_inputs)

        # For encoder-decoder models, use encoder prompt token ids instead of
        # decoder prompt to construct dummy seq_data for encoder profiling.
        encoder_prompt_token_ids = mm_inputs["encoder_prompt_token_ids"]

        total_len = len(encoder_prompt_token_ids)
304

305
306
307
308
        processor = cast(EncDecMultiModalProcessor, self.processor)
        if processor.pad_dummy_encoder_prompt:
            num_tokens_to_pad = max(total_len, seq_len) - total_len
            encoder_prompt_token_ids.extend([0] * num_tokens_to_pad)
309
        # NOTE: Whisper allows total_len > seq_len.
310
        elif total_len > seq_len and not envs.VLLM_USE_V1:
311
            # `max_num_batched_tokens` is defined by `SchedulerConfig`
312
            logger.warning_once(
313
314
315
316
317
318
319
320
                "The encoder sequence length used for profiling (max_num_batched_tokens / max_num_seqs = %d) "  # noqa: E501
                "is too short to hold the multi-modal embeddings in the worst case (%d tokens in total, out of which %s are reserved for multi-modal embeddings). "  # noqa: E501
                "This may cause certain multi-modal inputs to fail during inference, even when the input text is short. "  # noqa: E501
                "To avoid this, you should increase `max_model_len`, reduce `max_num_seqs`, and/or reduce `mm_counts`.",  # noqa: E501
                seq_len,
                total_len,
                str(self._get_mm_num_tokens(mm_inputs)),
            )
321

322
        return DummyEncoderData(encoder_prompt_token_ids)
323

324
325
326
327
    def get_decoder_dummy_data(
        self,
        seq_len: int,
        mm_counts: Optional[Mapping[str, int]] = None,
328
        mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
329
    ) -> DummyDecoderData:
330
        mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts, mm_options)
331

332
        prompt_token_ids = mm_inputs["prompt_token_ids"]
333
334
        total_len = len(prompt_token_ids)

335
336
        if total_len < seq_len:
            prompt_token_ids.extend([0] * (seq_len - total_len))
337

338
339
        return DummyDecoderData(
            prompt_token_ids=prompt_token_ids,
340
            multi_modal_data=mm_inputs["mm_kwargs"].require_data(),
341
            multi_modal_placeholders=mm_inputs["mm_placeholders"],
342
        )
343

344
    def _get_mm_max_tokens(
345
346
347
        self,
        seq_len: int,
        mm_counts: Optional[Mapping[str, int]] = None,
348
        mm_embeddings_only: bool = True,
349
    ) -> Mapping[str, int]:
350
351
352
353
354
355
356
        if mm_counts is None:
            mm_counts = self.get_mm_limits()

        max_tokens_per_item = self.processing_info.get_mm_max_tokens_per_item(
            seq_len=seq_len,
            mm_counts=mm_counts,
        )
357
358
        if max_tokens_per_item is not None:
            return max_tokens_per_item
359

360
        mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
361
        return self._get_mm_num_tokens(mm_inputs, mm_embeddings_only=mm_embeddings_only)
362
363
364
365
366
367
368
369
370
371

    def get_mm_max_contiguous_tokens(
        self,
        seq_len: int,
        mm_counts: Optional[Mapping[str, int]] = None,
    ):
        """
        Returns the maximum length of the multimodal (image placeholders+text)
        tokens, including any break/text tokens in-between image embeddings.

372
        `<im_start> [IMG] [IMG] [IMG] <row_break> [IMG] [IMG] [IMG] <im_end>`
373
        Returns 9, even when the number of image embeddings is 6.
374

375
376
377
378
        This is important to take into account when profiling and
        initializing the encoder cache size.
        """

379
        return self._get_mm_max_tokens(seq_len, mm_counts, mm_embeddings_only=False)