Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
c0e9a36c
Unverified
Commit
c0e9a36c
authored
Mar 19, 2025
by
JieXin Liang
Committed by
GitHub
Mar 18, 2025
Browse files
Optimize Triton decoding kernel for dynamic workload (#4553)
parent
588865f0
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
277 additions
and
57 deletions
+277
-57
python/sglang/srt/layers/attention/base_attn_backend.py
python/sglang/srt/layers/attention/base_attn_backend.py
+1
-0
python/sglang/srt/layers/attention/flashinfer_backend.py
python/sglang/srt/layers/attention/flashinfer_backend.py
+2
-0
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
+2
-0
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+142
-15
python/sglang/srt/layers/attention/triton_ops/decode_attention.py
...glang/srt/layers/attention/triton_ops/decode_attention.py
+92
-34
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+10
-0
test/srt/test_triton_attention_kernels.py
test/srt/test_triton_attention_kernels.py
+28
-8
No files found.
python/sglang/srt/layers/attention/base_attn_backend.py
View file @
c0e9a36c
...
...
@@ -39,6 +39,7 @@ class AttentionBackend(ABC):
def
init_forward_metadata_replay_cuda_graph
(
self
,
bs
:
int
,
num_kv_heads
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
...
...
python/sglang/srt/layers/attention/flashinfer_backend.py
View file @
c0e9a36c
...
...
@@ -349,6 +349,7 @@ class FlashInferAttnBackend(AttentionBackend):
def
init_forward_metadata_replay_cuda_graph
(
self
,
bs
:
int
,
num_kv_heads
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
...
...
@@ -1062,6 +1063,7 @@ class FlashInferMultiStepDraftBackend:
def
call_fn
(
i
,
forward_batch
):
self
.
attn_backends
[
i
].
init_forward_metadata_replay_cuda_graph
(
bs
,
-
1
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
seq_lens_sum
=-
1
,
...
...
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
View file @
c0e9a36c
...
...
@@ -279,6 +279,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
def
init_forward_metadata_replay_cuda_graph
(
self
,
bs
:
int
,
num_kv_heads
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
...
...
@@ -791,6 +792,7 @@ class FlashInferMLAMultiStepDraftBackend:
def
call_fn
(
i
,
forward_batch
):
self
.
attn_backends
[
i
].
init_forward_metadata_replay_cuda_graph
(
bs
,
-
1
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
seq_lens_sum
=-
1
,
...
...
python/sglang/srt/layers/attention/triton_backend.py
View file @
c0e9a36c
...
...
@@ -4,11 +4,13 @@ from typing import TYPE_CHECKING, Optional, Union
import
torch
import
triton
import
triton.language
as
tl
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.layers.attention.utils
import
create_flashinfer_kv_indices_triton
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.utils
import
get_bool_env_var
,
get_device_core_count
if
TYPE_CHECKING
:
from
sglang.srt.layers.radix_attention
import
RadixAttention
...
...
@@ -16,6 +18,51 @@ if TYPE_CHECKING:
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
EagleVerifyInput
@
triton
.
jit
def
get_num_kv_splits_triton
(
num_kv_splits_ptr
,
seq_lens_ptr
,
bs
,
num_head
,
num_kv_head
,
max_kv_splits
,
device_core_count
,
MAX_BS
:
tl
.
constexpr
,
):
# TODO: this method is tunable
offs_b
=
tl
.
arange
(
0
,
MAX_BS
)
mask_b
=
offs_b
<
bs
seq_lens
=
tl
.
load
(
seq_lens_ptr
+
offs_b
,
mask
=
mask_b
,
other
=
0
)
max_seq_len
=
tl
.
max
(
seq_lens
)
seq_lens
=
tl
.
load
(
seq_lens_ptr
+
offs_b
,
mask
=
mask_b
,
other
=
max_seq_len
)
min_seq_len
=
tl
.
min
(
seq_lens
)
if
max_seq_len
*
8
<
min_seq_len
*
10
:
min_seq_len
=
max_seq_len
max_kv_splits_1
=
tl
.
minimum
(
tl
.
cdiv
(
max_seq_len
,
min_seq_len
),
max_kv_splits
)
kv_chunk_size_1
=
tl
.
cdiv
(
max_seq_len
,
max_kv_splits_1
)
# NOTE: this is a hack to let num_kv_split grows up with seqlen gradually
ext_seq_len
=
tl
.
cast
(
tl
.
cdiv
(
max_seq_len
,
256
),
tl
.
float32
)
ext_device_core_count
=
device_core_count
*
tl
.
maximum
(
tl
.
cast
(
tl
.
ceil
(
tl
.
log2
(
ext_seq_len
)),
tl
.
int32
),
1
)
block_h
,
num_kv_group
=
16
,
num_head
//
num_kv_head
if
num_kv_group
==
1
:
bh_grid
=
bs
*
num_head
else
:
# from triton_ops/decode_attention.py:_decode_grouped_att_m_fwd
block_h
=
tl
.
minimum
(
block_h
,
num_kv_group
)
bh_grid
=
bs
*
tl
.
cdiv
(
num_head
,
block_h
)
max_kv_splits_2
=
tl
.
minimum
(
tl
.
cdiv
(
ext_device_core_count
,
bh_grid
),
max_kv_splits
)
kv_chunk_size_2
=
tl
.
cdiv
(
max_seq_len
,
max_kv_splits_2
)
num_kv_splits
=
tl
.
maximum
(
tl
.
cdiv
(
seq_lens
,
kv_chunk_size_1
),
tl
.
cdiv
(
seq_lens
,
kv_chunk_size_2
)
)
tl
.
store
(
num_kv_splits_ptr
+
offs_b
,
num_kv_splits
,
mask
=
mask_b
)
class
TritonAttnBackend
(
AttentionBackend
):
def
__init__
(
self
,
...
...
@@ -64,7 +111,10 @@ class TritonAttnBackend(AttentionBackend):
model_runner
.
model_config
.
num_attention_heads
//
get_attention_tp_size
()
)
self
.
num_kv_splits
=
model_runner
.
server_args
.
triton_attention_num_kv_splits
self
.
static_kv_splits
=
get_bool_env_var
(
"SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS"
,
"false"
)
self
.
max_kv_splits
=
model_runner
.
server_args
.
triton_attention_num_kv_splits
self
.
v_head_dim
=
model_runner
.
token_to_kv_pool
.
get_value_buffer
(
0
).
shape
[
-
1
]
self
.
forward_metadata
=
None
...
...
@@ -72,6 +122,30 @@ class TritonAttnBackend(AttentionBackend):
self
.
max_context_len
=
model_runner
.
model_config
.
context_len
self
.
device
=
model_runner
.
device
self
.
device_core_count
=
get_device_core_count
(
model_runner
.
gpu_id
)
def
get_num_kv_splits
(
self
,
num_kv_splits
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
bs
:
int
,
num_kv_head
:
int
,
):
MAX_SCHEDULE_BS
=
4096
if
self
.
static_kv_splits
or
self
.
device_core_count
<=
0
or
bs
>
MAX_SCHEDULE_BS
:
num_kv_splits
.
fill_
(
self
.
max_kv_splits
)
return
get_num_kv_splits_triton
[(
1
,)](
num_kv_splits
,
seq_lens
,
bs
,
self
.
num_head
,
num_kv_head
,
self
.
max_kv_splits
,
self
.
device_core_count
,
MAX_BS
=
MAX_SCHEDULE_BS
,
)
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
"""Init auxiliary variables for triton attention backend."""
...
...
@@ -100,15 +174,35 @@ class TritonAttnBackend(AttentionBackend):
kv_indptr
,
kv_indices
=
spec_info
.
kv_indptr
,
spec_info
.
kv_indices
bs
=
kv_indptr
.
shape
[
0
]
-
1
attn_logits
=
torch
.
empty
(
(
bs
,
self
.
num_head
,
self
.
num_kv_splits
,
self
.
v_head_dim
+
1
,
attn_logits
=
[
torch
.
empty
(
(
bs
,
self
.
num_head
,
self
.
max_kv_splits
,
self
.
v_head_dim
,
),
dtype
=
torch
.
float32
,
device
=
self
.
device
,
),
dtype
=
torch
.
float32
,
device
=
self
.
device
,
torch
.
empty
(
(
bs
,
self
.
num_head
,
self
.
max_kv_splits
,
),
dtype
=
torch
.
float32
,
device
=
self
.
device
,
),
]
num_kv_splits
=
torch
.
empty
((
bs
,),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
num_kv_heads
=
self
.
num_head
if
hasattr
(
forward_batch
.
token_to_kv_pool
,
"k_buffer"
):
if
isinstance
(
forward_batch
.
token_to_kv_pool
.
k_buffer
,
list
):
num_kv_heads
=
forward_batch
.
token_to_kv_pool
.
k_buffer
[
0
].
shape
[
1
]
self
.
get_num_kv_splits
(
num_kv_splits
,
forward_batch
.
seq_lens
,
bs
,
num_kv_heads
)
qo_indptr
=
None
...
...
@@ -148,6 +242,7 @@ class TritonAttnBackend(AttentionBackend):
mask_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
seq_mask_len
[:
bs
],
dim
=
0
)
mask_indptr
=
mask_indptr
[:
bs
+
1
]
max_extend_len
=
self
.
num_draft_tokens
num_kv_splits
=
None
attn_logits
=
None
elif
forward_batch
.
forward_mode
.
is_draft_extend
():
kv_indices
,
kv_indptr
,
qo_indptr
,
custom_mask
=
(
...
...
@@ -160,6 +255,7 @@ class TritonAttnBackend(AttentionBackend):
)
mask_indptr
=
None
max_extend_len
=
torch
.
max
(
spec_info
.
accept_length
).
item
()
num_kv_splits
=
None
attn_logits
=
None
else
:
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
...
...
@@ -188,10 +284,12 @@ class TritonAttnBackend(AttentionBackend):
mask_indptr
=
None
attn_logits
=
None
max_extend_len
=
torch
.
max
(
forward_batch
.
extend_seq_lens
).
item
()
num_kv_splits
=
None
self
.
forward_metadata
=
(
attn_logits
,
max_extend_len
,
num_kv_splits
,
kv_indptr
,
kv_indices
,
qo_indptr
,
...
...
@@ -202,10 +300,20 @@ class TritonAttnBackend(AttentionBackend):
def
init_cuda_graph_state
(
self
,
max_bs
:
int
,
kv_indices_buf
:
Optional
[
torch
.
Tensor
]
=
None
):
self
.
cuda_graph_attn_logits
=
torch
.
zeros
(
(
max_bs
,
self
.
num_head
,
self
.
num_kv_splits
,
self
.
v_head_dim
+
1
),
dtype
=
torch
.
float32
,
device
=
self
.
device
,
self
.
cuda_graph_attn_logits
=
[
torch
.
zeros
(
(
max_bs
,
self
.
num_head
,
self
.
max_kv_splits
,
self
.
v_head_dim
),
dtype
=
torch
.
float32
,
device
=
self
.
device
,
),
torch
.
zeros
(
(
max_bs
,
self
.
num_head
,
self
.
max_kv_splits
),
dtype
=
torch
.
float32
,
device
=
self
.
device
,
),
]
self
.
cuda_graph_num_kv_splits
=
torch
.
full
(
(
max_bs
,),
self
.
max_kv_splits
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
if
kv_indices_buf
is
None
:
self
.
cuda_graph_kv_indices
=
torch
.
zeros
(
...
...
@@ -255,6 +363,7 @@ class TritonAttnBackend(AttentionBackend):
attn_logits
=
self
.
cuda_graph_attn_logits
max_extend_len
=
None
num_kv_splits
=
self
.
cuda_graph_num_kv_splits
qo_indptr
=
None
custom_mask
=
None
mask_indptr
=
None
...
...
@@ -285,6 +394,7 @@ class TritonAttnBackend(AttentionBackend):
mask_indptr
=
self
.
mask_indptr
[:
bs
+
1
]
mask_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
seq_mask_len
,
dim
=
0
)
max_extend_len
=
self
.
num_draft_tokens
num_kv_splits
=
None
attn_logits
=
None
else
:
raise
ValueError
(
...
...
@@ -294,6 +404,7 @@ class TritonAttnBackend(AttentionBackend):
self
.
forward_metadata
=
(
attn_logits
,
max_extend_len
,
num_kv_splits
,
kv_indptr
,
kv_indices
,
qo_indptr
,
...
...
@@ -304,6 +415,7 @@ class TritonAttnBackend(AttentionBackend):
def
init_forward_metadata_replay_cuda_graph
(
self
,
bs
:
int
,
num_kv_head
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
...
...
@@ -317,6 +429,7 @@ class TritonAttnBackend(AttentionBackend):
# Update kv_indptr, kv_indices
kv_indptr
=
self
.
kv_indptr
kv_indices
=
self
.
cuda_graph_kv_indices
num_kv_splits
=
self
.
cuda_graph_num_kv_splits
if
spec_info
is
None
:
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
seq_lens
[:
bs
],
dim
=
0
)
kv_indptr
=
kv_indptr
[:
bs
+
1
]
...
...
@@ -332,6 +445,7 @@ class TritonAttnBackend(AttentionBackend):
else
:
kv_indptr
[:
spec_info
.
kv_indptr
.
shape
[
0
]]
=
spec_info
.
kv_indptr
kv_indices
[:
spec_info
.
kv_indices
.
shape
[
0
]]
=
spec_info
.
kv_indices
self
.
get_num_kv_splits
(
num_kv_splits
,
seq_lens
,
bs
,
num_kv_head
)
elif
forward_mode
.
is_target_verify
():
# Update qo_indptr, kv_indptr, kv_indices, custom_mask, mask_indptr
bs
=
len
(
req_pool_indices
)
...
...
@@ -391,6 +505,7 @@ class TritonAttnBackend(AttentionBackend):
(
_
,
max_extend_len
,
_
,
kv_indptr
,
kv_indices
,
qo_indptr
,
...
...
@@ -435,7 +550,9 @@ class TritonAttnBackend(AttentionBackend):
else
:
o
=
torch
.
empty_like
(
q
)
attn_logits
,
_
,
kv_indptr
,
kv_indices
,
_
,
_
,
_
=
self
.
forward_metadata
attn_logits
,
_
,
num_kv_splits
,
kv_indptr
,
kv_indices
,
_
,
_
,
_
=
(
self
.
forward_metadata
)
if
save_kv_cache
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
...
...
@@ -450,7 +567,8 @@ class TritonAttnBackend(AttentionBackend):
kv_indptr
,
kv_indices
,
attn_logits
,
self
.
num_kv_splits
,
num_kv_splits
,
self
.
max_kv_splits
,
layer
.
scaling
,
layer
.
logit_cap
,
)
...
...
@@ -493,6 +611,9 @@ class TritonMultiStepDraftBackend:
)
)
self
.
max_context_len
=
self
.
attn_backends
[
0
].
max_context_len
self
.
num_head
=
(
model_runner
.
model_config
.
num_attention_heads
//
get_attention_tp_size
()
)
self
.
device
=
model_runner
.
device
# Cached variables for generate_draft_decode_kv_indices
self
.
pool_len
=
model_runner
.
req_to_token_pool
.
req_to_token
.
shape
[
1
]
...
...
@@ -579,9 +700,15 @@ class TritonMultiStepDraftBackend:
def
init_forward_metadata_replay_cuda_graph
(
self
,
forward_batch
:
ForwardBatch
,
bs
:
int
):
num_kv_heads
=
self
.
num_head
if
hasattr
(
forward_batch
.
token_to_kv_pool
,
"k_buffer"
):
if
isinstance
(
forward_batch
.
token_to_kv_pool
.
k_buffer
,
list
):
num_kv_heads
=
forward_batch
.
token_to_kv_pool
.
k_buffer
[
0
].
shape
[
1
]
def
call_fn
(
i
,
forward_batch
):
self
.
attn_backends
[
i
].
init_forward_metadata_replay_cuda_graph
(
bs
,
num_kv_heads
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
seq_lens_sum
=-
1
,
...
...
python/sglang/srt/layers/attention/triton_ops/decode_attention.py
View file @
c0e9a36c
...
...
@@ -37,6 +37,9 @@ logger.warning(
)
_MIN_BLOCK_KV
=
32
@
triton
.
jit
def
tanh
(
x
):
# Tanh is just a scaled sigmoid
...
...
@@ -52,6 +55,8 @@ def _fwd_kernel_stage1(
kv_indptr
,
kv_indices
,
Att_Out
,
Att_Lse
,
num_kv_splits
,
stride_qbs
,
stride_qh
,
stride_buf_kbs
,
...
...
@@ -65,7 +70,7 @@ def _fwd_kernel_stage1(
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_DV
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
NUM_KV_SPLITS
:
tl
.
constexpr
,
MIN_BLOCK_KV
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
Lk
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
...
...
@@ -83,11 +88,13 @@ def _fwd_kernel_stage1(
cur_batch_kv_start_idx
=
tl
.
load
(
kv_indptr
+
cur_batch
)
cur_batch_seq_len
=
tl
.
load
(
kv_indptr
+
cur_batch
+
1
)
-
cur_batch_kv_start_idx
kv_splits
=
tl
.
load
(
num_kv_splits
+
cur_batch
)
off_q
=
cur_batch
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_d
q
=
tl
.
load
(
Q
+
off_q
,
mask
=
mask_d
,
other
=
0.0
)
kv_len_per_split
=
tl
.
cdiv
(
cur_batch_seq_len
,
NUM_KV_SPLITS
)
kv_len_per_split
=
(
tl
.
cdiv
(
tl
.
cdiv
(
cur_batch_seq_len
,
kv_splits
),
MIN_BLOCK_KV
)
*
MIN_BLOCK_KV
)
split_kv_start
=
kv_len_per_split
*
split_kv_id
split_kv_end
=
tl
.
minimum
(
split_kv_start
+
kv_len_per_split
,
cur_batch_seq_len
)
...
...
@@ -96,6 +103,7 @@ def _fwd_kernel_stage1(
acc
=
tl
.
zeros
([
BLOCK_DV
],
dtype
=
tl
.
float32
)
if
split_kv_end
>
split_kv_start
:
q
=
tl
.
load
(
Q
+
off_q
,
mask
=
mask_d
,
other
=
0.0
)
for
start_n
in
range
(
split_kv_start
,
split_kv_end
,
BLOCK_N
):
offs_n
=
start_n
+
tl
.
arange
(
0
,
BLOCK_N
)
kv_loc
=
tl
.
load
(
...
...
@@ -158,11 +166,10 @@ def _fwd_kernel_stage1(
cur_batch
*
stride_mid_ob
+
cur_head
*
stride_mid_oh
+
split_kv_id
*
stride_mid_os
+
Lv
)
)
//
Lv
tl
.
store
(
Att_
Out
+
offs_mid_o_1
,
Att_
Lse
+
offs_mid_o_1
,
e_max
+
tl
.
log
(
e_sum
),
)
...
...
@@ -172,9 +179,11 @@ def _decode_att_m_fwd(
k_buffer
,
v_buffer
,
att_out
,
att_lse
,
kv_indptr
,
kv_indices
,
num_kv_splits
,
max_kv_splits
,
sm_scale
,
logit_cap
,
):
...
...
@@ -182,13 +191,13 @@ def _decode_att_m_fwd(
# [TODO] work around SGPR limit on MI3xx
if
_is_hip
:
BLOCK
=
8
NUM
_KV_SPLITS
=
num
_kv_splits
MAX
_KV_SPLITS
=
max
_kv_splits
Lk
=
k_buffer
.
shape
[
-
1
]
Lv
=
v_buffer
.
shape
[
-
1
]
batch
,
head_num
=
kv_indptr
.
shape
[
0
]
-
1
,
q
.
shape
[
1
]
grid
=
(
batch
,
head_num
,
NUM
_KV_SPLITS
)
grid
=
(
batch
,
head_num
,
MAX
_KV_SPLITS
)
kv_group_num
=
q
.
shape
[
1
]
//
k_buffer
.
shape
[
1
]
if
kv_group_num
==
1
:
...
...
@@ -209,6 +218,8 @@ def _decode_att_m_fwd(
kv_indptr
,
kv_indices
,
att_out
,
att_lse
,
num_kv_splits
,
q
.
stride
(
0
),
q
.
stride
(
1
),
k_buffer
.
stride
(
0
),
...
...
@@ -222,7 +233,7 @@ def _decode_att_m_fwd(
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_DV
=
BLOCK_DV
,
BLOCK_N
=
BLOCK
,
NUM_KV_SPLITS
=
NUM_KV_SPLITS
,
MIN_BLOCK_KV
=
_MIN_BLOCK_KV
,
logit_cap
=
logit_cap
,
num_warps
=
num_warps
,
num_stages
=
2
,
...
...
@@ -240,6 +251,8 @@ def _fwd_grouped_kernel_stage1(
kv_indptr
,
kv_indices
,
Att_Out
,
Att_Lse
,
num_kv_splits
,
stride_qbs
,
stride_qh
,
stride_buf_kbs
,
...
...
@@ -256,7 +269,7 @@ def _fwd_grouped_kernel_stage1(
BLOCK_DV
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_H
:
tl
.
constexpr
,
NUM_KV_SPLITS
:
tl
.
constexpr
,
MIN_BLOCK_KV
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
Lk
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
...
...
@@ -281,9 +294,9 @@ def _fwd_grouped_kernel_stage1(
cur_batch_kv_start_idx
=
tl
.
load
(
kv_indptr
+
cur_batch
)
cur_batch_seq_len
=
tl
.
load
(
kv_indptr
+
cur_batch
+
1
)
-
cur_batch_kv_start_idx
kv_splits
=
tl
.
load
(
num_kv_splits
+
cur_batch
)
offs_q
=
cur_batch
*
stride_qbs
+
cur_head
[:,
None
]
*
stride_qh
+
offs_d
[
None
,
:]
q
=
tl
.
load
(
Q
+
offs_q
,
mask
=
(
mask_h
[:,
None
])
&
(
mask_d
[
None
,
:]),
other
=
0.0
)
if
BLOCK_DPE
>
0
:
offs_dpe
=
BLOCK_DMODEL
+
tl
.
arange
(
0
,
BLOCK_DPE
)
...
...
@@ -291,11 +304,10 @@ def _fwd_grouped_kernel_stage1(
off_qpe
=
(
cur_batch
*
stride_qbs
+
cur_head
[:,
None
]
*
stride_qh
+
offs_dpe
[
None
,
:]
)
qpe
=
tl
.
load
(
Q
+
off_qpe
,
mask
=
(
mask_h
[:,
None
])
&
(
mask_dpe
[
None
,
:]),
other
=
0.0
)
kv_len_per_split
=
tl
.
cdiv
(
cur_batch_seq_len
,
NUM_KV_SPLITS
)
kv_len_per_split
=
(
tl
.
cdiv
(
tl
.
cdiv
(
cur_batch_seq_len
,
kv_splits
),
MIN_BLOCK_KV
)
*
MIN_BLOCK_KV
)
split_kv_start
=
kv_len_per_split
*
split_kv_id
split_kv_end
=
tl
.
minimum
(
split_kv_start
+
kv_len_per_split
,
cur_batch_seq_len
)
...
...
@@ -304,6 +316,11 @@ def _fwd_grouped_kernel_stage1(
acc
=
tl
.
zeros
([
BLOCK_H
,
BLOCK_DV
],
dtype
=
tl
.
float32
)
if
split_kv_end
>
split_kv_start
:
q
=
tl
.
load
(
Q
+
offs_q
,
mask
=
(
mask_h
[:,
None
])
&
(
mask_d
[
None
,
:]),
other
=
0.0
)
if
BLOCK_DPE
>
0
:
qpe
=
tl
.
load
(
Q
+
off_qpe
,
mask
=
(
mask_h
[:,
None
])
&
(
mask_dpe
[
None
,
:]),
other
=
0.0
)
for
start_n
in
range
(
split_kv_start
,
split_kv_end
,
BLOCK_N
):
offs_n
=
start_n
+
tl
.
arange
(
0
,
BLOCK_N
)
kv_loc
=
tl
.
load
(
...
...
@@ -380,11 +397,10 @@ def _fwd_grouped_kernel_stage1(
cur_batch
*
stride_mid_ob
+
cur_head
*
stride_mid_oh
+
split_kv_id
*
stride_mid_os
+
Lv
)
)
//
Lv
tl
.
store
(
Att_
Out
+
offs_mid_o_1
,
Att_
Lse
+
offs_mid_o_1
,
e_max
+
tl
.
log
(
e_sum
),
mask
=
mask_h
,
)
...
...
@@ -395,9 +411,11 @@ def _decode_grouped_att_m_fwd(
k_buffer
,
v_buffer
,
att_out
,
att_lse
,
kv_indptr
,
kv_indices
,
num_kv_splits
,
max_kv_splits
,
sm_scale
,
logit_cap
,
):
...
...
@@ -424,11 +442,11 @@ def _decode_grouped_att_m_fwd(
kv_group_num
=
q
.
shape
[
1
]
//
k_buffer
.
shape
[
1
]
BLOCK_H
=
16
NUM
_KV_SPLITS
=
num
_kv_splits
MAX
_KV_SPLITS
=
max
_kv_splits
grid
=
(
batch
,
triton
.
cdiv
(
head_num
,
min
(
BLOCK_H
,
kv_group_num
)),
NUM
_KV_SPLITS
,
MAX
_KV_SPLITS
,
)
extra_kargs
=
{}
...
...
@@ -447,6 +465,8 @@ def _decode_grouped_att_m_fwd(
kv_indptr
,
kv_indices
,
att_out
,
att_lse
,
num_kv_splits
,
q
.
stride
(
0
),
q
.
stride
(
1
),
k_buffer
.
stride
(
0
),
...
...
@@ -463,7 +483,7 @@ def _decode_grouped_att_m_fwd(
BLOCK_DV
=
BLOCK_DV
,
BLOCK_N
=
BLOCK
,
BLOCK_H
=
BLOCK_H
,
NUM_KV_SPLITS
=
NUM_KV_SPLITS
,
MIN_BLOCK_KV
=
_MIN_BLOCK_KV
,
logit_cap
=
logit_cap
,
num_warps
=
4
,
num_stages
=
num_stages
,
...
...
@@ -476,14 +496,17 @@ def _decode_grouped_att_m_fwd(
@
triton
.
jit
def
_fwd_kernel_stage2
(
Mid_O
,
Mid_O_1
,
O
,
kv_indptr
,
num_kv_splits
,
stride_mid_ob
,
stride_mid_oh
,
stride_mid_os
,
stride_obs
,
stride_oh
,
NUM_KV_SPLITS
:
tl
.
constexpr
,
MAX_KV_SPLITS
:
tl
.
constexpr
,
MIN_BLOCK_KV
:
tl
.
constexpr
,
BLOCK_DV
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
):
...
...
@@ -493,6 +516,7 @@ def _fwd_kernel_stage2(
cur_batch_seq_len
=
tl
.
load
(
kv_indptr
+
cur_batch
+
1
)
-
tl
.
load
(
kv_indptr
+
cur_batch
)
kv_splits
=
tl
.
load
(
num_kv_splits
+
cur_batch
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DV
)
mask_d
=
offs_d
<
Lv
...
...
@@ -502,10 +526,12 @@ def _fwd_kernel_stage2(
acc
=
tl
.
zeros
([
BLOCK_DV
],
dtype
=
tl
.
float32
)
offs_v
=
cur_batch
*
stride_mid_ob
+
cur_head
*
stride_mid_oh
+
offs_d
offs_logic
=
cur_batch
*
stride_mid_ob
+
cur_head
*
stride_mid_oh
+
Lv
offs_logic
=
(
cur_batch
*
stride_mid_ob
+
cur_head
*
stride_mid_oh
)
//
Lv
kv_len_per_split
=
(
tl
.
cdiv
(
tl
.
cdiv
(
cur_batch_seq_len
,
kv_splits
),
MIN_BLOCK_KV
)
*
MIN_BLOCK_KV
)
for
split_kv_id
in
range
(
0
,
NUM_KV_SPLITS
):
kv_len_per_split
=
tl
.
cdiv
(
cur_batch_seq_len
,
NUM_KV_SPLITS
)
for
split_kv_id
in
range
(
0
,
MAX_KV_SPLITS
):
split_kv_start
=
kv_len_per_split
*
split_kv_id
split_kv_end
=
tl
.
minimum
(
split_kv_start
+
kv_len_per_split
,
cur_batch_seq_len
)
...
...
@@ -513,7 +539,7 @@ def _fwd_kernel_stage2(
tv
=
tl
.
load
(
Mid_O
+
offs_v
+
split_kv_id
*
stride_mid_os
,
mask
=
mask_d
,
other
=
0.0
)
tlogic
=
tl
.
load
(
Mid_O
+
offs_logic
+
split_kv_id
*
stride_mid_os
)
tlogic
=
tl
.
load
(
Mid_O
_1
+
offs_logic
+
split_kv_id
*
stride_mid_os
//
Lv
)
n_e_max
=
tl
.
maximum
(
tlogic
,
e_max
)
old_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
...
...
@@ -533,17 +559,19 @@ def _fwd_kernel_stage2(
def
_decode_softmax_reducev_fwd
(
logits
,
lse
,
q
,
o
,
v_buffer
,
kv_indptr
,
num_kv_splits
,
max_kv_splits
,
):
batch
,
head_num
=
q
.
shape
[
0
],
q
.
shape
[
1
]
Lv
=
v_buffer
.
shape
[
-
1
]
BLOCK_DV
=
triton
.
next_power_of_2
(
Lv
)
NUM
_KV_SPLITS
=
num
_kv_splits
MAX
_KV_SPLITS
=
max
_kv_splits
extra_kargs
=
{}
if
_is_hip
:
...
...
@@ -554,14 +582,17 @@ def _decode_softmax_reducev_fwd(
grid
=
(
batch
,
head_num
)
_fwd_kernel_stage2
[
grid
](
logits
,
lse
,
o
,
kv_indptr
,
num_kv_splits
,
logits
.
stride
(
0
),
logits
.
stride
(
1
),
logits
.
stride
(
2
),
o
.
stride
(
0
),
o
.
stride
(
1
),
NUM_KV_SPLITS
=
NUM_KV_SPLITS
,
MAX_KV_SPLITS
=
MAX_KV_SPLITS
,
MIN_BLOCK_KV
=
_MIN_BLOCK_KV
,
BLOCK_DV
=
BLOCK_DV
,
Lv
=
Lv
,
num_warps
=
4
,
...
...
@@ -579,6 +610,7 @@ def decode_attention_fwd_normal(
kv_indices
,
attn_logits
,
num_kv_splits
,
max_kv_splits
,
sm_scale
,
logit_cap
=
0.0
,
):
...
...
@@ -586,14 +618,25 @@ def decode_attention_fwd_normal(
q
,
k_buffer
,
v_buffer
,
attn_logits
,
attn_logits
[
0
],
attn_logits
[
1
],
kv_indptr
,
kv_indices
,
num_kv_splits
,
max_kv_splits
,
sm_scale
,
logit_cap
,
)
_decode_softmax_reducev_fwd
(
attn_logits
,
q
,
o
,
v_buffer
,
kv_indptr
,
num_kv_splits
)
_decode_softmax_reducev_fwd
(
attn_logits
[
0
],
attn_logits
[
1
],
q
,
o
,
v_buffer
,
kv_indptr
,
num_kv_splits
,
max_kv_splits
,
)
def
decode_attention_fwd_grouped
(
...
...
@@ -605,6 +648,7 @@ def decode_attention_fwd_grouped(
kv_indices
,
attn_logits
,
num_kv_splits
,
max_kv_splits
,
sm_scale
,
logit_cap
=
0.0
,
):
...
...
@@ -612,14 +656,25 @@ def decode_attention_fwd_grouped(
q
,
k_buffer
,
v_buffer
,
attn_logits
,
attn_logits
[
0
],
attn_logits
[
1
],
kv_indptr
,
kv_indices
,
num_kv_splits
,
max_kv_splits
,
sm_scale
,
logit_cap
,
)
_decode_softmax_reducev_fwd
(
attn_logits
,
q
,
o
,
v_buffer
,
kv_indptr
,
num_kv_splits
)
_decode_softmax_reducev_fwd
(
attn_logits
[
0
],
attn_logits
[
1
],
q
,
o
,
v_buffer
,
kv_indptr
,
num_kv_splits
,
max_kv_splits
,
)
def
decode_attention_fwd
(
...
...
@@ -631,12 +686,13 @@ def decode_attention_fwd(
kv_indices
,
attn_logits
,
num_kv_splits
,
max_kv_splits
,
sm_scale
,
logit_cap
=
0.0
,
):
assert
num
_kv_splits
==
attn_logits
.
shape
[
2
]
assert
max
_kv_splits
==
attn_logits
[
0
]
.
shape
[
2
]
assert
q
.
shape
[
0
]
<=
kv_indptr
.
shape
[
0
]
-
1
assert
q
.
shape
[
0
]
<=
attn_logits
.
shape
[
0
]
assert
q
.
shape
[
0
]
<=
attn_logits
[
0
]
.
shape
[
0
]
kv_group_num
=
q
.
shape
[
1
]
//
v_buffer
.
shape
[
1
]
...
...
@@ -651,6 +707,7 @@ def decode_attention_fwd(
kv_indices
,
attn_logits
,
num_kv_splits
,
max_kv_splits
,
sm_scale
,
logit_cap
,
)
...
...
@@ -665,6 +722,7 @@ def decode_attention_fwd(
kv_indices
,
attn_logits
,
num_kv_splits
,
max_kv_splits
,
sm_scale
,
logit_cap
,
)
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
c0e9a36c
...
...
@@ -26,6 +26,7 @@ import tqdm
from
sglang.srt.custom_op
import
CustomOp
from
sglang.srt.distributed
import
get_tensor_model_parallel_rank
from
sglang.srt.distributed.parallel_state
import
GroupCoordinator
,
graph_capture
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.moe.fused_moe_native
import
fused_moe_forward_native
from
sglang.srt.layers.torchao_utils
import
save_gemlite_cache
...
...
@@ -195,6 +196,9 @@ class CudaGraphRunner:
# Attention backend
self
.
max_bs
=
max
(
self
.
capture_bs
)
self
.
max_num_token
=
self
.
max_bs
*
self
.
num_tokens_per_bs
self
.
num_head
=
(
model_runner
.
model_config
.
num_attention_heads
//
get_attention_tp_size
()
)
self
.
model_runner
.
attn_backend
.
init_cuda_graph_state
(
self
.
max_num_token
)
self
.
seq_len_fill_value
=
(
self
.
model_runner
.
attn_backend
.
get_cuda_graph_seq_len_fill_value
()
...
...
@@ -503,9 +507,15 @@ class CudaGraphRunner:
if
hasattr
(
forward_batch
.
spec_info
,
"hidden_states"
):
self
.
hidden_states
[:
raw_num_token
]
=
forward_batch
.
spec_info
.
hidden_states
num_kv_heads
=
self
.
num_head
if
hasattr
(
forward_batch
.
token_to_kv_pool
,
"k_buffer"
):
if
isinstance
(
forward_batch
.
token_to_kv_pool
.
k_buffer
,
list
):
num_kv_heads
=
forward_batch
.
token_to_kv_pool
.
k_buffer
[
0
].
shape
[
1
]
# Attention backend
self
.
model_runner
.
attn_backend
.
init_forward_metadata_replay_cuda_graph
(
bs
,
num_kv_heads
,
self
.
req_pool_indices
,
self
.
seq_lens
,
forward_batch
.
seq_lens_sum
+
(
bs
-
raw_bs
),
...
...
test/srt/test_triton_attention_kernels.py
View file @
c0e9a36c
...
...
@@ -228,7 +228,8 @@ class TestTritonAttention(unittest.TestCase):
seq_len
=
10
# This represents the number of tokens already in the sequence
total_tokens
=
B
*
seq_len
sm_scale
=
1.0
/
(
D
**
0.5
)
num_kv_splits
=
8
max_kv_splits
=
8
num_kv_splits
=
torch
.
full
((
B
,),
4
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
# q represents the new token being generated, one per batch
q
=
torch
.
randn
(
B
,
H_Q
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
...
...
@@ -247,7 +248,12 @@ class TestTritonAttention(unittest.TestCase):
kv_indices
=
torch
.
arange
(
total_tokens
,
device
=
"cuda"
)
attn_logits
=
torch
.
empty
(
(
B
,
H_Q
,
num_kv_splits
,
D
+
1
),
(
B
,
H_Q
,
max_kv_splits
,
D
),
dtype
=
torch
.
float32
,
device
=
"cuda"
,
)
attn_lse
=
torch
.
empty
(
(
B
,
H_Q
,
max_kv_splits
),
dtype
=
torch
.
float32
,
device
=
"cuda"
,
)
...
...
@@ -259,8 +265,9 @@ class TestTritonAttention(unittest.TestCase):
o
,
kv_indptr
,
kv_indices
,
attn_logits
,
(
attn_logits
,
attn_lse
),
num_kv_splits
,
max_kv_splits
,
sm_scale
,
)
...
...
@@ -284,7 +291,8 @@ class TestTritonAttention(unittest.TestCase):
seq_len
=
S
# This represents the number of tokens already in the sequence
total_tokens
=
B
*
seq_len
sm_scale
=
1.0
/
(
D
**
0.5
)
num_kv_splits
=
8
max_kv_splits
=
8
num_kv_splits
=
torch
.
full
((
B
,),
4
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
# q represents the new token being generated, one per batch
q
=
torch
.
randn
(
B
,
H_Q
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
...
...
@@ -304,7 +312,12 @@ class TestTritonAttention(unittest.TestCase):
kv_indices
=
torch
.
arange
(
total_tokens
,
device
=
"cuda"
)
attn_logits
=
torch
.
empty
(
(
B
,
H_Q
,
num_kv_splits
,
D_V
+
1
),
(
B
,
H_Q
,
max_kv_splits
,
D_V
),
dtype
=
torch
.
float32
,
device
=
"cuda"
,
)
attn_lse
=
torch
.
empty
(
(
B
,
H_Q
,
max_kv_splits
),
dtype
=
torch
.
float32
,
device
=
"cuda"
,
)
...
...
@@ -316,13 +329,19 @@ class TestTritonAttention(unittest.TestCase):
o
,
kv_indptr
,
kv_indices
,
attn_logits
,
(
attn_logits
,
attn_lse
),
num_kv_splits
,
max_kv_splits
,
sm_scale
,
)
attn_logits1
=
torch
.
empty
(
(
B
,
H_Q
,
num_kv_splits
,
D_V
+
1
),
(
B
,
H_Q
,
max_kv_splits
,
D_V
),
dtype
=
torch
.
float32
,
device
=
"cuda"
,
)
attn_lse1
=
torch
.
empty
(
(
B
,
H_Q
,
max_kv_splits
,
D_V
),
dtype
=
torch
.
float32
,
device
=
"cuda"
,
)
...
...
@@ -334,8 +353,9 @@ class TestTritonAttention(unittest.TestCase):
o_grouped
,
kv_indptr
,
kv_indices
,
attn_logits1
,
(
attn_logits1
,
attn_lse1
),
num_kv_splits
,
max_kv_splits
,
sm_scale
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment