Unverified Commit e678cc71 authored by Mahmoud Ashraf's avatar Mahmoud Ashraf Committed by GitHub
Browse files

[bugfix]: use correct cache location for cross attention in torch native backend (#8622)

parent 4efe844a
......@@ -193,10 +193,13 @@ class TorchNativeAttnBackend(AttentionBackend):
else:
o = torch.empty_like(q)
if layer.is_cross_attention:
cache_loc = forward_batch.encoder_out_cache_loc
else:
cache_loc = forward_batch.out_cache_loc
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, v
)
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
......@@ -241,10 +244,13 @@ class TorchNativeAttnBackend(AttentionBackend):
else:
o = torch.empty_like(q)
if layer.is_cross_attention:
cache_loc = forward_batch.encoder_out_cache_loc
else:
cache_loc = forward_batch.out_cache_loc
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, forward_batch.out_cache_loc, k, v
)
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment