Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d2b52805
Commit
d2b52805
authored
Sep 07, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.10.2rc1' into v0.10.2rc1-ori
parents
9a521c23
5438967f
Changes
501
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1594 additions
and
171 deletions
+1594
-171
tests/evals/gsm8k/conftest.py
tests/evals/gsm8k/conftest.py
+66
-0
tests/evals/gsm8k/gsm8k_eval.py
tests/evals/gsm8k/gsm8k_eval.py
+252
-0
tests/evals/gsm8k/test_gsm8k_correctness.py
tests/evals/gsm8k/test_gsm8k_correctness.py
+90
-0
tests/kernels/attention/test_attention_selector.py
tests/kernels/attention/test_attention_selector.py
+3
-0
tests/kernels/attention/test_cache.py
tests/kernels/attention/test_cache.py
+93
-5
tests/kernels/attention/test_flashinfer.py
tests/kernels/attention/test_flashinfer.py
+2
-4
tests/kernels/attention/test_flashinfer_trtllm_attention.py
tests/kernels/attention/test_flashinfer_trtllm_attention.py
+233
-129
tests/kernels/attention/test_flashmla.py
tests/kernels/attention/test_flashmla.py
+51
-18
tests/kernels/moe/test_block_fp8.py
tests/kernels/moe/test_block_fp8.py
+2
-3
tests/kernels/moe/test_cutlass_moe.py
tests/kernels/moe/test_cutlass_moe.py
+14
-4
tests/kernels/moe/test_deepep_deepgemm_moe.py
tests/kernels/moe/test_deepep_deepgemm_moe.py
+6
-4
tests/kernels/moe/test_deepep_moe.py
tests/kernels/moe/test_deepep_moe.py
+3
-0
tests/kernels/moe/test_flashinfer.py
tests/kernels/moe/test_flashinfer.py
+248
-0
tests/kernels/moe/test_grouped_topk.py
tests/kernels/moe/test_grouped_topk.py
+76
-0
tests/kernels/moe/test_modular_kernel_combinations.py
tests/kernels/moe/test_modular_kernel_combinations.py
+2
-0
tests/kernels/moe/test_moe.py
tests/kernels/moe/test_moe.py
+1
-1
tests/kernels/moe/test_moe_permute_unpermute.py
tests/kernels/moe/test_moe_permute_unpermute.py
+5
-1
tests/kernels/moe/test_mxfp4_moe.py
tests/kernels/moe/test_mxfp4_moe.py
+419
-1
tests/kernels/moe/test_pplx_cutlass_moe.py
tests/kernels/moe/test_pplx_cutlass_moe.py
+23
-1
tests/kernels/moe/test_pplx_moe.py
tests/kernels/moe/test_pplx_moe.py
+5
-0
No files found.
Too many changes to show.
To preserve performance only
501 of 501+
files are displayed.
Plain diff
Email patch
tests/evals/gsm8k/conftest.py
0 → 100644
View file @
d2b52805
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
pathlib
import
Path
def
pytest_addoption
(
parser
):
"""Add custom command line options."""
parser
.
addoption
(
"--config-list-file"
,
default
=
"configs/models-small.txt"
,
help
=
"File containing list of config files to test"
)
parser
.
addoption
(
"--tp-size"
,
default
=
1
,
type
=
int
,
help
=
"Tensor parallel size"
)
def
pytest_generate_tests
(
metafunc
):
"""Generate test parameters from config files."""
if
"config_filename"
in
metafunc
.
fixturenames
:
config_list_file
=
metafunc
.
config
.
getoption
(
"--config-list-file"
)
tp_size
=
metafunc
.
config
.
getoption
(
"--tp-size"
)
# Handle both relative and absolute paths
config_list_path
=
Path
(
config_list_file
)
if
not
config_list_path
.
is_absolute
():
# If relative, try relative to test directory first
test_dir_path
=
Path
(
__file__
).
parent
/
config_list_file
if
test_dir_path
.
exists
():
config_list_path
=
test_dir_path
else
:
# Try relative to current working directory
config_list_path
=
Path
.
cwd
()
/
config_list_file
print
(
f
"Looking for config list at:
{
config_list_path
}
"
)
config_files
=
[]
if
config_list_path
.
exists
():
# Determine config directory (same directory as the list file)
config_dir
=
config_list_path
.
parent
with
open
(
config_list_path
)
as
f
:
for
line
in
f
:
line
=
line
.
strip
()
if
line
and
not
line
.
startswith
(
"#"
):
config_path
=
config_dir
/
line
print
(
f
"Checking config file:
{
config_path
}
"
)
if
config_path
.
exists
():
config_files
.
append
(
config_path
)
print
(
f
" ✓ Found:
{
config_path
}
"
)
else
:
print
(
f
" ✗ Missing:
{
config_path
}
"
)
else
:
print
(
f
"Config list file not found:
{
config_list_path
}
"
)
# Generate test parameters
if
config_files
:
metafunc
.
parametrize
([
"config_filename"
,
"tp_size"
],
[(
config_file
,
int
(
tp_size
))
for
config_file
in
config_files
],
ids
=
[
f
"
{
config_file
.
stem
}
-tp
{
tp_size
}
"
for
config_file
in
config_files
])
else
:
print
(
"No config files found, test will be skipped"
)
tests/evals/gsm8k/gsm8k_eval.py
0 → 100644
View file @
d2b52805
#!/usr/bin/env python3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Isolated GSM8K evaluation script for vLLM serve endpoint.
"""
import
argparse
import
ast
import
asyncio
import
json
import
os
import
time
from
collections.abc
import
Generator
from
typing
import
Optional
,
Union
import
aiohttp
import
numpy
as
np
import
regex
as
re
import
requests
from
tqdm.asyncio
import
tqdm
INVALID
=
-
9999999
def
download_and_cache_file
(
url
:
str
,
filename
:
Optional
[
str
]
=
None
)
->
str
:
"""Download and cache a file from a URL."""
if
filename
is
None
:
filename
=
os
.
path
.
join
(
"/tmp"
,
url
.
split
(
"/"
)[
-
1
])
if
os
.
path
.
exists
(
filename
):
return
filename
print
(
f
"Downloading from
{
url
}
to
{
filename
}
"
)
response
=
requests
.
get
(
url
,
stream
=
True
)
response
.
raise_for_status
()
with
open
(
filename
,
"wb"
)
as
f
:
for
chunk
in
response
.
iter_content
(
chunk_size
=
1024
):
f
.
write
(
chunk
)
return
filename
def
load_gsm8k_data
()
->
tuple
[
list
[
dict
],
list
[
dict
]]:
"""Load GSM8K train and test data"""
train_url
=
"https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/train.jsonl"
test_url
=
"https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl"
train_file
=
download_and_cache_file
(
train_url
)
test_file
=
download_and_cache_file
(
test_url
)
train_data
=
list
(
read_jsonl
(
train_file
))
test_data
=
list
(
read_jsonl
(
test_file
))
return
train_data
,
test_data
def
read_jsonl
(
filename
:
str
)
->
Generator
[
dict
,
None
,
None
]:
"""Read a JSONL file."""
with
open
(
filename
)
as
fin
:
for
line
in
fin
:
if
not
line
.
startswith
(
"#"
):
yield
json
.
loads
(
line
)
def
get_answer_value
(
answer_str
:
str
)
->
int
:
"""Extract the numerical answer from the response."""
answer_str
=
answer_str
.
replace
(
","
,
""
)
numbers
=
re
.
findall
(
r
"\d+"
,
answer_str
)
if
len
(
numbers
)
<
1
:
return
INVALID
try
:
return
ast
.
literal_eval
(
numbers
[
-
1
])
except
SyntaxError
:
return
INVALID
async
def
call_vllm_api
(
session
:
aiohttp
.
ClientSession
,
prompt
:
str
,
temperature
:
float
,
max_tokens
:
int
,
stop
:
Optional
[
list
[
str
]]
=
None
,
url
:
Optional
[
str
]
=
None
,
seed
:
Optional
[
int
]
=
None
)
->
str
:
"""Call vLLM's OpenAI-compatible completions endpoint."""
data
=
{
"prompt"
:
prompt
,
"temperature"
:
temperature
,
"max_tokens"
:
max_tokens
,
"stop"
:
stop
,
}
if
seed
is
not
None
:
data
[
"seed"
]
=
seed
try
:
async
with
session
.
post
(
f
"
{
url
}
/v1/completions"
,
json
=
data
)
as
response
:
response
.
raise_for_status
()
result
=
await
response
.
json
()
return
result
[
"choices"
][
0
][
"text"
]
except
Exception
as
e
:
print
(
f
"Error calling vLLM API:
{
e
}
"
)
return
""
def
evaluate_gsm8k
(
num_questions
:
int
=
1319
,
num_shots
:
int
=
5
,
max_tokens
:
int
=
256
,
host
:
str
=
"http://127.0.0.1"
,
port
:
int
=
8000
,
temperature
:
float
=
0.0
,
seed
:
Optional
[
int
]
=
42
)
->
dict
[
str
,
Union
[
float
,
int
]]:
"""
Evaluate GSM8K accuracy using vLLM serve endpoint.
Returns dict with accuracy, invalid_rate, latency, etc.
"""
base_url
=
f
"
{
host
}
:
{
port
}
"
# Load GSM8K train and test data
train_data
,
test_data
=
load_gsm8k_data
()
# Limit to available test questions
num_questions
=
min
(
num_questions
,
len
(
test_data
))
# Build few-shot examples from train split (like lm-eval does)
few_shot_examples
=
""
for
i
in
range
(
num_shots
):
few_shot_examples
+=
(
f
"Question:
{
train_data
[
i
][
'question'
]
}
\n
"
f
"Answer:
{
train_data
[
i
][
'answer'
]
}
\n\n
"
)
# Prepare test questions and labels from test split
questions
=
[]
labels
=
[]
for
i
in
range
(
num_questions
):
questions
.
append
(
f
"Question:
{
test_data
[
i
][
'question'
]
}
\n
Answer:"
)
labels
.
append
(
get_answer_value
(
test_data
[
i
][
"answer"
]))
assert
all
(
label
!=
INVALID
for
label
in
labels
),
"Some labels are invalid"
# Run evaluation
async
def
run_async_evaluation
():
states
:
list
[
str
]
=
[
""
]
*
num_questions
async
def
get_answer
(
session
:
aiohttp
.
ClientSession
,
i
:
int
)
->
str
:
prompt
=
few_shot_examples
+
questions
[
i
]
answer
=
await
call_vllm_api
(
session
=
session
,
prompt
=
prompt
,
temperature
=
temperature
,
max_tokens
=
max_tokens
,
stop
=
[
"Question"
,
"Assistant:"
,
"<|separator|>"
],
url
=
base_url
,
seed
=
seed
,
)
states
[
i
]
=
answer
return
answer
async
with
aiohttp
.
ClientSession
(
timeout
=
aiohttp
.
ClientTimeout
(
total
=
600
))
as
session
:
tasks
=
[
get_answer
(
session
,
i
)
for
i
in
range
(
num_questions
)]
await
tqdm
.
gather
(
*
tasks
,
desc
=
"Evaluating"
)
return
states
print
(
f
"Running GSM8K evaluation:
{
num_questions
}
questions, "
f
"
{
num_shots
}
-shot"
)
tic
=
time
.
perf_counter
()
states
=
asyncio
.
run
(
run_async_evaluation
())
latency
=
time
.
perf_counter
()
-
tic
# Compute metrics
preds
=
[
get_answer_value
(
state
)
for
state
in
states
]
accuracy
=
np
.
mean
(
np
.
array
(
preds
)
==
np
.
array
(
labels
))
invalid_rate
=
np
.
mean
(
np
.
array
(
preds
)
==
INVALID
)
result
=
{
"accuracy"
:
accuracy
,
"invalid_rate"
:
invalid_rate
,
"latency"
:
latency
,
"questions_per_second"
:
num_questions
/
latency
,
"num_questions"
:
num_questions
,
"num_shots"
:
num_shots
,
"max_tokens"
:
max_tokens
,
"timestamp"
:
time
.
time
(),
}
return
result
def
main
()
->
None
:
parser
=
argparse
.
ArgumentParser
(
description
=
"GSM8K evaluation for vLLM serve"
)
parser
.
add_argument
(
"--num-shots"
,
type
=
int
,
default
=
5
,
help
=
"Number of few-shot examples"
)
parser
.
add_argument
(
"--num-questions"
,
type
=
int
,
default
=
1319
,
help
=
"Number of questions to evaluate"
)
parser
.
add_argument
(
"--max-tokens"
,
type
=
int
,
default
=
256
,
help
=
"Max tokens for generation"
)
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"http://127.0.0.1"
,
help
=
"Host URL"
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
8000
,
help
=
"Port number"
)
parser
.
add_argument
(
"--temperature"
,
type
=
float
,
default
=
0.0
,
help
=
"Temperature for generation"
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
42
,
help
=
"Random seed for reproducibility"
)
parser
.
add_argument
(
"--save-results"
,
type
=
str
,
help
=
"Save results to JSON file"
)
args
=
parser
.
parse_args
()
result
=
evaluate_gsm8k
(
num_questions
=
args
.
num_questions
,
num_shots
=
args
.
num_shots
,
max_tokens
=
args
.
max_tokens
,
host
=
args
.
host
,
port
=
args
.
port
,
temperature
=
args
.
temperature
,
seed
=
args
.
seed
,
)
# Print results to terminal
print
(
"
\n
Results:"
)
print
(
f
"Accuracy:
{
result
[
'accuracy'
]:.
3
f
}
"
)
print
(
f
"Invalid responses:
{
result
[
'invalid_rate'
]:.
3
f
}
"
)
print
(
f
"Total latency:
{
result
[
'latency'
]:.
3
f
}
s"
)
print
(
f
"Questions per second:
{
result
[
'questions_per_second'
]:.
3
f
}
"
)
# Optional file saving
if
args
.
save_results
:
with
open
(
args
.
save_results
,
"w"
)
as
f
:
json
.
dump
(
result
,
f
,
indent
=
2
)
print
(
f
"Results saved to
{
args
.
save_results
}
"
)
if
__name__
==
"__main__"
:
main
()
tests/evals/gsm8k/test_gsm8k_correctness.py
0 → 100644
View file @
d2b52805
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
GSM8K evaluation using vLLM server and isolated GSM8K script.
Replacement for lm-eval-harness with better performance and control.
Usage:
pytest -s -v test_gsm8k_correctness.py
\
--config-list-file=configs/models-small.txt
\
--tp-size=1
"""
import
yaml
from
tests.utils
import
RemoteOpenAIServer
from
.gsm8k_eval
import
evaluate_gsm8k
RTOL
=
0.08
# Relative tolerance for accuracy comparison
def
launch_gsm8k_eval
(
eval_config
,
server_url
,
tp_size
):
"""Launch GSM8K evaluation using our isolated script."""
# Extract host and port from server URL
if
"://"
in
server_url
:
server_url
=
server_url
.
split
(
"://"
)[
1
]
host_port
=
server_url
.
split
(
"/"
)[
0
]
# Remove path if present
if
":"
in
host_port
:
host
,
port
=
host_port
.
split
(
":"
)
port
=
int
(
port
)
else
:
host
=
host_port
port
=
8000
# Add http:// prefix if not present
if
not
host
.
startswith
(
"http"
):
host
=
f
"http://
{
host
}
"
# Run GSM8K evaluation
results
=
evaluate_gsm8k
(
num_questions
=
eval_config
[
"num_questions"
],
num_shots
=
eval_config
[
"num_fewshot"
],
host
=
host
,
port
=
port
,
)
return
results
def
test_gsm8k_correctness_param
(
config_filename
,
tp_size
):
"""Test GSM8K correctness for a given model configuration."""
eval_config
=
yaml
.
safe_load
(
config_filename
.
read_text
(
encoding
=
"utf-8"
))
# Server arguments
server_args
=
[
"--max-model-len"
,
str
(
eval_config
.
get
(
"max_model_len"
,
4096
)),
"--enforce-eager"
,
"--trust-remote-code"
,
"--tensor-parallel-size"
,
str
(
tp_size
),
]
# Launch server and run evaluation
with
RemoteOpenAIServer
(
eval_config
[
"model_name"
],
server_args
,
max_wait_seconds
=
480
)
as
remote_server
:
server_url
=
remote_server
.
url_for
(
"v1"
)
results
=
launch_gsm8k_eval
(
eval_config
,
server_url
,
tp_size
)
# Check accuracy against threshold
measured_accuracy
=
results
[
"accuracy"
]
expected_accuracy
=
eval_config
[
"accuracy_threshold"
]
print
(
f
"GSM8K Results for
{
eval_config
[
'model_name'
]
}
:"
)
print
(
f
" Accuracy:
{
measured_accuracy
:.
3
f
}
"
)
print
(
f
" Expected:
{
expected_accuracy
:.
3
f
}
"
)
print
(
f
" Questions:
{
results
[
'num_questions'
]
}
"
)
print
(
f
" Invalid rate:
{
results
[
'invalid_rate'
]:.
3
f
}
"
)
print
(
f
" Latency:
{
results
[
'latency'
]:.
1
f
}
s"
)
print
(
f
" QPS:
{
results
[
'questions_per_second'
]:.
1
f
}
"
)
# Verify accuracy is within tolerance
assert
measured_accuracy
>=
expected_accuracy
-
RTOL
,
(
f
"Accuracy too low:
{
measured_accuracy
:.
3
f
}
< "
f
"
{
expected_accuracy
:.
3
f
}
-
{
RTOL
:.
3
f
}
"
)
print
(
f
"✅ GSM8K test passed for
{
eval_config
[
'model_name'
]
}
"
)
tests/kernels/attention/test_attention_selector.py
View file @
d2b52805
...
...
@@ -80,6 +80,9 @@ def test_env(
m
.
setenv
(
STR_BACKEND_ENV_VAR
,
name
)
m
.
setenv
(
"VLLM_MLA_DISABLE"
,
"1"
if
use_mla
else
"0"
)
if
name
==
"FLASHINFER"
and
not
use_v1
:
pytest
.
skip
(
"FlashInfer backend is only available on V1 engine"
)
if
device
==
"cpu"
:
with
patch
(
"vllm.attention.selector.current_platform"
,
CpuPlatform
()):
...
...
tests/kernels/attention/test_cache.py
View file @
d2b52805
...
...
@@ -702,6 +702,94 @@ def test_swap_blocks_mla(
f
"
{
dst
}
in dst_cache."
)
@
pytest
.
mark
.
parametrize
(
"kv_lora_rank"
,
[
512
])
@
pytest
.
mark
.
parametrize
(
"qk_rope_head_dim"
,
[
64
])
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
[
1024
])
@
pytest
.
mark
.
parametrize
(
"max_seq_len"
,
[
512
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
])
@
pytest
.
mark
.
parametrize
(
"kv_cache_dtype"
,
[
"auto"
,
"fp8"
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
def
test_gather_and_maybe_dequant_cache_mla
(
kv_lora_rank
,
qk_rope_head_dim
,
block_size
,
num_blocks
,
max_seq_len
,
batch_size
,
dtype
,
kv_cache_dtype
,
device
):
entry_size
=
kv_lora_rank
+
qk_rope_head_dim
scale
=
torch
.
tensor
(
0.1
,
dtype
=
torch
.
float32
,
device
=
device
)
src_cache
=
_create_mla_cache
(
num_blocks
,
block_size
,
entry_size
,
dtype
,
kv_cache_dtype
,
device
)
_fill_mla_cache
(
src_cache
,
kv_cache_dtype
=
kv_cache_dtype
)
seq_len_tensor
=
torch
.
randint
(
0
,
max_seq_len
+
1
,
(
batch_size
,
),
device
=
device
)
total_tokens
=
seq_len_tensor
.
sum
()
cu_seq_lens
=
torch
.
empty
((
batch_size
+
1
),
dtype
=
torch
.
int32
,
device
=
device
)
cu_seq_lens
[
0
]
=
0
cu_seq_lens
[
1
:]
=
seq_len_tensor
.
cumsum
(
dim
=
0
).
to
(
dtype
=
torch
.
int32
)
print
(
"seq_len_tensor"
,
seq_len_tensor
)
tot_blocks_tensor
=
(
seq_len_tensor
+
block_size
-
1
)
//
block_size
block_table
=
torch
.
empty
((
batch_size
,
num_blocks
),
dtype
=
torch
.
int32
,
device
=
device
)
for
b
in
range
(
batch_size
):
perm
=
torch
.
randperm
(
num_blocks
,
device
=
device
)
block_table
[
b
,
:]
=
perm
dst
=
torch
.
zeros
((
total_tokens
,
entry_size
),
dtype
=
dtype
,
device
=
device
)
expected_batches
=
[]
for
b
in
range
(
batch_size
):
s
=
seq_len_tensor
[
b
]
if
s
==
0
:
continue
tot
=
tot_blocks_tensor
[
b
]
blocks
=
block_table
[
b
,
:
tot
].
tolist
()
gathered_rows
=
[]
for
i
in
range
(
tot
-
1
):
block_data
=
src_cache
[
blocks
[
i
]]
if
kv_cache_dtype
==
"fp8"
:
dequantized_block
=
torch
.
empty_like
(
block_data
,
dtype
=
dtype
)
ops
.
convert_fp8
(
dequantized_block
,
block_data
,
scale
.
item
())
gathered_rows
.
append
(
dequantized_block
)
else
:
gathered_rows
.
append
(
block_data
)
remaining
=
s
-
(
tot
-
1
)
*
block_size
last_block_data
=
src_cache
[
blocks
[
-
1
],
:
remaining
,
:]
if
kv_cache_dtype
==
"fp8"
:
dequantized_last_block
=
torch
.
empty_like
(
last_block_data
,
dtype
=
dtype
)
ops
.
convert_fp8
(
dequantized_last_block
,
last_block_data
,
scale
.
item
())
gathered_rows
.
append
(
dequantized_last_block
)
else
:
gathered_rows
.
append
(
last_block_data
)
batch_expected
=
torch
.
cat
(
gathered_rows
,
dim
=
0
)
expected_batches
.
append
(
batch_expected
)
expected
=
torch
.
cat
(
expected_batches
,
dim
=
0
)
opcheck
(
torch
.
ops
.
_C_cache_ops
.
gather_and_maybe_dequant_cache
,
(
src_cache
,
dst
,
block_table
,
cu_seq_lens
,
batch_size
,
kv_cache_dtype
,
scale
,
None
),
test_utils
=
DEFAULT_OPCHECK_TEST_UTILS
,
)
ops
.
gather_and_maybe_dequant_cache
(
src_cache
,
dst
,
block_table
,
cu_seq_lens
,
batch_size
,
kv_cache_dtype
,
scale
,
None
)
torch
.
testing
.
assert_close
(
dst
,
expected
)
@
pytest
.
mark
.
parametrize
(
"kv_lora_rank"
,
[
512
])
@
pytest
.
mark
.
parametrize
(
"qk_rope_head_dim"
,
[
64
])
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
16
])
...
...
@@ -713,9 +801,9 @@ def test_swap_blocks_mla(
[
"auto"
])
# You can also test "fp8" if needed.
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
def
test_gather_cache_mla
(
kv_lora_rank
,
qk_rope_head_dim
,
block_size
,
num_blocks
,
max_seq_len
,
batch_size
,
dtype
,
kv_cache_dtype
,
device
):
def
test_
cp_
gather_cache_mla
(
kv_lora_rank
,
qk_rope_head_dim
,
block_size
,
num_blocks
,
max_seq_len
,
batch_size
,
dtype
,
kv_cache_dtype
,
device
):
entry_size
=
kv_lora_rank
+
qk_rope_head_dim
src_cache
=
_create_mla_cache
(
num_blocks
,
block_size
,
entry_size
,
dtype
,
kv_cache_dtype
,
device
)
...
...
@@ -765,12 +853,12 @@ def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,
expected
=
torch
.
cat
(
expected_batches
,
dim
=
0
)
opcheck
(
torch
.
ops
.
_C_cache_ops
.
gather_cache
,
torch
.
ops
.
_C_cache_ops
.
cp_
gather_cache
,
(
src_cache
,
dst
,
block_table
,
cu_seq_lens
,
batch_size
,
None
),
test_utils
=
DEFAULT_OPCHECK_TEST_UTILS
,
)
ops
.
gather_cache
(
src_cache
,
dst
,
block_table
,
cu_seq_lens
,
batch_size
)
ops
.
cp_
gather_cache
(
src_cache
,
dst
,
block_table
,
cu_seq_lens
,
batch_size
)
torch
.
testing
.
assert_close
(
dst
,
expected
)
...
...
tests/kernels/attention/test_flashinfer.py
View file @
d2b52805
...
...
@@ -137,9 +137,7 @@ def test_flashinfer_decode_with_paged_kv(
workspace_buffer
=
torch
.
empty
(
128
*
1024
*
1024
,
dtype
=
torch
.
int8
)
wrapper
=
flashinfer
.
\
BatchDecodeWithPagedKVCacheWrapper
(
workspace_buffer
,
"NHD"
,
use_tensor_cores
=
(
(
num_query_heads
//
num_kv_heads
)
>
4
)
)
use_tensor_cores
=
True
)
wrapper
.
plan
(
kv_indptr
,
kv_indices
,
...
...
@@ -411,7 +409,7 @@ def test_flashinfer_decode_with_paged_fp8_kv(
assert
num_query_heads
%
num_kv_heads
==
0
max_kv_len
=
max
(
kv_lens
)
scale
=
head_size
**-
0.5
use_tensor_cores
=
(
num_query_heads
//
num_kv_heads
)
>
4
use_tensor_cores
=
True
kv_cache_dtype
=
torch
.
float8_e4m3fn
query
=
torch
.
randn
(
num_seqs
,
num_query_heads
,
head_size
,
dtype
=
dtype
)
...
...
tests/kernels/attention/test_flashinfer_trtllm_attention.py
View file @
d2b52805
...
...
@@ -6,28 +6,19 @@ import flashinfer
import
pytest
import
torch
from
tests.kernels.quantization.nvfp4_utils
import
(
FLOAT4_E2M1_MAX
,
FLOAT8_E4M3_MAX
,
dequantize_nvfp4_to_dtype
)
from
vllm.platforms
import
current_platform
from
vllm.utils
import
round_up
if
not
current_platform
.
is_device_capability
(
100
):
pytest
.
skip
(
"This TRTLLM kernel requires NVIDIA Blackwell."
,
allow_module_level
=
True
)
FLOAT32_BYTES
=
torch
.
finfo
(
torch
.
float
).
bits
//
8
# KV Cache Layout for TRT-LLM
# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim)
MAX_Q_LEN
=
1024
MAX_KV_LEN
=
4096
BATCH_SIZES
=
[
4
,
12
]
NUM_HEADS
=
[(
16
,
16
),
(
40
,
8
)]
HEAD_SIZES
=
[
128
]
BLOCK_SIZES
=
[
16
]
KV_LAYOUTS
=
[
"HND"
]
DTYPES
=
[
torch
.
bfloat16
]
KV_CACHE_DTYPES
=
[
None
,
current_platform
.
fp8_dtype
()]
NUM_BLOCKS
=
32768
# Large enough to test overflow in index calculation.
SOFT_CAPS
=
[
None
,
50.0
]
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
FP4_DTYPE
=
torch
.
uint8
def
to_float8
(
x
,
dtype
=
torch
.
float8_e4m3fn
):
...
...
@@ -39,42 +30,61 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
return
x_scl_sat
.
to
(
dtype
),
scale
.
float
().
reciprocal
()
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
BATCH_SIZES
)
DTYPE
=
[
torch
.
bfloat16
]
QUANT_DTYPES
=
[
# (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
(
None
,
None
,
None
),
(
None
,
FP8_DTYPE
,
None
),
(
FP8_DTYPE
,
FP8_DTYPE
,
FP8_DTYPE
),
(
FP8_DTYPE
,
FP8_DTYPE
,
FP4_DTYPE
),
]
BATCH_SIZE
=
[
4
,
12
]
MAX_SEQ_LENS
=
[(
1024
,
4096
)]
NUM_HEADS
=
[(
64
,
8
),
(
40
,
8
)]
HEAD_SIZE
=
[
128
]
KV_LAYOUT
=
[
"HND"
]
# currently only HND is supported
BLOCK_SIZE
=
[
16
]
SOFT_CAP
=
[
None
,
50.0
]
NUM_BLOCKS
=
32768
# Large enough to test overflow in index calculation.
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPE
)
@
pytest
.
mark
.
parametrize
(
"quant_dtypes"
,
QUANT_DTYPES
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
BATCH_SIZE
)
@
pytest
.
mark
.
parametrize
(
"max_seq_lens"
,
MAX_SEQ_LENS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"kv_layout"
,
KV_LAYOUTS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"kv_cache_dtype"
,
KV_CACHE_DTYPES
)
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
SOFT_CAPS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZE
)
@
pytest
.
mark
.
parametrize
(
"kv_layout"
,
KV_LAYOUT
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZE
)
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
SOFT_CAP
)
@
torch
.
inference_mode
def
test_flashinfer_trtllm_decode_with_baseline
(
dtype
:
torch
.
dtype
,
quant_dtypes
:
tuple
[
Optional
[
torch
.
dtype
],
Optional
[
torch
.
dtype
],
Optional
[
torch
.
dtype
]],
batch_size
:
int
,
max_seq_lens
:
tuple
[
int
,
int
],
num_heads
:
tuple
[
int
,
int
],
head_size
:
int
,
block_size
:
int
,
kv_layout
:
str
,
dtype
:
torch
.
dtype
,
kv_cache_dtype
:
Optional
[
torch
.
dtype
],
block_size
:
int
,
soft_cap
:
Optional
[
float
],
)
->
None
:
kv_cache_dtype
=
dtype
if
kv_cache_dtype
is
None
else
kv_cache_dtype
torch
.
set_default_device
(
"cuda"
)
current_platform
.
seed_everything
(
0
)
kv_lens
=
torch
.
randint
(
1
,
MAX_KV_LEN
,
(
batch_size
,
),
dtype
=
torch
.
int32
)
kv_lens
[
-
1
]
=
MAX_KV_LEN
max_kv_len
=
torch
.
max
(
kv_lens
).
item
()
num_seqs
=
len
(
kv_lens
)
q_quant_dtype
,
kv_quant_dtype
,
o_quant_dtype
=
quant_dtypes
q_quant_dtype
=
q_quant_dtype
or
dtype
kv_quant_dtype
=
kv_quant_dtype
or
dtype
o_quant_dtype
=
o_quant_dtype
or
dtype
num_query_heads
=
num_heads
[
0
]
num_kv_heads
=
num_heads
[
1
]
assert
num_query_heads
%
num_kv_heads
==
0
_
,
max_kv_len
=
max_seq_lens
scale
=
head_size
**-
0.5
num_qo_heads
,
num_kv_heads
=
num_heads
assert
num_qo_heads
%
num_kv_heads
==
0
query
=
torch
.
randn
(
num_seqs
,
num_query_heads
,
head_size
,
dtype
=
dtype
)
sm_scale
=
float
(
1.0
/
(
head_size
**
0.5
)
)
kv_cache_shape
=
None
if
kv_layout
==
"NHD"
:
...
...
@@ -83,23 +93,40 @@ def test_flashinfer_trtllm_decode_with_baseline(
kv_cache_shape
=
(
NUM_BLOCKS
,
2
,
num_kv_heads
,
block_size
,
head_size
)
else
:
raise
ValueError
(
f
"Invalid kv_layout:
{
kv_layout
}
"
)
key_value_cache
=
torch
.
randn
(
kv_cache_shape
,
dtype
=
dtype
)
kv_scale
=
1.0
if
kv_cache_dtype
is
current_platform
.
fp8_dtype
():
key_value_cache
,
kv_scale
=
to_float8
(
key_value_cache
,
current_platform
.
fp8_dtype
())
max_num_blocks_per_seq
=
(
max_kv_len
+
block_size
-
1
)
//
block_size
query
=
torch
.
randn
(
batch_size
,
num_qo_heads
,
head_size
,
dtype
=
dtype
)
if
q_quant_dtype
==
FP8_DTYPE
:
query
,
q_scale
=
to_float8
(
query
)
ref_query
=
query
.
to
(
dtype
)
*
q_scale
else
:
q_scale
=
1.0
ref_query
=
query
kv_lens
=
torch
.
randint
(
1
,
max_kv_len
,
(
batch_size
,
),
dtype
=
torch
.
int32
)
kv_lens
[
-
1
]
=
max_kv_len
seq_lens
=
kv_lens
max_seq_len
=
torch
.
max
(
seq_lens
).
item
()
kv_cache
=
torch
.
randn
(
kv_cache_shape
,
dtype
=
dtype
)
if
kv_quant_dtype
==
FP8_DTYPE
:
kv_cache
,
kv_scale
=
to_float8
(
kv_cache
)
ref_kv_cache
=
kv_cache
.
to
(
dtype
)
*
kv_scale
else
:
kv_scale
=
1.0
ref_kv_cache
=
kv_cache
k_scale
=
v_scale
=
kv_scale
max_num_blocks_per_seq
=
(
max_seq_len
+
block_size
-
1
)
//
block_size
block_tables
=
torch
.
randint
(
0
,
NUM_BLOCKS
,
(
num_seqs
,
max_num_blocks_per_seq
),
(
batch_size
,
max_num_blocks_per_seq
),
dtype
=
torch
.
int32
)
k_scale
=
v_scale
=
kv_scale
kv_indptr
=
[
0
]
kv_indices
=
[]
kv_last_page_lens
=
[]
for
i
in
range
(
num_seqs
):
seq_len
=
kv
_lens
[
i
]
for
i
in
range
(
batch_size
):
seq_len
=
seq
_lens
[
i
]
assert
seq_len
>
0
num_blocks
=
(
seq_len
+
block_size
-
1
)
//
block_size
kv_indices
.
extend
(
block_tables
[
i
,
:
num_blocks
])
...
...
@@ -112,103 +139,120 @@ def test_flashinfer_trtllm_decode_with_baseline(
kv_indptr
=
torch
.
tensor
(
kv_indptr
,
dtype
=
torch
.
int32
)
kv_indices
=
torch
.
tensor
(
kv_indices
,
dtype
=
torch
.
int32
)
kv_last_page_lens
=
torch
.
tensor
(
kv_last_page_lens
,
dtype
=
torch
.
int32
)
workspace_buffer
=
torch
.
zeros
(
128
*
1024
*
1024
,
dtype
=
torch
.
int8
)
# Baseline Decode
wrapper
=
flashinfer
.
BatchDecodeWithPagedKVCacheWrapper
(
workspace_buffer
,
kv_layout
,
use_tensor_cores
=
((
num_query_heads
//
num_kv_heads
)
>
4
))
workspace_buffer
,
kv_layout
,
use_tensor_cores
=
True
)
wrapper
.
plan
(
kv_indptr
,
kv_indices
,
kv_last_page_lens
,
num_q
uery
_heads
,
num_q
o
_heads
,
num_kv_heads
,
head_size
,
block_size
,
"NONE"
,
sm_scale
=
scale
,
sm_scale
=
sm_
scale
,
q_data_type
=
dtype
,
kv_data_type
=
kv_cache_
dtype
,
kv_data_type
=
dtype
,
logits_soft_cap
=
soft_cap
)
output
=
torch
.
empty
(
query
.
shape
,
dtype
=
dtype
)
wrapper
.
run
(
query
,
key_value_cache
,
k_scale
=
k_scale
,
v_scale
=
v_scale
,
out
=
output
)
output
=
torch
.
empty
(
ref_query
.
shape
,
dtype
=
dtype
)
wrapper
.
run
(
ref_query
,
ref_kv_cache
,
out
=
output
)
o_scale
=
1.0
o_sf_scale
=
None
if
o_quant_dtype
==
FP8_DTYPE
:
_
,
o_scale
=
to_float8
(
output
)
elif
o_quant_dtype
==
FP4_DTYPE
:
o_sf_scale
=
((
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
)
/
torch
.
amax
(
output
.
flatten
(),
dim
=-
1
)).
to
(
torch
.
float32
)
# TRTLLM Decode
kv_lens_tensor
=
torch
.
tensor
(
kv_lens
,
dtype
=
torch
.
int32
)
output_trtllm
=
torch
.
empty
(
query
.
shape
,
dtype
=
dtype
)
if
o_quant_dtype
==
FP4_DTYPE
:
output_trtllm
=
flashinfer
.
utils
.
FP4Tensor
(
torch
.
empty
(
query
.
shape
[:
-
1
]
+
(
query
.
shape
[
-
1
]
//
2
,
),
dtype
=
torch
.
uint8
),
torch
.
empty
((
round_up
(
query
.
shape
[
0
],
128
),
round_up
(
query
.
shape
[
1
]
*
query
.
shape
[
2
]
//
16
,
4
)),
dtype
=
torch
.
float8_e4m3fn
),
)
else
:
output_trtllm
=
torch
.
empty
(
query
.
shape
,
dtype
=
o_quant_dtype
)
flashinfer
.
decode
.
trtllm_batch_decode_with_kv_cache
(
query
=
query
.
contiguous
()
,
kv_cache
=
k
ey_value
_cache
,
query
=
query
,
kv_cache
=
k
v
_cache
,
workspace_buffer
=
workspace_buffer
,
block_tables
=
block_tables
,
seq_lens
=
kv_lens_tensor
,
max_seq_len
=
max_kv_len
,
bmm1_scale
=
k_scale
*
scale
,
bmm2_scale
=
v_scale
,
seq_lens
=
seq_lens
,
max_seq_len
=
max_seq_len
,
bmm1_scale
=
q_scale
*
k_scale
*
sm_scale
,
bmm2_scale
=
v_scale
/
o_scale
,
o_sf_scale
=
o_sf_scale
,
out
=
output_trtllm
,
)
if
o_quant_dtype
==
FP8_DTYPE
:
output_trtllm
=
output_trtllm
.
to
(
dtype
)
*
o_scale
elif
o_quant_dtype
==
FP4_DTYPE
:
output_trtllm
.
data
=
output_trtllm
.
data
.
reshape
(
-
1
,
query
.
shape
[
1
]
*
query
.
shape
[
2
]
//
2
)
output_trtllm
=
dequantize_nvfp4_to_dtype
(
output_trtllm
.
data
,
output_trtllm
.
scale
,
o_sf_scale
,
dtype
,
query
.
device
)
output_trtllm
=
output_trtllm
.
reshape
(
-
1
,
query
.
shape
[
1
],
query
.
shape
[
2
])
if
q_quant_dtype
==
FP8_DTYPE
and
o_quant_dtype
==
FP4_DTYPE
:
rtol
,
atol
=
3e-1
,
1e0
elif
q_quant_dtype
==
FP8_DTYPE
and
o_quant_dtype
==
FP8_DTYPE
:
rtol
,
atol
=
5e-2
,
7e-2
else
:
rtol
,
atol
=
1e-2
,
2e-2
torch
.
testing
.
assert_close
(
output
,
output_trtllm
,
atol
=
1e-2
,
rtol
=
1e-2
),
\
torch
.
testing
.
assert_close
(
output
,
output_trtllm
,
atol
=
atol
,
rtol
=
rtol
),
\
f
"
{
torch
.
max
(
torch
.
abs
(
output
-
output_trtllm
))
}
"
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
BATCH_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPE
)
@
pytest
.
mark
.
parametrize
(
"quant_dtypes"
,
QUANT_DTYPES
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
BATCH_SIZE
)
@
pytest
.
mark
.
parametrize
(
"max_seq_lens"
,
MAX_SEQ_LENS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"kv_layout"
,
KV_LAYOUTS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"kv_cache_dtype"
,
KV_CACHE_DTYPES
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZE
)
@
pytest
.
mark
.
parametrize
(
"kv_layout"
,
KV_LAYOUT
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZE
)
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
])
@
torch
.
inference_mode
def
test_flashinfer_trtllm_prefill_with_baseline
(
dtype
:
torch
.
dtype
,
quant_dtypes
:
tuple
[
Optional
[
torch
.
dtype
],
Optional
[
torch
.
dtype
],
Optional
[
torch
.
dtype
]],
batch_size
:
int
,
max_seq_lens
:
tuple
[
int
,
int
],
num_heads
:
tuple
[
int
,
int
],
head_size
:
int
,
block_size
:
int
,
kv_layout
:
str
,
dtype
:
torch
.
dtype
,
kv_cache_dtype
:
Optional
[
torch
.
dtype
],
block_size
:
int
,
soft_cap
:
Optional
[
float
],
)
->
None
:
kv_cache_dtype
=
dtype
if
kv_cache_dtype
is
None
else
kv_cache_dtype
if
dtype
!=
kv_cache_dtype
:
pytest
.
skip
(
f
"Not supported dtype(
{
dtype
}
) with "
"kv_cache_dtype({kv_cache_dtype})"
)
torch
.
set_default_device
(
"cuda"
)
current_platform
.
seed_everything
(
0
)
q_lens
=
torch
.
randint
(
1
,
MAX_Q_LEN
,
(
batch_size
,
),
dtype
=
torch
.
int32
)
q_lens
[
-
1
]
=
MAX_Q_LEN
max_q_len
=
torch
.
max
(
q_lens
).
item
()
q_indptr
=
torch
.
cat
([
torch
.
tensor
([
0
],
dtype
=
torch
.
int32
),
torch
.
cumsum
(
q_lens
,
dim
=
0
,
dtype
=
torch
.
int32
),
])
q_quant_dtype
,
kv_quant_dtype
,
o_quant_dtype
=
quant_dtypes
q_quant_dtype
=
q_quant_dtype
or
dtype
kv_quant_dtype
=
kv_quant_dtype
or
dtype
o_quant_dtype
=
o_quant_dtype
or
dtype
kv_lens
=
torch
.
randint
(
0
,
MAX_KV_LEN
,
(
batch_size
,
),
dtype
=
torch
.
int32
)
kv_lens
[
-
1
]
=
MAX_KV_LEN
if
q_quant_dtype
!=
kv_quant_dtype
:
pytest
.
skip
(
"Skipped mixed QKV dtypes for prefill"
)
seq_lens
=
kv_lens
+
q_lens
max_seq_len
=
torch
.
max
(
seq_lens
).
item
()
num_seqs
=
len
(
seq_lens
)
max_q_len
,
max_kv_len
=
max_seq_lens
num_query_heads
=
num_heads
[
0
]
num_kv_heads
=
num_heads
[
1
]
assert
num_query_heads
%
num_kv_heads
==
0
num_qo_heads
,
num_kv_heads
=
num_heads
assert
num_qo_heads
%
num_kv_heads
==
0
scale
=
head_size
**-
0.5
query
=
torch
.
randn
(
torch
.
sum
(
q_lens
).
item
(),
num_query_heads
,
head_size
,
dtype
=
dtype
)
sm_scale
=
float
(
1.0
/
(
head_size
**
0.5
))
kv_cache_shape
=
None
if
kv_layout
==
"NHD"
:
...
...
@@ -217,22 +261,49 @@ def test_flashinfer_trtllm_prefill_with_baseline(
kv_cache_shape
=
(
NUM_BLOCKS
,
2
,
num_kv_heads
,
block_size
,
head_size
)
else
:
raise
ValueError
(
f
"Invalid kv_layout:
{
kv_layout
}
"
)
key_value_cache
=
torch
.
randn
(
kv_cache_shape
,
dtype
=
dtype
)
kv_scale
=
1.0
if
kv_cache_dtype
is
current_platform
.
fp8_dtype
():
key_value_cache
,
kv_scale
=
to_float8
(
key_value_cache
,
current_platform
.
fp8_dtype
())
q_lens
=
torch
.
randint
(
1
,
max_q_len
,
(
batch_size
,
),
dtype
=
torch
.
int32
)
q_lens
[
-
1
]
=
max_q_len
q_indptr
=
torch
.
cat
([
torch
.
tensor
([
0
],
dtype
=
torch
.
int32
),
torch
.
cumsum
(
q_lens
,
dim
=
0
,
dtype
=
torch
.
int32
),
])
query
=
torch
.
randn
(
torch
.
sum
(
q_lens
).
item
(),
num_qo_heads
,
head_size
,
dtype
=
dtype
)
if
q_quant_dtype
==
FP8_DTYPE
:
query
,
q_scale
=
to_float8
(
query
)
ref_query
=
query
.
to
(
dtype
)
*
q_scale
else
:
q_scale
=
1.0
ref_query
=
query
kv_lens
=
torch
.
randint
(
0
,
max_kv_len
,
(
batch_size
,
),
dtype
=
torch
.
int32
)
kv_lens
[
-
1
]
=
max_kv_len
seq_lens
=
kv_lens
+
q_lens
max_seq_len
=
torch
.
max
(
seq_lens
).
item
()
kv_cache
=
torch
.
randn
(
kv_cache_shape
,
dtype
=
dtype
)
if
kv_quant_dtype
==
FP8_DTYPE
:
kv_cache
,
kv_scale
=
to_float8
(
kv_cache
)
ref_kv_cache
=
kv_cache
.
to
(
dtype
)
*
kv_scale
else
:
kv_scale
=
1.0
ref_kv_cache
=
kv_cache
k_scale
=
v_scale
=
kv_scale
max_num_blocks_per_seq
=
(
max_seq_len
+
block_size
-
1
)
//
block_size
block_tables
=
torch
.
randint
(
0
,
NUM_BLOCKS
,
(
num_seqs
,
max_num_blocks_per_seq
),
(
batch_size
,
max_num_blocks_per_seq
),
dtype
=
torch
.
int32
)
k_scale
=
v_scale
=
kv_scale
kv_indptr
=
[
0
]
kv_indices
=
[]
kv_last_page_lens
=
[]
for
i
in
range
(
num_seqs
):
for
i
in
range
(
batch_size
):
seq_len
=
seq_lens
[
i
]
assert
seq_len
>
0
num_blocks
=
(
seq_len
+
block_size
-
1
)
//
block_size
...
...
@@ -246,48 +317,81 @@ def test_flashinfer_trtllm_prefill_with_baseline(
kv_indptr
=
torch
.
tensor
(
kv_indptr
,
dtype
=
torch
.
int32
)
kv_indices
=
torch
.
tensor
(
kv_indices
,
dtype
=
torch
.
int32
)
kv_last_page_lens
=
torch
.
tensor
(
kv_last_page_lens
,
dtype
=
torch
.
int32
)
workspace_buffer
=
torch
.
zeros
(
128
*
1024
*
1024
,
dtype
=
torch
.
int8
)
# Baseline Prefill
wrapper
=
flashinfer
.
BatchPrefillWithPagedKVCacheWrapper
(
workspace_buffer
,
kv_layout
)
wrapper
.
plan
(
q_indptr
,
kv_indptr
,
kv_indices
,
kv_last_page_lens
,
num_q
uery
_heads
,
num_q
o
_heads
,
num_kv_heads
,
head_size
,
block_size
,
causal
=
True
,
sm_scale
=
scale
,
sm_scale
=
sm_
scale
,
q_data_type
=
dtype
,
kv_data_type
=
kv_cache_
dtype
,
kv_data_type
=
dtype
,
logits_soft_cap
=
soft_cap
)
output
=
torch
.
empty
(
query
.
shape
,
dtype
=
dtype
)
wrapper
.
run
(
query
,
key_value_cache
,
k_scale
=
k_scale
,
v_scale
=
v_scale
,
out
=
output
)
output
=
torch
.
empty
(
ref_query
.
shape
,
dtype
=
dtype
)
wrapper
.
run
(
ref_query
,
ref_kv_cache
,
out
=
output
)
o_scale
=
1.0
o_sf_scale
=
None
if
o_quant_dtype
==
FP8_DTYPE
:
_
,
o_scale
=
to_float8
(
output
)
elif
o_quant_dtype
==
FP4_DTYPE
:
o_sf_scale
=
((
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
)
/
torch
.
amax
(
output
.
flatten
(),
dim
=-
1
)).
to
(
torch
.
float32
)
# TRTLLM Prefill
if
o_quant_dtype
==
FP4_DTYPE
:
output_trtllm
=
flashinfer
.
utils
.
FP4Tensor
(
torch
.
empty
(
query
.
shape
[:
-
1
]
+
(
query
.
shape
[
-
1
]
//
2
,
),
dtype
=
torch
.
uint8
),
torch
.
empty
((
round_up
(
query
.
shape
[
0
],
128
),
round_up
(
query
.
shape
[
1
]
*
query
.
shape
[
2
]
//
16
,
4
)),
dtype
=
torch
.
float8_e4m3fn
),
)
else
:
output_trtllm
=
torch
.
empty
(
query
.
shape
,
dtype
=
o_quant_dtype
)
# TRTLLM Decode
output_trtllm
=
torch
.
empty
(
query
.
shape
,
dtype
=
dtype
)
flashinfer
.
prefill
.
trtllm_batch_context_with_kv_cache
(
query
=
query
.
contiguous
()
,
kv_cache
=
k
ey_value
_cache
,
query
=
query
,
kv_cache
=
k
v
_cache
,
workspace_buffer
=
workspace_buffer
,
block_tables
=
block_tables
,
seq_lens
=
seq_lens
,
max_q_len
=
max_q_len
,
max_kv_len
=
max_seq_len
,
bmm1_scale
=
k_scale
*
scale
,
bmm2_scale
=
v_scale
,
batch_size
=
num_seqs
,
bmm1_scale
=
q_scale
*
k_scale
*
sm_
scale
,
bmm2_scale
=
v_scale
/
o_scale
,
batch_size
=
batch_size
,
cum_seq_lens_q
=
q_indptr
,
cum_seq_lens_kv
=
kv_indptr
,
o_sf_scale
=
o_sf_scale
,
out
=
output_trtllm
,
)
if
o_quant_dtype
==
FP8_DTYPE
:
output_trtllm
=
output_trtllm
.
to
(
dtype
)
*
o_scale
elif
o_quant_dtype
==
FP4_DTYPE
:
output_trtllm
.
data
=
output_trtllm
.
data
.
reshape
(
-
1
,
query
.
shape
[
1
]
*
query
.
shape
[
2
]
//
2
)
output_trtllm
=
dequantize_nvfp4_to_dtype
(
output_trtllm
.
data
,
output_trtllm
.
scale
,
o_sf_scale
,
dtype
,
query
.
device
)
output_trtllm
=
output_trtllm
.
reshape
(
-
1
,
query
.
shape
[
1
],
query
.
shape
[
2
])
if
q_quant_dtype
==
FP8_DTYPE
and
o_quant_dtype
==
FP4_DTYPE
:
rtol
,
atol
=
4e-1
,
1e0
elif
q_quant_dtype
==
FP8_DTYPE
and
o_quant_dtype
==
FP8_DTYPE
:
rtol
,
atol
=
5e-2
,
7e-2
else
:
rtol
,
atol
=
1e-2
,
1e-2
torch
.
testing
.
assert_close
(
output
,
output_trtllm
,
atol
=
1e-2
,
rtol
=
1e-2
),
\
torch
.
testing
.
assert_close
(
output
,
output_trtllm
,
atol
=
atol
,
rtol
=
rtol
),
\
f
"
{
torch
.
max
(
torch
.
abs
(
output
-
output_trtllm
))
}
"
tests/kernels/attention/test_flashmla.py
View file @
d2b52805
...
...
@@ -13,11 +13,17 @@ from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
from
vllm.triton_utils
import
triton
def
cal_diff
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
name
:
str
)
->
None
:
def
cal_diff
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
name
:
str
,
use_fp8
:
bool
=
False
)
->
None
:
x
,
y
=
x
.
double
(),
y
.
double
()
cos_diff
=
1
-
2
*
(
x
*
y
).
sum
().
item
()
/
max
(
(
x
*
x
+
y
*
y
).
sum
().
item
(),
1e-12
)
assert
cos_diff
<
1e-5
if
(
use_fp8
):
assert
cos_diff
<
1e-4
else
:
assert
cos_diff
<
1e-5
FLASH_MLA_UNSUPPORTED_REASON
=
is_flashmla_supported
()[
1
]
\
if
not
is_flashmla_supported
()[
0
]
else
"FlashMLA is supported"
...
...
@@ -27,7 +33,7 @@ FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \
reason
=
FLASH_MLA_UNSUPPORTED_REASON
)
@
pytest
.
mark
.
parametrize
(
"b"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"s_q"
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
"mean_sk"
,
[
4096
,
8192
])
@
pytest
.
mark
.
parametrize
(
"mean_sk"
,
[
4096
,
8192
,
16384
])
@
pytest
.
mark
.
parametrize
(
"h_q"
,
[
16
,
32
,
64
,
128
])
@
pytest
.
mark
.
parametrize
(
"h_kv"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
576
])
...
...
@@ -35,20 +41,26 @@ FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
64
])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"varlen"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"torch_dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
,
torch
.
float8_e4m3fn
])
@
torch
.
inference_mode
()
def
test_flash_mla
(
b
,
s_q
,
mean_sk
,
h_q
,
h_kv
,
d
,
dv
,
block_size
,
causal
,
varlen
,
dtype
):
varlen
,
torch_
dtype
):
device
=
torch
.
device
(
"cuda:0"
)
torch
.
set_default_dtype
(
dtype
)
if
torch_dtype
==
torch
.
float8_e4m3fn
:
init_dtype
=
torch
.
bfloat16
else
:
init_dtype
=
torch_dtype
torch
.
set_default_dtype
(
init_dtype
)
torch
.
set_default_device
(
device
)
torch
.
cuda
.
set_device
(
device
)
torch
.
manual_seed
(
0
)
random
.
seed
(
0
)
print
(
f
"
{
b
=
}
,
{
s_q
=
}
,
{
mean_sk
=
}
,
{
h_q
=
}
,
{
h_kv
=
}
, "
f
"
{
d
=
}
,
{
dv
=
}
,
{
causal
=
}
,
{
varlen
=
}
,
{
dtype
=
}
"
)
f
"
{
d
=
}
,
{
dv
=
}
,
{
causal
=
}
,
{
varlen
=
}
,
{
torch_
dtype
=
}
"
)
use_fp8
=
torch_dtype
==
torch
.
float8_e4m3fn
cache_seqlens
=
torch
.
full
((
b
,
),
mean_sk
,
dtype
=
torch
.
int32
)
if
varlen
:
for
i
in
range
(
b
):
...
...
@@ -71,6 +83,19 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
tile_scheduler_metadata
,
num_splits
=
get_mla_metadata
(
cache_seqlens
,
s_q
*
h_q
//
h_kv
,
h_kv
)
init_dtype
=
q
.
dtype
if
use_fp8
:
fp8_dtype
=
torch
.
float8_e4m3fn
descale_q
=
torch
.
ones
((
1
),
dtype
=
torch
.
float32
)
descale_k
=
torch
.
ones
((
1
),
dtype
=
torch
.
float32
)
q
=
q
.
to
(
fp8_dtype
)
blocked_k
=
blocked_k
.
to
(
fp8_dtype
)
blocked_v
=
blocked_v
.
to
(
fp8_dtype
)
else
:
descale_q
=
None
descale_k
=
None
def
flash_mla
():
return
flash_mla_with_kvcache
(
q
,
...
...
@@ -81,6 +106,8 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
tile_scheduler_metadata
,
num_splits
,
causal
=
causal
,
descale_q
=
descale_q
,
descale_k
=
descale_k
,
)
def
scaled_dot_product_attention
(
query
,
key
,
value
,
is_causal
=
False
):
...
...
@@ -104,29 +131,35 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
return
attn_weight
@
value
,
lse
def
ref_mla
():
q_
=
(
q
.
to
(
torch
.
float
)
*
descale_q
).
to
(
init_dtype
)
if
use_fp8
else
q
blocked_k_
=
(
blocked_k
.
to
(
torch
.
float
)
*
descale_k
).
to
(
init_dtype
)
if
use_fp8
else
blocked_k
blocked_v_
=
(
blocked_v
.
to
(
torch
.
float
)
*
descale_k
).
to
(
init_dtype
)
if
use_fp8
else
blocked_v
out
=
torch
.
empty
(
b
,
s_q
,
h_q
,
dv
,
dtype
=
torch
.
float32
)
lse
=
torch
.
empty
(
b
,
h_q
,
s_q
,
dtype
=
torch
.
float32
)
for
i
in
range
(
b
):
begin
=
i
*
max_seqlen_pad
end
=
begin
+
cache_seqlens
[
i
]
ref_O
,
LSE
=
scaled_dot_product_attention
(
q
[
i
].
transpose
(
0
,
1
),
blocked_k
.
view
(
-
1
,
h_kv
,
d
)[
begin
:
end
].
transpose
(
0
,
1
),
blocked_v
.
view
(
-
1
,
h_kv
,
dv
)[
begin
:
end
].
transpose
(
0
,
1
),
out_i
,
lse_i
=
scaled_dot_product_attention
(
q
_
[
i
].
transpose
(
0
,
1
),
blocked_k
_
.
view
(
-
1
,
h_kv
,
d
)[
begin
:
end
].
transpose
(
0
,
1
),
blocked_v
_
.
view
(
-
1
,
h_kv
,
dv
)[
begin
:
end
].
transpose
(
0
,
1
),
is_causal
=
causal
,
)
out
[
i
]
=
ref_O
.
transpose
(
0
,
1
)
lse
[
i
]
=
LSE
out
[
i
]
=
out_i
.
transpose
(
0
,
1
)
lse
[
i
]
=
lse_i
return
out
,
lse
out_flash
,
lse_flash
=
flash_mla
()
out_torch
,
lse_torch
=
ref_mla
()
cal_diff
(
out_flash
,
out_torch
,
"out"
)
cal_diff
(
out_flash
,
out_torch
,
"out"
,
use_fp8
)
cal_diff
(
lse_flash
,
lse_torch
,
"lse"
)
t
=
triton
.
testing
.
do_bench
(
flash_mla
)
FLOPS
=
s_q
*
total_seqlens
*
h_q
*
(
d
+
dv
)
*
2
bytes
=
(
total_seqlens
*
h_kv
*
d
+
b
*
s_q
*
h_q
*
d
+
b
*
s_q
*
h_q
*
dv
)
*
(
torch
.
finfo
(
dtype
).
bits
//
8
)
print
(
f
"
{
t
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
t
:.
0
f
}
"
f
"TFLOPS,
{
bytes
/
10
**
6
/
t
:.
0
f
}
GB/s"
)
bytes
=
(
total_seqlens
*
h_kv
*
d
+
b
*
s_q
*
h_q
*
d
)
*
(
torch
.
finfo
(
torch_dtype
).
bits
//
8
)
+
(
b
*
s_q
*
h_q
*
dv
)
*
(
torch
.
finfo
(
init_dtype
).
bits
//
8
)
print
(
f
"
{
t
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
t
:.
0
f
}
TFLOPS,"
,
f
"
{
bytes
/
10
**
6
/
t
:.
0
f
}
GB/s"
)
tests/kernels/moe/test_block_fp8.py
View file @
d2b52805
...
...
@@ -16,7 +16,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk
,
modular_triton_fused_moe
)
from
vllm.platforms
import
current_platform
from
vllm.utils
import
has_deep_gemm
from
vllm.utils.deep_gemm
import
is_
blackwell_
deep_gemm_e8m0_used
from
vllm.utils.deep_gemm
import
is_deep_gemm_e8m0_used
dg_available
=
has_deep_gemm
()
...
...
@@ -226,8 +226,7 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
@
pytest
.
mark
.
parametrize
(
"topk"
,
TOP_KS
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
skipif
(
not
dg_available
,
reason
=
"DeepGemm kernels not available."
)
@
pytest
.
mark
.
skipif
(
is_blackwell_deep_gemm_e8m0_used
(),
reason
=
"Not E8M0 scale MOE"
)
@
pytest
.
mark
.
skipif
(
is_deep_gemm_e8m0_used
(),
reason
=
"Not E8M0 scale MOE"
)
@
torch
.
inference_mode
()
def
test_w8a8_block_fp8_deep_gemm_fused_moe
(
M
,
N
,
K
,
E
,
topk
,
seed
,
monkeypatch
):
...
...
tests/kernels/moe/test_cutlass_moe.py
View file @
d2b52805
...
...
@@ -207,6 +207,10 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
'topk_ids'
:
topk_ids
,
'w1_scale'
:
moe_tensors
.
w1_scale
,
'w2_scale'
:
moe_tensors
.
w2_scale
,
'ab_strides1'
:
moe_tensors
.
ab_strides1
,
'ab_strides2'
:
moe_tensors
.
ab_strides2
,
'c_strides1'
:
moe_tensors
.
c_strides1
,
'c_strides2'
:
moe_tensors
.
c_strides2
,
'per_act_token'
:
per_act_token
,
'a1_scale'
:
None
#moe_tensors.a_scale
}
...
...
@@ -424,8 +428,8 @@ def test_run_cutlass_moe_fp8(
topk_ids
[
0
][
1
]
=
1
workspace13_shape
=
(
m
*
topk
,
max
(
2
*
n
,
k
))
workspace2_shape
=
(
m
*
topk
,
n
)
output_shape
=
(
m
*
topk
,
k
)
workspace2_shape
=
(
m
*
topk
,
max
(
n
,
k
)
)
output_shape
=
(
m
,
k
)
workspace13
=
torch
.
empty
(
prod
(
workspace13_shape
),
device
=
"cuda"
,
...
...
@@ -440,6 +444,11 @@ def test_run_cutlass_moe_fp8(
expert_map
[
start
:
end
]
=
list
(
range
(
num_local_experts
))
expert_map
=
torch
.
tensor
(
expert_map
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
ab_strides1
=
torch
.
full
((
e
,
),
k
,
device
=
"cuda"
,
dtype
=
torch
.
int64
)
ab_strides2
=
torch
.
full
((
e
,
),
n
,
device
=
"cuda"
,
dtype
=
torch
.
int64
)
c_strides1
=
torch
.
full
((
e
,
),
2
*
n
,
device
=
"cuda"
,
dtype
=
torch
.
int64
)
c_strides2
=
torch
.
full
((
e
,
),
k
,
device
=
"cuda"
,
dtype
=
torch
.
int64
)
activation
=
lambda
o
,
i
:
torch
.
ops
.
_C
.
silu_and_mul
(
o
,
i
)
a1q
,
a1q_scale
=
moe_kernel_quantize_input
(
mt
.
a
,
mt
.
a_scale
,
torch
.
float8_e4m3fn
,
...
...
@@ -448,8 +457,9 @@ def test_run_cutlass_moe_fp8(
func
=
lambda
output
:
run_cutlass_moe_fp8
(
output
,
a1q
,
mt
.
w1_q
,
mt
.
w2_q
,
topk_ids
,
activation
,
global_num_experts
,
expert_map
,
mt
.
w1_scale
,
mt
.
w2_scale
,
a1q_scale
,
None
,
workspace13
,
workspace2
,
None
,
mt
.
a
.
dtype
,
per_act_token
,
per_out_channel
,
False
)
a1q_scale
,
None
,
ab_strides1
,
ab_strides2
,
c_strides1
,
c_strides2
,
workspace13
,
workspace2
,
None
,
mt
.
a
.
dtype
,
per_act_token
,
per_out_channel
,
False
,
topk_weights
)
workspace13
.
random_
()
output_random_workspace
=
torch
.
empty
(
output_shape
,
...
...
tests/kernels/moe/test_deepep_deepgemm_moe.py
View file @
d2b52805
...
...
@@ -20,9 +20,9 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel
)
from
vllm.platforms
import
current_platform
from
vllm.utils
import
has_deep_ep
,
has_deep_gemm
from
vllm.utils.deep_gemm
import
(
is_blackwell_deep_gemm_e8m0_used
,
is_deep_gemm_supported
)
from
vllm.utils.deep_gemm
import
is_deep_gemm_e8m0_used
,
is_deep_gemm_supported
from
...utils
import
multi_gpu_test
from
.parallel_utils
import
ProcessGroupInfo
,
parallel_launch
from
.utils
import
make_test_weights
...
...
@@ -370,9 +370,10 @@ NUM_EXPERTS = [32]
@
pytest
.
mark
.
parametrize
(
"num_experts"
,
NUM_EXPERTS
)
@
pytest
.
mark
.
parametrize
(
"topk"
,
TOPKS
)
@
pytest
.
mark
.
parametrize
(
"world_dp_size"
,
[(
2
,
1
)])
@
multi_gpu_test
(
num_gpus
=
2
)
@
requires_deep_ep
@
requires_deep_gemm
@
pytest
.
mark
.
skipif
(
is_
blackwell_
deep_gemm_e8m0_used
(),
@
pytest
.
mark
.
skipif
(
is_deep_gemm_e8m0_used
(),
reason
=
"Skipping test for Blackwell DeepGEMM"
)
def
test_ht_deepep_deepgemm_moe
(
mnk
:
tuple
[
int
,
int
,
int
],
num_experts
:
int
,
topk
:
int
,
world_dp_size
:
tuple
[
int
,
int
]):
...
...
@@ -427,9 +428,10 @@ USE_FP8_DISPATCH = [False]
@
pytest
.
mark
.
parametrize
(
"use_fp8_dispatch"
,
USE_FP8_DISPATCH
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[[
128
,
128
]])
@
pytest
.
mark
.
parametrize
(
"world_dp_size"
,
[(
2
,
1
)])
@
multi_gpu_test
(
num_gpus
=
2
)
@
requires_deep_ep
@
requires_deep_gemm
@
pytest
.
mark
.
skipif
(
is_
blackwell_
deep_gemm_e8m0_used
(),
@
pytest
.
mark
.
skipif
(
is_deep_gemm_e8m0_used
(),
reason
=
"Skipping test for Blackwell DeepGEMM"
)
def
test_ll_deepep_deepgemm_moe
(
mnk
:
tuple
[
int
,
int
,
int
],
...
...
tests/kernels/moe/test_deepep_moe.py
View file @
d2b52805
...
...
@@ -24,6 +24,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
from
vllm.platforms
import
current_platform
from
vllm.utils
import
has_deep_ep
from
...utils
import
multi_gpu_test
from
.parallel_utils
import
ProcessGroupInfo
,
parallel_launch
if
has_deep_ep
():
...
...
@@ -411,6 +412,7 @@ DTYPES = [torch.bfloat16, torch.float8_e4m3fn]
@
pytest
.
mark
.
parametrize
(
"topk"
,
[
6
])
@
pytest
.
mark
.
parametrize
(
"world_dp_size"
,
[(
2
,
1
)])
@
pytest
.
mark
.
parametrize
(
"per_act_token_quant"
,
[
False
,
True
])
@
multi_gpu_test
(
num_gpus
=
2
)
@
requires_deep_ep
def
test_deep_ep_moe
(
dtype
:
torch
.
dtype
,
...
...
@@ -459,6 +461,7 @@ USE_FP8_DISPATCH = [True, False]
@
pytest
.
mark
.
parametrize
(
"topk"
,
[
6
])
@
pytest
.
mark
.
parametrize
(
"world_dp_size"
,
[(
2
,
1
)])
@
pytest
.
mark
.
parametrize
(
"use_fp8_dispatch"
,
USE_FP8_DISPATCH
)
@
multi_gpu_test
(
num_gpus
=
2
)
@
requires_deep_ep
def
test_low_latency_deep_ep_moe
(
dtype
:
torch
.
dtype
,
mnk
:
tuple
[
int
,
int
,
int
],
num_experts
:
int
,
topk
:
int
,
...
...
tests/kernels/moe/test_flashinfer.py
0 → 100644
View file @
d2b52805
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
dataclass
import
pytest
import
torch
from
vllm.config
import
ParallelConfig
,
VllmConfig
,
set_current_vllm_config
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_experts
from
vllm.model_executor.layers.fused_moe.layer
import
FusedMoE
from
vllm.model_executor.layers.quantization.utils.flashinfer_utils
import
(
apply_flashinfer_per_tensor_scale_fp8
,
flashinfer_cutlass_moe_fp8
,
register_moe_scaling_factors
,
rotate_flashinfer_fp8_moe_weights
,
swap_w13_to_w31
)
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
input_to_float8
)
from
vllm.model_executor.models.llama4
import
Llama4MoE
from
vllm.platforms
import
current_platform
from
vllm.utils.flashinfer
import
has_flashinfer_cutlass_fused_moe
if
not
has_flashinfer_cutlass_fused_moe
(
)
or
not
current_platform
.
has_device_capability
(
100
):
pytest
.
skip
(
"Requires flashinfer_cutlass_fused_moe and nvfp4 support"
,
allow_module_level
=
True
)
NUM_EXPERTS
=
[
16
]
TOP_KS
=
[
1
]
MNK_FACTORS
=
[
(
256
,
8192
,
5120
),
(
256
,
4096
,
5120
),
(
127
,
8192
,
5120
),
(
127
,
4096
,
5120
),
(
10
,
8192
,
5120
),
(
10
,
4096
,
5120
),
(
1
,
8192
,
5120
),
(
1
,
4096
,
5120
),
]
vllm_config
=
VllmConfig
(
parallel_config
=
ParallelConfig
(
pipeline_parallel_size
=
1
))
vllm_config
.
scheduler_config
.
max_num_seqs
=
128
vllm_config
.
scheduler_config
.
max_model_len
=
8192
def
quant_fp8_per_tensor_batches
(
a
):
num_batches
=
a
.
size
(
0
)
a_quant
=
[]
a_scales
=
[]
for
i
in
range
(
num_batches
):
a_fp8
,
a_global_sf
=
input_to_float8
(
a
[
i
])
a_global_sf
=
1.0
/
a_global_sf
a_quant
.
append
(
a_fp8
)
a_scales
.
append
(
a_global_sf
)
result_a_quant
=
torch
.
stack
(
a_quant
)
result_a_scales
=
torch
.
stack
(
a_scales
)
return
result_a_quant
,
result_a_scales
@
dataclass
class
TestData
:
hidden_states
:
torch
.
Tensor
w13_quantized
:
torch
.
Tensor
w2_quantized
:
torch
.
Tensor
a1_scale
:
torch
.
Tensor
a2_scale
:
torch
.
Tensor
w13_weight_scale
:
torch
.
Tensor
w2_weight_scale
:
torch
.
Tensor
layer
:
torch
.
nn
.
Module
@
staticmethod
def
make_moe_tensors_8bit
(
m
:
int
,
k
:
int
,
n
:
int
,
e
:
int
,
reorder
:
bool
)
->
"TestData"
:
hidden_states
=
torch
.
randn
(
(
m
,
k
),
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
/
10
w13
=
torch
.
randn
((
e
,
2
*
n
,
k
),
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
w2
=
torch
.
randn
((
e
,
k
,
n
),
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
# Scale to fp8
_
,
a1_scale
=
input_to_float8
(
hidden_states
)
a1_scale
=
1.0
/
a1_scale
a2_scale
=
torch
.
scalar_tensor
(
1.0
).
to
(
device
=
"cuda"
).
to
(
dtype
=
torch
.
float32
)
w13_quantized
,
w13_weight_scale
=
quant_fp8_per_tensor_batches
(
w13
)
w2_quantized
,
w2_weight_scale
=
quant_fp8_per_tensor_batches
(
w2
)
layer
=
torch
.
nn
.
Module
()
layer
.
w13_weight
=
w13_quantized
.
clone
()
layer
.
w2_weight
=
w2_quantized
.
clone
()
layer
.
w13_input_scale
=
a1_scale
layer
.
w2_input_scale
=
a2_scale
layer
.
w13_weight_scale
=
w13_weight_scale
layer
.
w2_weight_scale
=
w2_weight_scale
register_moe_scaling_factors
(
layer
)
# flashinfer expects swapped rows for w13
layer
.
w13_weight
.
data
=
swap_w13_to_w31
(
layer
.
w13_weight
.
data
)
if
reorder
:
rotate_flashinfer_fp8_moe_weights
(
layer
.
w13_weight
,
layer
.
w2_weight
)
layer
.
custom_routing_function
=
Llama4MoE
.
custom_routing_function
layer
.
intermediate_size_per_partition
=
n
layer
.
ep_rank
=
0
layer
.
local_num_experts
=
e
return
TestData
(
hidden_states
=
hidden_states
,
w13_quantized
=
w13_quantized
,
w2_quantized
=
w2_quantized
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
w13_weight_scale
=
w13_weight_scale
,
w2_weight_scale
=
w2_weight_scale
,
layer
=
layer
,
)
@
pytest
.
mark
.
parametrize
(
"m,n,k"
,
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"e"
,
NUM_EXPERTS
)
@
pytest
.
mark
.
parametrize
(
"topk"
,
TOP_KS
)
def
test_flashinfer_per_tensor_moe_fp8_no_graph
(
m
:
int
,
n
:
int
,
k
:
int
,
e
:
int
,
topk
:
int
,
monkeypatch
,
):
current_platform
.
seed_everything
(
7
)
monkeypatch
.
setenv
(
"VLLM_FUSED_MOE_CHUNK_SIZE"
,
"8192"
)
with
set_current_vllm_config
(
vllm_config
):
td
=
TestData
.
make_moe_tensors_8bit
(
m
,
k
,
n
,
e
,
reorder
=
True
)
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
td
.
hidden_states
,
router_logits
=
score
,
use_grouped_topk
=
False
,
top_k
=
topk
,
renormalize
=
False
,
custom_routing_function
=
Llama4MoE
.
custom_routing_function
,
scoring_func
=
"softmax"
)
output
=
fused_experts
(
td
.
hidden_states
,
td
.
w13_quantized
,
td
.
w2_quantized
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
False
,
activation
=
"silu"
,
use_fp8_w8a8
=
True
,
per_channel_quant
=
False
,
global_num_experts
=
e
,
expert_map
=
None
,
w1_scale
=
td
.
w13_weight_scale
,
w2_scale
=
td
.
w2_weight_scale
,
a1_scale
=
td
.
a1_scale
,
a2_scale
=
td
.
a2_scale
,
apply_router_weight_on_input
=
True
,
)
flashinfer_output
=
apply_flashinfer_per_tensor_scale_fp8
(
layer
=
td
.
layer
,
hidden_states
=
td
.
hidden_states
,
router_logits
=
score
,
routing_bias
=
None
,
global_num_experts
=
e
,
top_k
=
topk
,
num_expert_group
=
None
,
topk_group
=
None
,
apply_router_weight_on_input
=
True
)
torch
.
testing
.
assert_close
(
output
,
flashinfer_output
,
atol
=
5.5e-2
,
rtol
=
1e-2
)
@
pytest
.
mark
.
skip
(
"Requires flashinfer version that contains https://github.com/flashinfer-ai/flashinfer/pull/1472"
)
@
pytest
.
mark
.
parametrize
(
"m,n,k"
,
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"e"
,
NUM_EXPERTS
)
@
pytest
.
mark
.
parametrize
(
"topk"
,
TOP_KS
)
def
test_flashinfer_cutlass_moe_fp8_no_graph
(
m
:
int
,
n
:
int
,
k
:
int
,
e
:
int
,
topk
:
int
,
monkeypatch
,
):
current_platform
.
seed_everything
(
7
)
monkeypatch
.
setenv
(
"VLLM_FUSED_MOE_CHUNK_SIZE"
,
"8192"
)
with
set_current_vllm_config
(
vllm_config
):
td
=
TestData
.
make_moe_tensors_8bit
(
m
,
k
,
n
,
e
,
reorder
=
False
)
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
td
.
hidden_states
,
router_logits
=
score
,
use_grouped_topk
=
False
,
top_k
=
topk
,
renormalize
=
False
,
custom_routing_function
=
Llama4MoE
.
custom_routing_function
,
scoring_func
=
"softmax"
)
output
=
fused_experts
(
td
.
hidden_states
,
td
.
w13_quantized
,
td
.
w2_quantized
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
False
,
activation
=
"silu"
,
use_fp8_w8a8
=
True
,
per_channel_quant
=
False
,
global_num_experts
=
e
,
expert_map
=
None
,
w1_scale
=
td
.
w13_weight_scale
,
w2_scale
=
td
.
w2_weight_scale
,
a1_scale
=
td
.
a1_scale
,
a2_scale
=
td
.
a2_scale
,
apply_router_weight_on_input
=
True
,
)
td
.
layer
.
dp_size
=
1
flashinfer_cutlass_output
=
flashinfer_cutlass_moe_fp8
(
td
.
hidden_states
,
td
.
layer
,
topk_weights
,
topk_ids
,
activation
=
"silu"
,
global_num_experts
=
e
,
expert_map
=
None
,
apply_router_weight_on_input
=
True
,
)
torch
.
testing
.
assert_close
(
output
,
flashinfer_cutlass_output
,
atol
=
5.5e-2
,
rtol
=
1e-2
)
tests/kernels/moe/test_grouped_topk.py
0 → 100644
View file @
d2b52805
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for the MoE grouped topk kernel
Run `pytest tests/kernels/moe/test_grouped_topk.py`.
"""
import
pytest
import
torch
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
fused_grouped_topk
,
grouped_topk
)
from
vllm.platforms
import
current_platform
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda
(),
reason
=
"This test is skipped on non-CUDA platform."
)
@
pytest
.
mark
.
parametrize
(
"n_token"
,
[
1
,
33
,
64
])
@
pytest
.
mark
.
parametrize
(
"n_hidden"
,
[
1024
,
2048
])
@
pytest
.
mark
.
parametrize
(
"n_expert"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"topk"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"renormalize"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"num_expert_group"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"topk_group"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"scoring_func"
,
[
"softmax"
,
"sigmoid"
])
@
pytest
.
mark
.
parametrize
(
"routed_scaling_factor"
,
[
1.0
,
2.5
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
])
def
test_grouped_topk
(
monkeypatch
:
pytest
.
MonkeyPatch
,
n_token
:
int
,
n_hidden
:
int
,
n_expert
:
int
,
topk
:
int
,
renormalize
:
bool
,
num_expert_group
:
int
,
topk_group
:
int
,
scoring_func
:
str
,
routed_scaling_factor
:
float
,
dtype
:
torch
.
dtype
):
current_platform
.
seed_everything
(
0
)
hidden_states
=
torch
.
randn
((
n_token
,
n_hidden
),
dtype
=
dtype
,
device
=
"cuda"
)
gating_output
=
torch
.
randn
((
n_token
,
n_expert
),
dtype
=
dtype
,
device
=
"cuda"
)
e_score_correction_bias
=
torch
.
randn
((
n_expert
,
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_USE_FUSED_MOE_GROUPED_TOPK"
,
"0"
)
baseline_topk_weights
,
baseline_topk_ids
=
grouped_topk
(
hidden_states
=
hidden_states
,
gating_output
=
gating_output
,
topk
=
topk
,
renormalize
=
renormalize
,
num_expert_group
=
num_expert_group
,
topk_group
=
topk_group
,
scoring_func
=
scoring_func
,
routed_scaling_factor
=
routed_scaling_factor
,
e_score_correction_bias
=
e_score_correction_bias
)
test_topk_weights
,
test_topk_ids
=
fused_grouped_topk
(
hidden_states
=
hidden_states
,
gating_output
=
gating_output
,
topk
=
topk
,
renormalize
=
renormalize
,
num_expert_group
=
num_expert_group
,
topk_group
=
topk_group
,
scoring_func
=
scoring_func
,
routed_scaling_factor
=
routed_scaling_factor
,
e_score_correction_bias
=
e_score_correction_bias
)
if
renormalize
:
torch
.
testing
.
assert_close
(
baseline_topk_weights
,
test_topk_weights
,
atol
=
2e-2
,
rtol
=
0
)
torch
.
testing
.
assert_close
(
baseline_topk_ids
,
test_topk_ids
,
atol
=
0
,
rtol
=
0
)
tests/kernels/moe/test_modular_kernel_combinations.py
View file @
d2b52805
...
...
@@ -16,6 +16,7 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from
vllm.utils
import
has_deep_ep
,
has_deep_gemm
,
has_pplx
from
vllm.utils.flashinfer
import
has_flashinfer_cutlass_fused_moe
from
...utils
import
multi_gpu_test
from
.modular_kernel_tools.common
import
(
Config
,
RankTensors
,
WeightTensors
,
reference_moe_impl
,
run_modular_kernel
)
...
...
@@ -162,6 +163,7 @@ def is_nyi_config(config: Config) -> bool:
product
(
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES
,
MK_FUSED_EXPERT_TYPES
))
@
pytest
.
mark
.
parametrize
(
"fused_moe_chunk_size"
,
FUSED_MOE_CHUNK_SIZEs
)
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
2
])
@
multi_gpu_test
(
num_gpus
=
2
)
@
meets_multi_gpu_requirements
def
test_modular_kernel_combinations_multigpu
(
k
:
int
,
n
:
int
,
e
:
int
,
dtype
:
torch
.
dtype
,
...
...
tests/kernels/moe/test_moe.py
View file @
d2b52805
...
...
@@ -429,11 +429,11 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool,
vllm_moe
.
experts
.
w13_weight
,
(
0
,
128
),
"constant"
,
0
)[...,
0
:
-
128
],
requires_grad
=
False
)
torch
.
cuda
.
empty_cache
()
vllm_moe
.
experts
.
w2_weight
=
Parameter
(
F
.
pad
(
vllm_moe
.
experts
.
w2_weight
,
(
0
,
128
),
"constant"
,
0
)[...,
0
:
-
128
],
requires_grad
=
False
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
empty_cache
()
# Run forward passes for both MoE blocks
...
...
tests/kernels/moe/test_moe_permute_unpermute.py
View file @
d2b52805
...
...
@@ -238,7 +238,11 @@ def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int,
atol
=
0
,
rtol
=
0
)
# check mindice
torch
.
testing
.
assert_close
(
gold_m_indices
,
m_indices
,
atol
=
0
,
rtol
=
0
)
# current kernel usage assumes deepgemm requires align_block_size
# when it's not provided then we don't compute m_indices (for cutlass)
if
align_block_size
is
not
None
:
torch
.
testing
.
assert_close
(
gold_m_indices
,
m_indices
,
atol
=
0
,
rtol
=
0
)
# check permuted_hidden_states, only valid token
torch
.
testing
.
assert_close
(
gold_permuted_hidden_states
[
valid_row_idx
],
permuted_hidden_states
[
valid_row_idx
],
...
...
tests/kernels/moe/test_mxfp4_moe.py
View file @
d2b52805
...
...
@@ -4,15 +4,27 @@
import
importlib
import
importlib.metadata
from
dataclasses
import
dataclass
from
typing
import
Optional
import
pytest
import
torch
from
packaging
import
version
from
vllm.platforms
import
current_platform
QUARK_MXFP4_AVAILABLE
=
importlib
.
util
.
find_spec
(
"quark"
)
is
not
None
and
version
.
parse
(
importlib
.
metadata
.
version
(
"amd-quark"
))
>=
version
.
parse
(
'0.8.99'
)
TRTLLM_GEN_MXFP4_AVAILABLE
=
current_platform
.
is_cuda
(
)
and
current_platform
.
is_device_capability
(
100
)
if
TRTLLM_GEN_MXFP4_AVAILABLE
:
from
flashinfer
import
(
fp4_quantize
,
mxfp8_quantize
,
next_positive_power_of_2
,
reorder_rows_for_gated_act_gemm
,
shuffle_matrix_a
,
shuffle_matrix_sf_a
,
trtllm_fp4_block_scale_moe
)
@
dataclass
class
ModelCase
:
...
...
@@ -54,4 +66,410 @@ def test_mxfp4_loading_and_execution_moe(vllm_runner, model_case: ModelCase):
output
=
llm
.
generate_greedy
(
"Today I am in the French Alps and"
,
max_tokens
=
20
)
assert
output
\ No newline at end of file
assert
output
def
swiglu
(
x
,
alpha
:
float
=
1.702
,
beta
:
float
=
1.0
,
limit
:
Optional
[
float
]
=
None
):
# Note we add an extra bias of 1 to the linear layer
x_glu
,
x_linear
=
torch
.
chunk
(
x
,
2
,
dim
=-
1
)
if
limit
is
not
None
:
x_glu
=
x_glu
.
clamp
(
max
=
limit
)
x_linear
=
x_linear
.
clamp
(
min
=-
limit
,
max
=
limit
)
out_glu
=
x_glu
*
torch
.
sigmoid
(
alpha
*
x_glu
)
return
out_glu
*
(
x_linear
+
beta
)
fp4_lookup_table
=
[
0
,
0.5
,
1
,
1.5
,
2
,
3
,
4
,
6
,
-
0
,
-
0.5
,
-
1
,
-
1.5
,
-
2
,
-
3
,
-
4
,
-
6
]
def
mxfp4_dequantize
(
x
,
scale
):
assert
x
.
dtype
==
torch
.
uint8
x
=
x
.
view
(
torch
.
uint8
).
to
(
torch
.
int32
)
x_unpacked
=
torch
.
zeros
(
*
x
.
shape
[:
-
1
],
x
.
shape
[
-
1
]
*
2
,
dtype
=
torch
.
int32
,
device
=
x
.
device
)
x_unpacked
[...,
0
::
2
].
copy_
(
x
&
0xF
)
x_unpacked
[...,
1
::
2
].
copy_
((
x
>>
4
)
&
0xF
)
x_float
=
torch
.
zeros
(
x_unpacked
.
shape
,
dtype
=
torch
.
float32
,
device
=
x
.
device
)
for
i
,
val
in
enumerate
(
fp4_lookup_table
):
x_float
[
x_unpacked
==
i
]
=
val
scale
=
scale
.
view
(
torch
.
uint8
).
to
(
torch
.
int32
)
scale
=
(
scale
<<
23
).
view
(
torch
.
float32
)
scale
=
scale
.
reshape
(
*
x
.
shape
[:
-
1
],
-
1
)
scale
=
torch
.
stack
([
scale
]
*
32
,
dim
=-
1
).
reshape
(
*
x_float
.
shape
)
return
x_float
*
scale
def
mxfp8_dequantize
(
x
,
scale
):
assert
x
.
dtype
==
torch
.
float8_e4m3fn
x_float
=
x
.
to
(
torch
.
float32
)
scale
=
scale
.
view
(
torch
.
uint8
).
to
(
torch
.
int32
)
scale
=
(
scale
<<
23
).
view
(
torch
.
float32
)
scale
=
scale
.
reshape
(
*
x
.
shape
[:
-
1
],
-
1
)
scale
=
torch
.
stack
([
scale
]
*
32
,
dim
=-
1
).
reshape
(
*
x_float
.
shape
)
return
x_float
*
scale
def
reference_moe
(
roouting_logits
,
topk
,
num_experts
,
hidden_states
,
w13
,
bias13
,
w2
,
bias2
,
alpha
,
beta
,
limit
,
act_type
,
):
# renormalize routing
experts
=
torch
.
topk
(
roouting_logits
,
k
=
topk
,
dim
=-
1
,
sorted
=
True
)
expert_weights
=
torch
.
nn
.
functional
.
softmax
(
experts
.
values
,
dim
=
1
)
expert_indices
=
experts
.
indices
t
=
hidden_states
.
clone
()
# MLP #1
mlp1_weight
=
w13
[
expert_indices
,
...]
mlp1_bias
=
bias13
[
expert_indices
,
...]
t
=
torch
.
einsum
(
"beck,bk->bec"
,
mlp1_weight
,
t
)
+
mlp1_bias
t
=
swiglu
(
t
,
alpha
=
alpha
,
beta
=
beta
,
limit
=
limit
)
if
act_type
==
'mxfp8'
:
t_quantized
,
t_scale
=
mxfp8_quantize
(
t
.
to
(
torch
.
bfloat16
),
is_sf_swizzled_layout
=
False
)
t
=
mxfp8_dequantize
(
t_quantized
,
t_scale
)
# MLP #2
mlp2_weight
=
w2
[
expert_indices
,
...]
mlp2_bias
=
bias2
[
expert_indices
,
...]
t
=
torch
.
einsum
(
"beck,bek->bec"
,
mlp2_weight
,
t
)
+
mlp2_bias
# Weighted sum of experts
t
=
torch
.
einsum
(
"bec,be->bc"
,
t
,
expert_weights
)
assert
t
.
shape
==
hidden_states
.
shape
return
t
.
to
(
torch
.
bfloat16
)
def
get_tile_tokens_dim
(
x
:
torch
.
Tensor
,
top_k
:
int
,
num_experts
:
int
):
# Number of tokens in the input tensor.
num_tokens
=
x
.
shape
[
0
]
# Factor to account for the imbalance of the experts.
# factor equals to the
# max_real_num_tokens_per_expert / perfect_num_tokens_per_expert
# - 1.0 means perfect expert distribution.
# - > 1.0 means some experts have more
# tokens than the perfect distribution.
# - < 1.0 does not make sense.
imbalance_factor
=
1.3
# Calculate the number of tokens per expert
# assuming perfect distribution.
num_tokens_per_expert
=
(
num_tokens
*
top_k
)
//
num_experts
# Apply the imbalance factor.
num_tokens_per_expert
=
int
(
num_tokens_per_expert
*
imbalance_factor
)
# And pad the number to the next power of 2.
tile_tokens_dim
=
next_positive_power_of_2
(
num_tokens_per_expert
)
# Cap to 8-64 tokens per CTA tile
# as it's the range supported by the kernel.
tile_tokens_dim
=
min
(
max
(
tile_tokens_dim
,
8
),
64
)
return
tile_tokens_dim
def
tg_mxfp4_moe
(
router_logits
,
topk
,
num_experts
,
intermediate_size
,
hidden_size
,
hidden_states
,
hidden_states_scale
,
w13_weight
,
w13_weight_scale
,
w13_bias
,
w2_weight
,
w2_weight_scale
,
w2_bias
,
act_type
,
alpha
,
beta
,
limit
,
)
->
torch
.
Tensor
:
sf_block_size
=
32
assert
(
w13_weight
.
dim
()
==
3
and
w13_weight
.
shape
[
0
]
==
num_experts
and
w13_weight
.
shape
[
1
]
==
intermediate_size
*
2
and
w13_weight
.
shape
[
2
]
==
hidden_size
//
2
)
assert
(
w13_weight_scale
.
dim
()
==
3
and
w13_weight_scale
.
shape
[
0
]
==
num_experts
and
w13_weight_scale
.
shape
[
1
]
==
intermediate_size
*
2
and
w13_weight_scale
.
shape
[
2
]
==
hidden_size
//
sf_block_size
)
assert
(
w2_weight
.
dim
()
==
3
and
w2_weight
.
shape
[
0
]
==
num_experts
and
w2_weight
.
shape
[
1
]
==
hidden_size
and
w2_weight
.
shape
[
2
]
==
intermediate_size
//
2
)
assert
(
w2_weight_scale
.
dim
()
==
3
and
w2_weight_scale
.
shape
[
1
]
==
hidden_size
and
w2_weight_scale
.
shape
[
2
]
==
intermediate_size
//
sf_block_size
)
assert
(
w13_bias
.
dim
()
==
2
and
w13_bias
.
shape
[
0
]
==
num_experts
and
w13_bias
.
shape
[
1
]
==
intermediate_size
*
2
)
assert
(
w2_bias
.
dim
()
==
2
and
w2_bias
.
shape
[
0
]
==
num_experts
and
w2_bias
.
shape
[
1
]
==
hidden_size
)
# Swap w1 and w3 as the defenition of
# swiglu is different in the trtllm-gen
w13_weight_scale_
=
w13_weight_scale
.
clone
()
w13_weight_
=
w13_weight
.
clone
()
w13_bias_
=
w13_bias
.
clone
()
w13_weight
[:,
:
intermediate_size
,
:].
copy_
(
w13_weight_
[:,
intermediate_size
:,
:])
w13_weight
[:,
intermediate_size
:,
:].
copy_
(
w13_weight_
[:,
:
intermediate_size
,
:])
w13_weight_scale
[:,
:
intermediate_size
,
:].
copy_
(
w13_weight_scale_
[:,
intermediate_size
:,
:])
w13_weight_scale
[:,
intermediate_size
:,
:].
copy_
(
w13_weight_scale_
[:,
:
intermediate_size
,
:])
w13_bias
[:,
:
intermediate_size
].
copy_
(
w13_bias_
[:,
intermediate_size
:])
w13_bias
[:,
intermediate_size
:].
copy_
(
w13_bias_
[:,
:
intermediate_size
])
# Interleave the weights and scaling factors for activation
w13_weight_interleaved
=
[]
w13_weight_scale_interleaved
=
[]
w13_bias_interleaved
=
[]
for
i
in
range
(
num_experts
):
w13_weight_interleaved
.
append
(
reorder_rows_for_gated_act_gemm
(
w13_weight
[
i
].
clone
()))
w13_weight_scale_interleaved
.
append
(
reorder_rows_for_gated_act_gemm
(
w13_weight_scale
[
i
].
clone
()))
w13_bias_interleaved
.
append
(
reorder_rows_for_gated_act_gemm
(
w13_bias
[
i
].
clone
().
reshape
(
-
1
,
1
)))
w13_weight
=
torch
.
stack
(
w13_weight_interleaved
).
reshape
(
num_experts
,
2
*
intermediate_size
,
hidden_size
//
2
)
w13_weight_scale
=
torch
.
stack
(
w13_weight_scale_interleaved
).
reshape
(
num_experts
,
2
*
intermediate_size
,
hidden_size
//
32
)
w13_bias
=
torch
.
stack
(
w13_bias_interleaved
).
reshape
(
num_experts
,
2
*
intermediate_size
)
# Shuffle weights and scaling factors for transposed mma output
gemm1_weights_shuffled
=
[]
gemm1_scales_shuffled
=
[]
gemm2_weights_shuffled
=
[]
gemm2_scales_shuffled
=
[]
gemm1_bias_shuffled
=
[]
gemm2_bias_shuffled
=
[]
epilogue_tile_m
=
128
# FIXME: this depends on the kernel internals
for
i
in
range
(
num_experts
):
gemm1_weights_shuffled
.
append
(
shuffle_matrix_a
(
w13_weight
[
i
].
view
(
torch
.
uint8
),
epilogue_tile_m
))
gemm1_scales_shuffled
.
append
(
shuffle_matrix_sf_a
(
w13_weight_scale
[
i
].
view
(
torch
.
uint8
),
epilogue_tile_m
))
gemm2_weights_shuffled
.
append
(
shuffle_matrix_a
(
w2_weight
[
i
].
view
(
torch
.
uint8
),
epilogue_tile_m
))
gemm2_scales_shuffled
.
append
(
shuffle_matrix_sf_a
(
w2_weight_scale
[
i
].
view
(
torch
.
uint8
),
epilogue_tile_m
))
gemm1_bias_shuffled
.
append
(
shuffle_matrix_a
(
w13_bias
[
i
].
reshape
(
-
1
,
1
),
epilogue_tile_m
))
gemm2_bias_shuffled
.
append
(
shuffle_matrix_a
(
w2_bias
[
i
].
reshape
(
-
1
,
1
),
epilogue_tile_m
))
w13_weight
=
torch
.
stack
(
gemm1_weights_shuffled
)
w13_weight_scale
=
torch
.
stack
(
gemm1_scales_shuffled
).
reshape
(
num_experts
,
2
*
intermediate_size
,
hidden_size
//
sf_block_size
).
view
(
torch
.
float8_e4m3fn
)
w13_bias
=
torch
.
stack
(
gemm1_bias_shuffled
).
reshape
(
num_experts
,
-
1
)
w2_weight
=
torch
.
stack
(
gemm2_weights_shuffled
)
w2_weight_scale
=
torch
.
stack
(
gemm2_scales_shuffled
).
reshape
(
num_experts
,
hidden_size
,
intermediate_size
//
sf_block_size
).
view
(
torch
.
float8_e4m3fn
)
w2_bias
=
torch
.
stack
(
gemm2_bias_shuffled
).
reshape
(
num_experts
,
-
1
)
tg_result
=
trtllm_fp4_block_scale_moe
(
routing_logits
=
router_logits
.
to
(
torch
.
bfloat16
),
routing_bias
=
None
,
hidden_states
=
hidden_states
,
hidden_states_scale
=
hidden_states_scale
,
gemm1_weights
=
w13_weight
,
gemm1_weights_scale
=
w13_weight_scale
,
gemm1_bias
=
w13_bias
,
gemm1_alpha
=
alpha
,
gemm1_beta
=
beta
,
gemm1_clamp_limit
=
limit
,
gemm2_weights
=
w2_weight
,
gemm2_weights_scale
=
w2_weight_scale
,
gemm2_bias
=
w2_bias
,
output1_scale_scalar
=
None
,
output1_scale_gate_scalar
=
None
,
output2_scale_scalar
=
None
,
num_experts
=
num_experts
,
top_k
=
topk
,
n_group
=
None
,
topk_group
=
None
,
intermediate_size
=
intermediate_size
,
local_expert_offset
=
0
,
local_num_experts
=
num_experts
,
routed_scaling_factor
=
None
,
tile_tokens_dim
=
get_tile_tokens_dim
(
hidden_states
,
topk
,
num_experts
),
routing_method_type
=
1
,
# renormalize
do_finalize
=
True
)[
0
]
return
tg_result
def
check_accuracy
(
a
,
b
,
atol
,
rtol
,
percent
):
"""Allow a mismatch percentage of 1 - percent."""
if
torch
.
any
(
torch
.
isnan
(
a
)):
raise
Exception
(
"NaN in reference output"
)
if
torch
.
any
(
torch
.
isnan
(
b
)):
raise
Exception
(
"NaN in actual output"
)
if
torch
.
any
(
torch
.
isinf
(
a
)):
raise
Exception
(
"Inf in reference output"
)
if
torch
.
any
(
torch
.
isinf
(
b
)):
raise
Exception
(
"Inf in actual output"
)
assert
a
.
shape
==
b
.
shape
,
f
"Shape mismatch:
{
a
.
shape
}
vs
{
b
.
shape
}
"
left
=
torch
.
abs
(
a
-
b
)
right
=
atol
+
rtol
*
torch
.
abs
(
b
)
count
=
torch
.
sum
(
left
>
right
)
mismatch_percent
=
count
/
a
.
numel
()
if
mismatch_percent
>
1
-
percent
:
raise
Exception
(
f
"Mismatch percentage is
{
mismatch_percent
:.
4
f
}
for rtol
{
rtol
}
"
f
"(threshold:
{
1
-
percent
:.
4
f
}
)"
)
@
pytest
.
mark
.
parametrize
(
"topk"
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"num_experts"
,
[
32
,
128
])
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
1
,
128
,
1024
])
@
pytest
.
mark
.
parametrize
(
"intermediate_size,hidden_size"
,
[(
3072
,
3072
)])
@
pytest
.
mark
.
parametrize
(
"alpha,beta,limit"
,
[(
1.0
,
1.0
,
None
),
(
1.702
,
1.0
,
7.0
)])
@
pytest
.
mark
.
parametrize
(
"act_type"
,
[
'mxfp8'
,
'bf16'
])
@
pytest
.
mark
.
skipif
(
not
TRTLLM_GEN_MXFP4_AVAILABLE
,
reason
=
"nvidia gpu and compute capability sm100 is required for this test"
)
def
test_trtllm_gen_mxfp4_fused_moe
(
topk
:
int
,
num_experts
:
int
,
num_tokens
:
int
,
intermediate_size
:
int
,
hidden_size
:
int
,
alpha
:
float
,
beta
:
float
,
limit
:
Optional
[
float
],
act_type
:
str
,
):
seed
=
42
torch
.
manual_seed
(
seed
)
hidden_states
=
torch
.
randn
(
num_tokens
,
hidden_size
,
device
=
"cuda:0"
,
dtype
=
torch
.
bfloat16
)
w13
=
(
torch
.
randn
(
num_experts
,
intermediate_size
*
2
,
hidden_size
,
device
=
"cuda:0"
,
dtype
=
torch
.
bfloat16
))
w2
=
(
torch
.
randn
(
num_experts
,
hidden_size
,
intermediate_size
,
device
=
"cuda:0"
,
dtype
=
torch
.
bfloat16
))
bias13
=
torch
.
randn
(
num_experts
,
intermediate_size
*
2
,
device
=
"cuda:0"
)
*
10
bias2
=
torch
.
randn
(
num_experts
,
hidden_size
,
device
=
"cuda:0"
)
*
10
router_logits
=
torch
.
rand
(
num_tokens
,
num_experts
,
dtype
=
torch
.
float32
).
cuda
()
w13
,
w13_scale
=
fp4_quantize
(
w13
,
torch
.
tensor
(
1.0
,
device
=
"cuda:0"
),
32
,
sf_use_ue8m0
=
True
,
is_sf_swizzled_layout
=
False
)
w13_scale
=
w13_scale
.
view
(
torch
.
float8_e4m3fn
).
reshape
(
num_experts
,
intermediate_size
*
2
,
hidden_size
//
32
)
w2
,
w2_scale
=
fp4_quantize
(
w2
,
torch
.
tensor
(
1.0
,
device
=
"cuda:0"
),
32
,
sf_use_ue8m0
=
True
,
is_sf_swizzled_layout
=
False
)
w2_scale
=
w2_scale
.
view
(
torch
.
float8_e4m3fn
).
reshape
(
num_experts
,
hidden_size
,
intermediate_size
//
32
)
if
act_type
==
'mxfp8'
:
hidden_states
,
hidden_states_scale
=
mxfp8_quantize
(
hidden_states
,
is_sf_swizzled_layout
=
False
)
hidden_states_scale
=
hidden_states_scale
.
view
(
torch
.
float8_e4m3fn
).
reshape
(
-
1
)
else
:
hidden_states_scale
=
None
# reference result
ref_result
=
torch
.
empty_like
(
hidden_states
,
dtype
=
torch
.
bfloat16
)
w13_ref
=
mxfp4_dequantize
(
w13
.
clone
(),
w13_scale
.
clone
())
w2_ref
=
mxfp4_dequantize
(
w2
.
clone
(),
w2_scale
.
clone
())
bias13_ref
=
bias13
bias2_ref
=
bias2
if
act_type
==
'mxfp8'
:
hidden_states_ref
=
mxfp8_dequantize
(
hidden_states
,
hidden_states_scale
).
to
(
torch
.
float32
)
else
:
hidden_states_ref
=
hidden_states
.
to
(
torch
.
float32
)
# Process tokens in chunks of 32 to reduce memory usage
chunk_size
=
32
num_chunks
=
(
num_tokens
+
chunk_size
-
1
)
//
chunk_size
for
i
in
range
(
num_chunks
):
start_idx
=
i
*
chunk_size
end_idx
=
min
(
start_idx
+
chunk_size
,
num_tokens
)
chunk_result
=
reference_moe
(
router_logits
[
start_idx
:
end_idx
].
to
(
torch
.
float32
),
topk
,
num_experts
,
hidden_states_ref
[
start_idx
:
end_idx
],
w13_ref
,
bias13_ref
,
w2_ref
,
bias2_ref
,
alpha
,
beta
,
limit
,
act_type
,
)
ref_result
[
start_idx
:
end_idx
].
copy_
(
chunk_result
)
# trtllm-gen result
if
alpha
is
not
None
:
alpha
=
torch
.
full
((
num_experts
,
),
alpha
,
device
=
hidden_states
.
device
)
if
limit
is
not
None
:
limit
=
torch
.
full
((
num_experts
,
),
limit
,
device
=
hidden_states
.
device
)
if
beta
is
not
None
:
beta
=
torch
.
full
((
num_experts
,
),
beta
,
device
=
hidden_states
.
device
)
tg_result
=
tg_mxfp4_moe
(
router_logits
,
topk
,
num_experts
,
intermediate_size
,
hidden_size
,
hidden_states
,
hidden_states_scale
,
w13
,
w13_scale
,
bias13
,
w2
,
w2_scale
,
bias2
,
act_type
,
alpha
=
alpha
,
beta
=
beta
,
limit
=
limit
)
# relatively loose check since the mxfp4 quantization is less accurate
check_accuracy
(
ref_result
,
tg_result
,
atol
=
0
,
rtol
=
0.3
,
percent
=
0.8
)
tests/kernels/moe/test_pplx_cutlass_moe.py
View file @
d2b52805
...
...
@@ -17,6 +17,7 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
from
vllm.platforms
import
current_platform
from
vllm.utils
import
cdiv
from
...utils
import
multi_gpu_test
from
.parallel_utils
import
ProcessGroupInfo
,
parallel_launch
try
:
...
...
@@ -76,6 +77,7 @@ def pplx_cutlass_moe(
assert
torch
.
cuda
.
current_device
()
==
pgi
.
local_rank
num_tokens
,
hidden_dim
=
a
.
shape
intermediate_dim
=
w2
.
shape
[
2
]
num_experts
=
w1
.
shape
[
0
]
block_size
=
hidden_dim
# TODO support more cases
device
=
pgi
.
device
...
...
@@ -124,8 +126,27 @@ def pplx_cutlass_moe(
num_local_experts
=
num_local_experts
,
num_dispatchers
=
num_dispatchers
)
ab_strides1
=
torch
.
full
((
num_local_experts
,
),
hidden_dim
,
device
=
"cuda"
,
dtype
=
torch
.
int64
)
ab_strides2
=
torch
.
full
((
num_local_experts
,
),
intermediate_dim
,
device
=
"cuda"
,
dtype
=
torch
.
int64
)
c_strides1
=
torch
.
full
((
num_local_experts
,
),
2
*
intermediate_dim
,
device
=
"cuda"
,
dtype
=
torch
.
int64
)
c_strides2
=
torch
.
full
((
num_local_experts
,
),
hidden_dim
,
device
=
"cuda"
,
dtype
=
torch
.
int64
)
experts
=
CutlassBatchedExpertsFp8
(
num_local_experts
,
num_dispatchers
,
out_dtype
,
per_act_token
,
per_out_ch
)
out_dtype
,
per_act_token
,
per_out_ch
,
ab_strides1
,
ab_strides2
,
c_strides1
,
c_strides2
)
fused_cutlass_experts
=
FusedMoEModularKernel
(
prepare_finalize
,
...
...
@@ -227,6 +248,7 @@ def _pplx_moe(
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"world_dp_size"
,
[[
2
,
1
]])
#, [4, 2]])
@
pytest
.
mark
.
parametrize
(
"use_internode"
,
[
False
])
@
multi_gpu_test
(
num_gpus
=
2
)
@
pytest
.
mark
.
skipif
(
(
lambda
x
:
x
is
None
or
not
ops
.
cutlass_group_gemm_supported
(
x
.
to_int
()))(
current_platform
.
get_device_capability
()),
...
...
tests/kernels/moe/test_pplx_moe.py
View file @
d2b52805
...
...
@@ -37,6 +37,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
from
vllm.platforms
import
current_platform
from
vllm.utils
import
round_up
from
...utils
import
multi_gpu_test
from
.parallel_utils
import
ProcessGroupInfo
,
parallel_launch
requires_pplx
=
pytest
.
mark
.
skipif
(
...
...
@@ -452,6 +453,7 @@ def _pplx_prepare_finalize(
@
pytest
.
mark
.
parametrize
(
"use_internode"
,
[
False
])
@
pytest
.
mark
.
optional
@
requires_pplx
@
multi_gpu_test
(
num_gpus
=
2
)
def
test_pplx_prepare_finalize_slow
(
mnk
:
tuple
[
int
,
int
,
int
],
e
:
int
,
...
...
@@ -740,6 +742,7 @@ def _pplx_moe(
@
pytest
.
mark
.
parametrize
(
"use_internode"
,
[
False
])
@
pytest
.
mark
.
optional
@
requires_pplx
@
multi_gpu_test
(
num_gpus
=
2
)
def
test_pplx_moe_slow
(
mnk
:
tuple
[
int
,
int
,
int
],
e
:
int
,
...
...
@@ -880,6 +883,7 @@ def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool,
@
pytest
.
mark
.
parametrize
(
"world_dp_size"
,
[[
2
,
1
]])
@
pytest
.
mark
.
parametrize
(
"use_internode"
,
[
False
])
@
requires_pplx
@
multi_gpu_test
(
num_gpus
=
2
)
def
test_pplx_prepare_finalize
(
world_dp_size
:
tuple
[
int
,
int
],
use_internode
:
bool
,
...
...
@@ -893,6 +897,7 @@ def test_pplx_prepare_finalize(
@
pytest
.
mark
.
parametrize
(
"world_dp_size"
,
[[
2
,
1
]])
@
pytest
.
mark
.
parametrize
(
"use_internode"
,
[
False
])
@
requires_pplx
@
multi_gpu_test
(
num_gpus
=
2
)
def
test_pplx_moe
(
world_dp_size
:
tuple
[
int
,
int
],
use_internode
:
bool
,
...
...
Prev
1
…
8
9
10
11
12
13
14
15
16
…
26
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment