Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
4fc5f2f9
"profiler/vscode:/vscode.git/clone" did not exist on "63eee2d9991b08ca286f6895dd8f90da12a62da3"
Unverified
Commit
4fc5f2f9
authored
Aug 06, 2025
by
Ke Bao
Committed by
GitHub
Aug 06, 2025
Browse files
Add unit test for triton swa kernel (#8853)
parent
168033d5
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
184 additions
and
0 deletions
+184
-0
test/srt/test_triton_attention_kernels.py
test/srt/test_triton_attention_kernels.py
+184
-0
No files found.
test/srt/test_triton_attention_kernels.py
View file @
4fc5f2f9
...
@@ -2,6 +2,7 @@ import random
...
@@ -2,6 +2,7 @@ import random
import
unittest
import
unittest
import
torch
import
torch
import
torch.nn.functional
as
F
from
sglang.srt.layers.attention.triton_ops.decode_attention
import
(
from
sglang.srt.layers.attention.triton_ops.decode_attention
import
(
decode_attention_fwd
,
decode_attention_fwd
,
...
@@ -18,6 +19,80 @@ from sglang.srt.layers.attention.triton_ops.prefill_attention import (
...
@@ -18,6 +19,80 @@ from sglang.srt.layers.attention.triton_ops.prefill_attention import (
from
sglang.test.test_utils
import
CustomTestCase
from
sglang.test.test_utils
import
CustomTestCase
def
extend_attention_fwd_torch
(
q
:
torch
.
Tensor
,
# [extend_tokens, H_Q, D]
k
:
torch
.
Tensor
,
# [extend_tokens, H_KV, D]
v
:
torch
.
Tensor
,
# [extend_tokens, H_KV, D]
o
:
torch
.
Tensor
,
# [extend_tokens, H_Q, D]
k_cache
:
torch
.
Tensor
,
# [total_tokens, H_KV, D]
v_cache
:
torch
.
Tensor
,
# [total_tokens, H_KV, D]
qo_indptr
:
torch
.
Tensor
,
# [B+1]
kv_indptr
:
torch
.
Tensor
,
# [B+1]
kv_indices
:
torch
.
Tensor
,
# [prefix_tokens]
sliding_window_size
:
int
,
):
B
=
qo_indptr
.
size
(
0
)
-
1
_
,
H_Q
,
D
=
q
.
shape
_
,
H_KV
,
_
=
k
.
shape
group_size
=
H_Q
//
H_KV
scale
=
1.0
/
D
**
0.5
for
i
in
range
(
B
):
q_start
=
int
(
qo_indptr
[
i
].
item
())
q_end
=
int
(
qo_indptr
[
i
+
1
].
item
())
kv_start
=
int
(
kv_indptr
[
i
].
item
())
kv_end
=
int
(
kv_indptr
[
i
+
1
].
item
())
prefix_indices
=
kv_indices
[
kv_start
:
kv_end
]
k_prefix
=
k_cache
[
prefix_indices
]
# [prefix_len, H_KV, D]
v_prefix
=
v_cache
[
prefix_indices
]
# [prefix_len, H_KV, D]
k_extend
=
k
[
q_start
:
q_end
]
# [extend_len, H_KV, D]
v_extend
=
v
[
q_start
:
q_end
]
# [extend_len, H_KV, D]
q_extend
=
q
[
q_start
:
q_end
]
# [extend_len, H_Q, D]
k_full
=
torch
.
cat
([
k_prefix
,
k_extend
],
dim
=
0
)
# [total_len, H_KV, D]
v_full
=
torch
.
cat
([
v_prefix
,
v_extend
],
dim
=
0
)
# [total_len, H_KV, D]
if
group_size
!=
1
:
k_full_hq
=
k_full
.
repeat_interleave
(
group_size
,
dim
=
1
)
# [total_len, H_Q, D]
v_full_hq
=
v_full
.
repeat_interleave
(
group_size
,
dim
=
1
)
# [total_len, H_Q, D]
else
:
k_full_hq
=
k_full
v_full_hq
=
v_full
prefix_len
=
k_prefix
.
size
(
0
)
extend_len
=
k_extend
.
size
(
0
)
total_len
=
prefix_len
+
extend_len
# causal
pos_keys
=
torch
.
arange
(
total_len
,
device
=
q
.
device
)
t
=
prefix_len
+
torch
.
arange
(
extend_len
,
device
=
q
.
device
)
# [extend_len]
causal_mask
=
pos_keys
.
unsqueeze
(
0
)
<=
t
.
unsqueeze
(
1
)
# sliding window
if
sliding_window_size
is
not
None
and
sliding_window_size
>
0
:
start
=
(
t
-
(
sliding_window_size
)).
clamp_min
(
0
)
# [extend_len]
else
:
start
=
torch
.
zeros_like
(
t
)
window_mask
=
pos_keys
.
unsqueeze
(
0
)
>=
start
.
unsqueeze
(
1
)
final_mask
=
causal_mask
&
window_mask
attn_scores
=
(
torch
.
einsum
(
"qhd,khd->qhk"
,
q_extend
,
k_full_hq
)
*
scale
)
# [extend_len, H_Q, total_len]
attn_scores
=
attn_scores
.
masked_fill
(
~
final_mask
.
unsqueeze
(
1
),
float
(
"-inf"
))
attn_weights
=
F
.
softmax
(
attn_scores
,
dim
=-
1
)
o
[
q_start
:
q_end
]
=
torch
.
einsum
(
"qhk,khd->qhd"
,
attn_weights
,
v_full_hq
)
class
TestTritonAttention
(
CustomTestCase
):
class
TestTritonAttention
(
CustomTestCase
):
def
_set_all_seeds
(
self
,
seed
):
def
_set_all_seeds
(
self
,
seed
):
...
@@ -180,6 +255,115 @@ class TestTritonAttention(CustomTestCase):
...
@@ -180,6 +255,115 @@ class TestTritonAttention(CustomTestCase):
for
value
in
attention_values
:
for
value
in
attention_values
:
self
.
_test_extend_attention_once
(
19
,
12331
,
12
,
4
,
value
)
self
.
_test_extend_attention_once
(
19
,
12331
,
12
,
4
,
value
)
def
_test_extend_attention_sliding_window_once
(
self
,
B
,
N_CTX
,
H_Q
,
H_KV
,
D
,
WINDOW_SIZE
):
dtype
=
torch
.
bfloat16
b_seq_len_prefix
=
torch
.
randint
(
1
,
N_CTX
//
2
,
(
B
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
b_seq_len_extend
=
torch
.
randint
(
1
,
N_CTX
//
2
,
(
B
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
b_seq_len
=
b_seq_len_prefix
+
b_seq_len_extend
b_start_loc
=
torch
.
zeros
((
B
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
b_start_loc
[
1
:]
=
torch
.
cumsum
(
b_seq_len
[:
-
1
],
0
)
b_start_loc_extend
=
torch
.
zeros
((
B
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
b_start_loc_extend
[
1
:]
=
torch
.
cumsum
(
b_seq_len_extend
[:
-
1
],
0
)
kv_indptr
=
torch
.
zeros
((
B
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_indptr
[
1
:
B
+
1
]
=
torch
.
cumsum
(
b_seq_len_prefix
[:
B
],
dim
=
0
)
kv_indices
=
torch
.
zeros
(
(
b_seq_len_prefix
.
sum
().
item
(),),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
for
i
in
range
(
B
):
kv_indices
[
kv_indptr
[
i
]
:
kv_indptr
[
i
+
1
]]
=
torch
.
arange
(
b_start_loc
[
i
],
b_start_loc
[
i
]
+
b_seq_len_prefix
[
i
]
)
total_token_num
=
torch
.
sum
(
b_seq_len
).
item
()
extend_token_num
=
torch
.
sum
(
b_seq_len_extend
).
item
()
k_buffer
=
torch
.
empty
(
(
total_token_num
,
H_KV
,
D
),
dtype
=
dtype
,
device
=
"cuda"
).
normal_
(
mean
=
0.1
,
std
=
0.2
)
v_buffer
=
torch
.
empty
(
(
total_token_num
,
H_KV
,
D
),
dtype
=
dtype
,
device
=
"cuda"
).
normal_
(
mean
=
0.1
,
std
=
0.2
)
k_extend
=
torch
.
empty
((
extend_token_num
,
H_KV
,
D
),
dtype
=
dtype
,
device
=
"cuda"
)
v_extend
=
torch
.
empty
((
extend_token_num
,
H_KV
,
D
),
dtype
=
dtype
,
device
=
"cuda"
)
q_extend
=
torch
.
empty
((
extend_token_num
,
H_Q
,
D
),
dtype
=
dtype
,
device
=
"cuda"
)
for
i
in
range
(
B
):
extend_start_in_buffer
=
b_start_loc
[
i
]
+
b_seq_len_prefix
[
i
]
extend_end_in_buffer
=
b_start_loc
[
i
]
+
b_seq_len
[
i
]
extend_start
=
b_start_loc_extend
[
i
]
extend_end
=
b_start_loc_extend
[
i
]
+
b_seq_len_extend
[
i
]
k_extend
[
extend_start
:
extend_end
]
=
k_buffer
[
extend_start_in_buffer
:
extend_end_in_buffer
]
v_extend
[
extend_start
:
extend_end
]
=
v_buffer
[
extend_start_in_buffer
:
extend_end_in_buffer
]
q_extend
[
extend_start
:
extend_end
]
=
torch
.
empty
(
(
b_seq_len_extend
[
i
],
H_Q
,
D
),
dtype
=
dtype
,
device
=
"cuda"
).
normal_
(
mean
=
0.1
,
std
=
0.2
)
o_extend_triton
=
torch
.
empty
(
(
extend_token_num
,
H_Q
,
D
),
dtype
=
dtype
,
device
=
"cuda"
)
o_extend_torch
=
torch
.
empty
(
(
extend_token_num
,
H_Q
,
D
),
dtype
=
dtype
,
device
=
"cuda"
)
b_seq_len_extend
=
b_seq_len
-
b_seq_len_prefix
max_len_extend
=
torch
.
max
(
b_seq_len_extend
,
0
)[
0
].
item
()
qo_indptr
=
torch
.
zeros
((
B
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
qo_indptr
[
1
:
B
+
1
]
=
torch
.
cumsum
(
b_seq_len_extend
[:
B
],
dim
=
0
)
extend_attention_fwd
(
q_extend
,
k_extend
,
v_extend
,
o_extend_triton
,
k_buffer
,
v_buffer
,
qo_indptr
,
kv_indptr
,
kv_indices
,
custom_mask
=
None
,
is_causal
=
True
,
mask_indptr
=
None
,
max_len_extend
=
max_len_extend
,
sliding_window_size
=
WINDOW_SIZE
,
)
extend_attention_fwd_torch
(
q_extend
,
k_extend
,
v_extend
,
o_extend_torch
,
k_buffer
,
v_buffer
,
qo_indptr
,
kv_indptr
,
kv_indices
,
WINDOW_SIZE
,
)
self
.
assertTrue
(
torch
.
allclose
(
o_extend_triton
,
o_extend_torch
,
rtol
=
1e-3
,
atol
=
1e-3
)
)
def
test_extend_attention_sliding_window
(
self
):
window_sizes
=
[
-
1
,
127
]
for
window_size
in
window_sizes
:
self
.
_test_extend_attention_sliding_window_once
(
19
,
12331
,
64
,
8
,
128
,
window_size
)
def
_test_context_attention_once
(
self
,
head_dim
,
is_causal
):
def
_test_context_attention_once
(
self
,
head_dim
,
is_causal
):
# Set up a simple test case
# Set up a simple test case
num_heads
=
4
num_heads
=
4
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment