Commit 9faf1ffc authored by zhangyue's avatar zhangyue
Browse files

issue/209: kunlun elementwise first commit

parent 1a4cfb99
...@@ -16,7 +16,7 @@ typedef XPUStream kunlunStream_t; ...@@ -16,7 +16,7 @@ typedef XPUStream kunlunStream_t;
typedef XPUEvent kunlunEvent_t; typedef XPUEvent kunlunEvent_t;
typedef xdnn::Context *xdnnHandle_t; typedef xdnn::Context *xdnnHandle_t;
#define CHECK_XDNN(API) CHECK_INTERNAL(API, XPU_SUCCESS) #define CHECK_KUNLUN(API) CHECK_INTERNAL(API, XPU_SUCCESS)
namespace device::kunlun { namespace device::kunlun {
......
#ifndef __INFINIOP_KUNLUN_COMMON_H__ #ifndef __INFINIOP_KUNLUN_KERNEL_COMMON_H__
#define __INFINIOP_KUNLUN_COMMON_H__ #define __INFINIOP_KUNLUN_KERNEL_COMMON_H__
// This header file will only be include by .xpu file // This header file will only be include by .xpu file
#include "kunlun_kernel_dtype.h"
#include "xpu/kernel/xtdk.h" #include "xpu/kernel/xtdk.h"
#include "xpu/kernel/xtdk_io.h"
#include "xpu/kernel/xtdk_math.h" #include "xpu/kernel/xtdk_math.h"
#include "xpu/kernel/xtdk_simd.h" #include "xpu/kernel/xtdk_simd.h"
#include "xpu/runtime.h" #include "xpu/runtime.h"
namespace device::kunlun::kernel {
// Get mask for kunlun xpu 512bit register calculation // Get mask for kunlun xpu 512bit register calculation
// if data is not enough to 512bit, padding zero and use // if data is not enough to 512bit, padding zero and use
// mask to identify real data // mask to identify real data
...@@ -26,6 +29,37 @@ inline __device__ void atomicAddF32(__shared_ptr__ float *ptr, float value) { ...@@ -26,6 +29,37 @@ inline __device__ void atomicAddF32(__shared_ptr__ float *ptr, float value) {
} }
} }
inline __device__ size_t indexToReducedOffset(
size_t flat_index,
size_t ndim,
const ptrdiff_t *broadcasted_strides,
const ptrdiff_t *target_strides) {
size_t res = 0;
for (size_t i = 0; i < ndim; ++i) {
res += flat_index / broadcasted_strides[i].value * target_strides[i].value;
flat_index %= broadcasted_strides[i].value;
mfence();
}
return res;
}
inline __device__ size_t indexToOffset(
size_t flat_index,
size_t ndim,
const _size_t *shape,
const ptrdiff_t *strides) {
size_t res = 0;
for (size_t i = ndim; i-- > 0;) {
res += (flat_index % shape[i].value) * strides[i].value;
flat_index /= shape[i].value;
mfence();
}
return res;
}
} // namespace device::kunlun::kernel
// TODO: atomicAddF16 // TODO: atomicAddF16
// TODO: atomicAddI8 // TODO: atomicAddI8
#endif #endif
#ifndef __INFINIOP_KUNLUN_DTYPE_H__
#define __INFINIOP_KUNLUN_DTYPE_H__
#include "xpu/kernel/xtdk.h"
#include "xpu/kernel/xtdk_io.h"
#include "xpu/kernel/xtdk_math.h"
#include "xpu/kernel/xtdk_simd.h"
#include "xpu/runtime.h"
// kunlun ptrdiff_t* is used to save ptrdiff_t array
// copied from host
typedef struct ptrdiff_t {
long value; // 32 bit
long padding; // 32 bit
} ptrdiff_t;
// same as ptrdiff
typedef struct _size_t {
size_t value;
size_t padding;
} _size_t;
#endif
#ifndef __INFINIOP_ELEMENTWISE_KUNLUN_H__
#define __INFINIOP_ELEMENTWISE_KUNLUN_H__
#include "../../../utils.h"
#include "../../devices/kunlun/kunlun_handle.h"
#include "elementwise_kunlun_api.h"
namespace op::elementwise::kunlun {
struct DeviceImpl::Opaque {
std::shared_ptr<device::kunlun::Handle::Internal> internal;
Opaque(const std::shared_ptr<device::kunlun::Handle::Internal> &internal_)
: internal(internal_) {}
template <size_t N, typename Op, typename Tdata, typename... Args>
infiniStatus_t calculateImpl(const op::elementwise::ElementwiseInfo &info,
void *workspace,
void *output,
const std::vector<const void *> &inputs,
kunlunStream_t stream,
Args &&...args) {
auto output_size = info.getOutputSize();
if (output_size == 0) {
return INFINI_STATUS_SUCCESS;
}
// Device pointers
const void **d_inputs_arr = nullptr;
const bool *d_input_contiguous = nullptr;
const bool *d_input_broadcasted = nullptr;
const size_t *d_output_shape = nullptr;
const ptrdiff_t *d_output_strides = nullptr;
const size_t *d_input_shapes = nullptr;
const ptrdiff_t *d_input_strides = nullptr;
CHECK_STATUS(infoToDevice<N>(info, workspace, inputs.data(), d_inputs_arr,
d_input_contiguous, d_input_broadcasted,
d_output_shape, d_output_strides,
d_input_shapes, d_input_strides));
Op::template launch<Tdata>(
output_size,
info.getNdim(),
info.isOutputContiguous(),
reinterpret_cast<const void *>(d_input_contiguous),
reinterpret_cast<const void *>(d_input_broadcasted),
reinterpret_cast<const void *>(d_output_shape),
reinterpret_cast<const void *>(d_input_shapes),
reinterpret_cast<const void *>(d_output_strides),
reinterpret_cast<const void *>(d_input_strides),
output,
reinterpret_cast<const void *const *>(d_inputs_arr),
stream,
args...);
// std::forward<Args>(args)...);
return INFINI_STATUS_SUCCESS;
}
private:
template <size_t N>
infiniStatus_t infoToDevice(
const op::elementwise::ElementwiseInfo &info,
void *workspace,
const void *const *h_inputs_arr,
const void **&d_inputs_arr,
const bool *&d_input_contiguous,
const bool *&d_input_broadcasted,
const size_t *&d_output_shape,
const ptrdiff_t *&d_output_strides,
const size_t *&d_input_shapes,
const ptrdiff_t *&d_input_strides) const {
constexpr auto input_size = N;
const auto ndim = info.getNdim();
constexpr auto input_arr_size = N * sizeof(*h_inputs_arr);
const int8_t *info_meta_start = info.getMetaStart();
const int8_t *d_meta_start = reinterpret_cast<int8_t *>(workspace) + input_arr_size;
// copy the input pointer array and meta to device
CHECK_KUNLUN(xpu_memcpy(workspace, h_inputs_arr, input_arr_size, XPU_HOST_TO_DEVICE));
CHECK_KUNLUN(xpu_memcpy((void *)d_meta_start, info_meta_start, info.getMetaMemSize(), XPU_HOST_TO_DEVICE));
// offset/assign the pointers
d_inputs_arr = reinterpret_cast<const void **>(workspace);
d_output_shape = reinterpret_cast<const size_t *>(d_meta_start);
d_output_strides = reinterpret_cast<const ptrdiff_t *>(d_output_shape + ndim);
d_input_shapes = reinterpret_cast<const size_t *>(d_output_strides + ndim);
d_input_strides = reinterpret_cast<const ptrdiff_t *>(d_input_shapes + input_size * ndim);
d_input_contiguous = reinterpret_cast<const bool *>(d_input_strides + input_size * ndim);
d_input_broadcasted = reinterpret_cast<const bool *>(d_input_contiguous + input_size);
return INFINI_STATUS_SUCCESS;
}
};
template <typename... Args>
utils::Result<DeviceImpl *> DeviceImpl::create(Args &&...args) {
auto opaque = std::make_shared<Opaque>(std::forward<Args>(args)...);
return utils::Result<DeviceImpl *>(new DeviceImpl(opaque));
}
template <typename Op, typename Tdata, typename... Args>
infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info,
void *workspace,
void *output,
const std::vector<const void *> &inputs,
void *stream,
Args &&...args) {
constexpr size_t N = Op::num_inputs;
return _opaque->calculateImpl<N, Op, Tdata>(
info, workspace, output, inputs,
reinterpret_cast<kunlunStream_t>(stream),
std::forward<Args>(args)...);
}
} // namespace op::elementwise::kunlun
// Template for kunlun kernel interface declaration
#define LAUNCH_ELEMENTWISE_KERNEL(OpName) \
template <typename Tdata, typename... Args> \
void launch##OpName##Kernel( \
size_t output_size, \
size_t ndim, \
bool output_contiguous, \
const void *input_contiguous, \
const void *input_broadcasted, \
const void *output_shape, \
const void *input_shapes, \
const void *output_strides, \
const void *input_strides, \
void *output, \
const void *const *inputs, \
XPUStream stream, \
Args... args);
#endif
#ifndef __INFINIOP_ELEMENTWISE_KUNLUN_API_H__
#define __INFINIOP_ELEMENTWISE_KUNLUN_API_H__
#include "../elementwise.h"
namespace op::elementwise::kunlun {
class DeviceImpl final {
struct Opaque;
std::shared_ptr<Opaque> _opaque;
DeviceImpl(std::shared_ptr<Opaque> opaque) : _opaque(std::move(opaque)) {}
public:
~DeviceImpl() = default;
template <typename... Args>
static utils::Result<DeviceImpl *> create(Args &&...args);
template <typename Op, typename Tdata, typename... Args>
infiniStatus_t calculate(
const op::elementwise::ElementwiseInfo &info,
void *workspace,
void *output,
const std::vector<const void *> &inputs,
void *stream,
Args &&...args);
};
} // namespace op::elementwise::kunlun
#define CREATE_ELEMENTWISE_KUNLUN_DESCRIPTOR(HANDLE, DTYPE, OUT_DESC, INPUT_DESC_VEC) \
\
auto info_result = op::elementwise::ElementwiseInfo::create(OUT_DESC, INPUT_DESC_VEC); \
CHECK_RESULT(info_result); \
auto info = info_result.take(); \
auto workspace_size = info.getMetaMemSize() + info.getInputSize() * sizeof(void *); \
\
auto device_impl_result = op::elementwise::kunlun::DeviceImpl::create(HANDLE->internal()); \
CHECK_RESULT(device_impl_result); \
\
*desc_ptr = new Descriptor( \
DTYPE, \
std::move(info), \
std::move(device_impl_result.take()), \
workspace_size, \
HANDLE->device, \
HANDLE->device_id);
#endif
#ifndef __INFINIOP_ELEMENTWISE_KUNLUN_XPU__
#define __INFINIOP_ELEMENTWISE_KUNLUN_XPU__
#include "../../devices/kunlun/kunlun_kernel_common.h"
#include "xpu/kernel/xtdk_io.h"
// #include <cstdio>
using namespace device::kunlun::kernel;
/**
* @brief Computes input tile offset
*/
struct InputIndexer {
size_t idx;
size_t ndim;
const bool *input_contiguous;
const bool *input_broadcasted;
const _size_t *input_shapes;
const ptrdiff_t *input_strides;
const ptrdiff_t *output_strides;
__device__ size_t operator()(size_t input_id) const {
return input_contiguous[input_id]
? idx
: (input_broadcasted[input_id]
? indexToReducedOffset(idx, ndim, output_strides, input_strides + input_id * ndim)
: indexToOffset(idx, ndim, input_shapes + input_id * ndim, input_strides + input_id * ndim));
}
};
/**
* @brief Computes the output index in memory, accounting for strides if non-contiguous.
*
* @param idx Linear index.
* @param is_contiguous Whether the output tensor is contiguous.
* @param ndim Number of dimensions.
* @param shape Shape of the output tensor.
* @param strides Strides of the output tensor.
* @return Memory offset index.
*/
inline __device__ size_t
getOutputIndex(size_t idx,
bool is_contiguous,
size_t ndim,
const _size_t *shape,
const ptrdiff_t *strides) {
return is_contiguous ? idx : indexToOffset(idx, ndim, shape, strides);
}
template <size_t N, typename Op, typename Tdata, typename... Args>
__device__ void launchOp(
__global_ptr__ Tdata **typed_inputs, // gm pointer
__global_ptr__ Tdata *output, // gm pointer output
Tdata *inputs_buf, // local mem buffer
size_t *input_indexes,
size_t output_index,
Args... args) {
static_assert(N == Op::num_inputs, "template N is not equal to Op::num_inputs!\n");
#pragma unroll
// Copy inputs to buf
for (size_t i = 0; i < N; i++) {
auto gm = typed_inputs[i] + input_indexes[i];
auto lm = inputs_buf + i;
GM2LM_ASYNC(gm, lm, 1 * sizeof(Tdata));
}
mfence();
// Calculate elementwise
// Inputs save all operands
Tdata out = Op{}(inputs_buf, args...);
// Copy out to gm
LM2GM_ASYNC(&out, output + output_index, 1 * sizeof(Tdata));
mfence();
}
template <size_t N, typename Op, typename Tdata, typename... Args>
__global__ void elementwiseKernel(
size_t output_size,
size_t ndim,
bool output_contiguous,
const bool *input_contiguous_gm,
const bool *input_broadcasted_gm,
const _size_t *output_shape_gm,
const _size_t *input_shapes_gm,
const ptrdiff_t *output_strides_gm,
const ptrdiff_t *input_strides_gm,
Tdata *output,
const void *const *inputs,
Args... args) {
// Only support 3 mode elementwise
static_assert(N < 4, "elementwise Kernel support mode < 4 calculate");
int cid = core_id();
int ncores = core_num();
if (cid >= ncores) {
return;
}
int thread_id = ncores * cluster_id() + cid;
int nthreads = ncores * cluster_num();
// Cast input gm pointer type
auto typed_inputs = reinterpret_cast<const __global_ptr__ Tdata *const __global_ptr__ *>(inputs);
const int BUFF_SIZE = 64;
// Input data cache
__local__ Tdata inputs_buf[N];
// Input contiguous/broadcasted flags
__local__ bool input_contiguous[N];
__local__ bool input_broadcasted[N];
// Input shape/strides
__local__ _size_t input_shapes[N * ndim];
__local__ ptrdiff_t input_strides[N * ndim];
// Output shape/strides
__local__ _size_t output_shape[ndim];
__local__ ptrdiff_t output_strides[ndim];
// Inputs gm ptr buf
__local__ __global_ptr__ Tdata *typed_inputs_ptr[N];
// Load from gm
GM2LM_ASYNC(input_contiguous_gm, input_contiguous, N * sizeof(bool));
GM2LM_ASYNC(input_broadcasted_gm, input_broadcasted, N * sizeof(bool));
GM2LM_ASYNC(input_shapes_gm, input_shapes, N * ndim * sizeof(_size_t));
GM2LM_ASYNC(input_strides_gm, input_strides, N * ndim * sizeof(ptrdiff_t));
GM2LM_ASYNC(output_shape_gm, output_shape, ndim * sizeof(_size_t));
GM2LM_ASYNC(output_strides_gm, output_strides, ndim * sizeof(ptrdiff_t));
GM2LM_ASYNC(typed_inputs, typed_inputs_ptr, N * sizeof(__global_ptr__ Tdata *));
mfence();
int len_per_loop = min(BUFF_SIZE, roundup_div(output_size, nthreads));
for (int start = thread_id * len_per_loop; start < output_size; start += nthreads * len_per_loop) {
size_t read_len = min(len_per_loop, output_size - start);
for (int idx = start; idx < start + read_len; ++idx) {
size_t out_idx = getOutputIndex(static_cast<size_t>(idx), output_contiguous,
ndim, output_shape, output_strides);
InputIndexer indexer{static_cast<size_t>(idx), ndim, input_contiguous, input_broadcasted,
input_shapes, input_strides, output_strides};
// Get index offset for every operand
size_t indexes[N];
for (size_t i = 0; i < N; i++) {
indexes[i] = indexer(i);
}
// Launch operater
launchOp<N, Op, Tdata>(&typed_inputs_ptr[0], output, inputs_buf, indexes, out_idx, args...);
}
}
sync_cluster();
}
#define LAUNCH_ELEMENTWISE_KERNEL_IMPL(OpName, Op) \
template <typename Tdata, typename... Args> \
void launch##OpName##Kernel( \
size_t output_size, \
size_t ndim, \
bool output_contiguous, \
const void *input_contiguous, \
const void *input_broadcasted, \
const void *output_shape, \
const void *input_shapes, \
const void *output_strides, \
const void *input_strides, \
void *output, \
const void *const *inputs, \
XPUStream stream, \
Args... args) { \
elementwiseKernel<Op::num_inputs, Op, Tdata><<<8, 64, stream>>>( \
output_size, ndim, output_contiguous, \
reinterpret_cast<const bool *>(input_contiguous), \
reinterpret_cast<const bool *>(input_broadcasted), \
reinterpret_cast<const _size_t *>(output_shape), \
reinterpret_cast<const _size_t *>(input_shapes), \
reinterpret_cast<const ptrdiff_t *>(output_strides), \
reinterpret_cast<const ptrdiff_t *>(input_strides), \
reinterpret_cast<Tdata *>(output), inputs, args...); \
}
#define LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(OpName, T, ...) \
template void launch##OpName##Kernel<T, ##__VA_ARGS__>( \
size_t output_size, \
size_t ndim, \
bool output_contiguous, \
const void *input_contiguous, \
const void *input_broadcasted, \
const void *output_shape, \
const void *input_shapes, \
const void *output_strides, \
const void *input_strides, \
void *output, \
const void *const *inputs, \
XPUStream stream, \
##__VA_ARGS__);
#endif
...@@ -62,7 +62,7 @@ infiniStatus_t calculate( ...@@ -62,7 +62,7 @@ infiniStatus_t calculate(
(kunlunStream_t)stream, (kunlunStream_t)stream,
[&](xdnnHandle_t handle) { [&](xdnnHandle_t handle) {
for (size_t i = 0; i < info.batch; i++) { for (size_t i = 0; i < info.batch; i++) {
CHECK_XDNN((xdnn::fc_fusion<Tdata, Tdata, Tdata, int16_t>( CHECK_KUNLUN((xdnn::fc_fusion<Tdata, Tdata, Tdata, int16_t>(
handle, handle,
(Tdata *)((char *)a + i * info.a_matrix.stride * unit), (Tdata *)((char *)a + i * info.a_matrix.stride * unit),
(Tdata *)((char *)b + i * info.b_matrix.stride * unit), (Tdata *)((char *)b + i * info.b_matrix.stride * unit),
......
#ifndef __RMS_NORM_KUNLUN_KERNEL_XPU__ #ifndef __RMS_NORM_KUNLUN_KERNEL_XPU__
#define __RMS_NORM_KUNLUN_KERNEL_XPU__ #define __RMS_NORM_KUNLUN_KERNEL_XPU__
#include "../../../devices/kunlun/kunlun_common.h" #include "../../../devices/kunlun/kunlun_kernel_common.h"
#include "../../../reduce/kunlun/reduce_kunlun.h" #include "../../../reduce/kunlun/reduce_kunlun.h"
using namespace device::kunlun::kernel;
// Element wise mul used in x * w // Element wise mul used in x * w
static inline __device__ void elementwiseMulRms(float *x, float *w, float *y, int count, float rms) { static inline __device__ void elementwiseMulRms(float *x, float *w, float *y, int count, float rms) {
int remain = count % 16; int remain = count % 16;
......
#include "swiglu_kunlun.h"
// Op interface declare
LAUNCH_ELEMENTWISE_KERNEL(SwiGLU)
namespace op::swiglu::kunlun {
typedef struct SwiGLUOp {
static constexpr size_t num_inputs = 2;
template <typename Tdata, typename... Args>
static infiniStatus_t launch(Args... args) {
launchSwiGLUKernel<Tdata>(args...);
return INFINI_STATUS_SUCCESS;
}
} SwiGLUOp;
Descriptor::~Descriptor() = default;
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {
auto handle = reinterpret_cast<device::kunlun::Handle *>(handle_);
auto dtype = out_desc->dtype();
const auto &up_desc = input_desc_vec.at(0);
const auto &gate_desc = input_desc_vec.at(1);
const auto &out_shape = out_desc->shape();
const auto &up_shape = up_desc->shape();
const auto &gate_shape = gate_desc->shape();
CHECK_DTYPE(dtype, INFINI_DTYPE_F32);
CHECK_SAME_SHAPE(out_shape, up_shape, gate_shape);
// create KUNLUN elementwise descriptor
CREATE_ELEMENTWISE_KUNLUN_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec)
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *output,
std::vector<const void *> inputs,
void *stream) const {
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
switch (_dtype) {
case INFINI_DTYPE_F32:
return _device_info->calculate<SwiGLUOp, float>(_info, workspace, output, inputs, stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::swiglu::kunlun
\ No newline at end of file
#ifndef __SWIGLU_KUNLUN_H__
#define __SWIGLU_KUNLUN_H__
#include "../../../elementwise/kunlun/elementwise_kunlun.h"
ELEMENTWISE_DESCRIPTOR(swiglu, kunlun)
#endif // __SWIGLU_KUNLUN_H__
#ifndef __SWIGLU_KUNLUN_H__
#define __SWIGLU_KUNLUN_H__
#include "../../../devices/kunlun/kunlun_kernel_common.h"
#include "../../../elementwise/kunlun/elementwise_kunlun_kernel.h"
/// @brief Define swiglu op for local mem
typedef struct SwiGLUOp {
private:
template <typename T>
inline __device__ T sigmoid(T x) const {
return 1.0f / (1.0f + exp(-x));
}
public:
// This static number must be set in other Ops
static constexpr size_t num_inputs = 2;
template <typename T>
inline __device__ T operator()(const T *inputs) const {
T up = inputs[0];
T gate = inputs[1];
T out = gate * sigmoid(gate) * up;
return out;
}
} SwiGLUOp;
// Defination for swiglu kernel interface
LAUNCH_ELEMENTWISE_KERNEL_IMPL(SwiGLU, SwiGLUOp)
// Template instantiate
LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(SwiGLU, float)
#endif // __SWIGLU_KUNLUN_H__
...@@ -8,6 +8,9 @@ ...@@ -8,6 +8,9 @@
#ifdef ENABLE_CUDA_API #ifdef ENABLE_CUDA_API
#include "cuda/swiglu_cuda.cuh" #include "cuda/swiglu_cuda.cuh"
#endif #endif
#ifdef ENABLE_KUNLUN_API
#include "kunlun/swiglu_kunlun.h"
#endif
__C infiniStatus_t infiniopCreateSwiGLUDescriptor( __C infiniStatus_t infiniopCreateSwiGLUDescriptor(
infiniopHandle_t handle, infiniopHandle_t handle,
...@@ -33,6 +36,9 @@ __C infiniStatus_t infiniopCreateSwiGLUDescriptor( ...@@ -33,6 +36,9 @@ __C infiniStatus_t infiniopCreateSwiGLUDescriptor(
#ifdef ENABLE_CUDA_API #ifdef ENABLE_CUDA_API
CREATE(INFINI_DEVICE_NVIDIA, cuda); CREATE(INFINI_DEVICE_NVIDIA, cuda);
#endif #endif
#ifdef ENABLE_KUNLUN_API
CREATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_CAMBRICON_MLU #ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: { case DevCambriconMlu: {
return bangCreateSwiGLUDescriptor((BangHandle_t)handle, return bangCreateSwiGLUDescriptor((BangHandle_t)handle,
...@@ -80,6 +86,9 @@ __C infiniStatus_t infiniopGetSwiGLUWorkspaceSize(infiniopSwiGLUDescriptor_t des ...@@ -80,6 +86,9 @@ __C infiniStatus_t infiniopGetSwiGLUWorkspaceSize(infiniopSwiGLUDescriptor_t des
#ifdef ENABLE_CUDA_API #ifdef ENABLE_CUDA_API
GET(INFINI_DEVICE_NVIDIA, cuda) GET(INFINI_DEVICE_NVIDIA, cuda)
#endif #endif
#ifdef ENABLE_KUNLUN_API
GET(INFINI_DEVICE_KUNLUN, kunlun)
#endif
#ifdef ENABLE_CAMBRICON_MLU #ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: { case DevCambriconMlu: {
return bangGetSwiGLUWorkspaceSize((SwiGLUBangDescriptor_t)desc, size); return bangGetSwiGLUWorkspaceSize((SwiGLUBangDescriptor_t)desc, size);
...@@ -127,6 +136,9 @@ __C infiniStatus_t infiniopSwiGLU( ...@@ -127,6 +136,9 @@ __C infiniStatus_t infiniopSwiGLU(
#ifdef ENABLE_CUDA_API #ifdef ENABLE_CUDA_API
CALCULATE(INFINI_DEVICE_NVIDIA, cuda); CALCULATE(INFINI_DEVICE_NVIDIA, cuda);
#endif #endif
#ifdef ENABLE_KUNLUN_API
CALCULATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_CAMBRICON_MLU #ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: { case DevCambriconMlu: {
return bangSwiGLU((SwiGLUBangDescriptor_t)desc, c, a, b, stream); return bangSwiGLU((SwiGLUBangDescriptor_t)desc, c, a, b, stream);
...@@ -168,6 +180,9 @@ infiniopDestroySwiGLUDescriptor(infiniopSwiGLUDescriptor_t desc) { ...@@ -168,6 +180,9 @@ infiniopDestroySwiGLUDescriptor(infiniopSwiGLUDescriptor_t desc) {
#ifdef ENABLE_CUDA_API #ifdef ENABLE_CUDA_API
DELETE(INFINI_DEVICE_NVIDIA, cuda); DELETE(INFINI_DEVICE_NVIDIA, cuda);
#endif #endif
#ifdef ENABLE_KUNLUN_API
DELETE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_CAMBRICON_MLU #ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: { case DevCambriconMlu: {
return bangDestroySwiGLUDescriptor((SwiGLUBangDescriptor_t)desc); return bangDestroySwiGLUDescriptor((SwiGLUBangDescriptor_t)desc);
......
#ifndef __INFINIOP_REDUCE_KUNLUN_H__ #ifndef __INFINIOP_REDUCE_KUNLUN_H__
#define __INFINIOP_REDUCE_KUNLUN_H__ #define __INFINIOP_REDUCE_KUNLUN_H__
#include "../../devices/kunlun/kunlun_common.h" #include "../../devices/kunlun/kunlun_kernel_common.h"
namespace op::common_kunlun::reduce_op { namespace op::common_kunlun::reduce_op {
using namespace device::kunlun::kernel;
// Use 16 floats instruction to calculate reduce // Use 16 floats instruction to calculate reduce
// data_ptr is the pointer of LM // data_ptr is the pointer of LM
static inline __device__ float sumSquaredF32(float *data_ptr, int count) { static inline __device__ float sumSquaredF32(float *data_ptr, int count) {
......
...@@ -83,6 +83,7 @@ def test( ...@@ -83,6 +83,7 @@ def test(
b_stride=None, b_stride=None,
c_stride=None, c_stride=None,
dtype=torch.float16, dtype=torch.float16,
sync=None
): ):
print( print(
f"Testing Gemm on {torch_device} with alpha:{alpha}, beta:{beta}," f"Testing Gemm on {torch_device} with alpha:{alpha}, beta:{beta},"
...@@ -104,6 +105,9 @@ def test( ...@@ -104,6 +105,9 @@ def test(
] ]
a_tensor, b_tensor, c_tensor = [to_tensor(tensor, lib) for tensor in [a, b, c]] a_tensor, b_tensor, c_tensor = [to_tensor(tensor, lib) for tensor in [a, b, c]]
if sync is not None:
sync()
descriptor = infiniopGemmDescriptor_t() descriptor = infiniopGemmDescriptor_t()
check_error( check_error(
lib.infiniopCreateGemmDescriptor( lib.infiniopCreateGemmDescriptor(
......
...@@ -423,6 +423,7 @@ def test_operator(lib, device, test_func, test_cases, tensor_dtypes): ...@@ -423,6 +423,7 @@ def test_operator(lib, device, test_func, test_cases, tensor_dtypes):
infiniDeviceEnum_str_map[device], infiniDeviceEnum_str_map[device],
*test_case, *test_case,
tensor_dtype, tensor_dtype,
get_sync_func(device)
) )
finally: finally:
destroy_handle(lib, handle) destroy_handle(lib, handle)
...@@ -471,3 +472,14 @@ def get_test_devices(args): ...@@ -471,3 +472,14 @@ def get_test_devices(args):
devices_to_test = [InfiniDeviceEnum.CPU] devices_to_test = [InfiniDeviceEnum.CPU]
return devices_to_test return devices_to_test
def get_sync_func(device):
import torch
if device == "cpu":
sync = None
else:
sync = getattr(torch, infiniDeviceEnum_str_map[device]).synchronize
return sync
...@@ -27,17 +27,17 @@ _TEST_CASES = [ ...@@ -27,17 +27,17 @@ _TEST_CASES = [
# y_shape, x_shape, w_shape, y_stride, x_stride, w_dtype # y_shape, x_shape, w_shape, y_stride, x_stride, w_dtype
((1, 4), (1, 4), (4,), None, None, torch.float32), ((1, 4), (1, 4), (4,), None, None, torch.float32),
((16, 2048), (16, 2048), (2048,), None, None, torch.float32), ((16, 2048), (16, 2048), (2048,), None, None, torch.float32),
((16, 2048), (16, 2048), (2048,), None, None, torch.float16), # ((16, 2048), (16, 2048), (2048,), None, None, torch.float16),
((16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1), torch.float32), ((16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1), torch.float32),
((16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1), torch.float16), # ((16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1), torch.float16),
] ]
# x types used for testing # x types used for testing
_TENSOR_DTYPES = [torch.float16] _TENSOR_DTYPES = [torch.float32]
# Tolerance map for different data types # Tolerance map for different data types
_TOLERANCE_MAP = { _TOLERANCE_MAP = {
torch.float16: {"atol": 1e-3, "rtol": 1e-3}, torch.float32: {"atol": 1e-3, "rtol": 1e-3},
} }
DEBUG = False DEBUG = False
...@@ -72,6 +72,7 @@ def test( ...@@ -72,6 +72,7 @@ def test(
x_stride, x_stride,
w_dtype=torch.float16, w_dtype=torch.float16,
dtype=torch.float16, dtype=torch.float16,
sync=None
): ):
print( print(
f"Testing RMS_Norm on {torch_device} with y_shape:{y_shape} x_shape:{x_shape} w_shape:{w_shape}" f"Testing RMS_Norm on {torch_device} with y_shape:{y_shape} x_shape:{x_shape} w_shape:{w_shape}"
...@@ -89,9 +90,11 @@ def test( ...@@ -89,9 +90,11 @@ def test(
rearrange_if_needed(tensor, stride) rearrange_if_needed(tensor, stride)
for tensor, stride in zip([x, y], [x_stride, y_stride]) for tensor, stride in zip([x, y], [x_stride, y_stride])
] ]
x_tensor, y_tensor, w_tensor = [to_tensor(tensor, lib) for tensor in [x, y, w]] x_tensor, y_tensor, w_tensor = [to_tensor(tensor, lib) for tensor in [x, y, w]]
if sync is not None:
sync()
descriptor = infiniopRMSNormDescriptor_t() descriptor = infiniopRMSNormDescriptor_t()
check_error( check_error(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment