Unverified Commit 4ee91092 authored by pengcheng888's avatar pengcheng888 Committed by GitHub
Browse files

Merge pull request #534 from gongchensu/feature/add_silu_python_api

Add silu operator python interface and tests.
parents 6cef1a9f 229ed8ad
...@@ -6,4 +6,5 @@ ...@@ -6,4 +6,5 @@
#include "ops/ones.hpp" #include "ops/ones.hpp"
#include "ops/rearrange.hpp" #include "ops/rearrange.hpp"
#include "ops/rms_norm.hpp" #include "ops/rms_norm.hpp"
#include "ops/silu.hpp"
#include "ops/swiglu.hpp" #include "ops/swiglu.hpp"
#pragma once
#include "../device.hpp"
#include "common/op.hpp"
namespace infinicore::op {
class Silu {
public:
using schema = void (*)(Tensor, Tensor);
static void execute(Tensor output, Tensor input);
static common::OpDispatcher<schema> &dispatcher();
};
Tensor silu(Tensor input);
void silu_(Tensor output, Tensor input);
} // namespace infinicore::op
...@@ -30,6 +30,7 @@ from infinicore.ops.attention import attention ...@@ -30,6 +30,7 @@ from infinicore.ops.attention import attention
from infinicore.ops.matmul import matmul from infinicore.ops.matmul import matmul
from infinicore.ops.rearrange import rearrange from infinicore.ops.rearrange import rearrange
from infinicore.ops.rms_norm import rms_norm from infinicore.ops.rms_norm import rms_norm
from infinicore.ops.silu import silu
from infinicore.ops.swiglu import swiglu from infinicore.ops.swiglu import swiglu
from infinicore.tensor import ( from infinicore.tensor import (
empty, empty,
...@@ -75,6 +76,7 @@ __all__ = [ ...@@ -75,6 +76,7 @@ __all__ = [
"matmul", "matmul",
"rearrange", "rearrange",
"rms_norm", "rms_norm",
"silu",
"swiglu", "swiglu",
"empty", "empty",
"from_blob", "from_blob",
......
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor
def silu(input, *, out=None):
if out is None:
return Tensor(_infinicore.silu(input._underlying))
_infinicore.silu_(out._underlying, input._underlying)
#include "infinicore/ops/silu.hpp"
#include <stdexcept>
namespace infinicore::op {
common::OpDispatcher<Silu::schema> &Silu::dispatcher() {
static common::OpDispatcher<Silu::schema> dispatcher_;
return dispatcher_;
};
void Silu::execute(Tensor output, Tensor input) {
auto device_type = context::getDevice().getType();
auto func = dispatcher().lookup(device_type);
if (func == nullptr) {
throw std::runtime_error("No Silu implementation found for device type: " + std::to_string(static_cast<int>(device_type)));
}
func(output, input);
}
Tensor silu(Tensor input) {
Shape shape = input->shape();
auto output = Tensor::empty(shape, input->dtype(), input->device());
silu_(output, input);
return output;
}
void silu_(Tensor output, Tensor input) {
Silu::execute(output, input);
}
} // namespace infinicore::op
#include "../../utils.hpp"
#include "infinicore/common/hash.hpp"
#include "infinicore/ops/common/cache.hpp"
#include "infinicore/ops/silu.hpp"
#include <infiniop.h>
namespace infinicore::op::silu_impl::infiniop {
thread_local common::OpCache<size_t, infiniopSiluDescriptor_t> caches(
100, // capacity
[](infiniopSiluDescriptor_t &desc) {
if (desc != nullptr) {
INFINICORE_CHECK_ERROR(infiniopDestroySiluDescriptor(desc));
desc = nullptr;
}
});
void calculate(Tensor output, Tensor input) {
size_t seed = hash_combine(output, input);
auto device_type = context::getDevice().getType();
auto device_index = context::getDevice().getIndex();
auto &cache = caches.getCache(device_type, device_index);
auto desc_opt = cache.get(seed);
infiniopSiluDescriptor_t desc = nullptr;
if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateSiluDescriptor(
context::getInfiniopHandle(), &desc,
output->desc(), input->desc()));
cache.put(seed, desc);
} else {
desc = *desc_opt;
}
size_t workspace_size = 0;
INFINICORE_CHECK_ERROR(infiniopGetSiluWorkspaceSize(desc, &workspace_size));
std::shared_ptr<Memory> workspace = context::allocateMemory(workspace_size);
INFINICORE_CHECK_ERROR(infiniopSilu(
desc, workspace->data(), workspace_size,
output->data(), input->data(), context::getStream()));
}
static bool registered = []() {
Silu::dispatcher().registerAll(&calculate, false);
return true;
}();
} // namespace infinicore::op::silu_impl::infiniop
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include "ops/matmul.hpp" #include "ops/matmul.hpp"
#include "ops/rearrange.hpp" #include "ops/rearrange.hpp"
#include "ops/rms_norm.hpp" #include "ops/rms_norm.hpp"
#include "ops/silu.hpp"
#include "ops/swiglu.hpp" #include "ops/swiglu.hpp"
namespace py = pybind11; namespace py = pybind11;
...@@ -19,6 +20,7 @@ inline void bind(py::module &m) { ...@@ -19,6 +20,7 @@ inline void bind(py::module &m) {
bind_matmul(m); bind_matmul(m);
bind_rearrange(m); bind_rearrange(m);
bind_rms_norm(m); bind_rms_norm(m);
bind_silu(m);
bind_swiglu(m); bind_swiglu(m);
} }
......
#pragma once
#include <pybind11/pybind11.h>
#include "infinicore/ops/silu.hpp"
namespace py = pybind11;
namespace infinicore::ops {
inline void bind_silu(py::module &m) {
m.def("silu",
&op::silu,
py::arg("input"),
R"doc(SiLU (Swish) activation function.)doc");
m.def("silu_",
&op::silu_,
py::arg("output"),
py::arg("input"),
R"doc(In-place SiLU (Swish) activation function.)doc");
}
} // namespace infinicore::ops
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
from framework.base import BaseOperatorTest, TensorSpec, TestCase
from framework.runner import GenericTestRunner
# ==============================================================================
# Operator-specific configuration
# ==============================================================================
# Test cases format: (operation_mode, shape, input_strides, output_strides)
# SiLU is a single-input activation function: output = input * sigmoid(input)
_TEST_CASES_DATA = [
# Basic 2D SiLU
(TestCase.BOTH, (2, 4), None, None),
(TestCase.BOTH, (128, 64), None, None),
# 3D SiLU
(TestCase.BOTH, (2, 4, 8), None, None),
(TestCase.BOTH, (4, 48, 6), None, None),
# Strided tensors
(TestCase.BOTH, (1, 2048), (4096, 1), (4096, 1)),
(TestCase.BOTH, (6, 2560), (2048, 1), (2560, 1)),
# Mixed cases
(TestCase.BOTH, (8, 16, 32), None, None),
# Large tensors
(TestCase.BOTH, (16, 5632), None, None),
(TestCase.BOTH, (4, 4, 5632), None, None),
]
def parse_test_cases(data):
"""
Parse silu test case data according to format:
(operation_mode, shape, input_strides, output_strides)
"""
operation_mode = data[0]
shape = data[1]
input_strides = data[2] if len(data) > 2 else None
output_strides = data[3] if len(data) > 3 else None
# Create input specifications
inputs = []
# Tensor input
if input_strides is not None:
inputs.append(TensorSpec.from_strided_tensor(shape, input_strides))
else:
inputs.append(TensorSpec.from_tensor(shape))
# Output tensor
if output_strides is not None:
output = TensorSpec.from_strided_tensor(shape, output_strides)
else:
output = TensorSpec.from_tensor(shape)
return TestCase(operation_mode, inputs, output)
# Parse test cases
_TEST_CASES = [parse_test_cases(data) for data in _TEST_CASES_DATA]
# Data types
_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]
# Tolerance
_TOLERANCE_MAP = {
infinicore.float16: {"atol": 1e-3, "rtol": 1e-3},
infinicore.float32: {"atol": 1e-5, "rtol": 1e-5},
infinicore.bfloat16: {"atol": 5e-3, "rtol": 1e-2},
}
class OpTest(BaseOperatorTest):
"""SiLU test with simplified test case parsing"""
def __init__(self):
super().__init__("SiLU")
def get_test_cases(self):
return _TEST_CASES
def get_tensor_dtypes(self):
return _TENSOR_DTYPES
def get_tolerance_map(self):
return _TOLERANCE_MAP
def torch_operator(self, input, out=None, **kwargs):
# SiLU implementation: input * sigmoid(input)
sigmoid_input = torch.sigmoid(input)
result = input * sigmoid_input
if out is not None:
out.copy_(result)
return out
return result
def infinicore_operator(self, input, out=None, **kwargs):
return infinicore.silu(input, out=out)
def main():
"""Main entry point"""
runner = GenericTestRunner(OpTest)
runner.run_and_exit()
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment