Commit 8ab073b4 authored by PanZezhong's avatar PanZezhong Committed by wooway777
Browse files

issue/1033 fix mha_varlen test

parent f6496d44
import os import os
import sys import sys
import infinicore
import torch import torch
import infinicore
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from framework import ( from framework import (
...@@ -14,13 +15,17 @@ from framework import ( ...@@ -14,13 +15,17 @@ from framework import (
TestCase, TestCase,
) )
# Test Cases: (num_seqs, num_heads, num_kv_heads, head_size, block_size, max_step_len, num_rounds) # Test Cases: (num_heads, num_kv_heads, head_size, block_size, [request_batch])
_TEST_CASES_DATA = [ _TEST_CASES_DATA = [
(1, 1, 1, 128, 256, 16, 1), (1, 1, 128, 256, [(250,), (7,)]),
(1, 4, 4, 128, 256, 16, 4), (4, 4, 128, 256, [(250,), (7,)]),
(2, 8, 8, 128, 256, 16, 2), (1, 1, 128, 256, [(260, 73), (1, 1)]),
(8, 2, 128, 256, [(250,), (7,)]),
(8, 2, 128, 256, [(260, 73), (1, 1)]),
] ]
_MAX_SEQUENCE_LENGTH = 8192
_TOLERANCE_MAP = { _TOLERANCE_MAP = {
infinicore.float16: {"atol": 1e-2, "rtol": 1e-2}, infinicore.float16: {"atol": 1e-2, "rtol": 1e-2},
infinicore.bfloat16: {"atol": 2e-2, "rtol": 2e-2}, infinicore.bfloat16: {"atol": 2e-2, "rtol": 2e-2},
...@@ -58,24 +63,24 @@ def parse_test_cases(): ...@@ -58,24 +63,24 @@ def parse_test_cases():
test_cases = [] test_cases = []
for ( for (
num_seqs,
num_heads, num_heads,
num_kv_heads, num_kv_heads,
head_size, head_size,
block_size, block_size,
max_step_len, request_batches,
num_rounds,
) in _TEST_CASES_DATA: ) in _TEST_CASES_DATA:
scale = head_size**-0.5 scale = head_size**-0.5
num_blocks = 512 num_blocks = 512
manager = SimpleCacheManager(num_blocks, block_size) manager = SimpleCacheManager(num_blocks, block_size)
num_seqs = len(request_batches[0])
kv_lens = torch.zeros(num_seqs, dtype=torch.int32) kv_lens = torch.zeros(num_seqs, dtype=torch.int32)
persistent_k = torch.zeros((num_blocks, num_kv_heads, block_size, head_size)) persistent_k = torch.zeros((num_blocks, num_kv_heads, block_size, head_size))
persistent_v = torch.zeros((num_blocks, num_kv_heads, block_size, head_size)) persistent_v = torch.zeros((num_blocks, num_kv_heads, block_size, head_size))
for r in range(num_rounds): for r, req in enumerate(request_batches):
q_lens = torch.randint(1, max_step_len + 1, (num_seqs,), dtype=torch.int32) assert len(req) == num_seqs, "All requests should have the same length"
q_lens = torch.tensor(req, dtype=torch.int32)
kv_lens = kv_lens + q_lens kv_lens = kv_lens + q_lens
total_q_tokens = q_lens.sum().item() total_q_tokens = q_lens.sum().item()
cum_seqlens_q = torch.zeros(num_seqs + 1, dtype=torch.int32) cum_seqlens_q = torch.zeros(num_seqs + 1, dtype=torch.int32)
...@@ -134,12 +139,6 @@ def parse_test_cases(): ...@@ -134,12 +139,6 @@ def parse_test_cases():
set_tensor=padded_tables.clone(), set_tensor=padded_tables.clone(),
dtype=infinicore.int32, dtype=infinicore.int32,
), ),
# TensorSpec.from_tensor(
# kv_lens.shape,
# init_mode=TensorInitializer.MANUAL,
# set_tensor=kv_lens.clone(),
# dtype=infinicore.int64,
# ),
TensorSpec.from_tensor( TensorSpec.from_tensor(
cum_seqlens_q.shape, cum_seqlens_q.shape,
init_mode=TensorInitializer.MANUAL, init_mode=TensorInitializer.MANUAL,
...@@ -155,8 +154,8 @@ def parse_test_cases(): ...@@ -155,8 +154,8 @@ def parse_test_cases():
], ],
kwargs={ kwargs={
"scale": scale, "scale": scale,
"max_seqlen_q": max_step_len + num_rounds, "max_seqlen_q": _MAX_SEQUENCE_LENGTH,
"max_seqlen_k": max_step_len + num_rounds, "max_seqlen_k": _MAX_SEQUENCE_LENGTH,
}, },
tolerance=tolerance, tolerance=tolerance,
description=f"MHA_Varlen_Round_{r}_{str(dtype).split('.')[-1]}", description=f"MHA_Varlen_Round_{r}_{str(dtype).split('.')[-1]}",
...@@ -191,6 +190,15 @@ def ref_paged_attention_multi_turn( ...@@ -191,6 +190,15 @@ def ref_paged_attention_multi_turn(
K = torch.stack(keys, dim=0) K = torch.stack(keys, dim=0)
V = torch.stack(values, dim=0) V = torch.stack(values, dim=0)
q_heads = cur_q.shape[1]
kv_heads = K.shape[1]
assert q_heads % kv_heads == 0
group_size = q_heads // kv_heads
if group_size > 1:
K = K.repeat_interleave(group_size, dim=1)
V = V.repeat_interleave(group_size, dim=1)
scores = torch.einsum("qhd,khd->hqk", cur_q.float(), K.float()) * scale scores = torch.einsum("qhd,khd->hqk", cur_q.float(), K.float()) * scale
mask = torch.full((q_len, total_len), float("-inf"), device=query.device) mask = torch.full((q_len, total_len), float("-inf"), device=query.device)
for t in range(q_len): for t in range(q_len):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment