Commit 8ab073b4 authored by PanZezhong's avatar PanZezhong Committed by wooway777
Browse files

issue/1033 fix mha_varlen test

parent f6496d44
import os
import sys
import infinicore
import torch
import infinicore
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from framework import (
......@@ -14,13 +15,17 @@ from framework import (
TestCase,
)
# Test Cases: (num_seqs, num_heads, num_kv_heads, head_size, block_size, max_step_len, num_rounds)
# Test Cases: (num_heads, num_kv_heads, head_size, block_size, [request_batch])
_TEST_CASES_DATA = [
(1, 1, 1, 128, 256, 16, 1),
(1, 4, 4, 128, 256, 16, 4),
(2, 8, 8, 128, 256, 16, 2),
(1, 1, 128, 256, [(250,), (7,)]),
(4, 4, 128, 256, [(250,), (7,)]),
(1, 1, 128, 256, [(260, 73), (1, 1)]),
(8, 2, 128, 256, [(250,), (7,)]),
(8, 2, 128, 256, [(260, 73), (1, 1)]),
]
_MAX_SEQUENCE_LENGTH = 8192
_TOLERANCE_MAP = {
infinicore.float16: {"atol": 1e-2, "rtol": 1e-2},
infinicore.bfloat16: {"atol": 2e-2, "rtol": 2e-2},
......@@ -58,24 +63,24 @@ def parse_test_cases():
test_cases = []
for (
num_seqs,
num_heads,
num_kv_heads,
head_size,
block_size,
max_step_len,
num_rounds,
request_batches,
) in _TEST_CASES_DATA:
scale = head_size**-0.5
num_blocks = 512
manager = SimpleCacheManager(num_blocks, block_size)
num_seqs = len(request_batches[0])
kv_lens = torch.zeros(num_seqs, dtype=torch.int32)
persistent_k = torch.zeros((num_blocks, num_kv_heads, block_size, head_size))
persistent_v = torch.zeros((num_blocks, num_kv_heads, block_size, head_size))
for r in range(num_rounds):
q_lens = torch.randint(1, max_step_len + 1, (num_seqs,), dtype=torch.int32)
for r, req in enumerate(request_batches):
assert len(req) == num_seqs, "All requests should have the same length"
q_lens = torch.tensor(req, dtype=torch.int32)
kv_lens = kv_lens + q_lens
total_q_tokens = q_lens.sum().item()
cum_seqlens_q = torch.zeros(num_seqs + 1, dtype=torch.int32)
......@@ -134,12 +139,6 @@ def parse_test_cases():
set_tensor=padded_tables.clone(),
dtype=infinicore.int32,
),
# TensorSpec.from_tensor(
# kv_lens.shape,
# init_mode=TensorInitializer.MANUAL,
# set_tensor=kv_lens.clone(),
# dtype=infinicore.int64,
# ),
TensorSpec.from_tensor(
cum_seqlens_q.shape,
init_mode=TensorInitializer.MANUAL,
......@@ -155,8 +154,8 @@ def parse_test_cases():
],
kwargs={
"scale": scale,
"max_seqlen_q": max_step_len + num_rounds,
"max_seqlen_k": max_step_len + num_rounds,
"max_seqlen_q": _MAX_SEQUENCE_LENGTH,
"max_seqlen_k": _MAX_SEQUENCE_LENGTH,
},
tolerance=tolerance,
description=f"MHA_Varlen_Round_{r}_{str(dtype).split('.')[-1]}",
......@@ -191,6 +190,15 @@ def ref_paged_attention_multi_turn(
K = torch.stack(keys, dim=0)
V = torch.stack(values, dim=0)
q_heads = cur_q.shape[1]
kv_heads = K.shape[1]
assert q_heads % kv_heads == 0
group_size = q_heads // kv_heads
if group_size > 1:
K = K.repeat_interleave(group_size, dim=1)
V = V.repeat_interleave(group_size, dim=1)
scores = torch.einsum("qhd,khd->hqk", cur_q.float(), K.float()) * scale
mask = torch.full((q_len, total_len), float("-inf"), device=query.device)
for t in range(q_len):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment