Unverified Commit 2286cf78 authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #548 from gongchensu/feature/add_mul_python_api

Feature/add mul python api
parents 2d0a83cf a565b363
#pragma once
#include "../device.hpp"
#include "common/op.hpp"
namespace infinicore::op {
class Mul {
public:
using schema = void (*)(Tensor, Tensor, Tensor);
static void execute(Tensor c, Tensor a, Tensor b);
static common::OpDispatcher<schema> &dispatcher();
};
Tensor mul(Tensor a, Tensor b);
void mul_(Tensor c, Tensor a, Tensor b);
} // namespace infinicore::op
...@@ -30,6 +30,7 @@ from infinicore.dtype import ( ...@@ -30,6 +30,7 @@ from infinicore.dtype import (
from infinicore.ops.add import add from infinicore.ops.add import add
from infinicore.ops.attention import attention from infinicore.ops.attention import attention
from infinicore.ops.matmul import matmul from infinicore.ops.matmul import matmul
from infinicore.ops.mul import mul
from infinicore.ops.rearrange import rearrange from infinicore.ops.rearrange import rearrange
from infinicore.tensor import ( from infinicore.tensor import (
Tensor, Tensor,
...@@ -76,6 +77,7 @@ __all__ = [ ...@@ -76,6 +77,7 @@ __all__ = [
"add", "add",
"attention", "attention",
"matmul", "matmul",
"mul",
"rearrange", "rearrange",
"empty", "empty",
"empty_like", "empty_like",
......
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor
def mul(input, other, *, out=None):
if out is None:
return Tensor(_infinicore.mul(input._underlying, other._underlying))
_infinicore.mul_(out._underlying, input._underlying, other._underlying)
#include "infinicore/ops/mul.hpp"
namespace infinicore::op {
common::OpDispatcher<Mul::schema> &Mul::dispatcher() {
static common::OpDispatcher<Mul::schema> dispatcher_;
return dispatcher_;
};
void Mul::execute(Tensor c, Tensor a, Tensor b) {
dispatcher().lookup(context::getDevice().getType())(c, a, b);
}
Tensor mul(Tensor a, Tensor b) {
auto c = Tensor::empty(a->shape(), a->dtype(), a->device());
mul_(c, a, b);
return c;
}
void mul_(Tensor c, Tensor a, Tensor b) {
Mul::execute(c, a, b);
}
} // namespace infinicore::op
#include "../../utils.hpp"
#include "infinicore/common/hash.hpp"
#include "infinicore/ops/common/cache.hpp"
#include "infinicore/ops/mul.hpp"
#include <infiniop.h>
namespace infinicore::op::mul_impl::infiniop {
thread_local common::OpCache<size_t, infiniopMulDescriptor_t> caches(
100, // capacity
[](infiniopMulDescriptor_t &desc) {
if (desc != nullptr) {
INFINICORE_CHECK_ERROR(infiniopDestroyMulDescriptor(desc));
desc = nullptr;
}
});
void calculate(Tensor c, Tensor a, Tensor b) {
size_t seed = hash_combine(c, b, a);
auto device_type = context::getDevice().getType();
auto device_index = context::getDevice().getIndex();
auto &cache = caches.getCache(device_type, device_index);
auto desc_opt = cache.get(seed);
infiniopMulDescriptor_t desc = nullptr;
if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateMulDescriptor(
context::getInfiniopHandle(), &desc,
c->desc(), a->desc(), b->desc()));
cache.put(seed, desc);
} else {
desc = *desc_opt;
}
size_t workspace_size = 0;
INFINICORE_CHECK_ERROR(infiniopGetMulWorkspaceSize(desc, &workspace_size));
std::shared_ptr<Memory> workspace = context::allocateMemory(workspace_size);
INFINICORE_CHECK_ERROR(infiniopMul(
desc, workspace->data(), workspace_size,
c->data(), a->data(), b->data(), context::getStream()));
}
static bool registered = []() {
Mul::dispatcher().registerAll(&calculate, false);
return true;
}();
} // namespace infinicore::op::mul_impl::infiniop
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "ops/attention.hpp" #include "ops/attention.hpp"
#include "ops/causal_softmax.hpp" #include "ops/causal_softmax.hpp"
#include "ops/matmul.hpp" #include "ops/matmul.hpp"
#include "ops/mul.hpp"
#include "ops/rearrange.hpp" #include "ops/rearrange.hpp"
#include "ops/rms_norm.hpp" #include "ops/rms_norm.hpp"
#include "ops/silu.hpp" #include "ops/silu.hpp"
...@@ -20,6 +21,7 @@ inline void bind(py::module &m) { ...@@ -20,6 +21,7 @@ inline void bind(py::module &m) {
bind_attention(m); bind_attention(m);
bind_causal_softmax(m); bind_causal_softmax(m);
bind_matmul(m); bind_matmul(m);
bind_mul(m);
bind_rearrange(m); bind_rearrange(m);
bind_rms_norm(m); bind_rms_norm(m);
bind_silu(m); bind_silu(m);
......
#pragma once
#include <pybind11/pybind11.h>
#include "infinicore/ops/mul.hpp"
namespace py = pybind11;
namespace infinicore::ops {
inline void bind_mul(py::module &m) {
m.def("mul",
&op::mul,
py::arg("a"),
py::arg("b"),
R"doc(Element-wise multiplication of two tensors.)doc");
m.def("mul_",
&op::mul_,
py::arg("c"),
py::arg("a"),
py::arg("b"),
R"doc(In-place element-wise tensor multiplication.)doc");
}
} // namespace infinicore::ops
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
from framework.base import BaseOperatorTest, TensorSpec, TestCase
from framework.runner import GenericTestRunner
from framework.utils import is_broadcast
# ==============================================================================
# Operator-specific configuration
# ==============================================================================
# Test cases format: (shape, a_strides, b_strides, c_strides)
_TEST_CASES_DATA = [
((13, 4), None, None, None),
((13, 4), (10, 1), (10, 1), (10, 1)),
((13, 4), (0, 1), None, None),
((13, 4, 4), None, None, None),
((13, 4, 4), (20, 4, 1), (20, 4, 1), (20, 4, 1)),
((13, 4, 4), (4, 0, 1), (0, 4, 1), None),
((16, 5632), None, None, None),
((16, 5632), (13312, 1), (13312, 1), (13312, 1)),
]
# Data types
_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]
# Tolerance
_TOLERANCE_MAP = {
infinicore.float16: {"atol": 0, "rtol": 1e-2},
infinicore.float32: {"atol": 0, "rtol": 1e-3},
infinicore.bfloat16: {"atol": 0, "rtol": 5e-2},
}
def build_test_cases():
test_cases = []
for data in _TEST_CASES_DATA:
shape = data[0]
a_strides = data[1] if len(data) > 1 else None
b_strides = data[2] if len(data) > 2 else None
c_strides = data[3] if len(data) > 3 else None
a_supports_inplace = not is_broadcast(a_strides)
b_supports_inplace = not is_broadcast(b_strides)
c_supports_inplace = not is_broadcast(c_strides)
for dtype in _TENSOR_DTYPES:
tolerance = _TOLERANCE_MAP.get(dtype, {"atol": 0, "rtol": 1e-3})
a_spec = TensorSpec.from_tensor(shape, a_strides, dtype)
b_spec = TensorSpec.from_tensor(shape, b_strides, dtype)
c_spec = TensorSpec.from_tensor(shape, c_strides, dtype)
# Out-of-place (return value)
test_cases.append(
TestCase(
inputs=[a_spec, b_spec],
kwargs={},
output_spec=None,
comparison_target=None,
tolerance=tolerance,
description=f"Mul - OUT_OF_PLACE (dtype={dtype})",
)
)
# With explicit output tensor (mul(a, b, out=c))
if c_supports_inplace:
test_cases.append(
TestCase(
inputs=[a_spec, b_spec],
kwargs={},
output_spec=c_spec,
comparison_target="out",
tolerance=tolerance,
description=f"Mul - INPLACE(out) (dtype={dtype})",
)
)
# In-place on first input (mul(a, b, out=a))
if a_supports_inplace:
test_cases.append(
TestCase(
inputs=[a_spec, b_spec],
kwargs={"out": 0},
output_spec=None,
comparison_target=0,
tolerance=tolerance,
description=f"Mul - INPLACE(a) (dtype={dtype})",
)
)
# In-place on second input (mul(a, b, out=b))
if b_supports_inplace:
test_cases.append(
TestCase(
inputs=[a_spec, b_spec],
kwargs={"out": 1},
output_spec=None,
comparison_target=1,
tolerance=tolerance,
description=f"Mul - INPLACE(b) (dtype={dtype})",
)
)
return test_cases
_TEST_CASES = build_test_cases()
class OpTest(BaseOperatorTest):
"""Mul test with simplified test case parsing"""
def __init__(self):
super().__init__("Mul")
def get_test_cases(self):
return _TEST_CASES
def torch_operator(self, a, b, out=None, **kwargs):
return torch.mul(a, b, out=out)
def infinicore_operator(self, a, b, out=None, **kwargs):
try:
return infinicore.mul(a, b, out=out)
except AttributeError as exc:
raise NotImplementedError("InfiniCore mul operator not available") from exc
def main():
"""Main entry point"""
runner = GenericTestRunner(OpTest)
runner.run_and_exit()
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment