Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a07364cc
"vscode:/vscode.git/clone" did not exist on "dd4c74ff32d6f7ca18169d8ad42f3824879302c6"
Unverified
Commit
a07364cc
authored
Feb 04, 2025
by
Ke Bao
Committed by
GitHub
Feb 04, 2025
Browse files
Update Triton decode backend interface (#3292)
parent
2c1a695f
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
129 additions
and
77 deletions
+129
-77
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+71
-7
python/sglang/srt/layers/attention/triton_ops/decode_attention.py
...glang/srt/layers/attention/triton_ops/decode_attention.py
+44
-57
test/srt/test_triton_attention_kernels.py
test/srt/test_triton_attention_kernels.py
+14
-13
No files found.
python/sglang/srt/layers/attention/triton_backend.py
View file @
a07364cc
...
...
@@ -5,6 +5,9 @@ from typing import TYPE_CHECKING, Optional
import
torch
from
sglang.srt.layers.attention
import
AttentionBackend
from
sglang.srt.layers.attention.flashinfer_backend
import
(
create_flashinfer_kv_indices_triton
,
)
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
...
...
@@ -29,6 +32,12 @@ class TritonAttnBackend(AttentionBackend):
self
.
decode_attention_fwd
=
decode_attention_fwd
self
.
extend_attention_fwd
=
extend_attention_fwd
max_bs
=
model_runner
.
req_to_token_pool
.
size
self
.
kv_indptr
=
torch
.
zeros
(
(
max_bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
)
self
.
req_to_token
=
model_runner
.
req_to_token_pool
.
req_to_token
self
.
num_head
=
(
model_runner
.
model_config
.
num_attention_heads
//
get_attention_tp_size
()
)
...
...
@@ -58,11 +67,32 @@ class TritonAttnBackend(AttentionBackend):
)
max_extend_len
=
None
kv_indptr
=
self
.
kv_indptr
bs
=
len
(
forward_batch
.
req_pool_indices
)
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
forward_batch
.
seq_lens
,
dim
=
0
)
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indices
=
torch
.
empty
(
forward_batch
.
seq_lens_sum
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
create_flashinfer_kv_indices_triton
[(
bs
,)](
forward_batch
.
req_to_token_pool
.
req_to_token
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
kv_indptr
,
None
,
kv_indices
,
forward_batch
.
req_to_token_pool
.
req_to_token
.
stride
(
0
),
)
else
:
attn_logits
=
None
max_extend_len
=
torch
.
max
(
forward_batch
.
extend_seq_lens
).
item
()
self
.
forward_metadata
=
attn_logits
,
max_extend_len
kv_indptr
=
None
kv_indices
=
None
self
.
forward_metadata
=
attn_logits
,
max_extend_len
,
kv_indptr
,
kv_indices
def
init_cuda_graph_state
(
self
,
max_bs
:
int
):
self
.
cuda_graph_max_total_num_tokens
=
max_bs
*
self
.
cuda_graph_max_seq_len
...
...
@@ -73,7 +103,12 @@ class TritonAttnBackend(AttentionBackend):
self
.
cuda_graph_attn_logits
=
torch
.
empty
(
(
max_bs
,
self
.
num_head
,
self
.
num_kv_splits
,
self
.
v_head_dim
+
1
),
dtype
=
torch
.
float32
,
device
=
"cuda"
,
device
=
self
.
device
,
)
self
.
cuda_graph_kv_indices
=
torch
.
zeros
(
(
max_bs
*
self
.
cuda_graph_max_seq_len
),
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
def
init_forward_metadata_capture_cuda_graph
(
...
...
@@ -90,9 +125,25 @@ class TritonAttnBackend(AttentionBackend):
assert
forward_mode
.
is_decode
(),
"Not supported"
assert
spec_info
is
None
,
"Not supported"
kv_indptr
=
self
.
kv_indptr
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
seq_lens
,
dim
=
0
)
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indices
=
self
.
cuda_graph_kv_indices
create_flashinfer_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
req_pool_indices
,
seq_lens
,
kv_indptr
,
None
,
kv_indices
,
self
.
req_to_token
.
stride
(
0
),
)
self
.
forward_metadata
=
(
self
.
cuda_graph_attn_logits
,
None
,
kv_indptr
,
kv_indices
,
)
def
init_forward_metadata_replay_cuda_graph
(
...
...
@@ -109,6 +160,20 @@ class TritonAttnBackend(AttentionBackend):
self
.
cuda_graph_start_loc
.
zero_
()
self
.
cuda_graph_start_loc
[
1
:
bs
]
=
torch
.
cumsum
(
seq_lens
[:
bs
-
1
],
dim
=
0
)
kv_indptr
=
self
.
kv_indptr
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
seq_lens
[:
bs
],
dim
=
0
)
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indices
=
self
.
cuda_graph_kv_indices
create_flashinfer_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
req_pool_indices
[:
bs
],
seq_lens
[:
bs
],
kv_indptr
,
None
,
kv_indices
,
self
.
req_to_token
.
stride
(
0
),
)
def
get_cuda_graph_seq_len_fill_value
(
self
):
return
1
...
...
@@ -132,7 +197,7 @@ class TritonAttnBackend(AttentionBackend):
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
)
_
,
max_extend_len
=
self
.
forward_metadata
_
,
max_extend_len
,
_
,
_
=
self
.
forward_metadata
self
.
extend_attention_fwd
(
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
),
k
.
contiguous
(),
...
...
@@ -170,7 +235,7 @@ class TritonAttnBackend(AttentionBackend):
else
:
o
=
torch
.
empty_like
(
q
)
attn_logits
,
_
=
self
.
forward_metadata
attn_logits
,
_
,
kv_indptr
,
kv_indices
=
self
.
forward_metadata
if
save_kv_cache
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
...
...
@@ -182,9 +247,8 @@ class TritonAttnBackend(AttentionBackend):
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
),
forward_batch
.
token_to_kv_pool
.
get_value_buffer
(
layer
.
layer_id
),
o
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
v_head_dim
),
forward_batch
.
req_to_token_pool
.
req_to_token
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
kv_indptr
,
kv_indices
,
attn_logits
,
self
.
num_kv_splits
,
layer
.
scaling
,
...
...
python/sglang/srt/layers/attention/triton_ops/decode_attention.py
View file @
a07364cc
...
...
@@ -49,11 +49,9 @@ def _fwd_kernel_stage1(
K_Buffer
,
V_Buffer
,
sm_scale
,
Req_to_tokens
,
B_req_idx
,
B_Seqlen
,
kv_indptr
,
kv_indices
,
Att_Out
,
stride_req_to_tokens_b
,
stride_qbs
,
stride_qh
,
stride_buf_kbs
,
...
...
@@ -82,8 +80,9 @@ def _fwd_kernel_stage1(
offs_dv
=
tl
.
arange
(
0
,
BLOCK_DV
)
mask_d
=
offs_d
<
Lk
mask_dv
=
offs_dv
<
Lv
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
cur_batch_req_idx
=
tl
.
load
(
B_req_idx
+
cur_batch
)
cur_batch_kv_start_idx
=
tl
.
load
(
kv_indptr
+
cur_batch
)
cur_batch_seq_len
=
tl
.
load
(
kv_indptr
+
cur_batch
+
1
)
-
cur_batch_kv_start_idx
off_q
=
cur_batch
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_d
q
=
tl
.
load
(
Q
+
off_q
,
mask
=
mask_d
,
other
=
0.0
)
...
...
@@ -100,7 +99,7 @@ def _fwd_kernel_stage1(
for
start_n
in
range
(
split_kv_start
,
split_kv_end
,
BLOCK_N
):
offs_n
=
start_n
+
tl
.
arange
(
0
,
BLOCK_N
)
kv_loc
=
tl
.
load
(
Req_to_tokens
+
stride_req_to_tokens_b
*
cur_batch_
req
_idx
+
offs_n
,
kv_indices
+
cur_batch_
kv_start
_idx
+
offs_n
,
mask
=
offs_n
<
split_kv_end
,
other
=
0
,
)
...
...
@@ -173,9 +172,8 @@ def _decode_att_m_fwd(
k_buffer
,
v_buffer
,
att_out
,
Req_to_tokens
,
B_req_idx
,
B_Seqlen
,
kv_indptr
,
kv_indices
,
num_kv_splits
,
sm_scale
,
logit_cap
,
...
...
@@ -188,7 +186,7 @@ def _decode_att_m_fwd(
Lk
=
k_buffer
.
shape
[
-
1
]
Lv
=
v_buffer
.
shape
[
-
1
]
batch
,
head_num
=
B_req_idx
.
shape
[
0
],
q
.
shape
[
1
]
batch
,
head_num
=
kv_indptr
.
shape
[
0
]
-
1
,
q
.
shape
[
1
]
grid
=
(
batch
,
head_num
,
NUM_KV_SPLITS
)
kv_group_num
=
q
.
shape
[
1
]
//
k_buffer
.
shape
[
1
]
...
...
@@ -208,11 +206,9 @@ def _decode_att_m_fwd(
k_buffer
,
v_buffer
,
sm_scale
,
Req_to_tokens
,
B_req_idx
,
B_Seqlen
,
kv_indptr
,
kv_indices
,
att_out
,
Req_to_tokens
.
stride
(
0
),
q
.
stride
(
0
),
q
.
stride
(
1
),
k_buffer
.
stride
(
0
),
...
...
@@ -241,11 +237,9 @@ def _fwd_grouped_kernel_stage1(
K_Buffer
,
V_Buffer
,
sm_scale
,
Req_to_tokens
,
B_req_idx
,
B_Seqlen
,
kv_indptr
,
kv_indices
,
Att_Out
,
stride_req_to_tokens_b
,
stride_qbs
,
stride_qh
,
stride_buf_kbs
,
...
...
@@ -284,8 +278,9 @@ def _fwd_grouped_kernel_stage1(
offs_dv
=
tl
.
arange
(
0
,
BLOCK_DV
)
mask_d
=
offs_d
<
Lk
mask_dv
=
offs_dv
<
Lv
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
cur_batch_req_idx
=
tl
.
load
(
B_req_idx
+
cur_batch
)
cur_batch_kv_start_idx
=
tl
.
load
(
kv_indptr
+
cur_batch
)
cur_batch_seq_len
=
tl
.
load
(
kv_indptr
+
cur_batch
+
1
)
-
cur_batch_kv_start_idx
offs_q
=
cur_batch
*
stride_qbs
+
cur_head
[:,
None
]
*
stride_qh
+
offs_d
[
None
,
:]
q
=
tl
.
load
(
Q
+
offs_q
,
mask
=
(
mask_h
[:,
None
])
&
(
mask_d
[
None
,
:]),
other
=
0.0
)
...
...
@@ -312,7 +307,7 @@ def _fwd_grouped_kernel_stage1(
for
start_n
in
range
(
split_kv_start
,
split_kv_end
,
BLOCK_N
):
offs_n
=
start_n
+
tl
.
arange
(
0
,
BLOCK_N
)
kv_loc
=
tl
.
load
(
Req_to_tokens
+
stride_req_to_tokens_b
*
cur_batch_
req
_idx
+
offs_n
,
kv_indices
+
cur_batch_
kv_start
_idx
+
offs_n
,
mask
=
offs_n
<
split_kv_end
,
other
=
0
,
)
...
...
@@ -400,9 +395,8 @@ def _decode_grouped_att_m_fwd(
k_buffer
,
v_buffer
,
att_out
,
Req_to_tokens
,
B_req_idx
,
B_Seqlen
,
kv_indptr
,
kv_indices
,
num_kv_splits
,
sm_scale
,
logit_cap
,
...
...
@@ -426,7 +420,7 @@ def _decode_grouped_att_m_fwd(
BLOCK_DPE
=
0
BLOCK_DV
=
triton
.
next_power_of_2
(
Lv
)
batch
,
head_num
=
B_req_idx
.
shape
[
0
],
q
.
shape
[
1
]
batch
,
head_num
=
kv_indptr
.
shape
[
0
]
-
1
,
q
.
shape
[
1
]
kv_group_num
=
q
.
shape
[
1
]
//
k_buffer
.
shape
[
1
]
BLOCK_H
=
16
...
...
@@ -450,11 +444,9 @@ def _decode_grouped_att_m_fwd(
k_buffer
,
v_buffer
,
sm_scale
,
Req_to_tokens
,
B_req_idx
,
B_Seqlen
,
kv_indptr
,
kv_indices
,
att_out
,
Req_to_tokens
.
stride
(
0
),
q
.
stride
(
0
),
q
.
stride
(
1
),
k_buffer
.
stride
(
0
),
...
...
@@ -485,7 +477,7 @@ def _decode_grouped_att_m_fwd(
def
_fwd_kernel_stage2
(
Mid_O
,
O
,
B_Seqlen
,
kv_indptr
,
stride_mid_ob
,
stride_mid_oh
,
stride_mid_os
,
...
...
@@ -498,7 +490,9 @@ def _fwd_kernel_stage2(
cur_batch
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
cur_batch_seq_len
=
tl
.
load
(
kv_indptr
+
cur_batch
+
1
)
-
tl
.
load
(
kv_indptr
+
cur_batch
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DV
)
mask_d
=
offs_d
<
Lv
...
...
@@ -542,7 +536,7 @@ def _decode_softmax_reducev_fwd(
q
,
o
,
v_buffer
,
b_seq_len
,
kv_indptr
,
num_kv_splits
,
):
batch
,
head_num
=
q
.
shape
[
0
],
q
.
shape
[
1
]
...
...
@@ -561,7 +555,7 @@ def _decode_softmax_reducev_fwd(
_fwd_kernel_stage2
[
grid
](
logits
,
o
,
b_seq_len
,
kv_indptr
,
logits
.
stride
(
0
),
logits
.
stride
(
1
),
logits
.
stride
(
2
),
...
...
@@ -581,9 +575,8 @@ def decode_attention_fwd_normal(
k_buffer
,
v_buffer
,
o
,
req_to_token
,
b_req_idx
,
b_seq_len
,
kv_indptr
,
kv_indices
,
attn_logits
,
num_kv_splits
,
sm_scale
,
...
...
@@ -594,14 +587,13 @@ def decode_attention_fwd_normal(
k_buffer
,
v_buffer
,
attn_logits
,
req_to_token
,
b_req_idx
,
b_seq_len
,
kv_indptr
,
kv_indices
,
num_kv_splits
,
sm_scale
,
logit_cap
,
)
_decode_softmax_reducev_fwd
(
attn_logits
,
q
,
o
,
v_buffer
,
b_seq_len
,
num_kv_splits
)
_decode_softmax_reducev_fwd
(
attn_logits
,
q
,
o
,
v_buffer
,
kv_indptr
,
num_kv_splits
)
def
decode_attention_fwd_grouped
(
...
...
@@ -609,9 +601,8 @@ def decode_attention_fwd_grouped(
k_buffer
,
v_buffer
,
o
,
req_to_token
,
b_req_idx
,
b_seq_len
,
kv_indptr
,
kv_indices
,
attn_logits
,
num_kv_splits
,
sm_scale
,
...
...
@@ -622,14 +613,13 @@ def decode_attention_fwd_grouped(
k_buffer
,
v_buffer
,
attn_logits
,
req_to_token
,
b_req_idx
,
b_seq_len
,
kv_indptr
,
kv_indices
,
num_kv_splits
,
sm_scale
,
logit_cap
,
)
_decode_softmax_reducev_fwd
(
attn_logits
,
q
,
o
,
v_buffer
,
b_seq_len
,
num_kv_splits
)
_decode_softmax_reducev_fwd
(
attn_logits
,
q
,
o
,
v_buffer
,
kv_indptr
,
num_kv_splits
)
def
decode_attention_fwd
(
...
...
@@ -637,9 +627,8 @@ def decode_attention_fwd(
k_buffer
,
v_buffer
,
o
,
req_to_token
,
b_req_idx
,
b_seq_len
,
kv_indptr
,
kv_indices
,
attn_logits
,
num_kv_splits
,
sm_scale
,
...
...
@@ -655,9 +644,8 @@ def decode_attention_fwd(
k_buffer
,
v_buffer
,
o
,
req_to_token
,
b_req_idx
,
b_seq_len
,
kv_indptr
,
kv_indices
,
attn_logits
,
num_kv_splits
,
sm_scale
,
...
...
@@ -670,9 +658,8 @@ def decode_attention_fwd(
k_buffer
,
v_buffer
,
o
,
req_to_token
,
b_req_idx
,
b_seq_len
,
kv_indptr
,
kv_indices
,
attn_logits
,
num_kv_splits
,
sm_scale
,
...
...
test/srt/test_triton_attention_kernels.py
View file @
a07364cc
...
...
@@ -194,10 +194,12 @@ class TestTritonAttention(unittest.TestCase):
# o will have the same shape as q
o
=
torch
.
zeros
(
B
,
H_Q
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
req_to_token
=
torch
.
arange
(
total_tokens
,
device
=
"cuda"
).
reshape
(
B
,
seq_len
)
b_req_idx
=
torch
.
arange
(
B
,
device
=
"cuda"
)
b_seq_len
=
torch
.
full
((
B
,),
seq_len
,
device
=
"cuda"
)
kv_indptr
=
torch
.
zeros
((
B
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_indptr
[
1
:
B
+
1
]
=
torch
.
cumsum
(
b_seq_len
[:
B
],
dim
=
0
)
kv_indices
=
torch
.
arange
(
total_tokens
,
device
=
"cuda"
)
attn_logits
=
torch
.
empty
(
(
B
,
H_Q
,
num_kv_splits
,
D
+
1
),
dtype
=
torch
.
float32
,
...
...
@@ -209,9 +211,8 @@ class TestTritonAttention(unittest.TestCase):
k_buffer
,
v_buffer
,
o
,
req_to_token
,
b_req_idx
,
b_seq_len
,
kv_indptr
,
kv_indices
,
attn_logits
,
num_kv_splits
,
sm_scale
,
...
...
@@ -250,10 +251,12 @@ class TestTritonAttention(unittest.TestCase):
o
=
torch
.
zeros
(
B
,
H_Q
,
D_V
,
dtype
=
dtype
,
device
=
"cuda"
)
o_grouped
=
torch
.
zeros
(
B
,
H_Q
,
D_V
,
dtype
=
dtype
,
device
=
"cuda"
)
req_to_token
=
torch
.
arange
(
total_tokens
,
device
=
"cuda"
).
reshape
(
B
,
seq_len
)
b_req_idx
=
torch
.
arange
(
B
,
device
=
"cuda"
)
b_seq_len
=
torch
.
full
((
B
,),
seq_len
,
device
=
"cuda"
)
kv_indptr
=
torch
.
zeros
((
B
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_indptr
[
1
:
B
+
1
]
=
torch
.
cumsum
(
b_seq_len
[:
B
],
dim
=
0
)
kv_indices
=
torch
.
arange
(
total_tokens
,
device
=
"cuda"
)
attn_logits
=
torch
.
empty
(
(
B
,
H_Q
,
num_kv_splits
,
D_V
+
1
),
dtype
=
torch
.
float32
,
...
...
@@ -265,9 +268,8 @@ class TestTritonAttention(unittest.TestCase):
k_buffer
,
v_buffer
,
o
,
req_to_token
,
b_req_idx
,
b_seq_len
,
kv_indptr
,
kv_indices
,
attn_logits
,
num_kv_splits
,
sm_scale
,
...
...
@@ -284,9 +286,8 @@ class TestTritonAttention(unittest.TestCase):
k_buffer
,
v_buffer
,
o_grouped
,
req_to_token
,
b_req_idx
,
b_seq_len
,
kv_indptr
,
kv_indices
,
attn_logits1
,
num_kv_splits
,
sm_scale
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment