Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
41199996
Commit
41199996
authored
Dec 13, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.12.0' into v0.12.0-dev
parents
31021d81
4fd9d6a8
Changes
380
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1041 additions
and
210 deletions
+1041
-210
benchmarks/auto_tune/README.md
benchmarks/auto_tune/README.md
+2
-2
benchmarks/auto_tune/auto_tune.sh
benchmarks/auto_tune/auto_tune.sh
+6
-6
benchmarks/backend_request_func.py
benchmarks/backend_request_func.py
+14
-15
benchmarks/benchmark_batch_invariance.py
benchmarks/benchmark_batch_invariance.py
+380
-0
benchmarks/benchmark_block_pool.py
benchmarks/benchmark_block_pool.py
+2
-2
benchmarks/benchmark_long_document_qa_throughput.py
benchmarks/benchmark_long_document_qa_throughput.py
+1
-1
benchmarks/benchmark_ngram_proposer.py
benchmarks/benchmark_ngram_proposer.py
+7
-4
benchmarks/benchmark_prefix_caching.py
benchmarks/benchmark_prefix_caching.py
+5
-8
benchmarks/benchmark_prioritization.py
benchmarks/benchmark_prioritization.py
+2
-3
benchmarks/benchmark_serving_structured_output.py
benchmarks/benchmark_serving_structured_output.py
+11
-18
benchmarks/benchmark_utils.py
benchmarks/benchmark_utils.py
+8
-8
benchmarks/cutlass_benchmarks/sparse_benchmarks.py
benchmarks/cutlass_benchmarks/sparse_benchmarks.py
+2
-3
benchmarks/cutlass_benchmarks/w8a8_benchmarks.py
benchmarks/cutlass_benchmarks/w8a8_benchmarks.py
+9
-9
benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh
benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh
+2
-6
benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh
benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh
+4
-12
benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py
benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py
+155
-94
benchmarks/fused_kernels/layernorm_rms_benchmarks.py
benchmarks/fused_kernels/layernorm_rms_benchmarks.py
+4
-5
benchmarks/kernels/bench_block_fp8_gemm.py
benchmarks/kernels/bench_block_fp8_gemm.py
+29
-14
benchmarks/kernels/bench_mxfp4_qutlass.py
benchmarks/kernels/bench_mxfp4_qutlass.py
+191
-0
benchmarks/kernels/bench_nvfp4_qutlass.py
benchmarks/kernels/bench_nvfp4_qutlass.py
+207
-0
No files found.
Too many changes to show.
To preserve performance only
380 of 380+
files are displayed.
Plain diff
Email patch
benchmarks/auto_tune/README.md
View file @
41199996
...
@@ -83,7 +83,7 @@ MIN_CACHE_HIT_PCT=0
...
@@ -83,7 +83,7 @@ MIN_CACHE_HIT_PCT=0
MAX_LATENCY_ALLOWED_MS
=
100000000000
# A very large number
MAX_LATENCY_ALLOWED_MS
=
100000000000
# A very large number
```
```
###
#
2. Maximize Throughput with a Latency Requirement
### 2. Maximize Throughput with a Latency Requirement
-
**Goal**
: Find the best server parameters when P99 end-to-end latency must be below 500ms.
-
**Goal**
: Find the best server parameters when P99 end-to-end latency must be below 500ms.
-
**Configuration**
:
-
**Configuration**
:
...
@@ -96,7 +96,7 @@ MIN_CACHE_HIT_PCT=0
...
@@ -96,7 +96,7 @@ MIN_CACHE_HIT_PCT=0
MAX_LATENCY_ALLOWED_MS
=
500
MAX_LATENCY_ALLOWED_MS
=
500
```
```
###
#
3. Maximize Throughput with Prefix Caching and Latency Requirements
### 3. Maximize Throughput with Prefix Caching and Latency Requirements
-
**Goal**
: Find the best server parameters assuming a 60% prefix cache hit rate and a latency requirement of 500ms.
-
**Goal**
: Find the best server parameters assuming a 60% prefix cache hit rate and a latency requirement of 500ms.
-
**Configuration**
:
-
**Configuration**
:
...
...
benchmarks/auto_tune/auto_tune.sh
View file @
41199996
...
@@ -74,7 +74,7 @@ start_server() {
...
@@ -74,7 +74,7 @@ start_server() {
local
vllm_log
=
$4
local
vllm_log
=
$4
local
profile_dir
=
$5
local
profile_dir
=
$5
pkill
-if
vllm
pkill
-if
"
vllm
serve"
||
true
# Define the common arguments as a bash array.
# Define the common arguments as a bash array.
# Each argument and its value are separate elements.
# Each argument and its value are separate elements.
...
@@ -96,11 +96,11 @@ start_server() {
...
@@ -96,11 +96,11 @@ start_server() {
# This correctly passes each element as a separate argument.
# This correctly passes each element as a separate argument.
if
[[
-n
"
$profile_dir
"
]]
;
then
if
[[
-n
"
$profile_dir
"
]]
;
then
# Start server with profiling enabled
# Start server with profiling enabled
VLLM_USE_V1
=
1
VLLM_SERVER_DEV_MODE
=
1
VLLM_TORCH_PROFILER_DIR
=
$profile_dir
\
VLLM_SERVER_DEV_MODE
=
1
VLLM_TORCH_PROFILER_DIR
=
$profile_dir
\
vllm serve
"
${
common_args_array
[@]
}
"
>
"
$vllm_log
"
2>&1 &
vllm serve
"
${
common_args_array
[@]
}
"
>
"
$vllm_log
"
2>&1 &
else
else
# Start server without profiling
# Start server without profiling
VLLM_USE_V1
=
1
VLLM_SERVER_DEV_MODE
=
1
\
VLLM_SERVER_DEV_MODE
=
1
\
vllm serve
"
${
common_args_array
[@]
}
"
>
"
$vllm_log
"
2>&1 &
vllm serve
"
${
common_args_array
[@]
}
"
>
"
$vllm_log
"
2>&1 &
fi
fi
local
server_pid
=
$!
local
server_pid
=
$!
...
@@ -139,7 +139,7 @@ run_benchmark() {
...
@@ -139,7 +139,7 @@ run_benchmark() {
echo
"vllm_log:
$vllm_log
"
echo
"vllm_log:
$vllm_log
"
echo
echo
rm
-f
$vllm_log
rm
-f
$vllm_log
pkill
-if
vllm
pkill
-if
"
vllm
serve"
||
true
echo
"starting server..."
echo
"starting server..."
# Call start_server without a profile_dir to avoid profiling overhead
# Call start_server without a profile_dir to avoid profiling overhead
...
@@ -232,7 +232,7 @@ run_benchmark() {
...
@@ -232,7 +232,7 @@ run_benchmark() {
echo
"best_max_num_seqs:
$best_max_num_seqs
, best_num_batched_tokens:
$best_num_batched_tokens
, best_throughput:
$best_throughput
"
echo
"best_max_num_seqs:
$best_max_num_seqs
, best_num_batched_tokens:
$best_num_batched_tokens
, best_throughput:
$best_throughput
"
pkill
-if
vllm
pkill
-if
"
vllm
serve"
||
true
sleep
10
sleep
10
echo
"===================="
echo
"===================="
return
0
return
0
...
@@ -308,6 +308,6 @@ if (( $(echo "$best_throughput > 0" | bc -l) )); then
...
@@ -308,6 +308,6 @@ if (( $(echo "$best_throughput > 0" | bc -l) )); then
else
else
echo
"No configuration met the latency requirements. Skipping final profiling run."
echo
"No configuration met the latency requirements. Skipping final profiling run."
fi
fi
pkill
-if
vllm
pkill
-if
"
vllm
serve"
||
true
echo
"best_max_num_seqs:
$best_max_num_seqs
, best_num_batched_tokens:
$best_num_batched_tokens
, best_throughput:
$best_throughput
, profile saved in:
$PROFILE_PATH
"
echo
"best_max_num_seqs:
$best_max_num_seqs
, best_num_batched_tokens:
$best_num_batched_tokens
, best_throughput:
$best_throughput
, profile saved in:
$PROFILE_PATH
"
echo
"best_max_num_seqs:
$best_max_num_seqs
, best_num_batched_tokens:
$best_num_batched_tokens
, best_throughput:
$best_throughput
, profile saved in:
$PROFILE_PATH
"
>>
"
$RESULT
"
echo
"best_max_num_seqs:
$best_max_num_seqs
, best_num_batched_tokens:
$best_num_batched_tokens
, best_throughput:
$best_throughput
, profile saved in:
$PROFILE_PATH
"
>>
"
$RESULT
"
benchmarks/backend_request_func.py
View file @
41199996
...
@@ -8,7 +8,6 @@ import sys
...
@@ -8,7 +8,6 @@ import sys
import
time
import
time
import
traceback
import
traceback
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
typing
import
Optional
,
Union
import
aiohttp
import
aiohttp
import
huggingface_hub.constants
import
huggingface_hub.constants
...
@@ -28,13 +27,13 @@ class RequestFuncInput:
...
@@ -28,13 +27,13 @@ class RequestFuncInput:
prompt_len
:
int
prompt_len
:
int
output_len
:
int
output_len
:
int
model
:
str
model
:
str
model_name
:
Optional
[
str
]
=
None
model_name
:
str
|
None
=
None
logprobs
:
Optional
[
int
]
=
None
logprobs
:
int
|
None
=
None
extra_body
:
Optional
[
dict
]
=
None
extra_body
:
dict
|
None
=
None
multi_modal_content
:
Optional
[
dict
|
list
[
dict
]
]
=
None
multi_modal_content
:
dict
|
list
[
dict
]
|
None
=
None
ignore_eos
:
bool
=
False
ignore_eos
:
bool
=
False
language
:
Optional
[
str
]
=
None
language
:
str
|
None
=
None
request_id
:
Optional
[
str
]
=
None
request_id
:
str
|
None
=
None
@
dataclass
@
dataclass
...
@@ -52,7 +51,7 @@ class RequestFuncOutput:
...
@@ -52,7 +51,7 @@ class RequestFuncOutput:
async
def
async_request_tgi
(
async
def
async_request_tgi
(
request_func_input
:
RequestFuncInput
,
request_func_input
:
RequestFuncInput
,
pbar
:
Optional
[
tqdm
]
=
None
,
pbar
:
tqdm
|
None
=
None
,
)
->
RequestFuncOutput
:
)
->
RequestFuncOutput
:
api_url
=
request_func_input
.
api_url
api_url
=
request_func_input
.
api_url
assert
api_url
.
endswith
(
"generate_stream"
)
assert
api_url
.
endswith
(
"generate_stream"
)
...
@@ -133,7 +132,7 @@ async def async_request_tgi(
...
@@ -133,7 +132,7 @@ async def async_request_tgi(
async
def
async_request_trt_llm
(
async
def
async_request_trt_llm
(
request_func_input
:
RequestFuncInput
,
request_func_input
:
RequestFuncInput
,
pbar
:
Optional
[
tqdm
]
=
None
,
pbar
:
tqdm
|
None
=
None
,
)
->
RequestFuncOutput
:
)
->
RequestFuncOutput
:
api_url
=
request_func_input
.
api_url
api_url
=
request_func_input
.
api_url
assert
api_url
.
endswith
(
"generate_stream"
)
assert
api_url
.
endswith
(
"generate_stream"
)
...
@@ -204,7 +203,7 @@ async def async_request_trt_llm(
...
@@ -204,7 +203,7 @@ async def async_request_trt_llm(
async
def
async_request_deepspeed_mii
(
async
def
async_request_deepspeed_mii
(
request_func_input
:
RequestFuncInput
,
request_func_input
:
RequestFuncInput
,
pbar
:
Optional
[
tqdm
]
=
None
,
pbar
:
tqdm
|
None
=
None
,
)
->
RequestFuncOutput
:
)
->
RequestFuncOutput
:
api_url
=
request_func_input
.
api_url
api_url
=
request_func_input
.
api_url
assert
api_url
.
endswith
((
"completions"
,
"profile"
)),
(
assert
api_url
.
endswith
((
"completions"
,
"profile"
)),
(
...
@@ -267,7 +266,7 @@ async def async_request_deepspeed_mii(
...
@@ -267,7 +266,7 @@ async def async_request_deepspeed_mii(
async
def
async_request_openai_completions
(
async
def
async_request_openai_completions
(
request_func_input
:
RequestFuncInput
,
request_func_input
:
RequestFuncInput
,
pbar
:
Optional
[
tqdm
]
=
None
,
pbar
:
tqdm
|
None
=
None
,
)
->
RequestFuncOutput
:
)
->
RequestFuncOutput
:
api_url
=
request_func_input
.
api_url
api_url
=
request_func_input
.
api_url
assert
api_url
.
endswith
((
"completions"
,
"profile"
)),
(
assert
api_url
.
endswith
((
"completions"
,
"profile"
)),
(
...
@@ -367,7 +366,7 @@ async def async_request_openai_completions(
...
@@ -367,7 +366,7 @@ async def async_request_openai_completions(
async
def
async_request_openai_chat_completions
(
async
def
async_request_openai_chat_completions
(
request_func_input
:
RequestFuncInput
,
request_func_input
:
RequestFuncInput
,
pbar
:
Optional
[
tqdm
]
=
None
,
pbar
:
tqdm
|
None
=
None
,
)
->
RequestFuncOutput
:
)
->
RequestFuncOutput
:
api_url
=
request_func_input
.
api_url
api_url
=
request_func_input
.
api_url
assert
api_url
.
endswith
((
"chat/completions"
,
"profile"
)),
(
assert
api_url
.
endswith
((
"chat/completions"
,
"profile"
)),
(
...
@@ -476,7 +475,7 @@ async def async_request_openai_chat_completions(
...
@@ -476,7 +475,7 @@ async def async_request_openai_chat_completions(
async
def
async_request_openai_audio
(
async
def
async_request_openai_audio
(
request_func_input
:
RequestFuncInput
,
request_func_input
:
RequestFuncInput
,
pbar
:
Optional
[
tqdm
]
=
None
,
pbar
:
tqdm
|
None
=
None
,
)
->
RequestFuncOutput
:
)
->
RequestFuncOutput
:
# Lazy import without PlaceholderModule to avoid vllm dep.
# Lazy import without PlaceholderModule to avoid vllm dep.
import
soundfile
import
soundfile
...
@@ -610,7 +609,7 @@ def get_tokenizer(
...
@@ -610,7 +609,7 @@ def get_tokenizer(
tokenizer_mode
:
str
=
"auto"
,
tokenizer_mode
:
str
=
"auto"
,
trust_remote_code
:
bool
=
False
,
trust_remote_code
:
bool
=
False
,
**
kwargs
,
**
kwargs
,
)
->
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
]
:
)
->
PreTrainedTokenizer
|
PreTrainedTokenizerFast
:
if
pretrained_model_name_or_path
is
not
None
and
not
os
.
path
.
exists
(
if
pretrained_model_name_or_path
is
not
None
and
not
os
.
path
.
exists
(
pretrained_model_name_or_path
pretrained_model_name_or_path
):
):
...
@@ -621,7 +620,7 @@ def get_tokenizer(
...
@@ -621,7 +620,7 @@ def get_tokenizer(
kwargs
[
"use_fast"
]
=
False
kwargs
[
"use_fast"
]
=
False
if
tokenizer_mode
==
"mistral"
:
if
tokenizer_mode
==
"mistral"
:
try
:
try
:
from
vllm.
transformers_utils.
tokenizer
import
MistralTokenizer
from
vllm.tokenizer
s
import
MistralTokenizer
except
ImportError
as
e
:
except
ImportError
as
e
:
raise
ImportError
(
raise
ImportError
(
"MistralTokenizer requires vllm package.
\n
"
"MistralTokenizer requires vllm package.
\n
"
...
...
benchmarks/benchmark_batch_invariance.py
0 → 100755
View file @
41199996
#!/usr/bin/env python3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Benchmark to measure the performance overhead of VLLM_BATCH_INVARIANT mode.
This benchmark runs the same workload twice:
1. With VLLM_BATCH_INVARIANT=0 (baseline)
2. With VLLM_BATCH_INVARIANT=1 (batch invariant mode)
And reports the timing and throughput metrics for comparison.
Environment variables:
VLLM_BENCH_MODEL: Model to benchmark (default: "Qwen/Qwen3-1.7B")
VLLM_BENCH_TP_SIZE: Tensor parallel size (default: 1, use 8 for deepseek)
VLLM_BENCH_BATCH_SIZE: Max batch size (default: 128)
VLLM_BENCH_NUM_TRIALS: Number of trials to run (default: 5)
VLLM_BENCH_MIN_PROMPT: Min prompt length in words (default: 1024)
VLLM_BENCH_MAX_PROMPT: Max prompt length in words (default: 2048)
VLLM_BENCH_MAX_TOKENS: Max tokens to generate (default: 128)
VLLM_BENCH_TEMPERATURE: Temperature for sampling (default: 0.0)
VLLM_BENCH_GPU_MEMORY_UTILIZATION: GPU memory utilization (default: 0.4)
VLLM_BENCH_MAX_MODEL_LEN: Max model length (default: 5120)
VLLM_BENCH_BACKEND: Attention backend (default: FLASH_ATTN)
Example usage:
# Benchmark qwen3 (default)
python benchmarks/benchmark_batch_invariance.py
# Benchmark deepseek with 8 GPUs
VLLM_BENCH_MODEL="deepseek-ai/DeepSeek-V3" VLLM_BENCH_TP_SIZE=8
\\
python benchmarks/benchmark_batch_invariance.py
# Quick test with fewer trials
VLLM_BENCH_NUM_TRIALS=2 VLLM_BENCH_BATCH_SIZE=32
\\
python benchmarks/benchmark_batch_invariance.py
"""
import
contextlib
import
os
import
random
import
time
from
vllm
import
LLM
,
SamplingParams
from
vllm.platforms
import
current_platform
def
_random_prompt
(
min_words
:
int
=
1024
,
max_words
:
int
=
1024
*
2
)
->
str
:
"""Generate a random prompt for benchmarking."""
prompt_templates
=
[
"Question: What is the capital of France?
\n
Answer: The capital of France is"
,
"Q: How does photosynthesis work?
\n
A: Photosynthesis is the process by which"
,
"User: Can you explain quantum mechanics?
\n
Assistant: Quantum mechanics is"
,
"Once upon a time in a distant galaxy, there lived"
,
"The old man walked slowly down the street, remembering"
,
"In the year 2157, humanity finally discovered"
,
"To implement a binary search tree in Python, first we need to"
,
"The algorithm works by iterating through the array and"
,
"Here's how to optimize database queries using indexing:"
,
"The Renaissance was a period in European history that"
,
"Climate change is caused by several factors including"
,
"The human brain contains approximately 86 billion neurons which"
,
"I've been thinking about getting a new laptop because"
,
"Yesterday I went to the store and bought"
,
"My favorite thing about summer is definitely"
,
]
base_prompt
=
random
.
choice
(
prompt_templates
)
if
max_words
<
min_words
:
max_words
=
min_words
target_words
=
random
.
randint
(
min_words
,
max_words
)
if
target_words
>
50
:
padding_text
=
(
" This is an interesting topic that deserves more explanation. "
*
(
target_words
//
50
)
)
base_prompt
=
base_prompt
+
padding_text
return
base_prompt
def
run_benchmark_with_batch_invariant
(
model
:
str
,
tp_size
:
int
,
max_batch_size
:
int
,
num_trials
:
int
,
min_prompt
:
int
,
max_prompt
:
int
,
max_tokens
:
int
,
temperature
:
float
,
gpu_mem_util
:
float
,
max_model_len
:
int
,
backend
:
str
,
batch_invariant
:
bool
,
seed
:
int
=
12345
,
)
->
dict
:
"""
Run the benchmark with the specified configuration.
Returns a dict with timing and throughput metrics.
"""
random
.
seed
(
seed
)
# Set environment variables
os
.
environ
[
"VLLM_ATTENTION_BACKEND"
]
=
backend
if
batch_invariant
:
os
.
environ
[
"VLLM_BATCH_INVARIANT"
]
=
"1"
else
:
os
.
environ
[
"VLLM_BATCH_INVARIANT"
]
=
"0"
print
(
f
"
\n
{
'='
*
80
}
"
)
print
(
f
"BENCHMARK: VLLM_BATCH_INVARIANT=
{
int
(
batch_invariant
)
}
"
)
print
(
f
" Model:
{
model
}
"
)
print
(
f
" TP Size:
{
tp_size
}
"
)
print
(
f
" Backend:
{
backend
}
"
)
print
(
f
" Max Batch Size:
{
max_batch_size
}
"
)
print
(
f
" Trials:
{
num_trials
}
"
)
print
(
f
" Max Tokens:
{
max_tokens
}
"
)
print
(
f
"
{
'='
*
80
}
\n
"
)
sampling
=
SamplingParams
(
temperature
=
temperature
,
top_p
=
0.95
,
max_tokens
=
max_tokens
,
seed
=
20240919
,
)
needle_prompt
=
"There once was a "
llm
=
None
try
:
# Create LLM engine
start_init
=
time
.
perf_counter
()
llm
=
LLM
(
model
=
model
,
max_num_seqs
=
max_batch_size
,
gpu_memory_utilization
=
gpu_mem_util
,
max_model_len
=
max_model_len
,
dtype
=
"bfloat16"
,
tensor_parallel_size
=
tp_size
,
enable_prefix_caching
=
False
,
)
init_time
=
time
.
perf_counter
()
-
start_init
print
(
f
"Engine initialization time:
{
init_time
:.
2
f
}
s
\n
"
)
# Generate baseline
print
(
"Generating baseline (warmup)..."
)
baseline_out
=
llm
.
generate
([
needle_prompt
],
sampling
)
assert
len
(
baseline_out
)
==
1
baseline_text
=
baseline_out
[
0
].
outputs
[
0
].
text
print
(
f
"Baseline output: '
{
baseline_text
[:
50
]
}
...'
\n
"
)
# Run trials and measure timing
trial_times
:
list
[
float
]
=
[]
total_tokens
=
0
total_prompts
=
0
for
trial
in
range
(
num_trials
):
# Create a batch
prompts
:
list
[
str
]
=
[]
batch_size
=
random
.
randint
(
max_batch_size
//
2
,
max_batch_size
)
needle_pos
=
random
.
randint
(
0
,
batch_size
-
1
)
for
i
in
range
(
batch_size
):
if
i
==
needle_pos
:
prompts
.
append
(
needle_prompt
)
else
:
prompts
.
append
(
_random_prompt
(
min_prompt
,
max_prompt
))
# Measure time for this trial
start_time
=
time
.
perf_counter
()
outputs
=
llm
.
generate
(
prompts
,
sampling
)
trial_time
=
time
.
perf_counter
()
-
start_time
trial_times
.
append
(
trial_time
)
total_prompts
+=
len
(
prompts
)
# Count tokens
for
output
in
outputs
:
if
output
.
outputs
:
total_tokens
+=
len
(
output
.
outputs
[
0
].
token_ids
)
print
(
f
"Trial
{
trial
+
1
}
/
{
num_trials
}
: "
f
"batch_size=
{
batch_size
}
, "
f
"time=
{
trial_time
:.
2
f
}
s"
)
# Verify needle output still matches
needle_output
=
outputs
[
needle_pos
]
assert
needle_output
.
prompt
==
needle_prompt
# Compute statistics
avg_time
=
sum
(
trial_times
)
/
len
(
trial_times
)
min_time
=
min
(
trial_times
)
max_time
=
max
(
trial_times
)
throughput
=
total_tokens
/
sum
(
trial_times
)
prompts_per_sec
=
total_prompts
/
sum
(
trial_times
)
print
(
f
"
\n
{
'='
*
80
}
"
)
print
(
"RESULTS:"
)
print
(
f
" Average time per trial:
{
avg_time
:.
2
f
}
s"
)
print
(
f
" Min time:
{
min_time
:.
2
f
}
s"
)
print
(
f
" Max time:
{
max_time
:.
2
f
}
s"
)
print
(
f
" Total tokens generated:
{
total_tokens
}
"
)
print
(
f
" Total prompts processed:
{
total_prompts
}
"
)
print
(
f
" Throughput:
{
throughput
:.
2
f
}
tokens/s"
)
print
(
f
" Prompts/s:
{
prompts_per_sec
:.
2
f
}
"
)
print
(
f
"
{
'='
*
80
}
\n
"
)
return
{
"init_time"
:
init_time
,
"avg_time"
:
avg_time
,
"min_time"
:
min_time
,
"max_time"
:
max_time
,
"total_tokens"
:
total_tokens
,
"total_prompts"
:
total_prompts
,
"throughput"
:
throughput
,
"prompts_per_sec"
:
prompts_per_sec
,
"trial_times"
:
trial_times
,
}
finally
:
# Cleanup
if
llm
is
not
None
:
with
contextlib
.
suppress
(
Exception
):
llm
.
shutdown
()
def
main
():
# Check platform support
if
not
(
current_platform
.
is_cuda
()
and
current_platform
.
has_device_capability
(
90
)):
print
(
"ERROR: Requires CUDA and >= Hopper (SM90)"
)
print
(
f
"Current platform:
{
current_platform
.
device_type
}
"
)
if
current_platform
.
is_cuda
():
print
(
f
"Device capability:
{
current_platform
.
get_device_capability
()
}
"
)
return
1
# Read configuration from environment
model
=
os
.
getenv
(
"VLLM_BENCH_MODEL"
,
"Qwen/Qwen3-1.7B"
)
tp_size
=
int
(
os
.
getenv
(
"VLLM_BENCH_TP_SIZE"
,
"1"
))
max_batch_size
=
int
(
os
.
getenv
(
"VLLM_BENCH_BATCH_SIZE"
,
"128"
))
num_trials
=
int
(
os
.
getenv
(
"VLLM_BENCH_NUM_TRIALS"
,
"5"
))
min_prompt
=
int
(
os
.
getenv
(
"VLLM_BENCH_MIN_PROMPT"
,
"1024"
))
max_prompt
=
int
(
os
.
getenv
(
"VLLM_BENCH_MAX_PROMPT"
,
"2048"
))
max_tokens
=
int
(
os
.
getenv
(
"VLLM_BENCH_MAX_TOKENS"
,
"128"
))
temperature
=
float
(
os
.
getenv
(
"VLLM_BENCH_TEMPERATURE"
,
"0.0"
))
gpu_mem_util
=
float
(
os
.
getenv
(
"VLLM_BENCH_GPU_MEMORY_UTILIZATION"
,
"0.4"
))
max_model_len
=
int
(
os
.
getenv
(
"VLLM_BENCH_MAX_MODEL_LEN"
,
"5120"
))
backend
=
os
.
getenv
(
"VLLM_BENCH_BACKEND"
,
"FLASH_ATTN"
)
print
(
"
\n
"
+
"="
*
80
)
print
(
"VLLM BATCH INVARIANCE BENCHMARK"
)
print
(
"="
*
80
)
print
(
"
\n
Configuration:"
)
print
(
f
" Model:
{
model
}
"
)
print
(
f
" Tensor Parallel Size:
{
tp_size
}
"
)
print
(
f
" Attention Backend:
{
backend
}
"
)
print
(
f
" Max Batch Size:
{
max_batch_size
}
"
)
print
(
f
" Number of Trials:
{
num_trials
}
"
)
print
(
f
" Prompt Length Range:
{
min_prompt
}
-
{
max_prompt
}
words"
)
print
(
f
" Max Tokens to Generate:
{
max_tokens
}
"
)
print
(
f
" Temperature:
{
temperature
}
"
)
print
(
f
" GPU Memory Utilization:
{
gpu_mem_util
}
"
)
print
(
f
" Max Model Length:
{
max_model_len
}
"
)
print
(
"="
*
80
)
# Run benchmark WITHOUT batch invariance (baseline)
print
(
"
\n
"
+
"="
*
80
)
print
(
"PHASE 1: Running WITHOUT batch invariance (baseline)"
)
print
(
"="
*
80
)
baseline_results
=
run_benchmark_with_batch_invariant
(
model
=
model
,
tp_size
=
tp_size
,
max_batch_size
=
max_batch_size
,
num_trials
=
num_trials
,
min_prompt
=
min_prompt
,
max_prompt
=
max_prompt
,
max_tokens
=
max_tokens
,
temperature
=
temperature
,
gpu_mem_util
=
gpu_mem_util
,
max_model_len
=
max_model_len
,
backend
=
backend
,
batch_invariant
=
False
,
)
# Run benchmark WITH batch invariance
print
(
"
\n
"
+
"="
*
80
)
print
(
"PHASE 2: Running WITH batch invariance"
)
print
(
"="
*
80
)
batch_inv_results
=
run_benchmark_with_batch_invariant
(
model
=
model
,
tp_size
=
tp_size
,
max_batch_size
=
max_batch_size
,
num_trials
=
num_trials
,
min_prompt
=
min_prompt
,
max_prompt
=
max_prompt
,
max_tokens
=
max_tokens
,
temperature
=
temperature
,
gpu_mem_util
=
gpu_mem_util
,
max_model_len
=
max_model_len
,
backend
=
backend
,
batch_invariant
=
True
,
)
# Compare results
print
(
"
\n
"
+
"="
*
80
)
print
(
"COMPARISON: Batch Invariance vs Baseline"
)
print
(
"="
*
80
)
init_overhead_pct
=
(
(
batch_inv_results
[
"init_time"
]
-
baseline_results
[
"init_time"
])
/
baseline_results
[
"init_time"
]
*
100
)
time_overhead_pct
=
(
(
batch_inv_results
[
"avg_time"
]
-
baseline_results
[
"avg_time"
])
/
baseline_results
[
"avg_time"
]
*
100
)
throughput_change_pct
=
(
(
batch_inv_results
[
"throughput"
]
-
baseline_results
[
"throughput"
])
/
baseline_results
[
"throughput"
]
*
100
)
print
(
"
\n
Initialization Time:"
)
print
(
f
" Baseline:
{
baseline_results
[
'init_time'
]:.
2
f
}
s"
)
print
(
f
" Batch Invariant:
{
batch_inv_results
[
'init_time'
]:.
2
f
}
s"
)
print
(
f
" Overhead:
{
init_overhead_pct
:
+
.
2
f
}
%"
)
print
(
"
\n
Average Trial Time:"
)
print
(
f
" Baseline:
{
baseline_results
[
'avg_time'
]:.
2
f
}
s"
)
print
(
f
" Batch Invariant:
{
batch_inv_results
[
'avg_time'
]:.
2
f
}
s"
)
print
(
f
" Overhead:
{
time_overhead_pct
:
+
.
2
f
}
%"
)
print
(
"
\n
Throughput (tokens/s):"
)
print
(
f
" Baseline:
{
baseline_results
[
'throughput'
]:.
2
f
}
"
)
print
(
f
" Batch Invariant:
{
batch_inv_results
[
'throughput'
]:.
2
f
}
"
)
print
(
f
" Change:
{
throughput_change_pct
:
+
.
2
f
}
%"
)
print
(
"
\n
Prompts/s:"
)
print
(
f
" Baseline:
{
baseline_results
[
'prompts_per_sec'
]:.
2
f
}
"
)
print
(
f
" Batch Invariant:
{
batch_inv_results
[
'prompts_per_sec'
]:.
2
f
}
"
)
print
(
"
\n
"
+
"="
*
80
)
print
(
"SUMMARY"
)
print
(
"="
*
80
)
if
time_overhead_pct
>
0
:
print
(
f
"Batch invariance mode adds approximately
{
time_overhead_pct
:.
1
f
}
% "
"overhead"
)
else
:
print
(
f
"Batch invariance mode is approximately
{
-
time_overhead_pct
:.
1
f
}
% "
"faster (unexpected!)"
)
if
abs
(
throughput_change_pct
)
<
1.0
:
print
(
"Throughput difference is negligible (< 1%)"
)
elif
throughput_change_pct
<
0
:
print
(
f
"Throughput decreased by
{
-
throughput_change_pct
:.
1
f
}
% "
"with batch invariance"
)
else
:
print
(
f
"Throughput increased by
{
throughput_change_pct
:.
1
f
}
% "
"with batch invariance (unexpected!)"
)
print
(
"="
*
80
+
"
\n
"
)
return
0
if
__name__
==
"__main__"
:
exit
(
main
())
benchmarks/benchmark_block_pool.py
View file @
41199996
...
@@ -2,10 +2,10 @@
...
@@ -2,10 +2,10 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
gc
import
gc
from
benchmark_utils
import
TimeCollector
from
tabulate
import
tabulate
from
tabulate
import
tabulate
from
benchmark_utils
import
TimeCollector
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.v1.core.block_pool
import
BlockPool
from
vllm.v1.core.block_pool
import
BlockPool
...
...
benchmarks/benchmark_long_document_qa_throughput.py
View file @
41199996
...
@@ -46,7 +46,7 @@ import time
...
@@ -46,7 +46,7 @@ import time
from
vllm
import
LLM
,
SamplingParams
from
vllm
import
LLM
,
SamplingParams
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
.argparse_utils
import
FlexibleArgumentParser
def
test_long_document_qa
(
llm
=
None
,
sampling_params
=
None
,
prompts
=
None
):
def
test_long_document_qa
(
llm
=
None
,
sampling_params
=
None
,
prompts
=
None
):
...
...
benchmarks/benchmark_ngram_proposer.py
View file @
41199996
...
@@ -5,9 +5,9 @@ import time
...
@@ -5,9 +5,9 @@ import time
from
unittest
import
mock
from
unittest
import
mock
import
numpy
as
np
import
numpy
as
np
from
benchmark_utils
import
TimeCollector
from
tabulate
import
tabulate
from
tabulate
import
tabulate
from
benchmark_utils
import
TimeCollector
from
vllm.config
import
(
from
vllm.config
import
(
CacheConfig
,
CacheConfig
,
DeviceConfig
,
DeviceConfig
,
...
@@ -19,7 +19,7 @@ from vllm.config import (
...
@@ -19,7 +19,7 @@ from vllm.config import (
VllmConfig
,
VllmConfig
,
)
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
.argparse_utils
import
FlexibleArgumentParser
from
vllm.v1.spec_decode.ngram_proposer
import
NgramProposer
from
vllm.v1.spec_decode.ngram_proposer
import
NgramProposer
from
vllm.v1.worker.gpu_input_batch
import
InputBatch
from
vllm.v1.worker.gpu_input_batch
import
InputBatch
from
vllm.v1.worker.gpu_model_runner
import
GPUModelRunner
from
vllm.v1.worker.gpu_model_runner
import
GPUModelRunner
...
@@ -108,7 +108,10 @@ def benchmark_batched_propose(args):
...
@@ -108,7 +108,10 @@ def benchmark_batched_propose(args):
device_config
=
DeviceConfig
(
device
=
current_platform
.
device_type
),
device_config
=
DeviceConfig
(
device
=
current_platform
.
device_type
),
parallel_config
=
ParallelConfig
(),
parallel_config
=
ParallelConfig
(),
load_config
=
LoadConfig
(),
load_config
=
LoadConfig
(),
scheduler_config
=
SchedulerConfig
(),
scheduler_config
=
SchedulerConfig
(
max_model_len
=
model_config
.
max_model_len
,
is_encoder_decoder
=
model_config
.
is_encoder_decoder
,
),
)
)
# monkey patch vllm.v1.worker.gpu_model_runner.get_pp_group
# monkey patch vllm.v1.worker.gpu_model_runner.get_pp_group
...
@@ -164,7 +167,7 @@ def invoke_main() -> None:
...
@@ -164,7 +167,7 @@ def invoke_main() -> None:
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--batched"
,
action
=
"store_true"
,
help
=
"consider time to prepare batch"
"--batched"
,
action
=
"store_true"
,
help
=
"consider time to prepare batch"
)
# noqa: E501
)
parser
.
add_argument
(
parser
.
add_argument
(
"--num-iteration"
,
"--num-iteration"
,
type
=
int
,
type
=
int
,
...
...
benchmarks/benchmark_prefix_caching.py
View file @
41199996
...
@@ -32,18 +32,15 @@ import dataclasses
...
@@ -32,18 +32,15 @@ import dataclasses
import
json
import
json
import
random
import
random
import
time
import
time
from
typing
import
Optional
from
transformers
import
PreTrainedTokenizerBase
from
transformers
import
PreTrainedTokenizerBase
from
vllm
import
LLM
,
SamplingParams
from
vllm
import
LLM
,
SamplingParams
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
# import triton
try
:
try
:
from
vllm.
transformers_utils.
tokenizer
import
get_tokenizer
from
vllm.tokenizer
s
import
get_tokenizer
except
ImportError
:
except
ImportError
:
from
backend_request_func
import
get_tokenizer
from
backend_request_func
import
get_tokenizer
...
@@ -85,7 +82,7 @@ def sample_tokens(tokenizer: PreTrainedTokenizerBase, length: int) -> list[int]:
...
@@ -85,7 +82,7 @@ def sample_tokens(tokenizer: PreTrainedTokenizerBase, length: int) -> list[int]:
# Remove the special tokens.
# Remove the special tokens.
return
random
.
choices
(
return
random
.
choices
(
[
v
for
k
,
v
in
vocab
.
item
s
()
if
k
not
in
all_special_ids
],
[
v
for
v
in
vocab
.
value
s
()
if
v
not
in
all_special_ids
],
k
=
length
,
k
=
length
,
)
)
...
@@ -95,7 +92,7 @@ def sample_requests_from_dataset(
...
@@ -95,7 +92,7 @@ def sample_requests_from_dataset(
num_requests
:
int
,
num_requests
:
int
,
tokenizer
:
PreTrainedTokenizerBase
,
tokenizer
:
PreTrainedTokenizerBase
,
input_length_range
:
tuple
[
int
,
int
],
input_length_range
:
tuple
[
int
,
int
],
fixed_output_len
:
Optional
[
int
]
,
fixed_output_len
:
int
|
None
,
)
->
list
[
Request
]:
)
->
list
[
Request
]:
if
fixed_output_len
is
not
None
and
fixed_output_len
<
4
:
if
fixed_output_len
is
not
None
and
fixed_output_len
<
4
:
raise
ValueError
(
"output_len too small"
)
raise
ValueError
(
"output_len too small"
)
...
@@ -143,7 +140,7 @@ def sample_requests_from_random(
...
@@ -143,7 +140,7 @@ def sample_requests_from_random(
num_requests
:
int
,
num_requests
:
int
,
tokenizer
:
PreTrainedTokenizerBase
,
tokenizer
:
PreTrainedTokenizerBase
,
input_length_range
:
tuple
[
int
,
int
],
input_length_range
:
tuple
[
int
,
int
],
fixed_output_len
:
Optional
[
int
]
,
fixed_output_len
:
int
|
None
,
prefix_len
:
int
,
prefix_len
:
int
,
)
->
list
[
Request
]:
)
->
list
[
Request
]:
requests
=
[]
requests
=
[]
...
...
benchmarks/benchmark_prioritization.py
View file @
41199996
...
@@ -7,12 +7,11 @@ import dataclasses
...
@@ -7,12 +7,11 @@ import dataclasses
import
json
import
json
import
random
import
random
import
time
import
time
from
typing
import
Optional
from
transformers
import
AutoTokenizer
,
PreTrainedTokenizerBase
from
transformers
import
AutoTokenizer
,
PreTrainedTokenizerBase
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
.argparse_utils
import
FlexibleArgumentParser
# Select a equi-probable random priority
# Select a equi-probable random priority
...
@@ -24,7 +23,7 @@ def sample_requests(
...
@@ -24,7 +23,7 @@ def sample_requests(
dataset_path
:
str
,
dataset_path
:
str
,
num_requests
:
int
,
num_requests
:
int
,
tokenizer
:
PreTrainedTokenizerBase
,
tokenizer
:
PreTrainedTokenizerBase
,
fixed_output_len
:
Optional
[
int
]
,
fixed_output_len
:
int
|
None
,
)
->
list
[
tuple
[
str
,
int
,
int
,
int
]]:
)
->
list
[
tuple
[
str
,
int
,
int
,
int
]]:
if
fixed_output_len
is
not
None
and
fixed_output_len
<
4
:
if
fixed_output_len
is
not
None
and
fixed_output_len
<
4
:
raise
ValueError
(
"output_len too small"
)
raise
ValueError
(
"output_len too small"
)
...
...
benchmarks/benchmark_serving_structured_output.py
View file @
41199996
...
@@ -31,28 +31,27 @@ import time
...
@@ -31,28 +31,27 @@ import time
import
uuid
import
uuid
import
warnings
import
warnings
from
collections.abc
import
AsyncGenerator
from
collections.abc
import
AsyncGenerator
from
contextlib
import
nullcontext
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Optional
import
datasets
import
datasets
import
numpy
as
np
import
numpy
as
np
import
pandas
as
pd
import
pandas
as
pd
from
tqdm.asyncio
import
tqdm
from
transformers
import
PreTrainedTokenizerBase
from
backend_request_func
import
(
from
backend_request_func
import
(
ASYNC_REQUEST_FUNCS
,
ASYNC_REQUEST_FUNCS
,
RequestFuncInput
,
RequestFuncInput
,
RequestFuncOutput
,
RequestFuncOutput
,
)
)
from
tqdm.asyncio
import
tqdm
from
transformers
import
PreTrainedTokenizerBase
try
:
try
:
from
vllm.
transformers_utils.
tokenizer
import
get_tokenizer
from
vllm.tokenizer
s
import
get_tokenizer
except
ImportError
:
except
ImportError
:
from
backend_request_func
import
get_tokenizer
from
backend_request_func
import
get_tokenizer
try
:
try
:
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
.argparse_utils
import
FlexibleArgumentParser
except
ImportError
:
except
ImportError
:
from
argparse
import
ArgumentParser
as
FlexibleArgumentParser
from
argparse
import
ArgumentParser
as
FlexibleArgumentParser
...
@@ -317,7 +316,7 @@ def calculate_metrics(
...
@@ -317,7 +316,7 @@ def calculate_metrics(
tokenizer
:
PreTrainedTokenizerBase
,
tokenizer
:
PreTrainedTokenizerBase
,
selected_percentile_metrics
:
list
[
str
],
selected_percentile_metrics
:
list
[
str
],
selected_percentiles
:
list
[
float
],
selected_percentiles
:
list
[
float
],
goodput_config_dict
:
Optional
[
dict
[
str
,
float
]
]
=
None
,
goodput_config_dict
:
dict
[
str
,
float
]
|
None
=
None
,
)
->
tuple
[
BenchmarkMetrics
,
list
[
int
]]:
)
->
tuple
[
BenchmarkMetrics
,
list
[
int
]]:
actual_output_lens
:
list
[
int
]
=
[]
actual_output_lens
:
list
[
int
]
=
[]
total_input
=
0
total_input
=
0
...
@@ -437,9 +436,9 @@ async def benchmark(
...
@@ -437,9 +436,9 @@ async def benchmark(
selected_percentile_metrics
:
list
[
str
],
selected_percentile_metrics
:
list
[
str
],
selected_percentiles
:
list
[
str
],
selected_percentiles
:
list
[
str
],
ignore_eos
:
bool
,
ignore_eos
:
bool
,
max_concurrency
:
Optional
[
int
]
,
max_concurrency
:
int
|
None
,
structured_output_ratio
:
float
,
structured_output_ratio
:
float
,
goodput_config_dict
:
Optional
[
dict
[
str
,
float
]
]
=
None
,
goodput_config_dict
:
dict
[
str
,
float
]
|
None
=
None
,
):
):
if
backend
in
ASYNC_REQUEST_FUNCS
:
if
backend
in
ASYNC_REQUEST_FUNCS
:
request_func
=
ASYNC_REQUEST_FUNCS
[
backend
]
request_func
=
ASYNC_REQUEST_FUNCS
[
backend
]
...
@@ -503,15 +502,9 @@ async def benchmark(
...
@@ -503,15 +502,9 @@ async def benchmark(
pbar
=
None
if
disable_tqdm
else
tqdm
(
total
=
len
(
input_requests
))
pbar
=
None
if
disable_tqdm
else
tqdm
(
total
=
len
(
input_requests
))
# This can be used once the minimum Python version is 3.10 or higher,
semaphore
=
asyncio
.
Semaphore
(
max_concurrency
)
if
max_concurrency
else
nullcontext
()
# and it will simplify the code in limited_request_func.
# semaphore = (asyncio.Semaphore(max_concurrency)
# if max_concurrency else contextlib.nullcontext())
semaphore
=
asyncio
.
Semaphore
(
max_concurrency
)
if
max_concurrency
else
None
async
def
limited_request_func
(
request_func_input
,
pbar
):
async
def
limited_request_func
(
request_func_input
,
pbar
):
if
semaphore
is
None
:
return
await
request_func
(
request_func_input
=
request_func_input
,
pbar
=
pbar
)
async
with
semaphore
:
async
with
semaphore
:
return
await
request_func
(
request_func_input
=
request_func_input
,
pbar
=
pbar
)
return
await
request_func
(
request_func_input
=
request_func_input
,
pbar
=
pbar
)
...
@@ -910,13 +903,13 @@ def create_argument_parser():
...
@@ -910,13 +903,13 @@ def create_argument_parser():
parser
.
add_argument
(
parser
.
add_argument
(
"--tokenizer"
,
"--tokenizer"
,
type
=
str
,
type
=
str
,
help
=
"Name or path of the tokenizer, if not using the default tokenizer."
,
# noqa: E501
help
=
"Name or path of the tokenizer, if not using the default tokenizer."
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--tokenizer-mode"
,
"--tokenizer-mode"
,
type
=
str
,
type
=
str
,
default
=
"auto"
,
default
=
"auto"
,
help
=
"Name or path of the tokenizer, if not using the default tokenizer."
,
# noqa: E501
help
=
"Name or path of the tokenizer, if not using the default tokenizer."
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--num-prompts"
,
"--num-prompts"
,
...
...
benchmarks/benchmark_utils.py
View file @
41199996
...
@@ -6,7 +6,7 @@ import math
...
@@ -6,7 +6,7 @@ import math
import
os
import
os
import
time
import
time
from
types
import
TracebackType
from
types
import
TracebackType
from
typing
import
Any
,
Optional
,
Union
from
typing
import
Any
def
convert_to_pytorch_benchmark_format
(
def
convert_to_pytorch_benchmark_format
(
...
@@ -92,7 +92,7 @@ class TimeCollector:
...
@@ -92,7 +92,7 @@ class TimeCollector:
def
__init__
(
self
,
scale
:
int
)
->
None
:
def
__init__
(
self
,
scale
:
int
)
->
None
:
self
.
cnt
:
int
=
0
self
.
cnt
:
int
=
0
self
.
_sum
:
int
=
0
self
.
_sum
:
int
=
0
self
.
_max
:
Optional
[
int
]
=
None
self
.
_max
:
int
|
None
=
None
self
.
scale
=
scale
self
.
scale
=
scale
self
.
start_time
:
int
=
time
.
monotonic_ns
()
self
.
start_time
:
int
=
time
.
monotonic_ns
()
...
@@ -104,13 +104,13 @@ class TimeCollector:
...
@@ -104,13 +104,13 @@ class TimeCollector:
else
:
else
:
self
.
_max
=
max
(
self
.
_max
,
v
)
self
.
_max
=
max
(
self
.
_max
,
v
)
def
avg
(
self
)
->
Union
[
float
,
str
]
:
def
avg
(
self
)
->
float
|
str
:
return
self
.
_sum
*
1.0
/
self
.
cnt
/
self
.
scale
if
self
.
cnt
>
0
else
"N/A"
return
self
.
_sum
*
1.0
/
self
.
cnt
/
self
.
scale
if
self
.
cnt
>
0
else
"N/A"
def
max
(
self
)
->
Union
[
float
,
str
]
:
def
max
(
self
)
->
float
|
str
:
return
self
.
_max
/
self
.
scale
if
self
.
_max
else
"N/A"
return
self
.
_max
/
self
.
scale
if
self
.
_max
else
"N/A"
def
dump_avg_max
(
self
)
->
list
[
Union
[
float
,
str
]
]
:
def
dump_avg_max
(
self
)
->
list
[
float
|
str
]:
return
[
self
.
avg
(),
self
.
max
()]
return
[
self
.
avg
(),
self
.
max
()]
def
__enter__
(
self
)
->
None
:
def
__enter__
(
self
)
->
None
:
...
@@ -118,8 +118,8 @@ class TimeCollector:
...
@@ -118,8 +118,8 @@ class TimeCollector:
def
__exit__
(
def
__exit__
(
self
,
self
,
exc_type
:
Optional
[
type
[
BaseException
]
]
,
exc_type
:
type
[
BaseException
]
|
None
,
exc_value
:
Optional
[
BaseException
]
,
exc_value
:
BaseException
|
None
,
exc_traceback
:
Optional
[
TracebackType
]
,
exc_traceback
:
TracebackType
|
None
,
)
->
None
:
)
->
None
:
self
.
collect
(
time
.
monotonic_ns
()
-
self
.
start_time
)
self
.
collect
(
time
.
monotonic_ns
()
-
self
.
start_time
)
benchmarks/cutlass_benchmarks/sparse_benchmarks.py
View file @
41199996
...
@@ -6,8 +6,7 @@ import copy
...
@@ -6,8 +6,7 @@ import copy
import
itertools
import
itertools
import
pickle
as
pkl
import
pickle
as
pkl
import
time
import
time
from
collections.abc
import
Iterable
from
collections.abc
import
Callable
,
Iterable
from
typing
import
Callable
import
torch
import
torch
import
torch.utils.benchmark
as
TBenchmark
import
torch.utils.benchmark
as
TBenchmark
...
@@ -16,7 +15,7 @@ from utils import make_rand_sparse_tensors
...
@@ -16,7 +15,7 @@ from utils import make_rand_sparse_tensors
from
weight_shapes
import
WEIGHT_SHAPES
from
weight_shapes
import
WEIGHT_SHAPES
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
.argparse_utils
import
FlexibleArgumentParser
DEFAULT_MODELS
=
list
(
WEIGHT_SHAPES
.
keys
())
DEFAULT_MODELS
=
list
(
WEIGHT_SHAPES
.
keys
())
DEFAULT_BATCH_SIZES
=
[
1
,
16
,
32
,
64
,
128
,
256
,
512
]
DEFAULT_BATCH_SIZES
=
[
1
,
16
,
32
,
64
,
128
,
256
,
512
]
...
...
benchmarks/cutlass_benchmarks/w8a8_benchmarks.py
View file @
41199996
...
@@ -6,8 +6,7 @@ import copy
...
@@ -6,8 +6,7 @@ import copy
import
itertools
import
itertools
import
pickle
as
pkl
import
pickle
as
pkl
import
time
import
time
from
collections.abc
import
Iterable
from
collections.abc
import
Callable
,
Iterable
from
typing
import
Callable
,
Optional
import
torch
import
torch
import
torch.utils.benchmark
as
TBenchmark
import
torch.utils.benchmark
as
TBenchmark
...
@@ -17,9 +16,10 @@ from weight_shapes import WEIGHT_SHAPES
...
@@ -17,9 +16,10 @@ from weight_shapes import WEIGHT_SHAPES
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
w8a8_
block_fp8_matmul
,
w8a8_
triton_block_scaled_mm
,
)
)
from
vllm.utils
import
FlexibleArgumentParser
,
cdiv
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
from
vllm.utils.math_utils
import
cdiv
DEFAULT_MODELS
=
list
(
WEIGHT_SHAPES
.
keys
())
DEFAULT_MODELS
=
list
(
WEIGHT_SHAPES
.
keys
())
DEFAULT_BATCH_SIZES
=
[
1
,
16
,
32
,
64
,
128
,
256
,
512
]
DEFAULT_BATCH_SIZES
=
[
1
,
16
,
32
,
64
,
128
,
256
,
512
]
...
@@ -53,7 +53,7 @@ def bench_int8(
...
@@ -53,7 +53,7 @@ def bench_int8(
n
:
int
,
n
:
int
,
label
:
str
,
label
:
str
,
sub_label
:
str
,
sub_label
:
str
,
bench_kernels
:
Optional
[
list
[
str
]
]
=
None
,
bench_kernels
:
list
[
str
]
|
None
=
None
,
)
->
Iterable
[
TMeasurement
]:
)
->
Iterable
[
TMeasurement
]:
"""Benchmark INT8-based kernels."""
"""Benchmark INT8-based kernels."""
assert
dtype
==
torch
.
int8
assert
dtype
==
torch
.
int8
...
@@ -108,7 +108,7 @@ def bench_fp8(
...
@@ -108,7 +108,7 @@ def bench_fp8(
n
:
int
,
n
:
int
,
label
:
str
,
label
:
str
,
sub_label
:
str
,
sub_label
:
str
,
bench_kernels
:
Optional
[
list
[
str
]
]
=
None
,
bench_kernels
:
list
[
str
]
|
None
=
None
,
)
->
Iterable
[
TMeasurement
]:
)
->
Iterable
[
TMeasurement
]:
"""Benchmark FP8-based kernels."""
"""Benchmark FP8-based kernels."""
assert
dtype
==
torch
.
float8_e4m3fn
assert
dtype
==
torch
.
float8_e4m3fn
...
@@ -158,7 +158,7 @@ def bench_fp8(
...
@@ -158,7 +158,7 @@ def bench_fp8(
"cutlass_fp8_fp8_fp16_scaled_mm_bias"
:
lambda
:
ops
.
cutlass_scaled_mm
(
"cutlass_fp8_fp8_fp16_scaled_mm_bias"
:
lambda
:
ops
.
cutlass_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
torch
.
float16
,
bias
.
to
(
dtype
=
torch
.
float16
)
a
,
b
,
scale_a
,
scale_b
,
torch
.
float16
,
bias
.
to
(
dtype
=
torch
.
float16
)
),
),
"triton_fp8_fp8_fp16_scaled_mm_blockwise"
:
lambda
:
w8a8_
block_fp8_matmul
(
"triton_fp8_fp8_fp16_scaled_mm_blockwise"
:
lambda
:
w8a8_
triton_block_scaled_mm
(
a_cont
,
b
.
t
(),
block_scale_a
,
block_scale_b
.
t
(),
(
128
,
128
)
a_cont
,
b
.
t
(),
block_scale_a
,
block_scale_b
.
t
(),
(
128
,
128
)
),
),
"cutlass_fp8_fp8_fp16_scaled_mm_blockwise"
:
lambda
:
ops
.
cutlass_scaled_mm
(
"cutlass_fp8_fp8_fp16_scaled_mm_blockwise"
:
lambda
:
ops
.
cutlass_scaled_mm
(
...
@@ -183,7 +183,7 @@ def bench(
...
@@ -183,7 +183,7 @@ def bench(
n
:
int
,
n
:
int
,
label
:
str
,
label
:
str
,
sub_label
:
str
,
sub_label
:
str
,
bench_kernels
:
Optional
[
list
[
str
]
]
=
None
,
bench_kernels
:
list
[
str
]
|
None
=
None
,
)
->
Iterable
[
TMeasurement
]:
)
->
Iterable
[
TMeasurement
]:
if
dtype
==
torch
.
int8
:
if
dtype
==
torch
.
int8
:
return
bench_int8
(
dtype
,
m
,
k
,
n
,
label
,
sub_label
,
bench_kernels
)
return
bench_int8
(
dtype
,
m
,
k
,
n
,
label
,
sub_label
,
bench_kernels
)
...
@@ -201,7 +201,7 @@ def print_timers(timers: Iterable[TMeasurement]):
...
@@ -201,7 +201,7 @@ def print_timers(timers: Iterable[TMeasurement]):
def
run
(
def
run
(
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
MKNs
:
Iterable
[
tuple
[
int
,
int
,
int
]],
MKNs
:
Iterable
[
tuple
[
int
,
int
,
int
]],
bench_kernels
:
Optional
[
list
[
str
]
]
=
None
,
bench_kernels
:
list
[
str
]
|
None
=
None
,
)
->
Iterable
[
TMeasurement
]:
)
->
Iterable
[
TMeasurement
]:
results
=
[]
results
=
[]
for
m
,
k
,
n
in
MKNs
:
for
m
,
k
,
n
in
MKNs
:
...
...
benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh
View file @
41199996
...
@@ -55,9 +55,7 @@ benchmark() {
...
@@ -55,9 +55,7 @@ benchmark() {
output_len
=
$2
output_len
=
$2
CUDA_VISIBLE_DEVICES
=
0 python3
\
CUDA_VISIBLE_DEVICES
=
0 vllm serve
$model
\
-m
vllm.entrypoints.openai.api_server
\
--model
$model
\
--port
8100
\
--port
8100
\
--max-model-len
10000
\
--max-model-len
10000
\
--gpu-memory-utilization
0.6
\
--gpu-memory-utilization
0.6
\
...
@@ -65,9 +63,7 @@ benchmark() {
...
@@ -65,9 +63,7 @@ benchmark() {
'{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":5e9}'
&
'{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":5e9}'
&
CUDA_VISIBLE_DEVICES
=
1 python3
\
CUDA_VISIBLE_DEVICES
=
1 vllm serve
$model
\
-m
vllm.entrypoints.openai.api_server
\
--model
$model
\
--port
8200
\
--port
8200
\
--max-model-len
10000
\
--max-model-len
10000
\
--gpu-memory-utilization
0.6
\
--gpu-memory-utilization
0.6
\
...
...
benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh
View file @
41199996
...
@@ -38,16 +38,12 @@ wait_for_server() {
...
@@ -38,16 +38,12 @@ wait_for_server() {
launch_chunked_prefill
()
{
launch_chunked_prefill
()
{
model
=
"meta-llama/Meta-Llama-3.1-8B-Instruct"
model
=
"meta-llama/Meta-Llama-3.1-8B-Instruct"
# disagg prefill
# disagg prefill
CUDA_VISIBLE_DEVICES
=
0 python3
\
CUDA_VISIBLE_DEVICES
=
0 vllm serve
$model
\
-m
vllm.entrypoints.openai.api_server
\
--model
$model
\
--port
8100
\
--port
8100
\
--max-model-len
10000
\
--max-model-len
10000
\
--enable-chunked-prefill
\
--enable-chunked-prefill
\
--gpu-memory-utilization
0.6 &
--gpu-memory-utilization
0.6 &
CUDA_VISIBLE_DEVICES
=
1 python3
\
CUDA_VISIBLE_DEVICES
=
1 vllm serve
$model
\
-m
vllm.entrypoints.openai.api_server
\
--model
$model
\
--port
8200
\
--port
8200
\
--max-model-len
10000
\
--max-model-len
10000
\
--enable-chunked-prefill
\
--enable-chunked-prefill
\
...
@@ -62,18 +58,14 @@ launch_chunked_prefill() {
...
@@ -62,18 +58,14 @@ launch_chunked_prefill() {
launch_disagg_prefill
()
{
launch_disagg_prefill
()
{
model
=
"meta-llama/Meta-Llama-3.1-8B-Instruct"
model
=
"meta-llama/Meta-Llama-3.1-8B-Instruct"
# disagg prefill
# disagg prefill
CUDA_VISIBLE_DEVICES
=
0 python3
\
CUDA_VISIBLE_DEVICES
=
0 vllm serve
$model
\
-m
vllm.entrypoints.openai.api_server
\
--model
$model
\
--port
8100
\
--port
8100
\
--max-model-len
10000
\
--max-model-len
10000
\
--gpu-memory-utilization
0.6
\
--gpu-memory-utilization
0.6
\
--kv-transfer-config
\
--kv-transfer-config
\
'{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":5e9}'
&
'{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":5e9}'
&
CUDA_VISIBLE_DEVICES
=
1 python3
\
CUDA_VISIBLE_DEVICES
=
1 vllm serve
$model
\
-m
vllm.entrypoints.openai.api_server
\
--model
$model
\
--port
8200
\
--port
8200
\
--max-model-len
10000
\
--max-model-len
10000
\
--gpu-memory-utilization
0.6
\
--gpu-memory-utilization
0.6
\
...
...
benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py
View file @
41199996
...
@@ -5,11 +5,12 @@ import argparse
...
@@ -5,11 +5,12 @@ import argparse
import
asyncio
import
asyncio
import
logging
import
logging
import
os
import
os
import
time
import
uuid
from
urllib.parse
import
urlparse
import
aiohttp
import
aiohttp
from
quart
import
Quart
,
Response
,
make_response
,
request
from
quart
import
Quart
,
Response
,
make_response
,
request
from
rate_limiter
import
RateLimiter
from
request_queue
import
RequestQueue
# Configure logging
# Configure logging
logging
.
basicConfig
(
level
=
logging
.
INFO
)
logging
.
basicConfig
(
level
=
logging
.
INFO
)
...
@@ -24,26 +25,8 @@ def parse_args():
...
@@ -24,26 +25,8 @@ def parse_args():
parser
.
add_argument
(
parser
.
add_argument
(
"--timeout"
,
"--timeout"
,
type
=
float
,
type
=
float
,
default
=
300
,
default
=
6
*
60
*
60
,
help
=
"Timeout for backend service requests in seconds (default: 300)"
,
help
=
"Timeout for backend service requests in seconds (default: 21600)"
,
)
parser
.
add_argument
(
"--max-concurrent"
,
type
=
int
,
default
=
100
,
help
=
"Maximum concurrent requests to backend services (default: 100)"
,
)
parser
.
add_argument
(
"--queue-size"
,
type
=
int
,
default
=
500
,
help
=
"Maximum number of requests in the queue (default: 500)"
,
)
parser
.
add_argument
(
"--rate-limit"
,
type
=
int
,
default
=
40
,
help
=
"Maximum requests per second (default: 40)"
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--port"
,
"--port"
,
...
@@ -54,14 +37,32 @@ def parse_args():
...
@@ -54,14 +37,32 @@ def parse_args():
parser
.
add_argument
(
parser
.
add_argument
(
"--prefill-url"
,
"--prefill-url"
,
type
=
str
,
type
=
str
,
default
=
"http://localhost:8100
/v1/completions
"
,
default
=
"http://localhost:8100"
,
help
=
"Prefill service
endpoint URL
"
,
help
=
"Prefill service
base URL (protocol + host[:port])
"
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--decode-url"
,
"--decode-url"
,
type
=
str
,
type
=
str
,
default
=
"http://localhost:8200/v1/completions"
,
default
=
"http://localhost:8200"
,
help
=
"Decode service endpoint URL"
,
help
=
"Decode service base URL (protocol + host[:port])"
,
)
parser
.
add_argument
(
"--kv-host"
,
type
=
str
,
default
=
"localhost"
,
help
=
"Hostname or IP used by KV transfer (default: localhost)"
,
)
parser
.
add_argument
(
"--prefill-kv-port"
,
type
=
int
,
default
=
14579
,
help
=
"Prefill KV port (default: 14579)"
,
)
parser
.
add_argument
(
"--decode-kv-port"
,
type
=
int
,
default
=
14580
,
help
=
"Decode KV port (default: 14580)"
,
)
)
return
parser
.
parse_args
()
return
parser
.
parse_args
()
...
@@ -73,70 +74,129 @@ def main():
...
@@ -73,70 +74,129 @@ def main():
# Initialize configuration using command line parameters
# Initialize configuration using command line parameters
AIOHTTP_TIMEOUT
=
aiohttp
.
ClientTimeout
(
total
=
args
.
timeout
)
AIOHTTP_TIMEOUT
=
aiohttp
.
ClientTimeout
(
total
=
args
.
timeout
)
MAX_CONCURRENT_REQUESTS
=
args
.
max_concurrent
REQUEST_QUEUE_SIZE
=
args
.
queue_size
RATE_LIMIT
=
args
.
rate_limit
PREFILL_SERVICE_URL
=
args
.
prefill_url
PREFILL_SERVICE_URL
=
args
.
prefill_url
DECODE_SERVICE_URL
=
args
.
decode_url
DECODE_SERVICE_URL
=
args
.
decode_url
PORT
=
args
.
port
PORT
=
args
.
port
app
=
Quart
(
__name__
)
PREFILL_KV_ADDR
=
f
"
{
args
.
kv_host
}
:
{
args
.
prefill_kv_port
}
"
DECODE_KV_ADDR
=
f
"
{
args
.
kv_host
}
:
{
args
.
decode_kv_port
}
"
# Initialize the rate limiter and request queue
logger
.
info
(
rate_limiter
=
RateLimiter
(
RATE_LIMIT
)
"Proxy resolved KV addresses -> prefill: %s, decode: %s"
,
request_queue
=
RequestQueue
(
MAX_CONCURRENT_REQUESTS
,
REQUEST_QUEUE_SIZE
)
PREFILL_KV_ADDR
,
DECODE_KV_ADDR
,
)
app
=
Quart
(
__name__
)
# Attach the configuration object to the application instance
# Attach the configuration object to the application instance so helper
# coroutines can read the resolved backend URLs and timeouts without using
# globals.
app
.
config
.
update
(
app
.
config
.
update
(
{
{
"AIOHTTP_TIMEOUT"
:
AIOHTTP_TIMEOUT
,
"AIOHTTP_TIMEOUT"
:
AIOHTTP_TIMEOUT
,
"rate_limiter"
:
rate_limiter
,
"request_queue"
:
request_queue
,
"PREFILL_SERVICE_URL"
:
PREFILL_SERVICE_URL
,
"PREFILL_SERVICE_URL"
:
PREFILL_SERVICE_URL
,
"DECODE_SERVICE_URL"
:
DECODE_SERVICE_URL
,
"DECODE_SERVICE_URL"
:
DECODE_SERVICE_URL
,
"PREFILL_KV_ADDR"
:
PREFILL_KV_ADDR
,
"DECODE_KV_ADDR"
:
DECODE_KV_ADDR
,
}
}
)
)
# Start queue processing on app startup
def
_normalize_base_url
(
url
:
str
)
->
str
:
@
app
.
before_serving
"""Remove any trailing slash so path joins behave predictably."""
async
def
startup
():
return
url
.
rstrip
(
"/"
)
"""Start request processing task when app starts serving"""
asyncio
.
create_task
(
request_queue
.
process
())
def
_get_host_port
(
url
:
str
)
->
str
:
"""Return the hostname:port portion for logging and KV headers."""
async
def
forward_request
(
url
,
data
):
parsed
=
urlparse
(
url
)
"""Forward request to backend service with rate limiting and error handling"""
host
=
parsed
.
hostname
or
"localhost"
headers
=
{
"Authorization"
:
f
"Bearer
{
os
.
environ
.
get
(
'OPENAI_API_KEY'
)
}
"
}
port
=
parsed
.
port
if
port
is
None
:
# Use rate limiter as context manager
port
=
80
if
parsed
.
scheme
==
"http"
else
443
async
with
(
return
f
"
{
host
}
:
{
port
}
"
rate_limiter
,
aiohttp
.
ClientSession
(
timeout
=
AIOHTTP_TIMEOUT
)
as
session
,
PREFILL_BASE
=
_normalize_base_url
(
PREFILL_SERVICE_URL
)
):
DECODE_BASE
=
_normalize_base_url
(
DECODE_SERVICE_URL
)
try
:
KV_TARGET
=
_get_host_port
(
DECODE_SERVICE_URL
)
async
with
session
.
post
(
url
=
url
,
json
=
data
,
headers
=
headers
def
_build_headers
(
request_id
:
str
)
->
dict
[
str
,
str
]:
)
as
response
:
"""Construct the headers expected by vLLM's P2P disagg connector."""
if
response
.
status
==
200
:
headers
:
dict
[
str
,
str
]
=
{
"X-Request-Id"
:
request_id
,
"X-KV-Target"
:
KV_TARGET
}
# Stream response chunks
api_key
=
os
.
environ
.
get
(
"OPENAI_API_KEY"
)
async
for
chunk_bytes
in
response
.
content
.
iter_chunked
(
1024
):
if
api_key
:
yield
chunk_bytes
headers
[
"Authorization"
]
=
f
"Bearer
{
api_key
}
"
else
:
return
headers
# Handle backend service errors
error_text
=
await
response
.
text
()
async
def
_run_prefill
(
logger
.
error
(
request_path
:
str
,
"Backend service error: %s - %s"
,
payload
:
dict
,
response
.
status
,
headers
:
dict
[
str
,
str
],
error_text
,
request_id
:
str
,
)
):
yield
b
'{"error": "Backend service error"}'
url
=
f
"
{
PREFILL_BASE
}{
request_path
}
"
except
aiohttp
.
ClientError
as
e
:
start_ts
=
time
.
perf_counter
()
# Handle connection errors
logger
.
info
(
"[prefill] start request_id=%s url=%s"
,
request_id
,
url
)
logger
.
error
(
"Connection error to %s: %s"
,
url
,
str
(
e
))
try
:
yield
b
'{"error": "Service unavailable"}'
async
with
(
except
asyncio
.
TimeoutError
:
aiohttp
.
ClientSession
(
timeout
=
AIOHTTP_TIMEOUT
)
as
session
,
# Handle timeout errors
session
.
post
(
url
=
url
,
json
=
payload
,
headers
=
headers
)
as
resp
,
logger
.
error
(
"Timeout connecting to %s"
,
url
)
):
yield
b
'{"error": "Service timeout"}'
if
resp
.
status
!=
200
:
error_text
=
await
resp
.
text
()
raise
RuntimeError
(
f
"Prefill backend error
{
resp
.
status
}
:
{
error_text
}
"
)
await
resp
.
read
()
logger
.
info
(
"[prefill] done request_id=%s status=%s elapsed=%.2fs"
,
request_id
,
resp
.
status
,
time
.
perf_counter
()
-
start_ts
,
)
except
asyncio
.
TimeoutError
as
exc
:
raise
RuntimeError
(
f
"Prefill service timeout at
{
url
}
"
)
from
exc
except
aiohttp
.
ClientError
as
exc
:
raise
RuntimeError
(
f
"Prefill service unavailable at
{
url
}
"
)
from
exc
async
def
_stream_decode
(
request_path
:
str
,
payload
:
dict
,
headers
:
dict
[
str
,
str
],
request_id
:
str
,
):
url
=
f
"
{
DECODE_BASE
}{
request_path
}
"
# Stream tokens from the decode service once the prefill stage has
# materialized KV caches on the target workers.
logger
.
info
(
"[decode] start request_id=%s url=%s"
,
request_id
,
url
)
try
:
async
with
(
aiohttp
.
ClientSession
(
timeout
=
AIOHTTP_TIMEOUT
)
as
session
,
session
.
post
(
url
=
url
,
json
=
payload
,
headers
=
headers
)
as
resp
,
):
if
resp
.
status
!=
200
:
error_text
=
await
resp
.
text
()
logger
.
error
(
"Decode backend error %s - %s"
,
resp
.
status
,
error_text
)
err_msg
=
(
'{"error": "Decode backend error '
+
str
(
resp
.
status
)
+
'"}'
)
yield
err_msg
.
encode
()
return
logger
.
info
(
"[decode] streaming response request_id=%s status=%s"
,
request_id
,
resp
.
status
,
)
async
for
chunk_bytes
in
resp
.
content
.
iter_chunked
(
1024
):
yield
chunk_bytes
logger
.
info
(
"[decode] finished streaming request_id=%s"
,
request_id
)
except
asyncio
.
TimeoutError
:
logger
.
error
(
"Decode service timeout at %s"
,
url
)
yield
b
'{"error": "Decode service timeout"}'
except
aiohttp
.
ClientError
as
exc
:
logger
.
error
(
"Decode service error at %s: %s"
,
url
,
exc
)
yield
b
'{"error": "Decode service unavailable"}'
async
def
process_request
():
async
def
process_request
():
"""Process a single request through prefill and decode stages"""
"""Process a single request through prefill and decode stages"""
...
@@ -146,13 +206,27 @@ def main():
...
@@ -146,13 +206,27 @@ def main():
# Create prefill request (max_tokens=1)
# Create prefill request (max_tokens=1)
prefill_request
=
original_request_data
.
copy
()
prefill_request
=
original_request_data
.
copy
()
prefill_request
[
"max_tokens"
]
=
1
prefill_request
[
"max_tokens"
]
=
1
if
"max_completion_tokens"
in
prefill_request
:
prefill_request
[
"max_completion_tokens"
]
=
1
# Execute prefill stage
# Execute prefill stage
async
for
_
in
forward_request
(
PREFILL_SERVICE_URL
,
prefill_request
):
# The request id encodes both KV socket addresses so the backend can
continue
# shuttle tensors directly via NCCL once the prefill response
# completes.
request_id
=
(
f
"___prefill_addr_
{
PREFILL_KV_ADDR
}
___decode_addr_"
f
"
{
DECODE_KV_ADDR
}
_
{
uuid
.
uuid4
().
hex
}
"
)
headers
=
_build_headers
(
request_id
)
await
_run_prefill
(
request
.
path
,
prefill_request
,
headers
,
request_id
)
# Execute decode stage and stream response
# Execute decode stage and stream response
generator
=
forward_request
(
DECODE_SERVICE_URL
,
original_request_data
)
# Pass the unmodified user request so the decode phase can continue
# sampling with the already-populated KV cache.
generator
=
_stream_decode
(
request
.
path
,
original_request_data
,
headers
,
request_id
)
response
=
await
make_response
(
generator
)
response
=
await
make_response
(
generator
)
response
.
timeout
=
None
# Disable timeout for streaming response
response
.
timeout
=
None
# Disable timeout for streaming response
return
response
return
response
...
@@ -168,23 +242,10 @@ def main():
...
@@ -168,23 +242,10 @@ def main():
@
app
.
route
(
"/v1/completions"
,
methods
=
[
"POST"
])
@
app
.
route
(
"/v1/completions"
,
methods
=
[
"POST"
])
async
def
handle_request
():
async
def
handle_request
():
"""Handle incoming API requests with concurrency and rate limiting"""
"""Handle incoming API requests with concurrency and rate limiting"""
# Create task for request processing
task
=
asyncio
.
create_task
(
process_request
())
# Enqueue request or reject if queue is full
if
not
await
request_queue
.
enqueue
(
task
):
return
Response
(
response
=
b
'{"error": "Server busy, try again later"}'
,
status
=
503
,
content_type
=
"application/json"
,
)
try
:
try
:
# Return the response from the processing task
return
await
process_request
()
return
await
task
except
asyncio
.
CancelledError
:
except
asyncio
.
CancelledError
:
# Handle task cancellation (timeout or queue full)
logger
.
warning
(
"Request cancelled"
)
logger
.
warning
(
"Request cancelled due to timeout or queue full"
)
return
Response
(
return
Response
(
response
=
b
'{"error": "Request cancelled"}'
,
response
=
b
'{"error": "Request cancelled"}'
,
status
=
503
,
status
=
503
,
...
...
benchmarks/fused_kernels/layernorm_rms_benchmarks.py
View file @
41199996
...
@@ -3,10 +3,9 @@
...
@@ -3,10 +3,9 @@
import
pickle
as
pkl
import
pickle
as
pkl
import
time
import
time
from
collections.abc
import
Iterable
from
collections.abc
import
Callable
,
Iterable
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
itertools
import
product
from
itertools
import
product
from
typing
import
Callable
,
Optional
import
torch
import
torch
import
torch.utils.benchmark
as
TBenchmark
import
torch.utils.benchmark
as
TBenchmark
...
@@ -51,7 +50,7 @@ def get_bench_params() -> list[bench_params_t]:
...
@@ -51,7 +50,7 @@ def get_bench_params() -> list[bench_params_t]:
def
unfused_int8_impl
(
def
unfused_int8_impl
(
rms_norm_layer
:
RMSNorm
,
rms_norm_layer
:
RMSNorm
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
,
residual
:
torch
.
Tensor
|
None
,
quant_dtype
:
torch
.
dtype
,
quant_dtype
:
torch
.
dtype
,
):
):
# Norm
# Norm
...
@@ -68,7 +67,7 @@ def unfused_int8_impl(
...
@@ -68,7 +67,7 @@ def unfused_int8_impl(
def
unfused_fp8_impl
(
def
unfused_fp8_impl
(
rms_norm_layer
:
RMSNorm
,
rms_norm_layer
:
RMSNorm
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
,
residual
:
torch
.
Tensor
|
None
,
quant_dtype
:
torch
.
dtype
,
quant_dtype
:
torch
.
dtype
,
):
):
# Norm
# Norm
...
@@ -85,7 +84,7 @@ def unfused_fp8_impl(
...
@@ -85,7 +84,7 @@ def unfused_fp8_impl(
def
fused_impl
(
def
fused_impl
(
rms_norm_layer
:
RMSNorm
,
# this stores the weights
rms_norm_layer
:
RMSNorm
,
# this stores the weights
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
,
residual
:
torch
.
Tensor
|
None
,
quant_dtype
:
torch
.
dtype
,
quant_dtype
:
torch
.
dtype
,
):
):
out
,
_
=
ops
.
rms_norm_dynamic_per_token_quant
(
out
,
_
=
ops
.
rms_norm_dynamic_per_token_quant
(
...
...
benchmarks/kernels/bench_block_fp8_gemm.py
View file @
41199996
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
os
# Disable DeepGEMM for this benchmark to use CUTLASS
os
.
environ
[
"VLLM_USE_DEEP_GEMM"
]
=
"0"
import
torch
import
torch
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
apply_w8a8_block_fp8_linear
,
W8A8BlockFp8LinearOp
,
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
GroupShape
,
)
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
CUTLASS_BLOCK_FP8_SUPPORTED
,
CUTLASS_BLOCK_FP8_SUPPORTED
,
...
@@ -39,13 +47,14 @@ def build_w8a8_block_fp8_runner(M, N, K, block_size, device, use_cutlass):
...
@@ -39,13 +47,14 @@ def build_w8a8_block_fp8_runner(M, N, K, block_size, device, use_cutlass):
fp8_info
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
fp8_info
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
fp8_max
,
fp8_min
=
fp8_info
.
max
,
fp8_info
.
min
fp8_max
,
fp8_min
=
fp8_info
.
max
,
fp8_info
.
min
# Create random
FP8
tensor
s
# Create random
input
tensor
(bfloat16, will be quantized by W8A8BlockFp8LinearOp)
A_ref
=
(
torch
.
rand
(
M
,
K
,
dtype
=
torch
.
bfloat16
,
device
=
device
)
-
0.5
)
*
2
*
fp8_max
A_ref
=
(
torch
.
rand
(
M
,
K
,
dtype
=
torch
.
bfloat16
,
device
=
device
)
-
0.5
)
*
2
*
fp8_max
# Create quantized weight tensor
B_ref
=
(
torch
.
rand
(
N
,
K
,
dtype
=
torch
.
bfloat16
,
device
=
device
)
-
0.5
)
*
2
*
fp8_max
B_ref
=
(
torch
.
rand
(
N
,
K
,
dtype
=
torch
.
bfloat16
,
device
=
device
)
-
0.5
)
*
2
*
fp8_max
B
=
B_ref
.
clamp
(
min
=
fp8_min
,
max
=
fp8_max
).
to
(
torch
.
float8_e4m3fn
)
B
=
B_ref
.
clamp
(
min
=
fp8_min
,
max
=
fp8_max
).
to
(
torch
.
float8_e4m3fn
)
# Create scales
# Create
weight
scales
block_n
,
block_k
=
block_size
[
0
],
block_size
[
1
]
block_n
,
block_k
=
block_size
[
0
],
block_size
[
1
]
n_tiles
=
(
N
+
block_n
-
1
)
//
block_n
n_tiles
=
(
N
+
block_n
-
1
)
//
block_n
k_tiles
=
(
K
+
block_k
-
1
)
//
block_k
k_tiles
=
(
K
+
block_k
-
1
)
//
block_k
...
@@ -55,19 +64,25 @@ def build_w8a8_block_fp8_runner(M, N, K, block_size, device, use_cutlass):
...
@@ -55,19 +64,25 @@ def build_w8a8_block_fp8_runner(M, N, K, block_size, device, use_cutlass):
*
factor_for_scale
*
factor_for_scale
)
)
# SM90 CUTLASS requires row-major format for scales
# Create W8A8BlockFp8LinearOp instance
if
use_cutlass
and
current_platform
.
is_device_capability
(
90
):
weight_group_shape
=
GroupShape
(
block_n
,
block_k
)
Bs
=
Bs
.
T
.
contiguous
()
act_quant_group_shape
=
GroupShape
(
1
,
block_k
)
# Per-token, per-group quantization
linear_op
=
W8A8BlockFp8LinearOp
(
weight_group_shape
=
weight_group_shape
,
act_quant_group_shape
=
act_quant_group_shape
,
cutlass_block_fp8_supported
=
use_cutlass
,
use_aiter_and_is_supported
=
False
,
)
def
run
():
def
run
():
if
use_cutlass
:
return
linear_op
.
apply
(
return
apply_w8a8_block_fp8_linear
(
input
=
A_ref
,
A_ref
,
B
,
block_size
,
Bs
,
cutlass_block_fp8_supported
=
True
weight
=
B
,
)
weight_scale
=
Bs
,
else
:
input_scale
=
None
,
return
apply_w8a8_block_fp8_linear
(
bias
=
None
,
A_ref
,
B
,
block_size
,
Bs
,
cutlass_block_fp8_supported
=
False
)
)
return
run
return
run
...
...
benchmarks/kernels/bench_mxfp4_qutlass.py
0 → 100644
View file @
41199996
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
#
# Copyright (C) 2025 Roberto L. Castro (Roberto.LopezCastro@ist.ac.at).
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import
argparse
import
copy
import
itertools
import
torch
from
compressed_tensors.transform.utils.hadamard
import
deterministic_hadamard_matrix
from
weight_shapes
import
WEIGHT_SHAPES
from
vllm._custom_ops
import
fusedQuantizeMx
,
matmul_mxf4_bf16_tn
from
vllm.model_executor.layers.quantization.qutlass_utils
import
to_blocked
from
vllm.triton_utils
import
triton
PROVIDER_CFGS
=
{
"torch-bf16"
:
dict
(
enabled
=
True
),
"mxfp4"
:
dict
(
no_a_quant
=
False
,
enabled
=
True
),
"mxfp4-noquant"
:
dict
(
no_a_quant
=
True
,
enabled
=
True
),
}
_enabled
=
[
k
for
k
,
v
in
PROVIDER_CFGS
.
items
()
if
v
[
"enabled"
]]
def
get_hadamard_matrix
(
group_size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
):
return
(
deterministic_hadamard_matrix
(
group_size
,
dtype
=
dtype
,
device
=
device
)
*
group_size
**-
0.5
)
def
_quant_weight_mxfp4
(
b
:
torch
.
Tensor
,
forward_hadamard_matrix
:
torch
.
Tensor
,
device
:
str
):
weight_hf_e2m1
,
weight_hf_e8m0
=
fusedQuantizeMx
(
b
,
forward_hadamard_matrix
,
method
=
"abs_max"
)
weight_hf_scale_block
=
to_blocked
(
weight_hf_e8m0
,
backend
=
"triton"
)
return
weight_hf_e2m1
,
weight_hf_scale_block
def
build_mxfp4_runner
(
cfg
,
a
,
b
,
forward_hadamard_matrix
,
dtype
,
device
):
weight_hf_e2m1
,
weight_hf_scale_block
=
_quant_weight_mxfp4
(
b
,
forward_hadamard_matrix
,
device
)
alpha
=
torch
.
tensor
([
1.0
],
device
=
"cuda"
)
if
cfg
[
"no_a_quant"
]:
# Pre-quantize activation
input_hf_e2m1
,
input_hf_e8m0
=
fusedQuantizeMx
(
a
,
forward_hadamard_matrix
,
method
=
"abs_max"
)
input_hf_scale_block
=
to_blocked
(
input_hf_e8m0
,
backend
=
"triton"
)
def
run
():
return
matmul_mxf4_bf16_tn
(
input_hf_e2m1
,
weight_hf_e2m1
,
input_hf_scale_block
,
weight_hf_scale_block
,
alpha
,
)
return
run
# Quantize activation on-the-fly
def
run
():
input_hf_e2m1
,
input_hf_e8m0
=
fusedQuantizeMx
(
a
,
forward_hadamard_matrix
,
method
=
"abs_max"
)
input_hf_scale_block
=
to_blocked
(
input_hf_e8m0
,
backend
=
"triton"
)
return
matmul_mxf4_bf16_tn
(
input_hf_e2m1
,
weight_hf_e2m1
,
input_hf_scale_block
,
weight_hf_scale_block
,
alpha
,
)
return
run
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
],
x_vals
=
[
1
,
4
,
8
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
,
8192
,
16384
,
24576
,
32768
,
],
x_log
=
False
,
line_arg
=
"provider"
,
line_vals
=
_enabled
,
line_names
=
_enabled
,
ylabel
=
"TFLOP/s (larger is better)"
,
plot_name
=
"BF16 vs MXFP4 GEMMs"
,
args
=
{},
)
)
def
benchmark
(
batch_size
,
provider
,
N
,
K
,
had_size
):
M
=
batch_size
device
=
"cuda"
dtype
=
torch
.
bfloat16
a
=
torch
.
randn
((
M
,
K
),
device
=
device
,
dtype
=
dtype
)
b
=
torch
.
randn
((
N
,
K
),
device
=
device
,
dtype
=
dtype
)
forward_hadamard_matrix
=
get_hadamard_matrix
(
had_size
,
dtype
,
device
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"torch-bf16"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
lambda
:
torch
.
nn
.
functional
.
linear
(
a
,
b
),
rep
=
200
,
quantiles
=
quantiles
)
else
:
cfg
=
PROVIDER_CFGS
[
provider
]
run_quant
=
build_mxfp4_runner
(
cfg
,
a
,
b
,
forward_hadamard_matrix
,
dtype
,
device
)
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
lambda
:
run_quant
(),
rep
=
200
,
quantiles
=
quantiles
)
to_tflops
=
lambda
t_ms
:
(
2
*
M
*
N
*
K
)
*
1e-12
/
(
t_ms
*
1e-3
)
return
to_tflops
(
ms
),
to_tflops
(
max_ms
),
to_tflops
(
min_ms
)
def
prepare_shapes
(
args
):
out
=
[]
for
model
,
tp_size
in
itertools
.
product
(
args
.
models
,
args
.
tp_sizes
):
for
KN
,
tp_dim
in
copy
.
deepcopy
(
WEIGHT_SHAPES
[
model
]):
KN
[
tp_dim
]
//=
tp_size
KN
.
append
(
model
)
out
.
append
(
KN
)
return
out
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--models"
,
nargs
=
"+"
,
type
=
str
,
default
=
[
"meta-llama/Llama-3.3-70B-Instruct"
],
choices
=
list
(
WEIGHT_SHAPES
.
keys
()),
)
parser
.
add_argument
(
"--tp-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
[
1
])
args
=
parser
.
parse_args
()
for
K
,
N
,
model
in
prepare_shapes
(
args
):
for
had_size
in
[
32
,
64
,
128
]:
print
(
f
"
{
model
}
, N=
{
N
}
K=
{
K
}
, HAD=
{
had_size
}
, BF16 vs MXFP4 GEMMs TFLOP/s:"
)
benchmark
.
run
(
print_data
=
True
,
show_plots
=
True
,
save_path
=
f
"bench_mxfp4_res_n
{
N
}
_k
{
K
}
"
,
N
=
N
,
K
=
K
,
had_size
=
had_size
,
)
print
(
"Benchmark finished!"
)
benchmarks/kernels/bench_nvfp4_qutlass.py
0 → 100644
View file @
41199996
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
#
# Copyright (C) 2025 Roberto L. Castro (Roberto.LopezCastro@ist.ac.at).
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import
argparse
import
copy
import
itertools
import
torch
from
compressed_tensors.transform.utils.hadamard
import
deterministic_hadamard_matrix
from
weight_shapes
import
WEIGHT_SHAPES
from
vllm
import
_custom_ops
as
ops
# use existing nvfp4 gemm in vllm
from
vllm._custom_ops
import
fusedQuantizeNv
from
vllm.model_executor.layers.quantization.qutlass_utils
import
to_blocked
from
vllm.triton_utils
import
triton
PROVIDER_CFGS
=
{
"torch-bf16"
:
dict
(
enabled
=
True
),
"nvfp4"
:
dict
(
no_a_quant
=
False
,
enabled
=
True
),
"nvfp4-noquant"
:
dict
(
no_a_quant
=
True
,
enabled
=
True
),
}
_enabled
=
[
k
for
k
,
v
in
PROVIDER_CFGS
.
items
()
if
v
[
"enabled"
]]
def
get_hadamard_matrix
(
group_size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
):
return
(
deterministic_hadamard_matrix
(
group_size
,
dtype
=
dtype
,
device
=
device
)
*
group_size
**-
0.5
)
def
_quant_weight_nvfp4
(
b
:
torch
.
Tensor
,
forward_hadamard_matrix
:
torch
.
Tensor
,
global_scale
:
torch
.
Tensor
,
device
:
str
,
M
:
int
,
N
:
int
,
K
:
int
,
):
weight_hf_e2m1
,
weight_hf_e8m0
=
fusedQuantizeNv
(
b
,
forward_hadamard_matrix
,
global_scale
)
weight_hf_scale_block
=
to_blocked
(
weight_hf_e8m0
,
backend
=
"triton"
).
view
(
-
1
,
K
//
16
)
return
weight_hf_e2m1
,
weight_hf_scale_block
def
build_nvfp4_runner
(
cfg
,
a
,
b
,
forward_hadamard_matrix
,
dtype
,
device
,
M
,
N
,
K
):
alpha
=
torch
.
tensor
([
1.0
],
device
=
"cuda"
)
global_scale
=
torch
.
tensor
([
1.0
],
device
=
"cuda"
)
weight_hf_e2m1
,
weight_hf_scale_block
=
_quant_weight_nvfp4
(
b
,
forward_hadamard_matrix
,
global_scale
,
device
,
M
,
N
,
K
)
if
cfg
[
"no_a_quant"
]:
# Pre-quantize activation
input_hf_e2m1
,
input_hf_e8m0
=
fusedQuantizeNv
(
a
,
forward_hadamard_matrix
,
global_scale
)
input_hf_scale_block
=
to_blocked
(
input_hf_e8m0
,
backend
=
"triton"
).
view
(
-
1
,
K
//
16
)
def
run
():
return
ops
.
cutlass_scaled_fp4_mm
(
input_hf_e2m1
,
weight_hf_e2m1
,
input_hf_scale_block
,
weight_hf_scale_block
,
alpha
,
torch
.
bfloat16
,
)
return
run
# Quantize activation on-the-fly
def
run
():
input_hf_e2m1
,
input_hf_e8m0
=
fusedQuantizeNv
(
a
,
forward_hadamard_matrix
,
global_scale
)
input_hf_scale_block
=
to_blocked
(
input_hf_e8m0
,
backend
=
"triton"
).
view
(
-
1
,
K
//
16
)
return
ops
.
cutlass_scaled_fp4_mm
(
input_hf_e2m1
,
weight_hf_e2m1
,
input_hf_scale_block
,
weight_hf_scale_block
,
alpha
,
torch
.
bfloat16
,
)
return
run
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
],
x_vals
=
[
1
,
4
,
8
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
,
8192
,
16384
,
24576
,
32768
,
],
x_log
=
False
,
line_arg
=
"provider"
,
line_vals
=
_enabled
,
line_names
=
_enabled
,
ylabel
=
"TFLOP/s (larger is better)"
,
plot_name
=
"BF16 vs NVFP4 GEMMs"
,
args
=
{},
)
)
def
benchmark
(
batch_size
,
provider
,
N
,
K
,
had_size
):
M
=
batch_size
device
=
"cuda"
dtype
=
torch
.
bfloat16
a
=
torch
.
randn
((
M
,
K
),
device
=
device
,
dtype
=
dtype
)
b
=
torch
.
randn
((
N
,
K
),
device
=
device
,
dtype
=
dtype
)
forward_hadamard_matrix
=
get_hadamard_matrix
(
had_size
,
dtype
,
device
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"torch-bf16"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
lambda
:
torch
.
nn
.
functional
.
linear
(
a
,
b
),
rep
=
200
,
quantiles
=
quantiles
)
else
:
cfg
=
PROVIDER_CFGS
[
provider
]
run_quant
=
build_nvfp4_runner
(
cfg
,
a
,
b
,
forward_hadamard_matrix
,
dtype
,
device
,
M
,
N
,
K
)
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
lambda
:
run_quant
(),
rep
=
200
,
quantiles
=
quantiles
)
to_tflops
=
lambda
t_ms
:
(
2
*
M
*
N
*
K
)
*
1e-12
/
(
t_ms
*
1e-3
)
return
to_tflops
(
ms
),
to_tflops
(
max_ms
),
to_tflops
(
min_ms
)
def
prepare_shapes
(
args
):
out
=
[]
for
model
,
tp_size
in
itertools
.
product
(
args
.
models
,
args
.
tp_sizes
):
for
KN
,
tp_dim
in
copy
.
deepcopy
(
WEIGHT_SHAPES
[
model
]):
KN
[
tp_dim
]
//=
tp_size
KN
.
append
(
model
)
out
.
append
(
KN
)
return
out
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--models"
,
nargs
=
"+"
,
type
=
str
,
default
=
[
"meta-llama/Llama-3.3-70B-Instruct"
],
choices
=
list
(
WEIGHT_SHAPES
.
keys
()),
)
parser
.
add_argument
(
"--tp-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
[
1
])
args
=
parser
.
parse_args
()
for
K
,
N
,
model
in
prepare_shapes
(
args
):
for
had_size
in
[
16
,
32
,
64
,
128
]:
print
(
f
"
{
model
}
, N=
{
N
}
K=
{
K
}
, HAD=
{
had_size
}
, BF16 vs NVFP4 GEMMs TFLOP/s:"
)
benchmark
.
run
(
print_data
=
True
,
show_plots
=
True
,
save_path
=
f
"bench_nvfp4_res_n
{
N
}
_k
{
K
}
"
,
N
=
N
,
K
=
K
,
had_size
=
had_size
,
)
print
(
"Benchmark finished!"
)
Prev
1
2
3
4
5
6
7
8
9
…
19
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment