Commit 0611cb1b authored by wooway777's avatar wooway777
Browse files

issue/791 - fix add_rmsnorm api on mtx and mth

parent 4ddc6647
...@@ -8,8 +8,10 @@ ...@@ -8,8 +8,10 @@
// Posible maximum number of threads per block for METAX architectures // Posible maximum number of threads per block for METAX architectures
// Used for picking correct kernel launch configuration // Used for picking correct kernel launch configuration
#define METAX_BLOCK_SIZE_1024 1024
#define METAX_BLOCK_SIZE_512 512 #define METAX_BLOCK_SIZE_512 512
#define METAX_BLOCK_SIZE_1024 1024
#define METAX_BLOCK_SIZE_2048 2048
#define METAX_BLOCK_SIZE_4096 4096
#define CHECK_METAX(API) CHECK_INTERNAL(API, hcSuccess) #define CHECK_METAX(API) CHECK_INTERNAL(API, hcSuccess)
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
// Posible maximum number of threads per block for MUSA architectures // Posible maximum number of threads per block for MUSA architectures
// Used for picking correct kernel launch configuration // Used for picking correct kernel launch configuration
#define MOORE_BLOCK_SIZE_4096 4096
#define MOORE_BLOCK_SIZE_2048 2048 #define MOORE_BLOCK_SIZE_2048 2048
#define MOORE_BLOCK_SIZE_1024 1024 #define MOORE_BLOCK_SIZE_1024 1024
#define MOORE_BLOCK_SIZE_512 512 #define MOORE_BLOCK_SIZE_512 512
......
...@@ -53,12 +53,12 @@ infiniStatus_t Descriptor::create( ...@@ -53,12 +53,12 @@ infiniStatus_t Descriptor::create(
infiniopHandle_t handle, infiniopHandle_t handle,
Descriptor **desc_ptr, Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc, infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t residual_out_desc,
infiniopTensorDescriptor_t a_desc, infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc, infiniopTensorDescriptor_t b_desc,
infiniopTensorDescriptor_t weight_desc, infiniopTensorDescriptor_t weight_desc,
float epsilon, float epsilon) {
infiniopTensorDescriptor_t residual_out_desc) { auto result = AddRMSNormInfo::create(y_desc, residual_out_desc, a_desc, b_desc, weight_desc, epsilon);
auto result = AddRMSNormInfo::create(y_desc, a_desc, b_desc, weight_desc, epsilon, residual_out_desc);
CHECK_RESULT(result); CHECK_RESULT(result);
auto info = result.take(); auto info = result.take();
...@@ -104,16 +104,16 @@ infiniStatus_t launchKernel( ...@@ -104,16 +104,16 @@ infiniStatus_t launchKernel(
// Handle different data type combinations following Metax pattern // Handle different data type combinations following Metax pattern
if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) { if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) {
LAUNCH_KERNEL(half, half, float); LAUNCH_KERNEL(half, half, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_BF16) {
LAUNCH_KERNEL(__hpcc_bfloat16, __hpcc_bfloat16, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(__hpcc_bfloat16, float, float);
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(half, float, float);
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_BF16) { } else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_BF16) {
LAUNCH_KERNEL(half, __hpcc_bfloat16, float); LAUNCH_KERNEL(half, __hpcc_bfloat16, float);
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(half, float, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_BF16) {
LAUNCH_KERNEL(__hpcc_bfloat16, __hpcc_bfloat16, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F16) { } else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F16) {
LAUNCH_KERNEL(__hpcc_bfloat16, half, float); LAUNCH_KERNEL(__hpcc_bfloat16, half, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(__hpcc_bfloat16, float, float);
} else if (atype == INFINI_DTYPE_F32 && wtype == INFINI_DTYPE_F32) { } else if (atype == INFINI_DTYPE_F32 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(float, float, float); LAUNCH_KERNEL(float, float, float);
} else { } else {
...@@ -128,8 +128,8 @@ infiniStatus_t launchKernel( ...@@ -128,8 +128,8 @@ infiniStatus_t launchKernel(
// Main calculation function // Main calculation function
infiniStatus_t Descriptor::calculate( infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size, void *workspace, size_t workspace_size,
void *y, const void *a, const void *b, const void *weight, void *y, void *residual_out, const void *a, const void *b, const void *weight,
void *residual_out, void *stream_) const { void *stream) const {
// Check workspace size // Check workspace size
if (workspace_size < _workspace_size) { if (workspace_size < _workspace_size) {
...@@ -148,17 +148,41 @@ infiniStatus_t Descriptor::calculate( ...@@ -148,17 +148,41 @@ infiniStatus_t Descriptor::calculate(
auto dim = _info.dim(); auto dim = _info.dim();
uint32_t batch_size = static_cast<uint32_t>(_info.shape[0]); uint32_t batch_size = static_cast<uint32_t>(_info.shape[0]);
size_t nhead = _info.shape.size() > 2 ? _info.shape[1] : 1; size_t nhead = _info.shape.size() > 2 ? _info.shape[1] : 1;
auto stream = reinterpret_cast<hcStream_t>(stream_); auto stream_ = reinterpret_cast<hcStream_t>(stream);
// Launch kernel with appropriate block size based on device capability // Launch kernel with different block sizes
if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_1024) { if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_512) {
CHECK_STATUS(launchKernel<METAX_BLOCK_SIZE_512>(
batch_size, nhead, dim,
y, _info.atype, stride_y_batch, stride_y_nhead,
residual_out, stride_residual_out_batch, stride_residual_out_nhead,
a, stride_a_batch, stride_a_nhead,
b, stride_b_batch, stride_b_nhead,
weight, _info.wtype, _info.epsilon, stream_));
} else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_1024) {
CHECK_STATUS(launchKernel<METAX_BLOCK_SIZE_1024>( CHECK_STATUS(launchKernel<METAX_BLOCK_SIZE_1024>(
batch_size, nhead, dim, batch_size, nhead, dim,
y, _info.atype, stride_y_batch, stride_y_nhead, y, _info.atype, stride_y_batch, stride_y_nhead,
residual_out, stride_residual_out_batch, stride_residual_out_nhead, residual_out, stride_residual_out_batch, stride_residual_out_nhead,
a, stride_a_batch, stride_a_nhead, a, stride_a_batch, stride_a_nhead,
b, stride_b_batch, stride_b_nhead, b, stride_b_batch, stride_b_nhead,
weight, _info.wtype, _info.epsilon, stream)); weight, _info.wtype, _info.epsilon, stream_));
} else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_2048) {
CHECK_STATUS(launchKernel<METAX_BLOCK_SIZE_2048>(
batch_size, nhead, dim,
y, _info.atype, stride_y_batch, stride_y_nhead,
residual_out, stride_residual_out_batch, stride_residual_out_nhead,
a, stride_a_batch, stride_a_nhead,
b, stride_b_batch, stride_b_nhead,
weight, _info.wtype, _info.epsilon, stream_));
} else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_4096) {
CHECK_STATUS(launchKernel<METAX_BLOCK_SIZE_4096>(
batch_size, nhead, dim,
y, _info.atype, stride_y_batch, stride_y_nhead,
residual_out, stride_residual_out_batch, stride_residual_out_nhead,
a, stride_a_batch, stride_a_nhead,
b, stride_b_batch, stride_b_nhead,
weight, _info.wtype, _info.epsilon, stream_));
} else { } else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
} }
......
...@@ -53,12 +53,12 @@ infiniStatus_t Descriptor::create( ...@@ -53,12 +53,12 @@ infiniStatus_t Descriptor::create(
infiniopHandle_t handle, infiniopHandle_t handle,
Descriptor **desc_ptr, Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc, infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t residual_out_desc,
infiniopTensorDescriptor_t a_desc, infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc, infiniopTensorDescriptor_t b_desc,
infiniopTensorDescriptor_t weight_desc, infiniopTensorDescriptor_t weight_desc,
float epsilon, float epsilon) {
infiniopTensorDescriptor_t residual_out_desc) { auto result = AddRMSNormInfo::create(y_desc, residual_out_desc, a_desc, b_desc, weight_desc, epsilon);
auto result = AddRMSNormInfo::create(y_desc, a_desc, b_desc, weight_desc, epsilon, residual_out_desc);
CHECK_RESULT(result); CHECK_RESULT(result);
auto info = result.take(); auto info = result.take();
...@@ -128,8 +128,8 @@ infiniStatus_t launchKernel( ...@@ -128,8 +128,8 @@ infiniStatus_t launchKernel(
// Main calculation function // Main calculation function
infiniStatus_t Descriptor::calculate( infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size, void *workspace, size_t workspace_size,
void *y, const void *a, const void *b, const void *weight, void *y, void *residual_out, const void *a, const void *b, const void *weight,
void *residual_out, void *stream) const { void *stream) const {
// Check workspace size // Check workspace size
if (workspace_size < _workspace_size) { if (workspace_size < _workspace_size) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment