Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a99300bd
Commit
a99300bd
authored
Sep 09, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.10.2rc1' into v0.10.2rc1-dev
parents
cc3e01c7
5438967f
Changes
512
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1571 additions
and
170 deletions
+1571
-170
tests/evals/gsm8k/conftest.py
tests/evals/gsm8k/conftest.py
+66
-0
tests/evals/gsm8k/gsm8k_eval.py
tests/evals/gsm8k/gsm8k_eval.py
+252
-0
tests/evals/gsm8k/test_gsm8k_correctness.py
tests/evals/gsm8k/test_gsm8k_correctness.py
+90
-0
tests/kernels/attention/test_cache.py
tests/kernels/attention/test_cache.py
+93
-5
tests/kernels/attention/test_flashinfer_trtllm_attention.py
tests/kernels/attention/test_flashinfer_trtllm_attention.py
+233
-129
tests/kernels/attention/test_flashmla.py
tests/kernels/attention/test_flashmla.py
+51
-18
tests/kernels/attention/untest_attention_selector.py
tests/kernels/attention/untest_attention_selector.py
+3
-0
tests/kernels/attention/untest_flashinfer.py
tests/kernels/attention/untest_flashinfer.py
+2
-4
tests/kernels/mamba/untest_mamba_ssm_ssd.py
tests/kernels/mamba/untest_mamba_ssm_ssd.py
+0
-0
tests/kernels/moe/test_deepep_deepgemm_moe.py
tests/kernels/moe/test_deepep_deepgemm_moe.py
+6
-4
tests/kernels/moe/test_deepep_moe.py
tests/kernels/moe/test_deepep_moe.py
+3
-0
tests/kernels/moe/test_flashinfer.py
tests/kernels/moe/test_flashinfer.py
+248
-0
tests/kernels/moe/test_grouped_topk.py
tests/kernels/moe/test_grouped_topk.py
+76
-0
tests/kernels/moe/test_modular_kernel_combinations.py
tests/kernels/moe/test_modular_kernel_combinations.py
+2
-0
tests/kernels/moe/test_moe.py
tests/kernels/moe/test_moe.py
+1
-1
tests/kernels/moe/test_mxfp4_moe.py
tests/kernels/moe/test_mxfp4_moe.py
+419
-1
tests/kernels/moe/test_pplx_moe.py
tests/kernels/moe/test_pplx_moe.py
+5
-0
tests/kernels/moe/untest_block_fp8.py
tests/kernels/moe/untest_block_fp8.py
+2
-3
tests/kernels/moe/untest_cutlass_moe.py
tests/kernels/moe/untest_cutlass_moe.py
+14
-4
tests/kernels/moe/untest_moe_permute_unpermute.py
tests/kernels/moe/untest_moe_permute_unpermute.py
+5
-1
No files found.
Too many changes to show.
To preserve performance only
512 of 512+
files are displayed.
Plain diff
Email patch
tests/evals/gsm8k/conftest.py
0 → 100644
View file @
a99300bd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
pathlib
import
Path
def
pytest_addoption
(
parser
):
"""Add custom command line options."""
parser
.
addoption
(
"--config-list-file"
,
default
=
"configs/models-small.txt"
,
help
=
"File containing list of config files to test"
)
parser
.
addoption
(
"--tp-size"
,
default
=
1
,
type
=
int
,
help
=
"Tensor parallel size"
)
def
pytest_generate_tests
(
metafunc
):
"""Generate test parameters from config files."""
if
"config_filename"
in
metafunc
.
fixturenames
:
config_list_file
=
metafunc
.
config
.
getoption
(
"--config-list-file"
)
tp_size
=
metafunc
.
config
.
getoption
(
"--tp-size"
)
# Handle both relative and absolute paths
config_list_path
=
Path
(
config_list_file
)
if
not
config_list_path
.
is_absolute
():
# If relative, try relative to test directory first
test_dir_path
=
Path
(
__file__
).
parent
/
config_list_file
if
test_dir_path
.
exists
():
config_list_path
=
test_dir_path
else
:
# Try relative to current working directory
config_list_path
=
Path
.
cwd
()
/
config_list_file
print
(
f
"Looking for config list at:
{
config_list_path
}
"
)
config_files
=
[]
if
config_list_path
.
exists
():
# Determine config directory (same directory as the list file)
config_dir
=
config_list_path
.
parent
with
open
(
config_list_path
)
as
f
:
for
line
in
f
:
line
=
line
.
strip
()
if
line
and
not
line
.
startswith
(
"#"
):
config_path
=
config_dir
/
line
print
(
f
"Checking config file:
{
config_path
}
"
)
if
config_path
.
exists
():
config_files
.
append
(
config_path
)
print
(
f
" ✓ Found:
{
config_path
}
"
)
else
:
print
(
f
" ✗ Missing:
{
config_path
}
"
)
else
:
print
(
f
"Config list file not found:
{
config_list_path
}
"
)
# Generate test parameters
if
config_files
:
metafunc
.
parametrize
([
"config_filename"
,
"tp_size"
],
[(
config_file
,
int
(
tp_size
))
for
config_file
in
config_files
],
ids
=
[
f
"
{
config_file
.
stem
}
-tp
{
tp_size
}
"
for
config_file
in
config_files
])
else
:
print
(
"No config files found, test will be skipped"
)
tests/evals/gsm8k/gsm8k_eval.py
0 → 100644
View file @
a99300bd
#!/usr/bin/env python3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Isolated GSM8K evaluation script for vLLM serve endpoint.
"""
import
argparse
import
ast
import
asyncio
import
json
import
os
import
time
from
collections.abc
import
Generator
from
typing
import
Optional
,
Union
import
aiohttp
import
numpy
as
np
import
regex
as
re
import
requests
from
tqdm.asyncio
import
tqdm
INVALID
=
-
9999999
def
download_and_cache_file
(
url
:
str
,
filename
:
Optional
[
str
]
=
None
)
->
str
:
"""Download and cache a file from a URL."""
if
filename
is
None
:
filename
=
os
.
path
.
join
(
"/tmp"
,
url
.
split
(
"/"
)[
-
1
])
if
os
.
path
.
exists
(
filename
):
return
filename
print
(
f
"Downloading from
{
url
}
to
{
filename
}
"
)
response
=
requests
.
get
(
url
,
stream
=
True
)
response
.
raise_for_status
()
with
open
(
filename
,
"wb"
)
as
f
:
for
chunk
in
response
.
iter_content
(
chunk_size
=
1024
):
f
.
write
(
chunk
)
return
filename
def
load_gsm8k_data
()
->
tuple
[
list
[
dict
],
list
[
dict
]]:
"""Load GSM8K train and test data"""
train_url
=
"https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/train.jsonl"
test_url
=
"https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl"
train_file
=
download_and_cache_file
(
train_url
)
test_file
=
download_and_cache_file
(
test_url
)
train_data
=
list
(
read_jsonl
(
train_file
))
test_data
=
list
(
read_jsonl
(
test_file
))
return
train_data
,
test_data
def
read_jsonl
(
filename
:
str
)
->
Generator
[
dict
,
None
,
None
]:
"""Read a JSONL file."""
with
open
(
filename
)
as
fin
:
for
line
in
fin
:
if
not
line
.
startswith
(
"#"
):
yield
json
.
loads
(
line
)
def
get_answer_value
(
answer_str
:
str
)
->
int
:
"""Extract the numerical answer from the response."""
answer_str
=
answer_str
.
replace
(
","
,
""
)
numbers
=
re
.
findall
(
r
"\d+"
,
answer_str
)
if
len
(
numbers
)
<
1
:
return
INVALID
try
:
return
ast
.
literal_eval
(
numbers
[
-
1
])
except
SyntaxError
:
return
INVALID
async
def
call_vllm_api
(
session
:
aiohttp
.
ClientSession
,
prompt
:
str
,
temperature
:
float
,
max_tokens
:
int
,
stop
:
Optional
[
list
[
str
]]
=
None
,
url
:
Optional
[
str
]
=
None
,
seed
:
Optional
[
int
]
=
None
)
->
str
:
"""Call vLLM's OpenAI-compatible completions endpoint."""
data
=
{
"prompt"
:
prompt
,
"temperature"
:
temperature
,
"max_tokens"
:
max_tokens
,
"stop"
:
stop
,
}
if
seed
is
not
None
:
data
[
"seed"
]
=
seed
try
:
async
with
session
.
post
(
f
"
{
url
}
/v1/completions"
,
json
=
data
)
as
response
:
response
.
raise_for_status
()
result
=
await
response
.
json
()
return
result
[
"choices"
][
0
][
"text"
]
except
Exception
as
e
:
print
(
f
"Error calling vLLM API:
{
e
}
"
)
return
""
def
evaluate_gsm8k
(
num_questions
:
int
=
1319
,
num_shots
:
int
=
5
,
max_tokens
:
int
=
256
,
host
:
str
=
"http://127.0.0.1"
,
port
:
int
=
8000
,
temperature
:
float
=
0.0
,
seed
:
Optional
[
int
]
=
42
)
->
dict
[
str
,
Union
[
float
,
int
]]:
"""
Evaluate GSM8K accuracy using vLLM serve endpoint.
Returns dict with accuracy, invalid_rate, latency, etc.
"""
base_url
=
f
"
{
host
}
:
{
port
}
"
# Load GSM8K train and test data
train_data
,
test_data
=
load_gsm8k_data
()
# Limit to available test questions
num_questions
=
min
(
num_questions
,
len
(
test_data
))
# Build few-shot examples from train split (like lm-eval does)
few_shot_examples
=
""
for
i
in
range
(
num_shots
):
few_shot_examples
+=
(
f
"Question:
{
train_data
[
i
][
'question'
]
}
\n
"
f
"Answer:
{
train_data
[
i
][
'answer'
]
}
\n\n
"
)
# Prepare test questions and labels from test split
questions
=
[]
labels
=
[]
for
i
in
range
(
num_questions
):
questions
.
append
(
f
"Question:
{
test_data
[
i
][
'question'
]
}
\n
Answer:"
)
labels
.
append
(
get_answer_value
(
test_data
[
i
][
"answer"
]))
assert
all
(
label
!=
INVALID
for
label
in
labels
),
"Some labels are invalid"
# Run evaluation
async
def
run_async_evaluation
():
states
:
list
[
str
]
=
[
""
]
*
num_questions
async
def
get_answer
(
session
:
aiohttp
.
ClientSession
,
i
:
int
)
->
str
:
prompt
=
few_shot_examples
+
questions
[
i
]
answer
=
await
call_vllm_api
(
session
=
session
,
prompt
=
prompt
,
temperature
=
temperature
,
max_tokens
=
max_tokens
,
stop
=
[
"Question"
,
"Assistant:"
,
"<|separator|>"
],
url
=
base_url
,
seed
=
seed
,
)
states
[
i
]
=
answer
return
answer
async
with
aiohttp
.
ClientSession
(
timeout
=
aiohttp
.
ClientTimeout
(
total
=
600
))
as
session
:
tasks
=
[
get_answer
(
session
,
i
)
for
i
in
range
(
num_questions
)]
await
tqdm
.
gather
(
*
tasks
,
desc
=
"Evaluating"
)
return
states
print
(
f
"Running GSM8K evaluation:
{
num_questions
}
questions, "
f
"
{
num_shots
}
-shot"
)
tic
=
time
.
perf_counter
()
states
=
asyncio
.
run
(
run_async_evaluation
())
latency
=
time
.
perf_counter
()
-
tic
# Compute metrics
preds
=
[
get_answer_value
(
state
)
for
state
in
states
]
accuracy
=
np
.
mean
(
np
.
array
(
preds
)
==
np
.
array
(
labels
))
invalid_rate
=
np
.
mean
(
np
.
array
(
preds
)
==
INVALID
)
result
=
{
"accuracy"
:
accuracy
,
"invalid_rate"
:
invalid_rate
,
"latency"
:
latency
,
"questions_per_second"
:
num_questions
/
latency
,
"num_questions"
:
num_questions
,
"num_shots"
:
num_shots
,
"max_tokens"
:
max_tokens
,
"timestamp"
:
time
.
time
(),
}
return
result
def
main
()
->
None
:
parser
=
argparse
.
ArgumentParser
(
description
=
"GSM8K evaluation for vLLM serve"
)
parser
.
add_argument
(
"--num-shots"
,
type
=
int
,
default
=
5
,
help
=
"Number of few-shot examples"
)
parser
.
add_argument
(
"--num-questions"
,
type
=
int
,
default
=
1319
,
help
=
"Number of questions to evaluate"
)
parser
.
add_argument
(
"--max-tokens"
,
type
=
int
,
default
=
256
,
help
=
"Max tokens for generation"
)
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"http://127.0.0.1"
,
help
=
"Host URL"
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
8000
,
help
=
"Port number"
)
parser
.
add_argument
(
"--temperature"
,
type
=
float
,
default
=
0.0
,
help
=
"Temperature for generation"
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
42
,
help
=
"Random seed for reproducibility"
)
parser
.
add_argument
(
"--save-results"
,
type
=
str
,
help
=
"Save results to JSON file"
)
args
=
parser
.
parse_args
()
result
=
evaluate_gsm8k
(
num_questions
=
args
.
num_questions
,
num_shots
=
args
.
num_shots
,
max_tokens
=
args
.
max_tokens
,
host
=
args
.
host
,
port
=
args
.
port
,
temperature
=
args
.
temperature
,
seed
=
args
.
seed
,
)
# Print results to terminal
print
(
"
\n
Results:"
)
print
(
f
"Accuracy:
{
result
[
'accuracy'
]:.
3
f
}
"
)
print
(
f
"Invalid responses:
{
result
[
'invalid_rate'
]:.
3
f
}
"
)
print
(
f
"Total latency:
{
result
[
'latency'
]:.
3
f
}
s"
)
print
(
f
"Questions per second:
{
result
[
'questions_per_second'
]:.
3
f
}
"
)
# Optional file saving
if
args
.
save_results
:
with
open
(
args
.
save_results
,
"w"
)
as
f
:
json
.
dump
(
result
,
f
,
indent
=
2
)
print
(
f
"Results saved to
{
args
.
save_results
}
"
)
if
__name__
==
"__main__"
:
main
()
tests/evals/gsm8k/test_gsm8k_correctness.py
0 → 100644
View file @
a99300bd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
GSM8K evaluation using vLLM server and isolated GSM8K script.
Replacement for lm-eval-harness with better performance and control.
Usage:
pytest -s -v test_gsm8k_correctness.py
\
--config-list-file=configs/models-small.txt
\
--tp-size=1
"""
import
yaml
from
tests.utils
import
RemoteOpenAIServer
from
.gsm8k_eval
import
evaluate_gsm8k
RTOL
=
0.08
# Relative tolerance for accuracy comparison
def
launch_gsm8k_eval
(
eval_config
,
server_url
,
tp_size
):
"""Launch GSM8K evaluation using our isolated script."""
# Extract host and port from server URL
if
"://"
in
server_url
:
server_url
=
server_url
.
split
(
"://"
)[
1
]
host_port
=
server_url
.
split
(
"/"
)[
0
]
# Remove path if present
if
":"
in
host_port
:
host
,
port
=
host_port
.
split
(
":"
)
port
=
int
(
port
)
else
:
host
=
host_port
port
=
8000
# Add http:// prefix if not present
if
not
host
.
startswith
(
"http"
):
host
=
f
"http://
{
host
}
"
# Run GSM8K evaluation
results
=
evaluate_gsm8k
(
num_questions
=
eval_config
[
"num_questions"
],
num_shots
=
eval_config
[
"num_fewshot"
],
host
=
host
,
port
=
port
,
)
return
results
def
test_gsm8k_correctness_param
(
config_filename
,
tp_size
):
"""Test GSM8K correctness for a given model configuration."""
eval_config
=
yaml
.
safe_load
(
config_filename
.
read_text
(
encoding
=
"utf-8"
))
# Server arguments
server_args
=
[
"--max-model-len"
,
str
(
eval_config
.
get
(
"max_model_len"
,
4096
)),
"--enforce-eager"
,
"--trust-remote-code"
,
"--tensor-parallel-size"
,
str
(
tp_size
),
]
# Launch server and run evaluation
with
RemoteOpenAIServer
(
eval_config
[
"model_name"
],
server_args
,
max_wait_seconds
=
480
)
as
remote_server
:
server_url
=
remote_server
.
url_for
(
"v1"
)
results
=
launch_gsm8k_eval
(
eval_config
,
server_url
,
tp_size
)
# Check accuracy against threshold
measured_accuracy
=
results
[
"accuracy"
]
expected_accuracy
=
eval_config
[
"accuracy_threshold"
]
print
(
f
"GSM8K Results for
{
eval_config
[
'model_name'
]
}
:"
)
print
(
f
" Accuracy:
{
measured_accuracy
:.
3
f
}
"
)
print
(
f
" Expected:
{
expected_accuracy
:.
3
f
}
"
)
print
(
f
" Questions:
{
results
[
'num_questions'
]
}
"
)
print
(
f
" Invalid rate:
{
results
[
'invalid_rate'
]:.
3
f
}
"
)
print
(
f
" Latency:
{
results
[
'latency'
]:.
1
f
}
s"
)
print
(
f
" QPS:
{
results
[
'questions_per_second'
]:.
1
f
}
"
)
# Verify accuracy is within tolerance
assert
measured_accuracy
>=
expected_accuracy
-
RTOL
,
(
f
"Accuracy too low:
{
measured_accuracy
:.
3
f
}
< "
f
"
{
expected_accuracy
:.
3
f
}
-
{
RTOL
:.
3
f
}
"
)
print
(
f
"✅ GSM8K test passed for
{
eval_config
[
'model_name'
]
}
"
)
tests/kernels/attention/test_cache.py
View file @
a99300bd
...
@@ -705,6 +705,94 @@ def test_swap_blocks_mla(
...
@@ -705,6 +705,94 @@ def test_swap_blocks_mla(
f
"
{
dst
}
in dst_cache."
)
f
"
{
dst
}
in dst_cache."
)
@
pytest
.
mark
.
parametrize
(
"kv_lora_rank"
,
[
512
])
@
pytest
.
mark
.
parametrize
(
"qk_rope_head_dim"
,
[
64
])
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
[
1024
])
@
pytest
.
mark
.
parametrize
(
"max_seq_len"
,
[
512
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
])
@
pytest
.
mark
.
parametrize
(
"kv_cache_dtype"
,
[
"auto"
,
"fp8"
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
def
test_gather_and_maybe_dequant_cache_mla
(
kv_lora_rank
,
qk_rope_head_dim
,
block_size
,
num_blocks
,
max_seq_len
,
batch_size
,
dtype
,
kv_cache_dtype
,
device
):
entry_size
=
kv_lora_rank
+
qk_rope_head_dim
scale
=
torch
.
tensor
(
0.1
,
dtype
=
torch
.
float32
,
device
=
device
)
src_cache
=
_create_mla_cache
(
num_blocks
,
block_size
,
entry_size
,
dtype
,
kv_cache_dtype
,
device
)
_fill_mla_cache
(
src_cache
,
kv_cache_dtype
=
kv_cache_dtype
)
seq_len_tensor
=
torch
.
randint
(
0
,
max_seq_len
+
1
,
(
batch_size
,
),
device
=
device
)
total_tokens
=
seq_len_tensor
.
sum
()
cu_seq_lens
=
torch
.
empty
((
batch_size
+
1
),
dtype
=
torch
.
int32
,
device
=
device
)
cu_seq_lens
[
0
]
=
0
cu_seq_lens
[
1
:]
=
seq_len_tensor
.
cumsum
(
dim
=
0
).
to
(
dtype
=
torch
.
int32
)
print
(
"seq_len_tensor"
,
seq_len_tensor
)
tot_blocks_tensor
=
(
seq_len_tensor
+
block_size
-
1
)
//
block_size
block_table
=
torch
.
empty
((
batch_size
,
num_blocks
),
dtype
=
torch
.
int32
,
device
=
device
)
for
b
in
range
(
batch_size
):
perm
=
torch
.
randperm
(
num_blocks
,
device
=
device
)
block_table
[
b
,
:]
=
perm
dst
=
torch
.
zeros
((
total_tokens
,
entry_size
),
dtype
=
dtype
,
device
=
device
)
expected_batches
=
[]
for
b
in
range
(
batch_size
):
s
=
seq_len_tensor
[
b
]
if
s
==
0
:
continue
tot
=
tot_blocks_tensor
[
b
]
blocks
=
block_table
[
b
,
:
tot
].
tolist
()
gathered_rows
=
[]
for
i
in
range
(
tot
-
1
):
block_data
=
src_cache
[
blocks
[
i
]]
if
kv_cache_dtype
==
"fp8"
:
dequantized_block
=
torch
.
empty_like
(
block_data
,
dtype
=
dtype
)
ops
.
convert_fp8
(
dequantized_block
,
block_data
,
scale
.
item
())
gathered_rows
.
append
(
dequantized_block
)
else
:
gathered_rows
.
append
(
block_data
)
remaining
=
s
-
(
tot
-
1
)
*
block_size
last_block_data
=
src_cache
[
blocks
[
-
1
],
:
remaining
,
:]
if
kv_cache_dtype
==
"fp8"
:
dequantized_last_block
=
torch
.
empty_like
(
last_block_data
,
dtype
=
dtype
)
ops
.
convert_fp8
(
dequantized_last_block
,
last_block_data
,
scale
.
item
())
gathered_rows
.
append
(
dequantized_last_block
)
else
:
gathered_rows
.
append
(
last_block_data
)
batch_expected
=
torch
.
cat
(
gathered_rows
,
dim
=
0
)
expected_batches
.
append
(
batch_expected
)
expected
=
torch
.
cat
(
expected_batches
,
dim
=
0
)
opcheck
(
torch
.
ops
.
_C_cache_ops
.
gather_and_maybe_dequant_cache
,
(
src_cache
,
dst
,
block_table
,
cu_seq_lens
,
batch_size
,
kv_cache_dtype
,
scale
,
None
),
test_utils
=
DEFAULT_OPCHECK_TEST_UTILS
,
)
ops
.
gather_and_maybe_dequant_cache
(
src_cache
,
dst
,
block_table
,
cu_seq_lens
,
batch_size
,
kv_cache_dtype
,
scale
,
None
)
torch
.
testing
.
assert_close
(
dst
,
expected
)
@
pytest
.
mark
.
parametrize
(
"kv_lora_rank"
,
[
512
])
@
pytest
.
mark
.
parametrize
(
"kv_lora_rank"
,
[
512
])
@
pytest
.
mark
.
parametrize
(
"qk_rope_head_dim"
,
[
64
])
@
pytest
.
mark
.
parametrize
(
"qk_rope_head_dim"
,
[
64
])
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
16
])
...
@@ -716,9 +804,9 @@ def test_swap_blocks_mla(
...
@@ -716,9 +804,9 @@ def test_swap_blocks_mla(
[
"auto"
])
# You can also test "fp8" if needed.
[
"auto"
])
# You can also test "fp8" if needed.
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_gather_cache_mla
(
kv_lora_rank
,
qk_rope_head_dim
,
block_size
,
def
test_
cp_
gather_cache_mla
(
kv_lora_rank
,
qk_rope_head_dim
,
block_size
,
num_blocks
,
max_seq_len
,
batch_size
,
dtype
,
num_blocks
,
max_seq_len
,
batch_size
,
dtype
,
kv_cache_dtype
,
device
):
kv_cache_dtype
,
device
):
entry_size
=
kv_lora_rank
+
qk_rope_head_dim
entry_size
=
kv_lora_rank
+
qk_rope_head_dim
src_cache
=
_create_mla_cache
(
num_blocks
,
block_size
,
entry_size
,
dtype
,
src_cache
=
_create_mla_cache
(
num_blocks
,
block_size
,
entry_size
,
dtype
,
kv_cache_dtype
,
device
)
kv_cache_dtype
,
device
)
...
@@ -768,12 +856,12 @@ def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,
...
@@ -768,12 +856,12 @@ def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,
expected
=
torch
.
cat
(
expected_batches
,
dim
=
0
)
expected
=
torch
.
cat
(
expected_batches
,
dim
=
0
)
opcheck
(
opcheck
(
torch
.
ops
.
_C_cache_ops
.
gather_cache
,
torch
.
ops
.
_C_cache_ops
.
cp_
gather_cache
,
(
src_cache
,
dst
,
block_table
,
cu_seq_lens
,
batch_size
,
None
),
(
src_cache
,
dst
,
block_table
,
cu_seq_lens
,
batch_size
,
None
),
test_utils
=
DEFAULT_OPCHECK_TEST_UTILS
,
test_utils
=
DEFAULT_OPCHECK_TEST_UTILS
,
)
)
ops
.
gather_cache
(
src_cache
,
dst
,
block_table
,
cu_seq_lens
,
batch_size
)
ops
.
cp_
gather_cache
(
src_cache
,
dst
,
block_table
,
cu_seq_lens
,
batch_size
)
torch
.
testing
.
assert_close
(
dst
,
expected
)
torch
.
testing
.
assert_close
(
dst
,
expected
)
...
...
tests/kernels/attention/test_flashinfer_trtllm_attention.py
View file @
a99300bd
...
@@ -6,28 +6,19 @@ import flashinfer
...
@@ -6,28 +6,19 @@ import flashinfer
import
pytest
import
pytest
import
torch
import
torch
from
tests.kernels.quantization.nvfp4_utils
import
(
FLOAT4_E2M1_MAX
,
FLOAT8_E4M3_MAX
,
dequantize_nvfp4_to_dtype
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
round_up
if
not
current_platform
.
is_device_capability
(
100
):
if
not
current_platform
.
is_device_capability
(
100
):
pytest
.
skip
(
"This TRTLLM kernel requires NVIDIA Blackwell."
,
pytest
.
skip
(
"This TRTLLM kernel requires NVIDIA Blackwell."
,
allow_module_level
=
True
)
allow_module_level
=
True
)
FLOAT32_BYTES
=
torch
.
finfo
(
torch
.
float
).
bits
//
8
FLOAT32_BYTES
=
torch
.
finfo
(
torch
.
float
).
bits
//
8
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
# KV Cache Layout for TRT-LLM
FP4_DTYPE
=
torch
.
uint8
# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim)
MAX_Q_LEN
=
1024
MAX_KV_LEN
=
4096
BATCH_SIZES
=
[
4
,
12
]
NUM_HEADS
=
[(
16
,
16
),
(
40
,
8
)]
HEAD_SIZES
=
[
128
]
BLOCK_SIZES
=
[
16
]
KV_LAYOUTS
=
[
"HND"
]
DTYPES
=
[
torch
.
bfloat16
]
KV_CACHE_DTYPES
=
[
None
,
current_platform
.
fp8_dtype
()]
NUM_BLOCKS
=
32768
# Large enough to test overflow in index calculation.
SOFT_CAPS
=
[
None
,
50.0
]
def
to_float8
(
x
,
dtype
=
torch
.
float8_e4m3fn
):
def
to_float8
(
x
,
dtype
=
torch
.
float8_e4m3fn
):
...
@@ -39,42 +30,61 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
...
@@ -39,42 +30,61 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
return
x_scl_sat
.
to
(
dtype
),
scale
.
float
().
reciprocal
()
return
x_scl_sat
.
to
(
dtype
),
scale
.
float
().
reciprocal
()
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
BATCH_SIZES
)
DTYPE
=
[
torch
.
bfloat16
]
QUANT_DTYPES
=
[
# (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
(
None
,
None
,
None
),
(
None
,
FP8_DTYPE
,
None
),
(
FP8_DTYPE
,
FP8_DTYPE
,
FP8_DTYPE
),
(
FP8_DTYPE
,
FP8_DTYPE
,
FP4_DTYPE
),
]
BATCH_SIZE
=
[
4
,
12
]
MAX_SEQ_LENS
=
[(
1024
,
4096
)]
NUM_HEADS
=
[(
64
,
8
),
(
40
,
8
)]
HEAD_SIZE
=
[
128
]
KV_LAYOUT
=
[
"HND"
]
# currently only HND is supported
BLOCK_SIZE
=
[
16
]
SOFT_CAP
=
[
None
,
50.0
]
NUM_BLOCKS
=
32768
# Large enough to test overflow in index calculation.
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPE
)
@
pytest
.
mark
.
parametrize
(
"quant_dtypes"
,
QUANT_DTYPES
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
BATCH_SIZE
)
@
pytest
.
mark
.
parametrize
(
"max_seq_lens"
,
MAX_SEQ_LENS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZE
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"kv_layout"
,
KV_LAYOUT
)
@
pytest
.
mark
.
parametrize
(
"kv_layout"
,
KV_LAYOUTS
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZE
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
SOFT_CAP
)
@
pytest
.
mark
.
parametrize
(
"kv_cache_dtype"
,
KV_CACHE_DTYPES
)
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
SOFT_CAPS
)
@
torch
.
inference_mode
@
torch
.
inference_mode
def
test_flashinfer_trtllm_decode_with_baseline
(
def
test_flashinfer_trtllm_decode_with_baseline
(
dtype
:
torch
.
dtype
,
quant_dtypes
:
tuple
[
Optional
[
torch
.
dtype
],
Optional
[
torch
.
dtype
],
Optional
[
torch
.
dtype
]],
batch_size
:
int
,
batch_size
:
int
,
max_seq_lens
:
tuple
[
int
,
int
],
num_heads
:
tuple
[
int
,
int
],
num_heads
:
tuple
[
int
,
int
],
head_size
:
int
,
head_size
:
int
,
block_size
:
int
,
kv_layout
:
str
,
kv_layout
:
str
,
dtype
:
torch
.
dtype
,
block_size
:
int
,
kv_cache_dtype
:
Optional
[
torch
.
dtype
],
soft_cap
:
Optional
[
float
],
soft_cap
:
Optional
[
float
],
)
->
None
:
)
->
None
:
kv_cache_dtype
=
dtype
if
kv_cache_dtype
is
None
else
kv_cache_dtype
torch
.
set_default_device
(
"cuda"
)
torch
.
set_default_device
(
"cuda"
)
current_platform
.
seed_everything
(
0
)
current_platform
.
seed_everything
(
0
)
kv_lens
=
torch
.
randint
(
1
,
MAX_KV_LEN
,
(
batch_size
,
),
dtype
=
torch
.
int32
)
q_quant_dtype
,
kv_quant_dtype
,
o_quant_dtype
=
quant_dtypes
kv_lens
[
-
1
]
=
MAX_KV_LEN
q_quant_dtype
=
q_quant_dtype
or
dtype
max_kv_len
=
torch
.
max
(
kv_lens
).
item
()
kv_quant_dtype
=
kv_quant_dtype
or
dtype
num_seqs
=
len
(
kv_lens
)
o_quant_dtype
=
o_quant_dtype
or
dtype
num_query_heads
=
num_heads
[
0
]
_
,
max_kv_len
=
max_seq_lens
num_kv_heads
=
num_heads
[
1
]
assert
num_query_heads
%
num_kv_heads
==
0
scale
=
head_size
**-
0.5
num_qo_heads
,
num_kv_heads
=
num_heads
assert
num_qo_heads
%
num_kv_heads
==
0
query
=
torch
.
randn
(
num_seqs
,
num_query_heads
,
head_size
,
dtype
=
dtype
)
sm_scale
=
float
(
1.0
/
(
head_size
**
0.5
)
)
kv_cache_shape
=
None
kv_cache_shape
=
None
if
kv_layout
==
"NHD"
:
if
kv_layout
==
"NHD"
:
...
@@ -83,23 +93,40 @@ def test_flashinfer_trtllm_decode_with_baseline(
...
@@ -83,23 +93,40 @@ def test_flashinfer_trtllm_decode_with_baseline(
kv_cache_shape
=
(
NUM_BLOCKS
,
2
,
num_kv_heads
,
block_size
,
head_size
)
kv_cache_shape
=
(
NUM_BLOCKS
,
2
,
num_kv_heads
,
block_size
,
head_size
)
else
:
else
:
raise
ValueError
(
f
"Invalid kv_layout:
{
kv_layout
}
"
)
raise
ValueError
(
f
"Invalid kv_layout:
{
kv_layout
}
"
)
key_value_cache
=
torch
.
randn
(
kv_cache_shape
,
dtype
=
dtype
)
kv_scale
=
1.0
if
kv_cache_dtype
is
current_platform
.
fp8_dtype
():
key_value_cache
,
kv_scale
=
to_float8
(
key_value_cache
,
current_platform
.
fp8_dtype
())
max_num_blocks_per_seq
=
(
max_kv_len
+
block_size
-
1
)
//
block_size
query
=
torch
.
randn
(
batch_size
,
num_qo_heads
,
head_size
,
dtype
=
dtype
)
if
q_quant_dtype
==
FP8_DTYPE
:
query
,
q_scale
=
to_float8
(
query
)
ref_query
=
query
.
to
(
dtype
)
*
q_scale
else
:
q_scale
=
1.0
ref_query
=
query
kv_lens
=
torch
.
randint
(
1
,
max_kv_len
,
(
batch_size
,
),
dtype
=
torch
.
int32
)
kv_lens
[
-
1
]
=
max_kv_len
seq_lens
=
kv_lens
max_seq_len
=
torch
.
max
(
seq_lens
).
item
()
kv_cache
=
torch
.
randn
(
kv_cache_shape
,
dtype
=
dtype
)
if
kv_quant_dtype
==
FP8_DTYPE
:
kv_cache
,
kv_scale
=
to_float8
(
kv_cache
)
ref_kv_cache
=
kv_cache
.
to
(
dtype
)
*
kv_scale
else
:
kv_scale
=
1.0
ref_kv_cache
=
kv_cache
k_scale
=
v_scale
=
kv_scale
max_num_blocks_per_seq
=
(
max_seq_len
+
block_size
-
1
)
//
block_size
block_tables
=
torch
.
randint
(
0
,
block_tables
=
torch
.
randint
(
0
,
NUM_BLOCKS
,
NUM_BLOCKS
,
(
num_seqs
,
max_num_blocks_per_seq
),
(
batch_size
,
max_num_blocks_per_seq
),
dtype
=
torch
.
int32
)
dtype
=
torch
.
int32
)
k_scale
=
v_scale
=
kv_scale
kv_indptr
=
[
0
]
kv_indptr
=
[
0
]
kv_indices
=
[]
kv_indices
=
[]
kv_last_page_lens
=
[]
kv_last_page_lens
=
[]
for
i
in
range
(
num_seqs
):
for
i
in
range
(
batch_size
):
seq_len
=
kv
_lens
[
i
]
seq_len
=
seq
_lens
[
i
]
assert
seq_len
>
0
assert
seq_len
>
0
num_blocks
=
(
seq_len
+
block_size
-
1
)
//
block_size
num_blocks
=
(
seq_len
+
block_size
-
1
)
//
block_size
kv_indices
.
extend
(
block_tables
[
i
,
:
num_blocks
])
kv_indices
.
extend
(
block_tables
[
i
,
:
num_blocks
])
...
@@ -112,103 +139,120 @@ def test_flashinfer_trtllm_decode_with_baseline(
...
@@ -112,103 +139,120 @@ def test_flashinfer_trtllm_decode_with_baseline(
kv_indptr
=
torch
.
tensor
(
kv_indptr
,
dtype
=
torch
.
int32
)
kv_indptr
=
torch
.
tensor
(
kv_indptr
,
dtype
=
torch
.
int32
)
kv_indices
=
torch
.
tensor
(
kv_indices
,
dtype
=
torch
.
int32
)
kv_indices
=
torch
.
tensor
(
kv_indices
,
dtype
=
torch
.
int32
)
kv_last_page_lens
=
torch
.
tensor
(
kv_last_page_lens
,
dtype
=
torch
.
int32
)
kv_last_page_lens
=
torch
.
tensor
(
kv_last_page_lens
,
dtype
=
torch
.
int32
)
workspace_buffer
=
torch
.
zeros
(
128
*
1024
*
1024
,
dtype
=
torch
.
int8
)
workspace_buffer
=
torch
.
zeros
(
128
*
1024
*
1024
,
dtype
=
torch
.
int8
)
# Baseline Decode
wrapper
=
flashinfer
.
BatchDecodeWithPagedKVCacheWrapper
(
wrapper
=
flashinfer
.
BatchDecodeWithPagedKVCacheWrapper
(
workspace_buffer
,
workspace_buffer
,
kv_layout
,
use_tensor_cores
=
True
)
kv_layout
,
use_tensor_cores
=
((
num_query_heads
//
num_kv_heads
)
>
4
))
wrapper
.
plan
(
kv_indptr
,
wrapper
.
plan
(
kv_indptr
,
kv_indices
,
kv_indices
,
kv_last_page_lens
,
kv_last_page_lens
,
num_q
uery
_heads
,
num_q
o
_heads
,
num_kv_heads
,
num_kv_heads
,
head_size
,
head_size
,
block_size
,
block_size
,
"NONE"
,
"NONE"
,
sm_scale
=
scale
,
sm_scale
=
sm_
scale
,
q_data_type
=
dtype
,
q_data_type
=
dtype
,
kv_data_type
=
kv_cache_
dtype
,
kv_data_type
=
dtype
,
logits_soft_cap
=
soft_cap
)
logits_soft_cap
=
soft_cap
)
output
=
torch
.
empty
(
query
.
shape
,
dtype
=
dtype
)
output
=
torch
.
empty
(
ref_query
.
shape
,
dtype
=
dtype
)
wrapper
.
run
(
query
,
wrapper
.
run
(
ref_query
,
ref_kv_cache
,
out
=
output
)
key_value_cache
,
o_scale
=
1.0
k_scale
=
k_scale
,
o_sf_scale
=
None
v_scale
=
v_scale
,
if
o_quant_dtype
==
FP8_DTYPE
:
out
=
output
)
_
,
o_scale
=
to_float8
(
output
)
elif
o_quant_dtype
==
FP4_DTYPE
:
o_sf_scale
=
((
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
)
/
torch
.
amax
(
output
.
flatten
(),
dim
=-
1
)).
to
(
torch
.
float32
)
# TRTLLM Decode
# TRTLLM Decode
kv_lens_tensor
=
torch
.
tensor
(
kv_lens
,
dtype
=
torch
.
int32
)
if
o_quant_dtype
==
FP4_DTYPE
:
output_trtllm
=
torch
.
empty
(
query
.
shape
,
dtype
=
dtype
)
output_trtllm
=
flashinfer
.
utils
.
FP4Tensor
(
torch
.
empty
(
query
.
shape
[:
-
1
]
+
(
query
.
shape
[
-
1
]
//
2
,
),
dtype
=
torch
.
uint8
),
torch
.
empty
((
round_up
(
query
.
shape
[
0
],
128
),
round_up
(
query
.
shape
[
1
]
*
query
.
shape
[
2
]
//
16
,
4
)),
dtype
=
torch
.
float8_e4m3fn
),
)
else
:
output_trtllm
=
torch
.
empty
(
query
.
shape
,
dtype
=
o_quant_dtype
)
flashinfer
.
decode
.
trtllm_batch_decode_with_kv_cache
(
flashinfer
.
decode
.
trtllm_batch_decode_with_kv_cache
(
query
=
query
.
contiguous
()
,
query
=
query
,
kv_cache
=
k
ey_value
_cache
,
kv_cache
=
k
v
_cache
,
workspace_buffer
=
workspace_buffer
,
workspace_buffer
=
workspace_buffer
,
block_tables
=
block_tables
,
block_tables
=
block_tables
,
seq_lens
=
kv_lens_tensor
,
seq_lens
=
seq_lens
,
max_seq_len
=
max_kv_len
,
max_seq_len
=
max_seq_len
,
bmm1_scale
=
k_scale
*
scale
,
bmm1_scale
=
q_scale
*
k_scale
*
sm_scale
,
bmm2_scale
=
v_scale
,
bmm2_scale
=
v_scale
/
o_scale
,
o_sf_scale
=
o_sf_scale
,
out
=
output_trtllm
,
out
=
output_trtllm
,
)
)
if
o_quant_dtype
==
FP8_DTYPE
:
output_trtllm
=
output_trtllm
.
to
(
dtype
)
*
o_scale
elif
o_quant_dtype
==
FP4_DTYPE
:
output_trtllm
.
data
=
output_trtllm
.
data
.
reshape
(
-
1
,
query
.
shape
[
1
]
*
query
.
shape
[
2
]
//
2
)
output_trtllm
=
dequantize_nvfp4_to_dtype
(
output_trtllm
.
data
,
output_trtllm
.
scale
,
o_sf_scale
,
dtype
,
query
.
device
)
output_trtllm
=
output_trtllm
.
reshape
(
-
1
,
query
.
shape
[
1
],
query
.
shape
[
2
])
if
q_quant_dtype
==
FP8_DTYPE
and
o_quant_dtype
==
FP4_DTYPE
:
rtol
,
atol
=
3e-1
,
1e0
elif
q_quant_dtype
==
FP8_DTYPE
and
o_quant_dtype
==
FP8_DTYPE
:
rtol
,
atol
=
5e-2
,
7e-2
else
:
rtol
,
atol
=
1e-2
,
2e-2
torch
.
testing
.
assert_close
(
output
,
output_trtllm
,
atol
=
1e-2
,
rtol
=
1e-2
),
\
torch
.
testing
.
assert_close
(
output
,
output_trtllm
,
atol
=
atol
,
rtol
=
rtol
),
\
f
"
{
torch
.
max
(
torch
.
abs
(
output
-
output_trtllm
))
}
"
f
"
{
torch
.
max
(
torch
.
abs
(
output
-
output_trtllm
))
}
"
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
BATCH_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPE
)
@
pytest
.
mark
.
parametrize
(
"quant_dtypes"
,
QUANT_DTYPES
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
BATCH_SIZE
)
@
pytest
.
mark
.
parametrize
(
"max_seq_lens"
,
MAX_SEQ_LENS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZE
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"kv_layout"
,
KV_LAYOUT
)
@
pytest
.
mark
.
parametrize
(
"kv_layout"
,
KV_LAYOUTS
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZE
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"kv_cache_dtype"
,
KV_CACHE_DTYPES
)
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
])
@
torch
.
inference_mode
@
torch
.
inference_mode
def
test_flashinfer_trtllm_prefill_with_baseline
(
def
test_flashinfer_trtllm_prefill_with_baseline
(
dtype
:
torch
.
dtype
,
quant_dtypes
:
tuple
[
Optional
[
torch
.
dtype
],
Optional
[
torch
.
dtype
],
Optional
[
torch
.
dtype
]],
batch_size
:
int
,
batch_size
:
int
,
max_seq_lens
:
tuple
[
int
,
int
],
num_heads
:
tuple
[
int
,
int
],
num_heads
:
tuple
[
int
,
int
],
head_size
:
int
,
head_size
:
int
,
block_size
:
int
,
kv_layout
:
str
,
kv_layout
:
str
,
dtype
:
torch
.
dtype
,
block_size
:
int
,
kv_cache_dtype
:
Optional
[
torch
.
dtype
],
soft_cap
:
Optional
[
float
],
soft_cap
:
Optional
[
float
],
)
->
None
:
)
->
None
:
kv_cache_dtype
=
dtype
if
kv_cache_dtype
is
None
else
kv_cache_dtype
if
dtype
!=
kv_cache_dtype
:
pytest
.
skip
(
f
"Not supported dtype(
{
dtype
}
) with "
"kv_cache_dtype({kv_cache_dtype})"
)
torch
.
set_default_device
(
"cuda"
)
torch
.
set_default_device
(
"cuda"
)
current_platform
.
seed_everything
(
0
)
current_platform
.
seed_everything
(
0
)
q_lens
=
torch
.
randint
(
1
,
MAX_Q_LEN
,
(
batch_size
,
),
dtype
=
torch
.
int32
)
q_quant_dtype
,
kv_quant_dtype
,
o_quant_dtype
=
quant_dtypes
q_lens
[
-
1
]
=
MAX_Q_LEN
q_quant_dtype
=
q_quant_dtype
or
dtype
max_q_len
=
torch
.
max
(
q_lens
).
item
()
kv_quant_dtype
=
kv_quant_dtype
or
dtype
q_indptr
=
torch
.
cat
([
o_quant_dtype
=
o_quant_dtype
or
dtype
torch
.
tensor
([
0
],
dtype
=
torch
.
int32
),
torch
.
cumsum
(
q_lens
,
dim
=
0
,
dtype
=
torch
.
int32
),
])
kv_lens
=
torch
.
randint
(
0
,
MAX_KV_LEN
,
(
batch_size
,
),
dtype
=
torch
.
int32
)
if
q_quant_dtype
!=
kv_quant_dtype
:
kv_lens
[
-
1
]
=
MAX_KV_LEN
pytest
.
skip
(
"Skipped mixed QKV dtypes for prefill"
)
seq_lens
=
kv_lens
+
q_lens
max_q_len
,
max_kv_len
=
max_seq_lens
max_seq_len
=
torch
.
max
(
seq_lens
).
item
()
num_seqs
=
len
(
seq_lens
)
num_query_heads
=
num_heads
[
0
]
num_qo_heads
,
num_kv_heads
=
num_heads
num_kv_heads
=
num_heads
[
1
]
assert
num_qo_heads
%
num_kv_heads
==
0
assert
num_query_heads
%
num_kv_heads
==
0
scale
=
head_size
**-
0.5
sm_scale
=
float
(
1.0
/
(
head_size
**
0.5
))
query
=
torch
.
randn
(
torch
.
sum
(
q_lens
).
item
(),
num_query_heads
,
head_size
,
dtype
=
dtype
)
kv_cache_shape
=
None
kv_cache_shape
=
None
if
kv_layout
==
"NHD"
:
if
kv_layout
==
"NHD"
:
...
@@ -217,22 +261,49 @@ def test_flashinfer_trtllm_prefill_with_baseline(
...
@@ -217,22 +261,49 @@ def test_flashinfer_trtllm_prefill_with_baseline(
kv_cache_shape
=
(
NUM_BLOCKS
,
2
,
num_kv_heads
,
block_size
,
head_size
)
kv_cache_shape
=
(
NUM_BLOCKS
,
2
,
num_kv_heads
,
block_size
,
head_size
)
else
:
else
:
raise
ValueError
(
f
"Invalid kv_layout:
{
kv_layout
}
"
)
raise
ValueError
(
f
"Invalid kv_layout:
{
kv_layout
}
"
)
key_value_cache
=
torch
.
randn
(
kv_cache_shape
,
dtype
=
dtype
)
kv_scale
=
1.0
q_lens
=
torch
.
randint
(
1
,
max_q_len
,
(
batch_size
,
),
dtype
=
torch
.
int32
)
if
kv_cache_dtype
is
current_platform
.
fp8_dtype
():
q_lens
[
-
1
]
=
max_q_len
key_value_cache
,
kv_scale
=
to_float8
(
key_value_cache
,
q_indptr
=
torch
.
cat
([
current_platform
.
fp8_dtype
())
torch
.
tensor
([
0
],
dtype
=
torch
.
int32
),
torch
.
cumsum
(
q_lens
,
dim
=
0
,
dtype
=
torch
.
int32
),
])
query
=
torch
.
randn
(
torch
.
sum
(
q_lens
).
item
(),
num_qo_heads
,
head_size
,
dtype
=
dtype
)
if
q_quant_dtype
==
FP8_DTYPE
:
query
,
q_scale
=
to_float8
(
query
)
ref_query
=
query
.
to
(
dtype
)
*
q_scale
else
:
q_scale
=
1.0
ref_query
=
query
kv_lens
=
torch
.
randint
(
0
,
max_kv_len
,
(
batch_size
,
),
dtype
=
torch
.
int32
)
kv_lens
[
-
1
]
=
max_kv_len
seq_lens
=
kv_lens
+
q_lens
max_seq_len
=
torch
.
max
(
seq_lens
).
item
()
kv_cache
=
torch
.
randn
(
kv_cache_shape
,
dtype
=
dtype
)
if
kv_quant_dtype
==
FP8_DTYPE
:
kv_cache
,
kv_scale
=
to_float8
(
kv_cache
)
ref_kv_cache
=
kv_cache
.
to
(
dtype
)
*
kv_scale
else
:
kv_scale
=
1.0
ref_kv_cache
=
kv_cache
k_scale
=
v_scale
=
kv_scale
max_num_blocks_per_seq
=
(
max_seq_len
+
block_size
-
1
)
//
block_size
max_num_blocks_per_seq
=
(
max_seq_len
+
block_size
-
1
)
//
block_size
block_tables
=
torch
.
randint
(
0
,
block_tables
=
torch
.
randint
(
0
,
NUM_BLOCKS
,
NUM_BLOCKS
,
(
num_seqs
,
max_num_blocks_per_seq
),
(
batch_size
,
max_num_blocks_per_seq
),
dtype
=
torch
.
int32
)
dtype
=
torch
.
int32
)
k_scale
=
v_scale
=
kv_scale
kv_indptr
=
[
0
]
kv_indptr
=
[
0
]
kv_indices
=
[]
kv_indices
=
[]
kv_last_page_lens
=
[]
kv_last_page_lens
=
[]
for
i
in
range
(
num_seqs
):
for
i
in
range
(
batch_size
):
seq_len
=
seq_lens
[
i
]
seq_len
=
seq_lens
[
i
]
assert
seq_len
>
0
assert
seq_len
>
0
num_blocks
=
(
seq_len
+
block_size
-
1
)
//
block_size
num_blocks
=
(
seq_len
+
block_size
-
1
)
//
block_size
...
@@ -246,48 +317,81 @@ def test_flashinfer_trtllm_prefill_with_baseline(
...
@@ -246,48 +317,81 @@ def test_flashinfer_trtllm_prefill_with_baseline(
kv_indptr
=
torch
.
tensor
(
kv_indptr
,
dtype
=
torch
.
int32
)
kv_indptr
=
torch
.
tensor
(
kv_indptr
,
dtype
=
torch
.
int32
)
kv_indices
=
torch
.
tensor
(
kv_indices
,
dtype
=
torch
.
int32
)
kv_indices
=
torch
.
tensor
(
kv_indices
,
dtype
=
torch
.
int32
)
kv_last_page_lens
=
torch
.
tensor
(
kv_last_page_lens
,
dtype
=
torch
.
int32
)
kv_last_page_lens
=
torch
.
tensor
(
kv_last_page_lens
,
dtype
=
torch
.
int32
)
workspace_buffer
=
torch
.
zeros
(
128
*
1024
*
1024
,
dtype
=
torch
.
int8
)
workspace_buffer
=
torch
.
zeros
(
128
*
1024
*
1024
,
dtype
=
torch
.
int8
)
# Baseline Prefill
wrapper
=
flashinfer
.
BatchPrefillWithPagedKVCacheWrapper
(
wrapper
=
flashinfer
.
BatchPrefillWithPagedKVCacheWrapper
(
workspace_buffer
,
kv_layout
)
workspace_buffer
,
kv_layout
)
wrapper
.
plan
(
q_indptr
,
wrapper
.
plan
(
q_indptr
,
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
kv_last_page_lens
,
kv_last_page_lens
,
num_q
uery
_heads
,
num_q
o
_heads
,
num_kv_heads
,
num_kv_heads
,
head_size
,
head_size
,
block_size
,
block_size
,
causal
=
True
,
causal
=
True
,
sm_scale
=
scale
,
sm_scale
=
sm_
scale
,
q_data_type
=
dtype
,
q_data_type
=
dtype
,
kv_data_type
=
kv_cache_
dtype
,
kv_data_type
=
dtype
,
logits_soft_cap
=
soft_cap
)
logits_soft_cap
=
soft_cap
)
output
=
torch
.
empty
(
query
.
shape
,
dtype
=
dtype
)
output
=
torch
.
empty
(
ref_query
.
shape
,
dtype
=
dtype
)
wrapper
.
run
(
query
,
wrapper
.
run
(
ref_query
,
ref_kv_cache
,
out
=
output
)
key_value_cache
,
o_scale
=
1.0
k_scale
=
k_scale
,
o_sf_scale
=
None
v_scale
=
v_scale
,
if
o_quant_dtype
==
FP8_DTYPE
:
out
=
output
)
_
,
o_scale
=
to_float8
(
output
)
elif
o_quant_dtype
==
FP4_DTYPE
:
o_sf_scale
=
((
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
)
/
torch
.
amax
(
output
.
flatten
(),
dim
=-
1
)).
to
(
torch
.
float32
)
# TRTLLM Prefill
if
o_quant_dtype
==
FP4_DTYPE
:
output_trtllm
=
flashinfer
.
utils
.
FP4Tensor
(
torch
.
empty
(
query
.
shape
[:
-
1
]
+
(
query
.
shape
[
-
1
]
//
2
,
),
dtype
=
torch
.
uint8
),
torch
.
empty
((
round_up
(
query
.
shape
[
0
],
128
),
round_up
(
query
.
shape
[
1
]
*
query
.
shape
[
2
]
//
16
,
4
)),
dtype
=
torch
.
float8_e4m3fn
),
)
else
:
output_trtllm
=
torch
.
empty
(
query
.
shape
,
dtype
=
o_quant_dtype
)
# TRTLLM Decode
output_trtllm
=
torch
.
empty
(
query
.
shape
,
dtype
=
dtype
)
flashinfer
.
prefill
.
trtllm_batch_context_with_kv_cache
(
flashinfer
.
prefill
.
trtllm_batch_context_with_kv_cache
(
query
=
query
.
contiguous
()
,
query
=
query
,
kv_cache
=
k
ey_value
_cache
,
kv_cache
=
k
v
_cache
,
workspace_buffer
=
workspace_buffer
,
workspace_buffer
=
workspace_buffer
,
block_tables
=
block_tables
,
block_tables
=
block_tables
,
seq_lens
=
seq_lens
,
seq_lens
=
seq_lens
,
max_q_len
=
max_q_len
,
max_q_len
=
max_q_len
,
max_kv_len
=
max_seq_len
,
max_kv_len
=
max_seq_len
,
bmm1_scale
=
k_scale
*
scale
,
bmm1_scale
=
q_scale
*
k_scale
*
sm_
scale
,
bmm2_scale
=
v_scale
,
bmm2_scale
=
v_scale
/
o_scale
,
batch_size
=
num_seqs
,
batch_size
=
batch_size
,
cum_seq_lens_q
=
q_indptr
,
cum_seq_lens_q
=
q_indptr
,
cum_seq_lens_kv
=
kv_indptr
,
cum_seq_lens_kv
=
kv_indptr
,
o_sf_scale
=
o_sf_scale
,
out
=
output_trtllm
,
out
=
output_trtllm
,
)
)
if
o_quant_dtype
==
FP8_DTYPE
:
output_trtllm
=
output_trtllm
.
to
(
dtype
)
*
o_scale
elif
o_quant_dtype
==
FP4_DTYPE
:
output_trtllm
.
data
=
output_trtllm
.
data
.
reshape
(
-
1
,
query
.
shape
[
1
]
*
query
.
shape
[
2
]
//
2
)
output_trtllm
=
dequantize_nvfp4_to_dtype
(
output_trtllm
.
data
,
output_trtllm
.
scale
,
o_sf_scale
,
dtype
,
query
.
device
)
output_trtllm
=
output_trtllm
.
reshape
(
-
1
,
query
.
shape
[
1
],
query
.
shape
[
2
])
if
q_quant_dtype
==
FP8_DTYPE
and
o_quant_dtype
==
FP4_DTYPE
:
rtol
,
atol
=
4e-1
,
1e0
elif
q_quant_dtype
==
FP8_DTYPE
and
o_quant_dtype
==
FP8_DTYPE
:
rtol
,
atol
=
5e-2
,
7e-2
else
:
rtol
,
atol
=
1e-2
,
1e-2
torch
.
testing
.
assert_close
(
output
,
output_trtllm
,
atol
=
1e-2
,
rtol
=
1e-2
),
\
torch
.
testing
.
assert_close
(
output
,
output_trtllm
,
atol
=
atol
,
rtol
=
rtol
),
\
f
"
{
torch
.
max
(
torch
.
abs
(
output
-
output_trtllm
))
}
"
f
"
{
torch
.
max
(
torch
.
abs
(
output
-
output_trtllm
))
}
"
tests/kernels/attention/test_flashmla.py
View file @
a99300bd
...
@@ -13,11 +13,17 @@ from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
...
@@ -13,11 +13,17 @@ from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
from
vllm.triton_utils
import
triton
from
vllm.triton_utils
import
triton
def
cal_diff
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
name
:
str
)
->
None
:
def
cal_diff
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
name
:
str
,
use_fp8
:
bool
=
False
)
->
None
:
x
,
y
=
x
.
double
(),
y
.
double
()
x
,
y
=
x
.
double
(),
y
.
double
()
cos_diff
=
1
-
2
*
(
x
*
y
).
sum
().
item
()
/
max
(
cos_diff
=
1
-
2
*
(
x
*
y
).
sum
().
item
()
/
max
(
(
x
*
x
+
y
*
y
).
sum
().
item
(),
1e-12
)
(
x
*
x
+
y
*
y
).
sum
().
item
(),
1e-12
)
assert
cos_diff
<
1e-4
if
(
use_fp8
):
assert
cos_diff
<
1e-4
else
:
assert
cos_diff
<
1e-4
#1e-5
FLASH_MLA_UNSUPPORTED_REASON
=
is_flashmla_supported
()[
1
]
\
FLASH_MLA_UNSUPPORTED_REASON
=
is_flashmla_supported
()[
1
]
\
if
not
is_flashmla_supported
()[
0
]
else
"FlashMLA is supported"
if
not
is_flashmla_supported
()[
0
]
else
"FlashMLA is supported"
...
@@ -27,7 +33,7 @@ FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \
...
@@ -27,7 +33,7 @@ FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \
reason
=
FLASH_MLA_UNSUPPORTED_REASON
)
reason
=
FLASH_MLA_UNSUPPORTED_REASON
)
@
pytest
.
mark
.
parametrize
(
"b"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"b"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"s_q"
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
"s_q"
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
"mean_sk"
,
[
4096
,
8192
])
@
pytest
.
mark
.
parametrize
(
"mean_sk"
,
[
4096
,
8192
,
16384
])
@
pytest
.
mark
.
parametrize
(
"h_q"
,
[
16
,
32
,
64
,
128
])
@
pytest
.
mark
.
parametrize
(
"h_q"
,
[
16
,
32
,
64
,
128
])
@
pytest
.
mark
.
parametrize
(
"h_kv"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"h_kv"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
576
])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
576
])
...
@@ -35,20 +41,26 @@ FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \
...
@@ -35,20 +41,26 @@ FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
64
])
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
64
])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"varlen"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"varlen"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"torch_dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
,
torch
.
float8_e4m3fn
])
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_flash_mla
(
b
,
s_q
,
mean_sk
,
h_q
,
h_kv
,
d
,
dv
,
block_size
,
causal
,
def
test_flash_mla
(
b
,
s_q
,
mean_sk
,
h_q
,
h_kv
,
d
,
dv
,
block_size
,
causal
,
varlen
,
dtype
):
varlen
,
torch_
dtype
):
device
=
torch
.
device
(
"cuda:0"
)
device
=
torch
.
device
(
"cuda:0"
)
torch
.
set_default_dtype
(
dtype
)
if
torch_dtype
==
torch
.
float8_e4m3fn
:
init_dtype
=
torch
.
bfloat16
else
:
init_dtype
=
torch_dtype
torch
.
set_default_dtype
(
init_dtype
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
torch
.
cuda
.
set_device
(
device
)
torch
.
cuda
.
set_device
(
device
)
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
random
.
seed
(
0
)
random
.
seed
(
0
)
print
(
f
"
{
b
=
}
,
{
s_q
=
}
,
{
mean_sk
=
}
,
{
h_q
=
}
,
{
h_kv
=
}
, "
print
(
f
"
{
b
=
}
,
{
s_q
=
}
,
{
mean_sk
=
}
,
{
h_q
=
}
,
{
h_kv
=
}
, "
f
"
{
d
=
}
,
{
dv
=
}
,
{
causal
=
}
,
{
varlen
=
}
,
{
dtype
=
}
"
)
f
"
{
d
=
}
,
{
dv
=
}
,
{
causal
=
}
,
{
varlen
=
}
,
{
torch_
dtype
=
}
"
)
use_fp8
=
torch_dtype
==
torch
.
float8_e4m3fn
cache_seqlens
=
torch
.
full
((
b
,
),
mean_sk
,
dtype
=
torch
.
int32
)
cache_seqlens
=
torch
.
full
((
b
,
),
mean_sk
,
dtype
=
torch
.
int32
)
if
varlen
:
if
varlen
:
for
i
in
range
(
b
):
for
i
in
range
(
b
):
...
@@ -71,6 +83,19 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
...
@@ -71,6 +83,19 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
tile_scheduler_metadata
,
num_splits
=
get_mla_metadata
(
tile_scheduler_metadata
,
num_splits
=
get_mla_metadata
(
cache_seqlens
,
s_q
*
h_q
//
h_kv
,
h_kv
)
cache_seqlens
,
s_q
*
h_q
//
h_kv
,
h_kv
)
init_dtype
=
q
.
dtype
if
use_fp8
:
fp8_dtype
=
torch
.
float8_e4m3fn
descale_q
=
torch
.
ones
((
1
),
dtype
=
torch
.
float32
)
descale_k
=
torch
.
ones
((
1
),
dtype
=
torch
.
float32
)
q
=
q
.
to
(
fp8_dtype
)
blocked_k
=
blocked_k
.
to
(
fp8_dtype
)
blocked_v
=
blocked_v
.
to
(
fp8_dtype
)
else
:
descale_q
=
None
descale_k
=
None
def
flash_mla
():
def
flash_mla
():
return
flash_mla_with_kvcache
(
return
flash_mla_with_kvcache
(
q
,
q
,
...
@@ -81,6 +106,8 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
...
@@ -81,6 +106,8 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
tile_scheduler_metadata
,
tile_scheduler_metadata
,
num_splits
,
num_splits
,
causal
=
causal
,
causal
=
causal
,
descale_q
=
descale_q
,
descale_k
=
descale_k
,
)
)
def
scaled_dot_product_attention
(
query
,
key
,
value
,
is_causal
=
False
):
def
scaled_dot_product_attention
(
query
,
key
,
value
,
is_causal
=
False
):
...
@@ -104,29 +131,35 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
...
@@ -104,29 +131,35 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
return
attn_weight
@
value
,
lse
return
attn_weight
@
value
,
lse
def
ref_mla
():
def
ref_mla
():
q_
=
(
q
.
to
(
torch
.
float
)
*
descale_q
).
to
(
init_dtype
)
if
use_fp8
else
q
blocked_k_
=
(
blocked_k
.
to
(
torch
.
float
)
*
descale_k
).
to
(
init_dtype
)
if
use_fp8
else
blocked_k
blocked_v_
=
(
blocked_v
.
to
(
torch
.
float
)
*
descale_k
).
to
(
init_dtype
)
if
use_fp8
else
blocked_v
out
=
torch
.
empty
(
b
,
s_q
,
h_q
,
dv
,
dtype
=
torch
.
float32
)
out
=
torch
.
empty
(
b
,
s_q
,
h_q
,
dv
,
dtype
=
torch
.
float32
)
lse
=
torch
.
empty
(
b
,
h_q
,
s_q
,
dtype
=
torch
.
float32
)
lse
=
torch
.
empty
(
b
,
h_q
,
s_q
,
dtype
=
torch
.
float32
)
for
i
in
range
(
b
):
for
i
in
range
(
b
):
begin
=
i
*
max_seqlen_pad
begin
=
i
*
max_seqlen_pad
end
=
begin
+
cache_seqlens
[
i
]
end
=
begin
+
cache_seqlens
[
i
]
ref_O
,
LSE
=
scaled_dot_product_attention
(
out_i
,
lse_i
=
scaled_dot_product_attention
(
q
[
i
].
transpose
(
0
,
1
),
q
_
[
i
].
transpose
(
0
,
1
),
blocked_k
.
view
(
-
1
,
h_kv
,
d
)[
begin
:
end
].
transpose
(
0
,
1
),
blocked_k
_
.
view
(
-
1
,
h_kv
,
d
)[
begin
:
end
].
transpose
(
0
,
1
),
blocked_v
.
view
(
-
1
,
h_kv
,
dv
)[
begin
:
end
].
transpose
(
0
,
1
),
blocked_v
_
.
view
(
-
1
,
h_kv
,
dv
)[
begin
:
end
].
transpose
(
0
,
1
),
is_causal
=
causal
,
is_causal
=
causal
,
)
)
out
[
i
]
=
ref_O
.
transpose
(
0
,
1
)
out
[
i
]
=
out_i
.
transpose
(
0
,
1
)
lse
[
i
]
=
LSE
lse
[
i
]
=
lse_i
return
out
,
lse
return
out
,
lse
out_flash
,
lse_flash
=
flash_mla
()
out_flash
,
lse_flash
=
flash_mla
()
out_torch
,
lse_torch
=
ref_mla
()
out_torch
,
lse_torch
=
ref_mla
()
cal_diff
(
out_flash
,
out_torch
,
"out"
)
cal_diff
(
out_flash
,
out_torch
,
"out"
,
use_fp8
)
cal_diff
(
lse_flash
,
lse_torch
,
"lse"
)
cal_diff
(
lse_flash
,
lse_torch
,
"lse"
)
t
=
triton
.
testing
.
do_bench
(
flash_mla
)
t
=
triton
.
testing
.
do_bench
(
flash_mla
)
FLOPS
=
s_q
*
total_seqlens
*
h_q
*
(
d
+
dv
)
*
2
FLOPS
=
s_q
*
total_seqlens
*
h_q
*
(
d
+
dv
)
*
2
bytes
=
(
total_seqlens
*
h_kv
*
d
+
b
*
s_q
*
h_q
*
d
+
bytes
=
(
total_seqlens
*
h_kv
*
d
+
b
*
s_q
*
h_q
*
dv
)
*
(
torch
.
finfo
(
dtype
).
bits
//
8
)
b
*
s_q
*
h_q
*
d
)
*
(
torch
.
finfo
(
torch_dtype
).
bits
//
8
)
+
(
print
(
f
"
{
t
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
t
:.
0
f
}
"
b
*
s_q
*
h_q
*
dv
)
*
(
torch
.
finfo
(
init_dtype
).
bits
//
8
)
f
"TFLOPS,
{
bytes
/
10
**
6
/
t
:.
0
f
}
GB/s"
)
print
(
f
"
{
t
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
t
:.
0
f
}
TFLOPS,"
,
f
"
{
bytes
/
10
**
6
/
t
:.
0
f
}
GB/s"
)
tests/kernels/attention/untest_attention_selector.py
View file @
a99300bd
...
@@ -80,6 +80,9 @@ def test_env(
...
@@ -80,6 +80,9 @@ def test_env(
m
.
setenv
(
STR_BACKEND_ENV_VAR
,
name
)
m
.
setenv
(
STR_BACKEND_ENV_VAR
,
name
)
m
.
setenv
(
"VLLM_MLA_DISABLE"
,
"1"
if
use_mla
else
"0"
)
m
.
setenv
(
"VLLM_MLA_DISABLE"
,
"1"
if
use_mla
else
"0"
)
if
name
==
"FLASHINFER"
and
not
use_v1
:
pytest
.
skip
(
"FlashInfer backend is only available on V1 engine"
)
if
device
==
"cpu"
:
if
device
==
"cpu"
:
with
patch
(
"vllm.attention.selector.current_platform"
,
with
patch
(
"vllm.attention.selector.current_platform"
,
CpuPlatform
()):
CpuPlatform
()):
...
...
tests/kernels/attention/untest_flashinfer.py
View file @
a99300bd
...
@@ -137,9 +137,7 @@ def test_flashinfer_decode_with_paged_kv(
...
@@ -137,9 +137,7 @@ def test_flashinfer_decode_with_paged_kv(
workspace_buffer
=
torch
.
empty
(
128
*
1024
*
1024
,
dtype
=
torch
.
int8
)
workspace_buffer
=
torch
.
empty
(
128
*
1024
*
1024
,
dtype
=
torch
.
int8
)
wrapper
=
flashinfer
.
\
wrapper
=
flashinfer
.
\
BatchDecodeWithPagedKVCacheWrapper
(
workspace_buffer
,
"NHD"
,
BatchDecodeWithPagedKVCacheWrapper
(
workspace_buffer
,
"NHD"
,
use_tensor_cores
=
(
use_tensor_cores
=
True
)
(
num_query_heads
//
num_kv_heads
)
>
4
)
)
wrapper
.
plan
(
wrapper
.
plan
(
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
...
@@ -411,7 +409,7 @@ def test_flashinfer_decode_with_paged_fp8_kv(
...
@@ -411,7 +409,7 @@ def test_flashinfer_decode_with_paged_fp8_kv(
assert
num_query_heads
%
num_kv_heads
==
0
assert
num_query_heads
%
num_kv_heads
==
0
max_kv_len
=
max
(
kv_lens
)
max_kv_len
=
max
(
kv_lens
)
scale
=
head_size
**-
0.5
scale
=
head_size
**-
0.5
use_tensor_cores
=
(
num_query_heads
//
num_kv_heads
)
>
4
use_tensor_cores
=
True
kv_cache_dtype
=
torch
.
float8_e4m3fn
kv_cache_dtype
=
torch
.
float8_e4m3fn
query
=
torch
.
randn
(
num_seqs
,
num_query_heads
,
head_size
,
dtype
=
dtype
)
query
=
torch
.
randn
(
num_seqs
,
num_query_heads
,
head_size
,
dtype
=
dtype
)
...
...
tests/kernels/mamba/test_mamba_ssm_ssd.py
→
tests/kernels/mamba/
un
test_mamba_ssm_ssd.py
View file @
a99300bd
File moved
tests/kernels/moe/test_deepep_deepgemm_moe.py
View file @
a99300bd
...
@@ -20,9 +20,9 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
...
@@ -20,9 +20,9 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel
)
FusedMoEModularKernel
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
has_deep_ep
,
has_deep_gemm
from
vllm.utils
import
has_deep_ep
,
has_deep_gemm
from
vllm.utils.deep_gemm
import
(
is_blackwell_deep_gemm_e8m0_used
,
from
vllm.utils.deep_gemm
import
is_deep_gemm_e8m0_used
,
is_deep_gemm_supported
is_deep_gemm_supported
)
from
...utils
import
multi_gpu_test
from
.parallel_utils
import
ProcessGroupInfo
,
parallel_launch
from
.parallel_utils
import
ProcessGroupInfo
,
parallel_launch
from
.utils
import
make_test_weights
from
.utils
import
make_test_weights
...
@@ -370,9 +370,10 @@ NUM_EXPERTS = [32]
...
@@ -370,9 +370,10 @@ NUM_EXPERTS = [32]
@
pytest
.
mark
.
parametrize
(
"num_experts"
,
NUM_EXPERTS
)
@
pytest
.
mark
.
parametrize
(
"num_experts"
,
NUM_EXPERTS
)
@
pytest
.
mark
.
parametrize
(
"topk"
,
TOPKS
)
@
pytest
.
mark
.
parametrize
(
"topk"
,
TOPKS
)
@
pytest
.
mark
.
parametrize
(
"world_dp_size"
,
[(
2
,
1
)])
@
pytest
.
mark
.
parametrize
(
"world_dp_size"
,
[(
2
,
1
)])
@
multi_gpu_test
(
num_gpus
=
2
)
@
requires_deep_ep
@
requires_deep_ep
@
requires_deep_gemm
@
requires_deep_gemm
@
pytest
.
mark
.
skipif
(
is_
blackwell_
deep_gemm_e8m0_used
(),
@
pytest
.
mark
.
skipif
(
is_deep_gemm_e8m0_used
(),
reason
=
"Skipping test for Blackwell DeepGEMM"
)
reason
=
"Skipping test for Blackwell DeepGEMM"
)
def
test_ht_deepep_deepgemm_moe
(
mnk
:
tuple
[
int
,
int
,
int
],
num_experts
:
int
,
def
test_ht_deepep_deepgemm_moe
(
mnk
:
tuple
[
int
,
int
,
int
],
num_experts
:
int
,
topk
:
int
,
world_dp_size
:
tuple
[
int
,
int
]):
topk
:
int
,
world_dp_size
:
tuple
[
int
,
int
]):
...
@@ -427,9 +428,10 @@ USE_FP8_DISPATCH = [False]
...
@@ -427,9 +428,10 @@ USE_FP8_DISPATCH = [False]
@
pytest
.
mark
.
parametrize
(
"use_fp8_dispatch"
,
USE_FP8_DISPATCH
)
@
pytest
.
mark
.
parametrize
(
"use_fp8_dispatch"
,
USE_FP8_DISPATCH
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[[
128
,
128
]])
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[[
128
,
128
]])
@
pytest
.
mark
.
parametrize
(
"world_dp_size"
,
[(
2
,
1
)])
@
pytest
.
mark
.
parametrize
(
"world_dp_size"
,
[(
2
,
1
)])
@
multi_gpu_test
(
num_gpus
=
2
)
@
requires_deep_ep
@
requires_deep_ep
@
requires_deep_gemm
@
requires_deep_gemm
@
pytest
.
mark
.
skipif
(
is_
blackwell_
deep_gemm_e8m0_used
(),
@
pytest
.
mark
.
skipif
(
is_deep_gemm_e8m0_used
(),
reason
=
"Skipping test for Blackwell DeepGEMM"
)
reason
=
"Skipping test for Blackwell DeepGEMM"
)
def
test_ll_deepep_deepgemm_moe
(
def
test_ll_deepep_deepgemm_moe
(
mnk
:
tuple
[
int
,
int
,
int
],
mnk
:
tuple
[
int
,
int
,
int
],
...
...
tests/kernels/moe/test_deepep_moe.py
View file @
a99300bd
...
@@ -24,6 +24,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
...
@@ -24,6 +24,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
has_deep_ep
from
vllm.utils
import
has_deep_ep
from
...utils
import
multi_gpu_test
from
.parallel_utils
import
ProcessGroupInfo
,
parallel_launch
from
.parallel_utils
import
ProcessGroupInfo
,
parallel_launch
if
has_deep_ep
():
if
has_deep_ep
():
...
@@ -411,6 +412,7 @@ DTYPES = [torch.bfloat16, torch.float8_e4m3fn]
...
@@ -411,6 +412,7 @@ DTYPES = [torch.bfloat16, torch.float8_e4m3fn]
@
pytest
.
mark
.
parametrize
(
"topk"
,
[
6
])
@
pytest
.
mark
.
parametrize
(
"topk"
,
[
6
])
@
pytest
.
mark
.
parametrize
(
"world_dp_size"
,
[(
2
,
1
)])
@
pytest
.
mark
.
parametrize
(
"world_dp_size"
,
[(
2
,
1
)])
@
pytest
.
mark
.
parametrize
(
"per_act_token_quant"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"per_act_token_quant"
,
[
False
,
True
])
@
multi_gpu_test
(
num_gpus
=
2
)
@
requires_deep_ep
@
requires_deep_ep
def
test_deep_ep_moe
(
def
test_deep_ep_moe
(
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
...
@@ -459,6 +461,7 @@ USE_FP8_DISPATCH = [True, False]
...
@@ -459,6 +461,7 @@ USE_FP8_DISPATCH = [True, False]
@
pytest
.
mark
.
parametrize
(
"topk"
,
[
6
])
@
pytest
.
mark
.
parametrize
(
"topk"
,
[
6
])
@
pytest
.
mark
.
parametrize
(
"world_dp_size"
,
[(
2
,
1
)])
@
pytest
.
mark
.
parametrize
(
"world_dp_size"
,
[(
2
,
1
)])
@
pytest
.
mark
.
parametrize
(
"use_fp8_dispatch"
,
USE_FP8_DISPATCH
)
@
pytest
.
mark
.
parametrize
(
"use_fp8_dispatch"
,
USE_FP8_DISPATCH
)
@
multi_gpu_test
(
num_gpus
=
2
)
@
requires_deep_ep
@
requires_deep_ep
def
test_low_latency_deep_ep_moe
(
dtype
:
torch
.
dtype
,
mnk
:
tuple
[
int
,
int
,
int
],
def
test_low_latency_deep_ep_moe
(
dtype
:
torch
.
dtype
,
mnk
:
tuple
[
int
,
int
,
int
],
num_experts
:
int
,
topk
:
int
,
num_experts
:
int
,
topk
:
int
,
...
...
tests/kernels/moe/test_flashinfer.py
0 → 100644
View file @
a99300bd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
dataclass
import
pytest
import
torch
from
vllm.config
import
ParallelConfig
,
VllmConfig
,
set_current_vllm_config
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_experts
from
vllm.model_executor.layers.fused_moe.layer
import
FusedMoE
from
vllm.model_executor.layers.quantization.utils.flashinfer_utils
import
(
apply_flashinfer_per_tensor_scale_fp8
,
flashinfer_cutlass_moe_fp8
,
register_moe_scaling_factors
,
rotate_flashinfer_fp8_moe_weights
,
swap_w13_to_w31
)
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
input_to_float8
)
from
vllm.model_executor.models.llama4
import
Llama4MoE
from
vllm.platforms
import
current_platform
from
vllm.utils.flashinfer
import
has_flashinfer_cutlass_fused_moe
if
not
has_flashinfer_cutlass_fused_moe
(
)
or
not
current_platform
.
has_device_capability
(
100
):
pytest
.
skip
(
"Requires flashinfer_cutlass_fused_moe and nvfp4 support"
,
allow_module_level
=
True
)
NUM_EXPERTS
=
[
16
]
TOP_KS
=
[
1
]
MNK_FACTORS
=
[
(
256
,
8192
,
5120
),
(
256
,
4096
,
5120
),
(
127
,
8192
,
5120
),
(
127
,
4096
,
5120
),
(
10
,
8192
,
5120
),
(
10
,
4096
,
5120
),
(
1
,
8192
,
5120
),
(
1
,
4096
,
5120
),
]
vllm_config
=
VllmConfig
(
parallel_config
=
ParallelConfig
(
pipeline_parallel_size
=
1
))
vllm_config
.
scheduler_config
.
max_num_seqs
=
128
vllm_config
.
scheduler_config
.
max_model_len
=
8192
def
quant_fp8_per_tensor_batches
(
a
):
num_batches
=
a
.
size
(
0
)
a_quant
=
[]
a_scales
=
[]
for
i
in
range
(
num_batches
):
a_fp8
,
a_global_sf
=
input_to_float8
(
a
[
i
])
a_global_sf
=
1.0
/
a_global_sf
a_quant
.
append
(
a_fp8
)
a_scales
.
append
(
a_global_sf
)
result_a_quant
=
torch
.
stack
(
a_quant
)
result_a_scales
=
torch
.
stack
(
a_scales
)
return
result_a_quant
,
result_a_scales
@
dataclass
class
TestData
:
hidden_states
:
torch
.
Tensor
w13_quantized
:
torch
.
Tensor
w2_quantized
:
torch
.
Tensor
a1_scale
:
torch
.
Tensor
a2_scale
:
torch
.
Tensor
w13_weight_scale
:
torch
.
Tensor
w2_weight_scale
:
torch
.
Tensor
layer
:
torch
.
nn
.
Module
@
staticmethod
def
make_moe_tensors_8bit
(
m
:
int
,
k
:
int
,
n
:
int
,
e
:
int
,
reorder
:
bool
)
->
"TestData"
:
hidden_states
=
torch
.
randn
(
(
m
,
k
),
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
/
10
w13
=
torch
.
randn
((
e
,
2
*
n
,
k
),
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
w2
=
torch
.
randn
((
e
,
k
,
n
),
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
# Scale to fp8
_
,
a1_scale
=
input_to_float8
(
hidden_states
)
a1_scale
=
1.0
/
a1_scale
a2_scale
=
torch
.
scalar_tensor
(
1.0
).
to
(
device
=
"cuda"
).
to
(
dtype
=
torch
.
float32
)
w13_quantized
,
w13_weight_scale
=
quant_fp8_per_tensor_batches
(
w13
)
w2_quantized
,
w2_weight_scale
=
quant_fp8_per_tensor_batches
(
w2
)
layer
=
torch
.
nn
.
Module
()
layer
.
w13_weight
=
w13_quantized
.
clone
()
layer
.
w2_weight
=
w2_quantized
.
clone
()
layer
.
w13_input_scale
=
a1_scale
layer
.
w2_input_scale
=
a2_scale
layer
.
w13_weight_scale
=
w13_weight_scale
layer
.
w2_weight_scale
=
w2_weight_scale
register_moe_scaling_factors
(
layer
)
# flashinfer expects swapped rows for w13
layer
.
w13_weight
.
data
=
swap_w13_to_w31
(
layer
.
w13_weight
.
data
)
if
reorder
:
rotate_flashinfer_fp8_moe_weights
(
layer
.
w13_weight
,
layer
.
w2_weight
)
layer
.
custom_routing_function
=
Llama4MoE
.
custom_routing_function
layer
.
intermediate_size_per_partition
=
n
layer
.
ep_rank
=
0
layer
.
local_num_experts
=
e
return
TestData
(
hidden_states
=
hidden_states
,
w13_quantized
=
w13_quantized
,
w2_quantized
=
w2_quantized
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
w13_weight_scale
=
w13_weight_scale
,
w2_weight_scale
=
w2_weight_scale
,
layer
=
layer
,
)
@
pytest
.
mark
.
parametrize
(
"m,n,k"
,
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"e"
,
NUM_EXPERTS
)
@
pytest
.
mark
.
parametrize
(
"topk"
,
TOP_KS
)
def
test_flashinfer_per_tensor_moe_fp8_no_graph
(
m
:
int
,
n
:
int
,
k
:
int
,
e
:
int
,
topk
:
int
,
monkeypatch
,
):
current_platform
.
seed_everything
(
7
)
monkeypatch
.
setenv
(
"VLLM_FUSED_MOE_CHUNK_SIZE"
,
"8192"
)
with
set_current_vllm_config
(
vllm_config
):
td
=
TestData
.
make_moe_tensors_8bit
(
m
,
k
,
n
,
e
,
reorder
=
True
)
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
td
.
hidden_states
,
router_logits
=
score
,
use_grouped_topk
=
False
,
top_k
=
topk
,
renormalize
=
False
,
custom_routing_function
=
Llama4MoE
.
custom_routing_function
,
scoring_func
=
"softmax"
)
output
=
fused_experts
(
td
.
hidden_states
,
td
.
w13_quantized
,
td
.
w2_quantized
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
False
,
activation
=
"silu"
,
use_fp8_w8a8
=
True
,
per_channel_quant
=
False
,
global_num_experts
=
e
,
expert_map
=
None
,
w1_scale
=
td
.
w13_weight_scale
,
w2_scale
=
td
.
w2_weight_scale
,
a1_scale
=
td
.
a1_scale
,
a2_scale
=
td
.
a2_scale
,
apply_router_weight_on_input
=
True
,
)
flashinfer_output
=
apply_flashinfer_per_tensor_scale_fp8
(
layer
=
td
.
layer
,
hidden_states
=
td
.
hidden_states
,
router_logits
=
score
,
routing_bias
=
None
,
global_num_experts
=
e
,
top_k
=
topk
,
num_expert_group
=
None
,
topk_group
=
None
,
apply_router_weight_on_input
=
True
)
torch
.
testing
.
assert_close
(
output
,
flashinfer_output
,
atol
=
5.5e-2
,
rtol
=
1e-2
)
@
pytest
.
mark
.
skip
(
"Requires flashinfer version that contains https://github.com/flashinfer-ai/flashinfer/pull/1472"
)
@
pytest
.
mark
.
parametrize
(
"m,n,k"
,
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"e"
,
NUM_EXPERTS
)
@
pytest
.
mark
.
parametrize
(
"topk"
,
TOP_KS
)
def
test_flashinfer_cutlass_moe_fp8_no_graph
(
m
:
int
,
n
:
int
,
k
:
int
,
e
:
int
,
topk
:
int
,
monkeypatch
,
):
current_platform
.
seed_everything
(
7
)
monkeypatch
.
setenv
(
"VLLM_FUSED_MOE_CHUNK_SIZE"
,
"8192"
)
with
set_current_vllm_config
(
vllm_config
):
td
=
TestData
.
make_moe_tensors_8bit
(
m
,
k
,
n
,
e
,
reorder
=
False
)
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
td
.
hidden_states
,
router_logits
=
score
,
use_grouped_topk
=
False
,
top_k
=
topk
,
renormalize
=
False
,
custom_routing_function
=
Llama4MoE
.
custom_routing_function
,
scoring_func
=
"softmax"
)
output
=
fused_experts
(
td
.
hidden_states
,
td
.
w13_quantized
,
td
.
w2_quantized
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
False
,
activation
=
"silu"
,
use_fp8_w8a8
=
True
,
per_channel_quant
=
False
,
global_num_experts
=
e
,
expert_map
=
None
,
w1_scale
=
td
.
w13_weight_scale
,
w2_scale
=
td
.
w2_weight_scale
,
a1_scale
=
td
.
a1_scale
,
a2_scale
=
td
.
a2_scale
,
apply_router_weight_on_input
=
True
,
)
td
.
layer
.
dp_size
=
1
flashinfer_cutlass_output
=
flashinfer_cutlass_moe_fp8
(
td
.
hidden_states
,
td
.
layer
,
topk_weights
,
topk_ids
,
activation
=
"silu"
,
global_num_experts
=
e
,
expert_map
=
None
,
apply_router_weight_on_input
=
True
,
)
torch
.
testing
.
assert_close
(
output
,
flashinfer_cutlass_output
,
atol
=
5.5e-2
,
rtol
=
1e-2
)
tests/kernels/moe/test_grouped_topk.py
0 → 100644
View file @
a99300bd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for the MoE grouped topk kernel
Run `pytest tests/kernels/moe/test_grouped_topk.py`.
"""
import
pytest
import
torch
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
fused_grouped_topk
,
grouped_topk
)
from
vllm.platforms
import
current_platform
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda
(),
reason
=
"This test is skipped on non-CUDA platform."
)
@
pytest
.
mark
.
parametrize
(
"n_token"
,
[
1
,
33
,
64
])
@
pytest
.
mark
.
parametrize
(
"n_hidden"
,
[
1024
,
2048
])
@
pytest
.
mark
.
parametrize
(
"n_expert"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"topk"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"renormalize"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"num_expert_group"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"topk_group"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"scoring_func"
,
[
"softmax"
,
"sigmoid"
])
@
pytest
.
mark
.
parametrize
(
"routed_scaling_factor"
,
[
1.0
,
2.5
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
])
def
test_grouped_topk
(
monkeypatch
:
pytest
.
MonkeyPatch
,
n_token
:
int
,
n_hidden
:
int
,
n_expert
:
int
,
topk
:
int
,
renormalize
:
bool
,
num_expert_group
:
int
,
topk_group
:
int
,
scoring_func
:
str
,
routed_scaling_factor
:
float
,
dtype
:
torch
.
dtype
):
current_platform
.
seed_everything
(
0
)
hidden_states
=
torch
.
randn
((
n_token
,
n_hidden
),
dtype
=
dtype
,
device
=
"cuda"
)
gating_output
=
torch
.
randn
((
n_token
,
n_expert
),
dtype
=
dtype
,
device
=
"cuda"
)
e_score_correction_bias
=
torch
.
randn
((
n_expert
,
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_USE_FUSED_MOE_GROUPED_TOPK"
,
"0"
)
baseline_topk_weights
,
baseline_topk_ids
=
grouped_topk
(
hidden_states
=
hidden_states
,
gating_output
=
gating_output
,
topk
=
topk
,
renormalize
=
renormalize
,
num_expert_group
=
num_expert_group
,
topk_group
=
topk_group
,
scoring_func
=
scoring_func
,
routed_scaling_factor
=
routed_scaling_factor
,
e_score_correction_bias
=
e_score_correction_bias
)
test_topk_weights
,
test_topk_ids
=
fused_grouped_topk
(
hidden_states
=
hidden_states
,
gating_output
=
gating_output
,
topk
=
topk
,
renormalize
=
renormalize
,
num_expert_group
=
num_expert_group
,
topk_group
=
topk_group
,
scoring_func
=
scoring_func
,
routed_scaling_factor
=
routed_scaling_factor
,
e_score_correction_bias
=
e_score_correction_bias
)
if
renormalize
:
torch
.
testing
.
assert_close
(
baseline_topk_weights
,
test_topk_weights
,
atol
=
2e-2
,
rtol
=
0
)
torch
.
testing
.
assert_close
(
baseline_topk_ids
,
test_topk_ids
,
atol
=
0
,
rtol
=
0
)
tests/kernels/moe/test_modular_kernel_combinations.py
View file @
a99300bd
...
@@ -16,6 +16,7 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
...
@@ -16,6 +16,7 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from
vllm.utils
import
has_deep_ep
,
has_deep_gemm
,
has_pplx
from
vllm.utils
import
has_deep_ep
,
has_deep_gemm
,
has_pplx
from
vllm.utils.flashinfer
import
has_flashinfer_cutlass_fused_moe
from
vllm.utils.flashinfer
import
has_flashinfer_cutlass_fused_moe
from
...utils
import
multi_gpu_test
from
.modular_kernel_tools.common
import
(
Config
,
RankTensors
,
WeightTensors
,
from
.modular_kernel_tools.common
import
(
Config
,
RankTensors
,
WeightTensors
,
reference_moe_impl
,
reference_moe_impl
,
run_modular_kernel
)
run_modular_kernel
)
...
@@ -162,6 +163,7 @@ def is_nyi_config(config: Config) -> bool:
...
@@ -162,6 +163,7 @@ def is_nyi_config(config: Config) -> bool:
product
(
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES
,
MK_FUSED_EXPERT_TYPES
))
product
(
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES
,
MK_FUSED_EXPERT_TYPES
))
@
pytest
.
mark
.
parametrize
(
"fused_moe_chunk_size"
,
FUSED_MOE_CHUNK_SIZEs
)
@
pytest
.
mark
.
parametrize
(
"fused_moe_chunk_size"
,
FUSED_MOE_CHUNK_SIZEs
)
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
2
])
@
multi_gpu_test
(
num_gpus
=
2
)
@
meets_multi_gpu_requirements
@
meets_multi_gpu_requirements
def
test_modular_kernel_combinations_multigpu
(
def
test_modular_kernel_combinations_multigpu
(
k
:
int
,
n
:
int
,
e
:
int
,
dtype
:
torch
.
dtype
,
k
:
int
,
n
:
int
,
e
:
int
,
dtype
:
torch
.
dtype
,
...
...
tests/kernels/moe/test_moe.py
View file @
a99300bd
...
@@ -438,11 +438,11 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool,
...
@@ -438,11 +438,11 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool,
vllm_moe
.
experts
.
w13_weight
,
(
0
,
128
),
"constant"
,
0
)[...,
vllm_moe
.
experts
.
w13_weight
,
(
0
,
128
),
"constant"
,
0
)[...,
0
:
-
128
],
0
:
-
128
],
requires_grad
=
False
)
requires_grad
=
False
)
torch
.
cuda
.
empty_cache
()
vllm_moe
.
experts
.
w2_weight
=
Parameter
(
F
.
pad
(
vllm_moe
.
experts
.
w2_weight
=
Parameter
(
F
.
pad
(
vllm_moe
.
experts
.
w2_weight
,
(
0
,
128
),
"constant"
,
0
)[...,
vllm_moe
.
experts
.
w2_weight
,
(
0
,
128
),
"constant"
,
0
)[...,
0
:
-
128
],
0
:
-
128
],
requires_grad
=
False
)
requires_grad
=
False
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
# Run forward passes for both MoE blocks
# Run forward passes for both MoE blocks
...
...
tests/kernels/moe/test_mxfp4_moe.py
View file @
a99300bd
...
@@ -4,15 +4,27 @@
...
@@ -4,15 +4,27 @@
import
importlib
import
importlib
import
importlib.metadata
import
importlib.metadata
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Optional
import
pytest
import
pytest
import
torch
import
torch
from
packaging
import
version
from
packaging
import
version
from
vllm.platforms
import
current_platform
QUARK_MXFP4_AVAILABLE
=
importlib
.
util
.
find_spec
(
QUARK_MXFP4_AVAILABLE
=
importlib
.
util
.
find_spec
(
"quark"
)
is
not
None
and
version
.
parse
(
"quark"
)
is
not
None
and
version
.
parse
(
importlib
.
metadata
.
version
(
"amd-quark"
))
>=
version
.
parse
(
'0.8.99'
)
importlib
.
metadata
.
version
(
"amd-quark"
))
>=
version
.
parse
(
'0.8.99'
)
TRTLLM_GEN_MXFP4_AVAILABLE
=
current_platform
.
is_cuda
(
)
and
current_platform
.
is_device_capability
(
100
)
if
TRTLLM_GEN_MXFP4_AVAILABLE
:
from
flashinfer
import
(
fp4_quantize
,
mxfp8_quantize
,
next_positive_power_of_2
,
reorder_rows_for_gated_act_gemm
,
shuffle_matrix_a
,
shuffle_matrix_sf_a
,
trtllm_fp4_block_scale_moe
)
@
dataclass
@
dataclass
class
ModelCase
:
class
ModelCase
:
...
@@ -54,4 +66,410 @@ def test_mxfp4_loading_and_execution_moe(vllm_runner, model_case: ModelCase):
...
@@ -54,4 +66,410 @@ def test_mxfp4_loading_and_execution_moe(vllm_runner, model_case: ModelCase):
output
=
llm
.
generate_greedy
(
"Today I am in the French Alps and"
,
output
=
llm
.
generate_greedy
(
"Today I am in the French Alps and"
,
max_tokens
=
20
)
max_tokens
=
20
)
assert
output
assert
output
\ No newline at end of file
def
swiglu
(
x
,
alpha
:
float
=
1.702
,
beta
:
float
=
1.0
,
limit
:
Optional
[
float
]
=
None
):
# Note we add an extra bias of 1 to the linear layer
x_glu
,
x_linear
=
torch
.
chunk
(
x
,
2
,
dim
=-
1
)
if
limit
is
not
None
:
x_glu
=
x_glu
.
clamp
(
max
=
limit
)
x_linear
=
x_linear
.
clamp
(
min
=-
limit
,
max
=
limit
)
out_glu
=
x_glu
*
torch
.
sigmoid
(
alpha
*
x_glu
)
return
out_glu
*
(
x_linear
+
beta
)
fp4_lookup_table
=
[
0
,
0.5
,
1
,
1.5
,
2
,
3
,
4
,
6
,
-
0
,
-
0.5
,
-
1
,
-
1.5
,
-
2
,
-
3
,
-
4
,
-
6
]
def
mxfp4_dequantize
(
x
,
scale
):
assert
x
.
dtype
==
torch
.
uint8
x
=
x
.
view
(
torch
.
uint8
).
to
(
torch
.
int32
)
x_unpacked
=
torch
.
zeros
(
*
x
.
shape
[:
-
1
],
x
.
shape
[
-
1
]
*
2
,
dtype
=
torch
.
int32
,
device
=
x
.
device
)
x_unpacked
[...,
0
::
2
].
copy_
(
x
&
0xF
)
x_unpacked
[...,
1
::
2
].
copy_
((
x
>>
4
)
&
0xF
)
x_float
=
torch
.
zeros
(
x_unpacked
.
shape
,
dtype
=
torch
.
float32
,
device
=
x
.
device
)
for
i
,
val
in
enumerate
(
fp4_lookup_table
):
x_float
[
x_unpacked
==
i
]
=
val
scale
=
scale
.
view
(
torch
.
uint8
).
to
(
torch
.
int32
)
scale
=
(
scale
<<
23
).
view
(
torch
.
float32
)
scale
=
scale
.
reshape
(
*
x
.
shape
[:
-
1
],
-
1
)
scale
=
torch
.
stack
([
scale
]
*
32
,
dim
=-
1
).
reshape
(
*
x_float
.
shape
)
return
x_float
*
scale
def
mxfp8_dequantize
(
x
,
scale
):
assert
x
.
dtype
==
torch
.
float8_e4m3fn
x_float
=
x
.
to
(
torch
.
float32
)
scale
=
scale
.
view
(
torch
.
uint8
).
to
(
torch
.
int32
)
scale
=
(
scale
<<
23
).
view
(
torch
.
float32
)
scale
=
scale
.
reshape
(
*
x
.
shape
[:
-
1
],
-
1
)
scale
=
torch
.
stack
([
scale
]
*
32
,
dim
=-
1
).
reshape
(
*
x_float
.
shape
)
return
x_float
*
scale
def
reference_moe
(
roouting_logits
,
topk
,
num_experts
,
hidden_states
,
w13
,
bias13
,
w2
,
bias2
,
alpha
,
beta
,
limit
,
act_type
,
):
# renormalize routing
experts
=
torch
.
topk
(
roouting_logits
,
k
=
topk
,
dim
=-
1
,
sorted
=
True
)
expert_weights
=
torch
.
nn
.
functional
.
softmax
(
experts
.
values
,
dim
=
1
)
expert_indices
=
experts
.
indices
t
=
hidden_states
.
clone
()
# MLP #1
mlp1_weight
=
w13
[
expert_indices
,
...]
mlp1_bias
=
bias13
[
expert_indices
,
...]
t
=
torch
.
einsum
(
"beck,bk->bec"
,
mlp1_weight
,
t
)
+
mlp1_bias
t
=
swiglu
(
t
,
alpha
=
alpha
,
beta
=
beta
,
limit
=
limit
)
if
act_type
==
'mxfp8'
:
t_quantized
,
t_scale
=
mxfp8_quantize
(
t
.
to
(
torch
.
bfloat16
),
is_sf_swizzled_layout
=
False
)
t
=
mxfp8_dequantize
(
t_quantized
,
t_scale
)
# MLP #2
mlp2_weight
=
w2
[
expert_indices
,
...]
mlp2_bias
=
bias2
[
expert_indices
,
...]
t
=
torch
.
einsum
(
"beck,bek->bec"
,
mlp2_weight
,
t
)
+
mlp2_bias
# Weighted sum of experts
t
=
torch
.
einsum
(
"bec,be->bc"
,
t
,
expert_weights
)
assert
t
.
shape
==
hidden_states
.
shape
return
t
.
to
(
torch
.
bfloat16
)
def
get_tile_tokens_dim
(
x
:
torch
.
Tensor
,
top_k
:
int
,
num_experts
:
int
):
# Number of tokens in the input tensor.
num_tokens
=
x
.
shape
[
0
]
# Factor to account for the imbalance of the experts.
# factor equals to the
# max_real_num_tokens_per_expert / perfect_num_tokens_per_expert
# - 1.0 means perfect expert distribution.
# - > 1.0 means some experts have more
# tokens than the perfect distribution.
# - < 1.0 does not make sense.
imbalance_factor
=
1.3
# Calculate the number of tokens per expert
# assuming perfect distribution.
num_tokens_per_expert
=
(
num_tokens
*
top_k
)
//
num_experts
# Apply the imbalance factor.
num_tokens_per_expert
=
int
(
num_tokens_per_expert
*
imbalance_factor
)
# And pad the number to the next power of 2.
tile_tokens_dim
=
next_positive_power_of_2
(
num_tokens_per_expert
)
# Cap to 8-64 tokens per CTA tile
# as it's the range supported by the kernel.
tile_tokens_dim
=
min
(
max
(
tile_tokens_dim
,
8
),
64
)
return
tile_tokens_dim
def
tg_mxfp4_moe
(
router_logits
,
topk
,
num_experts
,
intermediate_size
,
hidden_size
,
hidden_states
,
hidden_states_scale
,
w13_weight
,
w13_weight_scale
,
w13_bias
,
w2_weight
,
w2_weight_scale
,
w2_bias
,
act_type
,
alpha
,
beta
,
limit
,
)
->
torch
.
Tensor
:
sf_block_size
=
32
assert
(
w13_weight
.
dim
()
==
3
and
w13_weight
.
shape
[
0
]
==
num_experts
and
w13_weight
.
shape
[
1
]
==
intermediate_size
*
2
and
w13_weight
.
shape
[
2
]
==
hidden_size
//
2
)
assert
(
w13_weight_scale
.
dim
()
==
3
and
w13_weight_scale
.
shape
[
0
]
==
num_experts
and
w13_weight_scale
.
shape
[
1
]
==
intermediate_size
*
2
and
w13_weight_scale
.
shape
[
2
]
==
hidden_size
//
sf_block_size
)
assert
(
w2_weight
.
dim
()
==
3
and
w2_weight
.
shape
[
0
]
==
num_experts
and
w2_weight
.
shape
[
1
]
==
hidden_size
and
w2_weight
.
shape
[
2
]
==
intermediate_size
//
2
)
assert
(
w2_weight_scale
.
dim
()
==
3
and
w2_weight_scale
.
shape
[
1
]
==
hidden_size
and
w2_weight_scale
.
shape
[
2
]
==
intermediate_size
//
sf_block_size
)
assert
(
w13_bias
.
dim
()
==
2
and
w13_bias
.
shape
[
0
]
==
num_experts
and
w13_bias
.
shape
[
1
]
==
intermediate_size
*
2
)
assert
(
w2_bias
.
dim
()
==
2
and
w2_bias
.
shape
[
0
]
==
num_experts
and
w2_bias
.
shape
[
1
]
==
hidden_size
)
# Swap w1 and w3 as the defenition of
# swiglu is different in the trtllm-gen
w13_weight_scale_
=
w13_weight_scale
.
clone
()
w13_weight_
=
w13_weight
.
clone
()
w13_bias_
=
w13_bias
.
clone
()
w13_weight
[:,
:
intermediate_size
,
:].
copy_
(
w13_weight_
[:,
intermediate_size
:,
:])
w13_weight
[:,
intermediate_size
:,
:].
copy_
(
w13_weight_
[:,
:
intermediate_size
,
:])
w13_weight_scale
[:,
:
intermediate_size
,
:].
copy_
(
w13_weight_scale_
[:,
intermediate_size
:,
:])
w13_weight_scale
[:,
intermediate_size
:,
:].
copy_
(
w13_weight_scale_
[:,
:
intermediate_size
,
:])
w13_bias
[:,
:
intermediate_size
].
copy_
(
w13_bias_
[:,
intermediate_size
:])
w13_bias
[:,
intermediate_size
:].
copy_
(
w13_bias_
[:,
:
intermediate_size
])
# Interleave the weights and scaling factors for activation
w13_weight_interleaved
=
[]
w13_weight_scale_interleaved
=
[]
w13_bias_interleaved
=
[]
for
i
in
range
(
num_experts
):
w13_weight_interleaved
.
append
(
reorder_rows_for_gated_act_gemm
(
w13_weight
[
i
].
clone
()))
w13_weight_scale_interleaved
.
append
(
reorder_rows_for_gated_act_gemm
(
w13_weight_scale
[
i
].
clone
()))
w13_bias_interleaved
.
append
(
reorder_rows_for_gated_act_gemm
(
w13_bias
[
i
].
clone
().
reshape
(
-
1
,
1
)))
w13_weight
=
torch
.
stack
(
w13_weight_interleaved
).
reshape
(
num_experts
,
2
*
intermediate_size
,
hidden_size
//
2
)
w13_weight_scale
=
torch
.
stack
(
w13_weight_scale_interleaved
).
reshape
(
num_experts
,
2
*
intermediate_size
,
hidden_size
//
32
)
w13_bias
=
torch
.
stack
(
w13_bias_interleaved
).
reshape
(
num_experts
,
2
*
intermediate_size
)
# Shuffle weights and scaling factors for transposed mma output
gemm1_weights_shuffled
=
[]
gemm1_scales_shuffled
=
[]
gemm2_weights_shuffled
=
[]
gemm2_scales_shuffled
=
[]
gemm1_bias_shuffled
=
[]
gemm2_bias_shuffled
=
[]
epilogue_tile_m
=
128
# FIXME: this depends on the kernel internals
for
i
in
range
(
num_experts
):
gemm1_weights_shuffled
.
append
(
shuffle_matrix_a
(
w13_weight
[
i
].
view
(
torch
.
uint8
),
epilogue_tile_m
))
gemm1_scales_shuffled
.
append
(
shuffle_matrix_sf_a
(
w13_weight_scale
[
i
].
view
(
torch
.
uint8
),
epilogue_tile_m
))
gemm2_weights_shuffled
.
append
(
shuffle_matrix_a
(
w2_weight
[
i
].
view
(
torch
.
uint8
),
epilogue_tile_m
))
gemm2_scales_shuffled
.
append
(
shuffle_matrix_sf_a
(
w2_weight_scale
[
i
].
view
(
torch
.
uint8
),
epilogue_tile_m
))
gemm1_bias_shuffled
.
append
(
shuffle_matrix_a
(
w13_bias
[
i
].
reshape
(
-
1
,
1
),
epilogue_tile_m
))
gemm2_bias_shuffled
.
append
(
shuffle_matrix_a
(
w2_bias
[
i
].
reshape
(
-
1
,
1
),
epilogue_tile_m
))
w13_weight
=
torch
.
stack
(
gemm1_weights_shuffled
)
w13_weight_scale
=
torch
.
stack
(
gemm1_scales_shuffled
).
reshape
(
num_experts
,
2
*
intermediate_size
,
hidden_size
//
sf_block_size
).
view
(
torch
.
float8_e4m3fn
)
w13_bias
=
torch
.
stack
(
gemm1_bias_shuffled
).
reshape
(
num_experts
,
-
1
)
w2_weight
=
torch
.
stack
(
gemm2_weights_shuffled
)
w2_weight_scale
=
torch
.
stack
(
gemm2_scales_shuffled
).
reshape
(
num_experts
,
hidden_size
,
intermediate_size
//
sf_block_size
).
view
(
torch
.
float8_e4m3fn
)
w2_bias
=
torch
.
stack
(
gemm2_bias_shuffled
).
reshape
(
num_experts
,
-
1
)
tg_result
=
trtllm_fp4_block_scale_moe
(
routing_logits
=
router_logits
.
to
(
torch
.
bfloat16
),
routing_bias
=
None
,
hidden_states
=
hidden_states
,
hidden_states_scale
=
hidden_states_scale
,
gemm1_weights
=
w13_weight
,
gemm1_weights_scale
=
w13_weight_scale
,
gemm1_bias
=
w13_bias
,
gemm1_alpha
=
alpha
,
gemm1_beta
=
beta
,
gemm1_clamp_limit
=
limit
,
gemm2_weights
=
w2_weight
,
gemm2_weights_scale
=
w2_weight_scale
,
gemm2_bias
=
w2_bias
,
output1_scale_scalar
=
None
,
output1_scale_gate_scalar
=
None
,
output2_scale_scalar
=
None
,
num_experts
=
num_experts
,
top_k
=
topk
,
n_group
=
None
,
topk_group
=
None
,
intermediate_size
=
intermediate_size
,
local_expert_offset
=
0
,
local_num_experts
=
num_experts
,
routed_scaling_factor
=
None
,
tile_tokens_dim
=
get_tile_tokens_dim
(
hidden_states
,
topk
,
num_experts
),
routing_method_type
=
1
,
# renormalize
do_finalize
=
True
)[
0
]
return
tg_result
def
check_accuracy
(
a
,
b
,
atol
,
rtol
,
percent
):
"""Allow a mismatch percentage of 1 - percent."""
if
torch
.
any
(
torch
.
isnan
(
a
)):
raise
Exception
(
"NaN in reference output"
)
if
torch
.
any
(
torch
.
isnan
(
b
)):
raise
Exception
(
"NaN in actual output"
)
if
torch
.
any
(
torch
.
isinf
(
a
)):
raise
Exception
(
"Inf in reference output"
)
if
torch
.
any
(
torch
.
isinf
(
b
)):
raise
Exception
(
"Inf in actual output"
)
assert
a
.
shape
==
b
.
shape
,
f
"Shape mismatch:
{
a
.
shape
}
vs
{
b
.
shape
}
"
left
=
torch
.
abs
(
a
-
b
)
right
=
atol
+
rtol
*
torch
.
abs
(
b
)
count
=
torch
.
sum
(
left
>
right
)
mismatch_percent
=
count
/
a
.
numel
()
if
mismatch_percent
>
1
-
percent
:
raise
Exception
(
f
"Mismatch percentage is
{
mismatch_percent
:.
4
f
}
for rtol
{
rtol
}
"
f
"(threshold:
{
1
-
percent
:.
4
f
}
)"
)
@
pytest
.
mark
.
parametrize
(
"topk"
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"num_experts"
,
[
32
,
128
])
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
1
,
128
,
1024
])
@
pytest
.
mark
.
parametrize
(
"intermediate_size,hidden_size"
,
[(
3072
,
3072
)])
@
pytest
.
mark
.
parametrize
(
"alpha,beta,limit"
,
[(
1.0
,
1.0
,
None
),
(
1.702
,
1.0
,
7.0
)])
@
pytest
.
mark
.
parametrize
(
"act_type"
,
[
'mxfp8'
,
'bf16'
])
@
pytest
.
mark
.
skipif
(
not
TRTLLM_GEN_MXFP4_AVAILABLE
,
reason
=
"nvidia gpu and compute capability sm100 is required for this test"
)
def
test_trtllm_gen_mxfp4_fused_moe
(
topk
:
int
,
num_experts
:
int
,
num_tokens
:
int
,
intermediate_size
:
int
,
hidden_size
:
int
,
alpha
:
float
,
beta
:
float
,
limit
:
Optional
[
float
],
act_type
:
str
,
):
seed
=
42
torch
.
manual_seed
(
seed
)
hidden_states
=
torch
.
randn
(
num_tokens
,
hidden_size
,
device
=
"cuda:0"
,
dtype
=
torch
.
bfloat16
)
w13
=
(
torch
.
randn
(
num_experts
,
intermediate_size
*
2
,
hidden_size
,
device
=
"cuda:0"
,
dtype
=
torch
.
bfloat16
))
w2
=
(
torch
.
randn
(
num_experts
,
hidden_size
,
intermediate_size
,
device
=
"cuda:0"
,
dtype
=
torch
.
bfloat16
))
bias13
=
torch
.
randn
(
num_experts
,
intermediate_size
*
2
,
device
=
"cuda:0"
)
*
10
bias2
=
torch
.
randn
(
num_experts
,
hidden_size
,
device
=
"cuda:0"
)
*
10
router_logits
=
torch
.
rand
(
num_tokens
,
num_experts
,
dtype
=
torch
.
float32
).
cuda
()
w13
,
w13_scale
=
fp4_quantize
(
w13
,
torch
.
tensor
(
1.0
,
device
=
"cuda:0"
),
32
,
sf_use_ue8m0
=
True
,
is_sf_swizzled_layout
=
False
)
w13_scale
=
w13_scale
.
view
(
torch
.
float8_e4m3fn
).
reshape
(
num_experts
,
intermediate_size
*
2
,
hidden_size
//
32
)
w2
,
w2_scale
=
fp4_quantize
(
w2
,
torch
.
tensor
(
1.0
,
device
=
"cuda:0"
),
32
,
sf_use_ue8m0
=
True
,
is_sf_swizzled_layout
=
False
)
w2_scale
=
w2_scale
.
view
(
torch
.
float8_e4m3fn
).
reshape
(
num_experts
,
hidden_size
,
intermediate_size
//
32
)
if
act_type
==
'mxfp8'
:
hidden_states
,
hidden_states_scale
=
mxfp8_quantize
(
hidden_states
,
is_sf_swizzled_layout
=
False
)
hidden_states_scale
=
hidden_states_scale
.
view
(
torch
.
float8_e4m3fn
).
reshape
(
-
1
)
else
:
hidden_states_scale
=
None
# reference result
ref_result
=
torch
.
empty_like
(
hidden_states
,
dtype
=
torch
.
bfloat16
)
w13_ref
=
mxfp4_dequantize
(
w13
.
clone
(),
w13_scale
.
clone
())
w2_ref
=
mxfp4_dequantize
(
w2
.
clone
(),
w2_scale
.
clone
())
bias13_ref
=
bias13
bias2_ref
=
bias2
if
act_type
==
'mxfp8'
:
hidden_states_ref
=
mxfp8_dequantize
(
hidden_states
,
hidden_states_scale
).
to
(
torch
.
float32
)
else
:
hidden_states_ref
=
hidden_states
.
to
(
torch
.
float32
)
# Process tokens in chunks of 32 to reduce memory usage
chunk_size
=
32
num_chunks
=
(
num_tokens
+
chunk_size
-
1
)
//
chunk_size
for
i
in
range
(
num_chunks
):
start_idx
=
i
*
chunk_size
end_idx
=
min
(
start_idx
+
chunk_size
,
num_tokens
)
chunk_result
=
reference_moe
(
router_logits
[
start_idx
:
end_idx
].
to
(
torch
.
float32
),
topk
,
num_experts
,
hidden_states_ref
[
start_idx
:
end_idx
],
w13_ref
,
bias13_ref
,
w2_ref
,
bias2_ref
,
alpha
,
beta
,
limit
,
act_type
,
)
ref_result
[
start_idx
:
end_idx
].
copy_
(
chunk_result
)
# trtllm-gen result
if
alpha
is
not
None
:
alpha
=
torch
.
full
((
num_experts
,
),
alpha
,
device
=
hidden_states
.
device
)
if
limit
is
not
None
:
limit
=
torch
.
full
((
num_experts
,
),
limit
,
device
=
hidden_states
.
device
)
if
beta
is
not
None
:
beta
=
torch
.
full
((
num_experts
,
),
beta
,
device
=
hidden_states
.
device
)
tg_result
=
tg_mxfp4_moe
(
router_logits
,
topk
,
num_experts
,
intermediate_size
,
hidden_size
,
hidden_states
,
hidden_states_scale
,
w13
,
w13_scale
,
bias13
,
w2
,
w2_scale
,
bias2
,
act_type
,
alpha
=
alpha
,
beta
=
beta
,
limit
=
limit
)
# relatively loose check since the mxfp4 quantization is less accurate
check_accuracy
(
ref_result
,
tg_result
,
atol
=
0
,
rtol
=
0.3
,
percent
=
0.8
)
tests/kernels/moe/test_pplx_moe.py
View file @
a99300bd
...
@@ -37,6 +37,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
...
@@ -37,6 +37,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
round_up
from
vllm.utils
import
round_up
from
...utils
import
multi_gpu_test
from
.parallel_utils
import
ProcessGroupInfo
,
parallel_launch
from
.parallel_utils
import
ProcessGroupInfo
,
parallel_launch
requires_pplx
=
pytest
.
mark
.
skipif
(
requires_pplx
=
pytest
.
mark
.
skipif
(
...
@@ -452,6 +453,7 @@ def _pplx_prepare_finalize(
...
@@ -452,6 +453,7 @@ def _pplx_prepare_finalize(
@
pytest
.
mark
.
parametrize
(
"use_internode"
,
[
False
])
@
pytest
.
mark
.
parametrize
(
"use_internode"
,
[
False
])
@
pytest
.
mark
.
optional
@
pytest
.
mark
.
optional
@
requires_pplx
@
requires_pplx
@
multi_gpu_test
(
num_gpus
=
2
)
def
test_pplx_prepare_finalize_slow
(
def
test_pplx_prepare_finalize_slow
(
mnk
:
tuple
[
int
,
int
,
int
],
mnk
:
tuple
[
int
,
int
,
int
],
e
:
int
,
e
:
int
,
...
@@ -740,6 +742,7 @@ def _pplx_moe(
...
@@ -740,6 +742,7 @@ def _pplx_moe(
@
pytest
.
mark
.
parametrize
(
"use_internode"
,
[
False
])
@
pytest
.
mark
.
parametrize
(
"use_internode"
,
[
False
])
@
pytest
.
mark
.
optional
@
pytest
.
mark
.
optional
@
requires_pplx
@
requires_pplx
@
multi_gpu_test
(
num_gpus
=
2
)
def
test_pplx_moe_slow
(
def
test_pplx_moe_slow
(
mnk
:
tuple
[
int
,
int
,
int
],
mnk
:
tuple
[
int
,
int
,
int
],
e
:
int
,
e
:
int
,
...
@@ -880,6 +883,7 @@ def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool,
...
@@ -880,6 +883,7 @@ def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool,
@
pytest
.
mark
.
parametrize
(
"world_dp_size"
,
[[
2
,
1
]])
@
pytest
.
mark
.
parametrize
(
"world_dp_size"
,
[[
2
,
1
]])
@
pytest
.
mark
.
parametrize
(
"use_internode"
,
[
False
])
@
pytest
.
mark
.
parametrize
(
"use_internode"
,
[
False
])
@
requires_pplx
@
requires_pplx
@
multi_gpu_test
(
num_gpus
=
2
)
def
test_pplx_prepare_finalize
(
def
test_pplx_prepare_finalize
(
world_dp_size
:
tuple
[
int
,
int
],
world_dp_size
:
tuple
[
int
,
int
],
use_internode
:
bool
,
use_internode
:
bool
,
...
@@ -893,6 +897,7 @@ def test_pplx_prepare_finalize(
...
@@ -893,6 +897,7 @@ def test_pplx_prepare_finalize(
@
pytest
.
mark
.
parametrize
(
"world_dp_size"
,
[[
2
,
1
]])
@
pytest
.
mark
.
parametrize
(
"world_dp_size"
,
[[
2
,
1
]])
@
pytest
.
mark
.
parametrize
(
"use_internode"
,
[
False
])
@
pytest
.
mark
.
parametrize
(
"use_internode"
,
[
False
])
@
requires_pplx
@
requires_pplx
@
multi_gpu_test
(
num_gpus
=
2
)
def
test_pplx_moe
(
def
test_pplx_moe
(
world_dp_size
:
tuple
[
int
,
int
],
world_dp_size
:
tuple
[
int
,
int
],
use_internode
:
bool
,
use_internode
:
bool
,
...
...
tests/kernels/moe/untest_block_fp8.py
View file @
a99300bd
...
@@ -16,7 +16,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
...
@@ -16,7 +16,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk
,
modular_triton_fused_moe
)
fused_topk
,
modular_triton_fused_moe
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
has_deep_gemm
from
vllm.utils
import
has_deep_gemm
from
vllm.utils.deep_gemm
import
is_
blackwell_
deep_gemm_e8m0_used
from
vllm.utils.deep_gemm
import
is_deep_gemm_e8m0_used
dg_available
=
has_deep_gemm
()
dg_available
=
has_deep_gemm
()
...
@@ -226,8 +226,7 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
...
@@ -226,8 +226,7 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
@
pytest
.
mark
.
parametrize
(
"topk"
,
TOP_KS
)
@
pytest
.
mark
.
parametrize
(
"topk"
,
TOP_KS
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
skipif
(
not
dg_available
,
reason
=
"DeepGemm kernels not available."
)
@
pytest
.
mark
.
skipif
(
not
dg_available
,
reason
=
"DeepGemm kernels not available."
)
@
pytest
.
mark
.
skipif
(
is_blackwell_deep_gemm_e8m0_used
(),
@
pytest
.
mark
.
skipif
(
is_deep_gemm_e8m0_used
(),
reason
=
"Not E8M0 scale MOE"
)
reason
=
"Not E8M0 scale MOE"
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_w8a8_block_fp8_deep_gemm_fused_moe
(
M
,
N
,
K
,
E
,
topk
,
seed
,
def
test_w8a8_block_fp8_deep_gemm_fused_moe
(
M
,
N
,
K
,
E
,
topk
,
seed
,
monkeypatch
):
monkeypatch
):
...
...
tests/kernels/moe/untest_cutlass_moe.py
View file @
a99300bd
...
@@ -207,6 +207,10 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
...
@@ -207,6 +207,10 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
'topk_ids'
:
topk_ids
,
'topk_ids'
:
topk_ids
,
'w1_scale'
:
moe_tensors
.
w1_scale
,
'w1_scale'
:
moe_tensors
.
w1_scale
,
'w2_scale'
:
moe_tensors
.
w2_scale
,
'w2_scale'
:
moe_tensors
.
w2_scale
,
'ab_strides1'
:
moe_tensors
.
ab_strides1
,
'ab_strides2'
:
moe_tensors
.
ab_strides2
,
'c_strides1'
:
moe_tensors
.
c_strides1
,
'c_strides2'
:
moe_tensors
.
c_strides2
,
'per_act_token'
:
per_act_token
,
'per_act_token'
:
per_act_token
,
'a1_scale'
:
None
#moe_tensors.a_scale
'a1_scale'
:
None
#moe_tensors.a_scale
}
}
...
@@ -424,8 +428,8 @@ def test_run_cutlass_moe_fp8(
...
@@ -424,8 +428,8 @@ def test_run_cutlass_moe_fp8(
topk_ids
[
0
][
1
]
=
1
topk_ids
[
0
][
1
]
=
1
workspace13_shape
=
(
m
*
topk
,
max
(
2
*
n
,
k
))
workspace13_shape
=
(
m
*
topk
,
max
(
2
*
n
,
k
))
workspace2_shape
=
(
m
*
topk
,
n
)
workspace2_shape
=
(
m
*
topk
,
max
(
n
,
k
)
)
output_shape
=
(
m
*
topk
,
k
)
output_shape
=
(
m
,
k
)
workspace13
=
torch
.
empty
(
prod
(
workspace13_shape
),
workspace13
=
torch
.
empty
(
prod
(
workspace13_shape
),
device
=
"cuda"
,
device
=
"cuda"
,
...
@@ -440,6 +444,11 @@ def test_run_cutlass_moe_fp8(
...
@@ -440,6 +444,11 @@ def test_run_cutlass_moe_fp8(
expert_map
[
start
:
end
]
=
list
(
range
(
num_local_experts
))
expert_map
[
start
:
end
]
=
list
(
range
(
num_local_experts
))
expert_map
=
torch
.
tensor
(
expert_map
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
expert_map
=
torch
.
tensor
(
expert_map
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
ab_strides1
=
torch
.
full
((
e
,
),
k
,
device
=
"cuda"
,
dtype
=
torch
.
int64
)
ab_strides2
=
torch
.
full
((
e
,
),
n
,
device
=
"cuda"
,
dtype
=
torch
.
int64
)
c_strides1
=
torch
.
full
((
e
,
),
2
*
n
,
device
=
"cuda"
,
dtype
=
torch
.
int64
)
c_strides2
=
torch
.
full
((
e
,
),
k
,
device
=
"cuda"
,
dtype
=
torch
.
int64
)
activation
=
lambda
o
,
i
:
torch
.
ops
.
_C
.
silu_and_mul
(
o
,
i
)
activation
=
lambda
o
,
i
:
torch
.
ops
.
_C
.
silu_and_mul
(
o
,
i
)
a1q
,
a1q_scale
=
moe_kernel_quantize_input
(
mt
.
a
,
mt
.
a_scale
,
a1q
,
a1q_scale
=
moe_kernel_quantize_input
(
mt
.
a
,
mt
.
a_scale
,
torch
.
float8_e4m3fn
,
torch
.
float8_e4m3fn
,
...
@@ -448,8 +457,9 @@ def test_run_cutlass_moe_fp8(
...
@@ -448,8 +457,9 @@ def test_run_cutlass_moe_fp8(
func
=
lambda
output
:
run_cutlass_moe_fp8
(
func
=
lambda
output
:
run_cutlass_moe_fp8
(
output
,
a1q
,
mt
.
w1_q
,
mt
.
w2_q
,
topk_ids
,
activation
,
output
,
a1q
,
mt
.
w1_q
,
mt
.
w2_q
,
topk_ids
,
activation
,
global_num_experts
,
expert_map
,
mt
.
w1_scale
,
mt
.
w2_scale
,
global_num_experts
,
expert_map
,
mt
.
w1_scale
,
mt
.
w2_scale
,
a1q_scale
,
None
,
workspace13
,
workspace2
,
None
,
mt
.
a
.
dtype
,
a1q_scale
,
None
,
ab_strides1
,
ab_strides2
,
c_strides1
,
c_strides2
,
per_act_token
,
per_out_channel
,
False
)
workspace13
,
workspace2
,
None
,
mt
.
a
.
dtype
,
per_act_token
,
per_out_channel
,
False
,
topk_weights
)
workspace13
.
random_
()
workspace13
.
random_
()
output_random_workspace
=
torch
.
empty
(
output_shape
,
output_random_workspace
=
torch
.
empty
(
output_shape
,
...
...
tests/kernels/moe/untest_moe_permute_unpermute.py
View file @
a99300bd
...
@@ -238,7 +238,11 @@ def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int,
...
@@ -238,7 +238,11 @@ def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int,
atol
=
0
,
atol
=
0
,
rtol
=
0
)
rtol
=
0
)
# check mindice
# check mindice
torch
.
testing
.
assert_close
(
gold_m_indices
,
m_indices
,
atol
=
0
,
rtol
=
0
)
# current kernel usage assumes deepgemm requires align_block_size
# when it's not provided then we don't compute m_indices (for cutlass)
if
align_block_size
is
not
None
:
torch
.
testing
.
assert_close
(
gold_m_indices
,
m_indices
,
atol
=
0
,
rtol
=
0
)
# check permuted_hidden_states, only valid token
# check permuted_hidden_states, only valid token
torch
.
testing
.
assert_close
(
gold_permuted_hidden_states
[
valid_row_idx
],
torch
.
testing
.
assert_close
(
gold_permuted_hidden_states
[
valid_row_idx
],
permuted_hidden_states
[
valid_row_idx
],
permuted_hidden_states
[
valid_row_idx
],
...
...
Prev
1
…
8
9
10
11
12
13
14
15
16
…
26
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment