You need to sign in or sign up before continuing.
Commit c1cacde6 authored by weishb's avatar weishb
Browse files

vllm-omni_0.15.0.rc1+fix1 first commit

parent 35607782
from vllm_omni.inputs.preprocess import OmniInputPreprocessor
def _make_preprocessor(monkeypatch):
preprocessor = object.__new__(OmniInputPreprocessor)
monkeypatch.setattr(preprocessor, "_truncate_inputs", lambda tokens, tokenization_kwargs=None: tokens)
monkeypatch.setattr(
preprocessor,
"_process_multimodal",
lambda *args, **kwargs: {"prompt_token_ids": [1, 2, 3]},
)
monkeypatch.setattr(preprocessor, "_tokenize_prompt", lambda prompt_text, tokenization_kwargs=None: [9, 8, 7])
return preprocessor
def test_process_tokens_keeps_additional_information(monkeypatch):
preprocessor = _make_preprocessor(monkeypatch)
parsed = {
"prompt_token_ids": [1, 2, 3],
"prompt_embeds": "embeds",
"additional_information": {"task": ["tts"], "lang": ["auto"]},
}
inputs = OmniInputPreprocessor._process_tokens(preprocessor, parsed)
assert inputs["prompt_token_ids"] == [1, 2, 3]
assert inputs["prompt_embeds"] == "embeds"
assert inputs["additional_information"] == {"task": ["tts"], "lang": ["auto"]}
def test_process_text_keeps_additional_information(monkeypatch):
preprocessor = _make_preprocessor(monkeypatch)
parsed = {
"prompt": "hello",
"prompt_embeds": "embeds",
"additional_information": {"speaker": ["alice"]},
}
inputs = OmniInputPreprocessor._process_text(preprocessor, parsed)
assert inputs["prompt_token_ids"] == [9, 8, 7]
assert inputs["prompt_embeds"] == "embeds"
assert inputs["additional_information"] == {"speaker": ["alice"]}
def test_process_text_multimodal_skips_empty_payloads(monkeypatch):
preprocessor = _make_preprocessor(monkeypatch)
parsed = {
"prompt": "hello",
"multi_modal_data": {"image": "fake"},
"prompt_embeds": None,
"additional_information": None,
}
inputs = OmniInputPreprocessor._process_text(preprocessor, parsed)
assert inputs["prompt_token_ids"] == [1, 2, 3]
assert "prompt_embeds" not in inputs
assert "additional_information" not in inputs
import uuid
import warnings
from queue import Empty, Queue
from typing import Any
from unittest.mock import MagicMock
import pytest
from vllm import SamplingParams
from vllm_omni.entrypoints.stage_utils import SHUTDOWN_TASK
# Suppress noisy DeprecationWarnings from optional Swig bindings imported by vLLM dependencies.
warnings.filterwarnings(
"ignore",
message=r"builtin type SwigPy.*has no __module__ attribute",
category=DeprecationWarning,
)
class _FakeEngineArgs(dict):
"""Fake engine args that can be used both as object attributes and as **kwargs."""
def __init__(self, args_dict: dict[str, Any]):
super().__init__(args_dict)
# Add required attributes if not present
if "model_stage" not in self:
self["model_stage"] = None
if "engine_output_type" not in self:
self["engine_output_type"] = None
# Also set as attributes for object-style access
for key, value in self.items():
setattr(self, key, value)
class _FakeStageConfig:
"""Fake stage config object that mimics the real stage config structure."""
def __init__(self, config_dict: dict[str, Any]):
# engine_args needs to work both as object (for OmniStage) and as dict (for **kwargs)
engine_args_dict = config_dict.get("engine_args", {})
self.engine_args = _FakeEngineArgs(engine_args_dict)
self.final_output = config_dict.get("final_output", False)
self.final_output_type = config_dict.get("final_output_type", None)
self.stage_id = config_dict.get("stage_id", 0)
# Store original dict for reference
self._config_dict = config_dict
class _FakeQueue:
"""Fake queue using standard library Queue to replace mp.Queue."""
def __init__(self, maxsize=0):
self._queue = Queue(maxsize=maxsize)
def put(self, item):
self._queue.put(item)
def put_nowait(self, item):
self._queue.put_nowait(item)
def get(self):
return self._queue.get()
def get_nowait(self):
return self._queue.get_nowait()
def empty(self):
return self._queue.empty()
class _FakeStage:
"""Lightweight Stage stub for multi-process pipeline version with queue support."""
def __init__(self, config, stage_init_timeout: int = 300):
# Handle both dict and object configs
if isinstance(config, dict):
config = _FakeStageConfig(config)
self.config = config
self.stage_config = config
self.engine = None
self.engine_outputs = None
# Set attributes that OmniStage expects
self.stage_id = getattr(config, "stage_id", 0)
self.engine_args = config.engine_args
self.model_stage = getattr(config.engine_args, "model_stage", None)
self.stage_type = "llm"
# set default sampling params
self.default_sampling_params = SamplingParams(temperature=1.0)
# Allow configuring final_output and final_output_type
self.final_output = config.final_output if hasattr(config, "final_output") else False
self.final_output_type = getattr(config, "final_output_type", None)
# Configurable processing logic, default returns placeholder
processed_input = getattr(config, "_config_dict", {}).get("processed_input", ["processed"])
self._processed_input = processed_input
# Queue references (set by attach_queues)
self._in_q = None
self._out_q = None
self._proc = None # Mock process reference
self._stage_init_timeout = max(0, int(stage_init_timeout))
def attach_queues(self, in_q, out_q):
"""Attach input and output queues."""
self._in_q = in_q
self._out_q = out_q
def init_stage_worker(
self,
model: str,
*,
is_async: bool = False,
shm_threshold_bytes: int = 65536,
ctx=None,
batch_timeout: int = 10,
**kwargs,
):
"""Mock init_stage_worker: don't start real process, just send stage_ready message."""
# Create a mock process object
self._proc = MagicMock()
self._proc.start = MagicMock()
self._proc.join = MagicMock()
self._proc.is_alive = MagicMock(return_value=False)
self._proc.terminate = MagicMock()
# Send stage_ready message to output queue
if self._out_q is not None:
try:
self._out_q.put_nowait({"type": "stage_ready", "stage_id": self.stage_id})
except Exception:
pass
def stop_stage_worker(self):
"""Mock stop_stage_worker: clean up queue references."""
if self._in_q is not None:
try:
self._in_q.put_nowait(SHUTDOWN_TASK)
except Exception:
pass
def submit(self, payload: dict[str, Any]):
"""Submit task to input queue."""
if self._in_q is not None:
self._in_q.put(payload)
def try_collect(self) -> Any:
"""Non-blocking collect from output queue."""
if self._out_q is None:
return None
try:
return self._out_q.get_nowait()
except Empty:
return None
def set_engine_outputs(self, outputs):
"""Set engine outputs for the stage."""
self.engine_outputs = outputs
def process_engine_inputs(self, stage_list, prompts):
"""Process engine inputs: return preset processed result."""
return self._processed_input
class _FakeEngine:
"""Lightweight Engine stub: provides generate iterator output."""
def __init__(self, outputs: list[Any]):
self._outputs = outputs
def generate(self, prompts, sampling_params):
# Record the most recent prompts for outer assertions
self._last_prompts = prompts
# Simplified: return preset list at once, ensuring iterability
yield from self._outputs
@pytest.fixture
def fake_stage_config():
return {
# Don't include 'model' in engine_args since it's passed separately
"engine_args": {},
"final_output": True,
"final_output_type": "text",
# Second stage will use processed_input to verify the chain
"processed_input": ["processed-by-stage"],
}
def _setup_engine_mocks(monkeypatch):
"""Helper function to set up common engine mocks."""
fake_engine = MagicMock()
# Add necessary attributes to fake_engine
fake_engine.tokenizer = MagicMock()
fake_engine.log_stats = False
fake_engine.vllm_config = MagicMock()
fake_engine.vllm_config.model_config = MagicMock()
fake_engine.vllm_config.model_config.io_processor_plugin = None
fake_engine.get_supported_tasks = MagicMock(return_value=[])
fake_engine.model_config = MagicMock()
fake_engine.model_config.io_processor_plugin = None
# Add registry with resolve_model_cls method
fake_registry = MagicMock()
fake_registry.resolve_model_cls = MagicMock(return_value=(MagicMock(), "test_arch"))
fake_engine.model_config.registry = fake_registry
fake_engine.vllm_config.model_config.registry = fake_registry
monkeypatch.setattr(
"vllm.v1.engine.llm_engine.LLMEngine.from_engine_args",
lambda **kw: fake_engine,
raising=False,
)
# Mock model_config.registry.resolve_model_cls to return a tuple
# Use a real class instead of MagicMock to avoid inspect.getsource issues
class FakeModelClass:
pass
monkeypatch.setattr(
"vllm.model_executor.model_loader.utils.get_model_architecture",
lambda model_config: (FakeModelClass, "test_arch"),
raising=False,
)
monkeypatch.setattr(
"vllm.model_executor.model_loader.utils._get_model_architecture",
lambda model_config: (FakeModelClass, "test_arch"),
raising=False,
)
# Mock try_create_mm_pooling_model_cls to return the class as-is
monkeypatch.setattr(
"vllm.model_executor.models.adapters.try_create_mm_pooling_model_cls",
lambda model_cls: model_cls,
raising=False,
)
# Mock _enable_processor_cache to return False
monkeypatch.setattr(
"vllm.multimodal.cache._enable_processor_cache",
lambda model_config, mm_registry: False,
raising=False,
)
# Mock get_io_processor to return None
monkeypatch.setattr(
"vllm.plugins.io_processors.get_io_processor",
lambda vllm_config, io_processor_plugin: None,
raising=False,
)
def _setup_multiprocessing_mocks(monkeypatch):
"""Helper function to set up multiprocessing mocks."""
import multiprocessing as mp
# Mock Process
fake_process_class = MagicMock()
fake_process_instance = MagicMock()
fake_process_instance.start = MagicMock()
fake_process_instance.join = MagicMock()
fake_process_instance.is_alive = MagicMock(return_value=False)
fake_process_instance.terminate = MagicMock()
fake_process_class.return_value = fake_process_instance
# Mock get_context to return a context with Queue that returns _FakeQueue
fake_ctx = MagicMock()
fake_ctx.Queue = lambda maxsize=0: _FakeQueue(maxsize=maxsize)
fake_ctx.Process = fake_process_class
def _mock_get_context(method):
return fake_ctx
monkeypatch.setattr(mp, "get_context", _mock_get_context, raising=False)
monkeypatch.setattr(mp, "Process", fake_process_class, raising=False)
def _setup_ipc_mocks(monkeypatch):
"""Helper function to set up IPC function mocks."""
# Mock _encode: simple serialization
def _fake_encode(obj, threshold, obj_key, shm_key):
return {obj_key: obj}
# Mock _load: extract object from result
def _fake_load(result, obj_key, shm_key):
return result.get(obj_key)
# Mock _set: calculate serialization size
def _fake_set(obj):
return str(obj).encode()
monkeypatch.setattr("vllm_omni.entrypoints.omni._encode", _fake_encode, raising=False)
monkeypatch.setattr("vllm_omni.entrypoints.omni._load", _fake_load, raising=False)
monkeypatch.setattr("vllm_omni.entrypoints.omni._set", _fake_set, raising=False)
def _setup_log_mocks(monkeypatch):
"""Helper function to set up logging and stats mocks."""
# Mock OrchestratorMetrics to be a simple class that doesn't require file operations
class _FakeOrchestratorMetrics:
def __init__(self, num_stages, enable_stats, wall_start_ts):
self.num_stages = num_stages
self.enable_stats = enable_stats
self.stage_first_ts = [None] * num_stages
self.stage_last_ts = [None] * num_stages
self.e2e_done = set()
def on_stage_metrics(self, stage_id, req_id, metrics):
pass
def on_finalize_request(self, stage_id, req_id, start_ts):
self.e2e_done.add(req_id)
def on_forward(self, from_stage, to_stage, req_id, size_bytes, tx_ms, use_shm):
pass
def build_and_log_summary(self, final_stage_id):
return "Fake summary"
monkeypatch.setattr(
"vllm_omni.entrypoints.omni.OrchestratorMetrics",
_FakeOrchestratorMetrics,
raising=False,
)
@pytest.fixture(autouse=True)
def mock_get_config(monkeypatch):
"""Auto-mock get_config and related model loading functions to avoid model path validation."""
# CRITICAL: Mock tokenizer-related imports FIRST, before any module imports
# This prevents ImportError when async_omni is imported (which happens via omni_stage)
import sys
fake_tokenizer = MagicMock()
fake_tokenizer.encode = MagicMock(return_value=[1, 2, 3])
fake_tokenizer.decode = MagicMock(return_value="test")
# Mock init_tokenizer_from_configs (used in async_omni)
def _mock_init_tokenizer_from_configs(model_config=None, **kwargs):
return fake_tokenizer
# Strategy 1: Mock in the original location (vllm.transformers_utils.tokenizer)
# This works if the module hasn't been imported yet
monkeypatch.setattr(
"vllm.transformers_utils.tokenizer.init_tokenizer_from_configs",
_mock_init_tokenizer_from_configs,
raising=False,
)
# Strategy 2: If the module is already in sys.modules, patch it directly
tokenizer_module_path = "vllm.transformers_utils.tokenizer"
if tokenizer_module_path in sys.modules:
tokenizer_module = sys.modules[tokenizer_module_path]
setattr(tokenizer_module, "init_tokenizer_from_configs", _mock_init_tokenizer_from_configs)
# CRITICAL: Mock length_from_prompt_token_ids_or_embeds BEFORE trying to mock async_omni
# This is because async_omni imports processor.py, which imports this function at module level
# Mock length_from_prompt_token_ids_or_embeds (used in processor.py)
def _mock_length_from_prompt_token_ids_or_embeds(prompt_token_ids=None, prompt_embeds=None):
# Return a reasonable default length
if prompt_token_ids is not None:
if isinstance(prompt_token_ids, list):
return len(prompt_token_ids)
elif hasattr(prompt_token_ids, "shape"):
return prompt_token_ids.shape[-1] if len(prompt_token_ids.shape) > 0 else 1
if prompt_embeds is not None:
if hasattr(prompt_embeds, "shape"):
return prompt_embeds.shape[-2] if len(prompt_embeds.shape) > 1 else 1
return 10 # Default length
# Mock in vllm.utils
monkeypatch.setattr(
"vllm.utils.length_from_prompt_token_ids_or_embeds",
_mock_length_from_prompt_token_ids_or_embeds,
raising=False,
)
# Also mock in processor module if it's imported
monkeypatch.setattr(
"vllm_omni.engine.input_processor.length_from_prompt_token_ids_or_embeds",
_mock_length_from_prompt_token_ids_or_embeds,
raising=False,
)
# If processor module is already imported, patch it directly
processor_module_path = "vllm_omni.engine.input_processor"
if processor_module_path in sys.modules:
processor_module = sys.modules[processor_module_path]
setattr(
processor_module, "length_from_prompt_token_ids_or_embeds", _mock_length_from_prompt_token_ids_or_embeds
)
# Strategy 3: Now mock async_omni AFTER length_from_prompt_token_ids_or_embeds is mocked
# This prevents ImportError when async_omni imports processor.py
monkeypatch.setattr(
"vllm_omni.entrypoints.async_omni.init_tokenizer_from_configs",
_mock_init_tokenizer_from_configs,
raising=False,
)
# Strategy 4: If async_omni is already imported, patch it directly
async_omni_path = "vllm_omni.entrypoints.async_omni"
if async_omni_path in sys.modules:
async_omni_module = sys.modules[async_omni_path]
setattr(async_omni_module, "init_tokenizer_from_configs", _mock_init_tokenizer_from_configs)
# Now mock get_config and other functions
fake_hf_config = MagicMock()
fake_hf_config.model_type = "qwen2_5_omni"
def _mock_get_config(model, **kwargs):
return fake_hf_config
monkeypatch.setattr(
"vllm.transformers_utils.config.get_config",
_mock_get_config,
raising=False,
)
monkeypatch.setattr(
"vllm_omni.entrypoints.utils.get_config",
_mock_get_config,
raising=False,
)
# Mock transformers' cached_file to avoid downloading model configs
def _mock_cached_file(path_or_repo_id, *args, **kwargs):
import os
import tempfile
fake_config_file = os.path.join(tempfile.gettempdir(), "fake_config.json")
if not os.path.exists(fake_config_file):
with open(fake_config_file, "w") as f:
f.write('{"model_type": "qwen2_5_omni"}')
return fake_config_file
monkeypatch.setattr(
"transformers.utils.hub.cached_file",
_mock_cached_file,
raising=False,
)
monkeypatch.setattr(
"transformers.utils.hub.cached_files",
lambda path_or_repo_id, filenames, **kwargs: (
[_mock_cached_file(path_or_repo_id, filenames[0])] if filenames else None
),
raising=False,
)
def test_initialize_stage_configs_called_when_none(monkeypatch, fake_stage_config):
"""Test that stage configs are auto-loaded when stage_configs_path is None."""
def _fake_loader(model: str, base_engine_args=None):
return [
_FakeStageConfig(fake_stage_config),
_FakeStageConfig(fake_stage_config),
]
# Remove modules from cache BEFORE setting mocks
import sys
for module_name in [
"vllm_omni.entrypoints.utils",
"vllm_omni.entrypoints.omni",
"vllm_omni.entrypoints.omni_stage",
]:
if module_name in sys.modules:
del sys.modules[module_name]
# Set up mocks
_setup_engine_mocks(monkeypatch)
_setup_multiprocessing_mocks(monkeypatch)
_setup_ipc_mocks(monkeypatch)
_setup_log_mocks(monkeypatch)
# Mock load_stage_configs_from_model
monkeypatch.setattr(
"vllm_omni.entrypoints.utils.load_stage_configs_from_model",
_fake_loader,
raising=False,
)
# Replace OmniStage
monkeypatch.setattr(
"vllm_omni.entrypoints.omni_stage.OmniStage",
lambda cfg, **kwargs: _FakeStage(cfg, **kwargs),
raising=False,
)
# Import the module after mocks are set
import vllm_omni.entrypoints.omni as omni_module
# Patch the imported function and class in the module
monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader)
monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(cfg, **kwargs))
from vllm_omni.entrypoints.omni import Omni
omni = Omni(model="any", init_timeout=1)
# Verify: auto-loaded stage_configs and stage_list have consistent count
assert isinstance(omni.stage_configs, list)
assert len(omni.stage_configs) == 2
assert len(omni.stage_list) == 2
# Verify: each Stage is _FakeStage instance
for st in omni.stage_list:
assert isinstance(st, _FakeStage)
# Verify: queues are attached
for st in omni.stage_list:
assert st._in_q is not None
assert st._out_q is not None
# Verify: all stages are ready
assert len(omni._stages_ready) == 2
def test_generate_raises_on_length_mismatch(monkeypatch, fake_stage_config):
"""Test that generate raises ValueError when sampling_params_list length doesn't match."""
def _fake_loader(model: str, base_engine_args=None):
return [_FakeStageConfig(fake_stage_config)]
import sys
for module_name in [
"vllm_omni.entrypoints.utils",
"vllm_omni.entrypoints.omni",
"vllm_omni.entrypoints.omni_stage",
]:
if module_name in sys.modules:
del sys.modules[module_name]
_setup_engine_mocks(monkeypatch)
_setup_multiprocessing_mocks(monkeypatch)
_setup_ipc_mocks(monkeypatch)
_setup_log_mocks(monkeypatch)
monkeypatch.setattr(
"vllm_omni.entrypoints.utils.load_stage_configs_from_model",
_fake_loader,
raising=False,
)
monkeypatch.setattr(
"vllm_omni.entrypoints.utils.load_stage_configs_from_model",
_fake_loader,
raising=False,
)
monkeypatch.setattr(
"vllm_omni.entrypoints.omni_stage.OmniStage",
lambda cfg, **kwargs: _FakeStage(cfg, **kwargs),
raising=False,
)
import vllm_omni.entrypoints.omni as omni_module
monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader)
monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(cfg, **kwargs))
from vllm_omni.entrypoints.omni import Omni
omni = Omni(model="any", init_timeout=1)
with pytest.raises(ValueError):
omni.generate(prompts=["hi"], sampling_params_list=[])
def test_generate_pipeline_and_final_outputs(monkeypatch, fake_stage_config):
"""Test multi-stage generation pipeline with queue polling."""
stage_cfg0 = dict(fake_stage_config)
stage_cfg1 = dict(fake_stage_config)
stage_cfg1["processed_input"] = ["processed-for-stage-1"]
def _fake_loader(model: str, base_engine_args=None):
return [_FakeStageConfig(stage_cfg0), _FakeStageConfig(stage_cfg1)]
import sys
for module_name in [
"vllm_omni.entrypoints.utils",
"vllm_omni.entrypoints.omni",
"vllm_omni.entrypoints.omni_stage",
]:
if module_name in sys.modules:
del sys.modules[module_name]
_setup_engine_mocks(monkeypatch)
_setup_multiprocessing_mocks(monkeypatch)
_setup_ipc_mocks(monkeypatch)
_setup_log_mocks(monkeypatch)
monkeypatch.setattr(
"vllm_omni.entrypoints.utils.load_stage_configs_from_model",
_fake_loader,
raising=False,
)
monkeypatch.setattr(
"vllm_omni.entrypoints.utils.load_stage_configs_from_model",
_fake_loader,
raising=False,
)
monkeypatch.setattr(
"vllm_omni.entrypoints.omni_stage.OmniStage",
lambda cfg, **kwargs: _FakeStage(cfg, **kwargs),
raising=False,
)
import vllm_omni.entrypoints.omni as omni_module
monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader)
monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(cfg, **kwargs))
# Mock uuid.uuid4() to return a predictable value for request ID generation
test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000000")
monkeypatch.setattr(uuid, "uuid4", lambda: test_uuid)
monkeypatch.setattr(omni_module, "uuid", uuid)
from vllm_omni.entrypoints.omni import Omni
omni = Omni(model="any", init_timeout=1)
# Generate the expected request ID format: "0_<uuid>"
expected_request_id = f"0_{test_uuid}"
# Simulate worker behavior: manually put results into output queues
# Note: We put results before calling generate, which simulates worker processes
# that have already completed. The polling loop will collect them in stage order.
# Stage 0 output (will be collected first)
omni.stage_list[0]._out_q.put_nowait(
{
"request_id": expected_request_id,
"engine_outputs": [{"stage": 0, "text": "s0"}],
"metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0},
}
)
# Stage 1 output (will be collected after stage 0 forwards to it)
# Note: In real flow, stage 1 result would appear after stage 0 forwards,
# but for testing we pre-populate it. The polling loop processes stages
# in order, so stage 0 result will be collected first, then forwarded,
# then stage 1 result will be collected.
omni.stage_list[1]._out_q.put_nowait(
{
"request_id": expected_request_id,
"engine_outputs": [{"stage": 1, "text": "s1"}],
"metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0},
}
)
sampling_params_list = [
SamplingParams(temperature=0.7),
SamplingParams(temperature=0.8),
]
prompts = ["hi"]
outputs = omni.generate(prompts=prompts, sampling_params_list=sampling_params_list)
# Both stages have final_output=True, so should aggregate two OmniRequestOutput
assert len(outputs) == 2
# Verify stage outputs are set
assert omni.stage_list[0].engine_outputs == [{"stage": 0, "text": "s0"}]
assert omni.stage_list[1].engine_outputs == [{"stage": 1, "text": "s1"}]
# Verify stage 0 input queue received the task
assert not omni.stage_list[0]._in_q.empty()
# Verify stage 1 received forwarded task (process_engine_inputs was called)
assert omni.stage_list[1].process_engine_inputs([], []) is not None
def test_generate_no_final_output_returns_empty(monkeypatch, fake_stage_config):
"""Test that generate returns empty list when all stages have final_output=False."""
stage_cfg0 = dict(fake_stage_config)
stage_cfg1 = dict(fake_stage_config)
stage_cfg0["final_output"] = False
stage_cfg1["final_output"] = False
def _fake_loader(model: str, base_engine_args=None):
return [_FakeStageConfig(stage_cfg0), _FakeStageConfig(stage_cfg1)]
import sys
for module_name in [
"vllm_omni.entrypoints.utils",
"vllm_omni.entrypoints.omni",
"vllm_omni.entrypoints.omni_stage",
]:
if module_name in sys.modules:
del sys.modules[module_name]
_setup_engine_mocks(monkeypatch)
_setup_multiprocessing_mocks(monkeypatch)
_setup_ipc_mocks(monkeypatch)
_setup_log_mocks(monkeypatch)
monkeypatch.setattr(
"vllm_omni.entrypoints.utils.load_stage_configs_from_model",
_fake_loader,
raising=False,
)
monkeypatch.setattr(
"vllm_omni.entrypoints.omni_stage.OmniStage",
lambda cfg, **kwargs: _FakeStage(cfg, **kwargs),
raising=False,
)
import vllm_omni.entrypoints.omni as omni_module
monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader)
monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(cfg, **kwargs))
# Mock uuid.uuid4() to return a predictable value for request ID generation
test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000000")
monkeypatch.setattr(uuid, "uuid4", lambda: test_uuid)
monkeypatch.setattr(omni_module, "uuid", uuid)
from vllm_omni.entrypoints.omni import Omni
omni = Omni(model="any", init_timeout=1)
# Generate the expected request ID format: "0_<uuid>"
expected_request_id = f"0_{test_uuid}"
# Simulate worker behavior: put results into output queues
omni.stage_list[0]._out_q.put_nowait(
{
"request_id": expected_request_id,
"engine_outputs": [{"stage": 0}],
"metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0},
}
)
omni.stage_list[1]._out_q.put_nowait(
{
"request_id": expected_request_id,
"engine_outputs": [{"stage": 1}],
"metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0},
}
)
outputs = omni.generate(
prompts=["p"],
sampling_params_list=[
SamplingParams(temperature=0.7),
SamplingParams(temperature=0.8),
],
)
assert outputs == []
def test_generate_sampling_params_none_use_default(monkeypatch, fake_stage_config):
"""Test that generate uses default sampling params when sampling_params_list is None."""
stage_cfg0 = dict(fake_stage_config)
stage_cfg1 = dict(fake_stage_config)
stage_cfg0["final_output"] = False
stage_cfg1["final_output"] = False
def _fake_loader(model: str, base_engine_args=None):
return [_FakeStageConfig(stage_cfg0), _FakeStageConfig(stage_cfg1)]
import sys
for module_name in [
"vllm_omni.entrypoints.utils",
"vllm_omni.entrypoints.omni",
"vllm_omni.entrypoints.omni_stage",
]:
if module_name in sys.modules:
del sys.modules[module_name]
_setup_engine_mocks(monkeypatch)
_setup_multiprocessing_mocks(monkeypatch)
_setup_ipc_mocks(monkeypatch)
_setup_log_mocks(monkeypatch)
monkeypatch.setattr(
"vllm_omni.entrypoints.utils.load_stage_configs_from_model",
_fake_loader,
raising=False,
)
monkeypatch.setattr(
"vllm_omni.entrypoints.omni_stage.OmniStage",
lambda cfg, **kwargs: _FakeStage(cfg, **kwargs),
raising=False,
)
import vllm_omni.entrypoints.omni as omni_module
monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader)
monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(cfg, **kwargs))
# Mock uuid.uuid4() to return a predictable value for request ID generation
test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000000")
monkeypatch.setattr(uuid, "uuid4", lambda: test_uuid)
monkeypatch.setattr(omni_module, "uuid", uuid)
from vllm_omni.entrypoints.omni import Omni
omni = Omni(model="any", init_timeout=1)
# Generate the expected request ID format: "0_<uuid>"
expected_request_id = f"0_{test_uuid}"
# Simulate worker behavior: put results into output queues
omni.stage_list[0]._out_q.put_nowait(
{
"request_id": expected_request_id,
"engine_outputs": [{"stage": 0}],
"metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0},
}
)
omni.stage_list[1]._out_q.put_nowait(
{
"request_id": expected_request_id,
"engine_outputs": [{"stage": 1}],
"metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0},
}
)
# Use the default sampling params
omni.generate(prompts=["p"], sampling_params_list=None)
def test_wait_for_stages_ready_timeout(monkeypatch, fake_stage_config):
"""Test that _wait_for_stages_ready handles timeout correctly."""
def _fake_loader(model: str, base_engine_args=None):
return [_FakeStageConfig(fake_stage_config)]
import sys
for module_name in [
"vllm_omni.entrypoints.utils",
"vllm_omni.entrypoints.omni",
"vllm_omni.entrypoints.omni_stage",
]:
if module_name in sys.modules:
del sys.modules[module_name]
_setup_engine_mocks(monkeypatch)
_setup_multiprocessing_mocks(monkeypatch)
_setup_ipc_mocks(monkeypatch)
_setup_log_mocks(monkeypatch)
monkeypatch.setattr(
"vllm_omni.entrypoints.utils.load_stage_configs_from_model",
_fake_loader,
raising=False,
)
# Create a stage that doesn't send stage_ready message
class _FakeStageNoReady(_FakeStage):
def init_stage_worker(self, *args, **kwargs):
# Don't send stage_ready message
self._proc = MagicMock()
self._proc.start = MagicMock()
self._proc.join = MagicMock()
self._proc.is_alive = MagicMock(return_value=False)
self._proc.terminate = MagicMock()
monkeypatch.setattr(
"vllm_omni.entrypoints.omni_stage.OmniStage",
lambda cfg, **kwargs: _FakeStageNoReady(cfg, **kwargs),
raising=False,
)
import vllm_omni.entrypoints.omni as omni_module
monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader)
monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStageNoReady(cfg, **kwargs))
from vllm_omni.entrypoints.omni import Omni
# Use very short timeout
omni = Omni(model="any", init_timeout=0.01)
# Verify that no stages are ready
assert len(omni._stages_ready) == 0
def test_generate_handles_error_messages(monkeypatch, fake_stage_config):
"""Test that generate handles error messages from stages correctly."""
def _fake_loader(model: str, base_engine_args=None):
return [_FakeStageConfig(fake_stage_config)]
import sys
for module_name in [
"vllm_omni.entrypoints.utils",
"vllm_omni.entrypoints.omni",
"vllm_omni.entrypoints.omni_stage",
]:
if module_name in sys.modules:
del sys.modules[module_name]
_setup_engine_mocks(monkeypatch)
_setup_multiprocessing_mocks(monkeypatch)
_setup_ipc_mocks(monkeypatch)
_setup_log_mocks(monkeypatch)
monkeypatch.setattr(
"vllm_omni.entrypoints.utils.load_stage_configs_from_model",
_fake_loader,
raising=False,
)
monkeypatch.setattr(
"vllm_omni.entrypoints.omni_stage.OmniStage",
lambda cfg, **kwargs: _FakeStage(cfg, **kwargs),
raising=False,
)
import vllm_omni.entrypoints.omni as omni_module
monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader)
monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(cfg, **kwargs))
# Mock uuid.uuid4() to return a predictable value for request ID generation
test_uuid = uuid.UUID("00000000-0000-0000-0000-000000000000")
monkeypatch.setattr(uuid, "uuid4", lambda: test_uuid)
monkeypatch.setattr(omni_module, "uuid", uuid)
from vllm_omni.entrypoints.omni import Omni
omni = Omni(model="any", init_timeout=1)
# Generate the expected request ID format: "0_<uuid>"
expected_request_id = f"0_{test_uuid}"
# Put error message in output queue
omni.stage_list[0]._out_q.put_nowait(
{
"request_id": expected_request_id,
"error": "test error",
}
)
# Also put a valid result after error to allow the loop to complete
# (error handling continues the loop, so we need a valid result to finish)
omni.stage_list[0]._out_q.put_nowait(
{
"request_id": expected_request_id,
"engine_outputs": [{"stage": 0, "text": "result"}],
"metrics": {"num_tokens_out": 1, "stage_gen_time_ms": 10.0},
}
)
# Generate should handle error gracefully (log but continue)
sampling_params_list = [SamplingParams(temperature=0.7)]
outputs = omni.generate(prompts=["hi"], sampling_params_list=sampling_params_list)
# Should return final output (error was logged but didn't stop processing)
assert isinstance(outputs, list)
# Since final_output=True, should have one output
assert len(outputs) == 1
def test_close_sends_shutdown_signal(monkeypatch, fake_stage_config):
"""Test that close() sends shutdown signal to all input queues."""
def _fake_loader(model: str, base_engine_args=None):
return [_FakeStageConfig(fake_stage_config)]
import sys
for module_name in [
"vllm_omni.entrypoints.utils",
"vllm_omni.entrypoints.omni",
"vllm_omni.entrypoints.omni_stage",
]:
if module_name in sys.modules:
del sys.modules[module_name]
_setup_engine_mocks(monkeypatch)
_setup_multiprocessing_mocks(monkeypatch)
_setup_ipc_mocks(monkeypatch)
_setup_log_mocks(monkeypatch)
monkeypatch.setattr(
"vllm_omni.entrypoints.utils.load_stage_configs_from_model",
_fake_loader,
raising=False,
)
monkeypatch.setattr(
"vllm_omni.entrypoints.omni_stage.OmniStage",
lambda cfg, **kwargs: _FakeStage(cfg, **kwargs),
raising=False,
)
import vllm_omni.entrypoints.omni as omni_module
monkeypatch.setattr(omni_module, "load_stage_configs_from_model", _fake_loader)
monkeypatch.setattr(omni_module, "OmniStage", lambda cfg, **kwargs: _FakeStage(cfg, **kwargs))
from vllm_omni.entrypoints.omni import Omni
omni = Omni(model="any", init_timeout=1)
# Call close
omni.close()
# Verify shutdown signal (None) was sent to input queue
# Use get_nowait to avoid blocking (close() uses put_nowait, so should be safe)
try:
shutdown_signal = omni.stage_list[0]._in_q.get_nowait()
assert shutdown_signal == SHUTDOWN_TASK
except Empty:
# If queue was already empty or only had stage_ready, that's also acceptable
# The important thing is that close() was called without error
pass
# Verify stop_stage_worker was called (process should be set)
assert omni.stage_list[0]._proc is not None
from types import SimpleNamespace
import torch
from vllm_omni.core.sched.output import OmniNewRequestData
def test_omni_new_request_data_copies_payloads():
prompt_embeds = torch.randn(2, 3)
additional_information = {
"speaker": ["test"],
"codes": torch.tensor([1, 2], dtype=torch.int64),
}
request = SimpleNamespace(
request_id="req-1",
external_req_id="ext-1",
prompt_token_ids=[101, 102],
mm_features=None,
sampling_params=None,
pooling_params=None,
num_computed_tokens=0,
lora_request=None,
prompt_embeds=prompt_embeds,
additional_information=additional_information,
)
data = OmniNewRequestData.from_request(request, ([0, 1],), prefill_token_ids=[101, 102])
assert data.prompt_embeds is prompt_embeds
assert data.additional_information is additional_information
assert data.prefill_token_ids == [101, 102]
def test_omni_new_request_data_allows_missing_payloads():
request = SimpleNamespace(
request_id="req-2",
external_req_id="ext-2",
prompt_token_ids=[201, 202],
mm_features=None,
sampling_params=None,
pooling_params=None,
num_computed_tokens=0,
lora_request=None,
prompt_embeds=None,
additional_information=None,
)
data = OmniNewRequestData.from_request(request, ([0],), prefill_token_ids=None)
assert data.prompt_embeds is None
assert data.additional_information is None
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm_omni.entrypoints.omni_stage import _build_od_config
def test_build_od_config_includes_diffusion_fields():
engine_args = {
"cache_backend": "cache_dit",
"cache_config": {"Fn_compute_blocks": 2},
"vae_use_slicing": True,
}
od_config = _build_od_config(engine_args, model="dummy-model")
assert od_config["model"] == "dummy-model"
assert od_config["cache_backend"] == "cache_dit"
assert od_config["cache_config"]["Fn_compute_blocks"] == 2
assert od_config["vae_use_slicing"] is True
def test_build_od_config_respects_explicit_config():
engine_args = {
"od_config": {"cache_backend": "tea_cache"},
"cache_backend": "cache_dit",
}
od_config = _build_od_config(engine_args, model="dummy-model")
assert od_config == {"cache_backend": "tea_cache"}
import os
import sys
from unittest.mock import MagicMock
import pytest
from vllm_omni.entrypoints.stage_utils import set_stage_devices
def _make_dummy_torch(call_log):
class _Props:
def __init__(self, total):
self.total_memory = total
class _Cuda:
@staticmethod
def is_available():
return True
@staticmethod
def set_device(idx):
call_log.append(idx)
@staticmethod
def device_count():
return 2
@staticmethod
def get_device_properties(idx):
return _Props(total=16000)
@staticmethod
def mem_get_info(idx):
return (8000, 16000)
@staticmethod
def get_device_name(idx):
return f"gpu-{idx}"
class _Torch:
cuda = _Cuda
return _Torch
def _make_mock_platform(device_type: str = "cuda", env_var: str = "CUDA_VISIBLE_DEVICES"):
"""Create a mock platform for testing."""
mock_platform = MagicMock()
mock_platform.device_type = device_type
mock_platform.device_control_env_var = env_var
return mock_platform
@pytest.mark.usefixtures("clean_gpu_memory_between_tests")
def test_set_stage_devices_respects_logical_ids(monkeypatch):
# Preserve an existing logical mapping and ensure devices "0,1" map through it.
monkeypatch.setenv("CUDA_VISIBLE_DEVICES", "6,7")
call_log: list[int] = []
dummy_torch = _make_dummy_torch(call_log)
monkeypatch.setitem(sys.modules, "torch", dummy_torch)
# Mock the platform at the source module where it's defined
mock_platform = _make_mock_platform(device_type="cuda", env_var="CUDA_VISIBLE_DEVICES")
monkeypatch.setattr(
"vllm_omni.platforms.current_omni_platform",
mock_platform,
)
set_stage_devices(stage_id=0, devices="0,1")
assert os.environ["CUDA_VISIBLE_DEVICES"] == "6,7"
@pytest.mark.usefixtures("clean_gpu_memory_between_tests")
def test_set_stage_devices_npu_platform(monkeypatch):
"""Test that set_stage_devices works correctly for NPU platform."""
monkeypatch.setenv("ASCEND_RT_VISIBLE_DEVICES", "4,5")
call_log: list[int] = []
# Create NPU mock torch
class _Npu:
@staticmethod
def is_available():
return True
@staticmethod
def set_device(idx):
call_log.append(idx)
@staticmethod
def device_count():
return 2
class _NpuTorch:
npu = _Npu
monkeypatch.setitem(sys.modules, "torch", _NpuTorch)
# Mock NPU platform at the source module where it's defined
mock_platform = _make_mock_platform(device_type="npu", env_var="ASCEND_RT_VISIBLE_DEVICES")
monkeypatch.setattr(
"vllm_omni.platforms.current_omni_platform",
mock_platform,
)
set_stage_devices(stage_id=0, devices="0,1")
assert os.environ["ASCEND_RT_VISIBLE_DEVICES"] == "4,5"
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
def test_resolve_max_mel_frames_default():
from vllm_omni.model_executor.models.qwen2_5_omni.audio_length import resolve_max_mel_frames
assert resolve_max_mel_frames(None, default=30000) == 30000
assert resolve_max_mel_frames(None, default=6000) == 6000
def test_resolve_max_mel_frames_explicit():
from vllm_omni.model_executor.models.qwen2_5_omni.audio_length import resolve_max_mel_frames
# Explicit argument always wins over default
assert resolve_max_mel_frames(123, default=30000) == 123
assert resolve_max_mel_frames(6000, default=30000) == 6000
assert resolve_max_mel_frames(0, default=30000) == 0
@pytest.mark.parametrize("repeats", [2, 4])
@pytest.mark.parametrize("code_len", [0, 1, 32768])
@pytest.mark.parametrize("max_mel_frames", [None, -1, 0, 1, 6000, 30000])
def test_cap_and_align_mel_length_no_mismatch(repeats, code_len, max_mel_frames):
"""Guard that any max_mel_frames yields a mel length aligned to repeats, and
consistent with the truncated code length (prevents concat mismatch).
"""
from vllm_omni.model_executor.models.qwen2_5_omni.audio_length import cap_and_align_mel_length
target_code_len, target_mel_len = cap_and_align_mel_length(
code_len=code_len,
repeats=repeats,
max_mel_frames=max_mel_frames,
)
assert isinstance(target_code_len, int)
assert isinstance(target_mel_len, int)
if code_len == 0:
assert target_code_len == 0
assert target_mel_len == 0
return
assert target_code_len >= 1
assert target_mel_len >= repeats
assert target_mel_len % repeats == 0
assert target_mel_len == target_code_len * repeats
assert target_code_len <= code_len
if max_mel_frames is not None and int(max_mel_frames) > 0 and int(max_mel_frames) >= repeats:
assert target_mel_len <= int(max_mel_frames)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for OmniRequestOutput class."""
from unittest.mock import Mock
from PIL import Image
from vllm_omni.outputs import OmniRequestOutput
class TestOmniRequestOutput:
"""Tests for OmniRequestOutput class."""
def test_from_diffusion(self):
"""Test creating output from diffusion model."""
images = [Image.new("RGB", (64, 64), color="red")]
output = OmniRequestOutput.from_diffusion(
request_id="test-123",
images=images,
prompt="a cat",
metrics={"steps": 50},
)
assert output.request_id == "test-123"
assert output.images == images
assert output.prompt == "a cat"
assert output.metrics == {"steps": 50}
assert output.is_diffusion_output
assert output.num_images == 1
def test_from_pipeline(self):
"""Test creating output from pipeline stage."""
mock_request_output = Mock()
mock_request_output.request_id = "pipeline-123"
mock_request_output.prompt_token_ids = [1, 2, 3]
mock_request_output.outputs = [Mock()]
mock_request_output.encoder_prompt_token_ids = None
mock_request_output.prompt_logprobs = None
mock_request_output.num_cached_tokens = 10
mock_request_output.kv_transfer_params = None
mock_request_output.multimodal_output = {"image": Mock()}
output = OmniRequestOutput.from_pipeline(
stage_id=0,
final_output_type="text",
request_output=mock_request_output,
)
assert output.request_id == "pipeline-123"
assert output.stage_id == 0
assert output.final_output_type == "text"
assert output.is_pipeline_output
def test_prompt_token_ids_property(self):
"""Test prompt_token_ids property for streaming compatibility."""
mock_request_output = Mock()
mock_request_output.prompt_token_ids = [1, 2, 3, 4, 5]
output = OmniRequestOutput.from_pipeline(
stage_id=0,
final_output_type="text",
request_output=mock_request_output,
)
assert output.prompt_token_ids == [1, 2, 3, 4, 5]
def test_prompt_token_ids_none_when_no_request_output(self):
"""Test prompt_token_ids returns None when no request_output."""
output = OmniRequestOutput.from_diffusion(
request_id="test-123",
images=[],
prompt="a cat",
)
assert output.prompt_token_ids is None
def test_outputs_property(self):
"""Test outputs property for chat completion compatibility."""
mock_output = Mock()
mock_request_output = Mock()
mock_request_output.outputs = [mock_output]
output = OmniRequestOutput.from_pipeline(
stage_id=0,
final_output_type="text",
request_output=mock_request_output,
)
assert output.outputs == [mock_output]
def test_outputs_empty_when_no_request_output(self):
"""Test outputs returns empty list when no request_output."""
output = OmniRequestOutput.from_diffusion(
request_id="test-123",
images=[],
prompt="a cat",
)
assert output.outputs == []
def test_encoder_prompt_token_ids_property(self):
"""Test encoder_prompt_token_ids property."""
mock_request_output = Mock()
mock_request_output.encoder_prompt_token_ids = [10, 20, 30]
output = OmniRequestOutput.from_pipeline(
stage_id=0,
final_output_type="text",
request_output=mock_request_output,
)
assert output.encoder_prompt_token_ids == [10, 20, 30]
def test_num_cached_tokens_property(self):
"""Test num_cached_tokens property."""
mock_request_output = Mock()
mock_request_output.num_cached_tokens = 42
output = OmniRequestOutput.from_pipeline(
stage_id=0,
final_output_type="text",
request_output=mock_request_output,
)
assert output.num_cached_tokens == 42
def test_multimodal_output_property(self):
"""Test multimodal_output property."""
mock_request_output = Mock()
mock_audio = Mock()
expected_output = {"audio": mock_audio}
mock_request_output.multimodal_output = expected_output
output = OmniRequestOutput.from_pipeline(
stage_id=0,
final_output_type="audio",
request_output=mock_request_output,
)
assert output.multimodal_output is expected_output
def test_multimodal_output_prefers_completion_output(self):
"""Test multimodal_output prefers completion output payloads."""
completion_output = Mock()
completion_mm = {"audio": Mock()}
completion_output.multimodal_output = completion_mm
mock_request_output = Mock()
mock_request_output.outputs = [completion_output]
mock_request_output.multimodal_output = {"audio": Mock()}
output = OmniRequestOutput.from_pipeline(
stage_id=0,
final_output_type="audio",
request_output=mock_request_output,
)
assert output.multimodal_output is completion_mm
def test_to_dict_diffusion(self):
"""Test to_dict for diffusion output."""
output = OmniRequestOutput.from_diffusion(
request_id="test-123",
images=[Image.new("RGB", (64, 64), color="red")],
prompt="a cat",
metrics={"steps": 50},
)
result = output.to_dict()
assert result["request_id"] == "test-123"
assert result["finished"] is True
assert result["final_output_type"] == "image"
assert result["num_images"] == 1
assert result["prompt"] == "a cat"
def test_to_dict_pipeline(self):
"""Test to_dict for pipeline output."""
mock_request_output = Mock()
mock_request_output.request_id = "pipeline-123"
output = OmniRequestOutput.from_pipeline(
stage_id=0,
final_output_type="text",
request_output=mock_request_output,
)
result = output.to_dict()
assert result["request_id"] == "pipeline-123"
assert result["finished"] is True
assert result["final_output_type"] == "text"
assert result["stage_id"] == 0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Some functions are copied from vllm/tests/utils.py
import functools
import os
import signal
import subprocess
import sys
import tempfile
import threading
import time
from collections.abc import Callable
from contextlib import ExitStack, contextmanager, suppress
from typing import Any, Literal
import cloudpickle
import pytest
import torch
from typing_extensions import ParamSpec
from vllm.platforms import current_platform
from vllm.utils.torch_utils import cuda_device_count_stateless
_P = ParamSpec("_P")
if current_platform.is_rocm():
from amdsmi import (
amdsmi_get_gpu_vram_usage,
amdsmi_get_processor_handles,
amdsmi_init,
amdsmi_shut_down,
)
@contextmanager
def _nvml():
try:
amdsmi_init()
yield
finally:
amdsmi_shut_down()
elif current_platform.is_cuda():
from vllm.third_party.pynvml import (
nvmlDeviceGetHandleByIndex,
nvmlDeviceGetMemoryInfo,
nvmlInit,
nvmlShutdown,
)
@contextmanager
def _nvml():
try:
nvmlInit()
yield
finally:
nvmlShutdown()
else:
@contextmanager
def _nvml():
yield
def get_physical_device_indices(devices):
visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES")
if visible_devices is None:
return devices
visible_indices = [int(x) for x in visible_devices.split(",")]
index_mapping = {i: physical for i, physical in enumerate(visible_indices)}
return [index_mapping[i] for i in devices if i in index_mapping]
@_nvml()
def wait_for_gpu_memory_to_clear(
*,
devices: list[int],
threshold_bytes: int | None = None,
threshold_ratio: float | None = None,
timeout_s: float = 120,
) -> None:
import gc
assert threshold_bytes is not None or threshold_ratio is not None
# Use nvml instead of pytorch to reduce measurement error from torch cuda
# context.
devices = get_physical_device_indices(devices)
start_time = time.time()
# Print waiting start information
device_list = ", ".join(str(d) for d in devices)
if threshold_bytes is not None:
threshold_str = f"{threshold_bytes / 2**30:.2f} GiB"
condition_str = f"Memory usage ≤ {threshold_str}"
else:
threshold_percent = threshold_ratio * 100
threshold_str = f"{threshold_percent:.1f}%"
condition_str = f"Memory usage ratio ≤ {threshold_str}"
print(f"[GPU Memory Monitor] Waiting for GPU {device_list} to free memory, Condition: {condition_str}")
# Define the is_free function based on threshold type
if threshold_bytes is not None:
def is_free(used, total):
return used <= threshold_bytes / 2**30
else:
def is_free(used, total):
return used / total <= threshold_ratio
while True:
output: dict[int, str] = {}
output_raw: dict[int, tuple[float, float]] = {}
for device in devices:
if current_platform.is_rocm():
dev_handle = amdsmi_get_processor_handles()[device]
mem_info = amdsmi_get_gpu_vram_usage(dev_handle)
gb_used = mem_info["vram_used"] / 2**10
gb_total = mem_info["vram_total"] / 2**10
else:
dev_handle = nvmlDeviceGetHandleByIndex(device)
mem_info = nvmlDeviceGetMemoryInfo(dev_handle)
gb_used = mem_info.used / 2**30
gb_total = mem_info.total / 2**30
output_raw[device] = (gb_used, gb_total)
# Format to more readable form
usage_percent = (gb_used / gb_total) * 100 if gb_total > 0 else 0
output[device] = f"{gb_used:.1f}GiB/{gb_total:.1f}GiB ({usage_percent:.1f}%)"
# Optimized GPU memory status print
print("[GPU Memory Status] Current usage:")
for device_id, mem_info in output.items():
print(f" GPU {device_id}: {mem_info}")
# Calculate waiting duration
dur_s = time.time() - start_time
elapsed_minutes = dur_s / 60
# Check if all devices meet the condition
if all(is_free(used, total) for used, total in output_raw.values()):
# Optimized completion message
print(f"[GPU Memory Freed] Devices {device_list} meet memory condition")
print(f" Condition: {condition_str}")
print(f" Wait time: {dur_s:.1f} seconds ({elapsed_minutes:.1f} minutes)")
print(" Final status:")
for device_id, mem_info in output.items():
print(f" GPU {device_id}: {mem_info}")
break
# Check timeout
if dur_s >= timeout_s:
raise ValueError(
f"[GPU Memory Timeout] Devices {device_list} still don't meet memory condition after {dur_s:.1f} seconds\n"
f"Condition: {condition_str}\n"
f"Current status:\n" + "\n".join(f" GPU {device}: {output[device]}" for device in devices)
)
# Add waiting hint (optional)
if dur_s > 10 and int(dur_s) % 10 == 0: # Show hint every 10 seconds
print(f"Waiting... Already waited {dur_s:.1f} seconds ({elapsed_minutes:.1f} minutes)")
gc.collect()
torch.cuda.empty_cache()
time.sleep(5)
def fork_new_process_for_each_test(func: Callable[_P, None]) -> Callable[_P, None]:
"""Decorator to fork a new process for each test function.
See https://github.com/vllm-project/vllm/issues/7053 for more details.
"""
@functools.wraps(func)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
# Make the process the leader of its own process group
# to avoid sending SIGTERM to the parent process
os.setpgrp()
from _pytest.outcomes import Skipped
# Create a unique temporary file to store exception info from child
# process. Use test function name and process ID to avoid collisions.
with (
tempfile.NamedTemporaryFile(
delete=False, mode="w+b", prefix=f"vllm_test_{func.__name__}_{os.getpid()}_", suffix=".exc"
) as exc_file,
ExitStack() as delete_after,
):
exc_file_path = exc_file.name
delete_after.callback(os.remove, exc_file_path)
pid = os.fork()
print(f"Fork a new process to run a test {pid}")
if pid == 0:
# Parent process responsible for deleting, don't delete
# in child.
delete_after.pop_all()
try:
func(*args, **kwargs)
except Skipped as e:
# convert Skipped to exit code 0
print(str(e))
os._exit(0)
except Exception as e:
import traceback
tb_string = traceback.format_exc()
# Try to serialize the exception object first
exc_to_serialize: dict[str, Any]
try:
# First, try to pickle the actual exception with
# its traceback.
exc_to_serialize = {"pickled_exception": e}
# Test if it can be pickled
cloudpickle.dumps(exc_to_serialize)
except (Exception, KeyboardInterrupt):
# Fall back to string-based approach.
exc_to_serialize = {
"exception_type": type(e).__name__,
"exception_msg": str(e),
"traceback": tb_string,
}
try:
with open(exc_file_path, "wb") as f:
cloudpickle.dump(exc_to_serialize, f)
except Exception:
# Fallback: just print the traceback.
print(tb_string)
os._exit(1)
else:
os._exit(0)
else:
pgid = os.getpgid(pid)
_pid, _exitcode = os.waitpid(pid, 0)
# ignore SIGTERM signal itself
old_signal_handler = signal.signal(signal.SIGTERM, signal.SIG_IGN)
# kill all child processes
os.killpg(pgid, signal.SIGTERM)
# restore the signal handler
signal.signal(signal.SIGTERM, old_signal_handler)
if _exitcode != 0:
# Try to read the exception from the child process
exc_info = {}
if os.path.exists(exc_file_path):
with suppress(Exception), open(exc_file_path, "rb") as f:
exc_info = cloudpickle.load(f)
if (original_exception := exc_info.get("pickled_exception")) is not None:
# Re-raise the actual exception object if it was
# successfully pickled.
assert isinstance(original_exception, Exception)
raise original_exception
if (original_tb := exc_info.get("traceback")) is not None:
# Use string-based traceback for fallback case
raise AssertionError(
f"Test {func.__name__} failed when called with"
f" args {args} and kwargs {kwargs}"
f" (exit code: {_exitcode}):\n{original_tb}"
) from None
# Fallback to the original generic error
raise AssertionError(
f"function {func.__name__} failed when called with"
f" args {args} and kwargs {kwargs}"
f" (exit code: {_exitcode})"
) from None
return wrapper
def spawn_new_process_for_each_test(f: Callable[_P, None]) -> Callable[_P, None]:
"""Decorator to spawn a new process for each test function."""
@functools.wraps(f)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
# Check if we're already in a subprocess
if os.environ.get("RUNNING_IN_SUBPROCESS") == "1":
# If we are, just run the function directly
return f(*args, **kwargs)
import torch.multiprocessing as mp
with suppress(RuntimeError):
mp.set_start_method("spawn")
# Get the module
module_name = f.__module__
# Create a process with environment variable set
env = os.environ.copy()
env["RUNNING_IN_SUBPROCESS"] = "1"
with tempfile.TemporaryDirectory() as tempdir:
output_filepath = os.path.join(tempdir, "new_process.tmp")
# `cloudpickle` allows pickling complex functions directly
input_bytes = cloudpickle.dumps((f, output_filepath))
cmd = [sys.executable, "-m", f"{module_name}"]
returned = subprocess.run(cmd, input=input_bytes, capture_output=True, env=env)
# check if the subprocess is successful
try:
returned.check_returncode()
except Exception as e:
# wrap raised exception to provide more information
raise RuntimeError(f"Error raised in subprocess:\n{returned.stderr.decode()}") from e
return wrapper
def create_new_process_for_each_test(
method: Literal["spawn", "fork"] | None = None,
) -> Callable[[Callable[_P, None]], Callable[_P, None]]:
"""Creates a decorator that runs each test function in a new process.
Args:
method: The process creation method. Can be either "spawn" or "fork".
If not specified, it defaults to "spawn" on ROCm and XPU
platforms and "fork" otherwise.
Returns:
A decorator to run test functions in separate processes.
"""
if method is None:
# TODO: Spawn is not working correctly on ROCm
# The test content will not run and tests passed immediately.
# For now, using `fork` for ROCm as it can run with `fork`
# and tests are running correctly.
use_spawn = current_platform.is_xpu()
method = "spawn" if use_spawn else "fork"
assert method in ["spawn", "fork"], "Method must be either 'spawn' or 'fork'"
if method == "fork":
return fork_new_process_for_each_test
return spawn_new_process_for_each_test
def cuda_marks(*, res: str, num_cards: int):
"""
Get a collection of pytest marks to apply for `@cuda_test`.
Args:
res: Resource type, e.g., "L4" or "H100".
num_cards: Number of GPU cards required.
Returns:
List of pytest marks to apply.
"""
test_platform_detail = pytest.mark.cuda
if res == "L4":
test_resource = pytest.mark.L4
elif res == "H100":
test_resource = pytest.mark.H100
else:
raise ValueError(f"Invalid CUDA resource type: {res}. Supported: L4, H100")
marks = [test_resource, test_platform_detail]
if num_cards == 1:
return marks
else:
test_distributed = pytest.mark.distributed_cuda(num_cards=num_cards)
test_skipif = pytest.mark.skipif_cuda(
cuda_device_count_stateless() < num_cards,
reason=f"Need at least {num_cards} CUDA GPUs to run the test.",
)
return marks + [test_distributed, test_skipif]
def rocm_marks(*, res: str, num_cards: int):
"""
Get a collection of pytest marks to apply for `@rocm_test`.
Args:
res: Resource type, e.g., "MI325".
num_cards: Number of GPU cards required.
Returns:
List of pytest marks to apply.
"""
test_platform_detail = pytest.mark.rocm
if res == "MI325":
test_resource = pytest.mark.MI325
else:
raise ValueError(f"Invalid ROCm resource type: {res}. Supported: MI325")
marks = [test_resource, test_platform_detail]
if num_cards == 1:
return marks
else:
test_distributed = pytest.mark.distributed_rocm(num_cards=num_cards)
# TODO: add ROCm support for `skipif_rocm` marker
return marks + [test_distributed]
def gpu_marks(*, res: str, num_cards: int):
"""
Get a collection of pytest marks to apply for `@gpu_test`.
Platform is automatically determined based on resource type.
Args:
res: Resource type, e.g., "L4", "H100" for CUDA, or "MI325" for ROCm.
num_cards: Number of GPU cards required.
Returns:
List of pytest marks to apply.
"""
test_platform = pytest.mark.gpu
if res in ("L4", "H100"):
return [test_platform] + cuda_marks(res=res, num_cards=num_cards)
if res == "MI325":
return [test_platform] + rocm_marks(res=res, num_cards=num_cards)
raise ValueError(f"Invalid resource type: {res}. Supported: L4, H100, MI325")
def npu_marks(*, res: str, num_cards: int):
"""Get a collection of pytest marks to apply for `@npu_test`."""
test_platform = pytest.mark.npu
if res == "A2":
test_resource = pytest.mark.A2
elif res == "A3":
test_resource = pytest.mark.A3
else:
# TODO: Currently we don't have various NPU card types defined
# Use None to skip resource-specific marking for unknown types
test_resource = None
if num_cards == 1:
return [mark for mark in [test_platform, test_resource] if mark is not None]
else:
# Multiple cards scenario needs distributed_npu mark
test_distributed = pytest.mark.distributed_npu(num_cards=num_cards)
# TODO: add NPU support for `skipif_npu` marker
return [mark for mark in [test_platform, test_resource, test_distributed] if mark is not None]
def hardware_test(*, res: dict[str, str], num_cards: int | dict[str, int] = 1):
"""
Decorate a test for multiple hardware platforms with a single call.
Automatically wraps the test with @create_new_process_for_each_test() for distributed tests.
Args:
res: Mapping from platform to resource type. Supported platforms/resources:
- cuda: L4, H100
- rocm: MI325
- npu: A2, A3
num_cards: Number of cards required. Can be:
- int: same card count for all platforms (default: 1)
- dict: per-platform card count, e.g., {"cuda": 2, "rocm": 2}
Example:
@hardware_test(
res={"cuda": "L4", "rocm": "MI325", "npu": "A2"},
num_cards={"cuda": 2, "rocm": 2, "npu": 2},
)
def test_multi_platform():
...
"""
# Validate platforms
# Don't validate platform details in this decorator
for platform, _ in res.items():
if platform not in ("cuda", "rocm", "npu"):
raise ValueError(f"Unsupported platform: {platform}")
# Normalize num_cards
if isinstance(num_cards, int):
num_cards_dict = {platform: num_cards for platform in res.keys()}
else:
num_cards_dict = num_cards
for platform in num_cards_dict.keys():
if platform not in res:
raise ValueError(
f"Platform '{platform}' in num_cards but not in res. Available platforms: {list(res.keys())}"
)
for platform in res.keys():
if platform not in num_cards_dict:
num_cards_dict[platform] = 1
# Collect marks from all platforms
all_marks: list[Callable[[Callable[_P, None]], Callable[_P, None]]] = []
for platform, resource in res.items():
cards = num_cards_dict[platform]
if platform == "cuda" or platform == "rocm":
marks = gpu_marks(res=resource, num_cards=cards)
elif platform == "npu":
marks = npu_marks(res=resource, num_cards=cards)
else:
raise ValueError(f"Unsupported platform: {platform}")
all_marks.extend(marks)
create_new_process_flag = False
for cards in num_cards_dict.values():
if cards > 1:
create_new_process_flag = True
break
def wrapper(f: Callable[_P, None]) -> Callable[_P, None]:
if create_new_process_flag:
# only for distributed tests
func = create_new_process_for_each_test()(f)
else:
func = f
for mark in reversed(all_marks):
func = mark(func)
return func
return wrapper
class GPUMemoryMonitor:
"""Poll global device memory usage via CUDA APIs."""
def __init__(self, device_index: int, interval: float = 0.05):
self.device_index = device_index
self.interval = interval
self._peak_used_mb = 0.0
self._stop_event = threading.Event()
self._thread: threading.Thread | None = None
def start(self) -> None:
def monitor_loop() -> None:
while not self._stop_event.is_set():
try:
with torch.cuda.device(self.device_index):
free_bytes, total_bytes = torch.cuda.mem_get_info()
used_mb = (total_bytes - free_bytes) / (1024**2)
self._peak_used_mb = max(self._peak_used_mb, used_mb)
except Exception:
pass
time.sleep(self.interval)
self._thread = threading.Thread(target=monitor_loop, daemon=False)
self._thread.start()
def stop(self) -> None:
if self._thread is None:
return
self._stop_event.set()
self._thread.join(timeout=2.0)
@property
def peak_used_mb(self) -> float:
fallback_alloc = torch.cuda.max_memory_allocated(device=self.device_index) / (1024**2)
fallback_reserved = torch.cuda.max_memory_reserved(device=self.device_index) / (1024**2)
return max(self._peak_used_mb, fallback_alloc, fallback_reserved)
def __del__(self):
self.stop()
import torch
from vllm_omni.worker.gpu_generation_model_runner import GPUGenerationModelRunner
class _DummyInputBatch:
def __init__(self):
self.req_ids = ["req-1"]
self.req_id_to_index = {"req-1": 0}
self.num_reqs = 1
self.vocab_size = 10
def _make_runner(multimodal_outputs):
runner = object.__new__(GPUGenerationModelRunner)
runner.execute_model_state = (
None,
None,
None,
None,
None,
None,
None,
None,
None,
multimodal_outputs,
)
runner.kv_connector_output = None
runner.input_batch = _DummyInputBatch()
runner.use_async_scheduling = False
runner.device = torch.device("cpu")
runner.supports_mm_inputs = False
return runner
def test_sample_tokens_tensor_output():
multimodal_outputs = torch.randn(1, 2, 3)
runner = _make_runner(multimodal_outputs)
output = GPUGenerationModelRunner.sample_tokens(runner)
assert len(output.pooler_output) == 1
assert output.pooler_output[0]["model_outputs"].shape == (2, 3)
def test_sample_tokens_list_output():
multimodal_outputs = [torch.randn(2, 1)]
runner = _make_runner(multimodal_outputs)
output = GPUGenerationModelRunner.sample_tokens(runner)
assert len(output.pooler_output) == 1
assert output.pooler_output[0]["model_outputs"].shape == (2, 1)
def test_sample_tokens_list_allows_none_output():
multimodal_outputs = [None]
runner = _make_runner(multimodal_outputs)
output = GPUGenerationModelRunner.sample_tokens(runner)
assert len(output.pooler_output) == 1
assert output.pooler_output[0]["model_outputs"] is None
def test_sample_tokens_dict_output():
multimodal_outputs = {"audio": torch.randn(1, 4), "unused": None}
runner = _make_runner(multimodal_outputs)
output = GPUGenerationModelRunner.sample_tokens(runner)
assert len(output.pooler_output) == 1
assert "audio" in output.pooler_output[0]
assert "unused" not in output.pooler_output[0]
assert output.pooler_output[0]["audio"].shape == (1, 4)
from contextlib import contextmanager
from types import SimpleNamespace
import torch
from vllm_omni.worker.gpu_model_runner import OmniGPUModelRunner
class DummyBuffer:
"""A minimal buffer wrapper that exposes the `.gpu` attribute."""
def __init__(self, t: torch.Tensor):
self.gpu = t
class DummyInputBatch:
"""A minimal input batch that only provides `req_ids`."""
def __init__(self, req_ids):
self.req_ids = req_ids
class DummyReqState:
"""A minimal request state container."""
pass
class DummyTalkerMTP(torch.nn.Module):
"""A fake talker_mtp module for deterministic CPU testing."""
def forward(self, req_input_ids, req_embeds, last_talker_hidden, text_step):
# Deterministic behavior:
# - output embeds = input embeds + 1
# - output codes = [[0], [1], ...]
bsz = req_embeds.shape[0]
new_embeds = req_embeds + 1.0
codes = torch.arange(bsz, dtype=torch.int64).view(bsz, 1)
return new_embeds, codes
@contextmanager
def _noop_forward_context(*args, **kwargs):
"""A no-op context manager to replace vLLM forward context in CPU tests."""
yield
def _make_runner(req_ids=("r1", "r2"), hidden_size=4):
# Create an instance without calling OmniGPUModelRunner.__init__
runner = object.__new__(OmniGPUModelRunner)
# Minimal attributes used by OmniGPUModelRunner._talker_mtp_forward
runner.input_batch = DummyInputBatch(list(req_ids))
runner.requests = {rid: DummyReqState() for rid in req_ids}
# query_start_loc.cpu[req_index] is used to locate the token position
# in the flattened `inputs_embeds`.
runner.query_start_loc = type("QSL", (), {})()
# Map: r1 -> offset 0, r2 -> offset 3
runner.query_start_loc.cpu = torch.tensor([0, 3], dtype=torch.int32)
bsz = len(req_ids)
runner.talker_mtp_input_ids = DummyBuffer(torch.zeros((bsz,), dtype=torch.int64))
runner.talker_mtp_inputs_embeds = DummyBuffer(torch.zeros((bsz, hidden_size), dtype=torch.float32))
runner.last_talker_hidden = DummyBuffer(torch.zeros((bsz, hidden_size), dtype=torch.float32))
runner.text_step = DummyBuffer(torch.zeros((bsz, hidden_size), dtype=torch.float32))
runner.talker_mtp = DummyTalkerMTP()
runner.vllm_config = object()
# Provide a minimal implementation that returns the expected 4-tuple.
def _determine_batch_execution_and_padding(**kwargs):
return None, object(), None, None, None
runner._determine_batch_execution_and_padding = _determine_batch_execution_and_padding
# Use the real merge method from OmniGPUModelRunner.
return runner
def test_talker_mtp_forward_cpu_updates_inputs_and_info(monkeypatch):
# Patch the module-level `set_forward_context` symbol used inside
# OmniGPUModelRunner._talker_mtp_forward.
import vllm_omni.worker.gpu_model_runner as mod # Must be the same module that defines OmniGPUModelRunner
monkeypatch.setattr(mod, "set_forward_context", _noop_forward_context)
runner = _make_runner(req_ids=("r1", "r2"), hidden_size=4)
def fake_determine(self, num_tokens, num_reqs, num_scheduled_tokens_np, max_num_scheduled_tokens, use_cascade_attn):
batch_desc = SimpleNamespace(num_tokens=int(num_tokens))
return (False, batch_desc, None, None, None)
monkeypatch.setattr(runner, "_determine_batch_execution_and_padding", fake_determine.__get__(runner, type(runner)))
# Initialize per-request embeds (batch-major inside talker_mtp_inputs_embeds)
runner.talker_mtp_inputs_embeds.gpu[0] = torch.tensor([1.0, 2.0, 3.0, 4.0])
runner.talker_mtp_inputs_embeds.gpu[1] = torch.tensor([10.0, 20.0, 30.0, 40.0])
# Flattened `inputs_embeds`: offsets 0 and 3 will be overwritten
inputs_embeds = torch.zeros((6, 4), dtype=torch.float32)
# Call the original implementation from OmniGPUModelRunner (no re-implementation)
OmniGPUModelRunner._talker_mtp_forward(runner, ["r1", "r2"], inputs_embeds)
# Validate embeds were written back (+1)
assert torch.allclose(inputs_embeds[0], torch.tensor([2.0, 3.0, 4.0, 5.0]))
assert torch.allclose(inputs_embeds[3], torch.tensor([11.0, 21.0, 31.0, 41.0]))
# Validate per-request additional_information_cpu was updated
info_r1 = runner.requests["r1"].additional_information_cpu
info_r2 = runner.requests["r2"].additional_information_cpu
assert int(info_r1["code_predictor_codes"][0, 0]) == 0
assert int(info_r2["code_predictor_codes"][0, 0]) == 1
def test_talker_mtp_forward_cpu_empty_batch_noop(monkeypatch):
import vllm_omni.worker.gpu_model_runner as mod
monkeypatch.setattr(mod, "set_forward_context", _noop_forward_context)
runner = _make_runner(req_ids=("r1",), hidden_size=4)
inputs_embeds = torch.randn((2, 4))
before = inputs_embeds.clone()
OmniGPUModelRunner._talker_mtp_forward(runner, [], inputs_embeds)
# Ensure no changes were made
assert torch.allclose(inputs_embeds, before)
#!/usr/bin/env python3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import sys
import regex as re
# List of files (relative to repo root) that are allowed to import pickle or
# cloudpickle
#
# STOP AND READ BEFORE YOU ADD ANYTHING ELSE TO THIS LIST:
# The pickle and cloudpickle modules are known to be unsafe when deserializing
# data from potentially untrusted parties. They have resulted in multiple CVEs
# for vLLM and numerous vulnerabilities in the Python ecosystem more broadly.
# Before adding new uses of pickle/cloudpickle, please consider safer
# alternatives like msgpack or pydantic that are already in use in vLLM. Only
# add to this list if absolutely necessary and after careful security review.
ALLOWED_FILES = {
"vllm_omni/entrypoints/omni_llm.py",
"tests/e2e/offline_inference/utils.py",
"tests/utils.py",
"vllm_omni/diffusion/distributed/group_coordinator.py",
"tests/diffusion/attention/test_attention_sp.py",
}
PICKLE_RE = re.compile(
r"^\s*(import\s+(pickle|cloudpickle)(\s|$|\sas)"
r"|from\s+(pickle|cloudpickle)\s+import\b)"
)
def scan_file(path: str) -> int:
with open(path, encoding="utf-8") as f:
for i, line in enumerate(f, 1):
if PICKLE_RE.match(line):
print(
f"{path}:{i}: "
"\033[91merror:\033[0m " # red color
"Found pickle/cloudpickle import"
)
return 1
return 0
def main():
returncode = 0
for filename in sys.argv[1:]:
if filename in ALLOWED_FILES:
continue
returncode |= scan_file(filename)
return returncode
def test_regex():
test_cases = [
# Should match
("import pickle", True),
("import cloudpickle", True),
("import pickle as pkl", True),
("import cloudpickle as cpkl", True),
("from pickle import *", True),
("from cloudpickle import dumps", True),
("from pickle import dumps, loads", True),
("from cloudpickle import (dumps, loads)", True),
(" import pickle", True),
("\timport cloudpickle", True),
("from pickle import loads", True),
# Should not match
("import somethingelse", False),
("from somethingelse import pickle", False),
("# import pickle", False),
("print('import pickle')", False),
("import pickleas as asdf", False),
]
for i, (line, should_match) in enumerate(test_cases):
result = bool(PICKLE_RE.match(line))
assert result == should_match, f"Test case {i} failed: '{line}' (expected {should_match}, got {result})"
print("All regex tests passed.")
if __name__ == "__main__":
if "--test-regex" in sys.argv:
test_regex()
else:
sys.exit(main())
"""
vLLM-Omni: Multi-modality models inference and serving with
non-autoregressive structures.
This package extends vLLM beyond traditional text-based, autoregressive
generation to support multi-modality models with non-autoregressive
structures and non-textual outputs.
Architecture:
- 🟡 Modified: vLLM components modified for multimodal support
- 🔴 Added: New components for multimodal and non-autoregressive
processing
"""
try:
from . import patch # noqa: F401
except ModuleNotFoundError as exc: # pragma: no cover - optional dependency
if exc.name != "vllm":
raise
# Allow importing vllm_omni without vllm (e.g., documentation builds)
patch = None # type: ignore
from .config import OmniModelConfig
from .entrypoints.async_omni import AsyncOmni
# Main entry points
from .entrypoints.omni import Omni
from .version import __version__, __version_tuple__ # isort:skip
__all__ = [
"__version__",
"__version_tuple__",
# Main components
"Omni",
"AsyncOmni",
# Configuration
"OmniModelConfig",
# All other components are available through their respective modules
# processors.*, schedulers.*, executors.*, etc.
]
import librosa
import numpy as np
from vllm.assets.video import VideoAsset
def extract_video_audio(path: str = None, sampling_rate: int = 16000) -> np.ndarray:
"""This function extracts the audio from a video file path and returns the audio as a numpy array.
Args:
path: The path to the video file.
Returns:
The audio as a numpy array.
"""
if not path:
path = VideoAsset(name="baby_reading").video_path
audio_signal, sr = librosa.load(path, sr=sampling_rate)
return audio_signal
import base64
import io
import logging
from collections.abc import Mapping
from typing import Any
import numpy as np
import soundfile as sf
import torch
from vllm.benchmarks.datasets import RandomMultiModalDataset, process_image, process_video
logger = logging.getLogger(__name__)
def process_audio(audio: Any) -> Mapping[str, Any]:
"""
Process a single audio input and return a multimedia content dictionary.
Supports the following input types:
1. Dictionary with raw audio bytes: - Expects a dict with a 'bytes' key
containing raw audio data.
2. String input: - Treats the string as a URL or local file path. -
Prepends "file://" if the string doesn't start with "http://" or
"file://". - Returns a dictionary with the audio URL.
Raises:
ValueError: If the input is not a supported type.
"""
if isinstance(audio, dict) and "bytes" in audio:
audio_bytes = audio["bytes"]
audio_base64 = base64.b64encode(audio_bytes).decode("utf-8")
return {
"type": "audio_url",
"audio_url": {"url": f"data:audio/mpeg;base64,{audio_base64}"},
}
if isinstance(audio, str):
audio_url = audio if audio.startswith(("http://", "https://", "file://")) else f"file://{audio}"
return {"type": "audio_url", "audio_url": {"url": audio_url}}
raise ValueError(
f"Invalid audio input {audio}. Must be a string of local path/remote url, "
f"or a dictionary with raw audio bytes in the form of `{{'bytes': raw_audio_bytes}}`."
)
# -----------------------------------------------------------------------------
# MultiModalDataset Implementation
# -----------------------------------------------------------------------------
class OmniRandomMultiModalDataset(RandomMultiModalDataset):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def generate_synthetic_audio(
self,
duration: int, # seconds
num_channels: int, # 1:Mono,2:Stereo 5:5.1 surround sound
) -> dict[str, Any]:
"""Generate synthetic audio with random values.
Default use 48000Hz.
"""
sample_rate = 48000
num_samples = int(sample_rate * duration)
audio_data = self._rng.uniform(-0.5, 0.5, (num_samples, num_channels))
audio_data = np.clip(audio_data, -1.0, 1.0)
audio_tensor = torch.FloatTensor(audio_data.T)
audio_np = audio_tensor.numpy()
buffer = io.BytesIO()
sf.write(buffer, audio_np.T, sample_rate, format="wav")
buffer.seek(0)
audio_bytes = buffer.read()
buffer.close()
return {
"bytes": audio_bytes,
}
def generate_mm_item(
self,
mm_item_config: tuple[int, int, int],
) -> Mapping[str, Any]:
"""
Create synthetic images and videos and
apply process_image/process_video respectively.
This follows the OpenAI API chat completions
https://github.com/openai/openai-python
"""
if self.map_config_to_modality(mm_item_config) == "image":
return process_image(self.generate_synthetic_image(mm_item_config[1], mm_item_config[0]))
elif self.map_config_to_modality(mm_item_config) == "video":
return process_video(self.generate_synthetic_video(mm_item_config[1], mm_item_config[0], mm_item_config[2]))
elif self.map_config_to_modality(mm_item_config) == "audio":
return process_audio(self.generate_synthetic_audio(mm_item_config[1], mm_item_config[2]))
else:
raise ValueError(f"Invalid multimodal item configuration: {mm_item_config}")
def generate_synthetic_video(self, width: int, height: int, num_frames: int) -> Any:
"""Generate synthetic video with random values."""
import imageio
video_data = self._rng.integers(
0,
256,
(num_frames, height, width, 3),
dtype=np.uint8,
)
buffer = io.BytesIO()
writer_kwargs = {
"format": "mp4",
"fps": 30,
"codec": "libx264",
"quality": 7,
"pixelformat": "yuv420p",
"macro_block_size": 16,
"ffmpeg_params": [
"-preset",
"medium",
"-crf",
"23",
"-movflags",
"+faststart",
"-pix_fmt",
"yuv420p",
"-vf",
f"scale={width}:{height}",
],
}
with imageio.get_writer(buffer, **writer_kwargs) as writer:
for frame_idx in range(num_frames):
writer.append_data(video_data[frame_idx])
buffer.seek(0)
video_bytes = buffer.read()
return {
"bytes": video_bytes,
}
def map_config_to_modality(self, config: tuple[int, int, int]) -> str:
"""Map the configuration to the modality."""
if config[0] == 0:
return "audio"
elif config[-1] == 1:
return "image"
elif config[-1] > 1:
return "video"
else:
raise ValueError(f"Invalid multimodal item configuration: {config}")
import warnings
from dataclasses import dataclass
import numpy as np
from transformers import PreTrainedTokenizerBase
from vllm.benchmarks.datasets import SampleRequest
from vllm.benchmarks.lib.endpoint_request_func import RequestFuncOutput
from vllm.benchmarks.serve import MILLISECONDS_TO_SECONDS_CONVERSION, TERM_PLOTLIB_AVAILABLE, BenchmarkMetrics, TaskType
@dataclass
class MultiModalsBenchmarkMetrics(BenchmarkMetrics):
mean_audio_ttfp_ms: float = 0.0
median_audio_ttfp_ms: float = 0.0
std_audio_ttfp_ms: float = 0.0
percentiles_audio_ttfp_ms: list[tuple[float, float]] = None
total_audio_duration_ms: float = 0.0
total_audio_frames: int = 0
audio_throughput: float = 0.0
mean_audio_rtf: float = 0.0
median_audio_rtf: float = 0.0
std_audio_rtf: float = 0.0
percentiles_audio_rtf: list[tuple[float, float]] = None
def print_metrics(
task_type,
selected_percentile_metrics,
max_concurrency,
request_rate,
benchmark_duration,
goodput_config_dict,
metrics: MultiModalsBenchmarkMetrics,
):
print("{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="="))
print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
print("{:<40} {:<10}".format("Failed requests:", metrics.failed))
if max_concurrency is not None:
print("{:<40} {:<10}".format("Maximum request concurrency:", max_concurrency))
if request_rate != float("inf"):
print("{:<40} {:<10.2f}".format("Request rate configured (RPS):", request_rate))
print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration))
print("{:<40} {:<10.2f}".format("Request throughput (req/s):", metrics.request_throughput))
if goodput_config_dict:
print("{:<40} {:<10.2f}".format("Request goodput (req/s):", metrics.request_goodput))
if isinstance(metrics, MultiModalsBenchmarkMetrics):
print("{:<40} {:<10.2f}".format("Peak concurrent requests:", metrics.max_concurrent_requests))
if task_type != TaskType.GENERATION or "e2el" in selected_percentile_metrics:
process_one_metric("e2el", metrics)
print_text_metrics(task_type, selected_percentile_metrics, metrics)
if task_type == TaskType.GENERATION:
print_audio_metrics(selected_percentile_metrics, metrics)
print("=" * 50)
def print_text_metrics(task_type, selected_percentile_metrics, metrics: MultiModalsBenchmarkMetrics):
print("{s:{c}^{n}}".format(s=" Text Result ", n=50, c="="))
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
if isinstance(metrics, MultiModalsBenchmarkMetrics):
print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output))
print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):", metrics.output_throughput))
print("{:<40} {:<10.2f}".format("Peak output token throughput (tok/s):", metrics.max_output_tokens_per_s))
print("{:<40} {:<10.2f}".format("Peak concurrent requests:", metrics.max_concurrent_requests))
print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):", metrics.total_token_throughput))
if task_type == TaskType.GENERATION:
for metric in selected_percentile_metrics:
if metric == "e2el":
continue
if not metric.startswith("audio"):
process_one_metric(metric, metrics)
def print_audio_metrics(selected_percentile_metrics, metrics: MultiModalsBenchmarkMetrics):
print("{s:{c}^{n}}".format(s=" Audio Result ", n=50, c="="))
print("{:<40} {:<10.2f}".format("Total audio duration generated(s):", metrics.total_audio_duration_ms))
print("{:<40} {:<10}".format("Total audio frames generated:", metrics.total_audio_frames))
print("{:<40} {:<10.2f}".format("Audio throughput(audio duration/s):", metrics.audio_throughput))
for metric in selected_percentile_metrics:
if metric.startswith("audio"):
process_one_metric(metric, metrics)
def process_one_metric(
metric_attribute_name: str,
metrics: MultiModalsBenchmarkMetrics,
):
metric_header_map = {
"ttft": "Time to First Token",
"tpot": "Time per Output Token (excl. 1st token)",
"itl": "Inter-token Latency",
"e2el": "End-to-end Latency",
"audio_ttfp": "Time to First Packet",
"audio_rtf": "Real Time Factor",
}
header = metric_header_map.get(metric_attribute_name, metric_attribute_name)
print("{s:{c}^{n}}".format(s=header, n=50, c="-"))
is_audio_rtf = metric_attribute_name == "audio_rtf"
suffix = "" if is_audio_rtf else "_ms"
unit_suffix = "" if is_audio_rtf else " (ms)"
mean_attr_name = f"mean_{metric_attribute_name}{suffix}"
mean_value = getattr(metrics, mean_attr_name, 0.0)
print(f"{f'Mean {metric_attribute_name.upper()}{unit_suffix}:':<40} {mean_value:<10.2f}")
median_attr_name = f"median_{metric_attribute_name}{suffix}"
median_value = getattr(metrics, median_attr_name, 0.0)
print(f"{f'Median {metric_attribute_name.upper()}{unit_suffix}:':<40} {median_value:<10.2f}")
percentiles_attr_name = f"percentiles_{metric_attribute_name}{suffix}"
percentiles = getattr(metrics, percentiles_attr_name, [])
for percentile, value in percentiles:
p_str = str(int(percentile)) if percentile.is_integer() else str(percentile)
label = f"P{p_str} {metric_attribute_name.upper()}{unit_suffix}:"
print(f"{label:<40} {value:<10.2f}")
def calculate_metrics(
input_requests: list[SampleRequest],
outputs: list[RequestFuncOutput],
dur_s: float,
tokenizer: PreTrainedTokenizerBase,
selected_percentiles: list[float],
goodput_config_dict: dict[str, float],
task_type,
selected_percentile_metrics,
max_concurrency,
request_rate,
benchmark_duration,
) -> tuple[BenchmarkMetrics, list[int]]:
"""Calculate the metrics for the benchmark.
Args:
input_requests: The input requests.
outputs: The outputs of the requests.
dur_s: The duration of the benchmark.
tokenizer: The tokenizer to use.
selected_percentiles: The percentiles to select.
goodput_config_dict: The goodput configuration.
Returns:
A tuple of the benchmark metrics and the actual output lengths.
"""
actual_output_lens: list[int] = []
total_input = 0
completed = 0
good_completed = 0
itls: list[float] = []
tpots: list[float] = []
all_tpots: list[float] = []
ttfts: list[float] = []
e2els: list[float] = []
audio_ttfps: list[float] = []
audio_rtfs: list[float] = []
audio_duration: list[float] = []
audio_frames: list[int] = []
for i in range(len(outputs)):
if outputs[i].success:
output_len = outputs[i].output_tokens
if not output_len:
# We use the tokenizer to count the number of output tokens
# for some serving backends instead of looking at
# len(outputs[i].itl) since multiple output tokens may be
# bundled together
# Note : this may inflate the output token count slightly
output_len = len(tokenizer(outputs[i].generated_text, add_special_tokens=False).input_ids)
actual_output_lens.append(output_len)
total_input += input_requests[i].prompt_len
tpot = 0
if output_len > 1:
latency_minus_ttft = outputs[i].latency - outputs[i].ttft
tpot = latency_minus_ttft / (output_len - 1)
tpots.append(tpot)
# Note: if output_len <= 1, we regard tpot as 0 for goodput
all_tpots.append(tpot)
itls += outputs[i].itl
ttfts.append(outputs[i].ttft)
audio_ttfps.append(getattr(outputs[i], "audio_ttfp", 0.0))
audio_rtfs.append(getattr(outputs[i], "audio_rtf", 0.0))
audio_duration.append(getattr(outputs[i], "audio_duration", 0.0))
audio_frames.append(getattr(outputs[i], "audio_frames", 0.0))
e2els.append(outputs[i].latency)
completed += 1
else:
actual_output_lens.append(0)
if goodput_config_dict:
valid_metrics = []
slo_values = []
if "ttft" in goodput_config_dict:
valid_metrics.append(ttfts)
slo_values.append(goodput_config_dict["ttft"] / MILLISECONDS_TO_SECONDS_CONVERSION)
if "audio_ttft" in goodput_config_dict:
valid_metrics.append(audio_ttfps)
slo_values.append(goodput_config_dict["audio_ttft"] / MILLISECONDS_TO_SECONDS_CONVERSION)
if "tpot" in goodput_config_dict:
valid_metrics.append(all_tpots)
slo_values.append(goodput_config_dict["tpot"] / MILLISECONDS_TO_SECONDS_CONVERSION)
if "e2el" in goodput_config_dict:
valid_metrics.append(e2els)
slo_values.append(goodput_config_dict["e2el"] / MILLISECONDS_TO_SECONDS_CONVERSION)
for req_metric in zip(*valid_metrics):
is_good_req = all([s >= r for s, r in zip(slo_values, req_metric)])
if is_good_req:
good_completed += 1
if completed == 0:
warnings.warn(
"All requests failed. This is likely due to a misconfiguration on the benchmark arguments.",
stacklevel=2,
)
# Calculate max output tokens per second metric
max_output_tokens_per_s = 0.0
max_concurrent_requests = 0
# Find the time range across all successful requests
successful_outputs = [output for output in outputs if output.success]
failed_outputs = [output for output in outputs if not output.success]
if successful_outputs:
min_start_time = min(output.start_time for output in successful_outputs)
max_end_time = max(output.start_time + output.latency for output in successful_outputs)
# Create second buckets (ceiling to ensure we capture all time)
duration_seconds = int(np.ceil(max_end_time - min_start_time)) + 1
tokens_per_second = np.zeros(duration_seconds)
concurrent_requests_per_second = np.zeros(duration_seconds)
for i, output in enumerate(successful_outputs):
# Calculate token generation timestamp using
# start_time, ttft, and itl
token_times = [output.start_time + output.ttft]
current_time = token_times[0]
for itl_value in output.itl:
current_time += itl_value
token_times.append(current_time)
# Add tokens to second buckets
for token_time in token_times:
second_bucket = int(token_time - min_start_time)
if 0 <= second_bucket < duration_seconds:
tokens_per_second[second_bucket] += 1
# Track concurrent requests for each second this request was active
request_start_second = int(output.start_time - min_start_time)
request_end_second = int((output.start_time + output.latency) - min_start_time)
for second in range(request_start_second, request_end_second + 1):
concurrent_requests_per_second[second] += 1
# Find the maximum tokens per second and corresponding
# concurrent requests
if len(tokens_per_second) > 0:
max_output_tokens_per_s = float(np.max(tokens_per_second))
max_concurrent_requests = int(np.max(concurrent_requests_per_second))
if TERM_PLOTLIB_AVAILABLE:
import termplotlib as tpl
fig = tpl.figure()
fig.plot(
np.arange(len(tokens_per_second)),
tokens_per_second,
title="Output tokens per second",
)
fig.plot(
np.arange(len(concurrent_requests_per_second)),
concurrent_requests_per_second,
title="Concurrent requests per second",
)
fig.show()
else:
print("tip: install termplotlib and gnuplot to plot the metrics")
metrics = MultiModalsBenchmarkMetrics(
completed=completed,
failed=len(failed_outputs),
total_input=total_input,
total_output=sum(actual_output_lens),
request_throughput=completed / dur_s,
request_goodput=good_completed / dur_s,
output_throughput=sum(actual_output_lens) / dur_s,
total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s,
mean_ttft_ms=np.mean(ttfts or 0) * 1000, # ttfts is empty if streaming is not supported by the endpoint
std_ttft_ms=np.std(ttfts or 0) * 1000,
median_ttft_ms=np.median(ttfts or 0) * 1000,
percentiles_ttft_ms=[(p, np.percentile(ttfts or 0, p) * 1000) for p in selected_percentiles],
mean_audio_ttfp_ms=np.mean(audio_ttfps or 0) * 1000,
std_audio_ttfp_ms=np.std(audio_ttfps or 0) * 1000,
median_audio_ttfp_ms=np.median(audio_ttfps or 0) * 1000,
percentiles_audio_ttfp_ms=[(p, np.percentile(audio_ttfps or 0, p) * 1000) for p in selected_percentiles],
total_audio_duration_ms=sum(audio_duration),
total_audio_frames=sum(audio_frames),
audio_throughput=sum(audio_duration) / dur_s,
mean_audio_rtf=np.mean(audio_rtfs or 0),
std_audio_rtf=np.std(audio_rtfs or 0),
median_audio_rtf=np.median(audio_rtfs or 0),
percentiles_audio_rtf=[(p, np.percentile(audio_rtfs or 0, p)) for p in selected_percentiles],
mean_tpot_ms=np.mean(tpots or 0) * 1000,
std_tpot_ms=np.std(tpots or 0) * 1000,
median_tpot_ms=np.median(tpots or 0) * 1000,
percentiles_tpot_ms=[(p, np.percentile(tpots or 0, p) * 1000) for p in selected_percentiles],
mean_itl_ms=np.mean(itls or 0) * 1000,
std_itl_ms=np.std(itls or 0) * 1000,
median_itl_ms=np.median(itls or 0) * 1000,
percentiles_itl_ms=[(p, np.percentile(itls or 0, p) * 1000) for p in selected_percentiles],
mean_e2el_ms=np.mean(e2els or 0) * 1000,
std_e2el_ms=np.std(e2els or 0) * 1000,
median_e2el_ms=np.median(e2els or 0) * 1000,
percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000) for p in selected_percentiles],
max_output_tokens_per_s=max_output_tokens_per_s,
max_concurrent_requests=max_concurrent_requests,
)
print_metrics(
task_type,
selected_percentile_metrics,
max_concurrency,
request_rate,
benchmark_duration,
goodput_config_dict,
metrics,
)
return metrics, actual_output_lens
import asyncio
import base64
import contextlib
import io
import json
import os
import random
import sys
import time
import traceback
from collections.abc import Iterable
from dataclasses import dataclass
from datetime import datetime
from typing import Literal
import aiohttp
from pydub import AudioSegment
from tqdm.asyncio import tqdm
from transformers import PreTrainedTokenizerBase
from vllm.benchmarks import datasets
from vllm.benchmarks.datasets import SampleRequest
from vllm.benchmarks.lib.endpoint_request_func import (
ASYNC_REQUEST_FUNCS,
OPENAI_COMPATIBLE_BACKENDS,
RequestFuncInput,
RequestFuncOutput,
StreamedResponseHandler,
_get_chat_content,
_update_headers_common,
_update_payload_common,
_validate_api_url,
)
from vllm.logger import init_logger
logger = init_logger(__name__)
from vllm_omni.benchmarks.data_modules.random_multi_modal_dataset import OmniRandomMultiModalDataset
get_samples_old = datasets.get_samples
def get_samples(args, tokenizer):
if args.backend not in ["openai-chat-omni"]:
raise ValueError("benchmark is only supported on 'openai-chat-omni' backend.")
if args.dataset_name == "random-mm":
dataset = OmniRandomMultiModalDataset(random_seed=args.seed, dataset_path=args.dataset_path)
input_requests = dataset.sample(
tokenizer=tokenizer,
num_requests=args.num_prompts,
prefix_len=args.random_prefix_len,
range_ratio=args.random_range_ratio,
input_len=args.random_input_len,
output_len=args.random_output_len,
base_items_per_request=args.random_mm_base_items_per_request,
limit_mm_per_prompt=args.random_mm_limit_mm_per_prompt,
num_mm_items_range_ratio=args.random_mm_num_mm_items_range_ratio,
bucket_config=args.random_mm_bucket_config,
request_id_prefix=args.request_id_prefix,
no_oversample=args.no_oversample,
)
return input_requests
else:
return get_samples_old(args, tokenizer)
datasets.get_samples = get_samples
@dataclass
class MixRequestFuncOutput(RequestFuncOutput):
audio_ttfp: float = 0.0
audio_duration: float = 0.0
audio_frames: int = 0
audio_rtf: float = 0.0
async def async_request_openai_chat_omni_completions(
request_func_input: RequestFuncInput,
session: aiohttp.ClientSession,
pbar: tqdm | None = None,
mm_position: Literal["first", "last"] = "last",
) -> MixRequestFuncOutput:
api_url = request_func_input.api_url
_validate_api_url(api_url, "OpenAI Chat Completions API", "chat/completions")
content = _get_chat_content(request_func_input, mm_position=mm_position)
payload = {
"model": request_func_input.model_name if request_func_input.model_name else request_func_input.model,
"messages": [
{"role": "user", "content": content},
],
"temperature": 0.0,
"max_tokens": request_func_input.output_len,
"stream": True,
"stream_options": {
"include_usage": True,
},
}
_update_payload_common(payload, request_func_input)
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
}
_update_headers_common(headers, request_func_input)
output = MixRequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
generated_text = ""
generated_audio = None
ttft = 0.0
st = time.perf_counter()
output.start_time = st
most_recent_timestamp = st
audio_generate_time = 0.0
audio_first_timestamp = st
try:
async with session.post(url=api_url, json=payload, headers=headers) as response:
if response.status == 200:
handler = StreamedResponseHandler()
async for chunk_bytes in response.content.iter_any():
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
messages = handler.add_chunk(chunk_bytes)
for message in messages:
# NOTE: SSE comments (often used as pings) start with
# a colon. These are not JSON data payload and should
# be skipped.
if message.startswith(":"):
continue
chunk = message.removeprefix("data: ")
if chunk != "[DONE]":
timestamp = time.perf_counter()
data = json.loads(chunk)
if choices := data.get("choices"):
modality = data.get("modality")
content = choices[0]["delta"].get("content")
if modality == "text":
# First token
if ttft == 0.0:
ttft = timestamp - st
output.ttft = ttft
else:
output.itl.append(timestamp - most_recent_timestamp)
generated_text += content or ""
elif modality == "audio":
if output.audio_ttfp == 0.0:
audio_first_timestamp = timestamp
output.audio_ttfp = timestamp - st
audio_generate_time = timestamp - audio_first_timestamp
if content != "":
audio_bytes = base64.b64decode(content)
seg = AudioSegment.from_file(io.BytesIO(audio_bytes))
if seg is not None:
if generated_audio is None:
generated_audio = seg
else:
generated_audio = generated_audio + seg
elif usage := data.get("usage"):
output.output_tokens = usage.get("completion_tokens")
most_recent_timestamp = timestamp
output.generated_text = generated_text
if generated_audio is not None:
output.audio_duration = len(generated_audio) / 1000.0
frame_width = generated_audio.frame_width
if frame_width > 0:
output.audio_frames = len(generated_audio.raw_data) // frame_width
else:
output.audio_frames = 0
logger.warning("Audio frame width is zero")
audio_duration = output.audio_duration
if audio_duration > 0:
output.audio_rtf = audio_generate_time / output.audio_duration
else:
output.audio_rtf = 0
logger.warning("Audio duration is zero")
output.success = True
output.latency = most_recent_timestamp - st
else:
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
logger.error(f"ERROR: send request failed, reason is: {output.error}")
if pbar:
pbar.update(1)
return output
ASYNC_REQUEST_FUNCS["openai-chat-omni"] = async_request_openai_chat_omni_completions
if "openai-chat-omni" not in OPENAI_COMPATIBLE_BACKENDS:
OPENAI_COMPATIBLE_BACKENDS.append("openai-chat-omni")
# ruff: noqa: E402
# Prevent import order from causing patch failures
from vllm.benchmarks import serve
from vllm.benchmarks.serve import TaskType, calculate_metrics_for_embeddings, get_request, wait_for_endpoint
from vllm_omni.benchmarks.metrics.metrics import MultiModalsBenchmarkMetrics, calculate_metrics
# ruff: noqa: E402
benchmark_old = serve.benchmark
async def benchmark(
task_type: TaskType,
endpoint_type: str,
api_url: str,
base_url: str,
model_id: str,
model_name: str,
tokenizer: PreTrainedTokenizerBase,
input_requests: list[SampleRequest],
logprobs: int | None,
request_rate: float,
burstiness: float,
disable_tqdm: bool,
num_warmups: int,
profile: bool,
selected_percentile_metrics: list[str],
selected_percentiles: list[float],
ignore_eos: bool,
goodput_config_dict: dict[str, float],
max_concurrency: int | None,
lora_modules: Iterable[str] | None,
extra_headers: dict | None,
extra_body: dict | None,
ramp_up_strategy: Literal["linear", "exponential"] | None = None,
ramp_up_start_rps: int | None = None,
ramp_up_end_rps: int | None = None,
ready_check_timeout_sec: int = 600,
):
try:
request_func = ASYNC_REQUEST_FUNCS[endpoint_type]
except KeyError:
raise ValueError(f"Unknown backend: {endpoint_type}") from None
# Reuses connections across requests to reduce TLS handshake overhead.
connector = aiohttp.TCPConnector(
limit=max_concurrency or 0,
limit_per_host=max_concurrency or 0,
ttl_dns_cache=300,
use_dns_cache=True,
keepalive_timeout=60,
enable_cleanup_closed=True,
force_close=False,
ssl=("https://" in api_url),
)
session = aiohttp.ClientSession(
connector=connector,
trust_env=True,
timeout=aiohttp.ClientTimeout(total=6 * 60 * 60),
)
print("Starting initial single prompt test run...")
test_prompt, test_prompt_len, test_output_len, test_mm_content = (
input_requests[0].prompt,
input_requests[0].prompt_len,
input_requests[0].expected_output_len,
input_requests[0].multi_modal_data,
)
assert (
test_mm_content is None
or isinstance(test_mm_content, dict)
or (isinstance(test_mm_content, list) and all(isinstance(item, dict) for item in test_mm_content))
), "multi_modal_data must be a dict or list[dict]"
test_input = RequestFuncInput(
model=model_id,
model_name=model_name,
prompt=test_prompt,
api_url=api_url,
prompt_len=test_prompt_len,
output_len=test_output_len,
logprobs=logprobs,
multi_modal_content=test_mm_content,
ignore_eos=ignore_eos,
extra_headers=extra_headers,
extra_body=extra_body,
)
if ready_check_timeout_sec > 0:
test_output = await wait_for_endpoint(
request_func,
test_input,
session,
timeout_seconds=ready_check_timeout_sec,
)
if not test_output.success:
raise ValueError(
"Initial test run failed - Please make sure benchmark "
"arguments are correctly specified. "
f"Error: {test_output.error}"
)
else:
print("Initial test run completed.")
else:
print("Skipping endpoint ready check.")
if num_warmups > 0:
print(f"Warming up with {num_warmups} requests...")
warmup_pbar = None if disable_tqdm else tqdm(total=num_warmups)
warmup_semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else contextlib.nullcontext()
warmup_tasks = []
async def warmup_limited_request_func():
async with warmup_semaphore:
return await request_func(request_func_input=test_input, session=session, pbar=warmup_pbar)
for _ in range(num_warmups):
request_task = asyncio.create_task(warmup_limited_request_func())
warmup_tasks.append(request_task)
_ = await asyncio.gather(*warmup_tasks)
if warmup_pbar is not None:
warmup_pbar.close()
print("Warmup run completed.")
print("Starting main benchmark run...")
if lora_modules:
# For each input request, choose a LoRA module at random.
lora_modules = iter([random.choice(lora_modules) for _ in range(len(input_requests))])
if profile:
print("Starting profiler...")
profile_input = RequestFuncInput(
model=model_id,
model_name=model_name,
prompt=test_prompt,
api_url=base_url + "/start_profile",
prompt_len=test_prompt_len,
output_len=test_output_len,
logprobs=logprobs,
multi_modal_content=test_mm_content,
ignore_eos=ignore_eos,
extra_headers=extra_headers,
extra_body=extra_body,
)
profile_output = await request_func(request_func_input=profile_input, session=session)
if profile_output.success:
print("Profiler started")
distribution = "Poisson process" if burstiness == 1.0 else "Gamma distribution"
if ramp_up_strategy is not None:
print(f"Traffic ramp-up strategy: {ramp_up_strategy}.")
print(
f"Will increase RPS from {ramp_up_start_rps} to {ramp_up_end_rps} RPS over the duration of the benchmark."
)
else:
print(f"Traffic request rate: {request_rate}")
print(f"Burstiness factor: {burstiness} ({distribution})")
print(f"Maximum request concurrency: {max_concurrency}")
pbar = None if disable_tqdm else tqdm(total=len(input_requests))
semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else contextlib.nullcontext()
async def limited_request_func(request_func_input, session, pbar):
async with semaphore:
return await request_func(request_func_input=request_func_input, session=session, pbar=pbar)
benchmark_start_time = time.perf_counter()
tasks: list[asyncio.Task] = []
rps_change_events = []
last_int_rps = -1
if ramp_up_strategy is not None and ramp_up_start_rps is not None:
last_int_rps = ramp_up_start_rps
rps_change_events.append(
{
"rps": last_int_rps,
"timestamp": datetime.now().isoformat(),
}
)
async for request, current_request_rate in get_request(
input_requests,
request_rate,
burstiness,
ramp_up_strategy,
ramp_up_start_rps,
ramp_up_end_rps,
):
if ramp_up_strategy is not None:
current_int_rps = int(current_request_rate)
if current_int_rps > last_int_rps:
timestamp = datetime.now().isoformat()
for rps_val in range(last_int_rps + 1, current_int_rps + 1):
rps_change_events.append({"rps": rps_val, "timestamp": timestamp})
last_int_rps = current_int_rps
prompt, prompt_len, output_len, mm_content, request_id = (
request.prompt,
request.prompt_len,
request.expected_output_len,
request.multi_modal_data,
request.request_id,
)
req_model_id, req_model_name = model_id, model_name
if lora_modules:
req_lora_module = next(lora_modules)
req_model_id, req_model_name = req_lora_module, req_lora_module
request_func_input = RequestFuncInput(
model=req_model_id,
model_name=req_model_name,
prompt=prompt,
api_url=api_url,
prompt_len=prompt_len,
output_len=output_len,
logprobs=logprobs,
multi_modal_content=mm_content,
ignore_eos=ignore_eos,
extra_headers=extra_headers,
extra_body=extra_body,
request_id=request_id,
)
tasks.append(
asyncio.create_task(limited_request_func(request_func_input=request_func_input, session=session, pbar=pbar))
)
outputs: list[MixRequestFuncOutput] = await asyncio.gather(*tasks)
if pbar is not None:
pbar.close()
benchmark_duration = time.perf_counter() - benchmark_start_time
if task_type == TaskType.GENERATION:
metrics, actual_output_lens = calculate_metrics(
input_requests=input_requests,
outputs=outputs,
dur_s=benchmark_duration,
tokenizer=tokenizer,
selected_percentiles=selected_percentiles,
goodput_config_dict=goodput_config_dict,
task_type=task_type,
selected_percentile_metrics=selected_percentile_metrics,
max_concurrency=max_concurrency,
request_rate=request_rate,
benchmark_duration=benchmark_duration,
)
else:
metrics = calculate_metrics_for_embeddings(
outputs=outputs,
dur_s=benchmark_duration,
selected_percentiles=selected_percentiles,
)
actual_output_lens = 0
if isinstance(metrics, MultiModalsBenchmarkMetrics):
result = {
"duration": benchmark_duration,
"completed": metrics.completed,
"failed": metrics.failed,
"total_input_tokens": metrics.total_input,
"total_output_tokens": metrics.total_output,
"request_throughput": metrics.request_throughput,
"request_goodput": metrics.request_goodput if goodput_config_dict else None,
"output_throughput": metrics.output_throughput,
"total_token_throughput": metrics.total_token_throughput,
"input_lens": [output.prompt_len for output in outputs],
"output_lens": actual_output_lens,
"ttfts": [output.ttft for output in outputs],
"itls": [output.itl for output in outputs],
"generated_texts": [output.generated_text for output in outputs],
"errors": [output.error for output in outputs],
"max_output_tokens_per_s": metrics.max_output_tokens_per_s,
"max_concurrent_requests": metrics.max_concurrent_requests,
}
else:
result = {
"duration": benchmark_duration,
"completed": metrics.completed,
"total_input_tokens": metrics.total_input,
"request_throughput": metrics.request_throughput,
"total_token_throughput": metrics.total_token_throughput,
"input_lens": [output.prompt_len for output in outputs],
"errors": [output.error for output in outputs],
}
if rps_change_events:
result["rps_change_events"] = rps_change_events
def process_one_metric(
# E.g., "ttft"
metric_attribute_name: str,
):
# This function prints and adds statistics of the specified
# metric.
if metric_attribute_name not in selected_percentile_metrics:
return
is_audio_rtf = metric_attribute_name == "audio_rtf"
suffix = "" if is_audio_rtf else "_ms"
for p, value in getattr(metrics, f"percentiles_{metric_attribute_name}{suffix}"):
p_word = str(int(p)) if int(p) == p else str(p)
result[f"p{p_word}_{metric_attribute_name}{suffix}"] = value
if task_type == TaskType.GENERATION:
for metric in selected_percentile_metrics:
process_one_metric(metric)
else:
process_one_metric("e2el")
if profile:
print("Stopping profiler...")
profile_input = RequestFuncInput(
model=model_id,
prompt=test_prompt,
api_url=base_url + "/stop_profile",
prompt_len=test_prompt_len,
output_len=test_output_len,
logprobs=logprobs,
)
profile_output = await request_func(request_func_input=profile_input, session=session)
if profile_output.success:
print("Profiler stopped")
await session.close()
return result
serve.benchmark = benchmark
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment