Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a07364cc
Unverified
Commit
a07364cc
authored
Feb 04, 2025
by
Ke Bao
Committed by
GitHub
Feb 04, 2025
Browse files
Update Triton decode backend interface (#3292)
parent
2c1a695f
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
129 additions
and
77 deletions
+129
-77
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+71
-7
python/sglang/srt/layers/attention/triton_ops/decode_attention.py
...glang/srt/layers/attention/triton_ops/decode_attention.py
+44
-57
test/srt/test_triton_attention_kernels.py
test/srt/test_triton_attention_kernels.py
+14
-13
No files found.
python/sglang/srt/layers/attention/triton_backend.py
View file @
a07364cc
...
@@ -5,6 +5,9 @@ from typing import TYPE_CHECKING, Optional
...
@@ -5,6 +5,9 @@ from typing import TYPE_CHECKING, Optional
import
torch
import
torch
from
sglang.srt.layers.attention
import
AttentionBackend
from
sglang.srt.layers.attention
import
AttentionBackend
from
sglang.srt.layers.attention.flashinfer_backend
import
(
create_flashinfer_kv_indices_triton
,
)
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
...
@@ -29,6 +32,12 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -29,6 +32,12 @@ class TritonAttnBackend(AttentionBackend):
self
.
decode_attention_fwd
=
decode_attention_fwd
self
.
decode_attention_fwd
=
decode_attention_fwd
self
.
extend_attention_fwd
=
extend_attention_fwd
self
.
extend_attention_fwd
=
extend_attention_fwd
max_bs
=
model_runner
.
req_to_token_pool
.
size
self
.
kv_indptr
=
torch
.
zeros
(
(
max_bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
)
self
.
req_to_token
=
model_runner
.
req_to_token_pool
.
req_to_token
self
.
num_head
=
(
self
.
num_head
=
(
model_runner
.
model_config
.
num_attention_heads
//
get_attention_tp_size
()
model_runner
.
model_config
.
num_attention_heads
//
get_attention_tp_size
()
)
)
...
@@ -58,11 +67,32 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -58,11 +67,32 @@ class TritonAttnBackend(AttentionBackend):
)
)
max_extend_len
=
None
max_extend_len
=
None
kv_indptr
=
self
.
kv_indptr
bs
=
len
(
forward_batch
.
req_pool_indices
)
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
forward_batch
.
seq_lens
,
dim
=
0
)
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indices
=
torch
.
empty
(
forward_batch
.
seq_lens_sum
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
create_flashinfer_kv_indices_triton
[(
bs
,)](
forward_batch
.
req_to_token_pool
.
req_to_token
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
kv_indptr
,
None
,
kv_indices
,
forward_batch
.
req_to_token_pool
.
req_to_token
.
stride
(
0
),
)
else
:
else
:
attn_logits
=
None
attn_logits
=
None
max_extend_len
=
torch
.
max
(
forward_batch
.
extend_seq_lens
).
item
()
max_extend_len
=
torch
.
max
(
forward_batch
.
extend_seq_lens
).
item
()
self
.
forward_metadata
=
attn_logits
,
max_extend_len
kv_indptr
=
None
kv_indices
=
None
self
.
forward_metadata
=
attn_logits
,
max_extend_len
,
kv_indptr
,
kv_indices
def
init_cuda_graph_state
(
self
,
max_bs
:
int
):
def
init_cuda_graph_state
(
self
,
max_bs
:
int
):
self
.
cuda_graph_max_total_num_tokens
=
max_bs
*
self
.
cuda_graph_max_seq_len
self
.
cuda_graph_max_total_num_tokens
=
max_bs
*
self
.
cuda_graph_max_seq_len
...
@@ -73,7 +103,12 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -73,7 +103,12 @@ class TritonAttnBackend(AttentionBackend):
self
.
cuda_graph_attn_logits
=
torch
.
empty
(
self
.
cuda_graph_attn_logits
=
torch
.
empty
(
(
max_bs
,
self
.
num_head
,
self
.
num_kv_splits
,
self
.
v_head_dim
+
1
),
(
max_bs
,
self
.
num_head
,
self
.
num_kv_splits
,
self
.
v_head_dim
+
1
),
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
device
=
"cuda"
,
device
=
self
.
device
,
)
self
.
cuda_graph_kv_indices
=
torch
.
zeros
(
(
max_bs
*
self
.
cuda_graph_max_seq_len
),
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
)
def
init_forward_metadata_capture_cuda_graph
(
def
init_forward_metadata_capture_cuda_graph
(
...
@@ -90,9 +125,25 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -90,9 +125,25 @@ class TritonAttnBackend(AttentionBackend):
assert
forward_mode
.
is_decode
(),
"Not supported"
assert
forward_mode
.
is_decode
(),
"Not supported"
assert
spec_info
is
None
,
"Not supported"
assert
spec_info
is
None
,
"Not supported"
kv_indptr
=
self
.
kv_indptr
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
seq_lens
,
dim
=
0
)
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indices
=
self
.
cuda_graph_kv_indices
create_flashinfer_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
req_pool_indices
,
seq_lens
,
kv_indptr
,
None
,
kv_indices
,
self
.
req_to_token
.
stride
(
0
),
)
self
.
forward_metadata
=
(
self
.
forward_metadata
=
(
self
.
cuda_graph_attn_logits
,
self
.
cuda_graph_attn_logits
,
None
,
None
,
kv_indptr
,
kv_indices
,
)
)
def
init_forward_metadata_replay_cuda_graph
(
def
init_forward_metadata_replay_cuda_graph
(
...
@@ -109,6 +160,20 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -109,6 +160,20 @@ class TritonAttnBackend(AttentionBackend):
self
.
cuda_graph_start_loc
.
zero_
()
self
.
cuda_graph_start_loc
.
zero_
()
self
.
cuda_graph_start_loc
[
1
:
bs
]
=
torch
.
cumsum
(
seq_lens
[:
bs
-
1
],
dim
=
0
)
self
.
cuda_graph_start_loc
[
1
:
bs
]
=
torch
.
cumsum
(
seq_lens
[:
bs
-
1
],
dim
=
0
)
kv_indptr
=
self
.
kv_indptr
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
seq_lens
[:
bs
],
dim
=
0
)
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indices
=
self
.
cuda_graph_kv_indices
create_flashinfer_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
req_pool_indices
[:
bs
],
seq_lens
[:
bs
],
kv_indptr
,
None
,
kv_indices
,
self
.
req_to_token
.
stride
(
0
),
)
def
get_cuda_graph_seq_len_fill_value
(
self
):
def
get_cuda_graph_seq_len_fill_value
(
self
):
return
1
return
1
...
@@ -132,7 +197,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -132,7 +197,7 @@ class TritonAttnBackend(AttentionBackend):
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
)
)
_
,
max_extend_len
=
self
.
forward_metadata
_
,
max_extend_len
,
_
,
_
=
self
.
forward_metadata
self
.
extend_attention_fwd
(
self
.
extend_attention_fwd
(
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
),
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
),
k
.
contiguous
(),
k
.
contiguous
(),
...
@@ -170,7 +235,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -170,7 +235,7 @@ class TritonAttnBackend(AttentionBackend):
else
:
else
:
o
=
torch
.
empty_like
(
q
)
o
=
torch
.
empty_like
(
q
)
attn_logits
,
_
=
self
.
forward_metadata
attn_logits
,
_
,
kv_indptr
,
kv_indices
=
self
.
forward_metadata
if
save_kv_cache
:
if
save_kv_cache
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
...
@@ -182,9 +247,8 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -182,9 +247,8 @@ class TritonAttnBackend(AttentionBackend):
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
),
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
),
forward_batch
.
token_to_kv_pool
.
get_value_buffer
(
layer
.
layer_id
),
forward_batch
.
token_to_kv_pool
.
get_value_buffer
(
layer
.
layer_id
),
o
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
v_head_dim
),
o
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
v_head_dim
),
forward_batch
.
req_to_token_pool
.
req_to_token
,
kv_indptr
,
forward_batch
.
req_pool_indices
,
kv_indices
,
forward_batch
.
seq_lens
,
attn_logits
,
attn_logits
,
self
.
num_kv_splits
,
self
.
num_kv_splits
,
layer
.
scaling
,
layer
.
scaling
,
...
...
python/sglang/srt/layers/attention/triton_ops/decode_attention.py
View file @
a07364cc
...
@@ -49,11 +49,9 @@ def _fwd_kernel_stage1(
...
@@ -49,11 +49,9 @@ def _fwd_kernel_stage1(
K_Buffer
,
K_Buffer
,
V_Buffer
,
V_Buffer
,
sm_scale
,
sm_scale
,
Req_to_tokens
,
kv_indptr
,
B_req_idx
,
kv_indices
,
B_Seqlen
,
Att_Out
,
Att_Out
,
stride_req_to_tokens_b
,
stride_qbs
,
stride_qbs
,
stride_qh
,
stride_qh
,
stride_buf_kbs
,
stride_buf_kbs
,
...
@@ -82,8 +80,9 @@ def _fwd_kernel_stage1(
...
@@ -82,8 +80,9 @@ def _fwd_kernel_stage1(
offs_dv
=
tl
.
arange
(
0
,
BLOCK_DV
)
offs_dv
=
tl
.
arange
(
0
,
BLOCK_DV
)
mask_d
=
offs_d
<
Lk
mask_d
=
offs_d
<
Lk
mask_dv
=
offs_dv
<
Lv
mask_dv
=
offs_dv
<
Lv
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
cur_batch_req_idx
=
tl
.
load
(
B_req_idx
+
cur_batch
)
cur_batch_kv_start_idx
=
tl
.
load
(
kv_indptr
+
cur_batch
)
cur_batch_seq_len
=
tl
.
load
(
kv_indptr
+
cur_batch
+
1
)
-
cur_batch_kv_start_idx
off_q
=
cur_batch
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_d
off_q
=
cur_batch
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_d
q
=
tl
.
load
(
Q
+
off_q
,
mask
=
mask_d
,
other
=
0.0
)
q
=
tl
.
load
(
Q
+
off_q
,
mask
=
mask_d
,
other
=
0.0
)
...
@@ -100,7 +99,7 @@ def _fwd_kernel_stage1(
...
@@ -100,7 +99,7 @@ def _fwd_kernel_stage1(
for
start_n
in
range
(
split_kv_start
,
split_kv_end
,
BLOCK_N
):
for
start_n
in
range
(
split_kv_start
,
split_kv_end
,
BLOCK_N
):
offs_n
=
start_n
+
tl
.
arange
(
0
,
BLOCK_N
)
offs_n
=
start_n
+
tl
.
arange
(
0
,
BLOCK_N
)
kv_loc
=
tl
.
load
(
kv_loc
=
tl
.
load
(
Req_to_tokens
+
stride_req_to_tokens_b
*
cur_batch_
req
_idx
+
offs_n
,
kv_indices
+
cur_batch_
kv_start
_idx
+
offs_n
,
mask
=
offs_n
<
split_kv_end
,
mask
=
offs_n
<
split_kv_end
,
other
=
0
,
other
=
0
,
)
)
...
@@ -173,9 +172,8 @@ def _decode_att_m_fwd(
...
@@ -173,9 +172,8 @@ def _decode_att_m_fwd(
k_buffer
,
k_buffer
,
v_buffer
,
v_buffer
,
att_out
,
att_out
,
Req_to_tokens
,
kv_indptr
,
B_req_idx
,
kv_indices
,
B_Seqlen
,
num_kv_splits
,
num_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
,
logit_cap
,
...
@@ -188,7 +186,7 @@ def _decode_att_m_fwd(
...
@@ -188,7 +186,7 @@ def _decode_att_m_fwd(
Lk
=
k_buffer
.
shape
[
-
1
]
Lk
=
k_buffer
.
shape
[
-
1
]
Lv
=
v_buffer
.
shape
[
-
1
]
Lv
=
v_buffer
.
shape
[
-
1
]
batch
,
head_num
=
B_req_idx
.
shape
[
0
],
q
.
shape
[
1
]
batch
,
head_num
=
kv_indptr
.
shape
[
0
]
-
1
,
q
.
shape
[
1
]
grid
=
(
batch
,
head_num
,
NUM_KV_SPLITS
)
grid
=
(
batch
,
head_num
,
NUM_KV_SPLITS
)
kv_group_num
=
q
.
shape
[
1
]
//
k_buffer
.
shape
[
1
]
kv_group_num
=
q
.
shape
[
1
]
//
k_buffer
.
shape
[
1
]
...
@@ -208,11 +206,9 @@ def _decode_att_m_fwd(
...
@@ -208,11 +206,9 @@ def _decode_att_m_fwd(
k_buffer
,
k_buffer
,
v_buffer
,
v_buffer
,
sm_scale
,
sm_scale
,
Req_to_tokens
,
kv_indptr
,
B_req_idx
,
kv_indices
,
B_Seqlen
,
att_out
,
att_out
,
Req_to_tokens
.
stride
(
0
),
q
.
stride
(
0
),
q
.
stride
(
0
),
q
.
stride
(
1
),
q
.
stride
(
1
),
k_buffer
.
stride
(
0
),
k_buffer
.
stride
(
0
),
...
@@ -241,11 +237,9 @@ def _fwd_grouped_kernel_stage1(
...
@@ -241,11 +237,9 @@ def _fwd_grouped_kernel_stage1(
K_Buffer
,
K_Buffer
,
V_Buffer
,
V_Buffer
,
sm_scale
,
sm_scale
,
Req_to_tokens
,
kv_indptr
,
B_req_idx
,
kv_indices
,
B_Seqlen
,
Att_Out
,
Att_Out
,
stride_req_to_tokens_b
,
stride_qbs
,
stride_qbs
,
stride_qh
,
stride_qh
,
stride_buf_kbs
,
stride_buf_kbs
,
...
@@ -284,8 +278,9 @@ def _fwd_grouped_kernel_stage1(
...
@@ -284,8 +278,9 @@ def _fwd_grouped_kernel_stage1(
offs_dv
=
tl
.
arange
(
0
,
BLOCK_DV
)
offs_dv
=
tl
.
arange
(
0
,
BLOCK_DV
)
mask_d
=
offs_d
<
Lk
mask_d
=
offs_d
<
Lk
mask_dv
=
offs_dv
<
Lv
mask_dv
=
offs_dv
<
Lv
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
cur_batch_req_idx
=
tl
.
load
(
B_req_idx
+
cur_batch
)
cur_batch_kv_start_idx
=
tl
.
load
(
kv_indptr
+
cur_batch
)
cur_batch_seq_len
=
tl
.
load
(
kv_indptr
+
cur_batch
+
1
)
-
cur_batch_kv_start_idx
offs_q
=
cur_batch
*
stride_qbs
+
cur_head
[:,
None
]
*
stride_qh
+
offs_d
[
None
,
:]
offs_q
=
cur_batch
*
stride_qbs
+
cur_head
[:,
None
]
*
stride_qh
+
offs_d
[
None
,
:]
q
=
tl
.
load
(
Q
+
offs_q
,
mask
=
(
mask_h
[:,
None
])
&
(
mask_d
[
None
,
:]),
other
=
0.0
)
q
=
tl
.
load
(
Q
+
offs_q
,
mask
=
(
mask_h
[:,
None
])
&
(
mask_d
[
None
,
:]),
other
=
0.0
)
...
@@ -312,7 +307,7 @@ def _fwd_grouped_kernel_stage1(
...
@@ -312,7 +307,7 @@ def _fwd_grouped_kernel_stage1(
for
start_n
in
range
(
split_kv_start
,
split_kv_end
,
BLOCK_N
):
for
start_n
in
range
(
split_kv_start
,
split_kv_end
,
BLOCK_N
):
offs_n
=
start_n
+
tl
.
arange
(
0
,
BLOCK_N
)
offs_n
=
start_n
+
tl
.
arange
(
0
,
BLOCK_N
)
kv_loc
=
tl
.
load
(
kv_loc
=
tl
.
load
(
Req_to_tokens
+
stride_req_to_tokens_b
*
cur_batch_
req
_idx
+
offs_n
,
kv_indices
+
cur_batch_
kv_start
_idx
+
offs_n
,
mask
=
offs_n
<
split_kv_end
,
mask
=
offs_n
<
split_kv_end
,
other
=
0
,
other
=
0
,
)
)
...
@@ -400,9 +395,8 @@ def _decode_grouped_att_m_fwd(
...
@@ -400,9 +395,8 @@ def _decode_grouped_att_m_fwd(
k_buffer
,
k_buffer
,
v_buffer
,
v_buffer
,
att_out
,
att_out
,
Req_to_tokens
,
kv_indptr
,
B_req_idx
,
kv_indices
,
B_Seqlen
,
num_kv_splits
,
num_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
,
logit_cap
,
...
@@ -426,7 +420,7 @@ def _decode_grouped_att_m_fwd(
...
@@ -426,7 +420,7 @@ def _decode_grouped_att_m_fwd(
BLOCK_DPE
=
0
BLOCK_DPE
=
0
BLOCK_DV
=
triton
.
next_power_of_2
(
Lv
)
BLOCK_DV
=
triton
.
next_power_of_2
(
Lv
)
batch
,
head_num
=
B_req_idx
.
shape
[
0
],
q
.
shape
[
1
]
batch
,
head_num
=
kv_indptr
.
shape
[
0
]
-
1
,
q
.
shape
[
1
]
kv_group_num
=
q
.
shape
[
1
]
//
k_buffer
.
shape
[
1
]
kv_group_num
=
q
.
shape
[
1
]
//
k_buffer
.
shape
[
1
]
BLOCK_H
=
16
BLOCK_H
=
16
...
@@ -450,11 +444,9 @@ def _decode_grouped_att_m_fwd(
...
@@ -450,11 +444,9 @@ def _decode_grouped_att_m_fwd(
k_buffer
,
k_buffer
,
v_buffer
,
v_buffer
,
sm_scale
,
sm_scale
,
Req_to_tokens
,
kv_indptr
,
B_req_idx
,
kv_indices
,
B_Seqlen
,
att_out
,
att_out
,
Req_to_tokens
.
stride
(
0
),
q
.
stride
(
0
),
q
.
stride
(
0
),
q
.
stride
(
1
),
q
.
stride
(
1
),
k_buffer
.
stride
(
0
),
k_buffer
.
stride
(
0
),
...
@@ -485,7 +477,7 @@ def _decode_grouped_att_m_fwd(
...
@@ -485,7 +477,7 @@ def _decode_grouped_att_m_fwd(
def
_fwd_kernel_stage2
(
def
_fwd_kernel_stage2
(
Mid_O
,
Mid_O
,
O
,
O
,
B_Seqlen
,
kv_indptr
,
stride_mid_ob
,
stride_mid_ob
,
stride_mid_oh
,
stride_mid_oh
,
stride_mid_os
,
stride_mid_os
,
...
@@ -498,7 +490,9 @@ def _fwd_kernel_stage2(
...
@@ -498,7 +490,9 @@ def _fwd_kernel_stage2(
cur_batch
=
tl
.
program_id
(
0
)
cur_batch
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
cur_head
=
tl
.
program_id
(
1
)
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
cur_batch_seq_len
=
tl
.
load
(
kv_indptr
+
cur_batch
+
1
)
-
tl
.
load
(
kv_indptr
+
cur_batch
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DV
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DV
)
mask_d
=
offs_d
<
Lv
mask_d
=
offs_d
<
Lv
...
@@ -542,7 +536,7 @@ def _decode_softmax_reducev_fwd(
...
@@ -542,7 +536,7 @@ def _decode_softmax_reducev_fwd(
q
,
q
,
o
,
o
,
v_buffer
,
v_buffer
,
b_seq_len
,
kv_indptr
,
num_kv_splits
,
num_kv_splits
,
):
):
batch
,
head_num
=
q
.
shape
[
0
],
q
.
shape
[
1
]
batch
,
head_num
=
q
.
shape
[
0
],
q
.
shape
[
1
]
...
@@ -561,7 +555,7 @@ def _decode_softmax_reducev_fwd(
...
@@ -561,7 +555,7 @@ def _decode_softmax_reducev_fwd(
_fwd_kernel_stage2
[
grid
](
_fwd_kernel_stage2
[
grid
](
logits
,
logits
,
o
,
o
,
b_seq_len
,
kv_indptr
,
logits
.
stride
(
0
),
logits
.
stride
(
0
),
logits
.
stride
(
1
),
logits
.
stride
(
1
),
logits
.
stride
(
2
),
logits
.
stride
(
2
),
...
@@ -581,9 +575,8 @@ def decode_attention_fwd_normal(
...
@@ -581,9 +575,8 @@ def decode_attention_fwd_normal(
k_buffer
,
k_buffer
,
v_buffer
,
v_buffer
,
o
,
o
,
req_to_token
,
kv_indptr
,
b_req_idx
,
kv_indices
,
b_seq_len
,
attn_logits
,
attn_logits
,
num_kv_splits
,
num_kv_splits
,
sm_scale
,
sm_scale
,
...
@@ -594,14 +587,13 @@ def decode_attention_fwd_normal(
...
@@ -594,14 +587,13 @@ def decode_attention_fwd_normal(
k_buffer
,
k_buffer
,
v_buffer
,
v_buffer
,
attn_logits
,
attn_logits
,
req_to_token
,
kv_indptr
,
b_req_idx
,
kv_indices
,
b_seq_len
,
num_kv_splits
,
num_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
,
logit_cap
,
)
)
_decode_softmax_reducev_fwd
(
attn_logits
,
q
,
o
,
v_buffer
,
b_seq_len
,
num_kv_splits
)
_decode_softmax_reducev_fwd
(
attn_logits
,
q
,
o
,
v_buffer
,
kv_indptr
,
num_kv_splits
)
def
decode_attention_fwd_grouped
(
def
decode_attention_fwd_grouped
(
...
@@ -609,9 +601,8 @@ def decode_attention_fwd_grouped(
...
@@ -609,9 +601,8 @@ def decode_attention_fwd_grouped(
k_buffer
,
k_buffer
,
v_buffer
,
v_buffer
,
o
,
o
,
req_to_token
,
kv_indptr
,
b_req_idx
,
kv_indices
,
b_seq_len
,
attn_logits
,
attn_logits
,
num_kv_splits
,
num_kv_splits
,
sm_scale
,
sm_scale
,
...
@@ -622,14 +613,13 @@ def decode_attention_fwd_grouped(
...
@@ -622,14 +613,13 @@ def decode_attention_fwd_grouped(
k_buffer
,
k_buffer
,
v_buffer
,
v_buffer
,
attn_logits
,
attn_logits
,
req_to_token
,
kv_indptr
,
b_req_idx
,
kv_indices
,
b_seq_len
,
num_kv_splits
,
num_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
,
logit_cap
,
)
)
_decode_softmax_reducev_fwd
(
attn_logits
,
q
,
o
,
v_buffer
,
b_seq_len
,
num_kv_splits
)
_decode_softmax_reducev_fwd
(
attn_logits
,
q
,
o
,
v_buffer
,
kv_indptr
,
num_kv_splits
)
def
decode_attention_fwd
(
def
decode_attention_fwd
(
...
@@ -637,9 +627,8 @@ def decode_attention_fwd(
...
@@ -637,9 +627,8 @@ def decode_attention_fwd(
k_buffer
,
k_buffer
,
v_buffer
,
v_buffer
,
o
,
o
,
req_to_token
,
kv_indptr
,
b_req_idx
,
kv_indices
,
b_seq_len
,
attn_logits
,
attn_logits
,
num_kv_splits
,
num_kv_splits
,
sm_scale
,
sm_scale
,
...
@@ -655,9 +644,8 @@ def decode_attention_fwd(
...
@@ -655,9 +644,8 @@ def decode_attention_fwd(
k_buffer
,
k_buffer
,
v_buffer
,
v_buffer
,
o
,
o
,
req_to_token
,
kv_indptr
,
b_req_idx
,
kv_indices
,
b_seq_len
,
attn_logits
,
attn_logits
,
num_kv_splits
,
num_kv_splits
,
sm_scale
,
sm_scale
,
...
@@ -670,9 +658,8 @@ def decode_attention_fwd(
...
@@ -670,9 +658,8 @@ def decode_attention_fwd(
k_buffer
,
k_buffer
,
v_buffer
,
v_buffer
,
o
,
o
,
req_to_token
,
kv_indptr
,
b_req_idx
,
kv_indices
,
b_seq_len
,
attn_logits
,
attn_logits
,
num_kv_splits
,
num_kv_splits
,
sm_scale
,
sm_scale
,
...
...
test/srt/test_triton_attention_kernels.py
View file @
a07364cc
...
@@ -194,10 +194,12 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -194,10 +194,12 @@ class TestTritonAttention(unittest.TestCase):
# o will have the same shape as q
# o will have the same shape as q
o
=
torch
.
zeros
(
B
,
H_Q
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
o
=
torch
.
zeros
(
B
,
H_Q
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
req_to_token
=
torch
.
arange
(
total_tokens
,
device
=
"cuda"
).
reshape
(
B
,
seq_len
)
b_req_idx
=
torch
.
arange
(
B
,
device
=
"cuda"
)
b_seq_len
=
torch
.
full
((
B
,),
seq_len
,
device
=
"cuda"
)
b_seq_len
=
torch
.
full
((
B
,),
seq_len
,
device
=
"cuda"
)
kv_indptr
=
torch
.
zeros
((
B
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_indptr
[
1
:
B
+
1
]
=
torch
.
cumsum
(
b_seq_len
[:
B
],
dim
=
0
)
kv_indices
=
torch
.
arange
(
total_tokens
,
device
=
"cuda"
)
attn_logits
=
torch
.
empty
(
attn_logits
=
torch
.
empty
(
(
B
,
H_Q
,
num_kv_splits
,
D
+
1
),
(
B
,
H_Q
,
num_kv_splits
,
D
+
1
),
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
...
@@ -209,9 +211,8 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -209,9 +211,8 @@ class TestTritonAttention(unittest.TestCase):
k_buffer
,
k_buffer
,
v_buffer
,
v_buffer
,
o
,
o
,
req_to_token
,
kv_indptr
,
b_req_idx
,
kv_indices
,
b_seq_len
,
attn_logits
,
attn_logits
,
num_kv_splits
,
num_kv_splits
,
sm_scale
,
sm_scale
,
...
@@ -250,10 +251,12 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -250,10 +251,12 @@ class TestTritonAttention(unittest.TestCase):
o
=
torch
.
zeros
(
B
,
H_Q
,
D_V
,
dtype
=
dtype
,
device
=
"cuda"
)
o
=
torch
.
zeros
(
B
,
H_Q
,
D_V
,
dtype
=
dtype
,
device
=
"cuda"
)
o_grouped
=
torch
.
zeros
(
B
,
H_Q
,
D_V
,
dtype
=
dtype
,
device
=
"cuda"
)
o_grouped
=
torch
.
zeros
(
B
,
H_Q
,
D_V
,
dtype
=
dtype
,
device
=
"cuda"
)
req_to_token
=
torch
.
arange
(
total_tokens
,
device
=
"cuda"
).
reshape
(
B
,
seq_len
)
b_req_idx
=
torch
.
arange
(
B
,
device
=
"cuda"
)
b_seq_len
=
torch
.
full
((
B
,),
seq_len
,
device
=
"cuda"
)
b_seq_len
=
torch
.
full
((
B
,),
seq_len
,
device
=
"cuda"
)
kv_indptr
=
torch
.
zeros
((
B
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_indptr
[
1
:
B
+
1
]
=
torch
.
cumsum
(
b_seq_len
[:
B
],
dim
=
0
)
kv_indices
=
torch
.
arange
(
total_tokens
,
device
=
"cuda"
)
attn_logits
=
torch
.
empty
(
attn_logits
=
torch
.
empty
(
(
B
,
H_Q
,
num_kv_splits
,
D_V
+
1
),
(
B
,
H_Q
,
num_kv_splits
,
D_V
+
1
),
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
...
@@ -265,9 +268,8 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -265,9 +268,8 @@ class TestTritonAttention(unittest.TestCase):
k_buffer
,
k_buffer
,
v_buffer
,
v_buffer
,
o
,
o
,
req_to_token
,
kv_indptr
,
b_req_idx
,
kv_indices
,
b_seq_len
,
attn_logits
,
attn_logits
,
num_kv_splits
,
num_kv_splits
,
sm_scale
,
sm_scale
,
...
@@ -284,9 +286,8 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -284,9 +286,8 @@ class TestTritonAttention(unittest.TestCase):
k_buffer
,
k_buffer
,
v_buffer
,
v_buffer
,
o_grouped
,
o_grouped
,
req_to_token
,
kv_indptr
,
b_req_idx
,
kv_indices
,
b_seq_len
,
attn_logits1
,
attn_logits1
,
num_kv_splits
,
num_kv_splits
,
sm_scale
,
sm_scale
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment