Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e8f62b20
"tests/vscode:/vscode.git/clone" did not exist on "d493a5d523f4731457435548b255781367c1c92f"
Unverified
Commit
e8f62b20
authored
Apr 15, 2025
by
Trevor Morris
Committed by
GitHub
Apr 15, 2025
Browse files
BLackwell cutlass mla: Add check for bad page size/block num combinations (#5431)
parent
88defc4d
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
11 additions
and
4 deletions
+11
-4
sgl-kernel/python/sgl_kernel/attention.py
sgl-kernel/python/sgl_kernel/attention.py
+5
-3
sgl-kernel/tests/test_cutlass_mla.py
sgl-kernel/tests/test_cutlass_mla.py
+6
-1
No files found.
sgl-kernel/python/sgl_kernel/attention.py
View file @
e8f62b20
...
...
@@ -74,9 +74,11 @@ def cutlass_mla_decode(
f
"but got D_q =
{
D_q
}
, D_ckv =
{
D_ckv
}
, D_latent =
{
D_latent
}
, D_rope =
{
D_rope
}
"
)
assert
H
==
128
,
f
"H must be 128, but got
{
H
}
"
# TODO: There is currently an illegal memory access issue with page size !=
# 128. Change this when it is fixed.
assert
PAGE_SIZE
==
128
,
f
"PAGE_SIZE must be 128, but got
{
PAGE_SIZE
}
"
assert
len
(
page_table
.
shape
)
==
2
B_block_table
,
block_num
=
page_table
.
shape
assert
B_block_table
==
B_q
assert
block_num
%
(
128
/
PAGE_SIZE
)
==
0
# TODO(kaixih@nvidia): support fp8
assert
q_nope_and_q_pe
.
dtype
in
(
...
...
sgl-kernel/tests/test_cutlass_mla.py
View file @
e8f62b20
...
...
@@ -39,7 +39,7 @@ def ref_mla(
@
pytest
.
mark
.
parametrize
(
"mean_seq_len"
,
[
128
,
1024
,
4096
])
@
pytest
.
mark
.
parametrize
(
"bs"
,
[
1
,
2
,
4
])
@
pytest
.
mark
.
parametrize
(
"varlen"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
1
,
16
,
64
,
128
])
def
test_cutlass_mla_decode
(
dtype
:
torch
.
dtype
,
mean_seq_len
:
int
,
bs
:
int
,
varlen
:
bool
,
block_size
:
int
):
...
...
@@ -62,6 +62,11 @@ def test_cutlass_mla_decode(
max_seq_len
=
seq_lens
.
max
().
item
()
block_num
=
(
max_seq_len
+
block_size
-
1
)
//
block_size
# Pad block_num so that small blocks can be packed into full 128-sized CUTLASS tiles.
# One 128-wide tile can hold (128 // block_size) small blocks.
pack_factor
=
128
//
block_size
block_num
=
((
block_num
+
pack_factor
-
1
)
//
pack_factor
)
*
pack_factor
q
=
torch
.
randn
(
bs
,
h_q
,
d
)
block_table
=
torch
.
randint
(
0
,
bs
*
block_num
,
(
bs
,
block_num
),
dtype
=
torch
.
int32
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment