profiling.py 8.96 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from abc import ABC
4
5
from collections.abc import Mapping
from dataclasses import dataclass, field
6
from typing import Generic, NamedTuple, Optional, TypeVar, cast
7
8
9
10
11

import numpy as np
import numpy.typing as npt
from PIL import Image

12
import vllm.envs as envs
13
14
from vllm.logger import init_logger

15
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
16
17
                     MultiModalInputs, MultiModalKwargs,
                     MultiModalPlaceholderDict)
18
19
from .processing import (BaseMultiModalProcessor, BaseProcessingInfo,
                         EncDecMultiModalProcessor)
20
21
22
23
24
25

logger = init_logger(__name__)


@dataclass
class ProcessorInputs:
26
27
28
29
    """
    Represents the keyword arguments to
    :meth:`vllm.multimodal.processing.BaseMultiModalProcessor.apply`.
    """
30
31
32
33
34
    prompt_text: str
    mm_data: MultiModalDataDict
    hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict)


35
36
37
38
39
40
41
42
43
44
45
46
47
48
class DummyEncoderData(NamedTuple):
    """Dummy data used for profiling."""

    prompt_token_ids: list[int]


class DummyDecoderData(NamedTuple):
    """Dummy data used for profiling."""

    prompt_token_ids: list[int]
    multi_modal_data: MultiModalKwargs
    multi_modal_placeholders: MultiModalPlaceholderDict


49
50
51
52
_I = TypeVar("_I", bound=BaseProcessingInfo)


class BaseDummyInputsBuilder(ABC, Generic[_I]):
53
    """
54
    Abstract base class that constructs the dummy data to profile
55
56
57
    multi-modal models.
    """

58
    def __init__(self, info: _I) -> None:
59
60
        super().__init__()

61
        self.info = info
62

63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
    # TODO: @abstractmethod after transition
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        """
        Build the text input corresponding to :code:`mm_counts`.
        """
        if (type(self).get_dummy_processor_inputs ==
                BaseDummyInputsBuilder.get_dummy_processor_inputs):
            raise NotImplementedError

        logger.warning_once("`get_dummy_processor_inputs` has been split up "
                            "into `get_dummy_text` and `get_dummy_mm_data`. "
                            "These two methods will be marked as abstract "
                            "in an upcoming release.")

        seq_len = self.info.ctx.model_config.max_model_len
        return self.get_dummy_processor_inputs(seq_len, mm_counts).prompt_text

    # TODO: @abstractmethod after transition
    def get_dummy_mm_data(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> MultiModalDataDict:
        """
        Build the multimodal input which, after processing, results in
        the maximum possible number of placeholder tokens.
        """
        raise NotImplementedError

92
93
94
95
96
97
    def get_dummy_processor_inputs(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> ProcessorInputs:
        """
98
        Build the input which, after processing, results in
99
        the maximum possible number of placeholder tokens.
100
        """
101
102
103
104
        dummy_text = self.get_dummy_text(mm_counts)
        dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts)

        return ProcessorInputs(prompt_text=dummy_text, mm_data=dummy_mm_data)
105
106
107
108
109
110
111

    def _get_dummy_audios(
        self,
        *,
        length: int,
        num_audios: int,
    ) -> list[npt.NDArray]:
112
113
        if num_audios == 0:
            return []
114
115
116
117
118
119
120
121
122
123
        audio = np.zeros((length, ))
        return [audio] * num_audios

    def _get_dummy_images(
        self,
        *,
        width: int,
        height: int,
        num_images: int,
    ) -> list[Image.Image]:
124
125
        if num_images == 0:
            return []
126
        image = Image.new("RGB", (width, height), color=255)
127
128
129
130
131
132
133
134
135
136
        return [image] * num_images

    def _get_dummy_videos(
        self,
        *,
        width: int,
        height: int,
        num_frames: int,
        num_videos: int,
    ) -> list[npt.NDArray]:
137
138
        if num_videos == 0:
            return []
139
        video = np.full((num_frames, width, height, 3), 255)
140
141
        return [video] * num_videos

142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163

class MultiModalProfiler(Generic[_I]):
    """
    Contains code for running memory profiling for multi-modal models.
    """

    def __init__(
        self,
        processor: BaseMultiModalProcessor[_I],
    ) -> None:
        super().__init__()

        self.processor = processor

    @property
    def processing_info(self) -> BaseProcessingInfo:
        return self.processor.info

    @property
    def dummy_inputs(self) -> BaseDummyInputsBuilder[_I]:
        return self.processor.dummy_inputs

164
    def get_mm_limits(self) -> Mapping[str, int]:
165
        return self.processing_info.get_allowed_mm_limits()
166
167
168
169

    def _get_dummy_mm_inputs(
        self,
        seq_len: int,
170
        mm_counts: Optional[Mapping[str, int]] = None,
171
    ) -> MultiModalInputs:
172
173
174
        if mm_counts is None:
            mm_counts = self.get_mm_limits()

175
176
177
178
179
        factory = self.dummy_inputs
        processor_inputs = factory.get_dummy_processor_inputs(
            seq_len, mm_counts)

        return self.processor.apply(
180
            prompt=processor_inputs.prompt_text,
181
182
183
184
            mm_data=processor_inputs.mm_data,
            hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
        )

185
    def _get_mm_num_tokens(
186
        self,
187
188
        mm_inputs: MultiModalInputs,
    ) -> Mapping[str, int]:
189
190
        placeholders_by_modality = mm_inputs["mm_placeholders"]

191
        return {
192
            modality: sum(item.get_num_embeds() for item in placeholders)
193
194
            for modality, placeholders in placeholders_by_modality.items()
        }
195

196
197
198
199
200
    def get_encoder_dummy_data(
        self,
        seq_len: int,
        mm_counts: Optional[Mapping[str, int]] = None,
    ) -> DummyEncoderData:
201
        mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
202
203
204
205
206
207
208
        mm_inputs = cast(MultiModalEncDecInputs, mm_inputs)

        # For encoder-decoder models, use encoder prompt token ids instead of
        # decoder prompt to construct dummy seq_data for encoder profiling.
        encoder_prompt_token_ids = mm_inputs["encoder_prompt_token_ids"]

        total_len = len(encoder_prompt_token_ids)
209

210
211
212
213
214
215
        processor = cast(EncDecMultiModalProcessor, self.processor)
        if processor.pad_dummy_encoder_prompt:
            num_tokens_to_pad = max(total_len, seq_len) - total_len
            encoder_prompt_token_ids.extend([0] * num_tokens_to_pad)
        # NOTE: Whisper allows total_len > seq_len.
        elif total_len > seq_len and not envs.VLLM_USE_V1:
216
            # `max_num_batched_tokens` is defined by `SchedulerConfig`
217
            logger.warning_once(
218
219
220
221
222
223
224
225
                "The encoder sequence length used for profiling (max_num_batched_tokens / max_num_seqs = %d) "  # noqa: E501
                "is too short to hold the multi-modal embeddings in the worst case (%d tokens in total, out of which %s are reserved for multi-modal embeddings). "  # noqa: E501
                "This may cause certain multi-modal inputs to fail during inference, even when the input text is short. "  # noqa: E501
                "To avoid this, you should increase `max_model_len`, reduce `max_num_seqs`, and/or reduce `mm_counts`.",  # noqa: E501
                seq_len,
                total_len,
                str(self._get_mm_num_tokens(mm_inputs)),
            )
226

227
        return DummyEncoderData(encoder_prompt_token_ids)
228

229
230
231
232
233
    def get_decoder_dummy_data(
        self,
        seq_len: int,
        mm_counts: Optional[Mapping[str, int]] = None,
    ) -> DummyDecoderData:
234
        mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
235

236
        prompt_token_ids = mm_inputs["prompt_token_ids"]
237
238
239
        total_len = len(prompt_token_ids)

        # V0 does not support chunked prefill.
240
        if total_len > seq_len and not envs.VLLM_USE_V1:
241
            # `max_num_batched_tokens` is defined by `SchedulerConfig`
242
            logger.warning_once(
243
244
245
246
247
248
249
250
                "The sequence length used for profiling (max_num_batched_tokens / max_num_seqs = %d) "  # noqa: E501
                "is too short to hold the multi-modal embeddings in the worst case (%d tokens in total, out of which %s are reserved for multi-modal embeddings). "  # noqa: E501
                "This may cause certain multi-modal inputs to fail during inference, even when the input text is short. "  # noqa: E501
                "To avoid this, you should increase `max_model_len`, reduce `max_num_seqs`, and/or reduce `mm_counts`.",  # noqa: E501
                seq_len,
                total_len,
                str(self._get_mm_num_tokens(mm_inputs)),
            )
251

252
253
        if total_len < seq_len:
            prompt_token_ids.extend([0] * (seq_len - total_len))
254

255
256
        return DummyDecoderData(
            prompt_token_ids=prompt_token_ids,
257
            multi_modal_data=mm_inputs["mm_kwargs"],
258
            multi_modal_placeholders=mm_inputs["mm_placeholders"],
259
        )
260
261
262
263
264
265
266
267
268

    def get_mm_max_tokens(
        self,
        seq_len: int,
        mm_counts: Optional[Mapping[str, int]] = None,
    ) -> Mapping[str, int]:
        mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)

        return self._get_mm_num_tokens(mm_inputs)