Commit 7712471f authored by zhuyue's avatar zhuyue
Browse files

Add NVIDIA GPU implementation for add_rms_norm and make residual_out required.

parent 2a432b34
......@@ -5,27 +5,43 @@ from infinicore.tensor import Tensor
def add_rms_norm(a, b, weight, epsilon=1e-5, *, out=None):
"""
Fused Add and RMS Normalization.
Args:
a: First input tensor
b: Second input tensor
weight: Scale weights
epsilon: Small constant for numerical stability, default is 1e-5
out: Optional output tuple (y, residual_out) for in-place operation
Returns:
Tuple of (normalized_result, add_result): (RMSNorm(a + b) * weight, a + b)
The add_result can be used as residual for subsequent layers.
"""
if out is None:
result = _infinicore.add_rms_norm(a._underlying, b._underlying, weight._underlying, epsilon)
result = _infinicore.add_rms_norm(
a._underlying, b._underlying, weight._underlying, epsilon
)
return (Tensor(result[0]), Tensor(result[1]))
y, residual_out = out
_infinicore.add_rms_norm_(y._underlying, residual_out._underlying, a._underlying, b._underlying, weight._underlying, epsilon)
_infinicore.add_rms_norm_(
y._underlying,
residual_out._underlying,
a._underlying,
b._underlying,
weight._underlying,
epsilon,
)
return (y, residual_out)
def add_rms_norm_(y, residual_out, a, b, weight, epsilon=1e-5):
"""In-place Fused Add and RMS Normalization."""
_infinicore.add_rms_norm_(y._underlying, residual_out._underlying, a._underlying, b._underlying, weight._underlying, epsilon)
_infinicore.add_rms_norm_(
y._underlying,
residual_out._underlying,
a._underlying,
b._underlying,
weight._underlying,
epsilon,
)
......@@ -6,8 +6,8 @@
#define DESCRIPTOR(NAMESPACE) \
\
namespace op::add_rms_norm::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
namespace op::add_rms_norm::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
AddRMSNormInfo _info; \
......@@ -19,7 +19,7 @@
size_t workspace_size, \
infiniDevice_t device_type, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
: InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
_info(info), \
_workspace_size(workspace_size) {} \
......@@ -29,24 +29,24 @@
\
size_t workspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t y_desc, \
infiniopTensorDescriptor_t y_desc, \
infiniopTensorDescriptor_t a_desc, \
infiniopTensorDescriptor_t b_desc, \
infiniopTensorDescriptor_t weight_desc, \
float epsilon, \
infiniopTensorDescriptor_t residual_out_desc); \
infiniopTensorDescriptor_t weight_desc, \
float epsilon, \
infiniopTensorDescriptor_t residual_out_desc); \
\
infiniStatus_t calculate( \
infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
void *y, \
const void *a, \
const void *b, \
const void *weight, \
void *residual_out, \
void *stream) const; \
void *stream) const; \
}; \
}
......
......@@ -36,16 +36,13 @@ infiniStatus_t add_rmsnorm(const AddRMSNormInfo *info, T *y, const T *a, const T
const T *a_ptr = a + i * info->a_strides[0] + j * info->a_strides[1];
const T *b_ptr = b + i * info->b_strides[0] + j * info->b_strides[1];
T *y_ptr = y + i * info->y_strides[0] + j * info->y_strides[1];
T *residual_out_ptr = info->has_residual_out ?
(residual_out + i * info->residual_out_strides[0] + j * info->residual_out_strides[1]) : nullptr;
T *residual_out_ptr = residual_out + i * info->residual_out_strides[0] + j * info->residual_out_strides[1];
// Compute add(a, b) once and store it
T sum_squared = (T)0;
for (size_t k = 0; k < dim; k++) {
T sum_val = a_ptr[k] + b_ptr[k];
if (residual_out_ptr != nullptr) {
residual_out_ptr[k] = sum_val; // Store add result
}
residual_out_ptr[k] = sum_val; // Store add result
sum_squared += sum_val * sum_val;
}
......@@ -54,18 +51,9 @@ infiniStatus_t add_rmsnorm(const AddRMSNormInfo *info, T *y, const T *a, const T
T rms = (T)1 / std::sqrt(sum_squared / (T)(dim) + (T)(info->epsilon));
// Apply normalization: y = (a + b) * w * rms
// Reuse the stored sum values if residual_out was computed, otherwise recompute
if (residual_out_ptr != nullptr) {
// Reuse stored values
for (size_t k = 0; k < dim; k++) {
y_ptr[k] = residual_out_ptr[k] * w[k] * rms;
}
} else {
// Recompute sum
for (size_t k = 0; k < dim; k++) {
T sum_val = a_ptr[k] + b_ptr[k];
y_ptr[k] = sum_val * w[k] * rms;
}
// Reuse stored values from residual_out
for (size_t k = 0; k < dim; k++) {
y_ptr[k] = residual_out_ptr[k] * w[k] * rms;
}
}
......@@ -90,16 +78,13 @@ infiniStatus_t add_rmsnormHalfPrecision(const AddRMSNormInfo *info, T *y, const
const T *a_ptr = a + i * info->a_strides[0] + j * info->a_strides[1];
const T *b_ptr = b + i * info->b_strides[0] + j * info->b_strides[1];
T *y_ptr = y + i * info->y_strides[0] + j * info->y_strides[1];
T *residual_out_ptr = info->has_residual_out ?
(residual_out + i * info->residual_out_strides[0] + j * info->residual_out_strides[1]) : nullptr;
T *residual_out_ptr = residual_out + i * info->residual_out_strides[0] + j * info->residual_out_strides[1];
// Compute sum of squares for RMS normalization and store add result
float sum_squared = 0.0f;
for (size_t k = 0; k < dim; k++) {
float sum_val = utils::cast<float>(a_ptr[k]) + utils::cast<float>(b_ptr[k]);
if (residual_out_ptr != nullptr) {
residual_out_ptr[k] = utils::cast<T>(sum_val); // Store add result
}
residual_out_ptr[k] = utils::cast<T>(sum_val); // Store add result
sum_squared += sum_val * sum_val;
}
......@@ -107,35 +92,18 @@ infiniStatus_t add_rmsnormHalfPrecision(const AddRMSNormInfo *info, T *y, const
float rms = 1.f / std::sqrt(sum_squared / (float)(dim) + info->epsilon);
// Apply normalization: y = (a + b) * w * rms
// Reuse stored values if residual_out was computed, otherwise recompute
if (residual_out_ptr != nullptr) {
// Reuse stored values
for (size_t k = 0; k < dim; k++) {
float sum_val = utils::cast<float>(residual_out_ptr[k]);
float val;
if constexpr (std::is_same<Tw, float>::value) {
val = sum_val * w[k] * rms;
} else if constexpr (std::is_same<Tw, T>::value || std::is_same_v<Tw, fp16_t> || std::is_same_v<Tw, bf16_t>) {
val = sum_val * utils::cast<float>(w[k]) * rms;
} else {
std::abort();
}
y_ptr[k] = utils::cast<T>(val);
}
} else {
// Recompute sum
for (size_t k = 0; k < dim; k++) {
float sum_val = utils::cast<float>(a_ptr[k]) + utils::cast<float>(b_ptr[k]);
float val;
if constexpr (std::is_same<Tw, float>::value) {
val = sum_val * w[k] * rms;
} else if constexpr (std::is_same<Tw, T>::value || std::is_same_v<Tw, fp16_t> || std::is_same_v<Tw, bf16_t>) {
val = sum_val * utils::cast<float>(w[k]) * rms;
} else {
std::abort();
}
y_ptr[k] = utils::cast<T>(val);
// Reuse stored values from residual_out
for (size_t k = 0; k < dim; k++) {
float sum_val = utils::cast<float>(residual_out_ptr[k]);
float val;
if constexpr (std::is_same<Tw, float>::value) {
val = sum_val * w[k] * rms;
} else if constexpr (std::is_same<Tw, T>::value || std::is_same_v<Tw, fp16_t> || std::is_same_v<Tw, bf16_t>) {
val = sum_val * utils::cast<float>(w[k]) * rms;
} else {
std::abort();
}
y_ptr[k] = utils::cast<T>(val);
}
}
......
#ifndef __ADD_RMS_NORM_CUDA_KERNEL_H__
#define __ADD_RMS_NORM_CUDA_KERNEL_H__
#include <cub/block/block_reduce.cuh>
template <unsigned int BLOCK_SIZE, typename Tcompute, typename Tdata, typename Tweight>
__device__ void add_rmsnormBlock(
Tdata *__restrict__ y,
Tdata *__restrict__ residual_out,
ptrdiff_t stride_y_batch,
ptrdiff_t stride_y_nhead,
ptrdiff_t stride_residual_out_batch,
ptrdiff_t stride_residual_out_nhead,
const Tdata *__restrict__ a,
ptrdiff_t stride_a_batch,
ptrdiff_t stride_a_nhead,
const Tdata *__restrict__ b,
ptrdiff_t stride_b_batch,
ptrdiff_t stride_b_nhead,
const Tweight *__restrict__ w,
size_t nhead,
size_t dim,
float epsilon) {
// Each block takes care of one head in one batch
// Each thread deals with every block_size element in the row
size_t batch_idx = blockIdx.x / nhead;
size_t head_idx = blockIdx.x % nhead;
auto y_ptr = y + batch_idx * stride_y_batch + head_idx * stride_y_nhead;
auto a_ptr = a + batch_idx * stride_a_batch + head_idx * stride_a_nhead;
auto b_ptr = b + batch_idx * stride_b_batch + head_idx * stride_b_nhead;
auto w_ptr = w;
Tdata *residual_out_ptr = residual_out + batch_idx * stride_residual_out_batch + head_idx * stride_residual_out_nhead;
// Compute add(a, b) and sum of squares in one pass
Tcompute sum_squared = 0;
for (size_t i = threadIdx.x; i < dim; i += BLOCK_SIZE) {
Tcompute sum_val = Tcompute(a_ptr[i]) + Tcompute(b_ptr[i]);
residual_out_ptr[i] = Tdata(sum_val); // Store add result
sum_squared += sum_val * sum_val;
}
// Block-reduce sum of squares
using BlockReduce = cub::BlockReduce<Tcompute, BLOCK_SIZE>;
__shared__ typename BlockReduce::TempStorage temp_storage;
sum_squared = BlockReduce(temp_storage).Sum(sum_squared);
// Thread_0 computes RMS=1/sqrt(ss/dim+epsilon) and stores in shared memory
__shared__ Tcompute rms;
if (threadIdx.x == 0) {
rms = Tcompute(rsqrtf(sum_squared / Tcompute(dim) + epsilon));
}
__syncthreads();
// Apply normalization: y = (a + b) * w * rms
// Reuse stored values from residual_out
for (size_t i = threadIdx.x; i < dim; i += BLOCK_SIZE) {
Tcompute sum_val = Tcompute(residual_out_ptr[i]); // Reuse stored value
y_ptr[i] = Tdata(sum_val * Tcompute(w_ptr[i]) * rms);
}
}
#endif
......@@ -34,12 +34,12 @@ public:
auto atype = y_desc->dtype();
auto wtype = weight_desc->dtype();
// Check that all input tensors have the same dtype
if (a_desc->dtype() != atype || b_desc->dtype() != atype) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
if (atype == INFINI_DTYPE_F16 || atype == INFINI_DTYPE_BF16) {
// For half-precision types (FP16/BF16), weights can be the same half-precision type or FP32
if (wtype != atype && wtype != INFINI_DTYPE_F32 && wtype != INFINI_DTYPE_BF16 && wtype != INFINI_DTYPE_F16) {
......@@ -71,9 +71,7 @@ public:
batch = y_desc->dim(0);
dim = y_desc->dim(1);
if (a_desc->dim(0) != batch || a_desc->dim(1) != dim ||
b_desc->dim(0) != batch || b_desc->dim(1) != dim ||
weight_desc->dim(0) != dim) {
if (a_desc->dim(0) != batch || a_desc->dim(1) != dim || b_desc->dim(0) != batch || b_desc->dim(1) != dim || weight_desc->dim(0) != dim) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
} else if (y_ndim == 3) {
......@@ -81,9 +79,7 @@ public:
nhead = y_desc->dim(1);
dim = y_desc->dim(2);
if (a_desc->dim(0) != batch || a_desc->dim(1) != nhead || a_desc->dim(2) != dim ||
b_desc->dim(0) != batch || b_desc->dim(1) != nhead || b_desc->dim(2) != dim ||
weight_desc->dim(0) != dim) {
if (a_desc->dim(0) != batch || a_desc->dim(1) != nhead || a_desc->dim(2) != dim || b_desc->dim(0) != batch || b_desc->dim(1) != nhead || b_desc->dim(2) != dim || weight_desc->dim(0) != dim) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
} else {
......@@ -91,32 +87,30 @@ public:
}
// Check contiguity of the last dimension
if (y_desc->stride(y_ndim - 1) != 1 ||
a_desc->stride(a_ndim - 1) != 1 ||
b_desc->stride(b_ndim - 1) != 1 ||
weight_desc->stride(w_ndim - 1) != 1) {
if (y_desc->stride(y_ndim - 1) != 1 || a_desc->stride(a_ndim - 1) != 1 || b_desc->stride(b_ndim - 1) != 1 || weight_desc->stride(w_ndim - 1) != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
// Check residual_out_desc if provided
bool has_residual_out = (residual_out_desc != nullptr);
if (has_residual_out) {
const size_t residual_out_ndim = residual_out_desc->ndim();
if (residual_out_ndim != y_ndim) {
// residual_out_desc is required (always needed for fused operator)
if (residual_out_desc == nullptr) {
return INFINI_STATUS_BAD_PARAM;
}
const size_t residual_out_ndim = residual_out_desc->ndim();
if (residual_out_ndim != y_ndim) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (residual_out_desc->dtype() != atype) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
// Check shape matches
for (size_t i = 0; i < y_ndim; i++) {
if (residual_out_desc->dim(i) != y_desc->dim(i)) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (residual_out_desc->dtype() != atype) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
// Check shape matches
for (size_t i = 0; i < y_ndim; i++) {
if (residual_out_desc->dim(i) != y_desc->dim(i)) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
if (residual_out_desc->stride(residual_out_ndim - 1) != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
}
if (residual_out_desc->stride(residual_out_ndim - 1) != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
AddRMSNormInfo info;
......@@ -127,10 +121,8 @@ public:
info.y_strides = y_desc->strides();
info.a_strides = a_desc->strides();
info.b_strides = b_desc->strides();
info.has_residual_out = has_residual_out;
if (has_residual_out) {
info.residual_out_strides = residual_out_desc->strides();
}
info.has_residual_out = true; // Always true now
info.residual_out_strides = residual_out_desc->strides();
return utils::Result<AddRMSNormInfo>(info);
}
};
......
#include "../../../devices/nvidia/nvidia_common.cuh"
#include "add_rms_norm_nvidia.cuh"
#include "../../../devices/nvidia/nvidia_kernel_common.cuh"
#include <cub/block/block_reduce.cuh>
#include "../../../reduce/cuda/reduce.cuh"
#include "../cuda/kernel.cuh"
template <unsigned int BLOCK_SIZE, typename Tcompute, typename Tdata, typename Tweight>
INFINIOP_CUDA_KERNEL add_rmsnormKernel(
Tdata *__restrict__ y,
Tdata *__restrict__ residual_out,
ptrdiff_t stride_y_batch,
ptrdiff_t stride_y_nhead,
ptrdiff_t stride_residual_out_batch,
ptrdiff_t stride_residual_out_nhead,
const Tdata *__restrict__ a,
ptrdiff_t stride_a_batch,
ptrdiff_t stride_a_nhead,
const Tdata *__restrict__ b,
ptrdiff_t stride_b_batch,
ptrdiff_t stride_b_nhead,
const Tweight *__restrict__ w,
size_t nhead,
size_t dim,
float epsilon) {
add_rmsnormBlock<BLOCK_SIZE, Tcompute>(
y, residual_out,
stride_y_batch, stride_y_nhead,
stride_residual_out_batch, stride_residual_out_nhead,
a, stride_a_batch, stride_a_nhead,
b, stride_b_batch, stride_b_nhead,
w, nhead, dim, epsilon);
}
namespace op::add_rms_norm::nvidia {
struct Descriptor::Opaque {
std::shared_ptr<device::nvidia::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc,
infiniopTensorDescriptor_t weight_desc,
float epsilon,
infiniopTensorDescriptor_t residual_out_desc) {
auto result = AddRMSNormInfo::create(y_desc, a_desc, b_desc, weight_desc, epsilon, residual_out_desc);
CHECK_RESULT(result);
auto info = result.take();
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::nvidia::Handle *>(handle)->internal()},
std::move(info),
0,
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
// launch kernel with different data types
template <unsigned int BLOCK_SIZE>
infiniStatus_t launchKernel(
uint32_t batch_size, size_t nhead, size_t dim,
void *y, infiniDtype_t atype, ptrdiff_t stride_y_batch, ptrdiff_t stride_y_nhead,
void *residual_out, ptrdiff_t stride_residual_out_batch, ptrdiff_t stride_residual_out_nhead,
const void *a, ptrdiff_t stride_a_batch, ptrdiff_t stride_a_nhead,
const void *b, ptrdiff_t stride_b_batch, ptrdiff_t stride_b_nhead,
const void *w, infiniDtype_t wtype,
float epsilon,
cudaStream_t cuda_stream) {
#define LAUNCH_KERNEL(Tdata, Tweight, Tcompute) \
add_rmsnormKernel<BLOCK_SIZE, Tcompute, Tdata, Tweight><<<batch_size * nhead, BLOCK_SIZE, 0, cuda_stream>>>( \
reinterpret_cast<Tdata *>(y), \
reinterpret_cast<Tdata *>(residual_out), \
stride_y_batch, \
stride_y_nhead, \
stride_residual_out_batch, \
stride_residual_out_nhead, \
reinterpret_cast<const Tdata *>(a), \
stride_a_batch, \
stride_a_nhead, \
reinterpret_cast<const Tdata *>(b), \
stride_b_batch, \
stride_b_nhead, \
reinterpret_cast<const Tweight *>(w), \
nhead, \
dim, \
epsilon)
if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) {
LAUNCH_KERNEL(half, half, float);
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_BF16) {
LAUNCH_KERNEL(half, __nv_bfloat16, float);
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(half, float, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_BF16) {
LAUNCH_KERNEL(__nv_bfloat16, __nv_bfloat16, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F16) {
LAUNCH_KERNEL(__nv_bfloat16, half, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(__nv_bfloat16, float, float);
} else if (atype == INFINI_DTYPE_F32 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(float, float, float);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
#undef LAUNCH_KERNEL
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *y, const void *a, const void *b, const void *weight,
void *residual_out, void *stream) const {
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
auto stride_a_batch = _info.a_strides[0];
auto stride_a_nhead = _info.a_strides[1];
auto stride_b_batch = _info.b_strides[0];
auto stride_b_nhead = _info.b_strides[1];
auto stride_y_batch = _info.y_strides[0];
auto stride_y_nhead = _info.y_strides[1];
auto stride_residual_out_batch = _info.residual_out_strides[0];
auto stride_residual_out_nhead = _info.residual_out_strides[1];
auto dim = _info.dim();
uint32_t batch_size = static_cast<uint32_t>(_info.shape[0]);
size_t nhead = _info.shape.size() > 2 ? _info.shape[1] : 1;
auto cuda_stream = reinterpret_cast<cudaStream_t>(stream);
// launch kernel with different block sizes
if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) {
CHECK_STATUS(launchKernel<CUDA_BLOCK_SIZE_1024>(
batch_size, nhead, dim,
y, _info.atype, stride_y_batch, stride_y_nhead,
residual_out, stride_residual_out_batch, stride_residual_out_nhead,
a, stride_a_batch, stride_a_nhead,
b, stride_b_batch, stride_b_nhead,
weight, _info.wtype, _info.epsilon, cuda_stream));
} else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_512) {
CHECK_STATUS(launchKernel<CUDA_BLOCK_SIZE_512>(
batch_size, nhead, dim,
y, _info.atype, stride_y_batch, stride_y_nhead,
residual_out, stride_residual_out_batch, stride_residual_out_nhead,
a, stride_a_batch, stride_a_nhead,
b, stride_b_batch, stride_b_nhead,
weight, _info.wtype, _info.epsilon, cuda_stream));
} else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_4096) {
CHECK_STATUS(launchKernel<CUDA_BLOCK_SIZE_4096>(
batch_size, nhead, dim,
y, _info.atype, stride_y_batch, stride_y_nhead,
residual_out, stride_residual_out_batch, stride_residual_out_nhead,
a, stride_a_batch, stride_a_nhead,
b, stride_b_batch, stride_b_nhead,
weight, _info.wtype, _info.epsilon, cuda_stream));
} else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::add_rms_norm::nvidia
#ifndef __ADD_RMS_NORM_NVIDIA_CUDA_H__
#define __ADD_RMS_NORM_NVIDIA_CUDA_H__
#include "../add_rms_norm.h"
DESCRIPTOR(nvidia)
#endif
......@@ -6,8 +6,7 @@
#include "cpu/add_rms_norm_cpu.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API)
// TODO: Add NVIDIA implementation
// #include "nvidia/add_rms_norm_nvidia.cuh"
#include "nvidia/add_rms_norm_nvidia.cuh"
#endif
#ifdef ENABLE_ASCEND_API
// TODO: Add Ascend implementation
......@@ -40,16 +39,16 @@ __C infiniStatus_t infiniopCreateAddRMSNormDescriptor(
float epsilon,
infiniopTensorDescriptor_t residual_out_desc) {
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::add_rms_norm::NAMESPACE::Descriptor::create( \
handle, \
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::add_rms_norm::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::add_rms_norm::NAMESPACE::Descriptor **>(desc_ptr), \
y_desc, \
a_desc, \
b_desc, \
weight_desc, \
epsilon, \
y_desc, \
a_desc, \
b_desc, \
weight_desc, \
epsilon, \
residual_out_desc)
switch (handle->device) {
......@@ -57,16 +56,16 @@ __C infiniStatus_t infiniopCreateAddRMSNormDescriptor(
CREATE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NVIDIA_API
// CREATE(INFINI_DEVICE_NVIDIA, nvidia);
CREATE(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
// CREATE(INFINI_DEVICE_ILUVATAR, nvidia);
CREATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_QY_API
// CREATE(INFINI_DEVICE_QY, nvidia);
CREATE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_HYGON_API
// CREATE(INFINI_DEVICE_HYGON, nvidia);
CREATE(INFINI_DEVICE_HYGON, nvidia);
#endif
#ifdef ENABLE_KUNLUN_API
// CREATE(INFINI_DEVICE_KUNLUN, kunlun);
......@@ -80,8 +79,8 @@ __C infiniStatus_t infiniopCreateAddRMSNormDescriptor(
__C infiniStatus_t infiniopGetAddRMSNormWorkspaceSize(infiniopAddRMSNormDescriptor_t desc, size_t *size) {
#define GET(CASE, NAMESPACE) \
case CASE: \
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<op::add_rms_norm::NAMESPACE::Descriptor *>(desc)->workspaceSize(); \
return INFINI_STATUS_SUCCESS
......@@ -90,16 +89,16 @@ __C infiniStatus_t infiniopGetAddRMSNormWorkspaceSize(infiniopAddRMSNormDescript
GET(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NVIDIA_API
// GET(INFINI_DEVICE_NVIDIA, nvidia);
GET(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
// GET(INFINI_DEVICE_ILUVATAR, nvidia);
GET(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_QY_API
// GET(INFINI_DEVICE_QY, nvidia);
GET(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_HYGON_API
// GET(INFINI_DEVICE_HYGON, nvidia);
GET(INFINI_DEVICE_HYGON, nvidia);
#endif
#ifdef ENABLE_KUNLUN_API
// GET(INFINI_DEVICE_KUNLUN, kunlun);
......@@ -123,9 +122,9 @@ __C infiniStatus_t infiniopAddRMSNorm(
void *residual_out,
void *stream) {
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<const op::add_rms_norm::NAMESPACE::Descriptor *>(desc) \
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<const op::add_rms_norm::NAMESPACE::Descriptor *>(desc) \
->calculate(workspace, workspace_size, y, a, b, weight, residual_out, stream)
switch (desc->device_type) {
......@@ -134,16 +133,16 @@ __C infiniStatus_t infiniopAddRMSNorm(
CALCULATE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NVIDIA_API
// CALCULATE(INFINI_DEVICE_NVIDIA, nvidia);
CALCULATE(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
// CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia);
CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_QY_API
// CALCULATE(INFINI_DEVICE_QY, nvidia);
CALCULATE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_HYGON_API
// CALCULATE(INFINI_DEVICE_HYGON, nvidia);
CALCULATE(INFINI_DEVICE_HYGON, nvidia);
#endif
#ifdef ENABLE_KUNLUN_API
// CALCULATE(INFINI_DEVICE_KUNLUN, kunlun);
......@@ -159,9 +158,9 @@ __C infiniStatus_t infiniopDestroyAddRMSNormDescriptor(infiniopAddRMSNormDescrip
return INFINI_STATUS_SUCCESS;
}
#define DESTROY(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<op::add_rms_norm::NAMESPACE::Descriptor *>(desc); \
#define DESTROY(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<op::add_rms_norm::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS
switch (desc->device_type) {
......@@ -169,16 +168,16 @@ __C infiniStatus_t infiniopDestroyAddRMSNormDescriptor(infiniopAddRMSNormDescrip
DESTROY(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NVIDIA_API
// DESTROY(INFINI_DEVICE_NVIDIA, nvidia);
DESTROY(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
// DESTROY(INFINI_DEVICE_ILUVATAR, nvidia);
DESTROY(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_QY_API
// DESTROY(INFINI_DEVICE_QY, nvidia);
DESTROY(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_HYGON_API
// DESTROY(INFINI_DEVICE_HYGON, nvidia);
DESTROY(INFINI_DEVICE_HYGON, nvidia);
#endif
#ifdef ENABLE_KUNLUN_API
// DESTROY(INFINI_DEVICE_KUNLUN, kunlun);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment