Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
c0e9a36c
"csrc/segment_csr.cpp2" did not exist on "bf1f1014734799cad7db901a117ac43b30c852f0"
Unverified
Commit
c0e9a36c
authored
Mar 19, 2025
by
JieXin Liang
Committed by
GitHub
Mar 18, 2025
Browse files
Optimize Triton decoding kernel for dynamic workload (#4553)
parent
588865f0
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
277 additions
and
57 deletions
+277
-57
python/sglang/srt/layers/attention/base_attn_backend.py
python/sglang/srt/layers/attention/base_attn_backend.py
+1
-0
python/sglang/srt/layers/attention/flashinfer_backend.py
python/sglang/srt/layers/attention/flashinfer_backend.py
+2
-0
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
+2
-0
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+142
-15
python/sglang/srt/layers/attention/triton_ops/decode_attention.py
...glang/srt/layers/attention/triton_ops/decode_attention.py
+92
-34
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+10
-0
test/srt/test_triton_attention_kernels.py
test/srt/test_triton_attention_kernels.py
+28
-8
No files found.
python/sglang/srt/layers/attention/base_attn_backend.py
View file @
c0e9a36c
...
@@ -39,6 +39,7 @@ class AttentionBackend(ABC):
...
@@ -39,6 +39,7 @@ class AttentionBackend(ABC):
def
init_forward_metadata_replay_cuda_graph
(
def
init_forward_metadata_replay_cuda_graph
(
self
,
self
,
bs
:
int
,
bs
:
int
,
num_kv_heads
:
int
,
req_pool_indices
:
torch
.
Tensor
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
...
...
python/sglang/srt/layers/attention/flashinfer_backend.py
View file @
c0e9a36c
...
@@ -349,6 +349,7 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -349,6 +349,7 @@ class FlashInferAttnBackend(AttentionBackend):
def
init_forward_metadata_replay_cuda_graph
(
def
init_forward_metadata_replay_cuda_graph
(
self
,
self
,
bs
:
int
,
bs
:
int
,
num_kv_heads
:
int
,
req_pool_indices
:
torch
.
Tensor
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
...
@@ -1062,6 +1063,7 @@ class FlashInferMultiStepDraftBackend:
...
@@ -1062,6 +1063,7 @@ class FlashInferMultiStepDraftBackend:
def
call_fn
(
i
,
forward_batch
):
def
call_fn
(
i
,
forward_batch
):
self
.
attn_backends
[
i
].
init_forward_metadata_replay_cuda_graph
(
self
.
attn_backends
[
i
].
init_forward_metadata_replay_cuda_graph
(
bs
,
bs
,
-
1
,
forward_batch
.
req_pool_indices
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
forward_batch
.
seq_lens
,
seq_lens_sum
=-
1
,
seq_lens_sum
=-
1
,
...
...
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
View file @
c0e9a36c
...
@@ -279,6 +279,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
...
@@ -279,6 +279,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
def
init_forward_metadata_replay_cuda_graph
(
def
init_forward_metadata_replay_cuda_graph
(
self
,
self
,
bs
:
int
,
bs
:
int
,
num_kv_heads
:
int
,
req_pool_indices
:
torch
.
Tensor
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
...
@@ -791,6 +792,7 @@ class FlashInferMLAMultiStepDraftBackend:
...
@@ -791,6 +792,7 @@ class FlashInferMLAMultiStepDraftBackend:
def
call_fn
(
i
,
forward_batch
):
def
call_fn
(
i
,
forward_batch
):
self
.
attn_backends
[
i
].
init_forward_metadata_replay_cuda_graph
(
self
.
attn_backends
[
i
].
init_forward_metadata_replay_cuda_graph
(
bs
,
bs
,
-
1
,
forward_batch
.
req_pool_indices
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
forward_batch
.
seq_lens
,
seq_lens_sum
=-
1
,
seq_lens_sum
=-
1
,
...
...
python/sglang/srt/layers/attention/triton_backend.py
View file @
c0e9a36c
...
@@ -4,11 +4,13 @@ from typing import TYPE_CHECKING, Optional, Union
...
@@ -4,11 +4,13 @@ from typing import TYPE_CHECKING, Optional, Union
import
torch
import
torch
import
triton
import
triton
import
triton.language
as
tl
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.layers.attention.utils
import
create_flashinfer_kv_indices_triton
from
sglang.srt.layers.attention.utils
import
create_flashinfer_kv_indices_triton
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.utils
import
get_bool_env_var
,
get_device_core_count
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
...
@@ -16,6 +18,51 @@ if TYPE_CHECKING:
...
@@ -16,6 +18,51 @@ if TYPE_CHECKING:
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
EagleVerifyInput
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
EagleVerifyInput
@
triton
.
jit
def
get_num_kv_splits_triton
(
num_kv_splits_ptr
,
seq_lens_ptr
,
bs
,
num_head
,
num_kv_head
,
max_kv_splits
,
device_core_count
,
MAX_BS
:
tl
.
constexpr
,
):
# TODO: this method is tunable
offs_b
=
tl
.
arange
(
0
,
MAX_BS
)
mask_b
=
offs_b
<
bs
seq_lens
=
tl
.
load
(
seq_lens_ptr
+
offs_b
,
mask
=
mask_b
,
other
=
0
)
max_seq_len
=
tl
.
max
(
seq_lens
)
seq_lens
=
tl
.
load
(
seq_lens_ptr
+
offs_b
,
mask
=
mask_b
,
other
=
max_seq_len
)
min_seq_len
=
tl
.
min
(
seq_lens
)
if
max_seq_len
*
8
<
min_seq_len
*
10
:
min_seq_len
=
max_seq_len
max_kv_splits_1
=
tl
.
minimum
(
tl
.
cdiv
(
max_seq_len
,
min_seq_len
),
max_kv_splits
)
kv_chunk_size_1
=
tl
.
cdiv
(
max_seq_len
,
max_kv_splits_1
)
# NOTE: this is a hack to let num_kv_split grows up with seqlen gradually
ext_seq_len
=
tl
.
cast
(
tl
.
cdiv
(
max_seq_len
,
256
),
tl
.
float32
)
ext_device_core_count
=
device_core_count
*
tl
.
maximum
(
tl
.
cast
(
tl
.
ceil
(
tl
.
log2
(
ext_seq_len
)),
tl
.
int32
),
1
)
block_h
,
num_kv_group
=
16
,
num_head
//
num_kv_head
if
num_kv_group
==
1
:
bh_grid
=
bs
*
num_head
else
:
# from triton_ops/decode_attention.py:_decode_grouped_att_m_fwd
block_h
=
tl
.
minimum
(
block_h
,
num_kv_group
)
bh_grid
=
bs
*
tl
.
cdiv
(
num_head
,
block_h
)
max_kv_splits_2
=
tl
.
minimum
(
tl
.
cdiv
(
ext_device_core_count
,
bh_grid
),
max_kv_splits
)
kv_chunk_size_2
=
tl
.
cdiv
(
max_seq_len
,
max_kv_splits_2
)
num_kv_splits
=
tl
.
maximum
(
tl
.
cdiv
(
seq_lens
,
kv_chunk_size_1
),
tl
.
cdiv
(
seq_lens
,
kv_chunk_size_2
)
)
tl
.
store
(
num_kv_splits_ptr
+
offs_b
,
num_kv_splits
,
mask
=
mask_b
)
class
TritonAttnBackend
(
AttentionBackend
):
class
TritonAttnBackend
(
AttentionBackend
):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -64,7 +111,10 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -64,7 +111,10 @@ class TritonAttnBackend(AttentionBackend):
model_runner
.
model_config
.
num_attention_heads
//
get_attention_tp_size
()
model_runner
.
model_config
.
num_attention_heads
//
get_attention_tp_size
()
)
)
self
.
num_kv_splits
=
model_runner
.
server_args
.
triton_attention_num_kv_splits
self
.
static_kv_splits
=
get_bool_env_var
(
"SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS"
,
"false"
)
self
.
max_kv_splits
=
model_runner
.
server_args
.
triton_attention_num_kv_splits
self
.
v_head_dim
=
model_runner
.
token_to_kv_pool
.
get_value_buffer
(
0
).
shape
[
-
1
]
self
.
v_head_dim
=
model_runner
.
token_to_kv_pool
.
get_value_buffer
(
0
).
shape
[
-
1
]
self
.
forward_metadata
=
None
self
.
forward_metadata
=
None
...
@@ -72,6 +122,30 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -72,6 +122,30 @@ class TritonAttnBackend(AttentionBackend):
self
.
max_context_len
=
model_runner
.
model_config
.
context_len
self
.
max_context_len
=
model_runner
.
model_config
.
context_len
self
.
device
=
model_runner
.
device
self
.
device
=
model_runner
.
device
self
.
device_core_count
=
get_device_core_count
(
model_runner
.
gpu_id
)
def
get_num_kv_splits
(
self
,
num_kv_splits
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
bs
:
int
,
num_kv_head
:
int
,
):
MAX_SCHEDULE_BS
=
4096
if
self
.
static_kv_splits
or
self
.
device_core_count
<=
0
or
bs
>
MAX_SCHEDULE_BS
:
num_kv_splits
.
fill_
(
self
.
max_kv_splits
)
return
get_num_kv_splits_triton
[(
1
,)](
num_kv_splits
,
seq_lens
,
bs
,
self
.
num_head
,
num_kv_head
,
self
.
max_kv_splits
,
self
.
device_core_count
,
MAX_BS
=
MAX_SCHEDULE_BS
,
)
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
"""Init auxiliary variables for triton attention backend."""
"""Init auxiliary variables for triton attention backend."""
...
@@ -100,15 +174,35 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -100,15 +174,35 @@ class TritonAttnBackend(AttentionBackend):
kv_indptr
,
kv_indices
=
spec_info
.
kv_indptr
,
spec_info
.
kv_indices
kv_indptr
,
kv_indices
=
spec_info
.
kv_indptr
,
spec_info
.
kv_indices
bs
=
kv_indptr
.
shape
[
0
]
-
1
bs
=
kv_indptr
.
shape
[
0
]
-
1
attn_logits
=
torch
.
empty
(
attn_logits
=
[
torch
.
empty
(
(
bs
,
self
.
num_head
,
self
.
max_kv_splits
,
self
.
v_head_dim
,
),
dtype
=
torch
.
float32
,
device
=
self
.
device
,
),
torch
.
empty
(
(
(
bs
,
bs
,
self
.
num_head
,
self
.
num_head
,
self
.
num_kv_splits
,
self
.
max_kv_splits
,
self
.
v_head_dim
+
1
,
),
),
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
device
=
self
.
device
,
device
=
self
.
device
,
),
]
num_kv_splits
=
torch
.
empty
((
bs
,),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
num_kv_heads
=
self
.
num_head
if
hasattr
(
forward_batch
.
token_to_kv_pool
,
"k_buffer"
):
if
isinstance
(
forward_batch
.
token_to_kv_pool
.
k_buffer
,
list
):
num_kv_heads
=
forward_batch
.
token_to_kv_pool
.
k_buffer
[
0
].
shape
[
1
]
self
.
get_num_kv_splits
(
num_kv_splits
,
forward_batch
.
seq_lens
,
bs
,
num_kv_heads
)
)
qo_indptr
=
None
qo_indptr
=
None
...
@@ -148,6 +242,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -148,6 +242,7 @@ class TritonAttnBackend(AttentionBackend):
mask_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
seq_mask_len
[:
bs
],
dim
=
0
)
mask_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
seq_mask_len
[:
bs
],
dim
=
0
)
mask_indptr
=
mask_indptr
[:
bs
+
1
]
mask_indptr
=
mask_indptr
[:
bs
+
1
]
max_extend_len
=
self
.
num_draft_tokens
max_extend_len
=
self
.
num_draft_tokens
num_kv_splits
=
None
attn_logits
=
None
attn_logits
=
None
elif
forward_batch
.
forward_mode
.
is_draft_extend
():
elif
forward_batch
.
forward_mode
.
is_draft_extend
():
kv_indices
,
kv_indptr
,
qo_indptr
,
custom_mask
=
(
kv_indices
,
kv_indptr
,
qo_indptr
,
custom_mask
=
(
...
@@ -160,6 +255,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -160,6 +255,7 @@ class TritonAttnBackend(AttentionBackend):
)
)
mask_indptr
=
None
mask_indptr
=
None
max_extend_len
=
torch
.
max
(
spec_info
.
accept_length
).
item
()
max_extend_len
=
torch
.
max
(
spec_info
.
accept_length
).
item
()
num_kv_splits
=
None
attn_logits
=
None
attn_logits
=
None
else
:
else
:
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
...
@@ -188,10 +284,12 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -188,10 +284,12 @@ class TritonAttnBackend(AttentionBackend):
mask_indptr
=
None
mask_indptr
=
None
attn_logits
=
None
attn_logits
=
None
max_extend_len
=
torch
.
max
(
forward_batch
.
extend_seq_lens
).
item
()
max_extend_len
=
torch
.
max
(
forward_batch
.
extend_seq_lens
).
item
()
num_kv_splits
=
None
self
.
forward_metadata
=
(
self
.
forward_metadata
=
(
attn_logits
,
attn_logits
,
max_extend_len
,
max_extend_len
,
num_kv_splits
,
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
qo_indptr
,
qo_indptr
,
...
@@ -202,10 +300,20 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -202,10 +300,20 @@ class TritonAttnBackend(AttentionBackend):
def
init_cuda_graph_state
(
def
init_cuda_graph_state
(
self
,
max_bs
:
int
,
kv_indices_buf
:
Optional
[
torch
.
Tensor
]
=
None
self
,
max_bs
:
int
,
kv_indices_buf
:
Optional
[
torch
.
Tensor
]
=
None
):
):
self
.
cuda_graph_attn_logits
=
torch
.
zeros
(
self
.
cuda_graph_attn_logits
=
[
(
max_bs
,
self
.
num_head
,
self
.
num_kv_splits
,
self
.
v_head_dim
+
1
),
torch
.
zeros
(
(
max_bs
,
self
.
num_head
,
self
.
max_kv_splits
,
self
.
v_head_dim
),
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
device
=
self
.
device
,
device
=
self
.
device
,
),
torch
.
zeros
(
(
max_bs
,
self
.
num_head
,
self
.
max_kv_splits
),
dtype
=
torch
.
float32
,
device
=
self
.
device
,
),
]
self
.
cuda_graph_num_kv_splits
=
torch
.
full
(
(
max_bs
,),
self
.
max_kv_splits
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
)
if
kv_indices_buf
is
None
:
if
kv_indices_buf
is
None
:
self
.
cuda_graph_kv_indices
=
torch
.
zeros
(
self
.
cuda_graph_kv_indices
=
torch
.
zeros
(
...
@@ -255,6 +363,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -255,6 +363,7 @@ class TritonAttnBackend(AttentionBackend):
attn_logits
=
self
.
cuda_graph_attn_logits
attn_logits
=
self
.
cuda_graph_attn_logits
max_extend_len
=
None
max_extend_len
=
None
num_kv_splits
=
self
.
cuda_graph_num_kv_splits
qo_indptr
=
None
qo_indptr
=
None
custom_mask
=
None
custom_mask
=
None
mask_indptr
=
None
mask_indptr
=
None
...
@@ -285,6 +394,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -285,6 +394,7 @@ class TritonAttnBackend(AttentionBackend):
mask_indptr
=
self
.
mask_indptr
[:
bs
+
1
]
mask_indptr
=
self
.
mask_indptr
[:
bs
+
1
]
mask_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
seq_mask_len
,
dim
=
0
)
mask_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
seq_mask_len
,
dim
=
0
)
max_extend_len
=
self
.
num_draft_tokens
max_extend_len
=
self
.
num_draft_tokens
num_kv_splits
=
None
attn_logits
=
None
attn_logits
=
None
else
:
else
:
raise
ValueError
(
raise
ValueError
(
...
@@ -294,6 +404,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -294,6 +404,7 @@ class TritonAttnBackend(AttentionBackend):
self
.
forward_metadata
=
(
self
.
forward_metadata
=
(
attn_logits
,
attn_logits
,
max_extend_len
,
max_extend_len
,
num_kv_splits
,
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
qo_indptr
,
qo_indptr
,
...
@@ -304,6 +415,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -304,6 +415,7 @@ class TritonAttnBackend(AttentionBackend):
def
init_forward_metadata_replay_cuda_graph
(
def
init_forward_metadata_replay_cuda_graph
(
self
,
self
,
bs
:
int
,
bs
:
int
,
num_kv_head
:
int
,
req_pool_indices
:
torch
.
Tensor
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
...
@@ -317,6 +429,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -317,6 +429,7 @@ class TritonAttnBackend(AttentionBackend):
# Update kv_indptr, kv_indices
# Update kv_indptr, kv_indices
kv_indptr
=
self
.
kv_indptr
kv_indptr
=
self
.
kv_indptr
kv_indices
=
self
.
cuda_graph_kv_indices
kv_indices
=
self
.
cuda_graph_kv_indices
num_kv_splits
=
self
.
cuda_graph_num_kv_splits
if
spec_info
is
None
:
if
spec_info
is
None
:
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
seq_lens
[:
bs
],
dim
=
0
)
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
seq_lens
[:
bs
],
dim
=
0
)
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indptr
=
kv_indptr
[:
bs
+
1
]
...
@@ -332,6 +445,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -332,6 +445,7 @@ class TritonAttnBackend(AttentionBackend):
else
:
else
:
kv_indptr
[:
spec_info
.
kv_indptr
.
shape
[
0
]]
=
spec_info
.
kv_indptr
kv_indptr
[:
spec_info
.
kv_indptr
.
shape
[
0
]]
=
spec_info
.
kv_indptr
kv_indices
[:
spec_info
.
kv_indices
.
shape
[
0
]]
=
spec_info
.
kv_indices
kv_indices
[:
spec_info
.
kv_indices
.
shape
[
0
]]
=
spec_info
.
kv_indices
self
.
get_num_kv_splits
(
num_kv_splits
,
seq_lens
,
bs
,
num_kv_head
)
elif
forward_mode
.
is_target_verify
():
elif
forward_mode
.
is_target_verify
():
# Update qo_indptr, kv_indptr, kv_indices, custom_mask, mask_indptr
# Update qo_indptr, kv_indptr, kv_indices, custom_mask, mask_indptr
bs
=
len
(
req_pool_indices
)
bs
=
len
(
req_pool_indices
)
...
@@ -391,6 +505,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -391,6 +505,7 @@ class TritonAttnBackend(AttentionBackend):
(
(
_
,
_
,
max_extend_len
,
max_extend_len
,
_
,
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
qo_indptr
,
qo_indptr
,
...
@@ -435,7 +550,9 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -435,7 +550,9 @@ class TritonAttnBackend(AttentionBackend):
else
:
else
:
o
=
torch
.
empty_like
(
q
)
o
=
torch
.
empty_like
(
q
)
attn_logits
,
_
,
kv_indptr
,
kv_indices
,
_
,
_
,
_
=
self
.
forward_metadata
attn_logits
,
_
,
num_kv_splits
,
kv_indptr
,
kv_indices
,
_
,
_
,
_
=
(
self
.
forward_metadata
)
if
save_kv_cache
:
if
save_kv_cache
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
...
@@ -450,7 +567,8 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -450,7 +567,8 @@ class TritonAttnBackend(AttentionBackend):
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
attn_logits
,
attn_logits
,
self
.
num_kv_splits
,
num_kv_splits
,
self
.
max_kv_splits
,
layer
.
scaling
,
layer
.
scaling
,
layer
.
logit_cap
,
layer
.
logit_cap
,
)
)
...
@@ -493,6 +611,9 @@ class TritonMultiStepDraftBackend:
...
@@ -493,6 +611,9 @@ class TritonMultiStepDraftBackend:
)
)
)
)
self
.
max_context_len
=
self
.
attn_backends
[
0
].
max_context_len
self
.
max_context_len
=
self
.
attn_backends
[
0
].
max_context_len
self
.
num_head
=
(
model_runner
.
model_config
.
num_attention_heads
//
get_attention_tp_size
()
)
self
.
device
=
model_runner
.
device
self
.
device
=
model_runner
.
device
# Cached variables for generate_draft_decode_kv_indices
# Cached variables for generate_draft_decode_kv_indices
self
.
pool_len
=
model_runner
.
req_to_token_pool
.
req_to_token
.
shape
[
1
]
self
.
pool_len
=
model_runner
.
req_to_token_pool
.
req_to_token
.
shape
[
1
]
...
@@ -579,9 +700,15 @@ class TritonMultiStepDraftBackend:
...
@@ -579,9 +700,15 @@ class TritonMultiStepDraftBackend:
def
init_forward_metadata_replay_cuda_graph
(
def
init_forward_metadata_replay_cuda_graph
(
self
,
forward_batch
:
ForwardBatch
,
bs
:
int
self
,
forward_batch
:
ForwardBatch
,
bs
:
int
):
):
num_kv_heads
=
self
.
num_head
if
hasattr
(
forward_batch
.
token_to_kv_pool
,
"k_buffer"
):
if
isinstance
(
forward_batch
.
token_to_kv_pool
.
k_buffer
,
list
):
num_kv_heads
=
forward_batch
.
token_to_kv_pool
.
k_buffer
[
0
].
shape
[
1
]
def
call_fn
(
i
,
forward_batch
):
def
call_fn
(
i
,
forward_batch
):
self
.
attn_backends
[
i
].
init_forward_metadata_replay_cuda_graph
(
self
.
attn_backends
[
i
].
init_forward_metadata_replay_cuda_graph
(
bs
,
bs
,
num_kv_heads
,
forward_batch
.
req_pool_indices
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
forward_batch
.
seq_lens
,
seq_lens_sum
=-
1
,
seq_lens_sum
=-
1
,
...
...
python/sglang/srt/layers/attention/triton_ops/decode_attention.py
View file @
c0e9a36c
...
@@ -37,6 +37,9 @@ logger.warning(
...
@@ -37,6 +37,9 @@ logger.warning(
)
)
_MIN_BLOCK_KV
=
32
@
triton
.
jit
@
triton
.
jit
def
tanh
(
x
):
def
tanh
(
x
):
# Tanh is just a scaled sigmoid
# Tanh is just a scaled sigmoid
...
@@ -52,6 +55,8 @@ def _fwd_kernel_stage1(
...
@@ -52,6 +55,8 @@ def _fwd_kernel_stage1(
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
Att_Out
,
Att_Out
,
Att_Lse
,
num_kv_splits
,
stride_qbs
,
stride_qbs
,
stride_qh
,
stride_qh
,
stride_buf_kbs
,
stride_buf_kbs
,
...
@@ -65,7 +70,7 @@ def _fwd_kernel_stage1(
...
@@ -65,7 +70,7 @@ def _fwd_kernel_stage1(
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_DV
:
tl
.
constexpr
,
BLOCK_DV
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
NUM_KV_SPLITS
:
tl
.
constexpr
,
MIN_BLOCK_KV
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
Lk
:
tl
.
constexpr
,
Lk
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
...
@@ -83,11 +88,13 @@ def _fwd_kernel_stage1(
...
@@ -83,11 +88,13 @@ def _fwd_kernel_stage1(
cur_batch_kv_start_idx
=
tl
.
load
(
kv_indptr
+
cur_batch
)
cur_batch_kv_start_idx
=
tl
.
load
(
kv_indptr
+
cur_batch
)
cur_batch_seq_len
=
tl
.
load
(
kv_indptr
+
cur_batch
+
1
)
-
cur_batch_kv_start_idx
cur_batch_seq_len
=
tl
.
load
(
kv_indptr
+
cur_batch
+
1
)
-
cur_batch_kv_start_idx
kv_splits
=
tl
.
load
(
num_kv_splits
+
cur_batch
)
off_q
=
cur_batch
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_d
off_q
=
cur_batch
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_d
q
=
tl
.
load
(
Q
+
off_q
,
mask
=
mask_d
,
other
=
0.0
)
kv_len_per_split
=
tl
.
cdiv
(
cur_batch_seq_len
,
NUM_KV_SPLITS
)
kv_len_per_split
=
(
tl
.
cdiv
(
tl
.
cdiv
(
cur_batch_seq_len
,
kv_splits
),
MIN_BLOCK_KV
)
*
MIN_BLOCK_KV
)
split_kv_start
=
kv_len_per_split
*
split_kv_id
split_kv_start
=
kv_len_per_split
*
split_kv_id
split_kv_end
=
tl
.
minimum
(
split_kv_start
+
kv_len_per_split
,
cur_batch_seq_len
)
split_kv_end
=
tl
.
minimum
(
split_kv_start
+
kv_len_per_split
,
cur_batch_seq_len
)
...
@@ -96,6 +103,7 @@ def _fwd_kernel_stage1(
...
@@ -96,6 +103,7 @@ def _fwd_kernel_stage1(
acc
=
tl
.
zeros
([
BLOCK_DV
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_DV
],
dtype
=
tl
.
float32
)
if
split_kv_end
>
split_kv_start
:
if
split_kv_end
>
split_kv_start
:
q
=
tl
.
load
(
Q
+
off_q
,
mask
=
mask_d
,
other
=
0.0
)
for
start_n
in
range
(
split_kv_start
,
split_kv_end
,
BLOCK_N
):
for
start_n
in
range
(
split_kv_start
,
split_kv_end
,
BLOCK_N
):
offs_n
=
start_n
+
tl
.
arange
(
0
,
BLOCK_N
)
offs_n
=
start_n
+
tl
.
arange
(
0
,
BLOCK_N
)
kv_loc
=
tl
.
load
(
kv_loc
=
tl
.
load
(
...
@@ -158,11 +166,10 @@ def _fwd_kernel_stage1(
...
@@ -158,11 +166,10 @@ def _fwd_kernel_stage1(
cur_batch
*
stride_mid_ob
cur_batch
*
stride_mid_ob
+
cur_head
*
stride_mid_oh
+
cur_head
*
stride_mid_oh
+
split_kv_id
*
stride_mid_os
+
split_kv_id
*
stride_mid_os
+
Lv
)
//
Lv
)
tl
.
store
(
tl
.
store
(
Att_
Out
+
offs_mid_o_1
,
Att_
Lse
+
offs_mid_o_1
,
e_max
+
tl
.
log
(
e_sum
),
e_max
+
tl
.
log
(
e_sum
),
)
)
...
@@ -172,9 +179,11 @@ def _decode_att_m_fwd(
...
@@ -172,9 +179,11 @@ def _decode_att_m_fwd(
k_buffer
,
k_buffer
,
v_buffer
,
v_buffer
,
att_out
,
att_out
,
att_lse
,
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
num_kv_splits
,
num_kv_splits
,
max_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
,
logit_cap
,
):
):
...
@@ -182,13 +191,13 @@ def _decode_att_m_fwd(
...
@@ -182,13 +191,13 @@ def _decode_att_m_fwd(
# [TODO] work around SGPR limit on MI3xx
# [TODO] work around SGPR limit on MI3xx
if
_is_hip
:
if
_is_hip
:
BLOCK
=
8
BLOCK
=
8
NUM
_KV_SPLITS
=
num
_kv_splits
MAX
_KV_SPLITS
=
max
_kv_splits
Lk
=
k_buffer
.
shape
[
-
1
]
Lk
=
k_buffer
.
shape
[
-
1
]
Lv
=
v_buffer
.
shape
[
-
1
]
Lv
=
v_buffer
.
shape
[
-
1
]
batch
,
head_num
=
kv_indptr
.
shape
[
0
]
-
1
,
q
.
shape
[
1
]
batch
,
head_num
=
kv_indptr
.
shape
[
0
]
-
1
,
q
.
shape
[
1
]
grid
=
(
batch
,
head_num
,
NUM
_KV_SPLITS
)
grid
=
(
batch
,
head_num
,
MAX
_KV_SPLITS
)
kv_group_num
=
q
.
shape
[
1
]
//
k_buffer
.
shape
[
1
]
kv_group_num
=
q
.
shape
[
1
]
//
k_buffer
.
shape
[
1
]
if
kv_group_num
==
1
:
if
kv_group_num
==
1
:
...
@@ -209,6 +218,8 @@ def _decode_att_m_fwd(
...
@@ -209,6 +218,8 @@ def _decode_att_m_fwd(
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
att_out
,
att_out
,
att_lse
,
num_kv_splits
,
q
.
stride
(
0
),
q
.
stride
(
0
),
q
.
stride
(
1
),
q
.
stride
(
1
),
k_buffer
.
stride
(
0
),
k_buffer
.
stride
(
0
),
...
@@ -222,7 +233,7 @@ def _decode_att_m_fwd(
...
@@ -222,7 +233,7 @@ def _decode_att_m_fwd(
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_DV
=
BLOCK_DV
,
BLOCK_DV
=
BLOCK_DV
,
BLOCK_N
=
BLOCK
,
BLOCK_N
=
BLOCK
,
NUM_KV_SPLITS
=
NUM_KV_SPLITS
,
MIN_BLOCK_KV
=
_MIN_BLOCK_KV
,
logit_cap
=
logit_cap
,
logit_cap
=
logit_cap
,
num_warps
=
num_warps
,
num_warps
=
num_warps
,
num_stages
=
2
,
num_stages
=
2
,
...
@@ -240,6 +251,8 @@ def _fwd_grouped_kernel_stage1(
...
@@ -240,6 +251,8 @@ def _fwd_grouped_kernel_stage1(
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
Att_Out
,
Att_Out
,
Att_Lse
,
num_kv_splits
,
stride_qbs
,
stride_qbs
,
stride_qh
,
stride_qh
,
stride_buf_kbs
,
stride_buf_kbs
,
...
@@ -256,7 +269,7 @@ def _fwd_grouped_kernel_stage1(
...
@@ -256,7 +269,7 @@ def _fwd_grouped_kernel_stage1(
BLOCK_DV
:
tl
.
constexpr
,
BLOCK_DV
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_H
:
tl
.
constexpr
,
BLOCK_H
:
tl
.
constexpr
,
NUM_KV_SPLITS
:
tl
.
constexpr
,
MIN_BLOCK_KV
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
Lk
:
tl
.
constexpr
,
Lk
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
...
@@ -281,9 +294,9 @@ def _fwd_grouped_kernel_stage1(
...
@@ -281,9 +294,9 @@ def _fwd_grouped_kernel_stage1(
cur_batch_kv_start_idx
=
tl
.
load
(
kv_indptr
+
cur_batch
)
cur_batch_kv_start_idx
=
tl
.
load
(
kv_indptr
+
cur_batch
)
cur_batch_seq_len
=
tl
.
load
(
kv_indptr
+
cur_batch
+
1
)
-
cur_batch_kv_start_idx
cur_batch_seq_len
=
tl
.
load
(
kv_indptr
+
cur_batch
+
1
)
-
cur_batch_kv_start_idx
kv_splits
=
tl
.
load
(
num_kv_splits
+
cur_batch
)
offs_q
=
cur_batch
*
stride_qbs
+
cur_head
[:,
None
]
*
stride_qh
+
offs_d
[
None
,
:]
offs_q
=
cur_batch
*
stride_qbs
+
cur_head
[:,
None
]
*
stride_qh
+
offs_d
[
None
,
:]
q
=
tl
.
load
(
Q
+
offs_q
,
mask
=
(
mask_h
[:,
None
])
&
(
mask_d
[
None
,
:]),
other
=
0.0
)
if
BLOCK_DPE
>
0
:
if
BLOCK_DPE
>
0
:
offs_dpe
=
BLOCK_DMODEL
+
tl
.
arange
(
0
,
BLOCK_DPE
)
offs_dpe
=
BLOCK_DMODEL
+
tl
.
arange
(
0
,
BLOCK_DPE
)
...
@@ -291,11 +304,10 @@ def _fwd_grouped_kernel_stage1(
...
@@ -291,11 +304,10 @@ def _fwd_grouped_kernel_stage1(
off_qpe
=
(
off_qpe
=
(
cur_batch
*
stride_qbs
+
cur_head
[:,
None
]
*
stride_qh
+
offs_dpe
[
None
,
:]
cur_batch
*
stride_qbs
+
cur_head
[:,
None
]
*
stride_qh
+
offs_dpe
[
None
,
:]
)
)
qpe
=
tl
.
load
(
Q
+
off_qpe
,
mask
=
(
mask_h
[:,
None
])
&
(
mask_dpe
[
None
,
:]),
other
=
0.0
)
kv_len_per_split
=
tl
.
cdiv
(
cur_batch_seq_len
,
NUM_KV_SPLITS
)
kv_len_per_split
=
(
tl
.
cdiv
(
tl
.
cdiv
(
cur_batch_seq_len
,
kv_splits
),
MIN_BLOCK_KV
)
*
MIN_BLOCK_KV
)
split_kv_start
=
kv_len_per_split
*
split_kv_id
split_kv_start
=
kv_len_per_split
*
split_kv_id
split_kv_end
=
tl
.
minimum
(
split_kv_start
+
kv_len_per_split
,
cur_batch_seq_len
)
split_kv_end
=
tl
.
minimum
(
split_kv_start
+
kv_len_per_split
,
cur_batch_seq_len
)
...
@@ -304,6 +316,11 @@ def _fwd_grouped_kernel_stage1(
...
@@ -304,6 +316,11 @@ def _fwd_grouped_kernel_stage1(
acc
=
tl
.
zeros
([
BLOCK_H
,
BLOCK_DV
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_H
,
BLOCK_DV
],
dtype
=
tl
.
float32
)
if
split_kv_end
>
split_kv_start
:
if
split_kv_end
>
split_kv_start
:
q
=
tl
.
load
(
Q
+
offs_q
,
mask
=
(
mask_h
[:,
None
])
&
(
mask_d
[
None
,
:]),
other
=
0.0
)
if
BLOCK_DPE
>
0
:
qpe
=
tl
.
load
(
Q
+
off_qpe
,
mask
=
(
mask_h
[:,
None
])
&
(
mask_dpe
[
None
,
:]),
other
=
0.0
)
for
start_n
in
range
(
split_kv_start
,
split_kv_end
,
BLOCK_N
):
for
start_n
in
range
(
split_kv_start
,
split_kv_end
,
BLOCK_N
):
offs_n
=
start_n
+
tl
.
arange
(
0
,
BLOCK_N
)
offs_n
=
start_n
+
tl
.
arange
(
0
,
BLOCK_N
)
kv_loc
=
tl
.
load
(
kv_loc
=
tl
.
load
(
...
@@ -380,11 +397,10 @@ def _fwd_grouped_kernel_stage1(
...
@@ -380,11 +397,10 @@ def _fwd_grouped_kernel_stage1(
cur_batch
*
stride_mid_ob
cur_batch
*
stride_mid_ob
+
cur_head
*
stride_mid_oh
+
cur_head
*
stride_mid_oh
+
split_kv_id
*
stride_mid_os
+
split_kv_id
*
stride_mid_os
+
Lv
)
//
Lv
)
tl
.
store
(
tl
.
store
(
Att_
Out
+
offs_mid_o_1
,
Att_
Lse
+
offs_mid_o_1
,
e_max
+
tl
.
log
(
e_sum
),
e_max
+
tl
.
log
(
e_sum
),
mask
=
mask_h
,
mask
=
mask_h
,
)
)
...
@@ -395,9 +411,11 @@ def _decode_grouped_att_m_fwd(
...
@@ -395,9 +411,11 @@ def _decode_grouped_att_m_fwd(
k_buffer
,
k_buffer
,
v_buffer
,
v_buffer
,
att_out
,
att_out
,
att_lse
,
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
num_kv_splits
,
num_kv_splits
,
max_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
,
logit_cap
,
):
):
...
@@ -424,11 +442,11 @@ def _decode_grouped_att_m_fwd(
...
@@ -424,11 +442,11 @@ def _decode_grouped_att_m_fwd(
kv_group_num
=
q
.
shape
[
1
]
//
k_buffer
.
shape
[
1
]
kv_group_num
=
q
.
shape
[
1
]
//
k_buffer
.
shape
[
1
]
BLOCK_H
=
16
BLOCK_H
=
16
NUM
_KV_SPLITS
=
num
_kv_splits
MAX
_KV_SPLITS
=
max
_kv_splits
grid
=
(
grid
=
(
batch
,
batch
,
triton
.
cdiv
(
head_num
,
min
(
BLOCK_H
,
kv_group_num
)),
triton
.
cdiv
(
head_num
,
min
(
BLOCK_H
,
kv_group_num
)),
NUM
_KV_SPLITS
,
MAX
_KV_SPLITS
,
)
)
extra_kargs
=
{}
extra_kargs
=
{}
...
@@ -447,6 +465,8 @@ def _decode_grouped_att_m_fwd(
...
@@ -447,6 +465,8 @@ def _decode_grouped_att_m_fwd(
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
att_out
,
att_out
,
att_lse
,
num_kv_splits
,
q
.
stride
(
0
),
q
.
stride
(
0
),
q
.
stride
(
1
),
q
.
stride
(
1
),
k_buffer
.
stride
(
0
),
k_buffer
.
stride
(
0
),
...
@@ -463,7 +483,7 @@ def _decode_grouped_att_m_fwd(
...
@@ -463,7 +483,7 @@ def _decode_grouped_att_m_fwd(
BLOCK_DV
=
BLOCK_DV
,
BLOCK_DV
=
BLOCK_DV
,
BLOCK_N
=
BLOCK
,
BLOCK_N
=
BLOCK
,
BLOCK_H
=
BLOCK_H
,
BLOCK_H
=
BLOCK_H
,
NUM_KV_SPLITS
=
NUM_KV_SPLITS
,
MIN_BLOCK_KV
=
_MIN_BLOCK_KV
,
logit_cap
=
logit_cap
,
logit_cap
=
logit_cap
,
num_warps
=
4
,
num_warps
=
4
,
num_stages
=
num_stages
,
num_stages
=
num_stages
,
...
@@ -476,14 +496,17 @@ def _decode_grouped_att_m_fwd(
...
@@ -476,14 +496,17 @@ def _decode_grouped_att_m_fwd(
@
triton
.
jit
@
triton
.
jit
def
_fwd_kernel_stage2
(
def
_fwd_kernel_stage2
(
Mid_O
,
Mid_O
,
Mid_O_1
,
O
,
O
,
kv_indptr
,
kv_indptr
,
num_kv_splits
,
stride_mid_ob
,
stride_mid_ob
,
stride_mid_oh
,
stride_mid_oh
,
stride_mid_os
,
stride_mid_os
,
stride_obs
,
stride_obs
,
stride_oh
,
stride_oh
,
NUM_KV_SPLITS
:
tl
.
constexpr
,
MAX_KV_SPLITS
:
tl
.
constexpr
,
MIN_BLOCK_KV
:
tl
.
constexpr
,
BLOCK_DV
:
tl
.
constexpr
,
BLOCK_DV
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
):
):
...
@@ -493,6 +516,7 @@ def _fwd_kernel_stage2(
...
@@ -493,6 +516,7 @@ def _fwd_kernel_stage2(
cur_batch_seq_len
=
tl
.
load
(
kv_indptr
+
cur_batch
+
1
)
-
tl
.
load
(
cur_batch_seq_len
=
tl
.
load
(
kv_indptr
+
cur_batch
+
1
)
-
tl
.
load
(
kv_indptr
+
cur_batch
kv_indptr
+
cur_batch
)
)
kv_splits
=
tl
.
load
(
num_kv_splits
+
cur_batch
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DV
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DV
)
mask_d
=
offs_d
<
Lv
mask_d
=
offs_d
<
Lv
...
@@ -502,10 +526,12 @@ def _fwd_kernel_stage2(
...
@@ -502,10 +526,12 @@ def _fwd_kernel_stage2(
acc
=
tl
.
zeros
([
BLOCK_DV
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_DV
],
dtype
=
tl
.
float32
)
offs_v
=
cur_batch
*
stride_mid_ob
+
cur_head
*
stride_mid_oh
+
offs_d
offs_v
=
cur_batch
*
stride_mid_ob
+
cur_head
*
stride_mid_oh
+
offs_d
offs_logic
=
cur_batch
*
stride_mid_ob
+
cur_head
*
stride_mid_oh
+
Lv
offs_logic
=
(
cur_batch
*
stride_mid_ob
+
cur_head
*
stride_mid_oh
)
//
Lv
kv_len_per_split
=
(
tl
.
cdiv
(
tl
.
cdiv
(
cur_batch_seq_len
,
kv_splits
),
MIN_BLOCK_KV
)
*
MIN_BLOCK_KV
)
for
split_kv_id
in
range
(
0
,
NUM_KV_SPLITS
):
for
split_kv_id
in
range
(
0
,
MAX_KV_SPLITS
):
kv_len_per_split
=
tl
.
cdiv
(
cur_batch_seq_len
,
NUM_KV_SPLITS
)
split_kv_start
=
kv_len_per_split
*
split_kv_id
split_kv_start
=
kv_len_per_split
*
split_kv_id
split_kv_end
=
tl
.
minimum
(
split_kv_start
+
kv_len_per_split
,
cur_batch_seq_len
)
split_kv_end
=
tl
.
minimum
(
split_kv_start
+
kv_len_per_split
,
cur_batch_seq_len
)
...
@@ -513,7 +539,7 @@ def _fwd_kernel_stage2(
...
@@ -513,7 +539,7 @@ def _fwd_kernel_stage2(
tv
=
tl
.
load
(
tv
=
tl
.
load
(
Mid_O
+
offs_v
+
split_kv_id
*
stride_mid_os
,
mask
=
mask_d
,
other
=
0.0
Mid_O
+
offs_v
+
split_kv_id
*
stride_mid_os
,
mask
=
mask_d
,
other
=
0.0
)
)
tlogic
=
tl
.
load
(
Mid_O
+
offs_logic
+
split_kv_id
*
stride_mid_os
)
tlogic
=
tl
.
load
(
Mid_O
_1
+
offs_logic
+
split_kv_id
*
stride_mid_os
//
Lv
)
n_e_max
=
tl
.
maximum
(
tlogic
,
e_max
)
n_e_max
=
tl
.
maximum
(
tlogic
,
e_max
)
old_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
old_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
...
@@ -533,17 +559,19 @@ def _fwd_kernel_stage2(
...
@@ -533,17 +559,19 @@ def _fwd_kernel_stage2(
def
_decode_softmax_reducev_fwd
(
def
_decode_softmax_reducev_fwd
(
logits
,
logits
,
lse
,
q
,
q
,
o
,
o
,
v_buffer
,
v_buffer
,
kv_indptr
,
kv_indptr
,
num_kv_splits
,
num_kv_splits
,
max_kv_splits
,
):
):
batch
,
head_num
=
q
.
shape
[
0
],
q
.
shape
[
1
]
batch
,
head_num
=
q
.
shape
[
0
],
q
.
shape
[
1
]
Lv
=
v_buffer
.
shape
[
-
1
]
Lv
=
v_buffer
.
shape
[
-
1
]
BLOCK_DV
=
triton
.
next_power_of_2
(
Lv
)
BLOCK_DV
=
triton
.
next_power_of_2
(
Lv
)
NUM
_KV_SPLITS
=
num
_kv_splits
MAX
_KV_SPLITS
=
max
_kv_splits
extra_kargs
=
{}
extra_kargs
=
{}
if
_is_hip
:
if
_is_hip
:
...
@@ -554,14 +582,17 @@ def _decode_softmax_reducev_fwd(
...
@@ -554,14 +582,17 @@ def _decode_softmax_reducev_fwd(
grid
=
(
batch
,
head_num
)
grid
=
(
batch
,
head_num
)
_fwd_kernel_stage2
[
grid
](
_fwd_kernel_stage2
[
grid
](
logits
,
logits
,
lse
,
o
,
o
,
kv_indptr
,
kv_indptr
,
num_kv_splits
,
logits
.
stride
(
0
),
logits
.
stride
(
0
),
logits
.
stride
(
1
),
logits
.
stride
(
1
),
logits
.
stride
(
2
),
logits
.
stride
(
2
),
o
.
stride
(
0
),
o
.
stride
(
0
),
o
.
stride
(
1
),
o
.
stride
(
1
),
NUM_KV_SPLITS
=
NUM_KV_SPLITS
,
MAX_KV_SPLITS
=
MAX_KV_SPLITS
,
MIN_BLOCK_KV
=
_MIN_BLOCK_KV
,
BLOCK_DV
=
BLOCK_DV
,
BLOCK_DV
=
BLOCK_DV
,
Lv
=
Lv
,
Lv
=
Lv
,
num_warps
=
4
,
num_warps
=
4
,
...
@@ -579,6 +610,7 @@ def decode_attention_fwd_normal(
...
@@ -579,6 +610,7 @@ def decode_attention_fwd_normal(
kv_indices
,
kv_indices
,
attn_logits
,
attn_logits
,
num_kv_splits
,
num_kv_splits
,
max_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
=
0.0
,
logit_cap
=
0.0
,
):
):
...
@@ -586,14 +618,25 @@ def decode_attention_fwd_normal(
...
@@ -586,14 +618,25 @@ def decode_attention_fwd_normal(
q
,
q
,
k_buffer
,
k_buffer
,
v_buffer
,
v_buffer
,
attn_logits
,
attn_logits
[
0
],
attn_logits
[
1
],
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
num_kv_splits
,
num_kv_splits
,
max_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
,
logit_cap
,
)
)
_decode_softmax_reducev_fwd
(
attn_logits
,
q
,
o
,
v_buffer
,
kv_indptr
,
num_kv_splits
)
_decode_softmax_reducev_fwd
(
attn_logits
[
0
],
attn_logits
[
1
],
q
,
o
,
v_buffer
,
kv_indptr
,
num_kv_splits
,
max_kv_splits
,
)
def
decode_attention_fwd_grouped
(
def
decode_attention_fwd_grouped
(
...
@@ -605,6 +648,7 @@ def decode_attention_fwd_grouped(
...
@@ -605,6 +648,7 @@ def decode_attention_fwd_grouped(
kv_indices
,
kv_indices
,
attn_logits
,
attn_logits
,
num_kv_splits
,
num_kv_splits
,
max_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
=
0.0
,
logit_cap
=
0.0
,
):
):
...
@@ -612,14 +656,25 @@ def decode_attention_fwd_grouped(
...
@@ -612,14 +656,25 @@ def decode_attention_fwd_grouped(
q
,
q
,
k_buffer
,
k_buffer
,
v_buffer
,
v_buffer
,
attn_logits
,
attn_logits
[
0
],
attn_logits
[
1
],
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
num_kv_splits
,
num_kv_splits
,
max_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
,
logit_cap
,
)
)
_decode_softmax_reducev_fwd
(
attn_logits
,
q
,
o
,
v_buffer
,
kv_indptr
,
num_kv_splits
)
_decode_softmax_reducev_fwd
(
attn_logits
[
0
],
attn_logits
[
1
],
q
,
o
,
v_buffer
,
kv_indptr
,
num_kv_splits
,
max_kv_splits
,
)
def
decode_attention_fwd
(
def
decode_attention_fwd
(
...
@@ -631,12 +686,13 @@ def decode_attention_fwd(
...
@@ -631,12 +686,13 @@ def decode_attention_fwd(
kv_indices
,
kv_indices
,
attn_logits
,
attn_logits
,
num_kv_splits
,
num_kv_splits
,
max_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
=
0.0
,
logit_cap
=
0.0
,
):
):
assert
num
_kv_splits
==
attn_logits
.
shape
[
2
]
assert
max
_kv_splits
==
attn_logits
[
0
]
.
shape
[
2
]
assert
q
.
shape
[
0
]
<=
kv_indptr
.
shape
[
0
]
-
1
assert
q
.
shape
[
0
]
<=
kv_indptr
.
shape
[
0
]
-
1
assert
q
.
shape
[
0
]
<=
attn_logits
.
shape
[
0
]
assert
q
.
shape
[
0
]
<=
attn_logits
[
0
]
.
shape
[
0
]
kv_group_num
=
q
.
shape
[
1
]
//
v_buffer
.
shape
[
1
]
kv_group_num
=
q
.
shape
[
1
]
//
v_buffer
.
shape
[
1
]
...
@@ -651,6 +707,7 @@ def decode_attention_fwd(
...
@@ -651,6 +707,7 @@ def decode_attention_fwd(
kv_indices
,
kv_indices
,
attn_logits
,
attn_logits
,
num_kv_splits
,
num_kv_splits
,
max_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
,
logit_cap
,
)
)
...
@@ -665,6 +722,7 @@ def decode_attention_fwd(
...
@@ -665,6 +722,7 @@ def decode_attention_fwd(
kv_indices
,
kv_indices
,
attn_logits
,
attn_logits
,
num_kv_splits
,
num_kv_splits
,
max_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
,
logit_cap
,
)
)
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
c0e9a36c
...
@@ -26,6 +26,7 @@ import tqdm
...
@@ -26,6 +26,7 @@ import tqdm
from
sglang.srt.custom_op
import
CustomOp
from
sglang.srt.custom_op
import
CustomOp
from
sglang.srt.distributed
import
get_tensor_model_parallel_rank
from
sglang.srt.distributed
import
get_tensor_model_parallel_rank
from
sglang.srt.distributed.parallel_state
import
GroupCoordinator
,
graph_capture
from
sglang.srt.distributed.parallel_state
import
GroupCoordinator
,
graph_capture
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.moe.fused_moe_native
import
fused_moe_forward_native
from
sglang.srt.layers.moe.fused_moe_native
import
fused_moe_forward_native
from
sglang.srt.layers.torchao_utils
import
save_gemlite_cache
from
sglang.srt.layers.torchao_utils
import
save_gemlite_cache
...
@@ -195,6 +196,9 @@ class CudaGraphRunner:
...
@@ -195,6 +196,9 @@ class CudaGraphRunner:
# Attention backend
# Attention backend
self
.
max_bs
=
max
(
self
.
capture_bs
)
self
.
max_bs
=
max
(
self
.
capture_bs
)
self
.
max_num_token
=
self
.
max_bs
*
self
.
num_tokens_per_bs
self
.
max_num_token
=
self
.
max_bs
*
self
.
num_tokens_per_bs
self
.
num_head
=
(
model_runner
.
model_config
.
num_attention_heads
//
get_attention_tp_size
()
)
self
.
model_runner
.
attn_backend
.
init_cuda_graph_state
(
self
.
max_num_token
)
self
.
model_runner
.
attn_backend
.
init_cuda_graph_state
(
self
.
max_num_token
)
self
.
seq_len_fill_value
=
(
self
.
seq_len_fill_value
=
(
self
.
model_runner
.
attn_backend
.
get_cuda_graph_seq_len_fill_value
()
self
.
model_runner
.
attn_backend
.
get_cuda_graph_seq_len_fill_value
()
...
@@ -503,9 +507,15 @@ class CudaGraphRunner:
...
@@ -503,9 +507,15 @@ class CudaGraphRunner:
if
hasattr
(
forward_batch
.
spec_info
,
"hidden_states"
):
if
hasattr
(
forward_batch
.
spec_info
,
"hidden_states"
):
self
.
hidden_states
[:
raw_num_token
]
=
forward_batch
.
spec_info
.
hidden_states
self
.
hidden_states
[:
raw_num_token
]
=
forward_batch
.
spec_info
.
hidden_states
num_kv_heads
=
self
.
num_head
if
hasattr
(
forward_batch
.
token_to_kv_pool
,
"k_buffer"
):
if
isinstance
(
forward_batch
.
token_to_kv_pool
.
k_buffer
,
list
):
num_kv_heads
=
forward_batch
.
token_to_kv_pool
.
k_buffer
[
0
].
shape
[
1
]
# Attention backend
# Attention backend
self
.
model_runner
.
attn_backend
.
init_forward_metadata_replay_cuda_graph
(
self
.
model_runner
.
attn_backend
.
init_forward_metadata_replay_cuda_graph
(
bs
,
bs
,
num_kv_heads
,
self
.
req_pool_indices
,
self
.
req_pool_indices
,
self
.
seq_lens
,
self
.
seq_lens
,
forward_batch
.
seq_lens_sum
+
(
bs
-
raw_bs
),
forward_batch
.
seq_lens_sum
+
(
bs
-
raw_bs
),
...
...
test/srt/test_triton_attention_kernels.py
View file @
c0e9a36c
...
@@ -228,7 +228,8 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -228,7 +228,8 @@ class TestTritonAttention(unittest.TestCase):
seq_len
=
10
# This represents the number of tokens already in the sequence
seq_len
=
10
# This represents the number of tokens already in the sequence
total_tokens
=
B
*
seq_len
total_tokens
=
B
*
seq_len
sm_scale
=
1.0
/
(
D
**
0.5
)
sm_scale
=
1.0
/
(
D
**
0.5
)
num_kv_splits
=
8
max_kv_splits
=
8
num_kv_splits
=
torch
.
full
((
B
,),
4
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
# q represents the new token being generated, one per batch
# q represents the new token being generated, one per batch
q
=
torch
.
randn
(
B
,
H_Q
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
q
=
torch
.
randn
(
B
,
H_Q
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
...
@@ -247,7 +248,12 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -247,7 +248,12 @@ class TestTritonAttention(unittest.TestCase):
kv_indices
=
torch
.
arange
(
total_tokens
,
device
=
"cuda"
)
kv_indices
=
torch
.
arange
(
total_tokens
,
device
=
"cuda"
)
attn_logits
=
torch
.
empty
(
attn_logits
=
torch
.
empty
(
(
B
,
H_Q
,
num_kv_splits
,
D
+
1
),
(
B
,
H_Q
,
max_kv_splits
,
D
),
dtype
=
torch
.
float32
,
device
=
"cuda"
,
)
attn_lse
=
torch
.
empty
(
(
B
,
H_Q
,
max_kv_splits
),
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
device
=
"cuda"
,
device
=
"cuda"
,
)
)
...
@@ -259,8 +265,9 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -259,8 +265,9 @@ class TestTritonAttention(unittest.TestCase):
o
,
o
,
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
attn_logits
,
(
attn_logits
,
attn_lse
),
num_kv_splits
,
num_kv_splits
,
max_kv_splits
,
sm_scale
,
sm_scale
,
)
)
...
@@ -284,7 +291,8 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -284,7 +291,8 @@ class TestTritonAttention(unittest.TestCase):
seq_len
=
S
# This represents the number of tokens already in the sequence
seq_len
=
S
# This represents the number of tokens already in the sequence
total_tokens
=
B
*
seq_len
total_tokens
=
B
*
seq_len
sm_scale
=
1.0
/
(
D
**
0.5
)
sm_scale
=
1.0
/
(
D
**
0.5
)
num_kv_splits
=
8
max_kv_splits
=
8
num_kv_splits
=
torch
.
full
((
B
,),
4
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
# q represents the new token being generated, one per batch
# q represents the new token being generated, one per batch
q
=
torch
.
randn
(
B
,
H_Q
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
q
=
torch
.
randn
(
B
,
H_Q
,
D
,
dtype
=
dtype
,
device
=
"cuda"
)
...
@@ -304,7 +312,12 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -304,7 +312,12 @@ class TestTritonAttention(unittest.TestCase):
kv_indices
=
torch
.
arange
(
total_tokens
,
device
=
"cuda"
)
kv_indices
=
torch
.
arange
(
total_tokens
,
device
=
"cuda"
)
attn_logits
=
torch
.
empty
(
attn_logits
=
torch
.
empty
(
(
B
,
H_Q
,
num_kv_splits
,
D_V
+
1
),
(
B
,
H_Q
,
max_kv_splits
,
D_V
),
dtype
=
torch
.
float32
,
device
=
"cuda"
,
)
attn_lse
=
torch
.
empty
(
(
B
,
H_Q
,
max_kv_splits
),
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
device
=
"cuda"
,
device
=
"cuda"
,
)
)
...
@@ -316,13 +329,19 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -316,13 +329,19 @@ class TestTritonAttention(unittest.TestCase):
o
,
o
,
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
attn_logits
,
(
attn_logits
,
attn_lse
),
num_kv_splits
,
num_kv_splits
,
max_kv_splits
,
sm_scale
,
sm_scale
,
)
)
attn_logits1
=
torch
.
empty
(
attn_logits1
=
torch
.
empty
(
(
B
,
H_Q
,
num_kv_splits
,
D_V
+
1
),
(
B
,
H_Q
,
max_kv_splits
,
D_V
),
dtype
=
torch
.
float32
,
device
=
"cuda"
,
)
attn_lse1
=
torch
.
empty
(
(
B
,
H_Q
,
max_kv_splits
,
D_V
),
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
device
=
"cuda"
,
device
=
"cuda"
,
)
)
...
@@ -334,8 +353,9 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -334,8 +353,9 @@ class TestTritonAttention(unittest.TestCase):
o_grouped
,
o_grouped
,
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
attn_logits1
,
(
attn_logits1
,
attn_lse1
),
num_kv_splits
,
num_kv_splits
,
max_kv_splits
,
sm_scale
,
sm_scale
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment