Unverified Commit bda41c70 authored by Song's avatar Song Committed by GitHub
Browse files

hotfix attn alibi wo head mapping (#496)


Co-authored-by: default avataroliveryuan <oliveryuan@basemind.com>
parent 453bafb9
......@@ -199,6 +199,7 @@ def run_single_query_cached_kv_attention(
]
block_tables.append(block_table)
block_tables = torch.tensor(block_tables, dtype=torch.int, device='cuda')
head_mapping = torch.arange(num_heads, dtype=torch.int32, device="cuda")
scale = float(1.0 / (head_size**0.5))
output = torch.empty(num_tokens,
......@@ -211,6 +212,7 @@ def run_single_query_cached_kv_attention(
query,
key_cache,
value_cache,
head_mapping,
scale,
block_tables,
context_lens,
......
......@@ -408,6 +408,7 @@ class PagedAttentionWithALiBi(PagedAttention):
query,
key_cache,
value_cache,
self.head_mapping,
self.scale,
input_metadata.block_tables,
input_metadata.context_lens,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment