Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
f44d1439
Unverified
Commit
f44d1439
authored
Dec 30, 2024
by
Lianmin Zheng
Committed by
GitHub
Dec 30, 2024
Browse files
Support target model verification in the attention backend (#2678)
Co-authored-by:
yukavio
<
kavioyu@gmail.com
>
parent
b6b57fc2
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
309 additions
and
226 deletions
+309
-226
python/sglang/srt/layers/attention/__init__.py
python/sglang/srt/layers/attention/__init__.py
+14
-5
python/sglang/srt/layers/attention/double_sparsity_backend.py
...on/sglang/srt/layers/attention/double_sparsity_backend.py
+0
-52
python/sglang/srt/layers/attention/flashinfer_backend.py
python/sglang/srt/layers/attention/flashinfer_backend.py
+209
-81
python/sglang/srt/layers/attention/torch_native_backend.py
python/sglang/srt/layers/attention/torch_native_backend.py
+1
-38
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+20
-11
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+64
-38
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+1
-1
No files found.
python/sglang/srt/layers/attention/__init__.py
View file @
f44d1439
from
__future__
import
annotations
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
typing
import
Optional
from
typing
import
TYPE_CHECKING
,
Optional
import
torch
import
torch
from
sglang.srt.layers.radix_attention
import
RadixAttention
if
TYPE_CHECKING
:
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.speculative.spec_info
import
SpecInfo
class
AttentionBackend
(
ABC
):
class
AttentionBackend
(
ABC
):
...
@@ -22,9 +26,12 @@ class AttentionBackend(ABC):
...
@@ -22,9 +26,12 @@ class AttentionBackend(ABC):
def
init_forward_metadata_capture_cuda_graph
(
def
init_forward_metadata_capture_cuda_graph
(
self
,
self
,
bs
:
int
,
bs
:
int
,
num_token
:
int
,
req_pool_indices
:
torch
.
Tensor
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
]
=
None
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
SpecInfo
],
):
):
"""Init the metadata for a forward pass for capturing a cuda graph."""
"""Init the metadata for a forward pass for capturing a cuda graph."""
raise
NotImplementedError
()
raise
NotImplementedError
()
...
@@ -35,7 +42,9 @@ class AttentionBackend(ABC):
...
@@ -35,7 +42,9 @@ class AttentionBackend(ABC):
req_pool_indices
:
torch
.
Tensor
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
encoder_lens
:
Optional
[
torch
.
Tensor
]
=
None
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
SpecInfo
],
):
):
"""Init the metadata for a forward pass for replying a cuda graph."""
"""Init the metadata for a forward pass for replying a cuda graph."""
raise
NotImplementedError
()
raise
NotImplementedError
()
...
...
python/sglang/srt/layers/attention/double_sparsity_backend.py
View file @
f44d1439
...
@@ -3,7 +3,6 @@ from __future__ import annotations
...
@@ -3,7 +3,6 @@ from __future__ import annotations
from
typing
import
TYPE_CHECKING
from
typing
import
TYPE_CHECKING
import
torch
import
torch
import
torch.nn
as
nn
from
sglang.srt.layers.attention
import
AttentionBackend
from
sglang.srt.layers.attention
import
AttentionBackend
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
...
@@ -52,8 +51,6 @@ class DoubleSparseAttnBackend(AttentionBackend):
...
@@ -52,8 +51,6 @@ class DoubleSparseAttnBackend(AttentionBackend):
self
.
forward_metadata
=
None
self
.
forward_metadata
=
None
self
.
cuda_graph_max_seq_len
=
model_runner
.
model_config
.
context_len
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
"""Init auxiliary variables for triton attention backend."""
"""Init auxiliary variables for triton attention backend."""
...
@@ -115,55 +112,6 @@ class DoubleSparseAttnBackend(AttentionBackend):
...
@@ -115,55 +112,6 @@ class DoubleSparseAttnBackend(AttentionBackend):
ds_req_to_token
,
ds_req_to_token
,
)
)
def
init_cuda_graph_state
(
self
,
max_bs
:
int
):
# TODO(Andy): Support CUDA graph for double sparse attention
raise
ValueError
(
"Double sparse attention does not support CUDA graph for now. Please --disable-cuda-graph"
)
self
.
cuda_graph_max_total_num_tokens
=
max_bs
*
self
.
cuda_graph_max_seq_len
self
.
cuda_graph_start_loc
=
torch
.
zeros
(
(
max_bs
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
cuda_graph_attn_logits
=
torch
.
empty
(
(
self
.
num_head
,
self
.
cuda_graph_max_total_num_tokens
,
),
dtype
=
self
.
reduce_dtype
,
device
=
"cuda"
,
)
def
init_forward_metadata_capture_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
encoder_lens
=
None
,
):
# NOTE: encoder_lens expected to be zeros or None
self
.
forward_metadata
=
(
self
.
cuda_graph_start_loc
,
self
.
cuda_graph_attn_logits
,
self
.
cuda_graph_max_seq_len
,
None
,
)
def
init_forward_metadata_replay_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
encoder_lens
=
None
,
):
# NOTE: encoder_lens expected to be zeros or None
self
.
cuda_graph_start_loc
.
zero_
()
self
.
cuda_graph_start_loc
[
1
:
bs
]
=
torch
.
cumsum
(
seq_lens
[:
bs
-
1
],
dim
=
0
)
def
get_cuda_graph_seq_len_fill_value
(
self
):
return
1
def
forward_extend
(
def
forward_extend
(
self
,
self
,
q
,
q
,
...
...
python/sglang/srt/layers/attention/flashinfer_backend.py
View file @
f44d1439
...
@@ -10,7 +10,7 @@ Each backend supports two operators: extend (i.e. prefill with cached prefix) an
...
@@ -10,7 +10,7 @@ Each backend supports two operators: extend (i.e. prefill with cached prefix) an
import
os
import
os
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
enum
import
Enum
,
auto
from
enum
import
Enum
,
auto
from
typing
import
TYPE_CHECKING
,
List
,
Union
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Union
import
torch
import
torch
import
triton
import
triton
...
@@ -18,12 +18,13 @@ import triton.language as tl
...
@@ -18,12 +18,13 @@ import triton.language as tl
from
sglang.global_config
import
global_config
from
sglang.global_config
import
global_config
from
sglang.srt.layers.attention
import
AttentionBackend
from
sglang.srt.layers.attention
import
AttentionBackend
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.utils
import
is_flashinfer_available
from
sglang.srt.utils
import
is_flashinfer_available
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.speculative.spec_info
import
SpecInfo
if
is_flashinfer_available
():
if
is_flashinfer_available
():
from
flashinfer
import
(
from
flashinfer
import
(
...
@@ -113,11 +114,15 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -113,11 +114,15 @@ class FlashInferAttnBackend(AttentionBackend):
# Two wrappers: one for sliding window attention and one for full attention.
# Two wrappers: one for sliding window attention and one for full attention.
# Using two wrappers is unnecessary in the current PR, but are prepared for future PRs
# Using two wrappers is unnecessary in the current PR, but are prepared for future PRs
self
.
prefill_wrappers_paged
=
[]
self
.
prefill_wrappers_paged
=
[]
self
.
prefill_wrappers_verify
=
[]
self
.
decode_wrappers
=
[]
self
.
decode_wrappers
=
[]
for
_
in
range
(
self
.
num_wrappers
):
for
_
in
range
(
self
.
num_wrappers
):
self
.
prefill_wrappers_paged
.
append
(
self
.
prefill_wrappers_paged
.
append
(
BatchPrefillWithPagedKVCacheWrapper
(
self
.
workspace_buffer
,
"NHD"
)
BatchPrefillWithPagedKVCacheWrapper
(
self
.
workspace_buffer
,
"NHD"
)
)
)
self
.
prefill_wrappers_verify
.
append
(
BatchPrefillWithPagedKVCacheWrapper
(
self
.
workspace_buffer
,
"NHD"
)
)
self
.
decode_wrappers
.
append
(
self
.
decode_wrappers
.
append
(
BatchDecodeWithPagedKVCacheWrapper
(
BatchDecodeWithPagedKVCacheWrapper
(
self
.
workspace_buffer
,
self
.
workspace_buffer
,
...
@@ -135,6 +140,7 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -135,6 +140,7 @@ class FlashInferAttnBackend(AttentionBackend):
# Other metadata
# Other metadata
self
.
forward_metadata
:
Union
[
PrefillMetadata
,
DecodeMetadata
]
=
None
self
.
forward_metadata
:
Union
[
PrefillMetadata
,
DecodeMetadata
]
=
None
self
.
decode_cuda_graph_metadata
=
{}
self
.
decode_cuda_graph_metadata
=
{}
self
.
prefill_cuda_graph_metadata
=
{}
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
if
forward_batch
.
forward_mode
.
is_decode
():
if
forward_batch
.
forward_mode
.
is_decode
():
...
@@ -144,8 +150,37 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -144,8 +150,37 @@ class FlashInferAttnBackend(AttentionBackend):
forward_batch
.
seq_lens_sum
,
forward_batch
.
seq_lens_sum
,
decode_wrappers
=
self
.
decode_wrappers
,
decode_wrappers
=
self
.
decode_wrappers
,
encoder_lens
=
forward_batch
.
encoder_lens
,
encoder_lens
=
forward_batch
.
encoder_lens
,
spec_info
=
forward_batch
.
spec_info
,
)
)
self
.
forward_metadata
=
DecodeMetadata
(
self
.
decode_wrappers
)
self
.
forward_metadata
=
DecodeMetadata
(
self
.
decode_wrappers
)
elif
forward_batch
.
forward_mode
.
is_draft_extend
():
self
.
indices_updater_prefill
.
update
(
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
forward_batch
.
seq_lens_sum
,
prefix_lens
=
None
,
prefill_wrappers
=
self
.
prefill_wrappers_paged
,
use_ragged
=
False
,
encoder_lens
=
forward_batch
.
encoder_lens
,
spec_info
=
forward_batch
.
spec_info
,
)
self
.
forward_metadata
=
PrefillMetadata
(
self
.
prefill_wrappers_paged
,
False
,
False
)
elif
forward_batch
.
forward_mode
.
is_target_verify
():
self
.
indices_updater_prefill
.
update
(
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
forward_batch
.
seq_lens_sum
,
prefix_lens
=
None
,
prefill_wrappers
=
self
.
prefill_wrappers_verify
,
use_ragged
=
False
,
encoder_lens
=
forward_batch
.
encoder_lens
,
spec_info
=
forward_batch
.
spec_info
,
)
self
.
forward_metadata
=
PrefillMetadata
(
self
.
prefill_wrappers_verify
,
False
,
False
)
else
:
else
:
prefix_lens
=
forward_batch
.
extend_prefix_lens
prefix_lens
=
forward_batch
.
extend_prefix_lens
...
@@ -165,6 +200,7 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -165,6 +200,7 @@ class FlashInferAttnBackend(AttentionBackend):
prefill_wrappers
=
self
.
prefill_wrappers_paged
,
prefill_wrappers
=
self
.
prefill_wrappers_paged
,
use_ragged
=
use_ragged
,
use_ragged
=
use_ragged
,
encoder_lens
=
forward_batch
.
encoder_lens
,
encoder_lens
=
forward_batch
.
encoder_lens
,
spec_info
=
None
,
)
)
self
.
forward_metadata
=
PrefillMetadata
(
self
.
forward_metadata
=
PrefillMetadata
(
self
.
prefill_wrappers_paged
,
use_ragged
,
extend_no_prefix
self
.
prefill_wrappers_paged
,
use_ragged
,
extend_no_prefix
...
@@ -180,37 +216,80 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -180,37 +216,80 @@ class FlashInferAttnBackend(AttentionBackend):
cuda_graph_kv_indices
.
clone
()
for
_
in
range
(
self
.
num_wrappers
-
1
)
cuda_graph_kv_indices
.
clone
()
for
_
in
range
(
self
.
num_wrappers
-
1
)
]
]
self
.
cuda_graph_custom_mask
=
torch
.
zeros
(
(
max_bs
*
self
.
max_context_len
),
dtype
=
torch
.
uint8
,
device
=
"cuda"
,
)
self
.
cuda_graph_qk_indptr
=
[
x
.
clone
()
for
x
in
self
.
kv_indptr
]
self
.
cuda_graph_qo_indptr
=
[
x
.
clone
()
for
x
in
self
.
kv_indptr
]
def
init_forward_metadata_capture_cuda_graph
(
def
init_forward_metadata_capture_cuda_graph
(
self
,
self
,
bs
:
int
,
bs
:
int
,
num_token
:
int
,
req_pool_indices
:
torch
.
Tensor
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
torch
.
Tensor
=
None
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
SpecInfo
],
):
):
decode_wrappers
=
[]
if
forward_mode
.
is_decode
():
for
i
in
range
(
self
.
num_wrappers
):
decode_wrappers
=
[]
decode_wrappers
.
append
(
for
i
in
range
(
self
.
num_wrappers
):
BatchDecodeWithPagedKVCacheWrapper
(
decode_wrappers
.
append
(
self
.
workspace_buffer
,
BatchDecodeWithPagedKVCacheWrapper
(
"NHD"
,
self
.
workspace_buffer
,
use_cuda_graph
=
True
,
"NHD"
,
use_tensor_cores
=
self
.
decode_use_tensor_cores
,
use_cuda_graph
=
True
,
paged_kv_indptr_buffer
=
self
.
kv_indptr
[
i
][:
bs
+
1
],
use_tensor_cores
=
self
.
decode_use_tensor_cores
,
paged_kv_indices_buffer
=
self
.
cuda_graph_kv_indices
[
i
],
paged_kv_indptr_buffer
=
self
.
kv_indptr
[
i
][:
num_token
+
1
],
paged_kv_last_page_len_buffer
=
self
.
kv_last_page_len
[:
bs
],
paged_kv_indices_buffer
=
self
.
cuda_graph_kv_indices
[
i
],
paged_kv_last_page_len_buffer
=
self
.
kv_last_page_len
[:
num_token
],
)
)
)
seq_lens_sum
=
seq_lens
.
sum
().
item
()
self
.
indices_updater_decode
.
update
(
req_pool_indices
,
seq_lens
,
seq_lens_sum
,
decode_wrappers
=
decode_wrappers
,
encoder_lens
=
encoder_lens
,
spec_info
=
spec_info
,
)
)
self
.
decode_cuda_graph_metadata
[
bs
]
=
decode_wrappers
seq_lens_sum
=
seq_lens
.
sum
().
item
()
self
.
forward_metadata
=
DecodeMetadata
(
decode_wrappers
)
self
.
indices_updater_decode
.
update
(
elif
forward_mode
.
is_target_verify
():
req_pool_indices
,
prefill_wrappers
=
[]
seq_lens
,
for
i
in
range
(
self
.
num_wrappers
):
seq_lens_sum
,
prefill_wrappers
.
append
(
decode_wrappers
=
decode_wrappers
,
BatchPrefillWithPagedKVCacheWrapper
(
encoder_lens
=
encoder_lens
,
self
.
workspace_buffer
,
)
"NHD"
,
self
.
decode_cuda_graph_metadata
[
bs
]
=
decode_wrappers
use_cuda_graph
=
True
,
self
.
forward_metadata
=
DecodeMetadata
(
decode_wrappers
)
qo_indptr_buf
=
self
.
cuda_graph_qo_indptr
[
i
][:
bs
+
1
],
paged_kv_indptr_buf
=
self
.
kv_indptr
[
i
][:
bs
+
1
],
paged_kv_indices_buf
=
self
.
cuda_graph_kv_indices
[
i
],
paged_kv_last_page_len_buf
=
self
.
kv_last_page_len
[:
bs
],
custom_mask_buf
=
self
.
cuda_graph_custom_mask
,
qk_indptr_buf
=
self
.
cuda_graph_qk_indptr
[
i
][:
bs
+
1
],
)
)
seq_lens_sum
=
seq_lens
.
sum
().
item
()
self
.
indices_updater_prefill
.
update
(
req_pool_indices
,
seq_lens
,
seq_lens_sum
,
prefix_lens
=
None
,
prefill_wrappers
=
prefill_wrappers
,
use_ragged
=
False
,
encoder_lens
=
encoder_lens
,
spec_info
=
spec_info
,
)
self
.
prefill_cuda_graph_metadata
[
bs
]
=
prefill_wrappers
self
.
forward_metadata
=
PrefillMetadata
(
prefill_wrappers
,
False
,
False
)
else
:
raise
ValueError
(
f
"Invalid mode:
{
forward_mode
=
}
"
)
def
init_forward_metadata_replay_cuda_graph
(
def
init_forward_metadata_replay_cuda_graph
(
self
,
self
,
...
@@ -218,24 +297,41 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -218,24 +297,41 @@ class FlashInferAttnBackend(AttentionBackend):
req_pool_indices
:
torch
.
Tensor
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
encoder_lens
:
torch
.
Tensor
=
None
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
SpecInfo
],
):
):
self
.
indices_updater_decode
.
update
(
if
forward_mode
.
is_decode
():
req_pool_indices
[:
bs
],
self
.
indices_updater_decode
.
update
(
seq_lens
[:
bs
],
req_pool_indices
[:
bs
],
seq_lens_sum
,
seq_lens
[:
bs
],
decode_wrappers
=
self
.
decode_cuda_graph_metadata
[
bs
],
seq_lens_sum
,
encoder_lens
=
encoder_lens
[:
bs
]
if
encoder_lens
is
not
None
else
None
,
decode_wrappers
=
self
.
decode_cuda_graph_metadata
[
bs
],
)
encoder_lens
=
encoder_lens
[:
bs
]
if
encoder_lens
is
not
None
else
None
,
spec_info
=
spec_info
,
)
elif
forward_mode
.
is_target_verify
():
self
.
indices_updater_prefill
.
update
(
req_pool_indices
[:
bs
],
seq_lens
[:
bs
],
seq_lens_sum
,
prefix_lens
=
None
,
prefill_wrappers
=
self
.
prefill_cuda_graph_metadata
[
bs
],
use_ragged
=
False
,
encoder_lens
=
encoder_lens
[:
bs
]
if
encoder_lens
is
not
None
else
None
,
spec_info
=
spec_info
,
)
else
:
raise
ValueError
(
"Invalid forward mode"
)
def
get_cuda_graph_seq_len_fill_value
(
self
):
def
get_cuda_graph_seq_len_fill_value
(
self
):
return
0
return
0
def
forward_extend
(
def
forward_extend
(
self
,
self
,
q
,
q
:
torch
.
Tensor
,
k
,
k
:
torch
.
Tensor
,
v
,
v
:
torch
.
Tensor
,
layer
:
RadixAttention
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
save_kv_cache
=
True
,
save_kv_cache
=
True
,
...
@@ -293,9 +389,9 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -293,9 +389,9 @@ class FlashInferAttnBackend(AttentionBackend):
def
forward_decode
(
def
forward_decode
(
self
,
self
,
q
,
q
:
torch
.
Tensor
,
k
,
k
:
torch
.
Tensor
,
v
,
v
:
torch
.
Tensor
,
layer
:
RadixAttention
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
save_kv_cache
=
True
,
save_kv_cache
=
True
,
...
@@ -348,7 +444,6 @@ class FlashInferIndicesUpdaterDecode:
...
@@ -348,7 +444,6 @@ class FlashInferIndicesUpdaterDecode:
self
.
data_type
=
model_runner
.
kv_cache_dtype
self
.
data_type
=
model_runner
.
kv_cache_dtype
self
.
q_data_type
=
model_runner
.
dtype
self
.
q_data_type
=
model_runner
.
dtype
self
.
sliding_window_size
=
model_runner
.
sliding_window_size
self
.
sliding_window_size
=
model_runner
.
sliding_window_size
self
.
attn_backend
=
attn_backend
self
.
attn_backend
=
attn_backend
# Buffers and wrappers
# Buffers and wrappers
...
@@ -371,7 +466,8 @@ class FlashInferIndicesUpdaterDecode:
...
@@ -371,7 +466,8 @@ class FlashInferIndicesUpdaterDecode:
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
decode_wrappers
:
List
[
BatchDecodeWithPagedKVCacheWrapper
],
decode_wrappers
:
List
[
BatchDecodeWithPagedKVCacheWrapper
],
encoder_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
SpecInfo
],
):
):
# Keep the signature for type checking. It will be assigned during runtime.
# Keep the signature for type checking. It will be assigned during runtime.
raise
NotImplementedError
()
raise
NotImplementedError
()
...
@@ -382,7 +478,8 @@ class FlashInferIndicesUpdaterDecode:
...
@@ -382,7 +478,8 @@ class FlashInferIndicesUpdaterDecode:
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
decode_wrappers
:
List
[
BatchDecodeWithPagedKVCacheWrapper
],
decode_wrappers
:
List
[
BatchDecodeWithPagedKVCacheWrapper
],
encoder_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
SpecInfo
],
):
):
decode_wrappers
=
decode_wrappers
or
self
.
decode_wrappers
decode_wrappers
=
decode_wrappers
or
self
.
decode_wrappers
self
.
call_begin_forward
(
self
.
call_begin_forward
(
...
@@ -392,6 +489,7 @@ class FlashInferIndicesUpdaterDecode:
...
@@ -392,6 +489,7 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum
,
seq_lens_sum
,
self
.
kv_indptr
[
0
],
self
.
kv_indptr
[
0
],
None
,
None
,
spec_info
,
)
)
def
update_sliding_window
(
def
update_sliding_window
(
...
@@ -400,7 +498,8 @@ class FlashInferIndicesUpdaterDecode:
...
@@ -400,7 +498,8 @@ class FlashInferIndicesUpdaterDecode:
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
decode_wrappers
:
List
[
BatchDecodeWithPagedKVCacheWrapper
],
decode_wrappers
:
List
[
BatchDecodeWithPagedKVCacheWrapper
],
encoder_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
SpecInfo
],
):
):
for
wrapper_id
in
range
(
2
):
for
wrapper_id
in
range
(
2
):
if
wrapper_id
==
0
:
if
wrapper_id
==
0
:
...
@@ -424,6 +523,7 @@ class FlashInferIndicesUpdaterDecode:
...
@@ -424,6 +523,7 @@ class FlashInferIndicesUpdaterDecode:
paged_kernel_lens_sum_tmp
,
paged_kernel_lens_sum_tmp
,
self
.
kv_indptr
[
wrapper_id
],
self
.
kv_indptr
[
wrapper_id
],
kv_start_idx_tmp
,
kv_start_idx_tmp
,
spec_info
,
)
)
def
update_cross_attention
(
def
update_cross_attention
(
...
@@ -432,7 +532,8 @@ class FlashInferIndicesUpdaterDecode:
...
@@ -432,7 +532,8 @@ class FlashInferIndicesUpdaterDecode:
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
decode_wrappers
:
List
[
BatchDecodeWithPagedKVCacheWrapper
],
decode_wrappers
:
List
[
BatchDecodeWithPagedKVCacheWrapper
],
encoder_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
SpecInfo
],
):
):
for
wrapper_id
in
range
(
2
):
for
wrapper_id
in
range
(
2
):
if
wrapper_id
==
0
:
if
wrapper_id
==
0
:
...
@@ -452,6 +553,7 @@ class FlashInferIndicesUpdaterDecode:
...
@@ -452,6 +553,7 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum
,
seq_lens_sum
,
self
.
kv_indptr
[
wrapper_id
],
self
.
kv_indptr
[
wrapper_id
],
kv_start_idx
,
kv_start_idx
,
spec_info
,
)
)
def
call_begin_forward
(
def
call_begin_forward
(
...
@@ -462,23 +564,30 @@ class FlashInferIndicesUpdaterDecode:
...
@@ -462,23 +564,30 @@ class FlashInferIndicesUpdaterDecode:
paged_kernel_lens_sum
:
int
,
paged_kernel_lens_sum
:
int
,
kv_indptr
:
torch
.
Tensor
,
kv_indptr
:
torch
.
Tensor
,
kv_start_idx
:
torch
.
Tensor
,
kv_start_idx
:
torch
.
Tensor
,
spec_info
:
Optional
[
SpecInfo
],
):
):
bs
=
len
(
req_pool_indices
)
if
spec_info
is
None
:
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
bs
=
len
(
req_pool_indices
)
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
kv_indices
=
torch
.
empty
(
kv_indptr
=
kv_indptr
[:
bs
+
1
]
paged_kernel_lens_sum
,
dtype
=
torch
.
int32
,
device
=
"cuda"
kv_indices
=
torch
.
empty
(
)
paged_kernel_lens_sum
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
create_flashinfer_kv_indices_triton
[(
bs
,)](
create_flashinfer_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
self
.
req_to_token
,
req_pool_indices
,
req_pool_indices
,
paged_kernel_lens
,
paged_kernel_lens
,
kv_indptr
,
kv_indptr
,
kv_start_idx
,
kv_start_idx
,
kv_indices
,
kv_indices
,
self
.
req_to_token
.
shape
[
1
],
self
.
req_to_token
.
shape
[
1
],
)
)
else
:
bs
,
kv_indices
,
kv_indptr
=
spec_info
.
generate_attn_arg_decode
(
req_pool_indices
,
paged_kernel_lens
,
self
.
req_to_token
,
)
wrapper
.
end_forward
()
wrapper
.
end_forward
()
wrapper
.
begin_forward
(
wrapper
.
begin_forward
(
...
@@ -507,7 +616,6 @@ class FlashInferIndicesUpdaterPrefill:
...
@@ -507,7 +616,6 @@ class FlashInferIndicesUpdaterPrefill:
self
.
data_type
=
model_runner
.
kv_cache_dtype
self
.
data_type
=
model_runner
.
kv_cache_dtype
self
.
q_data_type
=
model_runner
.
dtype
self
.
q_data_type
=
model_runner
.
dtype
self
.
sliding_window_size
=
model_runner
.
sliding_window_size
self
.
sliding_window_size
=
model_runner
.
sliding_window_size
self
.
attn_backend
=
attn_backend
self
.
attn_backend
=
attn_backend
# Buffers and wrappers
# Buffers and wrappers
...
@@ -534,7 +642,8 @@ class FlashInferIndicesUpdaterPrefill:
...
@@ -534,7 +642,8 @@ class FlashInferIndicesUpdaterPrefill:
prefix_lens
:
torch
.
Tensor
,
prefix_lens
:
torch
.
Tensor
,
prefill_wrappers
:
List
[
BatchPrefillWithPagedKVCacheWrapper
],
prefill_wrappers
:
List
[
BatchPrefillWithPagedKVCacheWrapper
],
use_ragged
:
bool
,
use_ragged
:
bool
,
encoder_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
SpecInfo
],
):
):
# Keep the signature for type checking. It will be assigned during runtime.
# Keep the signature for type checking. It will be assigned during runtime.
raise
NotImplementedError
()
raise
NotImplementedError
()
...
@@ -547,7 +656,8 @@ class FlashInferIndicesUpdaterPrefill:
...
@@ -547,7 +656,8 @@ class FlashInferIndicesUpdaterPrefill:
prefix_lens
:
torch
.
Tensor
,
prefix_lens
:
torch
.
Tensor
,
prefill_wrappers
:
List
[
BatchPrefillWithPagedKVCacheWrapper
],
prefill_wrappers
:
List
[
BatchPrefillWithPagedKVCacheWrapper
],
use_ragged
:
bool
,
use_ragged
:
bool
,
encoder_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
SpecInfo
],
):
):
if
use_ragged
:
if
use_ragged
:
paged_kernel_lens
=
prefix_lens
paged_kernel_lens
=
prefix_lens
...
@@ -568,6 +678,7 @@ class FlashInferIndicesUpdaterPrefill:
...
@@ -568,6 +678,7 @@ class FlashInferIndicesUpdaterPrefill:
self
.
kv_indptr
[
0
],
self
.
kv_indptr
[
0
],
self
.
qo_indptr
[
0
],
self
.
qo_indptr
[
0
],
use_ragged
,
use_ragged
,
spec_info
,
)
)
def
update_sliding_window
(
def
update_sliding_window
(
...
@@ -578,7 +689,8 @@ class FlashInferIndicesUpdaterPrefill:
...
@@ -578,7 +689,8 @@ class FlashInferIndicesUpdaterPrefill:
prefix_lens
:
torch
.
Tensor
,
prefix_lens
:
torch
.
Tensor
,
prefill_wrappers
:
List
[
BatchPrefillWithPagedKVCacheWrapper
],
prefill_wrappers
:
List
[
BatchPrefillWithPagedKVCacheWrapper
],
use_ragged
:
bool
,
use_ragged
:
bool
,
encoder_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
SpecInfo
],
):
):
for
wrapper_id
in
range
(
2
):
for
wrapper_id
in
range
(
2
):
if
wrapper_id
==
0
:
if
wrapper_id
==
0
:
...
@@ -607,6 +719,7 @@ class FlashInferIndicesUpdaterPrefill:
...
@@ -607,6 +719,7 @@ class FlashInferIndicesUpdaterPrefill:
self
.
kv_indptr
[
wrapper_id
],
self
.
kv_indptr
[
wrapper_id
],
self
.
qo_indptr
[
wrapper_id
],
self
.
qo_indptr
[
wrapper_id
],
use_ragged
,
use_ragged
,
spec_info
,
)
)
def
update_cross_attention
(
def
update_cross_attention
(
...
@@ -617,7 +730,8 @@ class FlashInferIndicesUpdaterPrefill:
...
@@ -617,7 +730,8 @@ class FlashInferIndicesUpdaterPrefill:
prefix_lens
:
torch
.
Tensor
,
prefix_lens
:
torch
.
Tensor
,
prefill_wrappers
:
List
[
BatchPrefillWithPagedKVCacheWrapper
],
prefill_wrappers
:
List
[
BatchPrefillWithPagedKVCacheWrapper
],
use_ragged
:
bool
,
use_ragged
:
bool
,
encoder_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
spec_info
:
Optional
[
SpecInfo
],
):
):
for
wrapper_id
in
range
(
2
):
for
wrapper_id
in
range
(
2
):
if
wrapper_id
==
0
:
if
wrapper_id
==
0
:
...
@@ -643,6 +757,7 @@ class FlashInferIndicesUpdaterPrefill:
...
@@ -643,6 +757,7 @@ class FlashInferIndicesUpdaterPrefill:
self
.
kv_indptr
[
wrapper_id
],
self
.
kv_indptr
[
wrapper_id
],
self
.
qo_indptr
[
wrapper_id
],
self
.
qo_indptr
[
wrapper_id
],
use_ragged
,
use_ragged
,
spec_info
,
)
)
def
call_begin_forward
(
def
call_begin_forward
(
...
@@ -658,25 +773,37 @@ class FlashInferIndicesUpdaterPrefill:
...
@@ -658,25 +773,37 @@ class FlashInferIndicesUpdaterPrefill:
kv_indptr
:
torch
.
Tensor
,
kv_indptr
:
torch
.
Tensor
,
qo_indptr
:
torch
.
Tensor
,
qo_indptr
:
torch
.
Tensor
,
use_ragged
:
bool
,
use_ragged
:
bool
,
spec_info
:
Optional
[
SpecInfo
],
):
):
bs
=
len
(
req_pool_indices
)
bs
=
len
(
req_pool_indices
)
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
if
spec_info
is
None
:
kv_indptr
=
kv_indptr
[:
bs
+
1
]
# Normal extend
kv_indices
=
torch
.
empty
(
kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
paged_kernel_lens_sum
,
dtype
=
torch
.
int32
,
device
=
"cuda"
kv_indptr
=
kv_indptr
[:
bs
+
1
]
)
kv_indices
=
torch
.
empty
(
create_flashinfer_kv_indices_triton
[(
bs
,)](
paged_kernel_lens_sum
,
dtype
=
torch
.
int32
,
device
=
"cuda"
self
.
req_to_token
,
)
req_pool_indices
,
create_flashinfer_kv_indices_triton
[(
bs
,)](
paged_kernel_lens
,
self
.
req_to_token
,
kv_indptr
,
req_pool_indices
,
kv_start_idx
,
paged_kernel_lens
,
kv_indices
,
kv_indptr
,
self
.
req_to_token
.
shape
[
1
],
kv_start_idx
,
)
kv_indices
,
self
.
req_to_token
.
shape
[
1
],
)
qo_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
seq_lens
-
prefix_lens
,
dim
=
0
)
qo_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
seq_lens
-
prefix_lens
,
dim
=
0
)
qo_indptr
=
qo_indptr
[:
bs
+
1
]
qo_indptr
=
qo_indptr
[:
bs
+
1
]
custom_mask
=
None
else
:
kv_indices
,
kv_indptr
,
qo_indptr
,
custom_mask
=
(
spec_info
.
generate_attn_arg_prefill
(
req_pool_indices
,
paged_kernel_lens
,
self
.
req_to_token
,
)
)
# extend part
# extend part
if
use_ragged
:
if
use_ragged
:
...
@@ -702,6 +829,7 @@ class FlashInferIndicesUpdaterPrefill:
...
@@ -702,6 +829,7 @@ class FlashInferIndicesUpdaterPrefill:
self
.
head_dim
,
self
.
head_dim
,
1
,
1
,
q_data_type
=
self
.
q_data_type
,
q_data_type
=
self
.
q_data_type
,
custom_mask
=
custom_mask
,
)
)
...
...
python/sglang/srt/layers/attention/torch_native_backend.py
View file @
f44d1439
from
__future__
import
annotations
from
__future__
import
annotations
from
typing
import
TYPE_CHECKING
,
Optional
from
typing
import
TYPE_CHECKING
import
torch
import
torch
from
torch.nn.functional
import
scaled_dot_product_attention
from
torch.nn.functional
import
scaled_dot_product_attention
...
@@ -23,43 +23,6 @@ class TorchNativeAttnBackend(AttentionBackend):
...
@@ -23,43 +23,6 @@ class TorchNativeAttnBackend(AttentionBackend):
"""Init the metadata for a forward pass."""
"""Init the metadata for a forward pass."""
pass
pass
def
init_cuda_graph_state
(
self
,
max_bs
:
int
):
# TODO: Support CUDA graph
raise
ValueError
(
"Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph"
)
def
init_forward_metadata_capture_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
]
=
None
,
):
# TODO: Support CUDA graph
raise
ValueError
(
"Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph"
)
def
init_forward_metadata_replay_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
encoder_lens
:
Optional
[
torch
.
Tensor
]
=
None
,
):
# TODO: Support CUDA graph
raise
ValueError
(
"Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph"
)
def
get_cuda_graph_seq_len_fill_value
(
self
):
# TODO: Support CUDA graph
raise
ValueError
(
"Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph"
)
def
_run_sdpa_forward_extend
(
def
_run_sdpa_forward_extend
(
self
,
self
,
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
...
...
python/sglang/srt/layers/attention/triton_backend.py
View file @
f44d1439
from
__future__
import
annotations
from
__future__
import
annotations
from
typing
import
TYPE_CHECKING
from
typing
import
TYPE_CHECKING
,
Optional
import
torch
import
torch
from
sglang.srt.layers.attention
import
AttentionBackend
from
sglang.srt.layers.attention
import
AttentionBackend
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.speculative.spec_info
import
SpecInfo
class
TritonAttnBackend
(
AttentionBackend
):
class
TritonAttnBackend
(
AttentionBackend
):
...
@@ -80,11 +81,17 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -80,11 +81,17 @@ class TritonAttnBackend(AttentionBackend):
def
init_forward_metadata_capture_cuda_graph
(
def
init_forward_metadata_capture_cuda_graph
(
self
,
self
,
bs
:
int
,
bs
:
int
,
num_token
:
int
,
req_pool_indices
:
torch
.
Tensor
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
encoder_lens
=
None
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
SpecInfo
],
):
):
# NOTE: encoder_lens expected to be zeros or None
assert
encoder_lens
is
None
,
"Not supported"
assert
forward_mode
.
is_decode
(),
"Not supported"
assert
spec_info
is
None
,
"Not supported"
self
.
forward_metadata
=
(
self
.
forward_metadata
=
(
self
.
cuda_graph_attn_logits
,
self
.
cuda_graph_attn_logits
,
None
,
None
,
...
@@ -96,7 +103,9 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -96,7 +103,9 @@ class TritonAttnBackend(AttentionBackend):
req_pool_indices
:
torch
.
Tensor
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
seq_lens_sum
:
int
,
encoder_lens
=
None
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
SpecInfo
],
):
):
# NOTE: encoder_lens expected to be zeros or None
# NOTE: encoder_lens expected to be zeros or None
self
.
cuda_graph_start_loc
.
zero_
()
self
.
cuda_graph_start_loc
.
zero_
()
...
@@ -107,9 +116,9 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -107,9 +116,9 @@ class TritonAttnBackend(AttentionBackend):
def
forward_extend
(
def
forward_extend
(
self
,
self
,
q
,
q
:
torch
.
Tensor
,
k
,
k
:
torch
.
Tensor
,
v
,
v
:
torch
.
Tensor
,
layer
:
RadixAttention
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
save_kv_cache
=
True
,
save_kv_cache
=
True
,
...
@@ -146,9 +155,9 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -146,9 +155,9 @@ class TritonAttnBackend(AttentionBackend):
def
forward_decode
(
def
forward_decode
(
self
,
self
,
q
,
q
:
torch
.
Tensor
,
k
,
k
:
torch
.
Tensor
,
v
,
v
:
torch
.
Tensor
,
layer
:
RadixAttention
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
save_kv_cache
=
True
,
save_kv_cache
=
True
,
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
f44d1439
...
@@ -25,14 +25,14 @@ from vllm.distributed import get_tensor_model_parallel_rank
...
@@ -25,14 +25,14 @@ from vllm.distributed import get_tensor_model_parallel_rank
from
vllm.distributed.parallel_state
import
graph_capture
from
vllm.distributed.parallel_state
import
graph_capture
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.custom_op
import
CustomOp
from
sglang.srt.layers.logits_processor
import
(
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
LogitsMetadata
,
LogitsProcessor
,
LogitsProcessorOutput
,
)
from
sglang.srt.layers.moe.fused_moe_native
import
fused_moe_forward_native
from
sglang.srt.layers.moe.fused_moe_native
import
fused_moe_forward_native
from
sglang.srt.layers.torchao_utils
import
save_gemlite_cache
from
sglang.srt.layers.torchao_utils
import
save_gemlite_cache
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
(
CaptureHiddenMode
,
ForwardBatch
,
ForwardMode
,
)
from
sglang.srt.utils
import
maybe_torch_compile
,
monkey_patch_vllm_all_gather
from
sglang.srt.utils
import
maybe_torch_compile
,
monkey_patch_vllm_all_gather
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -153,6 +153,10 @@ class CudaGraphRunner:
...
@@ -153,6 +153,10 @@ class CudaGraphRunner:
if
bs
<=
model_runner
.
req_to_token_pool
.
size
if
bs
<=
model_runner
.
req_to_token_pool
.
size
and
bs
<=
model_runner
.
server_args
.
cuda_graph_max_bs
and
bs
<=
model_runner
.
server_args
.
cuda_graph_max_bs
]
]
self
.
capture_forward_mode
=
ForwardMode
.
DECODE
self
.
num_tokens_per_bs
=
1
self
.
compile_bs
=
(
self
.
compile_bs
=
(
[
[
bs
bs
...
@@ -165,8 +169,8 @@ class CudaGraphRunner:
...
@@ -165,8 +169,8 @@ class CudaGraphRunner:
# Attention backend
# Attention backend
self
.
max_bs
=
max
(
self
.
capture_bs
)
self
.
max_bs
=
max
(
self
.
capture_bs
)
self
.
m
odel_runner
.
attn_backend
.
init_cuda_graph_state
(
self
.
max
_bs
)
self
.
m
ax_num_token
=
self
.
max_bs
*
self
.
num_tokens_per
_bs
self
.
model_runner
.
attn_backend
.
init_cuda_graph_state
(
self
.
max_num_token
)
self
.
seq_len_fill_value
=
(
self
.
seq_len_fill_value
=
(
self
.
model_runner
.
attn_backend
.
get_cuda_graph_seq_len_fill_value
()
self
.
model_runner
.
attn_backend
.
get_cuda_graph_seq_len_fill_value
()
)
)
...
@@ -179,12 +183,13 @@ class CudaGraphRunner:
...
@@ -179,12 +183,13 @@ class CudaGraphRunner:
# Common inputs
# Common inputs
with
torch
.
device
(
"cuda"
):
with
torch
.
device
(
"cuda"
):
self
.
input_ids
=
torch
.
zeros
((
self
.
max_
bs
,),
dtype
=
torch
.
int32
)
self
.
input_ids
=
torch
.
zeros
((
self
.
max_
num_token
,),
dtype
=
torch
.
int32
)
self
.
req_pool_indices
=
torch
.
zeros
((
self
.
max_bs
,),
dtype
=
torch
.
int32
)
self
.
req_pool_indices
=
torch
.
zeros
((
self
.
max_bs
,),
dtype
=
torch
.
int32
)
self
.
seq_lens
=
torch
.
full
(
self
.
seq_lens
=
torch
.
full
(
(
self
.
max_bs
,),
self
.
seq_len_fill_value
,
dtype
=
torch
.
int32
(
self
.
max_bs
,),
self
.
seq_len_fill_value
,
dtype
=
torch
.
int32
)
)
self
.
out_cache_loc
=
torch
.
zeros
((
self
.
max_bs
,),
dtype
=
torch
.
int32
)
self
.
out_cache_loc
=
torch
.
zeros
((
self
.
max_num_token
,),
dtype
=
torch
.
int32
)
self
.
positions
=
torch
.
zeros
((
self
.
max_num_token
,),
dtype
=
torch
.
int64
)
self
.
mrope_positions
=
torch
.
zeros
((
3
,
self
.
max_bs
),
dtype
=
torch
.
int32
)
self
.
mrope_positions
=
torch
.
zeros
((
3
,
self
.
max_bs
),
dtype
=
torch
.
int32
)
if
self
.
is_encoder_decoder
:
if
self
.
is_encoder_decoder
:
...
@@ -229,6 +234,9 @@ class CudaGraphRunner:
...
@@ -229,6 +234,9 @@ class CudaGraphRunner:
self
.
model_runner
.
model
.
capture_mode
=
False
self
.
model_runner
.
model
.
capture_mode
=
False
def
can_run
(
self
,
forward_batch
:
ForwardBatch
):
def
can_run
(
self
,
forward_batch
:
ForwardBatch
):
if
not
forward_batch
.
forward_mode
.
is_cuda_graph
():
return
False
if
self
.
enable_dp_attention
:
if
self
.
enable_dp_attention
:
min_num_tokens
,
max_num_tokens
=
min
(
forward_batch
.
global_num_tokens
),
max
(
min_num_tokens
,
max_num_tokens
=
min
(
forward_batch
.
global_num_tokens
),
max
(
forward_batch
.
global_num_tokens
forward_batch
.
global_num_tokens
...
@@ -258,12 +266,12 @@ class CudaGraphRunner:
...
@@ -258,12 +266,12 @@ class CudaGraphRunner:
def
capture
(
self
):
def
capture
(
self
):
with
graph_capture
()
as
graph_capture_context
:
with
graph_capture
()
as
graph_capture_context
:
self
.
stream
=
graph_capture_context
.
stream
self
.
stream
=
graph_capture_context
.
stream
capture_
bs
=
(
capture_
range
=
(
tqdm
.
tqdm
(
self
.
capture_bs
)
tqdm
.
tqdm
(
self
.
capture_bs
)
if
get_tensor_model_parallel_rank
()
==
0
if
get_tensor_model_parallel_rank
()
==
0
else
self
.
capture_bs
else
self
.
capture_bs
)
)
for
bs
in
capture_
bs
:
for
bs
in
capture_
range
:
with
patch_model
(
with
patch_model
(
self
.
model_runner
.
model
,
self
.
model_runner
.
model
,
bs
in
self
.
compile_bs
,
bs
in
self
.
compile_bs
,
...
@@ -283,12 +291,15 @@ class CudaGraphRunner:
...
@@ -283,12 +291,15 @@ class CudaGraphRunner:
def
capture_one_batch_size
(
self
,
bs
:
int
,
forward
:
Callable
):
def
capture_one_batch_size
(
self
,
bs
:
int
,
forward
:
Callable
):
graph
=
torch
.
cuda
.
CUDAGraph
()
graph
=
torch
.
cuda
.
CUDAGraph
()
stream
=
self
.
stream
stream
=
self
.
stream
num_token
=
bs
*
self
.
num_tokens_per_bs
# Common inputs
# Common inputs
input_ids
=
self
.
input_ids
[:
bs
]
input_ids
=
self
.
input_ids
[:
num_token
]
req_pool_indices
=
self
.
req_pool_indices
[:
bs
]
req_pool_indices
=
self
.
req_pool_indices
[:
bs
]
seq_lens
=
self
.
seq_lens
[:
bs
]
seq_lens
=
self
.
seq_lens
[:
bs
]
out_cache_loc
=
self
.
out_cache_loc
[:
bs
]
out_cache_loc
=
self
.
out_cache_loc
[:
num_token
]
positions
=
self
.
positions
[:
num_token
]
if
self
.
is_encoder_decoder
:
if
self
.
is_encoder_decoder
:
encoder_lens
=
self
.
encoder_lens
[:
bs
]
encoder_lens
=
self
.
encoder_lens
[:
bs
]
else
:
else
:
...
@@ -304,37 +315,41 @@ class CudaGraphRunner:
...
@@ -304,37 +315,41 @@ class CudaGraphRunner:
global_num_tokens
=
None
global_num_tokens
=
None
gathered_buffer
=
None
gathered_buffer
=
None
forward_batch
=
ForwardBatch
(
forward_mode
=
self
.
capture_forward_mode
,
batch_size
=
bs
,
input_ids
=
input_ids
,
req_pool_indices
=
req_pool_indices
,
seq_lens
=
seq_lens
,
req_to_token_pool
=
self
.
model_runner
.
req_to_token_pool
,
token_to_kv_pool
=
self
.
model_runner
.
token_to_kv_pool
,
attn_backend
=
self
.
model_runner
.
attn_backend
,
out_cache_loc
=
out_cache_loc
,
seq_lens_sum
=
seq_lens_sum
,
encoder_lens
=
encoder_lens
,
return_logprob
=
False
,
top_logprobs_nums
=
[
0
]
*
num_token
,
positions
=
positions
,
global_num_tokens
=
global_num_tokens
,
mrope_positions
=
mrope_positions
,
gathered_buffer
=
gathered_buffer
,
)
# Attention backend
# Attention backend
self
.
model_runner
.
attn_backend
.
init_forward_metadata_capture_cuda_graph
(
self
.
model_runner
.
attn_backend
.
init_forward_metadata_capture_cuda_graph
(
bs
,
bs
,
num_token
,
req_pool_indices
,
req_pool_indices
,
seq_lens
,
seq_lens
,
encoder_lens
,
encoder_lens
,
forward_batch
.
forward_mode
,
forward_batch
.
spec_info
,
)
)
# Run and capture
# Run and capture
def
run_once
():
def
run_once
():
forward_batch
=
ForwardBatch
(
forward_mode
=
ForwardMode
.
DECODE
,
batch_size
=
bs
,
input_ids
=
input_ids
,
req_pool_indices
=
req_pool_indices
,
seq_lens
=
seq_lens
,
req_to_token_pool
=
self
.
model_runner
.
req_to_token_pool
,
token_to_kv_pool
=
self
.
model_runner
.
token_to_kv_pool
,
attn_backend
=
self
.
model_runner
.
attn_backend
,
out_cache_loc
=
out_cache_loc
,
seq_lens_sum
=
seq_lens_sum
,
encoder_lens
=
encoder_lens
,
return_logprob
=
False
,
top_logprobs_nums
=
[
0
]
*
bs
,
positions
=
clamp_position
(
seq_lens
),
mrope_positions
=
mrope_positions
,
global_num_tokens
=
global_num_tokens
,
gathered_buffer
=
gathered_buffer
,
)
logits_output
=
forward
(
input_ids
,
forward_batch
.
positions
,
forward_batch
)
logits_output
=
forward
(
input_ids
,
forward_batch
.
positions
,
forward_batch
)
return
logits_output
.
next_token_logits
return
logits_output
.
next_token_logits
,
logits_output
.
hidden_states
for
_
in
range
(
2
):
for
_
in
range
(
2
):
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
...
@@ -360,6 +375,9 @@ class CudaGraphRunner:
...
@@ -360,6 +375,9 @@ class CudaGraphRunner:
def
replay
(
self
,
forward_batch
:
ForwardBatch
):
def
replay
(
self
,
forward_batch
:
ForwardBatch
):
assert
forward_batch
.
out_cache_loc
is
not
None
assert
forward_batch
.
out_cache_loc
is
not
None
raw_bs
=
forward_batch
.
batch_size
raw_bs
=
forward_batch
.
batch_size
# In normal decoding case, raw_bs == raw_num_token
# But in speculative decoding, raw_num_token is raw_bs * self.num_tokens_per_bs
raw_num_token
=
forward_batch
.
input_ids
.
numel
()
# Pad
# Pad
if
self
.
enable_dp_attention
:
if
self
.
enable_dp_attention
:
...
@@ -374,10 +392,13 @@ class CudaGraphRunner:
...
@@ -374,10 +392,13 @@ class CudaGraphRunner:
self
.
out_cache_loc
.
zero_
()
self
.
out_cache_loc
.
zero_
()
# Common inputs
# Common inputs
self
.
input_ids
[:
raw_
bs
].
copy_
(
forward_batch
.
input_ids
)
self
.
input_ids
[:
raw_
num_token
].
copy_
(
forward_batch
.
input_ids
)
self
.
req_pool_indices
[:
raw_bs
].
copy_
(
forward_batch
.
req_pool_indices
)
self
.
req_pool_indices
[:
raw_bs
].
copy_
(
forward_batch
.
req_pool_indices
)
self
.
seq_lens
[:
raw_bs
].
copy_
(
forward_batch
.
seq_lens
)
self
.
seq_lens
[:
raw_bs
].
copy_
(
forward_batch
.
seq_lens
)
self
.
out_cache_loc
[:
raw_bs
].
copy_
(
forward_batch
.
out_cache_loc
)
self
.
out_cache_loc
[:
raw_num_token
].
copy_
(
forward_batch
.
out_cache_loc
)
positions
=
clamp_position
(
forward_batch
.
seq_lens
)
self
.
positions
[:
raw_num_token
].
copy_
(
positions
)
if
self
.
is_encoder_decoder
:
if
self
.
is_encoder_decoder
:
self
.
encoder_lens
[:
raw_bs
].
copy_
(
forward_batch
.
encoder_lens
)
self
.
encoder_lens
[:
raw_bs
].
copy_
(
forward_batch
.
encoder_lens
)
if
forward_batch
.
mrope_positions
is
not
None
:
if
forward_batch
.
mrope_positions
is
not
None
:
...
@@ -390,13 +411,18 @@ class CudaGraphRunner:
...
@@ -390,13 +411,18 @@ class CudaGraphRunner:
self
.
seq_lens
,
self
.
seq_lens
,
forward_batch
.
seq_lens_sum
+
(
bs
-
raw_bs
),
forward_batch
.
seq_lens_sum
+
(
bs
-
raw_bs
),
self
.
encoder_lens
,
self
.
encoder_lens
,
forward_batch
.
forward_mode
,
forward_batch
.
spec_info
,
)
)
# Replay
# Replay
self
.
graphs
[
bs
].
replay
()
self
.
graphs
[
bs
].
replay
()
next_token_logits
=
self
.
output_buffers
[
bs
]
[:
raw_bs
]
next_token_logits
,
hidden_states
=
self
.
output_buffers
[
bs
]
logits_output
=
LogitsProcessorOutput
(
logits_output
=
LogitsProcessorOutput
(
next_token_logits
=
next_token_logits
,
next_token_logits
=
next_token_logits
[:
raw_num_token
],
hidden_states
=
(
hidden_states
[:
raw_num_token
]
if
hidden_states
is
not
None
else
None
),
)
)
return
logits_output
return
logits_output
python/sglang/srt/model_executor/forward_batch_info.py
View file @
f44d1439
...
@@ -96,7 +96,7 @@ class ForwardMode(IntEnum):
...
@@ -96,7 +96,7 @@ class ForwardMode(IntEnum):
return
self
==
ForwardMode
.
DRAFT_EXTEND
return
self
==
ForwardMode
.
DRAFT_EXTEND
def
is_cuda_graph
(
self
):
def
is_cuda_graph
(
self
):
return
self
in
(
ForwardMode
.
DECODE
,
ForwardMode
.
TARGET_VERIFY
)
return
self
==
ForwardMode
.
DECODE
or
self
==
ForwardMode
.
TARGET_VERIFY
def
is_dummy_first
(
self
):
def
is_dummy_first
(
self
):
return
self
==
ForwardMode
.
DUMMY_FIRST
return
self
==
ForwardMode
.
DUMMY_FIRST
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment