Unverified Commit 3eb14921 authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #210 from InfiniTensor/issue-209-elementwise-kunlun

Issue/209 elementwise kunlun
parents 1a4cfb99 cda0ccba
...@@ -16,7 +16,7 @@ typedef XPUStream kunlunStream_t; ...@@ -16,7 +16,7 @@ typedef XPUStream kunlunStream_t;
typedef XPUEvent kunlunEvent_t; typedef XPUEvent kunlunEvent_t;
typedef xdnn::Context *xdnnHandle_t; typedef xdnn::Context *xdnnHandle_t;
#define CHECK_XDNN(API) CHECK_INTERNAL(API, XPU_SUCCESS) #define CHECK_KUNLUN(API) CHECK_INTERNAL(API, XPU_SUCCESS)
namespace device::kunlun { namespace device::kunlun {
......
#ifndef __INFINIOP_KUNLUN_COMMON_H__ #ifndef __INFINIOP_KUNLUN_KERNEL_COMMON_H__
#define __INFINIOP_KUNLUN_COMMON_H__ #define __INFINIOP_KUNLUN_KERNEL_COMMON_H__
// This header file will only be include by .xpu file // This header file will only be include by .xpu file
#include "kunlun_kernel_dtype.h"
#include "xpu/kernel/xtdk.h" #include "xpu/kernel/xtdk.h"
#include "xpu/kernel/xtdk_math.h" #include "xpu/kernel/xtdk_math.h"
#include "xpu/kernel/xtdk_simd.h" #include "xpu/kernel/xtdk_simd.h"
#include "xpu/runtime.h" #include "xpu/runtime.h"
namespace device::kunlun::kernel {
// Get mask for kunlun xpu 512bit register calculation // Get mask for kunlun xpu 512bit register calculation
// if data is not enough to 512bit, padding zero and use // if data is not enough to 512bit, padding zero and use
// mask to identify real data // mask to identify real data
...@@ -26,6 +28,37 @@ inline __device__ void atomicAddF32(__shared_ptr__ float *ptr, float value) { ...@@ -26,6 +28,37 @@ inline __device__ void atomicAddF32(__shared_ptr__ float *ptr, float value) {
} }
} }
inline __device__ size_t indexToReducedOffset(
size_t flat_index,
size_t ndim,
const _ptrdiff_t *broadcasted_strides,
const _ptrdiff_t *target_strides) {
size_t res = 0;
for (size_t i = 0; i < ndim; ++i) {
res += flat_index / broadcasted_strides[i].value * target_strides[i].value;
flat_index %= broadcasted_strides[i].value;
mfence();
}
return res;
}
inline __device__ size_t indexToOffset(
size_t flat_index,
size_t ndim,
const _size_t *shape,
const _ptrdiff_t *strides) {
size_t res = 0;
for (size_t i = ndim; i-- > 0;) {
res += (flat_index % shape[i].value) * strides[i].value;
flat_index /= shape[i].value;
mfence();
}
return res;
}
} // namespace device::kunlun::kernel
// TODO: atomicAddF16 // TODO: atomicAddF16
// TODO: atomicAddI8 // TODO: atomicAddI8
#endif #endif
#ifndef __INFINIOP_KUNLUN_DTYPE_H__
#define __INFINIOP_KUNLUN_DTYPE_H__
#include "xpu/kernel/xtdk.h"
#include "xpu/kernel/xtdk_math.h"
#include "xpu/kernel/xtdk_simd.h"
#include "xpu/runtime.h"
// kunlun ptrdiff_t* is used to save ptrdiff_t array
// copied from host
typedef struct _ptrdiff_t {
long value; // 32 bit
long padding; // 32 bit
} _ptrdiff_t;
// same as ptrdiff
typedef struct _size_t {
size_t value;
size_t padding;
} _size_t;
#endif
#ifndef __INFINIOP_ELEMENTWISE_KUNLUN_H__
#define __INFINIOP_ELEMENTWISE_KUNLUN_H__
#include "../../../utils.h"
#include "../../devices/kunlun/kunlun_handle.h"
#include "elementwise_kunlun_api.h"
namespace op::elementwise::kunlun {
struct DeviceImpl::Opaque {
std::shared_ptr<device::kunlun::Handle::Internal> internal;
Opaque(const std::shared_ptr<device::kunlun::Handle::Internal> &internal_)
: internal(internal_) {}
template <size_t N, typename Op, typename Tdata, typename... Args>
infiniStatus_t calculateImpl(const op::elementwise::ElementwiseInfo &info,
void *workspace,
void *output,
const std::vector<const void *> &inputs,
kunlunStream_t stream,
Args &&...args) {
auto output_size = info.getOutputSize();
if (output_size == 0) {
return INFINI_STATUS_SUCCESS;
}
// Device pointers
const void **d_inputs_arr = nullptr;
const bool *d_input_contiguous = nullptr;
const bool *d_input_broadcasted = nullptr;
const size_t *d_output_shape = nullptr;
const ptrdiff_t *d_output_strides = nullptr;
const size_t *d_input_shapes = nullptr;
const ptrdiff_t *d_input_strides = nullptr;
CHECK_STATUS(infoToDevice<N>(info, workspace, inputs.data(), d_inputs_arr,
d_input_contiguous, d_input_broadcasted,
d_output_shape, d_output_strides,
d_input_shapes, d_input_strides));
Op::template launch<Tdata>(
output_size,
info.getNdim(),
info.isOutputContiguous(),
reinterpret_cast<const void *>(d_input_contiguous),
reinterpret_cast<const void *>(d_input_broadcasted),
reinterpret_cast<const void *>(d_output_shape),
reinterpret_cast<const void *>(d_input_shapes),
reinterpret_cast<const void *>(d_output_strides),
reinterpret_cast<const void *>(d_input_strides),
output,
reinterpret_cast<const void *const *>(d_inputs_arr),
stream,
args...);
return INFINI_STATUS_SUCCESS;
}
private:
template <size_t N>
infiniStatus_t infoToDevice(
const op::elementwise::ElementwiseInfo &info,
void *workspace,
const void *const *h_inputs_arr,
const void **&d_inputs_arr,
const bool *&d_input_contiguous,
const bool *&d_input_broadcasted,
const size_t *&d_output_shape,
const ptrdiff_t *&d_output_strides,
const size_t *&d_input_shapes,
const ptrdiff_t *&d_input_strides) const {
constexpr auto input_size = N;
const auto ndim = info.getNdim();
constexpr auto input_arr_size = N * sizeof(*h_inputs_arr);
const int8_t *info_meta_start = info.getMetaStart();
const int8_t *d_meta_start = reinterpret_cast<int8_t *>(workspace) + input_arr_size;
// copy the input pointer array and meta to device
CHECK_KUNLUN(xpu_memcpy(workspace, h_inputs_arr, input_arr_size, XPU_HOST_TO_DEVICE));
CHECK_KUNLUN(xpu_memcpy((void *)d_meta_start, info_meta_start, info.getMetaMemSize(), XPU_HOST_TO_DEVICE));
// offset/assign the pointers
d_inputs_arr = reinterpret_cast<const void **>(workspace);
d_output_shape = reinterpret_cast<const size_t *>(d_meta_start);
d_output_strides = reinterpret_cast<const ptrdiff_t *>(d_output_shape + ndim);
d_input_shapes = reinterpret_cast<const size_t *>(d_output_strides + ndim);
d_input_strides = reinterpret_cast<const ptrdiff_t *>(d_input_shapes + input_size * ndim);
d_input_contiguous = reinterpret_cast<const bool *>(d_input_strides + input_size * ndim);
d_input_broadcasted = reinterpret_cast<const bool *>(d_input_contiguous + input_size);
return INFINI_STATUS_SUCCESS;
}
};
template <typename... Args>
utils::Result<DeviceImpl *> DeviceImpl::create(Args &&...args) {
auto opaque = std::make_shared<Opaque>(std::forward<Args>(args)...);
return utils::Result<DeviceImpl *>(new DeviceImpl(opaque));
}
template <typename Op, typename Tdata, typename... Args>
infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info,
void *workspace,
void *output,
const std::vector<const void *> &inputs,
void *stream,
Args &&...args) {
constexpr size_t N = Op::num_inputs;
return _opaque->calculateImpl<N, Op, Tdata>(
info, workspace, output, inputs,
reinterpret_cast<kunlunStream_t>(stream),
std::forward<Args>(args)...);
}
} // namespace op::elementwise::kunlun
// Template for kunlun kernel interface declaration
#define LAUNCH_ELEMENTWISE_KERNEL(OpName) \
template <typename Tdata, typename... Args> \
void launch##OpName##Kernel( \
size_t output_size, \
size_t ndim, \
bool output_contiguous, \
const void *input_contiguous, \
const void *input_broadcasted, \
const void *output_shape, \
const void *input_shapes, \
const void *output_strides, \
const void *input_strides, \
void *output, \
const void *const *inputs, \
XPUStream stream, \
Args... args);
#endif
#ifndef __INFINIOP_ELEMENTWISE_KUNLUN_API_H__
#define __INFINIOP_ELEMENTWISE_KUNLUN_API_H__
#include "../elementwise.h"
namespace op::elementwise::kunlun {
class DeviceImpl final {
struct Opaque;
std::shared_ptr<Opaque> _opaque;
DeviceImpl(std::shared_ptr<Opaque> opaque) : _opaque(std::move(opaque)) {}
public:
~DeviceImpl() = default;
template <typename... Args>
static utils::Result<DeviceImpl *> create(Args &&...args);
template <typename Op, typename Tdata, typename... Args>
infiniStatus_t calculate(
const op::elementwise::ElementwiseInfo &info,
void *workspace,
void *output,
const std::vector<const void *> &inputs,
void *stream,
Args &&...args);
};
} // namespace op::elementwise::kunlun
#define CREATE_ELEMENTWISE_KUNLUN_DESCRIPTOR(HANDLE, DTYPE, OUT_DESC, INPUT_DESC_VEC) \
\
auto info_result = op::elementwise::ElementwiseInfo::create(OUT_DESC, INPUT_DESC_VEC); \
CHECK_RESULT(info_result); \
auto info = info_result.take(); \
auto workspace_size = info.getMetaMemSize() + info.getInputSize() * sizeof(void *); \
\
auto device_impl_result = op::elementwise::kunlun::DeviceImpl::create(HANDLE->internal()); \
CHECK_RESULT(device_impl_result); \
\
*desc_ptr = new Descriptor( \
DTYPE, \
std::move(info), \
std::move(device_impl_result.take()), \
workspace_size, \
HANDLE->device, \
HANDLE->device_id);
#endif
#ifndef __INFINIOP_ELEMENTWISE_KUNLUN_XPU__
#define __INFINIOP_ELEMENTWISE_KUNLUN_XPU__
#include "../../devices/kunlun/kunlun_kernel_common.h"
using namespace device::kunlun::kernel;
/**
* @brief Computes input tile offset
*/
struct InputIndexer {
size_t idx;
size_t ndim;
const bool *input_contiguous;
const bool *input_broadcasted;
const _size_t *input_shapes;
const _ptrdiff_t *input_strides;
const _ptrdiff_t *output_strides;
__device__ size_t operator()(size_t input_id) const {
return input_contiguous[input_id]
? idx
: (input_broadcasted[input_id]
? indexToReducedOffset(idx, ndim, output_strides, input_strides + input_id * ndim)
: indexToOffset(idx, ndim, input_shapes + input_id * ndim, input_strides + input_id * ndim));
}
};
/**
* @brief Computes the output index in memory, accounting for strides if non-contiguous.
*
* @param idx Linear index.
* @param is_contiguous Whether the output tensor is contiguous.
* @param ndim Number of dimensions.
* @param shape Shape of the output tensor.
* @param strides Strides of the output tensor.
* @return Memory offset index.
*/
inline __device__ size_t
getOutputIndex(size_t idx,
bool is_contiguous,
size_t ndim,
const _size_t *shape,
const _ptrdiff_t *strides) {
return is_contiguous ? idx : indexToOffset(idx, ndim, shape, strides);
}
template <size_t N, typename Op, typename Tdata, typename... Args>
__device__ void launchOp(
__global_ptr__ Tdata **typed_inputs, // gm pointer
__global_ptr__ Tdata *output, // gm pointer output
Tdata *inputs_buf, // local mem buffer
size_t *input_indexes,
size_t output_index,
Args... args) {
static_assert(N == Op::num_inputs, "template N is not equal to Op::num_inputs!\n");
#pragma unroll
// Copy inputs to buf
for (size_t i = 0; i < N; i++) {
auto gm = typed_inputs[i] + input_indexes[i];
auto lm = inputs_buf + i;
GM2LM_ASYNC(gm, lm, 1 * sizeof(Tdata));
}
mfence();
// Calculate elementwise
// Inputs save all operands
Tdata out = Op{}(inputs_buf, args...);
// Copy out to gm
LM2GM_ASYNC(&out, output + output_index, 1 * sizeof(Tdata));
mfence();
}
template <size_t N, typename Op, typename Tdata, typename... Args>
__global__ void elementwiseKernel(
size_t output_size,
size_t ndim,
bool output_contiguous,
const bool *input_contiguous_gm,
const bool *input_broadcasted_gm,
const _size_t *output_shape_gm,
const _size_t *input_shapes_gm,
const _ptrdiff_t *output_strides_gm,
const _ptrdiff_t *input_strides_gm,
Tdata *output,
const void *const *inputs,
Args... args) {
int cid = core_id();
int ncores = core_num();
if (cid >= ncores) {
return;
}
int thread_id = ncores * cluster_id() + cid;
int nthreads = ncores * cluster_num();
// Cast input gm pointer type
auto typed_inputs = reinterpret_cast<const __global_ptr__ Tdata *const __global_ptr__ *>(inputs);
const int BUFF_SIZE = 64;
// Input data cache
__local__ Tdata inputs_buf[N];
// Input contiguous/broadcasted flags
__local__ bool input_contiguous[N];
__local__ bool input_broadcasted[N];
// Input shape/strides
__local__ _size_t input_shapes[N * ndim];
__local__ _ptrdiff_t input_strides[N * ndim];
// Output shape/strides
__local__ _size_t output_shape[ndim];
__local__ _ptrdiff_t output_strides[ndim];
// Inputs gm ptr buf
__local__ __global_ptr__ Tdata *typed_inputs_ptr[N];
// Load from gm
GM2LM_ASYNC(input_contiguous_gm, input_contiguous, N * sizeof(bool));
GM2LM_ASYNC(input_broadcasted_gm, input_broadcasted, N * sizeof(bool));
GM2LM_ASYNC(input_shapes_gm, input_shapes, N * ndim * sizeof(_size_t));
GM2LM_ASYNC(input_strides_gm, input_strides, N * ndim * sizeof(_ptrdiff_t));
GM2LM_ASYNC(output_shape_gm, output_shape, ndim * sizeof(_size_t));
GM2LM_ASYNC(output_strides_gm, output_strides, ndim * sizeof(_ptrdiff_t));
GM2LM_ASYNC(typed_inputs, typed_inputs_ptr, N * sizeof(__global_ptr__ Tdata *));
mfence();
int len_per_loop = min(BUFF_SIZE, roundup_div(output_size, nthreads));
for (int start = thread_id * len_per_loop; start < output_size; start += nthreads * len_per_loop) {
size_t read_len = min(len_per_loop, output_size - start);
for (int idx = start; idx < start + read_len; ++idx) {
size_t out_idx = getOutputIndex(static_cast<size_t>(idx), output_contiguous,
ndim, output_shape, output_strides);
InputIndexer indexer{static_cast<size_t>(idx), ndim, input_contiguous, input_broadcasted,
input_shapes, input_strides, output_strides};
// Get index offset for every operand
size_t indexes[N];
for (size_t i = 0; i < N; i++) {
indexes[i] = indexer(i);
}
// Launch operater
launchOp<N, Op, Tdata>(&typed_inputs_ptr[0], output, inputs_buf, indexes, out_idx, args...);
}
}
sync_cluster();
}
#define LAUNCH_ELEMENTWISE_KERNEL_IMPL(OpName, Op) \
template <typename Tdata, typename... Args> \
void launch##OpName##Kernel( \
size_t output_size, \
size_t ndim, \
bool output_contiguous, \
const void *input_contiguous, \
const void *input_broadcasted, \
const void *output_shape, \
const void *input_shapes, \
const void *output_strides, \
const void *input_strides, \
void *output, \
const void *const *inputs, \
XPUStream stream, \
Args... args) { \
elementwiseKernel<Op::num_inputs, Op, Tdata><<<8, 64, stream>>>( \
output_size, ndim, output_contiguous, \
reinterpret_cast<const bool *>(input_contiguous), \
reinterpret_cast<const bool *>(input_broadcasted), \
reinterpret_cast<const _size_t *>(output_shape), \
reinterpret_cast<const _size_t *>(input_shapes), \
reinterpret_cast<const _ptrdiff_t *>(output_strides), \
reinterpret_cast<const _ptrdiff_t *>(input_strides), \
reinterpret_cast<Tdata *>(output), inputs, args...); \
}
#define LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(OpName, T, ...) \
template void launch##OpName##Kernel<T, ##__VA_ARGS__>( \
size_t output_size, \
size_t ndim, \
bool output_contiguous, \
const void *input_contiguous, \
const void *input_broadcasted, \
const void *output_shape, \
const void *input_shapes, \
const void *output_strides, \
const void *input_strides, \
void *output, \
const void *const *inputs, \
XPUStream stream, \
##__VA_ARGS__);
#endif
...@@ -62,7 +62,7 @@ infiniStatus_t calculate( ...@@ -62,7 +62,7 @@ infiniStatus_t calculate(
(kunlunStream_t)stream, (kunlunStream_t)stream,
[&](xdnnHandle_t handle) { [&](xdnnHandle_t handle) {
for (size_t i = 0; i < info.batch; i++) { for (size_t i = 0; i < info.batch; i++) {
CHECK_XDNN((xdnn::fc_fusion<Tdata, Tdata, Tdata, int16_t>( CHECK_KUNLUN((xdnn::fc_fusion<Tdata, Tdata, Tdata, int16_t>(
handle, handle,
(Tdata *)((char *)a + i * info.a_matrix.stride * unit), (Tdata *)((char *)a + i * info.a_matrix.stride * unit),
(Tdata *)((char *)b + i * info.b_matrix.stride * unit), (Tdata *)((char *)b + i * info.b_matrix.stride * unit),
......
#ifndef __RMS_NORM_KUNLUN_KERNEL_XPU__ #ifndef __RMS_NORM_KUNLUN_KERNEL_XPU__
#define __RMS_NORM_KUNLUN_KERNEL_XPU__ #define __RMS_NORM_KUNLUN_KERNEL_XPU__
#include "../../../devices/kunlun/kunlun_common.h" #include "../../../devices/kunlun/kunlun_kernel_common.h"
#include "../../../reduce/kunlun/reduce_kunlun.h" #include "../../../reduce/kunlun/reduce_kunlun.h"
using namespace device::kunlun::kernel;
// Element wise mul used in x * w // Element wise mul used in x * w
static inline __device__ void elementwiseMulRms(float *x, float *w, float *y, int count, float rms) { static inline __device__ void elementwiseMulRms(float *x, float *w, float *y, int count, float rms) {
int remain = count % 16; int remain = count % 16;
......
#include "swiglu_kunlun.h"
// Op interface declare
LAUNCH_ELEMENTWISE_KERNEL(SwiGLU)
namespace op::swiglu::kunlun {
typedef struct SwiGLUOp {
static constexpr size_t num_inputs = 2;
template <typename Tdata, typename... Args>
static infiniStatus_t launch(Args... args) {
launchSwiGLUKernel<Tdata>(args...);
return INFINI_STATUS_SUCCESS;
}
} SwiGLUOp;
Descriptor::~Descriptor() = default;
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {
auto handle = reinterpret_cast<device::kunlun::Handle *>(handle_);
auto dtype = out_desc->dtype();
const auto &up_desc = input_desc_vec.at(0);
const auto &gate_desc = input_desc_vec.at(1);
const auto &out_shape = out_desc->shape();
const auto &up_shape = up_desc->shape();
const auto &gate_shape = gate_desc->shape();
CHECK_DTYPE(dtype, INFINI_DTYPE_F32);
CHECK_SAME_SHAPE(out_shape, up_shape, gate_shape);
// create KUNLUN elementwise descriptor
CREATE_ELEMENTWISE_KUNLUN_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec)
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *output,
std::vector<const void *> inputs,
void *stream) const {
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
switch (_dtype) {
case INFINI_DTYPE_F32:
return _device_info->calculate<SwiGLUOp, float>(_info, workspace, output, inputs, stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::swiglu::kunlun
#ifndef __SWIGLU_KUNLUN_H__
#define __SWIGLU_KUNLUN_H__
#include "../../../elementwise/kunlun/elementwise_kunlun.h"
ELEMENTWISE_DESCRIPTOR(swiglu, kunlun)
#endif // __SWIGLU_KUNLUN_H__
#ifndef __SWIGLU_KUNLUN_H__
#define __SWIGLU_KUNLUN_H__
#include "../../../devices/kunlun/kunlun_kernel_common.h"
#include "../../../elementwise/kunlun/elementwise_kunlun_kernel.h"
/// @brief Define swiglu op for local mem
typedef struct SwiGLUOp {
private:
template <typename T>
inline __device__ T sigmoid(T x) const {
return 1.0f / (1.0f + exp(-x));
}
public:
// This static number must be set in other Ops
static constexpr size_t num_inputs = 2;
template <typename T>
inline __device__ T operator()(const T *inputs) const {
T up = inputs[0];
T gate = inputs[1];
T out = gate * sigmoid(gate) * up;
return out;
}
} SwiGLUOp;
// Definition for swiglu kernel interface
LAUNCH_ELEMENTWISE_KERNEL_IMPL(SwiGLU, SwiGLUOp)
// Template instantiate
LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(SwiGLU, float)
#endif // __SWIGLU_KUNLUN_H__
...@@ -8,6 +8,9 @@ ...@@ -8,6 +8,9 @@
#ifdef ENABLE_CUDA_API #ifdef ENABLE_CUDA_API
#include "cuda/swiglu_cuda.cuh" #include "cuda/swiglu_cuda.cuh"
#endif #endif
#ifdef ENABLE_KUNLUN_API
#include "kunlun/swiglu_kunlun.h"
#endif
__C infiniStatus_t infiniopCreateSwiGLUDescriptor( __C infiniStatus_t infiniopCreateSwiGLUDescriptor(
infiniopHandle_t handle, infiniopHandle_t handle,
...@@ -33,6 +36,9 @@ __C infiniStatus_t infiniopCreateSwiGLUDescriptor( ...@@ -33,6 +36,9 @@ __C infiniStatus_t infiniopCreateSwiGLUDescriptor(
#ifdef ENABLE_CUDA_API #ifdef ENABLE_CUDA_API
CREATE(INFINI_DEVICE_NVIDIA, cuda); CREATE(INFINI_DEVICE_NVIDIA, cuda);
#endif #endif
#ifdef ENABLE_KUNLUN_API
CREATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_CAMBRICON_MLU #ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: { case DevCambriconMlu: {
return bangCreateSwiGLUDescriptor((BangHandle_t)handle, return bangCreateSwiGLUDescriptor((BangHandle_t)handle,
...@@ -80,6 +86,9 @@ __C infiniStatus_t infiniopGetSwiGLUWorkspaceSize(infiniopSwiGLUDescriptor_t des ...@@ -80,6 +86,9 @@ __C infiniStatus_t infiniopGetSwiGLUWorkspaceSize(infiniopSwiGLUDescriptor_t des
#ifdef ENABLE_CUDA_API #ifdef ENABLE_CUDA_API
GET(INFINI_DEVICE_NVIDIA, cuda) GET(INFINI_DEVICE_NVIDIA, cuda)
#endif #endif
#ifdef ENABLE_KUNLUN_API
GET(INFINI_DEVICE_KUNLUN, kunlun)
#endif
#ifdef ENABLE_CAMBRICON_MLU #ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: { case DevCambriconMlu: {
return bangGetSwiGLUWorkspaceSize((SwiGLUBangDescriptor_t)desc, size); return bangGetSwiGLUWorkspaceSize((SwiGLUBangDescriptor_t)desc, size);
...@@ -127,6 +136,9 @@ __C infiniStatus_t infiniopSwiGLU( ...@@ -127,6 +136,9 @@ __C infiniStatus_t infiniopSwiGLU(
#ifdef ENABLE_CUDA_API #ifdef ENABLE_CUDA_API
CALCULATE(INFINI_DEVICE_NVIDIA, cuda); CALCULATE(INFINI_DEVICE_NVIDIA, cuda);
#endif #endif
#ifdef ENABLE_KUNLUN_API
CALCULATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_CAMBRICON_MLU #ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: { case DevCambriconMlu: {
return bangSwiGLU((SwiGLUBangDescriptor_t)desc, c, a, b, stream); return bangSwiGLU((SwiGLUBangDescriptor_t)desc, c, a, b, stream);
...@@ -168,6 +180,9 @@ infiniopDestroySwiGLUDescriptor(infiniopSwiGLUDescriptor_t desc) { ...@@ -168,6 +180,9 @@ infiniopDestroySwiGLUDescriptor(infiniopSwiGLUDescriptor_t desc) {
#ifdef ENABLE_CUDA_API #ifdef ENABLE_CUDA_API
DELETE(INFINI_DEVICE_NVIDIA, cuda); DELETE(INFINI_DEVICE_NVIDIA, cuda);
#endif #endif
#ifdef ENABLE_KUNLUN_API
DELETE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_CAMBRICON_MLU #ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: { case DevCambriconMlu: {
return bangDestroySwiGLUDescriptor((SwiGLUBangDescriptor_t)desc); return bangDestroySwiGLUDescriptor((SwiGLUBangDescriptor_t)desc);
......
#ifndef __INFINIOP_REDUCE_KUNLUN_H__ #ifndef __INFINIOP_REDUCE_KUNLUN_H__
#define __INFINIOP_REDUCE_KUNLUN_H__ #define __INFINIOP_REDUCE_KUNLUN_H__
#include "../../devices/kunlun/kunlun_common.h" #include "../../devices/kunlun/kunlun_kernel_common.h"
namespace op::common_kunlun::reduce_op { namespace op::common_kunlun::reduce_op {
using namespace device::kunlun::kernel;
// Use 16 floats instruction to calculate reduce // Use 16 floats instruction to calculate reduce
// data_ptr is the pointer of LM // data_ptr is the pointer of LM
static inline __device__ float sumSquaredF32(float *data_ptr, int count) { static inline __device__ float sumSquaredF32(float *data_ptr, int count) {
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include "bang/infinirt_bang.h" #include "bang/infinirt_bang.h"
#include "cpu/infinirt_cpu.h" #include "cpu/infinirt_cpu.h"
#include "cuda/infinirt_cuda.cuh" #include "cuda/infinirt_cuda.cuh"
#include "kunlun/infinirt_kunlun.h"
#include "maca/infinirt_maca.h" #include "maca/infinirt_maca.h"
#include "musa/infinirt_musa.h" #include "musa/infinirt_musa.h"
...@@ -66,8 +67,11 @@ __C infiniStatus_t infinirtGetDevice(infiniDevice_t *device_ptr, int *device_id_ ...@@ -66,8 +67,11 @@ __C infiniStatus_t infinirtGetDevice(infiniDevice_t *device_ptr, int *device_id_
case INFINI_DEVICE_MOORE: \ case INFINI_DEVICE_MOORE: \
_status = infinirt::musa::API PARAMS; \ _status = infinirt::musa::API PARAMS; \
break; \ break; \
case INFINI_DEVICE_KUNLUN: \
_status = infinirt::kunlun::API PARAMS; \
break; \
default: \ default: \
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; \ _status = INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; \
} \ } \
{ ACTION; } \ { ACTION; } \
return _status; \ return _status; \
......
...@@ -101,6 +101,7 @@ def test( ...@@ -101,6 +101,7 @@ def test(
v_stride=None, v_stride=None,
k_cache_stride=None, k_cache_stride=None,
v_cache_stride=None, v_cache_stride=None,
sync=None
): ):
print( print(
f"Testing Attention on {torch_device} with n_q_head:{n_q_head} n_kv_head:{n_kv_head} seq_len:{seq_len} head_dim:{head_dim} pos:{pos} " f"Testing Attention on {torch_device} with n_q_head:{n_q_head} n_kv_head:{n_kv_head} seq_len:{seq_len} head_dim:{head_dim} pos:{pos} "
...@@ -140,6 +141,9 @@ def test( ...@@ -140,6 +141,9 @@ def test(
k_cache_tensor = to_tensor(k_cache, lib) k_cache_tensor = to_tensor(k_cache, lib)
v_cache_tensor = to_tensor(v_cache, lib) v_cache_tensor = to_tensor(v_cache, lib)
if sync is not None:
sync()
descriptor = infiniopAttentionDescriptor_t() descriptor = infiniopAttentionDescriptor_t()
check_error( check_error(
lib.infiniopCreateAttentionDescriptor( lib.infiniopCreateAttentionDescriptor(
......
...@@ -88,6 +88,7 @@ def test( ...@@ -88,6 +88,7 @@ def test(
padding, padding,
strides, strides,
tensor_dtype=torch.float16, tensor_dtype=torch.float16,
sync=None
): ):
print( print(
f"Testing AvgPool on {torch_device} with x_shape:{x_shape} kernel_shape:{k_shape} padding:{padding} strides:{strides} dtype:{tensor_dtype}" f"Testing AvgPool on {torch_device} with x_shape:{x_shape} kernel_shape:{k_shape} padding:{padding} strides:{strides} dtype:{tensor_dtype}"
...@@ -109,6 +110,10 @@ def test( ...@@ -109,6 +110,10 @@ def test(
x_tensor = to_tensor(x, lib) x_tensor = to_tensor(x, lib)
y_tensor = to_tensor(y, lib) y_tensor = to_tensor(y, lib)
if sync is not None:
sync()
descriptor = infiniopAvgPoolDescriptor_t() descriptor = infiniopAvgPoolDescriptor_t()
check_error( check_error(
......
...@@ -87,6 +87,7 @@ def test( ...@@ -87,6 +87,7 @@ def test(
y_stride=None, y_stride=None,
inplace=Inplace.OUT_OF_PLACE, inplace=Inplace.OUT_OF_PLACE,
dtype=torch.float16, dtype=torch.float16,
sync=None
): ):
print( print(
f"Testing CausalSoftmax on {torch_device} with shape:{shape} x_stride:{x_stride} y_stride:{y_stride} dtype:{dtype} inplace:{inplace}" f"Testing CausalSoftmax on {torch_device} with shape:{shape} x_stride:{x_stride} y_stride:{y_stride} dtype:{dtype} inplace:{inplace}"
...@@ -108,6 +109,9 @@ def test( ...@@ -108,6 +109,9 @@ def test(
y = rearrange_if_needed(y, y_stride) y = rearrange_if_needed(y, y_stride)
y_tensor = to_tensor(y, lib) y_tensor = to_tensor(y, lib)
if sync is not None:
sync()
descriptor = infiniopCausalSoftmaxDescriptor_t() descriptor = infiniopCausalSoftmaxDescriptor_t()
check_error( check_error(
lib.infiniopCreateCausalSoftmaxDescriptor( lib.infiniopCreateCausalSoftmaxDescriptor(
......
...@@ -95,6 +95,7 @@ def test( ...@@ -95,6 +95,7 @@ def test(
dilations, dilations,
tensor_stride=None, tensor_stride=None,
tensor_dtype=torch.float16, tensor_dtype=torch.float16,
sync=None
): ):
assert len(pads) == len(strides) == len(dilations) assert len(pads) == len(strides) == len(dilations)
print( print(
...@@ -118,8 +119,11 @@ def test( ...@@ -118,8 +119,11 @@ def test(
x_tensor = to_tensor(x, lib) x_tensor = to_tensor(x, lib)
w_tensor = to_tensor(w, lib) w_tensor = to_tensor(w, lib)
y_tensor = to_tensor(y, lib) y_tensor = to_tensor(y, lib)
descriptor = infiniopConvDescriptor_t()
if sync is not None:
sync()
descriptor = infiniopConvDescriptor_t()
check_error( check_error(
lib.infiniopCreateConvDescriptor( lib.infiniopCreateConvDescriptor(
handle, handle,
......
...@@ -52,6 +52,7 @@ def test( ...@@ -52,6 +52,7 @@ def test(
y_stride=None, y_stride=None,
x_stride=None, x_stride=None,
tensor_dtype=torch.float16, tensor_dtype=torch.float16,
sync=None
): ):
print( print(
f"Testing Expand on {torch_device} with x_shape:{x_shape} y_shape:{y_shape} x_stride:{x_stride} y_stride:{y_stride} dtype:{tensor_dtype}" f"Testing Expand on {torch_device} with x_shape:{x_shape} y_shape:{y_shape} x_stride:{x_stride} y_stride:{y_stride} dtype:{tensor_dtype}"
...@@ -76,8 +77,11 @@ def test( ...@@ -76,8 +77,11 @@ def test(
x_tensor = to_tensor(x, lib) x_tensor = to_tensor(x, lib)
y_tensor = to_tensor(y, lib) y_tensor = to_tensor(y, lib)
descriptor = infiniopExpandDescriptor_t()
if sync is not None:
sync()
descriptor = infiniopExpandDescriptor_t()
check_error( check_error(
lib.infiniopCreateExpandDescriptor( lib.infiniopCreateExpandDescriptor(
handle, handle,
......
...@@ -83,6 +83,7 @@ def test( ...@@ -83,6 +83,7 @@ def test(
b_stride=None, b_stride=None,
c_stride=None, c_stride=None,
dtype=torch.float16, dtype=torch.float16,
sync=None
): ):
print( print(
f"Testing Gemm on {torch_device} with alpha:{alpha}, beta:{beta}," f"Testing Gemm on {torch_device} with alpha:{alpha}, beta:{beta},"
...@@ -104,6 +105,9 @@ def test( ...@@ -104,6 +105,9 @@ def test(
] ]
a_tensor, b_tensor, c_tensor = [to_tensor(tensor, lib) for tensor in [a, b, c]] a_tensor, b_tensor, c_tensor = [to_tensor(tensor, lib) for tensor in [a, b, c]]
if sync is not None:
sync()
descriptor = infiniopGemmDescriptor_t() descriptor = infiniopGemmDescriptor_t()
check_error( check_error(
lib.infiniopCreateGemmDescriptor( lib.infiniopCreateGemmDescriptor(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment