Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
79028d43
Unverified
Commit
79028d43
authored
Feb 05, 2026
by
Xin Yang
Committed by
GitHub
Feb 05, 2026
Browse files
[Perf] Disable clean_logits in deepgemm fp8_mqa_logits kernel (#33568)
parent
325ab6b0
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
61 additions
and
27 deletions
+61
-27
tests/kernels/attention/test_deepgemm_attention.py
tests/kernels/attention/test_deepgemm_attention.py
+13
-7
tests/kernels/test_top_k_per_row.py
tests/kernels/test_top_k_per_row.py
+38
-18
vllm/model_executor/layers/sparse_attn_indexer.py
vllm/model_executor/layers/sparse_attn_indexer.py
+2
-0
vllm/utils/deep_gemm.py
vllm/utils/deep_gemm.py
+8
-2
No files found.
tests/kernels/attention/test_deepgemm_attention.py
View file @
79028d43
...
@@ -95,7 +95,8 @@ def _ref_fp8_mqa_logits(
...
@@ -95,7 +95,8 @@ def _ref_fp8_mqa_logits(
@
pytest
.
mark
.
skipif
(
@
pytest
.
mark
.
skipif
(
not
current_platform
.
has_device_capability
(
90
),
reason
=
"SM90 and SM100 only"
not
current_platform
.
has_device_capability
(
90
),
reason
=
"SM90 and SM100 only"
)
)
def
test_deepgemm_fp8_mqa_logits
():
@
pytest
.
mark
.
parametrize
(
"clean_logits"
,
[
True
,
False
])
def
test_deepgemm_fp8_mqa_logits
(
clean_logits
:
bool
):
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
random
.
seed
(
0
)
random
.
seed
(
0
)
num_heads
,
head_dim
=
32
,
128
num_heads
,
head_dim
=
32
,
128
...
@@ -126,7 +127,9 @@ def test_deepgemm_fp8_mqa_logits():
...
@@ -126,7 +127,9 @@ def test_deepgemm_fp8_mqa_logits():
q_fp8
=
q
.
to
(
torch
.
float8_e4m3fn
)
q_fp8
=
q
.
to
(
torch
.
float8_e4m3fn
)
kv_fp8
=
per_custom_dims_cast_to_fp8
(
kv
,
(
0
,),
False
)
kv_fp8
=
per_custom_dims_cast_to_fp8
(
kv
,
(
0
,),
False
)
logits
=
fp8_mqa_logits
(
q_fp8
,
kv_fp8
,
weights
,
ks
,
ke
)
logits
=
fp8_mqa_logits
(
q_fp8
,
kv_fp8
,
weights
,
ks
,
ke
,
clean_logits
=
clean_logits
)
ref_logits
=
_ref_fp8_mqa_logits
(
ref_logits
=
_ref_fp8_mqa_logits
(
q
=
q
,
q
=
q
,
...
@@ -135,13 +138,14 @@ def test_deepgemm_fp8_mqa_logits():
...
@@ -135,13 +138,14 @@ def test_deepgemm_fp8_mqa_logits():
cu_seqlen_ks
=
ks
,
cu_seqlen_ks
=
ks
,
cu_seqlen_ke
=
ke
,
cu_seqlen_ke
=
ke
,
)
)
ref_neginf_mask
=
ref_logits
==
float
(
"-inf"
)
ref_neginf_mask
=
ref_logits
==
float
(
"-inf"
)
neginf_mask
=
logits
==
float
(
"-inf"
)
assert
torch
.
equal
(
neginf_mask
,
ref_neginf_mask
)
if
clean_logits
:
neginf_mask
=
logits
==
float
(
"-inf"
)
assert
torch
.
equal
(
neginf_mask
,
ref_neginf_mask
)
ref_logits
=
ref_logits
.
masked_fill
(
ref_neginf_mask
,
0
)
ref_logits
=
ref_logits
.
masked_fill
(
ref_neginf_mask
,
0
)
logits
=
logits
.
masked_fill
(
neginf_mask
,
0
)
logits
=
logits
.
masked_fill
(
ref_
neginf_mask
,
0
)
diff
=
calc_diff
(
logits
,
ref_logits
)
diff
=
calc_diff
(
logits
,
ref_logits
)
assert
diff
<
1e-3
,
f
"
{
diff
=
}
"
assert
diff
<
1e-3
,
f
"
{
diff
=
}
"
...
@@ -201,7 +205,8 @@ def _ref_fp8_paged_mqa_logits(
...
@@ -201,7 +205,8 @@ def _ref_fp8_paged_mqa_logits(
@
pytest
.
mark
.
skipif
(
@
pytest
.
mark
.
skipif
(
not
current_platform
.
has_device_capability
(
90
),
reason
=
"SM90 and SM100 only"
not
current_platform
.
has_device_capability
(
90
),
reason
=
"SM90 and SM100 only"
)
)
def
test_deepgemm_fp8_paged_mqa_logits
():
@
pytest
.
mark
.
parametrize
(
"clean_logits"
,
[
True
,
False
])
def
test_deepgemm_fp8_paged_mqa_logits
(
clean_logits
:
bool
):
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
random
.
seed
(
0
)
random
.
seed
(
0
)
...
@@ -264,6 +269,7 @@ def test_deepgemm_fp8_paged_mqa_logits():
...
@@ -264,6 +269,7 @@ def test_deepgemm_fp8_paged_mqa_logits():
block_tables
,
block_tables
,
schedule_metadata
,
schedule_metadata
,
max_model_len
,
max_model_len
,
clean_logits
=
clean_logits
,
)
)
ref_logits
=
_ref_fp8_paged_mqa_logits
(
ref_logits
=
_ref_fp8_paged_mqa_logits
(
...
...
tests/kernels/test_top_k_per_row.py
View file @
79028d43
...
@@ -6,6 +6,7 @@ import pytest
...
@@ -6,6 +6,7 @@ import pytest
import
torch
import
torch
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils.torch_utils
import
set_random_seed
# Test parameters
# Test parameters
NUM_ROWS
=
[
1
,
32
,
2050
]
NUM_ROWS
=
[
1
,
32
,
2050
]
...
@@ -20,6 +21,7 @@ def create_random_logits(
...
@@ -20,6 +21,7 @@ def create_random_logits(
row_ends
:
torch
.
Tensor
,
row_ends
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
seed
:
int
,
seed
:
int
,
clean_logits
:
bool
,
data_generation
:
str
,
data_generation
:
str
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Create random logits tensor for testing."""
"""Create random logits tensor for testing."""
...
@@ -48,8 +50,9 @@ def create_random_logits(
...
@@ -48,8 +50,9 @@ def create_random_logits(
)
)
logits
=
logits_bits
.
view
(
dtype
)
logits
=
logits_bits
.
view
(
dtype
)
for
i
,
end
in
enumerate
(
row_ends
):
if
clean_logits
:
logits
[
i
,
end
:]
=
float
(
"-inf"
)
for
i
,
end
in
enumerate
(
row_ends
):
logits
[
i
,
end
:]
=
float
(
"-inf"
)
return
logits
return
logits
...
@@ -121,21 +124,26 @@ def compare_top_k_results(
...
@@ -121,21 +124,26 @@ def compare_top_k_results(
@
pytest
.
mark
.
parametrize
(
"num_rows"
,
NUM_ROWS
)
@
pytest
.
mark
.
parametrize
(
"num_rows"
,
NUM_ROWS
)
@
pytest
.
mark
.
parametrize
(
"top_k"
,
TOP_K_VALUES
)
@
pytest
.
mark
.
parametrize
(
"top_k"
,
TOP_K_VALUES
)
@
pytest
.
mark
.
parametrize
(
"clean_logits"
,
[
True
,
False
])
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda
(),
reason
=
"This test requires CUDA"
)
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda
(),
reason
=
"This test requires CUDA"
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_top_k_per_row
(
def
test_top_k_per_row
(
num_rows
:
int
,
num_rows
:
int
,
top_k
:
int
,
top_k
:
int
,
clean_logits
:
bool
,
)
->
None
:
)
->
None
:
"""
"""
Test top_k_per_row.
Test top_k_per_row.
"""
"""
set_random_seed
(
0
)
torch
.
set_default_device
(
"cuda:0"
)
torch
.
set_default_device
(
"cuda:0"
)
# Create test data
# Create test data
vocab_size
=
20000
vocab_size
=
20000
row_starts
,
row_ends
=
create_row_boundaries
(
num_rows
,
vocab_size
)
row_starts
,
row_ends
=
create_row_boundaries
(
num_rows
,
vocab_size
)
logits
=
create_random_logits
(
row_starts
,
row_ends
,
torch
.
float32
,
42
,
"random"
)
logits
=
create_random_logits
(
row_starts
,
row_ends
,
torch
.
float32
,
42
,
clean_logits
,
"random"
)
# Create output tensors
# Create output tensors
indices
=
torch
.
empty
((
num_rows
,
top_k
),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
indices
=
torch
.
empty
((
num_rows
,
top_k
),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
...
@@ -153,11 +161,12 @@ def test_top_k_per_row(
...
@@ -153,11 +161,12 @@ def test_top_k_per_row(
)
)
# Run reference implementation
# Run reference implementation
torch_indices
=
logits
.
topk
(
min
(
top_k
,
max
(
row_ends
)),
dim
=-
1
)[
1
]
torch_indices
=
torch
.
empty
((
num_rows
,
top_k
),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
mask_lo
=
torch_indices
>=
0
for
i
in
range
(
num_rows
):
mask_hi
=
(
torch_indices
-
(
row_ends
-
row_starts
)[:,
None
])
<
0
row_end
=
int
(
row_ends
[
i
])
mask
=
mask_lo
&
mask_hi
k_i
=
min
(
top_k
,
row_end
)
torch_indices
=
torch_indices
.
masked_fill
(
~
mask
,
-
1
)
idx
=
logits
[
i
,
:
row_end
].
topk
(
k_i
,
dim
=-
1
)[
1
]
torch_indices
[
i
,
:
k_i
]
=
idx
# Compare results
# Compare results
assert
compare_top_k_results
(
assert
compare_top_k_results
(
...
@@ -170,6 +179,7 @@ def _run_top_k_per_row_decode_test(
...
@@ -170,6 +179,7 @@ def _run_top_k_per_row_decode_test(
batch_size
:
int
,
batch_size
:
int
,
next_n
:
int
,
next_n
:
int
,
vocab_size
:
int
,
vocab_size
:
int
,
clean_logits
:
bool
,
data_generation
:
str
,
data_generation
:
str
,
)
->
None
:
)
->
None
:
"""
"""
...
@@ -180,14 +190,18 @@ def _run_top_k_per_row_decode_test(
...
@@ -180,14 +190,18 @@ def _run_top_k_per_row_decode_test(
# Create test data
# Create test data
num_rows
=
batch_size
*
next_n
num_rows
=
batch_size
*
next_n
seq_lens
=
torch
.
randint
(
seq_lens
=
torch
.
randint
(
vocab_size
,
(
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
low
=
next_n
,
high
=
vocab_size
,
size
=
(
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
,
)
)
row_starts
=
torch
.
zeros
(
num_rows
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
row_starts
=
torch
.
zeros
(
num_rows
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
row_indices
=
torch
.
arange
(
num_rows
,
device
=
"cuda"
)
//
next_n
row_indices
=
torch
.
arange
(
num_rows
,
device
=
"cuda"
)
//
next_n
next_n_offset
=
torch
.
arange
(
num_rows
,
device
=
"cuda"
)
%
next_n
next_n_offset
=
torch
.
arange
(
num_rows
,
device
=
"cuda"
)
%
next_n
row_ends
=
seq_lens
[
row_indices
]
-
next_n
+
next_n_offset
+
1
row_ends
=
seq_lens
[
row_indices
]
-
next_n
+
next_n_offset
+
1
logits
=
create_random_logits
(
logits
=
create_random_logits
(
row_starts
,
row_ends
,
torch
.
float32
,
42
,
data_generation
row_starts
,
row_ends
,
torch
.
float32
,
42
,
clean_logits
,
data_generation
)
)
# Create output tensors
# Create output tensors
...
@@ -208,11 +222,12 @@ def _run_top_k_per_row_decode_test(
...
@@ -208,11 +222,12 @@ def _run_top_k_per_row_decode_test(
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
# Run reference implementation
# Run reference implementation
torch_indices
=
logits
.
topk
(
min
(
top_k
,
max
(
row_ends
)),
dim
=-
1
)[
1
]
torch_indices
=
torch
.
empty
((
num_rows
,
top_k
),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
mask_lo
=
torch_indices
>=
0
for
i
in
range
(
num_rows
):
mask_hi
=
(
torch_indices
-
(
row_ends
-
row_starts
)[:,
None
])
<
0
row_end
=
int
(
row_ends
[
i
])
mask
=
mask_lo
&
mask_hi
k_i
=
min
(
top_k
,
row_end
)
torch_indices
=
torch_indices
.
masked_fill
(
~
mask
,
-
1
)
idx
=
logits
[
i
,
:
row_end
].
topk
(
k_i
,
dim
=-
1
)[
1
]
torch_indices
[
i
,
:
k_i
]
=
idx
# Compare results
# Compare results
assert
compare_top_k_results
(
assert
compare_top_k_results
(
...
@@ -223,6 +238,7 @@ def _run_top_k_per_row_decode_test(
...
@@ -223,6 +238,7 @@ def _run_top_k_per_row_decode_test(
@
pytest
.
mark
.
parametrize
(
"top_k"
,
TOP_K_VALUES
)
@
pytest
.
mark
.
parametrize
(
"top_k"
,
TOP_K_VALUES
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
BATCH_SIZE
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
BATCH_SIZE
)
@
pytest
.
mark
.
parametrize
(
"next_n"
,
NEXT_N
)
@
pytest
.
mark
.
parametrize
(
"next_n"
,
NEXT_N
)
@
pytest
.
mark
.
parametrize
(
"clean_logits"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"data_generation"
,
DATA_GENERATION
)
@
pytest
.
mark
.
parametrize
(
"data_generation"
,
DATA_GENERATION
)
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda
(),
reason
=
"This test requires CUDA"
)
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda
(),
reason
=
"This test requires CUDA"
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
...
@@ -230,28 +246,32 @@ def test_top_k_per_row_decode(
...
@@ -230,28 +246,32 @@ def test_top_k_per_row_decode(
top_k
:
int
,
top_k
:
int
,
batch_size
:
int
,
batch_size
:
int
,
next_n
:
int
,
next_n
:
int
,
clean_logits
:
bool
,
data_generation
:
str
,
data_generation
:
str
,
)
->
None
:
)
->
None
:
"""
"""
Test top_k_per_row with seq_lens tensor.
Test top_k_per_row with seq_lens tensor.
"""
"""
set_random_seed
(
0
)
vocab_size
=
20000
vocab_size
=
20000
_run_top_k_per_row_decode_test
(
_run_top_k_per_row_decode_test
(
top_k
,
batch_size
,
next_n
,
vocab_size
,
data_generation
top_k
,
batch_size
,
next_n
,
vocab_size
,
clean_logits
,
data_generation
)
)
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda
(),
reason
=
"This test requires CUDA"
)
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda
(),
reason
=
"This test requires CUDA"
)
@
pytest
.
mark
.
parametrize
(
"clean_logits"
,
[
True
,
False
])
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_top_k_per_row_decode_large_vocab_size
()
->
None
:
def
test_top_k_per_row_decode_large_vocab_size
(
clean_logits
:
bool
)
->
None
:
"""
"""
Test top_k_per_row_decode with large vocabulary size.
Test top_k_per_row_decode with large vocabulary size.
"""
"""
set_random_seed
(
0
)
top_k
=
2048
top_k
=
2048
batch_size
=
2
batch_size
=
2
next_n
=
2
next_n
=
2
vocab_size
=
300000
vocab_size
=
300000
data_generation
=
"random"
data_generation
=
"random"
_run_top_k_per_row_decode_test
(
_run_top_k_per_row_decode_test
(
top_k
,
batch_size
,
next_n
,
vocab_size
,
data_generation
top_k
,
batch_size
,
next_n
,
vocab_size
,
clean_logits
,
data_generation
)
)
vllm/model_executor/layers/sparse_attn_indexer.py
View file @
79028d43
...
@@ -108,6 +108,7 @@ def sparse_attn_indexer(
...
@@ -108,6 +108,7 @@ def sparse_attn_indexer(
weights
[
chunk
.
token_start
:
chunk
.
token_end
],
weights
[
chunk
.
token_start
:
chunk
.
token_end
],
chunk
.
cu_seqlen_ks
,
chunk
.
cu_seqlen_ks
,
chunk
.
cu_seqlen_ke
,
chunk
.
cu_seqlen_ke
,
clean_logits
=
False
,
)
)
num_rows
=
logits
.
shape
[
0
]
num_rows
=
logits
.
shape
[
0
]
...
@@ -157,6 +158,7 @@ def sparse_attn_indexer(
...
@@ -157,6 +158,7 @@ def sparse_attn_indexer(
decode_metadata
.
block_table
,
decode_metadata
.
block_table
,
decode_metadata
.
schedule_metadata
,
decode_metadata
.
schedule_metadata
,
max_model_len
=
max_model_len
,
max_model_len
=
max_model_len
,
clean_logits
=
False
,
)
)
num_rows
=
logits
.
shape
[
0
]
num_rows
=
logits
.
shape
[
0
]
...
...
vllm/utils/deep_gemm.py
View file @
79028d43
...
@@ -242,6 +242,7 @@ def fp8_mqa_logits(
...
@@ -242,6 +242,7 @@ def fp8_mqa_logits(
weights
:
torch
.
Tensor
,
weights
:
torch
.
Tensor
,
cu_seqlen_ks
:
torch
.
Tensor
,
cu_seqlen_ks
:
torch
.
Tensor
,
cu_seqlen_ke
:
torch
.
Tensor
,
cu_seqlen_ke
:
torch
.
Tensor
,
clean_logits
:
bool
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Compute FP8 MQA logits for a single sequence without KV paging.
"""Compute FP8 MQA logits for a single sequence without KV paging.
...
@@ -256,6 +257,7 @@ def fp8_mqa_logits(
...
@@ -256,6 +257,7 @@ def fp8_mqa_logits(
shape [M], dtype int32.
shape [M], dtype int32.
cu_seqlen_ke: End indices (exclusive) for valid K per query position,
cu_seqlen_ke: End indices (exclusive) for valid K per query position,
shape [M], dtype int32.
shape [M], dtype int32.
clean_logits: Whether to clean the unfilled logits into `-inf`.
Returns:
Returns:
Logits tensor of shape [M, N], dtype `torch.float32`.
Logits tensor of shape [M, N], dtype `torch.float32`.
...
@@ -263,7 +265,9 @@ def fp8_mqa_logits(
...
@@ -263,7 +265,9 @@ def fp8_mqa_logits(
_lazy_init
()
_lazy_init
()
if
_fp8_mqa_logits_impl
is
None
:
if
_fp8_mqa_logits_impl
is
None
:
return
_missing
()
return
_missing
()
return
_fp8_mqa_logits_impl
(
q
,
kv
,
weights
,
cu_seqlen_ks
,
cu_seqlen_ke
)
return
_fp8_mqa_logits_impl
(
q
,
kv
,
weights
,
cu_seqlen_ks
,
cu_seqlen_ke
,
clean_logits
=
clean_logits
)
def
get_paged_mqa_logits_metadata
(
def
get_paged_mqa_logits_metadata
(
...
@@ -295,6 +299,7 @@ def fp8_paged_mqa_logits(
...
@@ -295,6 +299,7 @@ def fp8_paged_mqa_logits(
block_tables
:
torch
.
Tensor
,
block_tables
:
torch
.
Tensor
,
schedule_metadata
:
torch
.
Tensor
,
schedule_metadata
:
torch
.
Tensor
,
max_model_len
:
int
,
max_model_len
:
int
,
clean_logits
:
bool
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Compute FP8 MQA logits using paged KV-cache.
"""Compute FP8 MQA logits using paged KV-cache.
...
@@ -312,6 +317,7 @@ def fp8_paged_mqa_logits(
...
@@ -312,6 +317,7 @@ def fp8_paged_mqa_logits(
schedule_metadata: Returned by `get_paged_mqa_logits_metadata`;
schedule_metadata: Returned by `get_paged_mqa_logits_metadata`;
used to distribute work across SMs.
used to distribute work across SMs.
max_model_len: Maximum sequence length used to size the logits output.
max_model_len: Maximum sequence length used to size the logits output.
clean_logits: Whether to clean the unfilled logits into `-inf`.
Returns:
Returns:
Logits tensor of shape [B * next_n, max_model_len], dtype
Logits tensor of shape [B * next_n, max_model_len], dtype
...
@@ -328,7 +334,7 @@ def fp8_paged_mqa_logits(
...
@@ -328,7 +334,7 @@ def fp8_paged_mqa_logits(
block_tables
,
block_tables
,
schedule_metadata
,
schedule_metadata
,
max_model_len
,
max_model_len
,
clean_logits
=
True
,
clean_logits
=
clean_logits
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment