profiling.py 9.33 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
from abc import ABC, abstractmethod
from collections.abc import Mapping
from dataclasses import dataclass, field
6
from typing import Generic, NamedTuple, Optional, TypeVar, cast
7
8
9
10
11

import numpy as np
import numpy.typing as npt
from PIL import Image

12
import vllm.envs as envs
13
14
from vllm.logger import init_logger

15
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
16
17
                     MultiModalInputs, MultiModalKwargs,
                     MultiModalPlaceholderDict)
18
19
from .processing import (BaseMultiModalProcessor, BaseProcessingInfo,
                         EncDecMultiModalProcessor)
20
21
22
23
24
25

logger = init_logger(__name__)


@dataclass
class ProcessorInputs:
26
27
28
29
    """
    Represents the keyword arguments to
    :meth:`vllm.multimodal.processing.BaseMultiModalProcessor.apply`.
    """
30
31
32
33
34
    prompt_text: str
    mm_data: MultiModalDataDict
    hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict)


35
36
37
38
39
40
41
42
43
44
45
46
47
48
class DummyEncoderData(NamedTuple):
    """Dummy data used for profiling."""

    prompt_token_ids: list[int]


class DummyDecoderData(NamedTuple):
    """Dummy data used for profiling."""

    prompt_token_ids: list[int]
    multi_modal_data: MultiModalKwargs
    multi_modal_placeholders: MultiModalPlaceholderDict


49
50
51
52
_I = TypeVar("_I", bound=BaseProcessingInfo)


class BaseDummyInputsBuilder(ABC, Generic[_I]):
53
    """
54
    Abstract base class that constructs the dummy data to profile
55
56
57
    multi-modal models.
    """

58
    def __init__(self, info: _I) -> None:
59
60
        super().__init__()

61
        self.info = info
62
63
64
65
66
67
68
69

    @abstractmethod
    def get_dummy_processor_inputs(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> ProcessorInputs:
        """
70
        Build the input which, after processing, results in
71
        :code:`self.info.get_mm_max_tokens_per_item()` placeholder tokens.
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
        """
        raise NotImplementedError

    def _get_dummy_audios(
        self,
        *,
        length: int,
        num_audios: int,
    ) -> list[npt.NDArray]:
        audio = np.zeros((length, ))
        return [audio] * num_audios

    def _get_dummy_images(
        self,
        *,
        width: int,
        height: int,
        num_images: int,
    ) -> list[Image.Image]:
91
        image = Image.new("RGB", (width, height), color=255)
92
93
94
95
96
97
98
99
100
101
        return [image] * num_images

    def _get_dummy_videos(
        self,
        *,
        width: int,
        height: int,
        num_frames: int,
        num_videos: int,
    ) -> list[npt.NDArray]:
102
        video = np.full((num_frames, width, height, 3), 255)
103
104
        return [video] * num_videos

105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126

class MultiModalProfiler(Generic[_I]):
    """
    Contains code for running memory profiling for multi-modal models.
    """

    def __init__(
        self,
        processor: BaseMultiModalProcessor[_I],
    ) -> None:
        super().__init__()

        self.processor = processor

    @property
    def processing_info(self) -> BaseProcessingInfo:
        return self.processor.info

    @property
    def dummy_inputs(self) -> BaseDummyInputsBuilder[_I]:
        return self.processor.dummy_inputs

127
    def get_mm_limits(self) -> Mapping[str, int]:
128
129
        mm_config = self.processing_info.ctx.get_mm_config()
        supported_mm_limits = self.processing_info.get_supported_mm_limits()
130
131

        mm_limits = {
132
            modality: mm_config.get_limit_per_prompt(modality)
133
134
135
136
137
138
139
140
141
142
143
144
            for modality in supported_mm_limits
        }

        for modality, supported_limit in supported_mm_limits.items():
            limit = mm_limits[modality]
            if supported_limit is not None and supported_limit < limit:
                raise ValueError(
                    f"You set {modality}={limit} (or defaulted to 1) in "
                    f"`--limit-mm-per-prompt`, but this model only supports "
                    f"at most {supported_limit} {modality} items.")

        return mm_limits
145
146
147
148
149

    def _get_dummy_mm_inputs(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
150
    ) -> MultiModalInputs:
151
152
153
154
155
        factory = self.dummy_inputs
        processor_inputs = factory.get_dummy_processor_inputs(
            seq_len, mm_counts)

        return self.processor.apply(
156
            prompt=processor_inputs.prompt_text,
157
158
159
160
            mm_data=processor_inputs.mm_data,
            hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
        )

161
    def get_and_validate_mm_inputs(
162
163
        self,
        seq_len: int,
164
        mm_counts: Optional[Mapping[str, int]] = None,
165
    ) -> tuple[MultiModalInputs, Mapping[str, int]]:
166
167
        if mm_counts is None:
            mm_counts = self.get_mm_limits()
168
169

        info = self.processing_info
170
171
        mm_max_tokens_per_item = info.get_mm_max_tokens_per_item(
            seq_len, mm_counts)
172

173
        if mm_counts.keys() - mm_max_tokens_per_item.keys():
174
            raise AssertionError(
175
                "The keys returned by `get_supported_mm_limits` "
176
                f"({set(mm_counts.keys())}) should be a subset of those "
177
178
179
180
181
182
183
                "returned by `get_mm_max_tokens_per_item` "
                f"({set(mm_max_tokens_per_item.keys())})")

        mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
        placeholders_by_modality = mm_inputs["mm_placeholders"]

        total_placeholders_by_modality = {
184
            modality: sum(item.get_num_embeds() for item in placeholders)
185
186
187
188
189
190
191
192
193
194
195
196
            for modality, placeholders in placeholders_by_modality.items()
        }
        expected_placeholders_by_modality = {
            modality: mm_max_tokens_per_item[modality] * mm_counts[modality]
            for modality in placeholders_by_modality
        }
        if total_placeholders_by_modality != expected_placeholders_by_modality:
            raise AssertionError(
                f"The processed dummy data has a total of "
                f"{total_placeholders_by_modality} placeholder tokens, which "
                f"is not the expected {expected_placeholders_by_modality} "
                "tokens.")
197
198
        return mm_inputs, total_placeholders_by_modality

199
200
201
202
203
    def get_encoder_dummy_data(
        self,
        seq_len: int,
        mm_counts: Optional[Mapping[str, int]] = None,
    ) -> DummyEncoderData:
204
205
206
207
        (
            mm_inputs,
            total_placeholders_by_modality,
        ) = self.get_and_validate_mm_inputs(seq_len, mm_counts)
208
209
210
211
212
213
214
        mm_inputs = cast(MultiModalEncDecInputs, mm_inputs)

        # For encoder-decoder models, use encoder prompt token ids instead of
        # decoder prompt to construct dummy seq_data for encoder profiling.
        encoder_prompt_token_ids = mm_inputs["encoder_prompt_token_ids"]

        total_len = len(encoder_prompt_token_ids)
215
216
217
218

        # Encoder-decoder multimodal models only support v0
        if total_len > seq_len:
            # `max_num_batched_tokens` is defined by `SchedulerConfig`
219
            logger.warning_once(
220
                "The encoder sequence length used for profiling ("
221
222
                f"max_num_batched_tokens / max_num_seqs = {seq_len}) "
                " is too short "
223
                "to hold the multi-modal embeddings in the worst case "
224
225
                f"({total_len} tokens in total, out of which "
                f"{total_placeholders_by_modality} are reserved for "
226
227
228
229
                "multi-modal embeddings). This may cause certain "
                "multi-modal inputs to fail during inference, even when "
                "the input text is short. To avoid this, you should "
                "increase `max_model_len`, reduce `max_num_seqs`, "
230
                "and/or reduce `mm_counts`.")
231
232
233
234
235

        processor = cast(EncDecMultiModalProcessor, self.processor)
        if processor.pad_dummy_encoder_prompt:
            num_tokens_to_pad = max(total_len, seq_len) - total_len
            encoder_prompt_token_ids.extend([0] * num_tokens_to_pad)
236

237
        return DummyEncoderData(encoder_prompt_token_ids)
238

239
240
241
242
243
244
245
246
247
    def get_decoder_dummy_data(
        self,
        seq_len: int,
        mm_counts: Optional[Mapping[str, int]] = None,
    ) -> DummyDecoderData:
        (
            mm_inputs,
            total_placeholders_by_modality,
        ) = self.get_and_validate_mm_inputs(seq_len, mm_counts)
248

249
        prompt_token_ids = mm_inputs["prompt_token_ids"]
250
251
252
        total_len = len(prompt_token_ids)

        # V0 does not support chunked prefill.
253
        if total_len > seq_len and not envs.VLLM_USE_V1:
254
            # `max_num_batched_tokens` is defined by `SchedulerConfig`
255
            logger.warning_once(
256
                "The sequence length used for profiling ("
257
258
                f"max_num_batched_tokens / max_num_seqs = {seq_len}) "
                "is too short "
259
                "to hold the multi-modal embeddings in the worst case "
260
261
                f"({total_len} tokens in total, out of which "
                f"{total_placeholders_by_modality} are reserved for "
262
263
264
265
                "multi-modal embeddings). This may cause certain "
                "multi-modal inputs to fail during inference, even when "
                "the input text is short. To avoid this, you should "
                "increase `max_model_len`, reduce `max_num_seqs`, "
266
                "and/or reduce `mm_counts`.")
267

268
269
        if total_len < seq_len:
            prompt_token_ids.extend([0] * (seq_len - total_len))
270

271
272
        return DummyDecoderData(
            prompt_token_ids=prompt_token_ids,
273
            multi_modal_data=mm_inputs["mm_kwargs"],
274
            multi_modal_placeholders=mm_inputs["mm_placeholders"],
275
        )