Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
4fc5f2f9
Unverified
Commit
4fc5f2f9
authored
Aug 06, 2025
by
Ke Bao
Committed by
GitHub
Aug 06, 2025
Browse files
Add unit test for triton swa kernel (#8853)
parent
168033d5
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
184 additions
and
0 deletions
+184
-0
test/srt/test_triton_attention_kernels.py
test/srt/test_triton_attention_kernels.py
+184
-0
No files found.
test/srt/test_triton_attention_kernels.py
View file @
4fc5f2f9
...
@@ -2,6 +2,7 @@ import random
...
@@ -2,6 +2,7 @@ import random
import
unittest
import
unittest
import
torch
import
torch
import
torch.nn.functional
as
F
from
sglang.srt.layers.attention.triton_ops.decode_attention
import
(
from
sglang.srt.layers.attention.triton_ops.decode_attention
import
(
decode_attention_fwd
,
decode_attention_fwd
,
...
@@ -18,6 +19,80 @@ from sglang.srt.layers.attention.triton_ops.prefill_attention import (
...
@@ -18,6 +19,80 @@ from sglang.srt.layers.attention.triton_ops.prefill_attention import (
from
sglang.test.test_utils
import
CustomTestCase
from
sglang.test.test_utils
import
CustomTestCase
def
extend_attention_fwd_torch
(
q
:
torch
.
Tensor
,
# [extend_tokens, H_Q, D]
k
:
torch
.
Tensor
,
# [extend_tokens, H_KV, D]
v
:
torch
.
Tensor
,
# [extend_tokens, H_KV, D]
o
:
torch
.
Tensor
,
# [extend_tokens, H_Q, D]
k_cache
:
torch
.
Tensor
,
# [total_tokens, H_KV, D]
v_cache
:
torch
.
Tensor
,
# [total_tokens, H_KV, D]
qo_indptr
:
torch
.
Tensor
,
# [B+1]
kv_indptr
:
torch
.
Tensor
,
# [B+1]
kv_indices
:
torch
.
Tensor
,
# [prefix_tokens]
sliding_window_size
:
int
,
):
B
=
qo_indptr
.
size
(
0
)
-
1
_
,
H_Q
,
D
=
q
.
shape
_
,
H_KV
,
_
=
k
.
shape
group_size
=
H_Q
//
H_KV
scale
=
1.0
/
D
**
0.5
for
i
in
range
(
B
):
q_start
=
int
(
qo_indptr
[
i
].
item
())
q_end
=
int
(
qo_indptr
[
i
+
1
].
item
())
kv_start
=
int
(
kv_indptr
[
i
].
item
())
kv_end
=
int
(
kv_indptr
[
i
+
1
].
item
())
prefix_indices
=
kv_indices
[
kv_start
:
kv_end
]
k_prefix
=
k_cache
[
prefix_indices
]
# [prefix_len, H_KV, D]
v_prefix
=
v_cache
[
prefix_indices
]
# [prefix_len, H_KV, D]
k_extend
=
k
[
q_start
:
q_end
]
# [extend_len, H_KV, D]
v_extend
=
v
[
q_start
:
q_end
]
# [extend_len, H_KV, D]
q_extend
=
q
[
q_start
:
q_end
]
# [extend_len, H_Q, D]
k_full
=
torch
.
cat
([
k_prefix
,
k_extend
],
dim
=
0
)
# [total_len, H_KV, D]
v_full
=
torch
.
cat
([
v_prefix
,
v_extend
],
dim
=
0
)
# [total_len, H_KV, D]
if
group_size
!=
1
:
k_full_hq
=
k_full
.
repeat_interleave
(
group_size
,
dim
=
1
)
# [total_len, H_Q, D]
v_full_hq
=
v_full
.
repeat_interleave
(
group_size
,
dim
=
1
)
# [total_len, H_Q, D]
else
:
k_full_hq
=
k_full
v_full_hq
=
v_full
prefix_len
=
k_prefix
.
size
(
0
)
extend_len
=
k_extend
.
size
(
0
)
total_len
=
prefix_len
+
extend_len
# causal
pos_keys
=
torch
.
arange
(
total_len
,
device
=
q
.
device
)
t
=
prefix_len
+
torch
.
arange
(
extend_len
,
device
=
q
.
device
)
# [extend_len]
causal_mask
=
pos_keys
.
unsqueeze
(
0
)
<=
t
.
unsqueeze
(
1
)
# sliding window
if
sliding_window_size
is
not
None
and
sliding_window_size
>
0
:
start
=
(
t
-
(
sliding_window_size
)).
clamp_min
(
0
)
# [extend_len]
else
:
start
=
torch
.
zeros_like
(
t
)
window_mask
=
pos_keys
.
unsqueeze
(
0
)
>=
start
.
unsqueeze
(
1
)
final_mask
=
causal_mask
&
window_mask
attn_scores
=
(
torch
.
einsum
(
"qhd,khd->qhk"
,
q_extend
,
k_full_hq
)
*
scale
)
# [extend_len, H_Q, total_len]
attn_scores
=
attn_scores
.
masked_fill
(
~
final_mask
.
unsqueeze
(
1
),
float
(
"-inf"
))
attn_weights
=
F
.
softmax
(
attn_scores
,
dim
=-
1
)
o
[
q_start
:
q_end
]
=
torch
.
einsum
(
"qhk,khd->qhd"
,
attn_weights
,
v_full_hq
)
class
TestTritonAttention
(
CustomTestCase
):
class
TestTritonAttention
(
CustomTestCase
):
def
_set_all_seeds
(
self
,
seed
):
def
_set_all_seeds
(
self
,
seed
):
...
@@ -180,6 +255,115 @@ class TestTritonAttention(CustomTestCase):
...
@@ -180,6 +255,115 @@ class TestTritonAttention(CustomTestCase):
for
value
in
attention_values
:
for
value
in
attention_values
:
self
.
_test_extend_attention_once
(
19
,
12331
,
12
,
4
,
value
)
self
.
_test_extend_attention_once
(
19
,
12331
,
12
,
4
,
value
)
def
_test_extend_attention_sliding_window_once
(
self
,
B
,
N_CTX
,
H_Q
,
H_KV
,
D
,
WINDOW_SIZE
):
dtype
=
torch
.
bfloat16
b_seq_len_prefix
=
torch
.
randint
(
1
,
N_CTX
//
2
,
(
B
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
b_seq_len_extend
=
torch
.
randint
(
1
,
N_CTX
//
2
,
(
B
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
b_seq_len
=
b_seq_len_prefix
+
b_seq_len_extend
b_start_loc
=
torch
.
zeros
((
B
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
b_start_loc
[
1
:]
=
torch
.
cumsum
(
b_seq_len
[:
-
1
],
0
)
b_start_loc_extend
=
torch
.
zeros
((
B
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
b_start_loc_extend
[
1
:]
=
torch
.
cumsum
(
b_seq_len_extend
[:
-
1
],
0
)
kv_indptr
=
torch
.
zeros
((
B
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_indptr
[
1
:
B
+
1
]
=
torch
.
cumsum
(
b_seq_len_prefix
[:
B
],
dim
=
0
)
kv_indices
=
torch
.
zeros
(
(
b_seq_len_prefix
.
sum
().
item
(),),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
for
i
in
range
(
B
):
kv_indices
[
kv_indptr
[
i
]
:
kv_indptr
[
i
+
1
]]
=
torch
.
arange
(
b_start_loc
[
i
],
b_start_loc
[
i
]
+
b_seq_len_prefix
[
i
]
)
total_token_num
=
torch
.
sum
(
b_seq_len
).
item
()
extend_token_num
=
torch
.
sum
(
b_seq_len_extend
).
item
()
k_buffer
=
torch
.
empty
(
(
total_token_num
,
H_KV
,
D
),
dtype
=
dtype
,
device
=
"cuda"
).
normal_
(
mean
=
0.1
,
std
=
0.2
)
v_buffer
=
torch
.
empty
(
(
total_token_num
,
H_KV
,
D
),
dtype
=
dtype
,
device
=
"cuda"
).
normal_
(
mean
=
0.1
,
std
=
0.2
)
k_extend
=
torch
.
empty
((
extend_token_num
,
H_KV
,
D
),
dtype
=
dtype
,
device
=
"cuda"
)
v_extend
=
torch
.
empty
((
extend_token_num
,
H_KV
,
D
),
dtype
=
dtype
,
device
=
"cuda"
)
q_extend
=
torch
.
empty
((
extend_token_num
,
H_Q
,
D
),
dtype
=
dtype
,
device
=
"cuda"
)
for
i
in
range
(
B
):
extend_start_in_buffer
=
b_start_loc
[
i
]
+
b_seq_len_prefix
[
i
]
extend_end_in_buffer
=
b_start_loc
[
i
]
+
b_seq_len
[
i
]
extend_start
=
b_start_loc_extend
[
i
]
extend_end
=
b_start_loc_extend
[
i
]
+
b_seq_len_extend
[
i
]
k_extend
[
extend_start
:
extend_end
]
=
k_buffer
[
extend_start_in_buffer
:
extend_end_in_buffer
]
v_extend
[
extend_start
:
extend_end
]
=
v_buffer
[
extend_start_in_buffer
:
extend_end_in_buffer
]
q_extend
[
extend_start
:
extend_end
]
=
torch
.
empty
(
(
b_seq_len_extend
[
i
],
H_Q
,
D
),
dtype
=
dtype
,
device
=
"cuda"
).
normal_
(
mean
=
0.1
,
std
=
0.2
)
o_extend_triton
=
torch
.
empty
(
(
extend_token_num
,
H_Q
,
D
),
dtype
=
dtype
,
device
=
"cuda"
)
o_extend_torch
=
torch
.
empty
(
(
extend_token_num
,
H_Q
,
D
),
dtype
=
dtype
,
device
=
"cuda"
)
b_seq_len_extend
=
b_seq_len
-
b_seq_len_prefix
max_len_extend
=
torch
.
max
(
b_seq_len_extend
,
0
)[
0
].
item
()
qo_indptr
=
torch
.
zeros
((
B
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
qo_indptr
[
1
:
B
+
1
]
=
torch
.
cumsum
(
b_seq_len_extend
[:
B
],
dim
=
0
)
extend_attention_fwd
(
q_extend
,
k_extend
,
v_extend
,
o_extend_triton
,
k_buffer
,
v_buffer
,
qo_indptr
,
kv_indptr
,
kv_indices
,
custom_mask
=
None
,
is_causal
=
True
,
mask_indptr
=
None
,
max_len_extend
=
max_len_extend
,
sliding_window_size
=
WINDOW_SIZE
,
)
extend_attention_fwd_torch
(
q_extend
,
k_extend
,
v_extend
,
o_extend_torch
,
k_buffer
,
v_buffer
,
qo_indptr
,
kv_indptr
,
kv_indices
,
WINDOW_SIZE
,
)
self
.
assertTrue
(
torch
.
allclose
(
o_extend_triton
,
o_extend_torch
,
rtol
=
1e-3
,
atol
=
1e-3
)
)
def
test_extend_attention_sliding_window
(
self
):
window_sizes
=
[
-
1
,
127
]
for
window_size
in
window_sizes
:
self
.
_test_extend_attention_sliding_window_once
(
19
,
12331
,
64
,
8
,
128
,
window_size
)
def
_test_context_attention_once
(
self
,
head_dim
,
is_causal
):
def
_test_context_attention_once
(
self
,
head_dim
,
is_causal
):
# Set up a simple test case
# Set up a simple test case
num_heads
=
4
num_heads
=
4
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment