Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9e93ef3f
Unverified
Commit
9e93ef3f
authored
Mar 20, 2025
by
JieXin Liang
Committed by
GitHub
Mar 20, 2025
Browse files
[fix] fix illegal mem access and clean up triton attention backend (#4571)
parent
fad86a68
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
124 additions
and
125 deletions
+124
-125
python/sglang/srt/layers/attention/base_attn_backend.py
python/sglang/srt/layers/attention/base_attn_backend.py
+0
-1
python/sglang/srt/layers/attention/flashinfer_backend.py
python/sglang/srt/layers/attention/flashinfer_backend.py
+0
-2
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
+0
-2
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+103
-97
python/sglang/srt/layers/attention/triton_ops/decode_attention.py
...glang/srt/layers/attention/triton_ops/decode_attention.py
+15
-10
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+0
-10
test/srt/test_triton_attention_kernels.py
test/srt/test_triton_attention_kernels.py
+6
-3
No files found.
python/sglang/srt/layers/attention/base_attn_backend.py
View file @
9e93ef3f
...
@@ -39,7 +39,6 @@ class AttentionBackend(ABC):
...
@@ -39,7 +39,6 @@ class AttentionBackend(ABC):
def
init_forward_metadata_replay_cuda_graph
(
def
init_forward_metadata_replay_cuda_graph
(
self
,
self
,
bs
:
int
,
bs
:
int
,
num_kv_heads
:
int
,
req_pool_indices
:
torch
.
Tensor
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
...
...
python/sglang/srt/layers/attention/flashinfer_backend.py
View file @
9e93ef3f
...
@@ -349,7 +349,6 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -349,7 +349,6 @@ class FlashInferAttnBackend(AttentionBackend):
def
init_forward_metadata_replay_cuda_graph
(
def
init_forward_metadata_replay_cuda_graph
(
self
,
self
,
bs
:
int
,
bs
:
int
,
num_kv_heads
:
int
,
req_pool_indices
:
torch
.
Tensor
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
...
@@ -1063,7 +1062,6 @@ class FlashInferMultiStepDraftBackend:
...
@@ -1063,7 +1062,6 @@ class FlashInferMultiStepDraftBackend:
def
call_fn
(
i
,
forward_batch
):
def
call_fn
(
i
,
forward_batch
):
self
.
attn_backends
[
i
].
init_forward_metadata_replay_cuda_graph
(
self
.
attn_backends
[
i
].
init_forward_metadata_replay_cuda_graph
(
bs
,
bs
,
-
1
,
forward_batch
.
req_pool_indices
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
forward_batch
.
seq_lens
,
seq_lens_sum
=-
1
,
seq_lens_sum
=-
1
,
...
...
python/sglang/srt/layers/attention/flashinfer_mla_backend.py
View file @
9e93ef3f
...
@@ -279,7 +279,6 @@ class FlashInferMLAAttnBackend(AttentionBackend):
...
@@ -279,7 +279,6 @@ class FlashInferMLAAttnBackend(AttentionBackend):
def
init_forward_metadata_replay_cuda_graph
(
def
init_forward_metadata_replay_cuda_graph
(
self
,
self
,
bs
:
int
,
bs
:
int
,
num_kv_heads
:
int
,
req_pool_indices
:
torch
.
Tensor
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
...
@@ -792,7 +791,6 @@ class FlashInferMLAMultiStepDraftBackend:
...
@@ -792,7 +791,6 @@ class FlashInferMLAMultiStepDraftBackend:
def
call_fn
(
i
,
forward_batch
):
def
call_fn
(
i
,
forward_batch
):
self
.
attn_backends
[
i
].
init_forward_metadata_replay_cuda_graph
(
self
.
attn_backends
[
i
].
init_forward_metadata_replay_cuda_graph
(
bs
,
bs
,
-
1
,
forward_batch
.
req_pool_indices
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
forward_batch
.
seq_lens
,
seq_lens_sum
=-
1
,
seq_lens_sum
=-
1
,
...
...
python/sglang/srt/layers/attention/triton_backend.py
View file @
9e93ef3f
from
__future__
import
annotations
from
__future__
import
annotations
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Optional
,
Union
import
torch
import
torch
...
@@ -22,20 +23,21 @@ if TYPE_CHECKING:
...
@@ -22,20 +23,21 @@ if TYPE_CHECKING:
def
get_num_kv_splits_triton
(
def
get_num_kv_splits_triton
(
num_kv_splits_ptr
,
num_kv_splits_ptr
,
seq_lens_ptr
,
seq_lens_ptr
,
bs
,
num_seq
,
num_group
,
num_head
,
num_head
,
num_kv_head
,
num_kv_head
,
max_kv_splits
,
max_kv_splits
,
device_core_count
,
device_core_count
,
MAX_
BS
:
tl
.
constexpr
,
MAX_
NUM_SEQ
:
tl
.
constexpr
,
):
):
# TODO: this method is tunable
# TODO: this method is tunable
, we need more online serving data to tune it
offs_
b
=
tl
.
arange
(
0
,
MAX_
BS
)
offs_
seq
=
tl
.
arange
(
0
,
MAX_
NUM_SEQ
)
mask_
b
=
offs_
b
<
bs
mask_
seq
=
offs_
seq
<
num_seq
seq_lens
=
tl
.
load
(
seq_lens_ptr
+
offs_
b
,
mask
=
mask_
b
,
other
=
0
)
seq_lens
=
tl
.
load
(
seq_lens_ptr
+
offs_
seq
,
mask
=
mask_
seq
,
other
=
0
)
max_seq_len
=
tl
.
max
(
seq_lens
)
max_seq_len
=
tl
.
max
(
seq_lens
)
seq_lens
=
tl
.
load
(
seq_lens_ptr
+
offs_
b
,
mask
=
mask_
b
,
other
=
max_seq_len
)
seq_lens
=
tl
.
load
(
seq_lens_ptr
+
offs_
seq
,
mask
=
mask_
seq
,
other
=
max_seq_len
)
min_seq_len
=
tl
.
min
(
seq_lens
)
min_seq_len
=
tl
.
min
(
seq_lens
)
if
max_seq_len
*
8
<
min_seq_len
*
10
:
if
max_seq_len
*
8
<
min_seq_len
*
10
:
min_seq_len
=
max_seq_len
min_seq_len
=
max_seq_len
...
@@ -43,24 +45,43 @@ def get_num_kv_splits_triton(
...
@@ -43,24 +45,43 @@ def get_num_kv_splits_triton(
kv_chunk_size_1
=
tl
.
cdiv
(
max_seq_len
,
max_kv_splits_1
)
kv_chunk_size_1
=
tl
.
cdiv
(
max_seq_len
,
max_kv_splits_1
)
# NOTE: this is a hack to let num_kv_split grows up with seqlen gradually
# NOTE: this is a hack to let num_kv_split grows up with seqlen gradually
ext_seq_len
=
tl
.
cast
(
tl
.
cdiv
(
max_seq_len
,
256
),
tl
.
float32
)
ext_seq_len
=
tl
.
cast
(
max_seq_len
,
tl
.
float32
)
/
64.0
ext_device_core_count
=
device_core_count
*
tl
.
maximum
(
ext_device_core_count
=
tl
.
cast
(
tl
.
cast
(
tl
.
ceil
(
tl
.
log2
(
ext_seq_len
)),
tl
.
int32
),
1
device_core_count
*
tl
.
maximum
(
tl
.
log2
(
ext_seq_len
)
,
1.0
),
tl
.
int32
)
)
block_h
,
num_kv_group
=
16
,
num_head
//
num_kv_head
block_h
,
num_kv_group
=
16
,
num_head
//
num_kv_head
if
num_kv_group
==
1
:
if
num_kv_group
==
1
:
bh
_grid
=
bs
*
num_head
token
_grid
=
num_seq
*
num_group
*
num_head
else
:
else
:
# from triton_ops/decode_attention.py:_decode_grouped_att_m_fwd
# from triton_ops/decode_attention.py:_decode_grouped_att_m_fwd
block_h
=
tl
.
minimum
(
block_h
,
num_kv_group
)
block_h
=
tl
.
minimum
(
block_h
,
num_kv_group
)
bh_grid
=
bs
*
tl
.
cdiv
(
num_head
,
block_h
)
token_grid
=
num_seq
*
num_group
*
tl
.
cdiv
(
num_head
,
block_h
)
max_kv_splits_2
=
tl
.
minimum
(
tl
.
cdiv
(
ext_device_core_count
,
bh_grid
),
max_kv_splits
)
max_kv_splits_2
=
tl
.
minimum
(
tl
.
cdiv
(
ext_device_core_count
,
token_grid
),
max_kv_splits
)
kv_chunk_size_2
=
tl
.
cdiv
(
max_seq_len
,
max_kv_splits_2
)
kv_chunk_size_2
=
tl
.
cdiv
(
max_seq_len
,
max_kv_splits_2
)
num_kv_splits
=
tl
.
maximum
(
num_kv_splits
=
tl
.
maximum
(
tl
.
cdiv
(
seq_lens
,
kv_chunk_size_1
),
tl
.
cdiv
(
seq_lens
,
kv_chunk_size_2
)
tl
.
cdiv
(
seq_lens
,
kv_chunk_size_1
),
tl
.
cdiv
(
seq_lens
,
kv_chunk_size_2
)
)
)
tl
.
store
(
num_kv_splits_ptr
+
offs_b
,
num_kv_splits
,
mask
=
mask_b
)
offs_token
=
offs_seq
*
num_group
mask_token
=
offs_token
<
num_seq
*
num_group
for
i
in
range
(
0
,
num_group
):
tl
.
store
(
num_kv_splits_ptr
+
i
+
offs_token
,
num_kv_splits
,
mask
=
mask_token
)
@
dataclass
class
ForwardMetadata
:
attn_logits
:
torch
.
Tensor
attn_lse
:
torch
.
Tensor
max_extend_len
:
int
num_kv_splits
:
torch
.
Tensor
kv_indptr
:
torch
.
Tensor
kv_indices
:
torch
.
Tensor
qo_indptr
:
torch
.
Tensor
custom_mask
:
torch
.
Tensor
mask_indptr
:
torch
.
Tensor
class
TritonAttnBackend
(
AttentionBackend
):
class
TritonAttnBackend
(
AttentionBackend
):
...
@@ -110,6 +131,9 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -110,6 +131,9 @@ class TritonAttnBackend(AttentionBackend):
self
.
num_head
=
(
self
.
num_head
=
(
model_runner
.
model_config
.
num_attention_heads
//
get_attention_tp_size
()
model_runner
.
model_config
.
num_attention_heads
//
get_attention_tp_size
()
)
)
self
.
num_kv_head
=
model_runner
.
model_config
.
get_num_kv_heads
(
get_attention_tp_size
()
)
self
.
static_kv_splits
=
get_bool_env_var
(
self
.
static_kv_splits
=
get_bool_env_var
(
"SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS"
,
"false"
"SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS"
,
"false"
...
@@ -117,7 +141,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -117,7 +141,7 @@ class TritonAttnBackend(AttentionBackend):
self
.
max_kv_splits
=
model_runner
.
server_args
.
triton_attention_num_kv_splits
self
.
max_kv_splits
=
model_runner
.
server_args
.
triton_attention_num_kv_splits
self
.
v_head_dim
=
model_runner
.
token_to_kv_pool
.
get_value_buffer
(
0
).
shape
[
-
1
]
self
.
v_head_dim
=
model_runner
.
token_to_kv_pool
.
get_value_buffer
(
0
).
shape
[
-
1
]
self
.
forward_metadata
=
None
self
.
forward_metadata
:
ForwardMetadata
=
None
self
.
max_context_len
=
model_runner
.
model_config
.
context_len
self
.
max_context_len
=
model_runner
.
model_config
.
context_len
...
@@ -128,23 +152,33 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -128,23 +152,33 @@ class TritonAttnBackend(AttentionBackend):
self
,
self
,
num_kv_splits
:
torch
.
Tensor
,
num_kv_splits
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
bs
:
int
,
num_kv_head
:
int
,
):
):
MAX_SCHEDULE_BS
=
4096
num_token
,
num_seq
=
num_kv_splits
.
shape
[
0
],
seq_lens
.
shape
[
0
]
if
self
.
static_kv_splits
or
self
.
device_core_count
<=
0
or
bs
>
MAX_SCHEDULE_BS
:
num_group
=
num_token
//
num_seq
assert
(
num_group
*
num_seq
==
num_token
),
f
"num_seq(
{
num_seq
}
), num_token(
{
num_token
}
), something goes wrong!"
if
self
.
static_kv_splits
or
self
.
device_core_count
<=
0
:
num_kv_splits
.
fill_
(
self
.
max_kv_splits
)
num_kv_splits
.
fill_
(
self
.
max_kv_splits
)
return
return
if
num_seq
<
256
:
SCHEDULE_SEQ
=
256
else
:
SCHEDULE_SEQ
=
triton
.
next_power_of_2
(
num_seq
)
get_num_kv_splits_triton
[(
1
,)](
get_num_kv_splits_triton
[(
1
,)](
num_kv_splits
,
num_kv_splits
,
seq_lens
,
seq_lens
,
bs
,
num_seq
,
num_group
,
self
.
num_head
,
self
.
num_head
,
num_kv_head
,
self
.
num_kv_head
,
self
.
max_kv_splits
,
self
.
max_kv_splits
,
self
.
device_core_count
,
self
.
device_core_count
,
MAX_
BS
=
MAX_
SCHEDULE_
B
S
,
MAX_
NUM_SEQ
=
SCHEDULE_S
EQ
,
)
)
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
...
@@ -174,36 +208,19 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -174,36 +208,19 @@ class TritonAttnBackend(AttentionBackend):
kv_indptr
,
kv_indices
=
spec_info
.
kv_indptr
,
spec_info
.
kv_indices
kv_indptr
,
kv_indices
=
spec_info
.
kv_indptr
,
spec_info
.
kv_indices
bs
=
kv_indptr
.
shape
[
0
]
-
1
bs
=
kv_indptr
.
shape
[
0
]
-
1
attn_logits
=
[
attn_logits
=
torch
.
empty
(
torch
.
empty
(
(
bs
,
self
.
num_head
,
self
.
max_kv_splits
,
self
.
v_head_dim
),
(
bs
,
self
.
num_head
,
self
.
max_kv_splits
,
self
.
v_head_dim
,
),
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
device
=
self
.
device
,
device
=
self
.
device
,
),
)
torch
.
empty
(
attn_lse
=
torch
.
empty
(
(
(
bs
,
self
.
num_head
,
self
.
max_kv_splits
),
bs
,
self
.
num_head
,
self
.
max_kv_splits
,
),
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
device
=
self
.
device
,
device
=
self
.
device
,
),
)
]
num_kv_splits
=
torch
.
empty
((
bs
,),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
num_kv_splits
=
torch
.
empty
((
bs
,),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
num_kv_heads
=
self
.
num_head
self
.
get_num_kv_splits
(
num_kv_splits
,
forward_batch
.
seq_lens
)
if
hasattr
(
forward_batch
.
token_to_kv_pool
,
"k_buffer"
):
if
isinstance
(
forward_batch
.
token_to_kv_pool
.
k_buffer
,
list
):
num_kv_heads
=
forward_batch
.
token_to_kv_pool
.
k_buffer
[
0
].
shape
[
1
]
self
.
get_num_kv_splits
(
num_kv_splits
,
forward_batch
.
seq_lens
,
bs
,
num_kv_heads
)
qo_indptr
=
None
qo_indptr
=
None
custom_mask
=
None
custom_mask
=
None
...
@@ -244,6 +261,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -244,6 +261,7 @@ class TritonAttnBackend(AttentionBackend):
max_extend_len
=
self
.
num_draft_tokens
max_extend_len
=
self
.
num_draft_tokens
num_kv_splits
=
None
num_kv_splits
=
None
attn_logits
=
None
attn_logits
=
None
attn_lse
=
None
elif
forward_batch
.
forward_mode
.
is_draft_extend
():
elif
forward_batch
.
forward_mode
.
is_draft_extend
():
kv_indices
,
kv_indptr
,
qo_indptr
,
custom_mask
=
(
kv_indices
,
kv_indptr
,
qo_indptr
,
custom_mask
=
(
spec_info
.
generate_attn_arg_prefill
(
spec_info
.
generate_attn_arg_prefill
(
...
@@ -254,9 +272,13 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -254,9 +272,13 @@ class TritonAttnBackend(AttentionBackend):
)
)
)
)
mask_indptr
=
None
mask_indptr
=
None
# TODO(FIXME): This will trigger an invalid Eagle tree when using
# `max(spec_info.accept_length_cpu)`.
# It might have been forgotten to update somewhere.
max_extend_len
=
torch
.
max
(
spec_info
.
accept_length
).
item
()
max_extend_len
=
torch
.
max
(
spec_info
.
accept_length
).
item
()
num_kv_splits
=
None
num_kv_splits
=
None
attn_logits
=
None
attn_logits
=
None
attn_lse
=
None
else
:
else
:
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
forward_batch
.
extend_prefix_lens
,
dim
=
0
forward_batch
.
extend_prefix_lens
,
dim
=
0
...
@@ -283,11 +305,13 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -283,11 +305,13 @@ class TritonAttnBackend(AttentionBackend):
custom_mask
=
None
custom_mask
=
None
mask_indptr
=
None
mask_indptr
=
None
attn_logits
=
None
attn_logits
=
None
attn_lse
=
None
max_extend_len
=
torch
.
max
(
forward_batch
.
extend_seq_lens
).
item
()
max_extend_len
=
torch
.
max
(
forward_batch
.
extend_seq_lens
).
item
()
num_kv_splits
=
None
num_kv_splits
=
None
self
.
forward_metadata
=
(
self
.
forward_metadata
=
ForwardMetadata
(
attn_logits
,
attn_logits
,
attn_lse
,
max_extend_len
,
max_extend_len
,
num_kv_splits
,
num_kv_splits
,
kv_indptr
,
kv_indptr
,
...
@@ -300,18 +324,16 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -300,18 +324,16 @@ class TritonAttnBackend(AttentionBackend):
def
init_cuda_graph_state
(
def
init_cuda_graph_state
(
self
,
max_bs
:
int
,
kv_indices_buf
:
Optional
[
torch
.
Tensor
]
=
None
self
,
max_bs
:
int
,
kv_indices_buf
:
Optional
[
torch
.
Tensor
]
=
None
):
):
self
.
cuda_graph_attn_logits
=
[
self
.
cuda_graph_attn_logits
=
torch
.
zeros
(
torch
.
zeros
(
(
max_bs
,
self
.
num_head
,
self
.
max_kv_splits
,
self
.
v_head_dim
),
(
max_bs
,
self
.
num_head
,
self
.
max_kv_splits
,
self
.
v_head_dim
),
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
device
=
self
.
device
,
device
=
self
.
device
,
),
)
torch
.
zeros
(
self
.
cuda_graph_attn_lse
=
torch
.
zeros
(
(
max_bs
,
self
.
num_head
,
self
.
max_kv_splits
),
(
max_bs
,
self
.
num_head
,
self
.
max_kv_splits
),
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
device
=
self
.
device
,
device
=
self
.
device
,
),
)
]
self
.
cuda_graph_num_kv_splits
=
torch
.
full
(
self
.
cuda_graph_num_kv_splits
=
torch
.
full
(
(
max_bs
,),
self
.
max_kv_splits
,
dtype
=
torch
.
int32
,
device
=
self
.
device
(
max_bs
,),
self
.
max_kv_splits
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
)
...
@@ -362,6 +384,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -362,6 +384,7 @@ class TritonAttnBackend(AttentionBackend):
kv_indptr
,
kv_indices
=
spec_info
.
kv_indptr
,
spec_info
.
kv_indices
kv_indptr
,
kv_indices
=
spec_info
.
kv_indptr
,
spec_info
.
kv_indices
attn_logits
=
self
.
cuda_graph_attn_logits
attn_logits
=
self
.
cuda_graph_attn_logits
attn_lse
=
self
.
cuda_graph_attn_lse
max_extend_len
=
None
max_extend_len
=
None
num_kv_splits
=
self
.
cuda_graph_num_kv_splits
num_kv_splits
=
self
.
cuda_graph_num_kv_splits
qo_indptr
=
None
qo_indptr
=
None
...
@@ -396,13 +419,15 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -396,13 +419,15 @@ class TritonAttnBackend(AttentionBackend):
max_extend_len
=
self
.
num_draft_tokens
max_extend_len
=
self
.
num_draft_tokens
num_kv_splits
=
None
num_kv_splits
=
None
attn_logits
=
None
attn_logits
=
None
attn_lse
=
None
else
:
else
:
raise
ValueError
(
raise
ValueError
(
f
"Invalid forward mode:
{
forward_mode
=
}
for CUDA Graph capture."
f
"Invalid forward mode:
{
forward_mode
=
}
for CUDA Graph capture."
)
)
self
.
forward_metadata
=
(
self
.
forward_metadata
=
ForwardMetadata
(
attn_logits
,
attn_logits
,
attn_lse
,
max_extend_len
,
max_extend_len
,
num_kv_splits
,
num_kv_splits
,
kv_indptr
,
kv_indptr
,
...
@@ -415,7 +440,6 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -415,7 +440,6 @@ class TritonAttnBackend(AttentionBackend):
def
init_forward_metadata_replay_cuda_graph
(
def
init_forward_metadata_replay_cuda_graph
(
self
,
self
,
bs
:
int
,
bs
:
int
,
num_kv_head
:
int
,
req_pool_indices
:
torch
.
Tensor
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
...
@@ -442,10 +466,12 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -442,10 +466,12 @@ class TritonAttnBackend(AttentionBackend):
kv_indices
,
kv_indices
,
self
.
req_to_token
.
stride
(
0
),
self
.
req_to_token
.
stride
(
0
),
)
)
num_token
=
bs
else
:
else
:
kv_indptr
[:
spec_info
.
kv_indptr
.
shape
[
0
]]
=
spec_info
.
kv_indptr
kv_indptr
[:
spec_info
.
kv_indptr
.
shape
[
0
]]
=
spec_info
.
kv_indptr
kv_indices
[:
spec_info
.
kv_indices
.
shape
[
0
]]
=
spec_info
.
kv_indices
kv_indices
[:
spec_info
.
kv_indices
.
shape
[
0
]]
=
spec_info
.
kv_indices
self
.
get_num_kv_splits
(
num_kv_splits
,
seq_lens
,
bs
,
num_kv_head
)
num_token
=
spec_info
.
kv_indptr
.
shape
[
0
]
-
1
self
.
get_num_kv_splits
(
num_kv_splits
[:
num_token
],
seq_lens
[:
bs
])
elif
forward_mode
.
is_target_verify
():
elif
forward_mode
.
is_target_verify
():
# Update qo_indptr, kv_indptr, kv_indices, custom_mask, mask_indptr
# Update qo_indptr, kv_indptr, kv_indices, custom_mask, mask_indptr
bs
=
len
(
req_pool_indices
)
bs
=
len
(
req_pool_indices
)
...
@@ -502,17 +528,6 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -502,17 +528,6 @@ class TritonAttnBackend(AttentionBackend):
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
)
)
(
_
,
max_extend_len
,
_
,
kv_indptr
,
kv_indices
,
qo_indptr
,
custom_mask
,
mask_indptr
,
)
=
self
.
forward_metadata
self
.
extend_attention_fwd
(
self
.
extend_attention_fwd
(
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
),
q
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
qk_head_dim
),
k
.
contiguous
(),
k
.
contiguous
(),
...
@@ -520,12 +535,12 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -520,12 +535,12 @@ class TritonAttnBackend(AttentionBackend):
o
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
v_head_dim
),
o
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
v_head_dim
),
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
),
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
),
forward_batch
.
token_to_kv_pool
.
get_value_buffer
(
layer
.
layer_id
),
forward_batch
.
token_to_kv_pool
.
get_value_buffer
(
layer
.
layer_id
),
qo_indptr
,
self
.
forward_metadata
.
qo_indptr
,
kv_indptr
,
self
.
forward_metadata
.
kv_indptr
,
kv_indices
,
self
.
forward_metadata
.
kv_indices
,
custom_mask
,
self
.
forward_metadata
.
custom_mask
,
mask_indptr
,
self
.
forward_metadata
.
mask_indptr
,
max_extend_len
,
self
.
forward_metadata
.
max_extend_len
,
layer
.
scaling
,
layer
.
scaling
,
layer
.
logit_cap
,
layer
.
logit_cap
,
)
)
...
@@ -550,10 +565,6 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -550,10 +565,6 @@ class TritonAttnBackend(AttentionBackend):
else
:
else
:
o
=
torch
.
empty_like
(
q
)
o
=
torch
.
empty_like
(
q
)
attn_logits
,
_
,
num_kv_splits
,
kv_indptr
,
kv_indices
,
_
,
_
,
_
=
(
self
.
forward_metadata
)
if
save_kv_cache
:
if
save_kv_cache
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
...
@@ -564,10 +575,11 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -564,10 +575,11 @@ class TritonAttnBackend(AttentionBackend):
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
),
forward_batch
.
token_to_kv_pool
.
get_key_buffer
(
layer
.
layer_id
),
forward_batch
.
token_to_kv_pool
.
get_value_buffer
(
layer
.
layer_id
),
forward_batch
.
token_to_kv_pool
.
get_value_buffer
(
layer
.
layer_id
),
o
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
v_head_dim
),
o
.
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
v_head_dim
),
kv_indptr
,
self
.
forward_metadata
.
kv_indptr
,
kv_indices
,
self
.
forward_metadata
.
kv_indices
,
attn_logits
,
self
.
forward_metadata
.
attn_logits
,
num_kv_splits
,
self
.
forward_metadata
.
attn_lse
,
self
.
forward_metadata
.
num_kv_splits
,
self
.
max_kv_splits
,
self
.
max_kv_splits
,
layer
.
scaling
,
layer
.
scaling
,
layer
.
logit_cap
,
layer
.
logit_cap
,
...
@@ -700,15 +712,9 @@ class TritonMultiStepDraftBackend:
...
@@ -700,15 +712,9 @@ class TritonMultiStepDraftBackend:
def
init_forward_metadata_replay_cuda_graph
(
def
init_forward_metadata_replay_cuda_graph
(
self
,
forward_batch
:
ForwardBatch
,
bs
:
int
self
,
forward_batch
:
ForwardBatch
,
bs
:
int
):
):
num_kv_heads
=
self
.
num_head
if
hasattr
(
forward_batch
.
token_to_kv_pool
,
"k_buffer"
):
if
isinstance
(
forward_batch
.
token_to_kv_pool
.
k_buffer
,
list
):
num_kv_heads
=
forward_batch
.
token_to_kv_pool
.
k_buffer
[
0
].
shape
[
1
]
def
call_fn
(
i
,
forward_batch
):
def
call_fn
(
i
,
forward_batch
):
self
.
attn_backends
[
i
].
init_forward_metadata_replay_cuda_graph
(
self
.
attn_backends
[
i
].
init_forward_metadata_replay_cuda_graph
(
bs
,
bs
,
num_kv_heads
,
forward_batch
.
req_pool_indices
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
forward_batch
.
seq_lens
,
seq_lens_sum
=-
1
,
seq_lens_sum
=-
1
,
...
...
python/sglang/srt/layers/attention/triton_ops/decode_attention.py
View file @
9e93ef3f
...
@@ -609,6 +609,7 @@ def decode_attention_fwd_normal(
...
@@ -609,6 +609,7 @@ def decode_attention_fwd_normal(
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
attn_logits
,
attn_logits
,
attn_lse
,
num_kv_splits
,
num_kv_splits
,
max_kv_splits
,
max_kv_splits
,
sm_scale
,
sm_scale
,
...
@@ -618,8 +619,8 @@ def decode_attention_fwd_normal(
...
@@ -618,8 +619,8 @@ def decode_attention_fwd_normal(
q
,
q
,
k_buffer
,
k_buffer
,
v_buffer
,
v_buffer
,
attn_logits
[
0
]
,
attn_logits
,
attn_l
ogits
[
1
]
,
attn_l
se
,
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
num_kv_splits
,
num_kv_splits
,
...
@@ -628,8 +629,8 @@ def decode_attention_fwd_normal(
...
@@ -628,8 +629,8 @@ def decode_attention_fwd_normal(
logit_cap
,
logit_cap
,
)
)
_decode_softmax_reducev_fwd
(
_decode_softmax_reducev_fwd
(
attn_logits
[
0
]
,
attn_logits
,
attn_l
ogits
[
1
]
,
attn_l
se
,
q
,
q
,
o
,
o
,
v_buffer
,
v_buffer
,
...
@@ -647,6 +648,7 @@ def decode_attention_fwd_grouped(
...
@@ -647,6 +648,7 @@ def decode_attention_fwd_grouped(
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
attn_logits
,
attn_logits
,
attn_lse
,
num_kv_splits
,
num_kv_splits
,
max_kv_splits
,
max_kv_splits
,
sm_scale
,
sm_scale
,
...
@@ -656,8 +658,8 @@ def decode_attention_fwd_grouped(
...
@@ -656,8 +658,8 @@ def decode_attention_fwd_grouped(
q
,
q
,
k_buffer
,
k_buffer
,
v_buffer
,
v_buffer
,
attn_logits
[
0
]
,
attn_logits
,
attn_l
ogits
[
1
]
,
attn_l
se
,
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
num_kv_splits
,
num_kv_splits
,
...
@@ -666,8 +668,8 @@ def decode_attention_fwd_grouped(
...
@@ -666,8 +668,8 @@ def decode_attention_fwd_grouped(
logit_cap
,
logit_cap
,
)
)
_decode_softmax_reducev_fwd
(
_decode_softmax_reducev_fwd
(
attn_logits
[
0
]
,
attn_logits
,
attn_l
ogits
[
1
]
,
attn_l
se
,
q
,
q
,
o
,
o
,
v_buffer
,
v_buffer
,
...
@@ -685,14 +687,15 @@ def decode_attention_fwd(
...
@@ -685,14 +687,15 @@ def decode_attention_fwd(
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
attn_logits
,
attn_logits
,
attn_lse
,
num_kv_splits
,
num_kv_splits
,
max_kv_splits
,
max_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
=
0.0
,
logit_cap
=
0.0
,
):
):
assert
max_kv_splits
==
attn_logits
[
0
]
.
shape
[
2
]
assert
max_kv_splits
==
attn_logits
.
shape
[
2
]
assert
q
.
shape
[
0
]
<=
kv_indptr
.
shape
[
0
]
-
1
assert
q
.
shape
[
0
]
<=
kv_indptr
.
shape
[
0
]
-
1
assert
q
.
shape
[
0
]
<=
attn_logits
[
0
]
.
shape
[
0
]
assert
q
.
shape
[
0
]
<=
attn_logits
.
shape
[
0
]
kv_group_num
=
q
.
shape
[
1
]
//
v_buffer
.
shape
[
1
]
kv_group_num
=
q
.
shape
[
1
]
//
v_buffer
.
shape
[
1
]
...
@@ -706,6 +709,7 @@ def decode_attention_fwd(
...
@@ -706,6 +709,7 @@ def decode_attention_fwd(
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
attn_logits
,
attn_logits
,
attn_lse
,
num_kv_splits
,
num_kv_splits
,
max_kv_splits
,
max_kv_splits
,
sm_scale
,
sm_scale
,
...
@@ -721,6 +725,7 @@ def decode_attention_fwd(
...
@@ -721,6 +725,7 @@ def decode_attention_fwd(
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
attn_logits
,
attn_logits
,
attn_lse
,
num_kv_splits
,
num_kv_splits
,
max_kv_splits
,
max_kv_splits
,
sm_scale
,
sm_scale
,
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
9e93ef3f
...
@@ -26,7 +26,6 @@ import tqdm
...
@@ -26,7 +26,6 @@ import tqdm
from
sglang.srt.custom_op
import
CustomOp
from
sglang.srt.custom_op
import
CustomOp
from
sglang.srt.distributed
import
get_tensor_model_parallel_rank
from
sglang.srt.distributed
import
get_tensor_model_parallel_rank
from
sglang.srt.distributed.parallel_state
import
GroupCoordinator
,
graph_capture
from
sglang.srt.distributed.parallel_state
import
GroupCoordinator
,
graph_capture
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.moe.fused_moe_native
import
fused_moe_forward_native
from
sglang.srt.layers.moe.fused_moe_native
import
fused_moe_forward_native
from
sglang.srt.layers.torchao_utils
import
save_gemlite_cache
from
sglang.srt.layers.torchao_utils
import
save_gemlite_cache
...
@@ -196,9 +195,6 @@ class CudaGraphRunner:
...
@@ -196,9 +195,6 @@ class CudaGraphRunner:
# Attention backend
# Attention backend
self
.
max_bs
=
max
(
self
.
capture_bs
)
self
.
max_bs
=
max
(
self
.
capture_bs
)
self
.
max_num_token
=
self
.
max_bs
*
self
.
num_tokens_per_bs
self
.
max_num_token
=
self
.
max_bs
*
self
.
num_tokens_per_bs
self
.
num_head
=
(
model_runner
.
model_config
.
num_attention_heads
//
get_attention_tp_size
()
)
self
.
model_runner
.
attn_backend
.
init_cuda_graph_state
(
self
.
max_num_token
)
self
.
model_runner
.
attn_backend
.
init_cuda_graph_state
(
self
.
max_num_token
)
self
.
seq_len_fill_value
=
(
self
.
seq_len_fill_value
=
(
self
.
model_runner
.
attn_backend
.
get_cuda_graph_seq_len_fill_value
()
self
.
model_runner
.
attn_backend
.
get_cuda_graph_seq_len_fill_value
()
...
@@ -507,15 +503,9 @@ class CudaGraphRunner:
...
@@ -507,15 +503,9 @@ class CudaGraphRunner:
if
hasattr
(
forward_batch
.
spec_info
,
"hidden_states"
):
if
hasattr
(
forward_batch
.
spec_info
,
"hidden_states"
):
self
.
hidden_states
[:
raw_num_token
]
=
forward_batch
.
spec_info
.
hidden_states
self
.
hidden_states
[:
raw_num_token
]
=
forward_batch
.
spec_info
.
hidden_states
num_kv_heads
=
self
.
num_head
if
hasattr
(
forward_batch
.
token_to_kv_pool
,
"k_buffer"
):
if
isinstance
(
forward_batch
.
token_to_kv_pool
.
k_buffer
,
list
):
num_kv_heads
=
forward_batch
.
token_to_kv_pool
.
k_buffer
[
0
].
shape
[
1
]
# Attention backend
# Attention backend
self
.
model_runner
.
attn_backend
.
init_forward_metadata_replay_cuda_graph
(
self
.
model_runner
.
attn_backend
.
init_forward_metadata_replay_cuda_graph
(
bs
,
bs
,
num_kv_heads
,
self
.
req_pool_indices
,
self
.
req_pool_indices
,
self
.
seq_lens
,
self
.
seq_lens
,
forward_batch
.
seq_lens_sum
+
(
bs
-
raw_bs
),
forward_batch
.
seq_lens_sum
+
(
bs
-
raw_bs
),
...
...
test/srt/test_triton_attention_kernels.py
View file @
9e93ef3f
...
@@ -265,7 +265,8 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -265,7 +265,8 @@ class TestTritonAttention(unittest.TestCase):
o
,
o
,
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
(
attn_logits
,
attn_lse
),
attn_logits
,
attn_lse
,
num_kv_splits
,
num_kv_splits
,
max_kv_splits
,
max_kv_splits
,
sm_scale
,
sm_scale
,
...
@@ -329,7 +330,8 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -329,7 +330,8 @@ class TestTritonAttention(unittest.TestCase):
o
,
o
,
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
(
attn_logits
,
attn_lse
),
attn_logits
,
attn_lse
,
num_kv_splits
,
num_kv_splits
,
max_kv_splits
,
max_kv_splits
,
sm_scale
,
sm_scale
,
...
@@ -353,7 +355,8 @@ class TestTritonAttention(unittest.TestCase):
...
@@ -353,7 +355,8 @@ class TestTritonAttention(unittest.TestCase):
o_grouped
,
o_grouped
,
kv_indptr
,
kv_indptr
,
kv_indices
,
kv_indices
,
(
attn_logits1
,
attn_lse1
),
attn_logits1
,
attn_lse1
,
num_kv_splits
,
num_kv_splits
,
max_kv_splits
,
max_kv_splits
,
sm_scale
,
sm_scale
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment