Commit 5614e1be authored by wooway777's avatar wooway777
Browse files

issue/931 - ninetoothed swiglu for nv, il, mtx

parent 6ac8f906
import ninetoothed
from . import swiglu
import infiniop.ninetoothed.build
def build():
MAX_NDIM = 5
ndim_values = range(1, MAX_NDIM + 1)
dtype_values = (
ninetoothed.float16,
ninetoothed.bfloat16,
ninetoothed.float32,
)
constexpr_param_grid = {
"ndim": ndim_values,
"dtype": dtype_values,
"block_size": (1024,),
}
infiniop.ninetoothed.build.build(
swiglu.premake,
constexpr_param_grid,
caller="cuda",
op_name="swiglu",
output_dir=infiniop.ninetoothed.build.BUILD_DIRECTORY_PATH,
)
#ifndef SWIGLU_H
#define SWIGLU_H
#include "../../../handle.h"
#include "../../../operator.h"
#include "../../../tensor.h"
#include "../../../../../build/ninetoothed/swiglu.h"
#include "../../../ninetoothed/utils.h"
namespace op::swiglu::ninetoothed {
class Descriptor final : public InfiniopDescriptor {
public:
Descriptor(
infiniopHandle_t handle,
infiniopTensorDescriptor_t out_desc,
std::vector<infiniopTensorDescriptor_t> input_desc_vec) : InfiniopDescriptor{handle->device, handle->device_id},
out_shape_{out_desc->shape()},
out_strides_{out_desc->strides()},
up_shape_{input_desc_vec[0]->shape()},
up_strides_{input_desc_vec[0]->strides()},
gate_shape_{input_desc_vec[1]->shape()},
gate_strides_{input_desc_vec[1]->strides()},
dtype_{out_desc->dtype()} {}
~Descriptor() = default;
size_t workspaceSize() const {
return 0;
}
static infiniStatus_t create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {
*desc_ptr = new Descriptor(handle, out_desc, input_desc_vec);
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t calculate(
void *workspace,
size_t workspace_size,
void *output,
std::vector<const void *> inputs,
void *stream) const {
auto out_nt{::ninetoothed::Tensor(output, out_shape_, out_strides_)};
auto up_nt{::ninetoothed::Tensor(inputs[0], up_shape_, up_strides_)};
auto gate_nt{::ninetoothed::Tensor(inputs[1], gate_shape_, gate_strides_)};
if (launch_swiglu(stream,
out_nt,
up_nt,
gate_nt,
out_shape_.size(),
dtype_,
1024)) {
return INFINI_STATUS_NOT_IMPLEMENTED;
}
return INFINI_STATUS_SUCCESS;
}
private:
using Size = ::ninetoothed::Tensor<>::Size;
using Stride = ::ninetoothed::Tensor<>::Stride;
std::vector<Size> out_shape_;
std::vector<Stride> out_strides_;
std::vector<Size> up_shape_;
std::vector<Stride> up_strides_;
std::vector<Size> gate_shape_;
std::vector<Stride> gate_strides_;
infiniDtype_t dtype_;
};
} // namespace op::swiglu::ninetoothed
#endif // SWIGLU_H
import functools
import ninetoothed.language as ntl
from ninetoothed import Tensor
from ntops.kernels.element_wise import arrangement
def application(output, up, gate):
output = ntl.sigmoid(ntl.cast(gate, ntl.float32)) * gate * up # noqa: F841
def premake(ndim, dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement, block_size=block_size)
tensors = (
Tensor(ndim, dtype=dtype),
Tensor(ndim, dtype=dtype),
Tensor(ndim, dtype=dtype),
)
return arrangement_, application, tensors
......@@ -6,14 +6,22 @@
#include "cpu/swiglu_cpu.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API)
#if defined(ENABLE_NINETOOTHED)
#include "ninetoothed/swiglu.h"
#else
#include "nvidia/swiglu_nvidia.cuh"
#endif
#endif
#ifdef ENABLE_KUNLUN_API
#include "kunlun/swiglu_kunlun.h"
#endif
#ifdef ENABLE_METAX_API
#if defined(ENABLE_NINETOOTHED)
#include "ninetoothed/swiglu.h"
#else
#include "metax/swiglu_metax.h"
#endif
#endif
#ifdef ENABLE_CAMBRICON_API
#include "bang/swiglu_bang.h"
#endif
......@@ -46,11 +54,19 @@ __C infiniStatus_t infiniopCreateSwiGLUDescriptor(
CREATE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NVIDIA_API
#ifdef ENABLE_NINETOOTHED
CREATE(INFINI_DEVICE_NVIDIA, ninetoothed);
#else
CREATE(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#endif
#ifdef ENABLE_ILUVATAR_API
#ifdef ENABLE_NINETOOTHED
CREATE(INFINI_DEVICE_ILUVATAR, ninetoothed);
#else
CREATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#endif
#ifdef ENABLE_QY_API
CREATE(INFINI_DEVICE_QY, nvidia);
#endif
......@@ -61,8 +77,12 @@ __C infiniStatus_t infiniopCreateSwiGLUDescriptor(
CREATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_METAX_API
#ifdef ENABLE_NINETOOTHED
CREATE(INFINI_DEVICE_METAX, ninetoothed);
#else
CREATE(INFINI_DEVICE_METAX, metax);
#endif
#endif
#ifdef ENABLE_CAMBRICON_API
CREATE(INFINI_DEVICE_CAMBRICON, bang);
#endif
......@@ -92,11 +112,19 @@ __C infiniStatus_t infiniopGetSwiGLUWorkspaceSize(infiniopSwiGLUDescriptor_t des
GET(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NVIDIA_API
#ifdef ENABLE_NINETOOTHED
GET(INFINI_DEVICE_NVIDIA, ninetoothed);
#else
GET(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#endif
#ifdef ENABLE_ILUVATAR_API
#ifdef ENABLE_NINETOOTHED
GET(INFINI_DEVICE_ILUVATAR, ninetoothed);
#else
GET(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#endif
#ifdef ENABLE_QY_API
GET(INFINI_DEVICE_QY, nvidia);
#endif
......@@ -107,8 +135,12 @@ __C infiniStatus_t infiniopGetSwiGLUWorkspaceSize(infiniopSwiGLUDescriptor_t des
GET(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_METAX_API
#ifdef ENABLE_NINETOOTHED
GET(INFINI_DEVICE_METAX, ninetoothed);
#else
GET(INFINI_DEVICE_METAX, metax);
#endif
#endif
#ifdef ENABLE_CAMBRICON_API
GET(INFINI_DEVICE_CAMBRICON, bang);
#endif
......@@ -145,11 +177,19 @@ __C infiniStatus_t infiniopSwiGLU(
CALCULATE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NVIDIA_API
#ifdef ENABLE_NINETOOTHED
CALCULATE(INFINI_DEVICE_NVIDIA, ninetoothed);
#else
CALCULATE(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#endif
#ifdef ENABLE_ILUVATAR_API
#ifdef ENABLE_NINETOOTHED
CALCULATE(INFINI_DEVICE_ILUVATAR, ninetoothed);
#else
CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#endif
#ifdef ENABLE_QY_API
CALCULATE(INFINI_DEVICE_QY, nvidia);
#endif
......@@ -160,8 +200,12 @@ __C infiniStatus_t infiniopSwiGLU(
CALCULATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_METAX_API
#ifdef ENABLE_NINETOOTHED
CALCULATE(INFINI_DEVICE_METAX, ninetoothed);
#else
CALCULATE(INFINI_DEVICE_METAX, metax);
#endif
#endif
#ifdef ENABLE_CAMBRICON_API
CALCULATE(INFINI_DEVICE_CAMBRICON, bang);
#endif
......@@ -193,11 +237,19 @@ infiniopDestroySwiGLUDescriptor(infiniopSwiGLUDescriptor_t desc) {
DELETE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NVIDIA_API
#ifdef ENABLE_NINETOOTHED
DELETE(INFINI_DEVICE_NVIDIA, ninetoothed);
#else
DELETE(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#endif
#ifdef ENABLE_ILUVATAR_API
#ifdef ENABLE_NINETOOTHED
DELETE(INFINI_DEVICE_ILUVATAR, ninetoothed);
#else
DELETE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#endif
#ifdef ENABLE_QY_API
DELETE(INFINI_DEVICE_QY, nvidia);
#endif
......@@ -208,8 +260,12 @@ infiniopDestroySwiGLUDescriptor(infiniopSwiGLUDescriptor_t desc) {
DELETE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_METAX_API
#ifdef ENABLE_NINETOOTHED
DELETE(INFINI_DEVICE_METAX, ninetoothed);
#else
DELETE(INFINI_DEVICE_METAX, metax);
#endif
#endif
#ifdef ENABLE_CAMBRICON_API
DELETE(INFINI_DEVICE_CAMBRICON, bang);
#endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment