Commit f7e7c7ba authored by zhangyue's avatar zhangyue
Browse files

支持p800上编译手写算子,重构elementwise 算子组件

parent 7d3ca92d
#include "kunlun_common.h"
#include "../../../utils.h"
#include <functional>
namespace device::kunlun {
infiniStatus_t Handle::Internal::useXdnn(kunlunStream_t stream, const Fn<xdnnHandle_t> &f) const {
auto handle = dnn_handles.pop();
if (!handle) {
*handle = xdnn::create_context();
}
(*handle)->set_stream(stream);
CHECK_STATUS(f(*handle));
dnn_handles.push(std::move(*handle));
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Handle::create(InfiniopHandle **handle_ptr, int device_id) {
*handle_ptr = new Handle(device_id);
return INFINI_STATUS_SUCCESS;
}
} // namespace device::kunlun
\ No newline at end of file
#include "../pool.h"
#include "kunlun_handle.h"
#include <xpu/runtime.h>
#include <xpu/runtime_ex.h>
#include <xpu/xdnn.h>
namespace xdnn = baidu::xpu::api;
typedef XPUStream kunlunStream_t;
typedef XPUEvent kunlunEvent_t;
typedef xdnn::Context *xdnnHandle_t;
#define CHECK_KUNLUN(API) CHECK_INTERNAL(API, XPU_SUCCESS)
namespace device::kunlun {
class Handle::Internal {
Pool<xdnnHandle_t> dnn_handles;
template <typename T>
using Fn = std::function<infiniStatus_t(T)>;
public:
infiniStatus_t useXdnn(kunlunStream_t stream, const Fn<xdnnHandle_t> &f) const;
};
} // namespace device::kunlun
#include "kunlun_handle.h" #include "kunlun_common.h"
namespace device::kunlun { namespace device::kunlun {
...@@ -10,20 +10,4 @@ auto Handle::internal() const -> const std::shared_ptr<Internal> & { ...@@ -10,20 +10,4 @@ auto Handle::internal() const -> const std::shared_ptr<Internal> & {
return _internal; return _internal;
} }
infiniStatus_t Handle::Internal::useXdnn(kunlunStream_t stream, const Fn<xdnnHandle_t> &f) const {
auto handle = dnn_handles.pop();
if (!handle) {
*handle = xdnn::create_context();
}
(*handle)->set_stream(stream);
CHECK_STATUS(f(*handle));
dnn_handles.push(std::move(*handle));
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Handle::create(InfiniopHandle **handle_ptr, int device_id) {
*handle_ptr = new Handle(device_id);
return INFINI_STATUS_SUCCESS;
}
} // namespace device::kunlun } // namespace device::kunlun
#ifndef __INFINIOP_KUNLUN_HANDLE_H__ #ifndef __INFINIOP_KUNLUN_HANDLE_H__
#define __INFINIOP_KUNLUN_HANDLE_H__ #define __INFINIOP_KUNLUN_HANDLE_H__
#include "../../../utils.h"
#include "../../handle.h" #include "../../handle.h"
#include "../pool.h"
#include <functional>
#include <memory> #include <memory>
#include <xpu/runtime.h>
#include <xpu/runtime_ex.h>
#include <xpu/xdnn.h>
namespace xdnn = baidu::xpu::api;
typedef XPUStream kunlunStream_t;
typedef XPUEvent kunlunEvent_t;
typedef xdnn::Context *xdnnHandle_t;
#define CHECK_KUNLUN(API) CHECK_INTERNAL(API, XPU_SUCCESS)
namespace device::kunlun { namespace device::kunlun {
...@@ -33,15 +19,6 @@ public: ...@@ -33,15 +19,6 @@ public:
static infiniStatus_t create(InfiniopHandle **handle_ptr, int device_id); static infiniStatus_t create(InfiniopHandle **handle_ptr, int device_id);
}; };
class Handle::Internal {
Pool<xdnnHandle_t> dnn_handles;
template <typename T>
using Fn = std::function<infiniStatus_t(T)>;
public:
infiniStatus_t useXdnn(kunlunStream_t stream, const Fn<xdnnHandle_t> &f) const;
};
} // namespace device::kunlun } // namespace device::kunlun
#endif // __INFINIOP_KUNLUN_HANDLE_H__ #endif // __INFINIOP_KUNLUN_HANDLE_H__
...@@ -2,13 +2,25 @@ ...@@ -2,13 +2,25 @@
#define __INFINIOP_KUNLUN_KERNEL_COMMON_H__ #define __INFINIOP_KUNLUN_KERNEL_COMMON_H__
// This header file will only be include by .xpu file // This header file will only be include by .xpu file
#include "kunlun_kernel_dtype.h"
#include "xpu/kernel/xtdk.h"
#include "xpu/kernel/xtdk_math.h"
#include "xpu/kernel/xtdk_simd.h"
#include "xpu/runtime.h" #include "xpu/runtime.h"
#include <xpu/kernel/xtdk.h>
#include <xpu/kernel/xtdk_io.h>
#include <xpu/kernel/xtdk_math.h>
#include <xpu/kernel/xtdk_simd.h>
namespace device::kunlun::kernel { namespace device::kunlun::kernel {
typedef struct _ptrdiff_t {
ptrdiff_t value; // 32 bit
ptrdiff_t padding; // 32 bit
} _ptrdiff_t;
// same as ptrdiff
typedef struct _size_t {
size_t value;
size_t padding;
} _size_t;
// Get mask for kunlun xpu 512bit register calculation // Get mask for kunlun xpu 512bit register calculation
// if data is not enough to 512bit, padding zero and use // if data is not enough to 512bit, padding zero and use
// mask to identify real data // mask to identify real data
...@@ -28,37 +40,50 @@ inline __device__ void atomicAddF32(__shared_ptr__ float *ptr, float value) { ...@@ -28,37 +40,50 @@ inline __device__ void atomicAddF32(__shared_ptr__ float *ptr, float value) {
} }
} }
inline __device__ size_t indexToReducedOffset( /**
size_t flat_index, * @brief Get index of broadcasted input
size_t ndim, * flat_index: flatten index of output tensor
const _ptrdiff_t *broadcasted_strides, * ndim: dim of output tensor
const _ptrdiff_t *target_strides) { * broadcasted_strides: strides of output tensor
* target_strides: strides of input tensor
*/
inline __device__ int indexToReducedOffset(
int flat_index, // output flatten index
int ndim, // output dims
const _ptrdiff_t *broadcasted_strides, // output strides
const _ptrdiff_t *target_strides) { // strides of inputs
size_t res = 0; int res = 0;
for (size_t i = 0; i < ndim; ++i) { for (int i = 0; i < ndim; ++i) {
res += flat_index / broadcasted_strides[i].value * target_strides[i].value; res += flat_index / broadcasted_strides[i].value * target_strides[i].value;
flat_index %= broadcasted_strides[i].value; flat_index %= broadcasted_strides[i].value;
mfence();
} }
return res; return res;
} }
inline __device__ size_t indexToOffset( /**
size_t flat_index, * @brief Get real offset of input index
size_t ndim, * flat_index: flatten index input
* ndim: dim of input tensor
* shape: shape of input tensor
* strides: strides of input tensor
*/
inline __device__ int indexToOffset(
int flat_index,
int ndim,
const _size_t *shape, const _size_t *shape,
const _ptrdiff_t *strides) { const _ptrdiff_t *strides) {
size_t res = 0; int res = 0;
for (size_t i = ndim; i-- > 0;) { for (int i = ndim; i-- > 0;) {
res += (flat_index % shape[i].value) * strides[i].value; res += (flat_index % shape[i].value) * strides[i].value;
flat_index /= shape[i].value; flat_index /= shape[i].value;
mfence();
} }
return res; return res;
} }
} // namespace device::kunlun::kernel } // namespace device::kunlun::kernel
#endif // __INFINIOP_KUNLUN_KERNEL_COMMON_H__
// TODO: atomicAddF16 // TODO: atomicAddF16
// TODO: atomicAddI8 // TODO: atomicAddI8
#endif
#ifndef __INFINIOP_KUNLUN_DTYPE_H__
#define __INFINIOP_KUNLUN_DTYPE_H__
#include "xpu/kernel/xtdk.h"
#include "xpu/kernel/xtdk_math.h"
#include "xpu/kernel/xtdk_simd.h"
#include "xpu/runtime.h"
// kunlun ptrdiff_t* is used to save ptrdiff_t array
// copied from host
typedef struct _ptrdiff_t {
long value; // 32 bit
long padding; // 32 bit
} _ptrdiff_t;
// same as ptrdiff
typedef struct _size_t {
size_t value;
size_t padding;
} _size_t;
#endif
#ifndef __INFINIOP_ELEMENTWISE_KUNLUN_H__ #ifndef __INFINIOP_ELEMENTWISE_KUNLUN_XPU__
#define __INFINIOP_ELEMENTWISE_KUNLUN_H__ #define __INFINIOP_ELEMENTWISE_KUNLUN_XPU__
#include "../../../utils.h" #include "../../../utils.h"
#include "../../devices/kunlun/kunlun_handle.h" #include "../../devices/kunlun/kunlun_common.h"
#include "../../devices/kunlun/kunlun_kernel_common.h"
#include "elementwise_kunlun_api.h" #include "elementwise_kunlun_api.h"
namespace op::elementwise::kunlun { namespace op::elementwise::kunlun {
using namespace device::kunlun::kernel;
template <typename T>
__device__ const T *typedInputPtr(const void *ptr) {
return reinterpret_cast<const T *>(ptr);
}
/**
* @brief Computes input tile offset
*/
struct InputIndexer {
int idx;
int ndim;
const bool *input_contiguous;
const bool *input_broadcasted;
const _size_t *input_shapes;
const _ptrdiff_t *input_strides;
const _ptrdiff_t *output_strides;
inline __device__ int operator()(int input_id) const {
return input_contiguous[input_id]
? idx
: (input_broadcasted[input_id]
? indexToReducedOffset(idx, ndim, output_strides, input_strides + input_id * ndim)
: indexToOffset(idx, ndim, input_shapes + input_id * ndim, input_strides + input_id * ndim));
}
};
/**
* @brief Computes the output index in memory, accounting for strides if non-contiguous.
*
* @param idx Linear index.
* @param is_contiguous Whether the output tensor is contiguous.
* @param ndim Number of dimensions.
* @param shape Shape of the output tensor.
* @param strides Strides of the output tensor.
* @return Memory offset index.
*/
inline __device__ int
getOutputIndex(int idx,
bool is_contiguous,
int ndim,
const _size_t *shape,
const _ptrdiff_t *strides) {
return is_contiguous ? idx : indexToOffset(idx, ndim, shape, strides);
}
/**
* @brief Computes elements of input indexes
*/
template <int N, typename Op, typename Tdata, typename... Args>
__device__ void launchOp(
__global_ptr__ Tdata **typed_inputs, // gm pointer
__global_ptr__ Tdata *output, // gm pointer output
Tdata *inputs_buf, // local mem buffer
int *input_indexes,
int output_index,
Args... args) {
static_assert(N == Op::num_inputs, "template N is not equal to Op::num_inputs!\n");
#pragma unroll
// Copy inputs to buf
for (int i = 0; i < N; i++) {
auto gm = typed_inputs[i] + input_indexes[i];
auto lm = inputs_buf + i;
GM2LM_ASYNC(gm, lm, 1 * sizeof(Tdata));
}
mfence();
// Calculate elementwise
// Inputs save all operands
Tdata out = Op{}(inputs_buf, args...);
// Copy out to gm
LM2GM_ASYNC(&out, output + output_index, 1 * sizeof(Tdata));
mfence();
}
template <int N, typename Op, typename Tdata, typename... Args>
__global__ void elementwiseKernel(
int output_size,
int ndim,
bool output_contiguous,
const bool *input_contiguous_gm,
const bool *input_broadcasted_gm,
const void *output_shape_gm,
const void *input_shapes_gm,
const void *output_strides_gm,
const void *input_strides_gm,
Tdata *output,
const void *const *inputs,
Args... args) {
int cid = core_id();
int ncores = core_num();
if (cid >= ncores) {
return;
}
int thread_id = ncores * cluster_id() + cid;
int nthreads = ncores * cluster_num();
// Cast input gm pointer type
auto typed_inputs = reinterpret_cast<const __global_ptr__ Tdata *const __global_ptr__ *>(inputs);
const int BUFF_SIZE = 64;
// Input data cache
__local__ Tdata inputs_buf[N];
// Input contiguous/broadcasted flags
__local__ bool input_contiguous[N];
__local__ bool input_broadcasted[N];
// Input shape/strides
__local__ _size_t input_shapes[N * ndim];
__local__ _ptrdiff_t input_strides[N * ndim];
// Output shape/strides
__local__ _size_t output_shape[ndim];
__local__ _ptrdiff_t output_strides[ndim];
// Inputs gm ptr buf
__local__ __global_ptr__ Tdata *typed_inputs_ptr[N];
// Load from gm
GM2LM_ASYNC(input_contiguous_gm, input_contiguous, N * sizeof(bool));
GM2LM_ASYNC(input_broadcasted_gm, input_broadcasted, N * sizeof(bool));
GM2LM_ASYNC(input_shapes_gm, input_shapes, N * ndim * sizeof(_size_t));
GM2LM_ASYNC(input_strides_gm, input_strides, N * ndim * sizeof(_ptrdiff_t));
GM2LM_ASYNC(output_shape_gm, output_shape, ndim * sizeof(_size_t));
GM2LM_ASYNC(output_strides_gm, output_strides, ndim * sizeof(_ptrdiff_t));
GM2LM_ASYNC(typed_inputs, typed_inputs_ptr, N * sizeof(__global_ptr__ Tdata *));
mfence();
int len_per_loop = min(BUFF_SIZE, roundup_div(output_size, nthreads));
for (int start = thread_id * len_per_loop; start < output_size; start += nthreads * len_per_loop) {
int read_len = min(len_per_loop, output_size - start);
for (int idx = start; idx < start + read_len; ++idx) {
int out_idx = getOutputIndex(idx, output_contiguous,
ndim, output_shape, output_strides);
InputIndexer indexer{idx, ndim, input_contiguous, input_broadcasted,
input_shapes, input_strides, output_strides};
// Get index offset for every operand
int indexes[N];
for (int i = 0; i < N; i++) {
indexes[i] = indexer(i);
}
// Launch operater
launchOp<N, Op, Tdata>(&typed_inputs_ptr[0], output, inputs_buf, indexes, out_idx, args...);
}
}
sync_cluster();
}
struct DeviceImpl::Opaque { struct DeviceImpl::Opaque {
std::shared_ptr<device::kunlun::Handle::Internal> internal; std::shared_ptr<device::kunlun::Handle::Internal> internal;
Opaque(const std::shared_ptr<device::kunlun::Handle::Internal> &internal_) Opaque(const std::shared_ptr<device::kunlun::Handle::Internal> &internal_)
: internal(internal_) {} : internal(internal_) {}
template <size_t N, typename Op, typename Tdata, typename... Args> template <uint32_t BLOCK_SIZE, int N, typename Op, typename Tdata, typename... Args>
infiniStatus_t calculateImpl(const op::elementwise::ElementwiseInfo &info, infiniStatus_t calculateImpl(const op::elementwise::ElementwiseInfo &info,
void *workspace, void *workspace,
void *output, void *output,
const std::vector<const void *> &inputs, const std::vector<const void *> &inputs,
kunlunStream_t stream, kunlunStream_t stream,
Args &&...args) { Args &&...args) {
return launchElementwiseKernel<BLOCK_SIZE, N>(
auto output_size = info.getOutputSize(); info,
if (output_size == 0) { workspace,
return INFINI_STATUS_SUCCESS; reinterpret_cast<Tdata *>(output),
} inputs,
elementwiseKernel<N, Op, Tdata, Args...>,
// Device pointers
const void **d_inputs_arr = nullptr;
const bool *d_input_contiguous = nullptr;
const bool *d_input_broadcasted = nullptr;
const size_t *d_output_shape = nullptr;
const ptrdiff_t *d_output_strides = nullptr;
const size_t *d_input_shapes = nullptr;
const ptrdiff_t *d_input_strides = nullptr;
CHECK_STATUS(infoToDevice<N>(info, workspace, inputs.data(), d_inputs_arr,
d_input_contiguous, d_input_broadcasted,
d_output_shape, d_output_strides,
d_input_shapes, d_input_strides));
Op::template launch<Tdata>(
output_size,
info.getNdim(),
info.isOutputContiguous(),
reinterpret_cast<const void *>(d_input_contiguous),
reinterpret_cast<const void *>(d_input_broadcasted),
reinterpret_cast<const void *>(d_output_shape),
reinterpret_cast<const void *>(d_input_shapes),
reinterpret_cast<const void *>(d_output_strides),
reinterpret_cast<const void *>(d_input_strides),
output,
reinterpret_cast<const void *const *>(d_inputs_arr),
stream, stream,
args...); std::forward<Args>(args)...);
return INFINI_STATUS_SUCCESS;
} }
private: private:
template <size_t N> /**
* @brief Transfers elementwise operation metadata and input pointers from host to device memory.
*
* @tparam N Number of input tensors.
*
* @param info Elementwise operation metadata (shapes, strides, flags, etc.).
* @param workspace Pointer to device workspace memory for storing metadata and input pointers.
* @param h_inputs_arr Host array of input tensor pointers.
* @param d_inputs_arr Input reference to device array of input tensor pointers.
* @param d_input_contiguous Input reference to device array indicating whether each input is contiguous.
* @param d_input_broadcasted Input reference to device array indicating whether each input is broadcasted.
* @param d_output_shape Output reference to device array holding the output tensor shape.
* @param d_output_strides Output reference to device array holding output tensor strides.
* @param d_input_shapes Output reference to flattened input tensor shapes (N * ndim).
* @param d_input_strides Output reference to flattened input tensor strides (N * ndim).
* @param stream KUNLUN stream used for asynchronous memory transfer.
* @return infiniStatus_t Status indicating success or failure of the memory transfer and setup.
*/
template <int N>
infiniStatus_t infoToDevice( infiniStatus_t infoToDevice(
const op::elementwise::ElementwiseInfo &info, const op::elementwise::ElementwiseInfo &info,
void *workspace, void *workspace,
const void *const *h_inputs_arr, const void *const *h_inputs_arr,
const void **&d_inputs_arr, __global_ptr__ const void **&d_inputs_arr,
const bool *&d_input_contiguous, __global_ptr__ const bool *&d_input_contiguous,
const bool *&d_input_broadcasted, __global_ptr__ const bool *&d_input_broadcasted,
const size_t *&d_output_shape, __global_ptr__ const size_t *&d_output_shape,
const ptrdiff_t *&d_output_strides, __global_ptr__ const ptrdiff_t *&d_output_strides,
const size_t *&d_input_shapes, __global_ptr__ const size_t *&d_input_shapes,
const ptrdiff_t *&d_input_strides) const { __global_ptr__ const ptrdiff_t *&d_input_strides,
kunlunStream_t stream) const {
constexpr auto input_size = N; constexpr auto input_size = N;
const auto ndim = info.getNdim(); const auto ndim = info.getNdim();
constexpr auto input_arr_size = N * sizeof(*h_inputs_arr); constexpr auto input_arr_size = N * sizeof(*h_inputs_arr);
const int8_t *info_meta_start = info.getMetaStart(); auto info_meta_start = info.getMetaStart(); // host meta pointer
const int8_t *d_meta_start = reinterpret_cast<int8_t *>(workspace) + input_arr_size;
auto d_meta_start = reinterpret_cast<__global_ptr__ int8_t *>(workspace)
+ input_arr_size; // device meta pointer
// copy the input pointer array and meta to device // copy the input pointer array and meta to device
CHECK_KUNLUN(xpu_memcpy(workspace, h_inputs_arr, input_arr_size, XPU_HOST_TO_DEVICE)); CHECK_KUNLUN(xpu_memcpy_async(workspace, h_inputs_arr, input_arr_size, XPU_HOST_TO_DEVICE, stream));
CHECK_KUNLUN(xpu_memcpy((void *)d_meta_start, info_meta_start, info.getMetaMemSize(), XPU_HOST_TO_DEVICE)); CHECK_KUNLUN(xpu_memcpy_async((void *)d_meta_start, info_meta_start, info.getMetaMemSize(), XPU_HOST_TO_DEVICE, stream));
xpu_wait(stream);
// xpu_wait(stream);
// offset/assign the pointers // offset/assign the pointers
d_inputs_arr = reinterpret_cast<const void **>(workspace); d_inputs_arr = reinterpret_cast<__global_ptr__ const void **>(workspace);
d_output_shape = reinterpret_cast<const size_t *>(d_meta_start); d_output_shape = reinterpret_cast<__global_ptr__ const size_t *>(d_meta_start);
d_output_strides = reinterpret_cast<const ptrdiff_t *>(d_output_shape + ndim); d_output_strides = reinterpret_cast<__global_ptr__ const ptrdiff_t *>(d_output_shape + ndim);
d_input_shapes = reinterpret_cast<const size_t *>(d_output_strides + ndim); d_input_shapes = reinterpret_cast<__global_ptr__ const size_t *>(d_output_strides + ndim);
d_input_strides = reinterpret_cast<const ptrdiff_t *>(d_input_shapes + input_size * ndim); d_input_strides = reinterpret_cast<__global_ptr__ const ptrdiff_t *>(d_input_shapes + input_size * ndim);
d_input_contiguous = reinterpret_cast<const bool *>(d_input_strides + input_size * ndim); d_input_contiguous = reinterpret_cast<const bool *>(d_input_strides + input_size * ndim);
d_input_broadcasted = reinterpret_cast<const bool *>(d_input_contiguous + input_size); d_input_broadcasted = reinterpret_cast<const bool *>(d_input_contiguous + input_size);
// contiguous / broadcast 信息
const bool *contiguous = info.getInputContiguous();
const bool *broadcasted = info.getInputBroadcasted();
return INFINI_STATUS_SUCCESS;
}
/**
* @brief Launch elementwise kernel
*/
template <uint32_t BLOCK_SIZE, int N, typename KernelFunc, typename Tout, typename... Args>
infiniStatus_t launchElementwiseKernel(
const op::elementwise::ElementwiseInfo &info,
void *workspace,
Tout *output,
const std::vector<const void *> &inputs,
KernelFunc kernel_func,
kunlunStream_t stream,
Args &&...args) {
auto output_size = info.getOutputSize();
if (output_size == 0) {
return INFINI_STATUS_SUCCESS;
}
// Device pointers
__global_ptr__ const void **d_inputs_arr = nullptr;
__global_ptr__ const bool *d_input_contiguous = nullptr;
__global_ptr__ const bool *d_input_broadcasted = nullptr;
__global_ptr__ const size_t *d_output_shape = nullptr;
__global_ptr__ const ptrdiff_t *d_output_strides = nullptr;
__global_ptr__ const size_t *d_input_shapes = nullptr;
__global_ptr__ const ptrdiff_t *d_input_strides = nullptr;
CHECK_STATUS(infoToDevice<N>(info, workspace, inputs.data(), d_inputs_arr,
d_input_contiguous, d_input_broadcasted,
d_output_shape, d_output_strides,
d_input_shapes, d_input_strides, stream));
kernel_func<<<BLOCK_SIZE, 64, stream>>>(
output_size,
info.getNdim(),
info.isOutputContiguous(),
d_input_contiguous,
d_input_broadcasted,
reinterpret_cast<__global_ptr__ const void *>(d_output_shape),
reinterpret_cast<__global_ptr__ const void *>(d_input_shapes),
reinterpret_cast<__global_ptr__ const void *>(d_output_strides),
reinterpret_cast<__global_ptr__ const void *>(d_input_strides),
output,
reinterpret_cast<__global_ptr__ const void **>(d_inputs_arr),
args...);
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
}; };
...@@ -101,37 +302,35 @@ utils::Result<DeviceImpl *> DeviceImpl::create(Args &&...args) { ...@@ -101,37 +302,35 @@ utils::Result<DeviceImpl *> DeviceImpl::create(Args &&...args) {
return utils::Result<DeviceImpl *>(new DeviceImpl(opaque)); return utils::Result<DeviceImpl *>(new DeviceImpl(opaque));
} }
template <typename Op, typename Tdata, typename... Args> template <unsigned int BLOCK_SIZE, typename Op, typename Tdata, typename... Args>
infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info, infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info,
void *workspace, void *workspace,
void *output, void *output,
const std::vector<const void *> &inputs, const std::vector<const void *> &inputs,
void *stream, void *stream,
Args &&...args) { Args &&...args) {
constexpr size_t N = Op::num_inputs; constexpr int N = Op::num_inputs;
return _opaque->calculateImpl<N, Op, Tdata>( return _opaque->calculateImpl<BLOCK_SIZE, N, Op, Tdata>(
info, workspace, output, inputs, info, workspace, output, inputs,
reinterpret_cast<kunlunStream_t>(stream), reinterpret_cast<kunlunStream_t>(stream),
std::forward<Args>(args)...); std::forward<Args>(args)...);
} }
} // namespace op::elementwise::kunlun
// Template for kunlun kernel interface declaration #define INSTANTIATE_ELEMENTWISE_KERNEL(N, Op, Tdata, ...) \
#define LAUNCH_ELEMENTWISE_KERNEL(OpName) \ template __global__ void elementwiseKernel<N, Op, Tdata, ##__VA_ARGS__>( \
template <typename Tdata, typename... Args> \ int output_size, \
void launch##OpName##Kernel( \ int ndim, \
size_t output_size, \ bool output_contiguous, \
size_t ndim, \ const bool *input_contiguous_gm, \
bool output_contiguous, \ const bool *input_broadcasted_gm, \
const void *input_contiguous, \ const void *output_shape_gm, \
const void *input_broadcasted, \ const void *input_shapes_gm, \
const void *output_shape, \ const void *output_strides_gm, \
const void *input_shapes, \ const void *input_strides_gm, \
const void *output_strides, \ Tdata *output, \
const void *input_strides, \ const void *const *inputs, \
void *output, \ ##__VA_ARGS__);
const void *const *inputs, \
XPUStream stream, \ } // namespace op::elementwise::kunlun
Args... args);
#endif #endif
...@@ -17,7 +17,10 @@ public: ...@@ -17,7 +17,10 @@ public:
template <typename... Args> template <typename... Args>
static utils::Result<DeviceImpl *> create(Args &&...args); static utils::Result<DeviceImpl *> create(Args &&...args);
template <typename Op, typename Tdata, typename... Args> /**
* @brief Launches elementwise operation. Operands have same types.
*/
template <unsigned int BLOCK_SIZE, typename Op, typename Tdata, typename... Args>
infiniStatus_t calculate( infiniStatus_t calculate(
const op::elementwise::ElementwiseInfo &info, const op::elementwise::ElementwiseInfo &info,
void *workspace, void *workspace,
...@@ -25,6 +28,20 @@ public: ...@@ -25,6 +28,20 @@ public:
const std::vector<const void *> &inputs, const std::vector<const void *> &inputs,
void *stream, void *stream,
Args &&...args); Args &&...args);
// /**
// * @brief Launches elementwise operation where operands' types differ
// */
// template <unsigned int BLOCK_SIZE, typename Op, typename Tout, typename... Tin,
// typename... Args,
// std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int> = 0>
// infiniStatus_t calculate(
// const op::elementwise::ElementwiseInfo &info,
// void *workspace,
// void *output,
// const std::vector<const void *> &inputs,
// void *stream,
// Args &&...args);
}; };
} // namespace op::elementwise::kunlun } // namespace op::elementwise::kunlun
......
#ifndef __INFINIOP_ELEMENTWISE_KUNLUN_XPU__
#define __INFINIOP_ELEMENTWISE_KUNLUN_XPU__
#include "../../devices/kunlun/kunlun_kernel_common.h"
using namespace device::kunlun::kernel;
/**
* @brief Computes input tile offset
*/
struct InputIndexer {
size_t idx;
size_t ndim;
const bool *input_contiguous;
const bool *input_broadcasted;
const _size_t *input_shapes;
const _ptrdiff_t *input_strides;
const _ptrdiff_t *output_strides;
__device__ size_t operator()(size_t input_id) const {
return input_contiguous[input_id]
? idx
: (input_broadcasted[input_id]
? indexToReducedOffset(idx, ndim, output_strides, input_strides + input_id * ndim)
: indexToOffset(idx, ndim, input_shapes + input_id * ndim, input_strides + input_id * ndim));
}
};
/**
* @brief Computes the output index in memory, accounting for strides if non-contiguous.
*
* @param idx Linear index.
* @param is_contiguous Whether the output tensor is contiguous.
* @param ndim Number of dimensions.
* @param shape Shape of the output tensor.
* @param strides Strides of the output tensor.
* @return Memory offset index.
*/
inline __device__ size_t
getOutputIndex(size_t idx,
bool is_contiguous,
size_t ndim,
const _size_t *shape,
const _ptrdiff_t *strides) {
return is_contiguous ? idx : indexToOffset(idx, ndim, shape, strides);
}
template <size_t N, typename Op, typename Tdata, typename... Args>
__device__ void launchOp(
__global_ptr__ Tdata **typed_inputs, // gm pointer
__global_ptr__ Tdata *output, // gm pointer output
Tdata *inputs_buf, // local mem buffer
size_t *input_indexes,
size_t output_index,
Args... args) {
static_assert(N == Op::num_inputs, "template N is not equal to Op::num_inputs!\n");
#pragma unroll
// Copy inputs to buf
for (size_t i = 0; i < N; i++) {
auto gm = typed_inputs[i] + input_indexes[i];
auto lm = inputs_buf + i;
GM2LM_ASYNC(gm, lm, 1 * sizeof(Tdata));
}
mfence();
// Calculate elementwise
// Inputs save all operands
Tdata out = Op{}(inputs_buf, args...);
// Copy out to gm
LM2GM_ASYNC(&out, output + output_index, 1 * sizeof(Tdata));
mfence();
}
template <size_t N, typename Op, typename Tdata, typename... Args>
__global__ void elementwiseKernel(
size_t output_size,
size_t ndim,
bool output_contiguous,
const bool *input_contiguous_gm,
const bool *input_broadcasted_gm,
const _size_t *output_shape_gm,
const _size_t *input_shapes_gm,
const _ptrdiff_t *output_strides_gm,
const _ptrdiff_t *input_strides_gm,
Tdata *output,
const void *const *inputs,
Args... args) {
int cid = core_id();
int ncores = core_num();
if (cid >= ncores) {
return;
}
int thread_id = ncores * cluster_id() + cid;
int nthreads = ncores * cluster_num();
// Cast input gm pointer type
auto typed_inputs = reinterpret_cast<const __global_ptr__ Tdata *const __global_ptr__ *>(inputs);
const int BUFF_SIZE = 64;
// Input data cache
__local__ Tdata inputs_buf[N];
// Input contiguous/broadcasted flags
__local__ bool input_contiguous[N];
__local__ bool input_broadcasted[N];
// Input shape/strides
__local__ _size_t input_shapes[N * ndim];
__local__ _ptrdiff_t input_strides[N * ndim];
// Output shape/strides
__local__ _size_t output_shape[ndim];
__local__ _ptrdiff_t output_strides[ndim];
// Inputs gm ptr buf
__local__ __global_ptr__ Tdata *typed_inputs_ptr[N];
// Load from gm
GM2LM_ASYNC(input_contiguous_gm, input_contiguous, N * sizeof(bool));
GM2LM_ASYNC(input_broadcasted_gm, input_broadcasted, N * sizeof(bool));
GM2LM_ASYNC(input_shapes_gm, input_shapes, N * ndim * sizeof(_size_t));
GM2LM_ASYNC(input_strides_gm, input_strides, N * ndim * sizeof(_ptrdiff_t));
GM2LM_ASYNC(output_shape_gm, output_shape, ndim * sizeof(_size_t));
GM2LM_ASYNC(output_strides_gm, output_strides, ndim * sizeof(_ptrdiff_t));
GM2LM_ASYNC(typed_inputs, typed_inputs_ptr, N * sizeof(__global_ptr__ Tdata *));
mfence();
int len_per_loop = min(BUFF_SIZE, roundup_div(output_size, nthreads));
for (int start = thread_id * len_per_loop; start < output_size; start += nthreads * len_per_loop) {
size_t read_len = min(len_per_loop, output_size - start);
for (int idx = start; idx < start + read_len; ++idx) {
size_t out_idx = getOutputIndex(static_cast<size_t>(idx), output_contiguous,
ndim, output_shape, output_strides);
InputIndexer indexer{static_cast<size_t>(idx), ndim, input_contiguous, input_broadcasted,
input_shapes, input_strides, output_strides};
// Get index offset for every operand
size_t indexes[N];
for (size_t i = 0; i < N; i++) {
indexes[i] = indexer(i);
}
// Launch operater
launchOp<N, Op, Tdata>(&typed_inputs_ptr[0], output, inputs_buf, indexes, out_idx, args...);
}
}
sync_cluster();
}
#define LAUNCH_ELEMENTWISE_KERNEL_IMPL(OpName, Op) \
template <typename Tdata, typename... Args> \
void launch##OpName##Kernel( \
size_t output_size, \
size_t ndim, \
bool output_contiguous, \
const void *input_contiguous, \
const void *input_broadcasted, \
const void *output_shape, \
const void *input_shapes, \
const void *output_strides, \
const void *input_strides, \
void *output, \
const void *const *inputs, \
XPUStream stream, \
Args... args) { \
elementwiseKernel<Op::num_inputs, Op, Tdata><<<8, 64, stream>>>( \
output_size, ndim, output_contiguous, \
reinterpret_cast<const bool *>(input_contiguous), \
reinterpret_cast<const bool *>(input_broadcasted), \
reinterpret_cast<const _size_t *>(output_shape), \
reinterpret_cast<const _size_t *>(input_shapes), \
reinterpret_cast<const _ptrdiff_t *>(output_strides), \
reinterpret_cast<const _ptrdiff_t *>(input_strides), \
reinterpret_cast<Tdata *>(output), inputs, args...); \
}
#define LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(OpName, T, ...) \
template void launch##OpName##Kernel<T, ##__VA_ARGS__>( \
size_t output_size, \
size_t ndim, \
bool output_contiguous, \
const void *input_contiguous, \
const void *input_broadcasted, \
const void *output_shape, \
const void *input_shapes, \
const void *output_strides, \
const void *input_strides, \
void *output, \
const void *const *inputs, \
XPUStream stream, \
##__VA_ARGS__);
#endif
...@@ -265,9 +265,9 @@ private: ...@@ -265,9 +265,9 @@ private:
* @param info Elementwise operation metadata (shapes, strides, flags, etc.). * @param info Elementwise operation metadata (shapes, strides, flags, etc.).
* @param workspace Pointer to device workspace memory for storing metadata and input pointers. * @param workspace Pointer to device workspace memory for storing metadata and input pointers.
* @param h_inputs_arr Host array of input tensor pointers. * @param h_inputs_arr Host array of input tensor pointers.
* @param d_inputs_arr Output reference to device array of input tensor pointers. * @param d_inputs_arr Input reference to device array of input tensor pointers.
* @param d_input_contiguous Output reference to device array indicating whether each input is contiguous. * @param d_input_contiguous Input reference to device array indicating whether each input is contiguous.
* @param d_input_broadcasted Output reference to device array indicating whether each input is broadcasted. * @param d_input_broadcasted Input reference to device array indicating whether each input is broadcasted.
* @param d_output_shape Output reference to device array holding the output tensor shape. * @param d_output_shape Output reference to device array holding the output tensor shape.
* @param d_output_strides Output reference to device array holding output tensor strides. * @param d_output_strides Output reference to device array holding output tensor strides.
* @param d_input_shapes Output reference to flattened input tensor shapes (N * ndim). * @param d_input_shapes Output reference to flattened input tensor shapes (N * ndim).
......
#include "gemm_kunlun.h" #include "gemm_kunlun.h"
#include "../../../../utils.h" #include "../../../../utils.h"
#include "../../../devices/kunlun/kunlun_handle.h" #include "../../../devices/kunlun/kunlun_common.h"
namespace op::gemm::kunlun { namespace op::gemm::kunlun {
......
#include "rms_norm_kunlun.h" #include "rms_norm_kunlun.h"
#include "../../../devices/kunlun/kunlun_handle.h" #include "../../../devices/kunlun/kunlun_common.h"
#include <memory> #include <memory>
#include <stdint.h> #include <stdint.h>
......
#ifndef __SWIGLU_KUNLUN_H__ #ifndef __SWIGLU_KUNLUN_H__
#define __SWIGLU_KUNLUN_H__ #define __SWIGLU_KUNLUN_H__
#include "../../../elementwise/kunlun/elementwise_kunlun.h" #include "../../../elementwise/kunlun/elementwise_kunlun_api.h"
ELEMENTWISE_DESCRIPTOR(swiglu, kunlun) ELEMENTWISE_DESCRIPTOR(swiglu, kunlun)
......
#include "../../../elementwise/kunlun/elementwise_kunlun.h"
#include "swiglu_kunlun.h" #include "swiglu_kunlun.h"
// Op interface declare namespace op::elementwise::kunlun {
LAUNCH_ELEMENTWISE_KERNEL(SwiGLU)
namespace op::swiglu::kunlun {
/// @brief SwiGLU op kernel
typedef struct SwiGLUOp { typedef struct SwiGLUOp {
static constexpr size_t num_inputs = 2; private:
template <typename Tdata, typename... Args> template <typename T>
static infiniStatus_t launch(Args... args) { inline __device__ T sigmoid(T x) const {
launchSwiGLUKernel<Tdata>(args...); return 1.0f / (1.0f + exp(-x));
return INFINI_STATUS_SUCCESS; }
public:
// This static number must be set in other Ops
static constexpr int num_inputs = 2;
template <typename T>
inline __device__ T operator()(const T *inputs) const {
T up = inputs[0];
T gate = inputs[1];
T out = gate * sigmoid(gate) * up;
return out;
} }
} SwiGLUOp; } SwiGLUOp;
INSTANTIATE_ELEMENTWISE_KERNEL(SwiGLUOp::num_inputs, SwiGLUOp, float);
} // namespace op::elementwise::kunlun
namespace op::swiglu::kunlun {
Descriptor::~Descriptor() = default; Descriptor::~Descriptor() = default;
infiniStatus_t Descriptor::create( infiniStatus_t Descriptor::create(
...@@ -53,7 +67,7 @@ infiniStatus_t Descriptor::calculate( ...@@ -53,7 +67,7 @@ infiniStatus_t Descriptor::calculate(
switch (_dtype) { switch (_dtype) {
case INFINI_DTYPE_F32: case INFINI_DTYPE_F32:
return _device_info->calculate<SwiGLUOp, float>(_info, workspace, output, inputs, stream); return _device_info->calculate<8, op::elementwise::kunlun::SwiGLUOp, float>(_info, workspace, output, inputs, stream);
default: default:
return INFINI_STATUS_BAD_TENSOR_DTYPE; return INFINI_STATUS_BAD_TENSOR_DTYPE;
} }
......
#ifndef __SWIGLU_KUNLUN_H__
#define __SWIGLU_KUNLUN_H__
#include "../../../devices/kunlun/kunlun_kernel_common.h"
#include "../../../elementwise/kunlun/elementwise_kunlun_kernel.h"
/// @brief Define swiglu op for local mem
typedef struct SwiGLUOp {
private:
template <typename T>
inline __device__ T sigmoid(T x) const {
return 1.0f / (1.0f + exp(-x));
}
public:
// This static number must be set in other Ops
static constexpr size_t num_inputs = 2;
template <typename T>
inline __device__ T operator()(const T *inputs) const {
T up = inputs[0];
T gate = inputs[1];
T out = gate * sigmoid(gate) * up;
return out;
}
} SwiGLUOp;
// Definition for swiglu kernel interface
LAUNCH_ELEMENTWISE_KERNEL_IMPL(SwiGLU, SwiGLUOp)
// Template instantiate
LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(SwiGLU, float)
#endif // __SWIGLU_KUNLUN_H__
...@@ -5,27 +5,22 @@ local XTDK_DIR = path.join(KUNLUN_HOME, "xtdk") ...@@ -5,27 +5,22 @@ local XTDK_DIR = path.join(KUNLUN_HOME, "xtdk")
local XDNN_DIR = path.join(KUNLUN_HOME, "xhpc", "xdnn") local XDNN_DIR = path.join(KUNLUN_HOME, "xhpc", "xdnn")
-- Add include dirs -- Add include dirs
add_includedirs(path.join(XRE_DIR, "include"), {public=true}) add_includedirs(path.join(XRE_DIR, "include"))
add_includedirs(path.join(XTDK_DIR, "include"), {public=true}) add_includedirs(path.join(XDNN_DIR, "include"))
add_includedirs(path.join(XDNN_DIR, "include"), {public=true}) add_includedirs(path.join(XTDK_DIR, "include"))
add_linkdirs(path.join(XRE_DIR, "so")) add_linkdirs(path.join(XRE_DIR, "so"))
add_linkdirs(path.join(XDNN_DIR, "so")) add_linkdirs(path.join(XDNN_DIR, "so"))
add_links("xpurt") add_links("xpurt", "xpuapi")
add_links("xpuapi")
rule("xpu") rule("xpu")
set_extensions(".xpu") set_extensions(".xpu")
on_load(function (target)
target:add("includedirs", path.join(os.projectdir(), "include"))
end)
on_build_file(function (target, sourcefile) on_build_file(function (target, sourcefile)
local sourcefile_config = target:fileconfig(sourcefile) or {}
local includedirs = sourcefile_config.includedirs or {}
local objectfile = target:objectfile(sourcefile) local objectfile = target:objectfile(sourcefile)
print("Compiling:", sourcefile, "->", objectfile)
-- local basename = objectfile:gsub("%.o$", "")
os.mkdir(path.directory(objectfile)) os.mkdir(path.directory(objectfile))
local cc = path.join(XTDK_DIR, "bin/clang++") local cc = path.join(XTDK_DIR, "bin/clang++")
local includedirs = table.concat(target:get("includedirs"), " ") local includedirs = table.concat(target:get("includedirs"), " ")
...@@ -35,19 +30,15 @@ rule("xpu") ...@@ -35,19 +30,15 @@ rule("xpu")
} }
local args = { local args = {
-- "--sysroot=/", "--sysroot=/",
"--target=" .. arch_map[os.arch()], "--target=" .. arch_map[os.arch()],
"-fPIC", "-fPIC",
-- "-pie",
"--xpu-arch=xpu3", "--xpu-arch=xpu3",
-- "--basename", basename,
"-std=c++17", "-std=c++17",
"-O2", "-O2",
"-fno-builtin", "-fno-builtin",
-- "-g",
"-c", sourcefile, "-c", sourcefile,
"-o", objectfile "-o", objectfile
-- "-v"
} }
for _, includedir in ipairs(target:get("includedirs")) do for _, includedir in ipairs(target:get("includedirs")) do
...@@ -59,8 +50,7 @@ rule("xpu") ...@@ -59,8 +50,7 @@ rule("xpu")
assert(ok == 0, "Compile failed: " .. sourcefile) assert(ok == 0, "Compile failed: " .. sourcefile)
table.insert(target:objectfiles(), objectfile) table.insert(target:objectfiles(), objectfile)
-- table.insert(target:objectfiles(), basename .. ".device.bin.o") print(target:objectfiles())
-- print(target:objectfiles())
end) end)
rule_end() rule_end()
...@@ -79,7 +69,15 @@ target("infiniop-kunlun") ...@@ -79,7 +69,15 @@ target("infiniop-kunlun")
-- compile handwriting kernel -- compile handwriting kernel
local xpu_files = os.files(src_dir .. "/ops/*/kunlun/*.xpu") local xpu_files = os.files(src_dir .. "/ops/*/kunlun/*.xpu")
if #xpu_files > 0 then if #xpu_files > 0 then
add_files(xpu_files, {rule = "xpu"}) add_files(xpu_files, {
rule = "xpu",
includedirs = {
path.join(os.projectdir, "include"),
path.join(XRE_DIR, "include"),
path.join(XDNN_DIR, "include"),
path.join(XTDK_DIR, "include")
}
})
end end
target_end() target_end()
...@@ -89,7 +87,7 @@ target("infinirt-kunlun") ...@@ -89,7 +87,7 @@ target("infinirt-kunlun")
set_languages("cxx17") set_languages("cxx17")
on_install(function (target) end) on_install(function (target) end)
-- Add include dirs -- Add include dirs
add_files("../src/infinirt/kunlun/*.cc") add_files("$(projectdir)/src/infinirt/kunlun/*.cc")
add_cxflags("-lstdc++ -Wall -Werror -fPIC") add_cxflags("-lstdc++ -Wall -Werror -fPIC")
target_end() target_end()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment