Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
ddcf9fe3
Unverified
Commit
ddcf9fe3
authored
Feb 21, 2025
by
Ke Bao
Committed by
GitHub
Feb 21, 2025
Browse files
Optimize triton attention custom mask (#3731)
parent
6252ade9
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
1 deletion
+6
-1
python/sglang/srt/layers/attention/triton_ops/extend_attention.py
...glang/srt/layers/attention/triton_ops/extend_attention.py
+6
-1
No files found.
python/sglang/srt/layers/attention/triton_ops/extend_attention.py
View file @
ddcf9fe3
...
@@ -74,6 +74,7 @@ def _fwd_kernel(
...
@@ -74,6 +74,7 @@ def _fwd_kernel(
BLOCK_M
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
USE_CUSTOM_MASK
:
tl
.
constexpr
,
USE_CUSTOM_MASK
:
tl
.
constexpr
,
SKIP_PREFIX_CUSTOM_MASK
:
tl
.
constexpr
,
STORE_TRANSPOSE
:
tl
.
constexpr
,
STORE_TRANSPOSE
:
tl
.
constexpr
,
):
):
cur_seq
=
tl
.
program_id
(
0
)
cur_seq
=
tl
.
program_id
(
0
)
...
@@ -160,7 +161,7 @@ def _fwd_kernel(
...
@@ -160,7 +161,7 @@ def _fwd_kernel(
if
logit_cap
>
0
:
if
logit_cap
>
0
:
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
if
USE_CUSTOM_MASK
:
if
USE_CUSTOM_MASK
and
not
SKIP_PREFIX_CUSTOM_MASK
:
custom_mask
=
tl
.
load
(
custom_mask
=
tl
.
load
(
mask_ptr
mask_ptr
+
cur_seq_mask_start_idx
+
cur_seq_mask_start_idx
...
@@ -302,6 +303,7 @@ def extend_attention_fwd(
...
@@ -302,6 +303,7 @@ def extend_attention_fwd(
max_len_extend
,
max_len_extend
,
sm_scale
=
None
,
sm_scale
=
None
,
logit_cap
=
0.0
,
logit_cap
=
0.0
,
skip_prefix_custom_mask
=
True
,
):
):
"""
"""
q_extend, k_extend, v_extend, o_extend: contiguous tensors
q_extend, k_extend, v_extend, o_extend: contiguous tensors
...
@@ -355,6 +357,8 @@ def extend_attention_fwd(
...
@@ -355,6 +357,8 @@ def extend_attention_fwd(
kv_group_num
=
q_extend
.
shape
[
1
]
//
k_extend
.
shape
[
1
]
kv_group_num
=
q_extend
.
shape
[
1
]
//
k_extend
.
shape
[
1
]
USE_CUSTOM_MASK
=
custom_mask
is
not
None
USE_CUSTOM_MASK
=
custom_mask
is
not
None
# Skip custom mask for prefix part
SKIP_PREFIX_CUSTOM_MASK
=
skip_prefix_custom_mask
grid
=
(
batch_size
,
head_num
,
triton
.
cdiv
(
max_len_extend
,
BLOCK_M
))
grid
=
(
batch_size
,
head_num
,
triton
.
cdiv
(
max_len_extend
,
BLOCK_M
))
num_stages
=
1
num_stages
=
1
...
@@ -398,6 +402,7 @@ def extend_attention_fwd(
...
@@ -398,6 +402,7 @@ def extend_attention_fwd(
Lq
=
Lq
,
Lq
=
Lq
,
Lv
=
Lv
,
Lv
=
Lv
,
USE_CUSTOM_MASK
=
USE_CUSTOM_MASK
,
USE_CUSTOM_MASK
=
USE_CUSTOM_MASK
,
SKIP_PREFIX_CUSTOM_MASK
=
SKIP_PREFIX_CUSTOM_MASK
,
STORE_TRANSPOSE
=
is_hip_
,
STORE_TRANSPOSE
=
is_hip_
,
num_warps
=
num_warps
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
num_stages
=
num_stages
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment