Commit dfafc21f authored by wooway777's avatar wooway777
Browse files

issue/884 - add_rms_norm on iluvatar, metax and moore

parent 0c204dfd
#ifndef __ADD_RMS_NORM_METAX_CUH__
#define __ADD_RMS_NORM_METAX_CUH__
#include "../add_rms_norm.h"
DESCRIPTOR(metax)
#endif
#include "../../../devices/metax/metax_common.h"
#include "add_rms_norm_metax.cuh"
#include "../../../devices/metax/metax_kernel_common.h"
#include <cub/block/block_reduce.cuh>
#include "../../../reduce/cuda/reduce.cuh"
#include "../cuda/kernel.cuh"
// Kernel function template for add_rms_norm on Metax platform
template <unsigned int BLOCK_SIZE, typename Tcompute, typename Tdata, typename Tweight>
INFINIOP_METAX_KERNEL add_rmsnormKernel(
Tdata *__restrict__ y,
Tdata *__restrict__ residual_out,
ptrdiff_t stride_y_batch,
ptrdiff_t stride_y_nhead,
ptrdiff_t stride_residual_out_batch,
ptrdiff_t stride_residual_out_nhead,
const Tdata *__restrict__ a,
ptrdiff_t stride_a_batch,
ptrdiff_t stride_a_nhead,
const Tdata *__restrict__ b,
ptrdiff_t stride_b_batch,
ptrdiff_t stride_b_nhead,
const Tweight *__restrict__ w,
size_t nhead,
size_t dim,
float epsilon) {
add_rmsnormBlock<BLOCK_SIZE, Tcompute>(
y, residual_out,
stride_y_batch, stride_y_nhead,
stride_residual_out_batch, stride_residual_out_nhead,
a, stride_a_batch, stride_a_nhead,
b, stride_b_batch, stride_b_nhead,
w, nhead, dim, epsilon);
}
namespace op::add_rms_norm::metax {
// Internal opaque structure for Metax device handle
struct Descriptor::Opaque {
std::shared_ptr<device::metax::Handle::Internal> internal;
};
// Destructor
Descriptor::~Descriptor() {
delete _opaque;
}
// Create descriptor for add_rms_norm operator
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc,
infiniopTensorDescriptor_t weight_desc,
float epsilon,
infiniopTensorDescriptor_t residual_out_desc) {
auto result = AddRMSNormInfo::create(y_desc, a_desc, b_desc, weight_desc, epsilon, residual_out_desc);
CHECK_RESULT(result);
auto info = result.take();
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::metax::Handle *>(handle)->internal()},
std::move(info),
0,
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
// Launch kernel with different data types
template <unsigned int BLOCK_SIZE>
infiniStatus_t launchKernel(
uint32_t batch_size, size_t nhead, size_t dim,
void *y, infiniDtype_t atype, ptrdiff_t stride_y_batch, ptrdiff_t stride_y_nhead,
void *residual_out, ptrdiff_t stride_residual_out_batch, ptrdiff_t stride_residual_out_nhead,
const void *a, ptrdiff_t stride_a_batch, ptrdiff_t stride_a_nhead,
const void *b, ptrdiff_t stride_b_batch, ptrdiff_t stride_b_nhead,
const void *w, infiniDtype_t wtype,
float epsilon,
hcStream_t stream) {
#define LAUNCH_KERNEL(Tdata, Tweight, Tcompute) \
add_rmsnormKernel<BLOCK_SIZE, Tcompute, Tdata, Tweight><<<batch_size * nhead, BLOCK_SIZE, 0, stream>>>( \
reinterpret_cast<Tdata *>(y), \
reinterpret_cast<Tdata *>(residual_out), \
stride_y_batch, \
stride_y_nhead, \
stride_residual_out_batch, \
stride_residual_out_nhead, \
reinterpret_cast<const Tdata *>(a), \
stride_a_batch, \
stride_a_nhead, \
reinterpret_cast<const Tdata *>(b), \
stride_b_batch, \
stride_b_nhead, \
reinterpret_cast<const Tweight *>(w), \
nhead, \
dim, \
epsilon)
// Handle different data type combinations following Metax pattern
if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) {
LAUNCH_KERNEL(half, half, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_BF16) {
LAUNCH_KERNEL(__hpcc_bfloat16, __hpcc_bfloat16, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(__hpcc_bfloat16, float, float);
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(half, float, float);
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_BF16) {
LAUNCH_KERNEL(half, __hpcc_bfloat16, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F16) {
LAUNCH_KERNEL(__hpcc_bfloat16, half, float);
} else if (atype == INFINI_DTYPE_F32 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(float, float, float);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
#undef LAUNCH_KERNEL
return INFINI_STATUS_SUCCESS;
}
// Main calculation function
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *y, const void *a, const void *b, const void *weight,
void *residual_out, void *stream_) const {
// Check workspace size
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
// Extract tensor strides and dimensions
auto stride_a_batch = _info.a_strides[0];
auto stride_a_nhead = _info.a_strides[1];
auto stride_b_batch = _info.b_strides[0];
auto stride_b_nhead = _info.b_strides[1];
auto stride_y_batch = _info.y_strides[0];
auto stride_y_nhead = _info.y_strides[1];
auto stride_residual_out_batch = _info.residual_out_strides[0];
auto stride_residual_out_nhead = _info.residual_out_strides[1];
auto dim = _info.dim();
uint32_t batch_size = static_cast<uint32_t>(_info.shape[0]);
size_t nhead = _info.shape.size() > 2 ? _info.shape[1] : 1;
auto stream = reinterpret_cast<hcStream_t>(stream_);
// Launch kernel with appropriate block size based on device capability
if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_1024) {
CHECK_STATUS(launchKernel<METAX_BLOCK_SIZE_1024>(
batch_size, nhead, dim,
y, _info.atype, stride_y_batch, stride_y_nhead,
residual_out, stride_residual_out_batch, stride_residual_out_nhead,
a, stride_a_batch, stride_a_nhead,
b, stride_b_batch, stride_b_nhead,
weight, _info.wtype, _info.epsilon, stream));
} else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::add_rms_norm::metax
#ifndef __ADD_RMS_NORM_MOORE_H__
#define __ADD_RMS_NORM_MOORE_H__
#include "../add_rms_norm.h"
DESCRIPTOR(moore)
#endif
#include "../../../devices/moore/moore_common.h"
#include "add_rms_norm_moore.h"
#include "../../../devices/moore/moore_kernel_common.h"
#include <cub/block/block_reduce.cuh>
#include "../../../reduce/cuda/reduce.cuh"
#include "../cuda/kernel.cuh"
// Kernel function template for add_rms_norm on Moore platform
template <unsigned int BLOCK_SIZE, typename Tcompute, typename Tdata, typename Tweight>
INFINIOP_MOORE_KERNEL add_rmsnormKernel(
Tdata *__restrict__ y,
Tdata *__restrict__ residual_out,
ptrdiff_t stride_y_batch,
ptrdiff_t stride_y_nhead,
ptrdiff_t stride_residual_out_batch,
ptrdiff_t stride_residual_out_nhead,
const Tdata *__restrict__ a,
ptrdiff_t stride_a_batch,
ptrdiff_t stride_a_nhead,
const Tdata *__restrict__ b,
ptrdiff_t stride_b_batch,
ptrdiff_t stride_b_nhead,
const Tweight *__restrict__ w,
size_t nhead,
size_t dim,
float epsilon) {
add_rmsnormBlock<BLOCK_SIZE, Tcompute>(
y, residual_out,
stride_y_batch, stride_y_nhead,
stride_residual_out_batch, stride_residual_out_nhead,
a, stride_a_batch, stride_a_nhead,
b, stride_b_batch, stride_b_nhead,
w, nhead, dim, epsilon);
}
namespace op::add_rms_norm::moore {
// Internal opaque structure for Moore device handle
struct Descriptor::Opaque {
std::shared_ptr<device::moore::Handle::Internal> internal;
};
// Destructor
Descriptor::~Descriptor() {
delete _opaque;
}
// Create descriptor for add_rms_norm operator
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc,
infiniopTensorDescriptor_t weight_desc,
float epsilon,
infiniopTensorDescriptor_t residual_out_desc) {
auto result = AddRMSNormInfo::create(y_desc, a_desc, b_desc, weight_desc, epsilon, residual_out_desc);
CHECK_RESULT(result);
auto info = result.take();
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::moore::Handle *>(handle)->internal()},
std::move(info),
0,
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
// Launch kernel with different data types
template <unsigned int BLOCK_SIZE>
infiniStatus_t launchKernel(
uint32_t batch_size, size_t nhead, size_t dim,
void *y, infiniDtype_t atype, ptrdiff_t stride_y_batch, ptrdiff_t stride_y_nhead,
void *residual_out, ptrdiff_t stride_residual_out_batch, ptrdiff_t stride_residual_out_nhead,
const void *a, ptrdiff_t stride_a_batch, ptrdiff_t stride_a_nhead,
const void *b, ptrdiff_t stride_b_batch, ptrdiff_t stride_b_nhead,
const void *w, infiniDtype_t wtype,
float epsilon,
musaStream_t musa_stream) {
#define LAUNCH_KERNEL(Tdata, Tweight, Tcompute) \
add_rmsnormKernel<BLOCK_SIZE, Tcompute, Tdata, Tweight><<<batch_size * nhead, BLOCK_SIZE, 0, musa_stream>>>( \
reinterpret_cast<Tdata *>(y), \
reinterpret_cast<Tdata *>(residual_out), \
stride_y_batch, \
stride_y_nhead, \
stride_residual_out_batch, \
stride_residual_out_nhead, \
reinterpret_cast<const Tdata *>(a), \
stride_a_batch, \
stride_a_nhead, \
reinterpret_cast<const Tdata *>(b), \
stride_b_batch, \
stride_b_nhead, \
reinterpret_cast<const Tweight *>(w), \
nhead, \
dim, \
epsilon)
// Handle different data type combinations
if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) {
LAUNCH_KERNEL(half, half, float);
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_BF16) {
LAUNCH_KERNEL(half, __mt_bfloat16, float);
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(half, float, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_BF16) {
LAUNCH_KERNEL(__mt_bfloat16, __mt_bfloat16, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F16) {
LAUNCH_KERNEL(__mt_bfloat16, half, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(__mt_bfloat16, float, float);
} else if (atype == INFINI_DTYPE_F32 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(float, float, float);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
#undef LAUNCH_KERNEL
return INFINI_STATUS_SUCCESS;
}
// Main calculation function
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *y, const void *a, const void *b, const void *weight,
void *residual_out, void *stream) const {
// Check workspace size
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
// Extract tensor strides and dimensions
auto stride_a_batch = _info.a_strides[0];
auto stride_a_nhead = _info.a_strides[1];
auto stride_b_batch = _info.b_strides[0];
auto stride_b_nhead = _info.b_strides[1];
auto stride_y_batch = _info.y_strides[0];
auto stride_y_nhead = _info.y_strides[1];
auto stride_residual_out_batch = _info.residual_out_strides[0];
auto stride_residual_out_nhead = _info.residual_out_strides[1];
auto dim = _info.dim();
uint32_t batch_size = static_cast<uint32_t>(_info.shape[0]);
size_t nhead = _info.shape.size() > 2 ? _info.shape[1] : 1;
auto musa_stream = reinterpret_cast<musaStream_t>(stream);
// Launch kernel with appropriate block size based on device capability
if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_1024) {
CHECK_STATUS(launchKernel<MOORE_BLOCK_SIZE_1024>(
batch_size, nhead, dim,
y, _info.atype, stride_y_batch, stride_y_nhead,
residual_out, stride_residual_out_batch, stride_residual_out_nhead,
a, stride_a_batch, stride_a_nhead,
b, stride_b_batch, stride_b_nhead,
weight, _info.wtype, _info.epsilon, musa_stream));
} else if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_512) {
CHECK_STATUS(launchKernel<MOORE_BLOCK_SIZE_512>(
batch_size, nhead, dim,
y, _info.atype, stride_y_batch, stride_y_nhead,
residual_out, stride_residual_out_batch, stride_residual_out_nhead,
a, stride_a_batch, stride_a_nhead,
b, stride_b_batch, stride_b_nhead,
weight, _info.wtype, _info.epsilon, musa_stream));
} else if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_2048) {
CHECK_STATUS(launchKernel<MOORE_BLOCK_SIZE_2048>(
batch_size, nhead, dim,
y, _info.atype, stride_y_batch, stride_y_nhead,
residual_out, stride_residual_out_batch, stride_residual_out_nhead,
a, stride_a_batch, stride_a_nhead,
b, stride_b_batch, stride_b_nhead,
weight, _info.wtype, _info.epsilon, musa_stream));
} else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::add_rms_norm::moore
......@@ -143,7 +143,15 @@ infiniStatus_t Descriptor::calculate(
auto cuda_stream = reinterpret_cast<cudaStream_t>(stream);
// launch kernel with different block sizes
if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) {
if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_512) {
CHECK_STATUS(launchKernel<CUDA_BLOCK_SIZE_512>(
batch_size, nhead, dim,
y, _info.atype, stride_y_batch, stride_y_nhead,
residual_out, stride_residual_out_batch, stride_residual_out_nhead,
a, stride_a_batch, stride_a_nhead,
b, stride_b_batch, stride_b_nhead,
weight, _info.wtype, _info.epsilon, cuda_stream));
} else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) {
CHECK_STATUS(launchKernel<CUDA_BLOCK_SIZE_1024>(
batch_size, nhead, dim,
y, _info.atype, stride_y_batch, stride_y_nhead,
......@@ -151,8 +159,8 @@ infiniStatus_t Descriptor::calculate(
a, stride_a_batch, stride_a_nhead,
b, stride_b_batch, stride_b_nhead,
weight, _info.wtype, _info.epsilon, cuda_stream));
} else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_512) {
CHECK_STATUS(launchKernel<CUDA_BLOCK_SIZE_512>(
} else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_2048) {
CHECK_STATUS(launchKernel<CUDA_BLOCK_SIZE_2048>(
batch_size, nhead, dim,
y, _info.atype, stride_y_batch, stride_y_nhead,
residual_out, stride_residual_out_batch, stride_residual_out_nhead,
......
......@@ -17,12 +17,10 @@
// #include "bang/add_rms_norm_bang.h"
#endif
#ifdef ENABLE_METAX_API
// TODO: Add Metax implementation
// #include "metax/add_rms_norm_metax.cuh"
#include "metax/add_rms_norm_metax.cuh"
#endif
#ifdef ENABLE_MOORE_API
// TODO: Add Moore implementation
// #include "moore/add_rms_norm_moore.h"
#include "moore/add_rms_norm_moore.h"
#endif
#ifdef ENABLE_KUNLUN_API
// TODO: Add Kunlun implementation
......@@ -61,6 +59,12 @@ __C infiniStatus_t infiniopCreateAddRMSNormDescriptor(
#ifdef ENABLE_ILUVATAR_API
CREATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_MOORE_API
CREATE(INFINI_DEVICE_MOORE, moore);
#endif
#ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_QY_API
CREATE(INFINI_DEVICE_QY, nvidia);
#endif
......@@ -94,6 +98,12 @@ __C infiniStatus_t infiniopGetAddRMSNormWorkspaceSize(infiniopAddRMSNormDescript
#ifdef ENABLE_ILUVATAR_API
GET(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_MOORE_API
GET(INFINI_DEVICE_MOORE, moore);
#endif
#ifdef ENABLE_METAX_API
GET(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_QY_API
GET(INFINI_DEVICE_QY, nvidia);
#endif
......@@ -138,6 +148,12 @@ __C infiniStatus_t infiniopAddRMSNorm(
#ifdef ENABLE_ILUVATAR_API
CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_MOORE_API
CALCULATE(INFINI_DEVICE_MOORE, moore);
#endif
#ifdef ENABLE_METAX_API
CALCULATE(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_QY_API
CALCULATE(INFINI_DEVICE_QY, nvidia);
#endif
......@@ -173,6 +189,12 @@ __C infiniStatus_t infiniopDestroyAddRMSNormDescriptor(infiniopAddRMSNormDescrip
#ifdef ENABLE_ILUVATAR_API
DESTROY(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_MOORE_API
DESTROY(INFINI_DEVICE_MOORE, moore);
#endif
#ifdef ENABLE_METAX_API
DESTROY(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_QY_API
DESTROY(INFINI_DEVICE_QY, nvidia);
#endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment