Commit b82153fe authored by Catheriany's avatar Catheriany
Browse files

issue/150: rope算子测例

parent 4e4d3415
...@@ -27,7 +27,7 @@ DECLARE_INFINIOP_TEST(rope) ...@@ -27,7 +27,7 @@ DECLARE_INFINIOP_TEST(rope)
REGISTER_INFINIOP_TEST(gemm) \ REGISTER_INFINIOP_TEST(gemm) \
REGISTER_INFINIOP_TEST(random_sample) \ REGISTER_INFINIOP_TEST(random_sample) \
REGISTER_INFINIOP_TEST(mul) \ REGISTER_INFINIOP_TEST(mul) \
REGISTER_INFINIOP_TEST(rope) \ REGISTER_INFINIOP_TEST(rope) \
} }
namespace infiniop_test { namespace infiniop_test {
......
from ast import List from ast import List
import numpy as np import numpy as np
import gguf import gguf
from enum import Enum, auto from typing import List
from .. import InfiniopTestWriter, InfiniopTestCase, np_dtype_to_ggml, gguf_strides from .. import InfiniopTestWriter, InfiniopTestCase, np_dtype_to_ggml, gguf_strides
...@@ -108,15 +109,6 @@ _TEST_CASES_ = [ ...@@ -108,15 +109,6 @@ _TEST_CASES_ = [
_TENSOR_DTYPES_ = [np.float16, np.float32] _TENSOR_DTYPES_ = [np.float16, np.float32]
class Inplace(Enum):
OUT_OF_PLACE = auto()
INPLACE_X = auto()
_INPLACE_ = [
Inplace.OUT_OF_PLACE,
Inplace.INPLACE_X,
]
if __name__ == "__main__": if __name__ == "__main__":
test_writer = InfiniopTestWriter("rope.gguf") test_writer = InfiniopTestWriter("rope.gguf")
...@@ -124,26 +116,23 @@ if __name__ == "__main__": ...@@ -124,26 +116,23 @@ if __name__ == "__main__":
for dtype in _TENSOR_DTYPES_: for dtype in _TENSOR_DTYPES_:
for shape, stride_x, stride_y in _TEST_CASES_: for shape, stride_x, stride_y in _TEST_CASES_:
for inplace in _INPLACE_: x = np.random.rand(*shape).astype(dtype)
x = np.random.rand(*shape).astype(dtype)
if inplace == Inplace.INPLACE_X: y = np.random.rand(*shape).astype(dtype)
y = x
else: pos_ids = np.arange(0, x.shape[0], dtype=np.int32)
y = np.random.rand(*shape).astype(dtype)
sin_table, cos_table = sin_cos_table(pos_ids, x.shape[2], theta=1e5, dtype=dtype)
pos_ids = np.arange(0, x.shape[0], dtype=np.int32)
test_case = RoPETestCase(
sin_table, cos_table = sin_cos_table(pos_ids, x.shape[2], theta=1e5, dtype=dtype) y=y,
x=x,
test_case = RoPETestCase( stride_y=stride_y,
y=y, stride_x=stride_x,
x=x, pos_ids=pos_ids,
stride_y=stride_y, sin_table=sin_table,
stride_x=stride_x, cos_table=cos_table,
pos_ids=pos_ids, )
sin_table=sin_table, test_cases.append(test_case)
cos_table=cos_table,
)
test_cases.append(test_case)
test_writer.add_tests(test_cases) test_writer.add_tests(test_cases)
test_writer.save() test_writer.save()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment