Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
619bb6dd
Unverified
Commit
619bb6dd
authored
Sep 30, 2024
by
Liangsheng Yin
Committed by
GitHub
Sep 30, 2024
Browse files
Dispatch flashinfer wrappers (#1550)
parent
b88ea90d
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
76 additions
and
96 deletions
+76
-96
python/sglang/srt/layers/attention/flashinfer_backend.py
python/sglang/srt/layers/attention/flashinfer_backend.py
+61
-85
python/sglang/srt/layers/attention/flashinfer_utils.py
python/sglang/srt/layers/attention/flashinfer_utils.py
+15
-11
No files found.
python/sglang/srt/layers/attention/flashinfer_backend.py
View file @
619bb6dd
...
...
@@ -53,29 +53,27 @@ class FlashInferAttnBackend(AttentionBackend):
device
=
"cuda"
,
)
if
model_runner
.
sliding_window_size
is
None
:
self
.
prefill_wrapper_ragged
=
BatchPrefillWithRaggedKVCacheWrapper
(
self
.
workspace_buffer
,
"NHD"
)
self
.
prefill_wrapper_paged
=
BatchPrefillWithPagedKVCacheWrapper
(
self
.
workspace_buffer
,
"NHD"
)
self
.
decode_wrapper
=
BatchDecodeWithPagedKVCacheWrapper
(
self
.
workspace_buffer
,
"NHD"
,
use_tensor_cores
=
self
.
decode_use_tensor_cores
,
)
if
model_runner
.
sliding_window_size
is
not
None
:
self
.
num_wrappers
=
2
else
:
self
.
num_wrappers
=
1
# NOTE: we do not use ragged attention when there are multiple wrappers
self
.
prefill_wrapper_ragged
=
(
BatchPrefillWithRaggedKVCacheWrapper
(
self
.
workspace_buffer
,
"NHD"
)
if
self
.
num_wrappers
==
1
else
None
)
# Two wrappers: one for sliding window attention and one for full attention.
# Using two wrappers is unnecessary in the current PR, but are prepared for future PRs
self
.
prefill_wrapper_ragged
=
None
self
.
prefill_wrapper_paged
=
[]
self
.
decode_wrapper
=
[]
for
_
in
range
(
2
):
self
.
prefill_wrapper_paged
.
append
(
self
.
prefill_wrappers_paged
=
[]
self
.
decode_wrappers
=
[]
for
_
in
range
(
self
.
num_wrappers
):
self
.
prefill_wrappers_paged
.
append
(
BatchPrefillWithPagedKVCacheWrapper
(
self
.
workspace_buffer
,
"NHD"
)
)
self
.
decode_wrapper
.
append
(
self
.
decode_wrapper
s
.
append
(
BatchDecodeWithPagedKVCacheWrapper
(
self
.
workspace_buffer
,
"NHD"
,
...
...
@@ -86,6 +84,13 @@ class FlashInferAttnBackend(AttentionBackend):
self
.
forward_metadata
=
None
self
.
cuda_graph_metadata
=
{}
def
_get_wrapper_idx
(
self
,
layer
:
nn
.
Module
):
if
self
.
num_wrappers
==
1
:
return
0
# TODO: make sure the idx is related to sliding window size
return
layer
.
sliding_window_size
==
-
1
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
if
forward_batch
.
forward_mode
.
is_decode
():
prefix_lens
=
None
...
...
@@ -99,7 +104,7 @@ class FlashInferAttnBackend(AttentionBackend):
use_ragged
=
False
if
(
torch
.
sum
(
forward_batch
.
seq_lens
).
item
()
>=
4096
and
self
.
model_runner
.
sliding_window_size
is
None
and
self
.
num_wrappers
==
1
):
use_ragged
=
True
...
...
@@ -119,7 +124,7 @@ class FlashInferAttnBackend(AttentionBackend):
use_ragged
,
extend_no_prefix
,
total_num_tokens
,
self
.
decode_wrapper
,
self
.
decode_wrapper
s
,
)
def
init_cuda_graph_state
(
self
,
max_bs
:
int
):
...
...
@@ -135,33 +140,20 @@ class FlashInferAttnBackend(AttentionBackend):
(
max_bs
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
if
self
.
model_runner
.
sliding_window_size
is
not
None
:
self
.
cuda_graph_kv_indptr
=
[
self
.
cuda_graph_kv_indptr
,
self
.
cuda_graph_kv_indptr
.
clone
(),
# NOTE: the buffers are always in the form of list
self
.
cuda_graph_kv_indptr
=
[
self
.
cuda_graph_kv_indptr
]
+
[
self
.
cuda_graph_kv_indptr
.
clone
()
for
_
in
range
(
self
.
num_wrappers
-
1
)
]
self
.
cuda_graph_kv_indices
=
[
self
.
cuda_graph_kv_indices
,
self
.
cuda_graph_kv_indices
.
clone
(),
self
.
cuda_graph_kv_indices
=
[
self
.
cuda_graph_kv_indices
]
+
[
self
.
cuda_graph_kv_indices
.
clone
()
for
_
in
range
(
self
.
num_wrappers
-
1
)
]
def
init_forward_metadata_capture_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
,
seq_lens
):
if
self
.
model_runner
.
sliding_window_size
is
None
:
decode_wrapper
=
BatchDecodeWithPagedKVCacheWrapper
(
self
.
workspace_buffer
,
"NHD"
,
use_cuda_graph
=
True
,
use_tensor_cores
=
self
.
decode_use_tensor_cores
,
paged_kv_indptr_buffer
=
self
.
cuda_graph_kv_indptr
[:
bs
+
1
],
paged_kv_indices_buffer
=
self
.
cuda_graph_kv_indices
,
paged_kv_last_page_len_buffer
=
self
.
cuda_graph_kv_last_page_len
[:
bs
],
)
else
:
decode_wrapper
=
[]
for
i
in
range
(
2
):
decode_wrapper
.
append
(
decode_wrappers
=
[]
for
i
in
range
(
self
.
num_wrappers
):
decode_wrappers
.
append
(
BatchDecodeWithPagedKVCacheWrapper
(
self
.
workspace_buffer
,
"NHD"
,
...
...
@@ -169,9 +161,7 @@ class FlashInferAttnBackend(AttentionBackend):
use_tensor_cores
=
self
.
decode_use_tensor_cores
,
paged_kv_indptr_buffer
=
self
.
cuda_graph_kv_indptr
[
i
][:
bs
+
1
],
paged_kv_indices_buffer
=
self
.
cuda_graph_kv_indices
[
i
],
paged_kv_last_page_len_buffer
=
self
.
cuda_graph_kv_last_page_len
[
:
bs
],
paged_kv_last_page_len_buffer
=
self
.
cuda_graph_kv_last_page_len
[:
bs
],
)
)
...
...
@@ -181,12 +171,12 @@ class FlashInferAttnBackend(AttentionBackend):
req_pool_indices
,
seq_lens
,
None
,
decode_wrapper
,
decode_wrapper
s
,
)
self
.
cuda_graph_metadata
[
bs
]
=
decode_wrapper
self
.
cuda_graph_metadata
[
bs
]
=
decode_wrapper
s
self
.
forward_metadata
=
(
False
,
False
,
None
,
decode_wrapper
)
self
.
forward_metadata
=
(
False
,
False
,
None
,
decode_wrapper
s
)
def
init_forward_metadata_replay_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
,
seq_lens
...
...
@@ -204,17 +194,11 @@ class FlashInferAttnBackend(AttentionBackend):
return
0
def
forward_extend
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
):
if
not
isinstance
(
self
.
prefill_wrapper_paged
,
list
):
prefill_wrapper_paged
=
self
.
prefill_wrapper_paged
else
:
if
layer
.
sliding_window_size
!=
-
1
:
prefill_wrapper_paged
=
self
.
prefill_wrapper_paged
[
0
]
else
:
prefill_wrapper_paged
=
self
.
prefill_wrapper_paged
[
1
]
prefill_wrapper_paged
=
self
.
prefill_wrappers_paged
[
self
.
_get_wrapper_idx
(
layer
)
]
use_ragged
,
extend_no_prefix
,
total_num_tokens
,
decode_wrapper
=
(
self
.
forward_metadata
)
use_ragged
,
extend_no_prefix
,
_
,
_
=
self
.
forward_metadata
if
not
use_ragged
:
if
k
is
not
None
:
...
...
@@ -260,15 +244,7 @@ class FlashInferAttnBackend(AttentionBackend):
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
head_dim
)
def
forward_decode
(
self
,
q
,
k
,
v
,
layer
:
nn
.
Module
,
forward_batch
:
ForwardBatch
):
use_ragged
,
extend_no_prefix
,
total_num_tokens
,
decode_wrapper
=
(
self
.
forward_metadata
)
if
isinstance
(
decode_wrapper
,
list
):
if
layer
.
sliding_window_size
!=
-
1
:
decode_wrapper
=
decode_wrapper
[
0
]
else
:
decode_wrapper
=
decode_wrapper
[
1
]
decode_wrapper
=
self
.
forward_metadata
[
-
1
][
self
.
_get_wrapper_idx
(
layer
)]
if
k
is
not
None
:
assert
v
is
not
None
...
...
python/sglang/srt/layers/attention/flashinfer_utils.py
View file @
619bb6dd
...
...
@@ -47,7 +47,7 @@ class FlashinferUpdater:
req_pool_indices
,
seq_lens
,
prefix_lens
,
decode_wrapper
=
None
,
decode_wrapper
s
=
None
,
use_ragged
=
False
,
):
self
.
forward_mode
=
forward_mode
...
...
@@ -66,14 +66,14 @@ class FlashinferUpdater:
self
.
head_dim
=
model_runner
.
model_config
.
head_dim
self
.
batch_size
=
len
(
req_pool_indices
)
self
.
decode_wrapper
=
(
decode_wrapper
or
self
.
model_runner
.
attn_backend
.
decode_wrapper
self
.
decode_wrapper
s
=
(
decode_wrapper
s
or
self
.
model_runner
.
attn_backend
.
decode_wrapper
s
)
self
.
prefill_wrapper_ragged
=
(
self
.
model_runner
.
attn_backend
.
prefill_wrapper_ragged
)
self
.
prefill_wrapper_paged
=
(
self
.
model_runner
.
attn_backend
.
prefill_wrapper_paged
self
.
prefill_wrapper
s
_paged
=
(
self
.
model_runner
.
attn_backend
.
prefill_wrapper
s
_paged
)
self
.
kv_last_page_len
=
torch
.
ones
(
...
...
@@ -142,6 +142,7 @@ class FlashinferUpdater:
)
def
_update_decode_indices
(
self
,
decode_wrapper
):
assert
not
isinstance
(
decode_wrapper
,
list
)
decode_wrapper
.
end_forward
()
decode_wrapper
.
begin_forward
(
self
.
kv_indptr
,
...
...
@@ -156,6 +157,9 @@ class FlashinferUpdater:
)
def
_update_extend_indices
(
self
,
ragged_wrapper
,
paged_wrapper
):
assert
not
isinstance
(
paged_wrapper
,
list
)
assert
not
isinstance
(
ragged_wrapper
,
list
)
# extend part
qo_indptr
=
torch
.
zeros
(
(
self
.
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
...
...
@@ -189,11 +193,11 @@ class FlashinferUpdater:
self
.
_init_indices_no_sliding_window
()
if
self
.
forward_mode
.
is_decode
():
self
.
_update_decode_indices
(
self
.
decode_wrapper
)
self
.
_update_decode_indices
(
self
.
decode_wrapper
s
[
0
]
)
else
:
self
.
_update_extend_indices
(
self
.
prefill_wrapper_ragged
,
self
.
prefill_wrapper_paged
,
self
.
prefill_wrapper
s
_paged
[
0
]
,
)
def
update_indices_sliding_window
(
self
):
...
...
@@ -202,11 +206,11 @@ class FlashinferUpdater:
for
wrapper_id
in
range
(
2
):
self
.
_init_indices_sliding_window
(
wrapper_id
)
if
self
.
forward_mode
.
is_decode
():
self
.
_update_decode_indices
(
self
.
decode_wrapper
[
wrapper_id
])
self
.
_update_decode_indices
(
self
.
decode_wrapper
s
[
wrapper_id
])
else
:
self
.
_update_extend_indices
(
None
,
self
.
prefill_wrapper_paged
[
wrapper_id
],
self
.
prefill_wrapper
s
_paged
[
wrapper_id
],
)
...
...
@@ -216,7 +220,7 @@ def update_flashinfer_indices(
req_pool_indices
,
seq_lens
,
prefix_lens
,
decode_wrapper
=
None
,
decode_wrapper
s
=
None
,
use_ragged
=
False
,
):
updater
=
FlashinferUpdater
(
...
...
@@ -225,7 +229,7 @@ def update_flashinfer_indices(
req_pool_indices
,
seq_lens
,
prefix_lens
,
decode_wrapper
,
decode_wrapper
s
,
use_ragged
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment