Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
f44d1439
Unverified
Commit
f44d1439
authored
Dec 30, 2024
by
Lianmin Zheng
Committed by
GitHub
Dec 30, 2024
Browse files
Support target model verification in the attention backend (#2678)
Co-authored-by:
yukavio
<
kavioyu@gmail.com
>
parent
b6b57fc2
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
309 additions
and
226 deletions
+309
-226
python/sglang/srt/layers/attention/__init__.py
python/sglang/srt/layers/attention/__init__.py
+14
-5
python/sglang/srt/layers/attention/double_sparsity_backend.py
...on/sglang/srt/layers/attention/double_sparsity_backend.py
+0
-52
python/sglang/srt/layers/attention/flashinfer_backend.py
python/sglang/srt/layers/attention/flashinfer_backend.py
+209
-81
python/sglang/srt/layers/attention/torch_native_backend.py
python/sglang/srt/layers/attention/torch_native_backend.py
+1
-38
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+20
-11
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+64
-38
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+1
-1
No files found.
python/sglang/srt/layers/attention/__init__.py
View file @
f44d1439
from
__future__
import
annotations
from
abc
import
ABC
,
abstractmethod
from
typing
import
Optional
from
typing
import
TYPE_CHECKING
,
Optional
import
torch
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
if
TYPE_CHECKING
:
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.speculative.spec_info
import
SpecInfo
class
AttentionBackend
(
ABC
):
...
...
@@ -22,9 +26,12 @@ class AttentionBackend(ABC):
def
init_forward_metadata_capture_cuda_graph
(
self
,
bs
:
int
,
num_token
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
]
=
None
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
SpecInfo
],
):
"""Init the metadata for a forward pass for capturing a cuda graph."""
raise
NotImplementedError
()
...
...
@@ -35,7 +42,9 @@ class AttentionBackend(ABC):
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
encoder_lens
:
Optional
[
torch
.
Tensor
]
=
None
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
SpecInfo
],
):
"""Init the metadata for a forward pass for replying a cuda graph."""
raise
NotImplementedError
()
...
...
python/sglang/srt/layers/attention/double_sparsity_backend.py
View file @
f44d1439
...
...
@@ -3,7 +3,6 @@ from __future__ import annotations
from
typing
import
TYPE_CHECKING
import
torch
import
torch.nn
as
nn
from
sglang.srt.layers.attention
import
AttentionBackend
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
...
...
@@ -52,8 +51,6 @@ class DoubleSparseAttnBackend(AttentionBackend):
self
.
forward_metadata
=
None
self
.
cuda_graph_max_seq_len
=
model_runner
.
model_config
.
context_len
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
"""Init auxiliary variables for triton attention backend."""
...
...
@@ -115,55 +112,6 @@ class DoubleSparseAttnBackend(AttentionBackend):
ds_req_to_token
,
)
def
init_cuda_graph_state
(
self
,
max_bs
:
int
):
# TODO(Andy): Support CUDA graph for double sparse attention
raise
ValueError
(
"Double sparse attention does not support CUDA graph for now. Please --disable-cuda-graph"
)
self
.
cuda_graph_max_total_num_tokens
=
max_bs
*
self
.
cuda_graph_max_seq_len
self
.
cuda_graph_start_loc
=
torch
.
zeros
(
(
max_bs
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
cuda_graph_attn_logits
=
torch
.
empty
(
(
self
.
num_head
,
self
.
cuda_graph_max_total_num_tokens
,
),
dtype
=
self
.
reduce_dtype
,
device
=
"cuda"
,
)
def
init_forward_metadata_capture_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
encoder_lens
=
None
,
):
# NOTE: encoder_lens expected to be zeros or None
self
.
forward_metadata
=
(
self
.
cuda_graph_start_loc
,
self
.
cuda_graph_attn_logits
,
self
.
cuda_graph_max_seq_len
,
None
,
)
def
init_forward_metadata_replay_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
encoder_lens
=
None
,
):
# NOTE: encoder_lens expected to be zeros or None
self
.
cuda_graph_start_loc
.
zero_
()
self
.
cuda_graph_start_loc
[
1
:
bs
]
=
torch
.
cumsum
(
seq_lens
[:
bs
-
1
],
dim
=
0
)
def
get_cuda_graph_seq_len_fill_value
(
self
):
return
1
def
forward_extend
(
self
,
q
,
...
...
python/sglang/srt/layers/attention/flashinfer_backend.py
View file @
f44d1439
...
...
@@ -10,7 +10,7 @@ Each backend supports two operators: extend (i.e. prefill with cached prefix) an
import
os
from
dataclasses
import
dataclass
from
enum
import
Enum
,
auto
from
typing
import
TYPE_CHECKING
,
List
,
Union
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Union
import
torch
import
triton
...
...
@@ -18,12 +18,13 @@ import triton.language as tl
from
sglang.global_config
import
global_config
from
sglang.srt.layers.attention
import
AttentionBackend
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.utils
import
is_flashinfer_available
if
TYPE_CHECKING
:
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.speculative.spec_info
import
SpecInfo
if
is_flashinfer_available
():
from
flashinfer
import
(
...
...
@@ -113,11 +114,15 @@ class FlashInferAttnBackend(AttentionBackend):
# Two wrappers: one for sliding window attention and one for full attention.
# Using two wrappers is unnecessary in the current PR, but are prepared for future PRs
self
.
prefill_wrappers_paged
=
[]
self
.
prefill_wrappers_verify
=
[]
self
.
decode_wrappers
=
[]
for
_
in
range
(
self
.
num_wrappers
):
self
.
prefill_wrappers_paged
.
append
(
BatchPrefillWithPagedKVCacheWrapper
(
self
.
workspace_buffer
,
"NHD"
)
)
self
.
prefill_wrappers_verify
.
append
(
BatchPrefillWithPagedKVCacheWrapper
(
self
.
workspace_buffer
,
"NHD"
)
)
self
.
decode_wrappers
.
append
(
BatchDecodeWithPagedKVCacheWrapper
(
self
.
workspace_buffer
,
...
...
@@ -135,6 +140,7 @@ class FlashInferAttnBackend(AttentionBackend):
# Other metadata
self
.
forward_metadata
:
Union
[
PrefillMetadata
,
DecodeMetadata
]
=
None
self
.
decode_cuda_graph_metadata
=
{}
self
.
prefill_cuda_graph_metadata
=
{}
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
if
forward_batch
.
forward_mode
.
is_decode
():
...
...
@@ -144,8 +150,37 @@ class FlashInferAttnBackend(AttentionBackend):
forward_batch
.
seq_lens_sum
,
decode_wrappers
=
self
.
decode_wrappers
,
encoder_lens
=
forward_batch
.
encoder_lens
,
spec_info
=
forward_batch
.
spec_info
,
)
self
.
forward_metadata
=
DecodeMetadata
(
self
.
decode_wrappers
)
elif
forward_batch
.
forward_mode
.
is_draft_extend
():
self
.
indices_updater_prefill
.
update
(
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
forward_batch
.
seq_lens_sum
,
prefix_lens
=
None
,
prefill_wrappers
=
self
.
prefill_wrappers_paged
,
use_ragged
=
False
,
encoder_lens
=
forward_batch
.
encoder_lens
,
spec_info
=
forward_batch
.
spec_info
,
)
self
.
forward_metadata
=
PrefillMetadata
(
self
.
prefill_wrappers_paged
,
False
,
False
)
elif
forward_batch
.
forward_mode
.
is_target_verify
():
self
.
indices_updater_prefill
.
update
(
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
forward_batch
.
seq_lens_sum
,
prefix_lens
=
None
,
prefill_wrappers
=
self
.
prefill_wrappers_verify
,
use_ragged
=
False
,
encoder_lens
=
forward_batch
.
encoder_lens
,
spec_info
=
forward_batch
.
spec_info
,
)
self
.
forward_metadata
=
PrefillMetadata
(
self
.
prefill_wrappers_verify
,
False
,
False
)
else
:
prefix_lens
=
forward_batch
.
extend_prefix_lens
...
...
@@ -165,6 +200,7 @@ class FlashInferAttnBackend(AttentionBackend):
prefill_wrappers
=
self
.
prefill_wrappers_paged
,
use_ragged
=
use_ragged
,
encoder_lens
=
forward_batch
.
encoder_lens
,
spec_info
=
None
,
)
self
.
forward_metadata
=
PrefillMetadata
(
self
.
prefill_wrappers_paged
,
use_ragged
,
extend_no_prefix
...
...
@@ -180,37 +216,80 @@ class FlashInferAttnBackend(AttentionBackend):
cuda_graph_kv_indices
.
clone
()
for
_
in
range
(
self
.
num_wrappers
-
1
)
]
self
.
cuda_graph_custom_mask
=
torch
.
zeros
(
(
max_bs
*
self
.
max_context_len
),
dtype
=
torch
.
uint8
,
device
=
"cuda"
,
)
self
.
cuda_graph_qk_indptr
=
[
x
.
clone
()
for
x
in
self
.
kv_indptr
]
self
.
cuda_graph_qo_indptr
=
[
x
.
clone
()
for
x
in
self
.
kv_indptr
]
def
init_forward_metadata_capture_cuda_graph
(
self
,
bs
:
int
,
num_token
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
torch
.
Tensor
=
None
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
SpecInfo
],
):
decode_wrappers
=
[]
for
i
in
range
(
self
.
num_wrappers
):
decode_wrappers
.
append
(
BatchDecodeWithPagedKVCacheWrapper
(
self
.
workspace_buffer
,
"NHD"
,
use_cuda_graph
=
True
,
use_tensor_cores
=
self
.
decode_use_tensor_cores
,
paged_kv_indptr_buffer
=
self
.
kv_indptr
[
i
][:
bs
+
1
],
paged_kv_indices_buffer
=
self
.
cuda_graph_kv_indices
[
i
],
paged_kv_last_page_len_buffer
=
self
.
kv_last_page_len
[:
bs
],
if
forward_mode
.
is_decode
():
decode_wrappers
=
[]
for
i
in
range
(
self
.
num_wrappers
):
decode_wrappers
.
append
(
BatchDecodeWithPagedKVCacheWrapper
(
self
.
workspace_buffer
,
"NHD"
,
use_cuda_graph
=
True
,
use_tensor_cores
=
self
.
decode_use_tensor_cores
,
paged_kv_indptr_buffer
=
self
.
kv_indptr
[
i
][:
num_token
+
1
],
paged_kv_indices_buffer
=
self
.
cuda_graph_kv_indices
[
i
],
paged_kv_last_page_len_buffer
=
self
.
kv_last_page_len
[:
num_token
],
)
)
seq_lens_sum
=
seq_lens
.
sum
().
item
()
self
.
indices_updater_decode
.
update
(
req_pool_indices
,
seq_lens
,
seq_lens_sum
,
decode_wrappers
=
decode_wrappers
,
encoder_lens
=
encoder_lens
,
spec_info
=
spec_info
,
)
seq_lens_sum
=
seq_lens
.
sum
().
item
()
self
.
indices_updater_decode
.
update
(
req_pool_indices
,
seq_lens
,
seq_lens_sum
,
decode_wrappers
=
decode_wrappers
,
encoder_lens
=
encoder_lens
,
)
self
.
decode_cuda_graph_metadata
[
bs
]
=
decode_wrappers
self
.
forward_metadata
=
DecodeMetadata
(
decode_wrappers
)
self
.
decode_cuda_graph_metadata
[
bs
]
=
decode_wrappers
self
.
forward_metadata
=
DecodeMetadata
(
decode_wrappers
)
elif
forward_mode
.
is_target_verify
():
prefill_wrappers
=
[]
for
i
in
range
(
self
.
num_wrappers
):
prefill_wrappers
.
append
(
BatchPrefillWithPagedKVCacheWrapper
(
self
.
workspace_buffer
,
"NHD"
,
use_cuda_graph
=
True
,
qo_indptr_buf
=
self
.
cuda_graph_qo_indptr
[
i
][:
bs
+
1
],
paged_kv_indptr_buf
=
self
.
kv_indptr
[
i
][:
bs
+
1
],
paged_kv_indices_buf
=
self
.
cuda_graph_kv_indices
[
i
],
paged_kv_last_page_len_buf
=
self
.
kv_last_page_len
[:
bs
],
custom_mask_buf
=
self
.
cuda_graph_custom_mask
,
qk_indptr_buf
=
self
.
cuda_graph_qk_indptr
[
i
][:
bs
+
1
],
)
)
seq_lens_sum
=
seq_lens
.
sum
().
item
()
self
.
indices_updater_prefill
.
update
(
req_pool_indices
,
seq_lens
,
seq_lens_sum
,
prefix_lens
=
None
,
prefill_wrappers
=
prefill_wrappers
,
use_ragged
=
False
,
encoder_lens
=
encoder_lens
,
spec_info
=
spec_info
,
)
self
.
prefill_cuda_graph_metadata
[
bs
]
=
prefill_wrappers
self
.
forward_metadata
=
PrefillMetadata
(
prefill_wrappers
,
False
,
False
)
else
:
raise
ValueError
(
f
"Invalid mode:
{
forward_mode
=
}
"
)
def
init_forward_metadata_replay_cuda_graph
(
self
,
...
...
@@ -218,24 +297,41 @@ class FlashInferAttnBackend(AttentionBackend):
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
encoder_lens
:
torch
.
Tensor
=
None
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
SpecInfo
],
):
self
.
indices_updater_decode
.
update
(
req_pool_indices
[:
bs
],
seq_lens
[:
bs
],
seq_lens_sum
,
decode_wrappers
=
self
.
decode_cuda_graph_metadata
[
bs
],
encoder_lens
=
encoder_lens
[:
bs
]
if
encoder_lens
is
not
None
else
None
,
)
if
forward_mode
.
is_decode
():
self
.
indices_updater_decode
.
update
(
req_pool_indices
[:
bs
],
seq_lens
[:
bs
],
seq_lens_sum
,
decode_wrappers
=
self
.
decode_cuda_graph_metadata
[
bs
],
encoder_lens
=
encoder_lens
[:
bs
]
if
encoder_lens
is
not
None
else
None
,
spec_info
=
spec_info
,
)
elif
forward_mode
.
is_target_verify
():
self
.
indices_updater_prefill
.
update
(
req_pool_indices
[:
bs
],
seq_lens
[:
bs
],
seq_lens_sum
,
prefix_lens
=
None
,
prefill_wrappers
=
self
.
prefill_cuda_graph_metadata
[
bs
],
use_ragged
=
False
,
encoder_lens
=
encoder_lens
[:
bs
]
if
encoder_lens
is
not
None
else
None
,
spec_info
=
spec_info
,
)
else
:
raise
ValueError
(
"Invalid forward mode"
)
def
get_cuda_graph_seq_len_fill_value
(
self
):
return
0
def
forward_extend
(
self
,
q
,
k
,
v
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
save_kv_cache
=
True
,
...
...
@@ -293,9 +389,9 @@ class FlashInferAttnBackend(AttentionBackend):
def
forward_decode
(
self
,
q
,
k
,
v
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
save_kv_cache
=
True
,
...
...
@@ -348,7 +444,6 @@ class FlashInferIndicesUpdaterDecode:
self
.
data_type
=
model_runner
.
kv_cache_dtype
self
.
q_data_type
=
model_runner
.
dtype
self
.
sliding_window_size
=
model_runner
.
sliding_window_size
self
.
attn_backend
=
attn_backend
# Buffers and wrappers
...
...
@@ -371,7 +466,8 @@ class FlashInferIndicesUpdaterDecode:
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
decode_wrappers
:
List
[
BatchDecodeWithPagedKVCacheWrapper
],
encoder_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
SpecInfo
],
):
# Keep the signature for type checking. It will be assigned during runtime.
raise
NotImplementedError
()
...
...
@@ -382,7 +478,8 @@ class FlashInferIndicesUpdaterDecode:
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
decode_wrappers
:
List
[
BatchDecodeWithPagedKVCacheWrapper
],
encoder_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
SpecInfo
],
):
decode_wrappers
=
decode_wrappers
or
self
.
decode_wrappers
self
.
call_begin_forward
(
...
...
@@ -392,6 +489,7 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum
,
self
.
kv_indptr
[
0
],
None
,
spec_info
,
)
def
update_sliding_window
(
...
...
@@ -400,7 +498,8 @@ class FlashInferIndicesUpdaterDecode:
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
decode_wrappers
:
List
[
BatchDecodeWithPagedKVCacheWrapper
],
encoder_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
SpecInfo
],
):
for
wrapper_id
in
range
(
2
):
if
wrapper_id
==
0
:
...
...
@@ -424,6 +523,7 @@ class FlashInferIndicesUpdaterDecode:
paged_kernel_lens_sum_tmp
,
self
.
kv_indptr
[
wrapper_id
],
kv_start_idx_tmp
,
spec_info
,
)
def
update_cross_attention
(
...
...
@@ -432,7 +532,8 @@ class FlashInferIndicesUpdaterDecode:
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
decode_wrappers
:
List
[
BatchDecodeWithPagedKVCacheWrapper
],
encoder_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
SpecInfo
],
):
for
wrapper_id
in
range
(
2
):
if
wrapper_id
==
0
:
...
...
@@ -452,6 +553,7 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum
,
self
.
kv_indptr
[
wrapper_id
],
kv_start_idx
,
spec_info
,
)
def
call_begin_forward
(
...
...
@@ -462,23 +564,30 @@ class FlashInferIndicesUpdaterDecode:
paged_kernel_lens_sum
:
int
,
kv_indptr
:
torch
.
Tensor
,
kv_start_idx
:
torch
.
Tensor
,
spec_info
:
Optional
[
SpecInfo
],
):
bs
=
len
(
req_pool_indices
)
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indices
=
torch
.
empty
(
paged_kernel_lens_sum
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
create_flashinfer_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
req_pool_indices
,
paged_kernel_lens
,
kv_indptr
,
kv_start_idx
,
kv_indices
,
self
.
req_to_token
.
shape
[
1
],
)
if
spec_info
is
None
:
bs
=
len
(
req_pool_indices
)
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indices
=
torch
.
empty
(
paged_kernel_lens_sum
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
create_flashinfer_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
req_pool_indices
,
paged_kernel_lens
,
kv_indptr
,
kv_start_idx
,
kv_indices
,
self
.
req_to_token
.
shape
[
1
],
)
else
:
bs
,
kv_indices
,
kv_indptr
=
spec_info
.
generate_attn_arg_decode
(
req_pool_indices
,
paged_kernel_lens
,
self
.
req_to_token
,
)
wrapper
.
end_forward
()
wrapper
.
begin_forward
(
...
...
@@ -507,7 +616,6 @@ class FlashInferIndicesUpdaterPrefill:
self
.
data_type
=
model_runner
.
kv_cache_dtype
self
.
q_data_type
=
model_runner
.
dtype
self
.
sliding_window_size
=
model_runner
.
sliding_window_size
self
.
attn_backend
=
attn_backend
# Buffers and wrappers
...
...
@@ -534,7 +642,8 @@ class FlashInferIndicesUpdaterPrefill:
prefix_lens
:
torch
.
Tensor
,
prefill_wrappers
:
List
[
BatchPrefillWithPagedKVCacheWrapper
],
use_ragged
:
bool
,
encoder_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
SpecInfo
],
):
# Keep the signature for type checking. It will be assigned during runtime.
raise
NotImplementedError
()
...
...
@@ -547,7 +656,8 @@ class FlashInferIndicesUpdaterPrefill:
prefix_lens
:
torch
.
Tensor
,
prefill_wrappers
:
List
[
BatchPrefillWithPagedKVCacheWrapper
],
use_ragged
:
bool
,
encoder_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
SpecInfo
],
):
if
use_ragged
:
paged_kernel_lens
=
prefix_lens
...
...
@@ -568,6 +678,7 @@ class FlashInferIndicesUpdaterPrefill:
self
.
kv_indptr
[
0
],
self
.
qo_indptr
[
0
],
use_ragged
,
spec_info
,
)
def
update_sliding_window
(
...
...
@@ -578,7 +689,8 @@ class FlashInferIndicesUpdaterPrefill:
prefix_lens
:
torch
.
Tensor
,
prefill_wrappers
:
List
[
BatchPrefillWithPagedKVCacheWrapper
],
use_ragged
:
bool
,
encoder_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
SpecInfo
],
):
for
wrapper_id
in
range
(
2
):
if
wrapper_id
==
0
:
...
...
@@ -607,6 +719,7 @@ class FlashInferIndicesUpdaterPrefill:
self
.
kv_indptr
[
wrapper_id
],
self
.
qo_indptr
[
wrapper_id
],
use_ragged
,
spec_info
,
)
def
update_cross_attention
(
...
...
@@ -617,7 +730,8 @@ class FlashInferIndicesUpdaterPrefill:
prefix_lens
:
torch
.
Tensor
,
prefill_wrappers
:
List
[
BatchPrefillWithPagedKVCacheWrapper
],
use_ragged
:
bool
,
encoder_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
SpecInfo
],
):
for
wrapper_id
in
range
(
2
):
if
wrapper_id
==
0
:
...
...
@@ -643,6 +757,7 @@ class FlashInferIndicesUpdaterPrefill:
self
.
kv_indptr
[
wrapper_id
],
self
.
qo_indptr
[
wrapper_id
],
use_ragged
,
spec_info
,
)
def
call_begin_forward
(
...
...
@@ -658,25 +773,37 @@ class FlashInferIndicesUpdaterPrefill:
kv_indptr
:
torch
.
Tensor
,
qo_indptr
:
torch
.
Tensor
,
use_ragged
:
bool
,
spec_info
:
Optional
[
SpecInfo
],
):
bs
=
len
(
req_pool_indices
)
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indices
=
torch
.
empty
(
paged_kernel_lens_sum
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
create_flashinfer_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
req_pool_indices
,
paged_kernel_lens
,
kv_indptr
,
kv_start_idx
,
kv_indices
,
self
.
req_to_token
.
shape
[
1
],
)
if
spec_info
is
None
:
# Normal extend
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indices
=
torch
.
empty
(
paged_kernel_lens_sum
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
create_flashinfer_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
req_pool_indices
,
paged_kernel_lens
,
kv_indptr
,
kv_start_idx
,
kv_indices
,
self
.
req_to_token
.
shape
[
1
],
)
qo_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
seq_lens
-
prefix_lens
,
dim
=
0
)
qo_indptr
=
qo_indptr
[:
bs
+
1
]
qo_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
seq_lens
-
prefix_lens
,
dim
=
0
)
qo_indptr
=
qo_indptr
[:
bs
+
1
]
custom_mask
=
None
else
:
kv_indices
,
kv_indptr
,
qo_indptr
,
custom_mask
=
(
spec_info
.
generate_attn_arg_prefill
(
req_pool_indices
,
paged_kernel_lens
,
self
.
req_to_token
,
)
)
# extend part
if
use_ragged
:
...
...
@@ -702,6 +829,7 @@ class FlashInferIndicesUpdaterPrefill:
self
.
head_dim
,
1
,
q_data_type
=
self
.
q_data_type
,
custom_mask
=
custom_mask
,
)
...
...
python/sglang/srt/layers/attention/torch_native_backend.py
View file @
f44d1439
from
__future__
import
annotations
from
typing
import
TYPE_CHECKING
,
Optional
from
typing
import
TYPE_CHECKING
import
torch
from
torch.nn.functional
import
scaled_dot_product_attention
...
...
@@ -23,43 +23,6 @@ class TorchNativeAttnBackend(AttentionBackend):
"""Init the metadata for a forward pass."""
pass
def
init_cuda_graph_state
(
self
,
max_bs
:
int
):
# TODO: Support CUDA graph
raise
ValueError
(
"Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph"
)
def
init_forward_metadata_capture_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
]
=
None
,
):
# TODO: Support CUDA graph
raise
ValueError
(
"Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph"
)
def
init_forward_metadata_replay_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
encoder_lens
:
Optional
[
torch
.
Tensor
]
=
None
,
):
# TODO: Support CUDA graph
raise
ValueError
(
"Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph"
)
def
get_cuda_graph_seq_len_fill_value
(
self
):
# TODO: Support CUDA graph
raise
ValueError
(
"Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph"
)
def
_run_sdpa_forward_extend
(
self
,
query
:
torch
.
Tensor
,
...
...
python/sglang/srt/layers/attention/triton_backend.py
View file @
f44d1439
from
__future__
import
annotations
from
typing
import
TYPE_CHECKING
from
typing
import
TYPE_CHECKING
,
Optional
import
torch
from
sglang.srt.layers.attention
import
AttentionBackend
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
if
TYPE_CHECKING
:
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.speculative.spec_info
import
SpecInfo
class
TritonAttnBackend
(
AttentionBackend
):
...
...
@@ -80,11 +81,17 @@ class TritonAttnBackend(AttentionBackend):
def
init_forward_metadata_capture_cuda_graph
(
self
,
bs
:
int
,
num_token
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
encoder_lens
=
None
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
SpecInfo
],
):
# NOTE: encoder_lens expected to be zeros or None
assert
encoder_lens
is
None
,
"Not supported"
assert
forward_mode
.
is_decode
(),
"Not supported"
assert
spec_info
is
None
,
"Not supported"
self
.
forward_metadata
=
(
self
.
cuda_graph_attn_logits
,
None
,
...
...
@@ -96,7 +103,9 @@ class TritonAttnBackend(AttentionBackend):
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
encoder_lens
=
None
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
SpecInfo
],
):
# NOTE: encoder_lens expected to be zeros or None
self
.
cuda_graph_start_loc
.
zero_
()
...
...
@@ -107,9 +116,9 @@ class TritonAttnBackend(AttentionBackend):
def
forward_extend
(
self
,
q
,
k
,
v
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
save_kv_cache
=
True
,
...
...
@@ -146,9 +155,9 @@ class TritonAttnBackend(AttentionBackend):
def
forward_decode
(
self
,
q
,
k
,
v
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
save_kv_cache
=
True
,
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
f44d1439
...
...
@@ -25,14 +25,14 @@ from vllm.distributed import get_tensor_model_parallel_rank
from
vllm.distributed.parallel_state
import
graph_capture
from
vllm.model_executor.custom_op
import
CustomOp
from
sglang.srt.layers.logits_processor
import
(
LogitsMetadata
,
LogitsProcessor
,
LogitsProcessorOutput
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.moe.fused_moe_native
import
fused_moe_forward_native
from
sglang.srt.layers.torchao_utils
import
save_gemlite_cache
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
(
CaptureHiddenMode
,
ForwardBatch
,
ForwardMode
,
)
from
sglang.srt.utils
import
maybe_torch_compile
,
monkey_patch_vllm_all_gather
if
TYPE_CHECKING
:
...
...
@@ -153,6 +153,10 @@ class CudaGraphRunner:
if
bs
<=
model_runner
.
req_to_token_pool
.
size
and
bs
<=
model_runner
.
server_args
.
cuda_graph_max_bs
]
self
.
capture_forward_mode
=
ForwardMode
.
DECODE
self
.
num_tokens_per_bs
=
1
self
.
compile_bs
=
(
[
bs
...
...
@@ -165,8 +169,8 @@ class CudaGraphRunner:
# Attention backend
self
.
max_bs
=
max
(
self
.
capture_bs
)
self
.
m
odel_runner
.
attn_backend
.
init_cuda_graph_state
(
self
.
max
_bs
)
self
.
m
ax_num_token
=
self
.
max_bs
*
self
.
num_tokens_per
_bs
self
.
model_runner
.
attn_backend
.
init_cuda_graph_state
(
self
.
max_num_token
)
self
.
seq_len_fill_value
=
(
self
.
model_runner
.
attn_backend
.
get_cuda_graph_seq_len_fill_value
()
)
...
...
@@ -179,12 +183,13 @@ class CudaGraphRunner:
# Common inputs
with
torch
.
device
(
"cuda"
):
self
.
input_ids
=
torch
.
zeros
((
self
.
max_
bs
,),
dtype
=
torch
.
int32
)
self
.
input_ids
=
torch
.
zeros
((
self
.
max_
num_token
,),
dtype
=
torch
.
int32
)
self
.
req_pool_indices
=
torch
.
zeros
((
self
.
max_bs
,),
dtype
=
torch
.
int32
)
self
.
seq_lens
=
torch
.
full
(
(
self
.
max_bs
,),
self
.
seq_len_fill_value
,
dtype
=
torch
.
int32
)
self
.
out_cache_loc
=
torch
.
zeros
((
self
.
max_bs
,),
dtype
=
torch
.
int32
)
self
.
out_cache_loc
=
torch
.
zeros
((
self
.
max_num_token
,),
dtype
=
torch
.
int32
)
self
.
positions
=
torch
.
zeros
((
self
.
max_num_token
,),
dtype
=
torch
.
int64
)
self
.
mrope_positions
=
torch
.
zeros
((
3
,
self
.
max_bs
),
dtype
=
torch
.
int32
)
if
self
.
is_encoder_decoder
:
...
...
@@ -229,6 +234,9 @@ class CudaGraphRunner:
self
.
model_runner
.
model
.
capture_mode
=
False
def
can_run
(
self
,
forward_batch
:
ForwardBatch
):
if
not
forward_batch
.
forward_mode
.
is_cuda_graph
():
return
False
if
self
.
enable_dp_attention
:
min_num_tokens
,
max_num_tokens
=
min
(
forward_batch
.
global_num_tokens
),
max
(
forward_batch
.
global_num_tokens
...
...
@@ -258,12 +266,12 @@ class CudaGraphRunner:
def
capture
(
self
):
with
graph_capture
()
as
graph_capture_context
:
self
.
stream
=
graph_capture_context
.
stream
capture_
bs
=
(
capture_
range
=
(
tqdm
.
tqdm
(
self
.
capture_bs
)
if
get_tensor_model_parallel_rank
()
==
0
else
self
.
capture_bs
)
for
bs
in
capture_
bs
:
for
bs
in
capture_
range
:
with
patch_model
(
self
.
model_runner
.
model
,
bs
in
self
.
compile_bs
,
...
...
@@ -283,12 +291,15 @@ class CudaGraphRunner:
def
capture_one_batch_size
(
self
,
bs
:
int
,
forward
:
Callable
):
graph
=
torch
.
cuda
.
CUDAGraph
()
stream
=
self
.
stream
num_token
=
bs
*
self
.
num_tokens_per_bs
# Common inputs
input_ids
=
self
.
input_ids
[:
bs
]
input_ids
=
self
.
input_ids
[:
num_token
]
req_pool_indices
=
self
.
req_pool_indices
[:
bs
]
seq_lens
=
self
.
seq_lens
[:
bs
]
out_cache_loc
=
self
.
out_cache_loc
[:
bs
]
out_cache_loc
=
self
.
out_cache_loc
[:
num_token
]
positions
=
self
.
positions
[:
num_token
]
if
self
.
is_encoder_decoder
:
encoder_lens
=
self
.
encoder_lens
[:
bs
]
else
:
...
...
@@ -304,37 +315,41 @@ class CudaGraphRunner:
global_num_tokens
=
None
gathered_buffer
=
None
forward_batch
=
ForwardBatch
(
forward_mode
=
self
.
capture_forward_mode
,
batch_size
=
bs
,
input_ids
=
input_ids
,
req_pool_indices
=
req_pool_indices
,
seq_lens
=
seq_lens
,
req_to_token_pool
=
self
.
model_runner
.
req_to_token_pool
,
token_to_kv_pool
=
self
.
model_runner
.
token_to_kv_pool
,
attn_backend
=
self
.
model_runner
.
attn_backend
,
out_cache_loc
=
out_cache_loc
,
seq_lens_sum
=
seq_lens_sum
,
encoder_lens
=
encoder_lens
,
return_logprob
=
False
,
top_logprobs_nums
=
[
0
]
*
num_token
,
positions
=
positions
,
global_num_tokens
=
global_num_tokens
,
mrope_positions
=
mrope_positions
,
gathered_buffer
=
gathered_buffer
,
)
# Attention backend
self
.
model_runner
.
attn_backend
.
init_forward_metadata_capture_cuda_graph
(
bs
,
num_token
,
req_pool_indices
,
seq_lens
,
encoder_lens
,
forward_batch
.
forward_mode
,
forward_batch
.
spec_info
,
)
# Run and capture
def
run_once
():
forward_batch
=
ForwardBatch
(
forward_mode
=
ForwardMode
.
DECODE
,
batch_size
=
bs
,
input_ids
=
input_ids
,
req_pool_indices
=
req_pool_indices
,
seq_lens
=
seq_lens
,
req_to_token_pool
=
self
.
model_runner
.
req_to_token_pool
,
token_to_kv_pool
=
self
.
model_runner
.
token_to_kv_pool
,
attn_backend
=
self
.
model_runner
.
attn_backend
,
out_cache_loc
=
out_cache_loc
,
seq_lens_sum
=
seq_lens_sum
,
encoder_lens
=
encoder_lens
,
return_logprob
=
False
,
top_logprobs_nums
=
[
0
]
*
bs
,
positions
=
clamp_position
(
seq_lens
),
mrope_positions
=
mrope_positions
,
global_num_tokens
=
global_num_tokens
,
gathered_buffer
=
gathered_buffer
,
)
logits_output
=
forward
(
input_ids
,
forward_batch
.
positions
,
forward_batch
)
return
logits_output
.
next_token_logits
return
logits_output
.
next_token_logits
,
logits_output
.
hidden_states
for
_
in
range
(
2
):
torch
.
cuda
.
synchronize
()
...
...
@@ -360,6 +375,9 @@ class CudaGraphRunner:
def
replay
(
self
,
forward_batch
:
ForwardBatch
):
assert
forward_batch
.
out_cache_loc
is
not
None
raw_bs
=
forward_batch
.
batch_size
# In normal decoding case, raw_bs == raw_num_token
# But in speculative decoding, raw_num_token is raw_bs * self.num_tokens_per_bs
raw_num_token
=
forward_batch
.
input_ids
.
numel
()
# Pad
if
self
.
enable_dp_attention
:
...
...
@@ -374,10 +392,13 @@ class CudaGraphRunner:
self
.
out_cache_loc
.
zero_
()
# Common inputs
self
.
input_ids
[:
raw_
bs
].
copy_
(
forward_batch
.
input_ids
)
self
.
input_ids
[:
raw_
num_token
].
copy_
(
forward_batch
.
input_ids
)
self
.
req_pool_indices
[:
raw_bs
].
copy_
(
forward_batch
.
req_pool_indices
)
self
.
seq_lens
[:
raw_bs
].
copy_
(
forward_batch
.
seq_lens
)
self
.
out_cache_loc
[:
raw_bs
].
copy_
(
forward_batch
.
out_cache_loc
)
self
.
out_cache_loc
[:
raw_num_token
].
copy_
(
forward_batch
.
out_cache_loc
)
positions
=
clamp_position
(
forward_batch
.
seq_lens
)
self
.
positions
[:
raw_num_token
].
copy_
(
positions
)
if
self
.
is_encoder_decoder
:
self
.
encoder_lens
[:
raw_bs
].
copy_
(
forward_batch
.
encoder_lens
)
if
forward_batch
.
mrope_positions
is
not
None
:
...
...
@@ -390,13 +411,18 @@ class CudaGraphRunner:
self
.
seq_lens
,
forward_batch
.
seq_lens_sum
+
(
bs
-
raw_bs
),
self
.
encoder_lens
,
forward_batch
.
forward_mode
,
forward_batch
.
spec_info
,
)
# Replay
self
.
graphs
[
bs
].
replay
()
next_token_logits
=
self
.
output_buffers
[
bs
]
[:
raw_bs
]
next_token_logits
,
hidden_states
=
self
.
output_buffers
[
bs
]
logits_output
=
LogitsProcessorOutput
(
next_token_logits
=
next_token_logits
,
next_token_logits
=
next_token_logits
[:
raw_num_token
],
hidden_states
=
(
hidden_states
[:
raw_num_token
]
if
hidden_states
is
not
None
else
None
),
)
return
logits_output
python/sglang/srt/model_executor/forward_batch_info.py
View file @
f44d1439
...
...
@@ -96,7 +96,7 @@ class ForwardMode(IntEnum):
return
self
==
ForwardMode
.
DRAFT_EXTEND
def
is_cuda_graph
(
self
):
return
self
in
(
ForwardMode
.
DECODE
,
ForwardMode
.
TARGET_VERIFY
)
return
self
==
ForwardMode
.
DECODE
or
self
==
ForwardMode
.
TARGET_VERIFY
def
is_dummy_first
(
self
):
return
self
==
ForwardMode
.
DUMMY_FIRST
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment