Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
0611cb1b
Commit
0611cb1b
authored
Jan 26, 2026
by
wooway777
Browse files
issue/791 - fix add_rmsnorm api on mtx and mth
parent
4ddc6647
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
48 additions
and
21 deletions
+48
-21
src/infiniop/devices/metax/metax_kernel_common.h
src/infiniop/devices/metax/metax_kernel_common.h
+3
-1
src/infiniop/devices/moore/moore_kernel_common.h
src/infiniop/devices/moore/moore_kernel_common.h
+1
-0
src/infiniop/ops/add_rms_norm/metax/add_rms_norm_metax.maca
src/infiniop/ops/add_rms_norm/metax/add_rms_norm_metax.maca
+39
-15
src/infiniop/ops/add_rms_norm/moore/add_rms_norm_moore.mu
src/infiniop/ops/add_rms_norm/moore/add_rms_norm_moore.mu
+5
-5
No files found.
src/infiniop/devices/metax/metax_kernel_common.h
View file @
0611cb1b
...
@@ -8,8 +8,10 @@
...
@@ -8,8 +8,10 @@
// Posible maximum number of threads per block for METAX architectures
// Posible maximum number of threads per block for METAX architectures
// Used for picking correct kernel launch configuration
// Used for picking correct kernel launch configuration
#define METAX_BLOCK_SIZE_1024 1024
#define METAX_BLOCK_SIZE_512 512
#define METAX_BLOCK_SIZE_512 512
#define METAX_BLOCK_SIZE_1024 1024
#define METAX_BLOCK_SIZE_2048 2048
#define METAX_BLOCK_SIZE_4096 4096
#define CHECK_METAX(API) CHECK_INTERNAL(API, hcSuccess)
#define CHECK_METAX(API) CHECK_INTERNAL(API, hcSuccess)
...
...
src/infiniop/devices/moore/moore_kernel_common.h
View file @
0611cb1b
...
@@ -6,6 +6,7 @@
...
@@ -6,6 +6,7 @@
// Posible maximum number of threads per block for MUSA architectures
// Posible maximum number of threads per block for MUSA architectures
// Used for picking correct kernel launch configuration
// Used for picking correct kernel launch configuration
#define MOORE_BLOCK_SIZE_4096 4096
#define MOORE_BLOCK_SIZE_2048 2048
#define MOORE_BLOCK_SIZE_2048 2048
#define MOORE_BLOCK_SIZE_1024 1024
#define MOORE_BLOCK_SIZE_1024 1024
#define MOORE_BLOCK_SIZE_512 512
#define MOORE_BLOCK_SIZE_512 512
...
...
src/infiniop/ops/add_rms_norm/metax/add_rms_norm_metax.maca
View file @
0611cb1b
...
@@ -53,12 +53,12 @@ infiniStatus_t Descriptor::create(
...
@@ -53,12 +53,12 @@ infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
infiniopHandle_t handle,
Descriptor **desc_ptr,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t residual_out_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc,
infiniopTensorDescriptor_t b_desc,
infiniopTensorDescriptor_t weight_desc,
infiniopTensorDescriptor_t weight_desc,
float epsilon,
float epsilon) {
infiniopTensorDescriptor_t residual_out_desc) {
auto result = AddRMSNormInfo::create(y_desc, residual_out_desc, a_desc, b_desc, weight_desc, epsilon);
auto result = AddRMSNormInfo::create(y_desc, a_desc, b_desc, weight_desc, epsilon, residual_out_desc);
CHECK_RESULT(result);
CHECK_RESULT(result);
auto info = result.take();
auto info = result.take();
...
@@ -104,16 +104,16 @@ infiniStatus_t launchKernel(
...
@@ -104,16 +104,16 @@ infiniStatus_t launchKernel(
// Handle different data type combinations following Metax pattern
// Handle different data type combinations following Metax pattern
if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) {
if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) {
LAUNCH_KERNEL(half, half, float);
LAUNCH_KERNEL(half, half, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_BF16) {
LAUNCH_KERNEL(__hpcc_bfloat16, __hpcc_bfloat16, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(__hpcc_bfloat16, float, float);
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(half, float, float);
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_BF16) {
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_BF16) {
LAUNCH_KERNEL(half, __hpcc_bfloat16, float);
LAUNCH_KERNEL(half, __hpcc_bfloat16, float);
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(half, float, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_BF16) {
LAUNCH_KERNEL(__hpcc_bfloat16, __hpcc_bfloat16, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F16) {
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F16) {
LAUNCH_KERNEL(__hpcc_bfloat16, half, float);
LAUNCH_KERNEL(__hpcc_bfloat16, half, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(__hpcc_bfloat16, float, float);
} else if (atype == INFINI_DTYPE_F32 && wtype == INFINI_DTYPE_F32) {
} else if (atype == INFINI_DTYPE_F32 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(float, float, float);
LAUNCH_KERNEL(float, float, float);
} else {
} else {
...
@@ -128,8 +128,8 @@ infiniStatus_t launchKernel(
...
@@ -128,8 +128,8 @@ infiniStatus_t launchKernel(
// Main calculation function
// Main calculation function
infiniStatus_t Descriptor::calculate(
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *workspace, size_t workspace_size,
void *y, const void *a, const void *b, const void *weight,
void *y,
void *residual_out,
const void *a, const void *b, const void *weight,
void
*residual_out, void
*stream
_
) const {
void *stream) const {
// Check workspace size
// Check workspace size
if (workspace_size < _workspace_size) {
if (workspace_size < _workspace_size) {
...
@@ -148,17 +148,41 @@ infiniStatus_t Descriptor::calculate(
...
@@ -148,17 +148,41 @@ infiniStatus_t Descriptor::calculate(
auto dim = _info.dim();
auto dim = _info.dim();
uint32_t batch_size = static_cast<uint32_t>(_info.shape[0]);
uint32_t batch_size = static_cast<uint32_t>(_info.shape[0]);
size_t nhead = _info.shape.size() > 2 ? _info.shape[1] : 1;
size_t nhead = _info.shape.size() > 2 ? _info.shape[1] : 1;
auto stream = reinterpret_cast<hcStream_t>(stream
_
);
auto stream
_
= reinterpret_cast<hcStream_t>(stream);
// Launch kernel with appropriate block size based on device capability
// Launch kernel with different block sizes
if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_1024) {
if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_512) {
CHECK_STATUS(launchKernel<METAX_BLOCK_SIZE_512>(
batch_size, nhead, dim,
y, _info.atype, stride_y_batch, stride_y_nhead,
residual_out, stride_residual_out_batch, stride_residual_out_nhead,
a, stride_a_batch, stride_a_nhead,
b, stride_b_batch, stride_b_nhead,
weight, _info.wtype, _info.epsilon, stream_));
} else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_1024) {
CHECK_STATUS(launchKernel<METAX_BLOCK_SIZE_1024>(
CHECK_STATUS(launchKernel<METAX_BLOCK_SIZE_1024>(
batch_size, nhead, dim,
batch_size, nhead, dim,
y, _info.atype, stride_y_batch, stride_y_nhead,
y, _info.atype, stride_y_batch, stride_y_nhead,
residual_out, stride_residual_out_batch, stride_residual_out_nhead,
residual_out, stride_residual_out_batch, stride_residual_out_nhead,
a, stride_a_batch, stride_a_nhead,
a, stride_a_batch, stride_a_nhead,
b, stride_b_batch, stride_b_nhead,
b, stride_b_batch, stride_b_nhead,
weight, _info.wtype, _info.epsilon, stream));
weight, _info.wtype, _info.epsilon, stream_));
} else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_2048) {
CHECK_STATUS(launchKernel<METAX_BLOCK_SIZE_2048>(
batch_size, nhead, dim,
y, _info.atype, stride_y_batch, stride_y_nhead,
residual_out, stride_residual_out_batch, stride_residual_out_nhead,
a, stride_a_batch, stride_a_nhead,
b, stride_b_batch, stride_b_nhead,
weight, _info.wtype, _info.epsilon, stream_));
} else if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_4096) {
CHECK_STATUS(launchKernel<METAX_BLOCK_SIZE_4096>(
batch_size, nhead, dim,
y, _info.atype, stride_y_batch, stride_y_nhead,
residual_out, stride_residual_out_batch, stride_residual_out_nhead,
a, stride_a_batch, stride_a_nhead,
b, stride_b_batch, stride_b_nhead,
weight, _info.wtype, _info.epsilon, stream_));
} else {
} else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
}
...
...
src/infiniop/ops/add_rms_norm/moore/add_rms_norm_moore.mu
View file @
0611cb1b
...
@@ -53,12 +53,12 @@ infiniStatus_t Descriptor::create(
...
@@ -53,12 +53,12 @@ infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
infiniopHandle_t handle,
Descriptor **desc_ptr,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t residual_out_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc,
infiniopTensorDescriptor_t b_desc,
infiniopTensorDescriptor_t weight_desc,
infiniopTensorDescriptor_t weight_desc,
float epsilon,
float epsilon) {
infiniopTensorDescriptor_t residual_out_desc) {
auto result = AddRMSNormInfo::create(y_desc, residual_out_desc, a_desc, b_desc, weight_desc, epsilon);
auto result = AddRMSNormInfo::create(y_desc, a_desc, b_desc, weight_desc, epsilon, residual_out_desc);
CHECK_RESULT(result);
CHECK_RESULT(result);
auto info = result.take();
auto info = result.take();
...
@@ -128,8 +128,8 @@ infiniStatus_t launchKernel(
...
@@ -128,8 +128,8 @@ infiniStatus_t launchKernel(
// Main calculation function
// Main calculation function
infiniStatus_t Descriptor::calculate(
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *workspace, size_t workspace_size,
void *y, const void *a, const void *b, const void *weight,
void *y,
void *residual_out,
const void *a, const void *b, const void *weight,
void
*residual_out, void
*stream) const {
void *stream) const {
// Check workspace size
// Check workspace size
if (workspace_size < _workspace_size) {
if (workspace_size < _workspace_size) {
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment