Commit 15bcbdfc authored by Catheriany's avatar Catheriany
Browse files

issue/150: rope算子测例

parent b82153fe
...@@ -46,7 +46,6 @@ infiniStatus_t calculateRoPE(const RoPEInfo &info, ...@@ -46,7 +46,6 @@ infiniStatus_t calculateRoPE(const RoPEInfo &info,
const Tdata *sin_table, const Tdata *sin_table,
const Tdata *cos_table, const Tdata *cos_table,
cudaStream_t stream) { cudaStream_t stream) {
auto dimx = uint32_t(info.seqlen), auto dimx = uint32_t(info.seqlen),
dimy = uint32_t(info.nhead); dimy = uint32_t(info.nhead);
int nthreads = std::max(int(info.table_dim), block_size); int nthreads = std::max(int(info.table_dim), block_size);
......
...@@ -92,11 +92,15 @@ class RoPETestCase(InfiniopTestCase): ...@@ -92,11 +92,15 @@ class RoPETestCase(InfiniopTestCase):
test_writer.gguf_key("ans"), ans, raw_dtype=gguf.GGMLQuantizationType.F64 test_writer.gguf_key("ans"), ans, raw_dtype=gguf.GGMLQuantizationType.F64
) )
# ==============================================================================
# Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules if __name__ == "__main__":
_TEST_CASES_ = [ # ==============================================================================
# Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES_ = [
# (shape, x_strides, y_strides) # (shape, x_strides, y_strides)
((1, 32, 128), None, None), ((1, 32, 128), None, None),
((10, 32, 64), None, None), ((10, 32, 64), None, None),
...@@ -105,12 +109,9 @@ _TEST_CASES_ = [ ...@@ -105,12 +109,9 @@ _TEST_CASES_ = [
((4, 1, 32), gguf_strides(64, 64, 1), None), ((4, 1, 32), gguf_strides(64, 64, 1), None),
((11, 33, 128), None, gguf_strides(8000, 200, 1)), ((11, 33, 128), None, gguf_strides(8000, 200, 1)),
((3, 32, 128), gguf_strides(8000, 200, 1), gguf_strides(7000, 128, 1)), ((3, 32, 128), gguf_strides(8000, 200, 1), gguf_strides(7000, 128, 1)),
] ]
_TENSOR_DTYPES_ = [np.float16, np.float32] _TENSOR_DTYPES_ = [np.float16, np.float32]
if __name__ == "__main__":
test_writer = InfiniopTestWriter("rope.gguf") test_writer = InfiniopTestWriter("rope.gguf")
test_cases = [] test_cases = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment