Unverified Commit e4605f7c authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #293 from YdrMaster/distinct-cuda

issue291 合并 cuda 代码
parents 5025ebed eac2b0ca
#include "../../../devices/cuda/cuda_common.cuh"
#include "rms_norm_cuda.cuh"
#include "rms_norm_kernel.cuh"
#include "rms_norm_nvidia.cuh"
namespace op::rms_norm::cuda {
#include "../../../devices/cuda/cuda_kernel_common.cuh"
#include <cub/block/block_reduce.cuh>
#include "../../../reduce/cuda/reduce.cuh"
#include "../cuda/kernel.cuh"
template <unsigned int BLOCK_SIZE, typename Tcompute, typename Tdata, typename Tweight>
INFINIOP_CUDA_KERNEL rmsnormKernel(
Tdata *__restrict__ y,
ptrdiff_t stride_y,
const Tdata *__restrict__ x,
ptrdiff_t stride_x,
const Tweight *__restrict__ w,
size_t dim,
float epsilon) {
rmsnormBlock<BLOCK_SIZE, Tcompute>(y, stride_y, x, stride_x, w, dim, epsilon);
}
namespace op::rms_norm::nvidia {
struct Descriptor::Opaque {
std::shared_ptr<device::cuda::Handle::Internal> internal;
......@@ -46,14 +64,14 @@ infiniStatus_t launchKernel(
float epsilon,
cudaStream_t cuda_stream) {
#define LAUNCH_KERNEL(Tdata, Tweight, Tcompute) \
rmsnormBlock<BLOCK_SIZE, Tdata, Tweight, Tcompute><<<batch_size, BLOCK_SIZE, 0, cuda_stream>>>( \
reinterpret_cast<Tdata *>(y), \
stride_y, \
reinterpret_cast<const Tdata *>(x), \
stride_x, \
reinterpret_cast<const Tweight *>(w), \
dim, \
#define LAUNCH_KERNEL(Tdata, Tweight, Tcompute) \
rmsnormKernel<BLOCK_SIZE, Tcompute, Tdata, Tweight><<<batch_size, BLOCK_SIZE, 0, cuda_stream>>>( \
reinterpret_cast<Tdata *>(y), \
stride_y, \
reinterpret_cast<const Tdata *>(x), \
stride_x, \
reinterpret_cast<const Tweight *>(w), \
dim, \
epsilon)
if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) {
......@@ -102,4 +120,4 @@ infiniStatus_t Descriptor::calculate(
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::rms_norm::cuda
} // namespace op::rms_norm::nvidia
......@@ -3,6 +3,6 @@
#include "../rms_norm.h"
DESCRIPTOR(cuda)
DESCRIPTOR(nvidia)
#endif
......@@ -6,13 +6,13 @@
#include "cpu/rms_norm_cpu.h"
#endif
#ifdef ENABLE_NVIDIA_API
#include "cuda/rms_norm_cuda.cuh"
#include "nvidia/rms_norm_nvidia.cuh"
#endif
#ifdef ENABLE_ASCEND_API
#include "ascend/rms_norm_aclnn.h"
#endif
#ifdef ENABLE_METAX_API
#include "maca/rms_norm_maca.cuh"
#include "metax/rms_norm_metax.cuh"
#endif
#ifdef ENABLE_MOORE_API
#include "musa/rms_norm_musa.cuh"
......@@ -37,17 +37,17 @@ __C infiniStatus_t infiniopCreateRMSNormDescriptor(
y_desc, \
x_desc, \
w_desc, \
epsilon);
epsilon)
switch (handle->device) {
#ifdef ENABLE_CPU_API
CREATE(INFINI_DEVICE_CPU, cpu)
CREATE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NVIDIA_API
CREATE(INFINI_DEVICE_NVIDIA, cuda)
CREATE(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_KUNLUN_API
CREATE(INFINI_DEVICE_KUNLUN, kunlun)
CREATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: {
......@@ -55,13 +55,13 @@ __C infiniStatus_t infiniopCreateRMSNormDescriptor(
}
#endif
#ifdef ENABLE_ASCEND_API
CREATE(INFINI_DEVICE_ASCEND, ascend)
CREATE(INFINI_DEVICE_ASCEND, ascend);
#endif
#ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, maca)
CREATE(INFINI_DEVICE_METAX, maca);
#endif
#ifdef ENABLE_MOORE_API
CREATE(INFINI_DEVICE_MOORE, musa)
CREATE(INFINI_DEVICE_MOORE, musa);
#endif
}
......@@ -75,17 +75,17 @@ __C infiniStatus_t infiniopGetRMSNormWorkspaceSize(infiniopRMSNormDescriptor_t d
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<op::rms_norm::NAMESPACE::Descriptor *>(desc)->workspaceSize(); \
return INFINI_STATUS_SUCCESS;
return INFINI_STATUS_SUCCESS
switch (desc->device_type) {
#ifdef ENABLE_CPU_API
GET(INFINI_DEVICE_CPU, cpu)
GET(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NVIDIA_API
GET(INFINI_DEVICE_NVIDIA, cuda)
GET(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_KUNLUN_API
GET(INFINI_DEVICE_KUNLUN, kunlun)
GET(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: {
......@@ -93,13 +93,13 @@ __C infiniStatus_t infiniopGetRMSNormWorkspaceSize(infiniopRMSNormDescriptor_t d
}
#endif
#ifdef ENABLE_ASCEND_API
GET(INFINI_DEVICE_ASCEND, ascend)
GET(INFINI_DEVICE_ASCEND, ascend);
#endif
#ifdef ENABLE_METAX_API
GET(INFINI_DEVICE_METAX, maca)
GET(INFINI_DEVICE_METAX, maca);
#endif
#ifdef ENABLE_MOORE_API
GET(INFINI_DEVICE_MOORE, musa)
GET(INFINI_DEVICE_MOORE, musa);
#endif
}
......@@ -114,17 +114,17 @@ __C infiniStatus_t infiniopRMSNorm(infiniopRMSNormDescriptor_t desc, void *works
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<op::rms_norm::NAMESPACE::Descriptor *>(desc)->calculate( \
workspace, workspace_size, y, x, w, stream);
workspace, workspace_size, y, x, w, stream)
switch (desc->device_type) {
#ifdef ENABLE_CPU_API
CALCULATE(INFINI_DEVICE_CPU, cpu)
CALCULATE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NVIDIA_API
CALCULATE(INFINI_DEVICE_NVIDIA, cuda)
CALCULATE(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_KUNLUN_API
CALCULATE(INFINI_DEVICE_KUNLUN, kunlun)
CALCULATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: {
......@@ -132,13 +132,13 @@ __C infiniStatus_t infiniopRMSNorm(infiniopRMSNormDescriptor_t desc, void *works
}
#endif
#ifdef ENABLE_ASCEND_API
CALCULATE(INFINI_DEVICE_ASCEND, ascend)
CALCULATE(INFINI_DEVICE_ASCEND, ascend);
#endif
#ifdef ENABLE_METAX_API
CALCULATE(INFINI_DEVICE_METAX, maca)
CALCULATE(INFINI_DEVICE_METAX, maca);
#endif
#ifdef ENABLE_MOORE_API
CALCULATE(INFINI_DEVICE_MOORE, musa)
CALCULATE(INFINI_DEVICE_MOORE, musa);
#endif
}
......@@ -152,17 +152,17 @@ __C infiniStatus_t infiniopDestroyRMSNormDescriptor(infiniopRMSNormDescriptor_t
#define DESTROY(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<op::rms_norm::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS;
return INFINI_STATUS_SUCCESS
switch (desc->device_type) {
#ifdef ENABLE_CPU_API
DESTROY(INFINI_DEVICE_CPU, cpu)
DESTROY(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NVIDIA_API
DESTROY(INFINI_DEVICE_NVIDIA, cuda)
DESTROY(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_KUNLUN_API
DESTROY(INFINI_DEVICE_KUNLUN, kunlun)
DESTROY(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: {
......@@ -170,13 +170,13 @@ __C infiniStatus_t infiniopDestroyRMSNormDescriptor(infiniopRMSNormDescriptor_t
}
#endif
#ifdef ENABLE_ASCEND_API
DESTROY(INFINI_DEVICE_ASCEND, ascend)
DESTROY(INFINI_DEVICE_ASCEND, ascend);
#endif
#ifdef ENABLE_METAX_API
DESTROY(INFINI_DEVICE_METAX, maca)
DESTROY(INFINI_DEVICE_METAX, maca);
#endif
#ifdef ENABLE_MOORE_API
DESTROY(INFINI_DEVICE_MOORE, musa)
DESTROY(INFINI_DEVICE_MOORE, musa);
#endif
}
......
#ifndef __INFINIOP_ROPE_CUDA_KERNEL_CUH__
#define __INFINIOP_ROPE_CUDA_KERNEL_CUH__
#include "../../../devices/cuda/cuda_kernel_common.cuh"
template <typename Tdata, typename Tindex, typename Tangle>
INFINIOP_CUDA_KERNEL ropeThreadPerItem(
__device__ void ropeThreadPerItemBlock(
Tdata *y_,
const Tdata *x_,
const Tindex *__restrict__ pos_ids,
......@@ -30,9 +28,9 @@ INFINIOP_CUDA_KERNEL ropeThreadPerItem(
Tangle y0 = x.x * cos__ - x.y * sin__,
y1 = x.x * sin__ + x.y * cos__;
y = half2(y0, y1);
} else if constexpr (std::is_same<Tdata, __nv_bfloat16>::value) {
auto &y = reinterpret_cast<__nv_bfloat162 &>(y_[y_offset + 2 * i]);
auto &x = reinterpret_cast<const __nv_bfloat162 &>(x_[x_offset + 2 * i]);
} else if constexpr (std::is_same<Tdata, cuda_bfloat16>::value) {
auto &y = reinterpret_cast<cuda_bfloat162 &>(y_[y_offset + 2 * i]);
auto &x = reinterpret_cast<const cuda_bfloat162 &>(x_[x_offset + 2 * i]);
Tangle x0 = __low2bfloat16(x);
Tangle x1 = __high2bfloat16(x);
......
#ifndef __INFINIOP_ROPE_MACA_KERNEL_H__
#define __INFINIOP_ROPE_MACA_KERNEL_H__
#include "../../../devices/maca/maca_kernel_common.h"
template <typename Tdata, typename Tindex, typename Tangle>
INFINIOP_MACA_KERNEL ropeThreadPerItem(
Tdata *y_,
const Tdata *x_,
const Tindex *__restrict__ pos_ids,
const Tangle *__restrict__ sin_table,
const Tangle *__restrict__ cos_table,
size_t table_dim,
ptrdiff_t y_stride_seqlen,
ptrdiff_t y_stride_nhead,
ptrdiff_t x_stride_seqlen,
ptrdiff_t x_stride_nhead) {
auto y_offset = blockIdx.x * y_stride_seqlen + blockIdx.y * y_stride_nhead;
auto x_offset = blockIdx.x * x_stride_seqlen + blockIdx.y * x_stride_nhead;
size_t pos_id = size_t(pos_ids[blockIdx.x]);
auto table_offset = pos_id * table_dim;
for (size_t i = threadIdx.x; i < table_dim; i += blockDim.x) {
Tangle sin__ = sin_table[table_offset + i],
cos__ = cos_table[table_offset + i];
if constexpr (std::is_same<Tdata, half>::value) {
auto &y = reinterpret_cast<half2 &>(y_[y_offset + 2 * i]);
auto &x = reinterpret_cast<const half2 &>(x_[x_offset + 2 * i]);
Tangle y0 = x.x * cos__ - x.y * sin__,
y1 = x.x * sin__ + x.y * cos__;
y = half2(y0, y1);
} else {
Tangle x0 = x_[x_offset + 2 * i],
x1 = x_[x_offset + 2 * i + 1];
y_[y_offset + 2 * i] = Tdata(x0 * cos__ - x1 * sin__);
y_[y_offset + 2 * i + 1] = Tdata(x0 * sin__ + x1 * cos__);
}
}
}
#endif
......@@ -3,6 +3,6 @@
#include "../rope.h"
DESCRIPTOR(maca)
DESCRIPTOR(metax)
#endif // __INFINIOP_ROPE_MACA_H__
#include "../../../devices/maca/common_maca.h"
#include "rope_maca.h"
#include "rope_maca_kernel.h"
#include "rope_metax.h"
#include "../../../devices/maca/maca_kernel_common.h"
#include "../cuda/kernel.cuh"
template <typename Tdata, typename Tindex, typename Tangle>
INFINIOP_MACA_KERNEL ropeThreadPerItemKernel(
Tdata *y_,
const Tdata *x_,
const Tindex *__restrict__ pos_ids,
const Tangle *__restrict__ sin_table,
const Tangle *__restrict__ cos_table,
size_t table_dim,
ptrdiff_t y_stride_seqlen,
ptrdiff_t y_stride_nhead,
ptrdiff_t x_stride_seqlen,
ptrdiff_t x_stride_nhead) {
ropeThreadPerItemBlock(
y_, x_, pos_ids,
sin_table, cos_table,
table_dim,
y_stride_seqlen, y_stride_nhead,
x_stride_seqlen, x_stride_nhead);
}
namespace op::rope::maca {
namespace op::rope::metax {
struct Descriptor::Opaque {
std::shared_ptr<device::maca::Handle::Internal> internal;
......@@ -50,7 +73,7 @@ infiniStatus_t calculateRoPE(const RoPEInfo &info,
dimy = uint32_t(info.nhead);
int nthreads = std::max(int(info.table_dim), block_size);
ropeThreadPerItem<<<dim3(dimx, dimy), nthreads, 0, stream>>>(
ropeThreadPerItemKernel<<<dim3(dimx, dimy), nthreads, 0, stream>>>(
y, x, pos_ids, sin_table, cos_table, info.table_dim,
info.y_stride_seqlen, info.y_stride_nhead, info.x_stride_seqlen, info.x_stride_nhead);
......@@ -102,6 +125,8 @@ infiniStatus_t Descriptor::calculate(
switch (_info.data_type) {
case INFINI_DTYPE_F16:
ROPE_TYPE(half);
case INFINI_DTYPE_BF16:
ROPE_TYPE(cuda_bfloat16);
case INFINI_DTYPE_F32:
ROPE_TYPE(float);
case INFINI_DTYPE_F64:
......
#include "../../../devices/cuda/cuda_common.cuh"
#include "rope_cuda.cuh"
#include "rope_cuda_kernel.cuh"
#include "rope_nvidia.cuh"
#include "../../../devices/cuda/cuda_kernel_common.cuh"
#include "../cuda/kernel.cuh"
template <typename Tdata, typename Tindex, typename Tangle>
INFINIOP_CUDA_KERNEL ropeThreadPerItemKernel(
Tdata *y_,
const Tdata *x_,
const Tindex *__restrict__ pos_ids,
const Tangle *__restrict__ sin_table,
const Tangle *__restrict__ cos_table,
size_t table_dim,
ptrdiff_t y_stride_seqlen,
ptrdiff_t y_stride_nhead,
ptrdiff_t x_stride_seqlen,
ptrdiff_t x_stride_nhead) {
ropeThreadPerItemBlock(
y_, x_, pos_ids,
sin_table, cos_table,
table_dim,
y_stride_seqlen, y_stride_nhead,
x_stride_seqlen, x_stride_nhead);
}
namespace op::rope::cuda {
namespace op::rope::nvidia {
struct Descriptor::Opaque {
std::shared_ptr<device::cuda::Handle::Internal> internal;
......@@ -50,7 +73,7 @@ infiniStatus_t calculateRoPE(const RoPEInfo &info,
dimy = uint32_t(info.nhead);
int nthreads = std::max(int(info.table_dim), block_size);
ropeThreadPerItem<<<dim3(dimx, dimy), nthreads, 0, stream>>>(
ropeThreadPerItemKernel<<<dim3(dimx, dimy), nthreads, 0, stream>>>(
y, x, pos_ids, sin_table, cos_table, info.table_dim,
info.y_stride_seqlen, info.y_stride_nhead, info.x_stride_seqlen, info.x_stride_nhead);
......@@ -103,7 +126,7 @@ infiniStatus_t Descriptor::calculate(
case INFINI_DTYPE_F16:
ROPE_TYPE(half);
case INFINI_DTYPE_BF16:
ROPE_TYPE(__nv_bfloat16);
ROPE_TYPE(cuda_bfloat16);
case INFINI_DTYPE_F32:
ROPE_TYPE(float);
case INFINI_DTYPE_F64:
......@@ -118,4 +141,4 @@ infiniStatus_t Descriptor::calculate(
#undef ROPE_TYPE
#undef CALCULATE_ROPE
} // namespace op::rope::cuda
} // namespace op::rope::nvidia
......@@ -3,6 +3,6 @@
#include "../rope.h"
DESCRIPTOR(cuda)
DESCRIPTOR(nvidia)
#endif // __INFINIOP_ROPE_CUDA_H__
......@@ -6,13 +6,13 @@
#include "cpu/rope_cpu.h"
#endif
#ifdef ENABLE_NVIDIA_API
#include "cuda/rope_cuda.cuh"
#include "nvidia/rope_nvidia.cuh"
#endif
#ifdef ENABLE_ASCEND_API
#include "ascend/rope_ascend.h"
#endif
#ifdef ENABLE_METAX_API
#include "maca/rope_maca.h"
#include "metax/rope_metax.h"
#endif
__C infiniStatus_t infiniopCreateRoPEDescriptor(
......@@ -40,10 +40,10 @@ __C infiniStatus_t infiniopCreateRoPEDescriptor(
CREATE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NVIDIA_API
CREATE(INFINI_DEVICE_NVIDIA, cuda);
CREATE(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, maca);
CREATE(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_ASCEND_API
CREATE(INFINI_DEVICE_ASCEND, ascend);
......@@ -81,10 +81,10 @@ __C infiniStatus_t infiniopGetRoPEWorkspaceSize(infiniopRoPEDescriptor_t desc,
GET(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NVIDIA_API
GET(INFINI_DEVICE_NVIDIA, cuda);
GET(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_METAX_API
GET(INFINI_DEVICE_METAX, maca);
GET(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: {
......@@ -132,10 +132,10 @@ __C infiniStatus_t infiniopRoPE(
CALCULATE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NVIDIA_API
CALCULATE(INFINI_DEVICE_NVIDIA, cuda);
CALCULATE(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_METAX_API
CALCULATE(INFINI_DEVICE_METAX, maca);
CALCULATE(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: {
......@@ -178,10 +178,10 @@ infiniopDestroyRoPEDescriptor(infiniopRoPEDescriptor_t desc) {
DELETE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NVIDIA_API
DELETE(INFINI_DEVICE_NVIDIA, cuda);
DELETE(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_METAX_API
DELETE(INFINI_DEVICE_METAX, maca);
DELETE(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: {
......
......@@ -3,7 +3,7 @@
#include "../../../elementwise/cpu/elementwise_cpu.h"
ELEMENTWISE_DESCRIPTOR(sub, cpu)
ELEMENTWISE_DESCRIPTOR(sub, cpu, cpu)
namespace op::sub::cpu {
typedef struct SubOp {
......
#include "sub_cuda.cuh"
#include "sub_cuda_internal.cuh"
#include "../cuda/kernel.cuh"
#include "sub_nvidia.cuh"
namespace op::sub::cuda {
namespace op::sub::nvidia {
Descriptor::~Descriptor() = default;
......@@ -43,17 +43,17 @@ infiniStatus_t Descriptor::calculate(
switch (_dtype) {
case INFINI_DTYPE_F16:
return _device_info->calculate<256, SubOp, half>(_info, workspace, output, inputs, stream);
return _device_info->calculate<256, cuda::SubOp, half>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F32:
return _device_info->calculate<256, SubOp, float>(_info, workspace, output, inputs, stream);
return _device_info->calculate<256, cuda::SubOp, float>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F64:
return _device_info->calculate<256, SubOp, double>(_info, workspace, output, inputs, stream);
return _device_info->calculate<256, cuda::SubOp, double>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_BF16:
return _device_info->calculate<256, SubOp, __nv_bfloat16>(_info, workspace, output, inputs, stream);
return _device_info->calculate<256, cuda::SubOp, __nv_bfloat16>(_info, workspace, output, inputs, stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::sub::cuda
} // namespace op::sub::nvidia
......@@ -3,6 +3,6 @@
#include "../../../elementwise/cuda/elementwise_cuda_api.cuh"
ELEMENTWISE_DESCRIPTOR(sub, cuda)
ELEMENTWISE_DESCRIPTOR(sub, nvidia, cuda)
#endif // __SUB_CUDA_API_H__
......@@ -6,7 +6,7 @@
#include "cpu/sub_cpu.h"
#endif
#ifdef ENABLE_NVIDIA_API
#include "cuda/sub_cuda.cuh"
#include "nvidia/sub_nvidia.cuh"
#endif
__C infiniStatus_t infiniopCreateSubDescriptor(
......@@ -31,7 +31,7 @@ __C infiniStatus_t infiniopCreateSubDescriptor(
CREATE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NVIDIA_API
CREATE(INFINI_DEVICE_NVIDIA, cuda);
CREATE(INFINI_DEVICE_NVIDIA, nvidia);
#endif
default:
......@@ -46,14 +46,14 @@ __C infiniStatus_t infiniopGetSubWorkspaceSize(infiniopSubDescriptor_t desc, siz
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<op::sub::NAMESPACE::Descriptor *>(desc)->workspaceSize(); \
return INFINI_STATUS_SUCCESS;
return INFINI_STATUS_SUCCESS
switch (desc->device_type) {
#ifdef ENABLE_CPU_API
GET(INFINI_DEVICE_CPU, cpu)
GET(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NVIDIA_API
GET(INFINI_DEVICE_NVIDIA, cuda)
GET(INFINI_DEVICE_NVIDIA, nvidia);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......@@ -83,7 +83,7 @@ __C infiniStatus_t infiniopSub(
CALCULATE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NVIDIA_API
CALCULATE(INFINI_DEVICE_NVIDIA, cuda);
CALCULATE(INFINI_DEVICE_NVIDIA, nvidia);
#endif
default:
......@@ -107,7 +107,7 @@ infiniopDestroySubDescriptor(infiniopSubDescriptor_t desc) {
DELETE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NVIDIA_API
DELETE(INFINI_DEVICE_NVIDIA, cuda);
DELETE(INFINI_DEVICE_NVIDIA, nvidia);
#endif
default:
......
......@@ -3,7 +3,7 @@
#include "../../../elementwise/cpu/elementwise_cpu.h"
ELEMENTWISE_DESCRIPTOR(swiglu, cpu)
ELEMENTWISE_DESCRIPTOR(swiglu, cpu, cpu)
namespace op::swiglu::cpu {
typedef struct SwiGLUOp {
......
#ifndef __SWIGLU_CUDA_H__
#define __SWIGLU_CUDA_H__
#include "../../../elementwise/cuda/elementwise_cuda.cuh"
#include <cuda_bf16.h>
#include <cuda_fp16.h>
namespace op::swiglu::cuda {
typedef struct SwiGLUOp {
private:
......@@ -14,13 +10,13 @@ private:
return h2rcp(__hadd2(make_half2(1, 1), h2exp(__hneg2(x))));
} else if constexpr (std::is_same_v<T, half>) {
return hrcp(__hadd(half(1.f), __float2half(__expf(__half2float(__hneg(x))))));
} else if constexpr (std::is_same_v<T, __nv_bfloat162>) {
} else if constexpr (std::is_same_v<T, cuda_bfloat162>) {
float x0 = __bfloat162float(__low2bfloat16(x));
float x1 = __bfloat162float(__high2bfloat16(x));
float sig0 = __frcp_rn(__fadd_rn(1.0f, __expf(-x0)));
float sig1 = __frcp_rn(__fadd_rn(1.0f, __expf(-x1)));
return __floats2bfloat162_rn(sig0, sig1);
} else if constexpr (std::is_same_v<T, __nv_bfloat16>) {
} else if constexpr (std::is_same_v<T, cuda_bfloat16>) {
float xf = __bfloat162float(x);
return __float2bfloat16_rn(__frcp_rn(__fadd_rn(1.0f, __expf(-xf))));
} else if constexpr (std::is_same_v<T, float>) {
......@@ -38,8 +34,8 @@ public:
return __hmul2(__hmul2(gate, sigmoid(gate)), up);
} else if constexpr (std::is_same_v<T, half>) {
return __hmul(__hmul(gate, sigmoid(gate)), up);
} else if constexpr (std::is_same_v<T, __nv_bfloat162>) {
__nv_bfloat162 sig = sigmoid(gate);
} else if constexpr (std::is_same_v<T, cuda_bfloat162>) {
cuda_bfloat162 sig = sigmoid(gate);
float gate0 = __bfloat162float(__low2bfloat16(gate));
float gate1 = __bfloat162float(__high2bfloat16(gate));
float sig0 = __bfloat162float(__low2bfloat16(sig));
......@@ -49,8 +45,8 @@ public:
float res0 = __fmul_rn(__fmul_rn(gate0, sig0), up0);
float res1 = __fmul_rn(__fmul_rn(gate1, sig1), up1);
return __floats2bfloat162_rn(res0, res1);
} else if constexpr (std::is_same_v<T, __nv_bfloat16>) {
__nv_bfloat16 sig = sigmoid(gate);
} else if constexpr (std::is_same_v<T, cuda_bfloat16>) {
cuda_bfloat16 sig = sigmoid(gate);
float gatef = __bfloat162float(gate);
float sigf = __bfloat162float(sig);
float upf = __bfloat162float(up);
......
#ifndef __SWIGLU_MACA_H__
#define __SWIGLU_MACA_H__
#include "../../../elementwise/maca/elementwise_maca.h"
#include <hctlass/half.h>
namespace op::swiglu::maca {
typedef struct SwiGLUOp {
private:
template <typename T>
__device__ __forceinline__ T sigmoid(const T &x) const {
if constexpr (std::is_same_v<T, half2>) {
return h2rcp(__hadd2(make_half2(1, 1), h2exp(__hneg2(x))));
} else if constexpr (std::is_same_v<T, half>) {
return hrcp(__hadd(half(1.f), __float2half(__expf(__half2float(__hneg(x))))));
} else if constexpr (std::is_same_v<T, float>) {
return __frcp_rn(__fadd_rn(1, __expf(-x)));
} else {
return 1 / (1 + std::exp(-x));
}
}
public:
static constexpr size_t num_inputs = 2;
template <typename T>
__device__ __forceinline__ T operator()(const T &up, const T &gate) const {
if constexpr (std::is_same_v<T, half2>) {
return __hmul2(__hmul2(gate, sigmoid(gate)), up);
} else if constexpr (std::is_same_v<T, half>) {
return __hmul(__hmul(gate, sigmoid(gate)), up);
} else if constexpr (std::is_same_v<T, float>) {
return __fmul_rn(__fmul_rn(gate, sigmoid(gate)), up);
} else {
return gate * sigmoid(gate) * up;
}
}
} SwiGLUOp;
} // namespace op::swiglu::maca
#endif
......@@ -3,6 +3,6 @@
#include "../../../elementwise/maca/elementwise_maca_api.h"
ELEMENTWISE_DESCRIPTOR(swiglu, maca)
ELEMENTWISE_DESCRIPTOR(swiglu, metax, maca)
#endif // __SWIGLU_MACA_API_H__
#include "swiglu_maca.h"
#include "swiglu_maca_internal.h"
#include "swiglu_metax.h"
namespace op::swiglu::maca {
#include "../../../elementwise/maca/elementwise_maca.h"
#include "../cuda/kernel.cuh"
namespace op::swiglu::metax {
Descriptor::~Descriptor() = default;
......@@ -20,7 +23,7 @@ infiniStatus_t Descriptor::create(
const auto &up_shape = up_desc->shape();
const auto &gate_shape = gate_desc->shape();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64);
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32, INFINI_DTYPE_F64);
CHECK_SAME_SHAPE(out_shape, up_shape, gate_shape);
// create MACA elementwise descriptor
......@@ -42,15 +45,17 @@ infiniStatus_t Descriptor::calculate(
switch (_dtype) {
case INFINI_DTYPE_F16:
return _device_info->calculate<256, SwiGLUOp, half>(_info, workspace, output, inputs, stream);
return _device_info->calculate<256, cuda::SwiGLUOp, half>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_BF16:
return _device_info->calculate<256, cuda::SwiGLUOp, cuda_bfloat16>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F32:
return _device_info->calculate<256, SwiGLUOp, float>(_info, workspace, output, inputs, stream);
return _device_info->calculate<256, cuda::SwiGLUOp, float>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F64:
return _device_info->calculate<256, SwiGLUOp, double>(_info, workspace, output, inputs, stream);
return _device_info->calculate<256, cuda::SwiGLUOp, double>(_info, workspace, output, inputs, stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::swiglu::maca
} // namespace op::swiglu::metax
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment