Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
cf99eab7
Unverified
Commit
cf99eab7
authored
Jul 23, 2024
by
Ying Sheng
Committed by
GitHub
Jul 23, 2024
Browse files
Fix flashinfer (#700)
parent
9fdea29d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
34 additions
and
20 deletions
+34
-20
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+32
-19
python/sglang/srt/managers/controller/infer_batch.py
python/sglang/srt/managers/controller/infer_batch.py
+2
-1
No files found.
python/sglang/srt/layers/radix_attention.py
View file @
cf99eab7
...
...
@@ -85,32 +85,45 @@ class RadixAttention(nn.Module):
return
o
def
extend_forward_flashinfer
(
self
,
q
,
k
,
v
,
input_metadata
:
InputMetadata
):
o1
,
s1
=
input_metadata
.
flashinfer_prefill_wrapper_ragged
.
forward_return_lse
(
q
.
contiguous
().
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
head_dim
),
k
.
contiguous
().
view
(
-
1
,
self
.
tp_k_head_num
,
self
.
head_dim
),
v
.
contiguous
().
view
(
-
1
,
self
.
tp_v_head_num
,
self
.
head_dim
),
causal
=
True
,
sm_scale
=
self
.
scaling
,
logits_soft_cap
=
self
.
logit_cap
,
)
self
.
store_kv_cache
(
k
,
v
,
input_metadata
)
if
input_metadata
.
extend_no_prefix
:
o
=
o1
else
:
o2
,
s2
=
input_metadata
.
flashinfer_prefill_wrapper_paged
.
forward_return_lse
(
if
input_metadata
.
total_num_tokens
<=
4096
:
o
=
input_metadata
.
flashinfer_prefill_wrapper_paged
.
forward
(
q
.
contiguous
().
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
head_dim
),
input_metadata
.
token_to_kv_pool
.
get_kv_buffer
(
self
.
layer_id
),
causal
=
Fals
e
,
causal
=
Tru
e
,
sm_scale
=
self
.
scaling
,
logits_soft_cap
=
self
.
logit_cap
,
)
else
:
o1
,
s1
=
(
input_metadata
.
flashinfer_prefill_wrapper_ragged
.
forward_return_lse
(
q
.
contiguous
().
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
head_dim
),
k
.
contiguous
().
view
(
-
1
,
self
.
tp_k_head_num
,
self
.
head_dim
),
v
.
contiguous
().
view
(
-
1
,
self
.
tp_v_head_num
,
self
.
head_dim
),
causal
=
True
,
sm_scale
=
self
.
scaling
,
logits_soft_cap
=
self
.
logit_cap
,
)
)
o
,
_
=
merge_state
(
o1
,
s1
,
o2
,
s2
)
self
.
store_kv_cache
(
k
,
v
,
input_metadata
)
if
input_metadata
.
total_num_tokens
>=
global_config
.
layer_sync_threshold
:
torch
.
cuda
.
synchronize
()
if
input_metadata
.
extend_no_prefix
:
o
=
o1
else
:
o2
,
s2
=
(
input_metadata
.
flashinfer_prefill_wrapper_paged
.
forward_return_lse
(
q
.
contiguous
().
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
head_dim
),
input_metadata
.
token_to_kv_pool
.
get_kv_buffer
(
self
.
layer_id
),
causal
=
False
,
sm_scale
=
self
.
scaling
,
logits_soft_cap
=
self
.
logit_cap
,
)
)
o
,
_
=
merge_state
(
o1
,
s1
,
o2
,
s2
)
if
input_metadata
.
total_num_tokens
>=
global_config
.
layer_sync_threshold
:
torch
.
cuda
.
synchronize
()
return
o
.
view
(
-
1
,
self
.
tp_q_head_num
*
self
.
head_dim
)
...
...
python/sglang/srt/managers/controller/infer_batch.py
View file @
cf99eab7
...
...
@@ -829,8 +829,9 @@ def init_flashinfer_args(
num_kv_heads
=
model_runner
.
model_config
.
get_num_kv_heads
(
model_runner
.
tp_size
)
head_dim
=
model_runner
.
model_config
.
head_dim
batch_size
=
len
(
req_pool_indices
)
total_num_tokens
=
int
(
torch
.
sum
(
seq_lens
))
if
forward_mode
==
ForwardMode
.
DECODE
:
if
forward_mode
==
ForwardMode
.
DECODE
or
total_num_tokens
<=
4096
:
paged_kernel_lens
=
seq_lens
else
:
paged_kernel_lens
=
prefix_lens
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment