Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
be2d985d
"tests/vscode:/vscode.git/clone" did not exist on "faa1dc56fc7649b2a7c0db82bd33964708b3fd3d"
Unverified
Commit
be2d985d
authored
Jun 13, 2025
by
Lianmin Zheng
Committed by
GitHub
Jun 13, 2025
Browse files
Minor style change of triton backend (#7165)
parent
5b1afa78
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
113 additions
and
113 deletions
+113
-113
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+113
-113
No files found.
python/sglang/srt/layers/attention/triton_backend.py
View file @
be2d985d
...
...
@@ -20,117 +20,6 @@ if TYPE_CHECKING:
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
EagleVerifyInput
@
triton
.
jit
def
get_num_kv_splits_triton
(
num_kv_splits_ptr
,
seq_lens_ptr
,
num_seq
,
num_group
,
num_head
,
num_kv_head
,
max_kv_splits
,
device_core_count
,
MAX_NUM_SEQ
:
tl
.
constexpr
,
):
# TODO: this method is tunable, we need more online serving data to tune it
offs_seq
=
tl
.
arange
(
0
,
MAX_NUM_SEQ
)
mask_seq
=
offs_seq
<
num_seq
seq_lens
=
tl
.
load
(
seq_lens_ptr
+
offs_seq
,
mask
=
mask_seq
,
other
=
0
)
max_seq_len
=
tl
.
max
(
seq_lens
)
seq_lens
=
tl
.
load
(
seq_lens_ptr
+
offs_seq
,
mask
=
mask_seq
,
other
=
max_seq_len
)
min_seq_len
=
tl
.
min
(
seq_lens
)
if
max_seq_len
*
8
<
min_seq_len
*
10
:
min_seq_len
=
max_seq_len
max_kv_splits_1
=
tl
.
minimum
(
tl
.
cdiv
(
max_seq_len
,
min_seq_len
),
max_kv_splits
)
kv_chunk_size_1
=
tl
.
cdiv
(
max_seq_len
,
max_kv_splits_1
)
# NOTE: this is a hack to let num_kv_split grows up with seqlen gradually
ext_seq_len
=
tl
.
cast
(
max_seq_len
,
tl
.
float32
)
/
64.0
ext_device_core_count
=
tl
.
cast
(
device_core_count
*
tl
.
maximum
(
tl
.
log2
(
ext_seq_len
),
1.0
),
tl
.
int32
)
block_h
,
num_kv_group
=
16
,
num_head
//
num_kv_head
if
num_kv_group
==
1
:
token_grid
=
num_seq
*
num_group
*
num_head
else
:
# from triton_ops/decode_attention.py:_decode_grouped_att_m_fwd
block_h
=
tl
.
minimum
(
block_h
,
num_kv_group
)
token_grid
=
num_seq
*
num_group
*
tl
.
cdiv
(
num_head
,
block_h
)
max_kv_splits_2
=
tl
.
minimum
(
tl
.
cdiv
(
ext_device_core_count
,
token_grid
),
max_kv_splits
)
kv_chunk_size_2
=
tl
.
cdiv
(
max_seq_len
,
max_kv_splits_2
)
num_kv_splits
=
tl
.
maximum
(
tl
.
cdiv
(
seq_lens
,
kv_chunk_size_1
),
tl
.
cdiv
(
seq_lens
,
kv_chunk_size_2
)
)
offs_token
=
offs_seq
*
num_group
mask_token
=
offs_token
<
num_seq
*
num_group
for
i
in
range
(
0
,
num_group
):
tl
.
store
(
num_kv_splits_ptr
+
i
+
offs_token
,
num_kv_splits
,
mask
=
mask_token
)
def
update_sliding_window_buffer
(
window_kv_indptr
,
req_to_token
,
sliding_window_size
,
seq_lens
,
req_pool_indices
,
bs
,
device
,
):
window_kv_lens
=
torch
.
minimum
(
seq_lens
,
torch
.
tensor
(
sliding_window_size
+
1
),
)
window_kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
window_kv_lens
,
dim
=
0
)
window_kv_indptr
=
window_kv_indptr
[:
bs
+
1
]
window_kv_indices
=
torch
.
empty
(
window_kv_indptr
[
-
1
],
dtype
=
torch
.
int32
,
device
=
device
)
window_kv_start_idx
=
seq_lens
-
window_kv_lens
create_flashinfer_kv_indices_triton
[(
bs
,)](
req_to_token
,
req_pool_indices
,
window_kv_lens
,
window_kv_indptr
,
window_kv_start_idx
,
window_kv_indices
,
req_to_token
.
stride
(
0
),
)
return
window_kv_indptr
,
window_kv_indices
,
window_kv_lens
def
update_sliding_window_buffer_cuda_graph
(
window_kv_indptr
,
window_kv_indices
,
req_to_token
,
sliding_window_size
,
seq_lens
,
req_pool_indices
,
bs
,
):
window_kv_lens
=
torch
.
minimum
(
seq_lens
,
torch
.
tensor
(
sliding_window_size
+
1
),
)
window_kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
window_kv_lens
,
dim
=
0
)
window_kv_indptr
=
window_kv_indptr
[:
bs
+
1
]
window_kv_start_idx
=
seq_lens
-
window_kv_lens
create_flashinfer_kv_indices_triton
[(
bs
,)](
req_to_token
,
req_pool_indices
,
window_kv_lens
,
window_kv_indptr
,
window_kv_start_idx
,
window_kv_indices
,
req_to_token
.
stride
(
0
),
)
return
window_kv_indptr
,
window_kv_lens
@
dataclass
class
ForwardMetadata
:
attn_logits
:
torch
.
Tensor
...
...
@@ -165,8 +54,8 @@ class TritonAttnBackend(AttentionBackend):
super
().
__init__
()
self
.
decode_attention_fwd
=
decode_attention_fwd
self
.
extend_attention_fwd
=
extend_attention_fwd
self
.
decode_attention_fwd
=
torch
.
compiler
.
disable
(
decode_attention_fwd
)
self
.
extend_attention_fwd
=
torch
.
compiler
.
disable
(
extend_attention_fwd
)
self
.
skip_prefill
=
skip_prefill
...
...
@@ -973,3 +862,114 @@ class TritonMultiStepDraftBackend:
)
self
.
common_template
(
forward_batch
,
self
.
cuda_graph_kv_indices
,
call_fn
)
@
triton
.
jit
def
get_num_kv_splits_triton
(
num_kv_splits_ptr
,
seq_lens_ptr
,
num_seq
,
num_group
,
num_head
,
num_kv_head
,
max_kv_splits
,
device_core_count
,
MAX_NUM_SEQ
:
tl
.
constexpr
,
):
# TODO: this method is tunable, we need more online serving data to tune it
offs_seq
=
tl
.
arange
(
0
,
MAX_NUM_SEQ
)
mask_seq
=
offs_seq
<
num_seq
seq_lens
=
tl
.
load
(
seq_lens_ptr
+
offs_seq
,
mask
=
mask_seq
,
other
=
0
)
max_seq_len
=
tl
.
max
(
seq_lens
)
seq_lens
=
tl
.
load
(
seq_lens_ptr
+
offs_seq
,
mask
=
mask_seq
,
other
=
max_seq_len
)
min_seq_len
=
tl
.
min
(
seq_lens
)
if
max_seq_len
*
8
<
min_seq_len
*
10
:
min_seq_len
=
max_seq_len
max_kv_splits_1
=
tl
.
minimum
(
tl
.
cdiv
(
max_seq_len
,
min_seq_len
),
max_kv_splits
)
kv_chunk_size_1
=
tl
.
cdiv
(
max_seq_len
,
max_kv_splits_1
)
# NOTE: this is a hack to let num_kv_split grows up with seqlen gradually
ext_seq_len
=
tl
.
cast
(
max_seq_len
,
tl
.
float32
)
/
64.0
ext_device_core_count
=
tl
.
cast
(
device_core_count
*
tl
.
maximum
(
tl
.
log2
(
ext_seq_len
),
1.0
),
tl
.
int32
)
block_h
,
num_kv_group
=
16
,
num_head
//
num_kv_head
if
num_kv_group
==
1
:
token_grid
=
num_seq
*
num_group
*
num_head
else
:
# from triton_ops/decode_attention.py:_decode_grouped_att_m_fwd
block_h
=
tl
.
minimum
(
block_h
,
num_kv_group
)
token_grid
=
num_seq
*
num_group
*
tl
.
cdiv
(
num_head
,
block_h
)
max_kv_splits_2
=
tl
.
minimum
(
tl
.
cdiv
(
ext_device_core_count
,
token_grid
),
max_kv_splits
)
kv_chunk_size_2
=
tl
.
cdiv
(
max_seq_len
,
max_kv_splits_2
)
num_kv_splits
=
tl
.
maximum
(
tl
.
cdiv
(
seq_lens
,
kv_chunk_size_1
),
tl
.
cdiv
(
seq_lens
,
kv_chunk_size_2
)
)
offs_token
=
offs_seq
*
num_group
mask_token
=
offs_token
<
num_seq
*
num_group
for
i
in
range
(
0
,
num_group
):
tl
.
store
(
num_kv_splits_ptr
+
i
+
offs_token
,
num_kv_splits
,
mask
=
mask_token
)
def
update_sliding_window_buffer
(
window_kv_indptr
,
req_to_token
,
sliding_window_size
,
seq_lens
,
req_pool_indices
,
bs
,
device
,
):
window_kv_lens
=
torch
.
minimum
(
seq_lens
,
torch
.
tensor
(
sliding_window_size
+
1
),
)
window_kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
window_kv_lens
,
dim
=
0
)
window_kv_indptr
=
window_kv_indptr
[:
bs
+
1
]
window_kv_indices
=
torch
.
empty
(
window_kv_indptr
[
-
1
],
dtype
=
torch
.
int32
,
device
=
device
)
window_kv_start_idx
=
seq_lens
-
window_kv_lens
create_flashinfer_kv_indices_triton
[(
bs
,)](
req_to_token
,
req_pool_indices
,
window_kv_lens
,
window_kv_indptr
,
window_kv_start_idx
,
window_kv_indices
,
req_to_token
.
stride
(
0
),
)
return
window_kv_indptr
,
window_kv_indices
,
window_kv_lens
def
update_sliding_window_buffer_cuda_graph
(
window_kv_indptr
,
window_kv_indices
,
req_to_token
,
sliding_window_size
,
seq_lens
,
req_pool_indices
,
bs
,
):
window_kv_lens
=
torch
.
minimum
(
seq_lens
,
torch
.
tensor
(
sliding_window_size
+
1
),
)
window_kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
window_kv_lens
,
dim
=
0
)
window_kv_indptr
=
window_kv_indptr
[:
bs
+
1
]
window_kv_start_idx
=
seq_lens
-
window_kv_lens
create_flashinfer_kv_indices_triton
[(
bs
,)](
req_to_token
,
req_pool_indices
,
window_kv_lens
,
window_kv_indptr
,
window_kv_start_idx
,
window_kv_indices
,
req_to_token
.
stride
(
0
),
)
return
window_kv_indptr
,
window_kv_lens
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment