Commit fbeb8a6f authored by raojy's avatar raojy
Browse files

raw_vllm

parent 2ca8867f
Pipeline #3454 canceled with stages
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Offline benchmark to test the long document QA throughput.
Example usage:
# This workload samples 8 different prompts with a default input
# length of 20000 tokens, then replicates each prompt 2 times
# in random order.
python benchmark_long_document_qa_throughput.py \
--model meta-llama/Llama-2-7b-chat-hf \
--enable-prefix-caching \
--num-documents 8 \
--repeat-count 2
Commandline arguments:
--num-documents: The number of documents to sample prompts from.
--document-length: The length of each document in tokens.
(Optional, default: 20000)
--output-len: The number of tokens to generate for each prompt.
(Optional, default: 10)
--repeat-count: The number of times to repeat each prompt.
(Optional, default: 2)
--repeat-mode: The mode to repeat prompts. The supported modes are:
- 'random': shuffle the prompts randomly. (Default)
- 'tile': the entire prompt list is repeated in sequence. (Potentially
lowest cache hit)
- 'interleave': each prompt is repeated consecutively before
moving to the next element. (Highest cache hit)
--shuffle-seed: Random seed when the repeat mode is "random".
(Optional, default: 0)
In the meantime, it also supports all the vLLM engine args to initialize the
LLM engine. You can refer to the `vllm.engine.arg_utils.EngineArgs` for more
details.
"""
import dataclasses
import random
import time
from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import EngineArgs
from vllm.utils.argparse_utils import FlexibleArgumentParser
def test_long_document_qa(llm=None, sampling_params=None, prompts=None):
"""
Test long document QA with the given prompts and sampling parameters.
Print the time spent in processing all the prompts.
Args:
llm: The language model used for generating responses.
sampling_params: Sampling parameter used to generate the response.
prompts: A list of prompt strings to be processed by the LLM.
"""
start_time = time.time()
llm.generate(prompts, sampling_params=sampling_params)
end_time = time.time()
print(f"Time to execute all requests: {end_time - start_time:.4f} secs")
def repeat_prompts(prompts, repeat_count, mode: str):
"""
Repeat each prompt in the list for a specified number of times.
The order of prompts in the output list depends on the mode.
Args:
prompts: A list of prompts to be repeated.
repeat_count: The number of times each prompt is repeated.
mode: The mode of repetition. Supported modes are:
- 'random': Shuffle the prompts randomly after repetition.
- 'tile': Repeat the entire prompt list in sequence.
Example: [1, 2, 3] -> [1, 2, 3, 1, 2, 3].
- 'interleave': Repeat each prompt consecutively before moving to
the next. Example: [1, 2, 3] -> [1, 1, 2, 2, 3, 3].
Returns:
A list of repeated prompts in the specified order.
Raises:
ValueError: If an invalid mode is provided.
"""
print("Repeat mode: ", mode)
if mode == "random":
repeated_prompts = prompts * repeat_count
random.shuffle(repeated_prompts)
return repeated_prompts
elif mode == "tile":
return prompts * repeat_count
elif mode == "interleave":
repeated_prompts = []
for prompt in prompts:
repeated_prompts.extend([prompt] * repeat_count)
return repeated_prompts
else:
raise ValueError(
f"Invalid mode: {mode}, only support 'random', 'tile', 'interleave'"
)
def main(args):
random.seed(args.shuffle_seed)
# Prepare the prompts:
# we append the document id at the beginning to avoid any of the document
# being the prefix of other documents
prompts = [
str(i) + " ".join(["hi"] * args.document_length)
for i in range(args.num_documents)
]
prompts = repeat_prompts(prompts, args.repeat_count, mode=args.repeat_mode)
warmup_prompts = [
"This is warm up request " + str(i) + " ".join(["hi"] * args.document_length)
for i in range(args.num_documents)
]
# Create the LLM engine
engine_args = EngineArgs.from_cli_args(args)
llm = LLM(**dataclasses.asdict(engine_args))
sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len)
print("------warm up------")
test_long_document_qa(
llm=llm,
prompts=warmup_prompts,
sampling_params=sampling_params,
)
print("------start generating------")
test_long_document_qa(
llm=llm,
prompts=prompts,
sampling_params=sampling_params,
)
def create_argument_parser():
parser = FlexibleArgumentParser(
description="Benchmark the performance with or "
"without automatic prefix caching."
)
parser.add_argument(
"--document-length",
type=int,
# Roughly the number of tokens for a system paper,
# excluding images
default=20000,
help="Range of input lengths for sampling prompts, "
'specified as "min:max" (e.g., "128:256").',
)
parser.add_argument(
"--num-documents",
type=int,
default=8,
help="Range of input lengths for sampling prompts, "
'specified as "min:max" (e.g., "128:256").',
)
parser.add_argument("--output-len", type=int, default=10)
parser.add_argument(
"--repeat-count",
type=int,
default=2,
help="Number of times to repeat each prompt",
)
parser.add_argument(
"--repeat-mode",
type=str,
default="random",
help="The mode to repeat prompts. The supported "
'modes are "random", "tile", and "interleave". '
"See repeat_prompts() in the source code for details.",
)
parser.add_argument(
"--shuffle-seed",
type=int,
default=0,
help='Random seed when the repeat mode is "random"',
)
parser = EngineArgs.add_cli_args(parser)
return parser
if __name__ == "__main__":
parser = create_argument_parser()
args = parser.parse_args()
main(args)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import gc
import time
from unittest import mock
import numpy as np
from benchmark_utils import TimeCollector
from tabulate import tabulate
from vllm.config import (
CacheConfig,
DeviceConfig,
LoadConfig,
ModelConfig,
ParallelConfig,
SchedulerConfig,
SpeculativeConfig,
VllmConfig,
)
from vllm.platforms import current_platform
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
def benchmark_propose(args):
rows = []
for max_ngram in args.max_ngram:
collector = TimeCollector(TimeCollector.US)
model_config = ModelConfig(
model="facebook/opt-125m",
max_model_len=args.num_token + args.num_spec_token,
tokenizer="facebook/opt-125m",
tokenizer_mode="auto",
dtype="auto",
seed=0,
trust_remote_code=False,
)
proposer = NgramProposer(
vllm_config=VllmConfig(
model_config=model_config,
speculative_config=SpeculativeConfig(
prompt_lookup_min=args.min_ngram,
prompt_lookup_max=max_ngram,
num_speculative_tokens=args.num_spec_token,
method="ngram",
),
)
)
# Warm up
proposer.propose(np.random.randint(0, 20, (args.num_token,)))
gc.collect()
for _ in range(args.num_iteration):
tokens = np.random.randint(0, 20, (args.num_req, args.num_token))
with collector:
for i in range(args.num_req):
proposer.propose(tokens[i, :])
rows.append(
[args.num_req, args.num_token, args.min_ngram, max_ngram]
+ collector.dump_avg_max()
)
print(
tabulate(
rows,
headers=[
"# Request",
"# Token",
"Min Ngram",
"Max Ngram",
"Avg (us)",
"Max (us)",
],
tablefmt="grid",
floatfmt=".3f",
)
)
def benchmark_batched_propose(args):
NUM_SPECULATIVE_TOKENS_NGRAM = 10
PROMPT_LOOKUP_MIN = 5
PROMPT_LOOKUP_MAX = 15
MAX_MODEL_LEN = int(1e7)
DEVICE = current_platform.device_type
model_config = ModelConfig(model="facebook/opt-125m", runner="generate")
speculative_config = SpeculativeConfig(
target_model_config=model_config,
target_parallel_config=ParallelConfig(),
method="ngram",
num_speculative_tokens=NUM_SPECULATIVE_TOKENS_NGRAM,
prompt_lookup_max=PROMPT_LOOKUP_MAX,
prompt_lookup_min=PROMPT_LOOKUP_MIN,
)
vllm_config = VllmConfig(
model_config=model_config,
cache_config=CacheConfig(),
speculative_config=speculative_config,
device_config=DeviceConfig(device=current_platform.device_type),
parallel_config=ParallelConfig(),
load_config=LoadConfig(),
scheduler_config=SchedulerConfig(
max_model_len=model_config.max_model_len,
is_encoder_decoder=model_config.is_encoder_decoder,
),
)
# monkey patch vllm.v1.worker.gpu_model_runner.get_pp_group
mock_pp_group = mock.MagicMock()
mock_pp_group.world_size = 1
with mock.patch(
"vllm.v1.worker.gpu_model_runner.get_pp_group", return_value=mock_pp_group
):
runner = GPUModelRunner(vllm_config, DEVICE)
# hack max model len
runner.max_model_len = MAX_MODEL_LEN
runner.drafter.max_model_len = MAX_MODEL_LEN
dummy_input_batch = InputBatch(
max_num_reqs=args.num_req,
max_model_len=MAX_MODEL_LEN,
max_num_batched_tokens=args.num_req * args.num_token,
device=DEVICE,
pin_memory=False,
vocab_size=256000,
block_sizes=[16],
)
dummy_input_batch._req_ids = list(str(id) for id in range(args.num_req))
dummy_input_batch.num_tokens_no_spec = [args.num_token] * args.num_req
dummy_input_batch.token_ids_cpu = np.random.randint(
0, 20, (args.num_req, args.num_token)
)
runner.input_batch = dummy_input_batch
sampled_token_ids = [[0]] * args.num_req
print("Starting benchmark")
# first run is warmup so ignore it
for _ in range(args.num_iteration):
start = time.time()
runner.drafter.propose(
sampled_token_ids,
dummy_input_batch.num_tokens_no_spec,
dummy_input_batch.token_ids_cpu,
)
end = time.time()
print(f"Iteration time (s): {end - start}")
def invoke_main() -> None:
parser = FlexibleArgumentParser(
description="Benchmark the performance of N-gram speculative decode drafting"
)
parser.add_argument(
"--batched", action="store_true", help="consider time to prepare batch"
)
parser.add_argument(
"--num-iteration",
type=int,
default=100,
help="Number of iterations to run to stabilize final data readings",
)
parser.add_argument(
"--num-req", type=int, default=128, help="Number of requests in the batch"
)
parser.add_argument(
"--num-token", type=int, default=1500, help="Number of tokens for each request"
)
parser.add_argument(
"--min-ngram",
type=int,
default=3,
help="Minimum n-gram to match",
)
parser.add_argument(
"--max-ngram",
type=int,
nargs="*",
default=[5, 7, 10, 15, 20],
help="Maximum n-gram to match",
)
parser.add_argument(
"--num-spec-token",
type=int,
default=3,
help="Number of speculative tokens to generate",
)
args = parser.parse_args()
if not args.batched:
benchmark_propose(args)
else:
benchmark_batched_propose(args)
"""
# Example command lines:
# time python3 benchmarks/benchmark_ngram_proposer.py
# time python3 benchmarks/benchmark_ngram_proposer.py --batched --num-iteration 4 --num-token 1000000 --num-req 128
""" # noqa: E501
if __name__ == "__main__":
invoke_main() # pragma: no cover
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Simple benchmark to compare prefix-cache block hashing algorithms.
Example:
python benchmark_prefix_block_hash.py --num-blocks 20000 --block-size 32
"""
from __future__ import annotations
import argparse
import random
import statistics
import sys
import time
from collections.abc import Callable, Iterable, Sequence
from vllm.utils.hashing import get_hash_fn_by_name
from vllm.v1.core.kv_cache_utils import BlockHash, hash_block_tokens, init_none_hash
SUPPORTED_ALGOS = ("sha256", "sha256_cbor", "xxhash", "xxhash_cbor")
def _generate_blocks(
num_blocks: int, block_size: int, vocab_size: int, seed: int
) -> list[list[int]]:
rng = random.Random(seed)
return [
[rng.randrange(vocab_size) for _ in range(block_size)]
for _ in range(num_blocks)
]
def _hash_all_blocks(
hash_fn: Callable[[object], bytes],
blocks: Iterable[Sequence[int]],
) -> float:
parent_hash: BlockHash | None = None
start = time.perf_counter()
for block in blocks:
parent_hash = hash_block_tokens(hash_fn, parent_hash, block, extra_keys=None)
end = time.perf_counter()
return end - start
def _benchmark(
hash_algo: str,
blocks: list[list[int]],
trials: int,
) -> tuple[float, float, float] | None:
try:
hash_fn = get_hash_fn_by_name(hash_algo)
init_none_hash(hash_fn)
timings = [_hash_all_blocks(hash_fn, blocks) for _ in range(trials)]
except ModuleNotFoundError as exc:
print(f"Skipping {hash_algo}: {exc}", file=sys.stderr)
return None
avg = statistics.mean(timings)
best = min(timings)
# throughput: tokens / second
tokens_hashed = len(blocks) * len(blocks[0])
throughput = tokens_hashed / best
return avg, best, throughput
def main() -> None:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--num-blocks", type=int, default=10000, help="Block count.")
parser.add_argument("--block-size", type=int, default=32, help="Tokens per block.")
parser.add_argument(
"--vocab-size", type=int, default=32000, help="Token id range [0, vocab_size)."
)
parser.add_argument("--seed", type=int, default=0, help="Random seed.")
parser.add_argument(
"--trials", type=int, default=5, help="Number of timed trials per algorithm."
)
parser.add_argument(
"--algorithms",
nargs="+",
default=SUPPORTED_ALGOS,
choices=SUPPORTED_ALGOS,
help="Hash algorithms to benchmark.",
)
args = parser.parse_args()
blocks = _generate_blocks(
args.num_blocks, args.block_size, args.vocab_size, args.seed
)
print(
f"Benchmarking {len(args.algorithms)} algorithms on "
f"{args.num_blocks} blocks (block size={args.block_size})."
)
for algo in args.algorithms:
result = _benchmark(algo, blocks, args.trials)
if result is None:
continue
avg, best, throughput = result
print(
f"{algo:14s} avg: {avg:.6f}s best: {best:.6f}s "
f"throughput: {throughput / 1e6:.2f}M tokens/s"
)
if __name__ == "__main__":
main()
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Benchmark the efficiency of prefix caching.
This script allows you to benchmark the performance of
a model with and without prefix caching using either fixed prompts
or prompts sampled from the ShareGPT dataset.
Fixed example usage:
python benchmark_prefix_caching.py \
--model meta-llama/Llama-2-7b-chat-hf \
--enable-prefix-caching \
--num-prompts 1 \
--repeat-count 100 \
--input-length-range 128:256
ShareGPT example usage:
# This command samples 20 prompts with input lengths
# between 128 and 256 tokens from the ShareGPT dataset,
# then replicates each prompt 5 times.
python benchmark_prefix_caching.py \
--model meta-llama/Llama-2-7b-chat-hf \
--dataset-path /path/to/ShareGPT_V3_unfiltered_cleaned_split.json \
--enable-prefix-caching \
--num-prompts 20 \
--repeat-count 5 \
--input-length-range 128:256
"""
import dataclasses
import json
import random
import time
from transformers import PreTrainedTokenizerBase
from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import EngineArgs
from vllm.utils.argparse_utils import FlexibleArgumentParser
try:
from vllm.tokenizers import get_tokenizer
except ImportError:
from backend_request_func import get_tokenizer
PROMPT = "You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as fellows. You need to answer my question about the table.\n# Table\n|Opening|Opening|Sl. No.|Film|Cast|Director|Music Director|Notes|\n|----|----|----|----|----|----|----|----|\n|J A N|9|1|Agni Pushpam|Jayabharathi, Kamalahasan|Jeassy|M. K. Arjunan||\n|J A N|16|2|Priyamvada|Mohan Sharma, Lakshmi, KPAC Lalitha|K. S. Sethumadhavan|V. Dakshinamoorthy||\n|J A N|23|3|Yakshagaanam|Madhu, Sheela|Sheela|M. S. Viswanathan||\n|J A N|30|4|Paalkkadal|Sheela, Sharada|T. K. Prasad|A. T. Ummer||\n|F E B|5|5|Amma|Madhu, Srividya|M. Krishnan Nair|M. K. Arjunan||\n|F E B|13|6|Appooppan|Thikkurissi Sukumaran Nair, Kamal Haasan|P. Bhaskaran|M. S. Baburaj||\n|F E B|20|7|Srishti|Chowalloor Krishnankutty, Ravi Alummoodu|K. T. Muhammad|M. S. Baburaj||\n|F E B|20|8|Vanadevatha|Prem Nazir, Madhubala|Yusufali Kechery|G. Devarajan||\n|F E B|27|9|Samasya|Madhu, Kamalahaasan|K. Thankappan|Shyam||\n|F E B|27|10|Yudhabhoomi|K. P. Ummer, Vidhubala|Crossbelt Mani|R. K. Shekhar||\n|M A R|5|11|Seemantha Puthran|Prem Nazir, Jayabharathi|A. B. Raj|M. K. Arjunan||\n|M A R|12|12|Swapnadanam|Rani Chandra, Dr. Mohandas|K. G. George|Bhaskar Chandavarkar||\n|M A R|19|13|Thulavarsham|Prem Nazir, sreedevi, Sudheer|N. Sankaran Nair|V. Dakshinamoorthy||\n|M A R|20|14|Aruthu|Kaviyoor Ponnamma, Kamalahasan|Ravi|G. Devarajan||\n|M A R|26|15|Swimming Pool|Kamal Haasan, M. G. Soman|J. Sasikumar|M. K. Arjunan||\n\n# Question\nWhat' s the content in the (1,1) cells\n" # noqa: E501
def test_prefix(llm=None, sampling_params=None, prompts=None):
start_time = time.time()
llm.generate(prompts, sampling_params=sampling_params)
end_time = time.time()
print(f"cost time {end_time - start_time}")
@dataclasses.dataclass
class Request:
prompt: str
prompt_len: int
output_len: int
def sample_tokens(tokenizer: PreTrainedTokenizerBase, length: int) -> list[int]:
vocab = tokenizer.get_vocab()
all_special_ids = set(tokenizer.all_special_ids)
# Remove the special tokens.
return random.choices(
[v for v in vocab.values() if v not in all_special_ids],
k=length,
)
def sample_requests_from_dataset(
dataset_path: str,
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
input_length_range: tuple[int, int],
fixed_output_len: int | None,
) -> list[Request]:
if fixed_output_len is not None and fixed_output_len < 4:
raise ValueError("output_len too small")
# Load the dataset.
with open(dataset_path) as f:
dataset = json.load(f)
# Filter out the conversations with less than 2 turns.
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
# Only keep the first two turns of each conversation.
dataset = [
(data["conversations"][0]["value"], data["conversations"][1]["value"])
for data in dataset
]
# Shuffle the dataset.
random.shuffle(dataset)
min_len, max_len = input_length_range
assert min_len >= 0 and max_len >= min_len, "input_length_range too small"
# Filter out sequences that are too long or too short
filtered_requests: list[Request] = []
for i in range(len(dataset)):
if len(filtered_requests) == num_requests:
break
# Tokenize the prompts and completions.
prompt_token_ids = tokenizer(dataset[i][0]).input_ids
prompt = tokenizer.decode(prompt_token_ids)
completion = dataset[i][1]
completion_token_ids = tokenizer(completion).input_ids
prompt_len = len(prompt_token_ids)
output_len = (
len(completion_token_ids) if fixed_output_len is None else fixed_output_len
)
if min_len <= prompt_len <= max_len:
filtered_requests.append(Request(prompt, prompt_len, output_len))
return filtered_requests
def sample_requests_from_random(
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
input_length_range: tuple[int, int],
fixed_output_len: int | None,
prefix_len: int,
) -> list[Request]:
requests = []
prefix_token_ids = sample_tokens(tokenizer, prefix_len)
min_len, max_len = input_length_range
for i in range(num_requests):
unique_part_token_ids = sample_tokens(
tokenizer, random.randint(min_len - prefix_len, max_len - prefix_len)
)
prompt_token_ids = prefix_token_ids + unique_part_token_ids
prompt = tokenizer.decode(prompt_token_ids)
prompt_len = len(prompt_token_ids)
assert min_len <= prompt_len <= max_len, (
f"prompt_len {prompt_len} out of range {min_len}:{max_len}"
)
requests.append(Request(prompt, prompt_len, fixed_output_len))
return requests
def repeat_and_sort_requests(
requests: list[Request], repeat_count: int, sort: bool = False
) -> list[str]:
repeated_requests = requests * repeat_count
if sort:
repeated_requests.sort(key=lambda x: x[1])
else:
random.shuffle(repeated_requests)
return [req.prompt for req in repeated_requests]
def main(args):
tokenizer = get_tokenizer(args.model, trust_remote_code=True)
input_length_range = tuple(map(int, args.input_length_range.split(":")))
random.seed(args.seed)
if args.dataset_path is not None:
if args.prefix_len > 0:
raise ValueError(
"prefix-len is not supported when dataset-path is provided."
)
print(f"Start to sample {args.num_prompts} prompts from {args.dataset_path}")
filtered_requests = sample_requests_from_dataset(
dataset_path=args.dataset_path,
num_requests=args.num_prompts,
tokenizer=tokenizer,
input_length_range=input_length_range,
fixed_output_len=args.output_len,
)
else:
print(f"Start to sample {args.num_prompts} prompts from random")
filtered_requests = sample_requests_from_random(
num_requests=args.num_prompts,
tokenizer=tokenizer,
input_length_range=input_length_range,
fixed_output_len=args.output_len,
prefix_len=args.prefix_len,
)
# Print some helpful stats of the requests.
print(f"Sampled {len(filtered_requests)} requests.")
prompt_lens = [req.prompt_len for req in filtered_requests]
print(f"Average input length: {sum(prompt_lens) / len(prompt_lens)}")
print(f"P50 input length: {sorted(prompt_lens)[len(prompt_lens) // 2]}")
print(f"Min Prompt Length: {min(prompt_lens)}")
print(f"Max Prompt Length: {max(prompt_lens)}")
engine_args = EngineArgs.from_cli_args(args)
llm = LLM(**dataclasses.asdict(engine_args))
sampling_params = SamplingParams(
temperature=0,
max_tokens=args.output_len,
detokenize=not args.disable_detokenize,
)
print("Testing filtered requests")
prompts = repeat_and_sort_requests(
filtered_requests, repeat_count=args.repeat_count, sort=args.sort
)
print("------start generating------")
test_prefix(
llm=llm,
prompts=prompts,
sampling_params=sampling_params,
)
def create_argument_parser():
parser = FlexibleArgumentParser(
description="Benchmark the performance with or without "
"automatic prefix caching."
)
parser.add_argument(
"--dataset-path", type=str, default=None, help="Path to the dataset."
)
parser.add_argument("--output-len", type=int, default=10)
parser.add_argument(
"--num-prompts",
type=int,
required=True,
help="Number of the prompts sampled from dataset",
)
parser.add_argument(
"--repeat-count",
type=int,
default=1,
help="Number of times to repeat each prompt",
)
parser.add_argument(
"--sort", action="store_true", help="Sort prompts by input length"
)
parser.add_argument(
"--input-length-range",
type=str,
required=True,
help="Range of input lengths for sampling prompts,"
'specified as "min:max" (e.g., "128:256").',
)
parser.add_argument(
"--prefix-len",
type=int,
default=0,
help="Specifies the length of a common prefix to be "
"added to the input prompt. The input-length-range will "
"subtract this length when filtering prompts. Only used "
"when dataset-path is not provided.",
)
parser.add_argument(
"--disable-detokenize",
action="store_true",
help=(
"Do not detokenize responses (i.e. do not include "
"detokenization time in the latency measurement)"
),
)
parser = EngineArgs.add_cli_args(parser)
return parser
if __name__ == "__main__":
parser = create_argument_parser()
args = parser.parse_args()
main(args)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Benchmark offline prioritization."""
import argparse
import dataclasses
import json
import random
import time
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from vllm.engine.arg_utils import EngineArgs
from vllm.utils.argparse_utils import FlexibleArgumentParser
# Select a equi-probable random priority
def get_random_flag():
return 0 if random.random() < 0.5 else 1
def sample_requests(
dataset_path: str,
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
fixed_output_len: int | None,
) -> list[tuple[str, int, int, int]]:
if fixed_output_len is not None and fixed_output_len < 4:
raise ValueError("output_len too small")
# Load the dataset.
with open(dataset_path) as f:
dataset = json.load(f)
# Filter out the conversations with less than 2 turns.
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
# Only keep the first two turns of each conversation.
dataset = [
(data["conversations"][0]["value"], data["conversations"][1]["value"])
for data in dataset
]
# Shuffle the dataset.
random.shuffle(dataset)
# Filter out sequences that are too long or too short
filtered_dataset: list[tuple[str, int, int]] = []
for i in range(len(dataset)):
if len(filtered_dataset) == num_requests:
break
# Tokenize the prompts and completions.
prompt = dataset[i][0]
prompt_token_ids = tokenizer(prompt).input_ids
completion = dataset[i][1]
completion_token_ids = tokenizer(completion).input_ids
prompt_len = len(prompt_token_ids)
output_len = (
len(completion_token_ids) if fixed_output_len is None else fixed_output_len
)
if prompt_len < 4 or output_len < 4:
# Prune too short sequences.
continue
if prompt_len > 1024 or prompt_len + output_len > 2048:
# Prune too long sequences.
continue
priority = get_random_flag()
filtered_dataset.append((prompt, prompt_len, output_len, priority))
return filtered_dataset
def run_vllm(
requests: list[tuple[str, int, int]],
n: int,
engine_args: EngineArgs,
disable_detokenize: bool = False,
) -> float:
from vllm import LLM, SamplingParams
llm = LLM(**dataclasses.asdict(engine_args))
assert all(
llm.llm_engine.model_config.max_model_len >= (request[1] + request[2])
for request in requests
), (
"Please ensure that max_model_len is greater than the sum of"
" input_len and output_len for all requests."
)
# Add the requests to the engine.
prompts = []
sampling_params = []
priority = []
for prompt, _, output_len, _priority in requests:
prompts.append(prompt)
priority.append(_priority)
sampling_params.append(
SamplingParams(
n=n,
temperature=1.0,
top_p=1.0,
ignore_eos=True,
max_tokens=output_len,
detokenize=not disable_detokenize,
)
)
start = time.perf_counter()
llm.generate(prompts, sampling_params, priority=priority, use_tqdm=True)
end = time.perf_counter()
return end - start
def main(args: argparse.Namespace):
print(args)
random.seed(args.seed)
# Sample the requests.
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer, trust_remote_code=args.trust_remote_code
)
if args.dataset is None:
# Synthesize a prompt with the given input length.
prompt = "hi" * (args.input_len - 1)
requests = [
(prompt, args.input_len, args.output_len, get_random_flag())
for _ in range(args.num_prompts)
]
else:
requests = sample_requests(
args.dataset, args.num_prompts, tokenizer, args.output_len
)
if args.backend == "vllm":
elapsed_time = run_vllm(
requests, args.n, EngineArgs.from_cli_args(args), args.disable_detokenize
)
else:
raise ValueError(f"Unknown backend: {args.backend}")
total_num_tokens = sum(
prompt_len + output_len for _, prompt_len, output_len, priority in requests
)
print(
f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
f"{total_num_tokens / elapsed_time:.2f} tokens/s"
)
# Output JSON results if specified
if args.output_json:
results = {
"elapsed_time": elapsed_time,
"num_requests": len(requests),
"total_num_tokens": total_num_tokens,
"requests_per_second": len(requests) / elapsed_time,
"tokens_per_second": total_num_tokens / elapsed_time,
}
with open(args.output_json, "w") as f:
json.dump(results, f, indent=4)
def create_argument_parser():
parser = FlexibleArgumentParser(description="Benchmark the throughput.")
parser.add_argument(
"--backend", type=str, choices=["vllm", "hf", "mii"], default="vllm"
)
parser.add_argument(
"--dataset", type=str, default=None, help="Path to the dataset."
)
parser.add_argument(
"--input-len",
type=int,
default=None,
help="Input prompt length for each request",
)
parser.add_argument(
"--output-len",
type=int,
default=None,
help="Output length for each request. Overrides the "
"output length from the dataset.",
)
parser.add_argument(
"--n", type=int, default=1, help="Number of generated sequences per prompt."
)
parser.add_argument(
"--num-prompts", type=int, default=200, help="Number of prompts to process."
)
parser.add_argument(
"--output-json",
type=str,
default=None,
help="Path to save the throughput results in JSON format.",
)
parser.add_argument(
"--disable-detokenize",
action="store_true",
help=(
"Do not detokenize responses (i.e. do not include "
"detokenization time in the latency measurement)"
),
)
parser = EngineArgs.add_cli_args(parser)
return parser
if __name__ == "__main__":
parser = create_argument_parser()
args = parser.parse_args()
if args.tokenizer is None:
args.tokenizer = args.model
if args.dataset is None:
assert args.input_len is not None
assert args.output_len is not None
else:
assert args.input_len is None
main(args)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import sys
if __name__ == "__main__":
print("""DEPRECATED: This script has been moved to the vLLM CLI.
Please use the following command instead:
vllm bench serve
For help with the new command, run:
vllm bench serve --help
Alternatively, you can run the new command directly with:
python -m vllm.entrypoints.cli.main bench serve --help
""")
sys.exit(1)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
r"""Benchmark online serving throughput with structured outputs.
On the server side, run one of the following commands:
(vLLM OpenAI API server)
vllm serve <your_model>
On the client side, run:
python benchmarks/benchmark_serving_structured_output.py \
--backend <backend> \
--model <your_model> \
--dataset json \
--structured-output-ratio 1.0 \
--request-rate 10 \
--num-prompts 1000
when using tgi backend, add
--endpoint /generate_stream
to the end of the command above.
"""
import argparse
import asyncio
import copy
import dataclasses
import json
import os
import random
import time
import uuid
import warnings
from collections.abc import AsyncGenerator
from contextlib import nullcontext
from dataclasses import dataclass
import datasets
import numpy as np
import pandas as pd
from backend_request_func import (
ASYNC_REQUEST_FUNCS,
RequestFuncInput,
RequestFuncOutput,
)
from tqdm.asyncio import tqdm
from transformers import PreTrainedTokenizerBase
try:
from vllm.tokenizers import get_tokenizer
except ImportError:
from backend_request_func import get_tokenizer
try:
from vllm.utils.argparse_utils import FlexibleArgumentParser
except ImportError:
from argparse import ArgumentParser as FlexibleArgumentParser
from vllm.v1.structured_output.backend_xgrammar import (
has_xgrammar_unsupported_json_features,
)
MILLISECONDS_TO_SECONDS_CONVERSION = 1000
@dataclass
class BenchmarkMetrics:
completed: int
total_input: int
total_output: int
request_throughput: float
request_goodput: float
output_throughput: float
total_token_throughput: float
mean_ttft_ms: float
median_ttft_ms: float
std_ttft_ms: float
percentiles_ttft_ms: list[tuple[float, float]]
mean_tpot_ms: float
median_tpot_ms: float
std_tpot_ms: float
percentiles_tpot_ms: list[tuple[float, float]]
mean_itl_ms: float
median_itl_ms: float
std_itl_ms: float
percentiles_itl_ms: list[tuple[float, float]]
# E2EL stands for end-to-end latency per request.
# It is the time taken on the client side from sending
# a request to receiving a complete response.
mean_e2el_ms: float
median_e2el_ms: float
std_e2el_ms: float
percentiles_e2el_ms: list[tuple[float, float]]
@dataclasses.dataclass
class SampleRequest:
"""A class representing a single inference request for benchmarking.
Attributes:
prompt: The input text prompt for the model.
multi_modal_data: Optional dictionary containing multi-modal data (e.g.
images).
prompt_len: The length of the prompt in tokens.
expected_output_len: The expected length of the output in tokens.
"""
prompt: str
prompt_len: int
expected_output_len: int
schema: dict
structure_type: str
completion: str = None
def sample_requests(
tokenizer: PreTrainedTokenizerBase, args: argparse.Namespace
) -> list[SampleRequest]:
if args.dataset == "json" or args.dataset == "json-unique":
if args.json_schema_path is None:
dir_path = os.path.dirname(os.path.realpath(__file__))
args.json_schema_path = os.path.join(
dir_path, "structured_schemas", "structured_schema_1.json"
)
json_schemas = []
with open(args.json_schema_path) as f:
schema = json.load(f)
if args.dataset == "json-unique":
json_schemas = [copy.deepcopy(schema) for _ in range(args.num_prompts)]
for i in range(len(json_schemas)):
if "properties" not in json_schemas[i]:
json_schemas[i]["properties"] = {}
json_schemas[i]["properties"][f"__optional_field_{uuid.uuid4()}"] = {
"type": "string",
"description": "An unique optional field to avoid cached schemas",
}
else:
json_schemas = [schema] * args.num_prompts
def gen_prompt(index: int):
return f"Generate an example of a brief user profile given the following schema: {json.dumps(get_schema(index))}" # noqa: E501
def get_schema(index: int):
return json_schemas[index % len(json_schemas)]
requests = [
SampleRequest(
prompt=gen_prompt(i),
prompt_len=len(tokenizer(gen_prompt(i)).input_ids),
expected_output_len=args.output_len,
schema=get_schema(i),
structure_type=args.structure_type,
)
for i in range(args.num_prompts)
]
elif args.dataset == "grammar":
schema = """
root ::= select_statement
select_statement ::= "SELECT " column " from " table " where " condition
column ::= "col_1 " | "col_2 "
table ::= "table_1 " | "table_2 "
condition ::= column "= " number
number ::= "1 " | "2 "
"""
prompt = "Generate an SQL query to show the 'username' \
and 'email' from the 'users' table."
input_len = len(tokenizer(prompt).input_ids)
print(f"Input length of the prompt: {input_len} tokens")
requests = [
SampleRequest(
prompt=prompt,
prompt_len=input_len,
expected_output_len=args.output_len,
schema=schema,
structure_type=args.structure_type,
)
for _ in range(args.num_prompts)
]
elif args.dataset == "regex":
regex = r"\w+@\w+\.com\n"
args.regex = regex
prompt = "Generate an email address for Alan Turing, \
who works in Enigma. End in .com and new line. \
Example result: alan.turing@enigma.com\n"
input_len = len(tokenizer(prompt).input_ids)
print(f"Input length of the prompt: {input_len} tokens")
requests = [
SampleRequest(
prompt=prompt,
prompt_len=input_len,
expected_output_len=args.output_len,
schema=regex,
structure_type=args.structure_type,
)
for _ in range(args.num_prompts)
]
elif args.dataset == "choice":
choice = ["Positive", "Negative"]
args.choice = choice
prompt = "Classify this sentiment: vLLM is wonderful!"
input_len = len(tokenizer(prompt).input_ids)
print(f"Input length of the prompt: {input_len} tokens")
requests = [
SampleRequest(
prompt=prompt,
prompt_len=input_len,
expected_output_len=args.output_len,
schema=choice,
structure_type=args.structure_type,
)
for _ in range(args.num_prompts)
]
elif args.dataset == "xgrammar_bench":
requests: list[SampleRequest] = []
dataset = datasets.load_dataset("NousResearch/json-mode-eval", split="train")
full_dataset_len = len(dataset)
def _filter_func(item):
import json
schema = json.loads(item["schema"])
return not has_xgrammar_unsupported_json_features(schema)
dataset = dataset.filter(_filter_func)
num_filtered_out = full_dataset_len - len(dataset)
print(
f"dataset has {len(dataset)} entries after filtering "
f"out {num_filtered_out} entries with unsupported features"
)
len_dataset = len(dataset)
for data_point_idx in range(args.num_prompts):
idx = data_point_idx
while idx >= len_dataset:
idx -= len_dataset
schema = dataset["schema"][idx]
prompt = tokenizer.apply_chat_template(
dataset["prompt"][idx], tokenize=False, add_generation_prompt=True
)
input_len = len(tokenizer(prompt).input_ids)
completion = dataset["completion"][idx]
requests.append(
SampleRequest(
prompt=prompt,
prompt_len=input_len,
expected_output_len=args.output_len,
schema=schema,
structure_type=args.structure_type,
completion=completion,
)
)
return requests
async def get_request(
input_requests: list[SampleRequest],
request_rate: float,
burstiness: float = 1.0,
) -> AsyncGenerator[tuple[int, SampleRequest], None]:
"""
Asynchronously generates requests at a specified rate
with OPTIONAL burstiness.
Args:
input_requests:
A list of input requests, each represented as a tuple.
request_rate:
The rate at which requests are generated (requests/s).
burstiness (optional):
The burstiness factor of the request generation.
Only takes effect when request_rate is not inf.
Default value is 1, which follows a Poisson process.
Otherwise, the request intervals follow a gamma distribution.
A lower burstiness value (0 < burstiness < 1) results
in more bursty requests, while a higher burstiness value
(burstiness > 1) results in a more uniform arrival of requests.
"""
input_requests = iter(input_requests)
# Calculate scale parameter theta to maintain the desired request_rate.
assert burstiness > 0, (
f"A positive burstiness factor is expected, but given {burstiness}."
)
theta = 1.0 / (request_rate * burstiness)
for i, request in enumerate(input_requests):
yield i, request
if request_rate == float("inf"):
# If the request rate is infinity, then we don't need to wait.
continue
# Sample the request interval from the gamma distribution.
# If burstiness is 1, it follows exponential distribution.
interval = np.random.gamma(shape=burstiness, scale=theta)
# The next request will be sent after the interval.
await asyncio.sleep(interval)
def calculate_metrics(
input_requests: list[tuple[str, int, int]],
outputs: list[RequestFuncOutput],
dur_s: float,
tokenizer: PreTrainedTokenizerBase,
selected_percentile_metrics: list[str],
selected_percentiles: list[float],
goodput_config_dict: dict[str, float] | None = None,
) -> tuple[BenchmarkMetrics, list[int]]:
actual_output_lens: list[int] = []
total_input = 0
completed = 0
good_completed = 0
itls: list[float] = []
tpots: list[float] = []
all_tpots: list[float] = []
ttfts: list[float] = []
e2els: list[float] = []
for i in range(len(outputs)):
if outputs[i].success:
# We use the tokenizer to count the number of output tokens for all
# serving backends instead of looking at len(outputs[i].itl) since
# multiple output tokens may be bundled together
# Note : this may inflate the output token count slightly
output_len = len(
tokenizer(outputs[i].generated_text, add_special_tokens=False).input_ids
)
actual_output_lens.append(output_len)
total_input += input_requests[i].prompt_len
tpot = 0
if output_len > 1:
latency_minus_ttft = outputs[i].latency - outputs[i].ttft
tpot = latency_minus_ttft / (output_len - 1)
tpots.append(tpot)
outputs[i].tpot = tpot
# Note: if output_len <= 1, we regard tpot as 0 for goodput
all_tpots.append(tpot)
itls += outputs[i].itl
ttfts.append(outputs[i].ttft)
e2els.append(outputs[i].latency)
completed += 1
else:
actual_output_lens.append(0)
if goodput_config_dict:
valid_metrics = []
slo_values = []
if "ttft" in goodput_config_dict:
valid_metrics.append(ttfts)
slo_values.append(
goodput_config_dict["ttft"] / MILLISECONDS_TO_SECONDS_CONVERSION
)
if "tpot" in goodput_config_dict:
valid_metrics.append(all_tpots)
slo_values.append(
goodput_config_dict["tpot"] / MILLISECONDS_TO_SECONDS_CONVERSION
)
if "e2el" in goodput_config_dict:
valid_metrics.append(e2els)
slo_values.append(
goodput_config_dict["e2el"] / MILLISECONDS_TO_SECONDS_CONVERSION
)
for req_metric in zip(*valid_metrics):
is_good_req = all([s >= r for s, r in zip(slo_values, req_metric)])
if is_good_req:
good_completed += 1
if completed == 0:
warnings.warn(
"All requests failed. This is likely due to a misconfiguration "
"on the benchmark arguments.",
stacklevel=2,
)
metrics = BenchmarkMetrics(
completed=completed,
total_input=total_input,
total_output=sum(actual_output_lens),
request_throughput=completed / dur_s,
request_goodput=good_completed / dur_s,
output_throughput=sum(actual_output_lens) / dur_s,
total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s,
mean_ttft_ms=np.mean(ttfts or 0)
* 1000, # ttfts is empty if streaming is not supported by backend
std_ttft_ms=np.std(ttfts or 0) * 1000,
median_ttft_ms=np.median(ttfts or 0) * 1000,
percentiles_ttft_ms=[
(p, np.percentile(ttfts or 0, p) * 1000) for p in selected_percentiles
],
mean_tpot_ms=np.mean(tpots or 0) * 1000,
std_tpot_ms=np.std(tpots or 0) * 1000,
median_tpot_ms=np.median(tpots or 0) * 1000,
percentiles_tpot_ms=[
(p, np.percentile(tpots or 0, p) * 1000) for p in selected_percentiles
],
mean_itl_ms=np.mean(itls or 0) * 1000,
std_itl_ms=np.std(itls or 0) * 1000,
median_itl_ms=np.median(itls or 0) * 1000,
percentiles_itl_ms=[
(p, np.percentile(itls or 0, p) * 1000) for p in selected_percentiles
],
mean_e2el_ms=np.mean(e2els or 0) * 1000,
std_e2el_ms=np.std(e2els or 0) * 1000,
median_e2el_ms=np.median(e2els or 0) * 1000,
percentiles_e2el_ms=[
(p, np.percentile(e2els or 0, p) * 1000) for p in selected_percentiles
],
)
return metrics, actual_output_lens
async def benchmark(
backend: str,
api_url: str,
base_url: str,
model_id: str,
tokenizer: PreTrainedTokenizerBase,
input_requests: list[SampleRequest],
request_rate: float,
burstiness: float,
disable_tqdm: bool,
profile: bool,
selected_percentile_metrics: list[str],
selected_percentiles: list[str],
ignore_eos: bool,
max_concurrency: int | None,
structured_output_ratio: float,
goodput_config_dict: dict[str, float] | None = None,
):
if backend in ASYNC_REQUEST_FUNCS:
request_func = ASYNC_REQUEST_FUNCS[backend]
else:
raise ValueError(f"Unknown backend: {backend}")
def prepare_extra_body(request) -> dict:
extra_body = {}
# Add the schema to the extra_body
extra_body["structured_outputs"] = {}
extra_body["structured_outputs"][request.structure_type] = request.schema
return extra_body
print("Starting initial single prompt test run...")
structured_output_req_idx = random.sample(
range(len(input_requests)), int(len(input_requests) * structured_output_ratio)
)
test_request = input_requests[0]
test_req_extra_body = (
prepare_extra_body(test_request) if 0 in structured_output_req_idx else None
)
test_input = RequestFuncInput(
model=model_id,
prompt=test_request.prompt,
api_url=api_url,
prompt_len=test_request.prompt_len,
output_len=test_request.expected_output_len,
ignore_eos=ignore_eos,
extra_body=test_req_extra_body,
)
test_output = await request_func(request_func_input=test_input)
if not test_output.success:
raise ValueError(
"Initial test run failed - Please make sure benchmark arguments "
f"are correctly specified. Error: {test_output.error}"
)
else:
print("Initial test run completed. Starting main benchmark run...")
if profile:
print("Starting profiler...")
profile_input = RequestFuncInput(
model=model_id,
prompt=test_request.prompt,
api_url=base_url + "/start_profile",
prompt_len=test_request.prompt_len,
output_len=test_request.expected_output_len,
ignore_eos=ignore_eos,
extra_body=test_req_extra_body,
)
profile_output = await request_func(request_func_input=profile_input)
if profile_output.success:
print("Profiler started")
distribution = "Poisson process" if burstiness == 1.0 else "Gamma distribution"
print(f"Traffic request rate: {request_rate}")
print(f"Burstiness factor: {burstiness} ({distribution})")
print(f"Maximum request concurrency: {max_concurrency}")
pbar = None if disable_tqdm else tqdm(total=len(input_requests))
semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else nullcontext()
async def limited_request_func(request_func_input, pbar):
async with semaphore:
return await request_func(request_func_input=request_func_input, pbar=pbar)
benchmark_start_time = time.perf_counter()
tasks: list[asyncio.Task] = []
expected: list[str] = []
async for i, request in get_request(input_requests, request_rate, burstiness):
extra_body = (
prepare_extra_body(request) if i in structured_output_req_idx else None
)
request_func_input = RequestFuncInput(
model=model_id,
prompt=request.prompt,
api_url=api_url,
prompt_len=request.prompt_len,
output_len=request.expected_output_len,
ignore_eos=ignore_eos,
extra_body=extra_body,
)
expected.append(request.completion)
tasks.append(
asyncio.create_task(
limited_request_func(request_func_input=request_func_input, pbar=pbar)
)
)
outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks)
if pbar is not None:
pbar.close()
benchmark_duration = time.perf_counter() - benchmark_start_time
metrics, actual_output_lens = calculate_metrics(
input_requests=input_requests,
outputs=outputs,
dur_s=benchmark_duration,
tokenizer=tokenizer,
selected_percentile_metrics=selected_percentile_metrics,
selected_percentiles=selected_percentiles,
goodput_config_dict=goodput_config_dict,
)
print("{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="="))
print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
if max_concurrency is not None:
print("{:<40} {:<10}".format("Maximum request concurrency:", max_concurrency))
if request_rate != float("inf"):
print("{:<40} {:<10.2f}".format("Request rate configured (RPS):", request_rate))
print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration))
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output))
print(
"{:<40} {:<10.2f}".format(
"Request throughput (req/s):", metrics.request_throughput
)
)
if goodput_config_dict:
print(
"{:<40} {:<10.2f}".format(
"Request goodput (req/s):", metrics.request_goodput
)
)
print(
"{:<40} {:<10.2f}".format(
"Output token throughput (tok/s):", metrics.output_throughput
)
)
print(
"{:<40} {:<10.2f}".format(
"Total token throughput (tok/s):", metrics.total_token_throughput
)
)
result = {
"duration": benchmark_duration,
"completed": metrics.completed,
"total_input_tokens": metrics.total_input,
"total_output_tokens": metrics.total_output,
"request_throughput": metrics.request_throughput,
"output_throughput": metrics.output_throughput,
"total_token_throughput": metrics.total_token_throughput,
"ttft_description": pd.Series([output.ttft for output in outputs])
.describe()
.to_dict(),
"tpot_description": pd.Series([output.tpot for output in outputs])
.describe()
.to_dict(),
"input_lens": [output.prompt_len for output in outputs],
"output_lens": actual_output_lens,
"ttfts": [output.ttft for output in outputs],
"itls": [output.itl for output in outputs],
"errors": [output.error for output in outputs],
}
ret = [
{"generated": output.generated_text, "expected": gt}
for output, gt in zip(outputs, expected)
]
def process_one_metric(
# E.g., "ttft"
metric_attribute_name: str,
# E.g., "TTFT"
metric_name: str,
# E.g., "Time to First Token"
metric_header: str,
):
# This function prints and adds statistics of the specified
# metric.
if metric_attribute_name not in selected_percentile_metrics:
return
print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-"))
print(
"{:<40} {:<10.2f}".format(
f"Mean {metric_name} (ms):",
getattr(metrics, f"mean_{metric_attribute_name}_ms"),
)
)
print(
"{:<40} {:<10.2f}".format(
f"Median {metric_name} (ms):",
getattr(metrics, f"median_{metric_attribute_name}_ms"),
)
)
result[f"mean_{metric_attribute_name}_ms"] = getattr(
metrics, f"mean_{metric_attribute_name}_ms"
)
result[f"median_{metric_attribute_name}_ms"] = getattr(
metrics, f"median_{metric_attribute_name}_ms"
)
result[f"std_{metric_attribute_name}_ms"] = getattr(
metrics, f"std_{metric_attribute_name}_ms"
)
for p, value in getattr(metrics, f"percentiles_{metric_attribute_name}_ms"):
p_word = str(int(p)) if int(p) == p else str(p)
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", value))
result[f"p{p_word}_{metric_attribute_name}_ms"] = value
process_one_metric("ttft", "TTFT", "Time to First Token")
process_one_metric("tpot", "TPOT", "Time per Output Token (excl. 1st token)")
process_one_metric("itl", "ITL", "Inter-token Latency")
process_one_metric("e2el", "E2EL", "End-to-end Latency")
print("=" * 50)
if profile:
print("Stopping profiler...")
profile_input = RequestFuncInput(
model=model_id,
prompt=test_request.prompt,
api_url=base_url + "/stop_profile",
prompt_len=test_request.prompt_len,
output_len=test_request.expected_output_len,
extra_body={test_request.structure_type: test_request.schema},
)
profile_output = await request_func(request_func_input=profile_input)
if profile_output.success:
print("Profiler stopped")
return result, ret
def evaluate(ret, args):
def _eval_correctness_json(expected, actual):
# extract json string from string using regex
import regex as re
actual = actual.replace("\n", "").replace(" ", "").strip()
try:
actual = re.search(r"\{.*\}", actual).group()
actual = json.loads(actual)
except Exception:
return False
return True
def _eval_correctness_choice(expected, actual):
return actual in args.choice
def _eval_correctness_regex(expected, actual):
import regex as re
return re.match(args.regex, actual) is not None
def _eval_correctness(expected, actual):
if args.structure_type == "json":
return _eval_correctness_json(expected, actual)
elif args.structure_type == "regex":
return _eval_correctness_regex(expected, actual)
elif args.structure_type == "choice":
return _eval_correctness_choice(expected, actual)
else:
return None
scores = []
for res in ret:
score = _eval_correctness(res["expected"], res["generated"])
res["correctness"] = score
scores.append(score)
not_none_scores = [score for score in scores if score is not None]
return (
(sum(not_none_scores) / len(not_none_scores) * 100)
if len(not_none_scores) > 0
else None
)
def parse_goodput(slo_pairs):
goodput_config_dict = {}
try:
for slo_pair in slo_pairs:
slo_name, slo_val = slo_pair.split(":")
goodput_config_dict[slo_name] = float(slo_val)
except ValueError as err:
raise argparse.ArgumentTypeError(
"Invalid format found for service level objectives. "
'Specify service level objectives for goodput as "KEY:VALUE" '
"pairs, where the key is a metric name, and the value is a "
"number in milliseconds."
) from err
return goodput_config_dict
def check_goodput_args(args):
goodput_config_dict = {}
VALID_NAMES = ["ttft", "tpot", "e2el"]
if args.goodput:
goodput_config_dict = parse_goodput(args.goodput)
for slo_name, slo_val in goodput_config_dict.items():
if slo_name not in VALID_NAMES:
raise ValueError(
f"Invalid metric name found, {slo_name}: {slo_val}. "
"The service level objective name should be one of "
f"{str(VALID_NAMES)}. "
)
if slo_val < 0:
raise ValueError(
f"Invalid value found, {slo_name}: {slo_val}. "
"The service level objective value should be "
"non-negative."
)
return goodput_config_dict
def main(args: argparse.Namespace):
print(args)
random.seed(args.seed)
np.random.seed(args.seed)
backend = args.backend
model_id = args.model
tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
if args.base_url is not None:
api_url = f"{args.base_url}{args.endpoint}"
base_url = f"{args.base_url}"
else:
api_url = f"http://{args.host}:{args.port}{args.endpoint}"
base_url = f"http://{args.host}:{args.port}"
tokenizer = get_tokenizer(
tokenizer_id,
trust_remote_code=args.trust_remote_code,
tokenizer_mode=args.tokenizer_mode,
)
if args.dataset == "grammar":
args.structure_type = "grammar"
elif args.dataset == "regex":
args.structure_type = "regex"
elif args.dataset == "choice":
args.structure_type = "choice"
else:
args.structure_type = "json"
if args.no_structured_output:
args.structured_output_ratio = 0
if args.save_results:
result_file_name = f"{args.structured_output_ratio}so"
result_file_name += f"_{backend}"
result_file_name += f"_{args.request_rate}qps"
result_file_name += f"_{args.model.split('/')[-1]}"
result_file_name += f"_{args.dataset}"
result_file_name += f"_{args.num_prompts}"
result_file_name += f"_out{args.output_len}"
result_file_name += ".txt"
else:
result_file_name = None
input_requests = sample_requests(tokenizer, args)
goodput_config_dict = check_goodput_args(args)
benchmark_result, ret = asyncio.run(
benchmark(
backend=backend,
api_url=api_url,
base_url=base_url,
model_id=model_id,
tokenizer=tokenizer,
input_requests=input_requests,
request_rate=args.request_rate,
burstiness=args.burstiness,
disable_tqdm=args.disable_tqdm,
profile=args.profile,
selected_percentile_metrics=args.percentile_metrics.split(","),
selected_percentiles=[float(p) for p in args.metric_percentiles.split(",")],
ignore_eos=args.ignore_eos,
max_concurrency=args.max_concurrency,
structured_output_ratio=args.structured_output_ratio,
goodput_config_dict=goodput_config_dict,
)
)
# Save config and results to json
score = evaluate(ret, args)
print("correct_rate(%)", score, "\n")
if args.save_results:
results = {
"backend": backend,
"model_id": model_id,
"tokenizer_id": tokenizer_id,
"num_prompts": args.num_prompts,
"request_rate": args.request_rate
if args.request_rate < float("inf")
else "inf",
"burstiness": args.burstiness,
"max_concurrency": args.max_concurrency,
"correct_rate(%)": score,
}
results = {"outputs": ret, **results, **benchmark_result}
# Save to file
if args.result_filename:
result_file_name = args.result_filename
if args.result_dir:
result_file_name = os.path.join(args.result_dir, result_file_name)
with open(result_file_name, "w", encoding="utf-8") as outfile:
json.dump(results, outfile, indent=4)
def create_argument_parser():
parser = FlexibleArgumentParser(
description="Benchmark the online serving throughput."
)
parser.add_argument(
"--backend",
type=str,
default="vllm",
choices=list(ASYNC_REQUEST_FUNCS.keys()),
)
parser.add_argument(
"--base-url",
type=str,
default=None,
help="Server or API base url if not using http host and port.",
)
# Use 127.0.0.1 here instead of localhost to force the use of ipv4
parser.add_argument("--host", type=str, default="127.0.0.1")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument(
"--endpoint",
type=str,
default="/v1/completions",
help="API endpoint.",
)
parser.add_argument(
"--dataset",
default="json",
choices=["json", "json-unique", "grammar", "regex", "choice", "xgrammar_bench"],
)
parser.add_argument(
"--json-schema-path", type=str, default=None, help="Path to json schema."
)
parser.add_argument(
"--max-concurrency",
type=int,
default=None,
help="Maximum number of concurrent requests. This can be used "
"to help simulate an environment where a higher level component "
"is enforcing a maximum number of concurrent requests. While the "
"--request-rate argument controls the rate at which requests are "
"initiated, this argument will control how many are actually allowed "
"to execute at a time. This means that when used in combination, the "
"actual request rate may be lower than specified with --request-rate, "
"if the server is not processing requests fast enough to keep up.",
)
parser.add_argument(
"--model",
type=str,
required=True,
help="Name of the model.",
)
parser.add_argument(
"--tokenizer",
type=str,
help="Name or path of the tokenizer, if not using the default tokenizer.",
)
parser.add_argument(
"--tokenizer-mode",
type=str,
default="auto",
help="Name or path of the tokenizer, if not using the default tokenizer.",
)
parser.add_argument(
"--num-prompts",
type=int,
default=1000,
help="Number of prompts to process.",
)
parser.add_argument(
"--output-len",
type=int,
default=128,
help="Number of output tokens.",
)
parser.add_argument(
"--request-rate",
type=float,
default=float("inf"),
help="Number of requests per second. If this is inf, "
"then all the requests are sent at time 0. "
"Otherwise, we use Poisson process or gamma distribution "
"to synthesize the request arrival times.",
)
parser.add_argument(
"--burstiness",
type=float,
default=1.0,
help="Burstiness factor of the request generation. "
"Only take effect when request_rate is not inf. "
"Default value is 1, which follows Poisson process. "
"Otherwise, the request intervals follow a gamma distribution. "
"A lower burstiness value (0 < burstiness < 1) results in more "
"bursty requests. A higher burstiness value (burstiness > 1) "
"results in a more uniform arrival of requests.",
)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument(
"--trust-remote-code",
action="store_true",
help="Trust remote code from huggingface",
)
parser.add_argument(
"--disable-tqdm",
action="store_true",
help="Specify to disable tqdm progress bar.",
)
parser.add_argument(
"--save-results",
action="store_true",
help="Specify to save benchmark results to a json file",
)
parser.add_argument(
"--profile",
action="store_true",
help="Use vLLM Profiling. --profiler-config must be provided on the server.",
)
parser.add_argument(
"--result-dir",
type=str,
default=None,
help="Specify directory to save benchmark json results."
"If not specified, results are saved in the current directory.",
)
parser.add_argument(
"--result-filename",
type=str,
default=None,
help="Specify the filename to save benchmark json results."
"If not specified, results will be saved in "
"{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json"
" format.",
)
parser.add_argument(
"--ignore-eos",
action="store_true",
help="Set ignore_eos flag when sending the benchmark request."
"Warning: ignore_eos is not supported in deepspeed_mii and tgi.",
)
parser.add_argument(
"--percentile-metrics",
type=str,
default="ttft,tpot,itl",
help="Comma-separated list of selected metrics to report percentiles. "
"This argument specifies the metrics to report percentiles. "
'Allowed metric names are "ttft", "tpot", "itl", "e2el". '
'Default value is "ttft,tpot,itl".',
)
parser.add_argument(
"--metric-percentiles",
type=str,
default="99",
help="Comma-separated list of percentiles for selected metrics. "
'To report 25-th, 50-th, and 75-th percentiles, use "25,50,75". '
'Default value is "99". '
'Use "--percentile-metrics" to select metrics.',
)
parser.add_argument(
"--goodput",
nargs="+",
required=False,
help='Specify service level objectives for goodput as "KEY:VALUE" '
"pairs, where the key is a metric name, and the value is in "
'milliseconds. Multiple "KEY:VALUE" pairs can be provided, '
"separated by spaces. Allowed request level metric names are "
'"ttft", "tpot", "e2el". For more context on the definition of '
"goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 "
"and the blog: https://hao-ai-lab.github.io/blogs/distserve",
)
parser.add_argument(
"--no-structured-output",
action="store_true",
default=False,
help="Whether to disable JSON decoding or not.",
)
parser.add_argument(
"--structured-output-ratio",
type=float,
default=1.0,
help="Ratio of Structured Outputs requests",
)
return parser
if __name__ == "__main__":
parser = create_argument_parser()
args = parser.parse_args()
main(args)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import sys
if __name__ == "__main__":
print("""DEPRECATED: This script has been moved to the vLLM CLI.
Please use the following command instead:
vllm bench throughput
For help with the new command, run:
vllm bench throughput --help
Alternatively, you can run the new command directly with:
python -m vllm.entrypoints.cli.main bench throughput --help
""")
sys.exit(1)
#!/usr/bin/env python3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Benchmark comparing Triton vs PyTorch sort-based top-k/top-p implementations.
Compares:
- apply_top_k_top_p_triton (Triton binary search)
- apply_top_k_top_p (PyTorch sort-based)
Scenarios:
- top_k only (whole batch, partial batch)
- top_p only (whole batch, partial batch)
- mix of top_k and top_p
"""
import argparse
import gc
from dataclasses import dataclass
import torch
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p_pytorch
from vllm.v1.sample.ops.topk_topp_triton import (
apply_top_k_top_p_triton,
reset_buffer_cache,
)
@dataclass
class BenchmarkConfig:
"""Configuration for a benchmark run."""
name: str
batch_size: int
vocab_size: int
# k and p can be tensors or None
k_values: torch.Tensor | None # [batch_size] or None
p_values: torch.Tensor | None # [batch_size] or None
description: str
ops_pct: float = 0.0 # Percentage of ops relative to batch size
def calculate_ops_pct(
k_values: torch.Tensor | None,
p_values: torch.Tensor | None,
vocab_size: int,
batch_size: int,
) -> float:
"""
Calculate the percentage of active top-k and top-p operations.
Returns percentage where 100% = batch_size ops.
E.g., if all rows have both top-k and top-p active, returns 200%.
"""
active_ops = 0
if k_values is not None:
# Count rows where k < vocab_size (active top-k filtering)
active_ops += (k_values < vocab_size).sum().item()
if p_values is not None:
# Count rows where p < 1.0 (active top-p filtering)
active_ops += (p_values < 1.0).sum().item()
return (active_ops / batch_size) * 100 if batch_size > 0 else 0.0
def create_logits(
batch_size: int, vocab_size: int, device: str = "cuda"
) -> torch.Tensor:
"""Create random logits mimicking a realistic LLM distribution.
Uses a Zipf-like probability distribution (rank^-1.1) converted to logits
via log, then randomly permuted per row. This produces a peaked distribution
where a small number of tokens capture most probability mass, similar to
real model outputs.
"""
# Create Zipf-like probabilities: p(rank) ~ rank^(-alpha)
ranks = torch.arange(1, vocab_size + 1, dtype=torch.float32, device=device)
probs = ranks.pow(-1.1)
probs = probs / probs.sum()
# Convert to logits (log-probabilities, unnormalized is fine)
base_logits = probs.log()
# Broadcast to batch and randomly permute each row
logits = base_logits.unsqueeze(0).expand(batch_size, -1).clone()
for i in range(batch_size):
logits[i] = logits[i, torch.randperm(vocab_size, device=device)]
return logits
def measure_memory() -> tuple[int, int]:
"""Return (allocated, reserved) memory in bytes."""
torch.cuda.synchronize()
return torch.cuda.memory_allocated(), torch.cuda.max_memory_allocated()
def reset_memory_stats():
"""Reset peak memory statistics."""
reset_buffer_cache()
torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()
gc.collect()
def benchmark_function(
func,
logits: torch.Tensor,
k: torch.Tensor | None,
p: torch.Tensor | None,
warmup_iters: int = 5,
benchmark_iters: int = 20,
) -> tuple[float, int]:
"""
Benchmark a function and return (avg_time_ms, peak_memory_bytes).
Returns average time in milliseconds and peak memory usage.
"""
# Warmup
for _ in range(warmup_iters):
logits_copy = logits.clone()
func(logits_copy, k, p)
torch.cuda.synchronize()
# Reset memory stats before benchmark
reset_memory_stats()
# Benchmark
start_events = [
torch.cuda.Event(enable_timing=True) for _ in range(benchmark_iters)
]
end_events = [torch.cuda.Event(enable_timing=True) for _ in range(benchmark_iters)]
for i in range(benchmark_iters):
logits_copy = logits.clone()
start_events[i].record()
func(logits_copy, k, p)
end_events[i].record()
torch.cuda.synchronize()
# Calculate timing
times = [
start_events[i].elapsed_time(end_events[i]) for i in range(benchmark_iters)
]
avg_time = sum(times) / len(times)
# Get peak memory
_, peak_memory = measure_memory()
return avg_time, peak_memory
def create_benchmark_configs(
batch_sizes: list[int],
vocab_sizes: list[int],
device: str = "cuda",
) -> list[BenchmarkConfig]:
"""Create all benchmark configurations."""
configs = []
for vocab_size in vocab_sizes:
for batch_size in batch_sizes:
# 1. Top-k only - whole batch (all rows have k < vocab_size)
k_all = torch.full((batch_size,), 50, dtype=torch.int32, device=device)
configs.append(
BenchmarkConfig(
name=f"topk_whole_b{batch_size}_v{vocab_size // 1000}k",
batch_size=batch_size,
vocab_size=vocab_size,
k_values=k_all,
p_values=None,
description=f"Top-k only (whole batch, k=50), "
f"batch={batch_size}, vocab={vocab_size}",
ops_pct=calculate_ops_pct(k_all, None, vocab_size, batch_size),
)
)
# 2. Top-k only - partial batch (half have k=50, half have k=vocab_size)
k_partial = torch.full((batch_size,), 50, dtype=torch.int32, device=device)
k_partial[batch_size // 2 :] = vocab_size # No filtering for second half
configs.append(
BenchmarkConfig(
name=f"topk_partial_b{batch_size}_v{vocab_size // 1000}k",
batch_size=batch_size,
vocab_size=vocab_size,
k_values=k_partial,
p_values=None,
description=f"Top-k only (partial batch, 50% k=50, 50% k=vocab), "
f"batch={batch_size}, vocab={vocab_size}",
ops_pct=calculate_ops_pct(k_partial, None, vocab_size, batch_size),
)
)
# 3. Top-p only - whole batch (all rows have p < 1.0)
p_all = torch.full((batch_size,), 0.9, dtype=torch.float32, device=device)
configs.append(
BenchmarkConfig(
name=f"topp_whole_b{batch_size}_v{vocab_size // 1000}k",
batch_size=batch_size,
vocab_size=vocab_size,
k_values=None,
p_values=p_all,
description=f"Top-p only (whole batch, p=0.9), "
f"batch={batch_size}, vocab={vocab_size}",
ops_pct=calculate_ops_pct(None, p_all, vocab_size, batch_size),
)
)
# 4. Top-p only - partial batch (half have p=0.9, half have p=1.0)
p_partial = torch.full(
(batch_size,), 0.9, dtype=torch.float32, device=device
)
p_partial[batch_size // 2 :] = 1.0 # No filtering for second half
configs.append(
BenchmarkConfig(
name=f"topp_partial_b{batch_size}_v{vocab_size // 1000}k",
batch_size=batch_size,
vocab_size=vocab_size,
k_values=None,
p_values=p_partial,
description=f"Top-p only (partial batch, 50% p=0.9, 50% p=1.0), "
f"batch={batch_size}, vocab={vocab_size}",
ops_pct=calculate_ops_pct(None, p_partial, vocab_size, batch_size),
)
)
# 5. Mix of top-k and top-p (both applied to whole batch)
k_mix = torch.full((batch_size,), 100, dtype=torch.int32, device=device)
p_mix = torch.full((batch_size,), 0.9, dtype=torch.float32, device=device)
configs.append(
BenchmarkConfig(
name=f"topk_topp_whole_b{batch_size}_v{vocab_size // 1000}k",
batch_size=batch_size,
vocab_size=vocab_size,
k_values=k_mix,
p_values=p_mix,
description=f"Top-k + Top-p (whole batch, k=100, p=0.9), "
f"batch={batch_size}, vocab={vocab_size}",
ops_pct=calculate_ops_pct(k_mix, p_mix, vocab_size, batch_size),
)
)
# 6. Mix with partial application (some rows k only, some p only, some both)
k_mixed = torch.full(
(batch_size,), vocab_size, dtype=torch.int32, device=device
)
p_mixed = torch.full((batch_size,), 1.0, dtype=torch.float32, device=device)
# First third: k only
third = batch_size // 3
k_mixed[:third] = 50
# Second third: p only
p_mixed[third : 2 * third] = 0.5
# Last third: both k and p
k_mixed[2 * third :] = 100
p_mixed[2 * third :] = 0.9
configs.append(
BenchmarkConfig(
name=f"mixed_partial_b{batch_size}_v{vocab_size // 1000}k",
batch_size=batch_size,
vocab_size=vocab_size,
k_values=k_mixed,
p_values=p_mixed,
description=f"Mixed partial (1/3 k=50, 1/3 p=0.9, 1/3 both), "
f"batch={batch_size}, vocab={vocab_size}",
ops_pct=calculate_ops_pct(k_mixed, p_mixed, vocab_size, batch_size),
)
)
return configs
def format_memory(bytes_val: int) -> str:
"""Format memory in human-readable form."""
if bytes_val >= 1024**3:
return f"{bytes_val / (1024**3):.2f} GB"
elif bytes_val >= 1024**2:
return f"{bytes_val / (1024**2):.2f} MB"
elif bytes_val >= 1024:
return f"{bytes_val / 1024:.2f} KB"
return f"{bytes_val} B"
def run_benchmark(
configs: list[BenchmarkConfig],
warmup_iters: int = 5,
benchmark_iters: int = 20,
verbose: bool = True,
):
"""Run all benchmarks and print results."""
results = []
print("=" * 100)
print("Top-k/Top-p Benchmark: Triton vs PyTorch Sort-based")
print("=" * 100)
print()
for config in configs:
if verbose:
print(f"Running: {config.description}")
# Create fresh logits for this config
logits = create_logits(config.batch_size, config.vocab_size)
# Benchmark Triton
reset_memory_stats()
triton_time, triton_mem = benchmark_function(
apply_top_k_top_p_triton,
logits,
config.k_values,
config.p_values,
warmup_iters,
benchmark_iters,
)
# Benchmark PyTorch
reset_memory_stats()
pytorch_time, pytorch_mem = benchmark_function(
apply_top_k_top_p_pytorch,
logits,
config.k_values,
config.p_values,
warmup_iters,
benchmark_iters,
)
speedup = pytorch_time / triton_time if triton_time > 0 else float("inf")
mem_ratio = pytorch_mem / triton_mem if triton_mem > 0 else float("inf")
result = {
"config": config,
"triton_time_ms": triton_time,
"pytorch_time_ms": pytorch_time,
"triton_mem": triton_mem,
"pytorch_mem": pytorch_mem,
"speedup": speedup,
"mem_ratio": mem_ratio,
}
results.append(result)
if verbose:
print(f" Triton: {triton_time:.3f} ms, {format_memory(triton_mem)}")
print(f" PyTorch: {pytorch_time:.3f} ms, {format_memory(pytorch_mem)}")
print(f" Speedup: {speedup:.2f}x, Memory ratio: {mem_ratio:.2f}x")
print()
# Clean up
del logits
reset_memory_stats()
return results
def print_summary_table(results: list[dict]):
"""Print a summary table of results."""
print()
print("=" * 130)
print("SUMMARY TABLE")
print("=" * 130)
print()
# Header
header = (
f"{'Scenario':<40} {'Batch':>6} {'Vocab':>7} {'Ops%':>6} "
f"{'Triton (ms)':>12} {'PyTorch (ms)':>13} {'Speedup':>8} "
f"{'Tri Mem':>10} {'Pyt Mem':>10}"
)
print(header)
print("-" * 130)
# Group by scenario type
current_vocab = None
for result in results:
config = result["config"]
# Add separator between vocab sizes
if current_vocab != config.vocab_size:
if current_vocab is not None:
print("-" * 130)
current_vocab = config.vocab_size
scenario = config.name.split("_b")[0] # Extract scenario name
print(
f"{scenario:<40} {config.batch_size:>6} {config.vocab_size:>7} "
f"{config.ops_pct:>5.0f}% "
f"{result['triton_time_ms']:>12.3f} {result['pytorch_time_ms']:>13.3f} "
f"{result['speedup']:>7.2f}x "
f"{format_memory(result['triton_mem']):>10} "
f"{format_memory(result['pytorch_mem']):>10}"
)
print("=" * 130)
def main():
parser = argparse.ArgumentParser(
description="Benchmark Triton vs PyTorch sort-based top-k/top-p implementations"
)
parser.add_argument(
"--batch-sizes",
type=int,
nargs="+",
default=[1, 4, 16, 64, 128, 512, 1024, 2048],
help="Batch sizes to test (default: 1 4 16 64)",
)
parser.add_argument(
"--vocab-sizes",
type=int,
nargs="+",
default=[32768, 131072], # 32k, 128k
help="Vocabulary sizes to test (default: 32768 131072)",
)
parser.add_argument(
"--warmup-iters",
type=int,
default=5,
help="Number of warmup iterations (default: 5)",
)
parser.add_argument(
"--benchmark-iters",
type=int,
default=20,
help="Number of benchmark iterations (default: 20)",
)
parser.add_argument(
"--quiet",
action="store_true",
help="Only print summary table",
)
args = parser.parse_args()
# Print configuration
print(f"Batch sizes: {args.batch_sizes}")
print(f"Vocab sizes: {args.vocab_sizes}")
print(f"Warmup iterations: {args.warmup_iters}")
print(f"Benchmark iterations: {args.benchmark_iters}")
print()
# Check CUDA
if not torch.cuda.is_available():
print("ERROR: CUDA is not available. This benchmark requires a GPU.")
return
device_name = torch.cuda.get_device_name(0)
print(f"GPU: {device_name}")
print()
# Create configs
configs = create_benchmark_configs(
args.batch_sizes,
args.vocab_sizes,
)
# Run benchmarks
results = run_benchmark(
configs,
warmup_iters=args.warmup_iters,
benchmark_iters=args.benchmark_iters,
verbose=not args.quiet,
)
# Print summary
print_summary_table(results)
if __name__ == "__main__":
main()
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from types import TracebackType
# Collect time and generate time metrics
#
# Example Usage:
# collector = TimeCollector(TimeCollector.US)
# for _ in range(total_iteration):
# with collector:
# ...
# collector.dump_avg_max()
class TimeCollector:
NS: int = 1
US: int = NS * 1000
MS: int = US * 1000
S: int = MS * 1000
def __init__(self, scale: int) -> None:
self.cnt: int = 0
self._sum: int = 0
self._max: int | None = None
self.scale = scale
self.start_time: int = time.monotonic_ns()
def collect(self, v: int) -> None:
self.cnt += 1
self._sum += v
if self._max is None:
self._max = v
else:
self._max = max(self._max, v)
def avg(self) -> float | str:
return self._sum * 1.0 / self.cnt / self.scale if self.cnt > 0 else "N/A"
def max(self) -> float | str:
return self._max / self.scale if self._max else "N/A"
def dump_avg_max(self) -> list[float | str]:
return [self.avg(), self.max()]
def __enter__(self) -> None:
self.start_time = time.monotonic_ns()
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
exc_traceback: TracebackType | None,
) -> None:
self.collect(time.monotonic_ns() - self.start_time)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import copy
import itertools
import pickle as pkl
import time
from collections.abc import Callable, Iterable
import torch
import torch.utils.benchmark as TBenchmark
from torch.utils.benchmark import Measurement as TMeasurement
from utils import make_rand_sparse_tensors
from weight_shapes import WEIGHT_SHAPES
from vllm import _custom_ops as ops
from vllm.utils.argparse_utils import FlexibleArgumentParser
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
DEFAULT_TP_SIZES = [1]
# bench
def bench_fn(
label: str, sub_label: str, description: str, fn: Callable, *args, **kwargs
) -> TMeasurement:
min_run_time = 1
globals = {
"args": args,
"kwargs": kwargs,
"fn": fn,
}
return TBenchmark.Timer(
stmt="fn(*args, **kwargs)",
globals=globals,
label=label,
sub_label=sub_label,
description=description,
).blocked_autorange(min_run_time=min_run_time)
def bench_int8(
dtype: torch.dtype, m: int, k: int, n: int, label: str, sub_label: str
) -> Iterable[TMeasurement]:
assert dtype == torch.int8
b_compressed, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k)
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16)
out = ops.cutlass_scaled_sparse_mm(
a, b_compressed, e, scale_a, scale_b, torch.bfloat16
)
out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16)
if not torch.allclose(out, out_ref):
print("Incorrect results")
print(out)
print(out_ref)
else:
print("Correct results")
timers = []
# pytorch impl - bfloat16
timers.append(
bench_fn(
label,
sub_label,
"pytorch_bf16_bf16_bf16_matmul-no-scales",
torch.mm,
a.to(dtype=torch.bfloat16),
b.to(dtype=torch.bfloat16),
)
)
# pytorch impl - float16
timers.append(
bench_fn(
label,
sub_label,
"pytorch_fp16_fp16_fp16_matmul-no-scales",
torch.mm,
a.to(dtype=torch.float16),
b.to(dtype=torch.float16),
)
)
# cutlass impl
timers.append(
bench_fn(
label,
sub_label,
"cutlass_i8_i8_bf16_scaled_mm",
ops.cutlass_scaled_mm,
a,
b,
scale_a,
scale_b,
torch.bfloat16,
)
)
# cutlass with bias
timers.append(
bench_fn(
label,
sub_label,
"cutlass_i8_i8_bf16_scaled_mm_bias",
ops.cutlass_scaled_mm,
a,
b,
scale_a,
scale_b,
torch.bfloat16,
bias,
)
)
# cutlass sparse impl
timers.append(
bench_fn(
label,
sub_label,
"cutlass_i8_i8_bf16_scaled_sparse_mm",
ops.cutlass_scaled_sparse_mm,
a,
b_compressed,
e,
scale_a,
scale_b,
torch.bfloat16,
)
)
# cutlass sparse with bias
timers.append(
bench_fn(
label,
sub_label,
"cutlass_i8_i8_bf16_scaled_sparse_mm_bias",
ops.cutlass_scaled_sparse_mm,
a,
b_compressed,
e,
scale_a,
scale_b,
torch.bfloat16,
bias,
)
)
return timers
def bench_fp8(
dtype: torch.dtype, m: int, k: int, n: int, label: str, sub_label: str
) -> Iterable[TMeasurement]:
assert dtype == torch.float8_e4m3fn
b_compressed, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, k)
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16)
out = ops.cutlass_scaled_sparse_mm(
a, b_compressed, e, scale_a, scale_b, torch.bfloat16
)
out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16)
if not torch.allclose(out, out_ref):
print("Incorrect results")
print(out)
print(out_ref)
else:
print("Correct results")
timers = []
# pytorch impl w. bf16
timers.append(
bench_fn(
label,
sub_label,
"pytorch_bf16_bf16_bf16_matmul-no-scales",
torch.mm,
a.to(dtype=torch.bfloat16, device="cuda"),
b.to(dtype=torch.bfloat16, device="cuda"),
)
)
# pytorch impl: bf16 output, without fp8 fast accum
timers.append(
bench_fn(
label,
sub_label,
"pytorch_fp8_fp8_bf16_scaled_mm",
torch._scaled_mm,
a,
b,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=torch.bfloat16,
)
)
# pytorch impl: bf16 output, with fp8 fast accum
timers.append(
bench_fn(
label,
sub_label,
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum",
torch._scaled_mm,
a,
b,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=torch.bfloat16,
use_fast_accum=True,
)
)
# pytorch impl: fp16 output, without fp8 fast accum
timers.append(
bench_fn(
label,
sub_label,
"pytorch_fp8_fp8_fp16_scaled_mm",
torch._scaled_mm,
a,
b,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=torch.float16,
)
)
# pytorch impl: fp16 output, with fp8 fast accum
timers.append(
bench_fn(
label,
sub_label,
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum",
torch._scaled_mm,
a,
b,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=torch.float16,
use_fast_accum=True,
)
)
# cutlass impl: bf16 output
timers.append(
bench_fn(
label,
sub_label,
"cutlass_fp8_fp8_bf16_scaled_mm",
ops.cutlass_scaled_mm,
a,
b,
scale_a,
scale_b,
torch.bfloat16,
)
)
# cutlass impl: bf16 output
timers.append(
bench_fn(
label,
sub_label,
"cutlass_fp8_fp8_bf16_scaled_sparse_mm",
ops.cutlass_scaled_sparse_mm,
a,
b_compressed,
e,
scale_a,
scale_b,
torch.bfloat16,
)
)
# cutlass impl: fp16 output
timers.append(
bench_fn(
label,
sub_label,
"cutlass_fp8_fp8_fp16_scaled_sparse_mm",
ops.cutlass_scaled_sparse_mm,
a,
b_compressed,
e,
scale_a,
scale_b,
torch.float16,
)
)
# cutlass impl: bf16 output, with bias
timers.append(
bench_fn(
label,
sub_label,
"cutlass_fp8_fp8_bf16_scaled_sparse_mm_bias",
ops.cutlass_scaled_sparse_mm,
a,
b_compressed,
e,
scale_a,
scale_b,
torch.bfloat16,
bias,
)
)
# cutlass impl: fp16 output, with bias
timers.append(
bench_fn(
label,
sub_label,
"cutlass_fp8_fp8_fp16_scaled_sparse_mm_bias",
ops.cutlass_scaled_sparse_mm,
a,
b_compressed,
e,
scale_a,
scale_b,
torch.float16,
bias.to(dtype=torch.float16),
)
)
return timers
def bench(
dtype: torch.dtype, m: int, k: int, n: int, label: str, sub_label: str
) -> Iterable[TMeasurement]:
if dtype == torch.int8:
return bench_int8(dtype, m, k, n, label, sub_label)
if dtype == torch.float8_e4m3fn:
return bench_fp8(dtype, m, k, n, label, sub_label)
raise ValueError(
f"Unsupported dtype {dtype}: should be one of torch.int8, torch.float8_e4m3fn."
)
# runner
def print_timers(timers: Iterable[TMeasurement]):
compare = TBenchmark.Compare(timers)
compare.print()
def run(
dtype: torch.dtype, MKNs: Iterable[tuple[int, int, int]]
) -> Iterable[TMeasurement]:
results = []
for m, k, n in MKNs:
timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm", f"MKN=({m}x{k}x{n})")
print_timers(timers)
results.extend(timers)
return results
# output makers
def make_output(
data: Iterable[TMeasurement],
MKNs: Iterable[tuple[int, int, int]],
base_description: str,
timestamp=None,
):
print(f"== All Results {base_description} ====")
print_timers(data)
# pickle all the results
timestamp = int(time.time()) if timestamp is None else timestamp
with open(f"{base_description}-{timestamp}.pkl", "wb") as f:
pkl.dump(data, f)
# argparse runners
def run_square_bench(args):
dim_sizes = list(range(args.dim_start, args.dim_end + 1, args.dim_increment))
MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
data = run(args.dtype, MKNs)
make_output(data, MKNs, f"square_bench-{args.dtype}")
def run_range_bench(args):
dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment))
n = len(dim_sizes)
Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes
Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes
Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes
MKNs = list(zip(Ms, Ks, Ns))
data = run(args.dtype, MKNs)
make_output(data, MKNs, f"range_bench-{args.dtype}")
def run_model_bench(args):
print("Benchmarking models:")
for i, model in enumerate(args.models):
print(f"[{i}] {model}")
def model_shapes(model_name: str, tp_size: int) -> list[tuple[int, int]]:
KNs = []
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]):
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
KNs.append(KN)
return KNs
model_bench_data = []
models_tps = list(itertools.product(args.models, args.tp_sizes))
for model, tp_size in models_tps:
Ms = args.batch_sizes
KNs = model_shapes(model, tp_size)
MKNs = []
for m in Ms:
for k, n in KNs:
MKNs.append((m, k, n))
data = run(args.dtype, MKNs)
model_bench_data.append(data)
# Print all results
for data, model_tp in zip(model_bench_data, models_tps):
model, tp_size = model_tp
print(f"== Results {args.dtype} {model}-TP{tp_size} ====")
print_timers(data)
timestamp = int(time.time())
all_data = []
for d in model_bench_data:
all_data.extend(d)
# pickle all data
with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f:
pkl.dump(all_data, f)
if __name__ == "__main__":
def to_torch_dtype(dt):
if dt == "int8":
return torch.int8
if dt == "fp8":
return torch.float8_e4m3fn
raise ValueError("unsupported dtype")
parser = FlexibleArgumentParser(
description="""
Benchmark Cutlass GEMM.
To run square GEMMs:
python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 square_bench --dim-start 128 --dim-end 512 --dim-increment 64
To run constant N and K and sweep M:
python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384
To run dimensions from a model:
python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1
Output:
- a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs.
""", # noqa: E501
formatter_class=argparse.RawTextHelpFormatter,
)
parser.add_argument(
"--dtype",
type=to_torch_dtype,
required=True,
help="Available options are ['int8', 'fp8']",
)
subparsers = parser.add_subparsers(dest="cmd")
square_parser = subparsers.add_parser("square_bench")
square_parser.add_argument("--dim-start", type=int, required=True)
square_parser.add_argument("--dim-end", type=int, required=True)
square_parser.add_argument("--dim-increment", type=int, required=True)
square_parser.set_defaults(func=run_square_bench)
range_parser = subparsers.add_parser("range_bench")
range_parser.add_argument("--dim-start", type=int, required=True)
range_parser.add_argument("--dim-end", type=int, required=True)
range_parser.add_argument("--dim-increment", type=int, required=True)
range_parser.add_argument("--m-constant", type=int, default=None)
range_parser.add_argument("--n-constant", type=int, default=None)
range_parser.add_argument("--k-constant", type=int, default=None)
range_parser.set_defaults(func=run_range_bench)
model_parser = subparsers.add_parser("model_bench")
model_parser.add_argument(
"--models",
nargs="+",
type=str,
default=DEFAULT_MODELS,
choices=WEIGHT_SHAPES.keys(),
)
model_parser.add_argument(
"--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES
)
model_parser.add_argument(
"--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
)
model_parser.set_defaults(func=run_model_bench)
args = parser.parse_args()
args.func(args)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Cutlass bench utils
import torch
import vllm._custom_ops as ops
def to_fp8(tensor: torch.Tensor) -> torch.Tensor:
finfo = torch.finfo(torch.float8_e4m3fn)
return torch.round(tensor.clamp(min=finfo.min, max=finfo.max)).to(
dtype=torch.float8_e4m3fn
)
def to_int8(tensor: torch.Tensor) -> torch.Tensor:
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
def to_bf16(tensor: torch.Tensor) -> torch.Tensor:
return tensor.to(dtype=torch.bfloat16)
def to_fp16(tensor: torch.Tensor) -> torch.Tensor:
return tensor.to(dtype=torch.float16)
def make_rand_tensors(
dtype: torch.dtype, m: int, n: int, k: int
) -> tuple[torch.Tensor, torch.Tensor]:
a = torch.randn((m, k), device="cuda") * 5
b = torch.randn((n, k), device="cuda").t() * 5
if dtype == torch.int8:
return to_int8(a), to_int8(b)
if dtype == torch.float8_e4m3fn:
return to_fp8(a), to_fp8(b)
raise ValueError("unsupported dtype")
def prune_to_2_4(tensor):
# Reshape tensor to [N, 4] where N is number of groups of 4
original_shape = tensor.shape
reshaped = tensor.reshape(-1, 4)
# Get indices of top 2 absolute values in each group of 4
_, indices = torch.topk(torch.abs(reshaped), k=2, dim=1)
# Create binary mask
mask = torch.zeros_like(reshaped)
mask.scatter_(dim=1, index=indices, src=torch.ones_like(indices, dtype=mask.dtype))
# Apply mask and reshape back
pruned = reshaped * mask
# Turn all -0.0 to 0.0
pruned[pruned == -0.0] = 0.0
return pruned.reshape(original_shape)
def make_rand_sparse_tensors(
dtype: torch.dtype, m: int, n: int, k: int
) -> tuple[torch.Tensor, torch.Tensor]:
a = torch.randn((m, k), device="cuda") * 5
b = torch.randn((n, k), device="cuda").t() * 5
b = prune_to_2_4(b.t()).t()
if dtype == torch.int8:
a, b = to_int8(a), to_int8(b)
elif dtype == torch.float8_e4m3fn:
a, b = to_fp8(a), to_fp8(b)
elif dtype == torch.float16:
a, b = to_fp16(a), to_fp16(b)
elif dtype == torch.bfloat16:
a, b = to_bf16(a), to_bf16(b)
else:
raise ValueError("unsupported dtype")
b_compressed, e = ops.cutlass_sparse_compress(b.t())
# Compressed B, Metadata, Original A, B
return b_compressed, e, a, b
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import copy
import itertools
import pickle as pkl
import time
from collections.abc import Callable, Iterable
import torch
import torch.utils.benchmark as TBenchmark
from torch.utils.benchmark import Measurement as TMeasurement
from utils import make_rand_tensors
from weight_shapes import WEIGHT_SHAPES
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
w8a8_triton_block_scaled_mm,
)
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.math_utils import cdiv
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
DEFAULT_TP_SIZES = [1]
# bench
def bench_fn(
label: str, sub_label: str, description: str, fn: Callable, *args, **kwargs
) -> TMeasurement:
min_run_time = 1
globals = {
"args": args,
"kwargs": kwargs,
"fn": fn,
}
return TBenchmark.Timer(
stmt="fn(*args, **kwargs)",
globals=globals,
label=label,
sub_label=sub_label,
description=description,
).blocked_autorange(min_run_time=min_run_time)
def bench_int8(
dtype: torch.dtype,
m: int,
k: int,
n: int,
label: str,
sub_label: str,
bench_kernels: list[str] | None = None,
) -> Iterable[TMeasurement]:
"""Benchmark INT8-based kernels."""
assert dtype == torch.int8
a, b = make_rand_tensors(torch.int8, m, n, k)
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16)
azp = torch.zeros((m,), device="cuda", dtype=torch.int32)
azp_adj = torch.zeros((n,), device="cuda", dtype=torch.int32)
bench_fns = {
"pytorch_bf16_bf16_bf16_matmul-no-scales": lambda: torch.mm(
a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16)
),
"pytorch_fp16_fp16_fp16_matmul-no-scales": lambda: torch.mm(
a.to(dtype=torch.float16), b.to(dtype=torch.float16)
),
"cutlass_i8_i8_bf16_scaled_mm": lambda: ops.cutlass_scaled_mm(
a, b, scale_a, scale_b, torch.bfloat16
),
"cutlass_i8_i8_bf16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm(
a, b, scale_a, scale_b, torch.bfloat16, bias
),
"cutlass_i8_i8_bf16_scaled_mm_azp": lambda: ops.cutlass_scaled_mm_azp(
a, b, scale_a, scale_b, torch.bfloat16, azp_adj
),
"cutlass_i8_i8_bf16_scaled_mm_azp_bias": lambda: ops.cutlass_scaled_mm_azp(
a, b, scale_a, scale_b, torch.bfloat16, azp_adj, None, bias
),
"cutlass_i8_i8_bf16_scaled_mm_azp_pt": lambda: ops.cutlass_scaled_mm_azp(
a, b, scale_a, scale_b, torch.bfloat16, azp_adj, azp
),
"cutlass_i8_i8_bf16_scaled_mm_azp_pt_bias": lambda: ops.cutlass_scaled_mm_azp(
a, b, scale_a, scale_b, torch.bfloat16, azp_adj, azp, bias
),
}
timers = []
for name, fn in bench_fns.items():
# If bench_kernels is None, run all. Otherwise, run only exact matches.
if bench_kernels is None or name in bench_kernels:
print(f"Running {name}")
timers.append(bench_fn(label, sub_label, name, fn))
return timers
def bench_fp8(
dtype: torch.dtype,
m: int,
k: int,
n: int,
label: str,
sub_label: str,
bench_kernels: list[str] | None = None,
) -> Iterable[TMeasurement]:
"""Benchmark FP8-based kernels."""
assert dtype == torch.float8_e4m3fn
a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k)
a_cont = a.contiguous()
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
block_scale_a = torch.rand((m, cdiv(k, 128)), device="cuda", dtype=torch.float32)
block_scale_b = torch.rand(
cdiv(k, 128), cdiv(n, 128), device="cuda", dtype=torch.float32
)
block_scale_a_M_major = block_scale_a.t().contiguous().t()
block_scale_b_K_major = block_scale_b.t().contiguous().t()
bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16)
print(m, k, n)
bench_fns = {
"pytorch_bf16_bf16_bf16_matmul-no-scales": lambda: torch.mm(
a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16)
),
"pytorch_fp16_fp16_fp16_matmul-no-scales": lambda: torch.mm(
a.to(dtype=torch.float16), b.to(dtype=torch.float16)
),
"pytorch_fp8_fp8_fp16_scaled_mm": lambda: torch._scaled_mm(
a, b, scale_a, scale_b, out_dtype=torch.float16
),
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum": lambda: torch._scaled_mm(
a, b, scale_a, scale_b, out_dtype=torch.float16, use_fast_accum=True
),
"pytorch_fp8_fp8_bf16_scaled_mm": lambda: torch._scaled_mm(
a, b, scale_a, scale_b, out_dtype=torch.bfloat16
),
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum": lambda: torch._scaled_mm(
a, b, scale_a, scale_b, out_dtype=torch.bfloat16, use_fast_accum=True
),
"cutlass_fp8_fp8_bf16_scaled_mm": lambda: ops.cutlass_scaled_mm(
a, b, scale_a, scale_b, torch.bfloat16
),
"cutlass_fp8_fp8_fp16_scaled_mm": lambda: ops.cutlass_scaled_mm(
a, b, scale_a, scale_b, torch.float16
),
"cutlass_fp8_fp8_bf16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm(
a, b, scale_a, scale_b, torch.bfloat16, bias
),
"cutlass_fp8_fp8_fp16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm(
a, b, scale_a, scale_b, torch.float16, bias.to(dtype=torch.float16)
),
"triton_fp8_fp8_fp16_scaled_mm_blockwise": lambda: w8a8_triton_block_scaled_mm(
a_cont, b.t(), block_scale_a, block_scale_b.t(), (128, 128)
),
"cutlass_fp8_fp8_fp16_scaled_mm_blockwise": lambda: ops.cutlass_scaled_mm(
a, b, block_scale_a_M_major, block_scale_b_K_major, torch.float16
),
}
timers = []
for name, fn in bench_fns.items():
# If bench_kernels is None, run all. Otherwise, run only exact matches.
if bench_kernels is None or name in bench_kernels:
print(f"Running {name}")
timers.append(bench_fn(label, sub_label, name, fn))
return timers
def bench(
dtype: torch.dtype,
m: int,
k: int,
n: int,
label: str,
sub_label: str,
bench_kernels: list[str] | None = None,
) -> Iterable[TMeasurement]:
if dtype == torch.int8:
return bench_int8(dtype, m, k, n, label, sub_label, bench_kernels)
if dtype == torch.float8_e4m3fn:
return bench_fp8(dtype, m, k, n, label, sub_label, bench_kernels)
raise ValueError("unsupported type")
# runner
def print_timers(timers: Iterable[TMeasurement]):
compare = TBenchmark.Compare(timers)
compare.print()
def run(
dtype: torch.dtype,
MKNs: Iterable[tuple[int, int, int]],
bench_kernels: list[str] | None = None,
) -> Iterable[TMeasurement]:
results = []
for m, k, n in MKNs:
timers = bench(
dtype,
m,
k,
n,
f"scaled-{dtype}-gemm",
f"MKN=({m}x{k}x{n})",
bench_kernels=bench_kernels,
)
print_timers(timers)
results.extend(timers)
return results
def make_output(
data: Iterable[TMeasurement],
MKNs: Iterable[tuple[int, int, int]],
base_description: str,
timestamp=None,
):
print(f"== All Results {base_description} ====")
print_timers(data)
# pickle all the results
timestamp = int(time.time()) if timestamp is None else timestamp
with open(f"{base_description}-{timestamp}.pkl", "wb") as f:
pkl.dump(data, f)
def run_square_bench(args):
dim_sizes = list(range(args.dim_start, args.dim_end + 1, args.dim_increment))
MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
data = run(args.dtype, MKNs, bench_kernels=args.kernels)
make_output(data, MKNs, f"square_bench-{args.dtype}")
def run_range_bench(args):
dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment))
n = len(dim_sizes)
Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes
Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes
Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes
MKNs = list(zip(Ms, Ks, Ns))
data = run(args.dtype, MKNs, bench_kernels=args.kernels)
make_output(data, MKNs, f"range_bench-{args.dtype}")
def run_model_bench(args):
print("Benchmarking models:")
for i, model in enumerate(args.models):
print(f"[{i}] {model}")
def model_shapes(model_name: str, tp_size: int) -> list[tuple[int, int]]:
KNs = []
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]):
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
KNs.append(KN)
return KNs
model_bench_data = []
models_tps = list(itertools.product(args.models, args.tp_sizes))
for model, tp_size in models_tps:
Ms = args.batch_sizes
KNs = model_shapes(model, tp_size)
MKNs = []
for m in Ms:
for k, n in KNs:
MKNs.append((m, k, n))
data = run(args.dtype, MKNs, bench_kernels=args.kernels)
model_bench_data.append(data)
# Print all results
for data, model_tp in zip(model_bench_data, models_tps):
model, tp_size = model_tp
print(f"== Results {args.dtype} {model}-TP{tp_size} ====")
print_timers(data)
timestamp = int(time.time())
all_data = []
for d in model_bench_data:
all_data.extend(d)
# pickle all data
with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f:
pkl.dump(all_data, f)
if __name__ == "__main__":
def to_torch_dtype(dt):
if dt == "int8":
return torch.int8
if dt == "fp8":
return torch.float8_e4m3fn
raise ValueError("unsupported dtype")
parser = FlexibleArgumentParser(
description="""
Benchmark Cutlass GEMM.
To run square GEMMs:
python3 ./benchmarks/cutlass_benchmarks/w8a8_benchmarks.py --dtype fp8 square_bench --dim-start 128 --dim-end 512 --dim-increment 64
To run constant N and K and sweep M:
python3 ./benchmarks/cutlass_benchmarks/w8a8_benchmarks.py --dtype fp8 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384
To run dimensions from a model:
python3 ./benchmarks/cutlass_benchmarks/w8a8_benchmarks.py --dtype fp8 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1
Output:
- a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs.
""", # noqa: E501
formatter_class=argparse.RawTextHelpFormatter,
)
parser.add_argument(
"--dtype",
type=to_torch_dtype,
required=True,
help="Available options are ['int8', 'fp8']",
)
parser.add_argument(
"--kernels",
nargs="+",
type=str,
default=None,
help="Exact names of the kernels to benchmark. If not set, runs all kernels.",
)
subparsers = parser.add_subparsers(dest="cmd")
square_parser = subparsers.add_parser("square_bench")
square_parser.add_argument("--dim-start", type=int, required=True)
square_parser.add_argument("--dim-end", type=int, required=True)
square_parser.add_argument("--dim-increment", type=int, required=True)
square_parser.set_defaults(func=run_square_bench)
range_parser = subparsers.add_parser("range_bench")
range_parser.add_argument("--dim-start", type=int, required=True)
range_parser.add_argument("--dim-end", type=int, required=True)
range_parser.add_argument("--dim-increment", type=int, required=True)
range_parser.add_argument("--m-constant", type=int, default=None)
range_parser.add_argument("--n-constant", type=int, default=None)
range_parser.add_argument("--k-constant", type=int, default=None)
range_parser.set_defaults(func=run_range_bench)
model_parser = subparsers.add_parser("model_bench")
model_parser.add_argument(
"--models",
nargs="+",
type=str,
default=DEFAULT_MODELS,
choices=WEIGHT_SHAPES.keys(),
)
model_parser.add_argument(
"--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES
)
model_parser.add_argument(
"--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
)
model_parser.set_defaults(func=run_model_bench)
args = parser.parse_args()
args.func(args)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Weight Shapes are in the format
# ([K, N], TP_SPLIT_DIM)
# Example:
# A shape of ([14336, 4096], 0) indicates the following GEMM shape,
# - TP1 : K = 14336, N = 4096
# - TP2 : K = 7168, N = 4096
# A shape of ([4096, 6144], 1) indicates the following GEMM shape,
# - TP1 : K = 4096, N = 6144
# - TP4 : K = 4096, N = 1536
# TP1 shapes
WEIGHT_SHAPES = {
"mistralai/Mistral-7B-v0.1": [
([4096, 6144], 1),
([4096, 4096], 0),
([4096, 28672], 1),
([14336, 4096], 0),
],
"meta-llama/Llama-2-7b-hf": [
([4096, 12288], 1),
([4096, 4096], 0),
([4096, 22016], 1),
([11008, 4096], 0),
],
"meta-llama/Llama-3-8b": [
([4096, 6144], 1),
([4096, 4096], 0),
([4096, 28672], 1),
([14336, 4096], 0),
],
"meta-llama/Llama-2-13b-hf": [
([5120, 15360], 1),
([5120, 5120], 0),
([5120, 27648], 1),
([13824, 5120], 0),
],
"meta-llama/Llama-2-70b-hf": [
([8192, 10240], 1),
([8192, 8192], 0),
([8192, 57344], 1),
([28672, 8192], 0),
],
}
#!/bin/bash
# benchmark the overhead of disaggregated prefill.
# methodology:
# - send all request to prefill vLLM instance. It will buffer KV cache.
# - then send all request to decode instance.
# - The TTFT of decode instance is the overhead.
set -ex
kill_gpu_processes() {
# kill all processes on GPU.
pgrep pt_main_thread | xargs -r kill -9
pgrep python3 | xargs -r kill -9
# vLLM now names the process with VLLM prefix after https://github.com/vllm-project/vllm/pull/21445
pgrep VLLM | xargs -r kill -9
sleep 10
# remove vllm config file
rm -rf ~/.config/vllm
# Print the GPU memory usage
# so that we know if all GPU processes are killed.
gpu_memory_usage=$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits -i 0)
# The memory usage should be 0 MB.
echo "GPU 0 Memory Usage: $gpu_memory_usage MB"
}
wait_for_server() {
# wait for vllm server to start
# return 1 if vllm server crashes
local port=$1
timeout 1200 bash -c "
until curl -s localhost:${port}/v1/completions > /dev/null; do
sleep 1
done" && return 0 || return 1
}
benchmark() {
export VLLM_LOGGING_LEVEL=DEBUG
export VLLM_HOST_IP=$(hostname -I | awk '{print $1}')
# compare chunked prefill with disaggregated prefill
results_folder="./results"
model="meta-llama/Meta-Llama-3.1-8B-Instruct"
dataset_name="sonnet"
dataset_path="../sonnet_4x.txt"
num_prompts=10
qps=$1
prefix_len=50
input_len=2048
output_len=$2
CUDA_VISIBLE_DEVICES=0 vllm serve $model \
--port 8100 \
--max-model-len 10000 \
--gpu-memory-utilization 0.6 \
--kv-transfer-config \
'{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":5e9}' &
CUDA_VISIBLE_DEVICES=1 vllm serve $model \
--port 8200 \
--max-model-len 10000 \
--gpu-memory-utilization 0.6 \
--kv-transfer-config \
'{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2,"kv_buffer_size":5e9}' &
wait_for_server 8100
wait_for_server 8200
# let the prefill instance finish prefill
vllm bench serve \
--backend vllm \
--model $model \
--dataset-name $dataset_name \
--dataset-path $dataset_path \
--sonnet-input-len $input_len \
--sonnet-output-len "$output_len" \
--sonnet-prefix-len $prefix_len \
--num-prompts $num_prompts \
--port 8100 \
--save-result \
--result-dir $results_folder \
--result-filename disagg_prefill_tp1.json \
--request-rate "inf"
# send the request to decode.
# The TTFT of this command will be the overhead of disagg prefill impl.
vllm bench serve \
--backend vllm \
--model $model \
--dataset-name $dataset_name \
--dataset-path $dataset_path \
--sonnet-input-len $input_len \
--sonnet-output-len "$output_len" \
--sonnet-prefix-len $prefix_len \
--num-prompts $num_prompts \
--port 8200 \
--save-result \
--result-dir $results_folder \
--result-filename disagg_prefill_tp1_overhead.json \
--request-rate "$qps"
kill_gpu_processes
}
main() {
(which wget && which curl) || (apt-get update && apt-get install -y wget curl)
(which jq) || (apt-get -y install jq)
(which socat) || (apt-get -y install socat)
pip install quart httpx datasets
cd "$(dirname "$0")"
cd ..
# create sonnet-4x.txt
echo "" > sonnet_4x.txt
for _ in {1..4}
do
cat sonnet.txt >> sonnet_4x.txt
done
cd disagg_benchmarks
rm -rf results
mkdir results
default_qps=1
default_output_len=1
benchmark $default_qps $default_output_len
}
main "$@"
#!/bin/bash
# Requirement: 2x GPUs.
# Model: meta-llama/Meta-Llama-3.1-8B-Instruct
# Query: 1024 input tokens, 6 output tokens, QPS 2/4/6/8, 100 requests
# Resource: 2x GPU
# Approaches:
# 2. Chunked prefill: 2 vllm instance with tp=4, equivalent to 1 tp=4 instance with QPS 4
# 3. Disaggregated prefill: 1 prefilling instance and 1 decoding instance
# Prefilling instance: max_output_token=1
# Decoding instance: force the input tokens be the same across requests to bypass prefilling
set -ex
kill_gpu_processes() {
# kill all processes on GPU.
pgrep pt_main_thread | xargs -r kill -9
pgrep python3 | xargs -r kill -9
# vLLM now names the process with VLLM prefix after https://github.com/vllm-project/vllm/pull/21445
pgrep VLLM | xargs -r kill -9
for port in 8000 8100 8200; do lsof -t -i:$port | xargs -r kill -9; done
sleep 1
}
wait_for_server() {
# wait for vllm server to start
# return 1 if vllm server crashes
local port=$1
timeout 1200 bash -c "
until curl -s localhost:${port}/v1/completions > /dev/null; do
sleep 1
done" && return 0 || return 1
}
launch_chunked_prefill() {
model="meta-llama/Meta-Llama-3.1-8B-Instruct"
# disagg prefill
CUDA_VISIBLE_DEVICES=0 vllm serve $model \
--port 8100 \
--max-model-len 10000 \
--enable-chunked-prefill \
--gpu-memory-utilization 0.6 &
CUDA_VISIBLE_DEVICES=1 vllm serve $model \
--port 8200 \
--max-model-len 10000 \
--enable-chunked-prefill \
--gpu-memory-utilization 0.6 &
wait_for_server 8100
wait_for_server 8200
python3 round_robin_proxy.py &
sleep 1
}
launch_disagg_prefill() {
model="meta-llama/Meta-Llama-3.1-8B-Instruct"
# disagg prefill
CUDA_VISIBLE_DEVICES=0 vllm serve $model \
--port 8100 \
--max-model-len 10000 \
--gpu-memory-utilization 0.6 \
--kv-transfer-config \
'{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":5e9}' &
CUDA_VISIBLE_DEVICES=1 vllm serve $model \
--port 8200 \
--max-model-len 10000 \
--gpu-memory-utilization 0.6 \
--kv-transfer-config \
'{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2,"kv_buffer_size":5e9}' &
wait_for_server 8100
wait_for_server 8200
python3 disagg_prefill_proxy_server.py &
sleep 1
}
benchmark() {
results_folder="./results"
model="meta-llama/Meta-Llama-3.1-8B-Instruct"
dataset_name="sonnet"
dataset_path="../sonnet_4x.txt"
num_prompts=100
qps=$1
prefix_len=50
input_len=1024
output_len=$2
tag=$3
vllm bench serve \
--backend vllm \
--model $model \
--dataset-name $dataset_name \
--dataset-path $dataset_path \
--sonnet-input-len $input_len \
--sonnet-output-len "$output_len" \
--sonnet-prefix-len $prefix_len \
--num-prompts $num_prompts \
--port 8000 \
--save-result \
--result-dir $results_folder \
--result-filename "$tag"-qps-"$qps".json \
--request-rate "$qps"
sleep 2
}
main() {
(which wget && which curl) || (apt-get update && apt-get install -y wget curl)
(which jq) || (apt-get -y install jq)
(which socat) || (apt-get -y install socat)
(which lsof) || (apt-get -y install lsof)
pip install quart httpx matplotlib aiohttp datasets
cd "$(dirname "$0")"
cd ..
# create sonnet-4x.txt so that we can sample 2048 tokens for input
echo "" > sonnet_4x.txt
for _ in {1..4}
do
cat sonnet.txt >> sonnet_4x.txt
done
cd disagg_benchmarks
rm -rf results
mkdir results
default_output_len=6
export VLLM_HOST_IP=$(hostname -I | awk '{print $1}')
launch_chunked_prefill
for qps in 2 4 6 8; do
benchmark $qps $default_output_len chunked_prefill
done
kill_gpu_processes
launch_disagg_prefill
for qps in 2 4 6 8; do
benchmark $qps $default_output_len disagg_prefill
done
kill_gpu_processes
python3 visualize_benchmark_results.py
}
main "$@"
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import asyncio
import logging
import os
import time
import uuid
from urllib.parse import urlparse
import aiohttp
from quart import Quart, Response, make_response, request
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def parse_args():
"""parse command line arguments"""
parser = argparse.ArgumentParser(description="vLLM P/D disaggregation proxy server")
# Add args
parser.add_argument(
"--timeout",
type=float,
default=6 * 60 * 60,
help="Timeout for backend service requests in seconds (default: 21600)",
)
parser.add_argument(
"--port",
type=int,
default=8000,
help="Port to run the server on (default: 8000)",
)
parser.add_argument(
"--prefill-url",
type=str,
default="http://localhost:8100",
help="Prefill service base URL (protocol + host[:port])",
)
parser.add_argument(
"--decode-url",
type=str,
default="http://localhost:8200",
help="Decode service base URL (protocol + host[:port])",
)
parser.add_argument(
"--kv-host",
type=str,
default="localhost",
help="Hostname or IP used by KV transfer (default: localhost)",
)
parser.add_argument(
"--prefill-kv-port",
type=int,
default=14579,
help="Prefill KV port (default: 14579)",
)
parser.add_argument(
"--decode-kv-port",
type=int,
default=14580,
help="Decode KV port (default: 14580)",
)
return parser.parse_args()
def main():
"""parse command line arguments"""
args = parse_args()
# Initialize configuration using command line parameters
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=args.timeout)
PREFILL_SERVICE_URL = args.prefill_url
DECODE_SERVICE_URL = args.decode_url
PORT = args.port
PREFILL_KV_ADDR = f"{args.kv_host}:{args.prefill_kv_port}"
DECODE_KV_ADDR = f"{args.kv_host}:{args.decode_kv_port}"
logger.info(
"Proxy resolved KV addresses -> prefill: %s, decode: %s",
PREFILL_KV_ADDR,
DECODE_KV_ADDR,
)
app = Quart(__name__)
# Attach the configuration object to the application instance so helper
# coroutines can read the resolved backend URLs and timeouts without using
# globals.
app.config.update(
{
"AIOHTTP_TIMEOUT": AIOHTTP_TIMEOUT,
"PREFILL_SERVICE_URL": PREFILL_SERVICE_URL,
"DECODE_SERVICE_URL": DECODE_SERVICE_URL,
"PREFILL_KV_ADDR": PREFILL_KV_ADDR,
"DECODE_KV_ADDR": DECODE_KV_ADDR,
}
)
def _normalize_base_url(url: str) -> str:
"""Remove any trailing slash so path joins behave predictably."""
return url.rstrip("/")
def _get_host_port(url: str) -> str:
"""Return the hostname:port portion for logging and KV headers."""
parsed = urlparse(url)
host = parsed.hostname or "localhost"
port = parsed.port
if port is None:
port = 80 if parsed.scheme == "http" else 443
return f"{host}:{port}"
PREFILL_BASE = _normalize_base_url(PREFILL_SERVICE_URL)
DECODE_BASE = _normalize_base_url(DECODE_SERVICE_URL)
KV_TARGET = _get_host_port(DECODE_SERVICE_URL)
def _build_headers(request_id: str) -> dict[str, str]:
"""Construct the headers expected by vLLM's P2P disagg connector."""
headers: dict[str, str] = {"X-Request-Id": request_id, "X-KV-Target": KV_TARGET}
api_key = os.environ.get("OPENAI_API_KEY")
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
return headers
async def _run_prefill(
request_path: str,
payload: dict,
headers: dict[str, str],
request_id: str,
):
url = f"{PREFILL_BASE}{request_path}"
start_ts = time.perf_counter()
logger.info("[prefill] start request_id=%s url=%s", request_id, url)
try:
async with (
aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session,
session.post(url=url, json=payload, headers=headers) as resp,
):
if resp.status != 200:
error_text = await resp.text()
raise RuntimeError(
f"Prefill backend error {resp.status}: {error_text}"
)
await resp.read()
logger.info(
"[prefill] done request_id=%s status=%s elapsed=%.2fs",
request_id,
resp.status,
time.perf_counter() - start_ts,
)
except asyncio.TimeoutError as exc:
raise RuntimeError(f"Prefill service timeout at {url}") from exc
except aiohttp.ClientError as exc:
raise RuntimeError(f"Prefill service unavailable at {url}") from exc
async def _stream_decode(
request_path: str,
payload: dict,
headers: dict[str, str],
request_id: str,
):
url = f"{DECODE_BASE}{request_path}"
# Stream tokens from the decode service once the prefill stage has
# materialized KV caches on the target workers.
logger.info("[decode] start request_id=%s url=%s", request_id, url)
try:
async with (
aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session,
session.post(url=url, json=payload, headers=headers) as resp,
):
if resp.status != 200:
error_text = await resp.text()
logger.error(
"Decode backend error %s - %s", resp.status, error_text
)
err_msg = (
'{"error": "Decode backend error ' + str(resp.status) + '"}'
)
yield err_msg.encode()
return
logger.info(
"[decode] streaming response request_id=%s status=%s",
request_id,
resp.status,
)
async for chunk_bytes in resp.content.iter_chunked(1024):
yield chunk_bytes
logger.info("[decode] finished streaming request_id=%s", request_id)
except asyncio.TimeoutError:
logger.error("Decode service timeout at %s", url)
yield b'{"error": "Decode service timeout"}'
except aiohttp.ClientError as exc:
logger.error("Decode service error at %s: %s", url, exc)
yield b'{"error": "Decode service unavailable"}'
async def process_request():
"""Process a single request through prefill and decode stages"""
try:
original_request_data = await request.get_json()
# Create prefill request (max_tokens=1)
prefill_request = original_request_data.copy()
prefill_request["max_tokens"] = 1
if "max_completion_tokens" in prefill_request:
prefill_request["max_completion_tokens"] = 1
# Execute prefill stage
# The request id encodes both KV socket addresses so the backend can
# shuttle tensors directly via NCCL once the prefill response
# completes.
request_id = (
f"___prefill_addr_{PREFILL_KV_ADDR}___decode_addr_"
f"{DECODE_KV_ADDR}_{uuid.uuid4().hex}"
)
headers = _build_headers(request_id)
await _run_prefill(request.path, prefill_request, headers, request_id)
# Execute decode stage and stream response
# Pass the unmodified user request so the decode phase can continue
# sampling with the already-populated KV cache.
generator = _stream_decode(
request.path, original_request_data, headers, request_id
)
response = await make_response(generator)
response.timeout = None # Disable timeout for streaming response
return response
except Exception:
logger.exception("Error processing request")
return Response(
response=b'{"error": "Internal server error"}',
status=500,
content_type="application/json",
)
@app.route("/v1/completions", methods=["POST"])
async def handle_request():
"""Handle incoming API requests with concurrency and rate limiting"""
try:
return await process_request()
except asyncio.CancelledError:
logger.warning("Request cancelled")
return Response(
response=b'{"error": "Request cancelled"}',
status=503,
content_type="application/json",
)
# Start the Quart server with host can be set to 0.0.0.0
app.run(port=PORT)
if __name__ == "__main__":
main()
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import itertools
import aiohttp
from aiohttp import web
class RoundRobinProxy:
def __init__(self, target_ports):
self.target_ports = target_ports
self.port_cycle = itertools.cycle(self.target_ports)
async def handle_request(self, request):
target_port = next(self.port_cycle)
target_url = f"http://localhost:{target_port}{request.path_qs}"
async with aiohttp.ClientSession() as session:
try:
# Forward the request
async with session.request(
method=request.method,
url=target_url,
headers=request.headers,
data=request.content,
) as response:
# Start sending the response
resp = web.StreamResponse(
status=response.status, headers=response.headers
)
await resp.prepare(request)
# Stream the response content
async for chunk in response.content.iter_any():
await resp.write(chunk)
await resp.write_eof()
return resp
except Exception as e:
return web.Response(text=f"Error: {str(e)}", status=500)
async def main():
proxy = RoundRobinProxy([8100, 8200])
app = web.Application()
app.router.add_route("*", "/{path:.*}", proxy.handle_request)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "localhost", 8000)
await site.start()
print("Proxy server started on http://localhost:8000")
# Keep the server running
await asyncio.Event().wait()
if __name__ == "__main__":
asyncio.run(main())
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import matplotlib.pyplot as plt
import pandas as pd
if __name__ == "__main__":
data = []
for name in ["disagg_prefill", "chunked_prefill"]:
for qps in [2, 4, 6, 8]:
with open(f"results/{name}-qps-{qps}.json") as f:
x = json.load(f)
x["name"] = name
x["qps"] = qps
data.append(x)
df = pd.DataFrame.from_dict(data)
dis_df = df[df["name"] == "disagg_prefill"]
chu_df = df[df["name"] == "chunked_prefill"]
plt.style.use("bmh")
plt.rcParams["font.size"] = 20
for key in [
"mean_ttft_ms",
"median_ttft_ms",
"p99_ttft_ms",
"mean_itl_ms",
"median_itl_ms",
"p99_itl_ms",
]:
fig, ax = plt.subplots(figsize=(11, 7))
plt.plot(
dis_df["qps"], dis_df[key], label="disagg_prefill", marker="o", linewidth=4
)
plt.plot(
chu_df["qps"], chu_df[key], label="chunked_prefill", marker="o", linewidth=4
)
ax.legend()
ax.set_xlabel("QPS")
ax.set_ylabel(key)
ax.set_ylim(bottom=0)
fig.savefig(f"results/{key}.png")
plt.close(fig)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pickle as pkl
import time
from collections.abc import Callable, Iterable
from dataclasses import dataclass
from itertools import product
import torch
import torch.utils.benchmark as TBenchmark
from torch.utils.benchmark import Measurement as TMeasurement
from tqdm import tqdm
import vllm._custom_ops as ops
from vllm.benchmarks.lib.utils import default_vllm_config
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8,
)
@dataclass
class bench_params_t:
num_tokens: int
hidden_size: int
add_residual: bool
dtype: torch.dtype
group_size: list[int]
def description(self):
return (
f"N {self.num_tokens} "
f"x D {self.hidden_size} "
f"x R {self.add_residual} "
f"x DT {self.dtype}"
f"x GS {self.group_size}"
)
def get_bench_params() -> list[bench_params_t]:
## Test Fixtures
NUM_TOKENS = [2**x for x in range(11)]
HIDDEN_SIZES = list(range(1024, 8129, 1024))
ADD_RESIDUAL = [True, False]
DTYPES = [torch.bfloat16, torch.float]
GROUP_SIZES = [[1, 64], [1, 128]]
combinations = product(NUM_TOKENS, HIDDEN_SIZES, ADD_RESIDUAL, DTYPES, GROUP_SIZES)
bench_params = list(
map(lambda x: bench_params_t(x[0], x[1], x[2], x[3], x[4]), combinations)
)
return bench_params
# Reference impls
def unfused_int8_impl(
rms_norm_layer: RMSNorm,
x: torch.Tensor,
residual: torch.Tensor | None,
quant_dtype: torch.dtype,
group_size: list[int],
):
# Norm
torch_out = None
if residual is None:
torch_out = rms_norm_layer.forward_cuda(x, residual)
else:
torch_out, _ = rms_norm_layer.forward_cuda(x, residual)
# Quant
torch_out, _, _ = ops.scaled_int8_quant(torch_out)
def unfused_fp8_impl(
rms_norm_layer: RMSNorm,
x: torch.Tensor,
residual: torch.Tensor | None,
quant_dtype: torch.dtype,
group_size: list[int],
):
# Norm
torch_out = None
if residual is None:
torch_out = rms_norm_layer.forward_cuda(x, residual)
else:
torch_out, _ = rms_norm_layer.forward_cuda(x, residual)
# Quant
torch_out, _ = ops.scaled_fp8_quant(torch_out)
def unfused_groupwise_fp8_impl(
rms_norm_layer: RMSNorm,
x: torch.Tensor,
residual: torch.Tensor | None,
quant_dtype: torch.dtype,
group_size: list[int],
):
# Norm
torch_out = None
if residual is None:
torch_out = rms_norm_layer.forward_cuda(x, residual)
else:
torch_out, _ = rms_norm_layer.forward_cuda(x, residual)
# Quant
torch_out, _ = per_token_group_quant_fp8(
torch_out, group_size=group_size[1], use_ue8m0=False
)
def fused_impl(
rms_norm_layer: RMSNorm, # this stores the weights
x: torch.Tensor,
residual: torch.Tensor | None,
quant_dtype: torch.dtype,
group_size: list[int],
):
out, _ = ops.rms_norm_dynamic_per_token_quant(
x, rms_norm_layer.weight, 1e-6, quant_dtype, residual=residual
)
def fused_groupwise_impl(
rms_norm_layer: RMSNorm, # this stores the weights
x: torch.Tensor,
residual: torch.Tensor | None,
quant_dtype: torch.dtype,
group_size: list[int],
):
out, _ = ops.rms_norm_per_block_quant(
x,
rms_norm_layer.weight,
1e-6,
quant_dtype,
group_size,
residual=residual,
is_scale_transposed=True,
)
# Bench functions
def bench_fn(
rms_norm_layer: RMSNorm,
x: torch.Tensor,
residual: torch.Tensor,
quant_dtype: torch.dtype,
group_size: list[int],
label: str,
sub_label: str,
fn: Callable,
description: str,
) -> TMeasurement:
min_run_time = 1
globals = {
"rms_norm_layer": rms_norm_layer,
"x": x,
"residual": residual,
"quant_dtype": quant_dtype,
"group_size": group_size,
"fn": fn,
}
return TBenchmark.Timer(
stmt="fn(rms_norm_layer, x, residual, quant_dtype, group_size)",
globals=globals,
label=label,
sub_label=sub_label,
description=description,
).blocked_autorange(min_run_time=min_run_time)
def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasurement]:
# Make inputs
layer = RMSNorm(params.hidden_size, 1e-6).to(dtype=params.dtype)
# Make weights
layer.weight.data.normal_(mean=1.0, std=0.1)
# Make inputs
scale = 1 / params.hidden_size
x = (
torch.randn(
params.num_tokens, params.hidden_size, dtype=params.dtype, device="cuda"
)
* scale
)
residual = (
(torch.randn_like(x) * scale).to(device="cuda") if params.add_residual else None
)
timers = []
# unfused int8 impl.
timers.append(
bench_fn(
layer,
x,
residual,
torch.int8,
params.group_size,
label,
sub_label,
unfused_int8_impl,
"unfused_int8_impl",
)
)
# unfused fp8 impl.
timers.append(
bench_fn(
layer,
x,
residual,
torch.float8_e4m3fn,
params.group_size,
label,
sub_label,
unfused_fp8_impl,
"unfused_fp8_impl",
)
)
# fused int8 impl.
timers.append(
bench_fn(
layer,
x,
residual,
torch.int8,
params.group_size,
label,
sub_label,
fused_impl,
"fused_int8_impl",
)
)
# fused fp8 impl.
timers.append(
bench_fn(
layer,
x,
residual,
torch.float8_e4m3fn,
params.group_size,
label,
sub_label,
fused_impl,
"fused_fp8_impl",
)
)
# unfused groupwise fp8 impl.
timers.append(
bench_fn(
layer,
x,
residual,
torch.float8_e4m3fn,
params.group_size,
label,
sub_label,
unfused_groupwise_fp8_impl,
"unfused_groupwise_fp8_impl",
)
)
# fused groupwise fp8 impl.
timers.append(
bench_fn(
layer,
x,
residual,
torch.float8_e4m3fn,
params.group_size,
label,
sub_label,
fused_groupwise_impl,
"fused_groupwise_fp8_impl",
)
)
print_timers(timers)
return timers
# launch bench
# runner
def print_timers(timers: Iterable[TMeasurement]):
compare = TBenchmark.Compare(timers)
compare.print()
@default_vllm_config()
def main():
torch.set_default_device("cuda")
bench_params = get_bench_params()
timers = []
for bp in tqdm(bench_params):
timers.extend(bench(bp, "rms-norm-dynamic-per-token-quant", bp.description()))
print_timers(timers)
# pickle all the results
timestamp = int(time.time())
with open(f"rms_norm_dpt_quant-{timestamp}.pkl", "wb") as f:
pkl.dump(timers, f)
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment