Commit 03edef48 authored by Graylatzhou's avatar Graylatzhou
Browse files

issue/183 Mul算子CPU&CUDA

parent 9e5d4eba
......@@ -12,6 +12,7 @@
#include "infiniop/ops/global_avg_pool.h"
#include "infiniop/ops/max_pool.h"
#include "infiniop/ops/mlp.h"
#include "infiniop/ops/mul.h"
#include "infiniop/ops/random_sample.h"
#include "infiniop/ops/rearrange.h"
#include "infiniop/ops/relu.h"
......@@ -19,6 +20,5 @@
#include "infiniop/ops/rope.h"
#include "infiniop/ops/swiglu.h"
#include "infiniop/tensor_descriptor.h"
#include "infiniop/ops/mul.h"
#endif // __INFINIOP_API_H__
......@@ -6,12 +6,9 @@ from .. import InfiniopTestWriter, InfiniopTestCase, np_dtype_to_ggml, gguf_stri
def mul(
a: np.ndarray,
b: np.ndarray,
c: np.ndarray = None,
b: np.ndarray
):
if c is None:
return np.multiply(a, b)
return np.multiply(a, b, out=c)
return np.multiply(a, b)
def random_tensor(shape, dtype):
rate = 1e-3
......@@ -56,7 +53,7 @@ class MulTestCase(InfiniopTestCase):
a_fp64 = self.a.astype(np.float64)
b_fp64 = self.b.astype(np.float64)
ans_fp64 = np.multiply(a_fp64, b_fp64)
ans = mul(self.a, self.b, self.c)
ans = mul(self.a, self.b)
test_writer.add_tensor(
test_writer.gguf_key("ans"), ans, raw_dtype=np_dtype_to_ggml(ans.dtype)
)
......@@ -73,17 +70,17 @@ if __name__ == '__main__':
random_tensor((2, 3), np.float32),
gguf_strides(3, 1),
random_tensor((2, 3), np.float32),
(1, 2),
gguf_strides(1, 2),
random_tensor((2, 3), np.float32),
gguf_strides(3, 1),
),
MulTestCase(
random_tensor((2, 3), np.float16),
(1, 2),
gguf_strides(1, 2),
random_tensor((2, 3), np.float16),
gguf_strides(3, 1),
random_tensor((2, 3), np.float16),
(1, 2),
gguf_strides(1, 2),
),
MulTestCase(
random_tensor((2, 3), np.float64),
......@@ -91,7 +88,7 @@ if __name__ == '__main__':
random_tensor((2, 3), np.float64),
gguf_strides(3, 1),
random_tensor((2, 3), np.float64),
(1, 2),
gguf_strides(1, 2),
),
MulTestCase(
random_tensor((4, 6), np.float16),
......@@ -103,49 +100,49 @@ if __name__ == '__main__':
),
MulTestCase(
random_tensor((1, 2048), np.float16),
(1, 1),
gguf_strides(1, 1),
random_tensor((1, 2048), np.float16),
gguf_strides(2048, 1),
random_tensor((1, 2048), np.float16),
(1, 1),
gguf_strides(1, 1),
),
MulTestCase(
random_tensor((2048, 2048), np.float32),
None,
random_tensor((2048, 2048), np.float32),
(1, 2048),
gguf_strides(1, 2048),
random_tensor((2048, 2048), np.float32),
None,
),
MulTestCase(
random_tensor((2, 4, 2048), np.float16),
(4 * 2048, 2048, 1),
gguf_strides(4 * 2048, 2048, 1),
random_tensor((2, 4, 2048), np.float16),
(1, 2, 2 * 4),
gguf_strides(1, 2, 2 * 4),
random_tensor((2, 4, 2048), np.float16),
(4 * 2048, 2048, 1),
gguf_strides(4 * 2048, 2048, 1),
),
MulTestCase(
random_tensor((2, 4, 2048), np.float32),
(1, 2, 2 * 4),
gguf_strides(1, 2, 2 * 4),
random_tensor((2, 4, 2048), np.float32),
None,
random_tensor((2, 4, 2048), np.float32),
(1, 2, 2 * 4),
gguf_strides(1, 2, 2 * 4),
),
MulTestCase(
random_tensor((2048, 2560), np.float32),
gguf_strides(2560, 1),
random_tensor((2048, 2560), np.float32),
(1, 2048),
gguf_strides(1, 2048),
random_tensor((2048, 2560), np.float32),
gguf_strides(2560, 1),
),
MulTestCase(
random_tensor((4, 48, 64), np.float16),
(64 * 48, 64, 1),
gguf_strides(64 * 48, 64, 1),
random_tensor((4, 48, 64), np.float16),
(1, 4, 4 * 48),
gguf_strides(1, 4, 4 * 48),
random_tensor((4, 48, 64), np.float16),
None
),
......@@ -153,11 +150,10 @@ if __name__ == '__main__':
random_tensor((4, 48, 64), np.float32),
None,
random_tensor((4, 48, 64), np.float32),
(1, 4, 4 * 48),
gguf_strides(1, 4, 4 * 48),
random_tensor((4, 48, 64), np.float32),
(48 * 64, 64, 1),
),
gguf_strides(48 * 64, 64, 1),
)
]
test_writer.add_tests(test_cases)
test_writer.save()
......@@ -240,4 +240,5 @@ if __name__ == "__main__":
for device in get_test_devices(args):
test_operator(lib, device, test, _TEST_CASES, _TENSOR_DTYPES)
print("\033[92mTest passed!\033[0m")
\ No newline at end of file
print("\033[92mTest passed!\033[0m")
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment