Commit 8b760951 authored by zhangyue's avatar zhangyue
Browse files

Merge branch 'main' of https://github.com/InfiniTensor/InfiniCore into issue-385

parents eb3972eb d4b03cf7
#ifndef __CLIP_KUNLUN_KERNEL_H__
#define __CLIP_KUNLUN_KERNEL_H__
#include <xpu/kernel/xtdk_io.h>
namespace op::clip::kunlun {
typedef struct ClipOp {
public:
static constexpr int num_inputs = 3;
template <typename T>
inline __device__ T operator()(const T *inputs) const {
T x = inputs[0];
T min_val = inputs[1];
T max_val = inputs[2];
return fmax(fmin(x, max_val), min_val);
}
// bfloat16 特化版本(使用 float 计算精度)
inline __device__ bfloat16_t operator()(const bfloat16_t *inputs) const {
float x_f = __bfloat162float(inputs[0]);
float min_val_f = __bfloat162float(inputs[1]);
float max_val_f = __bfloat162float(inputs[2]);
float result_f = fmax(fmin(x_f, max_val_f), min_val_f);
return __float2bfloat16(result_f);
}
} ClipOp;
} // namespace op::clip::kunlun
#endif // __CLIP_KUNLUN_KERNEL_H__
...@@ -11,6 +11,9 @@ ...@@ -11,6 +11,9 @@
#ifdef ENABLE_METAX_API #ifdef ENABLE_METAX_API
#include "metax/clip_metax.h" #include "metax/clip_metax.h"
#endif #endif
#ifdef ENABLE_KUNLUN_API
#include "kunlun/clip_kunlun.h"
#endif
__C infiniStatus_t infiniopCreateClipDescriptor( __C infiniStatus_t infiniopCreateClipDescriptor(
infiniopHandle_t handle, infiniopHandle_t handle,
...@@ -42,6 +45,9 @@ __C infiniStatus_t infiniopCreateClipDescriptor( ...@@ -42,6 +45,9 @@ __C infiniStatus_t infiniopCreateClipDescriptor(
#ifdef ENABLE_METAX_API #ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, metax); CREATE(INFINI_DEVICE_METAX, metax);
#endif #endif
#ifdef ENABLE_KUNLUN_API
CREATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
default: default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
...@@ -69,6 +75,9 @@ __C infiniStatus_t infiniopGetClipWorkspaceSize(infiniopClipDescriptor_t desc, s ...@@ -69,6 +75,9 @@ __C infiniStatus_t infiniopGetClipWorkspaceSize(infiniopClipDescriptor_t desc, s
#endif #endif
#ifdef ENABLE_METAX_API #ifdef ENABLE_METAX_API
GET(INFINI_DEVICE_METAX, metax) GET(INFINI_DEVICE_METAX, metax)
#endif
#ifdef ENABLE_KUNLUN_API
GET(INFINI_DEVICE_KUNLUN, kunlun)
#endif #endif
} }
...@@ -106,6 +115,9 @@ __C infiniStatus_t infiniopClip( ...@@ -106,6 +115,9 @@ __C infiniStatus_t infiniopClip(
#ifdef ENABLE_METAX_API #ifdef ENABLE_METAX_API
CALCULATE(INFINI_DEVICE_METAX, metax); CALCULATE(INFINI_DEVICE_METAX, metax);
#endif #endif
#ifdef ENABLE_KUNLUN_API
CALCULATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
default: default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
...@@ -136,6 +148,9 @@ infiniopDestroyClipDescriptor(infiniopClipDescriptor_t desc) { ...@@ -136,6 +148,9 @@ infiniopDestroyClipDescriptor(infiniopClipDescriptor_t desc) {
#ifdef ENABLE_METAX_API #ifdef ENABLE_METAX_API
DELETE(INFINI_DEVICE_METAX, metax); DELETE(INFINI_DEVICE_METAX, metax);
#endif #endif
#ifdef ENABLE_KUNLUN_API
DELETE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
default: default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......
#ifndef __GEMM_MOORE_H__
#define __GEMM_MOORE_H__
#include "../gemm.h"
DESCRIPTOR(moore)
#endif // __GEMM_MOORE_H__
#include "../../../devices/musa/common_musa.h" #include "../../../devices/moore/moore_common.h"
#include "../../../devices/musa/musa_handle.h" #include "../../../devices/moore/moore_handle.h"
#include "gemm_musa.h" #include "gemm_moore.h"
namespace op::gemm::musa { namespace op::gemm::moore {
struct Descriptor::Opaque { struct Descriptor::Opaque {
std::shared_ptr<device::musa::Handle::Internal> internal; std::shared_ptr<device::moore::Handle::Internal> internal;
}; };
Descriptor::~Descriptor() { Descriptor::~Descriptor() {
...@@ -18,10 +18,10 @@ infiniStatus_t Descriptor::create( ...@@ -18,10 +18,10 @@ infiniStatus_t Descriptor::create(
infiniopTensorDescriptor_t c_desc, infiniopTensorDescriptor_t c_desc,
infiniopTensorDescriptor_t a_desc, infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc) { infiniopTensorDescriptor_t b_desc) {
auto handle = reinterpret_cast<device::musa::Handle *>(handle_); auto handle = reinterpret_cast<device::moore::Handle *>(handle_);
auto dtype = c_desc->dtype(); auto dtype = c_desc->dtype();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32); CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16);
auto result = MatmulInfo::create(c_desc, a_desc, b_desc, MatrixLayout::COL_MAJOR); auto result = MatmulInfo::create(c_desc, a_desc, b_desc, MatrixLayout::COL_MAJOR);
CHECK_RESULT(result); CHECK_RESULT(result);
...@@ -33,41 +33,63 @@ infiniStatus_t Descriptor::create( ...@@ -33,41 +33,63 @@ infiniStatus_t Descriptor::create(
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
template <typename Tdata> infiniStatus_t Descriptor::calculate(
infiniStatus_t calculate( void *workspace,
const MatmulInfo &info, size_t workspace_size,
std::shared_ptr<device::musa::Handle::Internal> &_internal,
void *c, void *c,
float beta, float beta,
const void *a, const void *a,
const void *b, const void *b,
float alpha, float alpha,
void *stream) { void *stream) const {
musaDataType a_type, b_type, c_type; musaDataType a_type, b_type, c_type;
mublasComputeType_t compute_type; mublasComputeType_t compute_type;
Tdata alpha_, beta_;
if constexpr (std::is_same<Tdata, half>::value) { // MUSA's GEMM operations require that the scalar values alpha and beta have the same data type as the matrices.
alpha_ = __float2half(alpha); // This ensures correct computation during the muBLAS GEMM operation.
beta_ = __float2half(beta); // Declare half-precision variables to handle F16 types.
half alpha_h, beta_h;
// Initialize generic void pointers for alpha and beta.
// They point to the original float values
// It will be used directly when the GEMM operation is performed with F32 data.
const void *p_alpha = &alpha;
const void *p_beta = &beta;
switch (_dtype) {
case INFINI_DTYPE_F16:
a_type = b_type = c_type = MUSA_R_16F; a_type = b_type = c_type = MUSA_R_16F;
compute_type = MUBLAS_COMPUTE_16F; compute_type = MUBLAS_COMPUTE_16F;
} else {
alpha_ = alpha; // Convert alpha/beta to half-precision and update the pointers.
beta_ = beta; alpha_h = __float2half(alpha);
beta_h = __float2half(beta);
p_alpha = &alpha_h;
p_beta = &beta_h;
break;
case INFINI_DTYPE_BF16:
a_type = b_type = c_type = MUSA_R_16BF;
compute_type = MUBLAS_COMPUTE_32F;
break;
case INFINI_DTYPE_F32:
a_type = b_type = c_type = MUSA_R_32F; a_type = b_type = c_type = MUSA_R_32F;
compute_type = MUBLAS_COMPUTE_32F_FAST_TF32; compute_type = MUBLAS_COMPUTE_32F_FAST_TF32;
break;
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
} }
if (info.is_transed) { if (_info.is_transed) {
std::swap(a, b); std::swap(a, b);
} }
auto op_a = info.a_matrix.row_stride == 1 ? MUBLAS_OP_N : MUBLAS_OP_T; auto op_a = _info.a_matrix.row_stride == 1 ? MUBLAS_OP_N : MUBLAS_OP_T;
auto op_b = info.b_matrix.row_stride == 1 ? MUBLAS_OP_N : MUBLAS_OP_T; auto op_b = _info.b_matrix.row_stride == 1 ? MUBLAS_OP_N : MUBLAS_OP_T;
CHECK_STATUS(_internal->useMublas( CHECK_STATUS(_opaque->internal->useMublas(
(musaStream_t)stream, (musaStream_t)stream,
[&](mublasHandle_t handle) { [&](mublasHandle_t handle) {
CHECK_MUBLAS( CHECK_MUBLAS(
...@@ -75,24 +97,24 @@ infiniStatus_t calculate( ...@@ -75,24 +97,24 @@ infiniStatus_t calculate(
handle, handle,
op_a, op_a,
op_b, op_b,
static_cast<int>(info.m), static_cast<int>(_info.m),
static_cast<int>(info.n), static_cast<int>(_info.n),
static_cast<int>(info.k), static_cast<int>(_info.k),
&alpha_, p_alpha,
a, a,
a_type, a_type,
static_cast<int>(info.a_matrix.ld()), static_cast<int>(_info.a_matrix.ld()),
info.a_matrix.stride, _info.a_matrix.stride,
b, b,
b_type, b_type,
static_cast<int>(info.b_matrix.ld()), static_cast<int>(_info.b_matrix.ld()),
info.b_matrix.stride, _info.b_matrix.stride,
&beta_, p_beta,
c, c,
c_type, c_type,
static_cast<int>(info.c_matrix.ld()), static_cast<int>(_info.c_matrix.ld()),
info.c_matrix.stride, _info.c_matrix.stride,
static_cast<int>(info.batch), static_cast<int>(_info.batch),
compute_type, compute_type,
MUBLAS_GEMM_DEFAULT)); MUBLAS_GEMM_DEFAULT));
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
...@@ -100,22 +122,4 @@ infiniStatus_t calculate( ...@@ -100,22 +122,4 @@ infiniStatus_t calculate(
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
infiniStatus_t Descriptor::calculate(void *workspace, } // namespace op::gemm::moore
size_t workspace_size,
void *c,
float beta,
const void *a,
const void *b,
float alpha,
void *stream) const {
switch (_dtype) {
case INFINI_DTYPE_F16:
return musa::calculate<half>(_info, _opaque->internal, c, beta, a, b, alpha, stream);
case INFINI_DTYPE_F32:
return musa::calculate<float>(_info,_opaque->internal, c, beta, a, b, alpha, stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
}
} // namespace op::gemm::musa
#ifndef __GEMM_MUSA_H__
#define __GEMM_MUSA_H__
#include "../gemm.h"
DESCRIPTOR(musa)
#endif // __GEMM_MUSA_H__
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#include "metax/gemm_metax.h" #include "metax/gemm_metax.h"
#endif #endif
#ifdef ENABLE_MOORE_API #ifdef ENABLE_MOORE_API
#include "musa/gemm_musa.h" #include "moore/gemm_moore.h"
#endif #endif
#ifdef ENABLE_KUNLUN_API #ifdef ENABLE_KUNLUN_API
#include "kunlun/gemm_kunlun.h" #include "kunlun/gemm_kunlun.h"
...@@ -61,7 +61,7 @@ __C infiniStatus_t infiniopCreateGemmDescriptor( ...@@ -61,7 +61,7 @@ __C infiniStatus_t infiniopCreateGemmDescriptor(
CREATE(INFINI_DEVICE_METAX, metax); CREATE(INFINI_DEVICE_METAX, metax);
#endif #endif
#ifdef ENABLE_MOORE_API #ifdef ENABLE_MOORE_API
CREATE(INFINI_DEVICE_MOORE, musa); CREATE(INFINI_DEVICE_MOORE, moore);
#endif #endif
#ifdef ENABLE_KUNLUN_API #ifdef ENABLE_KUNLUN_API
...@@ -106,7 +106,7 @@ infiniopGetGemmWorkspaceSize( ...@@ -106,7 +106,7 @@ infiniopGetGemmWorkspaceSize(
GET(INFINI_DEVICE_METAX, metax); GET(INFINI_DEVICE_METAX, metax);
#endif #endif
#ifdef ENABLE_MOORE_API #ifdef ENABLE_MOORE_API
GET(INFINI_DEVICE_MOORE, musa); GET(INFINI_DEVICE_MOORE, moore);
#endif #endif
#ifdef ENABLE_KUNLUN_API #ifdef ENABLE_KUNLUN_API
GET(INFINI_DEVICE_KUNLUN, kunlun); GET(INFINI_DEVICE_KUNLUN, kunlun);
...@@ -158,7 +158,7 @@ __C infiniStatus_t infiniopGemm( ...@@ -158,7 +158,7 @@ __C infiniStatus_t infiniopGemm(
CALCULATE(INFINI_DEVICE_METAX, metax); CALCULATE(INFINI_DEVICE_METAX, metax);
#endif #endif
#ifdef ENABLE_MOORE_API #ifdef ENABLE_MOORE_API
CALCULATE(INFINI_DEVICE_MOORE, musa); CALCULATE(INFINI_DEVICE_MOORE, moore);
#endif #endif
#ifdef ENABLE_KUNLUN_API #ifdef ENABLE_KUNLUN_API
CALCULATE(INFINI_DEVICE_KUNLUN, kunlun); CALCULATE(INFINI_DEVICE_KUNLUN, kunlun);
...@@ -200,7 +200,7 @@ infiniopDestroyGemmDescriptor(infiniopGemmDescriptor_t desc) { ...@@ -200,7 +200,7 @@ infiniopDestroyGemmDescriptor(infiniopGemmDescriptor_t desc) {
DELETE(INFINI_DEVICE_METAX, metax); DELETE(INFINI_DEVICE_METAX, metax);
#endif #endif
#ifdef ENABLE_MOORE_API #ifdef ENABLE_MOORE_API
DELETE(INFINI_DEVICE_MOORE, musa); DELETE(INFINI_DEVICE_MOORE, moore);
#endif #endif
#ifdef ENABLE_KUNLUN_API #ifdef ENABLE_KUNLUN_API
DELETE(INFINI_DEVICE_KUNLUN, kunlun); DELETE(INFINI_DEVICE_KUNLUN, kunlun);
......
#ifndef __MUL_KUNLUN_KERNEL_H__
#define __MUL_KUNLUN_KERNEL_H__
namespace op::mul::kunlun {
typedef struct MulOp {
public:
static constexpr int num_inputs = 2;
template <typename T>
inline __device__ T operator()(const T *inputs) const {
T a = inputs[0];
T b = inputs[1];
return a * b;
}
// bfloat16 特化版本(使用 float 计算精度)
inline __device__ bfloat16_t operator()(const bfloat16_t *inputs) const {
float a_f = __bfloat162float(inputs[0]);
float b_f = __bfloat162float(inputs[1]);
return __float2bfloat16(a_f * b_f);
}
} MulOp;
} // namespace op::mul::kunlun
#endif // __MUL_KUNLUN_KERNEL_H__
#ifndef __MUL_KUNLUN_API_H__
#define __MUL_KUNLUN_API_H__
#include "../../../elementwise/kunlun/elementwise_kunlun_api.h"
ELEMENTWISE_DESCRIPTOR(mul, kunlun)
#endif // __MUL_KUNLUN_API_H__
#include "../../../elementwise/kunlun/elementwise_kunlun.h"
#include "kernel.h"
#include "mul_kunlun.h"
namespace op::elementwise::kunlun {
using MulOp = op::mul::kunlun::MulOp;
INSTANTIATE_ELEMENTWISE_KERNEL(MulOp::num_inputs, MulOp, float);
INSTANTIATE_ELEMENTWISE_KERNEL(MulOp::num_inputs, MulOp, half);
INSTANTIATE_ELEMENTWISE_KERNEL(MulOp::num_inputs, MulOp, bfloat16_t);
} // namespace op::elementwise::kunlun
namespace op::mul::kunlun {
Descriptor::~Descriptor() = default;
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {
auto handle = reinterpret_cast<device::kunlun::Handle *>(handle_);
auto dtype = out_desc->dtype();
const auto &a_desc = input_desc_vec.at(0);
const auto &b_desc = input_desc_vec.at(1);
const auto &c_shape = out_desc->shape();
const auto &a_shape = a_desc->shape();
const auto &b_shape = b_desc->shape();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16);
CHECK_SAME_SHAPE(c_shape, a_shape, b_shape);
// create KUNLUN elementwise descriptor
CREATE_ELEMENTWISE_KUNLUN_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec)
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *output,
std::vector<const void *> inputs,
void *stream) const {
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
switch (_dtype) {
case INFINI_DTYPE_F16:
return _device_info->calculate<8, MulOp, half>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_BF16:
return _device_info->calculate<8, MulOp, bfloat16_t>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F32:
return _device_info->calculate<8, MulOp, float>(_info, workspace, output, inputs, stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::mul::kunlun
...@@ -11,6 +11,9 @@ ...@@ -11,6 +11,9 @@
#ifdef ENABLE_METAX_API #ifdef ENABLE_METAX_API
#include "metax/mul_metax.h" #include "metax/mul_metax.h"
#endif #endif
#ifdef ENABLE_KUNLUN_API
#include "kunlun/mul_kunlun.h"
#endif
__C infiniStatus_t infiniopCreateMulDescriptor( __C infiniStatus_t infiniopCreateMulDescriptor(
infiniopHandle_t handle, infiniopHandle_t handle,
...@@ -42,6 +45,9 @@ __C infiniStatus_t infiniopCreateMulDescriptor( ...@@ -42,6 +45,9 @@ __C infiniStatus_t infiniopCreateMulDescriptor(
#ifdef ENABLE_METAX_API #ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, metax); CREATE(INFINI_DEVICE_METAX, metax);
#endif #endif
#ifdef ENABLE_KUNLUN_API
CREATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
default: default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
...@@ -70,6 +76,9 @@ __C infiniStatus_t infiniopGetMulWorkspaceSize(infiniopMulDescriptor_t desc, siz ...@@ -70,6 +76,9 @@ __C infiniStatus_t infiniopGetMulWorkspaceSize(infiniopMulDescriptor_t desc, siz
#ifdef ENABLE_METAX_API #ifdef ENABLE_METAX_API
GET(INFINI_DEVICE_METAX, metax); GET(INFINI_DEVICE_METAX, metax);
#endif #endif
#ifdef ENABLE_KUNLUN_API
GET(INFINI_DEVICE_KUNLUN, kunlun);
#endif
default: default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
...@@ -107,6 +116,9 @@ __C infiniStatus_t infiniopMul( ...@@ -107,6 +116,9 @@ __C infiniStatus_t infiniopMul(
#ifdef ENABLE_METAX_API #ifdef ENABLE_METAX_API
CALCULATE(INFINI_DEVICE_METAX, metax); CALCULATE(INFINI_DEVICE_METAX, metax);
#endif #endif
#ifdef ENABLE_KUNLUN_API
CALCULATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
default: default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
...@@ -137,6 +149,9 @@ infiniopDestroyMulDescriptor(infiniopMulDescriptor_t desc) { ...@@ -137,6 +149,9 @@ infiniopDestroyMulDescriptor(infiniopMulDescriptor_t desc) {
#ifdef ENABLE_METAX_API #ifdef ENABLE_METAX_API
DELETE(INFINI_DEVICE_METAX, metax); DELETE(INFINI_DEVICE_METAX, metax);
#endif #endif
#ifdef ENABLE_KUNLUN_API
DELETE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
default: default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......
...@@ -22,7 +22,7 @@ __device__ void rmsnormBlock( ...@@ -22,7 +22,7 @@ __device__ void rmsnormBlock(
// Thread_0 computes RMS=1/sqrt(ss/dim+epsilon) and stores in shared memory // Thread_0 computes RMS=1/sqrt(ss/dim+epsilon) and stores in shared memory
__shared__ Tcompute rms; __shared__ Tcompute rms;
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
rms = Tdata(rsqrtf(ss / Tcompute(dim) + epsilon)); rms = Tcompute(rsqrtf(ss / Tcompute(dim) + epsilon));
} }
__syncthreads(); __syncthreads();
......
#ifndef __RMS_NORM_MUSA_CUH__ #ifndef __RMS_NORM_MOORE_H__
#define __RMS_NORM_MUSA_CUH__ #define __RMS_NORM_MOORE_H__
#include "../rms_norm.h" #include "../rms_norm.h"
DESCRIPTOR(musa) DESCRIPTOR(moore)
#endif #endif
#include "../../../devices/musa/common_musa.h" #include "../../../devices/moore/moore_common.h"
#include "../cuda/rms_norm_kernel.cuh" #include "rms_norm_moore.h"
#include "rms_norm_musa.cuh"
namespace op::rms_norm::musa { #include "../../../devices/moore/moore_kernel_common.h"
#include <cub/block/block_reduce.cuh>
#include "../../../reduce/cuda/reduce.cuh"
#include "../cuda/kernel.cuh"
template <unsigned int BLOCK_SIZE, typename Tcompute, typename Tdata, typename Tweight>
INFINIOP_MOORE_KERNEL rmsnormKernel(
Tdata *__restrict__ y,
ptrdiff_t stride_y,
const Tdata *__restrict__ x,
ptrdiff_t stride_x,
const Tweight *__restrict__ w,
size_t dim,
float epsilon) {
rmsnormBlock<BLOCK_SIZE, Tcompute>(y, stride_y, x, stride_x, w, dim, epsilon);
}
namespace op::rms_norm::moore {
struct Descriptor::Opaque { struct Descriptor::Opaque {
std::shared_ptr<device::musa::Handle::Internal> internal; std::shared_ptr<device::moore::Handle::Internal> internal;
}; };
Descriptor::~Descriptor() { Descriptor::~Descriptor() {
...@@ -29,7 +47,7 @@ infiniStatus_t Descriptor::create( ...@@ -29,7 +47,7 @@ infiniStatus_t Descriptor::create(
} }
*desc_ptr = new Descriptor( *desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::musa::Handle *>(handle)->internal()}, new Opaque{reinterpret_cast<device::moore::Handle *>(handle)->internal()},
std::move(info), std::move(info),
0, 0,
handle->device, handle->device_id); handle->device, handle->device_id);
...@@ -46,20 +64,24 @@ infiniStatus_t launchKernel( ...@@ -46,20 +64,24 @@ infiniStatus_t launchKernel(
float epsilon, float epsilon,
musaStream_t musa_stream) { musaStream_t musa_stream) {
#define LAUNCH_KERNEL(Tdata, Tweight, Tcompute) \ #define LAUNCH_KERNEL(Tdata, Tweight, Tcompute) \
rmsnormBlock<BLOCK_SIZE, Tdata, Tweight, Tcompute><<<batch_size, BLOCK_SIZE, 0, musa_stream>>>( \ rmsnormKernel<BLOCK_SIZE, Tcompute, Tdata, Tweight><<<batch_size, BLOCK_SIZE, 0, musa_stream>>>( \
reinterpret_cast<Tdata *>(y), \ reinterpret_cast<Tdata *>(y), \
stride_y, \ stride_y, \
reinterpret_cast<const Tdata *>(x), \ reinterpret_cast<const Tdata *>(x), \
stride_x, \ stride_x, \
reinterpret_cast<const Tweight *>(w), \ reinterpret_cast<const Tweight *>(w), \
dim, \ dim, \
epsilon) epsilon)
if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) { if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) {
LAUNCH_KERNEL(half, half, float); LAUNCH_KERNEL(half, half, float);
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F32) { } else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(half, float, float); LAUNCH_KERNEL(half, float, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_BF16) {
LAUNCH_KERNEL(__mt_bfloat16, __mt_bfloat16, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(__mt_bfloat16, float, float);
} else if (atype == INFINI_DTYPE_F32 && wtype == INFINI_DTYPE_F32) { } else if (atype == INFINI_DTYPE_F32 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(float, float, float); LAUNCH_KERNEL(float, float, float);
} else { } else {
...@@ -87,11 +109,15 @@ infiniStatus_t Descriptor::calculate( ...@@ -87,11 +109,15 @@ infiniStatus_t Descriptor::calculate(
auto musa_stream = reinterpret_cast<musaStream_t>(stream); auto musa_stream = reinterpret_cast<musaStream_t>(stream);
// launch kernel with different block sizes // launch kernel with different block sizes
if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) { if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_1024) {
CHECK_STATUS(launchKernel<CUDA_BLOCK_SIZE_1024>(batch_size, dim, y, _info.atype, stride_y, x, stride_x, w, _info.wtype, _info.epsilon, musa_stream)); CHECK_STATUS(launchKernel<MOORE_BLOCK_SIZE_1024>(batch_size, dim, y, _info.atype, stride_y, x, stride_x, w, _info.wtype, _info.epsilon, musa_stream));
} else if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_512) {
CHECK_STATUS(launchKernel<MOORE_BLOCK_SIZE_512>(batch_size, dim, y, _info.atype, stride_y, x, stride_x, w, _info.wtype, _info.epsilon, musa_stream));
} else if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_2048) {
CHECK_STATUS(launchKernel<MOORE_BLOCK_SIZE_2048>(batch_size, dim, y, _info.atype, stride_y, x, stride_x, w, _info.wtype, _info.epsilon, musa_stream));
} else { } else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
} }
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
} // namespace op::rms_norm::musa } // namespace op::rms_norm::moore
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#include "metax/rms_norm_metax.cuh" #include "metax/rms_norm_metax.cuh"
#endif #endif
#ifdef ENABLE_MOORE_API #ifdef ENABLE_MOORE_API
#include "musa/rms_norm_musa.cuh" #include "moore/rms_norm_moore.h"
#endif #endif
#ifdef ENABLE_KUNLUN_API #ifdef ENABLE_KUNLUN_API
#include "kunlun/rms_norm_kunlun.h" #include "kunlun/rms_norm_kunlun.h"
...@@ -64,7 +64,7 @@ __C infiniStatus_t infiniopCreateRMSNormDescriptor( ...@@ -64,7 +64,7 @@ __C infiniStatus_t infiniopCreateRMSNormDescriptor(
CREATE(INFINI_DEVICE_METAX, metax); CREATE(INFINI_DEVICE_METAX, metax);
#endif #endif
#ifdef ENABLE_MOORE_API #ifdef ENABLE_MOORE_API
CREATE(INFINI_DEVICE_MOORE, musa); CREATE(INFINI_DEVICE_MOORE, moore);
#endif #endif
} }
...@@ -105,7 +105,7 @@ __C infiniStatus_t infiniopGetRMSNormWorkspaceSize(infiniopRMSNormDescriptor_t d ...@@ -105,7 +105,7 @@ __C infiniStatus_t infiniopGetRMSNormWorkspaceSize(infiniopRMSNormDescriptor_t d
GET(INFINI_DEVICE_METAX, metax); GET(INFINI_DEVICE_METAX, metax);
#endif #endif
#ifdef ENABLE_MOORE_API #ifdef ENABLE_MOORE_API
GET(INFINI_DEVICE_MOORE, musa); GET(INFINI_DEVICE_MOORE, moore);
#endif #endif
} }
...@@ -147,7 +147,7 @@ __C infiniStatus_t infiniopRMSNorm(infiniopRMSNormDescriptor_t desc, void *works ...@@ -147,7 +147,7 @@ __C infiniStatus_t infiniopRMSNorm(infiniopRMSNormDescriptor_t desc, void *works
CALCULATE(INFINI_DEVICE_METAX, metax); CALCULATE(INFINI_DEVICE_METAX, metax);
#endif #endif
#ifdef ENABLE_MOORE_API #ifdef ENABLE_MOORE_API
CALCULATE(INFINI_DEVICE_MOORE, musa); CALCULATE(INFINI_DEVICE_MOORE, moore);
#endif #endif
} }
...@@ -188,7 +188,7 @@ __C infiniStatus_t infiniopDestroyRMSNormDescriptor(infiniopRMSNormDescriptor_t ...@@ -188,7 +188,7 @@ __C infiniStatus_t infiniopDestroyRMSNormDescriptor(infiniopRMSNormDescriptor_t
DESTROY(INFINI_DEVICE_METAX, metax); DESTROY(INFINI_DEVICE_METAX, metax);
#endif #endif
#ifdef ENABLE_MOORE_API #ifdef ENABLE_MOORE_API
DESTROY(INFINI_DEVICE_MOORE, musa); DESTROY(INFINI_DEVICE_MOORE, moore);
#endif #endif
} }
......
#ifndef __SUB_KUNLUN_KERNEL_H__
#define __SUB_KUNLUN_KERNEL_H__
namespace op::sub::kunlun {
typedef struct SubOp {
public:
static constexpr int num_inputs = 2;
template <typename T>
inline __device__ T operator()(const T *inputs) const {
T a = inputs[0];
T b = inputs[1];
return a - b;
}
// bfloat16 特化版本(使用 float 计算精度)
inline __device__ bfloat16_t operator()(const bfloat16_t *inputs) const {
float a_f = __bfloat162float(inputs[0]);
float b_f = __bfloat162float(inputs[1]);
return __float2bfloat16(a_f - b_f);
}
} SubOp;
} // namespace op::sub::kunlun
#endif // __SUB_KUNLUN_KERNEL_H__
#ifndef __SUB_KUNLUN_API_H__
#define __SUB_KUNLUN_API_H__
#include "../../../elementwise/kunlun/elementwise_kunlun_api.h"
ELEMENTWISE_DESCRIPTOR(sub, kunlun)
#endif // __SUB_KUNLUN_API_H__
#include "../../../elementwise/kunlun/elementwise_kunlun.h"
#include "kernel.h"
#include "sub_kunlun.h"
namespace op::elementwise::kunlun {
using SubOp = op::sub::kunlun::SubOp;
INSTANTIATE_ELEMENTWISE_KERNEL(SubOp::num_inputs, SubOp, float);
INSTANTIATE_ELEMENTWISE_KERNEL(SubOp::num_inputs, SubOp, half);
INSTANTIATE_ELEMENTWISE_KERNEL(SubOp::num_inputs, SubOp, bfloat16_t);
} // namespace op::elementwise::kunlun
namespace op::sub::kunlun {
Descriptor::~Descriptor() = default;
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {
auto handle = reinterpret_cast<device::kunlun::Handle *>(handle_);
auto dtype = out_desc->dtype();
const auto &a_desc = input_desc_vec.at(0);
const auto &b_desc = input_desc_vec.at(1);
const auto &c_shape = out_desc->shape();
const auto &a_shape = a_desc->shape();
const auto &b_shape = b_desc->shape();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16);
CHECK_SAME_SHAPE(c_shape, a_shape, b_shape);
// create KUNLUN elementwise descriptor
CREATE_ELEMENTWISE_KUNLUN_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec)
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *output,
std::vector<const void *> inputs,
void *stream) const {
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
switch (_dtype) {
case INFINI_DTYPE_F16:
return _device_info->calculate<8, SubOp, half>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_BF16:
return _device_info->calculate<8, SubOp, bfloat16_t>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F32:
return _device_info->calculate<8, SubOp, float>(_info, workspace, output, inputs, stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::sub::kunlun
...@@ -11,6 +11,9 @@ ...@@ -11,6 +11,9 @@
#ifdef ENABLE_METAX_API #ifdef ENABLE_METAX_API
#include "metax/sub_metax.h" #include "metax/sub_metax.h"
#endif #endif
#ifdef ENABLE_KUNLUN_API
#include "kunlun/sub_kunlun.h"
#endif
__C infiniStatus_t infiniopCreateSubDescriptor( __C infiniStatus_t infiniopCreateSubDescriptor(
infiniopHandle_t handle, infiniopHandle_t handle,
...@@ -42,6 +45,9 @@ __C infiniStatus_t infiniopCreateSubDescriptor( ...@@ -42,6 +45,9 @@ __C infiniStatus_t infiniopCreateSubDescriptor(
#ifdef ENABLE_METAX_API #ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, metax); CREATE(INFINI_DEVICE_METAX, metax);
#endif #endif
#ifdef ENABLE_KUNLUN_API
CREATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
default: default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
...@@ -70,6 +76,10 @@ __C infiniStatus_t infiniopGetSubWorkspaceSize(infiniopSubDescriptor_t desc, siz ...@@ -70,6 +76,10 @@ __C infiniStatus_t infiniopGetSubWorkspaceSize(infiniopSubDescriptor_t desc, siz
#ifdef ENABLE_METAX_API #ifdef ENABLE_METAX_API
GET(INFINI_DEVICE_METAX, metax); GET(INFINI_DEVICE_METAX, metax);
#endif #endif
#ifdef ENABLE_KUNLUN_API
GET(INFINI_DEVICE_KUNLUN, kunlun);
#endif
default: default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
...@@ -106,6 +116,9 @@ __C infiniStatus_t infiniopSub( ...@@ -106,6 +116,9 @@ __C infiniStatus_t infiniopSub(
#ifdef ENABLE_METAX_API #ifdef ENABLE_METAX_API
CALCULATE(INFINI_DEVICE_METAX, metax); CALCULATE(INFINI_DEVICE_METAX, metax);
#endif #endif
#ifdef ENABLE_KUNLUN_API
CALCULATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
default: default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
...@@ -136,6 +149,9 @@ infiniopDestroySubDescriptor(infiniopSubDescriptor_t desc) { ...@@ -136,6 +149,9 @@ infiniopDestroySubDescriptor(infiniopSubDescriptor_t desc) {
#ifdef ENABLE_METAX_API #ifdef ENABLE_METAX_API
DELETE(INFINI_DEVICE_METAX, metax); DELETE(INFINI_DEVICE_METAX, metax);
#endif #endif
#ifdef ENABLE_KUNLUN_API
DELETE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
default: default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......
#ifndef __SWIGLU_BANG_API_H__
#define __SWIGLU_BANG_API_H__
#include "../../../elementwise/bang/elementwise_bang.h"
ELEMENTWISE_DESCRIPTOR(swiglu, bang)
#endif // __SWIGLU_BANG_API_H__
#include "swiglu_bang.h"
// Operator Interface Declaration
LAUNCH_ELEMENTWISE_KERNEL(SwiGLU)
namespace op::swiglu::bang {
typedef struct SwiGLUOp {
static constexpr size_t num_inputs = 2;
template <typename Tdata, typename... Args>
static infiniStatus_t launch(Args... args) {
launchSwiGLUKernel<Tdata>(args...);
return INFINI_STATUS_SUCCESS;
}
} SwiGLUOp;
Descriptor::~Descriptor() = default;
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {
auto handle = reinterpret_cast<device::bang::Handle *>(handle_);
auto dtype = out_desc->dtype();
const auto &up_desc = input_desc_vec.at(0);
const auto &gate_desc = input_desc_vec.at(1);
const auto &out_shape = out_desc->shape();
const auto &up_shape = up_desc->shape();
const auto &gate_shape = gate_desc->shape();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16);
CHECK_SAME_SHAPE(out_shape, up_shape, gate_shape);
// create Bang elementwise descriptor
CREATE_ELEMENTWISE_BANG_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec)
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *output,
std::vector<const void *> inputs,
void *queue) const {
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
switch (_dtype) {
case INFINI_DTYPE_F16:
return _device_info->calculate<SwiGLUOp, half>(_info, workspace, output, inputs, queue);
case INFINI_DTYPE_BF16:
return _device_info->calculate<SwiGLUOp, bfloat16_t>(_info, workspace, output, inputs, queue);
case INFINI_DTYPE_F32:
return _device_info->calculate<SwiGLUOp, float>(_info, workspace, output, inputs, queue);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::swiglu::bang
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment