profiling.py 12.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from abc import ABC, abstractmethod
4
5
from collections.abc import Mapping
from dataclasses import dataclass, field
6
from typing import Generic, NamedTuple, Optional, TypeVar, Union, cast
7
8
9
10
11

import numpy as np
import numpy.typing as npt
from PIL import Image

12
import vllm.envs as envs
13
14
from vllm.config.multimodal import (AudioDummyOptions, BaseDummyOptions,
                                    ImageDummyOptions, VideoDummyOptions)
15
16
from vllm.logger import init_logger

17
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
18
                     MultiModalInputs, MultiModalKwargsItems,
19
                     MultiModalPlaceholderDict)
20
21
from .processing import (BaseMultiModalProcessor, BaseProcessingInfo,
                         EncDecMultiModalProcessor)
22
23
24
25
26
27

logger = init_logger(__name__)


@dataclass
class ProcessorInputs:
28
29
    """
    Represents the keyword arguments to
30
    [`vllm.multimodal.processing.BaseMultiModalProcessor.apply`][].
31
    """
32
    prompt: Union[str, list[int]]
33
34
    mm_data: MultiModalDataDict
    hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict)
35
    tokenization_kwargs: Mapping[str, object] = field(default_factory=dict)
36
37


38
39
40
41
42
43
44
45
46
47
class DummyEncoderData(NamedTuple):
    """Dummy data used for profiling."""

    prompt_token_ids: list[int]


class DummyDecoderData(NamedTuple):
    """Dummy data used for profiling."""

    prompt_token_ids: list[int]
48
    multi_modal_data: MultiModalKwargsItems
49
50
51
    multi_modal_placeholders: MultiModalPlaceholderDict


52
53
54
55
_I = TypeVar("_I", bound=BaseProcessingInfo)


class BaseDummyInputsBuilder(ABC, Generic[_I]):
56
    """
57
    Abstract base class that constructs the dummy data to profile
58
59
60
    multi-modal models.
    """

61
    def __init__(self, info: _I) -> None:
62
63
        super().__init__()

64
        self.info = info
65

66
    @abstractmethod
67
68
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        """
69
        Build the text input corresponding to `mm_counts`.
70
        """
71
        raise NotImplementedError
72

73
    @abstractmethod
74
75
76
77
    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
78
        mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
79
80
81
82
    ) -> MultiModalDataDict:
        """
        Build the multimodal input which, after processing, results in
        the maximum possible number of placeholder tokens.
83
84
85
86
87
88
89
90

        Args:
            seq_len: Sequence length
            mm_counts: Count of items per modality
            mm_options: Configurable options per modality (optional).
                       If None, use model defaults for backward compatibility.
                       If provided, models can use these to customize dummy 
                       data generation.
91
92
93
        """
        raise NotImplementedError

94
95
96
97
    def get_dummy_processor_inputs(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
98
        mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
99
100
    ) -> ProcessorInputs:
        """
101
        Build the input which, after processing, results in
102
        the maximum possible number of placeholder tokens.
103
104
105
106
107

        Args:
            seq_len: Sequence length
            mm_counts: Count of items per modality
            mm_options: Configurable options per modality (optional)
108
        """
109
        dummy_text = self.get_dummy_text(mm_counts)
110
111
112
113

        # Use the unified function for both legacy and configurable cases
        dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts, mm_options)

114
        tokenization_kwargs = {"truncation": False}
115

116
117
118
        return ProcessorInputs(prompt=dummy_text,
                               mm_data=dummy_mm_data,
                               tokenization_kwargs=tokenization_kwargs)
119
120
121
122
123
124

    def _get_dummy_audios(
        self,
        *,
        length: int,
        num_audios: int,
125
        overrides: Optional[AudioDummyOptions] = None,
126
    ) -> list[npt.NDArray]:
127
128
        if num_audios == 0:
            return []
129
130
131
132
133
134
135
        if overrides and overrides.length:
            if overrides.length > length:
                logger.warning(
                    "audio.length override (%d) exceeds model's "
                    "maximum length (%d), will be ignored", overrides.length,
                    length)
            length = min(length, overrides.length)
136
137
138
139
140
141
142
143
144
        audio = np.zeros((length, ))
        return [audio] * num_audios

    def _get_dummy_images(
        self,
        *,
        width: int,
        height: int,
        num_images: int,
145
        overrides: Optional[ImageDummyOptions] = None,
146
    ) -> list[Image.Image]:
147
148
        if num_images == 0:
            return []
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
        if overrides:
            if overrides.width:
                if overrides.width > width:
                    logger.warning(
                        "image.width override (%d) exceeds model's "
                        "maximum width (%d), will be ignored", overrides.width,
                        width)
                width = min(width, overrides.width)
            if overrides.height:
                if overrides.height > height:
                    logger.warning(
                        "image.height override (%d) exceeds model's "
                        "maximum height (%d), will be ignored",
                        overrides.height, height)
                height = min(height, overrides.height)
164
        image = Image.new("RGB", (width, height), color=255)
165
166
167
168
169
170
171
172
173
        return [image] * num_images

    def _get_dummy_videos(
        self,
        *,
        width: int,
        height: int,
        num_frames: int,
        num_videos: int,
174
        overrides: Optional[VideoDummyOptions] = None,
175
    ) -> list[npt.NDArray]:
176
177
        if num_videos == 0:
            return []
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
        if overrides:
            if overrides.num_frames:
                if overrides.num_frames > num_frames:
                    logger.warning(
                        "video.num_frames override (%d) exceeds model's "
                        "maximum number of frames (%d), will be ignored",
                        overrides.num_frames, num_frames)
                num_frames = min(num_frames, overrides.num_frames)
            if overrides.width:
                if overrides.width > width:
                    logger.warning(
                        "video.width override (%d) exceeds model's "
                        "maximum width (%d), will be ignored", overrides.width,
                        width)
                width = min(width, overrides.width)
            if overrides.height:
                if overrides.height > height:
                    logger.warning(
                        "video.height override (%d) exceeds model's "
                        "maximum height (%d), will be ignored",
                        overrides.height, height)
                height = min(height, overrides.height)
200
        video = np.full((num_frames, width, height, 3), 255)
201
202
        return [video] * num_videos

203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224

class MultiModalProfiler(Generic[_I]):
    """
    Contains code for running memory profiling for multi-modal models.
    """

    def __init__(
        self,
        processor: BaseMultiModalProcessor[_I],
    ) -> None:
        super().__init__()

        self.processor = processor

    @property
    def processing_info(self) -> BaseProcessingInfo:
        return self.processor.info

    @property
    def dummy_inputs(self) -> BaseDummyInputsBuilder[_I]:
        return self.processor.dummy_inputs

225
    def get_mm_limits(self) -> Mapping[str, int]:
226
        return self.processor.allowed_mm_limits
227
228
229
230

    def _get_dummy_mm_inputs(
        self,
        seq_len: int,
231
        mm_counts: Optional[Mapping[str, int]] = None,
232
        mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
233
    ) -> MultiModalInputs:
234
235
236
        if mm_counts is None:
            mm_counts = self.get_mm_limits()

237
238
        factory = self.dummy_inputs
        processor_inputs = factory.get_dummy_processor_inputs(
239
            seq_len, mm_counts, mm_options)
240
241

        return self.processor.apply(
242
            prompt=processor_inputs.prompt,
243
244
            mm_data=processor_inputs.mm_data,
            hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
245
            tokenization_kwargs=processor_inputs.tokenization_kwargs,
246
247
        )

248
    def _get_mm_num_tokens(
249
        self,
250
        mm_inputs: MultiModalInputs,
251
        mm_embeddings_only: bool = True,
252
    ) -> Mapping[str, int]:
253
254
        placeholders_by_modality = mm_inputs["mm_placeholders"]

255
        return {
256
257
258
            modality:
            sum(item.get_num_embeds() if mm_embeddings_only else item.length
                for item in placeholders)
259
260
            for modality, placeholders in placeholders_by_modality.items()
        }
261

262
263
264
265
    def get_encoder_dummy_data(
        self,
        seq_len: int,
        mm_counts: Optional[Mapping[str, int]] = None,
266
        mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
267
    ) -> DummyEncoderData:
268
        mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts, mm_options)
269
270
271
272
273
274
275
        mm_inputs = cast(MultiModalEncDecInputs, mm_inputs)

        # For encoder-decoder models, use encoder prompt token ids instead of
        # decoder prompt to construct dummy seq_data for encoder profiling.
        encoder_prompt_token_ids = mm_inputs["encoder_prompt_token_ids"]

        total_len = len(encoder_prompt_token_ids)
276

277
278
279
280
        processor = cast(EncDecMultiModalProcessor, self.processor)
        if processor.pad_dummy_encoder_prompt:
            num_tokens_to_pad = max(total_len, seq_len) - total_len
            encoder_prompt_token_ids.extend([0] * num_tokens_to_pad)
281
        # NOTE: Whisper allows total_len > seq_len.
282
        elif total_len > seq_len and not envs.VLLM_USE_V1:
283
            # `max_num_batched_tokens` is defined by `SchedulerConfig`
284
            logger.warning_once(
285
286
287
288
289
290
291
292
                "The encoder sequence length used for profiling (max_num_batched_tokens / max_num_seqs = %d) "  # noqa: E501
                "is too short to hold the multi-modal embeddings in the worst case (%d tokens in total, out of which %s are reserved for multi-modal embeddings). "  # noqa: E501
                "This may cause certain multi-modal inputs to fail during inference, even when the input text is short. "  # noqa: E501
                "To avoid this, you should increase `max_model_len`, reduce `max_num_seqs`, and/or reduce `mm_counts`.",  # noqa: E501
                seq_len,
                total_len,
                str(self._get_mm_num_tokens(mm_inputs)),
            )
293

294
        return DummyEncoderData(encoder_prompt_token_ids)
295

296
297
298
299
    def get_decoder_dummy_data(
        self,
        seq_len: int,
        mm_counts: Optional[Mapping[str, int]] = None,
300
        mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
301
    ) -> DummyDecoderData:
302
        mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts, mm_options)
303

304
        prompt_token_ids = mm_inputs["prompt_token_ids"]
305
306
        total_len = len(prompt_token_ids)

307
308
        if total_len < seq_len:
            prompt_token_ids.extend([0] * (seq_len - total_len))
309

310
311
        return DummyDecoderData(
            prompt_token_ids=prompt_token_ids,
312
            multi_modal_data=mm_inputs["mm_kwargs"].require_data(),
313
            multi_modal_placeholders=mm_inputs["mm_placeholders"],
314
        )
315

316
    def _get_mm_max_tokens(
317
318
319
        self,
        seq_len: int,
        mm_counts: Optional[Mapping[str, int]] = None,
320
        mm_embeddings_only: bool = True,
321
    ) -> Mapping[str, int]:
322
323
324
325
326
327
328
        if mm_counts is None:
            mm_counts = self.get_mm_limits()

        max_tokens_per_item = self.processing_info.get_mm_max_tokens_per_item(
            seq_len=seq_len,
            mm_counts=mm_counts,
        )
329
330
        if max_tokens_per_item is not None:
            return max_tokens_per_item
331

332
        mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
333
334
335
336
337
338
339
340
341
342
343
344
        return self._get_mm_num_tokens(mm_inputs,
                                       mm_embeddings_only=mm_embeddings_only)

    def get_mm_max_contiguous_tokens(
        self,
        seq_len: int,
        mm_counts: Optional[Mapping[str, int]] = None,
    ):
        """
        Returns the maximum length of the multimodal (image placeholders+text)
        tokens, including any break/text tokens in-between image embeddings.

345
        `<im_start> [IMG] [IMG] [IMG] <row_break> [IMG] [IMG] [IMG] <im_end>`
346
        Returns 9, even when the number of image embeddings is 6.
347

348
349
350
351
352
353
354
        This is important to take into account when profiling and
        initializing the encoder cache size.
        """

        return self._get_mm_max_tokens(seq_len,
                                       mm_counts,
                                       mm_embeddings_only=False)