Unverified Commit 95fd5c1b authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #156 from InfiniTensor/issue/127_add_general_cuda_elementwise

Issue/127/feat. General Elementwise Framework with Refactored SwiGLU (CPU & CUDA)
parents b985bc5e da881f4d
...@@ -11,7 +11,11 @@ __C __export infiniStatus_t infiniopCreateSwiGLUDescriptor(infiniopHandle_t hand ...@@ -11,7 +11,11 @@ __C __export infiniStatus_t infiniopCreateSwiGLUDescriptor(infiniopHandle_t hand
infiniopTensorDescriptor_t a_desc, infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc); infiniopTensorDescriptor_t b_desc);
__C __export infiniStatus_t infiniopGetSwiGLUWorkspaceSize(infiniopSwiGLUDescriptor_t desc, size_t *size);
__C __export infiniStatus_t infiniopSwiGLU(infiniopSwiGLUDescriptor_t desc, __C __export infiniStatus_t infiniopSwiGLU(infiniopSwiGLUDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *c, void *c,
void const *a, void const *a,
void const *b, void const *b,
......
...@@ -9,6 +9,10 @@ ...@@ -9,6 +9,10 @@
#define CUDA_BLOCK_SIZE_1024 1024 #define CUDA_BLOCK_SIZE_1024 1024
#define CUDA_BLOCK_SIZE_512 512 #define CUDA_BLOCK_SIZE_512 512
#define CHECK_CUDA(API) CHECK_INTERNAL(API, cudaSuccess)
namespace device::cuda {
// return the memory offset of original tensor, given the flattened index of broadcasted tensor // return the memory offset of original tensor, given the flattened index of broadcasted tensor
__forceinline__ __device__ __host__ size_t __forceinline__ __device__ __host__ size_t
indexToReducedOffset( indexToReducedOffset(
...@@ -38,6 +42,7 @@ indexToOffset( ...@@ -38,6 +42,7 @@ indexToOffset(
} }
return res; return res;
} }
} // namespace device::cuda
#ifdef ENABLE_CUDA_API #ifdef ENABLE_CUDA_API
#include <cuda_fp16.h> #include <cuda_fp16.h>
......
#ifndef __INFINIOP_ELEMENTWISE_CPU_H__
#define __INFINIOP_ELEMENTWISE_CPU_H__
#include "../../devices/cpu/common_cpu.h"
#include "../elementwise.h"
#include <utility>
/**
* @brief Define the process for initializing a Descriptor of an elementwise operation
* for its CPU implementation
*
* @param HANDLE The device handle.
* @param DTYPE The output dtype.
* @param OUT_DESC The output tensor descriptor.
* @param INPUT_DESC_VEC A vector containing input tensor descriptors.
*/
#define CREATE_ELEMENTWISE_CPU_DESCRIPTOR(HANDLE, DTYPE, OUT_DESC, INPUT_DESC_VEC) \
\
auto info_result = op::elementwise::ElementwiseInfo::create(OUT_DESC, INPUT_DESC_VEC); \
CHECK_RESULT(info_result); \
\
*desc_ptr = new Descriptor( \
DTYPE, \
info_result.take(), \
nullptr, \
0, \
HANDLE->device, \
HANDLE->device_id);
namespace op::elementwise::cpu {
/**
* @brief CPU-specific device implementation for resource management and
* calculation implementations.
*
* This class encapsulates device-specific behavior and execution logic.
* Use the static create() method to instantiate a DeviceImpl.
*/
class DeviceImpl final {
struct Opaque;
std::shared_ptr<Opaque> _opaque;
DeviceImpl(std::shared_ptr<Opaque> opaque) : _opaque(std::move(opaque)) {}
public:
~DeviceImpl() = default;
template <typename... Args>
static utils::Result<DeviceImpl> create(Args &&...args);
/**
* @brief Dispatches an elementwise operation with uniform input types.
*
* @tparam Op The elementwise operation to perform.
* @tparam Tdata The common data type of all inputs and output.
* @tparam Args Additional backend-specific arguments.
* @param info Precomputed tensor metadata (shapes, strides, etc.).
* @param output Pointer to the output tensor buffer.
* @param inputs Vector of input tensor data pointers.
* @param stream Device execution stream.
* @param args Additional backend-specific arguments.
* @return infiniStatus_t Status indicating success or failure.
*/
template <typename Op, typename Tdata, typename... Args>
infiniStatus_t calculate(
const op::elementwise::ElementwiseInfo &info,
void *output,
const std::vector<const void *> &inputs,
void *stream,
Args &&...args);
/**
* @brief Dispatches an elementwise operation with heterogeneous input types.
*
* Supports operations where each input may have a different type, as defined by Op.
* The number of input types must match the operation's expected input count.
*
* @tparam Op The elementwise operation to perform.
* @tparam Tout Output data type.
* @tparam Tin Variadic input data types.
* @tparam Args Additional backend-specific arguments.
* @param info Precomputed tensor metadata (shapes, strides, etc.).
* @param output Pointer to the output tensor buffer.
* @param inputs Vector of input tensor data pointers.
* @param stream Device execution stream.
* @param args Additional backend-specific arguments.
* @return infiniStatus_t Status indicating success or failure.
*/
template <typename Op, typename Tout, typename... Tin,
typename... Args,
std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int> = 0>
infiniStatus_t calculate(
const op::elementwise::ElementwiseInfo &info,
void *output,
const std::vector<const void *> &inputs,
void *stream,
Args &&...args);
};
// Define the Opaque struct for CPU, which is empty
struct DeviceImpl::Opaque {};
template <typename... Args>
utils::Result<DeviceImpl> DeviceImpl::create(Args &&...args) {
return utils::Result<DeviceImpl>(nullptr);
}
// Perform elementwise operation for different input types
template <typename Op, typename Tout, typename... Tin, size_t... Is, typename... Args,
std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int> = 0>
void calculate_impl(const op::elementwise::ElementwiseInfo &info,
void *output,
const std::vector<const void *> &inputs,
std::index_sequence<Is...>,
Args &&...args) {
Tout *out = reinterpret_cast<Tout *>(output);
std::tuple<const Tin *...> input_ptrs = {reinterpret_cast<const Tin *>(inputs[Is])...};
ptrdiff_t output_size = info.getOutputSize();
#pragma omp parallel for
for (ptrdiff_t i = 0; i < output_size; ++i) {
size_t out_idx = info.isOutputContiguous()
? i
: op::common_cpu::indexToOffset(i, info.getNdim(), info.getOutputShape(), info.getOutputStrides());
auto get_input_idx = [&](size_t input_id) {
return info.getInputContiguous()[input_id]
? i
: (info.getInputBroadcasted()[input_id]
? op::common_cpu::indexToReducedOffset(i, info.getNdim(), info.getOutputStrides(), info.getInputStrides(input_id))
: op::common_cpu::indexToOffset(i, info.getNdim(), info.getInputShape(input_id), info.getInputStrides(input_id)));
};
out[out_idx] = utils::cast<Tout>(
Op{}.template operator()<Tout, Tin...>(std::get<Is>(input_ptrs)[get_input_idx(Is)]..., std::forward<Args>(args)...));
}
}
// Invoke elementwise operation for different input types
template <typename Op, typename Tout, typename... Tin, typename... Args, std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int>>
infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info,
void *output,
const std::vector<const void *> &inputs,
void *stream,
Args &&...args) {
static_assert(sizeof...(Tin) == Op::num_inputs, "Input type count mismatch");
calculate_impl<Op, Tout, Tin...>(info, output, inputs, std::make_index_sequence<sizeof...(Tin)>{}, std::forward<Args>(args)...);
return INFINI_STATUS_SUCCESS;
}
// Perform elementwise operation when all inputs have the same type
template <typename Op, typename Tdata, size_t... Is, typename... Args>
void calculate_impl(const op::elementwise::ElementwiseInfo &info,
void *output,
const std::vector<const void *> &inputs,
std::index_sequence<Is...>,
Args &&...args) {
Tdata *out = reinterpret_cast<Tdata *>(output);
std::array<const Tdata *, sizeof...(Is)> ins = {reinterpret_cast<const Tdata *>(inputs[Is])...};
const ptrdiff_t output_size = info.getOutputSize();
#pragma omp parallel for
for (ptrdiff_t i = 0; i < output_size; ++i) {
size_t out_idx = info.isOutputContiguous()
? i
: op::common_cpu::indexToOffset(i, info.getNdim(), info.getOutputShape(), info.getOutputStrides());
auto get_input_idx = [&](size_t input_id) {
return info.getInputContiguous()[input_id]
? i
: (info.getInputBroadcasted()[input_id]
? op::common_cpu::indexToReducedOffset(i, info.getNdim(), info.getOutputStrides(), info.getInputStrides(input_id))
: op::common_cpu::indexToOffset(i, info.getNdim(), info.getInputShape(input_id), info.getInputStrides(input_id)));
};
if constexpr (std::is_same_v<Tdata, fp16_t>) {
out[out_idx] = utils::cast<fp16_t>(Op{}(utils::cast<float>(ins[Is][get_input_idx(Is)])..., std::forward<Args>(args)...));
} else {
out[out_idx] = Op{}(ins[Is][get_input_idx(Is)]..., std::forward<Args>(args)...);
}
}
}
// Invoke elementwise operation when all inputs have the same type
template <typename Op, typename Tdata, typename... Args>
infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info,
void *output,
const std::vector<const void *> &inputs,
void *stream,
Args &&...args) {
constexpr size_t N = Op::num_inputs;
calculate_impl<Op, Tdata>(info, output, inputs, std::make_index_sequence<N>{}, std::forward<Args>(args)...);
return INFINI_STATUS_SUCCESS;
}
} // namespace op::elementwise::cpu
#endif // __INFINIOP_ELEMENTWISE_CPU_H__
#ifndef __INFINIOP_ELEMENTWISE_CUDA_H__
#define __INFINIOP_ELEMENTWISE_CUDA_H__
#include "../../../utils.h"
#include "../../devices/cuda/cuda_common.cuh"
#include "../../devices/cuda/cuda_kernel_common.cuh"
#include "elementwise_cuda_api.cuh"
namespace op::elementwise::cuda {
/**
* @brief Casts an untyped device pointer to a typed pointer of type T.
*
* @tparam T Desired pointer type.
*
* @param ptr Untyped pointer.
* @return Pointer of type const T*.
*/
template <typename T>
__device__ __forceinline__ const T *typedInputPtr(const void *ptr) {
return reinterpret_cast<const T *>(ptr);
}
/**
* @brief Computes the output index in memory, accounting for strides if non-contiguous.
*
* @param idx Linear index.
* @param is_contiguous Whether the output tensor is contiguous.
* @param ndim Number of dimensions.
* @param shape Shape of the output tensor.
* @param strides Strides of the output tensor.
* @return Memory offset index.
*/
__device__ __forceinline__ size_t getOutputIndex(size_t idx, bool is_contiguous, size_t ndim,
const size_t *shape, const ptrdiff_t *strides) {
return is_contiguous ? idx : device::cuda::indexToOffset(idx, ndim, shape, strides);
}
/**
* @brief Computes input element offset for broadcasting and strided access.
*
* Used to map a linear output index to the corresponding index in an input tensor,
* considering contiguity and broadcasting.
*/
struct InputIndexer {
size_t idx;
size_t ndim;
const bool *input_contiguous;
const bool *input_broadcasted;
const size_t *input_shapes;
const ptrdiff_t *input_strides;
const ptrdiff_t *output_strides;
/**
* @brief Computes the memory offset for a given input tensor at current index.
*
* @param input_id ID of the input tensor.
* @return Offset into the input tensor.
*/
__device__ __forceinline__ size_t operator()(size_t input_id) const {
return input_contiguous[input_id]
? idx
: (input_broadcasted[input_id]
? device::cuda::indexToReducedOffset(idx, ndim, output_strides, input_strides + input_id * ndim)
: device::cuda::indexToOffset(idx, ndim, input_shapes + input_id * ndim, input_strides + input_id * ndim));
}
};
/**
* @brief Invokes a callable with compile-time index constants.
*
* Used to unpack index sequence for variadic template processing of inputs.
*
* @tparam F Callable type.
* @tparam Is Compile-time index sequence.
*
* @param f Callable to invoke with index constants.
*/
template <typename F, size_t... Is>
__device__ __forceinline__ void unpackInputsAndApply(F &&f, std::index_sequence<Is...>) {
f(std::integral_constant<size_t, Is>{}...);
}
/**
* @brief CUDA kernel for performing elementwise operations on tensors where all inputs share the same data type.
*
* @tparam N Number of input tensors.
* @tparam Op Operator type implementing operator()(Tdata...).
* @tparam Tdata Common data type for inputs and output.
* @tparam Args Additional arguments to pass to the operator.
*
* @param output_size Total number of output elements.
* @param ndim Number of dimensions in tensors.
* @param output_contiguous Whether the output tensor is contiguous in memory.
* @param input_contiguous Array indicating if each input tensor is contiguous.
* @param input_broadcasted Array indicating if each input tensor is broadcasted.
* @param output_shape Shape of the output tensor.
* @param input_shapes Shapes of the input tensors.
* @param output_strides Strides for the output tensor.
* @param input_strides Strides for each input tensor.
* @param output Output buffer.
* @param inputs Array of input pointers, all of type Tdata.
* @param offset Linear offset to support partitioned execution.
* @param args Additional arguments passed to the operator.
*/
template <size_t N, typename Op, typename Tdata, typename... Args>
INFINIOP_CUDA_KERNEL elementwiseKernel(
size_t output_size,
size_t ndim,
bool output_contiguous,
const bool *__restrict__ input_contiguous,
const bool *__restrict__ input_broadcasted,
const size_t *__restrict__ output_shape,
const size_t *__restrict__ input_shapes,
const ptrdiff_t *__restrict__ output_strides,
const ptrdiff_t *__restrict__ input_strides,
Tdata *output,
const void *const *inputs,
size_t offset,
Args... args) {
size_t idx = blockIdx.x * blockDim.x + threadIdx.x + offset;
if (idx < output_size) {
const Tdata *const *typed_inputs = reinterpret_cast<const Tdata *const *>(inputs);
size_t out_idx = getOutputIndex(idx, output_contiguous, ndim, output_shape, output_strides);
InputIndexer indexer{idx, ndim, input_contiguous, input_broadcasted, input_shapes, input_strides, output_strides};
unpackInputsAndApply(
[&](auto... Is) {
output[out_idx] = Op{}(typed_inputs[Is.value][indexer(Is.value)]..., std::forward<Args>(args)...);
},
std::make_index_sequence<N>{});
}
}
/**
* @brief CUDA kernel for performing an elementwise operation on tensors with support
* for broadcasting and mixed data types.
*
* @tparam Op Operator type implementing a templated operator() for (Tout, Tin...).
* @tparam Tout Output data type.
* @tparam Tin Variadic input data types.
*
* @param output_size Total number of output elements.
* @param ndim Number of dimensions in the tensors.
* @param output_contiguous Whether the output tensor is contiguous.
* @param input_contiguous Array indicating whether each input is contiguous.
* @param input_broadcasted Array indicating whether each input is broadcasted.
* @param output_shape Shape of the output tensor.
* @param input_shapes Shapes of the input tensors.
* @param output_strides Strides of the output tensor.
* @param input_strides Strides of the input tensors.
* @param output Pointer to the output buffer.
* @param inputs Array of untyped input pointers.
* @param offset Linear offset into the output for partitioned execution.
*/
template <typename Op, typename Tout, typename... Tin>
INFINIOP_CUDA_KERNEL elementwiseKernel(
size_t output_size,
size_t ndim,
bool output_contiguous,
const bool *__restrict__ input_contiguous,
const bool *__restrict__ input_broadcasted,
const size_t *__restrict__ output_shape,
const size_t *__restrict__ input_shapes,
const ptrdiff_t *__restrict__ output_strides,
const ptrdiff_t *__restrict__ input_strides,
Tout *output,
const void *const *__restrict__ inputs,
size_t offset) {
size_t idx = blockIdx.x * blockDim.x + threadIdx.x + offset;
if (idx < output_size) {
size_t out_idx = getOutputIndex(idx, output_contiguous, ndim, output_shape, output_strides);
InputIndexer indexer{idx, ndim, input_contiguous, input_broadcasted, input_shapes, input_strides, output_strides};
unpackInputsAndApply(
[&](auto... Is) {
output[out_idx] = Op{}.template operator()<Tout, Tin...>(
(typedInputPtr<Tin>(inputs[Is.value])[indexer(Is.value)])...);
},
std::index_sequence_for<Tin...>{});
}
}
struct DeviceImpl::Opaque {
std::shared_ptr<device::cuda::Handle::Internal> internal;
Opaque(const std::shared_ptr<device::cuda::Handle::Internal> &internal)
: internal(internal) {}
/**
* @brief Executes an elementwise operation where all inputs and the output share the same data type.
*
* @tparam BLOCK_SIZE CUDA block size used for kernel launch.
* @tparam N Number of input tensors.
* @tparam Op Functor representing the elementwise operation.
* @tparam Tdata Data type of both input and output tensors.
* @tparam Args Optional additional arguments passed to the operation.
*
* @param info Metadata about the operation including shape, size, and dimensionality.
* @param workspace Temporary workspace used for storing metadata on device.
* @param output Pointer to the output buffer.
* @param inputs Vector of pointers to input buffers.
* @param stream CUDA stream for asynchronous execution.
* @param args Additional arguments forwarded to the operation.
* @return infiniStatus_t Returns success or failure status.
*/
template <size_t BLOCK_SIZE, size_t N, typename Op, typename Tdata, typename... Args>
infiniStatus_t calculateImpl(const op::elementwise::ElementwiseInfo &info,
void *workspace,
void *output,
const std::vector<const void *> &inputs,
cudaStream_t stream,
Args &&...args) {
return launchElementwiseKernel<BLOCK_SIZE, N>(
info, workspace,
reinterpret_cast<Tdata *>(output), inputs,
elementwiseKernel<N, Op, Tdata, Args...>,
stream,
std::forward<Args>(args)...);
}
/**
* @brief Executes an elementwise operation with mixed input and output data types.
*
* @tparam BLOCK_SIZE CUDA block size used for kernel launch.
* @tparam N Number of input tensors.
* @tparam Op Functor representing the elementwise operation.
* @tparam Tout Data type of the output tensor.
* @tparam Tin... Data types of the input tensors.
* @tparam Args Optional additional arguments passed to the operation.(UNUSED)
*
* @param info Metadata about the operation including shape, size, and dimensionality.
* @param workspace Temporary workspace used for storing metadata on device.
* @param output Pointer to the output buffer.
* @param inputs Vector of pointers to input buffers.
* @param stream CUDA stream for asynchronous execution.
* @param args Additional arguments forwarded to the operation.
* @return infiniStatus_t Returns success or failure status.
*/
template <size_t BLOCK_SIZE, size_t N, typename Op, typename Tout, typename... Tin, typename... Args,
std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int> = 0>
infiniStatus_t calculateImpl(const op::elementwise::ElementwiseInfo &info,
void *workspace,
void *output,
const std::vector<const void *> &inputs,
cudaStream_t stream,
Args &&...args) {
return launchElementwiseKernel<BLOCK_SIZE, N>(
info, workspace,
reinterpret_cast<Tout *>(output), inputs,
elementwiseKernel<Op, Tout, Tin...>,
stream);
}
private:
/**
* @brief Transfers elementwise operation metadata and input pointers from host to device memory.
*
* @tparam N Number of input tensors.
*
* @param info Elementwise operation metadata (shapes, strides, flags, etc.).
* @param workspace Pointer to device workspace memory for storing metadata and input pointers.
* @param h_inputs_arr Host array of input tensor pointers.
* @param d_inputs_arr Output reference to device array of input tensor pointers.
* @param d_input_contiguous Output reference to device array indicating whether each input is contiguous.
* @param d_input_broadcasted Output reference to device array indicating whether each input is broadcasted.
* @param d_output_shape Output reference to device array holding the output tensor shape.
* @param d_output_strides Output reference to device array holding output tensor strides.
* @param d_input_shapes Output reference to flattened input tensor shapes (N * ndim).
* @param d_input_strides Output reference to flattened input tensor strides (N * ndim).
* @param stream CUDA stream used for asynchronous memory transfer.
* @return infiniStatus_t Status indicating success or failure of the memory transfer and setup.
*/
template <size_t N>
infiniStatus_t infoToDevice(
const op::elementwise::ElementwiseInfo &info,
void *workspace,
const void *const *h_inputs_arr,
const void **&d_inputs_arr,
const bool *&d_input_contiguous,
const bool *&d_input_broadcasted,
const size_t *&d_output_shape,
const ptrdiff_t *&d_output_strides,
const size_t *&d_input_shapes,
const ptrdiff_t *&d_input_strides,
cudaStream_t stream) const {
constexpr auto input_size = N;
const auto ndim = info.getNdim();
constexpr auto input_arr_size = N * sizeof(*h_inputs_arr);
const int8_t *info_meta_start = info.getMetaStart();
const int8_t *d_meta_start = reinterpret_cast<int8_t *>(workspace) + input_arr_size;
// copy the input pointer array and meta to device
CHECK_CUDA(cudaMemcpyAsync(workspace, h_inputs_arr, input_arr_size, cudaMemcpyHostToDevice, stream));
CHECK_CUDA(cudaMemcpyAsync((void *)d_meta_start, info_meta_start, info.getMetaMemSize(), cudaMemcpyHostToDevice, stream));
// offset/assign the pointers
d_inputs_arr = reinterpret_cast<const void **>(workspace);
d_output_shape = reinterpret_cast<const size_t *>(d_meta_start);
d_output_strides = reinterpret_cast<const ptrdiff_t *>(d_output_shape + ndim);
d_input_shapes = reinterpret_cast<const size_t *>(d_output_strides + ndim);
d_input_strides = reinterpret_cast<const ptrdiff_t *>(d_input_shapes + input_size * ndim);
d_input_contiguous = reinterpret_cast<const bool *>(d_input_strides + input_size * ndim);
d_input_broadcasted = reinterpret_cast<const bool *>(d_input_contiguous + input_size);
return INFINI_STATUS_SUCCESS;
}
/**
* @brief Launches the elementwise kernel for the specified operation.
*
* @tparam BLOCK_SIZE Number of threads per block.
* @tparam N Number of input tensors.
* @tparam KernelFunc Type of the kernel function pointer.
* @tparam Tout Output data type.
* @tparam Args Additional arguments to be forwarded to the kernel.
*
* @param info Metadata about the elementwise operation (shapes, strides, etc.).
* @param workspace CUDA memory used for storing metadata.
* @param output Pointer to output buffer on device.
* @param inputs Vector of device pointers to input tensors.
* @param kernel_func Kernel function to launch.
* @param stream CUDA stream for asynchronous execution.
* @param args Additional arguments passed to the kernel.
* @return infiniStatus_t Status code indicating success or failure.
*/
template <size_t BLOCK_SIZE, size_t N, typename KernelFunc, typename Tout, typename... Args>
infiniStatus_t launchElementwiseKernel(
const op::elementwise::ElementwiseInfo &info,
void *workspace,
Tout *output,
const std::vector<const void *> &inputs,
KernelFunc kernel_func,
cudaStream_t stream,
Args &&...args) {
auto output_size = info.getOutputSize();
if (output_size == 0) {
return INFINI_STATUS_SUCCESS;
}
// Device pointers
const void **d_inputs_arr = nullptr;
const bool *d_input_contiguous = nullptr;
const bool *d_input_broadcasted = nullptr;
const size_t *d_output_shape = nullptr;
const ptrdiff_t *d_output_strides = nullptr;
const size_t *d_input_shapes = nullptr;
const ptrdiff_t *d_input_strides = nullptr;
CHECK_STATUS(infoToDevice<N>(info, workspace, inputs.data(), d_inputs_arr,
d_input_contiguous, d_input_broadcasted,
d_output_shape, d_output_strides,
d_input_shapes, d_input_strides, stream));
dim3 blockDims(std::min(BLOCK_SIZE, static_cast<size_t>(internal->maxThreadsPerBlock())));
dim3 gridDims(std::min(CEIL_DIV(output_size, blockDims.x), static_cast<size_t>(internal->gridSizeX())));
size_t step = gridDims.x * blockDims.x;
for (size_t i = 0; i < output_size; i += step) {
kernel_func<<<gridDims, blockDims, 0, stream>>>(
output_size, info.getNdim(), info.isOutputContiguous(),
d_input_contiguous, d_input_broadcasted,
d_output_shape, d_input_shapes,
d_output_strides, d_input_strides,
output, reinterpret_cast<const void **>(d_inputs_arr),
i, std::forward<Args>(args)...);
}
return INFINI_STATUS_SUCCESS;
}
};
template <typename... Args>
utils::Result<DeviceImpl *> DeviceImpl::create(Args &&...args) {
auto opaque = std::make_shared<Opaque>(std::forward<Args>(args)...);
return utils::Result<DeviceImpl *>(new DeviceImpl(opaque));
}
/* Invoke elementwise operation for different input types */
template <unsigned int BLOCK_SIZE, typename Op, typename Tout, typename... Tin, typename... Args,
std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int>>
infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info,
void *workspace,
void *output,
const std::vector<const void *> &inputs,
void *stream,
Args &&...args) {
constexpr size_t N = Op::num_inputs;
static_assert(sizeof...(Tin) == N, "Input type count mismatch");
return _opaque->calculateImpl<BLOCK_SIZE, N, Op, Tout, Tin...>(
info, workspace, output, inputs,
reinterpret_cast<cudaStream_t>(stream),
std::forward<Args>(args)...);
}
/* Invoke elementwise operation when all inputs have the same dtype */
template <unsigned int BLOCK_SIZE, typename Op, typename Tdata, typename... Args>
infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info,
void *workspace,
void *output,
const std::vector<const void *> &inputs,
void *stream,
Args &&...args) {
constexpr size_t N = Op::num_inputs;
return _opaque->calculateImpl<BLOCK_SIZE, N, Op, Tdata>(
info, workspace, output, inputs,
reinterpret_cast<cudaStream_t>(stream),
std::forward<Args>(args)...);
}
} // namespace op::elementwise::cuda
#endif // __INFINIOP_ELEMENTWISE_CUDA_H__
#ifndef __INFINIOP_ELEMENTWISE_CUDA_API_H__
#define __INFINIOP_ELEMENTWISE_CUDA_API_H__
#include "../elementwise.h"
namespace op::elementwise::cuda {
/**
* @brief Define the methods and info needed by CUDA to perform elementwise operation
*/
class DeviceImpl final {
struct Opaque;
std::shared_ptr<Opaque> _opaque;
DeviceImpl(std::shared_ptr<Opaque> opaque) : _opaque(std::move(opaque)) {}
public:
~DeviceImpl() = default;
template <typename... Args>
static utils::Result<DeviceImpl *> create(Args &&...args);
/**
* @brief Launches elementwise operation where all input types are the same.
*
* Calls the corresponding templated `calculateImpl` with a unified input type.
*
* @tparam BLOCK_SIZE Number of threads per block.
* @tparam Op Operation functor defining the computation.
* @tparam Tdata Data type for both input and output tensors.
* @tparam Args... Additional arguments passed to the operation.
*
* @param info Metadata describing tensor shapes, strides, etc.
* @param workspace Pointer to workspace buffer on device.
* @param output Pointer to output buffer on device.
* @param inputs Vector of input pointers (device memory).
* @param stream CUDA stream (opaque void*).
* @param args Additional operation-specific arguments.
* @return infiniStatus_t Status indicating success or failure.
*/
template <unsigned int BLOCK_SIZE, typename Op, typename Tdata, typename... Args>
infiniStatus_t calculate(
const op::elementwise::ElementwiseInfo &info,
void *workspace,
void *output,
const std::vector<const void *> &inputs,
void *stream,
Args &&...args);
/**
* @brief Launches elementwise operation where input types may differ.
*
* Dispatches to templated `calculateImpl` using specified output and input types.
*
* @tparam BLOCK_SIZE Number of threads per block.
* @tparam Op Operation functor defining the computation.
* @tparam Tout Output data type.
* @tparam Tin... Input data types (must match Op::num_inputs).
* @tparam Args... Additional arguments passed to the operation.
*
* @param info Metadata describing tensor shapes, strides, etc.
* @param workspace Pointer to workspace buffer on device.
* @param output Pointer to output buffer on device.
* @param inputs Vector of input pointers (device memory).
* @param stream CUDA stream (opaque void*).
* @param args (UNUSED) Additional operation-specific arguments.
* @return infiniStatus_t Status indicating success or failure.
*/
template <unsigned int BLOCK_SIZE, typename Op, typename Tout, typename... Tin,
typename... Args,
std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int> = 0>
infiniStatus_t calculate(
const op::elementwise::ElementwiseInfo &info,
void *workspace,
void *output,
const std::vector<const void *> &inputs,
void *stream,
Args &&...args);
};
} // namespace op::elementwise::cuda
/**
* @brief Define the process for initializing a Descriptor of an elementwise operation
* for its CUDA implementation
*
* @param HANDLE The device handle.
* @param DTYPE The output dtype.
* @param OUT_DESC The output tensor descriptor.
* @param INPUT_DESC_VEC A vector containing input tensor descriptors.
*/
#define CREATE_ELEMENTWISE_CUDA_DESCRIPTOR(HANDLE, DTYPE, OUT_DESC, INPUT_DESC_VEC) \
\
auto info_result = op::elementwise::ElementwiseInfo::create(OUT_DESC, INPUT_DESC_VEC); \
CHECK_RESULT(info_result); \
auto info = info_result.take(); \
auto workspace_size = info.getMetaMemSize() + info.getInputSize() * sizeof(void *); \
\
auto device_impl_result = op::elementwise::cuda::DeviceImpl::create(HANDLE->internal()); \
CHECK_RESULT(device_impl_result); \
\
*desc_ptr = new Descriptor( \
DTYPE, \
std::move(info), \
std::move(device_impl_result.take()), \
workspace_size, \
HANDLE->device, \
HANDLE->device_id);
#endif // __INFINIOP_ELEMENTWISE_CUDA_API_H__
#ifndef __INFINIOP_ELEMENTWISE_H__
#define __INFINIOP_ELEMENTWISE_H__
#include "../../utils.h"
#include "../operator.h"
#include "../tensor.h"
#include <algorithm>
#include <array>
#include <cstring>
#include <iostream>
#include <memory>
#include <numeric>
#include <vector>
#define ELEMENTWISE_DESCRIPTOR(OP, NAMESPACE) \
\
namespace op::OP::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
infiniDtype_t _dtype; \
op::elementwise::ElementwiseInfo _info; \
std::unique_ptr<op::elementwise::NAMESPACE::DeviceImpl> _device_info; \
size_t _workspace_size; \
\
Descriptor( \
infiniDtype_t dtype, \
op::elementwise::ElementwiseInfo info, \
op::elementwise::NAMESPACE::DeviceImpl *device_info, \
size_t workspace_size, \
infiniDevice_t device_type, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_dtype(dtype), \
_info(std::move(info)), \
_device_info(std::move(device_info)), \
_workspace_size(workspace_size) {} \
\
public: \
~Descriptor(); \
\
size_t workspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t output_desc, \
std::vector<infiniopTensorDescriptor_t> input_descs); \
\
infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
void *output, \
std::vector<const void *> inputs, \
void *stream) const; \
}; \
}
namespace op::elementwise {
/**
* @brief Stores the metadata required for performing an elementwise operation.
*
* This struct encapsulates shape, stride, and layout information for both
* output and multiple input tensors involved in an elementwise operation.
*
* Memory is manually managed and freed in the destructor.
* Supports move construction but disallows copy construction and copy/move assignment.
*
* Use ElementwiseInfo::create(...) to safely construct an instance from tensor descriptors.
*/
struct ElementwiseInfo {
private:
std::vector<size_t> _meta;
size_t _output_size;
size_t _input_size;
size_t _ndim;
bool _output_contiguous;
ElementwiseInfo(std::vector<size_t> meta,
size_t output_size,
size_t input_size,
size_t ndim,
bool output_contiguous)
: _meta(std::move(meta)), _output_size(output_size),
_input_size(input_size), _ndim(ndim),
_output_contiguous(output_contiguous) {}
public:
inline size_t getMetaMemSize() const {
return _meta.size();
}
inline const int8_t *getMetaStart() const {
return reinterpret_cast<const int8_t *>(_meta.data());
}
inline size_t getOutputSize() const {
return _output_size;
}
inline size_t getInputSize() const {
return _input_size;
}
inline size_t getNdim() const {
return _ndim;
}
inline bool isOutputContiguous() const {
return _output_contiguous;
}
inline const size_t *getOutputShape() const {
return reinterpret_cast<const size_t *>(_meta.data());
}
inline const ptrdiff_t *getOutputStrides() const {
return reinterpret_cast<const ptrdiff_t *>(getOutputShape() + _ndim);
}
inline const size_t *getAllInputShapes() const {
return reinterpret_cast<const size_t *>(getOutputStrides() + _ndim);
}
inline const size_t *getInputShape(const size_t &index) const {
if (index < _input_size) {
return reinterpret_cast<const size_t *>(getAllInputShapes() + index * _ndim);
}
return nullptr;
}
inline const ptrdiff_t *getAllInputStrides() const {
return reinterpret_cast<const ptrdiff_t *>(getAllInputShapes() + _input_size * _ndim);
}
inline const ptrdiff_t *getInputStrides(const size_t &index) const {
if (index < _input_size) {
return reinterpret_cast<const ptrdiff_t *>(getAllInputStrides() + index * _ndim);
}
return nullptr;
}
inline const bool *getInputContiguous() const {
return reinterpret_cast<const bool *>(getAllInputStrides() + _input_size * _ndim);
}
inline const bool *getInputBroadcasted() const {
return reinterpret_cast<const bool *>(getInputContiguous() + _input_size);
}
using ResultType = utils::Result<ElementwiseInfo>;
/**
* @brief Construct ElementwiseInfo from output and input tensor descriptors.
* @param output_desc Descriptor of the output tensor.
* @param input_descs Descriptors of the input tensors.
* @return Result<ElementwiseInfo> with the successfully constructed ElementwiseInfo,
* or the status code.
*/
static ResultType create(
infiniopTensorDescriptor_t output_desc,
std::vector<infiniopTensorDescriptor_t> input_descs) {
if (!output_desc || input_descs.empty()) {
return INFINI_STATUS_BAD_PARAM;
}
// Destination cannot have broadcast setup
if (output_desc->hasBroadcastDim()) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
auto input_size = input_descs.size();
auto ndim = output_desc->ndim();
auto output_size = output_desc->numel();
auto output_contiguous = output_desc->isContiguous();
// Allocate memory for meta
auto shape_unit = output_desc->dim(0);
auto stride_unit = output_desc->stride(0);
size_t meta_mem_size = ndim * (sizeof(shape_unit) + sizeof(stride_unit))
+ input_size * ndim * sizeof(shape_unit)
+ input_size * ndim * sizeof(stride_unit)
+ 2 * input_size * sizeof(bool);
std::vector<size_t> meta(meta_mem_size);
int8_t *meta_ptr = reinterpret_cast<int8_t *>(meta.data());
const auto output_shape = output_desc->shape();
const auto output_strides = output_desc->strides();
// Pointers to the sections within _meta
size_t *output_shape_p = reinterpret_cast<size_t *>(meta_ptr);
ptrdiff_t *output_strides_p = reinterpret_cast<ptrdiff_t *>(output_shape_p + ndim);
size_t *input_shapes = reinterpret_cast<size_t *>(output_strides_p + ndim);
ptrdiff_t *input_strides = reinterpret_cast<ptrdiff_t *>(input_shapes + input_size * ndim);
bool *input_contiguous = reinterpret_cast<bool *>(input_strides + input_size * ndim);
bool *input_broadcasted = input_contiguous + input_size;
// Copy output shape and strides
std::memcpy(output_shape_p, output_shape.data(), ndim * sizeof(*output_shape_p));
std::memcpy(output_strides_p, output_strides.data(), ndim * sizeof(*output_strides_p));
// Copy input shapes, strides, contiguous, and broadcasted flags
for (size_t i = 0; i < input_size; ++i) {
auto &desc = input_descs[i];
const auto in_shape = desc->shape();
const auto in_strides = desc->strides();
std::memcpy(input_shapes + i * ndim, in_shape.data(), ndim * sizeof(*input_shapes));
std::memcpy(input_strides + i * ndim, in_strides.data(), ndim * sizeof(*input_strides));
input_contiguous[i] = desc->isContiguous();
input_broadcasted[i] = !input_contiguous[i] && (desc->ndim() != ndim || desc->hasBroadcastDim());
}
ElementwiseInfo info(std::move(meta), output_size, input_size, ndim, output_contiguous);
return ResultType(std::move(info));
}
};
} // namespace op::elementwise
#endif // __INFINIOP_ELEMENTWISE_H__
...@@ -8,11 +8,13 @@ infiniStatus_t Descriptor::create( ...@@ -8,11 +8,13 @@ infiniStatus_t Descriptor::create(
infiniopHandle_t handle_, infiniopHandle_t handle_,
Descriptor **desc_ptr, Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc, infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t up_desc, std::vector<infiniopTensorDescriptor_t> input_desc_vec) {
infiniopTensorDescriptor_t gate_desc) {
auto handle = reinterpret_cast<device::cpu::Handle *>(handle_); auto handle = reinterpret_cast<device::cpu::Handle *>(handle_);
auto dtype = out_desc->dtype(); auto dtype = out_desc->dtype();
const auto &up_desc = input_desc_vec.at(0);
const auto &gate_desc = input_desc_vec.at(1);
const auto &out_shape = out_desc->shape(); const auto &out_shape = out_desc->shape();
const auto &up_shape = up_desc->shape(); const auto &up_shape = up_desc->shape();
const auto &gate_shape = gate_desc->shape(); const auto &gate_shape = gate_desc->shape();
...@@ -21,36 +23,26 @@ infiniStatus_t Descriptor::create( ...@@ -21,36 +23,26 @@ infiniStatus_t Descriptor::create(
CHECK_SAME_SHAPE(out_shape, up_shape, gate_shape); CHECK_SAME_SHAPE(out_shape, up_shape, gate_shape);
op::binary::BinaryInfo info; // create CPU elementwise descriptor
CHECK_STATUS(op::binary::createBinaryInfo(info, out_desc, up_desc, gate_desc)); CREATE_ELEMENTWISE_CPU_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec);
// Create descriptor
*desc_ptr = new Descriptor(
dtype,
std::move(info),
nullptr,
handle->device,
handle->device_id);
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
infiniStatus_t Descriptor::calculate( infiniStatus_t Descriptor::calculate(
void *c, void *workspace,
const void *a, size_t workspace_size,
const void *b, void *output,
std::vector<const void *> inputs,
void *stream) const { void *stream) const {
switch (_dtype) { switch (_dtype) {
case INFINI_DTYPE_F16: case INFINI_DTYPE_F16:
op::common_cpu::binary_op::calculate<fp16_t, SwiGLUOp>(_info, c, a, b); return _device_info->calculate<SwiGLUOp, fp16_t>(_info, output, inputs, stream);
break;
case INFINI_DTYPE_F32: case INFINI_DTYPE_F32:
op::common_cpu::binary_op::calculate<float, SwiGLUOp>(_info, c, a, b); return _device_info->calculate<SwiGLUOp, float>(_info, output, inputs, stream);
break;
case INFINI_DTYPE_F64: case INFINI_DTYPE_F64:
op::common_cpu::binary_op::calculate<double, SwiGLUOp>(_info, c, a, b); return _device_info->calculate<SwiGLUOp, double>(_info, output, inputs, stream);
break;
default: default:
return INFINI_STATUS_BAD_TENSOR_DTYPE; return INFINI_STATUS_BAD_TENSOR_DTYPE;
} }
......
#ifndef __SWIGLU_CPU_H__ #ifndef __SWIGLU_CPU_H__
#define __SWIGLU_CPU_H__ #define __SWIGLU_CPU_H__
#include "../../../binary/cpu/binary_cpu.h" #include "../../../elementwise/cpu/elementwise_cpu.h"
BINARY_DESCRIPTOR(swiglu, cpu) ELEMENTWISE_DESCRIPTOR(swiglu, cpu)
struct SwiGLUOp { namespace op::swiglu::cpu {
typedef struct SwiGLUOp {
private: private:
template <typename T> template <typename T>
T sigmoid(const T &x) const { T sigmoid(const T &x) const {
return 1 / (1 + std::exp(-x)); return T(1) / (T(1) + std::exp(-x));
} }
public: public:
static constexpr size_t num_inputs = 2;
template <typename T> template <typename T>
T operator()(const T &up, const T &gate) const { T operator()(const T &up, const T &gate) const {
return gate * sigmoid(gate) * up; return gate * sigmoid(gate) * up;
} }
}; } SwiGLUOp;
} // namespace op::swiglu::cpu
#endif // __SWIGLU_CPU_H__ #endif // __SWIGLU_CPU_H__
#include "swiglu_cuda.cuh"
#include "swiglu_cuda_internal.cuh"
namespace op::swiglu::cuda {
Descriptor::~Descriptor() = default;
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {
auto handle = reinterpret_cast<device::cuda::Handle *>(handle_);
auto dtype = out_desc->dtype();
const auto &up_desc = input_desc_vec.at(0);
const auto &gate_desc = input_desc_vec.at(1);
const auto &out_shape = out_desc->shape();
const auto &up_shape = up_desc->shape();
const auto &gate_shape = gate_desc->shape();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64);
CHECK_SAME_SHAPE(out_shape, up_shape, gate_shape);
// create CUDA elementwise descriptor
CREATE_ELEMENTWISE_CUDA_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec)
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *output,
std::vector<const void *> inputs,
void *stream) const {
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
switch (_dtype) {
case INFINI_DTYPE_F16:
return _device_info->calculate<256, SwiGLUOp, half>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F32:
return _device_info->calculate<256, SwiGLUOp, float>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F64:
return _device_info->calculate<256, SwiGLUOp, double>(_info, workspace, output, inputs, stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::swiglu::cuda
#ifndef __SWIGLU_CUDA_API_H__
#define __SWIGLU_CUDA_API_H__
#include "../../../elementwise/cuda/elementwise_cuda_api.cuh"
ELEMENTWISE_DESCRIPTOR(swiglu, cuda)
#endif // __SWIGLU_CUDA_API_H__
#ifndef __SWIGLU_CUDA_H__
#define __SWIGLU_CUDA_H__
#include "../../../elementwise/cuda/elementwise_cuda.cuh"
#include <cuda_fp16.h>
namespace op::swiglu::cuda {
typedef struct SwiGLUOp {
private:
template <typename T>
__device__ __forceinline__ T sigmoid(const T &x) const {
if constexpr (std::is_same_v<T, half2>) {
return h2rcp(__hadd2(make_half2(1, 1), h2exp(__hneg2(x))));
} else if constexpr (std::is_same_v<T, half>) {
return hrcp(__hadd(half(1.f), __float2half(__expf(__half2float(__hneg(x))))));
} else if constexpr (std::is_same_v<T, float>) {
return __frcp_rd(__fadd_rd(1, __expf(-x)));
} else {
return 1 / (1 + std::exp(-x));
}
}
public:
static constexpr size_t num_inputs = 2;
template <typename T>
__device__ __forceinline__ T operator()(const T &up, const T &gate) const {
if constexpr (std::is_same_v<T, half2>) {
return __hmul2(__hmul2(gate, sigmoid(gate)), up);
} else if constexpr (std::is_same_v<T, half>) {
return __hmul(__hmul(gate, sigmoid(gate)), up);
} else if constexpr (std::is_same_v<T, float>) {
return __fmul_rd(__fmul_rd(gate, sigmoid(gate)), up);
} else {
return gate * sigmoid(gate) * up;
}
}
} SwiGLUOp;
} // namespace op::swiglu::cuda
#endif // __SWIGLU_CUDA_H__
...@@ -5,6 +5,9 @@ ...@@ -5,6 +5,9 @@
#ifdef ENABLE_CPU_API #ifdef ENABLE_CPU_API
#include "cpu/swiglu_cpu.h" #include "cpu/swiglu_cpu.h"
#endif #endif
#ifdef ENABLE_CUDA_API
#include "cuda/swiglu_cuda.cuh"
#endif
__C infiniStatus_t infiniopCreateSwiGLUDescriptor( __C infiniStatus_t infiniopCreateSwiGLUDescriptor(
infiniopHandle_t handle, infiniopHandle_t handle,
...@@ -19,19 +22,16 @@ __C infiniStatus_t infiniopCreateSwiGLUDescriptor( ...@@ -19,19 +22,16 @@ __C infiniStatus_t infiniopCreateSwiGLUDescriptor(
handle, \ handle, \
reinterpret_cast<op::swiglu::NAMESPACE::Descriptor **>(desc_ptr), \ reinterpret_cast<op::swiglu::NAMESPACE::Descriptor **>(desc_ptr), \
c_desc, \ c_desc, \
a_desc, \ {a_desc, \
b_desc) b_desc})
switch (handle->device) { switch (handle->device) {
#ifdef ENABLE_CPU_API #ifdef ENABLE_CPU_API
CREATE(INFINI_DEVICE_CPU, cpu); CREATE(INFINI_DEVICE_CPU, cpu);
#endif #endif
#ifdef ENABLE_NV_GPU #ifdef ENABLE_CUDA_API
case DevNvGpu: CREATE(INFINI_DEVICE_NVIDIA, cuda);
return cudaCreateSwiGLUDescriptor((CudaHandle_t)handle,
(SwiGLUCudaDescriptor_t *)desc_ptr,
c_desc, a_desc, b_desc);
#endif #endif
#ifdef ENABLE_CAMBRICON_MLU #ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: { case DevCambriconMlu: {
...@@ -66,8 +66,49 @@ __C infiniStatus_t infiniopCreateSwiGLUDescriptor( ...@@ -66,8 +66,49 @@ __C infiniStatus_t infiniopCreateSwiGLUDescriptor(
#undef CREATE #undef CREATE
} }
__C infiniStatus_t infiniopGetSwiGLUWorkspaceSize(infiniopSwiGLUDescriptor_t desc, size_t *size) {
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<op::swiglu::NAMESPACE::Descriptor *>(desc)->workspaceSize(); \
return INFINI_STATUS_SUCCESS;
switch (desc->device_type) {
#ifdef ENABLE_CPU_API
GET(INFINI_DEVICE_CPU, cpu)
#endif
#ifdef ENABLE_CUDA_API
GET(INFINI_DEVICE_NVIDIA, cuda)
#endif
#ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: {
return bangGetSwiGLUWorkspaceSize((SwiGLUBangDescriptor_t)desc, size);
}
#endif
#ifdef ENABLE_ASCEND_API
GET(INFINI_DEVICE_ASCEND, ascend)
#endif
#ifdef ENABLE_METAX_GPU
case DevMetaxGpu: {
return macaGetSwiGLUWorkspaceSize((SwiGLUMacaDescriptor_t)desc, size);
}
#endif
#ifdef ENABLE_MTHREADS_GPU
case DevMthreadsGpu: {
return musaGetSwiGLUWorkspaceSize((SwiGLUMusaDescriptor_t)desc, size);
}
#endif
}
#undef GET
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
__C infiniStatus_t infiniopSwiGLU( __C infiniStatus_t infiniopSwiGLU(
infiniopSwiGLUDescriptor_t desc, infiniopSwiGLUDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *c, void *c,
const void *a, const void *a,
const void *b, const void *b,
...@@ -76,16 +117,15 @@ __C infiniStatus_t infiniopSwiGLU( ...@@ -76,16 +117,15 @@ __C infiniStatus_t infiniopSwiGLU(
#define CALCULATE(CASE, NAMESPACE) \ #define CALCULATE(CASE, NAMESPACE) \
case CASE: \ case CASE: \
return reinterpret_cast<const op::swiglu::NAMESPACE::Descriptor *>(desc) \ return reinterpret_cast<const op::swiglu::NAMESPACE::Descriptor *>(desc) \
->calculate(c, a, b, stream) ->calculate(workspace, workspace_size, c, {a, b}, stream)
switch (desc->device_type) { switch (desc->device_type) {
#ifdef ENABLE_CPU_API #ifdef ENABLE_CPU_API
CALCULATE(INFINI_DEVICE_CPU, cpu); CALCULATE(INFINI_DEVICE_CPU, cpu);
#endif #endif
#ifdef ENABLE_NV_GPU #ifdef ENABLE_CUDA_API
case DevNvGpu: CALCULATE(INFINI_DEVICE_NVIDIA, cuda);
return cudaSwiGLU((SwiGLUCudaDescriptor_t)desc, c, a, b, stream);
#endif #endif
#ifdef ENABLE_CAMBRICON_MLU #ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: { case DevCambriconMlu: {
...@@ -125,9 +165,8 @@ infiniopDestroySwiGLUDescriptor(infiniopSwiGLUDescriptor_t desc) { ...@@ -125,9 +165,8 @@ infiniopDestroySwiGLUDescriptor(infiniopSwiGLUDescriptor_t desc) {
#ifdef ENABLE_CPU_API #ifdef ENABLE_CPU_API
DELETE(INFINI_DEVICE_CPU, cpu); DELETE(INFINI_DEVICE_CPU, cpu);
#endif #endif
#ifdef ENABLE_NV_GPU #ifdef ENABLE_CUDA_API
case DevNvGpu: DELETE(INFINI_DEVICE_NVIDIA, cuda);
return cudaDestroySwiGLUDescriptor((SwiGLUCudaDescriptor_t)desc);
#endif #endif
#ifdef ENABLE_CAMBRICON_MLU #ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: { case DevCambriconMlu: {
......
...@@ -98,4 +98,6 @@ inline std::string infiniDtypeToString(infiniDtype_t dtype) { ...@@ -98,4 +98,6 @@ inline std::string infiniDtypeToString(infiniDtype_t dtype) {
} }
} }
#define CEIL_DIV(x, y) (((x) + (y)-1) / (y))
#endif #endif
import torch import torch
import ctypes import ctypes
from ctypes import POINTER, Structure, c_int32, c_void_p from ctypes import POINTER, Structure, c_int32, c_void_p, c_uint64
from libinfiniop import ( from libinfiniop import (
infiniopHandle_t, infiniopHandle_t,
infiniopTensorDescriptor_t, infiniopTensorDescriptor_t,
...@@ -14,6 +14,7 @@ from libinfiniop import ( ...@@ -14,6 +14,7 @@ from libinfiniop import (
debug, debug,
get_tolerance, get_tolerance,
profile_operation, profile_operation,
create_workspace
) )
from enum import Enum, auto from enum import Enum, auto
...@@ -25,8 +26,10 @@ _TEST_CASES_ = [ ...@@ -25,8 +26,10 @@ _TEST_CASES_ = [
# shape, a_stride, b_stride, c_stride # shape, a_stride, b_stride, c_stride
((13, 4), None, None, None), ((13, 4), None, None, None),
((13, 4), (10, 1), (10, 1), (10, 1)), ((13, 4), (10, 1), (10, 1), (10, 1)),
((13, 4), (0, 1), None, None),
((13, 4, 4), None, None, None), ((13, 4, 4), None, None, None),
((13, 4, 4), (20, 4, 1), (20, 4, 1), (20, 4, 1)), ((13, 4, 4), (20, 4, 1), (20, 4, 1), (20, 4, 1)),
((13, 4, 4), (4, 0, 1), (0, 4, 1), None),
((16, 5632), None, None, None), ((16, 5632), None, None, None),
((16, 5632), (13312, 1), (13312, 1), (13312, 1)), ((16, 5632), (13312, 1), (13312, 1), (13312, 1)),
((4, 4, 5632), None, None, None), ((4, 4, 5632), None, None, None),
...@@ -78,6 +81,38 @@ def swiglu(a, b): ...@@ -78,6 +81,38 @@ def swiglu(a, b):
return a * b / (1 + torch.exp(-b.float()).to(b.dtype)) return a * b / (1 + torch.exp(-b.float()).to(b.dtype))
def process_tensors(c, c_strides, a, a_stride, b, b_stride, inplace):
"""
rearrange the tensors if needed and apply the inplace config.
if inplace is true and the output (i.e., c) is placed to the broadcasted input,
the inplace config is ignored and out-of-place is used
"""
original_c_strides = c_strides if c_strides else c.stride()
def _rearrange(tensor, strides):
if strides and 0 in strides:
tensor.set_(tensor.untyped_storage(), 0, tensor.shape, strides)
return tensor
else:
return rearrange_if_needed(tensor, strides)
a, b, c = [
_rearrange(tensor, stride)
for tensor, stride in zip([a, b, c], [a_stride, b_stride, c_strides])
]
c = (
c
if inplace == Inplace.OUT_OF_PLACE
else (a if inplace == Inplace.INPLACE_A else b)
)
# if inplace is true and c has broadcasted config, reset it to the original unbroadcasted strides
if 0 in c.stride():
c.set_(c.untyped_storage(), 0, c.shape, original_c_strides)
return a, b, c
def test( def test(
lib, lib,
handle, handle,
...@@ -98,18 +133,10 @@ def test( ...@@ -98,18 +133,10 @@ def test(
a = torch.rand(shape, dtype=dtype).to(torch_device) a = torch.rand(shape, dtype=dtype).to(torch_device)
b = torch.rand(shape, dtype=dtype).to(torch_device) b = torch.rand(shape, dtype=dtype).to(torch_device)
c = torch.rand(shape, dtype=dtype).to(torch_device) c = torch.rand(shape, dtype=dtype).to(torch_device)
a, b, c = process_tensors(c, c_stride, a, a_stride, b, b_stride, inplace)
ans = swiglu(a, b) ans = swiglu(a, b)
a, b, c = [
rearrange_if_needed(tensor, stride)
for tensor, stride in zip([a, b, c], [a_stride, b_stride, c_stride])
]
c = (
c
if inplace == Inplace.OUT_OF_PLACE
else (a if inplace == Inplace.INPLACE_A else b)
)
a_tensor, b_tensor = [to_tensor(tensor, lib) for tensor in [a, b]] a_tensor, b_tensor = [to_tensor(tensor, lib) for tensor in [a, b]]
c_tensor = ( c_tensor = (
to_tensor(c, lib) to_tensor(c, lib)
...@@ -134,10 +161,19 @@ def test( ...@@ -134,10 +161,19 @@ def test(
for tensor in [a_tensor, b_tensor, c_tensor]: for tensor in [a_tensor, b_tensor, c_tensor]:
tensor.destroyDesc(lib) tensor.destroyDesc(lib)
workspace_size = c_uint64(0)
check_error(
lib.infiniopGetSwiGLUWorkspaceSize(descriptor, ctypes.byref(workspace_size))
)
workspace = create_workspace(workspace_size.value, c.device)
def lib_swiglu(): def lib_swiglu():
check_error( check_error(
lib.infiniopSwiGLU( lib.infiniopSwiGLU(
descriptor, c_tensor.data, a_tensor.data, b_tensor.data, None descriptor,
workspace.data_ptr() if workspace is not None else None,
workspace_size.value,
c_tensor.data, a_tensor.data, b_tensor.data, None
) )
) )
...@@ -170,10 +206,18 @@ if __name__ == "__main__": ...@@ -170,10 +206,18 @@ if __name__ == "__main__":
infiniopTensorDescriptor_t, infiniopTensorDescriptor_t,
] ]
lib.infiniopGetSwiGLUWorkspaceSize.restype = c_int32
lib.infiniopGetSwiGLUWorkspaceSize.argtypes = [
infiniopSwiGLUDescriptor_t,
POINTER(c_uint64),
]
lib.infiniopSwiGLU.restype = c_int32 lib.infiniopSwiGLU.restype = c_int32
lib.infiniopSwiGLU.argtypes = [ lib.infiniopSwiGLU.argtypes = [
infiniopSwiGLUDescriptor_t, infiniopSwiGLUDescriptor_t,
c_void_p, c_void_p,
c_uint64,
c_void_p,
c_void_p, c_void_p,
c_void_p, c_void_p,
c_void_p, c_void_p,
......
...@@ -28,6 +28,7 @@ target("infiniop-cuda") ...@@ -28,6 +28,7 @@ target("infiniop-cuda")
else else
add_cuflags("-Xcompiler=-Wall", "-Xcompiler=-Werror") add_cuflags("-Xcompiler=-Wall", "-Xcompiler=-Werror")
add_cuflags("-Xcompiler=-fPIC") add_cuflags("-Xcompiler=-fPIC")
add_cuflags("--extended-lambda")
add_culdflags("-Xcompiler=-fPIC") add_culdflags("-Xcompiler=-fPIC")
add_cxxflags("-fPIC") add_cxxflags("-fPIC")
end end
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment