Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
8e6bdf85
Unverified
Commit
8e6bdf85
authored
Sep 09, 2024
by
Byron Hsu
Committed by
GitHub
Sep 09, 2024
Browse files
[triton] Support head_dim not 2^n in triton extend and decode attention (#1281)
parent
05bea688
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
84 additions
and
37 deletions
+84
-37
python/sglang/srt/layers/decode_attention.py
python/sglang/srt/layers/decode_attention.py
+35
-15
python/sglang/srt/layers/extend_attention.py
python/sglang/srt/layers/extend_attention.py
+35
-16
python/sglang/srt/layers/prefill_attention.py
python/sglang/srt/layers/prefill_attention.py
+14
-6
No files found.
python/sglang/srt/layers/decode_attention.py
View file @
8e6bdf85
...
@@ -60,6 +60,7 @@ def _fwd_kernel_stage1(
...
@@ -60,6 +60,7 @@ def _fwd_kernel_stage1(
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
Lk
:
tl
.
constexpr
,
):
):
cur_batch
=
tl
.
program_id
(
0
)
cur_batch
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
cur_head
=
tl
.
program_id
(
1
)
...
@@ -97,7 +98,7 @@ def _fwd_kernel_stage1(
...
@@ -97,7 +98,7 @@ def _fwd_kernel_stage1(
)
)
k
=
tl
.
load
(
k
=
tl
.
load
(
K_Buffer
+
offs_buf_k
,
K_Buffer
+
offs_buf_k
,
mask
=
offs_n_new
[:,
None
]
<
cur_batch_end_index
,
mask
=
(
offs_n_new
[:,
None
]
<
cur_batch_end_index
)
&
(
offs_d
[
None
,
:]
<
Lk
)
,
other
=
0.0
,
other
=
0.0
,
).
to
(
REDUCE_TRITON_TYPE
)
).
to
(
REDUCE_TRITON_TYPE
)
att_value
=
tl
.
sum
(
q
[
None
,
:]
*
k
,
1
)
att_value
=
tl
.
sum
(
q
[
None
,
:]
*
k
,
1
)
...
@@ -128,6 +129,7 @@ def _fwd_kernel_stage2(
...
@@ -128,6 +129,7 @@ def _fwd_kernel_stage2(
kv_group_num
:
tl
.
constexpr
,
kv_group_num
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
):
):
cur_batch
=
tl
.
program_id
(
0
)
cur_batch
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
cur_head
=
tl
.
program_id
(
1
)
...
@@ -170,14 +172,16 @@ def _fwd_kernel_stage2(
...
@@ -170,14 +172,16 @@ def _fwd_kernel_stage2(
old_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
old_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
p
=
tl
.
exp
(
qk
-
n_e_max
)
p
=
tl
.
exp
(
qk
-
n_e_max
)
e_sum
=
e_sum
*
old_scale
+
tl
.
sum
(
p
,
0
)
e_sum
=
e_sum
*
old_scale
+
tl
.
sum
(
p
,
0
)
v
=
tl
.
load
(
v_ptrs
+
v_index
[:,
None
]
*
stride_buf_vbs
)
v
=
tl
.
load
(
v_ptrs
+
v_index
[:,
None
]
*
stride_buf_vbs
,
mask
=
(
offs_d
[
None
,
:]
<
Lv
)
)
acc
=
acc
*
old_scale
+
tl
.
sum
(
p
[:,
None
]
*
v
,
0
)
acc
=
acc
*
old_scale
+
tl
.
sum
(
p
[:,
None
]
*
v
,
0
)
e_max
=
n_e_max
e_max
=
n_e_max
acc
=
acc
/
e_sum
acc
=
acc
/
e_sum
off_o
=
cur_batch
*
stride_obs
+
cur_head
*
stride_oh
+
offs_d
off_o
=
cur_batch
*
stride_obs
+
cur_head
*
stride_oh
+
offs_d
out_ptrs
=
Out
+
off_o
out_ptrs
=
Out
+
off_o
tl
.
store
(
out_ptrs
,
acc
)
tl
.
store
(
out_ptrs
,
acc
,
mask
=
(
offs_d
<
Lv
)
)
def
_decode_att_m_fwd
(
def
_decode_att_m_fwd
(
...
@@ -196,7 +200,7 @@ def _decode_att_m_fwd(
...
@@ -196,7 +200,7 @@ def _decode_att_m_fwd(
# shape constraints
# shape constraints
Lq
,
Lk
=
q
.
shape
[
-
1
],
k_buffer
.
shape
[
-
1
]
Lq
,
Lk
=
q
.
shape
[
-
1
],
k_buffer
.
shape
[
-
1
]
assert
Lq
==
Lk
assert
Lq
==
Lk
assert
Lk
in
{
16
,
32
,
64
,
128
,
256
}
assert
Lk
in
{
16
,
32
,
64
,
96
,
128
,
256
}
batch
,
head_num
=
B_req_idx
.
shape
[
0
],
q
.
shape
[
1
]
batch
,
head_num
=
B_req_idx
.
shape
[
0
],
q
.
shape
[
1
]
...
@@ -208,6 +212,8 @@ def _decode_att_m_fwd(
...
@@ -208,6 +212,8 @@ def _decode_att_m_fwd(
else
:
else
:
num_warps
=
2
num_warps
=
2
BLOCK_DMODEL
=
triton
.
next_power_of_2
(
Lk
)
_fwd_kernel_stage1
[
grid
](
_fwd_kernel_stage1
[
grid
](
q
,
q
,
k_buffer
,
k_buffer
,
...
@@ -224,11 +230,12 @@ def _decode_att_m_fwd(
...
@@ -224,11 +230,12 @@ def _decode_att_m_fwd(
k_buffer
.
stride
(
1
),
k_buffer
.
stride
(
1
),
att_out
.
stride
(
0
),
att_out
.
stride
(
0
),
kv_group_num
=
kv_group_num
,
kv_group_num
=
kv_group_num
,
BLOCK_DMODEL
=
L
k
,
BLOCK_DMODEL
=
BLOCK_DMODE
L
,
BLOCK_N
=
BLOCK
,
BLOCK_N
=
BLOCK
,
logit_cap
=
logit_cap
,
logit_cap
=
logit_cap
,
num_warps
=
num_warps
,
num_warps
=
num_warps
,
num_stages
=
1
,
num_stages
=
1
,
Lk
=
Lk
,
)
)
...
@@ -248,6 +255,9 @@ def _decode_softmax_reducev_fwd(
...
@@ -248,6 +255,9 @@ def _decode_softmax_reducev_fwd(
num_warps
=
1
num_warps
=
1
Lv
=
v_buffer
.
shape
[
-
1
]
BLOCK_DMODEL
=
triton
.
next_power_of_2
(
Lv
)
_fwd_kernel_stage2
[
grid
](
_fwd_kernel_stage2
[
grid
](
logics
,
logics
,
v_buffer
,
v_buffer
,
...
@@ -263,10 +273,11 @@ def _decode_softmax_reducev_fwd(
...
@@ -263,10 +273,11 @@ def _decode_softmax_reducev_fwd(
o
.
stride
(
1
),
o
.
stride
(
1
),
req_to_tokens
.
stride
(
0
),
req_to_tokens
.
stride
(
0
),
kv_group_num
=
kv_group_num
,
kv_group_num
=
kv_group_num
,
BLOCK_DMODEL
=
v_buffer
.
shape
[
-
1
]
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_N
=
BLOCK
,
BLOCK_N
=
BLOCK
,
num_warps
=
num_warps
,
num_warps
=
num_warps
,
num_stages
=
3
,
num_stages
=
3
,
Lv
=
Lv
,
)
)
...
@@ -293,6 +304,7 @@ def _fwd_grouped_kernel_stage1(
...
@@ -293,6 +304,7 @@ def _fwd_grouped_kernel_stage1(
BLOCK_N
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_H
:
tl
.
constexpr
,
BLOCK_H
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
Lk
:
tl
.
constexpr
,
):
):
cur_batch
=
tl
.
program_id
(
0
)
cur_batch
=
tl
.
program_id
(
0
)
cur_kv_head
=
tl
.
program_id
(
1
)
cur_kv_head
=
tl
.
program_id
(
1
)
...
@@ -324,9 +336,9 @@ def _fwd_grouped_kernel_stage1(
...
@@ -324,9 +336,9 @@ def _fwd_grouped_kernel_stage1(
block_mask
=
tl
.
where
(
block_stard_index
<
cur_batch_seq_len
,
1
,
0
)
block_mask
=
tl
.
where
(
block_stard_index
<
cur_batch_seq_len
,
1
,
0
)
for
start_mark
in
range
(
0
,
block_mask
,
1
):
for
start_mark
in
range
(
0
,
block_mask
,
1
):
q
=
tl
.
load
(
Q
+
offs_q
+
start_mark
,
mask
=
mask_h
[:,
None
]).
to
(
q
=
tl
.
load
(
REDUCE_TRITON_TYPE
Q
+
offs_q
+
start_mark
,
mask
=
(
mask_h
[:,
None
])
&
(
offs_d
[
None
,
:]
<
Lk
)
)
)
.
to
(
REDUCE_TRITON_TYPE
)
offs_n_new
=
cur_batch_start_index
+
offs_n
offs_n_new
=
cur_batch_start_index
+
offs_n
k_loc
=
tl
.
load
(
k_loc
=
tl
.
load
(
Req_to_tokens
+
stride_req_to_tokens_b
*
cur_batch_req_idx
+
offs_n_new
,
Req_to_tokens
+
stride_req_to_tokens_b
*
cur_batch_req_idx
+
offs_n_new
,
...
@@ -340,7 +352,7 @@ def _fwd_grouped_kernel_stage1(
...
@@ -340,7 +352,7 @@ def _fwd_grouped_kernel_stage1(
)
)
k
=
tl
.
load
(
k
=
tl
.
load
(
K_Buffer
+
offs_buf_k
,
K_Buffer
+
offs_buf_k
,
mask
=
offs_n_new
[
None
,
:]
<
cur_batch_end_index
,
mask
=
(
offs_n_new
[
None
,
:]
<
cur_batch_end_index
)
&
(
offs_d
[:,
None
]
<
Lk
)
,
other
=
0.0
,
other
=
0.0
,
).
to
(
REDUCE_TRITON_TYPE
)
).
to
(
REDUCE_TRITON_TYPE
)
qk
=
tl
.
dot
(
q
,
k
)
qk
=
tl
.
dot
(
q
,
k
)
...
@@ -395,6 +407,7 @@ def _fwd_grouped_kernel_stage2(
...
@@ -395,6 +407,7 @@ def _fwd_grouped_kernel_stage2(
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_H
:
tl
.
constexpr
,
BLOCK_H
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
):
):
cur_batch
=
tl
.
program_id
(
0
)
cur_batch
=
tl
.
program_id
(
0
)
cur_kv_head
=
tl
.
program_id
(
1
)
cur_kv_head
=
tl
.
program_id
(
1
)
...
@@ -441,7 +454,9 @@ def _fwd_grouped_kernel_stage2(
...
@@ -441,7 +454,9 @@ def _fwd_grouped_kernel_stage2(
old_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
old_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
p
=
tl
.
exp
(
qk
-
n_e_max
[:,
None
])
p
=
tl
.
exp
(
qk
-
n_e_max
[:,
None
])
e_sum
=
e_sum
*
old_scale
+
tl
.
sum
(
p
,
1
)
e_sum
=
e_sum
*
old_scale
+
tl
.
sum
(
p
,
1
)
v
=
tl
.
load
(
v_ptrs
+
v_index
[:,
None
]
*
stride_buf_vbs
)
v
=
tl
.
load
(
v_ptrs
+
v_index
[:,
None
]
*
stride_buf_vbs
,
mask
=
(
offs_d
[
None
,
:]
<
Lv
)
)
p
=
p
.
to
(
v
.
dtype
)
p
=
p
.
to
(
v
.
dtype
)
acc
=
acc
*
old_scale
[:,
None
]
+
tl
.
dot
(
p
,
v
)
acc
=
acc
*
old_scale
[:,
None
]
+
tl
.
dot
(
p
,
v
)
e_max
=
n_e_max
e_max
=
n_e_max
...
@@ -449,7 +464,7 @@ def _fwd_grouped_kernel_stage2(
...
@@ -449,7 +464,7 @@ def _fwd_grouped_kernel_stage2(
acc
=
acc
/
e_sum
[:,
None
]
acc
=
acc
/
e_sum
[:,
None
]
off_o
=
cur_batch
*
stride_obs
+
cur_head
[:,
None
]
*
stride_oh
+
offs_d
[
None
,
:]
off_o
=
cur_batch
*
stride_obs
+
cur_head
[:,
None
]
*
stride_oh
+
offs_d
[
None
,
:]
out_ptrs
=
Out
+
off_o
out_ptrs
=
Out
+
off_o
tl
.
store
(
out_ptrs
,
acc
,
mask
=
mask_h
[:,
None
])
tl
.
store
(
out_ptrs
,
acc
,
mask
=
(
mask_h
[:,
None
])
&
(
offs_d
[
None
,
:]
<
Lv
))
def
_decode_grouped_att_m_fwd
(
def
_decode_grouped_att_m_fwd
(
...
@@ -468,13 +483,13 @@ def _decode_grouped_att_m_fwd(
...
@@ -468,13 +483,13 @@ def _decode_grouped_att_m_fwd(
# shape constraints
# shape constraints
Lq
,
Lk
=
q
.
shape
[
-
1
],
k_buffer
.
shape
[
-
1
]
Lq
,
Lk
=
q
.
shape
[
-
1
],
k_buffer
.
shape
[
-
1
]
assert
Lq
==
Lk
assert
Lq
==
Lk
assert
Lk
in
{
16
,
32
,
64
,
128
,
256
,
576
}
assert
Lk
in
{
16
,
32
,
64
,
96
,
128
,
256
,
576
}
if
Lk
==
576
:
if
Lk
==
576
:
BLOCK_DMODEL
=
512
BLOCK_DMODEL
=
512
BLOCK_DPE
=
64
BLOCK_DPE
=
64
else
:
else
:
BLOCK_DMODEL
=
Lk
BLOCK_DMODEL
=
triton
.
next_power_of_2
(
Lk
)
BLOCK_DPE
=
0
BLOCK_DPE
=
0
batch
,
head_num
=
B_req_idx
.
shape
[
0
],
q
.
shape
[
1
]
batch
,
head_num
=
B_req_idx
.
shape
[
0
],
q
.
shape
[
1
]
...
@@ -513,6 +528,7 @@ def _decode_grouped_att_m_fwd(
...
@@ -513,6 +528,7 @@ def _decode_grouped_att_m_fwd(
logit_cap
=
logit_cap
,
logit_cap
=
logit_cap
,
num_warps
=
num_warps
,
num_warps
=
num_warps
,
num_stages
=
1
,
num_stages
=
1
,
Lk
=
Lk
,
)
)
...
@@ -533,6 +549,9 @@ def _decode_grouped_softmax_reducev_fwd(
...
@@ -533,6 +549,9 @@ def _decode_grouped_softmax_reducev_fwd(
num_warps
=
8
num_warps
=
8
Lv
=
v_buffer
.
shape
[
-
1
]
BLOCK_DMODEL
=
triton
.
next_power_of_2
(
Lv
)
_fwd_grouped_kernel_stage2
[
grid
](
_fwd_grouped_kernel_stage2
[
grid
](
logics
,
logics
,
v_buffer
,
v_buffer
,
...
@@ -549,11 +568,12 @@ def _decode_grouped_softmax_reducev_fwd(
...
@@ -549,11 +568,12 @@ def _decode_grouped_softmax_reducev_fwd(
req_to_tokens
.
stride
(
0
),
req_to_tokens
.
stride
(
0
),
kv_group_num
=
kv_group_num
,
kv_group_num
=
kv_group_num
,
q_head_num
=
head_num
,
q_head_num
=
head_num
,
BLOCK_DMODEL
=
v_buffer
.
shape
[
-
1
]
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_N
=
BLOCK
,
BLOCK_N
=
BLOCK
,
BLOCK_H
=
BLOCK_H
,
BLOCK_H
=
BLOCK_H
,
num_warps
=
num_warps
,
num_warps
=
num_warps
,
num_stages
=
1
,
num_stages
=
1
,
Lv
=
Lv
,
)
)
...
...
python/sglang/srt/layers/extend_attention.py
View file @
8e6bdf85
...
@@ -15,7 +15,7 @@ limitations under the License.
...
@@ -15,7 +15,7 @@ limitations under the License.
"""
"""
Memory-efficient attention for prefill.
Memory-efficient attention for prefill.
It suppor
s
t page size = 1 and prefill with KV cache (i.e. extend).
It support
s
page size = 1 and prefill with KV cache (i.e. extend).
"""
"""
import
torch
import
torch
...
@@ -67,6 +67,8 @@ def _fwd_kernel(
...
@@ -67,6 +67,8 @@ def _fwd_kernel(
BLOCK_M
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
Lq
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
):
):
cur_seq
=
tl
.
program_id
(
0
)
cur_seq
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
cur_head
=
tl
.
program_id
(
1
)
...
@@ -86,13 +88,18 @@ def _fwd_kernel(
...
@@ -86,13 +88,18 @@ def _fwd_kernel(
offs_m
=
tl
.
arange
(
0
,
BLOCK_M
)
offs_m
=
tl
.
arange
(
0
,
BLOCK_M
)
mask_m
=
(
cur_block_m
*
BLOCK_M
+
offs_m
)
<
cur_seq_len_extend
mask_m
=
(
cur_block_m
*
BLOCK_M
+
offs_m
)
<
cur_seq_len_extend
mask_d
=
offs_d
<
Lq
mask_dv
=
offs_dv
<
Lv
offs_q
=
(
offs_q
=
(
(
cur_seq_extend_start_contiguous
+
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
(
cur_seq_extend_start_contiguous
+
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
*
stride_qbs
*
stride_qbs
+
cur_head
*
stride_qh
+
cur_head
*
stride_qh
+
offs_d
[
None
,
:]
+
offs_d
[
None
,
:]
)
)
q
=
tl
.
load
(
Q_Extend
+
offs_q
,
mask
=
mask_m
[:,
None
],
other
=
0.0
)
q
=
tl
.
load
(
Q_Extend
+
offs_q
,
mask
=
(
mask_m
[:,
None
])
&
(
mask_d
[
None
,
:]),
other
=
0.0
)
if
BLOCK_DPE
>
0
:
if
BLOCK_DPE
>
0
:
offs_dpe
=
BLOCK_DMODEL
+
tl
.
arange
(
0
,
BLOCK_DPE
)
offs_dpe
=
BLOCK_DMODEL
+
tl
.
arange
(
0
,
BLOCK_DPE
)
...
@@ -125,7 +132,9 @@ def _fwd_kernel(
...
@@ -125,7 +132,9 @@ def _fwd_kernel(
+
cur_kv_head
*
stride_buf_kh
+
cur_kv_head
*
stride_buf_kh
+
offs_d
[:,
None
]
+
offs_d
[:,
None
]
)
)
k
=
tl
.
load
(
K_Buffer
+
offs_buf_k
,
mask
=
mask_n
[
None
,
:],
other
=
0.0
)
k
=
tl
.
load
(
K_Buffer
+
offs_buf_k
,
mask
=
(
mask_n
[
None
,
:])
&
(
mask_d
[:,
None
]),
other
=
0.0
)
qk
=
tl
.
dot
(
q
.
to
(
k
.
dtype
),
k
)
qk
=
tl
.
dot
(
q
.
to
(
k
.
dtype
),
k
)
if
BLOCK_DPE
>
0
:
if
BLOCK_DPE
>
0
:
...
@@ -157,7 +166,9 @@ def _fwd_kernel(
...
@@ -157,7 +166,9 @@ def _fwd_kernel(
+
cur_kv_head
*
stride_buf_vh
+
cur_kv_head
*
stride_buf_vh
+
offs_dv
[
None
,
:]
+
offs_dv
[
None
,
:]
)
)
v
=
tl
.
load
(
V_Buffer
+
offs_buf_v
,
mask
=
mask_n
[:,
None
],
other
=
0.0
)
v
=
tl
.
load
(
V_Buffer
+
offs_buf_v
,
mask
=
mask_n
[:,
None
]
&
mask_dv
[
None
,
:],
other
=
0.0
)
p
=
p
.
to
(
v
.
dtype
)
p
=
p
.
to
(
v
.
dtype
)
acc
=
acc
*
re_scale
[:,
None
]
+
tl
.
dot
(
p
,
v
)
acc
=
acc
*
re_scale
[:,
None
]
+
tl
.
dot
(
p
,
v
)
...
@@ -176,7 +187,9 @@ def _fwd_kernel(
...
@@ -176,7 +187,9 @@ def _fwd_kernel(
+
cur_kv_head
*
stride_kh
+
cur_kv_head
*
stride_kh
+
offs_d
[:,
None
]
+
offs_d
[:,
None
]
)
)
k
=
tl
.
load
(
K_Extend
+
offs_k
,
mask
=
mask_n
[
None
,
:],
other
=
0.0
)
k
=
tl
.
load
(
K_Extend
+
offs_k
,
mask
=
(
mask_n
[
None
,
:])
&
(
mask_d
[:,
None
]),
other
=
0.0
)
qk
=
tl
.
dot
(
q
,
k
,
out_dtype
=
tl
.
float32
)
qk
=
tl
.
dot
(
q
,
k
,
out_dtype
=
tl
.
float32
)
if
BLOCK_DPE
>
0
:
if
BLOCK_DPE
>
0
:
...
@@ -214,7 +227,9 @@ def _fwd_kernel(
...
@@ -214,7 +227,9 @@ def _fwd_kernel(
+
cur_kv_head
*
stride_vh
+
cur_kv_head
*
stride_vh
+
offs_dv
[
None
,
:]
+
offs_dv
[
None
,
:]
)
)
v
=
tl
.
load
(
V_Extend
+
offs_v
,
mask
=
mask_n
[:,
None
],
other
=
0.0
)
v
=
tl
.
load
(
V_Extend
+
offs_v
,
mask
=
mask_n
[:,
None
]
&
mask_dv
[
None
,
:],
other
=
0.0
)
p
=
p
.
to
(
v
.
dtype
)
p
=
p
.
to
(
v
.
dtype
)
acc
=
acc
*
re_scale
[:,
None
]
+
tl
.
dot
(
p
,
v
)
acc
=
acc
*
re_scale
[:,
None
]
+
tl
.
dot
(
p
,
v
)
...
@@ -226,7 +241,9 @@ def _fwd_kernel(
...
@@ -226,7 +241,9 @@ def _fwd_kernel(
+
cur_head
*
stride_oh
+
cur_head
*
stride_oh
+
offs_dv
[
None
,
:]
+
offs_dv
[
None
,
:]
)
)
tl
.
store
(
O_Extend
+
offs_o
,
acc
/
deno
[:,
None
],
mask
=
mask_m
[:,
None
])
tl
.
store
(
O_Extend
+
offs_o
,
acc
/
deno
[:,
None
],
mask
=
mask_m
[:,
None
]
&
mask_dv
[
None
,
:]
)
def
extend_attention_fwd
(
def
extend_attention_fwd
(
...
@@ -261,16 +278,18 @@ def extend_attention_fwd(
...
@@ -261,16 +278,18 @@ def extend_attention_fwd(
)
)
assert
Lq
==
Lk
and
Lv
==
Lo
assert
Lq
==
Lk
and
Lv
==
Lo
assert
Lq
in
{
16
,
32
,
64
,
128
,
256
,
576
}
assert
Lv
in
{
16
,
32
,
64
,
128
,
256
,
512
}
# TODO: is the assertion necessary?
assert
Lq
in
{
16
,
32
,
64
,
96
,
128
,
256
,
576
}
assert
Lv
in
{
16
,
32
,
64
,
96
,
128
,
256
,
512
}
if
Lq
==
576
:
if
Lq
==
576
:
BLOCK_DMODEL
=
512
BLOCK_DMODEL
=
512
BLOCK_DPE
=
64
BLOCK_DPE
=
64
else
:
else
:
BLOCK_DMODEL
=
Lq
BLOCK_DMODEL
=
triton
.
next_power_of_2
(
Lq
)
BLOCK_DPE
=
0
BLOCK_DPE
=
0
BLOCK_DV
=
Lv
BLOCK_DV
=
triton
.
next_power_of_2
(
Lv
)
if
CUDA_CAPABILITY
[
0
]
>=
9
:
if
CUDA_CAPABILITY
[
0
]
>=
9
:
if
Lq
<=
256
:
if
Lq
<=
256
:
...
@@ -330,6 +349,8 @@ def extend_attention_fwd(
...
@@ -330,6 +349,8 @@ def extend_attention_fwd(
num_warps
=
num_warps
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
num_stages
=
num_stages
,
logit_cap
=
logit_cap
,
logit_cap
=
logit_cap
,
Lq
=
Lq
,
Lv
=
Lv
,
)
)
...
@@ -373,10 +394,7 @@ def redundant_attention(
...
@@ -373,10 +394,7 @@ def redundant_attention(
pt
+=
cur_seq_len_extend
pt
+=
cur_seq_len_extend
def
test
():
def
test_once
(
B
,
N_CTX
,
H_Q
,
H_KV
,
D
):
torch
.
manual_seed
(
0
)
B
,
N_CTX
,
H_Q
,
H_KV
,
D
=
19
,
12331
,
12
,
4
,
128
dtype
=
torch
.
float16
dtype
=
torch
.
float16
b_seq_len_prefix
=
torch
.
randint
(
b_seq_len_prefix
=
torch
.
randint
(
...
@@ -473,4 +491,5 @@ def test():
...
@@ -473,4 +491,5 @@ def test():
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test
()
test_once
(
19
,
12331
,
12
,
4
,
128
)
test_once
(
19
,
12331
,
12
,
4
,
96
)
python/sglang/srt/layers/prefill_attention.py
View file @
8e6bdf85
...
@@ -48,6 +48,7 @@ def _fwd_kernel(
...
@@ -48,6 +48,7 @@ def _fwd_kernel(
BLOCK_M
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
Lk
:
tl
.
constexpr
,
):
):
cur_batch
=
tl
.
program_id
(
0
)
cur_batch
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
cur_head
=
tl
.
program_id
(
1
)
...
@@ -72,7 +73,11 @@ def _fwd_kernel(
...
@@ -72,7 +73,11 @@ def _fwd_kernel(
off_k
=
offs_n
[
None
,
:]
*
stride_kbs
+
cur_kv_head
*
stride_kh
+
offs_d
[:,
None
]
off_k
=
offs_n
[
None
,
:]
*
stride_kbs
+
cur_kv_head
*
stride_kh
+
offs_d
[:,
None
]
off_v
=
offs_n
[:,
None
]
*
stride_vbs
+
cur_kv_head
*
stride_vh
+
offs_d
[
None
,
:]
off_v
=
offs_n
[:,
None
]
*
stride_vbs
+
cur_kv_head
*
stride_vh
+
offs_d
[
None
,
:]
q
=
tl
.
load
(
Q
+
off_q
,
mask
=
offs_m
[:,
None
]
<
cur_batch_seq_len
,
other
=
0.0
)
mask_d
=
offs_d
<
Lk
q
=
tl
.
load
(
Q
+
off_q
,
mask
=
(
offs_m
[:,
None
]
<
cur_batch_seq_len
)
&
(
mask_d
),
other
=
0.0
)
k_ptrs
=
K
+
off_k
k_ptrs
=
K
+
off_k
v_ptrs
=
V
+
off_v
v_ptrs
=
V
+
off_v
...
@@ -89,7 +94,7 @@ def _fwd_kernel(
...
@@ -89,7 +94,7 @@ def _fwd_kernel(
# -- compute qk ----
# -- compute qk ----
k
=
tl
.
load
(
k
=
tl
.
load
(
k_ptrs
+
(
cur_batch_in_all_start_index
+
start_n
)
*
stride_kbs
,
k_ptrs
+
(
cur_batch_in_all_start_index
+
start_n
)
*
stride_kbs
,
mask
=
(
start_n
+
offs_n
[
None
,
:])
<
cur_batch_seq_len
,
mask
=
(
(
start_n
+
offs_n
[
None
,
:])
<
cur_batch_seq_len
)
&
(
mask_d
[:,
None
])
,
other
=
0.0
,
other
=
0.0
,
)
)
# mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0)
# mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0)
...
@@ -118,7 +123,7 @@ def _fwd_kernel(
...
@@ -118,7 +123,7 @@ def _fwd_kernel(
# update acc
# update acc
v
=
tl
.
load
(
v
=
tl
.
load
(
v_ptrs
+
(
cur_batch_in_all_start_index
+
start_n
)
*
stride_vbs
,
v_ptrs
+
(
cur_batch_in_all_start_index
+
start_n
)
*
stride_vbs
,
mask
=
(
start_n
+
offs_n
[:,
None
])
<
cur_batch_seq_len
,
mask
=
(
(
start_n
+
offs_n
[:,
None
])
<
cur_batch_seq_len
)
&
(
mask_d
[
None
,
:])
,
other
=
0.0
,
other
=
0.0
,
)
)
...
@@ -134,7 +139,9 @@ def _fwd_kernel(
...
@@ -134,7 +139,9 @@ def _fwd_kernel(
+
offs_d
[
None
,
:]
+
offs_d
[
None
,
:]
)
)
out_ptrs
=
Out
+
off_o
out_ptrs
=
Out
+
off_o
tl
.
store
(
out_ptrs
,
acc
,
mask
=
offs_m
[:,
None
]
<
cur_batch_seq_len
)
tl
.
store
(
out_ptrs
,
acc
,
mask
=
(
offs_m
[:,
None
]
<
cur_batch_seq_len
)
&
(
mask_d
[
None
,
:])
)
def
context_attention_fwd
(
q
,
k
,
v
,
o
,
b_start_loc
,
b_seq_len
,
max_input_len
):
def
context_attention_fwd
(
q
,
k
,
v
,
o
,
b_start_loc
,
b_seq_len
,
max_input_len
):
...
@@ -145,7 +152,7 @@ def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
...
@@ -145,7 +152,7 @@ def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
Lq
,
Lk
,
Lv
=
q
.
shape
[
-
1
],
k
.
shape
[
-
1
],
v
.
shape
[
-
1
]
Lq
,
Lk
,
Lv
=
q
.
shape
[
-
1
],
k
.
shape
[
-
1
],
v
.
shape
[
-
1
]
assert
Lq
==
Lk
and
Lk
==
Lv
assert
Lq
==
Lk
and
Lk
==
Lv
assert
Lk
in
{
16
,
32
,
64
,
128
,
256
}
assert
Lk
in
{
16
,
32
,
64
,
96
,
128
,
256
}
sm_scale
=
1.0
/
(
Lq
**
0.5
)
sm_scale
=
1.0
/
(
Lq
**
0.5
)
batch
,
head
=
b_seq_len
.
shape
[
0
],
q
.
shape
[
1
]
batch
,
head
=
b_seq_len
.
shape
[
0
],
q
.
shape
[
1
]
...
@@ -172,8 +179,9 @@ def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
...
@@ -172,8 +179,9 @@ def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
o
.
stride
(
1
),
o
.
stride
(
1
),
kv_group_num
=
kv_group_num
,
kv_group_num
=
kv_group_num
,
BLOCK_M
=
BLOCK
,
BLOCK_M
=
BLOCK
,
BLOCK_DMODEL
=
Lk
,
BLOCK_DMODEL
=
triton
.
next_power_of_2
(
Lk
)
,
BLOCK_N
=
BLOCK
,
BLOCK_N
=
BLOCK
,
num_warps
=
num_warps
,
num_warps
=
num_warps
,
num_stages
=
1
,
num_stages
=
1
,
Lk
=
Lk
,
)
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment