Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a322051e
Unverified
Commit
a322051e
authored
Feb 06, 2025
by
Ke Bao
Committed by
GitHub
Feb 06, 2025
Browse files
Support custom mask for Triton attention (#3317)
parent
de553334
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
107 additions
and
11 deletions
+107
-11
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+16
-4
python/sglang/srt/layers/attention/triton_ops/extend_attention.py
...glang/srt/layers/attention/triton_ops/extend_attention.py
+48
-7
test/srt/test_triton_attention_kernels.py
test/srt/test_triton_attention_kernels.py
+43
-0
No files found.
python/sglang/srt/layers/attention/triton_backend.py
View file @
a322051e
...
...
@@ -91,6 +91,7 @@ class TritonAttnBackend(AttentionBackend):
qo_indptr
=
None
custom_mask
=
None
mask_offsets
=
None
else
:
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
forward_batch
.
extend_prefix_lens
,
dim
=
0
...
...
@@ -115,6 +116,7 @@ class TritonAttnBackend(AttentionBackend):
qo_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
forward_batch
.
extend_seq_lens
,
dim
=
0
)
qo_indptr
=
qo_indptr
[:
bs
+
1
]
custom_mask
=
None
mask_offsets
=
None
attn_logits
=
None
max_extend_len
=
torch
.
max
(
forward_batch
.
extend_seq_lens
).
item
()
...
...
@@ -126,6 +128,7 @@ class TritonAttnBackend(AttentionBackend):
kv_indices
,
qo_indptr
,
custom_mask
,
mask_offsets
,
)
def
init_cuda_graph_state
(
self
,
max_bs
:
int
):
...
...
@@ -180,6 +183,7 @@ class TritonAttnBackend(AttentionBackend):
kv_indices
,
None
,
None
,
None
,
)
def
init_forward_metadata_replay_cuda_graph
(
...
...
@@ -233,9 +237,15 @@ class TritonAttnBackend(AttentionBackend):
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
)
_
,
max_extend_len
,
kv_indptr
,
kv_indices
,
qo_indptr
,
custom_mask
=
(
self
.
forward_metadata
)
(
_
,
max_extend_len
,
kv_indptr
,
kv_indices
,
qo_indptr
,
custom_mask
,
mask_offsets
,
)
=
self
.
forward_metadata
self
.
extend_attention_fwd
(
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
),
k
.
contiguous
(),
...
...
@@ -246,6 +256,8 @@ class TritonAttnBackend(AttentionBackend):
qo_indptr
,
kv_indptr
,
kv_indices
,
custom_mask
,
mask_offsets
,
max_extend_len
,
layer
.
scaling
,
layer
.
logit_cap
,
...
...
@@ -271,7 +283,7 @@ class TritonAttnBackend(AttentionBackend):
else
:
o
=
torch
.
empty_like
(
q
)
attn_logits
,
_
,
kv_indptr
,
kv_indices
,
_
,
_
=
self
.
forward_metadata
attn_logits
,
_
,
kv_indptr
,
kv_indices
,
_
,
_
,
_
=
self
.
forward_metadata
if
save_kv_cache
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
...
...
python/sglang/srt/layers/attention/triton_ops/extend_attention.py
View file @
a322051e
...
...
@@ -49,6 +49,8 @@ def _fwd_kernel(
qo_indptr
,
kv_indptr
,
kv_indices
,
mask_ptr
,
mask_offsets
,
sm_scale
,
kv_group_num
,
stride_qbs
,
...
...
@@ -71,6 +73,7 @@ def _fwd_kernel(
BLOCK_DV
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
USE_CUSTOM_MASK
:
tl
.
constexpr
,
):
cur_seq
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
...
...
@@ -81,6 +84,10 @@ def _fwd_kernel(
cur_seq_len_extend
=
tl
.
load
(
qo_indptr
+
cur_seq
+
1
)
-
cur_seq_extend_start_idx
cur_seq_kv_start_idx
=
tl
.
load
(
kv_indptr
+
cur_seq
)
cur_seq_len_prefix
=
tl
.
load
(
kv_indptr
+
cur_seq
+
1
)
-
cur_seq_kv_start_idx
cur_seq_len
=
cur_seq_len_prefix
+
cur_seq_len_extend
if
USE_CUSTOM_MASK
:
cur_seq_mask_start_idx
=
tl
.
load
(
mask_offsets
+
cur_seq
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
offs_dv
=
tl
.
arange
(
0
,
BLOCK_DV
)
...
...
@@ -152,7 +159,20 @@ def _fwd_kernel(
if
logit_cap
>
0
:
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
qk
=
tl
.
where
(
mask_m
[:,
None
]
&
mask_n
[
None
,
:],
qk
,
float
(
"-inf"
))
if
USE_CUSTOM_MASK
:
custom_mask
=
tl
.
load
(
mask_ptr
+
cur_seq_mask_start_idx
+
(
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
*
cur_seq_len
+
start_n
+
offs_n
[
None
,
:],
mask
=
(
mask_m
[:,
None
]
&
mask_n
[
None
,
:]),
other
=
0
,
)
custom_mask
&=
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
qk
=
tl
.
where
(
custom_mask
,
qk
,
float
(
"-inf"
))
else
:
qk
=
tl
.
where
(
mask_m
[:,
None
]
&
mask_n
[
None
,
:],
qk
,
float
(
"-inf"
))
n_e_max
=
tl
.
maximum
(
tl
.
max
(
qk
,
1
),
e_max
)
re_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
...
...
@@ -172,7 +192,7 @@ def _fwd_kernel(
e_max
=
n_e_max
# stage 2: compute the trian
l
ge part
# stage 2: compute the triang
l
e part
cur_block_m_end
=
tl
.
minimum
(
cur_seq_len_extend
,
(
cur_block_m
+
1
)
*
BLOCK_M
)
for
start_n
in
range
(
0
,
cur_block_m_end
,
BLOCK_N
):
...
...
@@ -208,11 +228,25 @@ def _fwd_kernel(
if
logit_cap
>
0
:
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
mask_causual
=
(
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
>=
(
start_n
+
offs_n
[
None
,
:]
)
mask_causual
&=
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
qk
=
tl
.
where
(
mask_causual
,
qk
,
float
(
"-inf"
))
if
USE_CUSTOM_MASK
:
custom_mask
=
tl
.
load
(
mask_ptr
+
cur_seq_mask_start_idx
+
(
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
*
cur_seq_len
+
cur_seq_len_prefix
+
start_n
+
offs_n
[
None
,
:],
mask
=
(
mask_m
[:,
None
]
&
mask_n
[
None
,
:]),
other
=
0
,
)
custom_mask
&=
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
qk
=
tl
.
where
(
custom_mask
,
qk
,
float
(
"-inf"
))
else
:
mask_causual
=
(
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
>=
(
start_n
+
offs_n
[
None
,
:]
)
mask_causual
&=
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
qk
=
tl
.
where
(
mask_causual
,
qk
,
float
(
"-inf"
))
n_e_max
=
tl
.
maximum
(
tl
.
max
(
qk
,
1
),
e_max
)
re_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
...
...
@@ -253,6 +287,8 @@ def extend_attention_fwd(
qo_indptr
,
kv_indptr
,
kv_indices
,
custom_mask
,
mask_offsets
,
max_len_extend
,
sm_scale
=
None
,
logit_cap
=
0.0
,
...
...
@@ -308,6 +344,8 @@ def extend_attention_fwd(
batch_size
,
head_num
=
qo_indptr
.
shape
[
0
]
-
1
,
q_extend
.
shape
[
1
]
kv_group_num
=
q_extend
.
shape
[
1
]
//
k_extend
.
shape
[
1
]
USE_CUSTOM_MASK
=
custom_mask
is
not
None
grid
=
(
batch_size
,
head_num
,
triton
.
cdiv
(
max_len_extend
,
BLOCK_M
))
num_stages
=
1
...
...
@@ -325,6 +363,8 @@ def extend_attention_fwd(
qo_indptr
,
kv_indptr
,
kv_indices
,
custom_mask
,
mask_offsets
,
sm_scale
,
kv_group_num
,
q_extend
.
stride
(
0
),
...
...
@@ -347,6 +387,7 @@ def extend_attention_fwd(
BLOCK_N
=
BLOCK_N
,
Lq
=
Lq
,
Lv
=
Lv
,
USE_CUSTOM_MASK
=
USE_CUSTOM_MASK
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
**
extra_kargs
,
...
...
test/srt/test_triton_attention_kernels.py
View file @
a322051e
...
...
@@ -89,6 +89,9 @@ class TestTritonAttention(unittest.TestCase):
).
normal_
(
mean
=
0.1
,
std
=
0.2
)
o_extend
=
torch
.
empty
((
extend_token_num
,
H_Q
,
D
),
dtype
=
dtype
,
device
=
"cuda"
)
o_extend_mask
=
torch
.
empty
(
(
extend_token_num
,
H_Q
,
D
),
dtype
=
dtype
,
device
=
"cuda"
)
o_redundant
=
torch
.
empty
(
(
extend_token_num
,
H_Q
,
D
),
dtype
=
dtype
,
device
=
"cuda"
)
...
...
@@ -98,6 +101,9 @@ class TestTritonAttention(unittest.TestCase):
qo_indptr
=
torch
.
zeros
((
B
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
qo_indptr
[
1
:
B
+
1
]
=
torch
.
cumsum
(
b_seq_len_extend
[:
B
],
dim
=
0
)
custom_mask
=
None
mask_offsets
=
None
extend_attention_fwd
(
q_extend
,
k_extend
,
...
...
@@ -108,6 +114,42 @@ class TestTritonAttention(unittest.TestCase):
qo_indptr
,
kv_indptr
,
kv_indices
,
custom_mask
,
mask_offsets
,
max_len_extend
,
)
b_seq_mask_len
=
b_seq_len_extend
*
b_seq_len
custom_mask
=
torch
.
ones
(
(
b_seq_mask_len
.
sum
().
item
(),),
dtype
=
torch
.
bool
,
device
=
"cuda"
)
mask_offsets
=
torch
.
zeros
((
B
+
1
,),
dtype
=
torch
.
int64
,
device
=
"cuda"
)
mask_offsets
[
1
:
B
+
1
]
=
torch
.
cumsum
(
b_seq_mask_len
[:
B
],
dim
=
0
)
for
i
in
range
(
B
):
causal_mask
=
(
torch
.
tril
(
torch
.
ones
(
b_seq_len_extend
[
i
],
b_seq_len_extend
[
i
]),
diagonal
=
0
)
==
1
)
prefix_mask
=
torch
.
ones
(
b_seq_len_extend
[
i
],
b_seq_len_prefix
[
i
],
dtype
=
torch
.
bool
)
mask_flatten
=
torch
.
cat
([
prefix_mask
,
causal_mask
],
dim
=
1
).
flatten
()
custom_mask
[
mask_offsets
[
i
]
:
mask_offsets
[
i
+
1
]]
=
mask_flatten
extend_attention_fwd
(
q_extend
,
k_extend
,
v_extend
,
o_extend_mask
,
k_buffer
,
v_buffer
,
qo_indptr
,
kv_indptr
,
kv_indices
,
custom_mask
,
mask_offsets
,
max_len_extend
,
)
...
...
@@ -124,6 +166,7 @@ class TestTritonAttention(unittest.TestCase):
)
self
.
assertTrue
(
torch
.
allclose
(
o_extend
,
o_redundant
,
rtol
=
1e-2
))
self
.
assertTrue
(
torch
.
allclose
(
o_extend_mask
,
o_redundant
,
rtol
=
1e-2
))
def
test_extend_attention
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment