Commit c2e87202 authored by Catheriany's avatar Catheriany
Browse files

Merge remote-tracking branch 'origin/main' into issue/142

parents 41818f84 c203635b
#ifndef __INFINIOP_KUNLUN_DTYPE_H__
#define __INFINIOP_KUNLUN_DTYPE_H__
#include "xpu/kernel/xtdk.h"
#include "xpu/kernel/xtdk_math.h"
#include "xpu/kernel/xtdk_simd.h"
#include "xpu/runtime.h"
// kunlun ptrdiff_t* is used to save ptrdiff_t array
// copied from host
typedef struct _ptrdiff_t {
long value; // 32 bit
long padding; // 32 bit
} _ptrdiff_t;
// same as ptrdiff
typedef struct _size_t {
size_t value;
size_t padding;
} _size_t;
#endif
...@@ -17,9 +17,24 @@ class Handle::Internal { ...@@ -17,9 +17,24 @@ class Handle::Internal {
template <typename T> template <typename T>
using Fn = std::function<infiniStatus_t(T)>; using Fn = std::function<infiniStatus_t(T)>;
int _warp_size,
_max_threads_per_block,
_block_size[3],
_grid_size[3];
public: public:
Internal(int);
infiniStatus_t useMcblas(hcStream_t stream, const Fn<hcblasHandle_t> &f) const; infiniStatus_t useMcblas(hcStream_t stream, const Fn<hcblasHandle_t> &f) const;
infiniStatus_t useMcdnn(hcStream_t stream, const Fn<hcdnnHandle_t> &f) const; infiniStatus_t useMcdnn(hcStream_t stream, const Fn<hcdnnHandle_t> &f) const;
int warpSize() const;
int maxThreadsPerBlock() const;
int blockSizeX() const;
int blockSizeY() const;
int blockSizeZ() const;
int gridSizeX() const;
int gridSizeY() const;
int gridSizeZ() const;
}; };
hcdnnDataType_t getHcdnnDtype(infiniDtype_t dt); hcdnnDataType_t getHcdnnDtype(infiniDtype_t dt);
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
namespace device::maca { namespace device::maca {
Handle::Handle(infiniDevice_t device, int device_id) Handle::Handle(infiniDevice_t device, int device_id)
: InfiniopHandle{device, device_id}, : InfiniopHandle{device, device_id},
_internal(std::make_shared<Handle::Internal>()) {} _internal(std::make_shared<Handle::Internal>(device_id)) {}
Handle::Handle(int device_id) : Handle(INFINI_DEVICE_METAX, device_id) {} Handle::Handle(int device_id) : Handle(INFINI_DEVICE_METAX, device_id) {}
...@@ -11,6 +11,19 @@ auto Handle::internal() const -> const std::shared_ptr<Internal> & { ...@@ -11,6 +11,19 @@ auto Handle::internal() const -> const std::shared_ptr<Internal> & {
return _internal; return _internal;
} }
Handle::Internal::Internal(int device_id) {
hcDeviceProp_t prop;
hcGetDeviceProperties(&prop, device_id);
_warp_size = prop.warpSize;
_max_threads_per_block = prop.maxThreadsPerBlock;
_block_size[0] = prop.maxThreadsDim[0];
_block_size[1] = prop.maxThreadsDim[1];
_block_size[2] = prop.maxThreadsDim[2];
_grid_size[0] = prop.maxGridSize[0];
_grid_size[1] = prop.maxGridSize[1];
_grid_size[2] = prop.maxGridSize[2];
}
infiniStatus_t Handle::Internal::useMcblas(hcStream_t stream, const Fn<hcblasHandle_t> &f) const { infiniStatus_t Handle::Internal::useMcblas(hcStream_t stream, const Fn<hcblasHandle_t> &f) const {
auto handle = mcblas_handles.pop(); auto handle = mcblas_handles.pop();
if (!handle) { if (!handle) {
...@@ -33,6 +46,15 @@ infiniStatus_t Handle::Internal::useMcdnn(hcStream_t stream, const Fn<hcdnnHandl ...@@ -33,6 +46,15 @@ infiniStatus_t Handle::Internal::useMcdnn(hcStream_t stream, const Fn<hcdnnHandl
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
int Handle::Internal::warpSize() const { return _warp_size; }
int Handle::Internal::maxThreadsPerBlock() const { return _max_threads_per_block; }
int Handle::Internal::blockSizeX() const { return _block_size[0]; }
int Handle::Internal::blockSizeY() const { return _block_size[1]; }
int Handle::Internal::blockSizeZ() const { return _block_size[2]; }
int Handle::Internal::gridSizeX() const { return _grid_size[0]; }
int Handle::Internal::gridSizeY() const { return _grid_size[1]; }
int Handle::Internal::gridSizeZ() const { return _grid_size[2]; }
hcdnnDataType_t getHcdnnDtype(infiniDtype_t dt) { hcdnnDataType_t getHcdnnDtype(infiniDtype_t dt) {
switch (dt) { switch (dt) {
case INFINI_DTYPE_F16: case INFINI_DTYPE_F16:
......
...@@ -16,12 +16,27 @@ class Handle::Internal { ...@@ -16,12 +16,27 @@ class Handle::Internal {
Pool<std::unique_ptr<mublasHandle_t>> mublas_handles; Pool<std::unique_ptr<mublasHandle_t>> mublas_handles;
Pool<std::unique_ptr<::musa::dnn::Handle>> mudnn_handles; Pool<std::unique_ptr<::musa::dnn::Handle>> mudnn_handles;
int _warp_size,
_max_threads_per_block,
_block_size[3],
_grid_size[3];
template <typename T> template <typename T>
using Fn = std::function<infiniStatus_t(T)>; using Fn = std::function<infiniStatus_t(T)>;
public: public:
Internal(int);
infiniStatus_t useMublas(musaStream_t stream, const Fn<mublasHandle_t> &f) const; infiniStatus_t useMublas(musaStream_t stream, const Fn<mublasHandle_t> &f) const;
infiniStatus_t useMudnn(musaStream_t stream, const Fn<::musa::dnn::Handle &> &f) const; infiniStatus_t useMudnn(musaStream_t stream, const Fn<::musa::dnn::Handle &> &f) const;
int warpSize() const;
int maxThreadsPerBlock() const;
int blockSizeX() const;
int blockSizeY() const;
int blockSizeZ() const;
int gridSizeX() const;
int gridSizeY() const;
int gridSizeZ() const;
}; };
} // namespace device::musa } // namespace device::musa
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
namespace device::musa { namespace device::musa {
Handle::Handle(infiniDevice_t device, int device_id) Handle::Handle(infiniDevice_t device, int device_id)
: InfiniopHandle{device, device_id}, : InfiniopHandle{device, device_id},
_internal(std::make_shared<Handle::Internal>()) {} _internal(std::make_shared<Handle::Internal>(device_id)) {}
Handle::Handle(int device_id) : Handle(INFINI_DEVICE_MOORE, device_id) {} Handle::Handle(int device_id) : Handle(INFINI_DEVICE_MOORE, device_id) {}
...@@ -11,6 +11,19 @@ auto Handle::internal() const -> const std::shared_ptr<Internal> & { ...@@ -11,6 +11,19 @@ auto Handle::internal() const -> const std::shared_ptr<Internal> & {
return _internal; return _internal;
} }
Handle::Internal::Internal(int device_id) {
musaDeviceProp prop;
musaGetDeviceProperties(&prop, device_id);
_warp_size = prop.warpSize;
_max_threads_per_block = prop.maxThreadsPerBlock;
_block_size[0] = prop.maxThreadsDim[0];
_block_size[1] = prop.maxThreadsDim[1];
_block_size[2] = prop.maxThreadsDim[2];
_grid_size[0] = prop.maxGridSize[0];
_grid_size[1] = prop.maxGridSize[1];
_grid_size[2] = prop.maxGridSize[2];
}
infiniStatus_t Handle::Internal::useMublas(musaStream_t stream, const Fn<mublasHandle_t> &f) const { infiniStatus_t Handle::Internal::useMublas(musaStream_t stream, const Fn<mublasHandle_t> &f) const {
std::unique_ptr<mublasHandle_t> handle; std::unique_ptr<mublasHandle_t> handle;
auto opt_handle = mublas_handles.pop(); auto opt_handle = mublas_handles.pop();
...@@ -40,6 +53,15 @@ infiniStatus_t Handle::Internal::useMudnn(musaStream_t stream, const Fn<::musa:: ...@@ -40,6 +53,15 @@ infiniStatus_t Handle::Internal::useMudnn(musaStream_t stream, const Fn<::musa::
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
int Handle::Internal::warpSize() const { return _warp_size; }
int Handle::Internal::maxThreadsPerBlock() const { return _max_threads_per_block; }
int Handle::Internal::blockSizeX() const { return _block_size[0]; }
int Handle::Internal::blockSizeY() const { return _block_size[1]; }
int Handle::Internal::blockSizeZ() const { return _block_size[2]; }
int Handle::Internal::gridSizeX() const { return _grid_size[0]; }
int Handle::Internal::gridSizeY() const { return _grid_size[1]; }
int Handle::Internal::gridSizeZ() const { return _grid_size[2]; }
infiniStatus_t Handle::create(InfiniopHandle **handle_ptr, int device_id) { infiniStatus_t Handle::create(InfiniopHandle **handle_ptr, int device_id) {
*handle_ptr = new Handle(INFINI_DEVICE_MOORE, device_id); *handle_ptr = new Handle(INFINI_DEVICE_MOORE, device_id);
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
......
#ifndef __INFINIOP_ELEMENTWISE_CPU_H__
#define __INFINIOP_ELEMENTWISE_CPU_H__
#include "../../devices/cpu/common_cpu.h"
#include "../elementwise.h"
#include <utility>
/**
* @brief Define the process for initializing a Descriptor of an elementwise operation
* for its CPU implementation
*
* @param HANDLE The device handle.
* @param DTYPE The output dtype.
* @param OUT_DESC The output tensor descriptor.
* @param INPUT_DESC_VEC A vector containing input tensor descriptors.
*/
#define CREATE_ELEMENTWISE_CPU_DESCRIPTOR(HANDLE, DTYPE, OUT_DESC, INPUT_DESC_VEC) \
\
auto info_result = op::elementwise::ElementwiseInfo::create(OUT_DESC, INPUT_DESC_VEC); \
CHECK_RESULT(info_result); \
\
*desc_ptr = new Descriptor( \
DTYPE, \
info_result.take(), \
nullptr, \
0, \
HANDLE->device, \
HANDLE->device_id);
namespace op::elementwise::cpu {
/**
* @brief CPU-specific device implementation for resource management and
* calculation implementations.
*
* This class encapsulates device-specific behavior and execution logic.
* Use the static create() method to instantiate a DeviceImpl.
*/
class DeviceImpl final {
struct Opaque;
std::shared_ptr<Opaque> _opaque;
DeviceImpl(std::shared_ptr<Opaque> opaque) : _opaque(std::move(opaque)) {}
public:
~DeviceImpl() = default;
template <typename... Args>
static utils::Result<DeviceImpl> create(Args &&...args);
/**
* @brief Dispatches an elementwise operation with uniform input types.
*
* @tparam Op The elementwise operation to perform.
* @tparam Tdata The common data type of all inputs and output.
* @tparam Args Additional backend-specific arguments.
* @param info Precomputed tensor metadata (shapes, strides, etc.).
* @param output Pointer to the output tensor buffer.
* @param inputs Vector of input tensor data pointers.
* @param stream Device execution stream.
* @param args Additional backend-specific arguments.
* @return infiniStatus_t Status indicating success or failure.
*/
template <typename Op, typename Tdata, typename... Args>
infiniStatus_t calculate(
const op::elementwise::ElementwiseInfo &info,
void *output,
const std::vector<const void *> &inputs,
void *stream,
Args &&...args);
/**
* @brief Dispatches an elementwise operation with heterogeneous input types.
*
* Supports operations where each input may have a different type, as defined by Op.
* The number of input types must match the operation's expected input count.
*
* @tparam Op The elementwise operation to perform.
* @tparam Tout Output data type.
* @tparam Tin Variadic input data types.
* @tparam Args Additional backend-specific arguments.
* @param info Precomputed tensor metadata (shapes, strides, etc.).
* @param output Pointer to the output tensor buffer.
* @param inputs Vector of input tensor data pointers.
* @param stream Device execution stream.
* @param args Additional backend-specific arguments.
* @return infiniStatus_t Status indicating success or failure.
*/
template <typename Op, typename Tout, typename... Tin,
typename... Args,
std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int> = 0>
infiniStatus_t calculate(
const op::elementwise::ElementwiseInfo &info,
void *output,
const std::vector<const void *> &inputs,
void *stream,
Args &&...args);
};
// Define the Opaque struct for CPU, which is empty
struct DeviceImpl::Opaque {};
template <typename... Args>
utils::Result<DeviceImpl> DeviceImpl::create(Args &&...args) {
return INFINI_STATUS_NOT_IMPLEMENTED;
}
// Perform elementwise operation for different input types
template <typename Op, typename Tout, typename... Tin, size_t... Is, typename... Args,
std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int> = 0>
void calculate_impl(const op::elementwise::ElementwiseInfo &info,
void *output,
const std::vector<const void *> &inputs,
std::index_sequence<Is...>,
Args &&...args) {
Tout *out = reinterpret_cast<Tout *>(output);
std::tuple<const Tin *...> input_ptrs = {reinterpret_cast<const Tin *>(inputs[Is])...};
ptrdiff_t output_size = info.getOutputSize();
#pragma omp parallel for
for (ptrdiff_t i = 0; i < output_size; ++i) {
size_t out_idx = info.isOutputContiguous()
? i
: op::common_cpu::indexToOffset(i, info.getNdim(), info.getOutputShape(), info.getOutputStrides());
auto get_input_idx = [&](size_t input_id) {
return info.getInputContiguous()[input_id]
? i
: (info.getInputBroadcasted()[input_id]
? op::common_cpu::indexToReducedOffset(i, info.getNdim(), info.getOutputStrides(), info.getInputStrides(input_id))
: op::common_cpu::indexToOffset(i, info.getNdim(), info.getInputShape(input_id), info.getInputStrides(input_id)));
};
out[out_idx] = utils::cast<Tout>(
Op{}.template operator()<Tout, Tin...>(std::get<Is>(input_ptrs)[get_input_idx(Is)]..., std::forward<Args>(args)...));
}
}
// Invoke elementwise operation for different input types
template <typename Op, typename Tout, typename... Tin, typename... Args, std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int>>
infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info,
void *output,
const std::vector<const void *> &inputs,
void *stream,
Args &&...args) {
static_assert(sizeof...(Tin) == Op::num_inputs, "Input type count mismatch");
calculate_impl<Op, Tout, Tin...>(info, output, inputs, std::make_index_sequence<sizeof...(Tin)>{}, std::forward<Args>(args)...);
return INFINI_STATUS_SUCCESS;
}
// Perform elementwise operation when all inputs have the same type
template <typename Op, typename Tdata, size_t... Is, typename... Args>
void calculate_impl(const op::elementwise::ElementwiseInfo &info,
void *output,
const std::vector<const void *> &inputs,
std::index_sequence<Is...>,
Args &&...args) {
Tdata *out = reinterpret_cast<Tdata *>(output);
std::array<const Tdata *, sizeof...(Is)> ins = {reinterpret_cast<const Tdata *>(inputs[Is])...};
const ptrdiff_t output_size = info.getOutputSize();
#pragma omp parallel for
for (ptrdiff_t i = 0; i < output_size; ++i) {
size_t out_idx = info.isOutputContiguous()
? i
: op::common_cpu::indexToOffset(i, info.getNdim(), info.getOutputShape(), info.getOutputStrides());
auto get_input_idx = [&](size_t input_id) {
return info.getInputContiguous()[input_id]
? i
: (info.getInputBroadcasted()[input_id]
? op::common_cpu::indexToReducedOffset(i, info.getNdim(), info.getOutputStrides(), info.getInputStrides(input_id))
: op::common_cpu::indexToOffset(i, info.getNdim(), info.getInputShape(input_id), info.getInputStrides(input_id)));
};
if constexpr (std::is_same_v<Tdata, fp16_t>) {
out[out_idx] = utils::cast<fp16_t>(Op{}(utils::cast<float>(ins[Is][get_input_idx(Is)])..., std::forward<Args>(args)...));
} else {
out[out_idx] = Op{}(ins[Is][get_input_idx(Is)]..., std::forward<Args>(args)...);
}
}
}
// Invoke elementwise operation when all inputs have the same type
template <typename Op, typename Tdata, typename... Args>
infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info,
void *output,
const std::vector<const void *> &inputs,
void *stream,
Args &&...args) {
constexpr size_t N = Op::num_inputs;
calculate_impl<Op, Tdata>(info, output, inputs, std::make_index_sequence<N>{}, std::forward<Args>(args)...);
return INFINI_STATUS_SUCCESS;
}
} // namespace op::elementwise::cpu
#endif // __INFINIOP_ELEMENTWISE_CPU_H__
#ifndef __INFINIOP_ELEMENTWISE_CUDA_H__
#define __INFINIOP_ELEMENTWISE_CUDA_H__
#include "../../../utils.h"
#include "../../devices/cuda/cuda_common.cuh"
#include "../../devices/cuda/cuda_kernel_common.cuh"
#include "elementwise_cuda_api.cuh"
namespace op::elementwise::cuda {
/**
* @brief Casts an untyped device pointer to a typed pointer of type T.
*
* @tparam T Desired pointer type.
*
* @param ptr Untyped pointer.
* @return Pointer of type const T*.
*/
template <typename T>
__device__ __forceinline__ const T *typedInputPtr(const void *ptr) {
return reinterpret_cast<const T *>(ptr);
}
/**
* @brief Computes the output index in memory, accounting for strides if non-contiguous.
*
* @param idx Linear index.
* @param is_contiguous Whether the output tensor is contiguous.
* @param ndim Number of dimensions.
* @param shape Shape of the output tensor.
* @param strides Strides of the output tensor.
* @return Memory offset index.
*/
__device__ __forceinline__ size_t getOutputIndex(size_t idx, bool is_contiguous, size_t ndim,
const size_t *shape, const ptrdiff_t *strides) {
return is_contiguous ? idx : device::cuda::indexToOffset(idx, ndim, shape, strides);
}
/**
* @brief Computes input element offset for broadcasting and strided access.
*
* Used to map a linear output index to the corresponding index in an input tensor,
* considering contiguity and broadcasting.
*/
struct InputIndexer {
size_t idx;
size_t ndim;
const bool *input_contiguous;
const bool *input_broadcasted;
const size_t *input_shapes;
const ptrdiff_t *input_strides;
const ptrdiff_t *output_strides;
/**
* @brief Computes the memory offset for a given input tensor at current index.
*
* @param input_id ID of the input tensor.
* @return Offset into the input tensor.
*/
__device__ __forceinline__ size_t operator()(size_t input_id) const {
return input_contiguous[input_id]
? idx
: (input_broadcasted[input_id]
? device::cuda::indexToReducedOffset(idx, ndim, output_strides, input_strides + input_id * ndim)
: device::cuda::indexToOffset(idx, ndim, input_shapes + input_id * ndim, input_strides + input_id * ndim));
}
};
/**
* @brief Invokes a callable with compile-time index constants.
*
* Used to unpack index sequence for variadic template processing of inputs.
*
* @tparam F Callable type.
* @tparam Is Compile-time index sequence.
*
* @param f Callable to invoke with index constants.
*/
template <typename F, size_t... Is>
__device__ __forceinline__ void unpackInputsAndApply(F &&f, std::index_sequence<Is...>) {
f(std::integral_constant<size_t, Is>{}...);
}
/**
* @brief CUDA kernel for performing elementwise operations on tensors where all inputs share the same data type.
*
* @tparam N Number of input tensors.
* @tparam Op Operator type implementing operator()(Tdata...).
* @tparam Tdata Common data type for inputs and output.
* @tparam Args Additional arguments to pass to the operator.
*
* @param output_size Total number of output elements.
* @param ndim Number of dimensions in tensors.
* @param output_contiguous Whether the output tensor is contiguous in memory.
* @param input_contiguous Array indicating if each input tensor is contiguous.
* @param input_broadcasted Array indicating if each input tensor is broadcasted.
* @param output_shape Shape of the output tensor.
* @param input_shapes Shapes of the input tensors.
* @param output_strides Strides for the output tensor.
* @param input_strides Strides for each input tensor.
* @param output Output buffer.
* @param inputs Array of input pointers, all of type Tdata.
* @param offset Linear offset to support partitioned execution.
* @param args Additional arguments passed to the operator.
*/
template <size_t N, typename Op, typename Tdata, typename... Args>
INFINIOP_CUDA_KERNEL elementwiseKernel(
size_t output_size,
size_t ndim,
bool output_contiguous,
const bool *__restrict__ input_contiguous,
const bool *__restrict__ input_broadcasted,
const size_t *__restrict__ output_shape,
const size_t *__restrict__ input_shapes,
const ptrdiff_t *__restrict__ output_strides,
const ptrdiff_t *__restrict__ input_strides,
Tdata *output,
const void *const *inputs,
size_t offset,
Args... args) {
size_t idx = blockIdx.x * blockDim.x + threadIdx.x + offset;
if (idx < output_size) {
const Tdata *const *typed_inputs = reinterpret_cast<const Tdata *const *>(inputs);
size_t out_idx = getOutputIndex(idx, output_contiguous, ndim, output_shape, output_strides);
InputIndexer indexer{idx, ndim, input_contiguous, input_broadcasted, input_shapes, input_strides, output_strides};
unpackInputsAndApply(
[&](auto... Is) {
output[out_idx] = Op{}(typed_inputs[Is.value][indexer(Is.value)]..., std::forward<Args>(args)...);
},
std::make_index_sequence<N>{});
}
}
/**
* @brief CUDA kernel for performing an elementwise operation on tensors with support
* for broadcasting and mixed data types.
*
* @tparam Op Operator type implementing a templated operator() for (Tout, Tin...).
* @tparam Tout Output data type.
* @tparam Tin Variadic input data types.
*
* @param output_size Total number of output elements.
* @param ndim Number of dimensions in the tensors.
* @param output_contiguous Whether the output tensor is contiguous.
* @param input_contiguous Array indicating whether each input is contiguous.
* @param input_broadcasted Array indicating whether each input is broadcasted.
* @param output_shape Shape of the output tensor.
* @param input_shapes Shapes of the input tensors.
* @param output_strides Strides of the output tensor.
* @param input_strides Strides of the input tensors.
* @param output Pointer to the output buffer.
* @param inputs Array of untyped input pointers.
* @param offset Linear offset into the output for partitioned execution.
*/
template <typename Op, typename Tout, typename... Tin>
INFINIOP_CUDA_KERNEL elementwiseKernel(
size_t output_size,
size_t ndim,
bool output_contiguous,
const bool *__restrict__ input_contiguous,
const bool *__restrict__ input_broadcasted,
const size_t *__restrict__ output_shape,
const size_t *__restrict__ input_shapes,
const ptrdiff_t *__restrict__ output_strides,
const ptrdiff_t *__restrict__ input_strides,
Tout *output,
const void *const *__restrict__ inputs,
size_t offset) {
size_t idx = blockIdx.x * blockDim.x + threadIdx.x + offset;
if (idx < output_size) {
size_t out_idx = getOutputIndex(idx, output_contiguous, ndim, output_shape, output_strides);
InputIndexer indexer{idx, ndim, input_contiguous, input_broadcasted, input_shapes, input_strides, output_strides};
unpackInputsAndApply(
[&](auto... Is) {
output[out_idx] = Op{}.template operator()<Tout, Tin...>(
(typedInputPtr<Tin>(inputs[Is.value])[indexer(Is.value)])...);
},
std::index_sequence_for<Tin...>{});
}
}
struct DeviceImpl::Opaque {
std::shared_ptr<device::cuda::Handle::Internal> internal;
Opaque(const std::shared_ptr<device::cuda::Handle::Internal> &internal)
: internal(internal) {}
/**
* @brief Executes an elementwise operation where all inputs and the output share the same data type.
*
* @tparam BLOCK_SIZE CUDA block size used for kernel launch.
* @tparam N Number of input tensors.
* @tparam Op Functor representing the elementwise operation.
* @tparam Tdata Data type of both input and output tensors.
* @tparam Args Optional additional arguments passed to the operation.
*
* @param info Metadata about the operation including shape, size, and dimensionality.
* @param workspace Temporary workspace used for storing metadata on device.
* @param output Pointer to the output buffer.
* @param inputs Vector of pointers to input buffers.
* @param stream CUDA stream for asynchronous execution.
* @param args Additional arguments forwarded to the operation.
* @return infiniStatus_t Returns success or failure status.
*/
template <uint32_t BLOCK_SIZE, size_t N, typename Op, typename Tdata, typename... Args>
infiniStatus_t calculateImpl(const op::elementwise::ElementwiseInfo &info,
void *workspace,
void *output,
const std::vector<const void *> &inputs,
cudaStream_t stream,
Args &&...args) {
return launchElementwiseKernel<BLOCK_SIZE, N>(
info, workspace,
reinterpret_cast<Tdata *>(output), inputs,
elementwiseKernel<N, Op, Tdata, Args...>,
stream,
std::forward<Args>(args)...);
}
/**
* @brief Executes an elementwise operation with mixed input and output data types.
*
* @tparam BLOCK_SIZE CUDA block size used for kernel launch.
* @tparam N Number of input tensors.
* @tparam Op Functor representing the elementwise operation.
* @tparam Tout Data type of the output tensor.
* @tparam Tin... Data types of the input tensors.
* @tparam Args Optional additional arguments passed to the operation.(UNUSED)
*
* @param info Metadata about the operation including shape, size, and dimensionality.
* @param workspace Temporary workspace used for storing metadata on device.
* @param output Pointer to the output buffer.
* @param inputs Vector of pointers to input buffers.
* @param stream CUDA stream for asynchronous execution.
* @param args Additional arguments forwarded to the operation.
* @return infiniStatus_t Returns success or failure status.
*/
template <uint32_t BLOCK_SIZE, size_t N, typename Op, typename Tout, typename... Tin, typename... Args,
std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int> = 0>
infiniStatus_t calculateImpl(const op::elementwise::ElementwiseInfo &info,
void *workspace,
void *output,
const std::vector<const void *> &inputs,
cudaStream_t stream,
Args &&...args) {
return launchElementwiseKernel<BLOCK_SIZE, N>(
info, workspace,
reinterpret_cast<Tout *>(output), inputs,
elementwiseKernel<Op, Tout, Tin...>,
stream);
}
private:
/**
* @brief Transfers elementwise operation metadata and input pointers from host to device memory.
*
* @tparam N Number of input tensors.
*
* @param info Elementwise operation metadata (shapes, strides, flags, etc.).
* @param workspace Pointer to device workspace memory for storing metadata and input pointers.
* @param h_inputs_arr Host array of input tensor pointers.
* @param d_inputs_arr Output reference to device array of input tensor pointers.
* @param d_input_contiguous Output reference to device array indicating whether each input is contiguous.
* @param d_input_broadcasted Output reference to device array indicating whether each input is broadcasted.
* @param d_output_shape Output reference to device array holding the output tensor shape.
* @param d_output_strides Output reference to device array holding output tensor strides.
* @param d_input_shapes Output reference to flattened input tensor shapes (N * ndim).
* @param d_input_strides Output reference to flattened input tensor strides (N * ndim).
* @param stream CUDA stream used for asynchronous memory transfer.
* @return infiniStatus_t Status indicating success or failure of the memory transfer and setup.
*/
template <size_t N>
infiniStatus_t infoToDevice(
const op::elementwise::ElementwiseInfo &info,
void *workspace,
const void *const *h_inputs_arr,
const void **&d_inputs_arr,
const bool *&d_input_contiguous,
const bool *&d_input_broadcasted,
const size_t *&d_output_shape,
const ptrdiff_t *&d_output_strides,
const size_t *&d_input_shapes,
const ptrdiff_t *&d_input_strides,
cudaStream_t stream) const {
constexpr auto input_size = N;
const auto ndim = info.getNdim();
constexpr auto input_arr_size = N * sizeof(*h_inputs_arr);
const int8_t *info_meta_start = info.getMetaStart();
const int8_t *d_meta_start = reinterpret_cast<int8_t *>(workspace) + input_arr_size;
// copy the input pointer array and meta to device
CHECK_CUDA(cudaMemcpyAsync(workspace, h_inputs_arr, input_arr_size, cudaMemcpyHostToDevice, stream));
CHECK_CUDA(cudaMemcpyAsync((void *)d_meta_start, info_meta_start, info.getMetaMemSize(), cudaMemcpyHostToDevice, stream));
// offset/assign the pointers
d_inputs_arr = reinterpret_cast<const void **>(workspace);
d_output_shape = reinterpret_cast<const size_t *>(d_meta_start);
d_output_strides = reinterpret_cast<const ptrdiff_t *>(d_output_shape + ndim);
d_input_shapes = reinterpret_cast<const size_t *>(d_output_strides + ndim);
d_input_strides = reinterpret_cast<const ptrdiff_t *>(d_input_shapes + input_size * ndim);
d_input_contiguous = reinterpret_cast<const bool *>(d_input_strides + input_size * ndim);
d_input_broadcasted = reinterpret_cast<const bool *>(d_input_contiguous + input_size);
return INFINI_STATUS_SUCCESS;
}
/**
* @brief Launches the elementwise kernel for the specified operation.
*
* @tparam BLOCK_SIZE Number of threads per block.
* @tparam N Number of input tensors.
* @tparam KernelFunc Type of the kernel function pointer.
* @tparam Tout Output data type.
* @tparam Args Additional arguments to be forwarded to the kernel.
*
* @param info Metadata about the elementwise operation (shapes, strides, etc.).
* @param workspace CUDA memory used for storing metadata.
* @param output Pointer to output buffer on device.
* @param inputs Vector of device pointers to input tensors.
* @param kernel_func Kernel function to launch.
* @param stream CUDA stream for asynchronous execution.
* @param args Additional arguments passed to the kernel.
* @return infiniStatus_t Status code indicating success or failure.
*/
template <uint32_t BLOCK_SIZE, size_t N, typename KernelFunc, typename Tout, typename... Args>
infiniStatus_t launchElementwiseKernel(
const op::elementwise::ElementwiseInfo &info,
void *workspace,
Tout *output,
const std::vector<const void *> &inputs,
KernelFunc kernel_func,
cudaStream_t stream,
Args &&...args) {
auto output_size = info.getOutputSize();
if (output_size == 0) {
return INFINI_STATUS_SUCCESS;
}
// Device pointers
const void **d_inputs_arr = nullptr;
const bool *d_input_contiguous = nullptr;
const bool *d_input_broadcasted = nullptr;
const size_t *d_output_shape = nullptr;
const ptrdiff_t *d_output_strides = nullptr;
const size_t *d_input_shapes = nullptr;
const ptrdiff_t *d_input_strides = nullptr;
CHECK_STATUS(infoToDevice<N>(info, workspace, inputs.data(), d_inputs_arr,
d_input_contiguous, d_input_broadcasted,
d_output_shape, d_output_strides,
d_input_shapes, d_input_strides, stream));
dim3 blockDims(std::min(BLOCK_SIZE, static_cast<uint32_t>(internal->maxThreadsPerBlock())));
dim3 gridDims(std::min(uint32_t(CEIL_DIV(output_size, blockDims.x)), static_cast<uint32_t>(internal->gridSizeX())));
size_t step = gridDims.x * blockDims.x;
for (size_t i = 0; i < output_size; i += step) {
kernel_func<<<gridDims, blockDims, 0, stream>>>(
output_size, info.getNdim(), info.isOutputContiguous(),
d_input_contiguous, d_input_broadcasted,
d_output_shape, d_input_shapes,
d_output_strides, d_input_strides,
output, reinterpret_cast<const void **>(d_inputs_arr),
i, std::forward<Args>(args)...);
}
return INFINI_STATUS_SUCCESS;
}
};
template <typename... Args>
utils::Result<DeviceImpl *> DeviceImpl::create(Args &&...args) {
auto opaque = std::make_shared<Opaque>(std::forward<Args>(args)...);
return utils::Result<DeviceImpl *>(new DeviceImpl(opaque));
}
/* Invoke elementwise operation for different input types */
template <unsigned int BLOCK_SIZE, typename Op, typename Tout, typename... Tin, typename... Args,
std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int>>
infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info,
void *workspace,
void *output,
const std::vector<const void *> &inputs,
void *stream,
Args &&...args) {
constexpr size_t N = Op::num_inputs;
static_assert(sizeof...(Tin) == N, "Input type count mismatch");
return _opaque->calculateImpl<BLOCK_SIZE, N, Op, Tout, Tin...>(
info, workspace, output, inputs,
reinterpret_cast<cudaStream_t>(stream),
std::forward<Args>(args)...);
}
/* Invoke elementwise operation when all inputs have the same dtype */
template <unsigned int BLOCK_SIZE, typename Op, typename Tdata, typename... Args>
infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info,
void *workspace,
void *output,
const std::vector<const void *> &inputs,
void *stream,
Args &&...args) {
constexpr size_t N = Op::num_inputs;
return _opaque->calculateImpl<BLOCK_SIZE, N, Op, Tdata>(
info, workspace, output, inputs,
reinterpret_cast<cudaStream_t>(stream),
std::forward<Args>(args)...);
}
} // namespace op::elementwise::cuda
#endif // __INFINIOP_ELEMENTWISE_CUDA_H__
#ifndef __INFINIOP_ELEMENTWISE_CUDA_API_H__
#define __INFINIOP_ELEMENTWISE_CUDA_API_H__
#include "../elementwise.h"
namespace op::elementwise::cuda {
/**
* @brief Define the methods and info needed by CUDA to perform elementwise operation
*/
class DeviceImpl final {
struct Opaque;
std::shared_ptr<Opaque> _opaque;
DeviceImpl(std::shared_ptr<Opaque> opaque) : _opaque(std::move(opaque)) {}
public:
~DeviceImpl() = default;
template <typename... Args>
static utils::Result<DeviceImpl *> create(Args &&...args);
/**
* @brief Launches elementwise operation where all input types are the same.
*
* Calls the corresponding templated `calculateImpl` with a unified input type.
*
* @tparam BLOCK_SIZE Number of threads per block.
* @tparam Op Operation functor defining the computation.
* @tparam Tdata Data type for both input and output tensors.
* @tparam Args... Additional arguments passed to the operation.
*
* @param info Metadata describing tensor shapes, strides, etc.
* @param workspace Pointer to workspace buffer on device.
* @param output Pointer to output buffer on device.
* @param inputs Vector of input pointers (device memory).
* @param stream CUDA stream (opaque void*).
* @param args Additional operation-specific arguments.
* @return infiniStatus_t Status indicating success or failure.
*/
template <unsigned int BLOCK_SIZE, typename Op, typename Tdata, typename... Args>
infiniStatus_t calculate(
const op::elementwise::ElementwiseInfo &info,
void *workspace,
void *output,
const std::vector<const void *> &inputs,
void *stream,
Args &&...args);
/**
* @brief Launches elementwise operation where input types may differ.
*
* Dispatches to templated `calculateImpl` using specified output and input types.
*
* @tparam BLOCK_SIZE Number of threads per block.
* @tparam Op Operation functor defining the computation.
* @tparam Tout Output data type.
* @tparam Tin... Input data types (must match Op::num_inputs).
* @tparam Args... Additional arguments passed to the operation.
*
* @param info Metadata describing tensor shapes, strides, etc.
* @param workspace Pointer to workspace buffer on device.
* @param output Pointer to output buffer on device.
* @param inputs Vector of input pointers (device memory).
* @param stream CUDA stream (opaque void*).
* @param args (UNUSED) Additional operation-specific arguments.
* @return infiniStatus_t Status indicating success or failure.
*/
template <unsigned int BLOCK_SIZE, typename Op, typename Tout, typename... Tin,
typename... Args,
std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int> = 0>
infiniStatus_t calculate(
const op::elementwise::ElementwiseInfo &info,
void *workspace,
void *output,
const std::vector<const void *> &inputs,
void *stream,
Args &&...args);
};
} // namespace op::elementwise::cuda
/**
* @brief Define the process for initializing a Descriptor of an elementwise operation
* for its CUDA implementation
*
* @param HANDLE The device handle.
* @param DTYPE The output dtype.
* @param OUT_DESC The output tensor descriptor.
* @param INPUT_DESC_VEC A vector containing input tensor descriptors.
*/
#define CREATE_ELEMENTWISE_CUDA_DESCRIPTOR(HANDLE, DTYPE, OUT_DESC, INPUT_DESC_VEC) \
\
auto info_result = op::elementwise::ElementwiseInfo::create(OUT_DESC, INPUT_DESC_VEC); \
CHECK_RESULT(info_result); \
auto info = info_result.take(); \
auto workspace_size = info.getMetaMemSize() + info.getInputSize() * sizeof(void *); \
\
auto device_impl_result = op::elementwise::cuda::DeviceImpl::create(HANDLE->internal()); \
CHECK_RESULT(device_impl_result); \
\
*desc_ptr = new Descriptor( \
DTYPE, \
std::move(info), \
std::move(device_impl_result.take()), \
workspace_size, \
HANDLE->device, \
HANDLE->device_id);
#endif // __INFINIOP_ELEMENTWISE_CUDA_API_H__
#ifndef __INFINIOP_ELEMENTWISE_H__
#define __INFINIOP_ELEMENTWISE_H__
#include "../../utils.h"
#include "../operator.h"
#include "../tensor.h"
#include <algorithm>
#include <array>
#include <cstring>
#include <iostream>
#include <memory>
#include <numeric>
#include <vector>
#define ELEMENTWISE_DESCRIPTOR(OP, NAMESPACE) \
\
namespace op::OP::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
infiniDtype_t _dtype; \
op::elementwise::ElementwiseInfo _info; \
std::unique_ptr<op::elementwise::NAMESPACE::DeviceImpl> _device_info; \
size_t _workspace_size; \
\
Descriptor( \
infiniDtype_t dtype, \
op::elementwise::ElementwiseInfo info, \
op::elementwise::NAMESPACE::DeviceImpl *device_info, \
size_t workspace_size, \
infiniDevice_t device_type, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_dtype(dtype), \
_info(std::move(info)), \
_device_info(std::move(device_info)), \
_workspace_size(workspace_size) {} \
\
public: \
~Descriptor(); \
\
size_t workspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t output_desc, \
std::vector<infiniopTensorDescriptor_t> input_descs); \
\
infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
void *output, \
std::vector<const void *> inputs, \
void *stream) const; \
}; \
}
namespace op::elementwise {
/**
* @brief Stores the metadata required for performing an elementwise operation.
*
* This struct encapsulates shape, stride, and layout information for both
* output and multiple input tensors involved in an elementwise operation.
*
* Memory is manually managed and freed in the destructor.
* Supports move construction but disallows copy construction and copy/move assignment.
*
* Use ElementwiseInfo::create(...) to safely construct an instance from tensor descriptors.
*/
struct ElementwiseInfo {
private:
std::vector<size_t> _meta;
size_t _output_size;
size_t _input_size;
size_t _ndim;
bool _output_contiguous;
ElementwiseInfo(std::vector<size_t> meta,
size_t output_size,
size_t input_size,
size_t ndim,
bool output_contiguous)
: _meta(std::move(meta)), _output_size(output_size),
_input_size(input_size), _ndim(ndim),
_output_contiguous(output_contiguous) {}
public:
// Get the Memory size of the meta data in bytes
inline size_t getMetaMemSize() const {
return _meta.size() * sizeof(size_t);
}
inline const int8_t *getMetaStart() const {
return reinterpret_cast<const int8_t *>(_meta.data());
}
inline size_t getOutputSize() const {
return _output_size;
}
inline size_t getInputSize() const {
return _input_size;
}
inline size_t getNdim() const {
return _ndim;
}
inline bool isOutputContiguous() const {
return _output_contiguous;
}
inline const size_t *getOutputShape() const {
return reinterpret_cast<const size_t *>(_meta.data());
}
inline const ptrdiff_t *getOutputStrides() const {
return reinterpret_cast<const ptrdiff_t *>(getOutputShape() + _ndim);
}
inline const size_t *getAllInputShapes() const {
return reinterpret_cast<const size_t *>(getOutputStrides() + _ndim);
}
inline const size_t *getInputShape(const size_t &index) const {
if (index < _input_size) {
return reinterpret_cast<const size_t *>(getAllInputShapes() + index * _ndim);
}
return nullptr;
}
inline const ptrdiff_t *getAllInputStrides() const {
return reinterpret_cast<const ptrdiff_t *>(getAllInputShapes() + _input_size * _ndim);
}
inline const ptrdiff_t *getInputStrides(const size_t &index) const {
if (index < _input_size) {
return reinterpret_cast<const ptrdiff_t *>(getAllInputStrides() + index * _ndim);
}
return nullptr;
}
inline const bool *getInputContiguous() const {
return reinterpret_cast<const bool *>(getAllInputStrides() + _input_size * _ndim);
}
inline const bool *getInputBroadcasted() const {
return reinterpret_cast<const bool *>(getInputContiguous() + _input_size);
}
using ResultType = utils::Result<ElementwiseInfo>;
/**
* @brief Construct ElementwiseInfo from output and input tensor descriptors.
* @param output_desc Descriptor of the output tensor.
* @param input_descs Descriptors of the input tensors.
* @return Result<ElementwiseInfo> with the successfully constructed ElementwiseInfo,
* or the status code.
*/
static ResultType create(
infiniopTensorDescriptor_t output_desc,
std::vector<infiniopTensorDescriptor_t> input_descs) {
if (!output_desc || input_descs.empty()) {
return INFINI_STATUS_BAD_PARAM;
}
// Destination cannot have broadcast setup
if (output_desc->hasBroadcastDim()) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
auto input_size = input_descs.size();
auto ndim = output_desc->ndim();
auto output_size = output_desc->numel();
auto output_contiguous = output_desc->isContiguous();
// Allocate memory for meta
auto shape_unit = output_desc->dim(0);
auto stride_unit = output_desc->stride(0);
size_t meta_mem_size = ndim * (sizeof(shape_unit) + sizeof(stride_unit))
+ input_size * ndim * sizeof(shape_unit)
+ input_size * ndim * sizeof(stride_unit)
+ 2 * input_size * sizeof(bool);
std::vector<size_t> meta(CEIL_DIV(meta_mem_size, sizeof(size_t)));
int8_t *meta_ptr = reinterpret_cast<int8_t *>(meta.data());
const auto output_shape = output_desc->shape();
const auto output_strides = output_desc->strides();
// Pointers to the sections within _meta
size_t *output_shape_p = reinterpret_cast<size_t *>(meta_ptr);
ptrdiff_t *output_strides_p = reinterpret_cast<ptrdiff_t *>(output_shape_p + ndim);
size_t *input_shapes = reinterpret_cast<size_t *>(output_strides_p + ndim);
ptrdiff_t *input_strides = reinterpret_cast<ptrdiff_t *>(input_shapes + input_size * ndim);
bool *input_contiguous = reinterpret_cast<bool *>(input_strides + input_size * ndim);
bool *input_broadcasted = input_contiguous + input_size;
// Copy output shape and strides
std::memcpy(output_shape_p, output_shape.data(), ndim * sizeof(*output_shape_p));
std::memcpy(output_strides_p, output_strides.data(), ndim * sizeof(*output_strides_p));
// Copy input shapes, strides, contiguous, and broadcasted flags
for (size_t i = 0; i < input_size; ++i) {
auto &desc = input_descs[i];
const auto in_shape = desc->shape();
const auto in_strides = desc->strides();
std::memcpy(input_shapes + i * ndim, in_shape.data(), ndim * sizeof(*input_shapes));
std::memcpy(input_strides + i * ndim, in_strides.data(), ndim * sizeof(*input_strides));
input_contiguous[i] = desc->isContiguous();
input_broadcasted[i] = !input_contiguous[i] && (desc->ndim() != ndim || desc->hasBroadcastDim());
}
ElementwiseInfo info(std::move(meta), output_size, input_size, ndim, output_contiguous);
return ResultType(std::move(info));
}
};
} // namespace op::elementwise
#endif // __INFINIOP_ELEMENTWISE_H__
#ifndef __INFINIOP_ELEMENTWISE_KUNLUN_H__
#define __INFINIOP_ELEMENTWISE_KUNLUN_H__
#include "../../../utils.h"
#include "../../devices/kunlun/kunlun_handle.h"
#include "elementwise_kunlun_api.h"
namespace op::elementwise::kunlun {
struct DeviceImpl::Opaque {
std::shared_ptr<device::kunlun::Handle::Internal> internal;
Opaque(const std::shared_ptr<device::kunlun::Handle::Internal> &internal_)
: internal(internal_) {}
template <size_t N, typename Op, typename Tdata, typename... Args>
infiniStatus_t calculateImpl(const op::elementwise::ElementwiseInfo &info,
void *workspace,
void *output,
const std::vector<const void *> &inputs,
kunlunStream_t stream,
Args &&...args) {
auto output_size = info.getOutputSize();
if (output_size == 0) {
return INFINI_STATUS_SUCCESS;
}
// Device pointers
const void **d_inputs_arr = nullptr;
const bool *d_input_contiguous = nullptr;
const bool *d_input_broadcasted = nullptr;
const size_t *d_output_shape = nullptr;
const ptrdiff_t *d_output_strides = nullptr;
const size_t *d_input_shapes = nullptr;
const ptrdiff_t *d_input_strides = nullptr;
CHECK_STATUS(infoToDevice<N>(info, workspace, inputs.data(), d_inputs_arr,
d_input_contiguous, d_input_broadcasted,
d_output_shape, d_output_strides,
d_input_shapes, d_input_strides));
Op::template launch<Tdata>(
output_size,
info.getNdim(),
info.isOutputContiguous(),
reinterpret_cast<const void *>(d_input_contiguous),
reinterpret_cast<const void *>(d_input_broadcasted),
reinterpret_cast<const void *>(d_output_shape),
reinterpret_cast<const void *>(d_input_shapes),
reinterpret_cast<const void *>(d_output_strides),
reinterpret_cast<const void *>(d_input_strides),
output,
reinterpret_cast<const void *const *>(d_inputs_arr),
stream,
args...);
return INFINI_STATUS_SUCCESS;
}
private:
template <size_t N>
infiniStatus_t infoToDevice(
const op::elementwise::ElementwiseInfo &info,
void *workspace,
const void *const *h_inputs_arr,
const void **&d_inputs_arr,
const bool *&d_input_contiguous,
const bool *&d_input_broadcasted,
const size_t *&d_output_shape,
const ptrdiff_t *&d_output_strides,
const size_t *&d_input_shapes,
const ptrdiff_t *&d_input_strides) const {
constexpr auto input_size = N;
const auto ndim = info.getNdim();
constexpr auto input_arr_size = N * sizeof(*h_inputs_arr);
const int8_t *info_meta_start = info.getMetaStart();
const int8_t *d_meta_start = reinterpret_cast<int8_t *>(workspace) + input_arr_size;
// copy the input pointer array and meta to device
CHECK_KUNLUN(xpu_memcpy(workspace, h_inputs_arr, input_arr_size, XPU_HOST_TO_DEVICE));
CHECK_KUNLUN(xpu_memcpy((void *)d_meta_start, info_meta_start, info.getMetaMemSize(), XPU_HOST_TO_DEVICE));
// offset/assign the pointers
d_inputs_arr = reinterpret_cast<const void **>(workspace);
d_output_shape = reinterpret_cast<const size_t *>(d_meta_start);
d_output_strides = reinterpret_cast<const ptrdiff_t *>(d_output_shape + ndim);
d_input_shapes = reinterpret_cast<const size_t *>(d_output_strides + ndim);
d_input_strides = reinterpret_cast<const ptrdiff_t *>(d_input_shapes + input_size * ndim);
d_input_contiguous = reinterpret_cast<const bool *>(d_input_strides + input_size * ndim);
d_input_broadcasted = reinterpret_cast<const bool *>(d_input_contiguous + input_size);
return INFINI_STATUS_SUCCESS;
}
};
template <typename... Args>
utils::Result<DeviceImpl *> DeviceImpl::create(Args &&...args) {
auto opaque = std::make_shared<Opaque>(std::forward<Args>(args)...);
return utils::Result<DeviceImpl *>(new DeviceImpl(opaque));
}
template <typename Op, typename Tdata, typename... Args>
infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info,
void *workspace,
void *output,
const std::vector<const void *> &inputs,
void *stream,
Args &&...args) {
constexpr size_t N = Op::num_inputs;
return _opaque->calculateImpl<N, Op, Tdata>(
info, workspace, output, inputs,
reinterpret_cast<kunlunStream_t>(stream),
std::forward<Args>(args)...);
}
} // namespace op::elementwise::kunlun
// Template for kunlun kernel interface declaration
#define LAUNCH_ELEMENTWISE_KERNEL(OpName) \
template <typename Tdata, typename... Args> \
void launch##OpName##Kernel( \
size_t output_size, \
size_t ndim, \
bool output_contiguous, \
const void *input_contiguous, \
const void *input_broadcasted, \
const void *output_shape, \
const void *input_shapes, \
const void *output_strides, \
const void *input_strides, \
void *output, \
const void *const *inputs, \
XPUStream stream, \
Args... args);
#endif
#ifndef __INFINIOP_ELEMENTWISE_KUNLUN_API_H__
#define __INFINIOP_ELEMENTWISE_KUNLUN_API_H__
#include "../elementwise.h"
namespace op::elementwise::kunlun {
class DeviceImpl final {
struct Opaque;
std::shared_ptr<Opaque> _opaque;
DeviceImpl(std::shared_ptr<Opaque> opaque) : _opaque(std::move(opaque)) {}
public:
~DeviceImpl() = default;
template <typename... Args>
static utils::Result<DeviceImpl *> create(Args &&...args);
template <typename Op, typename Tdata, typename... Args>
infiniStatus_t calculate(
const op::elementwise::ElementwiseInfo &info,
void *workspace,
void *output,
const std::vector<const void *> &inputs,
void *stream,
Args &&...args);
};
} // namespace op::elementwise::kunlun
#define CREATE_ELEMENTWISE_KUNLUN_DESCRIPTOR(HANDLE, DTYPE, OUT_DESC, INPUT_DESC_VEC) \
\
auto info_result = op::elementwise::ElementwiseInfo::create(OUT_DESC, INPUT_DESC_VEC); \
CHECK_RESULT(info_result); \
auto info = info_result.take(); \
auto workspace_size = info.getMetaMemSize() + info.getInputSize() * sizeof(void *); \
\
auto device_impl_result = op::elementwise::kunlun::DeviceImpl::create(HANDLE->internal()); \
CHECK_RESULT(device_impl_result); \
\
*desc_ptr = new Descriptor( \
DTYPE, \
std::move(info), \
std::move(device_impl_result.take()), \
workspace_size, \
HANDLE->device, \
HANDLE->device_id);
#endif
#ifndef __INFINIOP_ELEMENTWISE_KUNLUN_XPU__
#define __INFINIOP_ELEMENTWISE_KUNLUN_XPU__
#include "../../devices/kunlun/kunlun_kernel_common.h"
using namespace device::kunlun::kernel;
/**
* @brief Computes input tile offset
*/
struct InputIndexer {
size_t idx;
size_t ndim;
const bool *input_contiguous;
const bool *input_broadcasted;
const _size_t *input_shapes;
const _ptrdiff_t *input_strides;
const _ptrdiff_t *output_strides;
__device__ size_t operator()(size_t input_id) const {
return input_contiguous[input_id]
? idx
: (input_broadcasted[input_id]
? indexToReducedOffset(idx, ndim, output_strides, input_strides + input_id * ndim)
: indexToOffset(idx, ndim, input_shapes + input_id * ndim, input_strides + input_id * ndim));
}
};
/**
* @brief Computes the output index in memory, accounting for strides if non-contiguous.
*
* @param idx Linear index.
* @param is_contiguous Whether the output tensor is contiguous.
* @param ndim Number of dimensions.
* @param shape Shape of the output tensor.
* @param strides Strides of the output tensor.
* @return Memory offset index.
*/
inline __device__ size_t
getOutputIndex(size_t idx,
bool is_contiguous,
size_t ndim,
const _size_t *shape,
const _ptrdiff_t *strides) {
return is_contiguous ? idx : indexToOffset(idx, ndim, shape, strides);
}
template <size_t N, typename Op, typename Tdata, typename... Args>
__device__ void launchOp(
__global_ptr__ Tdata **typed_inputs, // gm pointer
__global_ptr__ Tdata *output, // gm pointer output
Tdata *inputs_buf, // local mem buffer
size_t *input_indexes,
size_t output_index,
Args... args) {
static_assert(N == Op::num_inputs, "template N is not equal to Op::num_inputs!\n");
#pragma unroll
// Copy inputs to buf
for (size_t i = 0; i < N; i++) {
auto gm = typed_inputs[i] + input_indexes[i];
auto lm = inputs_buf + i;
GM2LM_ASYNC(gm, lm, 1 * sizeof(Tdata));
}
mfence();
// Calculate elementwise
// Inputs save all operands
Tdata out = Op{}(inputs_buf, args...);
// Copy out to gm
LM2GM_ASYNC(&out, output + output_index, 1 * sizeof(Tdata));
mfence();
}
template <size_t N, typename Op, typename Tdata, typename... Args>
__global__ void elementwiseKernel(
size_t output_size,
size_t ndim,
bool output_contiguous,
const bool *input_contiguous_gm,
const bool *input_broadcasted_gm,
const _size_t *output_shape_gm,
const _size_t *input_shapes_gm,
const _ptrdiff_t *output_strides_gm,
const _ptrdiff_t *input_strides_gm,
Tdata *output,
const void *const *inputs,
Args... args) {
int cid = core_id();
int ncores = core_num();
if (cid >= ncores) {
return;
}
int thread_id = ncores * cluster_id() + cid;
int nthreads = ncores * cluster_num();
// Cast input gm pointer type
auto typed_inputs = reinterpret_cast<const __global_ptr__ Tdata *const __global_ptr__ *>(inputs);
const int BUFF_SIZE = 64;
// Input data cache
__local__ Tdata inputs_buf[N];
// Input contiguous/broadcasted flags
__local__ bool input_contiguous[N];
__local__ bool input_broadcasted[N];
// Input shape/strides
__local__ _size_t input_shapes[N * ndim];
__local__ _ptrdiff_t input_strides[N * ndim];
// Output shape/strides
__local__ _size_t output_shape[ndim];
__local__ _ptrdiff_t output_strides[ndim];
// Inputs gm ptr buf
__local__ __global_ptr__ Tdata *typed_inputs_ptr[N];
// Load from gm
GM2LM_ASYNC(input_contiguous_gm, input_contiguous, N * sizeof(bool));
GM2LM_ASYNC(input_broadcasted_gm, input_broadcasted, N * sizeof(bool));
GM2LM_ASYNC(input_shapes_gm, input_shapes, N * ndim * sizeof(_size_t));
GM2LM_ASYNC(input_strides_gm, input_strides, N * ndim * sizeof(_ptrdiff_t));
GM2LM_ASYNC(output_shape_gm, output_shape, ndim * sizeof(_size_t));
GM2LM_ASYNC(output_strides_gm, output_strides, ndim * sizeof(_ptrdiff_t));
GM2LM_ASYNC(typed_inputs, typed_inputs_ptr, N * sizeof(__global_ptr__ Tdata *));
mfence();
int len_per_loop = min(BUFF_SIZE, roundup_div(output_size, nthreads));
for (int start = thread_id * len_per_loop; start < output_size; start += nthreads * len_per_loop) {
size_t read_len = min(len_per_loop, output_size - start);
for (int idx = start; idx < start + read_len; ++idx) {
size_t out_idx = getOutputIndex(static_cast<size_t>(idx), output_contiguous,
ndim, output_shape, output_strides);
InputIndexer indexer{static_cast<size_t>(idx), ndim, input_contiguous, input_broadcasted,
input_shapes, input_strides, output_strides};
// Get index offset for every operand
size_t indexes[N];
for (size_t i = 0; i < N; i++) {
indexes[i] = indexer(i);
}
// Launch operater
launchOp<N, Op, Tdata>(&typed_inputs_ptr[0], output, inputs_buf, indexes, out_idx, args...);
}
}
sync_cluster();
}
#define LAUNCH_ELEMENTWISE_KERNEL_IMPL(OpName, Op) \
template <typename Tdata, typename... Args> \
void launch##OpName##Kernel( \
size_t output_size, \
size_t ndim, \
bool output_contiguous, \
const void *input_contiguous, \
const void *input_broadcasted, \
const void *output_shape, \
const void *input_shapes, \
const void *output_strides, \
const void *input_strides, \
void *output, \
const void *const *inputs, \
XPUStream stream, \
Args... args) { \
elementwiseKernel<Op::num_inputs, Op, Tdata><<<8, 64, stream>>>( \
output_size, ndim, output_contiguous, \
reinterpret_cast<const bool *>(input_contiguous), \
reinterpret_cast<const bool *>(input_broadcasted), \
reinterpret_cast<const _size_t *>(output_shape), \
reinterpret_cast<const _size_t *>(input_shapes), \
reinterpret_cast<const _ptrdiff_t *>(output_strides), \
reinterpret_cast<const _ptrdiff_t *>(input_strides), \
reinterpret_cast<Tdata *>(output), inputs, args...); \
}
#define LAUNCH_ELEMENTWISE_KERNEL_INSTANTIATE(OpName, T, ...) \
template void launch##OpName##Kernel<T, ##__VA_ARGS__>( \
size_t output_size, \
size_t ndim, \
bool output_contiguous, \
const void *input_contiguous, \
const void *input_broadcasted, \
const void *output_shape, \
const void *input_shapes, \
const void *output_strides, \
const void *input_strides, \
void *output, \
const void *const *inputs, \
XPUStream stream, \
##__VA_ARGS__);
#endif
#include "add_cpu.h"
namespace op::add::cpu {
Descriptor::~Descriptor() = default;
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {
auto handle = reinterpret_cast<device::cpu::Handle *>(handle_);
auto dtype = out_desc->dtype();
const auto &a_desc = input_desc_vec.at(0);
const auto &b_desc = input_desc_vec.at(1);
const auto &c_shape = out_desc->shape();
const auto &a_shape = a_desc->shape();
const auto &b_shape = b_desc->shape();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64);
CHECK_SAME_SHAPE(c_shape, a_shape, b_shape);
// create CPU elementwise descriptor
CREATE_ELEMENTWISE_CPU_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec);
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *output,
std::vector<const void *> inputs,
void *stream) const {
switch (_dtype) {
case INFINI_DTYPE_F16:
return _device_info->calculate<AddOp, fp16_t>(_info, output, inputs, stream);
case INFINI_DTYPE_F32:
return _device_info->calculate<AddOp, float>(_info, output, inputs, stream);
case INFINI_DTYPE_F64:
return _device_info->calculate<AddOp, double>(_info, output, inputs, stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::add::cpu
#ifndef __ADD_CPU_H__
#define __ADD_CPU_H__
#include "../../../elementwise/cpu/elementwise_cpu.h"
ELEMENTWISE_DESCRIPTOR(add, cpu)
namespace op::add::cpu {
typedef struct AddOp {
public:
static constexpr size_t num_inputs = 2;
template <typename T>
T operator()(const T &a, const T &b) const {
return a + b;
}
} AddOp;
} // namespace op::add::cpu
#endif // __ADD_CPU_H__
#include "add_cuda.cuh"
#include "add_cuda_internal.cuh"
namespace op::add::cuda {
Descriptor::~Descriptor() = default;
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {
auto handle = reinterpret_cast<device::cuda::Handle *>(handle_);
auto dtype = out_desc->dtype();
const auto &a_desc = input_desc_vec.at(0);
const auto &b_desc = input_desc_vec.at(1);
const auto &c_shape = out_desc->shape();
const auto &a_shape = a_desc->shape();
const auto &b_shape = b_desc->shape();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64);
CHECK_SAME_SHAPE(c_shape, a_shape, b_shape);
// create CUDA elementwise descriptor
CREATE_ELEMENTWISE_CUDA_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec)
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *output,
std::vector<const void *> inputs,
void *stream) const {
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
switch (_dtype) {
case INFINI_DTYPE_F16:
return _device_info->calculate<256, AddOp, half>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F32:
return _device_info->calculate<256, AddOp, float>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F64:
return _device_info->calculate<256, AddOp, double>(_info, workspace, output, inputs, stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::add::cuda
#ifndef __ADD_CUDA_API_H__
#define __ADD_CUDA_API_H__
#include "../../../elementwise/cuda/elementwise_cuda_api.cuh"
ELEMENTWISE_DESCRIPTOR(add, cuda)
#endif // __ADD_CUDA_API_H__
#ifndef __ADD_CUDA_H__
#define __ADD_CUDA_H__
#include "../../../elementwise/cuda/elementwise_cuda.cuh"
#include <cuda_fp16.h>
namespace op::add::cuda {
typedef struct AddOp {
public:
static constexpr size_t num_inputs = 2;
template <typename T>
__device__ __forceinline__ T operator()(const T &a, const T &b) const {
if constexpr (std::is_same_v<T, half2>) {
return __hadd2(a, b);
} else if constexpr (std::is_same_v<T, half>) {
return __hadd(a, b);
} else if constexpr (std::is_same_v<T, float>) {
return __fadd_rd(a, b);
} else {
return a + b;
}
}
} AddOp;
} // namespace op::add::cuda
#endif // __ADD_CUDA_H__
#include "../../operator.h"
#include "../../handle.h"
#include "infiniop/ops/add.h"
#ifdef ENABLE_CPU_API
#include "cpu/add_cpu.h"
#endif
#ifdef ENABLE_CUDA_API
#include "cuda/add_cuda.cuh"
#endif
__C infiniStatus_t infiniopCreateAddDescriptor(
infiniopHandle_t handle,
infiniopAddDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t c_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc) {
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::add::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::add::NAMESPACE::Descriptor **>(desc_ptr), \
c_desc, \
{a_desc, \
b_desc})
switch (handle->device) {
#ifdef ENABLE_CPU_API
CREATE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_CUDA_API
CREATE(INFINI_DEVICE_NVIDIA, cuda);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef CREATE
}
__C infiniStatus_t infiniopGetAddWorkspaceSize(infiniopAddDescriptor_t desc, size_t *size) {
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<op::add::NAMESPACE::Descriptor *>(desc)->workspaceSize(); \
return INFINI_STATUS_SUCCESS;
switch (desc->device_type) {
#ifdef ENABLE_CPU_API
GET(INFINI_DEVICE_CPU, cpu)
#endif
#ifdef ENABLE_CUDA_API
GET(INFINI_DEVICE_NVIDIA, cuda)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef GET
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
__C infiniStatus_t infiniopAdd(
infiniopAddDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *c,
const void *a,
const void *b,
void *stream) {
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<const op::add::NAMESPACE::Descriptor *>(desc) \
->calculate(workspace, workspace_size, c, {a, b}, stream)
switch (desc->device_type) {
#ifdef ENABLE_CPU_API
CALCULATE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_CUDA_API
CALCULATE(INFINI_DEVICE_NVIDIA, cuda);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef CALCULATE
}
__C infiniStatus_t
infiniopDestroyAddDescriptor(infiniopAddDescriptor_t desc) {
#define DELETE(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<const op::add::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS;
switch (desc->device_type) {
#ifdef ENABLE_CPU_API
DELETE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_CUDA_API
DELETE(INFINI_DEVICE_NVIDIA, cuda);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef DELETE
}
#ifndef ATTENTION_H
#define ATTENTION_H
#include "../../operator.h"
#include "info.h"
#define DESCRIPTOR(NAMESPACE) \
\
namespace op::attention::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
size_t _workspace_size; \
\
Descriptor( \
Opaque *opaque, \
size_t workspace_size, \
infiniDevice_t device_type, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
_workspace_size(workspace_size) {} \
\
public: \
~Descriptor(); \
\
size_t workspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t y_desc, \
infiniopTensorDescriptor_t x_desc); \
}; \
}
#endif // ATTENTION_H
#include "../../operator.h"
#include "../../../utils.h"
#include "../../../utils/check.h"
#include "../../handle.h"
#include "../../tensor.h"
#include "infiniop/ops/attention.h"
#include "infiniop/ops/causal_softmax.h"
#include "infiniop/ops/gemm.h"
#include "infiniop/ops/rearrange.h"
#include <cmath>
#include <cstdint>
struct InfiniopAttentionDescriptor {
InfiniopDescriptor _super;
infiniopRearrangeDescriptor_t rearrange_desc_k;
infiniopRearrangeDescriptor_t rearrange_desc_v;
infiniopRearrangeDescriptor_t rearrange_desc_q;
infiniopRearrangeDescriptor_t rearrange_desc_out;
infiniopGemmDescriptor_t matmul_desc1;
infiniopGemmDescriptor_t matmul_desc2;
infiniopCausalSoftmaxDescriptor_t softmax_desc;
size_t workspace_size;
size_t op_workspace_offset;
size_t op_workspace_size;
size_t q_cont_offset;
size_t att_score_offset;
size_t att_val_offset;
size_t k_cache_offset;
size_t v_cache_offset;
float qk_alpha;
};
__C __export infiniStatus_t infiniopCreateAttentionDescriptor(infiniopHandle_t handle,
infiniopAttentionDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t q_desc,
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
size_t pos) {
if (out_desc->ndim() != 3 || q_desc->ndim() != 3 || k_desc->ndim() != 3 || v_desc->ndim() != 3 || k_cache_desc->ndim() != 3 || v_cache_desc->ndim() != 3) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (!out_desc->isContiguous(0, 2)) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
if (q_desc->strides()[2] != 1 || k_desc->strides()[2] != 1 || v_desc->strides()[2] != 1 || k_cache_desc->strides()[2] != 1 || v_cache_desc->strides()[2] != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
size_t n_q_head = q_desc->shape()[0];
size_t seq_len = q_desc->shape()[1];
size_t head_dim = q_desc->shape()[2];
size_t hidden_size = n_q_head * head_dim;
size_t n_kv_head = k_desc->shape()[0];
size_t total_seq_len = seq_len + pos;
size_t n_group = n_q_head / n_kv_head;
size_t alignment = 256;
if (out_desc->shape()[0] != seq_len || out_desc->shape()[1] != n_q_head || out_desc->shape()[2] != head_dim) {
return INFINI_STATUS_BAD_PARAM;
}
// k: [n_kv_head, seq_len, head_dim]
if (k_desc->shape()[0] != n_kv_head || k_desc->shape()[1] != seq_len || k_desc->shape()[2] != head_dim) {
return INFINI_STATUS_BAD_PARAM;
}
// v: [n_kv_head, seq_len, head_dim]
if (v_desc->shape()[0] != n_kv_head || v_desc->shape()[1] != seq_len || v_desc->shape()[2] != head_dim) {
return INFINI_STATUS_BAD_PARAM;
}
// k_cache: [n_kv_head, _, head_dim]
if (k_cache_desc->shape()[0] != n_kv_head || k_cache_desc->shape()[1] < total_seq_len || k_cache_desc->shape()[2] != head_dim) {
return INFINI_STATUS_BAD_PARAM;
}
// v_cache: [n_kv_head, _, head_dim]
if (v_cache_desc->shape()[0] != n_kv_head || v_cache_desc->shape()[1] < total_seq_len || v_cache_desc->shape()[2] != head_dim) {
return INFINI_STATUS_BAD_PARAM;
}
// Rearrange k into k_cache
infiniopTensorDescriptor_t dst_k_desc;
CHECK_STATUS(infiniopCreateTensorDescriptor(&dst_k_desc, 3, k_desc->shape().data(), k_cache_desc->strides().data(), k_cache_desc->dtype()));
infiniopRearrangeDescriptor_t rearrange_desc_k;
CHECK_STATUS(infiniopCreateRearrangeDescriptor(handle, &rearrange_desc_k, dst_k_desc, k_desc));
// Rearrange v into v_cache
infiniopTensorDescriptor_t dst_v_desc;
CHECK_STATUS(infiniopCreateTensorDescriptor(&dst_v_desc, 3, v_desc->shape().data(), v_cache_desc->strides().data(), v_cache_desc->dtype()));
infiniopRearrangeDescriptor_t rearrange_desc_v;
CHECK_STATUS(infiniopCreateRearrangeDescriptor(handle, &rearrange_desc_v, dst_v_desc, v_desc));
infiniopRearrangeDescriptor_t rearrange_desc_q = nullptr;
size_t q_cont_size = 0;
infiniopTensorDescriptor_t rearranged_q_desc;
// Rearrange q into contiguous
if (!q_desc->isContiguous(0, 1)) {
CHECK_STATUS(infiniopCreateTensorDescriptor(&rearranged_q_desc, 3, q_desc->shape().data(), nullptr, q_desc->dtype()));
q_cont_size = utils::align(rearranged_q_desc->numel() * infiniSizeOf(rearranged_q_desc->dtype()), alignment);
rearrange_desc_q = new InfiniopDescriptor;
CHECK_STATUS(infiniopCreateRearrangeDescriptor(handle, &rearrange_desc_q, rearranged_q_desc, q_desc));
}
// Matmul1: q * full_k
// q: [n_q_head, seq_len, head_dim] -> [n_kv_head, n_group *seq_len, head_dim]
infiniopTensorDescriptor_t reshaped_q_desc;
CHECK_STATUS(infiniopCreateTensorDescriptor(&reshaped_q_desc, 3, q_desc->shape().data(), nullptr, q_desc->dtype()));
TRANSFORM_TENSOR_DESC(reshaped_q_desc, dimSplit(0, {n_kv_head, n_group}));
TRANSFORM_TENSOR_DESC(reshaped_q_desc, dimMerge(1, 2));
// full_k: [n_kv_head, head_dim, total_seq_len]
infiniopTensorDescriptor_t full_k_desc;
size_t full_k_shape[3] = {n_kv_head, total_seq_len, head_dim};
CHECK_STATUS(infiniopCreateTensorDescriptor(&full_k_desc, 3, full_k_shape, k_cache_desc->strides().data(), k_cache_desc->dtype()));
TRANSFORM_TENSOR_DESC(full_k_desc, dimPermute({0, 2, 1}));
// qk: [n_kv_head, n_group * seq_len, total_seq_len]
infiniopTensorDescriptor_t qk_desc;
size_t qk_shape[3] = {n_kv_head, n_group * seq_len, total_seq_len};
CHECK_STATUS(infiniopCreateTensorDescriptor(&qk_desc, 3, qk_shape, nullptr, q_desc->dtype()));
// matmul1_desc
// qk_alpha
float qk_alpha = 1 / sqrt(head_dim);
infiniopGemmDescriptor_t matmul1_desc;
CHECK_STATUS(infiniopCreateGemmDescriptor(handle, &matmul1_desc, qk_desc, reshaped_q_desc, full_k_desc));
// matmul1 workspace size
size_t matmul1_workspace_size;
CHECK_STATUS(infiniopGetGemmWorkspaceSize(matmul1_desc, &matmul1_workspace_size));
matmul1_workspace_size = utils::align(matmul1_workspace_size, alignment);
// attention score tensor size
size_t attn_score_size = utils::align(qk_desc->numel() * infiniSizeOf(qk_desc->dtype()), alignment);
// CausalSoftmax: softmax(qk)
// qk: [n_kv_head, n_group * seq_len, total_seq_len] -> [n_q_head, seq_len, total_seq_len]
TRANSFORM_TENSOR_DESC(qk_desc, dimSplit(1, {n_group, seq_len}));
TRANSFORM_TENSOR_DESC(qk_desc, dimMerge(0, 1));
infiniopCausalSoftmaxDescriptor_t softmax_desc;
CHECK_STATUS(infiniopCreateCausalSoftmaxDescriptor(handle, &softmax_desc, qk_desc, qk_desc));
// softmax workspace size
size_t softmax_workspace_size;
CHECK_STATUS(infiniopGetCausalSoftmaxWorkspaceSize(softmax_desc, &softmax_workspace_size));
softmax_workspace_size = utils::align(softmax_workspace_size, alignment);
// Matmul2: softmax(qk) * full_v
// softmax(qk): [n_q_head, seq_len, total_seq_len] -> [n_kv_head, n_group * seq_len, total_seq_len]
// full_v: [n_kv_head, total_seq_len, head_dim]
TRANSFORM_TENSOR_DESC(qk_desc, dimSplit(0, {n_kv_head, n_group}));
TRANSFORM_TENSOR_DESC(qk_desc, dimMerge(1, 2));
infiniopTensorDescriptor_t full_v_desc;
size_t full_v_shape[3] = {n_kv_head, total_seq_len, head_dim};
CHECK_STATUS(infiniopCreateTensorDescriptor(&full_v_desc, 3, full_v_shape, v_cache_desc->strides().data(), v_cache_desc->dtype()));
// temp_out: [n_kv_head, n_group * seq_len, head_dim]
infiniopTensorDescriptor_t att_val_desc;
size_t temp_out_shape[3] = {n_kv_head, n_group * seq_len, head_dim};
CHECK_STATUS(infiniopCreateTensorDescriptor(&att_val_desc, 3, temp_out_shape, nullptr, q_desc->dtype()));
// matmul2_desc
infiniopGemmDescriptor_t matmul2_desc;
CHECK_STATUS(infiniopCreateGemmDescriptor(handle, &matmul2_desc, att_val_desc, qk_desc, full_v_desc));
// matmul2 workspace size
size_t matmul2_workspace_size;
CHECK_STATUS(infiniopGetGemmWorkspaceSize(matmul2_desc, &matmul2_workspace_size));
matmul2_workspace_size = utils::align(matmul2_workspace_size, alignment);
// attention value tensor size
size_t att_val_size = utils::align(att_val_desc->numel() * infiniSizeOf(att_val_desc->dtype()), alignment);
// Rearrange temp_out into out
// out: [seq_len, n_q_head, head_dim]
// temp_out: [n_kv_head, n_group * seq_len, head_dim] -> [n_q_head, seq_len, head_dim] -> [seq_len, n_q_head, head_dim]
TRANSFORM_TENSOR_DESC(att_val_desc, dimSplit(1, {n_group, seq_len}));
TRANSFORM_TENSOR_DESC(att_val_desc, dimMerge(0, 1));
TRANSFORM_TENSOR_DESC(att_val_desc, dimPermute({1, 0, 2}));
infiniopRearrangeDescriptor_t rearrange_desc_out;
CHECK_STATUS(infiniopCreateRearrangeDescriptor(handle, &rearrange_desc_out, out_desc, att_val_desc));
// workspace size
size_t op_workspace_size = utils::align(std::max(std::max(matmul1_workspace_size, matmul2_workspace_size), softmax_workspace_size), alignment);
size_t temp_tensors_size = attn_score_size + std::max(q_cont_size, att_val_size);
size_t workspace_size = temp_tensors_size + op_workspace_size;
// k_cache_offset
size_t k_cache_offset = 0;
if (pos > 0) {
k_cache_offset = pos * k_cache_desc->getByteStrides()[1];
}
// v_cache_offset
size_t v_cache_offset = 0;
if (pos > 0) {
v_cache_offset = pos * v_cache_desc->getByteStrides()[1];
}
// create attention descriptor
*(InfiniopAttentionDescriptor **)desc_ptr = new InfiniopAttentionDescriptor{
{handle->device, handle->device_id},
rearrange_desc_k,
rearrange_desc_v,
rearrange_desc_q,
rearrange_desc_out,
matmul1_desc,
matmul2_desc,
softmax_desc,
workspace_size,
temp_tensors_size,
op_workspace_size,
attn_score_size,
0,
attn_score_size,
k_cache_offset,
v_cache_offset,
1.f / std::sqrt(float(head_dim)),
};
return INFINI_STATUS_SUCCESS;
}
__C __export infiniStatus_t infiniopGetAttentionWorkspaceSize(infiniopAttentionDescriptor_t desc, size_t *size) {
*size = ((InfiniopAttentionDescriptor *)desc)->workspace_size;
return INFINI_STATUS_SUCCESS;
}
__C __export infiniStatus_t infiniopAttention(infiniopAttentionDescriptor_t desc_,
void *workspace_,
size_t workspace_size_,
void *out,
void const *q,
void const *k,
void const *v,
void *k_cache,
void *v_cache,
void *stream) {
auto desc = (InfiniopAttentionDescriptor *)desc_;
if (workspace_size_ < desc->workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE; // STATUS_MEMORY_NOT_ALLOCATED
}
void *workspace = (char *)workspace_ + desc->op_workspace_offset;
size_t workspace_size = desc->op_workspace_size;
void *att_score = (char *)workspace_ + desc->att_score_offset;
void *att_val = (char *)workspace_ + desc->att_val_offset;
void const *q_ = q;
// concat k and v to k_cache and v_cache
CHECK_STATUS(infiniopRearrange(desc->rearrange_desc_k,
(char *)k_cache + desc->k_cache_offset, k, stream));
CHECK_STATUS(infiniopRearrange(desc->rearrange_desc_v,
(char *)v_cache + desc->v_cache_offset, v, stream));
// rearrange q into contiguous
if (desc->rearrange_desc_q) {
void *q_cont = (char *)workspace_ + desc->q_cont_offset;
CHECK_STATUS(infiniopRearrange(desc->rearrange_desc_q, q_cont, q, stream));
q_ = q_cont;
}
// matmul1: q * full_k
CHECK_STATUS(infiniopGemm(desc->matmul_desc1,
workspace, workspace_size,
att_score, q_, k_cache, desc->qk_alpha, 0.0, stream));
// softmax(qk)
CHECK_STATUS(infiniopCausalSoftmax(desc->softmax_desc,
workspace, workspace_size,
att_score, att_score, stream));
// matmul2: softmax(qk) * full_v
CHECK_STATUS(infiniopGemm(desc->matmul_desc2,
workspace, workspace_size,
att_val, att_score, v_cache, 1.0, 0.0, stream));
// rearrange out
CHECK_STATUS(infiniopRearrange(desc->rearrange_desc_out, out, att_val, stream));
return INFINI_STATUS_SUCCESS;
}
__C __export infiniStatus_t infiniopDestroyAttentionDescriptor(infiniopAttentionDescriptor_t desc_) {
auto desc = (InfiniopAttentionDescriptor *)desc_;
if (desc->rearrange_desc_q) {
CHECK_STATUS(infiniopDestroyRearrangeDescriptor(desc->rearrange_desc_q));
}
CHECK_STATUS(infiniopDestroyRearrangeDescriptor(desc->rearrange_desc_k));
CHECK_STATUS(infiniopDestroyRearrangeDescriptor(desc->rearrange_desc_v));
CHECK_STATUS(infiniopDestroyRearrangeDescriptor(desc->rearrange_desc_out));
CHECK_STATUS(infiniopDestroyGemmDescriptor(desc->matmul_desc1));
CHECK_STATUS(infiniopDestroyGemmDescriptor(desc->matmul_desc2));
CHECK_STATUS(infiniopDestroyCausalSoftmaxDescriptor(desc->softmax_desc));
delete desc;
return INFINI_STATUS_SUCCESS;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment