Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a523a3c1
Unverified
Commit
a523a3c1
authored
Jul 23, 2024
by
Mingyi
Committed by
GitHub
Jul 23, 2024
Browse files
Reduce hardcoded logic of kernel usage (#707)
parent
9f94728f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
23 additions
and
13 deletions
+23
-13
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+4
-2
python/sglang/srt/managers/controller/infer_batch.py
python/sglang/srt/managers/controller/infer_batch.py
+19
-11
No files found.
python/sglang/srt/layers/radix_attention.py
View file @
a523a3c1
...
...
@@ -85,9 +85,9 @@ class RadixAttention(nn.Module):
return
o
def
extend_forward_flashinfer
(
self
,
q
,
k
,
v
,
input_metadata
:
InputMetadata
):
self
.
store_kv_cache
(
k
,
v
,
input_metadata
)
if
not
input_metadata
.
use_ragged
:
self
.
store_kv_cache
(
k
,
v
,
input_metadata
)
if
input_metadata
.
total_num_tokens
<=
4096
:
o
=
input_metadata
.
flashinfer_prefill_wrapper_paged
.
forward
(
q
.
contiguous
().
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
head_dim
),
input_metadata
.
token_to_kv_pool
.
get_kv_buffer
(
self
.
layer_id
),
...
...
@@ -122,6 +122,8 @@ class RadixAttention(nn.Module):
o
,
_
=
merge_state
(
o1
,
s1
,
o2
,
s2
)
self
.
store_kv_cache
(
k
,
v
,
input_metadata
)
if
input_metadata
.
total_num_tokens
>=
global_config
.
layer_sync_threshold
:
torch
.
cuda
.
synchronize
()
...
...
python/sglang/srt/managers/controller/infer_batch.py
View file @
a523a3c1
...
...
@@ -726,6 +726,7 @@ class InputMetadata:
flashinfer_prefill_wrapper_ragged
:
"BatchPrefillWithRaggedKVCacheWrapper"
=
None
flashinfer_prefill_wrapper_paged
:
"BatchPrefillWithPagedKVCacheWrapper"
=
None
flashinfer_decode_wrapper
:
"BatchDecodeWithPagedKVCacheWrapper"
=
None
use_ragged
:
bool
=
False
@
classmethod
def
create
(
...
...
@@ -741,7 +742,10 @@ class InputMetadata:
return_logprob
=
False
,
skip_flashinfer_init
=
False
,
):
use_ragged
=
False
if
not
skip_flashinfer_init
and
not
model_runner
.
server_args
.
disable_flashinfer
:
if
forward_mode
!=
ForwardMode
.
DECODE
and
int
(
torch
.
sum
(
seq_lens
))
>
4096
:
use_ragged
=
True
init_flashinfer_args
(
forward_mode
,
model_runner
,
...
...
@@ -749,6 +753,7 @@ class InputMetadata:
seq_lens
,
prefix_lens
,
model_runner
.
flashinfer_decode_wrapper
,
use_ragged
,
)
batch_size
=
len
(
req_pool_indices
)
...
...
@@ -803,6 +808,7 @@ class InputMetadata:
flashinfer_prefill_wrapper_ragged
=
model_runner
.
flashinfer_prefill_wrapper_ragged
,
flashinfer_prefill_wrapper_paged
=
model_runner
.
flashinfer_prefill_wrapper_paged
,
flashinfer_decode_wrapper
=
model_runner
.
flashinfer_decode_wrapper
,
use_ragged
=
use_ragged
,
)
if
model_runner
.
server_args
.
disable_flashinfer
:
...
...
@@ -823,6 +829,7 @@ def init_flashinfer_args(
seq_lens
,
prefix_lens
,
flashinfer_decode_wrapper
,
use_ragged
=
False
,
):
"""Init auxiliary variables for FlashInfer attention backend."""
num_qo_heads
=
model_runner
.
model_config
.
num_attention_heads
//
model_runner
.
tp_size
...
...
@@ -831,10 +838,10 @@ def init_flashinfer_args(
batch_size
=
len
(
req_pool_indices
)
total_num_tokens
=
int
(
torch
.
sum
(
seq_lens
))
if
forward_mode
==
ForwardMode
.
DECODE
or
total_num_tokens
<=
4096
:
paged_kernel_lens
=
seq_lens
else
:
if
use_ragged
:
paged_kernel_lens
=
prefix_lens
else
:
paged_kernel_lens
=
seq_lens
kv_indptr
=
torch
.
zeros
((
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_indptr
[
1
:]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
...
...
@@ -867,14 +874,15 @@ def init_flashinfer_args(
qo_indptr
=
torch
.
zeros
((
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
qo_indptr
[
1
:]
=
torch
.
cumsum
(
seq_lens
-
prefix_lens
,
dim
=
0
)
model_runner
.
flashinfer_prefill_wrapper_ragged
.
end_forward
()
model_runner
.
flashinfer_prefill_wrapper_ragged
.
begin_forward
(
qo_indptr
,
qo_indptr
,
num_qo_heads
,
num_kv_heads
,
head_dim
,
)
if
use_ragged
:
model_runner
.
flashinfer_prefill_wrapper_ragged
.
end_forward
()
model_runner
.
flashinfer_prefill_wrapper_ragged
.
begin_forward
(
qo_indptr
,
qo_indptr
,
num_qo_heads
,
num_kv_heads
,
head_dim
,
)
# cached part
model_runner
.
flashinfer_prefill_wrapper_paged
.
end_forward
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment