Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
de553334
"vscode:/vscode.git/clone" did not exist on "c9dc1494538180c00b7a929bc1d245d39cc9ba1f"
Unverified
Commit
de553334
authored
Feb 05, 2025
by
Ke Bao
Committed by
GitHub
Feb 05, 2025
Browse files
Update Triton extend backend interface (#3309)
parent
7aad8d18
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
427 additions
and
69 deletions
+427
-69
python/sglang/srt/layers/attention/double_sparsity_backend.py
...on/sglang/srt/layers/attention/double_sparsity_backend.py
+1
-3
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+52
-16
python/sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py
.../layers/attention/triton_ops/double_sparsity_attention.py
+337
-3
python/sglang/srt/layers/attention/triton_ops/extend_attention.py
...glang/srt/layers/attention/triton_ops/extend_attention.py
+22
-35
test/srt/test_triton_attention_kernels.py
test/srt/test_triton_attention_kernels.py
+15
-12
No files found.
python/sglang/srt/layers/attention/double_sparsity_backend.py
View file @
de553334
...
...
@@ -17,12 +17,10 @@ class DoubleSparseAttnBackend(AttentionBackend):
def
__init__
(
self
,
model_runner
:
ModelRunner
):
# Lazy import to avoid the initialization of cuda context
from
sglang.srt.layers.attention.triton_ops.double_sparsity_attention
import
(
extend_attention_fwd
,
flash_decode_attention_fwd
,
flash_decode_sparse_attention_fwd
,
)
from
sglang.srt.layers.attention.triton_ops.extend_attention
import
(
extend_attention_fwd
,
)
super
().
__init__
()
...
...
python/sglang/srt/layers/attention/triton_backend.py
View file @
de553334
...
...
@@ -37,6 +37,9 @@ class TritonAttnBackend(AttentionBackend):
(
max_bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
)
self
.
req_to_token
=
model_runner
.
req_to_token_pool
.
req_to_token
self
.
qo_indptr
=
torch
.
zeros
(
(
max_bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
)
self
.
num_head
=
(
model_runner
.
model_config
.
num_attention_heads
//
get_attention_tp_size
()
...
...
@@ -54,6 +57,9 @@ class TritonAttnBackend(AttentionBackend):
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
"""Init auxiliary variables for triton attention backend."""
bs
=
forward_batch
.
batch_size
kv_indptr
=
self
.
kv_indptr
if
forward_batch
.
forward_mode
.
is_decode
():
attn_logits
=
torch
.
empty
(
(
...
...
@@ -68,31 +74,59 @@ class TritonAttnBackend(AttentionBackend):
max_extend_len
=
None
kv_indptr
=
self
.
kv_indptr
bs
=
len
(
forward_batch
.
req_pool_indices
)
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
forward_batch
.
seq_lens
,
dim
=
0
)
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indices
=
torch
.
empty
(
forward_batch
.
seq_lens_sum
,
dtype
=
torch
.
int32
,
device
=
"cuda"
forward_batch
.
seq_lens_sum
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
create_flashinfer_kv_indices_triton
[(
bs
,)](
forward_batch
.
req_to_token_pool
.
req_to_token
,
self
.
req_to_token
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
kv_indptr
,
None
,
kv_indices
,
forward_batch
.
req_to_token_pool
.
req_to_token
.
stride
(
0
),
self
.
req_to_token
.
stride
(
0
),
)
qo_indptr
=
None
custom_mask
=
None
else
:
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
forward_batch
.
extend_prefix_lens
,
dim
=
0
)
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indices
=
torch
.
empty
(
forward_batch
.
extend_prefix_lens
.
sum
().
item
(),
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
create_flashinfer_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
forward_batch
.
req_pool_indices
,
forward_batch
.
extend_prefix_lens
,
kv_indptr
,
None
,
kv_indices
,
self
.
req_to_token
.
stride
(
0
),
)
qo_indptr
=
self
.
qo_indptr
qo_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
forward_batch
.
extend_seq_lens
,
dim
=
0
)
qo_indptr
=
qo_indptr
[:
bs
+
1
]
custom_mask
=
None
attn_logits
=
None
max_extend_len
=
torch
.
max
(
forward_batch
.
extend_seq_lens
).
item
()
kv_indptr
=
None
kv_indices
=
None
self
.
forward_metadata
=
attn_logits
,
max_extend_len
,
kv_indptr
,
kv_indices
self
.
forward_metadata
=
(
attn_logits
,
max_extend_len
,
kv_indptr
,
kv_indices
,
qo_indptr
,
custom_mask
,
)
def
init_cuda_graph_state
(
self
,
max_bs
:
int
):
self
.
cuda_graph_max_total_num_tokens
=
max_bs
*
self
.
cuda_graph_max_seq_len
...
...
@@ -144,6 +178,8 @@ class TritonAttnBackend(AttentionBackend):
None
,
kv_indptr
,
kv_indices
,
None
,
None
,
)
def
init_forward_metadata_replay_cuda_graph
(
...
...
@@ -197,7 +233,9 @@ class TritonAttnBackend(AttentionBackend):
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
)
_
,
max_extend_len
,
_
,
_
=
self
.
forward_metadata
_
,
max_extend_len
,
kv_indptr
,
kv_indices
,
qo_indptr
,
custom_mask
=
(
self
.
forward_metadata
)
self
.
extend_attention_fwd
(
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
),
k
.
contiguous
(),
...
...
@@ -205,11 +243,9 @@ class TritonAttnBackend(AttentionBackend):
o
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
v_head_dim
),
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
),
forward_batch
.
token_to_kv_pool
.
get_value_buffer
(
layer
.
layer_id
),
forward_batch
.
req_to_token_pool
.
req_to_token
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
forward_batch
.
extend_seq_lens
,
forward_batch
.
extend_start_loc
,
qo_indptr
,
kv_indptr
,
kv_indices
,
max_extend_len
,
layer
.
scaling
,
layer
.
logit_cap
,
...
...
@@ -235,7 +271,7 @@ class TritonAttnBackend(AttentionBackend):
else
:
o
=
torch
.
empty_like
(
q
)
attn_logits
,
_
,
kv_indptr
,
kv_indices
=
self
.
forward_metadata
attn_logits
,
_
,
kv_indptr
,
kv_indices
,
_
,
_
=
self
.
forward_metadata
if
save_kv_cache
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
...
...
python/sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py
View file @
de553334
...
...
@@ -3,6 +3,13 @@ import triton
import
triton.language
as
tl
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.utils
import
is_hip
is_cuda_available
=
torch
.
cuda
.
is_available
()
if
is_cuda_available
:
CUDA_CAPABILITY
=
torch
.
cuda
.
get_device_capability
()
is_hip_
=
is_hip
()
if
global_server_args_dict
.
get
(
"attention_reduce_in_fp32"
,
False
):
REDUCE_TRITON_TYPE
=
tl
.
float32
...
...
@@ -274,9 +281,6 @@ def flash_decode_stage2(mid_out, mid_out_logexpsum, B_Seqlen, O, block_seq):
return
import
torch
def
flash_decode_attention_fwd
(
q
,
k_buffer
,
...
...
@@ -770,3 +774,333 @@ def flash_decode_sparse_attention_fwd(
)
sparse_flash_decode_stage3
(
heavy_token_num
,
mid_out
,
mid_o_logexpsum
,
o
,
BLOCK_SEQ
)
# Extend attention kernel for Double Sparsity
# Moved from https://github.com/sgl-project/sglang/blob/v0.4.2.post1/python/sglang/srt/layers/attention/triton_ops/extend_attention.py
@
triton
.
jit
def
_fwd_kernel
(
Q_Extend
,
K_Extend
,
V_Extend
,
O_Extend
,
K_Buffer
,
V_Buffer
,
Req_to_tokens
,
B_req_idx
,
B_Seq_Len
,
B_Start_Loc_Extend
,
B_Seq_Len_Extend
,
sm_scale
,
kv_group_num
,
stride_qbs
,
stride_qh
,
stride_kbs
,
stride_kh
,
stride_vbs
,
stride_vh
,
stride_obs
,
stride_oh
,
stride_buf_kbs
,
stride_buf_kh
,
stride_buf_vbs
,
stride_buf_vh
,
stride_req_to_tokens_b
,
logit_cap
:
tl
.
constexpr
,
Lq
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_DPE
:
tl
.
constexpr
,
BLOCK_DV
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
):
cur_seq
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
cur_block_m
=
tl
.
program_id
(
2
)
cur_kv_head
=
cur_head
//
kv_group_num
cur_seq_len
=
tl
.
load
(
B_Seq_Len
+
cur_seq
)
cur_seq_len_extend
=
tl
.
load
(
B_Seq_Len_Extend
+
cur_seq
)
cur_seq_len_prefix
=
cur_seq_len
-
cur_seq_len_extend
cur_seq_prefix_start_in_loc
=
0
cur_seq_extend_start_contiguous
=
tl
.
load
(
B_Start_Loc_Extend
+
cur_seq
)
cur_batch_req_idx
=
tl
.
load
(
B_req_idx
+
cur_seq
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
offs_dv
=
tl
.
arange
(
0
,
BLOCK_DV
)
offs_m
=
tl
.
arange
(
0
,
BLOCK_M
)
mask_m
=
(
cur_block_m
*
BLOCK_M
+
offs_m
)
<
cur_seq_len_extend
mask_d
=
offs_d
<
Lq
mask_dv
=
offs_dv
<
Lv
offs_q
=
(
(
cur_seq_extend_start_contiguous
+
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_d
[
None
,
:]
)
q
=
tl
.
load
(
Q_Extend
+
offs_q
,
mask
=
(
mask_m
[:,
None
])
&
(
mask_d
[
None
,
:]),
other
=
0.0
)
if
BLOCK_DPE
>
0
:
offs_dpe
=
BLOCK_DMODEL
+
tl
.
arange
(
0
,
BLOCK_DPE
)
offs_qpe
=
(
(
cur_seq_extend_start_contiguous
+
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_dpe
[
None
,
:]
)
qpe
=
tl
.
load
(
Q_Extend
+
offs_qpe
,
mask
=
mask_m
[:,
None
],
other
=
0.0
)
# stage 1: compute scores with prefix
offs_n
=
tl
.
arange
(
0
,
BLOCK_N
)
acc
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_DV
],
dtype
=
tl
.
float32
)
deno
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
e_max
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
for
start_n
in
range
(
0
,
cur_seq_len_prefix
,
BLOCK_N
):
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
mask_n
=
(
start_n
+
offs_n
)
<
cur_seq_len_prefix
offs_b_loc_prefix
=
cur_batch_req_idx
*
stride_req_to_tokens_b
+
(
cur_seq_prefix_start_in_loc
+
start_n
+
offs_n
)
offs_kv_loc
=
tl
.
load
(
Req_to_tokens
+
offs_b_loc_prefix
,
mask
=
mask_n
,
other
=
0
)
# load k in transposed way
offs_buf_k
=
(
offs_kv_loc
[
None
,
:]
*
stride_buf_kbs
+
cur_kv_head
*
stride_buf_kh
+
offs_d
[:,
None
]
)
k
=
tl
.
load
(
K_Buffer
+
offs_buf_k
,
mask
=
(
mask_n
[
None
,
:])
&
(
mask_d
[:,
None
]),
other
=
0.0
)
qk
=
tl
.
dot
(
q
.
to
(
k
.
dtype
),
k
)
if
BLOCK_DPE
>
0
:
offs_kpe
=
(
offs_kv_loc
[
None
,
:]
*
stride_buf_kbs
+
cur_kv_head
*
stride_buf_kh
+
offs_dpe
[:,
None
]
)
kpe
=
tl
.
load
(
K_Buffer
+
offs_kpe
,
mask
=
mask_n
[
None
,
:],
other
=
0.0
,
)
qk
+=
tl
.
dot
(
qpe
.
to
(
kpe
.
dtype
),
kpe
)
qk
*=
sm_scale
if
logit_cap
>
0
:
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
qk
=
tl
.
where
(
mask_m
[:,
None
]
&
mask_n
[
None
,
:],
qk
,
float
(
"-inf"
))
n_e_max
=
tl
.
maximum
(
tl
.
max
(
qk
,
1
),
e_max
)
re_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
p
=
tl
.
exp
(
qk
-
n_e_max
[:,
None
])
deno
=
deno
*
re_scale
+
tl
.
sum
(
p
,
1
)
offs_buf_v
=
(
offs_kv_loc
[:,
None
]
*
stride_buf_vbs
+
cur_kv_head
*
stride_buf_vh
+
offs_dv
[
None
,
:]
)
v
=
tl
.
load
(
V_Buffer
+
offs_buf_v
,
mask
=
mask_n
[:,
None
]
&
mask_dv
[
None
,
:],
other
=
0.0
)
p
=
p
.
to
(
v
.
dtype
)
acc
=
acc
*
re_scale
[:,
None
]
+
tl
.
dot
(
p
,
v
)
e_max
=
n_e_max
# stage 2: compute the trianlge part
cur_block_m_end
=
tl
.
minimum
(
cur_seq_len_extend
,
(
cur_block_m
+
1
)
*
BLOCK_M
)
for
start_n
in
range
(
0
,
cur_block_m_end
,
BLOCK_N
):
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
mask_n
=
(
start_n
+
offs_n
)
<
cur_block_m_end
# load k in transposed way
offs_k
=
(
(
cur_seq_extend_start_contiguous
+
start_n
+
offs_n
[
None
,
:])
*
stride_kbs
+
cur_kv_head
*
stride_kh
+
offs_d
[:,
None
]
)
k
=
tl
.
load
(
K_Extend
+
offs_k
,
mask
=
(
mask_n
[
None
,
:])
&
(
mask_d
[:,
None
]),
other
=
0.0
)
qk
=
tl
.
dot
(
q
,
k
,
out_dtype
=
tl
.
float32
)
if
BLOCK_DPE
>
0
:
offs_kpe
=
(
(
cur_seq_extend_start_contiguous
+
start_n
+
offs_n
[
None
,
:])
*
stride_kbs
+
cur_kv_head
*
stride_kh
+
offs_dpe
[:,
None
]
)
kpe
=
tl
.
load
(
K_Extend
+
offs_kpe
,
mask
=
mask_n
[
None
,
:],
other
=
0.0
,
)
qk
+=
tl
.
dot
(
qpe
,
kpe
)
qk
*=
sm_scale
if
logit_cap
>
0
:
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
mask_causual
=
(
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
>=
(
start_n
+
offs_n
[
None
,
:]
)
mask_causual
&=
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
qk
=
tl
.
where
(
mask_causual
,
qk
,
float
(
"-inf"
))
n_e_max
=
tl
.
maximum
(
tl
.
max
(
qk
,
1
),
e_max
)
re_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
p
=
tl
.
exp
(
qk
-
n_e_max
[:,
None
])
deno
=
deno
*
re_scale
+
tl
.
sum
(
p
,
1
)
offs_v
=
(
(
cur_seq_extend_start_contiguous
+
start_n
+
offs_n
[:,
None
])
*
stride_vbs
+
cur_kv_head
*
stride_vh
+
offs_dv
[
None
,
:]
)
v
=
tl
.
load
(
V_Extend
+
offs_v
,
mask
=
mask_n
[:,
None
]
&
mask_dv
[
None
,
:],
other
=
0.0
)
p
=
p
.
to
(
v
.
dtype
)
acc
=
acc
*
re_scale
[:,
None
]
+
tl
.
dot
(
p
,
v
)
e_max
=
n_e_max
offs_o
=
(
(
cur_seq_extend_start_contiguous
+
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
*
stride_obs
+
cur_head
*
stride_oh
+
offs_dv
[
None
,
:]
)
tl
.
store
(
O_Extend
+
offs_o
,
acc
/
deno
[:,
None
],
mask
=
mask_m
[:,
None
]
&
mask_dv
[
None
,
:]
)
def
extend_attention_fwd
(
q_extend
,
k_extend
,
v_extend
,
o_extend
,
k_buffer
,
v_buffer
,
req_to_tokens
,
b_req_idx
,
b_seq_len
,
b_seq_len_extend
,
b_start_loc_extend
,
max_len_extend
,
sm_scale
=
None
,
logit_cap
=
0.0
,
):
"""
q_extend, k_extend, v_extend, o_extend: contiguous tensors
k_buffer, v_buffer: (prefix + extend) tensors in mem_manager
"""
Lq
,
Lk
,
Lv
=
(
q_extend
.
shape
[
-
1
],
k_extend
.
shape
[
-
1
],
v_extend
.
shape
[
-
1
],
)
if
Lq
==
576
:
BLOCK_DMODEL
=
512
BLOCK_DPE
=
64
elif
Lq
==
288
:
BLOCK_DMODEL
=
256
BLOCK_DPE
=
32
elif
Lq
==
192
:
BLOCK_DMODEL
=
128
BLOCK_DPE
=
64
else
:
BLOCK_DMODEL
=
triton
.
next_power_of_2
(
Lq
)
BLOCK_DPE
=
0
BLOCK_DV
=
triton
.
next_power_of_2
(
Lv
)
if
is_hip_
:
BLOCK_M
,
BLOCK_N
=
(
64
,
64
)
num_warps
=
4
else
:
if
is_cuda_available
and
CUDA_CAPABILITY
[
0
]
>=
9
:
if
Lq
<=
256
:
BLOCK_M
,
BLOCK_N
=
(
128
,
64
)
else
:
BLOCK_M
,
BLOCK_N
=
(
32
,
64
)
elif
is_cuda_available
and
CUDA_CAPABILITY
[
0
]
>=
8
:
if
Lq
<=
128
:
BLOCK_M
,
BLOCK_N
=
(
128
,
128
)
elif
Lq
<=
256
:
BLOCK_M
,
BLOCK_N
=
(
64
,
64
)
else
:
BLOCK_M
,
BLOCK_N
=
(
32
,
64
)
else
:
BLOCK_M
,
BLOCK_N
=
(
64
,
64
)
if
Lq
<=
128
else
(
32
,
32
)
num_warps
=
4
if
Lk
<=
64
else
8
sm_scale
=
sm_scale
or
1.0
/
(
Lq
**
0.5
)
batch_size
,
head_num
=
b_seq_len
.
shape
[
0
],
q_extend
.
shape
[
1
]
kv_group_num
=
q_extend
.
shape
[
1
]
//
k_extend
.
shape
[
1
]
grid
=
(
batch_size
,
head_num
,
triton
.
cdiv
(
max_len_extend
,
BLOCK_M
))
num_stages
=
1
extra_kargs
=
{}
if
is_hip_
:
extra_kargs
=
{
"waves_per_eu"
:
4
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
}
_fwd_kernel
[
grid
](
q_extend
,
k_extend
,
v_extend
,
o_extend
,
k_buffer
,
v_buffer
,
req_to_tokens
,
b_req_idx
,
b_seq_len
,
b_start_loc_extend
,
b_seq_len_extend
,
sm_scale
,
kv_group_num
,
q_extend
.
stride
(
0
),
q_extend
.
stride
(
1
),
k_extend
.
stride
(
0
),
k_extend
.
stride
(
1
),
v_extend
.
stride
(
0
),
v_extend
.
stride
(
1
),
o_extend
.
stride
(
0
),
o_extend
.
stride
(
1
),
k_buffer
.
stride
(
0
),
k_buffer
.
stride
(
1
),
v_buffer
.
stride
(
0
),
v_buffer
.
stride
(
1
),
req_to_tokens
.
stride
(
0
),
logit_cap
=
logit_cap
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_DPE
=
BLOCK_DPE
,
BLOCK_DV
=
BLOCK_DV
,
BLOCK_M
=
BLOCK_M
,
BLOCK_N
=
BLOCK_N
,
Lq
=
Lq
,
Lv
=
Lv
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
**
extra_kargs
,
)
python/sglang/srt/layers/attention/triton_ops/extend_attention.py
View file @
de553334
...
...
@@ -46,11 +46,9 @@ def _fwd_kernel(
O_Extend
,
K_Buffer
,
V_Buffer
,
Req_to_tokens
,
B_req_idx
,
B_Seq_Len
,
B_Start_Loc_Extend
,
B_Seq_Len_Extend
,
qo_indptr
,
kv_indptr
,
kv_indices
,
sm_scale
,
kv_group_num
,
stride_qbs
,
...
...
@@ -65,7 +63,6 @@ def _fwd_kernel(
stride_buf_kh
,
stride_buf_vbs
,
stride_buf_vh
,
stride_req_to_tokens_b
,
logit_cap
:
tl
.
constexpr
,
Lq
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
...
...
@@ -80,13 +77,10 @@ def _fwd_kernel(
cur_block_m
=
tl
.
program_id
(
2
)
cur_kv_head
=
cur_head
//
kv_group_num
cur_seq_len
=
tl
.
load
(
B_Seq_Len
+
cur_seq
)
cur_seq_len_extend
=
tl
.
load
(
B_Seq_Len_Extend
+
cur_seq
)
cur_seq_len_prefix
=
cur_seq_len
-
cur_seq_len_extend
cur_seq_prefix_start_in_loc
=
0
cur_seq_extend_start_contiguous
=
tl
.
load
(
B_Start_Loc_Extend
+
cur_seq
)
cur_batch_req_idx
=
tl
.
load
(
B_req_idx
+
cur_seq
)
cur_seq_extend_start_idx
=
tl
.
load
(
qo_indptr
+
cur_seq
)
cur_seq_len_extend
=
tl
.
load
(
qo_indptr
+
cur_seq
+
1
)
-
cur_seq_extend_start_idx
cur_seq_kv_start_idx
=
tl
.
load
(
kv_indptr
+
cur_seq
)
cur_seq_len_prefix
=
tl
.
load
(
kv_indptr
+
cur_seq
+
1
)
-
cur_seq_kv_start_idx
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
offs_dv
=
tl
.
arange
(
0
,
BLOCK_DV
)
...
...
@@ -97,7 +91,7 @@ def _fwd_kernel(
mask_dv
=
offs_dv
<
Lv
offs_q
=
(
(
cur_seq_extend_start_
contiguous
+
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
(
cur_seq_extend_start_
idx
+
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_d
[
None
,
:]
...
...
@@ -109,7 +103,7 @@ def _fwd_kernel(
if
BLOCK_DPE
>
0
:
offs_dpe
=
BLOCK_DMODEL
+
tl
.
arange
(
0
,
BLOCK_DPE
)
offs_qpe
=
(
(
cur_seq_extend_start_
contiguous
+
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
(
cur_seq_extend_start_
idx
+
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_dpe
[
None
,
:]
...
...
@@ -126,10 +120,9 @@ def _fwd_kernel(
for
start_n
in
range
(
0
,
cur_seq_len_prefix
,
BLOCK_N
):
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
mask_n
=
(
start_n
+
offs_n
)
<
cur_seq_len_prefix
offs_
b
_loc
_prefix
=
cur_batch_req_idx
*
stride_req_to_tokens_b
+
(
cur_seq_
prefix
_start_i
n_loc
+
start_n
+
offs_n
offs_
kv
_loc
=
tl
.
load
(
kv_indices
+
cur_seq_
kv
_start_i
dx
+
start_n
+
offs_n
,
mask
=
mask_n
,
other
=
0
)
offs_kv_loc
=
tl
.
load
(
Req_to_tokens
+
offs_b_loc_prefix
,
mask
=
mask_n
,
other
=
0
)
# load k in transposed way
offs_buf_k
=
(
...
...
@@ -188,7 +181,7 @@ def _fwd_kernel(
# load k in transposed way
offs_k
=
(
(
cur_seq_extend_start_
contiguous
+
start_n
+
offs_n
[
None
,
:])
*
stride_kbs
(
cur_seq_extend_start_
idx
+
start_n
+
offs_n
[
None
,
:])
*
stride_kbs
+
cur_kv_head
*
stride_kh
+
offs_d
[:,
None
]
)
...
...
@@ -199,8 +192,7 @@ def _fwd_kernel(
qk
=
tl
.
dot
(
q
,
k
,
out_dtype
=
tl
.
float32
)
if
BLOCK_DPE
>
0
:
offs_kpe
=
(
(
cur_seq_extend_start_contiguous
+
start_n
+
offs_n
[
None
,
:])
*
stride_kbs
(
cur_seq_extend_start_idx
+
start_n
+
offs_n
[
None
,
:])
*
stride_kbs
+
cur_kv_head
*
stride_kh
+
offs_dpe
[:,
None
]
)
...
...
@@ -228,7 +220,7 @@ def _fwd_kernel(
deno
=
deno
*
re_scale
+
tl
.
sum
(
p
,
1
)
offs_v
=
(
(
cur_seq_extend_start_
contiguous
+
start_n
+
offs_n
[:,
None
])
*
stride_vbs
(
cur_seq_extend_start_
idx
+
start_n
+
offs_n
[:,
None
])
*
stride_vbs
+
cur_kv_head
*
stride_vh
+
offs_dv
[
None
,
:]
)
...
...
@@ -241,7 +233,7 @@ def _fwd_kernel(
e_max
=
n_e_max
offs_o
=
(
(
cur_seq_extend_start_
contiguous
+
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
(
cur_seq_extend_start_
idx
+
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
*
stride_obs
+
cur_head
*
stride_oh
+
offs_dv
[
None
,
:]
...
...
@@ -258,11 +250,9 @@ def extend_attention_fwd(
o_extend
,
k_buffer
,
v_buffer
,
req_to_tokens
,
b_req_idx
,
b_seq_len
,
b_seq_len_extend
,
b_start_loc_extend
,
qo_indptr
,
kv_indptr
,
kv_indices
,
max_len_extend
,
sm_scale
=
None
,
logit_cap
=
0.0
,
...
...
@@ -315,7 +305,7 @@ def extend_attention_fwd(
num_warps
=
4
if
Lk
<=
64
else
8
sm_scale
=
sm_scale
or
1.0
/
(
Lq
**
0.5
)
batch_size
,
head_num
=
b_seq_len
.
shape
[
0
],
q_extend
.
shape
[
1
]
batch_size
,
head_num
=
qo_indptr
.
shape
[
0
]
-
1
,
q_extend
.
shape
[
1
]
kv_group_num
=
q_extend
.
shape
[
1
]
//
k_extend
.
shape
[
1
]
grid
=
(
batch_size
,
head_num
,
triton
.
cdiv
(
max_len_extend
,
BLOCK_M
))
...
...
@@ -332,11 +322,9 @@ def extend_attention_fwd(
o_extend
,
k_buffer
,
v_buffer
,
req_to_tokens
,
b_req_idx
,
b_seq_len
,
b_start_loc_extend
,
b_seq_len_extend
,
qo_indptr
,
kv_indptr
,
kv_indices
,
sm_scale
,
kv_group_num
,
q_extend
.
stride
(
0
),
...
...
@@ -351,7 +339,6 @@ def extend_attention_fwd(
k_buffer
.
stride
(
1
),
v_buffer
.
stride
(
0
),
v_buffer
.
stride
(
1
),
req_to_tokens
.
stride
(
0
),
logit_cap
=
logit_cap
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_DPE
=
BLOCK_DPE
,
...
...
test/srt/test_triton_attention_kernels.py
View file @
de553334
...
...
@@ -45,16 +45,20 @@ class TestTritonAttention(unittest.TestCase):
max_len_in_batch
=
torch
.
max
(
b_seq_len
,
0
)[
0
].
item
()
b_req_idx
=
torch
.
arange
(
B
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
req_to_tokens
=
torch
.
empty
(
(
B
,
max_len_in_batch
),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
b_start_loc
=
torch
.
zeros
((
B
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
b_start_loc
[
1
:]
=
torch
.
cumsum
(
b_seq_len
[:
-
1
],
0
)
b_start_loc_extend
=
torch
.
zeros
((
B
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
b_start_loc_extend
[
1
:]
=
torch
.
cumsum
(
b_seq_len_extend
[:
-
1
],
0
)
kv_indptr
=
torch
.
zeros
((
B
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_indptr
[
1
:
B
+
1
]
=
torch
.
cumsum
(
b_seq_len_prefix
[:
B
],
dim
=
0
)
kv_indices
=
torch
.
zeros
(
(
b_seq_len_prefix
.
sum
().
item
(),),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
for
i
in
range
(
B
):
req_to_tokens
[
i
,
:
b_seq_len
[
i
]]
=
torch
.
arange
(
b_start_loc
[
i
],
b_start_loc
[
i
]
+
b_seq_len
[
i
]
kv_indices
[
kv_indptr
[
i
]
:
kv_indptr
[
i
+
1
]]
=
torch
.
arange
(
b_start_loc
[
i
],
b_start_loc
[
i
]
+
b_seq_len
_prefix
[
i
]
)
total_token_num
=
torch
.
sum
(
b_seq_len
).
item
()
...
...
@@ -90,9 +94,10 @@ class TestTritonAttention(unittest.TestCase):
)
b_seq_len_extend
=
b_seq_len
-
b_seq_len_prefix
b_start_loc_extend
=
torch
.
zeros_like
(
b_seq_len
)
b_start_loc_extend
[
1
:]
=
torch
.
cumsum
(
b_seq_len_extend
[:
-
1
],
0
)
max_len_extend
=
torch
.
max
(
b_seq_len_extend
,
0
)[
0
].
item
()
qo_indptr
=
torch
.
zeros
((
B
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
qo_indptr
[
1
:
B
+
1
]
=
torch
.
cumsum
(
b_seq_len_extend
[:
B
],
dim
=
0
)
extend_attention_fwd
(
q_extend
,
k_extend
,
...
...
@@ -100,11 +105,9 @@ class TestTritonAttention(unittest.TestCase):
o_extend
,
k_buffer
,
v_buffer
,
req_to_tokens
,
b_req_idx
,
b_seq_len
,
b_seq_len_extend
,
b_start_loc_extend
,
qo_indptr
,
kv_indptr
,
kv_indices
,
max_len_extend
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment