Commit 1715b204 authored by pengcheng888's avatar pengcheng888
Browse files

issue/608 - 修改rope的测试脚本

parent b2e1f8b7
......@@ -3,10 +3,10 @@ import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
from framework.base import BaseOperatorTest, TensorSpec, TestCase
from framework.runner import GenericTestRunner
from framework.utils import infinicore_tensor_from_torch, is_broadcast
from framework.utils import is_broadcast
from infinicore.nn.functional import RopeAlgo
import infinicore
......@@ -17,11 +17,11 @@ import infinicore
_TEST_CASES_DATA = [
# ntok, num, head_dim, Algo
(1, 1, 64, RopeAlgo.GPT_NEOX),
(5, 32, 64, RopeAlgo.GPT_NEOX),
(1, 1, 128, RopeAlgo.GPT_J),
(10, 1, 64, RopeAlgo.GPT_J),
# bs, seq_len, num, head_dim, Algo
(1, 1, 1, 64, RopeAlgo.GPT_NEOX),
(1, 5, 32, 64, RopeAlgo.GPT_NEOX),
(1, 1, 1, 128, RopeAlgo.GPT_J),
(1, 10, 1, 64, RopeAlgo.GPT_J),
]
# Tolerance configuration
......@@ -43,14 +43,14 @@ def parse_test_cases():
test_cases = []
for data in _TEST_CASES_DATA:
ntok, num, head_dim = data[0], data[1], data[2]
algo = data[3]
bs, seq_len, num, head_dim = data[0], data[1], data[2], data[3]
algo = data[4]
# Determine shapes based on batch dimension
out_shape = (ntok, num, head_dim)
x_shape = (ntok, num, head_dim)
sin_table_shape = (ntok, head_dim // 2)
cos_table_shape = (ntok, head_dim // 2)
out_shape = (bs, seq_len, num, head_dim)
x_shape = (bs, seq_len, num, head_dim)
sin_table_shape = (seq_len, head_dim // 2)
cos_table_shape = (seq_len, head_dim // 2)
# Check if tensors support in-place operations
c_supports_inplace = not is_broadcast(out_shape)
......@@ -151,18 +151,13 @@ class OpTest(BaseOperatorTest):
def infinicore_operator(self, x, sin_table, cos_table, algo, out=None, **kwargs):
"""InfiniCore Rope implementation"""
ntok = x.shape[0]
torch_device = "cpu"
if x.device.type != "cpu":
torch_device = "cuda"
bs, seq_len, num, head_dim = x.shape
# 创建 pos_ids的变量
pos_ids_torch = torch.arange(0, ntok, dtype=torch.int32, device=torch_device)
pos_ids_ref = infinicore_tensor_from_torch(pos_ids_torch)
pos_ids_infini = infinicore.empty(
list(pos_ids_ref.shape), dtype=pos_ids_ref.dtype, device=pos_ids_ref.device
## 创建 pos_ids的变量
cache_position_list = [list(range(0, seq_len)) for i in range(bs)]
pos_ids_infini = infinicore.from_list(
cache_position_list, dtype=infinicore.int64, device=x.device
)
pos_ids_infini.copy_(pos_ids_ref)
# 计算
pos_ids = pos_ids_infini
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment