Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
6d0fa73e
Unverified
Commit
6d0fa73e
authored
Oct 17, 2024
by
Lianmin Zheng
Committed by
GitHub
Oct 17, 2024
Browse files
Simplify flashinfer utilities (#1704)
parent
9e0dac1a
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
391 additions
and
337 deletions
+391
-337
python/sglang/srt/layers/attention/__init__.py
python/sglang/srt/layers/attention/__init__.py
+27
-5
python/sglang/srt/layers/attention/double_sparsity_backend.py
...on/sglang/srt/layers/attention/double_sparsity_backend.py
+2
-2
python/sglang/srt/layers/attention/flashinfer_backend.py
python/sglang/srt/layers/attention/flashinfer_backend.py
+352
-83
python/sglang/srt/layers/attention/flashinfer_utils.py
python/sglang/srt/layers/attention/flashinfer_utils.py
+0
-237
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+2
-2
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+4
-4
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+1
-4
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+3
-0
No files found.
python/sglang/srt/layers/attention/__init__.py
View file @
6d0fa73e
from
abc
import
ABC
,
abstractmethod
import
torch
from
torch
import
nn
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
...
...
@@ -18,13 +19,13 @@ class AttentionBackend(ABC):
raise
NotImplementedError
()
def
init_forward_metadata_capture_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
,
seq_lens
self
,
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
):
"""Init the metadata for a forward pass for capturing a cuda graph."""
raise
NotImplementedError
()
def
init_forward_metadata_replay_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
,
seq_lens
self
,
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
):
"""Init the metadata for a forward pass for replying a cuda graph."""
raise
NotImplementedError
()
...
...
@@ -33,17 +34,38 @@ class AttentionBackend(ABC):
"""Get the fill value for padded seq lens. Typically, it is 0 or 1."""
raise
NotImplementedError
()
def
forward
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
):
def
forward
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
,
):
"""Run forward on an attention layer."""
if
forward_batch
.
forward_mode
.
is_decode
():
return
self
.
forward_decode
(
q
,
k
,
v
,
layer
,
forward_batch
)
else
:
return
self
.
forward_extend
(
q
,
k
,
v
,
layer
,
forward_batch
)
def
forward_decode
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
):
def
forward_decode
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
,
):
"""Run a forward for decode."""
raise
NotImplementedError
()
def
forward_extend
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
):
def
forward_extend
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
,
):
"""Run a forward for extend."""
raise
NotImplementedError
()
python/sglang/srt/layers/attention/double_sparsity_backend.py
View file @
6d0fa73e
...
...
@@ -134,7 +134,7 @@ class DoubleSparseAttnBackend(AttentionBackend):
)
def
init_forward_metadata_capture_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
,
seq_lens
self
,
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
):
self
.
forward_metadata
=
(
self
.
cuda_graph_start_loc
,
...
...
@@ -144,7 +144,7 @@ class DoubleSparseAttnBackend(AttentionBackend):
)
def
init_forward_metadata_replay_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
,
seq_lens
self
,
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
):
self
.
cuda_graph_start_loc
.
zero_
()
self
.
cuda_graph_start_loc
[
1
:
bs
]
=
torch
.
cumsum
(
seq_lens
[:
bs
-
1
],
dim
=
0
)
...
...
python/sglang/srt/layers/attention/flashinfer_backend.py
View file @
6d0fa73e
...
...
@@ -7,18 +7,17 @@ FlashInfer is faster and Triton is easier to customize.
Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
"""
from
enum
import
Enum
,
auto
from
typing
import
TYPE_CHECKING
import
torch
import
torch.nn
as
nn
import
triton
import
triton.language
as
tl
from
sglang.global_config
import
global_config
from
sglang.srt.layers.attention
import
AttentionBackend
from
sglang.srt.layers.attention.flashinfer_utils
import
(
WrapperDispatch
,
update_flashinfer_indices
,
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.utils
import
is_flashinfer_available
if
TYPE_CHECKING
:
...
...
@@ -34,13 +33,18 @@ if is_flashinfer_available():
from
flashinfer.decode
import
_grouped_size_compiled_for_decode_kernels
class
WrapperDispatch
(
Enum
):
SLIDING_WINDOW
=
auto
()
CROSS_ATTENTION
=
auto
()
class
FlashInferAttnBackend
(
AttentionBackend
):
"""Flashinfer attention kernels."""
def
__init__
(
self
,
model_runner
:
ModelRunner
):
super
().
__init__
()
self
.
model_runner
=
model_runner
# Parse constants
if
not
_grouped_size_compiled_for_decode_kernels
(
model_runner
.
model_config
.
num_attention_heads
//
model_runner
.
tp_size
,
model_runner
.
model_config
.
get_num_kv_heads
(
model_runner
.
tp_size
),
...
...
@@ -48,27 +52,43 @@ class FlashInferAttnBackend(AttentionBackend):
self
.
decode_use_tensor_cores
=
True
else
:
self
.
decode_use_tensor_cores
=
False
self
.
workspace_buffer
=
torch
.
empty
(
global_config
.
flashinfer_workspace_size
,
dtype
=
torch
.
uint8
,
device
=
"cuda"
,
)
self
.
max_context_len
=
model_runner
.
model_config
.
context_len
assert
not
(
model_runner
.
sliding_window_size
is
not
None
and
model_runner
.
has_cross_attention
),
"Sliding window and cross attention are not supported together"
self
.
num_wrappers
=
1
self
.
dispatch_reason
=
None
if
model_runner
.
sliding_window_size
is
not
None
:
self
.
num_wrappers
=
2
self
.
dispatch_reason
=
WrapperDispatch
.
SLIDING_WINDOW
elif
model_runner
.
has_cross_attention
:
self
.
num_wrappers
=
2
self
.
dispatch_reason
=
WrapperDispatch
.
CROSS_ATTENTION
else
:
self
.
num_wrappers
=
1
self
.
dispatch_reason
=
None
# Allocate buffers
self
.
workspace_buffer
=
torch
.
empty
(
global_config
.
flashinfer_workspace_size
,
dtype
=
torch
.
uint8
,
device
=
model_runner
.
device
,
)
max_bs
=
model_runner
.
req_to_token_pool
.
size
self
.
kv_indptr
=
[
torch
.
zeros
((
max_bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
)
for
_
in
range
(
self
.
num_wrappers
)
]
self
.
kv_last_page_len
=
torch
.
ones
(
(
max_bs
,),
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
)
self
.
qo_indptr
=
[
torch
.
zeros
((
max_bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
)
for
_
in
range
(
self
.
num_wrappers
)
]
# Create wrappers
# NOTE: we do not use ragged attention when there are multiple wrappers
self
.
prefill_wrapper_ragged
=
(
BatchPrefillWithRaggedKVCacheWrapper
(
self
.
workspace_buffer
,
"NHD"
)
...
...
@@ -92,26 +112,23 @@ class FlashInferAttnBackend(AttentionBackend):
)
)
# Create indices updater
self
.
indices_updater_decode
=
FlashInferIndicesUpdaterDecode
(
model_runner
,
self
)
self
.
indices_updater_prefill
=
FlashInferIndicesUpdaterPrefill
(
model_runner
,
self
)
# Other metadata
self
.
forward_metadata
=
None
self
.
cuda_graph_metadata
=
{}
def
_get_wrapper_idx
(
self
,
layer
:
nn
.
Module
):
if
self
.
num_wrappers
==
1
:
return
0
if
self
.
dispatch_reason
==
WrapperDispatch
.
SLIDING_WINDOW
:
return
layer
.
sliding_window_size
==
-
1
if
self
.
dispatch_reason
==
WrapperDispatch
.
CROSS_ATTENTION
:
return
layer
.
is_cross_attention
raise
ValueError
(
f
"Unknown dispatch reason:
{
self
.
dispatch_reason
}
"
)
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
if
forward_batch
.
forward_mode
.
is_decode
():
prefix_lens
=
None
use_ragged
=
False
extend_no_prefix
=
False
total_num_tokens
=
None
self
.
indices_updater_decode
.
update
(
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
)
self
.
forward_metadata
=
(
self
.
decode_wrappers
,)
else
:
prefix_lens
=
forward_batch
.
extend_prefix_lens
...
...
@@ -123,48 +140,32 @@ class FlashInferAttnBackend(AttentionBackend):
):
use_ragged
=
True
total_num_tokens
=
torch
.
sum
(
forward_batch
.
seq_lens
).
item
()
extend_no_prefix
=
not
torch
.
any
(
forward_batch
.
extend_prefix_lens
).
item
()
update_flashinfer_indices
(
forward_batch
.
forward_mode
,
self
.
model_runner
,
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
prefix_lens
,
use_ragged
=
use_ragged
,
)
self
.
indices_updater_prefill
.
update
(
forward_batch
.
req_pool_indices
,
forward_batch
.
seq_lens
,
prefix_lens
,
use_ragged
,
)
self
.
forward_metadata
=
(
use_ragged
,
extend_no_prefix
,
total_num_tokens
,
self
.
decode_wrappers
,
)
self
.
forward_metadata
=
(
use_ragged
,
extend_no_prefix
,
)
def
init_cuda_graph_state
(
self
,
max_bs
:
int
):
self
.
cuda_graph_kv_indptr
=
torch
.
zeros
(
(
max_bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
cuda_graph_kv_indices
=
torch
.
zeros
(
(
max_bs
*
self
.
model_runner
.
model_config
.
context_len
,),
cuda_graph_kv_indices
=
torch
.
zeros
(
(
max_bs
*
self
.
max_context_len
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
,
)
self
.
cuda_graph_kv_last_page_len
=
torch
.
ones
(
(
max_bs
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
# NOTE: the buffers are always in the form of list
self
.
cuda_graph_kv_indptr
=
[
self
.
cuda_graph_kv_indptr
]
+
[
self
.
cuda_graph_kv_indptr
.
clone
()
for
_
in
range
(
self
.
num_wrappers
-
1
)
]
self
.
cuda_graph_kv_indices
=
[
self
.
cuda_graph_kv_indices
]
+
[
self
.
cuda_graph_kv_indices
.
clone
()
for
_
in
range
(
self
.
num_wrappers
-
1
)
self
.
cuda_graph_kv_indices
=
[
cuda_graph_kv_indices
]
+
[
cuda_graph_kv_indices
.
clone
()
for
_
in
range
(
self
.
num_wrappers
-
1
)
]
def
init_forward_metadata_capture_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
,
seq_lens
self
,
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
):
decode_wrappers
=
[]
for
i
in
range
(
self
.
num_wrappers
):
...
...
@@ -174,35 +175,21 @@ class FlashInferAttnBackend(AttentionBackend):
"NHD"
,
use_cuda_graph
=
True
,
use_tensor_cores
=
self
.
decode_use_tensor_cores
,
paged_kv_indptr_buffer
=
self
.
cuda_graph_
kv_indptr
[
i
][:
bs
+
1
],
paged_kv_indptr_buffer
=
self
.
kv_indptr
[
i
][:
bs
+
1
],
paged_kv_indices_buffer
=
self
.
cuda_graph_kv_indices
[
i
],
paged_kv_last_page_len_buffer
=
self
.
cuda_graph_
kv_last_page_len
[:
bs
],
paged_kv_last_page_len_buffer
=
self
.
kv_last_page_len
[:
bs
],
)
)
update_flashinfer_indices
(
ForwardMode
.
DECODE
,
self
.
model_runner
,
req_pool_indices
,
seq_lens
,
None
,
decode_wrappers
,
)
self
.
indices_updater_decode
.
update
(
req_pool_indices
,
seq_lens
,
decode_wrappers
)
self
.
cuda_graph_metadata
[
bs
]
=
decode_wrappers
self
.
forward_metadata
=
(
False
,
False
,
None
,
decode_wrappers
)
self
.
forward_metadata
=
(
decode_wrappers
,)
def
init_forward_metadata_replay_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
,
seq_lens
self
,
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
):
update_flashinfer_indices
(
ForwardMode
.
DECODE
,
self
.
model_runner
,
req_pool_indices
[:
bs
],
seq_lens
[:
bs
],
None
,
self
.
cuda_graph_metadata
[
bs
],
self
.
indices_updater_decode
.
update
(
req_pool_indices
[:
bs
],
seq_lens
[:
bs
],
self
.
cuda_graph_metadata
[
bs
]
)
def
get_cuda_graph_seq_len_fill_value
(
self
):
...
...
@@ -213,7 +200,7 @@ class FlashInferAttnBackend(AttentionBackend):
self
.
_get_wrapper_idx
(
layer
)
]
use_ragged
,
extend_no_prefix
,
_
,
_
=
self
.
forward_metadata
use_ragged
,
extend_no_prefix
=
self
.
forward_metadata
if
not
use_ragged
:
if
k
is
not
None
:
...
...
@@ -259,7 +246,7 @@ class FlashInferAttnBackend(AttentionBackend):
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
head_dim
)
def
forward_decode
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
):
decode_wrapper
=
self
.
forward_metadata
[
-
1
][
self
.
_get_wrapper_idx
(
layer
)]
decode_wrapper
=
self
.
forward_metadata
[
0
][
self
.
_get_wrapper_idx
(
layer
)]
if
k
is
not
None
:
assert
v
is
not
None
...
...
@@ -275,3 +262,285 @@ class FlashInferAttnBackend(AttentionBackend):
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
head_dim
)
def
_get_wrapper_idx
(
self
,
layer
:
nn
.
Module
):
if
self
.
num_wrappers
==
1
:
return
0
if
self
.
dispatch_reason
==
WrapperDispatch
.
SLIDING_WINDOW
:
return
layer
.
sliding_window_size
==
-
1
if
self
.
dispatch_reason
==
WrapperDispatch
.
CROSS_ATTENTION
:
return
layer
.
is_cross_attention
raise
ValueError
(
f
"Unknown dispatch reason:
{
self
.
dispatch_reason
}
"
)
class
FlashInferIndicesUpdaterDecode
:
def
__init__
(
self
,
model_runner
:
ModelRunner
,
attn_backend
:
AttentionBackend
):
# Constants
self
.
num_qo_heads
=
(
model_runner
.
model_config
.
num_attention_heads
//
model_runner
.
tp_size
)
self
.
num_kv_heads
=
model_runner
.
model_config
.
get_num_kv_heads
(
model_runner
.
tp_size
)
self
.
head_dim
=
model_runner
.
model_config
.
head_dim
self
.
data_type
=
model_runner
.
kv_cache_dtype
self
.
q_data_type
=
model_runner
.
dtype
self
.
max_context_len
=
model_runner
.
req_to_token_pool
.
req_to_token
.
size
(
1
)
self
.
sliding_window_size
=
model_runner
.
sliding_window_size
# Buffers and wrappers
self
.
kv_indptr
=
attn_backend
.
kv_indptr
self
.
kv_last_page_len
=
attn_backend
.
kv_last_page_len
self
.
req_to_token
=
model_runner
.
req_to_token_pool
.
req_to_token
self
.
decode_wrappers
=
attn_backend
.
decode_wrappers
# Dispatch
if
attn_backend
.
dispatch_reason
==
WrapperDispatch
.
SLIDING_WINDOW
:
self
.
update
=
self
.
update_sliding_window
elif
attn_backend
.
dispatch_reason
==
WrapperDispatch
.
CROSS_ATTENTION
:
self
.
update
=
self
.
update_cross_attention
else
:
assert
attn_backend
.
num_wrappers
==
1
self
.
update
=
self
.
update_single_wrapper
def
update_single_wrapper
(
self
,
req_pool_indices
,
seq_lens
,
decode_wrappers
=
None
):
decode_wrappers
=
decode_wrappers
or
self
.
decode_wrappers
self
.
call_begin_forward
(
decode_wrappers
[
0
],
req_pool_indices
,
seq_lens
,
self
.
kv_indptr
[
0
],
None
)
def
update_sliding_window
(
self
,
req_pool_indices
,
seq_lens
,
decode_wrappers
=
None
):
decode_wrappers
=
decode_wrappers
or
self
.
decode_wrappers
for
wrapper_id
in
range
(
2
):
if
wrapper_id
==
0
:
# Sliding window attention
paged_kernel_lens
=
torch
.
minimum
(
# TODO: replace this with clamp
seq_lens
,
torch
.
tensor
(
self
.
sliding_window_size
+
1
),
)
else
:
# Full attention
paged_kernel_lens
=
seq_lens
kv_start_idx
=
seq_lens
-
paged_kernel_lens
self
.
call_begin_forward
(
decode_wrappers
[
wrapper_id
],
req_pool_indices
,
paged_kernel_lens
,
self
.
kv_indptr
[
wrapper_id
],
kv_start_idx
,
)
def
update_cross_attention
(
self
):
raise
NotImplementedError
()
def
call_begin_forward
(
self
,
wrapper
,
req_pool_indices
,
paged_kernel_lens
,
kv_indptr
,
kv_start_idx
):
bs
=
len
(
req_pool_indices
)
kv_indptr
=
kv_indptr
[:
bs
+
1
]
# TODO: optimize the blocking call on kv_indptr[-1]
kv_indptr
[
1
:]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
kv_indices
=
torch
.
empty
(
kv_indptr
[
-
1
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
create_flashinfer_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
req_pool_indices
,
paged_kernel_lens
,
kv_indptr
,
kv_start_idx
,
kv_indices
,
self
.
max_context_len
,
)
wrapper
.
end_forward
()
wrapper
.
begin_forward
(
kv_indptr
,
kv_indices
,
self
.
kv_last_page_len
[:
bs
],
self
.
num_qo_heads
,
self
.
num_kv_heads
,
self
.
head_dim
,
1
,
data_type
=
self
.
data_type
,
q_data_type
=
self
.
q_data_type
,
)
class
FlashInferIndicesUpdaterPrefill
:
def
__init__
(
self
,
model_runner
:
ModelRunner
,
attn_backend
:
AttentionBackend
):
# Constants
self
.
num_qo_heads
=
(
model_runner
.
model_config
.
num_attention_heads
//
model_runner
.
tp_size
)
self
.
num_kv_heads
=
model_runner
.
model_config
.
get_num_kv_heads
(
model_runner
.
tp_size
)
self
.
head_dim
=
model_runner
.
model_config
.
head_dim
self
.
data_type
=
model_runner
.
kv_cache_dtype
self
.
q_data_type
=
model_runner
.
dtype
self
.
max_context_len
=
model_runner
.
req_to_token_pool
.
req_to_token
.
size
(
1
)
self
.
sliding_window_size
=
model_runner
.
sliding_window_size
# Buffers and wrappers
self
.
kv_indptr
=
attn_backend
.
kv_indptr
self
.
kv_last_page_len
=
attn_backend
.
kv_last_page_len
self
.
qo_indptr
=
attn_backend
.
qo_indptr
self
.
req_to_token
=
model_runner
.
req_to_token_pool
.
req_to_token
self
.
wrapper_ragged
=
attn_backend
.
prefill_wrapper_ragged
self
.
wrappers_paged
=
attn_backend
.
prefill_wrappers_paged
# Dispatch
if
attn_backend
.
dispatch_reason
==
WrapperDispatch
.
SLIDING_WINDOW
:
self
.
update
=
self
.
update_sliding_window
elif
attn_backend
.
dispatch_reason
==
WrapperDispatch
.
CROSS_ATTENTION
:
self
.
update
=
self
.
update_cross_attention
else
:
assert
attn_backend
.
num_wrappers
==
1
self
.
update
=
self
.
update_single_wrapper
def
update_single_wrapper
(
self
,
req_pool_indices
,
seq_lens
,
prefix_lens
,
use_ragged
):
if
use_ragged
:
paged_kernel_lens
=
prefix_lens
else
:
paged_kernel_lens
=
seq_lens
self
.
call_begin_forward
(
self
.
wrapper_ragged
,
self
.
wrappers_paged
[
0
],
req_pool_indices
,
paged_kernel_lens
,
seq_lens
,
prefix_lens
,
None
,
self
.
kv_indptr
[
0
],
self
.
qo_indptr
[
0
],
use_ragged
,
)
def
update_sliding_window
(
self
,
req_pool_indices
,
seq_lens
,
prefix_lens
,
use_ragged
):
for
wrapper_id
in
range
(
2
):
if
wrapper_id
==
0
:
# window attention use paged only
paged_kernel_lens
=
torch
.
minimum
(
seq_lens
,
torch
.
tensor
(
self
.
sliding_window_size
)
+
seq_lens
-
prefix_lens
,
)
else
:
# full attention
paged_kernel_lens
=
seq_lens
kv_start_idx
=
seq_lens
-
paged_kernel_lens
self
.
call_begin_forward
(
self
.
wrapper_ragged
,
self
.
wrappers_paged
[
wrapper_id
],
req_pool_indices
,
paged_kernel_lens
,
seq_lens
,
prefix_lens
,
kv_start_idx
,
self
.
kv_indptr
[
wrapper_id
],
self
.
qo_indptr
[
wrapper_id
],
use_ragged
,
)
def
update_cross_attention
(
self
):
raise
NotImplementedError
()
def
call_begin_forward
(
self
,
wrapper_ragged
,
wrapper_paged
,
req_pool_indices
,
paged_kernel_lens
,
seq_lens
,
prefix_lens
,
kv_start_idx
,
kv_indptr
,
qo_indptr
,
use_ragged
,
):
bs
=
len
(
req_pool_indices
)
kv_indptr
=
kv_indptr
[:
bs
+
1
]
kv_indptr
[
1
:]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
kv_indices
=
torch
.
empty
(
kv_indptr
[
-
1
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
create_flashinfer_kv_indices_triton
[(
bs
,)](
self
.
req_to_token
,
req_pool_indices
,
paged_kernel_lens
,
kv_indptr
,
kv_start_idx
,
kv_indices
,
self
.
max_context_len
,
)
qo_indptr
=
qo_indptr
[:
bs
+
1
]
qo_indptr
[
1
:]
=
torch
.
cumsum
(
seq_lens
-
prefix_lens
,
dim
=
0
)
# extend part
if
use_ragged
:
wrapper_ragged
.
end_forward
()
wrapper_ragged
.
begin_forward
(
qo_indptr
,
qo_indptr
,
self
.
num_qo_heads
,
self
.
num_kv_heads
,
self
.
head_dim
,
)
# cached part
wrapper_paged
.
end_forward
()
wrapper_paged
.
begin_forward
(
qo_indptr
,
kv_indptr
,
kv_indices
,
self
.
kv_last_page_len
[:
bs
],
self
.
num_qo_heads
,
self
.
num_kv_heads
,
self
.
head_dim
,
1
,
)
@
triton
.
jit
def
create_flashinfer_kv_indices_triton
(
req_to_token_ptr
,
# [max_batch, max_context_len]
req_pool_indices_ptr
,
page_kernel_lens_ptr
,
kv_indptr
,
kv_start_idx
,
kv_indices_ptr
,
max_context_len
:
tl
.
constexpr
,
):
BLOCK_SIZE
:
tl
.
constexpr
=
512
pid
=
tl
.
program_id
(
axis
=
0
)
req_pool_index
=
tl
.
load
(
req_pool_indices_ptr
+
pid
)
kv_indices_offset
=
tl
.
load
(
kv_indptr
+
pid
)
kv_start
=
0
kv_end
=
0
if
kv_start_idx
:
kv_start
=
tl
.
load
(
kv_start_idx
+
pid
).
to
(
tl
.
int32
)
kv_end
=
kv_start
kv_end
+=
tl
.
load
(
page_kernel_lens_ptr
+
pid
).
to
(
tl
.
int32
)
req_to_token_ptr
+=
req_pool_index
*
max_context_len
kv_indices_ptr
+=
kv_indices_offset
ld_offset
=
kv_start
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
st_offset
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
num_loop
=
tl
.
cdiv
(
kv_end
-
kv_start
,
BLOCK_SIZE
)
for
_
in
range
(
num_loop
):
mask
=
ld_offset
<
kv_end
data
=
tl
.
load
(
req_to_token_ptr
+
ld_offset
,
mask
=
mask
)
tl
.
store
(
kv_indices_ptr
+
st_offset
,
data
,
mask
=
mask
)
ld_offset
+=
BLOCK_SIZE
st_offset
+=
BLOCK_SIZE
python/sglang/srt/layers/attention/flashinfer_utils.py
deleted
100644 → 0
View file @
9e0dac1a
from
enum
import
Enum
,
auto
import
torch
import
triton
import
triton.language
as
tl
class
WrapperDispatch
(
Enum
):
SLIDING_WINDOW
=
auto
()
CROSS_ATTENTION
=
auto
()
@
triton
.
jit
def
create_flashinfer_kv_indices_triton
(
req_to_token_ptr
,
# [max_batch, max_context_len]
req_pool_indices_ptr
,
page_kernel_lens_ptr
,
kv_indptr
,
kv_start_idx
,
kv_indices_ptr
,
max_context_len
:
tl
.
constexpr
,
):
BLOCK_SIZE
:
tl
.
constexpr
=
512
pid
=
tl
.
program_id
(
axis
=
0
)
req_pool_index
=
tl
.
load
(
req_pool_indices_ptr
+
pid
)
kv_indices_offset
=
tl
.
load
(
kv_indptr
+
pid
)
kv_start
=
0
kv_end
=
0
if
kv_start_idx
:
kv_start
=
tl
.
load
(
kv_start_idx
+
pid
).
to
(
tl
.
int32
)
kv_end
=
kv_start
kv_end
+=
tl
.
load
(
page_kernel_lens_ptr
+
pid
).
to
(
tl
.
int32
)
req_to_token_ptr
+=
req_pool_index
*
max_context_len
kv_indices_ptr
+=
kv_indices_offset
ld_offset
=
kv_start
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
st_offset
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
num_loop
=
tl
.
cdiv
(
kv_end
-
kv_start
,
BLOCK_SIZE
)
for
_
in
range
(
num_loop
):
mask
=
ld_offset
<
kv_end
data
=
tl
.
load
(
req_to_token_ptr
+
ld_offset
,
mask
=
mask
)
tl
.
store
(
kv_indices_ptr
+
st_offset
,
data
,
mask
=
mask
)
ld_offset
+=
BLOCK_SIZE
st_offset
+=
BLOCK_SIZE
class
FlashinferUpdater
:
def
__init__
(
self
,
forward_mode
,
model_runner
,
req_pool_indices
,
seq_lens
,
prefix_lens
,
decode_wrappers
=
None
,
use_ragged
=
False
,
):
self
.
forward_mode
=
forward_mode
self
.
model_runner
=
model_runner
self
.
req_pool_indices
=
req_pool_indices
self
.
seq_lens
=
seq_lens
self
.
prefix_lens
=
prefix_lens
self
.
use_ragged
=
use_ragged
self
.
num_qo_heads
=
(
model_runner
.
model_config
.
num_attention_heads
//
model_runner
.
tp_size
)
self
.
num_kv_heads
=
model_runner
.
model_config
.
get_num_kv_heads
(
model_runner
.
tp_size
)
self
.
head_dim
=
model_runner
.
model_config
.
head_dim
self
.
batch_size
=
len
(
req_pool_indices
)
self
.
decode_wrappers
=
(
decode_wrappers
or
self
.
model_runner
.
attn_backend
.
decode_wrappers
)
self
.
prefill_wrapper_ragged
=
(
self
.
model_runner
.
attn_backend
.
prefill_wrapper_ragged
)
self
.
prefill_wrappers_paged
=
(
self
.
model_runner
.
attn_backend
.
prefill_wrappers_paged
)
self
.
kv_last_page_len
=
torch
.
ones
(
(
self
.
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
def
_update_decode_indices
(
self
,
decode_wrapper
):
assert
not
isinstance
(
decode_wrapper
,
list
)
decode_wrapper
.
end_forward
()
decode_wrapper
.
begin_forward
(
self
.
kv_indptr
,
self
.
kv_indices
,
self
.
kv_last_page_len
,
self
.
num_qo_heads
,
self
.
num_kv_heads
,
self
.
head_dim
,
1
,
data_type
=
self
.
model_runner
.
kv_cache_dtype
,
q_data_type
=
self
.
model_runner
.
dtype
,
)
def
_update_extend_indices
(
self
,
ragged_wrapper
,
paged_wrapper
):
assert
not
isinstance
(
paged_wrapper
,
list
)
assert
not
isinstance
(
ragged_wrapper
,
list
)
# extend part
qo_indptr
=
torch
.
zeros
(
(
self
.
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
qo_indptr
[
1
:]
=
torch
.
cumsum
(
self
.
seq_lens
-
self
.
prefix_lens
,
dim
=
0
)
if
self
.
use_ragged
:
ragged_wrapper
.
end_forward
()
ragged_wrapper
.
begin_forward
(
qo_indptr
,
qo_indptr
,
self
.
num_qo_heads
,
self
.
num_kv_heads
,
self
.
head_dim
,
)
# cached part
paged_wrapper
.
end_forward
()
paged_wrapper
.
begin_forward
(
qo_indptr
,
self
.
kv_indptr
,
self
.
kv_indices
,
self
.
kv_last_page_len
,
self
.
num_qo_heads
,
self
.
num_kv_heads
,
self
.
head_dim
,
1
,
)
def
_get_indices
(
self
,
dispatch_reason
:
WrapperDispatch
=
None
,
wrapper_id
=
0
):
if
dispatch_reason
is
None
:
if
self
.
use_ragged
:
paged_kernel_lens
=
self
.
prefix_lens
else
:
paged_kernel_lens
=
self
.
seq_lens
self
.
kv_start_idx
=
None
elif
dispatch_reason
==
WrapperDispatch
.
SLIDING_WINDOW
:
if
wrapper_id
==
0
:
# window attention use paged only
if
self
.
forward_mode
.
is_decode
():
paged_kernel_lens
=
torch
.
minimum
(
self
.
seq_lens
,
torch
.
tensor
(
self
.
model_runner
.
sliding_window_size
+
1
),
)
else
:
paged_kernel_lens
=
torch
.
minimum
(
self
.
seq_lens
,
torch
.
tensor
(
self
.
model_runner
.
sliding_window_size
)
+
self
.
seq_lens
-
self
.
prefix_lens
,
)
else
:
# full attention
paged_kernel_lens
=
self
.
seq_lens
self
.
kv_start_idx
=
self
.
seq_lens
-
paged_kernel_lens
self
.
kv_indptr
=
torch
.
zeros
(
(
self
.
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
self
.
kv_indptr
[
1
:]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
self
.
kv_indices
=
torch
.
empty
(
self
.
kv_indptr
[
-
1
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
create_flashinfer_kv_indices_triton
[(
self
.
batch_size
,)](
self
.
model_runner
.
req_to_token_pool
.
req_to_token
,
self
.
req_pool_indices
,
paged_kernel_lens
,
self
.
kv_indptr
,
self
.
kv_start_idx
,
self
.
kv_indices
,
self
.
model_runner
.
req_to_token_pool
.
req_to_token
.
size
(
1
),
)
def
_update_indicess_single_wrapper
(
self
):
self
.
_get_indices
()
if
self
.
forward_mode
.
is_decode
():
self
.
_update_decode_indices
(
self
.
decode_wrappers
[
0
])
else
:
self
.
_update_extend_indices
(
self
.
prefill_wrapper_ragged
,
self
.
prefill_wrappers_paged
[
0
],
)
def
_update_indices_cross_attention
(
self
):
pass
def
_update_indices_sliding_window
(
self
):
assert
self
.
use_ragged
is
False
for
wrapper_id
in
range
(
2
):
self
.
_get_indices
(
WrapperDispatch
.
SLIDING_WINDOW
,
wrapper_id
)
if
self
.
forward_mode
.
is_decode
():
self
.
_update_decode_indices
(
self
.
decode_wrappers
[
wrapper_id
])
else
:
self
.
_update_extend_indices
(
None
,
self
.
prefill_wrappers_paged
[
wrapper_id
],
)
def
update_flashinfer_indices
(
forward_mode
,
model_runner
,
req_pool_indices
,
seq_lens
,
prefix_lens
,
decode_wrappers
=
None
,
use_ragged
=
False
,
):
updater
=
FlashinferUpdater
(
forward_mode
,
model_runner
,
req_pool_indices
,
seq_lens
,
prefix_lens
,
decode_wrappers
,
use_ragged
,
)
dispatch_reason
=
model_runner
.
attn_backend
.
dispatch_reason
if
dispatch_reason
==
WrapperDispatch
.
SLIDING_WINDOW
:
updater
.
_update_indices_sliding_window
()
elif
dispatch_reason
==
WrapperDispatch
.
CROSS_ATTENTION
:
updater
.
_update_indices_cross_attention
()
else
:
assert
model_runner
.
attn_backend
.
num_wrappers
==
1
updater
.
_update_indicess_single_wrapper
()
python/sglang/srt/layers/attention/triton_backend.py
View file @
6d0fa73e
...
...
@@ -81,7 +81,7 @@ class TritonAttnBackend(AttentionBackend):
)
def
init_forward_metadata_capture_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
,
seq_lens
self
,
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
):
self
.
forward_metadata
=
(
self
.
cuda_graph_start_loc
,
...
...
@@ -91,7 +91,7 @@ class TritonAttnBackend(AttentionBackend):
)
def
init_forward_metadata_replay_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
,
seq_lens
self
,
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
):
self
.
cuda_graph_start_loc
.
zero_
()
self
.
cuda_graph_start_loc
[
1
:
bs
]
=
torch
.
cumsum
(
seq_lens
[:
bs
-
1
],
dim
=
0
)
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
6d0fa73e
...
...
@@ -744,7 +744,6 @@ class ScheduleBatch:
self
.
forward_mode
=
ForwardMode
.
DECODE
self
.
input_ids
=
self
.
output_ids
self
.
seq_lens
.
add_
(
1
)
self
.
output_ids
=
None
if
self
.
sampling_info
.
penalizer_orchestrator
:
self
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
...
...
@@ -755,9 +754,10 @@ class ScheduleBatch:
bs
=
len
(
self
.
reqs
)
self
.
out_cache_loc
=
self
.
alloc_token_slots
(
bs
)
self
.
req_to_token_pool
.
req_to_token
[
self
.
req_pool_indices
,
self
.
seq_lens
-
1
]
=
self
.
out_cache_loc
self
.
req_to_token_pool
.
req_to_token
[
self
.
req_pool_indices
,
self
.
seq_lens
]
=
(
self
.
out_cache_loc
)
self
.
seq_lens
.
add_
(
1
)
def
filter_batch
(
self
,
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
6d0fa73e
...
...
@@ -134,9 +134,7 @@ class ForwardBatch:
)
# Init position information
if
ret
.
forward_mode
.
is_decode
():
ret
.
positions
=
(
ret
.
seq_lens
-
1
).
to
(
torch
.
int64
)
else
:
if
not
ret
.
forward_mode
.
is_decode
():
ret
.
positions
=
torch
.
tensor
(
np
.
concatenate
(
[
...
...
@@ -164,7 +162,6 @@ class ForwardBatch:
ret
.
req_to_token_pool
=
model_runner
.
req_to_token_pool
ret
.
token_to_kv_pool
=
model_runner
.
token_to_kv_pool
ret
.
attn_backend
=
model_runner
.
attn_backend
model_runner
.
attn_backend
.
init_forward_metadata
(
ret
)
# Init lora information
if
model_runner
.
server_args
.
lora_paths
is
not
None
:
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
6d0fa73e
...
...
@@ -551,11 +551,14 @@ class ModelRunner:
):
return
self
.
cuda_graph_runner
.
replay
(
forward_batch
)
forward_batch
.
positions
=
(
forward_batch
.
seq_lens
-
1
).
to
(
torch
.
int64
)
self
.
attn_backend
.
init_forward_metadata
(
forward_batch
)
return
self
.
model
.
forward
(
forward_batch
.
input_ids
,
forward_batch
.
positions
,
forward_batch
)
def
forward_extend
(
self
,
forward_batch
:
ForwardBatch
):
self
.
attn_backend
.
init_forward_metadata
(
forward_batch
)
if
self
.
is_generation
:
return
self
.
model
.
forward
(
forward_batch
.
input_ids
,
forward_batch
.
positions
,
forward_batch
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment