Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e8f62b20
Unverified
Commit
e8f62b20
authored
Apr 15, 2025
by
Trevor Morris
Committed by
GitHub
Apr 15, 2025
Browse files
BLackwell cutlass mla: Add check for bad page size/block num combinations (#5431)
parent
88defc4d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
11 additions
and
4 deletions
+11
-4
sgl-kernel/python/sgl_kernel/attention.py
sgl-kernel/python/sgl_kernel/attention.py
+5
-3
sgl-kernel/tests/test_cutlass_mla.py
sgl-kernel/tests/test_cutlass_mla.py
+6
-1
No files found.
sgl-kernel/python/sgl_kernel/attention.py
View file @
e8f62b20
...
...
@@ -74,9 +74,11 @@ def cutlass_mla_decode(
f
"but got D_q =
{
D_q
}
, D_ckv =
{
D_ckv
}
, D_latent =
{
D_latent
}
, D_rope =
{
D_rope
}
"
)
assert
H
==
128
,
f
"H must be 128, but got
{
H
}
"
# TODO: There is currently an illegal memory access issue with page size !=
# 128. Change this when it is fixed.
assert
PAGE_SIZE
==
128
,
f
"PAGE_SIZE must be 128, but got
{
PAGE_SIZE
}
"
assert
len
(
page_table
.
shape
)
==
2
B_block_table
,
block_num
=
page_table
.
shape
assert
B_block_table
==
B_q
assert
block_num
%
(
128
/
PAGE_SIZE
)
==
0
# TODO(kaixih@nvidia): support fp8
assert
q_nope_and_q_pe
.
dtype
in
(
...
...
sgl-kernel/tests/test_cutlass_mla.py
View file @
e8f62b20
...
...
@@ -39,7 +39,7 @@ def ref_mla(
@
pytest
.
mark
.
parametrize
(
"mean_seq_len"
,
[
128
,
1024
,
4096
])
@
pytest
.
mark
.
parametrize
(
"bs"
,
[
1
,
2
,
4
])
@
pytest
.
mark
.
parametrize
(
"varlen"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
1
,
16
,
64
,
128
])
def
test_cutlass_mla_decode
(
dtype
:
torch
.
dtype
,
mean_seq_len
:
int
,
bs
:
int
,
varlen
:
bool
,
block_size
:
int
):
...
...
@@ -62,6 +62,11 @@ def test_cutlass_mla_decode(
max_seq_len
=
seq_lens
.
max
().
item
()
block_num
=
(
max_seq_len
+
block_size
-
1
)
//
block_size
# Pad block_num so that small blocks can be packed into full 128-sized CUTLASS tiles.
# One 128-wide tile can hold (128 // block_size) small blocks.
pack_factor
=
128
//
block_size
block_num
=
((
block_num
+
pack_factor
-
1
)
//
pack_factor
)
*
pack_factor
q
=
torch
.
randn
(
bs
,
h_q
,
d
)
block_table
=
torch
.
randint
(
0
,
bs
*
block_num
,
(
bs
,
block_num
),
dtype
=
torch
.
int32
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment