Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a523a3c1
"tests/test_training.py" did not exist on "89f2011cede1a899cc3f7e4b47ae1178d3a6e68c"
Unverified
Commit
a523a3c1
authored
Jul 23, 2024
by
Mingyi
Committed by
GitHub
Jul 23, 2024
Browse files
Reduce hardcoded logic of kernel usage (#707)
parent
9f94728f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
23 additions
and
13 deletions
+23
-13
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+4
-2
python/sglang/srt/managers/controller/infer_batch.py
python/sglang/srt/managers/controller/infer_batch.py
+19
-11
No files found.
python/sglang/srt/layers/radix_attention.py
View file @
a523a3c1
...
...
@@ -85,9 +85,9 @@ class RadixAttention(nn.Module):
return
o
def
extend_forward_flashinfer
(
self
,
q
,
k
,
v
,
input_metadata
:
InputMetadata
):
self
.
store_kv_cache
(
k
,
v
,
input_metadata
)
if
not
input_metadata
.
use_ragged
:
self
.
store_kv_cache
(
k
,
v
,
input_metadata
)
if
input_metadata
.
total_num_tokens
<=
4096
:
o
=
input_metadata
.
flashinfer_prefill_wrapper_paged
.
forward
(
q
.
contiguous
().
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
head_dim
),
input_metadata
.
token_to_kv_pool
.
get_kv_buffer
(
self
.
layer_id
),
...
...
@@ -122,6 +122,8 @@ class RadixAttention(nn.Module):
o
,
_
=
merge_state
(
o1
,
s1
,
o2
,
s2
)
self
.
store_kv_cache
(
k
,
v
,
input_metadata
)
if
input_metadata
.
total_num_tokens
>=
global_config
.
layer_sync_threshold
:
torch
.
cuda
.
synchronize
()
...
...
python/sglang/srt/managers/controller/infer_batch.py
View file @
a523a3c1
...
...
@@ -726,6 +726,7 @@ class InputMetadata:
flashinfer_prefill_wrapper_ragged
:
"BatchPrefillWithRaggedKVCacheWrapper"
=
None
flashinfer_prefill_wrapper_paged
:
"BatchPrefillWithPagedKVCacheWrapper"
=
None
flashinfer_decode_wrapper
:
"BatchDecodeWithPagedKVCacheWrapper"
=
None
use_ragged
:
bool
=
False
@
classmethod
def
create
(
...
...
@@ -741,7 +742,10 @@ class InputMetadata:
return_logprob
=
False
,
skip_flashinfer_init
=
False
,
):
use_ragged
=
False
if
not
skip_flashinfer_init
and
not
model_runner
.
server_args
.
disable_flashinfer
:
if
forward_mode
!=
ForwardMode
.
DECODE
and
int
(
torch
.
sum
(
seq_lens
))
>
4096
:
use_ragged
=
True
init_flashinfer_args
(
forward_mode
,
model_runner
,
...
...
@@ -749,6 +753,7 @@ class InputMetadata:
seq_lens
,
prefix_lens
,
model_runner
.
flashinfer_decode_wrapper
,
use_ragged
,
)
batch_size
=
len
(
req_pool_indices
)
...
...
@@ -803,6 +808,7 @@ class InputMetadata:
flashinfer_prefill_wrapper_ragged
=
model_runner
.
flashinfer_prefill_wrapper_ragged
,
flashinfer_prefill_wrapper_paged
=
model_runner
.
flashinfer_prefill_wrapper_paged
,
flashinfer_decode_wrapper
=
model_runner
.
flashinfer_decode_wrapper
,
use_ragged
=
use_ragged
,
)
if
model_runner
.
server_args
.
disable_flashinfer
:
...
...
@@ -823,6 +829,7 @@ def init_flashinfer_args(
seq_lens
,
prefix_lens
,
flashinfer_decode_wrapper
,
use_ragged
=
False
,
):
"""Init auxiliary variables for FlashInfer attention backend."""
num_qo_heads
=
model_runner
.
model_config
.
num_attention_heads
//
model_runner
.
tp_size
...
...
@@ -831,10 +838,10 @@ def init_flashinfer_args(
batch_size
=
len
(
req_pool_indices
)
total_num_tokens
=
int
(
torch
.
sum
(
seq_lens
))
if
forward_mode
==
ForwardMode
.
DECODE
or
total_num_tokens
<=
4096
:
paged_kernel_lens
=
seq_lens
else
:
if
use_ragged
:
paged_kernel_lens
=
prefix_lens
else
:
paged_kernel_lens
=
seq_lens
kv_indptr
=
torch
.
zeros
((
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
kv_indptr
[
1
:]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
...
...
@@ -867,14 +874,15 @@ def init_flashinfer_args(
qo_indptr
=
torch
.
zeros
((
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
qo_indptr
[
1
:]
=
torch
.
cumsum
(
seq_lens
-
prefix_lens
,
dim
=
0
)
model_runner
.
flashinfer_prefill_wrapper_ragged
.
end_forward
()
model_runner
.
flashinfer_prefill_wrapper_ragged
.
begin_forward
(
qo_indptr
,
qo_indptr
,
num_qo_heads
,
num_kv_heads
,
head_dim
,
)
if
use_ragged
:
model_runner
.
flashinfer_prefill_wrapper_ragged
.
end_forward
()
model_runner
.
flashinfer_prefill_wrapper_ragged
.
begin_forward
(
qo_indptr
,
qo_indptr
,
num_qo_heads
,
num_kv_heads
,
head_dim
,
)
# cached part
model_runner
.
flashinfer_prefill_wrapper_paged
.
end_forward
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment