Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
e1cba119
Commit
e1cba119
authored
Aug 15, 2025
by
zhushuang
Browse files
fix: add bf16 support and resolve build issues for rms_norm in moore gpu
parent
1d1e0649
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
121 additions
and
13 deletions
+121
-13
src/infiniop/devices/musa/musa_kernel_common.h
src/infiniop/devices/musa/musa_kernel_common.h
+74
-0
src/infiniop/ops/rms_norm/musa/rms_norm_musa.h
src/infiniop/ops/rms_norm/musa/rms_norm_musa.h
+8
-0
src/infiniop/ops/rms_norm/musa/rms_norm_musa.mu
src/infiniop/ops/rms_norm/musa/rms_norm_musa.mu
+38
-12
src/infiniop/ops/rms_norm/operator.cc
src/infiniop/ops/rms_norm/operator.cc
+1
-1
No files found.
src/infiniop/devices/musa/musa_kernel_common.h
0 → 100644
View file @
e1cba119
#define INFINIOP_MUSA_KERNEL __global__ void
#include <musa_bf16.h>
#include <musa_fp16.h>
// Posible maximum number of threads per block for MUSA architectures
// Used for picking correct kernel launch configuration
#define MUSA_BLOCK_SIZE_2048 2048
#define MUSA_BLOCK_SIZE_1024 1024
#define MUSA_BLOCK_SIZE_512 512
#define CHECK_MUSA(API) CHECK_INTERNAL(API, musaSuccess)
using
musa_bfloat16
=
mt_bfloat16
;
using
musa_bfloat162
=
mt_bfloat162
;
namespace
device
::
musa
{
// return the memory offset of original tensor, given the flattened index of broadcasted tensor
__forceinline__
__device__
__host__
size_t
indexToReducedOffset
(
size_t
flat_index
,
size_t
ndim
,
const
ptrdiff_t
*
broadcasted_strides
,
const
ptrdiff_t
*
target_strides
)
{
size_t
res
=
0
;
for
(
size_t
i
=
0
;
i
<
ndim
;
++
i
)
{
res
+=
flat_index
/
broadcasted_strides
[
i
]
*
target_strides
[
i
];
flat_index
%=
broadcasted_strides
[
i
];
}
return
res
;
}
// get the memory offset of the given element in a tensor given its flat index
__forceinline__
__device__
__host__
size_t
indexToOffset
(
size_t
flat_index
,
size_t
ndim
,
const
size_t
*
shape
,
const
ptrdiff_t
*
strides
)
{
size_t
res
=
0
;
for
(
size_t
i
=
ndim
;
i
--
>
0
;)
{
res
+=
(
flat_index
%
shape
[
i
])
*
strides
[
i
];
flat_index
/=
shape
[
i
];
}
return
res
;
}
}
// namespace device::musa
__forceinline__
__device__
float
exp_
(
const
float
val
)
{
return
expf
(
val
);
}
__forceinline__
__device__
double
exp_
(
const
double
val
)
{
return
exp
(
val
);
}
// <musa_bf16.h> may not support hexp
__forceinline__
__device__
__half
exp_
(
const
__half
x
)
{
float
f_val
=
__half2float
(
x
);
float
f_result
=
expf
(
f_val
);
return
__float2half
(
f_result
);
}
// <musa_bf16.h> may not support hexp
__forceinline__
__device__
__mt_bfloat16
exp_
(
const
__mt_bfloat16
x
)
{
float
f_val
=
__bfloat162float
(
x
);
float
f_result
=
expf
(
f_val
);
return
__float2bfloat16
(
f_result
);
}
src/infiniop/ops/rms_norm/musa/rms_norm_musa.
cu
h
→
src/infiniop/ops/rms_norm/musa/rms_norm_musa.h
View file @
e1cba119
#ifndef __RMS_NORM_MUSA_
CU
H__
#ifndef __RMS_NORM_MUSA_H__
#define __RMS_NORM_MUSA_
CU
H__
#define __RMS_NORM_MUSA_H__
#include "../rms_norm.h"
#include "../rms_norm.h"
...
...
src/infiniop/ops/rms_norm/musa/rms_norm_musa.mu
View file @
e1cba119
#include "../../../devices/musa/common_musa.h"
#include "../../../devices/musa/common_musa.h"
#include "../cuda/rms_norm_kernel.cuh"
#include "rms_norm_musa.h"
#include "rms_norm_musa.cuh"
#include "../../../devices/musa/musa_kernel_common.h"
#include <cub/block/block_reduce.cuh>
#include "../../../reduce/cuda/reduce.cuh"
#include "../cuda/kernel.cuh"
template <unsigned int BLOCK_SIZE, typename Tcompute, typename Tdata, typename Tweight>
INFINIOP_MUSA_KERNEL rmsnormKernel(
Tdata *__restrict__ y,
ptrdiff_t stride_y,
const Tdata *__restrict__ x,
ptrdiff_t stride_x,
const Tweight *__restrict__ w,
size_t dim,
float epsilon) {
rmsnormBlock<BLOCK_SIZE, Tcompute>(y, stride_y, x, stride_x, w, dim, epsilon);
}
namespace op::rms_norm::musa {
namespace op::rms_norm::musa {
...
@@ -46,20 +64,24 @@ infiniStatus_t launchKernel(
...
@@ -46,20 +64,24 @@ infiniStatus_t launchKernel(
float epsilon,
float epsilon,
musaStream_t musa_stream) {
musaStream_t musa_stream) {
#define LAUNCH_KERNEL(Tdata, Tweight, Tcompute) \
#define LAUNCH_KERNEL(Tdata, Tweight, Tcompute)
\
rmsnorm
Block
<BLOCK_SIZE, Tdata, Tweight
, Tcompute
><<<batch_size, BLOCK_SIZE, 0, musa_stream>>>( \
rmsnorm
Kernel
<BLOCK_SIZE,
Tcompute,
Tdata, Tweight><<<batch_size, BLOCK_SIZE, 0, musa_stream>>>( \
reinterpret_cast<Tdata *>(y), \
reinterpret_cast<Tdata *>(y),
\
stride_y, \
stride_y,
\
reinterpret_cast<const Tdata *>(x), \
reinterpret_cast<const Tdata *>(x),
\
stride_x, \
stride_x,
\
reinterpret_cast<const Tweight *>(w), \
reinterpret_cast<const Tweight *>(w),
\
dim, \
dim,
\
epsilon)
epsilon)
if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) {
if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) {
LAUNCH_KERNEL(half, half, float);
LAUNCH_KERNEL(half, half, float);
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F32) {
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(half, float, float);
LAUNCH_KERNEL(half, float, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_BF16) {
LAUNCH_KERNEL(__mt_bfloat16, __mt_bfloat16, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(__mt_bfloat16, float, float);
} else if (atype == INFINI_DTYPE_F32 && wtype == INFINI_DTYPE_F32) {
} else if (atype == INFINI_DTYPE_F32 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(float, float, float);
LAUNCH_KERNEL(float, float, float);
} else {
} else {
...
@@ -87,8 +109,12 @@ infiniStatus_t Descriptor::calculate(
...
@@ -87,8 +109,12 @@ infiniStatus_t Descriptor::calculate(
auto musa_stream = reinterpret_cast<musaStream_t>(stream);
auto musa_stream = reinterpret_cast<musaStream_t>(stream);
// launch kernel with different block sizes
// launch kernel with different block sizes
if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) {
if (_opaque->internal->maxThreadsPerBlock() == MUSA_BLOCK_SIZE_1024) {
CHECK_STATUS(launchKernel<CUDA_BLOCK_SIZE_1024>(batch_size, dim, y, _info.atype, stride_y, x, stride_x, w, _info.wtype, _info.epsilon, musa_stream));
CHECK_STATUS(launchKernel<MUSA_BLOCK_SIZE_1024>(batch_size, dim, y, _info.atype, stride_y, x, stride_x, w, _info.wtype, _info.epsilon, musa_stream));
} else if (_opaque->internal->maxThreadsPerBlock() == MUSA_BLOCK_SIZE_512) {
CHECK_STATUS(launchKernel<MUSA_BLOCK_SIZE_512>(batch_size, dim, y, _info.atype, stride_y, x, stride_x, w, _info.wtype, _info.epsilon, musa_stream));
} else if (_opaque->internal->maxThreadsPerBlock() == MUSA_BLOCK_SIZE_2048) {
CHECK_STATUS(launchKernel<MUSA_BLOCK_SIZE_2048>(batch_size, dim, y, _info.atype, stride_y, x, stride_x, w, _info.wtype, _info.epsilon, musa_stream));
} else {
} else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
}
...
...
src/infiniop/ops/rms_norm/operator.cc
View file @
e1cba119
...
@@ -15,7 +15,7 @@
...
@@ -15,7 +15,7 @@
#include "metax/rms_norm_metax.cuh"
#include "metax/rms_norm_metax.cuh"
#endif
#endif
#ifdef ENABLE_MOORE_API
#ifdef ENABLE_MOORE_API
#include "musa/rms_norm_musa.
cu
h"
#include "musa/rms_norm_musa.h"
#endif
#endif
#ifdef ENABLE_KUNLUN_API
#ifdef ENABLE_KUNLUN_API
#include "kunlun/rms_norm_kunlun.h"
#include "kunlun/rms_norm_kunlun.h"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment