Unverified Commit 609ef61f authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Bugfix] Fix profiling OOM and decouple encoder multimodal profiling (#14361)


Signed-off-by: default avatarIsotr0py <2037008807@qq.com>
parent db84f5eb
...@@ -873,7 +873,7 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid): ...@@ -873,7 +873,7 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid):
exc_ctx = pytest.raises(ValueError, match="this model only supports") exc_ctx = pytest.raises(ValueError, match="this model only supports")
with exc_ctx: with exc_ctx:
profiler.get_dummy_data(model_config.max_model_len) profiler.get_decoder_dummy_data(model_config.max_model_len)
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"]) @pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
......
...@@ -335,8 +335,10 @@ class InputRegistry: ...@@ -335,8 +335,10 @@ class InputRegistry:
tokenizer, tokenizer,
disable_cache=True) disable_cache=True)
profiler = MultiModalProfiler(processor) profiler = MultiModalProfiler(processor)
dummy_data = profiler.get_dummy_data( dummy_data_factory = (profiler.get_encoder_dummy_data
seq_len, is_encoder_data=is_encoder_data) if is_encoder_data else
profiler.get_decoder_dummy_data)
dummy_data = dummy_data_factory(seq_len)
else: else:
model_cls, _ = get_model_architecture(model_config) model_cls, _ = get_model_architecture(model_config)
if is_encoder_data: if is_encoder_data:
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Mapping from collections.abc import Mapping
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Generic, TypeVar from typing import Generic, TypeVar, cast
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
...@@ -13,7 +13,8 @@ import vllm.envs as envs ...@@ -13,7 +13,8 @@ import vllm.envs as envs
from vllm.inputs import DummyData from vllm.inputs import DummyData
from vllm.logger import init_logger from vllm.logger import init_logger
from .inputs import MultiModalDataDict, MultiModalInputs from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
MultiModalInputs)
from .processing import BaseMultiModalProcessor, BaseProcessingInfo from .processing import BaseMultiModalProcessor, BaseProcessingInfo
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -142,14 +143,10 @@ class MultiModalProfiler(Generic[_I]): ...@@ -142,14 +143,10 @@ class MultiModalProfiler(Generic[_I]):
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs, hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
) )
def get_dummy_data( def get_and_validate_mm_inputs(
self, self,
seq_len: int, seq_len: int,
is_encoder_data: bool = False, ) -> tuple[MultiModalInputs, Mapping[str, int]]:
) -> DummyData:
# Avoid circular import
from vllm.sequence import SequenceData
mm_counts = self.get_mm_limits() mm_counts = self.get_mm_limits()
info = self.processing_info info = self.processing_info
...@@ -165,11 +162,6 @@ class MultiModalProfiler(Generic[_I]): ...@@ -165,11 +162,6 @@ class MultiModalProfiler(Generic[_I]):
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts) mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
placeholders_by_modality = mm_inputs["mm_placeholders"] placeholders_by_modality = mm_inputs["mm_placeholders"]
# For encoder-decoder models, use encoder prompt token ids instead of
# decoder prompt to construct dummy seq_data for encoder profiling.
prompt_token_ids = (
mm_inputs["prompt_token_ids"] if not is_encoder_data else
mm_inputs["encoder_prompt_token_ids"]) # type: ignore
total_placeholders_by_modality = { total_placeholders_by_modality = {
modality: sum(item["length"] for item in placeholders) modality: sum(item["length"] for item in placeholders)
...@@ -185,28 +177,60 @@ class MultiModalProfiler(Generic[_I]): ...@@ -185,28 +177,60 @@ class MultiModalProfiler(Generic[_I]):
f"{total_placeholders_by_modality} placeholder tokens, which " f"{total_placeholders_by_modality} placeholder tokens, which "
f"is not the expected {expected_placeholders_by_modality} " f"is not the expected {expected_placeholders_by_modality} "
"tokens.") "tokens.")
return mm_inputs, total_placeholders_by_modality
def get_encoder_dummy_data(
self,
seq_len: int,
) -> DummyData:
# Avoid circular import
from vllm.sequence import SequenceData
mm_inputs, _ = self.get_and_validate_mm_inputs(seq_len)
mm_inputs = cast(MultiModalEncDecInputs, mm_inputs)
# For encoder-decoder models, use encoder prompt token ids instead of
# decoder prompt to construct dummy seq_data for encoder profiling.
encoder_prompt_token_ids = mm_inputs["encoder_prompt_token_ids"]
total_len = len(encoder_prompt_token_ids)
num_tokens_to_pad = max(total_len, seq_len) - total_len
encoder_prompt_token_ids.extend([0] * num_tokens_to_pad)
return DummyData(
seq_data=SequenceData.from_seqs(encoder_prompt_token_ids),
multi_modal_data=None,
multi_modal_placeholders=None,
)
def get_decoder_dummy_data(
self,
seq_len: int,
) -> DummyData:
# Avoid circular import
from vllm.sequence import SequenceData
(mm_inputs, total_placeholders_by_modality
) = self.get_and_validate_mm_inputs(seq_len)
prompt_token_ids = mm_inputs["prompt_token_ids"]
total_len = len(prompt_token_ids) total_len = len(prompt_token_ids)
# V0 does not support chunked prefill. # V0 does not support chunked prefill.
if (total_len > seq_len and not envs.VLLM_USE_V1) or is_encoder_data: if total_len > seq_len and not envs.VLLM_USE_V1:
if total_len > seq_len and not is_encoder_data: logger.warning(
logger.warning( "The context length (%d) of the model is too short "
"The context length (%d) of the model is too short " "to hold the multi-modal embeddings in the worst case "
"to hold the multi-modal embeddings in the worst case " "(%d tokens in total, out of which %s are reserved for "
"(%d tokens in total, out of which %s are reserved for " "multi-modal embeddings). This may cause certain "
"multi-modal embeddings). This may cause certain " "multi-modal inputs to fail during inference, even when "
"multi-modal inputs to fail during inference, even when " "the input text is short. To avoid this, you should "
"the input text is short. To avoid this, you should " "increase `max_model_len`, reduce `max_num_seqs`, "
"increase `max_model_len`, reduce `max_num_seqs`, " "and/or reduce `mm_counts`.", seq_len, total_len,
"and/or reduce `mm_counts`.", seq_len, total_len, total_placeholders_by_modality)
total_placeholders_by_modality)
num_tokens_to_pad = max(total_len, seq_len) - total_len
prompt_token_ids.extend([0] * num_tokens_to_pad)
return DummyData( return DummyData(
seq_data=SequenceData.from_seqs(prompt_token_ids), seq_data=SequenceData.from_prompt_token_counts((0, seq_len)),
multi_modal_data=None, multi_modal_data=None,
multi_modal_placeholders=None, multi_modal_placeholders=None,
) )
...@@ -216,5 +240,5 @@ class MultiModalProfiler(Generic[_I]): ...@@ -216,5 +240,5 @@ class MultiModalProfiler(Generic[_I]):
return DummyData( return DummyData(
seq_data=SequenceData.from_seqs(prompt_token_ids), seq_data=SequenceData.from_seqs(prompt_token_ids),
multi_modal_data=mm_inputs["mm_kwargs"], multi_modal_data=mm_inputs["mm_kwargs"],
multi_modal_placeholders=placeholders_by_modality, multi_modal_placeholders=mm_inputs["mm_placeholders"],
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment