Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
399e7ec8
Unverified
Commit
399e7ec8
authored
Aug 06, 2025
by
Ke Bao
Committed by
GitHub
Aug 06, 2025
Browse files
Refine naming (#8868)
parent
1bd53168
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
30 additions
and
30 deletions
+30
-30
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+4
-4
python/sglang/srt/layers/attention/triton_ops/decode_attention.py
...glang/srt/layers/attention/triton_ops/decode_attention.py
+16
-16
python/sglang/srt/layers/attention/triton_ops/extend_attention.py
...glang/srt/layers/attention/triton_ops/extend_attention.py
+9
-9
python/sglang/srt/models/gpt_oss.py
python/sglang/srt/models/gpt_oss.py
+1
-1
No files found.
python/sglang/srt/layers/attention/triton_backend.py
View file @
399e7ec8
...
@@ -686,7 +686,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -686,7 +686,7 @@ class TritonAttnBackend(AttentionBackend):
layer
:
RadixAttention
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
save_kv_cache
=
True
,
save_kv_cache
=
True
,
s
k
=
None
,
s
inks
=
None
,
):
):
# TODO: reuse the buffer across layers
# TODO: reuse the buffer across layers
if
layer
.
qk_head_dim
!=
layer
.
v_head_dim
:
if
layer
.
qk_head_dim
!=
layer
.
v_head_dim
:
...
@@ -731,7 +731,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -731,7 +731,7 @@ class TritonAttnBackend(AttentionBackend):
layer
.
scaling
,
layer
.
scaling
,
layer
.
logit_cap
,
layer
.
logit_cap
,
sliding_window_size
=
sliding_window_size
,
sliding_window_size
=
sliding_window_size
,
s
k
=
sk
,
s
inks
=
sinks
,
)
)
return
o
return
o
...
@@ -743,7 +743,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -743,7 +743,7 @@ class TritonAttnBackend(AttentionBackend):
layer
:
RadixAttention
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
save_kv_cache
=
True
,
save_kv_cache
=
True
,
s
k
=
None
,
s
inks
=
None
,
):
):
# During torch.compile, there is a bug in rotary_emb that causes the
# During torch.compile, there is a bug in rotary_emb that causes the
# output value to have a 3D tensor shape. This reshapes the output correctly.
# output value to have a 3D tensor shape. This reshapes the output correctly.
...
@@ -780,7 +780,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -780,7 +780,7 @@ class TritonAttnBackend(AttentionBackend):
self
.
max_kv_splits
,
self
.
max_kv_splits
,
layer
.
scaling
,
layer
.
scaling
,
layer
.
logit_cap
,
layer
.
logit_cap
,
s
k
=
sk
,
s
inks
=
sinks
,
)
)
return
o
return
o
...
...
python/sglang/srt/layers/attention/triton_ops/decode_attention.py
View file @
399e7ec8
...
@@ -495,7 +495,7 @@ def _fwd_kernel_stage2(
...
@@ -495,7 +495,7 @@ def _fwd_kernel_stage2(
O
,
O
,
kv_indptr
,
kv_indptr
,
num_kv_splits
,
num_kv_splits
,
sk_ptr
,
s
in
k_ptr
,
stride_mid_ob
,
stride_mid_ob
,
stride_mid_oh
,
stride_mid_oh
,
stride_mid_os
,
stride_mid_os
,
...
@@ -505,7 +505,7 @@ def _fwd_kernel_stage2(
...
@@ -505,7 +505,7 @@ def _fwd_kernel_stage2(
MIN_BLOCK_KV
:
tl
.
constexpr
,
MIN_BLOCK_KV
:
tl
.
constexpr
,
BLOCK_DV
:
tl
.
constexpr
,
BLOCK_DV
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
HAS_SK
:
tl
.
constexpr
,
HAS_S
IN
K
:
tl
.
constexpr
,
):
):
cur_batch
=
tl
.
program_id
(
0
)
cur_batch
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
cur_head
=
tl
.
program_id
(
1
)
...
@@ -547,9 +547,9 @@ def _fwd_kernel_stage2(
...
@@ -547,9 +547,9 @@ def _fwd_kernel_stage2(
e_sum
=
e_sum
*
old_scale
+
exp_logic
e_sum
=
e_sum
*
old_scale
+
exp_logic
e_max
=
n_e_max
e_max
=
n_e_max
if
HAS_SK
:
if
HAS_S
IN
K
:
cur_sk
=
tl
.
load
(
sk_ptr
+
cur_head
)
cur_s
in
k
=
tl
.
load
(
s
in
k_ptr
+
cur_head
)
e_sum
+=
tl
.
exp
(
cur_sk
-
e_max
)
e_sum
+=
tl
.
exp
(
cur_s
in
k
-
e_max
)
tl
.
store
(
tl
.
store
(
O
+
cur_batch
*
stride_obs
+
cur_head
*
stride_oh
+
offs_d
,
O
+
cur_batch
*
stride_obs
+
cur_head
*
stride_oh
+
offs_d
,
...
@@ -567,14 +567,14 @@ def _decode_softmax_reducev_fwd(
...
@@ -567,14 +567,14 @@ def _decode_softmax_reducev_fwd(
kv_indptr
,
kv_indptr
,
num_kv_splits
,
num_kv_splits
,
max_kv_splits
,
max_kv_splits
,
s
k
=
None
,
s
inks
=
None
,
):
):
batch
,
head_num
=
q
.
shape
[
0
],
q
.
shape
[
1
]
batch
,
head_num
=
q
.
shape
[
0
],
q
.
shape
[
1
]
Lv
=
v_buffer
.
shape
[
-
1
]
Lv
=
v_buffer
.
shape
[
-
1
]
BLOCK_DV
=
triton
.
next_power_of_2
(
Lv
)
BLOCK_DV
=
triton
.
next_power_of_2
(
Lv
)
MAX_KV_SPLITS
=
max_kv_splits
MAX_KV_SPLITS
=
max_kv_splits
HAS_SK
=
s
k
is
not
None
HAS_S
IN
K
=
s
inks
is
not
None
extra_kargs
=
{}
extra_kargs
=
{}
if
_is_hip
:
if
_is_hip
:
...
@@ -589,7 +589,7 @@ def _decode_softmax_reducev_fwd(
...
@@ -589,7 +589,7 @@ def _decode_softmax_reducev_fwd(
o
,
o
,
kv_indptr
,
kv_indptr
,
num_kv_splits
,
num_kv_splits
,
s
k
,
s
inks
,
logits
.
stride
(
0
),
logits
.
stride
(
0
),
logits
.
stride
(
1
),
logits
.
stride
(
1
),
logits
.
stride
(
2
),
logits
.
stride
(
2
),
...
@@ -599,7 +599,7 @@ def _decode_softmax_reducev_fwd(
...
@@ -599,7 +599,7 @@ def _decode_softmax_reducev_fwd(
MIN_BLOCK_KV
=
_MIN_BLOCK_KV
,
MIN_BLOCK_KV
=
_MIN_BLOCK_KV
,
BLOCK_DV
=
BLOCK_DV
,
BLOCK_DV
=
BLOCK_DV
,
Lv
=
Lv
,
Lv
=
Lv
,
HAS_SK
=
HAS_SK
,
HAS_S
IN
K
=
HAS_S
IN
K
,
num_warps
=
4
,
num_warps
=
4
,
num_stages
=
2
,
num_stages
=
2
,
**
extra_kargs
,
**
extra_kargs
,
...
@@ -619,7 +619,7 @@ def decode_attention_fwd_normal(
...
@@ -619,7 +619,7 @@ def decode_attention_fwd_normal(
max_kv_splits
,
max_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
=
0.0
,
logit_cap
=
0.0
,
s
k
=
None
,
s
inks
=
None
,
):
):
_decode_att_m_fwd
(
_decode_att_m_fwd
(
q
,
q
,
...
@@ -643,7 +643,7 @@ def decode_attention_fwd_normal(
...
@@ -643,7 +643,7 @@ def decode_attention_fwd_normal(
kv_indptr
,
kv_indptr
,
num_kv_splits
,
num_kv_splits
,
max_kv_splits
,
max_kv_splits
,
s
k
,
s
inks
,
)
)
...
@@ -660,7 +660,7 @@ def decode_attention_fwd_grouped(
...
@@ -660,7 +660,7 @@ def decode_attention_fwd_grouped(
max_kv_splits
,
max_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
=
0.0
,
logit_cap
=
0.0
,
s
k
=
None
,
s
inks
=
None
,
):
):
_decode_grouped_att_m_fwd
(
_decode_grouped_att_m_fwd
(
q
,
q
,
...
@@ -684,7 +684,7 @@ def decode_attention_fwd_grouped(
...
@@ -684,7 +684,7 @@ def decode_attention_fwd_grouped(
kv_indptr
,
kv_indptr
,
num_kv_splits
,
num_kv_splits
,
max_kv_splits
,
max_kv_splits
,
s
k
,
s
inks
,
)
)
...
@@ -701,7 +701,7 @@ def decode_attention_fwd(
...
@@ -701,7 +701,7 @@ def decode_attention_fwd(
max_kv_splits
,
max_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
=
0.0
,
logit_cap
=
0.0
,
s
k
=
None
,
s
inks
=
None
,
):
):
assert
max_kv_splits
==
attn_logits
.
shape
[
2
]
assert
max_kv_splits
==
attn_logits
.
shape
[
2
]
assert
q
.
shape
[
0
]
<=
kv_indptr
.
shape
[
0
]
-
1
assert
q
.
shape
[
0
]
<=
kv_indptr
.
shape
[
0
]
-
1
...
@@ -724,7 +724,7 @@ def decode_attention_fwd(
...
@@ -724,7 +724,7 @@ def decode_attention_fwd(
max_kv_splits
,
max_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
=
logit_cap
,
logit_cap
=
logit_cap
,
s
k
=
sk
,
s
inks
=
sinks
,
)
)
else
:
else
:
# GQA/MQA/MLA
# GQA/MQA/MLA
...
@@ -741,5 +741,5 @@ def decode_attention_fwd(
...
@@ -741,5 +741,5 @@ def decode_attention_fwd(
max_kv_splits
,
max_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
=
logit_cap
,
logit_cap
=
logit_cap
,
s
k
=
sk
,
s
inks
=
sinks
,
)
)
python/sglang/srt/layers/attention/triton_ops/extend_attention.py
View file @
399e7ec8
...
@@ -51,7 +51,7 @@ def _fwd_kernel(
...
@@ -51,7 +51,7 @@ def _fwd_kernel(
kv_indices
,
kv_indices
,
mask_ptr
,
mask_ptr
,
mask_indptr
,
mask_indptr
,
sk_ptr
,
s
in
k_ptr
,
sm_scale
,
sm_scale
,
kv_group_num
,
kv_group_num
,
stride_qbs
,
stride_qbs
,
...
@@ -79,7 +79,7 @@ def _fwd_kernel(
...
@@ -79,7 +79,7 @@ def _fwd_kernel(
IS_CAUSAL
:
tl
.
constexpr
,
IS_CAUSAL
:
tl
.
constexpr
,
SKIP_PREFIX_CUSTOM_MASK
:
tl
.
constexpr
,
SKIP_PREFIX_CUSTOM_MASK
:
tl
.
constexpr
,
STORE_TRANSPOSE
:
tl
.
constexpr
,
STORE_TRANSPOSE
:
tl
.
constexpr
,
HAS_SK
:
tl
.
constexpr
,
HAS_S
IN
K
:
tl
.
constexpr
,
):
):
cur_seq
=
tl
.
program_id
(
0
)
cur_seq
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
cur_head
=
tl
.
program_id
(
1
)
...
@@ -302,9 +302,9 @@ def _fwd_kernel(
...
@@ -302,9 +302,9 @@ def _fwd_kernel(
e_max
=
n_e_max
e_max
=
n_e_max
if
HAS_SK
:
if
HAS_S
IN
K
:
cur_sk
=
tl
.
load
(
sk_ptr
+
cur_head
)
cur_s
in
k
=
tl
.
load
(
s
in
k_ptr
+
cur_head
)
deno
+=
tl
.
exp
(
cur_sk
-
e_max
)
deno
+=
tl
.
exp
(
cur_s
in
k
-
e_max
)
offs_o
=
(
offs_o
=
(
(
cur_seq_extend_start_idx
+
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
(
cur_seq_extend_start_idx
+
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
...
@@ -344,7 +344,7 @@ def extend_attention_fwd(
...
@@ -344,7 +344,7 @@ def extend_attention_fwd(
logit_cap
=
0.0
,
logit_cap
=
0.0
,
skip_prefix_custom_mask
=
True
,
skip_prefix_custom_mask
=
True
,
sliding_window_size
=-
1
,
sliding_window_size
=-
1
,
s
k
=
None
,
s
inks
=
None
,
):
):
"""
"""
q_extend, k_extend, v_extend, o_extend: contiguous tensors
q_extend, k_extend, v_extend, o_extend: contiguous tensors
...
@@ -410,7 +410,7 @@ def extend_attention_fwd(
...
@@ -410,7 +410,7 @@ def extend_attention_fwd(
# Skip custom mask for prefix part
# Skip custom mask for prefix part
SKIP_PREFIX_CUSTOM_MASK
=
skip_prefix_custom_mask
SKIP_PREFIX_CUSTOM_MASK
=
skip_prefix_custom_mask
HAS_SK
=
s
k
is
not
None
HAS_S
IN
K
=
s
inks
is
not
None
grid
=
(
batch_size
,
head_num
,
triton
.
cdiv
(
max_len_extend
,
BLOCK_M
))
grid
=
(
batch_size
,
head_num
,
triton
.
cdiv
(
max_len_extend
,
BLOCK_M
))
num_stages
=
1
num_stages
=
1
...
@@ -431,7 +431,7 @@ def extend_attention_fwd(
...
@@ -431,7 +431,7 @@ def extend_attention_fwd(
kv_indices
,
kv_indices
,
custom_mask
,
custom_mask
,
mask_indptr
,
mask_indptr
,
s
k
,
s
inks
,
sm_scale
,
sm_scale
,
kv_group_num
,
kv_group_num
,
q_extend
.
stride
(
0
),
q_extend
.
stride
(
0
),
...
@@ -458,7 +458,7 @@ def extend_attention_fwd(
...
@@ -458,7 +458,7 @@ def extend_attention_fwd(
USE_CUSTOM_MASK
=
USE_CUSTOM_MASK
,
USE_CUSTOM_MASK
=
USE_CUSTOM_MASK
,
IS_CAUSAL
=
is_causal
,
IS_CAUSAL
=
is_causal
,
SKIP_PREFIX_CUSTOM_MASK
=
SKIP_PREFIX_CUSTOM_MASK
,
SKIP_PREFIX_CUSTOM_MASK
=
SKIP_PREFIX_CUSTOM_MASK
,
HAS_SK
=
HAS_SK
,
HAS_S
IN
K
=
HAS_S
IN
K
,
STORE_TRANSPOSE
=
_is_hip
,
STORE_TRANSPOSE
=
_is_hip
,
num_warps
=
num_warps
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
num_stages
=
num_stages
,
...
...
python/sglang/srt/models/gpt_oss.py
View file @
399e7ec8
...
@@ -301,7 +301,7 @@ class GptOssAttention(nn.Module):
...
@@ -301,7 +301,7 @@ class GptOssAttention(nn.Module):
hidden_states
,
forward_batch
,
inner_state
=
intermediate_state
hidden_states
,
forward_batch
,
inner_state
=
intermediate_state
if
inner_state
is
None
:
if
inner_state
is
None
:
return
hidden_states
return
hidden_states
attn_output
=
self
.
attn
(
*
inner_state
,
s
k
=
self
.
sinks
)
attn_output
=
self
.
attn
(
*
inner_state
,
s
inks
=
self
.
sinks
)
output
,
_
=
self
.
o_proj
(
attn_output
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
return
output
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment