Unverified Commit 874cc65b authored by pengcheng888's avatar pengcheng888 Committed by GitHub
Browse files

Merge pull request #653 from pengcheng888/issue/652

issue/652- 修改rope函数的测试脚本
parents edc11eb5 ea174701
......@@ -4,6 +4,7 @@ import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
from framework.base import BaseOperatorTest, TensorSpec, TestCase
from framework.runner import GenericTestRunner
from framework.utils import is_broadcast
......@@ -153,11 +154,13 @@ class OpTest(BaseOperatorTest):
bs, seq_len, num, head_dim = x.shape
infini_device = x.device
torch_device = torch.device(type=infini_device.type, index=infini_device.index)
## 创建 pos_ids的变量
cache_position_list = [list(range(0, seq_len)) for i in range(bs)]
pos_ids_infini = infinicore.from_list(
cache_position_list, dtype=infinicore.int64, device=x.device
)
pos_ids_torch = torch.tensor(cache_position_list, device=torch_device)
pos_ids_infini = infinicore.from_torch(pos_ids_torch)
# 计算
pos_ids = pos_ids_infini
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment