"""Benchmark offline inference throughput.""" import argparse import json import random import time from typing import List, Optional, Tuple import numpy as np import torch import uvloop from tqdm import tqdm from transformers import (AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase) from vllm.inputs import PromptInputs from vllm.engine.arg_utils import DEVICE_OPTIONS, AsyncEngineArgs, EngineArgs from vllm.entrypoints.openai.api_server import ( build_async_engine_client_from_engine_args) from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.utils import FlexibleArgumentParser, merge_async_iterators from vllm.lora.request import LoRARequest def nullable_str(val: str): if not val or val == "None": return None return val def sample_requests( dataset_path: str, num_requests: int, tokenizer: PreTrainedTokenizerBase, fixed_output_len: Optional[int], ) -> List[Tuple[str, int, int]]: if fixed_output_len is not None and fixed_output_len < 4: raise ValueError("output_len too small") # Load the dataset. with open(dataset_path) as f: dataset = json.load(f) # Only keep the first two turns of each conversation. dataset = [data["prompt"] for data in dataset] # Shuffle the dataset. random.shuffle(dataset) # Filter out sequences that are too long or too short filtered_dataset: List[Tuple[str, int, int]] = [] for i in range(len(dataset)): if len(filtered_dataset) == num_requests: break # Tokenize the prompts and completions. prompt = dataset[i] prompt_token_ids = tokenizer(prompt).input_ids prompt_len = len(prompt_token_ids) output_len = fixed_output_len filtered_dataset.append((prompt, prompt_len, output_len)) return filtered_dataset def run_vllm( warmup_requests: List[Tuple[str, int, int]], requests: List[Tuple[str, int, int]], model: str, tokenizer: str, quantization: Optional[str], tensor_parallel_size: int, seed: int, n: int, use_beam_search: bool, trust_remote_code: bool, dtype: str, max_model_len: Optional[int], enforce_eager: bool, kv_cache_dtype: str, quantization_param_path: Optional[str], device: str, enable_prefix_caching: bool, enable_chunked_prefill: bool, max_num_batched_tokens: int, distributed_executor_backend: Optional[str], gpu_memory_utilization: float = 0.9, num_scheduler_steps: int = 1, use_v2_block_manager: bool = False, download_dir: Optional[str] = None, load_format: str = EngineArgs.load_format, disable_async_output_proc: bool = False, max_num_seqs: int = 8, speculative_model: str=None, speculative_draft_tensor_parallel_size: int = 1, speculative_disable_by_batch_size: int = 4, spec_decoding_acceptance_method: str = None, enable_lora: bool = False, max_lora_rank: int = 32, merge_lora: bool = False, lora_extra_vocab_size: int = 0, lora_target_modules: List[str] = None, num_speculative_heads: int = 5, num_speculative_tokens: int = 64, use_new_beam_search_impl: bool = False, lora_modules: str = None ) -> float: from vllm import LLM, SamplingParams llm = LLM( model=model, tokenizer=tokenizer, quantization=quantization, tensor_parallel_size=tensor_parallel_size, seed=seed, trust_remote_code=trust_remote_code, dtype=dtype, max_model_len=max_model_len, gpu_memory_utilization=gpu_memory_utilization, enforce_eager=enforce_eager, kv_cache_dtype=kv_cache_dtype, quantization_param_path=quantization_param_path, device=device, enable_prefix_caching=enable_prefix_caching, download_dir=download_dir, enable_chunked_prefill=enable_chunked_prefill, max_num_batched_tokens=max_num_batched_tokens, distributed_executor_backend=distributed_executor_backend, load_format=load_format, num_scheduler_steps=num_scheduler_steps, use_v2_block_manager=use_v2_block_manager, disable_async_output_proc=disable_async_output_proc, max_num_seqs=max_num_seqs, speculative_model=speculative_model, speculative_draft_tensor_parallel_size=speculative_draft_tensor_parallel_size, speculative_disable_by_batch_size=speculative_disable_by_batch_size, spec_decoding_acceptance_method=spec_decoding_acceptance_method, enable_lora=enable_lora, max_lora_rank=max_lora_rank, merge_lora=merge_lora, lora_extra_vocab_size=lora_extra_vocab_size, lora_target_modules=lora_target_modules, num_speculative_heads=num_speculative_heads, num_speculative_tokens=num_speculative_tokens ) # Add the requests to the engine. prompts: List[str] = [] sampling_params: List[SamplingParams] = [] for prompt, _, output_len in requests: prompts.append(prompt) sampling_params.append( SamplingParams( n=n, temperature=0.0, top_p=1.0, use_beam_search=use_beam_search, ignore_eos=False, max_tokens=output_len, )) # warmup warmup_prompts = [] warmup_sampling_params = [] for prompt, _, output_len in warmup_requests: warmup_prompts.append(prompt) warmup_sampling_params.append( SamplingParams( n=n, temperature=0.0, top_p=1.0, use_beam_search=use_beam_search, ignore_eos=False, max_tokens=output_len, )) print("Warming up...") for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"): if lora_modules is None: llm.generate(warmup_prompts, warmup_sampling_params, use_tqdm=True) else: llm.generate(warmup_prompts, warmup_sampling_params, use_tqdm=True, lora_request=LoRARequest("medusa-lora", 1, lora_modules)) total_out_tokens = 0 start = time.perf_counter() if lora_modules is None: outputs = llm.generate(prompts, sampling_params, use_tqdm=False) else: outputs = llm.generate(prompts, sampling_params, use_tqdm=False, lora_request=LoRARequest("medusa-lora", 1, lora_modules)) for output in outputs: print("token_ids len:{} text:{}".format(len(output.outputs[0].token_ids), output.outputs[0].text)) total_out_tokens += len(output.outputs[0].token_ids) end = time.perf_counter() return end - start, total_out_tokens async def run_vllm_async( requests: List[Tuple[str, int, int]], model: str, tokenizer: str, quantization: Optional[str], tensor_parallel_size: int, seed: int, n: int, use_beam_search: bool, trust_remote_code: bool, dtype: str, max_model_len: Optional[int], enforce_eager: bool, kv_cache_dtype: str, quantization_param_path: Optional[str], device: str, enable_prefix_caching: bool, enable_chunked_prefill: bool, max_num_batched_tokens: int, distributed_executor_backend: Optional[str], gpu_memory_utilization: float = 0.9, num_scheduler_steps: int = 1, use_v2_block_manager: bool = False, download_dir: Optional[str] = None, load_format: str = EngineArgs.load_format, disable_async_output_proc: bool = False, disable_frontend_multiprocessing: bool = False, max_num_seqs: int = 8, speculative_model: str=None, speculative_draft_tensor_parallel_size: int = 1, speculative_disable_by_batch_size: int = 4, spec_decoding_acceptance_method: str = None, enable_lora: bool = False, max_lora_rank: int = 32, merge_lora: bool = False, lora_extra_vocab_size: int = 0, lora_target_modules: List[str] = None, num_speculative_heads: int = 5, num_speculative_tokens: int = 64, use_new_beam_search_impl: bool = False, lora_modules: str = None ) -> float: from vllm import SamplingParams engine_args = AsyncEngineArgs( model=model, tokenizer=tokenizer, quantization=quantization, tensor_parallel_size=tensor_parallel_size, seed=seed, trust_remote_code=trust_remote_code, dtype=dtype, max_model_len=max_model_len, gpu_memory_utilization=gpu_memory_utilization, enforce_eager=enforce_eager, kv_cache_dtype=kv_cache_dtype, quantization_param_path=quantization_param_path, device=device, enable_prefix_caching=enable_prefix_caching, download_dir=download_dir, enable_chunked_prefill=enable_chunked_prefill, max_num_batched_tokens=max_num_batched_tokens, distributed_executor_backend=distributed_executor_backend, load_format=load_format, num_scheduler_steps=num_scheduler_steps, use_v2_block_manager=use_v2_block_manager, disable_async_output_proc=disable_async_output_proc, worker_use_ray=False, disable_log_requests=True, max_num_seqs=max_num_seqs, speculative_model=speculative_model, speculative_draft_tensor_parallel_size=speculative_draft_tensor_parallel_size, speculative_disable_by_batch_size=speculative_disable_by_batch_size, spec_decoding_acceptance_method=spec_decoding_acceptance_method, enable_lora=enable_lora, max_lora_rank=max_lora_rank, merge_lora=merge_lora, lora_extra_vocab_size=lora_extra_vocab_size, lora_target_modules=lora_target_modules, num_speculative_heads=num_speculative_heads, num_speculative_tokens=num_speculative_tokens ) async with build_async_engine_client_from_engine_args( engine_args, disable_frontend_multiprocessing) as llm: # Add the requests to the engine. prompts: List[str] = [] sampling_params: List[SamplingParams] = [] for prompt, _, output_len in requests: prompts.append(prompt) sampling_params.append( SamplingParams( n=n, temperature=0.0 if use_beam_search else 1.0, top_p=1.0, use_beam_search=use_beam_search, ignore_eos=False, max_tokens=output_len, )) generators = [] start = time.perf_counter() for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)): generator = llm.generate(prompt, sp, request_id=f"test{i}") generators.append(generator) all_gens = merge_async_iterators(*generators) out_dict = {} async for i, res in all_gens: #print("res:", res) out_dict[res.request_id] = len(res.outputs[0].token_ids) end = time.perf_counter() total_out_tokens = 0 for token_num in out_dict.values(): total_out_tokens += token_num return end - start, total_out_tokens def main(args: argparse.Namespace): print(args) random.seed(args.seed) # Sample the requests. tokenizer = AutoTokenizer.from_pretrained( args.tokenizer, trust_remote_code=args.trust_remote_code) warmup_prompt = "hi" * 10 warmup_requests = [(warmup_prompt, 10, 10) for _ in range(1)] if args.dataset is None: # Synthesize a prompt with the given input length. prompt = "hi" * (args.input_len - 1) requests = [(prompt, args.input_len, args.output_len) for _ in range(args.num_prompts)] else: requests = sample_requests(args.dataset, args.num_prompts, tokenizer, args.output_len) if args.async_engine: run_args = [ requests, args.model, args.tokenizer, args.quantization, args.tensor_parallel_size, args.seed, args.n, args.use_beam_search, args.trust_remote_code, args.dtype, args.max_model_len, args.enforce_eager, args.kv_cache_dtype, args.quantization_param_path, args.device, args.enable_prefix_caching, args.enable_chunked_prefill, args.max_num_batched_tokens, args.distributed_executor_backend, args.gpu_memory_utilization, args.num_scheduler_steps, args.use_v2_block_manager, args.download_dir, args.load_format, args.disable_async_output_proc, False, args.max_num_seqs, args.speculative_model, args.speculative_draft_tensor_parallel_size, args.speculative_disable_by_batch_size, args.spec_decoding_acceptance_method, args.enable_lora, args.max_lora_rank, args.merge_lora, args.lora_extra_vocab_size, args.lora_target_modules, args.num_speculative_heads, args.num_speculative_tokens ] else: run_args = [ warmup_requests, requests, args.model, args.tokenizer, args.quantization, args.tensor_parallel_size, args.seed, args.n, args.use_beam_search, args.trust_remote_code, args.dtype, args.max_model_len, args.enforce_eager, args.kv_cache_dtype, args.quantization_param_path, args.device, args.enable_prefix_caching, args.enable_chunked_prefill, args.max_num_batched_tokens, args.distributed_executor_backend, args.gpu_memory_utilization, args.num_scheduler_steps, args.use_v2_block_manager, args.download_dir, args.load_format, args.disable_async_output_proc, args.max_num_seqs, args.speculative_model, args.speculative_draft_tensor_parallel_size, args.speculative_disable_by_batch_size, args.spec_decoding_acceptance_method, args.enable_lora, args.max_lora_rank, args.merge_lora, args.lora_extra_vocab_size, args.lora_target_modules, args.num_speculative_heads, args.num_speculative_tokens ] if args.async_engine: run_args.append(args.disable_frontend_multiprocessing) elapsed_time, total_out_tokens = uvloop.run(run_vllm_async(*run_args)) else: elapsed_time, total_out_tokens = run_vllm(*run_args, args.use_new_beam_search_impl, args.lora_modules) total_num_tokens = total_out_tokens + sum(prompt_len for _, prompt_len, _ in requests) print(f"Latency: {elapsed_time:.2f} s") print(f"All Throughput: {len(requests) / elapsed_time:.2f} requests/s, " f"{total_num_tokens / elapsed_time:.2f} tokens/s") print(f"Generate Throughput: {total_out_tokens / elapsed_time:.2f} tokens/s") # Output JSON results if specified if args.output_json: results = { "elapsed_time": elapsed_time, "num_requests": len(requests), "total_num_tokens": total_num_tokens, "requests_per_second": len(requests) / elapsed_time, "tokens_per_second": total_num_tokens / elapsed_time, } with open(args.output_json, "w") as f: json.dump(results, f, indent=4) if __name__ == "__main__": parser = FlexibleArgumentParser(description="Benchmark the throughput.") parser.add_argument("--dataset", type=str, default=None, help="Path to the dataset.") parser.add_argument("--input-len", type=int, default=None, help="Input prompt length for each request") parser.add_argument("--output-len", type=int, default=256, help="Output length for each request. Overrides the " "output length from the dataset.") parser.add_argument("--model", type=str, default="facebook/opt-125m") parser.add_argument("--tokenizer", type=str, default=None) parser.add_argument('--quantization', '-q', choices=[*QUANTIZATION_METHODS, None], default=None) parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1) parser.add_argument("--n", type=int, default=1, help="Number of generated sequences per prompt.") parser.add_argument("--use-beam-search", action="store_true") parser.add_argument('--num-iters-warmup', type=int, default=1, help='Number of iterations to run for warmup.') parser.add_argument("--use-new-beam-search-impl", action="store_true") parser.add_argument("--num-prompts", type=int, default=1000, help="Number of prompts to process.") parser.add_argument("--seed", type=int, default=0) parser.add_argument('--trust-remote-code', action='store_true', help='trust remote code from huggingface') parser.add_argument( '--max-model-len', type=int, default=None, help='Maximum length of a sequence (including prompt and output). ' 'If None, will be derived from the model.') parser.add_argument( '--dtype', type=str, default='auto', choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'], help='data type for model weights and activations. ' 'The "auto" option will use FP16 precision ' 'for FP32 and FP16 models, and BF16 precision ' 'for BF16 models.') parser.add_argument('--gpu-memory-utilization', type=float, default=0.9, help='the fraction of GPU memory to be used for ' 'the model executor, which can range from 0 to 1.' 'If unspecified, will use the default value of 0.9.') parser.add_argument("--enforce-eager", action="store_true", help="enforce eager execution") parser.add_argument( '--kv-cache-dtype', type=str, choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'], default="auto", help='Data type for kv cache storage. If "auto", will use model ' 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ' 'ROCm (hcu) supports fp8 (=fp8_e4m3)') parser.add_argument( '--quantization-param-path', type=str, default=None, help='Path to the JSON file containing the KV cache scaling factors. ' 'This should generally be supplied, when KV cache dtype is FP8. ' 'Otherwise, KV cache scaling factors default to 1.0, which may cause ' 'accuracy issues. FP8_E5M2 (without scaling) is only supported on ' 'cuda version greater than 11.8. On ROCm (hcu), FP8_E4M3 is ' 'instead supported for common inference criteria.') parser.add_argument("--device", type=str, default="auto", choices=DEVICE_OPTIONS, help='device type for vLLM execution') parser.add_argument( "--num-scheduler-steps", type=int, default=1, help="Maximum number of forward steps per scheduler call.") parser.add_argument("--use-v2-block-manager", action='store_true', help="Enable block manager v2.") parser.add_argument( "--enable-prefix-caching", action='store_true', help="Enable automatic prefix caching for vLLM backend.") parser.add_argument("--enable-chunked-prefill", action='store_true', help="enable chunked prefill for vLLM backend.") parser.add_argument('--max-num-batched-tokens', type=int, default=None, help='maximum number of batched tokens per ' 'iteration') parser.add_argument('--download-dir', type=str, default=None, help='directory to download and load the weights, ' 'default to the default cache dir of huggingface') parser.add_argument( '--output-json', type=str, default=None, help='Path to save the throughput results in JSON format.') parser.add_argument( '--distributed-executor-backend', choices=['ray', 'mp'], default=None, help='Backend to use for distributed serving. When more than 1 GPU ' 'is used, will be automatically set to "ray" if installed ' 'or "mp" (multiprocessing) otherwise.') parser.add_argument( '--load-format', type=str, default=EngineArgs.load_format, choices=[ 'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer', 'bitsandbytes' ], help='The format of the model weights to load.\n\n' '* "auto" will try to load the weights in the safetensors format ' 'and fall back to the pytorch bin format if safetensors format ' 'is not available.\n' '* "pt" will load the weights in the pytorch bin format.\n' '* "safetensors" will load the weights in the safetensors format.\n' '* "npcache" will load the weights in pytorch format and store ' 'a numpy cache to speed up the loading.\n' '* "dummy" will initialize the weights with random values, ' 'which is mainly for profiling.\n' '* "tensorizer" will load the weights using tensorizer from ' 'CoreWeave. See the Tensorize vLLM Model script in the Examples' 'section for more information.\n' '* "bitsandbytes" will load the weights using bitsandbytes ' 'quantization.\n') parser.add_argument( "--disable-async-output-proc", action='store_true', default=False, help="Disable async output processor for vLLM backend.") parser.add_argument("--async-engine", action='store_true', default=False, help="Use vLLM async engine rather than LLM class.") parser.add_argument("--disable-frontend-multiprocessing", action='store_true', default=False, help="Disable decoupled async engine frontend.") parser.add_argument('--max-num-seqs', type=int, default=EngineArgs.max_num_seqs, help='Maximum number of sequences per iteration.') parser.add_argument( '--speculative-model', type=nullable_str, default=EngineArgs.speculative_model, help= 'The name of the draft model to be used in speculative decoding.') parser.add_argument( '--speculative-draft-tensor-parallel-size', '-spec-draft-tp', type=int, default=EngineArgs.speculative_draft_tensor_parallel_size, help='Number of tensor parallel replicas for ' 'the draft model in speculative decoding.') parser.add_argument( '--speculative-disable-by-batch-size', type=int, default=EngineArgs.speculative_disable_by_batch_size, help='Disable speculative decoding for new incoming requests ' 'if the number of enqueue requests is larger than this value.') parser.add_argument( '--spec-decoding-acceptance-method', type=str, default=EngineArgs.spec_decoding_acceptance_method, choices=['rejection_sampler', 'typical_acceptance_sampler'], help='Specify the acceptance method to use during draft token ' 'verification in speculative decoding. Two types of acceptance ' 'routines are supported: ' '1) RejectionSampler which does not allow changing the ' 'acceptance rate of draft tokens, ' '2) TypicalAcceptanceSampler which is configurable, allowing for ' 'a higher acceptance rate at the cost of lower quality, ' 'and vice versa.') # LoRA related configs parser.add_argument('--enable-lora', action='store_true', help='If True, enable handling of LoRA adapters.') parser.add_argument('--max-lora-rank', type=int, default=EngineArgs.max_lora_rank, help='Max LoRA rank.') parser.add_argument('--merge-lora', type=bool, default=False, help='If set to True, the weights of the base layer will be merged with the weights of Lora.') parser.add_argument( '--lora-extra-vocab-size', type=int, default=EngineArgs.lora_extra_vocab_size, help=('Maximum size of extra vocabulary that can be ' 'present in a LoRA adapter (added to the base ' 'model vocabulary).')) parser.add_argument('--lora-target-modules', nargs='*', default=None, help='List of lora module name, If not specified, modules will be chosen according to the model architecture.') parser.add_argument( '--num-speculative-heads', type=int, default=EngineArgs.num_speculative_heads, help='The number of speculative heads to sample from ' 'the draft model in speculative decoding.') parser.add_argument( '--num-speculative-tokens', type=int, default=EngineArgs.num_speculative_tokens, help='The number of speculative tokens to sample from ' 'the draft model in speculative decoding.') parser.add_argument( '--lora-modules', type=nullable_str, default=None, help= 'Path of lora model.') args = parser.parse_args() if args.tokenizer is None: args.tokenizer = args.model if args.dataset is None: assert args.input_len is not None assert args.output_len is not None else: assert args.input_len is None main(args)