Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e678cc71
Unverified
Commit
e678cc71
authored
Sep 05, 2025
by
Mahmoud Ashraf
Committed by
GitHub
Sep 05, 2025
Browse files
[bugfix]: use correct cache location for cross attention in torch native backend (#8622)
parent
4efe844a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
6 deletions
+12
-6
python/sglang/srt/layers/attention/torch_native_backend.py
python/sglang/srt/layers/attention/torch_native_backend.py
+12
-6
No files found.
python/sglang/srt/layers/attention/torch_native_backend.py
View file @
e678cc71
...
...
@@ -193,10 +193,13 @@ class TorchNativeAttnBackend(AttentionBackend):
else
:
o
=
torch
.
empty_like
(
q
)
if
layer
.
is_cross_attention
:
cache_loc
=
forward_batch
.
encoder_out_cache_loc
else
:
cache_loc
=
forward_batch
.
out_cache_loc
if
save_kv_cache
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
)
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
)
use_gqa
=
layer
.
tp_q_head_num
!=
layer
.
tp_k_head_num
...
...
@@ -241,10 +244,13 @@ class TorchNativeAttnBackend(AttentionBackend):
else
:
o
=
torch
.
empty_like
(
q
)
if
layer
.
is_cross_attention
:
cache_loc
=
forward_batch
.
encoder_out_cache_loc
else
:
cache_loc
=
forward_batch
.
out_cache_loc
if
save_kv_cache
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
)
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
cache_loc
,
k
,
v
)
use_gqa
=
layer
.
tp_q_head_num
!=
layer
.
tp_k_head_num
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment