Unverified Commit d7a1c6d6 authored by Tao Peng's avatar Tao Peng Committed by GitHub
Browse files

Fix paged attention testing. (#495)


Signed-off-by: default avatarTao Peng <jiankeng.pt@alibaba-inc.com>
parent 7d5a155e
...@@ -164,6 +164,7 @@ def run_single_query_cached_kv_attention( ...@@ -164,6 +164,7 @@ def run_single_query_cached_kv_attention(
block_size: int, block_size: int,
num_blocks: int, num_blocks: int,
dtype: torch.dtype, dtype: torch.dtype,
num_kv_heads: int = None,
) -> None: ) -> None:
qkv = torch.empty(num_tokens, qkv = torch.empty(num_tokens,
3, 3,
...@@ -202,6 +203,14 @@ def run_single_query_cached_kv_attention( ...@@ -202,6 +203,14 @@ def run_single_query_cached_kv_attention(
head_mapping = torch.arange(num_heads, dtype=torch.int32, device="cuda") head_mapping = torch.arange(num_heads, dtype=torch.int32, device="cuda")
scale = float(1.0 / (head_size**0.5)) scale = float(1.0 / (head_size**0.5))
num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
assert num_heads % num_kv_heads == 0
num_queries_per_kv = num_heads // num_kv_heads
head_mapping = torch.repeat_interleave(
torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"),
num_queries_per_kv)
output = torch.empty(num_tokens, output = torch.empty(num_tokens,
num_heads, num_heads,
head_size, head_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment