Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
b4d34cd3
Unverified
Commit
b4d34cd3
authored
Mar 03, 2025
by
yinfan98
Committed by
GitHub
Mar 02, 2025
Browse files
Fix nightly-test CI (#3826)
parent
728e175f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
2 deletions
+13
-2
python/sglang/srt/layers/attention/flashinfer_backend.py
python/sglang/srt/layers/attention/flashinfer_backend.py
+13
-2
No files found.
python/sglang/srt/layers/attention/flashinfer_backend.py
View file @
b4d34cd3
...
@@ -422,7 +422,10 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -422,7 +422,10 @@ class FlashInferAttnBackend(AttentionBackend):
else
:
else
:
o2
,
s2
=
prefill_wrapper_paged
.
forward_return_lse
(
o2
,
s2
=
prefill_wrapper_paged
.
forward_return_lse
(
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
layer
.
layer_id
),
self
.
_to_dtype
(
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
layer
.
layer_id
),
q
.
dtype
,
),
causal
=
False
,
causal
=
False
,
sm_scale
=
layer
.
scaling
,
sm_scale
=
layer
.
scaling
,
logits_soft_cap
=
layer
.
logit_cap
,
logits_soft_cap
=
layer
.
logit_cap
,
...
@@ -464,7 +467,9 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -464,7 +467,9 @@ class FlashInferAttnBackend(AttentionBackend):
o
=
decode_wrapper
.
forward
(
o
=
decode_wrapper
.
forward
(
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
q
.
contiguous
().
view
(
-
1
,
layer
.
tp_q_head_num
,
layer
.
head_dim
),
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
layer
.
layer_id
),
self
.
_to_dtype
(
forward_batch
.
token_to_kv_pool
.
get_kv_buffer
(
layer
.
layer_id
),
q
.
dtype
),
sm_scale
=
layer
.
scaling
,
sm_scale
=
layer
.
scaling
,
logits_soft_cap
=
layer
.
logit_cap
,
logits_soft_cap
=
layer
.
logit_cap
,
k_scale
=
layer
.
k_scale
,
k_scale
=
layer
.
k_scale
,
...
@@ -473,6 +478,12 @@ class FlashInferAttnBackend(AttentionBackend):
...
@@ -473,6 +478,12 @@ class FlashInferAttnBackend(AttentionBackend):
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
head_dim
)
return
o
.
view
(
-
1
,
layer
.
tp_q_head_num
*
layer
.
head_dim
)
def
_to_dtype
(
self
,
kv_tuple
,
dtype
):
if
kv_tuple
[
0
].
dtype
!=
dtype
:
return
tuple
(
t
.
to
(
dtype
)
for
t
in
kv_tuple
)
else
:
return
kv_tuple
def
_get_wrapper_idx
(
self
,
layer
:
RadixAttention
):
def
_get_wrapper_idx
(
self
,
layer
:
RadixAttention
):
if
self
.
num_wrappers
==
1
:
if
self
.
num_wrappers
==
1
:
return
0
return
0
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment