Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
a322051e
"driver/vscode:/vscode.git/clone" did not exist on "b8385cca3be782b60c69f2611bcb8f02b38742d1"
Unverified
Commit
a322051e
authored
Feb 06, 2025
by
Ke Bao
Committed by
GitHub
Feb 06, 2025
Browse files
Support custom mask for Triton attention (#3317)
parent
de553334
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
107 additions
and
11 deletions
+107
-11
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+16
-4
python/sglang/srt/layers/attention/triton_ops/extend_attention.py
...glang/srt/layers/attention/triton_ops/extend_attention.py
+48
-7
test/srt/test_triton_attention_kernels.py
test/srt/test_triton_attention_kernels.py
+43
-0
No files found.
python/sglang/srt/layers/attention/triton_backend.py
View file @
a322051e
...
@@ -91,6 +91,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -91,6 +91,7 @@ class TritonAttnBackend(AttentionBackend):
qo_indptr
=
None
qo_indptr
=
None
custom_mask
=
None
custom_mask
=
None
mask_offsets
=
None
else
:
else
:
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
forward_batch
.
extend_prefix_lens
,
dim
=
0
forward_batch
.
extend_prefix_lens
,
dim
=
0
...
@@ -115,6 +116,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -115,6 +116,7 @@ class TritonAttnBackend(AttentionBackend):
qo_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
forward_batch
.
extend_seq_lens
,
dim
=
0
)
qo_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
forward_batch
.
extend_seq_lens
,
dim
=
0
)
qo_indptr
=
qo_indptr
[:
bs
+
1
]
qo_indptr
=
qo_indptr
[:
bs
+
1
]
custom_mask
=
None
custom_mask
=
None
mask_offsets
=
None
attn_logits
=
None
attn_logits
=
None
max_extend_len
=
torch
.
max
(
forward_batch
.
extend_seq_lens
).
item
()
max_extend_len
=
torch
.
max
(
forward_batch
.
extend_seq_lens
).
item
()
...
@@ -126,6 +128,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -126,6 +128,7 @@ class TritonAttnBackend(AttentionBackend):
kv_indices
,
kv_indices
,
qo_indptr
,
qo_indptr
,
custom_mask
,
custom_mask
,
mask_offsets
,
)
)
def
init_cuda_graph_state
(
self
,
max_bs
:
int
):
def
init_cuda_graph_state
(
self
,
max_bs
:
int
):
...
@@ -180,6 +183,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -180,6 +183,7 @@ class TritonAttnBackend(AttentionBackend):
kv_indices
,
kv_indices
,
None
,
None
,
None
,
None
,
None
,
)
)
def
init_forward_metadata_replay_cuda_graph
(
def
init_forward_metadata_replay_cuda_graph
(
...
@@ -233,9 +237,15 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -233,9 +237,15 @@ class TritonAttnBackend(AttentionBackend):
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
)
)
_
,
max_extend_len
,
kv_indptr
,
kv_indices
,
qo_indptr
,
custom_mask
=
(
(
self
.
forward_metadata
_
,
)
max_extend_len
,
kv_indptr
,
kv_indices
,
qo_indptr
,
custom_mask
,
mask_offsets
,
)
=
self
.
forward_metadata
self
.
extend_attention_fwd
(
self
.
extend_attention_fwd
(
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
),
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
),
k
.
contiguous
(),
k
.
contiguous
(),
...
@@ -246,6 +256,8 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -246,6 +256,8 @@ class TritonAttnBackend(AttentionBackend):
qo_indptr
,
qo_indptr
,
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
custom_mask
,
mask_offsets
,
max_extend_len
,
max_extend_len
,
layer
.
scaling
,
layer
.
scaling
,
layer
.
logit_cap
,
layer
.
logit_cap
,
...
@@ -271,7 +283,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -271,7 +283,7 @@ class TritonAttnBackend(AttentionBackend):
else
:
else
:
o
=
torch
.
empty_like
(
q
)
o
=
torch
.
empty_like
(
q
)
attn_logits
,
_
,
kv_indptr
,
kv_indices
,
_
,
_
=
self
.
forward_metadata
attn_logits
,
_
,
kv_indptr
,
kv_indices
,
_
,
_
,
_
=
self
.
forward_metadata
if
save_kv_cache
:
if
save_kv_cache
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
...
...
python/sglang/srt/layers/attention/triton_ops/extend_attention.py
View file @
a322051e
...
@@ -49,6 +49,8 @@ def _fwd_kernel(
...
@@ -49,6 +49,8 @@ def _fwd_kernel(
qo_indptr
,
qo_indptr
,
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
mask_ptr
,
mask_offsets
,
sm_scale
,
sm_scale
,
kv_group_num
,
kv_group_num
,
stride_qbs
,
stride_qbs
,
...
@@ -71,6 +73,7 @@ def _fwd_kernel(
...
@@ -71,6 +73,7 @@ def _fwd_kernel(
BLOCK_DV
:
tl
.
constexpr
,
BLOCK_DV
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
USE_CUSTOM_MASK
:
tl
.
constexpr
,
):
):
cur_seq
=
tl
.
program_id
(
0
)
cur_seq
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
cur_head
=
tl
.
program_id
(
1
)
...
@@ -81,6 +84,10 @@ def _fwd_kernel(
...
@@ -81,6 +84,10 @@ def _fwd_kernel(
cur_seq_len_extend
=
tl
.
load
(
qo_indptr
+
cur_seq
+
1
)
-
cur_seq_extend_start_idx
cur_seq_len_extend
=
tl
.
load
(
qo_indptr
+
cur_seq
+
1
)
-
cur_seq_extend_start_idx
cur_seq_kv_start_idx
=
tl
.
load
(
kv_indptr
+
cur_seq
)
cur_seq_kv_start_idx
=
tl
.
load
(
kv_indptr
+
cur_seq
)
cur_seq_len_prefix
=
tl
.
load
(
kv_indptr
+
cur_seq
+
1
)
-
cur_seq_kv_start_idx
cur_seq_len_prefix
=
tl
.
load
(
kv_indptr
+
cur_seq
+
1
)
-
cur_seq_kv_start_idx
cur_seq_len
=
cur_seq_len_prefix
+
cur_seq_len_extend
if
USE_CUSTOM_MASK
:
cur_seq_mask_start_idx
=
tl
.
load
(
mask_offsets
+
cur_seq
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
offs_dv
=
tl
.
arange
(
0
,
BLOCK_DV
)
offs_dv
=
tl
.
arange
(
0
,
BLOCK_DV
)
...
@@ -152,7 +159,20 @@ def _fwd_kernel(
...
@@ -152,7 +159,20 @@ def _fwd_kernel(
if
logit_cap
>
0
:
if
logit_cap
>
0
:
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
qk
=
tl
.
where
(
mask_m
[:,
None
]
&
mask_n
[
None
,
:],
qk
,
float
(
"-inf"
))
if
USE_CUSTOM_MASK
:
custom_mask
=
tl
.
load
(
mask_ptr
+
cur_seq_mask_start_idx
+
(
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
*
cur_seq_len
+
start_n
+
offs_n
[
None
,
:],
mask
=
(
mask_m
[:,
None
]
&
mask_n
[
None
,
:]),
other
=
0
,
)
custom_mask
&=
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
qk
=
tl
.
where
(
custom_mask
,
qk
,
float
(
"-inf"
))
else
:
qk
=
tl
.
where
(
mask_m
[:,
None
]
&
mask_n
[
None
,
:],
qk
,
float
(
"-inf"
))
n_e_max
=
tl
.
maximum
(
tl
.
max
(
qk
,
1
),
e_max
)
n_e_max
=
tl
.
maximum
(
tl
.
max
(
qk
,
1
),
e_max
)
re_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
re_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
...
@@ -172,7 +192,7 @@ def _fwd_kernel(
...
@@ -172,7 +192,7 @@ def _fwd_kernel(
e_max
=
n_e_max
e_max
=
n_e_max
# stage 2: compute the trian
l
ge part
# stage 2: compute the triang
l
e part
cur_block_m_end
=
tl
.
minimum
(
cur_seq_len_extend
,
(
cur_block_m
+
1
)
*
BLOCK_M
)
cur_block_m_end
=
tl
.
minimum
(
cur_seq_len_extend
,
(
cur_block_m
+
1
)
*
BLOCK_M
)
for
start_n
in
range
(
0
,
cur_block_m_end
,
BLOCK_N
):
for
start_n
in
range
(
0
,
cur_block_m_end
,
BLOCK_N
):
...
@@ -208,11 +228,25 @@ def _fwd_kernel(
...
@@ -208,11 +228,25 @@ def _fwd_kernel(
if
logit_cap
>
0
:
if
logit_cap
>
0
:
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
mask_causual
=
(
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
>=
(
if
USE_CUSTOM_MASK
:
start_n
+
offs_n
[
None
,
:]
custom_mask
=
tl
.
load
(
)
mask_ptr
mask_causual
&=
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
+
cur_seq_mask_start_idx
qk
=
tl
.
where
(
mask_causual
,
qk
,
float
(
"-inf"
))
+
(
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
*
cur_seq_len
+
cur_seq_len_prefix
+
start_n
+
offs_n
[
None
,
:],
mask
=
(
mask_m
[:,
None
]
&
mask_n
[
None
,
:]),
other
=
0
,
)
custom_mask
&=
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
qk
=
tl
.
where
(
custom_mask
,
qk
,
float
(
"-inf"
))
else
:
mask_causual
=
(
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
>=
(
start_n
+
offs_n
[
None
,
:]
)
mask_causual
&=
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
qk
=
tl
.
where
(
mask_causual
,
qk
,
float
(
"-inf"
))
n_e_max
=
tl
.
maximum
(
tl
.
max
(
qk
,
1
),
e_max
)
n_e_max
=
tl
.
maximum
(
tl
.
max
(
qk
,
1
),
e_max
)
re_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
re_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
...
@@ -253,6 +287,8 @@ def extend_attention_fwd(
...
@@ -253,6 +287,8 @@ def extend_attention_fwd(
qo_indptr
,
qo_indptr
,
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
custom_mask
,
mask_offsets
,
max_len_extend
,
max_len_extend
,
sm_scale
=
None
,
sm_scale
=
None
,
logit_cap
=
0.0
,
logit_cap
=
0.0
,
...
@@ -308,6 +344,8 @@ def extend_attention_fwd(
...
@@ -308,6 +344,8 @@ def extend_attention_fwd(
batch_size
,
head_num
=
qo_indptr
.
shape
[
0
]
-
1
,
q_extend
.
shape
[
1
]
batch_size
,
head_num
=
qo_indptr
.
shape
[
0
]
-
1
,
q_extend
.
shape
[
1
]
kv_group_num
=
q_extend
.
shape
[
1
]
//
k_extend
.
shape
[
1
]
kv_group_num
=
q_extend
.
shape
[
1
]
//
k_extend
.
shape
[
1
]
USE_CUSTOM_MASK
=
custom_mask
is
not
None
grid
=
(
batch_size
,
head_num
,
triton
.
cdiv
(
max_len_extend
,
BLOCK_M
))
grid
=
(
batch_size
,
head_num
,
triton
.
cdiv
(
max_len_extend
,
BLOCK_M
))
num_stages
=
1
num_stages
=
1
...
@@ -325,6 +363,8 @@ def extend_attention_fwd(
...
@@ -325,6 +363,8 @@ def extend_attention_fwd(
qo_indptr
,
qo_indptr
,
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
custom_mask
,
mask_offsets
,
sm_scale
,
sm_scale
,
kv_group_num
,
kv_group_num
,
q_extend
.
stride
(
0
),
q_extend
.
stride
(
0
),
...
@@ -347,6 +387,7 @@ def extend_attention_fwd(
...
@@ -347,6 +387,7 @@ def extend_attention_fwd(
BLOCK_N
=
BLOCK_N
,
BLOCK_N
=
BLOCK_N
,
Lq
=
Lq
,
Lq
=
Lq
,
Lv
=
Lv
,
Lv
=
Lv
,
USE_CUSTOM_MASK
=
USE_CUSTOM_MASK
,
num_warps
=
num_warps
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
num_stages
=
num_stages
,
**
extra_kargs
,
**
extra_kargs
,
...
...
test/srt/test_triton_attention_kernels.py
View file @
a322051e
...
@@ -89,6 +89,9 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -89,6 +89,9 @@ class TestTritonAttention(unittest.TestCase):
).
normal_
(
mean
=
0.1
,
std
=
0.2
)
).
normal_
(
mean
=
0.1
,
std
=
0.2
)
o_extend
=
torch
.
empty
((
extend_token_num
,
H_Q
,
D
),
dtype
=
dtype
,
device
=
"cuda"
)
o_extend
=
torch
.
empty
((
extend_token_num
,
H_Q
,
D
),
dtype
=
dtype
,
device
=
"cuda"
)
o_extend_mask
=
torch
.
empty
(
(
extend_token_num
,
H_Q
,
D
),
dtype
=
dtype
,
device
=
"cuda"
)
o_redundant
=
torch
.
empty
(
o_redundant
=
torch
.
empty
(
(
extend_token_num
,
H_Q
,
D
),
dtype
=
dtype
,
device
=
"cuda"
(
extend_token_num
,
H_Q
,
D
),
dtype
=
dtype
,
device
=
"cuda"
)
)
...
@@ -98,6 +101,9 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -98,6 +101,9 @@ class TestTritonAttention(unittest.TestCase):
qo_indptr
=
torch
.
zeros
((
B
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
qo_indptr
=
torch
.
zeros
((
B
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
qo_indptr
[
1
:
B
+
1
]
=
torch
.
cumsum
(
b_seq_len_extend
[:
B
],
dim
=
0
)
qo_indptr
[
1
:
B
+
1
]
=
torch
.
cumsum
(
b_seq_len_extend
[:
B
],
dim
=
0
)
custom_mask
=
None
mask_offsets
=
None
extend_attention_fwd
(
extend_attention_fwd
(
q_extend
,
q_extend
,
k_extend
,
k_extend
,
...
@@ -108,6 +114,42 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -108,6 +114,42 @@ class TestTritonAttention(unittest.TestCase):
qo_indptr
,
qo_indptr
,
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
custom_mask
,
mask_offsets
,
max_len_extend
,
)
b_seq_mask_len
=
b_seq_len_extend
*
b_seq_len
custom_mask
=
torch
.
ones
(
(
b_seq_mask_len
.
sum
().
item
(),),
dtype
=
torch
.
bool
,
device
=
"cuda"
)
mask_offsets
=
torch
.
zeros
((
B
+
1
,),
dtype
=
torch
.
int64
,
device
=
"cuda"
)
mask_offsets
[
1
:
B
+
1
]
=
torch
.
cumsum
(
b_seq_mask_len
[:
B
],
dim
=
0
)
for
i
in
range
(
B
):
causal_mask
=
(
torch
.
tril
(
torch
.
ones
(
b_seq_len_extend
[
i
],
b_seq_len_extend
[
i
]),
diagonal
=
0
)
==
1
)
prefix_mask
=
torch
.
ones
(
b_seq_len_extend
[
i
],
b_seq_len_prefix
[
i
],
dtype
=
torch
.
bool
)
mask_flatten
=
torch
.
cat
([
prefix_mask
,
causal_mask
],
dim
=
1
).
flatten
()
custom_mask
[
mask_offsets
[
i
]
:
mask_offsets
[
i
+
1
]]
=
mask_flatten
extend_attention_fwd
(
q_extend
,
k_extend
,
v_extend
,
o_extend_mask
,
k_buffer
,
v_buffer
,
qo_indptr
,
kv_indptr
,
kv_indices
,
custom_mask
,
mask_offsets
,
max_len_extend
,
max_len_extend
,
)
)
...
@@ -124,6 +166,7 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -124,6 +166,7 @@ class TestTritonAttention(unittest.TestCase):
)
)
self
.
assertTrue
(
torch
.
allclose
(
o_extend
,
o_redundant
,
rtol
=
1e-2
))
self
.
assertTrue
(
torch
.
allclose
(
o_extend
,
o_redundant
,
rtol
=
1e-2
))
self
.
assertTrue
(
torch
.
allclose
(
o_extend_mask
,
o_redundant
,
rtol
=
1e-2
))
def
test_extend_attention
(
self
):
def
test_extend_attention
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment