Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
ddcf9fe3
"vscode:/vscode.git/clone" did not exist on "1d391bba132ac2cb6077ee10bc4138a7260d39f2"
Unverified
Commit
ddcf9fe3
authored
Feb 21, 2025
by
Ke Bao
Committed by
GitHub
Feb 21, 2025
Browse files
Optimize triton attention custom mask (#3731)
parent
6252ade9
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
1 deletion
+6
-1
python/sglang/srt/layers/attention/triton_ops/extend_attention.py
...glang/srt/layers/attention/triton_ops/extend_attention.py
+6
-1
No files found.
python/sglang/srt/layers/attention/triton_ops/extend_attention.py
View file @
ddcf9fe3
...
...
@@ -74,6 +74,7 @@ def _fwd_kernel(
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
USE_CUSTOM_MASK
:
tl
.
constexpr
,
SKIP_PREFIX_CUSTOM_MASK
:
tl
.
constexpr
,
STORE_TRANSPOSE
:
tl
.
constexpr
,
):
cur_seq
=
tl
.
program_id
(
0
)
...
...
@@ -160,7 +161,7 @@ def _fwd_kernel(
if
logit_cap
>
0
:
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
if
USE_CUSTOM_MASK
:
if
USE_CUSTOM_MASK
and
not
SKIP_PREFIX_CUSTOM_MASK
:
custom_mask
=
tl
.
load
(
mask_ptr
+
cur_seq_mask_start_idx
...
...
@@ -302,6 +303,7 @@ def extend_attention_fwd(
max_len_extend
,
sm_scale
=
None
,
logit_cap
=
0.0
,
skip_prefix_custom_mask
=
True
,
):
"""
q_extend, k_extend, v_extend, o_extend: contiguous tensors
...
...
@@ -355,6 +357,8 @@ def extend_attention_fwd(
kv_group_num
=
q_extend
.
shape
[
1
]
//
k_extend
.
shape
[
1
]
USE_CUSTOM_MASK
=
custom_mask
is
not
None
# Skip custom mask for prefix part
SKIP_PREFIX_CUSTOM_MASK
=
skip_prefix_custom_mask
grid
=
(
batch_size
,
head_num
,
triton
.
cdiv
(
max_len_extend
,
BLOCK_M
))
num_stages
=
1
...
...
@@ -398,6 +402,7 @@ def extend_attention_fwd(
Lq
=
Lq
,
Lv
=
Lv
,
USE_CUSTOM_MASK
=
USE_CUSTOM_MASK
,
SKIP_PREFIX_CUSTOM_MASK
=
SKIP_PREFIX_CUSTOM_MASK
,
STORE_TRANSPOSE
=
is_hip_
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment