Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
c0e9a36c
Unverified
Commit
c0e9a36c
authored
Mar 19, 2025
by
JieXin Liang
Committed by
GitHub
Mar 18, 2025
Browse files
Optimize Triton decoding kernel for dynamic workload (#4553)
parent
588865f0
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
277 additions
and
57 deletions
+277
-57
python/sglang/srt/layers/attention/base_attn_backend.py
python/sglang/srt/layers/attention/base_attn_backend.py
+1
-0
python/sglang/srt/layers/attention/flashinfer_backend.py
python/sglang/srt/layers/attention/flashinfer_backend.py
+2
-0
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
+2
-0
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+142
-15
python/sglang/srt/layers/attention/triton_ops/decode_attention.py
...glang/srt/layers/attention/triton_ops/decode_attention.py
+92
-34
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+10
-0
test/srt/test_triton_attention_kernels.py
test/srt/test_triton_attention_kernels.py
+28
-8
No files found.
python/sglang/srt/layers/attention/base_attn_backend.py
View file @
c0e9a36c
...
@@ -39,6 +39,7 @@ class AttentionBackend(ABC):
...
@@ -39,6 +39,7 @@ class AttentionBackend(ABC):
def
init_forward_metadata_replay_cuda_graph
(
def
init_forward_metadata_replay_cuda_graph
(
self
,
self
,
bs
:
int
,
bs
:
int
,
num_kv_heads
:
int
,
req_pool_indices
:
torch
.
Tensor
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
...
...
python/sglang/srt/layers/attention/flashinfer_backend.py
View file @
c0e9a36c
...
@@ -349,6 +349,7 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -349,6 +349,7 @@ class FlashInferAttnBackend(AttentionBackend):
def
init_forward_metadata_replay_cuda_graph
(
def
init_forward_metadata_replay_cuda_graph
(
self
,
self
,
bs
:
int
,
bs
:
int
,
num_kv_heads
:
int
,
req_pool_indices
:
torch
.
Tensor
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
...
@@ -1062,6 +1063,7 @@ class FlashInferMultiStepDraftBackend:
...
@@ -1062,6 +1063,7 @@ class FlashInferMultiStepDraftBackend:
def
call_fn
(
i
,
forward_batch
):
def
call_fn
(
i
,
forward_batch
):
self
.
attn_backends
[
i
].
init_forward_metadata_replay_cuda_graph
(
self
.
attn_backends
[
i
].
init_forward_metadata_replay_cuda_graph
(
bs
,
bs
,
-
1
,
forward_batch
.
req_pool_indices
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
forward_batch
.
seq_lens
,
seq_lens_sum
=-
1
,
seq_lens_sum
=-
1
,
...
...
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
View file @
c0e9a36c
...
@@ -279,6 +279,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
...
@@ -279,6 +279,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
def
init_forward_metadata_replay_cuda_graph
(
def
init_forward_metadata_replay_cuda_graph
(
self
,
self
,
bs
:
int
,
bs
:
int
,
num_kv_heads
:
int
,
req_pool_indices
:
torch
.
Tensor
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
...
@@ -791,6 +792,7 @@ class FlashInferMLAMultiStepDraftBackend:
...
@@ -791,6 +792,7 @@ class FlashInferMLAMultiStepDraftBackend:
def
call_fn
(
i
,
forward_batch
):
def
call_fn
(
i
,
forward_batch
):
self
.
attn_backends
[
i
].
init_forward_metadata_replay_cuda_graph
(
self
.
attn_backends
[
i
].
init_forward_metadata_replay_cuda_graph
(
bs
,
bs
,
-
1
,
forward_batch
.
req_pool_indices
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
forward_batch
.
seq_lens
,
seq_lens_sum
=-
1
,
seq_lens_sum
=-
1
,
...
...
python/sglang/srt/layers/attention/triton_backend.py
View file @
c0e9a36c
...
@@ -4,11 +4,13 @@ from typing import TYPE_CHECKING, Optional, Union
...
@@ -4,11 +4,13 @@ from typing import TYPE_CHECKING, Optional, Union
import
torch
import
torch
import
triton
import
triton
import
triton.language
as
tl
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.layers.attention.utils
import
create_flashinfer_kv_indices_triton
from
sglang.srt.layers.attention.utils
import
create_flashinfer_kv_indices_triton
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.utils
import
get_bool_env_var
,
get_device_core_count
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
...
@@ -16,6 +18,51 @@ if TYPE_CHECKING:
...
@@ -16,6 +18,51 @@ if TYPE_CHECKING:
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
EagleVerifyInput
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
EagleVerifyInput
@
triton
.
jit
def
get_num_kv_splits_triton
(
num_kv_splits_ptr
,
seq_lens_ptr
,
bs
,
num_head
,
num_kv_head
,
max_kv_splits
,
device_core_count
,
MAX_BS
:
tl
.
constexpr
,
):
# TODO: this method is tunable
offs_b
=
tl
.
arange
(
0
,
MAX_BS
)
mask_b
=
offs_b
<
bs
seq_lens
=
tl
.
load
(
seq_lens_ptr
+
offs_b
,
mask
=
mask_b
,
other
=
0
)
max_seq_len
=
tl
.
max
(
seq_lens
)
seq_lens
=
tl
.
load
(
seq_lens_ptr
+
offs_b
,
mask
=
mask_b
,
other
=
max_seq_len
)
min_seq_len
=
tl
.
min
(
seq_lens
)
if
max_seq_len
*
8
<
min_seq_len
*
10
:
min_seq_len
=
max_seq_len
max_kv_splits_1
=
tl
.
minimum
(
tl
.
cdiv
(
max_seq_len
,
min_seq_len
),
max_kv_splits
)
kv_chunk_size_1
=
tl
.
cdiv
(
max_seq_len
,
max_kv_splits_1
)
# NOTE: this is a hack to let num_kv_split grows up with seqlen gradually
ext_seq_len
=
tl
.
cast
(
tl
.
cdiv
(
max_seq_len
,
256
),
tl
.
float32
)
ext_device_core_count
=
device_core_count
*
tl
.
maximum
(
tl
.
cast
(
tl
.
ceil
(
tl
.
log2
(
ext_seq_len
)),
tl
.
int32
),
1
)
block_h
,
num_kv_group
=
16
,
num_head
//
num_kv_head
if
num_kv_group
==
1
:
bh_grid
=
bs
*
num_head
else
:
# from triton_ops/decode_attention.py:_decode_grouped_att_m_fwd
block_h
=
tl
.
minimum
(
block_h
,
num_kv_group
)
bh_grid
=
bs
*
tl
.
cdiv
(
num_head
,
block_h
)
max_kv_splits_2
=
tl
.
minimum
(
tl
.
cdiv
(
ext_device_core_count
,
bh_grid
),
max_kv_splits
)
kv_chunk_size_2
=
tl
.
cdiv
(
max_seq_len
,
max_kv_splits_2
)
num_kv_splits
=
tl
.
maximum
(
tl
.
cdiv
(
seq_lens
,
kv_chunk_size_1
),
tl
.
cdiv
(
seq_lens
,
kv_chunk_size_2
)
)
tl
.
store
(
num_kv_splits_ptr
+
offs_b
,
num_kv_splits
,
mask
=
mask_b
)
class
TritonAttnBackend
(
AttentionBackend
):
class
TritonAttnBackend
(
AttentionBackend
):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -64,7 +111,10 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -64,7 +111,10 @@ class TritonAttnBackend(AttentionBackend):
model_runner
.
model_config
.
num_attention_heads
//
get_attention_tp_size
()
model_runner
.
model_config
.
num_attention_heads
//
get_attention_tp_size
()
)
)
self
.
num_kv_splits
=
model_runner
.
server_args
.
triton_attention_num_kv_splits
self
.
static_kv_splits
=
get_bool_env_var
(
"SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS"
,
"false"
)
self
.
max_kv_splits
=
model_runner
.
server_args
.
triton_attention_num_kv_splits
self
.
v_head_dim
=
model_runner
.
token_to_kv_pool
.
get_value_buffer
(
0
).
shape
[
-
1
]
self
.
v_head_dim
=
model_runner
.
token_to_kv_pool
.
get_value_buffer
(
0
).
shape
[
-
1
]
self
.
forward_metadata
=
None
self
.
forward_metadata
=
None
...
@@ -72,6 +122,30 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -72,6 +122,30 @@ class TritonAttnBackend(AttentionBackend):
self
.
max_context_len
=
model_runner
.
model_config
.
context_len
self
.
max_context_len
=
model_runner
.
model_config
.
context_len
self
.
device
=
model_runner
.
device
self
.
device
=
model_runner
.
device
self
.
device_core_count
=
get_device_core_count
(
model_runner
.
gpu_id
)
def
get_num_kv_splits
(
self
,
num_kv_splits
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
bs
:
int
,
num_kv_head
:
int
,
):
MAX_SCHEDULE_BS
=
4096
if
self
.
static_kv_splits
or
self
.
device_core_count
<=
0
or
bs
>
MAX_SCHEDULE_BS
:
num_kv_splits
.
fill_
(
self
.
max_kv_splits
)
return
get_num_kv_splits_triton
[(
1
,)](
num_kv_splits
,
seq_lens
,
bs
,
self
.
num_head
,
num_kv_head
,
self
.
max_kv_splits
,
self
.
device_core_count
,
MAX_BS
=
MAX_SCHEDULE_BS
,
)
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
"""Init auxiliary variables for triton attention backend."""
"""Init auxiliary variables for triton attention backend."""
...
@@ -100,15 +174,35 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -100,15 +174,35 @@ class TritonAttnBackend(AttentionBackend):
kv_indptr
,
kv_indices
=
spec_info
.
kv_indptr
,
spec_info
.
kv_indices
kv_indptr
,
kv_indices
=
spec_info
.
kv_indptr
,
spec_info
.
kv_indices
bs
=
kv_indptr
.
shape
[
0
]
-
1
bs
=
kv_indptr
.
shape
[
0
]
-
1
attn_logits
=
torch
.
empty
(
attn_logits
=
[
(
torch
.
empty
(
bs
,
(
self
.
num_head
,
bs
,
self
.
num_kv_splits
,
self
.
num_head
,
self
.
v_head_dim
+
1
,
self
.
max_kv_splits
,
self
.
v_head_dim
,
),
dtype
=
torch
.
float32
,
device
=
self
.
device
,
),
),
dtype
=
torch
.
float32
,
torch
.
empty
(
device
=
self
.
device
,
(
bs
,
self
.
num_head
,
self
.
max_kv_splits
,
),
dtype
=
torch
.
float32
,
device
=
self
.
device
,
),
]
num_kv_splits
=
torch
.
empty
((
bs
,),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
num_kv_heads
=
self
.
num_head
if
hasattr
(
forward_batch
.
token_to_kv_pool
,
"k_buffer"
):
if
isinstance
(
forward_batch
.
token_to_kv_pool
.
k_buffer
,
list
):
num_kv_heads
=
forward_batch
.
token_to_kv_pool
.
k_buffer
[
0
].
shape
[
1
]
self
.
get_num_kv_splits
(
num_kv_splits
,
forward_batch
.
seq_lens
,
bs
,
num_kv_heads
)
)
qo_indptr
=
None
qo_indptr
=
None
...
@@ -148,6 +242,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -148,6 +242,7 @@ class TritonAttnBackend(AttentionBackend):
mask_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
seq_mask_len
[:
bs
],
dim
=
0
)
mask_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
seq_mask_len
[:
bs
],
dim
=
0
)
mask_indptr
=
mask_indptr
[:
bs
+
1
]
mask_indptr
=
mask_indptr
[:
bs
+
1
]
max_extend_len
=
self
.
num_draft_tokens
max_extend_len
=
self
.
num_draft_tokens
num_kv_splits
=
None
attn_logits
=
None
attn_logits
=
None
elif
forward_batch
.
forward_mode
.
is_draft_extend
():
elif
forward_batch
.
forward_mode
.
is_draft_extend
():
kv_indices
,
kv_indptr
,
qo_indptr
,
custom_mask
=
(
kv_indices
,
kv_indptr
,
qo_indptr
,
custom_mask
=
(
...
@@ -160,6 +255,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -160,6 +255,7 @@ class TritonAttnBackend(AttentionBackend):
)
)
mask_indptr
=
None
mask_indptr
=
None
max_extend_len
=
torch
.
max
(
spec_info
.
accept_length
).
item
()
max_extend_len
=
torch
.
max
(
spec_info
.
accept_length
).
item
()
num_kv_splits
=
None
attn_logits
=
None
attn_logits
=
None
else
:
else
:
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
...
@@ -188,10 +284,12 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -188,10 +284,12 @@ class TritonAttnBackend(AttentionBackend):
mask_indptr
=
None
mask_indptr
=
None
attn_logits
=
None
attn_logits
=
None
max_extend_len
=
torch
.
max
(
forward_batch
.
extend_seq_lens
).
item
()
max_extend_len
=
torch
.
max
(
forward_batch
.
extend_seq_lens
).
item
()
num_kv_splits
=
None
self
.
forward_metadata
=
(
self
.
forward_metadata
=
(
attn_logits
,
attn_logits
,
max_extend_len
,
max_extend_len
,
num_kv_splits
,
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
qo_indptr
,
qo_indptr
,
...
@@ -202,10 +300,20 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -202,10 +300,20 @@ class TritonAttnBackend(AttentionBackend):
def
init_cuda_graph_state
(
def
init_cuda_graph_state
(
self
,
max_bs
:
int
,
kv_indices_buf
:
Optional
[
torch
.
Tensor
]
=
None
self
,
max_bs
:
int
,
kv_indices_buf
:
Optional
[
torch
.
Tensor
]
=
None
):
):
self
.
cuda_graph_attn_logits
=
torch
.
zeros
(
self
.
cuda_graph_attn_logits
=
[
(
max_bs
,
self
.
num_head
,
self
.
num_kv_splits
,
self
.
v_head_dim
+
1
),
torch
.
zeros
(
dtype
=
torch
.
float32
,
(
max_bs
,
self
.
num_head
,
self
.
max_kv_splits
,
self
.
v_head_dim
),
device
=
self
.
device
,
dtype
=
torch
.
float32
,
device
=
self
.
device
,
),
torch
.
zeros
(
(
max_bs
,
self
.
num_head
,
self
.
max_kv_splits
),
dtype
=
torch
.
float32
,
device
=
self
.
device
,
),
]
self
.
cuda_graph_num_kv_splits
=
torch
.
full
(
(
max_bs
,),
self
.
max_kv_splits
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
)
if
kv_indices_buf
is
None
:
if
kv_indices_buf
is
None
:
self
.
cuda_graph_kv_indices
=
torch
.
zeros
(
self
.
cuda_graph_kv_indices
=
torch
.
zeros
(
...
@@ -255,6 +363,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -255,6 +363,7 @@ class TritonAttnBackend(AttentionBackend):
attn_logits
=
self
.
cuda_graph_attn_logits
attn_logits
=
self
.
cuda_graph_attn_logits
max_extend_len
=
None
max_extend_len
=
None
num_kv_splits
=
self
.
cuda_graph_num_kv_splits
qo_indptr
=
None
qo_indptr
=
None
custom_mask
=
None
custom_mask
=
None
mask_indptr
=
None
mask_indptr
=
None
...
@@ -285,6 +394,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -285,6 +394,7 @@ class TritonAttnBackend(AttentionBackend):
mask_indptr
=
self
.
mask_indptr
[:
bs
+
1
]
mask_indptr
=
self
.
mask_indptr
[:
bs
+
1
]
mask_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
seq_mask_len
,
dim
=
0
)
mask_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
seq_mask_len
,
dim
=
0
)
max_extend_len
=
self
.
num_draft_tokens
max_extend_len
=
self
.
num_draft_tokens
num_kv_splits
=
None
attn_logits
=
None
attn_logits
=
None
else
:
else
:
raise
ValueError
(
raise
ValueError
(
...
@@ -294,6 +404,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -294,6 +404,7 @@ class TritonAttnBackend(AttentionBackend):
self
.
forward_metadata
=
(
self
.
forward_metadata
=
(
attn_logits
,
attn_logits
,
max_extend_len
,
max_extend_len
,
num_kv_splits
,
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
qo_indptr
,
qo_indptr
,
...
@@ -304,6 +415,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -304,6 +415,7 @@ class TritonAttnBackend(AttentionBackend):
def
init_forward_metadata_replay_cuda_graph
(
def
init_forward_metadata_replay_cuda_graph
(
self
,
self
,
bs
:
int
,
bs
:
int
,
num_kv_head
:
int
,
req_pool_indices
:
torch
.
Tensor
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
...
@@ -317,6 +429,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -317,6 +429,7 @@ class TritonAttnBackend(AttentionBackend):
# Update kv_indptr, kv_indices
# Update kv_indptr, kv_indices
kv_indptr
=
self
.
kv_indptr
kv_indptr
=
self
.
kv_indptr
kv_indices
=
self
.
cuda_graph_kv_indices
kv_indices
=
self
.
cuda_graph_kv_indices
num_kv_splits
=
self
.
cuda_graph_num_kv_splits
if
spec_info
is
None
:
if
spec_info
is
None
:
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
seq_lens
[:
bs
],
dim
=
0
)
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
seq_lens
[:
bs
],
dim
=
0
)
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indptr
=
kv_indptr
[:
bs
+
1
]
...
@@ -332,6 +445,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -332,6 +445,7 @@ class TritonAttnBackend(AttentionBackend):
else
:
else
:
kv_indptr
[:
spec_info
.
kv_indptr
.
shape
[
0
]]
=
spec_info
.
kv_indptr
kv_indptr
[:
spec_info
.
kv_indptr
.
shape
[
0
]]
=
spec_info
.
kv_indptr
kv_indices
[:
spec_info
.
kv_indices
.
shape
[
0
]]
=
spec_info
.
kv_indices
kv_indices
[:
spec_info
.
kv_indices
.
shape
[
0
]]
=
spec_info
.
kv_indices
self
.
get_num_kv_splits
(
num_kv_splits
,
seq_lens
,
bs
,
num_kv_head
)
elif
forward_mode
.
is_target_verify
():
elif
forward_mode
.
is_target_verify
():
# Update qo_indptr, kv_indptr, kv_indices, custom_mask, mask_indptr
# Update qo_indptr, kv_indptr, kv_indices, custom_mask, mask_indptr
bs
=
len
(
req_pool_indices
)
bs
=
len
(
req_pool_indices
)
...
@@ -391,6 +505,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -391,6 +505,7 @@ class TritonAttnBackend(AttentionBackend):
(
(
_
,
_
,
max_extend_len
,
max_extend_len
,
_
,
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
qo_indptr
,
qo_indptr
,
...
@@ -435,7 +550,9 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -435,7 +550,9 @@ class TritonAttnBackend(AttentionBackend):
else
:
else
:
o
=
torch
.
empty_like
(
q
)
o
=
torch
.
empty_like
(
q
)
attn_logits
,
_
,
kv_indptr
,
kv_indices
,
_
,
_
,
_
=
self
.
forward_metadata
attn_logits
,
_
,
num_kv_splits
,
kv_indptr
,
kv_indices
,
_
,
_
,
_
=
(
self
.
forward_metadata
)
if
save_kv_cache
:
if
save_kv_cache
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
...
@@ -450,7 +567,8 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -450,7 +567,8 @@ class TritonAttnBackend(AttentionBackend):
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
attn_logits
,
attn_logits
,
self
.
num_kv_splits
,
num_kv_splits
,
self
.
max_kv_splits
,
layer
.
scaling
,
layer
.
scaling
,
layer
.
logit_cap
,
layer
.
logit_cap
,
)
)
...
@@ -493,6 +611,9 @@ class TritonMultiStepDraftBackend:
...
@@ -493,6 +611,9 @@ class TritonMultiStepDraftBackend:
)
)
)
)
self
.
max_context_len
=
self
.
attn_backends
[
0
].
max_context_len
self
.
max_context_len
=
self
.
attn_backends
[
0
].
max_context_len
self
.
num_head
=
(
model_runner
.
model_config
.
num_attention_heads
//
get_attention_tp_size
()
)
self
.
device
=
model_runner
.
device
self
.
device
=
model_runner
.
device
# Cached variables for generate_draft_decode_kv_indices
# Cached variables for generate_draft_decode_kv_indices
self
.
pool_len
=
model_runner
.
req_to_token_pool
.
req_to_token
.
shape
[
1
]
self
.
pool_len
=
model_runner
.
req_to_token_pool
.
req_to_token
.
shape
[
1
]
...
@@ -579,9 +700,15 @@ class TritonMultiStepDraftBackend:
...
@@ -579,9 +700,15 @@ class TritonMultiStepDraftBackend:
def
init_forward_metadata_replay_cuda_graph
(
def
init_forward_metadata_replay_cuda_graph
(
self
,
forward_batch
:
ForwardBatch
,
bs
:
int
self
,
forward_batch
:
ForwardBatch
,
bs
:
int
):
):
num_kv_heads
=
self
.
num_head
if
hasattr
(
forward_batch
.
token_to_kv_pool
,
"k_buffer"
):
if
isinstance
(
forward_batch
.
token_to_kv_pool
.
k_buffer
,
list
):
num_kv_heads
=
forward_batch
.
token_to_kv_pool
.
k_buffer
[
0
].
shape
[
1
]
def
call_fn
(
i
,
forward_batch
):
def
call_fn
(
i
,
forward_batch
):
self
.
attn_backends
[
i
].
init_forward_metadata_replay_cuda_graph
(
self
.
attn_backends
[
i
].
init_forward_metadata_replay_cuda_graph
(
bs
,
bs
,
num_kv_heads
,
forward_batch
.
req_pool_indices
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
forward_batch
.
seq_lens
,
seq_lens_sum
=-
1
,
seq_lens_sum
=-
1
,
...
...
python/sglang/srt/layers/attention/triton_ops/decode_attention.py
View file @
c0e9a36c
...
@@ -37,6 +37,9 @@ logger.warning(
...
@@ -37,6 +37,9 @@ logger.warning(
)
)
_MIN_BLOCK_KV
=
32
@
triton
.
jit
@
triton
.
jit
def
tanh
(
x
):
def
tanh
(
x
):
# Tanh is just a scaled sigmoid
# Tanh is just a scaled sigmoid
...
@@ -52,6 +55,8 @@ def _fwd_kernel_stage1(
...
@@ -52,6 +55,8 @@ def _fwd_kernel_stage1(
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
Att_Out
,
Att_Out
,
Att_Lse
,
num_kv_splits
,
stride_qbs
,
stride_qbs
,
stride_qh
,
stride_qh
,
stride_buf_kbs
,
stride_buf_kbs
,
...
@@ -65,7 +70,7 @@ def _fwd_kernel_stage1(
...
@@ -65,7 +70,7 @@ def _fwd_kernel_stage1(
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_DV
:
tl
.
constexpr
,
BLOCK_DV
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
NUM_KV_SPLITS
:
tl
.
constexpr
,
MIN_BLOCK_KV
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
Lk
:
tl
.
constexpr
,
Lk
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
...
@@ -83,11 +88,13 @@ def _fwd_kernel_stage1(
...
@@ -83,11 +88,13 @@ def _fwd_kernel_stage1(
cur_batch_kv_start_idx
=
tl
.
load
(
kv_indptr
+
cur_batch
)
cur_batch_kv_start_idx
=
tl
.
load
(
kv_indptr
+
cur_batch
)
cur_batch_seq_len
=
tl
.
load
(
kv_indptr
+
cur_batch
+
1
)
-
cur_batch_kv_start_idx
cur_batch_seq_len
=
tl
.
load
(
kv_indptr
+
cur_batch
+
1
)
-
cur_batch_kv_start_idx
kv_splits
=
tl
.
load
(
num_kv_splits
+
cur_batch
)
off_q
=
cur_batch
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_d
off_q
=
cur_batch
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_d
q
=
tl
.
load
(
Q
+
off_q
,
mask
=
mask_d
,
other
=
0.0
)
kv_len_per_split
=
tl
.
cdiv
(
cur_batch_seq_len
,
NUM_KV_SPLITS
)
kv_len_per_split
=
(
tl
.
cdiv
(
tl
.
cdiv
(
cur_batch_seq_len
,
kv_splits
),
MIN_BLOCK_KV
)
*
MIN_BLOCK_KV
)
split_kv_start
=
kv_len_per_split
*
split_kv_id
split_kv_start
=
kv_len_per_split
*
split_kv_id
split_kv_end
=
tl
.
minimum
(
split_kv_start
+
kv_len_per_split
,
cur_batch_seq_len
)
split_kv_end
=
tl
.
minimum
(
split_kv_start
+
kv_len_per_split
,
cur_batch_seq_len
)
...
@@ -96,6 +103,7 @@ def _fwd_kernel_stage1(
...
@@ -96,6 +103,7 @@ def _fwd_kernel_stage1(
acc
=
tl
.
zeros
([
BLOCK_DV
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_DV
],
dtype
=
tl
.
float32
)
if
split_kv_end
>
split_kv_start
:
if
split_kv_end
>
split_kv_start
:
q
=
tl
.
load
(
Q
+
off_q
,
mask
=
mask_d
,
other
=
0.0
)
for
start_n
in
range
(
split_kv_start
,
split_kv_end
,
BLOCK_N
):
for
start_n
in
range
(
split_kv_start
,
split_kv_end
,
BLOCK_N
):
offs_n
=
start_n
+
tl
.
arange
(
0
,
BLOCK_N
)
offs_n
=
start_n
+
tl
.
arange
(
0
,
BLOCK_N
)
kv_loc
=
tl
.
load
(
kv_loc
=
tl
.
load
(
...
@@ -158,11 +166,10 @@ def _fwd_kernel_stage1(
...
@@ -158,11 +166,10 @@ def _fwd_kernel_stage1(
cur_batch
*
stride_mid_ob
cur_batch
*
stride_mid_ob
+
cur_head
*
stride_mid_oh
+
cur_head
*
stride_mid_oh
+
split_kv_id
*
stride_mid_os
+
split_kv_id
*
stride_mid_os
+
Lv
)
//
Lv
)
tl
.
store
(
tl
.
store
(
Att_
Out
+
offs_mid_o_1
,
Att_
Lse
+
offs_mid_o_1
,
e_max
+
tl
.
log
(
e_sum
),
e_max
+
tl
.
log
(
e_sum
),
)
)
...
@@ -172,9 +179,11 @@ def _decode_att_m_fwd(
...
@@ -172,9 +179,11 @@ def _decode_att_m_fwd(
k_buffer
,
k_buffer
,
v_buffer
,
v_buffer
,
att_out
,
att_out
,
att_lse
,
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
num_kv_splits
,
num_kv_splits
,
max_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
,
logit_cap
,
):
):
...
@@ -182,13 +191,13 @@ def _decode_att_m_fwd(
...
@@ -182,13 +191,13 @@ def _decode_att_m_fwd(
# [TODO] work around SGPR limit on MI3xx
# [TODO] work around SGPR limit on MI3xx
if
_is_hip
:
if
_is_hip
:
BLOCK
=
8
BLOCK
=
8
NUM
_KV_SPLITS
=
num
_kv_splits
MAX
_KV_SPLITS
=
max
_kv_splits
Lk
=
k_buffer
.
shape
[
-
1
]
Lk
=
k_buffer
.
shape
[
-
1
]
Lv
=
v_buffer
.
shape
[
-
1
]
Lv
=
v_buffer
.
shape
[
-
1
]
batch
,
head_num
=
kv_indptr
.
shape
[
0
]
-
1
,
q
.
shape
[
1
]
batch
,
head_num
=
kv_indptr
.
shape
[
0
]
-
1
,
q
.
shape
[
1
]
grid
=
(
batch
,
head_num
,
NUM
_KV_SPLITS
)
grid
=
(
batch
,
head_num
,
MAX
_KV_SPLITS
)
kv_group_num
=
q
.
shape
[
1
]
//
k_buffer
.
shape
[
1
]
kv_group_num
=
q
.
shape
[
1
]
//
k_buffer
.
shape
[
1
]
if
kv_group_num
==
1
:
if
kv_group_num
==
1
:
...
@@ -209,6 +218,8 @@ def _decode_att_m_fwd(
...
@@ -209,6 +218,8 @@ def _decode_att_m_fwd(
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
att_out
,
att_out
,
att_lse
,
num_kv_splits
,
q
.
stride
(
0
),
q
.
stride
(
0
),
q
.
stride
(
1
),
q
.
stride
(
1
),
k_buffer
.
stride
(
0
),
k_buffer
.
stride
(
0
),
...
@@ -222,7 +233,7 @@ def _decode_att_m_fwd(
...
@@ -222,7 +233,7 @@ def _decode_att_m_fwd(
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_DV
=
BLOCK_DV
,
BLOCK_DV
=
BLOCK_DV
,
BLOCK_N
=
BLOCK
,
BLOCK_N
=
BLOCK
,
NUM_KV_SPLITS
=
NUM_KV_SPLITS
,
MIN_BLOCK_KV
=
_MIN_BLOCK_KV
,
logit_cap
=
logit_cap
,
logit_cap
=
logit_cap
,
num_warps
=
num_warps
,
num_warps
=
num_warps
,
num_stages
=
2
,
num_stages
=
2
,
...
@@ -240,6 +251,8 @@ def _fwd_grouped_kernel_stage1(
...
@@ -240,6 +251,8 @@ def _fwd_grouped_kernel_stage1(
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
Att_Out
,
Att_Out
,
Att_Lse
,
num_kv_splits
,
stride_qbs
,
stride_qbs
,
stride_qh
,
stride_qh
,
stride_buf_kbs
,
stride_buf_kbs
,
...
@@ -256,7 +269,7 @@ def _fwd_grouped_kernel_stage1(
...
@@ -256,7 +269,7 @@ def _fwd_grouped_kernel_stage1(
BLOCK_DV
:
tl
.
constexpr
,
BLOCK_DV
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_H
:
tl
.
constexpr
,
BLOCK_H
:
tl
.
constexpr
,
NUM_KV_SPLITS
:
tl
.
constexpr
,
MIN_BLOCK_KV
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
Lk
:
tl
.
constexpr
,
Lk
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
...
@@ -281,9 +294,9 @@ def _fwd_grouped_kernel_stage1(
...
@@ -281,9 +294,9 @@ def _fwd_grouped_kernel_stage1(
cur_batch_kv_start_idx
=
tl
.
load
(
kv_indptr
+
cur_batch
)
cur_batch_kv_start_idx
=
tl
.
load
(
kv_indptr
+
cur_batch
)
cur_batch_seq_len
=
tl
.
load
(
kv_indptr
+
cur_batch
+
1
)
-
cur_batch_kv_start_idx
cur_batch_seq_len
=
tl
.
load
(
kv_indptr
+
cur_batch
+
1
)
-
cur_batch_kv_start_idx
kv_splits
=
tl
.
load
(
num_kv_splits
+
cur_batch
)
offs_q
=
cur_batch
*
stride_qbs
+
cur_head
[:,
None
]
*
stride_qh
+
offs_d
[
None
,
:]
offs_q
=
cur_batch
*
stride_qbs
+
cur_head
[:,
None
]
*
stride_qh
+
offs_d
[
None
,
:]
q
=
tl
.
load
(
Q
+
offs_q
,
mask
=
(
mask_h
[:,
None
])
&
(
mask_d
[
None
,
:]),
other
=
0.0
)
if
BLOCK_DPE
>
0
:
if
BLOCK_DPE
>
0
:
offs_dpe
=
BLOCK_DMODEL
+
tl
.
arange
(
0
,
BLOCK_DPE
)
offs_dpe
=
BLOCK_DMODEL
+
tl
.
arange
(
0
,
BLOCK_DPE
)
...
@@ -291,11 +304,10 @@ def _fwd_grouped_kernel_stage1(
...
@@ -291,11 +304,10 @@ def _fwd_grouped_kernel_stage1(
off_qpe
=
(
off_qpe
=
(
cur_batch
*
stride_qbs
+
cur_head
[:,
None
]
*
stride_qh
+
offs_dpe
[
None
,
:]
cur_batch
*
stride_qbs
+
cur_head
[:,
None
]
*
stride_qh
+
offs_dpe
[
None
,
:]
)
)
qpe
=
tl
.
load
(
Q
+
off_qpe
,
mask
=
(
mask_h
[:,
None
])
&
(
mask_dpe
[
None
,
:]),
other
=
0.0
)
kv_len_per_split
=
tl
.
cdiv
(
cur_batch_seq_len
,
NUM_KV_SPLITS
)
kv_len_per_split
=
(
tl
.
cdiv
(
tl
.
cdiv
(
cur_batch_seq_len
,
kv_splits
),
MIN_BLOCK_KV
)
*
MIN_BLOCK_KV
)
split_kv_start
=
kv_len_per_split
*
split_kv_id
split_kv_start
=
kv_len_per_split
*
split_kv_id
split_kv_end
=
tl
.
minimum
(
split_kv_start
+
kv_len_per_split
,
cur_batch_seq_len
)
split_kv_end
=
tl
.
minimum
(
split_kv_start
+
kv_len_per_split
,
cur_batch_seq_len
)
...
@@ -304,6 +316,11 @@ def _fwd_grouped_kernel_stage1(
...
@@ -304,6 +316,11 @@ def _fwd_grouped_kernel_stage1(
acc
=
tl
.
zeros
([
BLOCK_H
,
BLOCK_DV
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_H
,
BLOCK_DV
],
dtype
=
tl
.
float32
)
if
split_kv_end
>
split_kv_start
:
if
split_kv_end
>
split_kv_start
:
q
=
tl
.
load
(
Q
+
offs_q
,
mask
=
(
mask_h
[:,
None
])
&
(
mask_d
[
None
,
:]),
other
=
0.0
)
if
BLOCK_DPE
>
0
:
qpe
=
tl
.
load
(
Q
+
off_qpe
,
mask
=
(
mask_h
[:,
None
])
&
(
mask_dpe
[
None
,
:]),
other
=
0.0
)
for
start_n
in
range
(
split_kv_start
,
split_kv_end
,
BLOCK_N
):
for
start_n
in
range
(
split_kv_start
,
split_kv_end
,
BLOCK_N
):
offs_n
=
start_n
+
tl
.
arange
(
0
,
BLOCK_N
)
offs_n
=
start_n
+
tl
.
arange
(
0
,
BLOCK_N
)
kv_loc
=
tl
.
load
(
kv_loc
=
tl
.
load
(
...
@@ -380,11 +397,10 @@ def _fwd_grouped_kernel_stage1(
...
@@ -380,11 +397,10 @@ def _fwd_grouped_kernel_stage1(
cur_batch
*
stride_mid_ob
cur_batch
*
stride_mid_ob
+
cur_head
*
stride_mid_oh
+
cur_head
*
stride_mid_oh
+
split_kv_id
*
stride_mid_os
+
split_kv_id
*
stride_mid_os
+
Lv
)
//
Lv
)
tl
.
store
(
tl
.
store
(
Att_
Out
+
offs_mid_o_1
,
Att_
Lse
+
offs_mid_o_1
,
e_max
+
tl
.
log
(
e_sum
),
e_max
+
tl
.
log
(
e_sum
),
mask
=
mask_h
,
mask
=
mask_h
,
)
)
...
@@ -395,9 +411,11 @@ def _decode_grouped_att_m_fwd(
...
@@ -395,9 +411,11 @@ def _decode_grouped_att_m_fwd(
k_buffer
,
k_buffer
,
v_buffer
,
v_buffer
,
att_out
,
att_out
,
att_lse
,
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
num_kv_splits
,
num_kv_splits
,
max_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
,
logit_cap
,
):
):
...
@@ -424,11 +442,11 @@ def _decode_grouped_att_m_fwd(
...
@@ -424,11 +442,11 @@ def _decode_grouped_att_m_fwd(
kv_group_num
=
q
.
shape
[
1
]
//
k_buffer
.
shape
[
1
]
kv_group_num
=
q
.
shape
[
1
]
//
k_buffer
.
shape
[
1
]
BLOCK_H
=
16
BLOCK_H
=
16
NUM
_KV_SPLITS
=
num
_kv_splits
MAX
_KV_SPLITS
=
max
_kv_splits
grid
=
(
grid
=
(
batch
,
batch
,
triton
.
cdiv
(
head_num
,
min
(
BLOCK_H
,
kv_group_num
)),
triton
.
cdiv
(
head_num
,
min
(
BLOCK_H
,
kv_group_num
)),
NUM
_KV_SPLITS
,
MAX
_KV_SPLITS
,
)
)
extra_kargs
=
{}
extra_kargs
=
{}
...
@@ -447,6 +465,8 @@ def _decode_grouped_att_m_fwd(
...
@@ -447,6 +465,8 @@ def _decode_grouped_att_m_fwd(
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
att_out
,
att_out
,
att_lse
,
num_kv_splits
,
q
.
stride
(
0
),
q
.
stride
(
0
),
q
.
stride
(
1
),
q
.
stride
(
1
),
k_buffer
.
stride
(
0
),
k_buffer
.
stride
(
0
),
...
@@ -463,7 +483,7 @@ def _decode_grouped_att_m_fwd(
...
@@ -463,7 +483,7 @@ def _decode_grouped_att_m_fwd(
BLOCK_DV
=
BLOCK_DV
,
BLOCK_DV
=
BLOCK_DV
,
BLOCK_N
=
BLOCK
,
BLOCK_N
=
BLOCK
,
BLOCK_H
=
BLOCK_H
,
BLOCK_H
=
BLOCK_H
,
NUM_KV_SPLITS
=
NUM_KV_SPLITS
,
MIN_BLOCK_KV
=
_MIN_BLOCK_KV
,
logit_cap
=
logit_cap
,
logit_cap
=
logit_cap
,
num_warps
=
4
,
num_warps
=
4
,
num_stages
=
num_stages
,
num_stages
=
num_stages
,
...
@@ -476,14 +496,17 @@ def _decode_grouped_att_m_fwd(
...
@@ -476,14 +496,17 @@ def _decode_grouped_att_m_fwd(
@
triton
.
jit
@
triton
.
jit
def
_fwd_kernel_stage2
(
def
_fwd_kernel_stage2
(
Mid_O
,
Mid_O
,
Mid_O_1
,
O
,
O
,
kv_indptr
,
kv_indptr
,
num_kv_splits
,
stride_mid_ob
,
stride_mid_ob
,
stride_mid_oh
,
stride_mid_oh
,
stride_mid_os
,
stride_mid_os
,
stride_obs
,
stride_obs
,
stride_oh
,
stride_oh
,
NUM_KV_SPLITS
:
tl
.
constexpr
,
MAX_KV_SPLITS
:
tl
.
constexpr
,
MIN_BLOCK_KV
:
tl
.
constexpr
,
BLOCK_DV
:
tl
.
constexpr
,
BLOCK_DV
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
):
):
...
@@ -493,6 +516,7 @@ def _fwd_kernel_stage2(
...
@@ -493,6 +516,7 @@ def _fwd_kernel_stage2(
cur_batch_seq_len
=
tl
.
load
(
kv_indptr
+
cur_batch
+
1
)
-
tl
.
load
(
cur_batch_seq_len
=
tl
.
load
(
kv_indptr
+
cur_batch
+
1
)
-
tl
.
load
(
kv_indptr
+
cur_batch
kv_indptr
+
cur_batch
)
)
kv_splits
=
tl
.
load
(
num_kv_splits
+
cur_batch
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DV
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DV
)
mask_d
=
offs_d
<
Lv
mask_d
=
offs_d
<
Lv
...
@@ -502,10 +526,12 @@ def _fwd_kernel_stage2(
...
@@ -502,10 +526,12 @@ def _fwd_kernel_stage2(
acc
=
tl
.
zeros
([
BLOCK_DV
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_DV
],
dtype
=
tl
.
float32
)
offs_v
=
cur_batch
*
stride_mid_ob
+
cur_head
*
stride_mid_oh
+
offs_d
offs_v
=
cur_batch
*
stride_mid_ob
+
cur_head
*
stride_mid_oh
+
offs_d
offs_logic
=
cur_batch
*
stride_mid_ob
+
cur_head
*
stride_mid_oh
+
Lv
offs_logic
=
(
cur_batch
*
stride_mid_ob
+
cur_head
*
stride_mid_oh
)
//
Lv
kv_len_per_split
=
(
tl
.
cdiv
(
tl
.
cdiv
(
cur_batch_seq_len
,
kv_splits
),
MIN_BLOCK_KV
)
*
MIN_BLOCK_KV
)
for
split_kv_id
in
range
(
0
,
NUM_KV_SPLITS
):
for
split_kv_id
in
range
(
0
,
MAX_KV_SPLITS
):
kv_len_per_split
=
tl
.
cdiv
(
cur_batch_seq_len
,
NUM_KV_SPLITS
)
split_kv_start
=
kv_len_per_split
*
split_kv_id
split_kv_start
=
kv_len_per_split
*
split_kv_id
split_kv_end
=
tl
.
minimum
(
split_kv_start
+
kv_len_per_split
,
cur_batch_seq_len
)
split_kv_end
=
tl
.
minimum
(
split_kv_start
+
kv_len_per_split
,
cur_batch_seq_len
)
...
@@ -513,7 +539,7 @@ def _fwd_kernel_stage2(
...
@@ -513,7 +539,7 @@ def _fwd_kernel_stage2(
tv
=
tl
.
load
(
tv
=
tl
.
load
(
Mid_O
+
offs_v
+
split_kv_id
*
stride_mid_os
,
mask
=
mask_d
,
other
=
0.0
Mid_O
+
offs_v
+
split_kv_id
*
stride_mid_os
,
mask
=
mask_d
,
other
=
0.0
)
)
tlogic
=
tl
.
load
(
Mid_O
+
offs_logic
+
split_kv_id
*
stride_mid_os
)
tlogic
=
tl
.
load
(
Mid_O
_1
+
offs_logic
+
split_kv_id
*
stride_mid_os
//
Lv
)
n_e_max
=
tl
.
maximum
(
tlogic
,
e_max
)
n_e_max
=
tl
.
maximum
(
tlogic
,
e_max
)
old_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
old_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
...
@@ -533,17 +559,19 @@ def _fwd_kernel_stage2(
...
@@ -533,17 +559,19 @@ def _fwd_kernel_stage2(
def
_decode_softmax_reducev_fwd
(
def
_decode_softmax_reducev_fwd
(
logits
,
logits
,
lse
,
q
,
q
,
o
,
o
,
v_buffer
,
v_buffer
,
kv_indptr
,
kv_indptr
,
num_kv_splits
,
num_kv_splits
,
max_kv_splits
,
):
):
batch
,
head_num
=
q
.
shape
[
0
],
q
.
shape
[
1
]
batch
,
head_num
=
q
.
shape
[
0
],
q
.
shape
[
1
]
Lv
=
v_buffer
.
shape
[
-
1
]
Lv
=
v_buffer
.
shape
[
-
1
]
BLOCK_DV
=
triton
.
next_power_of_2
(
Lv
)
BLOCK_DV
=
triton
.
next_power_of_2
(
Lv
)
NUM
_KV_SPLITS
=
num
_kv_splits
MAX
_KV_SPLITS
=
max
_kv_splits
extra_kargs
=
{}
extra_kargs
=
{}
if
_is_hip
:
if
_is_hip
:
...
@@ -554,14 +582,17 @@ def _decode_softmax_reducev_fwd(
...
@@ -554,14 +582,17 @@ def _decode_softmax_reducev_fwd(
grid
=
(
batch
,
head_num
)
grid
=
(
batch
,
head_num
)
_fwd_kernel_stage2
[
grid
](
_fwd_kernel_stage2
[
grid
](
logits
,
logits
,
lse
,
o
,
o
,
kv_indptr
,
kv_indptr
,
num_kv_splits
,
logits
.
stride
(
0
),
logits
.
stride
(
0
),
logits
.
stride
(
1
),
logits
.
stride
(
1
),
logits
.
stride
(
2
),
logits
.
stride
(
2
),
o
.
stride
(
0
),
o
.
stride
(
0
),
o
.
stride
(
1
),
o
.
stride
(
1
),
NUM_KV_SPLITS
=
NUM_KV_SPLITS
,
MAX_KV_SPLITS
=
MAX_KV_SPLITS
,
MIN_BLOCK_KV
=
_MIN_BLOCK_KV
,
BLOCK_DV
=
BLOCK_DV
,
BLOCK_DV
=
BLOCK_DV
,
Lv
=
Lv
,
Lv
=
Lv
,
num_warps
=
4
,
num_warps
=
4
,
...
@@ -579,6 +610,7 @@ def decode_attention_fwd_normal(
...
@@ -579,6 +610,7 @@ def decode_attention_fwd_normal(
kv_indices
,
kv_indices
,
attn_logits
,
attn_logits
,
num_kv_splits
,
num_kv_splits
,
max_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
=
0.0
,
logit_cap
=
0.0
,
):
):
...
@@ -586,14 +618,25 @@ def decode_attention_fwd_normal(
...
@@ -586,14 +618,25 @@ def decode_attention_fwd_normal(
q
,
q
,
k_buffer
,
k_buffer
,
v_buffer
,
v_buffer
,
attn_logits
,
attn_logits
[
0
],
attn_logits
[
1
],
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
num_kv_splits
,
num_kv_splits
,
max_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
,
logit_cap
,
)
)
_decode_softmax_reducev_fwd
(
attn_logits
,
q
,
o
,
v_buffer
,
kv_indptr
,
num_kv_splits
)
_decode_softmax_reducev_fwd
(
attn_logits
[
0
],
attn_logits
[
1
],
q
,
o
,
v_buffer
,
kv_indptr
,
num_kv_splits
,
max_kv_splits
,
)
def
decode_attention_fwd_grouped
(
def
decode_attention_fwd_grouped
(
...
@@ -605,6 +648,7 @@ def decode_attention_fwd_grouped(
...
@@ -605,6 +648,7 @@ def decode_attention_fwd_grouped(
kv_indices
,
kv_indices
,
attn_logits
,
attn_logits
,
num_kv_splits
,
num_kv_splits
,
max_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
=
0.0
,
logit_cap
=
0.0
,
):
):
...
@@ -612,14 +656,25 @@ def decode_attention_fwd_grouped(
...
@@ -612,14 +656,25 @@ def decode_attention_fwd_grouped(
q
,
q
,
k_buffer
,
k_buffer
,
v_buffer
,
v_buffer
,
attn_logits
,
attn_logits
[
0
],
attn_logits
[
1
],
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
num_kv_splits
,
num_kv_splits
,
max_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
,
logit_cap
,
)
)
_decode_softmax_reducev_fwd
(
attn_logits
,
q
,
o
,
v_buffer
,
kv_indptr
,
num_kv_splits
)
_decode_softmax_reducev_fwd
(
attn_logits
[
0
],
attn_logits
[
1
],
q
,
o
,
v_buffer
,
kv_indptr
,
num_kv_splits
,
max_kv_splits
,
)
def
decode_attention_fwd
(
def
decode_attention_fwd
(
...
@@ -631,12 +686,13 @@ def decode_attention_fwd(
...
@@ -631,12 +686,13 @@ def decode_attention_fwd(
kv_indices
,
kv_indices
,
attn_logits
,
attn_logits
,
num_kv_splits
,
num_kv_splits
,
max_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
=
0.0
,
logit_cap
=
0.0
,
):
):
assert
num
_kv_splits
==
attn_logits
.
shape
[
2
]
assert
max
_kv_splits
==
attn_logits
[
0
]
.
shape
[
2
]
assert
q
.
shape
[
0
]
<=
kv_indptr
.
shape
[
0
]
-
1
assert
q
.
shape
[
0
]
<=
kv_indptr
.
shape
[
0
]
-
1
assert
q
.
shape
[
0
]
<=
attn_logits
.
shape
[
0
]
assert
q
.
shape
[
0
]
<=
attn_logits
[
0
]
.
shape
[
0
]
kv_group_num
=
q
.
shape
[
1
]
//
v_buffer
.
shape
[
1
]
kv_group_num
=
q
.
shape
[
1
]
//
v_buffer
.
shape
[
1
]
...
@@ -651,6 +707,7 @@ def decode_attention_fwd(
...
@@ -651,6 +707,7 @@ def decode_attention_fwd(
kv_indices
,
kv_indices
,
attn_logits
,
attn_logits
,
num_kv_splits
,
num_kv_splits
,
max_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
,
logit_cap
,
)
)
...
@@ -665,6 +722,7 @@ def decode_attention_fwd(
...
@@ -665,6 +722,7 @@ def decode_attention_fwd(
kv_indices
,
kv_indices
,
attn_logits
,
attn_logits
,
num_kv_splits
,
num_kv_splits
,
max_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
,
logit_cap
,
)
)
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
c0e9a36c
...
@@ -26,6 +26,7 @@ import tqdm
...
@@ -26,6 +26,7 @@ import tqdm
from
sglang.srt.custom_op
import
CustomOp
from
sglang.srt.custom_op
import
CustomOp
from
sglang.srt.distributed
import
get_tensor_model_parallel_rank
from
sglang.srt.distributed
import
get_tensor_model_parallel_rank
from
sglang.srt.distributed.parallel_state
import
GroupCoordinator
,
graph_capture
from
sglang.srt.distributed.parallel_state
import
GroupCoordinator
,
graph_capture
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.moe.fused_moe_native
import
fused_moe_forward_native
from
sglang.srt.layers.moe.fused_moe_native
import
fused_moe_forward_native
from
sglang.srt.layers.torchao_utils
import
save_gemlite_cache
from
sglang.srt.layers.torchao_utils
import
save_gemlite_cache
...
@@ -195,6 +196,9 @@ class CudaGraphRunner:
...
@@ -195,6 +196,9 @@ class CudaGraphRunner:
# Attention backend
# Attention backend
self
.
max_bs
=
max
(
self
.
capture_bs
)
self
.
max_bs
=
max
(
self
.
capture_bs
)
self
.
max_num_token
=
self
.
max_bs
*
self
.
num_tokens_per_bs
self
.
max_num_token
=
self
.
max_bs
*
self
.
num_tokens_per_bs
self
.
num_head
=
(
model_runner
.
model_config
.
num_attention_heads
//
get_attention_tp_size
()
)
self
.
model_runner
.
attn_backend
.
init_cuda_graph_state
(
self
.
max_num_token
)
self
.
model_runner
.
attn_backend
.
init_cuda_graph_state
(
self
.
max_num_token
)
self
.
seq_len_fill_value
=
(
self
.
seq_len_fill_value
=
(
self
.
model_runner
.
attn_backend
.
get_cuda_graph_seq_len_fill_value
()
self
.
model_runner
.
attn_backend
.
get_cuda_graph_seq_len_fill_value
()
...
@@ -503,9 +507,15 @@ class CudaGraphRunner:
...
@@ -503,9 +507,15 @@ class CudaGraphRunner:
if
hasattr
(
forward_batch
.
spec_info
,
"hidden_states"
):
if
hasattr
(
forward_batch
.
spec_info
,
"hidden_states"
):
self
.
hidden_states
[:
raw_num_token
]
=
forward_batch
.
spec_info
.
hidden_states
self
.
hidden_states
[:
raw_num_token
]
=
forward_batch
.
spec_info
.
hidden_states
num_kv_heads
=
self
.
num_head
if
hasattr
(
forward_batch
.
token_to_kv_pool
,
"k_buffer"
):
if
isinstance
(
forward_batch
.
token_to_kv_pool
.
k_buffer
,
list
):
num_kv_heads
=
forward_batch
.
token_to_kv_pool
.
k_buffer
[
0
].
shape
[
1
]
# Attention backend
# Attention backend
self
.
model_runner
.
attn_backend
.
init_forward_metadata_replay_cuda_graph
(
self
.
model_runner
.
attn_backend
.
init_forward_metadata_replay_cuda_graph
(
bs
,
bs
,
num_kv_heads
,
self
.
req_pool_indices
,
self
.
req_pool_indices
,
self
.
seq_lens
,
self
.
seq_lens
,
forward_batch
.
seq_lens_sum
+
(
bs
-
raw_bs
),
forward_batch
.
seq_lens_sum
+
(
bs
-
raw_bs
),
...
...
test/srt/test_triton_attention_kernels.py
View file @
c0e9a36c
...
@@ -228,7 +228,8 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -228,7 +228,8 @@ class TestTritonAttention(unittest.TestCase):
seq_len
=
10
# This represents the number of tokens already in the sequence
seq_len
=
10
# This represents the number of tokens already in the sequence
total_tokens
=
B
*
seq_len
total_tokens
=
B
*
seq_len
sm_scale
=
1.0
/
(
D
**
0.5
)
sm_scale
=
1.0
/
(
D
**
0.5
)
num_kv_splits
=
8
max_kv_splits
=
8
num_kv_splits
=
torch
.
full
((
B
,),
4
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
# q represents the new token being generated, one per batch
# q represents the new token being generated, one per batch
q
=
torch
.
randn
(
B
,
H_Q
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
q
=
torch
.
randn
(
B
,
H_Q
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
...
@@ -247,7 +248,12 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -247,7 +248,12 @@ class TestTritonAttention(unittest.TestCase):
kv_indices
=
torch
.
arange
(
total_tokens
,
device
=
"cuda"
)
kv_indices
=
torch
.
arange
(
total_tokens
,
device
=
"cuda"
)
attn_logits
=
torch
.
empty
(
attn_logits
=
torch
.
empty
(
(
B
,
H_Q
,
num_kv_splits
,
D
+
1
),
(
B
,
H_Q
,
max_kv_splits
,
D
),
dtype
=
torch
.
float32
,
device
=
"cuda"
,
)
attn_lse
=
torch
.
empty
(
(
B
,
H_Q
,
max_kv_splits
),
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
device
=
"cuda"
,
device
=
"cuda"
,
)
)
...
@@ -259,8 +265,9 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -259,8 +265,9 @@ class TestTritonAttention(unittest.TestCase):
o
,
o
,
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
attn_logits
,
(
attn_logits
,
attn_lse
),
num_kv_splits
,
num_kv_splits
,
max_kv_splits
,
sm_scale
,
sm_scale
,
)
)
...
@@ -284,7 +291,8 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -284,7 +291,8 @@ class TestTritonAttention(unittest.TestCase):
seq_len
=
S
# This represents the number of tokens already in the sequence
seq_len
=
S
# This represents the number of tokens already in the sequence
total_tokens
=
B
*
seq_len
total_tokens
=
B
*
seq_len
sm_scale
=
1.0
/
(
D
**
0.5
)
sm_scale
=
1.0
/
(
D
**
0.5
)
num_kv_splits
=
8
max_kv_splits
=
8
num_kv_splits
=
torch
.
full
((
B
,),
4
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
# q represents the new token being generated, one per batch
# q represents the new token being generated, one per batch
q
=
torch
.
randn
(
B
,
H_Q
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
q
=
torch
.
randn
(
B
,
H_Q
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
...
@@ -304,7 +312,12 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -304,7 +312,12 @@ class TestTritonAttention(unittest.TestCase):
kv_indices
=
torch
.
arange
(
total_tokens
,
device
=
"cuda"
)
kv_indices
=
torch
.
arange
(
total_tokens
,
device
=
"cuda"
)
attn_logits
=
torch
.
empty
(
attn_logits
=
torch
.
empty
(
(
B
,
H_Q
,
num_kv_splits
,
D_V
+
1
),
(
B
,
H_Q
,
max_kv_splits
,
D_V
),
dtype
=
torch
.
float32
,
device
=
"cuda"
,
)
attn_lse
=
torch
.
empty
(
(
B
,
H_Q
,
max_kv_splits
),
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
device
=
"cuda"
,
device
=
"cuda"
,
)
)
...
@@ -316,13 +329,19 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -316,13 +329,19 @@ class TestTritonAttention(unittest.TestCase):
o
,
o
,
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
attn_logits
,
(
attn_logits
,
attn_lse
),
num_kv_splits
,
num_kv_splits
,
max_kv_splits
,
sm_scale
,
sm_scale
,
)
)
attn_logits1
=
torch
.
empty
(
attn_logits1
=
torch
.
empty
(
(
B
,
H_Q
,
num_kv_splits
,
D_V
+
1
),
(
B
,
H_Q
,
max_kv_splits
,
D_V
),
dtype
=
torch
.
float32
,
device
=
"cuda"
,
)
attn_lse1
=
torch
.
empty
(
(
B
,
H_Q
,
max_kv_splits
,
D_V
),
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
device
=
"cuda"
,
device
=
"cuda"
,
)
)
...
@@ -334,8 +353,9 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -334,8 +353,9 @@ class TestTritonAttention(unittest.TestCase):
o_grouped
,
o_grouped
,
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
attn_logits1
,
(
attn_logits1
,
attn_lse1
),
num_kv_splits
,
num_kv_splits
,
max_kv_splits
,
sm_scale
,
sm_scale
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment