Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
14cb544d
Unverified
Commit
14cb544d
authored
Aug 15, 2024
by
Ying Sheng
Committed by
GitHub
Aug 15, 2024
Browse files
[Fix] fix flashinfer usage for window attention (#1107)
parent
e86b1ccb
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
12 additions
and
18 deletions
+12
-18
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+1
-4
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+6
-8
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+5
-6
No files found.
python/sglang/srt/layers/radix_attention.py
View file @
14cb544d
...
@@ -120,12 +120,9 @@ class RadixAttention(nn.Module):
...
@@ -120,12 +120,9 @@ class RadixAttention(nn.Module):
# using two wrappers is unnecessary in the current PR, but are prepared for future PRs
# using two wrappers is unnecessary in the current PR, but are prepared for future PRs
prefill_wrapper_ragged
=
input_metadata
.
flashinfer_prefill_wrapper_ragged
prefill_wrapper_ragged
=
input_metadata
.
flashinfer_prefill_wrapper_ragged
prefill_wrapper_paged
=
input_metadata
.
flashinfer_prefill_wrapper_paged
prefill_wrapper_paged
=
input_metadata
.
flashinfer_prefill_wrapper_paged
if
self
.
sliding_window_size
!=
-
1
:
if
self
.
sliding_window_size
!=
-
1
or
self
.
reuse
:
prefill_wrapper_ragged
=
prefill_wrapper_ragged
[
0
]
prefill_wrapper_paged
=
prefill_wrapper_paged
[
0
]
prefill_wrapper_paged
=
prefill_wrapper_paged
[
0
]
else
:
else
:
if
isinstance
(
prefill_wrapper_ragged
,
list
):
prefill_wrapper_ragged
=
prefill_wrapper_ragged
[
1
]
if
isinstance
(
prefill_wrapper_paged
,
list
):
if
isinstance
(
prefill_wrapper_paged
,
list
):
prefill_wrapper_paged
=
prefill_wrapper_paged
[
1
]
prefill_wrapper_paged
=
prefill_wrapper_paged
[
1
]
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
14cb544d
...
@@ -324,9 +324,11 @@ def update_flashinfer_indices(
...
@@ -324,9 +324,11 @@ def update_flashinfer_indices(
else
:
else
:
kv_last_page_len
=
torch
.
ones
((
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_last_page_len
=
torch
.
ones
((
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
for
wrapper_id
in
range
(
2
):
for
wrapper_id
in
range
(
2
):
if
flashinfer_use_ragged
:
if
flashinfer_use_ragged
and
wrapper_id
==
1
:
# full attention use ragged+paged
paged_kernel_lens
=
prefix_lens
paged_kernel_lens
=
prefix_lens
else
:
else
:
# window attention use paged only
paged_kernel_lens
=
seq_lens
paged_kernel_lens
=
seq_lens
if
wrapper_id
==
0
and
forward_mode
==
ForwardMode
.
DECODE
:
if
wrapper_id
==
0
and
forward_mode
==
ForwardMode
.
DECODE
:
...
@@ -374,13 +376,9 @@ def update_flashinfer_indices(
...
@@ -374,13 +376,9 @@ def update_flashinfer_indices(
)
)
qo_indptr
[
1
:]
=
torch
.
cumsum
(
seq_lens
-
prefix_lens
,
dim
=
0
)
qo_indptr
[
1
:]
=
torch
.
cumsum
(
seq_lens
-
prefix_lens
,
dim
=
0
)
if
flashinfer_use_ragged
:
if
flashinfer_use_ragged
and
wrapper_id
==
1
:
model_runner
.
flashinfer_prefill_wrapper_ragged
[
model_runner
.
flashinfer_prefill_wrapper_ragged
.
end_forward
()
wrapper_id
model_runner
.
flashinfer_prefill_wrapper_ragged
.
begin_forward
(
].
end_forward
()
model_runner
.
flashinfer_prefill_wrapper_ragged
[
wrapper_id
].
begin_forward
(
qo_indptr
,
qo_indptr
,
qo_indptr
,
qo_indptr
,
num_qo_heads
,
num_qo_heads
,
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
14cb544d
...
@@ -342,15 +342,14 @@ class ModelRunner:
...
@@ -342,15 +342,14 @@ class ModelRunner:
dtype
=
torch
.
uint8
,
dtype
=
torch
.
uint8
,
device
=
"cuda"
,
device
=
"cuda"
,
)
)
self
.
flashinfer_prefill_wrapper_ragged
=
[]
self
.
flashinfer_prefill_wrapper_ragged
=
(
BatchPrefillWithRaggedKVCacheWrapper
(
self
.
flashinfer_workspace_buffer
,
"NHD"
)
)
self
.
flashinfer_prefill_wrapper_paged
=
[]
self
.
flashinfer_prefill_wrapper_paged
=
[]
self
.
flashinfer_decode_wrapper
=
[]
self
.
flashinfer_decode_wrapper
=
[]
for
i
in
range
(
2
):
for
i
in
range
(
2
):
self
.
flashinfer_prefill_wrapper_ragged
.
append
(
BatchPrefillWithRaggedKVCacheWrapper
(
self
.
flashinfer_workspace_buffer
,
"NHD"
)
)
self
.
flashinfer_prefill_wrapper_paged
.
append
(
self
.
flashinfer_prefill_wrapper_paged
.
append
(
BatchPrefillWithPagedKVCacheWrapper
(
BatchPrefillWithPagedKVCacheWrapper
(
self
.
flashinfer_workspace_buffer
,
"NHD"
self
.
flashinfer_workspace_buffer
,
"NHD"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment