Unverified Commit f1ec6988 authored by pengcheng888's avatar pengcheng888 Committed by GitHub
Browse files

Merge pull request #624 from gongchensu/feature/tensor_matmul

#624 Implement __matmul__ operator for Tensor.
parents f4bf6ac9 79bd83d6
......@@ -101,6 +101,9 @@ class Tensor:
def __add__(self, other):
return infinicore.add(self, other)
def __matmul__(self, other):
return infinicore.matmul(self, other)
def __mul__(self, other):
return infinicore.mul(self, other)
......
......@@ -80,6 +80,34 @@ def test2():
print("abs error: ", torch.abs(ans2_torch_ref - torch_ans2_result).max())
def test3():
"测试infinicore.Tensor的@运算符功能(矩阵乘法)"
shape1 = [2, 3]
shape2 = [3, 4]
x1_torch = torch.rand(shape1, dtype=torch.float32, device="cpu")
x2_torch = torch.rand(shape2, dtype=torch.float32, device="cpu")
x1_infini = infinicore.from_torch(x1_torch.clone())
x2_infini = infinicore.from_torch(x2_torch.clone())
ans_infini = x1_infini @ x2_infini
ans_torch_ref = x1_torch @ x2_torch
print("----------------------------------------")
torch_ans_result = torch.zeros([2, 4], dtype=torch.float32, device="cpu")
torch_ans = infinicore.from_blob(
torch_ans_result.data_ptr(),
[2, 4],
dtype=infinicore.float32,
device=infinicore.device("cpu", 0),
)
torch_ans.copy_(ans_infini)
print("abs error: ", torch.abs(ans_torch_ref - torch_ans_result).max())
if __name__ == "__main__":
# test()
test2()
test3()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment