profiling.py 9.64 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from abc import ABC
4
5
from collections.abc import Mapping
from dataclasses import dataclass, field
6
from typing import Generic, NamedTuple, Optional, TypeVar, cast
7
8
9
10
11

import numpy as np
import numpy.typing as npt
from PIL import Image

12
import vllm.envs as envs
13
14
from vllm.logger import init_logger

15
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
16
17
                     MultiModalInputs, MultiModalKwargs,
                     MultiModalPlaceholderDict)
18
19
from .processing import (BaseMultiModalProcessor, BaseProcessingInfo,
                         EncDecMultiModalProcessor)
20
21
22
23
24
25

logger = init_logger(__name__)


@dataclass
class ProcessorInputs:
26
27
28
29
    """
    Represents the keyword arguments to
    :meth:`vllm.multimodal.processing.BaseMultiModalProcessor.apply`.
    """
30
31
32
33
34
    prompt_text: str
    mm_data: MultiModalDataDict
    hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict)


35
36
37
38
39
40
41
42
43
44
45
46
47
48
class DummyEncoderData(NamedTuple):
    """Dummy data used for profiling."""

    prompt_token_ids: list[int]


class DummyDecoderData(NamedTuple):
    """Dummy data used for profiling."""

    prompt_token_ids: list[int]
    multi_modal_data: MultiModalKwargs
    multi_modal_placeholders: MultiModalPlaceholderDict


49
50
51
52
_I = TypeVar("_I", bound=BaseProcessingInfo)


class BaseDummyInputsBuilder(ABC, Generic[_I]):
53
    """
54
    Abstract base class that constructs the dummy data to profile
55
56
57
    multi-modal models.
    """

58
    def __init__(self, info: _I) -> None:
59
60
        super().__init__()

61
        self.info = info
62

63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
    # TODO: @abstractmethod after transition
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        """
        Build the text input corresponding to :code:`mm_counts`.
        """
        if (type(self).get_dummy_processor_inputs ==
                BaseDummyInputsBuilder.get_dummy_processor_inputs):
            raise NotImplementedError

        logger.warning_once("`get_dummy_processor_inputs` has been split up "
                            "into `get_dummy_text` and `get_dummy_mm_data`. "
                            "These two methods will be marked as abstract "
                            "in an upcoming release.")

        seq_len = self.info.ctx.model_config.max_model_len
        return self.get_dummy_processor_inputs(seq_len, mm_counts).prompt_text

    # TODO: @abstractmethod after transition
    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> MultiModalDataDict:
        """
        Build the multimodal input which, after processing, results in
        the maximum possible number of placeholder tokens.
        """
        raise NotImplementedError

92
93
94
95
96
97
    def get_dummy_processor_inputs(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> ProcessorInputs:
        """
98
        Build the input which, after processing, results in
99
        the maximum possible number of placeholder tokens.
100
        """
101
102
103
104
        dummy_text = self.get_dummy_text(mm_counts)
        dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts)

        return ProcessorInputs(prompt_text=dummy_text, mm_data=dummy_mm_data)
105
106
107
108
109
110
111

    def _get_dummy_audios(
        self,
        *,
        length: int,
        num_audios: int,
    ) -> list[npt.NDArray]:
112
113
        if num_audios == 0:
            return []
114
115
116
117
118
119
120
121
122
123
        audio = np.zeros((length, ))
        return [audio] * num_audios

    def _get_dummy_images(
        self,
        *,
        width: int,
        height: int,
        num_images: int,
    ) -> list[Image.Image]:
124
125
        if num_images == 0:
            return []
126
        image = Image.new("RGB", (width, height), color=255)
127
128
129
130
131
132
133
134
135
136
        return [image] * num_images

    def _get_dummy_videos(
        self,
        *,
        width: int,
        height: int,
        num_frames: int,
        num_videos: int,
    ) -> list[npt.NDArray]:
137
138
        if num_videos == 0:
            return []
139
        video = np.full((num_frames, width, height, 3), 255)
140
141
        return [video] * num_videos

142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163

class MultiModalProfiler(Generic[_I]):
    """
    Contains code for running memory profiling for multi-modal models.
    """

    def __init__(
        self,
        processor: BaseMultiModalProcessor[_I],
    ) -> None:
        super().__init__()

        self.processor = processor

    @property
    def processing_info(self) -> BaseProcessingInfo:
        return self.processor.info

    @property
    def dummy_inputs(self) -> BaseDummyInputsBuilder[_I]:
        return self.processor.dummy_inputs

164
    def get_mm_limits(self) -> Mapping[str, int]:
165
166
        mm_config = self.processing_info.ctx.get_mm_config()
        supported_mm_limits = self.processing_info.get_supported_mm_limits()
167
168

        mm_limits = {
169
            modality: mm_config.get_limit_per_prompt(modality)
170
171
172
173
174
175
176
177
178
179
180
181
            for modality in supported_mm_limits
        }

        for modality, supported_limit in supported_mm_limits.items():
            limit = mm_limits[modality]
            if supported_limit is not None and supported_limit < limit:
                raise ValueError(
                    f"You set {modality}={limit} (or defaulted to 1) in "
                    f"`--limit-mm-per-prompt`, but this model only supports "
                    f"at most {supported_limit} {modality} items.")

        return mm_limits
182
183
184
185

    def _get_dummy_mm_inputs(
        self,
        seq_len: int,
186
        mm_counts: Optional[Mapping[str, int]] = None,
187
    ) -> MultiModalInputs:
188
189
190
        if mm_counts is None:
            mm_counts = self.get_mm_limits()

191
192
193
194
195
        factory = self.dummy_inputs
        processor_inputs = factory.get_dummy_processor_inputs(
            seq_len, mm_counts)

        return self.processor.apply(
196
            prompt=processor_inputs.prompt_text,
197
198
199
200
            mm_data=processor_inputs.mm_data,
            hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
        )

201
    def _get_mm_num_tokens(
202
        self,
203
204
        mm_inputs: MultiModalInputs,
    ) -> Mapping[str, int]:
205
206
        placeholders_by_modality = mm_inputs["mm_placeholders"]

207
        return {
208
            modality: sum(item.get_num_embeds() for item in placeholders)
209
210
            for modality, placeholders in placeholders_by_modality.items()
        }
211

212
213
214
215
216
    def get_encoder_dummy_data(
        self,
        seq_len: int,
        mm_counts: Optional[Mapping[str, int]] = None,
    ) -> DummyEncoderData:
217
        mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
218
219
220
221
222
223
224
        mm_inputs = cast(MultiModalEncDecInputs, mm_inputs)

        # For encoder-decoder models, use encoder prompt token ids instead of
        # decoder prompt to construct dummy seq_data for encoder profiling.
        encoder_prompt_token_ids = mm_inputs["encoder_prompt_token_ids"]

        total_len = len(encoder_prompt_token_ids)
225

226
227
228
229
230
231
        processor = cast(EncDecMultiModalProcessor, self.processor)
        if processor.pad_dummy_encoder_prompt:
            num_tokens_to_pad = max(total_len, seq_len) - total_len
            encoder_prompt_token_ids.extend([0] * num_tokens_to_pad)
        # NOTE: Whisper allows total_len > seq_len.
        elif total_len > seq_len and not envs.VLLM_USE_V1:
232
            # `max_num_batched_tokens` is defined by `SchedulerConfig`
233
            logger.warning_once(
234
                "The encoder sequence length used for profiling ("
235
236
                f"max_num_batched_tokens / max_num_seqs = {seq_len}) "
                " is too short "
237
                "to hold the multi-modal embeddings in the worst case "
238
                f"({total_len} tokens in total, out of which "
239
                f"{self._get_mm_num_tokens(mm_inputs)} are reserved for "
240
241
242
243
                "multi-modal embeddings). This may cause certain "
                "multi-modal inputs to fail during inference, even when "
                "the input text is short. To avoid this, you should "
                "increase `max_model_len`, reduce `max_num_seqs`, "
244
                "and/or reduce `mm_counts`.")
245

246
        return DummyEncoderData(encoder_prompt_token_ids)
247

248
249
250
251
252
    def get_decoder_dummy_data(
        self,
        seq_len: int,
        mm_counts: Optional[Mapping[str, int]] = None,
    ) -> DummyDecoderData:
253
        mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
254

255
        prompt_token_ids = mm_inputs["prompt_token_ids"]
256
257
258
        total_len = len(prompt_token_ids)

        # V0 does not support chunked prefill.
259
        if total_len > seq_len and not envs.VLLM_USE_V1:
260
            # `max_num_batched_tokens` is defined by `SchedulerConfig`
261
            logger.warning_once(
262
                "The sequence length used for profiling ("
263
264
                f"max_num_batched_tokens / max_num_seqs = {seq_len}) "
                "is too short "
265
                "to hold the multi-modal embeddings in the worst case "
266
                f"({total_len} tokens in total, out of which "
267
                f"{self._get_mm_num_tokens(mm_inputs)} are reserved for "
268
269
270
271
                "multi-modal embeddings). This may cause certain "
                "multi-modal inputs to fail during inference, even when "
                "the input text is short. To avoid this, you should "
                "increase `max_model_len`, reduce `max_num_seqs`, "
272
                "and/or reduce `mm_counts`.")
273

274
275
        if total_len < seq_len:
            prompt_token_ids.extend([0] * (seq_len - total_len))
276

277
278
        return DummyDecoderData(
            prompt_token_ids=prompt_token_ids,
279
            multi_modal_data=mm_inputs["mm_kwargs"],
280
            multi_modal_placeholders=mm_inputs["mm_placeholders"],
281
        )
282
283
284
285
286
287
288
289
290

    def get_mm_max_tokens(
        self,
        seq_len: int,
        mm_counts: Optional[Mapping[str, int]] = None,
    ) -> Mapping[str, int]:
        mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)

        return self._get_mm_num_tokens(mm_inputs)