Unverified Commit 229ed8ad authored by pengcheng888's avatar pengcheng888 Committed by GitHub
Browse files

Merge branch 'main' into feature/add_silu_python_api

parents 3633ae01 6cef1a9f
...@@ -7,3 +7,4 @@ ...@@ -7,3 +7,4 @@
#include "ops/rearrange.hpp" #include "ops/rearrange.hpp"
#include "ops/rms_norm.hpp" #include "ops/rms_norm.hpp"
#include "ops/silu.hpp" #include "ops/silu.hpp"
#include "ops/swiglu.hpp"
#pragma once
#include "../device.hpp"
#include "common/op.hpp"
namespace infinicore::op {
class SwiGLU {
public:
using schema = void (*)(Tensor, Tensor, Tensor);
static void execute(Tensor c, Tensor a, Tensor b);
static common::OpDispatcher<schema> &dispatcher();
};
Tensor swiglu(Tensor a, Tensor b);
void swiglu_(Tensor c, Tensor a, Tensor b);
} // namespace infinicore::op
...@@ -31,6 +31,7 @@ from infinicore.ops.matmul import matmul ...@@ -31,6 +31,7 @@ from infinicore.ops.matmul import matmul
from infinicore.ops.rearrange import rearrange from infinicore.ops.rearrange import rearrange
from infinicore.ops.rms_norm import rms_norm from infinicore.ops.rms_norm import rms_norm
from infinicore.ops.silu import silu from infinicore.ops.silu import silu
from infinicore.ops.swiglu import swiglu
from infinicore.tensor import ( from infinicore.tensor import (
empty, empty,
from_blob, from_blob,
...@@ -76,6 +77,7 @@ __all__ = [ ...@@ -76,6 +77,7 @@ __all__ = [
"rearrange", "rearrange",
"rms_norm", "rms_norm",
"silu", "silu",
"swiglu",
"empty", "empty",
"from_blob", "from_blob",
"ones", "ones",
......
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor
def swiglu(input, other, *, out=None):
if out is None:
return Tensor(_infinicore.swiglu(input._underlying, other._underlying))
_infinicore.swiglu_(out._underlying, input._underlying, other._underlying)
#include "infinicore/ops/swiglu.hpp"
#include <stdexcept>
namespace infinicore::op {
common::OpDispatcher<SwiGLU::schema> &SwiGLU::dispatcher() {
static common::OpDispatcher<SwiGLU::schema> dispatcher_;
return dispatcher_;
};
void SwiGLU::execute(Tensor c, Tensor a, Tensor b) {
auto device_type = context::getDevice().getType();
auto func = dispatcher().lookup(device_type);
if (func == nullptr) {
throw std::runtime_error("No SwiGLU implementation found for device type: " + std::to_string(static_cast<int>(device_type)));
}
func(c, a, b);
}
Tensor swiglu(Tensor a, Tensor b) {
Shape shape = a->shape();
auto c = Tensor::empty(shape, a->dtype(), a->device());
swiglu_(c, a, b);
return c;
}
void swiglu_(Tensor c, Tensor a, Tensor b) {
SwiGLU::execute(c, a, b);
}
} // namespace infinicore::op
#include "../../utils.hpp"
#include "infinicore/common/hash.hpp"
#include "infinicore/ops/common/cache.hpp"
#include "infinicore/ops/swiglu.hpp"
#include <infiniop.h>
namespace infinicore::op::swiglu_impl::infiniop {
thread_local common::OpCache<size_t, infiniopSwiGLUDescriptor_t> caches(
100, // capacity
[](infiniopSwiGLUDescriptor_t &desc) {
if (desc != nullptr) {
INFINICORE_CHECK_ERROR(infiniopDestroySwiGLUDescriptor(desc));
desc = nullptr;
}
});
void calculate(Tensor c, Tensor a, Tensor b) {
size_t seed = hash_combine(c, b, a);
auto device_type = context::getDevice().getType();
auto device_index = context::getDevice().getIndex();
auto &cache = caches.getCache(device_type, device_index);
auto desc_opt = cache.get(seed);
infiniopSwiGLUDescriptor_t desc = nullptr;
if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateSwiGLUDescriptor(
context::getInfiniopHandle(), &desc,
c->desc(), a->desc(), b->desc()));
cache.put(seed, desc);
} else {
desc = *desc_opt;
}
size_t workspace_size = 0;
INFINICORE_CHECK_ERROR(infiniopGetSwiGLUWorkspaceSize(desc, &workspace_size));
std::shared_ptr<Memory> workspace = context::allocateMemory(workspace_size);
INFINICORE_CHECK_ERROR(infiniopSwiGLU(
desc, workspace->data(), workspace_size,
c->data(), a->data(), b->data(), context::getStream()));
}
static bool registered = []() {
SwiGLU::dispatcher().registerAll(&calculate, false);
return true;
}();
} // namespace infinicore::op::swiglu_impl::infiniop
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include "ops/rearrange.hpp" #include "ops/rearrange.hpp"
#include "ops/rms_norm.hpp" #include "ops/rms_norm.hpp"
#include "ops/silu.hpp" #include "ops/silu.hpp"
#include "ops/swiglu.hpp"
namespace py = pybind11; namespace py = pybind11;
...@@ -20,6 +21,7 @@ inline void bind(py::module &m) { ...@@ -20,6 +21,7 @@ inline void bind(py::module &m) {
bind_rearrange(m); bind_rearrange(m);
bind_rms_norm(m); bind_rms_norm(m);
bind_silu(m); bind_silu(m);
bind_swiglu(m);
} }
} // namespace infinicore::ops } // namespace infinicore::ops
#pragma once
#include <pybind11/pybind11.h>
#include "infinicore/ops/swiglu.hpp"
namespace py = pybind11;
namespace infinicore::ops {
inline void bind_swiglu(py::module &m) {
m.def("swiglu",
&op::swiglu,
py::arg("a"),
py::arg("b"),
R"doc(SwiGLU activation function.)doc");
m.def("swiglu_",
&op::swiglu_,
py::arg("c"),
py::arg("a"),
py::arg("b"),
R"doc(In-place SwiGLU activation function.)doc");
}
} // namespace infinicore::ops
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
from framework.base import BaseOperatorTest, TensorSpec, TestCase
from framework.runner import GenericTestRunner
# ==============================================================================
# Operator-specific configuration
# ==============================================================================
# Test cases format: (operation_mode, shape, a_strides, b_strides, c_strides)
# SwiGLU operates element-wise on two tensors of the same shape
_TEST_CASES_DATA = [
# Basic 2D SwiGLU
(TestCase.BOTH, (2, 4), None, None, None),
(TestCase.BOTH, (128, 64), None, None, None),
# 3D SwiGLU
(TestCase.BOTH, (2, 4, 8), None, None, None),
(TestCase.BOTH, (4, 48, 6), None, None, None),
# Strided tensors
(TestCase.BOTH, (1, 2048), (4096, 1), (4096, 1), (4096, 1)),
(TestCase.BOTH, (6, 2560), (2048, 1), (1, 2048), (2560, 1)),
# Mixed cases
(TestCase.BOTH, (8, 16, 32), None, None, None),
# Large tensors
(TestCase.BOTH, (16, 5632), None, None, None),
(TestCase.BOTH, (4, 4, 5632), None, None, None),
]
def parse_test_cases(data):
"""
Parse swiglu test case data according to format:
(operation_mode, shape, a_strides, b_strides, c_strides)
"""
operation_mode = data[0]
shape = data[1]
a_strides = data[2] if len(data) > 2 else None
b_strides = data[3] if len(data) > 3 else None
c_strides = data[4] if len(data) > 4 else None
# Create input specifications
inputs = []
# Tensor a
if a_strides is not None:
inputs.append(TensorSpec.from_strided_tensor(shape, a_strides))
else:
inputs.append(TensorSpec.from_tensor(shape))
# Tensor b
if b_strides is not None:
inputs.append(TensorSpec.from_strided_tensor(shape, b_strides))
else:
inputs.append(TensorSpec.from_tensor(shape))
# Output tensor
if c_strides is not None:
output = TensorSpec.from_strided_tensor(shape, c_strides)
else:
output = TensorSpec.from_tensor(shape)
return TestCase(operation_mode, inputs, output)
# Parse test cases
_TEST_CASES = [parse_test_cases(data) for data in _TEST_CASES_DATA]
# Data types
_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]
# Tolerance
_TOLERANCE_MAP = {
infinicore.float16: {"atol": 1e-3, "rtol": 1e-3},
infinicore.float32: {"atol": 1e-5, "rtol": 1e-5},
infinicore.bfloat16: {"atol": 5e-3, "rtol": 1e-2},
}
class OpTest(BaseOperatorTest):
"""SwiGLU test with simplified test case parsing"""
def __init__(self):
super().__init__("SwiGLU")
def get_test_cases(self):
return _TEST_CASES
def get_tensor_dtypes(self):
return _TENSOR_DTYPES
def get_tolerance_map(self):
return _TOLERANCE_MAP
def torch_operator(self, a, b, out=None, **kwargs):
# SwiGLU implementation: a * b * sigmoid(b)
sigmoid_b = torch.sigmoid(b)
result = a * b * sigmoid_b
if out is not None:
out.copy_(result)
return out
return result
def infinicore_operator(self, a, b, out=None, **kwargs):
return infinicore.swiglu(a, b, out=out)
def main():
"""Main entry point"""
runner = GenericTestRunner(OpTest)
runner.run_and_exit()
if __name__ == "__main__":
main()
...@@ -85,7 +85,7 @@ target("infinicore-test") ...@@ -85,7 +85,7 @@ target("infinicore-test")
add_files(os.projectdir().."/src/infinicore/context/*.cc") add_files(os.projectdir().."/src/infinicore/context/*.cc")
add_files(os.projectdir().."/src/infinicore/context/*/*.cc") add_files(os.projectdir().."/src/infinicore/context/*/*.cc")
add_files(os.projectdir().."/src/infinicore/tensor/*.cc") add_files(os.projectdir().."/src/infinicore/tensor/*.cc")
add_files(os.projectdir().."/src/infinicore/op/*/*.cc") add_files(os.projectdir().."/src/infinicore/ops/*/*.cc")
add_files(os.projectdir().."/src/infinicore-test/*.cc") add_files(os.projectdir().."/src/infinicore-test/*.cc")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment