Unverified Commit 609ef61f authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Bugfix] Fix profiling OOM and decouple encoder multimodal profiling (#14361)


Signed-off-by: default avatarIsotr0py <2037008807@qq.com>
parent db84f5eb
......@@ -873,7 +873,7 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid):
exc_ctx = pytest.raises(ValueError, match="this model only supports")
with exc_ctx:
profiler.get_dummy_data(model_config.max_model_len)
profiler.get_decoder_dummy_data(model_config.max_model_len)
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
......
......@@ -335,8 +335,10 @@ class InputRegistry:
tokenizer,
disable_cache=True)
profiler = MultiModalProfiler(processor)
dummy_data = profiler.get_dummy_data(
seq_len, is_encoder_data=is_encoder_data)
dummy_data_factory = (profiler.get_encoder_dummy_data
if is_encoder_data else
profiler.get_decoder_dummy_data)
dummy_data = dummy_data_factory(seq_len)
else:
model_cls, _ = get_model_architecture(model_config)
if is_encoder_data:
......
......@@ -3,7 +3,7 @@
from abc import ABC, abstractmethod
from collections.abc import Mapping
from dataclasses import dataclass, field
from typing import Generic, TypeVar
from typing import Generic, TypeVar, cast
import numpy as np
import numpy.typing as npt
......@@ -13,7 +13,8 @@ import vllm.envs as envs
from vllm.inputs import DummyData
from vllm.logger import init_logger
from .inputs import MultiModalDataDict, MultiModalInputs
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
MultiModalInputs)
from .processing import BaseMultiModalProcessor, BaseProcessingInfo
logger = init_logger(__name__)
......@@ -142,14 +143,10 @@ class MultiModalProfiler(Generic[_I]):
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
)
def get_dummy_data(
def get_and_validate_mm_inputs(
self,
seq_len: int,
is_encoder_data: bool = False,
) -> DummyData:
# Avoid circular import
from vllm.sequence import SequenceData
) -> tuple[MultiModalInputs, Mapping[str, int]]:
mm_counts = self.get_mm_limits()
info = self.processing_info
......@@ -165,11 +162,6 @@ class MultiModalProfiler(Generic[_I]):
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
placeholders_by_modality = mm_inputs["mm_placeholders"]
# For encoder-decoder models, use encoder prompt token ids instead of
# decoder prompt to construct dummy seq_data for encoder profiling.
prompt_token_ids = (
mm_inputs["prompt_token_ids"] if not is_encoder_data else
mm_inputs["encoder_prompt_token_ids"]) # type: ignore
total_placeholders_by_modality = {
modality: sum(item["length"] for item in placeholders)
......@@ -185,12 +177,47 @@ class MultiModalProfiler(Generic[_I]):
f"{total_placeholders_by_modality} placeholder tokens, which "
f"is not the expected {expected_placeholders_by_modality} "
"tokens.")
return mm_inputs, total_placeholders_by_modality
def get_encoder_dummy_data(
self,
seq_len: int,
) -> DummyData:
# Avoid circular import
from vllm.sequence import SequenceData
mm_inputs, _ = self.get_and_validate_mm_inputs(seq_len)
mm_inputs = cast(MultiModalEncDecInputs, mm_inputs)
# For encoder-decoder models, use encoder prompt token ids instead of
# decoder prompt to construct dummy seq_data for encoder profiling.
encoder_prompt_token_ids = mm_inputs["encoder_prompt_token_ids"]
total_len = len(encoder_prompt_token_ids)
num_tokens_to_pad = max(total_len, seq_len) - total_len
encoder_prompt_token_ids.extend([0] * num_tokens_to_pad)
return DummyData(
seq_data=SequenceData.from_seqs(encoder_prompt_token_ids),
multi_modal_data=None,
multi_modal_placeholders=None,
)
def get_decoder_dummy_data(
self,
seq_len: int,
) -> DummyData:
# Avoid circular import
from vllm.sequence import SequenceData
(mm_inputs, total_placeholders_by_modality
) = self.get_and_validate_mm_inputs(seq_len)
prompt_token_ids = mm_inputs["prompt_token_ids"]
total_len = len(prompt_token_ids)
# V0 does not support chunked prefill.
if (total_len > seq_len and not envs.VLLM_USE_V1) or is_encoder_data:
if total_len > seq_len and not is_encoder_data:
if total_len > seq_len and not envs.VLLM_USE_V1:
logger.warning(
"The context length (%d) of the model is too short "
"to hold the multi-modal embeddings in the worst case "
......@@ -202,11 +229,8 @@ class MultiModalProfiler(Generic[_I]):
"and/or reduce `mm_counts`.", seq_len, total_len,
total_placeholders_by_modality)
num_tokens_to_pad = max(total_len, seq_len) - total_len
prompt_token_ids.extend([0] * num_tokens_to_pad)
return DummyData(
seq_data=SequenceData.from_seqs(prompt_token_ids),
seq_data=SequenceData.from_prompt_token_counts((0, seq_len)),
multi_modal_data=None,
multi_modal_placeholders=None,
)
......@@ -216,5 +240,5 @@ class MultiModalProfiler(Generic[_I]):
return DummyData(
seq_data=SequenceData.from_seqs(prompt_token_ids),
multi_modal_data=mm_inputs["mm_kwargs"],
multi_modal_placeholders=placeholders_by_modality,
multi_modal_placeholders=mm_inputs["mm_placeholders"],
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment