Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
be2d985d
Unverified
Commit
be2d985d
authored
Jun 13, 2025
by
Lianmin Zheng
Committed by
GitHub
Jun 13, 2025
Browse files
Minor style change of triton backend (#7165)
parent
5b1afa78
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
113 additions
and
113 deletions
+113
-113
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+113
-113
No files found.
python/sglang/srt/layers/attention/triton_backend.py
View file @
be2d985d
...
@@ -20,117 +20,6 @@ if TYPE_CHECKING:
...
@@ -20,117 +20,6 @@ if TYPE_CHECKING:
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
EagleVerifyInput
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
EagleVerifyInput
@
triton
.
jit
def
get_num_kv_splits_triton
(
num_kv_splits_ptr
,
seq_lens_ptr
,
num_seq
,
num_group
,
num_head
,
num_kv_head
,
max_kv_splits
,
device_core_count
,
MAX_NUM_SEQ
:
tl
.
constexpr
,
):
# TODO: this method is tunable, we need more online serving data to tune it
offs_seq
=
tl
.
arange
(
0
,
MAX_NUM_SEQ
)
mask_seq
=
offs_seq
<
num_seq
seq_lens
=
tl
.
load
(
seq_lens_ptr
+
offs_seq
,
mask
=
mask_seq
,
other
=
0
)
max_seq_len
=
tl
.
max
(
seq_lens
)
seq_lens
=
tl
.
load
(
seq_lens_ptr
+
offs_seq
,
mask
=
mask_seq
,
other
=
max_seq_len
)
min_seq_len
=
tl
.
min
(
seq_lens
)
if
max_seq_len
*
8
<
min_seq_len
*
10
:
min_seq_len
=
max_seq_len
max_kv_splits_1
=
tl
.
minimum
(
tl
.
cdiv
(
max_seq_len
,
min_seq_len
),
max_kv_splits
)
kv_chunk_size_1
=
tl
.
cdiv
(
max_seq_len
,
max_kv_splits_1
)
# NOTE: this is a hack to let num_kv_split grows up with seqlen gradually
ext_seq_len
=
tl
.
cast
(
max_seq_len
,
tl
.
float32
)
/
64.0
ext_device_core_count
=
tl
.
cast
(
device_core_count
*
tl
.
maximum
(
tl
.
log2
(
ext_seq_len
),
1.0
),
tl
.
int32
)
block_h
,
num_kv_group
=
16
,
num_head
//
num_kv_head
if
num_kv_group
==
1
:
token_grid
=
num_seq
*
num_group
*
num_head
else
:
# from triton_ops/decode_attention.py:_decode_grouped_att_m_fwd
block_h
=
tl
.
minimum
(
block_h
,
num_kv_group
)
token_grid
=
num_seq
*
num_group
*
tl
.
cdiv
(
num_head
,
block_h
)
max_kv_splits_2
=
tl
.
minimum
(
tl
.
cdiv
(
ext_device_core_count
,
token_grid
),
max_kv_splits
)
kv_chunk_size_2
=
tl
.
cdiv
(
max_seq_len
,
max_kv_splits_2
)
num_kv_splits
=
tl
.
maximum
(
tl
.
cdiv
(
seq_lens
,
kv_chunk_size_1
),
tl
.
cdiv
(
seq_lens
,
kv_chunk_size_2
)
)
offs_token
=
offs_seq
*
num_group
mask_token
=
offs_token
<
num_seq
*
num_group
for
i
in
range
(
0
,
num_group
):
tl
.
store
(
num_kv_splits_ptr
+
i
+
offs_token
,
num_kv_splits
,
mask
=
mask_token
)
def
update_sliding_window_buffer
(
window_kv_indptr
,
req_to_token
,
sliding_window_size
,
seq_lens
,
req_pool_indices
,
bs
,
device
,
):
window_kv_lens
=
torch
.
minimum
(
seq_lens
,
torch
.
tensor
(
sliding_window_size
+
1
),
)
window_kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
window_kv_lens
,
dim
=
0
)
window_kv_indptr
=
window_kv_indptr
[:
bs
+
1
]
window_kv_indices
=
torch
.
empty
(
window_kv_indptr
[
-
1
],
dtype
=
torch
.
int32
,
device
=
device
)
window_kv_start_idx
=
seq_lens
-
window_kv_lens
create_flashinfer_kv_indices_triton
[(
bs
,)](
req_to_token
,
req_pool_indices
,
window_kv_lens
,
window_kv_indptr
,
window_kv_start_idx
,
window_kv_indices
,
req_to_token
.
stride
(
0
),
)
return
window_kv_indptr
,
window_kv_indices
,
window_kv_lens
def
update_sliding_window_buffer_cuda_graph
(
window_kv_indptr
,
window_kv_indices
,
req_to_token
,
sliding_window_size
,
seq_lens
,
req_pool_indices
,
bs
,
):
window_kv_lens
=
torch
.
minimum
(
seq_lens
,
torch
.
tensor
(
sliding_window_size
+
1
),
)
window_kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
window_kv_lens
,
dim
=
0
)
window_kv_indptr
=
window_kv_indptr
[:
bs
+
1
]
window_kv_start_idx
=
seq_lens
-
window_kv_lens
create_flashinfer_kv_indices_triton
[(
bs
,)](
req_to_token
,
req_pool_indices
,
window_kv_lens
,
window_kv_indptr
,
window_kv_start_idx
,
window_kv_indices
,
req_to_token
.
stride
(
0
),
)
return
window_kv_indptr
,
window_kv_lens
@
dataclass
@
dataclass
class
ForwardMetadata
:
class
ForwardMetadata
:
attn_logits
:
torch
.
Tensor
attn_logits
:
torch
.
Tensor
...
@@ -165,8 +54,8 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -165,8 +54,8 @@ class TritonAttnBackend(AttentionBackend):
super
().
__init__
()
super
().
__init__
()
self
.
decode_attention_fwd
=
decode_attention_fwd
self
.
decode_attention_fwd
=
torch
.
compiler
.
disable
(
decode_attention_fwd
)
self
.
extend_attention_fwd
=
extend_attention_fwd
self
.
extend_attention_fwd
=
torch
.
compiler
.
disable
(
extend_attention_fwd
)
self
.
skip_prefill
=
skip_prefill
self
.
skip_prefill
=
skip_prefill
...
@@ -973,3 +862,114 @@ class TritonMultiStepDraftBackend:
...
@@ -973,3 +862,114 @@ class TritonMultiStepDraftBackend:
)
)
self
.
common_template
(
forward_batch
,
self
.
cuda_graph_kv_indices
,
call_fn
)
self
.
common_template
(
forward_batch
,
self
.
cuda_graph_kv_indices
,
call_fn
)
@
triton
.
jit
def
get_num_kv_splits_triton
(
num_kv_splits_ptr
,
seq_lens_ptr
,
num_seq
,
num_group
,
num_head
,
num_kv_head
,
max_kv_splits
,
device_core_count
,
MAX_NUM_SEQ
:
tl
.
constexpr
,
):
# TODO: this method is tunable, we need more online serving data to tune it
offs_seq
=
tl
.
arange
(
0
,
MAX_NUM_SEQ
)
mask_seq
=
offs_seq
<
num_seq
seq_lens
=
tl
.
load
(
seq_lens_ptr
+
offs_seq
,
mask
=
mask_seq
,
other
=
0
)
max_seq_len
=
tl
.
max
(
seq_lens
)
seq_lens
=
tl
.
load
(
seq_lens_ptr
+
offs_seq
,
mask
=
mask_seq
,
other
=
max_seq_len
)
min_seq_len
=
tl
.
min
(
seq_lens
)
if
max_seq_len
*
8
<
min_seq_len
*
10
:
min_seq_len
=
max_seq_len
max_kv_splits_1
=
tl
.
minimum
(
tl
.
cdiv
(
max_seq_len
,
min_seq_len
),
max_kv_splits
)
kv_chunk_size_1
=
tl
.
cdiv
(
max_seq_len
,
max_kv_splits_1
)
# NOTE: this is a hack to let num_kv_split grows up with seqlen gradually
ext_seq_len
=
tl
.
cast
(
max_seq_len
,
tl
.
float32
)
/
64.0
ext_device_core_count
=
tl
.
cast
(
device_core_count
*
tl
.
maximum
(
tl
.
log2
(
ext_seq_len
),
1.0
),
tl
.
int32
)
block_h
,
num_kv_group
=
16
,
num_head
//
num_kv_head
if
num_kv_group
==
1
:
token_grid
=
num_seq
*
num_group
*
num_head
else
:
# from triton_ops/decode_attention.py:_decode_grouped_att_m_fwd
block_h
=
tl
.
minimum
(
block_h
,
num_kv_group
)
token_grid
=
num_seq
*
num_group
*
tl
.
cdiv
(
num_head
,
block_h
)
max_kv_splits_2
=
tl
.
minimum
(
tl
.
cdiv
(
ext_device_core_count
,
token_grid
),
max_kv_splits
)
kv_chunk_size_2
=
tl
.
cdiv
(
max_seq_len
,
max_kv_splits_2
)
num_kv_splits
=
tl
.
maximum
(
tl
.
cdiv
(
seq_lens
,
kv_chunk_size_1
),
tl
.
cdiv
(
seq_lens
,
kv_chunk_size_2
)
)
offs_token
=
offs_seq
*
num_group
mask_token
=
offs_token
<
num_seq
*
num_group
for
i
in
range
(
0
,
num_group
):
tl
.
store
(
num_kv_splits_ptr
+
i
+
offs_token
,
num_kv_splits
,
mask
=
mask_token
)
def
update_sliding_window_buffer
(
window_kv_indptr
,
req_to_token
,
sliding_window_size
,
seq_lens
,
req_pool_indices
,
bs
,
device
,
):
window_kv_lens
=
torch
.
minimum
(
seq_lens
,
torch
.
tensor
(
sliding_window_size
+
1
),
)
window_kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
window_kv_lens
,
dim
=
0
)
window_kv_indptr
=
window_kv_indptr
[:
bs
+
1
]
window_kv_indices
=
torch
.
empty
(
window_kv_indptr
[
-
1
],
dtype
=
torch
.
int32
,
device
=
device
)
window_kv_start_idx
=
seq_lens
-
window_kv_lens
create_flashinfer_kv_indices_triton
[(
bs
,)](
req_to_token
,
req_pool_indices
,
window_kv_lens
,
window_kv_indptr
,
window_kv_start_idx
,
window_kv_indices
,
req_to_token
.
stride
(
0
),
)
return
window_kv_indptr
,
window_kv_indices
,
window_kv_lens
def
update_sliding_window_buffer_cuda_graph
(
window_kv_indptr
,
window_kv_indices
,
req_to_token
,
sliding_window_size
,
seq_lens
,
req_pool_indices
,
bs
,
):
window_kv_lens
=
torch
.
minimum
(
seq_lens
,
torch
.
tensor
(
sliding_window_size
+
1
),
)
window_kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
window_kv_lens
,
dim
=
0
)
window_kv_indptr
=
window_kv_indptr
[:
bs
+
1
]
window_kv_start_idx
=
seq_lens
-
window_kv_lens
create_flashinfer_kv_indices_triton
[(
bs
,)](
req_to_token
,
req_pool_indices
,
window_kv_lens
,
window_kv_indptr
,
window_kv_start_idx
,
window_kv_indices
,
req_to_token
.
stride
(
0
),
)
return
window_kv_indptr
,
window_kv_lens
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment