Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
de553334
Unverified
Commit
de553334
authored
Feb 05, 2025
by
Ke Bao
Committed by
GitHub
Feb 05, 2025
Browse files
Update Triton extend backend interface (#3309)
parent
7aad8d18
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
427 additions
and
69 deletions
+427
-69
python/sglang/srt/layers/attention/double_sparsity_backend.py
...on/sglang/srt/layers/attention/double_sparsity_backend.py
+1
-3
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+52
-16
python/sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py
.../layers/attention/triton_ops/double_sparsity_attention.py
+337
-3
python/sglang/srt/layers/attention/triton_ops/extend_attention.py
...glang/srt/layers/attention/triton_ops/extend_attention.py
+22
-35
test/srt/test_triton_attention_kernels.py
test/srt/test_triton_attention_kernels.py
+15
-12
No files found.
python/sglang/srt/layers/attention/double_sparsity_backend.py
View file @
de553334
...
...
@@ -17,12 +17,10 @@ class DoubleSparseAttnBackend(AttentionBackend):
def
__init__
(
self
,
model_runner
:
ModelRunner
):
# Lazy import to avoid the initialization of cuda context
from
sglang.srt.layers.attention.triton_ops.double_sparsity_attention
import
(
extend_attention_fwd
,
flash_decode_attention_fwd
,
flash_decode_sparse_attention_fwd
,
)
from
sglang.srt.layers.attention.triton_ops.extend_attention
import
(
extend_attention_fwd
,
)
super
().
__init__
()
...
...
python/sglang/srt/layers/attention/triton_backend.py
View file @
de553334
...
...
@@ -37,6 +37,9 @@ class TritonAttnBackend(AttentionBackend):
(
max_bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
)
self
.
req_to_token
=
model_runner
.
req_to_token_pool
.
req_to_token
self
.
qo_indptr
=
torch
.
zeros
(
(
max_bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
)
self
.
num_head
=
(
model_runner
.
model_config
.
num_attention_heads
//
get_attention_tp_size
()
...
...
@@ -54,6 +57,9 @@ class TritonAttnBackend(AttentionBackend):
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
"""Init auxiliary variables for triton attention backend."""
bs
=
forward_batch
.
batch_size
kv_indptr
=
self
.
kv_indptr
if
forward_batch
.
forward_mode
.
is_decode
():
attn_logits
=
torch
.
empty
(
(
...
...
@@ -68,31 +74,59 @@ class TritonAttnBackend(AttentionBackend):
max_extend_len
=
None
kv_indptr
=
self
.
kv_indptr
bs
=
len
(
forward_batch
.
req_pool_indices
)
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
forward_batch
.
seq_lens
,
dim
=
0
)
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indices
=
torch
.
empty
(
forward_batch
.
seq_lens_sum
,
dtype
=
torch
.
int32
,
device
=
"cuda"
forward_batch
.
seq_lens_sum
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
create_flashinfer_kv_indices_triton
[(
bs
,)](
forward_batch
.
req_to_token_pool
.
req_to_token
,
self
.
req_to_token
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
kv_indptr
,
None
,
kv_indices
,
forward_batch
.
req_to_token_pool
.
req_to_token
.
stride
(
0
),
self
.
req_to_token
.
stride
(
0
),
)
qo_indptr
=
None
custom_mask
=
None
else
:
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
forward_batch
.
extend_prefix_lens
,
dim
=
0
)
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indices
=
torch
.
empty
(
forward_batch
.
extend_prefix_lens
.
sum
().
item
(),
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
create_flashinfer_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
forward_batch
.
req_pool_indices
,
forward_batch
.
extend_prefix_lens
,
kv_indptr
,
None
,
kv_indices
,
self
.
req_to_token
.
stride
(
0
),
)
qo_indptr
=
self
.
qo_indptr
qo_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
forward_batch
.
extend_seq_lens
,
dim
=
0
)
qo_indptr
=
qo_indptr
[:
bs
+
1
]
custom_mask
=
None
attn_logits
=
None
max_extend_len
=
torch
.
max
(
forward_batch
.
extend_seq_lens
).
item
()
kv_indptr
=
None
kv_indices
=
None
self
.
forward_metadata
=
attn_logits
,
max_extend_len
,
kv_indptr
,
kv_indices
self
.
forward_metadata
=
(
attn_logits
,
max_extend_len
,
kv_indptr
,
kv_indices
,
qo_indptr
,
custom_mask
,
)
def
init_cuda_graph_state
(
self
,
max_bs
:
int
):
self
.
cuda_graph_max_total_num_tokens
=
max_bs
*
self
.
cuda_graph_max_seq_len
...
...
@@ -144,6 +178,8 @@ class TritonAttnBackend(AttentionBackend):
None
,
kv_indptr
,
kv_indices
,
None
,
None
,
)
def
init_forward_metadata_replay_cuda_graph
(
...
...
@@ -197,7 +233,9 @@ class TritonAttnBackend(AttentionBackend):
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
)
_
,
max_extend_len
,
_
,
_
=
self
.
forward_metadata
_
,
max_extend_len
,
kv_indptr
,
kv_indices
,
qo_indptr
,
custom_mask
=
(
self
.
forward_metadata
)
self
.
extend_attention_fwd
(
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
),
k
.
contiguous
(),
...
...
@@ -205,11 +243,9 @@ class TritonAttnBackend(AttentionBackend):
o
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
v_head_dim
),
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
),
forward_batch
.
token_to_kv_pool
.
get_value_buffer
(
layer
.
layer_id
),
forward_batch
.
req_to_token_pool
.
req_to_token
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
forward_batch
.
extend_seq_lens
,
forward_batch
.
extend_start_loc
,
qo_indptr
,
kv_indptr
,
kv_indices
,
max_extend_len
,
layer
.
scaling
,
layer
.
logit_cap
,
...
...
@@ -235,7 +271,7 @@ class TritonAttnBackend(AttentionBackend):
else
:
o
=
torch
.
empty_like
(
q
)
attn_logits
,
_
,
kv_indptr
,
kv_indices
=
self
.
forward_metadata
attn_logits
,
_
,
kv_indptr
,
kv_indices
,
_
,
_
=
self
.
forward_metadata
if
save_kv_cache
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
...
...
python/sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py
View file @
de553334
...
...
@@ -3,6 +3,13 @@ import triton
import
triton.language
as
tl
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.utils
import
is_hip
is_cuda_available
=
torch
.
cuda
.
is_available
()
if
is_cuda_available
:
CUDA_CAPABILITY
=
torch
.
cuda
.
get_device_capability
()
is_hip_
=
is_hip
()
if
global_server_args_dict
.
get
(
"attention_reduce_in_fp32"
,
False
):
REDUCE_TRITON_TYPE
=
tl
.
float32
...
...
@@ -274,9 +281,6 @@ def flash_decode_stage2(mid_out, mid_out_logexpsum, B_Seqlen, O, block_seq):
return
import
torch
def
flash_decode_attention_fwd
(
q
,
k_buffer
,
...
...
@@ -770,3 +774,333 @@ def flash_decode_sparse_attention_fwd(
)
sparse_flash_decode_stage3
(
heavy_token_num
,
mid_out
,
mid_o_logexpsum
,
o
,
BLOCK_SEQ
)
# Extend attention kernel for Double Sparsity
# Moved from https://github.com/sgl-project/sglang/blob/v0.4.2.post1/python/sglang/srt/layers/attention/triton_ops/extend_attention.py
@
triton
.
jit
def
_fwd_kernel
(
Q_Extend
,
K_Extend
,
V_Extend
,
O_Extend
,
K_Buffer
,
V_Buffer
,
Req_to_tokens
,
B_req_idx
,
B_Seq_Len
,
B_Start_Loc_Extend
,
B_Seq_Len_Extend
,
sm_scale
,
kv_group_num
,
stride_qbs
,
stride_qh
,
stride_kbs
,
stride_kh
,
stride_vbs
,
stride_vh
,
stride_obs
,
stride_oh
,
stride_buf_kbs
,
stride_buf_kh
,
stride_buf_vbs
,
stride_buf_vh
,
stride_req_to_tokens_b
,
logit_cap
:
tl
.
constexpr
,
Lq
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_DPE
:
tl
.
constexpr
,
BLOCK_DV
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
):
cur_seq
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
cur_block_m
=
tl
.
program_id
(
2
)
cur_kv_head
=
cur_head
//
kv_group_num
cur_seq_len
=
tl
.
load
(
B_Seq_Len
+
cur_seq
)
cur_seq_len_extend
=
tl
.
load
(
B_Seq_Len_Extend
+
cur_seq
)
cur_seq_len_prefix
=
cur_seq_len
-
cur_seq_len_extend
cur_seq_prefix_start_in_loc
=
0
cur_seq_extend_start_contiguous
=
tl
.
load
(
B_Start_Loc_Extend
+
cur_seq
)
cur_batch_req_idx
=
tl
.
load
(
B_req_idx
+
cur_seq
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
offs_dv
=
tl
.
arange
(
0
,
BLOCK_DV
)
offs_m
=
tl
.
arange
(
0
,
BLOCK_M
)
mask_m
=
(
cur_block_m
*
BLOCK_M
+
offs_m
)
<
cur_seq_len_extend
mask_d
=
offs_d
<
Lq
mask_dv
=
offs_dv
<
Lv
offs_q
=
(
(
cur_seq_extend_start_contiguous
+
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_d
[
None
,
:]
)
q
=
tl
.
load
(
Q_Extend
+
offs_q
,
mask
=
(
mask_m
[:,
None
])
&
(
mask_d
[
None
,
:]),
other
=
0.0
)
if
BLOCK_DPE
>
0
:
offs_dpe
=
BLOCK_DMODEL
+
tl
.
arange
(
0
,
BLOCK_DPE
)
offs_qpe
=
(
(
cur_seq_extend_start_contiguous
+
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_dpe
[
None
,
:]
)
qpe
=
tl
.
load
(
Q_Extend
+
offs_qpe
,
mask
=
mask_m
[:,
None
],
other
=
0.0
)
# stage 1: compute scores with prefix
offs_n
=
tl
.
arange
(
0
,
BLOCK_N
)
acc
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_DV
],
dtype
=
tl
.
float32
)
deno
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
e_max
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
for
start_n
in
range
(
0
,
cur_seq_len_prefix
,
BLOCK_N
):
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
mask_n
=
(
start_n
+
offs_n
)
<
cur_seq_len_prefix
offs_b_loc_prefix
=
cur_batch_req_idx
*
stride_req_to_tokens_b
+
(
cur_seq_prefix_start_in_loc
+
start_n
+
offs_n
)
offs_kv_loc
=
tl
.
load
(
Req_to_tokens
+
offs_b_loc_prefix
,
mask
=
mask_n
,
other
=
0
)
# load k in transposed way
offs_buf_k
=
(
offs_kv_loc
[
None
,
:]
*
stride_buf_kbs
+
cur_kv_head
*
stride_buf_kh
+
offs_d
[:,
None
]
)
k
=
tl
.
load
(
K_Buffer
+
offs_buf_k
,
mask
=
(
mask_n
[
None
,
:])
&
(
mask_d
[:,
None
]),
other
=
0.0
)
qk
=
tl
.
dot
(
q
.
to
(
k
.
dtype
),
k
)
if
BLOCK_DPE
>
0
:
offs_kpe
=
(
offs_kv_loc
[
None
,
:]
*
stride_buf_kbs
+
cur_kv_head
*
stride_buf_kh
+
offs_dpe
[:,
None
]
)
kpe
=
tl
.
load
(
K_Buffer
+
offs_kpe
,
mask
=
mask_n
[
None
,
:],
other
=
0.0
,
)
qk
+=
tl
.
dot
(
qpe
.
to
(
kpe
.
dtype
),
kpe
)
qk
*=
sm_scale
if
logit_cap
>
0
:
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
qk
=
tl
.
where
(
mask_m
[:,
None
]
&
mask_n
[
None
,
:],
qk
,
float
(
"-inf"
))
n_e_max
=
tl
.
maximum
(
tl
.
max
(
qk
,
1
),
e_max
)
re_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
p
=
tl
.
exp
(
qk
-
n_e_max
[:,
None
])
deno
=
deno
*
re_scale
+
tl
.
sum
(
p
,
1
)
offs_buf_v
=
(
offs_kv_loc
[:,
None
]
*
stride_buf_vbs
+
cur_kv_head
*
stride_buf_vh
+
offs_dv
[
None
,
:]
)
v
=
tl
.
load
(
V_Buffer
+
offs_buf_v
,
mask
=
mask_n
[:,
None
]
&
mask_dv
[
None
,
:],
other
=
0.0
)
p
=
p
.
to
(
v
.
dtype
)
acc
=
acc
*
re_scale
[:,
None
]
+
tl
.
dot
(
p
,
v
)
e_max
=
n_e_max
# stage 2: compute the trianlge part
cur_block_m_end
=
tl
.
minimum
(
cur_seq_len_extend
,
(
cur_block_m
+
1
)
*
BLOCK_M
)
for
start_n
in
range
(
0
,
cur_block_m_end
,
BLOCK_N
):
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
mask_n
=
(
start_n
+
offs_n
)
<
cur_block_m_end
# load k in transposed way
offs_k
=
(
(
cur_seq_extend_start_contiguous
+
start_n
+
offs_n
[
None
,
:])
*
stride_kbs
+
cur_kv_head
*
stride_kh
+
offs_d
[:,
None
]
)
k
=
tl
.
load
(
K_Extend
+
offs_k
,
mask
=
(
mask_n
[
None
,
:])
&
(
mask_d
[:,
None
]),
other
=
0.0
)
qk
=
tl
.
dot
(
q
,
k
,
out_dtype
=
tl
.
float32
)
if
BLOCK_DPE
>
0
:
offs_kpe
=
(
(
cur_seq_extend_start_contiguous
+
start_n
+
offs_n
[
None
,
:])
*
stride_kbs
+
cur_kv_head
*
stride_kh
+
offs_dpe
[:,
None
]
)
kpe
=
tl
.
load
(
K_Extend
+
offs_kpe
,
mask
=
mask_n
[
None
,
:],
other
=
0.0
,
)
qk
+=
tl
.
dot
(
qpe
,
kpe
)
qk
*=
sm_scale
if
logit_cap
>
0
:
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
mask_causual
=
(
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
>=
(
start_n
+
offs_n
[
None
,
:]
)
mask_causual
&=
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
qk
=
tl
.
where
(
mask_causual
,
qk
,
float
(
"-inf"
))
n_e_max
=
tl
.
maximum
(
tl
.
max
(
qk
,
1
),
e_max
)
re_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
p
=
tl
.
exp
(
qk
-
n_e_max
[:,
None
])
deno
=
deno
*
re_scale
+
tl
.
sum
(
p
,
1
)
offs_v
=
(
(
cur_seq_extend_start_contiguous
+
start_n
+
offs_n
[:,
None
])
*
stride_vbs
+
cur_kv_head
*
stride_vh
+
offs_dv
[
None
,
:]
)
v
=
tl
.
load
(
V_Extend
+
offs_v
,
mask
=
mask_n
[:,
None
]
&
mask_dv
[
None
,
:],
other
=
0.0
)
p
=
p
.
to
(
v
.
dtype
)
acc
=
acc
*
re_scale
[:,
None
]
+
tl
.
dot
(
p
,
v
)
e_max
=
n_e_max
offs_o
=
(
(
cur_seq_extend_start_contiguous
+
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
*
stride_obs
+
cur_head
*
stride_oh
+
offs_dv
[
None
,
:]
)
tl
.
store
(
O_Extend
+
offs_o
,
acc
/
deno
[:,
None
],
mask
=
mask_m
[:,
None
]
&
mask_dv
[
None
,
:]
)
def
extend_attention_fwd
(
q_extend
,
k_extend
,
v_extend
,
o_extend
,
k_buffer
,
v_buffer
,
req_to_tokens
,
b_req_idx
,
b_seq_len
,
b_seq_len_extend
,
b_start_loc_extend
,
max_len_extend
,
sm_scale
=
None
,
logit_cap
=
0.0
,
):
"""
q_extend, k_extend, v_extend, o_extend: contiguous tensors
k_buffer, v_buffer: (prefix + extend) tensors in mem_manager
"""
Lq
,
Lk
,
Lv
=
(
q_extend
.
shape
[
-
1
],
k_extend
.
shape
[
-
1
],
v_extend
.
shape
[
-
1
],
)
if
Lq
==
576
:
BLOCK_DMODEL
=
512
BLOCK_DPE
=
64
elif
Lq
==
288
:
BLOCK_DMODEL
=
256
BLOCK_DPE
=
32
elif
Lq
==
192
:
BLOCK_DMODEL
=
128
BLOCK_DPE
=
64
else
:
BLOCK_DMODEL
=
triton
.
next_power_of_2
(
Lq
)
BLOCK_DPE
=
0
BLOCK_DV
=
triton
.
next_power_of_2
(
Lv
)
if
is_hip_
:
BLOCK_M
,
BLOCK_N
=
(
64
,
64
)
num_warps
=
4
else
:
if
is_cuda_available
and
CUDA_CAPABILITY
[
0
]
>=
9
:
if
Lq
<=
256
:
BLOCK_M
,
BLOCK_N
=
(
128
,
64
)
else
:
BLOCK_M
,
BLOCK_N
=
(
32
,
64
)
elif
is_cuda_available
and
CUDA_CAPABILITY
[
0
]
>=
8
:
if
Lq
<=
128
:
BLOCK_M
,
BLOCK_N
=
(
128
,
128
)
elif
Lq
<=
256
:
BLOCK_M
,
BLOCK_N
=
(
64
,
64
)
else
:
BLOCK_M
,
BLOCK_N
=
(
32
,
64
)
else
:
BLOCK_M
,
BLOCK_N
=
(
64
,
64
)
if
Lq
<=
128
else
(
32
,
32
)
num_warps
=
4
if
Lk
<=
64
else
8
sm_scale
=
sm_scale
or
1.0
/
(
Lq
**
0.5
)
batch_size
,
head_num
=
b_seq_len
.
shape
[
0
],
q_extend
.
shape
[
1
]
kv_group_num
=
q_extend
.
shape
[
1
]
//
k_extend
.
shape
[
1
]
grid
=
(
batch_size
,
head_num
,
triton
.
cdiv
(
max_len_extend
,
BLOCK_M
))
num_stages
=
1
extra_kargs
=
{}
if
is_hip_
:
extra_kargs
=
{
"waves_per_eu"
:
4
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
}
_fwd_kernel
[
grid
](
q_extend
,
k_extend
,
v_extend
,
o_extend
,
k_buffer
,
v_buffer
,
req_to_tokens
,
b_req_idx
,
b_seq_len
,
b_start_loc_extend
,
b_seq_len_extend
,
sm_scale
,
kv_group_num
,
q_extend
.
stride
(
0
),
q_extend
.
stride
(
1
),
k_extend
.
stride
(
0
),
k_extend
.
stride
(
1
),
v_extend
.
stride
(
0
),
v_extend
.
stride
(
1
),
o_extend
.
stride
(
0
),
o_extend
.
stride
(
1
),
k_buffer
.
stride
(
0
),
k_buffer
.
stride
(
1
),
v_buffer
.
stride
(
0
),
v_buffer
.
stride
(
1
),
req_to_tokens
.
stride
(
0
),
logit_cap
=
logit_cap
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_DPE
=
BLOCK_DPE
,
BLOCK_DV
=
BLOCK_DV
,
BLOCK_M
=
BLOCK_M
,
BLOCK_N
=
BLOCK_N
,
Lq
=
Lq
,
Lv
=
Lv
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
**
extra_kargs
,
)
python/sglang/srt/layers/attention/triton_ops/extend_attention.py
View file @
de553334
...
...
@@ -46,11 +46,9 @@ def _fwd_kernel(
O_Extend
,
K_Buffer
,
V_Buffer
,
Req_to_tokens
,
B_req_idx
,
B_Seq_Len
,
B_Start_Loc_Extend
,
B_Seq_Len_Extend
,
qo_indptr
,
kv_indptr
,
kv_indices
,
sm_scale
,
kv_group_num
,
stride_qbs
,
...
...
@@ -65,7 +63,6 @@ def _fwd_kernel(
stride_buf_kh
,
stride_buf_vbs
,
stride_buf_vh
,
stride_req_to_tokens_b
,
logit_cap
:
tl
.
constexpr
,
Lq
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
...
...
@@ -80,13 +77,10 @@ def _fwd_kernel(
cur_block_m
=
tl
.
program_id
(
2
)
cur_kv_head
=
cur_head
//
kv_group_num
cur_seq_len
=
tl
.
load
(
B_Seq_Len
+
cur_seq
)
cur_seq_len_extend
=
tl
.
load
(
B_Seq_Len_Extend
+
cur_seq
)
cur_seq_len_prefix
=
cur_seq_len
-
cur_seq_len_extend
cur_seq_prefix_start_in_loc
=
0
cur_seq_extend_start_contiguous
=
tl
.
load
(
B_Start_Loc_Extend
+
cur_seq
)
cur_batch_req_idx
=
tl
.
load
(
B_req_idx
+
cur_seq
)
cur_seq_extend_start_idx
=
tl
.
load
(
qo_indptr
+
cur_seq
)
cur_seq_len_extend
=
tl
.
load
(
qo_indptr
+
cur_seq
+
1
)
-
cur_seq_extend_start_idx
cur_seq_kv_start_idx
=
tl
.
load
(
kv_indptr
+
cur_seq
)
cur_seq_len_prefix
=
tl
.
load
(
kv_indptr
+
cur_seq
+
1
)
-
cur_seq_kv_start_idx
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
offs_dv
=
tl
.
arange
(
0
,
BLOCK_DV
)
...
...
@@ -97,7 +91,7 @@ def _fwd_kernel(
mask_dv
=
offs_dv
<
Lv
offs_q
=
(
(
cur_seq_extend_start_
contiguous
+
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
(
cur_seq_extend_start_
idx
+
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_d
[
None
,
:]
...
...
@@ -109,7 +103,7 @@ def _fwd_kernel(
if
BLOCK_DPE
>
0
:
offs_dpe
=
BLOCK_DMODEL
+
tl
.
arange
(
0
,
BLOCK_DPE
)
offs_qpe
=
(
(
cur_seq_extend_start_
contiguous
+
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
(
cur_seq_extend_start_
idx
+
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_dpe
[
None
,
:]
...
...
@@ -126,10 +120,9 @@ def _fwd_kernel(
for
start_n
in
range
(
0
,
cur_seq_len_prefix
,
BLOCK_N
):
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
mask_n
=
(
start_n
+
offs_n
)
<
cur_seq_len_prefix
offs_
b
_loc
_prefix
=
cur_batch_req_idx
*
stride_req_to_tokens_b
+
(
cur_seq_
prefix
_start_i
n_loc
+
start_n
+
offs_n
offs_
kv
_loc
=
tl
.
load
(
kv_indices
+
cur_seq_
kv
_start_i
dx
+
start_n
+
offs_n
,
mask
=
mask_n
,
other
=
0
)
offs_kv_loc
=
tl
.
load
(
Req_to_tokens
+
offs_b_loc_prefix
,
mask
=
mask_n
,
other
=
0
)
# load k in transposed way
offs_buf_k
=
(
...
...
@@ -188,7 +181,7 @@ def _fwd_kernel(
# load k in transposed way
offs_k
=
(
(
cur_seq_extend_start_
contiguous
+
start_n
+
offs_n
[
None
,
:])
*
stride_kbs
(
cur_seq_extend_start_
idx
+
start_n
+
offs_n
[
None
,
:])
*
stride_kbs
+
cur_kv_head
*
stride_kh
+
offs_d
[:,
None
]
)
...
...
@@ -199,8 +192,7 @@ def _fwd_kernel(
qk
=
tl
.
dot
(
q
,
k
,
out_dtype
=
tl
.
float32
)
if
BLOCK_DPE
>
0
:
offs_kpe
=
(
(
cur_seq_extend_start_contiguous
+
start_n
+
offs_n
[
None
,
:])
*
stride_kbs
(
cur_seq_extend_start_idx
+
start_n
+
offs_n
[
None
,
:])
*
stride_kbs
+
cur_kv_head
*
stride_kh
+
offs_dpe
[:,
None
]
)
...
...
@@ -228,7 +220,7 @@ def _fwd_kernel(
deno
=
deno
*
re_scale
+
tl
.
sum
(
p
,
1
)
offs_v
=
(
(
cur_seq_extend_start_
contiguous
+
start_n
+
offs_n
[:,
None
])
*
stride_vbs
(
cur_seq_extend_start_
idx
+
start_n
+
offs_n
[:,
None
])
*
stride_vbs
+
cur_kv_head
*
stride_vh
+
offs_dv
[
None
,
:]
)
...
...
@@ -241,7 +233,7 @@ def _fwd_kernel(
e_max
=
n_e_max
offs_o
=
(
(
cur_seq_extend_start_
contiguous
+
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
(
cur_seq_extend_start_
idx
+
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
*
stride_obs
+
cur_head
*
stride_oh
+
offs_dv
[
None
,
:]
...
...
@@ -258,11 +250,9 @@ def extend_attention_fwd(
o_extend
,
k_buffer
,
v_buffer
,
req_to_tokens
,
b_req_idx
,
b_seq_len
,
b_seq_len_extend
,
b_start_loc_extend
,
qo_indptr
,
kv_indptr
,
kv_indices
,
max_len_extend
,
sm_scale
=
None
,
logit_cap
=
0.0
,
...
...
@@ -315,7 +305,7 @@ def extend_attention_fwd(
num_warps
=
4
if
Lk
<=
64
else
8
sm_scale
=
sm_scale
or
1.0
/
(
Lq
**
0.5
)
batch_size
,
head_num
=
b_seq_len
.
shape
[
0
],
q_extend
.
shape
[
1
]
batch_size
,
head_num
=
qo_indptr
.
shape
[
0
]
-
1
,
q_extend
.
shape
[
1
]
kv_group_num
=
q_extend
.
shape
[
1
]
//
k_extend
.
shape
[
1
]
grid
=
(
batch_size
,
head_num
,
triton
.
cdiv
(
max_len_extend
,
BLOCK_M
))
...
...
@@ -332,11 +322,9 @@ def extend_attention_fwd(
o_extend
,
k_buffer
,
v_buffer
,
req_to_tokens
,
b_req_idx
,
b_seq_len
,
b_start_loc_extend
,
b_seq_len_extend
,
qo_indptr
,
kv_indptr
,
kv_indices
,
sm_scale
,
kv_group_num
,
q_extend
.
stride
(
0
),
...
...
@@ -351,7 +339,6 @@ def extend_attention_fwd(
k_buffer
.
stride
(
1
),
v_buffer
.
stride
(
0
),
v_buffer
.
stride
(
1
),
req_to_tokens
.
stride
(
0
),
logit_cap
=
logit_cap
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_DPE
=
BLOCK_DPE
,
...
...
test/srt/test_triton_attention_kernels.py
View file @
de553334
...
...
@@ -45,16 +45,20 @@ class TestTritonAttention(unittest.TestCase):
max_len_in_batch
=
torch
.
max
(
b_seq_len
,
0
)[
0
].
item
()
b_req_idx
=
torch
.
arange
(
B
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
req_to_tokens
=
torch
.
empty
(
(
B
,
max_len_in_batch
),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
b_start_loc
=
torch
.
zeros
((
B
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
b_start_loc
[
1
:]
=
torch
.
cumsum
(
b_seq_len
[:
-
1
],
0
)
b_start_loc_extend
=
torch
.
zeros
((
B
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
b_start_loc_extend
[
1
:]
=
torch
.
cumsum
(
b_seq_len_extend
[:
-
1
],
0
)
kv_indptr
=
torch
.
zeros
((
B
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_indptr
[
1
:
B
+
1
]
=
torch
.
cumsum
(
b_seq_len_prefix
[:
B
],
dim
=
0
)
kv_indices
=
torch
.
zeros
(
(
b_seq_len_prefix
.
sum
().
item
(),),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
for
i
in
range
(
B
):
req_to_tokens
[
i
,
:
b_seq_len
[
i
]]
=
torch
.
arange
(
b_start_loc
[
i
],
b_start_loc
[
i
]
+
b_seq_len
[
i
]
kv_indices
[
kv_indptr
[
i
]
:
kv_indptr
[
i
+
1
]]
=
torch
.
arange
(
b_start_loc
[
i
],
b_start_loc
[
i
]
+
b_seq_len
_prefix
[
i
]
)
total_token_num
=
torch
.
sum
(
b_seq_len
).
item
()
...
...
@@ -90,9 +94,10 @@ class TestTritonAttention(unittest.TestCase):
)
b_seq_len_extend
=
b_seq_len
-
b_seq_len_prefix
b_start_loc_extend
=
torch
.
zeros_like
(
b_seq_len
)
b_start_loc_extend
[
1
:]
=
torch
.
cumsum
(
b_seq_len_extend
[:
-
1
],
0
)
max_len_extend
=
torch
.
max
(
b_seq_len_extend
,
0
)[
0
].
item
()
qo_indptr
=
torch
.
zeros
((
B
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
qo_indptr
[
1
:
B
+
1
]
=
torch
.
cumsum
(
b_seq_len_extend
[:
B
],
dim
=
0
)
extend_attention_fwd
(
q_extend
,
k_extend
,
...
...
@@ -100,11 +105,9 @@ class TestTritonAttention(unittest.TestCase):
o_extend
,
k_buffer
,
v_buffer
,
req_to_tokens
,
b_req_idx
,
b_seq_len
,
b_seq_len_extend
,
b_start_loc_extend
,
qo_indptr
,
kv_indptr
,
kv_indices
,
max_len_extend
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment