Unverified Commit e678cc71 authored by Mahmoud Ashraf's avatar Mahmoud Ashraf Committed by GitHub
Browse files

[bugfix]: use correct cache location for cross attention in torch native backend (#8622)

parent 4efe844a
...@@ -193,10 +193,13 @@ class TorchNativeAttnBackend(AttentionBackend): ...@@ -193,10 +193,13 @@ class TorchNativeAttnBackend(AttentionBackend):
else: else:
o = torch.empty_like(q) o = torch.empty_like(q)
if layer.is_cross_attention:
cache_loc = forward_batch.encoder_out_cache_loc
else:
cache_loc = forward_batch.out_cache_loc
if save_kv_cache: if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer( forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
layer, forward_batch.out_cache_loc, k, v
)
use_gqa = layer.tp_q_head_num != layer.tp_k_head_num use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
...@@ -241,10 +244,13 @@ class TorchNativeAttnBackend(AttentionBackend): ...@@ -241,10 +244,13 @@ class TorchNativeAttnBackend(AttentionBackend):
else: else:
o = torch.empty_like(q) o = torch.empty_like(q)
if layer.is_cross_attention:
cache_loc = forward_batch.encoder_out_cache_loc
else:
cache_loc = forward_batch.out_cache_loc
if save_kv_cache: if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer( forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
layer, forward_batch.out_cache_loc, k, v
)
use_gqa = layer.tp_q_head_num != layer.tp_k_head_num use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment