Unverified Commit 79e6883b authored by pengcheng888's avatar pengcheng888 Committed by GitHub
Browse files

Merge pull request #683 from pengcheng888/issue/680_mm

issue/680-为mulmat添加alpha参数算子
parents cad2d45a 4916ff8a
......@@ -5,7 +5,7 @@
namespace infinicore::op {
Tensor matmul(Tensor a, Tensor b);
void matmul_(Tensor c, Tensor a, Tensor b);
Tensor matmul(Tensor a, Tensor b, float alpha = 1.0f);
void matmul_(Tensor c, Tensor a, Tensor b, float alpha = 1.0f);
} // namespace infinicore::op
......@@ -2,10 +2,10 @@ from infinicore.lib import _infinicore
from infinicore.tensor import Tensor
def matmul(input, other, *, out=None):
def matmul(input, other, *, alpha=1.0, out=None):
if out is None:
return Tensor(_infinicore.matmul(input._underlying, other._underlying))
return Tensor(_infinicore.matmul(input._underlying, other._underlying, alpha))
_infinicore.matmul_(out._underlying, input._underlying, other._underlying)
_infinicore.matmul_(out._underlying, input._underlying, other._underlying, alpha)
return out
......@@ -3,11 +3,11 @@
namespace infinicore::op {
Tensor matmul(Tensor a, Tensor b) {
return gemm(a, b, 1.0f, 0.0f);
Tensor matmul(Tensor a, Tensor b, float alpha) {
return gemm(a, b, alpha, 0.0f);
}
void matmul_(Tensor c, Tensor a, Tensor b) {
Gemm::execute(c, a, b, 1.0f, 0.0f);
void matmul_(Tensor c, Tensor a, Tensor b, float alpha) {
Gemm::execute(c, a, b, alpha, 0.0f);
}
} // namespace infinicore::op
......@@ -13,6 +13,7 @@ inline void bind_matmul(py::module &m) {
&op::matmul,
py::arg("a"),
py::arg("b"),
py::arg("alpha") = 1.0f,
R"doc(Matrix multiplication of two tensors.)doc");
m.def("matmul_",
......@@ -20,6 +21,7 @@ inline void bind_matmul(py::module &m) {
py::arg("c"),
py::arg("a"),
py::arg("b"),
py::arg("alpha") = 1.0f,
R"doc(In-place matrix multiplication.)doc");
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment