import asyncio import base64 import contextlib import io import json import os import random import sys import time import traceback from collections.abc import Iterable from dataclasses import dataclass from datetime import datetime from typing import Literal import aiohttp from pydub import AudioSegment from tqdm.asyncio import tqdm from transformers import PreTrainedTokenizerBase from vllm.benchmarks import datasets from vllm.benchmarks.datasets import SampleRequest from vllm.benchmarks.lib.endpoint_request_func import ( ASYNC_REQUEST_FUNCS, OPENAI_COMPATIBLE_BACKENDS, RequestFuncInput, RequestFuncOutput, StreamedResponseHandler, _get_chat_content, _update_headers_common, _update_payload_common, _validate_api_url, ) from vllm.logger import init_logger logger = init_logger(__name__) from vllm_omni.benchmarks.data_modules.random_multi_modal_dataset import OmniRandomMultiModalDataset get_samples_old = datasets.get_samples def get_samples(args, tokenizer): if args.backend not in ["openai-chat-omni"]: raise ValueError("benchmark is only supported on 'openai-chat-omni' backend.") if args.dataset_name == "random-mm": dataset = OmniRandomMultiModalDataset(random_seed=args.seed, dataset_path=args.dataset_path) input_requests = dataset.sample( tokenizer=tokenizer, num_requests=args.num_prompts, prefix_len=args.random_prefix_len, range_ratio=args.random_range_ratio, input_len=args.random_input_len, output_len=args.random_output_len, base_items_per_request=args.random_mm_base_items_per_request, limit_mm_per_prompt=args.random_mm_limit_mm_per_prompt, num_mm_items_range_ratio=args.random_mm_num_mm_items_range_ratio, bucket_config=args.random_mm_bucket_config, request_id_prefix=args.request_id_prefix, no_oversample=args.no_oversample, ) return input_requests else: return get_samples_old(args, tokenizer) datasets.get_samples = get_samples @dataclass class MixRequestFuncOutput(RequestFuncOutput): audio_ttfp: float = 0.0 audio_duration: float = 0.0 audio_frames: int = 0 audio_rtf: float = 0.0 async def async_request_openai_chat_omni_completions( request_func_input: RequestFuncInput, session: aiohttp.ClientSession, pbar: tqdm | None = None, mm_position: Literal["first", "last"] = "last", ) -> MixRequestFuncOutput: api_url = request_func_input.api_url _validate_api_url(api_url, "OpenAI Chat Completions API", "chat/completions") content = _get_chat_content(request_func_input, mm_position=mm_position) payload = { "model": request_func_input.model_name if request_func_input.model_name else request_func_input.model, "messages": [ {"role": "user", "content": content}, ], "temperature": 0.0, "max_tokens": request_func_input.output_len, "stream": True, "stream_options": { "include_usage": True, }, } _update_payload_common(payload, request_func_input) headers = { "Content-Type": "application/json", "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", } _update_headers_common(headers, request_func_input) output = MixRequestFuncOutput() output.prompt_len = request_func_input.prompt_len generated_text = "" generated_audio = None ttft = 0.0 st = time.perf_counter() output.start_time = st most_recent_timestamp = st audio_generate_time = 0.0 audio_first_timestamp = st try: async with session.post(url=api_url, json=payload, headers=headers) as response: if response.status == 200: handler = StreamedResponseHandler() async for chunk_bytes in response.content.iter_any(): chunk_bytes = chunk_bytes.strip() if not chunk_bytes: continue messages = handler.add_chunk(chunk_bytes) for message in messages: # NOTE: SSE comments (often used as pings) start with # a colon. These are not JSON data payload and should # be skipped. if message.startswith(":"): continue chunk = message.removeprefix("data: ") if chunk != "[DONE]": timestamp = time.perf_counter() data = json.loads(chunk) if choices := data.get("choices"): modality = data.get("modality") content = choices[0]["delta"].get("content") if modality == "text": # First token if ttft == 0.0: ttft = timestamp - st output.ttft = ttft else: output.itl.append(timestamp - most_recent_timestamp) generated_text += content or "" elif modality == "audio": if output.audio_ttfp == 0.0: audio_first_timestamp = timestamp output.audio_ttfp = timestamp - st audio_generate_time = timestamp - audio_first_timestamp if content != "": audio_bytes = base64.b64decode(content) seg = AudioSegment.from_file(io.BytesIO(audio_bytes)) if seg is not None: if generated_audio is None: generated_audio = seg else: generated_audio = generated_audio + seg elif usage := data.get("usage"): output.output_tokens = usage.get("completion_tokens") most_recent_timestamp = timestamp output.generated_text = generated_text if generated_audio is not None: output.audio_duration = len(generated_audio) / 1000.0 frame_width = generated_audio.frame_width if frame_width > 0: output.audio_frames = len(generated_audio.raw_data) // frame_width else: output.audio_frames = 0 logger.warning("Audio frame width is zero") audio_duration = output.audio_duration if audio_duration > 0: output.audio_rtf = audio_generate_time / output.audio_duration else: output.audio_rtf = 0 logger.warning("Audio duration is zero") output.success = True output.latency = most_recent_timestamp - st else: output.error = response.reason or "" output.success = False except Exception: output.success = False exc_info = sys.exc_info() output.error = "".join(traceback.format_exception(*exc_info)) logger.error(f"ERROR: send request failed, reason is: {output.error}") if pbar: pbar.update(1) return output ASYNC_REQUEST_FUNCS["openai-chat-omni"] = async_request_openai_chat_omni_completions if "openai-chat-omni" not in OPENAI_COMPATIBLE_BACKENDS: OPENAI_COMPATIBLE_BACKENDS.append("openai-chat-omni") # ruff: noqa: E402 # Prevent import order from causing patch failures from vllm.benchmarks import serve from vllm.benchmarks.serve import TaskType, calculate_metrics_for_embeddings, get_request, wait_for_endpoint from vllm_omni.benchmarks.metrics.metrics import MultiModalsBenchmarkMetrics, calculate_metrics # ruff: noqa: E402 benchmark_old = serve.benchmark async def benchmark( task_type: TaskType, endpoint_type: str, api_url: str, base_url: str, model_id: str, model_name: str, tokenizer: PreTrainedTokenizerBase, input_requests: list[SampleRequest], logprobs: int | None, request_rate: float, burstiness: float, disable_tqdm: bool, num_warmups: int, profile: bool, selected_percentile_metrics: list[str], selected_percentiles: list[float], ignore_eos: bool, goodput_config_dict: dict[str, float], max_concurrency: int | None, lora_modules: Iterable[str] | None, extra_headers: dict | None, extra_body: dict | None, ramp_up_strategy: Literal["linear", "exponential"] | None = None, ramp_up_start_rps: int | None = None, ramp_up_end_rps: int | None = None, ready_check_timeout_sec: int = 600, ): try: request_func = ASYNC_REQUEST_FUNCS[endpoint_type] except KeyError: raise ValueError(f"Unknown backend: {endpoint_type}") from None # Reuses connections across requests to reduce TLS handshake overhead. connector = aiohttp.TCPConnector( limit=max_concurrency or 0, limit_per_host=max_concurrency or 0, ttl_dns_cache=300, use_dns_cache=True, keepalive_timeout=60, enable_cleanup_closed=True, force_close=False, ssl=("https://" in api_url), ) session = aiohttp.ClientSession( connector=connector, trust_env=True, timeout=aiohttp.ClientTimeout(total=6 * 60 * 60), ) print("Starting initial single prompt test run...") test_prompt, test_prompt_len, test_output_len, test_mm_content = ( input_requests[0].prompt, input_requests[0].prompt_len, input_requests[0].expected_output_len, input_requests[0].multi_modal_data, ) assert ( test_mm_content is None or isinstance(test_mm_content, dict) or (isinstance(test_mm_content, list) and all(isinstance(item, dict) for item in test_mm_content)) ), "multi_modal_data must be a dict or list[dict]" test_input = RequestFuncInput( model=model_id, model_name=model_name, prompt=test_prompt, api_url=api_url, prompt_len=test_prompt_len, output_len=test_output_len, logprobs=logprobs, multi_modal_content=test_mm_content, ignore_eos=ignore_eos, extra_headers=extra_headers, extra_body=extra_body, ) if ready_check_timeout_sec > 0: test_output = await wait_for_endpoint( request_func, test_input, session, timeout_seconds=ready_check_timeout_sec, ) if not test_output.success: raise ValueError( "Initial test run failed - Please make sure benchmark " "arguments are correctly specified. " f"Error: {test_output.error}" ) else: print("Initial test run completed.") else: print("Skipping endpoint ready check.") if num_warmups > 0: print(f"Warming up with {num_warmups} requests...") warmup_pbar = None if disable_tqdm else tqdm(total=num_warmups) warmup_semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else contextlib.nullcontext() warmup_tasks = [] async def warmup_limited_request_func(): async with warmup_semaphore: return await request_func(request_func_input=test_input, session=session, pbar=warmup_pbar) for _ in range(num_warmups): request_task = asyncio.create_task(warmup_limited_request_func()) warmup_tasks.append(request_task) _ = await asyncio.gather(*warmup_tasks) if warmup_pbar is not None: warmup_pbar.close() print("Warmup run completed.") print("Starting main benchmark run...") if lora_modules: # For each input request, choose a LoRA module at random. lora_modules = iter([random.choice(lora_modules) for _ in range(len(input_requests))]) if profile: print("Starting profiler...") profile_input = RequestFuncInput( model=model_id, model_name=model_name, prompt=test_prompt, api_url=base_url + "/start_profile", prompt_len=test_prompt_len, output_len=test_output_len, logprobs=logprobs, multi_modal_content=test_mm_content, ignore_eos=ignore_eos, extra_headers=extra_headers, extra_body=extra_body, ) profile_output = await request_func(request_func_input=profile_input, session=session) if profile_output.success: print("Profiler started") distribution = "Poisson process" if burstiness == 1.0 else "Gamma distribution" if ramp_up_strategy is not None: print(f"Traffic ramp-up strategy: {ramp_up_strategy}.") print( f"Will increase RPS from {ramp_up_start_rps} to {ramp_up_end_rps} RPS over the duration of the benchmark." ) else: print(f"Traffic request rate: {request_rate}") print(f"Burstiness factor: {burstiness} ({distribution})") print(f"Maximum request concurrency: {max_concurrency}") pbar = None if disable_tqdm else tqdm(total=len(input_requests)) semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else contextlib.nullcontext() async def limited_request_func(request_func_input, session, pbar): async with semaphore: return await request_func(request_func_input=request_func_input, session=session, pbar=pbar) benchmark_start_time = time.perf_counter() tasks: list[asyncio.Task] = [] rps_change_events = [] last_int_rps = -1 if ramp_up_strategy is not None and ramp_up_start_rps is not None: last_int_rps = ramp_up_start_rps rps_change_events.append( { "rps": last_int_rps, "timestamp": datetime.now().isoformat(), } ) async for request, current_request_rate in get_request( input_requests, request_rate, burstiness, ramp_up_strategy, ramp_up_start_rps, ramp_up_end_rps, ): if ramp_up_strategy is not None: current_int_rps = int(current_request_rate) if current_int_rps > last_int_rps: timestamp = datetime.now().isoformat() for rps_val in range(last_int_rps + 1, current_int_rps + 1): rps_change_events.append({"rps": rps_val, "timestamp": timestamp}) last_int_rps = current_int_rps prompt, prompt_len, output_len, mm_content, request_id = ( request.prompt, request.prompt_len, request.expected_output_len, request.multi_modal_data, request.request_id, ) req_model_id, req_model_name = model_id, model_name if lora_modules: req_lora_module = next(lora_modules) req_model_id, req_model_name = req_lora_module, req_lora_module request_func_input = RequestFuncInput( model=req_model_id, model_name=req_model_name, prompt=prompt, api_url=api_url, prompt_len=prompt_len, output_len=output_len, logprobs=logprobs, multi_modal_content=mm_content, ignore_eos=ignore_eos, extra_headers=extra_headers, extra_body=extra_body, request_id=request_id, ) tasks.append( asyncio.create_task(limited_request_func(request_func_input=request_func_input, session=session, pbar=pbar)) ) outputs: list[MixRequestFuncOutput] = await asyncio.gather(*tasks) if pbar is not None: pbar.close() benchmark_duration = time.perf_counter() - benchmark_start_time if task_type == TaskType.GENERATION: metrics, actual_output_lens = calculate_metrics( input_requests=input_requests, outputs=outputs, dur_s=benchmark_duration, tokenizer=tokenizer, selected_percentiles=selected_percentiles, goodput_config_dict=goodput_config_dict, task_type=task_type, selected_percentile_metrics=selected_percentile_metrics, max_concurrency=max_concurrency, request_rate=request_rate, benchmark_duration=benchmark_duration, ) else: metrics = calculate_metrics_for_embeddings( outputs=outputs, dur_s=benchmark_duration, selected_percentiles=selected_percentiles, ) actual_output_lens = 0 if isinstance(metrics, MultiModalsBenchmarkMetrics): result = { "duration": benchmark_duration, "completed": metrics.completed, "failed": metrics.failed, "total_input_tokens": metrics.total_input, "total_output_tokens": metrics.total_output, "request_throughput": metrics.request_throughput, "request_goodput": metrics.request_goodput if goodput_config_dict else None, "output_throughput": metrics.output_throughput, "total_token_throughput": metrics.total_token_throughput, "input_lens": [output.prompt_len for output in outputs], "output_lens": actual_output_lens, "ttfts": [output.ttft for output in outputs], "itls": [output.itl for output in outputs], "generated_texts": [output.generated_text for output in outputs], "errors": [output.error for output in outputs], "max_output_tokens_per_s": metrics.max_output_tokens_per_s, "max_concurrent_requests": metrics.max_concurrent_requests, } else: result = { "duration": benchmark_duration, "completed": metrics.completed, "total_input_tokens": metrics.total_input, "request_throughput": metrics.request_throughput, "total_token_throughput": metrics.total_token_throughput, "input_lens": [output.prompt_len for output in outputs], "errors": [output.error for output in outputs], } if rps_change_events: result["rps_change_events"] = rps_change_events def process_one_metric( # E.g., "ttft" metric_attribute_name: str, ): # This function prints and adds statistics of the specified # metric. if metric_attribute_name not in selected_percentile_metrics: return is_audio_rtf = metric_attribute_name == "audio_rtf" suffix = "" if is_audio_rtf else "_ms" for p, value in getattr(metrics, f"percentiles_{metric_attribute_name}{suffix}"): p_word = str(int(p)) if int(p) == p else str(p) result[f"p{p_word}_{metric_attribute_name}{suffix}"] = value if task_type == TaskType.GENERATION: for metric in selected_percentile_metrics: process_one_metric(metric) else: process_one_metric("e2el") if profile: print("Stopping profiler...") profile_input = RequestFuncInput( model=model_id, prompt=test_prompt, api_url=base_url + "/stop_profile", prompt_len=test_prompt_len, output_len=test_output_len, logprobs=logprobs, ) profile_output = await request_func(request_func_input=profile_input, session=session) if profile_output.success: print("Profiler stopped") await session.close() return result serve.benchmark = benchmark