Unverified Commit d4b03cf7 authored by spike-zhu's avatar spike-zhu Committed by GitHub
Browse files

issue/246: add Elementwise and SwiGLU in moore gpu

parent 9b758b9b
...@@ -11,8 +11,8 @@ ...@@ -11,8 +11,8 @@
#define CHECK_MOORE(API) CHECK_INTERNAL(API, musaSuccess) #define CHECK_MOORE(API) CHECK_INTERNAL(API, musaSuccess)
using musa_bfloat16 = mt_bfloat16; using cuda_bfloat16 = mt_bfloat16;
using musa_bfloat162 = mt_bfloat162; using cuda_bfloat162 = mt_bfloat162;
namespace device::moore { namespace device::moore {
...@@ -52,6 +52,11 @@ exp_(const float val) { ...@@ -52,6 +52,11 @@ exp_(const float val) {
return expf(val); return expf(val);
} }
__forceinline__ __device__ long double
exp_(const long double val) {
return exp(val);
}
__forceinline__ __device__ double __forceinline__ __device__ double
exp_(const double val) { exp_(const double val) {
return exp(val); return exp(val);
......
#ifndef __INFINIOP_ELEMENTWISE_MOORE_H__
#define __INFINIOP_ELEMENTWISE_MOORE_H__
#include "../../../utils.h"
#include "../../devices/moore/moore_common.h"
#include "../../devices/moore/moore_kernel_common.h"
#include "elementwise_moore_api.h"
namespace op::elementwise::moore {
template <typename T>
__device__ __forceinline__ const T *typedInputPtr(const void *ptr) {
return reinterpret_cast<const T *>(ptr);
}
__device__ __forceinline__ size_t getOutputIndex(size_t idx, bool is_contiguous, size_t ndim,
const size_t *shape, const ptrdiff_t *strides) {
return is_contiguous ? idx : device::moore::indexToOffset(idx, ndim, shape, strides);
}
struct InputIndexer {
size_t idx;
size_t ndim;
const bool *input_contiguous;
const bool *input_broadcasted;
const size_t *input_shapes;
const ptrdiff_t *input_strides;
const ptrdiff_t *output_strides;
__device__ __forceinline__ size_t operator()(size_t input_id) const {
return input_contiguous[input_id]
? idx
: (input_broadcasted[input_id]
? device::moore::indexToReducedOffset(idx, ndim, output_strides, input_strides + input_id * ndim)
: device::moore::indexToOffset(idx, ndim, input_shapes + input_id * ndim, input_strides + input_id * ndim));
}
};
template <typename F, size_t... Is>
__device__ __forceinline__ void unpackInputsAndApply(F &&f, std::index_sequence<Is...>) {
f(std::integral_constant<size_t, Is>{}...);
}
template <size_t N, typename Op, typename Tdata, typename... Args>
INFINIOP_MOORE_KERNEL elementwiseKernel(
size_t output_size,
size_t ndim,
bool output_contiguous,
const bool *__restrict__ input_contiguous,
const bool *__restrict__ input_broadcasted,
const size_t *__restrict__ output_shape,
const size_t *__restrict__ input_shapes,
const ptrdiff_t *__restrict__ output_strides,
const ptrdiff_t *__restrict__ input_strides,
Tdata *output,
const void *const *inputs,
size_t offset,
Args... args) {
size_t idx = blockIdx.x * blockDim.x + threadIdx.x + offset;
if (idx < output_size) {
const Tdata *const *typed_inputs = reinterpret_cast<const Tdata *const *>(inputs);
size_t out_idx = getOutputIndex(idx, output_contiguous, ndim, output_shape, output_strides);
InputIndexer indexer{idx, ndim, input_contiguous, input_broadcasted, input_shapes, input_strides, output_strides};
unpackInputsAndApply(
[&](auto... Is) {
output[out_idx] = Op{}(typed_inputs[Is.value][indexer(Is.value)]..., std::forward<Args>(args)...);
},
std::make_index_sequence<N>{});
}
}
template <typename Op, typename Tout, typename... Tin>
INFINIOP_MOORE_KERNEL elementwiseKernel(
size_t output_size,
size_t ndim,
bool output_contiguous,
const bool *__restrict__ input_contiguous,
const bool *__restrict__ input_broadcasted,
const size_t *__restrict__ output_shape,
const size_t *__restrict__ input_shapes,
const ptrdiff_t *__restrict__ output_strides,
const ptrdiff_t *__restrict__ input_strides,
Tout *output,
const void *const *__restrict__ inputs,
size_t offset) {
size_t idx = blockIdx.x * blockDim.x + threadIdx.x + offset;
if (idx < output_size) {
size_t out_idx = getOutputIndex(idx, output_contiguous, ndim, output_shape, output_strides);
InputIndexer indexer{idx, ndim, input_contiguous, input_broadcasted, input_shapes, input_strides, output_strides};
unpackInputsAndApply(
[&](auto... Is) {
output[out_idx] = Op{}.template operator()<Tout, Tin...>(
(typedInputPtr<Tin>(inputs[Is.value])[indexer(Is.value)])...);
},
std::index_sequence_for<Tin...>{});
}
}
struct DeviceImpl::Opaque {
std::shared_ptr<device::moore::Handle::Internal> internal;
Opaque(const std::shared_ptr<device::moore::Handle::Internal> &internal)
: internal(internal) {}
template <uint32_t BLOCK_SIZE, size_t N, typename Op, typename Tdata, typename... Args>
infiniStatus_t calculateImpl(const op::elementwise::ElementwiseInfo &info,
void *workspace,
void *output,
const std::vector<const void *> &inputs,
musaStream_t stream,
Args &&...args) {
return launchElementwiseKernel<BLOCK_SIZE, N>(
info, workspace,
reinterpret_cast<Tdata *>(output), inputs,
elementwiseKernel<N, Op, Tdata, Args...>,
stream,
std::forward<Args>(args)...);
}
template <uint32_t BLOCK_SIZE, size_t N, typename Op, typename Tout, typename... Tin, typename... Args,
std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int> = 0>
infiniStatus_t calculateImpl(const op::elementwise::ElementwiseInfo &info,
void *workspace,
void *output,
const std::vector<const void *> &inputs,
musaStream_t stream,
Args &&...args) {
return launchElementwiseKernel<BLOCK_SIZE, N>(
info, workspace,
reinterpret_cast<Tout *>(output), inputs,
elementwiseKernel<Op, Tout, Tin...>,
stream);
}
private:
template <size_t N>
infiniStatus_t infoToDevice(
const op::elementwise::ElementwiseInfo &info,
void *workspace,
const void *const *h_inputs_arr,
const void **&d_inputs_arr,
const bool *&d_input_contiguous,
const bool *&d_input_broadcasted,
const size_t *&d_output_shape,
const ptrdiff_t *&d_output_strides,
const size_t *&d_input_shapes,
const ptrdiff_t *&d_input_strides,
musaStream_t stream) const {
constexpr auto input_size = N;
const auto ndim = info.getNdim();
constexpr auto input_arr_size = N * sizeof(*h_inputs_arr);
const int8_t *info_meta_start = info.getMetaStart();
const int8_t *d_meta_start = reinterpret_cast<int8_t *>(workspace) + input_arr_size;
// copy the input pointer array and meta to device
CHECK_MOORE(musaMemcpyAsync(workspace, h_inputs_arr, input_arr_size, musaMemcpyHostToDevice, stream));
CHECK_MOORE(musaMemcpyAsync((void *)d_meta_start, info_meta_start, info.getMetaMemSize(), musaMemcpyHostToDevice, stream));
// offset/assign the pointers
d_inputs_arr = reinterpret_cast<const void **>(workspace);
d_output_shape = reinterpret_cast<const size_t *>(d_meta_start);
d_output_strides = reinterpret_cast<const ptrdiff_t *>(d_output_shape + ndim);
d_input_shapes = reinterpret_cast<const size_t *>(d_output_strides + ndim);
d_input_strides = reinterpret_cast<const ptrdiff_t *>(d_input_shapes + input_size * ndim);
d_input_contiguous = reinterpret_cast<const bool *>(d_input_strides + input_size * ndim);
d_input_broadcasted = reinterpret_cast<const bool *>(d_input_contiguous + input_size);
return INFINI_STATUS_SUCCESS;
}
template <uint32_t BLOCK_SIZE, size_t N, typename KernelFunc, typename Tout, typename... Args>
infiniStatus_t launchElementwiseKernel(
const op::elementwise::ElementwiseInfo &info,
void *workspace,
Tout *output,
const std::vector<const void *> &inputs,
KernelFunc kernel_func,
musaStream_t stream,
Args &&...args) {
auto output_size = info.getOutputSize();
if (output_size == 0) {
return INFINI_STATUS_SUCCESS;
}
// Device pointers
const void **d_inputs_arr = nullptr;
const bool *d_input_contiguous = nullptr;
const bool *d_input_broadcasted = nullptr;
const size_t *d_output_shape = nullptr;
const ptrdiff_t *d_output_strides = nullptr;
const size_t *d_input_shapes = nullptr;
const ptrdiff_t *d_input_strides = nullptr;
CHECK_STATUS(infoToDevice<N>(info, workspace, inputs.data(), d_inputs_arr,
d_input_contiguous, d_input_broadcasted,
d_output_shape, d_output_strides,
d_input_shapes, d_input_strides, stream));
dim3 blockDims(std::min(BLOCK_SIZE, static_cast<uint32_t>(internal->maxThreadsPerBlock())));
dim3 gridDims(std::min(uint32_t(CEIL_DIV(output_size, blockDims.x)), static_cast<uint32_t>(internal->gridSizeX())));
size_t step = gridDims.x * blockDims.x;
for (size_t i = 0; i < output_size; i += step) {
kernel_func<<<gridDims, blockDims, 0, stream>>>(
output_size, info.getNdim(), info.isOutputContiguous(),
d_input_contiguous, d_input_broadcasted,
d_output_shape, d_input_shapes,
d_output_strides, d_input_strides,
output, reinterpret_cast<const void **>(d_inputs_arr),
i, std::forward<Args>(args)...);
}
return INFINI_STATUS_SUCCESS;
}
};
template <typename... Args>
utils::Result<DeviceImpl *> DeviceImpl::create(Args &&...args) {
auto opaque = std::make_shared<Opaque>(std::forward<Args>(args)...);
return utils::Result<DeviceImpl *>(new DeviceImpl(opaque));
}
/* Invoke elementwise operation for different input types */
template <uint32_t BLOCK_SIZE, typename Op, typename Tout, typename... Tin, typename... Args,
std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int>>
infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info,
void *workspace,
void *output,
const std::vector<const void *> &inputs,
void *stream,
Args &&...args) {
constexpr size_t N = Op::num_inputs;
static_assert(sizeof...(Tin) == N, "Input type count mismatch");
return _opaque->calculateImpl<BLOCK_SIZE, N, Op, Tout, Tin...>(
info, workspace, output, inputs,
reinterpret_cast<musaStream_t>(stream),
std::forward<Args>(args)...);
}
/* Invoke elementwise operation when all inputs have the same dtype */
template <uint32_t BLOCK_SIZE, typename Op, typename Tdata, typename... Args>
infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info,
void *workspace,
void *output,
const std::vector<const void *> &inputs,
void *stream,
Args &&...args) {
constexpr size_t N = Op::num_inputs;
return _opaque->calculateImpl<BLOCK_SIZE, N, Op, Tdata>(
info, workspace, output, inputs,
reinterpret_cast<musaStream_t>(stream),
std::forward<Args>(args)...);
}
} // namespace op::elementwise::moore
#endif
#ifndef __INFINIOP_ELEMENTWISE_MOORE_API_H__
#define __INFINIOP_ELEMENTWISE_MOORE_API_H__
#include "../elementwise.h"
namespace op::elementwise::moore {
class DeviceImpl final {
struct Opaque;
std::shared_ptr<Opaque> _opaque;
DeviceImpl(std::shared_ptr<Opaque> opaque) : _opaque(std::move(opaque)) {}
public:
~DeviceImpl() = default;
template <typename... Args>
static utils::Result<DeviceImpl *> create(Args &&...args);
template <uint32_t BLOCK_SIZE, typename Op, typename Tdata, typename... Args>
infiniStatus_t calculate(
const op::elementwise::ElementwiseInfo &info,
void *workspace,
void *output,
const std::vector<const void *> &inputs,
void *stream,
Args &&...args);
template <uint32_t BLOCK_SIZE, typename Op, typename Tout, typename... Tin,
typename... Args,
std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int> = 0>
infiniStatus_t calculate(
const op::elementwise::ElementwiseInfo &info,
void *workspace,
void *output,
const std::vector<const void *> &inputs,
void *stream,
Args &&...args);
};
} // namespace op::elementwise::moore
#define CREATE_ELEMENTWISE_MOORE_DESCRIPTOR(HANDLE, DTYPE, OUT_DESC, INPUT_DESC_VEC) \
\
auto info_result = op::elementwise::ElementwiseInfo::create(OUT_DESC, INPUT_DESC_VEC); \
CHECK_RESULT(info_result); \
auto info = info_result.take(); \
auto workspace_size = info.getMetaMemSize() + info.getInputSize() * sizeof(void *); \
\
auto device_impl_result = op::elementwise::moore::DeviceImpl::create(HANDLE->internal()); \
CHECK_RESULT(device_impl_result); \
\
*desc_ptr = new Descriptor( \
DTYPE, \
std::move(info), \
std::move(device_impl_result.take()), \
workspace_size, \
HANDLE->device, \
HANDLE->device_id);
#endif // __INFINIOP_ELEMENTWISE_MOORE_API_H__
#ifndef __SWIGLU_CUDA_H__
#define __SWIGLU_CUDA_H__
/*
* This file contains the SwiGLU operation implementation for the MUSA backend.
*
* It uses the 'op::swiglu::cuda' namespace to maintain a consistent code structure
* and interface with the CUDA implementation, ensuring code alignment across different
* hardware platforms.
*/
namespace op::swiglu::cuda {
typedef struct SwiGLUOp {
private:
template <typename T>
__device__ __forceinline__ T sigmoid(const T &x) const {
if constexpr (std::is_same_v<T, half2>) {
return h2rcp(__hadd2(make_half2(1, 1), h2exp(__hneg2(x))));
} else if constexpr (std::is_same_v<T, half>) {
// This implementation uses standard floating-point arithmetic to calculate the sigmoid function,
// ensuring portability across on MUSA platforms.
//
// The original CUDA implementation's reliance on platform-specific intrinsics like hrcp for half-precision,
// which was not supported on the MUSA platform.
// To resolve this, the half-precision input is first converted to a higher-precision float,
// the calculation is performed, and the result is cast back to half.
float xf = __half2float(x);
float sigf = 1.0f / (1.0f + std::exp(-xf));
return __float2half(sigf);
} else if constexpr (std::is_same_v<T, cuda_bfloat162>) {
float x0 = __bfloat162float(__low2bfloat16(x));
float x1 = __bfloat162float(__high2bfloat16(x));
float sig0 = __frcp_rn(__fadd_rn(1.0f, __expf(-x0)));
float sig1 = __frcp_rn(__fadd_rn(1.0f, __expf(-x1)));
return __floats2bfloat162_rn(sig0, sig1);
} else if constexpr (std::is_same_v<T, cuda_bfloat16>) {
float xf = __bfloat162float(x);
return __float2bfloat16_rn(__frcp_rn(__fadd_rn(1.0f, __expf(-xf))));
} else if constexpr (std::is_same_v<T, float>) {
return __frcp_rn(__fadd_rn(1, __expf(-x)));
} else {
return 1 / (1 + std::exp(-x));
}
}
public:
static constexpr size_t num_inputs = 2;
template <typename T>
__device__ __forceinline__ T operator()(const T &up, const T &gate) const {
if constexpr (std::is_same_v<T, half2>) {
return __hmul2(__hmul2(gate, sigmoid(gate)), up);
} else if constexpr (std::is_same_v<T, half>) {
return __hmul(__hmul(gate, sigmoid(gate)), up);
} else if constexpr (std::is_same_v<T, cuda_bfloat162>) {
cuda_bfloat162 sig = sigmoid(gate);
// On the MUSA platform, `__low2float()` and `__high2float()` are used to directly
// extract and convert bfloat16 values to float. These functions replace the
// two-step process used in CUDA (e.g., `__low2bfloat16` followed by `__bfloat162float`).
// Since MUSA may not support '__low2bfloat16'
float gate0 = __low2float(gate);
float gate1 = __high2float(gate);
float sig0 = __low2float(sig);
float sig1 = __high2float(sig);
float up0 = __low2float(up);
float up1 = __high2float(up);
float res0 = __fmul_rn(__fmul_rn(gate0, sig0), up0);
float res1 = __fmul_rn(__fmul_rn(gate1, sig1), up1);
return __floats2bfloat162_rn(res0, res1);
} else if constexpr (std::is_same_v<T, cuda_bfloat16>) {
cuda_bfloat16 sig = sigmoid(gate);
float gatef = __bfloat162float(gate);
float sigf = __bfloat162float(sig);
float upf = __bfloat162float(up);
return __float2bfloat16_rn(__fmul_rn(__fmul_rn(gatef, sigf), upf));
} else if constexpr (std::is_same_v<T, float>) {
return __fmul_rn(__fmul_rn(gate, sigmoid(gate)), up);
} else {
return gate * sigmoid(gate) * up;
}
}
} SwiGLUOp;
} // namespace op::swiglu::cuda
#endif // __SWIGLU_CUDA_H__
#ifndef __SWIGLU_MOORE_API_H__
#define __SWIGLU_MOORE_API_H__
#include "../../../elementwise/moore/elementwise_moore_api.h"
ELEMENTWISE_DESCRIPTOR(swiglu, moore)
#endif // __SWIGLU_MOORE_API_H__
#include "swiglu_moore.h"
#include "../../../elementwise/moore/elementwise_moore.h"
#include "siwglu_moore_kernel.h"
namespace op::swiglu::moore {
Descriptor::~Descriptor() = default;
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {
auto handle = reinterpret_cast<device::moore::Handle *>(handle_);
auto dtype = out_desc->dtype();
const auto &up_desc = input_desc_vec.at(0);
const auto &gate_desc = input_desc_vec.at(1);
const auto &out_shape = out_desc->shape();
const auto &up_shape = up_desc->shape();
const auto &gate_shape = gate_desc->shape();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32, INFINI_DTYPE_F64);
CHECK_SAME_SHAPE(out_shape, up_shape, gate_shape);
// create MOORE elementwise descriptor
CREATE_ELEMENTWISE_MOORE_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec)
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *output,
std::vector<const void *> inputs,
void *stream) const {
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
switch (_dtype) {
case INFINI_DTYPE_F16:
return _device_info->calculate<256, cuda::SwiGLUOp, half>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_BF16:
return _device_info->calculate<256, cuda::SwiGLUOp, cuda_bfloat16>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F32:
return _device_info->calculate<256, cuda::SwiGLUOp, float>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F64:
return _device_info->calculate<256, cuda::SwiGLUOp, double>(_info, workspace, output, inputs, stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::swiglu::moore
...@@ -20,6 +20,9 @@ ...@@ -20,6 +20,9 @@
#ifdef ENABLE_ASCEND_API #ifdef ENABLE_ASCEND_API
#include "ascend/swiglu_ascend.h" #include "ascend/swiglu_ascend.h"
#endif #endif
#ifdef ENABLE_MOORE_API
#include "moore/swiglu_moore.h"
#endif
__C infiniStatus_t infiniopCreateSwiGLUDescriptor( __C infiniStatus_t infiniopCreateSwiGLUDescriptor(
infiniopHandle_t handle, infiniopHandle_t handle,
...@@ -60,10 +63,8 @@ __C infiniStatus_t infiniopCreateSwiGLUDescriptor( ...@@ -60,10 +63,8 @@ __C infiniStatus_t infiniopCreateSwiGLUDescriptor(
#ifdef ENABLE_ASCEND_API #ifdef ENABLE_ASCEND_API
CREATE(INFINI_DEVICE_ASCEND, ascend); CREATE(INFINI_DEVICE_ASCEND, ascend);
#endif #endif
#ifdef ENABLE_MTHREADS_GPU #ifdef ENABLE_MOORE_API
case DevMthreadsGpu: CREATE(INFINI_DEVICE_MOORE, moore);
return musaCreateSwiGLUDescriptor(
handle, (SwiGLUMusaDescriptor_t *)desc_ptr, c_desc, a_desc, b_desc);
#endif #endif
default: default:
...@@ -102,10 +103,8 @@ __C infiniStatus_t infiniopGetSwiGLUWorkspaceSize(infiniopSwiGLUDescriptor_t des ...@@ -102,10 +103,8 @@ __C infiniStatus_t infiniopGetSwiGLUWorkspaceSize(infiniopSwiGLUDescriptor_t des
#ifdef ENABLE_ASCEND_API #ifdef ENABLE_ASCEND_API
GET(INFINI_DEVICE_ASCEND, ascend); GET(INFINI_DEVICE_ASCEND, ascend);
#endif #endif
#ifdef ENABLE_MTHREADS_GPU #ifdef ENABLE_MOORE_API
case DevMthreadsGpu: { GET(INFINI_DEVICE_MOORE, moore);
return musaGetSwiGLUWorkspaceSize((SwiGLUMusaDescriptor_t)desc, size);
}
#endif #endif
} }
...@@ -151,9 +150,8 @@ __C infiniStatus_t infiniopSwiGLU( ...@@ -151,9 +150,8 @@ __C infiniStatus_t infiniopSwiGLU(
#ifdef ENABLE_ASCEND_API #ifdef ENABLE_ASCEND_API
CALCULATE(INFINI_DEVICE_ASCEND, ascend); CALCULATE(INFINI_DEVICE_ASCEND, ascend);
#endif #endif
#ifdef ENABLE_MTHREADS_GPU #ifdef ENABLE_MOORE_API
case DevMthreadsGpu: CALCULATE(INFINI_DEVICE_MOORE, moore);
return musaSwiGLU((SwiGLUMusaDescriptor_t)desc, c, a, b, stream);
#endif #endif
default: default:
...@@ -194,9 +192,8 @@ infiniopDestroySwiGLUDescriptor(infiniopSwiGLUDescriptor_t desc) { ...@@ -194,9 +192,8 @@ infiniopDestroySwiGLUDescriptor(infiniopSwiGLUDescriptor_t desc) {
#ifdef ENABLE_ASCEND_API #ifdef ENABLE_ASCEND_API
DELETE(INFINI_DEVICE_ASCEND, ascend) DELETE(INFINI_DEVICE_ASCEND, ascend)
#endif #endif
#ifdef ENABLE_MTHREADS_GPU #ifdef ENABLE_MOORE_API
case DevMthreadsGpu: DELETE(INFINI_DEVICE_MOORE, moore);
return musaDestroySwiGLUDescriptor((SwiGLUMusaDescriptor_t)desc);
#endif #endif
default: default:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment