Commit 4916ff8a authored by pengcheng888's avatar pengcheng888
Browse files

issue/680-为mulmat添加alpha参数算子

parent cad2d45a
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
namespace infinicore::op { namespace infinicore::op {
Tensor matmul(Tensor a, Tensor b); Tensor matmul(Tensor a, Tensor b, float alpha = 1.0f);
void matmul_(Tensor c, Tensor a, Tensor b); void matmul_(Tensor c, Tensor a, Tensor b, float alpha = 1.0f);
} // namespace infinicore::op } // namespace infinicore::op
...@@ -2,10 +2,10 @@ from infinicore.lib import _infinicore ...@@ -2,10 +2,10 @@ from infinicore.lib import _infinicore
from infinicore.tensor import Tensor from infinicore.tensor import Tensor
def matmul(input, other, *, out=None): def matmul(input, other, *, alpha=1.0, out=None):
if out is None: if out is None:
return Tensor(_infinicore.matmul(input._underlying, other._underlying)) return Tensor(_infinicore.matmul(input._underlying, other._underlying, alpha))
_infinicore.matmul_(out._underlying, input._underlying, other._underlying) _infinicore.matmul_(out._underlying, input._underlying, other._underlying, alpha)
return out return out
...@@ -3,11 +3,11 @@ ...@@ -3,11 +3,11 @@
namespace infinicore::op { namespace infinicore::op {
Tensor matmul(Tensor a, Tensor b) { Tensor matmul(Tensor a, Tensor b, float alpha) {
return gemm(a, b, 1.0f, 0.0f); return gemm(a, b, alpha, 0.0f);
} }
void matmul_(Tensor c, Tensor a, Tensor b) { void matmul_(Tensor c, Tensor a, Tensor b, float alpha) {
Gemm::execute(c, a, b, 1.0f, 0.0f); Gemm::execute(c, a, b, alpha, 0.0f);
} }
} // namespace infinicore::op } // namespace infinicore::op
...@@ -13,6 +13,7 @@ inline void bind_matmul(py::module &m) { ...@@ -13,6 +13,7 @@ inline void bind_matmul(py::module &m) {
&op::matmul, &op::matmul,
py::arg("a"), py::arg("a"),
py::arg("b"), py::arg("b"),
py::arg("alpha") = 1.0f,
R"doc(Matrix multiplication of two tensors.)doc"); R"doc(Matrix multiplication of two tensors.)doc");
m.def("matmul_", m.def("matmul_",
...@@ -20,6 +21,7 @@ inline void bind_matmul(py::module &m) { ...@@ -20,6 +21,7 @@ inline void bind_matmul(py::module &m) {
py::arg("c"), py::arg("c"),
py::arg("a"), py::arg("a"),
py::arg("b"), py::arg("b"),
py::arg("alpha") = 1.0f,
R"doc(In-place matrix multiplication.)doc"); R"doc(In-place matrix multiplication.)doc");
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment