Commit 0611cb1b authored by wooway777's avatar wooway777
Browse files

issue/791 - fix add_rmsnorm api on mtx and mth

parent 4ddc6647
......@@ -8,8 +8,10 @@
// Posible maximum number of threads per block for METAX architectures
// Used for picking correct kernel launch configuration
#define METAX_BLOCK_SIZE_1024 1024
#define METAX_BLOCK_SIZE_512 512
#define METAX_BLOCK_SIZE_1024 1024
#define METAX_BLOCK_SIZE_2048 2048
#define METAX_BLOCK_SIZE_4096 4096
#define CHECK_METAX(API) CHECK_INTERNAL(API, hcSuccess)
......
......@@ -6,6 +6,7 @@
// Posible maximum number of threads per block for MUSA architectures
// Used for picking correct kernel launch configuration
#define MOORE_BLOCK_SIZE_4096 4096
#define MOORE_BLOCK_SIZE_2048 2048
#define MOORE_BLOCK_SIZE_1024 1024
#define MOORE_BLOCK_SIZE_512 512
......
......@@ -53,12 +53,12 @@ infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t residual_out_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc,
infiniopTensorDescriptor_t weight_desc,
float epsilon,
infiniopTensorDescriptor_t residual_out_desc) {
auto result = AddRMSNormInfo::create(y_desc, a_desc, b_desc, weight_desc, epsilon, residual_out_desc);
float epsilon) {
auto result = AddRMSNormInfo::create(y_desc, residual_out_desc, a_desc, b_desc, weight_desc, epsilon);
CHECK_RESULT(result);
auto info = result.take();
......@@ -104,16 +104,16 @@ infiniStatus_t launchKernel(
// Handle different data type combinations following Metax pattern
if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) {
LAUNCH_KERNEL(half, half, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_BF16) {
LAUNCH_KERNEL(__hpcc_bfloat16, __hpcc_bfloat16, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(__hpcc_bfloat16, float, float);
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(half, float, float);
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_BF16) {
LAUNCH_KERNEL(half, __hpcc_bfloat16, float);
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(half, float, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_BF16) {
LAUNCH_KERNEL(__hpcc_bfloat16, __hpcc_bfloat16, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F16) {
LAUNCH_KERNEL(__hpcc_bfloat16, half, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(__hpcc_bfloat16, float, float);
} else if (atype == INFINI_DTYPE_F32 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(float, float, float);
} else {
......@@ -128,8 +128,8 @@ infiniStatus_t launchKernel(
// Main calculation function
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *y, const void *a, const void *b, const void *weight,
void *residual_out, void *stream_) const {
void *y, void *residual_out, const void *a, const void *b, const void *weight,
void *stream) const {
// Check workspace size
if (workspace_size < _workspace_size) {
......@@ -148,17 +148,41 @@ infiniStatus_t Descriptor::calculate(
auto dim = _info.dim();
uint32_t batch_size = static_cast<uint32_t>(_info.shape[0]);
size_t nhead = _info.shape.size() > 2 ? _info.shape[1] : 1;
auto stream = reinterpret_cast<hcStream_t>(stream_);
auto stream_ = reinterpret_cast<hcStream_t>(stream);
// Launch kernel with appropriate block size based on device capability
if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_1024) {
// Launch kernel with different block sizes
if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_512) {
CHECK_STATUS(launchKernel<METAX_BLOCK_SIZE_512>(
batch_size, nhead, dim,
y, _info.atype, stride_y_batch, stride_y_nhead,
residual_out, stride_residual_out_batch, stride_residual_out_nhead,
a, stride_a_batch, stride_a_nhead,
b, stride_b_batch, stride_b_nhead,
weight, _info.wtype, _info.epsilon, stream_));
} else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_1024) {
CHECK_STATUS(launchKernel<METAX_BLOCK_SIZE_1024>(
batch_size, nhead, dim,
y, _info.atype, stride_y_batch, stride_y_nhead,
residual_out, stride_residual_out_batch, stride_residual_out_nhead,
a, stride_a_batch, stride_a_nhead,
b, stride_b_batch, stride_b_nhead,
weight, _info.wtype, _info.epsilon, stream));
weight, _info.wtype, _info.epsilon, stream_));
} else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_2048) {
CHECK_STATUS(launchKernel<METAX_BLOCK_SIZE_2048>(
batch_size, nhead, dim,
y, _info.atype, stride_y_batch, stride_y_nhead,
residual_out, stride_residual_out_batch, stride_residual_out_nhead,
a, stride_a_batch, stride_a_nhead,
b, stride_b_batch, stride_b_nhead,
weight, _info.wtype, _info.epsilon, stream_));
} else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_4096) {
CHECK_STATUS(launchKernel<METAX_BLOCK_SIZE_4096>(
batch_size, nhead, dim,
y, _info.atype, stride_y_batch, stride_y_nhead,
residual_out, stride_residual_out_batch, stride_residual_out_nhead,
a, stride_a_batch, stride_a_nhead,
b, stride_b_batch, stride_b_nhead,
weight, _info.wtype, _info.epsilon, stream_));
} else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
......
......@@ -53,12 +53,12 @@ infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t residual_out_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc,
infiniopTensorDescriptor_t weight_desc,
float epsilon,
infiniopTensorDescriptor_t residual_out_desc) {
auto result = AddRMSNormInfo::create(y_desc, a_desc, b_desc, weight_desc, epsilon, residual_out_desc);
float epsilon) {
auto result = AddRMSNormInfo::create(y_desc, residual_out_desc, a_desc, b_desc, weight_desc, epsilon);
CHECK_RESULT(result);
auto info = result.take();
......@@ -128,8 +128,8 @@ infiniStatus_t launchKernel(
// Main calculation function
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *y, const void *a, const void *b, const void *weight,
void *residual_out, void *stream) const {
void *y, void *residual_out, const void *a, const void *b, const void *weight,
void *stream) const {
// Check workspace size
if (workspace_size < _workspace_size) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment