Commit 05247bb7 authored by PanZezhong's avatar PanZezhong
Browse files

issue/291/refactor: 适配沐曦


Signed-off-by: default avatarPanZezhong <panzezhong@qiyuanlab.com>
parent abf1e021
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#define CHECK_CUDA(API) CHECK_INTERNAL(API, cudaSuccess) #define CHECK_CUDA(API) CHECK_INTERNAL(API, cudaSuccess)
using cuda_bfloat16 = nv_bfloat16; using cuda_bfloat16 = nv_bfloat16;
using cuda_bfloat162 = nv_bfloat162;
namespace device::cuda { namespace device::cuda {
// return the memory offset of original tensor, given the flattened index of broadcasted tensor // return the memory offset of original tensor, given the flattened index of broadcasted tensor
......
#define INFINIOP_MACA_KERNEL __global__ void #define INFINIOP_MACA_KERNEL __global__ void
#include <maca_bf16.h>
#include <maca_fp16.h>
// Posible maximum number of threads per block for MACA architectures // Posible maximum number of threads per block for MACA architectures
// Used for picking correct kernel launch configuration // Used for picking correct kernel launch configuration
#define MACA_BLOCK_SIZE_1024 1024 #define MACA_BLOCK_SIZE_1024 1024
...@@ -10,7 +7,8 @@ ...@@ -10,7 +7,8 @@
#define CHECK_MACA(API) CHECK_INTERNAL(API, hcSuccess) #define CHECK_MACA(API) CHECK_INTERNAL(API, hcSuccess)
using cuda_bfloat16 = maca_bfloat16; using cuda_bfloat16 = hpcc_bfloat16;
using cuda_bfloat162 = hpcc_bfloat162;
namespace device::maca { namespace device::maca {
...@@ -52,7 +50,7 @@ exp_(const float val) { ...@@ -52,7 +50,7 @@ exp_(const float val) {
__forceinline__ __device__ long double __forceinline__ __device__ long double
exp_(const long double val) { exp_(const long double val) {
return expl(val); return exp(val);
} }
__forceinline__ __device__ double __forceinline__ __device__ double
...@@ -65,7 +63,7 @@ exp_(const __half x) { ...@@ -65,7 +63,7 @@ exp_(const __half x) {
return hexp(x); return hexp(x);
} }
__forceinline__ __device__ __hpcc_bfloat16; __forceinline__ __device__ __hpcc_bfloat16
exp_(const __hpcc_bfloat16; x) { exp_(const __hpcc_bfloat16 x) {
return hexp(x); return hexp(x);
} }
...@@ -21,9 +21,7 @@ infiniStatus_t Descriptor::create( ...@@ -21,9 +21,7 @@ infiniStatus_t Descriptor::create(
auto handle = reinterpret_cast<device::maca::Handle *>(handle_); auto handle = reinterpret_cast<device::maca::Handle *>(handle_);
auto dtype = c_desc->dtype(); auto dtype = c_desc->dtype();
if (dtype != INFINI_DTYPE_F16 && dtype != INFINI_DTYPE_F32) { CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16);
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
auto result = MatmulInfo::create(c_desc, a_desc, b_desc, MatrixLayout::COL_MAJOR); auto result = MatmulInfo::create(c_desc, a_desc, b_desc, MatrixLayout::COL_MAJOR);
CHECK_RESULT(result); CHECK_RESULT(result);
...@@ -53,7 +51,10 @@ infiniStatus_t Descriptor::calculate( ...@@ -53,7 +51,10 @@ infiniStatus_t Descriptor::calculate(
a_type = b_type = c_type = HPCC_R_16F; a_type = b_type = c_type = HPCC_R_16F;
compute_type = HCBLAS_COMPUTE_32F; compute_type = HCBLAS_COMPUTE_32F;
break; break;
case INFINI_DTYPE_BF16:
a_type = b_type = c_type = HPCC_R_16BF;
compute_type = HCBLAS_COMPUTE_32F;
break;
case INFINI_DTYPE_F32: case INFINI_DTYPE_F32:
a_type = b_type = c_type = HPCC_R_32F; a_type = b_type = c_type = HPCC_R_32F;
compute_type = HCBLAS_COMPUTE_32F_FAST_TF32; compute_type = HCBLAS_COMPUTE_32F_FAST_TF32;
......
#ifndef __RMS_NORM_CUDA_KERNEL_H__ #ifndef __RMS_NORM_CUDA_KERNEL_H__
#define __RMS_NORM_CUDA_KERNEL_H__ #define __RMS_NORM_CUDA_KERNEL_H__
template <unsigned int BLOCK_SIZE, typename Tdata, typename Tweight, typename Tcompute> template <unsigned int BLOCK_SIZE, typename Tcompute, typename Tdata, typename Tweight>
INFINIOP_CUDA_KERNEL rmsnormBlock( __device__ void rmsnormBlock(
Tdata *__restrict__ y, Tdata *__restrict__ y,
ptrdiff_t stride_y, ptrdiff_t stride_y,
const Tdata *__restrict__ x, const Tdata *__restrict__ x,
......
#include "../../../devices/maca/common_maca.h" #include "../../../devices/maca/common_maca.h"
#include "../cuda/rms_norm_kernel.cuh"
#include "rms_norm_metax.cuh" #include "rms_norm_metax.cuh"
#include "../../../devices/maca/maca_kernel_common.h"
#include <cub/block/block_reduce.cuh>
#include "../../../reduce/cuda/reduce.cuh"
#include "../cuda/kernel.cuh"
template <unsigned int BLOCK_SIZE, typename Tcompute, typename Tdata, typename Tweight>
INFINIOP_MACA_KERNEL rmsnormKernel(
Tdata *__restrict__ y,
ptrdiff_t stride_y,
const Tdata *__restrict__ x,
ptrdiff_t stride_x,
const Tweight *__restrict__ w,
size_t dim,
float epsilon) {
rmsnormBlock<BLOCK_SIZE, Tcompute>(y, stride_y, x, stride_x, w, dim, epsilon);
}
namespace op::rms_norm::maca { namespace op::rms_norm::maca {
struct Descriptor::Opaque { struct Descriptor::Opaque {
...@@ -47,7 +65,7 @@ infiniStatus_t launchKernel( ...@@ -47,7 +65,7 @@ infiniStatus_t launchKernel(
hcStream_t maca_stream) { hcStream_t maca_stream) {
#define LAUNCH_KERNEL(Tdata, Tweight, Tcompute) \ #define LAUNCH_KERNEL(Tdata, Tweight, Tcompute) \
rmsnormBlock<BLOCK_SIZE, Tdata, Tweight, Tcompute><<<batch_size, BLOCK_SIZE, 0, maca_stream>>>( \ rmsnormKernel<BLOCK_SIZE, Tcompute, Tdata, Tweight><<<batch_size, BLOCK_SIZE, 0, maca_stream>>>( \
reinterpret_cast<Tdata *>(y), \ reinterpret_cast<Tdata *>(y), \
stride_y, \ stride_y, \
reinterpret_cast<const Tdata *>(x), \ reinterpret_cast<const Tdata *>(x), \
...@@ -91,8 +109,8 @@ infiniStatus_t Descriptor::calculate( ...@@ -91,8 +109,8 @@ infiniStatus_t Descriptor::calculate(
auto maca_stream = reinterpret_cast<hcStream_t>(stream); auto maca_stream = reinterpret_cast<hcStream_t>(stream);
// launch kernel with different block sizes // launch kernel with different block sizes
if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) { if (_opaque->internal->maxThreadsPerBlock() == MACA_BLOCK_SIZE_1024) {
CHECK_STATUS(launchKernel<CUDA_BLOCK_SIZE_1024>(batch_size, dim, y, _info.atype, stride_y, x, stride_x, w, _info.wtype, _info.epsilon, maca_stream)); CHECK_STATUS(launchKernel<MACA_BLOCK_SIZE_1024>(batch_size, dim, y, _info.atype, stride_y, x, stride_x, w, _info.wtype, _info.epsilon, maca_stream));
} else { } else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
} }
......
...@@ -8,6 +8,18 @@ ...@@ -8,6 +8,18 @@
#include "../cuda/kernel.cuh" #include "../cuda/kernel.cuh"
template <unsigned int BLOCK_SIZE, typename Tcompute, typename Tdata, typename Tweight>
INFINIOP_CUDA_KERNEL rmsnormKernel(
Tdata *__restrict__ y,
ptrdiff_t stride_y,
const Tdata *__restrict__ x,
ptrdiff_t stride_x,
const Tweight *__restrict__ w,
size_t dim,
float epsilon) {
rmsnormBlock<BLOCK_SIZE, Tcompute>(y, stride_y, x, stride_x, w, dim, epsilon);
}
namespace op::rms_norm::nvidia { namespace op::rms_norm::nvidia {
struct Descriptor::Opaque { struct Descriptor::Opaque {
...@@ -53,7 +65,7 @@ infiniStatus_t launchKernel( ...@@ -53,7 +65,7 @@ infiniStatus_t launchKernel(
cudaStream_t cuda_stream) { cudaStream_t cuda_stream) {
#define LAUNCH_KERNEL(Tdata, Tweight, Tcompute) \ #define LAUNCH_KERNEL(Tdata, Tweight, Tcompute) \
rmsnormBlock<BLOCK_SIZE, Tdata, Tweight, Tcompute><<<batch_size, BLOCK_SIZE, 0, cuda_stream>>>( \ rmsnormKernel<BLOCK_SIZE, Tcompute, Tdata, Tweight><<<batch_size, BLOCK_SIZE, 0, cuda_stream>>>( \
reinterpret_cast<Tdata *>(y), \ reinterpret_cast<Tdata *>(y), \
stride_y, \ stride_y, \
reinterpret_cast<const Tdata *>(x), \ reinterpret_cast<const Tdata *>(x), \
...@@ -108,4 +120,4 @@ infiniStatus_t Descriptor::calculate( ...@@ -108,4 +120,4 @@ infiniStatus_t Descriptor::calculate(
} }
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
} // namespace op::rms_norm::cuda } // namespace op::rms_norm::nvidia
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#define __INFINIOP_ROPE_CUDA_KERNEL_CUH__ #define __INFINIOP_ROPE_CUDA_KERNEL_CUH__
template <typename Tdata, typename Tindex, typename Tangle> template <typename Tdata, typename Tindex, typename Tangle>
INFINIOP_CUDA_KERNEL ropeThreadPerItem( __device__ void ropeThreadPerItemBlock(
Tdata *y_, Tdata *y_,
const Tdata *x_, const Tdata *x_,
const Tindex *__restrict__ pos_ids, const Tindex *__restrict__ pos_ids,
...@@ -28,9 +28,9 @@ INFINIOP_CUDA_KERNEL ropeThreadPerItem( ...@@ -28,9 +28,9 @@ INFINIOP_CUDA_KERNEL ropeThreadPerItem(
Tangle y0 = x.x * cos__ - x.y * sin__, Tangle y0 = x.x * cos__ - x.y * sin__,
y1 = x.x * sin__ + x.y * cos__; y1 = x.x * sin__ + x.y * cos__;
y = half2(y0, y1); y = half2(y0, y1);
} else if constexpr (std::is_same<Tdata, __nv_bfloat16>::value) { } else if constexpr (std::is_same<Tdata, cuda_bfloat16>::value) {
auto &y = reinterpret_cast<__nv_bfloat162 &>(y_[y_offset + 2 * i]); auto &y = reinterpret_cast<cuda_bfloat162 &>(y_[y_offset + 2 * i]);
auto &x = reinterpret_cast<const __nv_bfloat162 &>(x_[x_offset + 2 * i]); auto &x = reinterpret_cast<const cuda_bfloat162 &>(x_[x_offset + 2 * i]);
Tangle x0 = __low2bfloat16(x); Tangle x0 = __low2bfloat16(x);
Tangle x1 = __high2bfloat16(x); Tangle x1 = __high2bfloat16(x);
......
#ifndef __INFINIOP_ROPE_MACA_KERNEL_H__
#define __INFINIOP_ROPE_MACA_KERNEL_H__
#include "../../../devices/maca/maca_kernel_common.h"
template <typename Tdata, typename Tindex, typename Tangle>
INFINIOP_MACA_KERNEL ropeThreadPerItem(
Tdata *y_,
const Tdata *x_,
const Tindex *__restrict__ pos_ids,
const Tangle *__restrict__ sin_table,
const Tangle *__restrict__ cos_table,
size_t table_dim,
ptrdiff_t y_stride_seqlen,
ptrdiff_t y_stride_nhead,
ptrdiff_t x_stride_seqlen,
ptrdiff_t x_stride_nhead) {
auto y_offset = blockIdx.x * y_stride_seqlen + blockIdx.y * y_stride_nhead;
auto x_offset = blockIdx.x * x_stride_seqlen + blockIdx.y * x_stride_nhead;
size_t pos_id = size_t(pos_ids[blockIdx.x]);
auto table_offset = pos_id * table_dim;
for (size_t i = threadIdx.x; i < table_dim; i += blockDim.x) {
Tangle sin__ = sin_table[table_offset + i],
cos__ = cos_table[table_offset + i];
if constexpr (std::is_same<Tdata, half>::value) {
auto &y = reinterpret_cast<half2 &>(y_[y_offset + 2 * i]);
auto &x = reinterpret_cast<const half2 &>(x_[x_offset + 2 * i]);
Tangle y0 = x.x * cos__ - x.y * sin__,
y1 = x.x * sin__ + x.y * cos__;
y = half2(y0, y1);
} else {
Tangle x0 = x_[x_offset + 2 * i],
x1 = x_[x_offset + 2 * i + 1];
y_[y_offset + 2 * i] = Tdata(x0 * cos__ - x1 * sin__);
y_[y_offset + 2 * i + 1] = Tdata(x0 * sin__ + x1 * cos__);
}
}
}
#endif
...@@ -3,6 +3,6 @@ ...@@ -3,6 +3,6 @@
#include "../rope.h" #include "../rope.h"
DESCRIPTOR(maca) DESCRIPTOR(metax)
#endif // __INFINIOP_ROPE_MACA_H__ #endif // __INFINIOP_ROPE_MACA_H__
#include "../../../devices/maca/common_maca.h" #include "../../../devices/maca/common_maca.h"
#include "rope_maca.h" #include "rope_metax.h"
#include "rope_maca_kernel.h"
#include "../../../devices/maca/maca_kernel_common.h"
#include "../cuda/kernel.cuh"
template <typename Tdata, typename Tindex, typename Tangle>
INFINIOP_MACA_KERNEL ropeThreadPerItemKernel(
Tdata *y_,
const Tdata *x_,
const Tindex *__restrict__ pos_ids,
const Tangle *__restrict__ sin_table,
const Tangle *__restrict__ cos_table,
size_t table_dim,
ptrdiff_t y_stride_seqlen,
ptrdiff_t y_stride_nhead,
ptrdiff_t x_stride_seqlen,
ptrdiff_t x_stride_nhead) {
ropeThreadPerItemBlock(
y_, x_, pos_ids,
sin_table, cos_table,
table_dim,
y_stride_seqlen, y_stride_nhead,
x_stride_seqlen, x_stride_nhead);
}
namespace op::rope::maca { namespace op::rope::metax {
struct Descriptor::Opaque { struct Descriptor::Opaque {
std::shared_ptr<device::maca::Handle::Internal> internal; std::shared_ptr<device::maca::Handle::Internal> internal;
...@@ -50,7 +73,7 @@ infiniStatus_t calculateRoPE(const RoPEInfo &info, ...@@ -50,7 +73,7 @@ infiniStatus_t calculateRoPE(const RoPEInfo &info,
dimy = uint32_t(info.nhead); dimy = uint32_t(info.nhead);
int nthreads = std::max(int(info.table_dim), block_size); int nthreads = std::max(int(info.table_dim), block_size);
ropeThreadPerItem<<<dim3(dimx, dimy), nthreads, 0, stream>>>( ropeThreadPerItemKernel<<<dim3(dimx, dimy), nthreads, 0, stream>>>(
y, x, pos_ids, sin_table, cos_table, info.table_dim, y, x, pos_ids, sin_table, cos_table, info.table_dim,
info.y_stride_seqlen, info.y_stride_nhead, info.x_stride_seqlen, info.x_stride_nhead); info.y_stride_seqlen, info.y_stride_nhead, info.x_stride_seqlen, info.x_stride_nhead);
...@@ -102,6 +125,8 @@ infiniStatus_t Descriptor::calculate( ...@@ -102,6 +125,8 @@ infiniStatus_t Descriptor::calculate(
switch (_info.data_type) { switch (_info.data_type) {
case INFINI_DTYPE_F16: case INFINI_DTYPE_F16:
ROPE_TYPE(half); ROPE_TYPE(half);
case INFINI_DTYPE_BF16:
ROPE_TYPE(cuda_bfloat16);
case INFINI_DTYPE_F32: case INFINI_DTYPE_F32:
ROPE_TYPE(float); ROPE_TYPE(float);
case INFINI_DTYPE_F64: case INFINI_DTYPE_F64:
......
...@@ -5,6 +5,26 @@ ...@@ -5,6 +5,26 @@
#include "../cuda/kernel.cuh" #include "../cuda/kernel.cuh"
template <typename Tdata, typename Tindex, typename Tangle>
INFINIOP_CUDA_KERNEL ropeThreadPerItemKernel(
Tdata *y_,
const Tdata *x_,
const Tindex *__restrict__ pos_ids,
const Tangle *__restrict__ sin_table,
const Tangle *__restrict__ cos_table,
size_t table_dim,
ptrdiff_t y_stride_seqlen,
ptrdiff_t y_stride_nhead,
ptrdiff_t x_stride_seqlen,
ptrdiff_t x_stride_nhead) {
ropeThreadPerItemBlock(
y_, x_, pos_ids,
sin_table, cos_table,
table_dim,
y_stride_seqlen, y_stride_nhead,
x_stride_seqlen, x_stride_nhead);
}
namespace op::rope::nvidia { namespace op::rope::nvidia {
struct Descriptor::Opaque { struct Descriptor::Opaque {
...@@ -53,7 +73,7 @@ infiniStatus_t calculateRoPE(const RoPEInfo &info, ...@@ -53,7 +73,7 @@ infiniStatus_t calculateRoPE(const RoPEInfo &info,
dimy = uint32_t(info.nhead); dimy = uint32_t(info.nhead);
int nthreads = std::max(int(info.table_dim), block_size); int nthreads = std::max(int(info.table_dim), block_size);
ropeThreadPerItem<<<dim3(dimx, dimy), nthreads, 0, stream>>>( ropeThreadPerItemKernel<<<dim3(dimx, dimy), nthreads, 0, stream>>>(
y, x, pos_ids, sin_table, cos_table, info.table_dim, y, x, pos_ids, sin_table, cos_table, info.table_dim,
info.y_stride_seqlen, info.y_stride_nhead, info.x_stride_seqlen, info.x_stride_nhead); info.y_stride_seqlen, info.y_stride_nhead, info.x_stride_seqlen, info.x_stride_nhead);
...@@ -106,7 +126,7 @@ infiniStatus_t Descriptor::calculate( ...@@ -106,7 +126,7 @@ infiniStatus_t Descriptor::calculate(
case INFINI_DTYPE_F16: case INFINI_DTYPE_F16:
ROPE_TYPE(half); ROPE_TYPE(half);
case INFINI_DTYPE_BF16: case INFINI_DTYPE_BF16:
ROPE_TYPE(__nv_bfloat16); ROPE_TYPE(cuda_bfloat16);
case INFINI_DTYPE_F32: case INFINI_DTYPE_F32:
ROPE_TYPE(float); ROPE_TYPE(float);
case INFINI_DTYPE_F64: case INFINI_DTYPE_F64:
...@@ -121,4 +141,4 @@ infiniStatus_t Descriptor::calculate( ...@@ -121,4 +141,4 @@ infiniStatus_t Descriptor::calculate(
#undef ROPE_TYPE #undef ROPE_TYPE
#undef CALCULATE_ROPE #undef CALCULATE_ROPE
} // namespace op::rope::cuda } // namespace op::rope::nvidia
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
#include "ascend/rope_ascend.h" #include "ascend/rope_ascend.h"
#endif #endif
#ifdef ENABLE_METAX_API #ifdef ENABLE_METAX_API
#include "maca/rope_maca.h" #include "metax/rope_metax.h"
#endif #endif
__C infiniStatus_t infiniopCreateRoPEDescriptor( __C infiniStatus_t infiniopCreateRoPEDescriptor(
...@@ -43,7 +43,7 @@ __C infiniStatus_t infiniopCreateRoPEDescriptor( ...@@ -43,7 +43,7 @@ __C infiniStatus_t infiniopCreateRoPEDescriptor(
CREATE(INFINI_DEVICE_NVIDIA, nvidia); CREATE(INFINI_DEVICE_NVIDIA, nvidia);
#endif #endif
#ifdef ENABLE_METAX_API #ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, maca); CREATE(INFINI_DEVICE_METAX, metax);
#endif #endif
#ifdef ENABLE_ASCEND_API #ifdef ENABLE_ASCEND_API
CREATE(INFINI_DEVICE_ASCEND, ascend); CREATE(INFINI_DEVICE_ASCEND, ascend);
...@@ -84,7 +84,7 @@ __C infiniStatus_t infiniopGetRoPEWorkspaceSize(infiniopRoPEDescriptor_t desc, ...@@ -84,7 +84,7 @@ __C infiniStatus_t infiniopGetRoPEWorkspaceSize(infiniopRoPEDescriptor_t desc,
GET(INFINI_DEVICE_NVIDIA, nvidia); GET(INFINI_DEVICE_NVIDIA, nvidia);
#endif #endif
#ifdef ENABLE_METAX_API #ifdef ENABLE_METAX_API
GET(INFINI_DEVICE_METAX, maca); GET(INFINI_DEVICE_METAX, metax);
#endif #endif
#ifdef ENABLE_CAMBRICON_MLU #ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: { case DevCambriconMlu: {
...@@ -135,7 +135,7 @@ __C infiniStatus_t infiniopRoPE( ...@@ -135,7 +135,7 @@ __C infiniStatus_t infiniopRoPE(
CALCULATE(INFINI_DEVICE_NVIDIA, nvidia); CALCULATE(INFINI_DEVICE_NVIDIA, nvidia);
#endif #endif
#ifdef ENABLE_METAX_API #ifdef ENABLE_METAX_API
CALCULATE(INFINI_DEVICE_METAX, maca); CALCULATE(INFINI_DEVICE_METAX, metax);
#endif #endif
#ifdef ENABLE_CAMBRICON_MLU #ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: { case DevCambriconMlu: {
...@@ -181,7 +181,7 @@ infiniopDestroyRoPEDescriptor(infiniopRoPEDescriptor_t desc) { ...@@ -181,7 +181,7 @@ infiniopDestroyRoPEDescriptor(infiniopRoPEDescriptor_t desc) {
DELETE(INFINI_DEVICE_NVIDIA, nvidia); DELETE(INFINI_DEVICE_NVIDIA, nvidia);
#endif #endif
#ifdef ENABLE_METAX_API #ifdef ENABLE_METAX_API
DELETE(INFINI_DEVICE_METAX, maca); DELETE(INFINI_DEVICE_METAX, metax);
#endif #endif
#ifdef ENABLE_CAMBRICON_MLU #ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: { case DevCambriconMlu: {
......
#ifndef __SWIGLU_CUDA_H__ #ifndef __SWIGLU_CUDA_H__
#define __SWIGLU_CUDA_H__ #define __SWIGLU_CUDA_H__
#include "../../../elementwise/cuda/elementwise_cuda.cuh"
#include <cuda_bf16.h>
#include <cuda_fp16.h>
namespace op::swiglu::cuda { namespace op::swiglu::cuda {
typedef struct SwiGLUOp { typedef struct SwiGLUOp {
private: private:
...@@ -14,13 +10,13 @@ private: ...@@ -14,13 +10,13 @@ private:
return h2rcp(__hadd2(make_half2(1, 1), h2exp(__hneg2(x)))); return h2rcp(__hadd2(make_half2(1, 1), h2exp(__hneg2(x))));
} else if constexpr (std::is_same_v<T, half>) { } else if constexpr (std::is_same_v<T, half>) {
return hrcp(__hadd(half(1.f), __float2half(__expf(__half2float(__hneg(x)))))); return hrcp(__hadd(half(1.f), __float2half(__expf(__half2float(__hneg(x))))));
} else if constexpr (std::is_same_v<T, __nv_bfloat162>) { } else if constexpr (std::is_same_v<T, cuda_bfloat162>) {
float x0 = __bfloat162float(__low2bfloat16(x)); float x0 = __bfloat162float(__low2bfloat16(x));
float x1 = __bfloat162float(__high2bfloat16(x)); float x1 = __bfloat162float(__high2bfloat16(x));
float sig0 = __frcp_rn(__fadd_rn(1.0f, __expf(-x0))); float sig0 = __frcp_rn(__fadd_rn(1.0f, __expf(-x0)));
float sig1 = __frcp_rn(__fadd_rn(1.0f, __expf(-x1))); float sig1 = __frcp_rn(__fadd_rn(1.0f, __expf(-x1)));
return __floats2bfloat162_rn(sig0, sig1); return __floats2bfloat162_rn(sig0, sig1);
} else if constexpr (std::is_same_v<T, __nv_bfloat16>) { } else if constexpr (std::is_same_v<T, cuda_bfloat16>) {
float xf = __bfloat162float(x); float xf = __bfloat162float(x);
return __float2bfloat16_rn(__frcp_rn(__fadd_rn(1.0f, __expf(-xf)))); return __float2bfloat16_rn(__frcp_rn(__fadd_rn(1.0f, __expf(-xf))));
} else if constexpr (std::is_same_v<T, float>) { } else if constexpr (std::is_same_v<T, float>) {
...@@ -38,8 +34,8 @@ public: ...@@ -38,8 +34,8 @@ public:
return __hmul2(__hmul2(gate, sigmoid(gate)), up); return __hmul2(__hmul2(gate, sigmoid(gate)), up);
} else if constexpr (std::is_same_v<T, half>) { } else if constexpr (std::is_same_v<T, half>) {
return __hmul(__hmul(gate, sigmoid(gate)), up); return __hmul(__hmul(gate, sigmoid(gate)), up);
} else if constexpr (std::is_same_v<T, __nv_bfloat162>) { } else if constexpr (std::is_same_v<T, cuda_bfloat162>) {
__nv_bfloat162 sig = sigmoid(gate); cuda_bfloat162 sig = sigmoid(gate);
float gate0 = __bfloat162float(__low2bfloat16(gate)); float gate0 = __bfloat162float(__low2bfloat16(gate));
float gate1 = __bfloat162float(__high2bfloat16(gate)); float gate1 = __bfloat162float(__high2bfloat16(gate));
float sig0 = __bfloat162float(__low2bfloat16(sig)); float sig0 = __bfloat162float(__low2bfloat16(sig));
...@@ -49,8 +45,8 @@ public: ...@@ -49,8 +45,8 @@ public:
float res0 = __fmul_rn(__fmul_rn(gate0, sig0), up0); float res0 = __fmul_rn(__fmul_rn(gate0, sig0), up0);
float res1 = __fmul_rn(__fmul_rn(gate1, sig1), up1); float res1 = __fmul_rn(__fmul_rn(gate1, sig1), up1);
return __floats2bfloat162_rn(res0, res1); return __floats2bfloat162_rn(res0, res1);
} else if constexpr (std::is_same_v<T, __nv_bfloat16>) { } else if constexpr (std::is_same_v<T, cuda_bfloat16>) {
__nv_bfloat16 sig = sigmoid(gate); cuda_bfloat16 sig = sigmoid(gate);
float gatef = __bfloat162float(gate); float gatef = __bfloat162float(gate);
float sigf = __bfloat162float(sig); float sigf = __bfloat162float(sig);
float upf = __bfloat162float(up); float upf = __bfloat162float(up);
......
#ifndef __SWIGLU_MACA_H__
#define __SWIGLU_MACA_H__
#include "../../../elementwise/maca/elementwise_maca.h"
#include <hctlass/half.h>
namespace op::swiglu::maca {
typedef struct SwiGLUOp {
private:
template <typename T>
__device__ __forceinline__ T sigmoid(const T &x) const {
if constexpr (std::is_same_v<T, half2>) {
return h2rcp(__hadd2(make_half2(1, 1), h2exp(__hneg2(x))));
} else if constexpr (std::is_same_v<T, half>) {
return hrcp(__hadd(half(1.f), __float2half(__expf(__half2float(__hneg(x))))));
} else if constexpr (std::is_same_v<T, float>) {
return __frcp_rn(__fadd_rn(1, __expf(-x)));
} else {
return 1 / (1 + std::exp(-x));
}
}
public:
static constexpr size_t num_inputs = 2;
template <typename T>
__device__ __forceinline__ T operator()(const T &up, const T &gate) const {
if constexpr (std::is_same_v<T, half2>) {
return __hmul2(__hmul2(gate, sigmoid(gate)), up);
} else if constexpr (std::is_same_v<T, half>) {
return __hmul(__hmul(gate, sigmoid(gate)), up);
} else if constexpr (std::is_same_v<T, float>) {
return __fmul_rn(__fmul_rn(gate, sigmoid(gate)), up);
} else {
return gate * sigmoid(gate) * up;
}
}
} SwiGLUOp;
} // namespace op::swiglu::maca
#endif
...@@ -3,6 +3,6 @@ ...@@ -3,6 +3,6 @@
#include "../../../elementwise/maca/elementwise_maca_api.h" #include "../../../elementwise/maca/elementwise_maca_api.h"
ELEMENTWISE_DESCRIPTOR(swiglu, maca) ELEMENTWISE_DESCRIPTOR(swiglu, metax, maca)
#endif // __SWIGLU_MACA_API_H__ #endif // __SWIGLU_MACA_API_H__
#include "swiglu_maca.h" #include "swiglu_metax.h"
#include "swiglu_maca_internal.h"
namespace op::swiglu::maca { #include "../../../elementwise/maca/elementwise_maca.h"
#include "../cuda/kernel.cuh"
namespace op::swiglu::metax {
Descriptor::~Descriptor() = default; Descriptor::~Descriptor() = default;
...@@ -20,7 +23,7 @@ infiniStatus_t Descriptor::create( ...@@ -20,7 +23,7 @@ infiniStatus_t Descriptor::create(
const auto &up_shape = up_desc->shape(); const auto &up_shape = up_desc->shape();
const auto &gate_shape = gate_desc->shape(); const auto &gate_shape = gate_desc->shape();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64); CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32, INFINI_DTYPE_F64);
CHECK_SAME_SHAPE(out_shape, up_shape, gate_shape); CHECK_SAME_SHAPE(out_shape, up_shape, gate_shape);
// create MACA elementwise descriptor // create MACA elementwise descriptor
...@@ -42,15 +45,17 @@ infiniStatus_t Descriptor::calculate( ...@@ -42,15 +45,17 @@ infiniStatus_t Descriptor::calculate(
switch (_dtype) { switch (_dtype) {
case INFINI_DTYPE_F16: case INFINI_DTYPE_F16:
return _device_info->calculate<256, SwiGLUOp, half>(_info, workspace, output, inputs, stream); return _device_info->calculate<256, cuda::SwiGLUOp, half>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_BF16:
return _device_info->calculate<256, cuda::SwiGLUOp, cuda_bfloat16>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F32: case INFINI_DTYPE_F32:
return _device_info->calculate<256, SwiGLUOp, float>(_info, workspace, output, inputs, stream); return _device_info->calculate<256, cuda::SwiGLUOp, float>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F64: case INFINI_DTYPE_F64:
return _device_info->calculate<256, SwiGLUOp, double>(_info, workspace, output, inputs, stream); return _device_info->calculate<256, cuda::SwiGLUOp, double>(_info, workspace, output, inputs, stream);
default: default:
return INFINI_STATUS_BAD_TENSOR_DTYPE; return INFINI_STATUS_BAD_TENSOR_DTYPE;
} }
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
} // namespace op::swiglu::maca } // namespace op::swiglu::metax
#include "swiglu_nvidia.cuh" #include "swiglu_nvidia.cuh"
#include "../../../elementwise/cuda/elementwise_cuda.cuh"
#include "../cuda/kernel.cuh" #include "../cuda/kernel.cuh"
namespace op::swiglu::nvidia { namespace op::swiglu::nvidia {
...@@ -44,7 +47,7 @@ infiniStatus_t Descriptor::calculate( ...@@ -44,7 +47,7 @@ infiniStatus_t Descriptor::calculate(
case INFINI_DTYPE_F16: case INFINI_DTYPE_F16:
return _device_info->calculate<256, cuda::SwiGLUOp, half>(_info, workspace, output, inputs, stream); return _device_info->calculate<256, cuda::SwiGLUOp, half>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_BF16: case INFINI_DTYPE_BF16:
return _device_info->calculate<256, cuda::SwiGLUOp, __nv_bfloat16>(_info, workspace, output, inputs, stream); return _device_info->calculate<256, cuda::SwiGLUOp, cuda_bfloat16>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F32: case INFINI_DTYPE_F32:
return _device_info->calculate<256, cuda::SwiGLUOp, float>(_info, workspace, output, inputs, stream); return _device_info->calculate<256, cuda::SwiGLUOp, float>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F64: case INFINI_DTYPE_F64:
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
#include "kunlun/swiglu_kunlun.h" #include "kunlun/swiglu_kunlun.h"
#endif #endif
#ifdef ENABLE_METAX_API #ifdef ENABLE_METAX_API
#include "maca/swiglu_maca.h" #include "metax/swiglu_metax.h"
#endif #endif
#ifdef ENABLE_ASCEND_API #ifdef ENABLE_ASCEND_API
#include "ascend/swiglu_ascend.h" #include "ascend/swiglu_ascend.h"
...@@ -46,7 +46,7 @@ __C infiniStatus_t infiniopCreateSwiGLUDescriptor( ...@@ -46,7 +46,7 @@ __C infiniStatus_t infiniopCreateSwiGLUDescriptor(
CREATE(INFINI_DEVICE_KUNLUN, kunlun); CREATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif #endif
#ifdef ENABLE_METAX_API #ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, maca); CREATE(INFINI_DEVICE_METAX, metax);
#endif #endif
#ifdef ENABLE_CAMBRICON_MLU #ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: { case DevCambriconMlu: {
...@@ -96,7 +96,7 @@ __C infiniStatus_t infiniopGetSwiGLUWorkspaceSize(infiniopSwiGLUDescriptor_t des ...@@ -96,7 +96,7 @@ __C infiniStatus_t infiniopGetSwiGLUWorkspaceSize(infiniopSwiGLUDescriptor_t des
GET(INFINI_DEVICE_KUNLUN, kunlun); GET(INFINI_DEVICE_KUNLUN, kunlun);
#endif #endif
#ifdef ENABLE_METAX_API #ifdef ENABLE_METAX_API
GET(INFINI_DEVICE_METAX, maca); GET(INFINI_DEVICE_METAX, metax);
#endif #endif
#ifdef ENABLE_CAMBRICON_MLU #ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: { case DevCambriconMlu: {
...@@ -144,7 +144,7 @@ __C infiniStatus_t infiniopSwiGLU( ...@@ -144,7 +144,7 @@ __C infiniStatus_t infiniopSwiGLU(
CALCULATE(INFINI_DEVICE_KUNLUN, kunlun); CALCULATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif #endif
#ifdef ENABLE_METAX_API #ifdef ENABLE_METAX_API
CALCULATE(INFINI_DEVICE_METAX, maca); CALCULATE(INFINI_DEVICE_METAX, metax);
#endif #endif
#ifdef ENABLE_CAMBRICON_MLU #ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: { case DevCambriconMlu: {
...@@ -190,7 +190,7 @@ infiniopDestroySwiGLUDescriptor(infiniopSwiGLUDescriptor_t desc) { ...@@ -190,7 +190,7 @@ infiniopDestroySwiGLUDescriptor(infiniopSwiGLUDescriptor_t desc) {
DELETE(INFINI_DEVICE_KUNLUN, kunlun); DELETE(INFINI_DEVICE_KUNLUN, kunlun);
#endif #endif
#ifdef ENABLE_METAX_API #ifdef ENABLE_METAX_API
DELETE(INFINI_DEVICE_METAX, maca); DELETE(INFINI_DEVICE_METAX, metax);
#endif #endif
#ifdef ENABLE_CAMBRICON_MLU #ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: { case DevCambriconMlu: {
......
...@@ -174,7 +174,7 @@ target("infini-utils") ...@@ -174,7 +174,7 @@ target("infini-utils")
add_cxflags("-fPIC", "-Wno-unknown-pragmas") add_cxflags("-fPIC", "-Wno-unknown-pragmas")
if has_config("omp") then if has_config("omp") then
add_cxflags("-fopenmp") add_cxflags("-fopenmp")
add_ldflags("-fopenmp") add_ldflags("-fopenmp", {force = true})
end end
end end
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment