Unverified Commit 81a5f627 authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #374 from InfiniTensor/issue/369

Issue/369: add bf16 support and resolve build issues for rms_norm in moore gpu
parents 1d1e0649 e1cba119
#define INFINIOP_MUSA_KERNEL __global__ void
#include <musa_bf16.h>
#include <musa_fp16.h>
// Posible maximum number of threads per block for MUSA architectures
// Used for picking correct kernel launch configuration
#define MUSA_BLOCK_SIZE_2048 2048
#define MUSA_BLOCK_SIZE_1024 1024
#define MUSA_BLOCK_SIZE_512 512
#define CHECK_MUSA(API) CHECK_INTERNAL(API, musaSuccess)
using musa_bfloat16 = mt_bfloat16;
using musa_bfloat162 = mt_bfloat162;
namespace device::musa {
// return the memory offset of original tensor, given the flattened index of broadcasted tensor
__forceinline__ __device__ __host__ size_t
indexToReducedOffset(
size_t flat_index,
size_t ndim,
const ptrdiff_t *broadcasted_strides,
const ptrdiff_t *target_strides) {
size_t res = 0;
for (size_t i = 0; i < ndim; ++i) {
res += flat_index / broadcasted_strides[i] * target_strides[i];
flat_index %= broadcasted_strides[i];
}
return res;
}
// get the memory offset of the given element in a tensor given its flat index
__forceinline__ __device__ __host__ size_t
indexToOffset(
size_t flat_index,
size_t ndim,
const size_t *shape,
const ptrdiff_t *strides) {
size_t res = 0;
for (size_t i = ndim; i-- > 0;) {
res += (flat_index % shape[i]) * strides[i];
flat_index /= shape[i];
}
return res;
}
} // namespace device::musa
__forceinline__ __device__ float
exp_(const float val) {
return expf(val);
}
__forceinline__ __device__ double
exp_(const double val) {
return exp(val);
}
// <musa_bf16.h> may not support hexp
__forceinline__ __device__ __half
exp_(const __half x) {
float f_val = __half2float(x);
float f_result = expf(f_val);
return __float2half(f_result);
}
// <musa_bf16.h> may not support hexp
__forceinline__ __device__ __mt_bfloat16
exp_(const __mt_bfloat16 x) {
float f_val = __bfloat162float(x);
float f_result = expf(f_val);
return __float2bfloat16(f_result);
}
#ifndef __RMS_NORM_MUSA_CUH__ #ifndef __RMS_NORM_MUSA_H__
#define __RMS_NORM_MUSA_CUH__ #define __RMS_NORM_MUSA_H__
#include "../rms_norm.h" #include "../rms_norm.h"
......
#include "../../../devices/musa/common_musa.h" #include "../../../devices/musa/common_musa.h"
#include "../cuda/rms_norm_kernel.cuh" #include "rms_norm_musa.h"
#include "rms_norm_musa.cuh"
#include "../../../devices/musa/musa_kernel_common.h"
#include <cub/block/block_reduce.cuh>
#include "../../../reduce/cuda/reduce.cuh"
#include "../cuda/kernel.cuh"
template <unsigned int BLOCK_SIZE, typename Tcompute, typename Tdata, typename Tweight>
INFINIOP_MUSA_KERNEL rmsnormKernel(
Tdata *__restrict__ y,
ptrdiff_t stride_y,
const Tdata *__restrict__ x,
ptrdiff_t stride_x,
const Tweight *__restrict__ w,
size_t dim,
float epsilon) {
rmsnormBlock<BLOCK_SIZE, Tcompute>(y, stride_y, x, stride_x, w, dim, epsilon);
}
namespace op::rms_norm::musa { namespace op::rms_norm::musa {
...@@ -46,20 +64,24 @@ infiniStatus_t launchKernel( ...@@ -46,20 +64,24 @@ infiniStatus_t launchKernel(
float epsilon, float epsilon,
musaStream_t musa_stream) { musaStream_t musa_stream) {
#define LAUNCH_KERNEL(Tdata, Tweight, Tcompute) \ #define LAUNCH_KERNEL(Tdata, Tweight, Tcompute) \
rmsnormBlock<BLOCK_SIZE, Tdata, Tweight, Tcompute><<<batch_size, BLOCK_SIZE, 0, musa_stream>>>( \ rmsnormKernel<BLOCK_SIZE, Tcompute, Tdata, Tweight><<<batch_size, BLOCK_SIZE, 0, musa_stream>>>( \
reinterpret_cast<Tdata *>(y), \ reinterpret_cast<Tdata *>(y), \
stride_y, \ stride_y, \
reinterpret_cast<const Tdata *>(x), \ reinterpret_cast<const Tdata *>(x), \
stride_x, \ stride_x, \
reinterpret_cast<const Tweight *>(w), \ reinterpret_cast<const Tweight *>(w), \
dim, \ dim, \
epsilon) epsilon)
if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) { if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) {
LAUNCH_KERNEL(half, half, float); LAUNCH_KERNEL(half, half, float);
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F32) { } else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(half, float, float); LAUNCH_KERNEL(half, float, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_BF16) {
LAUNCH_KERNEL(__mt_bfloat16, __mt_bfloat16, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(__mt_bfloat16, float, float);
} else if (atype == INFINI_DTYPE_F32 && wtype == INFINI_DTYPE_F32) { } else if (atype == INFINI_DTYPE_F32 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL(float, float, float); LAUNCH_KERNEL(float, float, float);
} else { } else {
...@@ -87,8 +109,12 @@ infiniStatus_t Descriptor::calculate( ...@@ -87,8 +109,12 @@ infiniStatus_t Descriptor::calculate(
auto musa_stream = reinterpret_cast<musaStream_t>(stream); auto musa_stream = reinterpret_cast<musaStream_t>(stream);
// launch kernel with different block sizes // launch kernel with different block sizes
if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) { if (_opaque->internal->maxThreadsPerBlock() == MUSA_BLOCK_SIZE_1024) {
CHECK_STATUS(launchKernel<CUDA_BLOCK_SIZE_1024>(batch_size, dim, y, _info.atype, stride_y, x, stride_x, w, _info.wtype, _info.epsilon, musa_stream)); CHECK_STATUS(launchKernel<MUSA_BLOCK_SIZE_1024>(batch_size, dim, y, _info.atype, stride_y, x, stride_x, w, _info.wtype, _info.epsilon, musa_stream));
} else if (_opaque->internal->maxThreadsPerBlock() == MUSA_BLOCK_SIZE_512) {
CHECK_STATUS(launchKernel<MUSA_BLOCK_SIZE_512>(batch_size, dim, y, _info.atype, stride_y, x, stride_x, w, _info.wtype, _info.epsilon, musa_stream));
} else if (_opaque->internal->maxThreadsPerBlock() == MUSA_BLOCK_SIZE_2048) {
CHECK_STATUS(launchKernel<MUSA_BLOCK_SIZE_2048>(batch_size, dim, y, _info.atype, stride_y, x, stride_x, w, _info.wtype, _info.epsilon, musa_stream));
} else { } else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
} }
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#include "metax/rms_norm_metax.cuh" #include "metax/rms_norm_metax.cuh"
#endif #endif
#ifdef ENABLE_MOORE_API #ifdef ENABLE_MOORE_API
#include "musa/rms_norm_musa.cuh" #include "musa/rms_norm_musa.h"
#endif #endif
#ifdef ENABLE_KUNLUN_API #ifdef ENABLE_KUNLUN_API
#include "kunlun/rms_norm_kunlun.h" #include "kunlun/rms_norm_kunlun.h"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment