profiling.py 9.88 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from abc import ABC, abstractmethod
4
5
from collections.abc import Mapping
from dataclasses import dataclass, field
6
from typing import Generic, NamedTuple, Optional, TypeVar, Union, cast
7
8
9
10
11

import numpy as np
import numpy.typing as npt
from PIL import Image

12
import vllm.envs as envs
13
14
from vllm.logger import init_logger

15
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
16
17
                     MultiModalInputs, MultiModalKwargs,
                     MultiModalPlaceholderDict)
18
19
from .processing import (BaseMultiModalProcessor, BaseProcessingInfo,
                         EncDecMultiModalProcessor)
20
21
22
23
24
25

logger = init_logger(__name__)


@dataclass
class ProcessorInputs:
26
27
    """
    Represents the keyword arguments to
28
    [`vllm.multimodal.processing.BaseMultiModalProcessor.apply`][].
29
    """
30
    prompt: Union[str, list[int]]
31
32
    mm_data: MultiModalDataDict
    hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict)
33
    tokenization_kwargs: Mapping[str, object] = field(default_factory=dict)
34
35


36
37
38
39
40
41
42
43
44
45
46
47
48
49
class DummyEncoderData(NamedTuple):
    """Dummy data used for profiling."""

    prompt_token_ids: list[int]


class DummyDecoderData(NamedTuple):
    """Dummy data used for profiling."""

    prompt_token_ids: list[int]
    multi_modal_data: MultiModalKwargs
    multi_modal_placeholders: MultiModalPlaceholderDict


50
51
52
53
_I = TypeVar("_I", bound=BaseProcessingInfo)


class BaseDummyInputsBuilder(ABC, Generic[_I]):
54
    """
55
    Abstract base class that constructs the dummy data to profile
56
57
58
    multi-modal models.
    """

59
    def __init__(self, info: _I) -> None:
60
61
        super().__init__()

62
        self.info = info
63

64
    @abstractmethod
65
66
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        """
67
        Build the text input corresponding to `mm_counts`.
68
        """
69
        raise NotImplementedError
70

71
    @abstractmethod
72
73
74
75
76
77
78
79
80
81
82
    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> MultiModalDataDict:
        """
        Build the multimodal input which, after processing, results in
        the maximum possible number of placeholder tokens.
        """
        raise NotImplementedError

83
84
85
86
87
88
    def get_dummy_processor_inputs(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> ProcessorInputs:
        """
89
        Build the input which, after processing, results in
90
        the maximum possible number of placeholder tokens.
91
        """
92
93
        dummy_text = self.get_dummy_text(mm_counts)
        dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts)
94
        tokenization_kwargs = {"truncation": False}
95

96
97
98
        return ProcessorInputs(prompt=dummy_text,
                               mm_data=dummy_mm_data,
                               tokenization_kwargs=tokenization_kwargs)
99
100
101
102
103
104
105

    def _get_dummy_audios(
        self,
        *,
        length: int,
        num_audios: int,
    ) -> list[npt.NDArray]:
106
107
        if num_audios == 0:
            return []
108
109
110
111
112
113
114
115
116
117
        audio = np.zeros((length, ))
        return [audio] * num_audios

    def _get_dummy_images(
        self,
        *,
        width: int,
        height: int,
        num_images: int,
    ) -> list[Image.Image]:
118
119
        if num_images == 0:
            return []
120
        image = Image.new("RGB", (width, height), color=255)
121
122
123
124
125
126
127
128
129
130
        return [image] * num_images

    def _get_dummy_videos(
        self,
        *,
        width: int,
        height: int,
        num_frames: int,
        num_videos: int,
    ) -> list[npt.NDArray]:
131
132
        if num_videos == 0:
            return []
133
        video = np.full((num_frames, width, height, 3), 255)
134
135
        return [video] * num_videos

136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157

class MultiModalProfiler(Generic[_I]):
    """
    Contains code for running memory profiling for multi-modal models.
    """

    def __init__(
        self,
        processor: BaseMultiModalProcessor[_I],
    ) -> None:
        super().__init__()

        self.processor = processor

    @property
    def processing_info(self) -> BaseProcessingInfo:
        return self.processor.info

    @property
    def dummy_inputs(self) -> BaseDummyInputsBuilder[_I]:
        return self.processor.dummy_inputs

158
    def get_mm_limits(self) -> Mapping[str, int]:
159
        return self.processing_info.get_allowed_mm_limits()
160
161
162
163

    def _get_dummy_mm_inputs(
        self,
        seq_len: int,
164
        mm_counts: Optional[Mapping[str, int]] = None,
165
    ) -> MultiModalInputs:
166
167
168
        if mm_counts is None:
            mm_counts = self.get_mm_limits()

169
170
171
172
173
        factory = self.dummy_inputs
        processor_inputs = factory.get_dummy_processor_inputs(
            seq_len, mm_counts)

        return self.processor.apply(
174
            prompt=processor_inputs.prompt,
175
176
            mm_data=processor_inputs.mm_data,
            hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
177
            tokenization_kwargs=processor_inputs.tokenization_kwargs,
178
179
        )

180
    def _get_mm_num_tokens(
181
        self,
182
183
        mm_inputs: MultiModalInputs,
    ) -> Mapping[str, int]:
184
185
        placeholders_by_modality = mm_inputs["mm_placeholders"]

186
        return {
187
            modality: sum(item.get_num_embeds() for item in placeholders)
188
189
            for modality, placeholders in placeholders_by_modality.items()
        }
190

191
192
193
194
195
    def get_encoder_dummy_data(
        self,
        seq_len: int,
        mm_counts: Optional[Mapping[str, int]] = None,
    ) -> DummyEncoderData:
196
        mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
197
198
199
200
201
202
203
        mm_inputs = cast(MultiModalEncDecInputs, mm_inputs)

        # For encoder-decoder models, use encoder prompt token ids instead of
        # decoder prompt to construct dummy seq_data for encoder profiling.
        encoder_prompt_token_ids = mm_inputs["encoder_prompt_token_ids"]

        total_len = len(encoder_prompt_token_ids)
204

205
206
207
208
209
210
        processor = cast(EncDecMultiModalProcessor, self.processor)
        if processor.pad_dummy_encoder_prompt:
            num_tokens_to_pad = max(total_len, seq_len) - total_len
            encoder_prompt_token_ids.extend([0] * num_tokens_to_pad)
        # NOTE: Whisper allows total_len > seq_len.
        elif total_len > seq_len and not envs.VLLM_USE_V1:
211
            # `max_num_batched_tokens` is defined by `SchedulerConfig`
212
            logger.warning_once(
213
214
215
216
217
218
219
220
                "The encoder sequence length used for profiling (max_num_batched_tokens / max_num_seqs = %d) "  # noqa: E501
                "is too short to hold the multi-modal embeddings in the worst case (%d tokens in total, out of which %s are reserved for multi-modal embeddings). "  # noqa: E501
                "This may cause certain multi-modal inputs to fail during inference, even when the input text is short. "  # noqa: E501
                "To avoid this, you should increase `max_model_len`, reduce `max_num_seqs`, and/or reduce `mm_counts`.",  # noqa: E501
                seq_len,
                total_len,
                str(self._get_mm_num_tokens(mm_inputs)),
            )
221

222
        return DummyEncoderData(encoder_prompt_token_ids)
223

224
225
226
227
228
    def get_decoder_dummy_data(
        self,
        seq_len: int,
        mm_counts: Optional[Mapping[str, int]] = None,
    ) -> DummyDecoderData:
229
        mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
230

231
        prompt_token_ids = mm_inputs["prompt_token_ids"]
232
233
234
        total_len = len(prompt_token_ids)

        # V0 does not support chunked prefill.
235
        if total_len > seq_len and not envs.VLLM_USE_V1:
236
            # `max_num_batched_tokens` is defined by `SchedulerConfig`
237
            logger.warning_once(
238
239
240
241
242
243
244
245
                "The sequence length used for profiling (max_num_batched_tokens / max_num_seqs = %d) "  # noqa: E501
                "is too short to hold the multi-modal embeddings in the worst case (%d tokens in total, out of which %s are reserved for multi-modal embeddings). "  # noqa: E501
                "This may cause certain multi-modal inputs to fail during inference, even when the input text is short. "  # noqa: E501
                "To avoid this, you should increase `max_model_len`, reduce `max_num_seqs`, and/or reduce `mm_counts`.",  # noqa: E501
                seq_len,
                total_len,
                str(self._get_mm_num_tokens(mm_inputs)),
            )
246

247
248
        if total_len < seq_len:
            prompt_token_ids.extend([0] * (seq_len - total_len))
249

250
251
        return DummyDecoderData(
            prompt_token_ids=prompt_token_ids,
252
            multi_modal_data=mm_inputs["mm_kwargs"],
253
            multi_modal_placeholders=mm_inputs["mm_placeholders"],
254
        )
255
256
257
258
259
260

    def get_mm_max_tokens(
        self,
        seq_len: int,
        mm_counts: Optional[Mapping[str, int]] = None,
    ) -> Mapping[str, int]:
261
262
263
264
265
266
267
        if mm_counts is None:
            mm_counts = self.get_mm_limits()

        max_tokens_per_item = self.processing_info.get_mm_max_tokens_per_item(
            seq_len=seq_len,
            mm_counts=mm_counts,
        )
268
269
270
271
272
273
274
275
276
277
        if max_tokens_per_item is not None:
            if mm_counts is None:
                total_mm_tokens = sum(max_tokens_per_item.values())
            else:
                total_mm_tokens = sum(max_tokens_per_item[k] * mm_counts[k]
                                      for k in max_tokens_per_item.keys()
                                      & mm_counts.keys())
            if total_mm_tokens > seq_len:
                logger.warning_once(
                    "The sequence length (%d) is smaller than the pre-defined"
278
                    " worst-case total number of multimodal tokens (%d). "
279
280
281
282
283
284
285
                    "This may cause certain multi-modal inputs to fail during "
                    "inference. To avoid this, you should increase "
                    "`max_model_len` or reduce `mm_counts`.",
                    seq_len,
                    total_mm_tokens,
                )
            return max_tokens_per_item
286

287
        mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
288
        return self._get_mm_num_tokens(mm_inputs)