Commit 9cc0c416 authored by Zimin Li's avatar Zimin Li
Browse files

issue/127: Refactor ElementwiseInfo, refactor elementwise to use workspace for...

issue/127: Refactor ElementwiseInfo, refactor elementwise to use workspace for storing meta, fix misc. issues
parent 40fdded5
......@@ -11,7 +11,11 @@ __C __export infiniStatus_t infiniopCreateSwiGLUDescriptor(infiniopHandle_t hand
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc);
__C __export infiniStatus_t infiniopGetSwiGLUWorkspaceSize(infiniopSwiGLUDescriptor_t desc, size_t *size);
__C __export infiniStatus_t infiniopSwiGLU(infiniopSwiGLUDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *c,
void const *a,
void const *b,
......
......@@ -9,6 +9,10 @@
#define CUDA_BLOCK_SIZE_1024 1024
#define CUDA_BLOCK_SIZE_512 512
#define CHECK_CUDA(API) CHECK_INTERNAL(API, cudaSuccess)
namespace device::cuda {
// return the memory offset of original tensor, given the flattened index of broadcasted tensor
__forceinline__ __device__ __host__ size_t
indexToReducedOffset(
......@@ -38,6 +42,7 @@ indexToOffset(
}
return res;
}
} // namespace device::cuda
#ifdef ENABLE_CUDA_API
#include <cuda_fp16.h>
......
......@@ -18,6 +18,7 @@
dtype, \
info_result.take(), \
nullptr, \
0, \
handle->device, \
handle->device_id);
......@@ -103,24 +104,34 @@ infiniStatus_t DeviceImpl::create(DeviceImpl **device_info, Args &&...args) {
}
// Perform elementwise operation for different input types
template <typename Op, typename Tout, typename... Tin, size_t... Is, typename... Args, std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int> = 0>
void calculate_impl(const op::elementwise::ElementwiseInfo &info, void *output, const std::vector<const void *> &inputs, std::index_sequence<Is...>, Args &&...args) {
template <typename Op, typename Tout, typename... Tin, size_t... Is, typename... Args,
std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int> = 0>
void calculate_impl(const op::elementwise::ElementwiseInfo &info,
void *output,
const std::vector<const void *> &inputs,
std::index_sequence<Is...>,
Args &&...args) {
Tout *out = reinterpret_cast<Tout *>(output);
std::tuple<const Tin *...> input_ptrs = {reinterpret_cast<const Tin *>(inputs[Is])...};
ptrdiff_t output_size = info.output_size;
ptrdiff_t output_size = info.getOutputSize();
#pragma omp parallel for
for (ptrdiff_t i = 0; i < output_size; ++i) {
size_t out_idx = info.output_contiguous ? i : op::common_cpu::indexToOffset(i, info.ndim, info.output_shape, info.output_strides);
size_t out_idx = info.isOutputContiguous()
? i
: op::common_cpu::indexToOffset(i, info.getNdim(), info.getOutputShape(), info.getOutputStrides());
auto get_input_idx = [&](size_t input_id) {
return info.input_contiguous[input_id] ? i
: (info.input_broadcasted[input_id]
? op::common_cpu::indexToReducedOffset(i, info.ndim, info.output_strides, info.input_strides[input_id])
: op::common_cpu::indexToOffset(i, info.ndim, info.input_shapes[input_id], info.input_strides[input_id]));
return info.getInputContiguous()[input_id]
? i
: (info.getInputBroadcasted()[input_id]
? op::common_cpu::indexToReducedOffset(i, info.getNdim(), info.getOutputStrides(), info.getInputStrides(input_id))
: op::common_cpu::indexToOffset(i, info.getNdim(), info.getInputShape(input_id), info.getInputStrides(input_id)));
};
out[out_idx] = utils::cast<Tout>(Op{}.template operator()<Tout, Tin...>(std::get<Is>(input_ptrs)[get_input_idx(Is)]..., std::forward<Args>(args)...));
out[out_idx] = utils::cast<Tout>(
Op{}.template operator()<Tout, Tin...>(std::get<Is>(input_ptrs)[get_input_idx(Is)]..., std::forward<Args>(args)...));
}
}
......@@ -147,17 +158,20 @@ void calculate_impl(const op::elementwise::ElementwiseInfo &info,
Tdata *out = reinterpret_cast<Tdata *>(output);
std::array<const Tdata *, sizeof...(Is)> ins = {reinterpret_cast<const Tdata *>(inputs[Is])...};
const ptrdiff_t output_size = info.output_size;
const ptrdiff_t output_size = info.getOutputSize();
#pragma omp parallel for
for (ptrdiff_t i = 0; i < output_size; ++i) {
size_t out_idx = info.output_contiguous ? i : op::common_cpu::indexToOffset(i, info.ndim, info.output_shape, info.output_strides);
size_t out_idx = info.isOutputContiguous()
? i
: op::common_cpu::indexToOffset(i, info.getNdim(), info.getOutputShape(), info.getOutputStrides());
auto get_input_idx = [&](size_t input_id) {
return info.input_contiguous[input_id] ? i
: (info.input_broadcasted[input_id]
? op::common_cpu::indexToReducedOffset(i, info.ndim, info.output_strides, info.input_strides[input_id])
: op::common_cpu::indexToOffset(i, info.ndim, info.input_shapes[input_id], info.input_strides[input_id]));
return info.getInputContiguous()[input_id]
? i
: (info.getInputBroadcasted()[input_id]
? op::common_cpu::indexToReducedOffset(i, info.getNdim(), info.getOutputStrides(), info.getInputStrides(input_id))
: op::common_cpu::indexToOffset(i, info.getNdim(), info.getInputShape(input_id), info.getInputStrides(input_id)));
};
if constexpr (std::is_same_v<Tdata, fp16_t>) {
......@@ -170,7 +184,11 @@ void calculate_impl(const op::elementwise::ElementwiseInfo &info,
// Invoke elementwise operation when all inputs have the same type
template <typename Op, typename Tdata, typename... Args>
infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info, void *output, const std::vector<const void *> &inputs, void *stream, Args &&...args) {
infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info,
void *output,
const std::vector<const void *> &inputs,
void *stream,
Args &&...args) {
constexpr size_t N = Op::num_inputs;
calculate_impl<Op, Tdata>(info, output, inputs, std::make_index_sequence<N>{}, std::forward<Args>(args)...);
return INFINI_STATUS_SUCCESS;
......
......@@ -3,6 +3,7 @@
#include "../../../utils.h"
#include "../../devices/cuda/cuda_common.cuh"
#include "../../devices/cuda/cuda_kernel_common.cuh"
#include "elementwise_cuda_api.cuh"
namespace op::elementwise::cuda {
......@@ -16,18 +17,17 @@ namespace op::elementwise::cuda {
* @param lambda Lambda to be called with std::integral_constant<size_t, Is>... as arguments.
*/
template <typename Lambda, size_t... Is>
__device__ __forceinline__ void call_expand(Lambda lambda, std::index_sequence<Is...>) {
__device__ __forceinline__ void callExpand(Lambda lambda, std::index_sequence<Is...>) {
lambda(std::integral_constant<size_t, Is>{}...);
}
/**
* @brief CUDA kernel for performing elementwise operations on tensors where all inputs share the same data type.
*
* @tparam N Number of input tensors.
* @tparam Op Operator type implementing operator()(Tdata...).
* @tparam Tdata Common data type for inputs and output.
* @tparam N Number of input tensors.
* @tparam Args Additional arguments to pass to the operator.
*
* @param output_size Total number of output elements.
* @param ndim Number of dimensions in tensors.
* @param output_contiguous Whether the output tensor is contiguous in memory.
......@@ -37,24 +37,22 @@ __device__ __forceinline__ void call_expand(Lambda lambda, std::index_sequence<I
* @param input_shapes Shapes of the input tensors.
* @param output_strides Strides for the output tensor.
* @param input_strides Strides for each input tensor.
* @param input_size Total number of input elements (optional, may be unused).
* @param output Output buffer.
* @param inputs Array of input pointers, all of type Tdata.
* @param offset Linear offset to support partitioned execution.
* @param args Additional arguments passed to the operator.
*/
template <size_t N, typename Op, typename Tdata, typename... Args>
INFINIOP_CUDA_KERNEL elementwise_kernel(
INFINIOP_CUDA_KERNEL elementwiseKernel(
size_t output_size,
size_t ndim,
bool output_contiguous,
const bool *__restrict__ input_contiguous,
const bool *__restrict__ input_broadcasted,
const size_t *__restrict__ output_shape,
const size_t *__restrict__ *__restrict__ input_shapes,
const size_t *__restrict__ input_shapes,
const ptrdiff_t *__restrict__ output_strides,
const ptrdiff_t *__restrict__ *__restrict__ input_strides,
size_t input_size,
const ptrdiff_t *__restrict__ input_strides,
Tdata *output,
const Tdata *const *inputs,
size_t offset,
......@@ -68,8 +66,8 @@ INFINIOP_CUDA_KERNEL elementwise_kernel(
auto get_input_idx = [&] __device__(size_t input_id) {
return input_contiguous[input_id] ? idx
: (input_broadcasted[input_id]
? device::cuda::indexToReducedOffset(idx, ndim, output_strides, input_strides[input_id])
: device::cuda::indexToOffset(idx, ndim, input_shapes[input_id], input_strides[input_id]));
? device::cuda::indexToReducedOffset(idx, ndim, output_strides, input_strides + input_id * ndim)
: device::cuda::indexToOffset(idx, ndim, input_shapes + input_id * ndim, input_strides + input_id * ndim));
};
// Use a helper to expand the index sequence into individual compile-time constants
......@@ -85,7 +83,7 @@ INFINIOP_CUDA_KERNEL elementwise_kernel(
}
};
call_expand(expand_inputs, std::make_index_sequence<N>{});
callExpand(expand_inputs, std::make_index_sequence<N>{});
}
}
......@@ -97,38 +95,38 @@ INFINIOP_CUDA_KERNEL elementwise_kernel(
* @return Pointer of type const T*.
*/
template <typename T>
__device__ inline const T *typed_input_ptr(const void *ptr) {
__device__ inline const T *typedInputPtr(const void *ptr) {
return reinterpret_cast<const T *>(ptr);
}
/**
* @brief Launches a type-safe elementwise operation on a single output element.
*
* @tparam Op Operator type implementing a templated operator() for (Tout, Tin...).
* @tparam Tout Output data type.
* @tparam Tin Variadic input data types.
* @tparam Is Index sequence corresponding to each input.
* @brief Launches elementwise operation at a specific output index.
*
* @param idx Linear index in the flattened output space.
* @param out_idx Actual output index (may be non-contiguous).
* @param ndim Number of dimensions in the tensors.
* @param input_contiguous Array indicating whether each input is contiguous.
* @param input_broadcasted Array indicating whether each input is broadcasted.
* @param input_shapes Shapes of the input tensors.
* @param input_strides Strides of the input tensors.
* @param inputs Raw pointers to input data.
* @param output Pointer to output data.
* @param ... Index sequence used for unpacking variadic inputs.
* @tparam Op Functor representing the elementwise operation.
* @tparam Tout Output data type.
* @tparam Tin... Input data types.
* @tparam Is... Index sequence for unpacking variadic inputs.
* @param idx Global linear index into the output tensor.
* @param out_idx Offset into the output array.
* @param ndim Number of dimensions in the tensors.
* @param input_contiguous Flags indicating whether each input is contiguous.
* @param input_broadcasted Flags indicating whether each input is broadcasted.
* @param input_shapes Flattened input shapes (N * ndim).
* @param input_strides Flattened input strides (N * ndim).
* @param output_strides Output tensor strides.
* @param inputs Array of pointers to input tensors.
* @param output Pointer to output tensor.
* @param ...Is Index sequence for iterating over input tensors.
*/
template <typename Op, typename Tout, typename... Tin, size_t... Is>
__device__ void launch_op(
__device__ void launchOp(
size_t idx,
size_t out_idx,
size_t ndim,
const bool *__restrict__ input_contiguous,
const bool *__restrict__ input_broadcasted,
const size_t *__restrict__ const *__restrict__ input_shapes,
const ptrdiff_t *__restrict__ const *__restrict__ input_strides,
const size_t *__restrict__ input_shapes,
const ptrdiff_t *__restrict__ input_strides,
const ptrdiff_t *__restrict__ output_strides,
const void *const *__restrict__ inputs,
Tout *output,
......@@ -138,12 +136,12 @@ __device__ void launch_op(
return input_contiguous[input_id]
? idx
: (input_broadcasted[input_id]
? device::cuda::indexToReducedOffset(idx, ndim, output_strides, input_strides[input_id])
: device::cuda::indexToOffset(idx, ndim, input_shapes[input_id], input_strides[input_id]));
? device::cuda::indexToReducedOffset(idx, ndim, output_strides, input_strides + input_id * ndim)
: device::cuda::indexToOffset(idx, ndim, input_shapes + input_id * ndim, input_strides + input_id * ndim));
};
output[out_idx] = Op{}.template operator()<Tout, Tin...>(
(typed_input_ptr<Tin>(inputs[Is])[get_input_idx(Is)])...);
(typedInputPtr<Tin>(inputs[Is])[get_input_idx(Is)])...);
}
/**
......@@ -153,7 +151,6 @@ __device__ void launch_op(
* @tparam Op Operator type implementing a templated operator() for (Tout, Tin...).
* @tparam Tout Output data type.
* @tparam Tin Variadic input data types.
*
* @param output_size Total number of output elements.
* @param ndim Number of dimensions in the tensors.
* @param output_contiguous Whether the output tensor is contiguous.
......@@ -163,23 +160,21 @@ __device__ void launch_op(
* @param input_shapes Shapes of the input tensors.
* @param output_strides Strides of the output tensor.
* @param input_strides Strides of the input tensors.
* @param input_size Total number of input elements (unused here, but may be used for validation).
* @param output Pointer to the output buffer.
* @param inputs Array of untyped input pointers.
* @param offset Linear offset into the output for partitioned execution.
*/
template <typename Op, typename Tout, typename... Tin>
INFINIOP_CUDA_KERNEL elementwise_kernel(
INFINIOP_CUDA_KERNEL elementwiseKernel(
size_t output_size,
size_t ndim,
bool output_contiguous,
const bool *__restrict__ input_contiguous,
const bool *__restrict__ input_broadcasted,
const size_t *__restrict__ output_shape,
const size_t *__restrict__ const *__restrict__ input_shapes,
const size_t *__restrict__ input_shapes,
const ptrdiff_t *__restrict__ output_strides,
const ptrdiff_t *__restrict__ const *__restrict__ input_strides,
size_t input_size,
const ptrdiff_t *__restrict__ input_strides,
Tout *output,
const void *const *__restrict__ inputs,
size_t offset) {
......@@ -193,7 +188,7 @@ INFINIOP_CUDA_KERNEL elementwise_kernel(
? idx
: device::cuda::indexToOffset(idx, ndim, output_shape, output_strides);
launch_op<Op, Tout, Tin...>(
launchOp<Op, Tout, Tin...>(
idx,
out_idx,
ndim,
......@@ -214,252 +209,185 @@ struct DeviceImpl::Opaque {
: internal(internal) {}
/**
* @brief Performs elementwise operations when all inputs and the output share the same data type.
* @brief Executes an elementwise operation where all inputs and the output share the same data type.
*
* @tparam BLOCK_SIZE The block size for the kernel launch.
* @tparam N The number of input tensors.
* @tparam Op The operation to perform (e.g., addition, multiplication).
* @tparam Tdata The data type of the input and output tensors.
* @tparam Args Additional arguments to be passed to the operation.
* @param info Structure containing elementwise operation information (size, shape, etc.).
* @param output Pointer to the output memory where results will be stored.
* @param inputs Vector of pointers to input tensors.
* @param stream CUDA stream used for asynchronous execution.
* @param args Additional arguments for the operation.
* @return infiniStatus_t Status indicating success or failure.
* @tparam BLOCK_SIZE CUDA block size used for kernel launch.
* @tparam N Number of input tensors.
* @tparam Op Functor representing the elementwise operation.
* @tparam Tdata Data type of both input and output tensors.
* @tparam Args Optional additional arguments passed to the operation.
* @param info Metadata about the operation including shape, size, and dimensionality.
* @param workspace Temporary workspace used for storing metadata on device.
* @param output Pointer to the output buffer.
* @param inputs Vector of pointers to input buffers.
* @param stream CUDA stream for asynchronous execution.
* @param args Additional arguments forwarded to the operation.
* @return infiniStatus_t Returns success or failure status.
*/
template <size_t BLOCK_SIZE, size_t N, typename Op, typename Tdata, typename... Args, size_t... Is>
template <size_t BLOCK_SIZE, size_t N, typename Op, typename Tdata, typename... Args>
infiniStatus_t calculateImpl(const op::elementwise::ElementwiseInfo &info,
void *workspace,
void *output,
const std::vector<const void *> &inputs,
std::index_sequence<Is...>,
cudaStream_t stream,
Args &&...args) {
if (info.output_size == 0) {
auto output_size = info.getOutputSize();
if (output_size == 0) {
return INFINI_STATUS_SUCCESS;
}
// casting the output and the inputs to Tdata pointers
Tdata *out = reinterpret_cast<Tdata *>(output);
const Tdata *inputs_arr[N];
const Tdata **d_inputs_arr = nullptr;
for (size_t i = 0; i < N; ++i) {
inputs_arr[i] = reinterpret_cast<const Tdata *>(inputs[i]);
}
CHECK_CUDA(cudaMallocAsync(&d_inputs_arr, N * sizeof(*d_inputs_arr), stream));
CHECK_CUDA(cudaMemcpyAsync(d_inputs_arr, inputs_arr, N * sizeof(*d_inputs_arr), cudaMemcpyHostToDevice, stream));
const void **d_inputs_arr = nullptr;
// create and send the info to device
const bool *d_bools = nullptr;
const bool *d_input_contiguous = nullptr;
const bool *d_input_broadcasted = nullptr;
const int8_t *d_output_shape_strides = nullptr;
const size_t *d_output_shape = nullptr;
const ptrdiff_t *d_output_strides = nullptr;
const size_t **d_input_shapes = nullptr;
const ptrdiff_t **d_input_strides = nullptr;
std::vector<const size_t *> tmp_device_ptrs(info.input_size);
std::vector<const ptrdiff_t *> tmp_device_ptrs_strides(info.input_size);
const size_t *d_input_shapes = nullptr;
const ptrdiff_t *d_input_strides = nullptr;
CHECK_STATUS(infoToDevice<N>(info, d_bools, d_input_contiguous,
d_input_broadcasted, d_output_shape_strides, d_output_shape,
d_output_strides, tmp_device_ptrs, d_input_shapes, tmp_device_ptrs_strides,
d_input_strides, stream));
CHECK_STATUS(infoToDevice<N>(info, workspace, inputs.data(), d_inputs_arr, d_input_contiguous, d_input_broadcasted,
d_output_shape, d_output_strides, d_input_shapes, d_input_strides, stream));
dim3 blockDims(std::min(BLOCK_SIZE, static_cast<size_t>(internal->maxThreadsPerBlock())));
dim3 gridDims(std::min(CEIL_DIV(info.output_size, blockDims.x), static_cast<size_t>(internal->gridSizeX())));
dim3 gridDims(std::min(CEIL_DIV(output_size, blockDims.x), static_cast<size_t>(internal->gridSizeX())));
size_t step = gridDims.x * blockDims.x;
for (size_t i = 0; i < info.output_size; i += step) {
elementwise_kernel<N, Op, Tdata, Args...><<<gridDims, blockDims, 0, stream>>>(
info.output_size,
info.ndim,
info.output_contiguous,
for (size_t i = 0; i < output_size; i += step) {
elementwiseKernel<N, Op, Tdata, Args...><<<gridDims, blockDims, 0, stream>>>(
output_size,
info.getNdim(),
info.isOutputContiguous(),
d_input_contiguous,
d_input_broadcasted,
d_output_shape,
d_input_shapes,
d_output_strides,
d_input_strides,
info.input_size, out, d_inputs_arr, i, std::forward<Args>(args)...);
out, reinterpret_cast<const Tdata **>(d_inputs_arr), i, std::forward<Args>(args)...);
}
CHECK_STATUS(freeAllDevice((const void **)d_inputs_arr, d_bools, d_output_shape_strides,
info.input_size, d_input_shapes, d_input_strides, stream));
return INFINI_STATUS_SUCCESS;
}
/**
* @brief Performs elementwise operations when inputs and the outputs have mixed data types (i.e., different dtypes).
* @brief Executes an elementwise operation with mixed input and output data types.
*
* @tparam BLOCK_SIZE The block size for the kernel launch.
* @tparam N The number of input tensors.
* @tparam Op The operation to perform (e.g., addition, multiplication).
* @tparam Tout The output data type.
* @tparam Tin The input data types.
* @tparam Args Additional arguments to be passed to the operation.
* @param info Structure containing elementwise operation information (size, shape, etc.).
* @param output Pointer to the output memory where results will be stored.
* @param inputs Vector of pointers to input tensors.
* @param stream CUDA stream used for asynchronous execution.
* @param args Additional arguments for the operation.
* @return infiniStatus_t Status indicating success or failure.
* @tparam BLOCK_SIZE CUDA block size used for kernel launch.
* @tparam N Number of input tensors.
* @tparam Op Functor representing the elementwise operation.
* @tparam Tout Data type of the output tensor.
* @tparam Tin... Data types of the input tensors.
* @tparam Args Optional additional arguments passed to the operation.(UNUSED)
* @param info Metadata about the operation including shape, size, and dimensionality.
* @param workspace Temporary workspace used for storing metadata on device.
* @param output Pointer to the output buffer.
* @param inputs Vector of pointers to input buffers.
* @param stream CUDA stream for asynchronous execution.
* @param args Additional arguments forwarded to the operation.
* @return infiniStatus_t Returns success or failure status.
*/
template <size_t BLOCK_SIZE, size_t N, typename Op, typename Tout, typename... Tin, typename... Args, size_t... Is,
template <size_t BLOCK_SIZE, size_t N, typename Op, typename Tout, typename... Tin, typename... Args,
std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int> = 0>
infiniStatus_t calculateImpl(const op::elementwise::ElementwiseInfo &info,
void *workspace,
void *output,
const std::vector<const void *> &inputs,
std::index_sequence<Is...>,
cudaStream_t stream,
Args &&...args) {
if (info.output_size == 0) {
auto output_size = info.getOutputSize();
if (output_size == 0) {
return INFINI_STATUS_SUCCESS;
}
Tout *out = reinterpret_cast<Tout *>(output);
// Store input pointers with the correct types
const std::tuple<const Tin *...> inputs_arr{reinterpret_cast<const Tin *>(inputs[Is])...};
const void **d_inputs_arr = nullptr;
// Create array of input pointers on host (void*) to copy to device
const void *host_input_ptrs[] = {reinterpret_cast<const void *>(std::get<Is>(inputs_arr))...};
CHECK_CUDA(cudaMallocAsync(&d_inputs_arr, N * sizeof(void *), stream));
CHECK_CUDA(cudaMemcpyAsync(d_inputs_arr, host_input_ptrs, N * sizeof(void *), cudaMemcpyHostToDevice, stream));
// Device pointers
const bool *d_bools = nullptr;
const bool *d_input_contiguous = nullptr;
const bool *d_input_broadcasted = nullptr;
const int8_t *d_output_shape_strides = nullptr;
const size_t *d_output_shape = nullptr;
const ptrdiff_t *d_output_strides = nullptr;
const size_t **d_input_shapes = nullptr;
const ptrdiff_t **d_input_strides = nullptr;
std::vector<const size_t *> tmp_device_ptrs(info.input_size);
std::vector<const ptrdiff_t *> tmp_device_ptrs_strides(info.input_size);
const size_t *d_input_shapes = nullptr;
const ptrdiff_t *d_input_strides = nullptr;
CHECK_STATUS(infoToDevice<N>(info, d_bools, d_input_contiguous,
d_input_broadcasted, d_output_shape_strides, d_output_shape,
d_output_strides, tmp_device_ptrs, d_input_shapes, tmp_device_ptrs_strides,
d_input_strides, stream));
CHECK_STATUS(infoToDevice<N>(info, workspace, inputs.data(), d_inputs_arr, d_input_contiguous, d_input_broadcasted,
d_output_shape, d_output_strides, d_input_shapes, d_input_strides, stream));
dim3 blockDims(std::min(BLOCK_SIZE, static_cast<size_t>(internal->maxThreadsPerBlock())));
dim3 gridDims(std::min(CEIL_DIV(info.output_size, blockDims.x), static_cast<size_t>(internal->gridSizeX())));
dim3 gridDims(std::min(CEIL_DIV(output_size, blockDims.x), static_cast<size_t>(internal->gridSizeX())));
size_t step = gridDims.x * blockDims.x;
for (size_t i = 0; i < info.output_size; i += step) {
elementwise_kernel<Op, Tout, Tin...><<<gridDims, blockDims, 0, stream>>>(
info.output_size,
info.ndim,
info.output_contiguous,
for (size_t i = 0; i < output_size; i += step) {
elementwiseKernel<Op, Tout, Tin...><<<gridDims, blockDims, 0, stream>>>(
output_size,
info.getNdim(),
info.isOutputContiguous(),
d_input_contiguous,
d_input_broadcasted,
d_output_shape,
d_input_shapes,
d_output_strides,
d_input_strides,
info.input_size, out, reinterpret_cast<const void **>(d_inputs_arr), i);
out, reinterpret_cast<const void **>(d_inputs_arr), i);
}
CHECK_STATUS(freeAllDevice(d_inputs_arr, d_bools, d_output_shape_strides, info.input_size, d_input_shapes, d_input_strides, stream));
return INFINI_STATUS_SUCCESS;
}
private:
/**
* @brief Transfers elementwise kernel metadata (shapes, strides, flags) from host to device.
* @brief Transfers elementwise operation metadata and input pointers from host to device memory.
*
* @tparam N Number of inputs.
* @param info Structure containing input/output metadata.
* @param d_bools Device pointer for input_contiguous and input_broadcasted flags.
* @param d_input_contiguous Device pointer to input contiguity flags.
* @param d_input_broadcasted Device pointer to input broadcasting flags.
* @param d_output_shape_strides Device buffer containing both output shape and strides.
* @param d_output_shape Device pointer to output shape.
* @param d_output_strides Device pointer to output strides.
* @param tmp_device_ptrs Temporary device pointers for input shapes.
* @param d_input_shapes Device array of pointers to input shapes.
* @param tmp_device_ptrs_strides Temporary device pointers for input strides.
* @param d_input_strides Device array of pointers to input strides.
* @param stream CUDA stream for async allocation and transfers.
* @return infiniStatus_t Status indicating success or failure.
* @tparam N Number of input tensors.
* @param info Elementwise operation metadata (shapes, strides, flags, etc.).
* @param workspace Pointer to device workspace memory for storing metadata and input pointers.
* @param h_inputs_arr Host array of input tensor pointers.
* @param d_inputs_arr Output reference to device array of input tensor pointers.
* @param d_input_contiguous Output reference to device array indicating whether each input is contiguous.
* @param d_input_broadcasted Output reference to device array indicating whether each input is broadcasted.
* @param d_output_shape Output reference to device array holding the output tensor shape.
* @param d_output_strides Output reference to device array holding output tensor strides.
* @param d_input_shapes Output reference to flattened input tensor shapes (N * ndim).
* @param d_input_strides Output reference to flattened input tensor strides (N * ndim).
* @param stream CUDA stream used for asynchronous memory transfer.
* @return infiniStatus_t Status indicating success or failure of the memory transfer and setup.
*/
template <size_t N>
infiniStatus_t infoToDevice(
const op::elementwise::ElementwiseInfo &info,
const bool *&d_bools,
void *workspace,
const void *const *h_inputs_arr,
const void **&d_inputs_arr,
const bool *&d_input_contiguous,
const bool *&d_input_broadcasted,
const int8_t *&d_output_shape_strides,
const size_t *&d_output_shape,
const ptrdiff_t *&d_output_strides,
std::vector<const size_t *> &tmp_device_ptrs,
const size_t **&d_input_shapes,
std::vector<const ptrdiff_t *> &tmp_device_ptrs_strides,
const ptrdiff_t **&d_input_strides,
const size_t *&d_input_shapes,
const ptrdiff_t *&d_input_strides,
cudaStream_t stream) const {
CHECK_CUDA(cudaMallocAsync(&d_bools, 2 * info.input_size * sizeof(*d_bools), stream));
CHECK_CUDA(cudaMemcpyAsync((void *)d_bools, info.input_contiguous, info.input_size * sizeof(*d_bools), cudaMemcpyHostToDevice, stream));
CHECK_CUDA(cudaMemcpyAsync((void *)(d_bools + info.input_size), info.input_broadcasted, info.input_size * sizeof(*d_bools), cudaMemcpyHostToDevice, stream));
CHECK_CUDA(cudaMallocAsync(&d_output_shape_strides, info.ndim * (sizeof(*d_output_shape) + sizeof(*d_output_strides)), stream));
CHECK_CUDA(cudaMemcpyAsync((void *)d_output_shape_strides, info.output_shape, info.ndim * sizeof(*d_output_shape), cudaMemcpyHostToDevice, stream));
CHECK_CUDA(cudaMemcpyAsync((void *)(d_output_shape_strides + info.ndim * sizeof(*d_output_shape)), info.output_strides, info.ndim * sizeof(*d_output_strides), cudaMemcpyHostToDevice, stream));
CHECK_CUDA(cudaMallocAsync(&d_input_shapes, info.input_size * sizeof(*d_input_shapes), stream));
for (size_t i = 0; i < info.input_size; ++i) {
CHECK_CUDA(cudaMallocAsync(&tmp_device_ptrs[i], info.ndim * sizeof(*&tmp_device_ptrs[i]), stream));
CHECK_CUDA(cudaMemcpyAsync((void *)tmp_device_ptrs[i], info.input_shapes[i],
info.ndim * sizeof(*tmp_device_ptrs[i]), cudaMemcpyHostToDevice, stream));
}
CHECK_CUDA(cudaMemcpyAsync((void *)d_input_shapes, tmp_device_ptrs.data(),
info.input_size * sizeof(*d_input_shapes), cudaMemcpyHostToDevice, stream));
CHECK_CUDA(cudaMallocAsync(&d_input_strides, info.input_size * sizeof(*d_input_strides), stream));
for (size_t i = 0; i < info.input_size; ++i) {
CHECK_CUDA(cudaMallocAsync(&tmp_device_ptrs_strides[i], info.ndim * sizeof(*tmp_device_ptrs_strides[i]), stream));
CHECK_CUDA(cudaMemcpyAsync((void *)tmp_device_ptrs_strides[i], info.input_strides[i],
info.ndim * sizeof(*tmp_device_ptrs_strides[i]), cudaMemcpyHostToDevice, stream));
}
CHECK_CUDA(cudaMemcpyAsync((void *)d_input_strides, tmp_device_ptrs_strides.data(),
info.input_size * sizeof(*d_input_strides), cudaMemcpyHostToDevice, stream));
d_input_contiguous = d_bools;
d_input_broadcasted = d_bools + info.input_size;
d_output_shape = reinterpret_cast<const size_t *>(d_output_shape_strides);
d_output_strides = reinterpret_cast<const ptrdiff_t *>(d_output_shape_strides + info.ndim * sizeof(*d_output_shape));
return INFINI_STATUS_SUCCESS;
}
constexpr auto input_size = N;
const auto ndim = info.getNdim();
constexpr auto input_arr_size = N * sizeof(*h_inputs_arr);
const int8_t *info_meta_start = info.getMetaStart();
const int8_t *d_meta_start = reinterpret_cast<int8_t *>(workspace) + input_arr_size;
// copy the input pointer array and meta to device
CHECK_CUDA(cudaMemcpyAsync(workspace, h_inputs_arr, input_arr_size, cudaMemcpyHostToDevice, stream));
CHECK_CUDA(cudaMemcpyAsync((void *)d_meta_start, info_meta_start, info.getMetaMemSize(), cudaMemcpyHostToDevice, stream));
// offset/assign the pointers
d_inputs_arr = reinterpret_cast<const void **>(workspace);
d_output_shape = reinterpret_cast<const size_t *>(d_meta_start);
d_output_strides = reinterpret_cast<const ptrdiff_t *>(d_output_shape + ndim);
d_input_shapes = reinterpret_cast<const size_t *>(d_output_strides + ndim);
d_input_strides = reinterpret_cast<const ptrdiff_t *>(d_input_shapes + input_size * ndim);
d_input_contiguous = reinterpret_cast<const bool *>(d_input_strides + input_size * ndim);
d_input_broadcasted = reinterpret_cast<const bool *>(d_input_contiguous + input_size);
/**
* @brief Frees all device-allocated memory used for metadata in elementwise kernel execution.
*
* @param d_inputs_arr Device array of input pointers.
* @param d_bools Device memory holding input flags.
* @param d_output_shape_strides Device buffer holding output shape and strides.
* @param input_size Number of input tensors.
* @param d_input_shapes Device array of input shape pointers.
* @param d_input_strides Device array of input stride pointers.
* @param stream CUDA stream for async deallocation.
* @return infiniStatus_t Status indicating success or failure.
*/
inline infiniStatus_t freeAllDevice(const void **d_inputs_arr,
const bool *d_bools,
const int8_t *d_output_shape_strides,
const size_t input_size,
const size_t **d_input_shapes,
const ptrdiff_t **d_input_strides,
cudaStream_t stream) const {
CHECK_CUDA(cudaFreeAsync((void *)d_inputs_arr, stream));
CHECK_CUDA(cudaFreeAsync((void *)d_bools, stream));
CHECK_CUDA(cudaFreeAsync((void *)d_output_shape_strides, stream));
CHECK_CUDA(cudaFreeAsync((void *)d_input_shapes, stream));
CHECK_CUDA(cudaFreeAsync((void *)d_input_strides, stream));
return INFINI_STATUS_SUCCESS;
}
};
......@@ -476,6 +404,7 @@ infiniStatus_t DeviceImpl::create(DeviceImpl **device_info,
template <unsigned int BLOCK_SIZE, typename Op, typename Tout, typename... Tin, typename... Args,
std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int>>
infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info,
void *workspace,
void *output,
const std::vector<const void *> &inputs,
void *stream,
......@@ -483,8 +412,7 @@ infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &inf
constexpr size_t N = Op::num_inputs;
static_assert(sizeof...(Tin) == N, "Input type count mismatch");
return _opaque->calculateImpl<BLOCK_SIZE, N, Op, Tout, Tin...>(
info, output, inputs,
std::make_index_sequence<N>{},
info, workspace, output, inputs,
reinterpret_cast<cudaStream_t>(stream),
std::forward<Args>(args)...);
}
......@@ -492,14 +420,14 @@ infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &inf
/* Invoke elementwise operation when all inputs have the same dtype */
template <unsigned int BLOCK_SIZE, typename Op, typename Tdata, typename... Args>
infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info,
void *workspace,
void *output,
const std::vector<const void *> &inputs,
void *stream,
Args &&...args) {
constexpr size_t N = Op::num_inputs;
return _opaque->calculateImpl<BLOCK_SIZE, N, Op, Tdata>(
info, output, inputs,
std::make_index_sequence<N>{},
info, workspace, output, inputs,
reinterpret_cast<cudaStream_t>(stream),
std::forward<Args>(args)...);
}
......
......@@ -31,6 +31,7 @@ public:
* @tparam Args... Additional arguments passed to the operation.
*
* @param info Metadata describing tensor shapes, strides, etc.
* @param workspace Pointer to workspace buffer on device.
* @param output Pointer to output buffer on device.
* @param inputs Vector of input pointers (device memory).
* @param stream CUDA stream (opaque void*).
......@@ -40,6 +41,7 @@ public:
template <unsigned int BLOCK_SIZE, typename Op, typename Tdata, typename... Args>
infiniStatus_t calculate(
const op::elementwise::ElementwiseInfo &info,
void *workspace,
void *output,
const std::vector<const void *> &inputs,
void *stream,
......@@ -56,6 +58,7 @@ public:
* @tparam Tin... Input data types (must match Op::num_inputs).
* @tparam Args... Additional arguments passed to the operation.
* @param info Metadata describing tensor shapes, strides, etc.
* @param workspace Pointer to workspace buffer on device.
* @param output Pointer to output buffer on device.
* @param inputs Vector of input pointers (device memory).
* @param stream CUDA stream (opaque void*).
......@@ -67,6 +70,7 @@ public:
std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int> = 0>
infiniStatus_t calculate(
const op::elementwise::ElementwiseInfo &info,
void *workspace,
void *output,
const std::vector<const void *> &inputs,
void *stream,
......@@ -82,14 +86,17 @@ public:
\
auto info_result = op::elementwise::ElementwiseInfo::create(out_desc, input_desc); \
CHECK_RESULT(info_result); \
auto info = info_result.take(); \
auto workspace_size = info.getMetaMemSize() + info.getInputSize() * sizeof(void *); \
\
op::elementwise::cuda::DeviceImpl *device_impl; \
CHECK_STATUS(op::elementwise::cuda::DeviceImpl::create(&device_impl, handle->internal())); \
\
*desc_ptr = new Descriptor( \
dtype, \
std::move(info_result.take()), \
std::move(info), \
device_impl, \
workspace_size, \
handle->device, \
handle->device_id);
......
......@@ -19,21 +19,26 @@
infiniDtype_t _dtype; \
op::elementwise::ElementwiseInfo _info; \
std::unique_ptr<op::elementwise::NAMESPACE::DeviceImpl> _device_info; \
size_t _workspace_size; \
\
Descriptor( \
infiniDtype_t dtype, \
op::elementwise::ElementwiseInfo info, \
op::elementwise::NAMESPACE::DeviceImpl *device_info, \
size_t workspace_size, \
infiniDevice_t device_type, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_dtype(dtype), \
_info(std::move(info)), \
_device_info(device_info) {} \
_device_info(device_info), \
_workspace_size(workspace_size) {} \
\
public: \
~Descriptor(); \
\
size_t workspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
......@@ -41,6 +46,7 @@
std::vector<infiniopTensorDescriptor_t> input_descs); \
\
infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
void *output, \
std::vector<const void *> inputs, \
void *stream) const; \
......@@ -62,57 +68,70 @@ namespace op::elementwise {
*/
struct ElementwiseInfo {
private:
ElementwiseInfo() = default;
std::vector<int8_t> _meta;
size_t _output_size;
size_t _input_size;
size_t _ndim;
bool _output_contiguous;
ElementwiseInfo(std::vector<int8_t> meta,
size_t output_size,
size_t input_size,
size_t ndim,
bool output_contiguous)
: _meta(std::move(meta)), _output_size(output_size),
_input_size(input_size), _ndim(ndim),
_output_contiguous(output_contiguous) {}
public:
size_t output_size;
size_t ndim;
bool output_contiguous;
bool *input_contiguous;
bool *input_broadcasted;
size_t *output_shape;
size_t **input_shapes;
ptrdiff_t *output_strides;
ptrdiff_t **input_strides;
size_t input_size;
~ElementwiseInfo() {
delete[] input_contiguous;
delete[] input_broadcasted;
delete[] output_shape;
delete[] output_strides;
for (size_t i = 0; i < input_size; ++i) {
delete[] input_shapes[i];
delete[] input_strides[i];
inline size_t getMetaMemSize() const {
return _meta.size();
}
inline const int8_t *getMetaStart() const {
return _meta.data();
}
inline size_t getOutputSize() const {
return _output_size;
}
inline size_t getInputSize() const {
return _input_size;
}
inline size_t getNdim() const {
return _ndim;
}
inline bool isOutputContiguous() const {
return _output_contiguous;
}
inline const size_t *getOutputShape() const {
return reinterpret_cast<const size_t *>(_meta.data());
}
inline const ptrdiff_t *getOutputStrides() const {
return reinterpret_cast<const ptrdiff_t *>(getOutputShape() + _ndim);
}
inline const size_t *getAllInputShapes() const {
return reinterpret_cast<const size_t *>(getOutputStrides() + _ndim);
}
inline const size_t *getInputShape(const size_t &index) const {
if (index < _input_size) {
return reinterpret_cast<const size_t *>(getAllInputShapes() + index * _ndim);
}
return nullptr;
}
inline const ptrdiff_t *getAllInputStrides() const {
return reinterpret_cast<const ptrdiff_t *>(getAllInputShapes() + _input_size * _ndim);
}
inline const ptrdiff_t *getInputStrides(const size_t &index) const {
if (index < _input_size) {
return reinterpret_cast<const ptrdiff_t *>(getAllInputStrides() + index * _ndim);
}
delete[] input_shapes;
delete[] input_strides;
}
ElementwiseInfo(ElementwiseInfo &&other) noexcept
: output_size(other.output_size),
ndim(other.ndim),
output_contiguous(other.output_contiguous),
input_contiguous(other.input_contiguous),
input_broadcasted(other.input_broadcasted),
output_shape(other.output_shape),
input_shapes(other.input_shapes),
output_strides(other.output_strides),
input_strides(other.input_strides),
input_size(other.input_size) {
other.input_contiguous = nullptr;
other.input_broadcasted = nullptr;
other.output_shape = nullptr;
other.input_shapes = nullptr;
other.output_strides = nullptr;
other.input_strides = nullptr;
other.input_size = 0;
}
ElementwiseInfo(const ElementwiseInfo &other) = delete;
ElementwiseInfo &operator=(const ElementwiseInfo &other) = delete;
ElementwiseInfo &operator=(ElementwiseInfo &&other) = delete;
return nullptr;
}
inline const bool *getInputContiguous() const {
return reinterpret_cast<const bool *>(getAllInputStrides() + _input_size * _ndim);
}
inline const bool *getInputBroadcasted() const {
return reinterpret_cast<const bool *>(getInputContiguous() + _input_size);
}
using ResultType = utils::Result<ElementwiseInfo>;
......@@ -136,40 +155,48 @@ public:
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
ElementwiseInfo info;
info.input_size = input_descs.size();
info.ndim = output_desc->ndim();
info.output_size = output_desc->numel();
info.output_contiguous = output_desc->isContiguous();
// Allocate memory for arrays
info.input_contiguous = new bool[info.input_size];
info.input_broadcasted = new bool[info.input_size];
info.output_shape = new size_t[info.ndim];
info.output_strides = new ptrdiff_t[info.ndim];
info.input_shapes = new size_t *[info.input_size];
info.input_strides = new ptrdiff_t *[info.input_size];
// Fill arrays
auto input_size = input_descs.size();
auto ndim = output_desc->ndim();
auto output_size = output_desc->numel();
auto output_contiguous = output_desc->isContiguous();
// Allocate memory for meta
auto shape_unit = output_desc->dim(0);
auto stride_unit = output_desc->stride(0);
size_t meta_mem_size = ndim * (sizeof(shape_unit) + sizeof(stride_unit))
+ input_size * ndim * sizeof(shape_unit)
+ input_size * ndim * sizeof(stride_unit)
+ 2 * input_size * sizeof(bool);
std::vector<int8_t> meta(meta_mem_size);
int8_t *meta_ptr = meta.data();
const auto output_shape = output_desc->shape();
const auto output_strides = output_desc->strides();
std::memcpy(info.output_shape, output_shape.data(), info.ndim * sizeof(*info.output_shape));
std::memcpy(info.output_strides, output_strides.data(), info.ndim * sizeof(*info.output_strides));
for (size_t i = 0; i < info.input_size; ++i) {
auto &desc = input_descs[i];
info.input_contiguous[i] = desc->isContiguous();
info.input_broadcasted[i] = !info.input_contiguous[i] && (desc->ndim() != info.ndim || desc->hasBroadcastDim());
// Pointers to the sections within _meta
size_t *output_shape_p = reinterpret_cast<size_t *>(meta_ptr);
ptrdiff_t *output_strides_p = reinterpret_cast<ptrdiff_t *>(output_shape_p + ndim);
size_t *input_shapes = reinterpret_cast<size_t *>(output_strides_p + ndim);
ptrdiff_t *input_strides = reinterpret_cast<ptrdiff_t *>(input_shapes + input_size * ndim);
bool *input_contiguous = reinterpret_cast<bool *>(input_strides + input_size * ndim);
bool *input_broadcasted = input_contiguous + input_size;
info.input_shapes[i] = new size_t[desc->ndim()];
const auto &in_shape = desc->shape();
std::memcpy(info.input_shapes[i], in_shape.data(), desc->ndim() * sizeof(*info.input_shapes[i]));
// Copy output shape and strides
std::memcpy(output_shape_p, output_shape.data(), ndim * sizeof(*output_shape_p));
std::memcpy(output_strides_p, output_strides.data(), ndim * sizeof(*output_strides_p));
info.input_strides[i] = new ptrdiff_t[desc->ndim()];
const auto &in_strides = desc->strides();
std::memcpy(info.input_strides[i], in_strides.data(), desc->ndim() * sizeof(*info.input_strides[i]));
// Copy input shapes, strides, contiguous, and broadcasted flags
for (size_t i = 0; i < input_size; ++i) {
auto &desc = input_descs[i];
const auto in_shape = desc->shape();
const auto in_strides = desc->strides();
std::memcpy(input_shapes + i * ndim, in_shape.data(), ndim * sizeof(*input_shapes));
std::memcpy(input_strides + i * ndim, in_strides.data(), ndim * sizeof(*input_strides));
input_contiguous[i] = desc->isContiguous();
input_broadcasted[i] = !input_contiguous[i] && (desc->ndim() != ndim || desc->hasBroadcastDim());
}
ElementwiseInfo info(std::move(meta), output_size, input_size, ndim, output_contiguous);
return ResultType(std::move(info));
}
};
......
......@@ -30,6 +30,8 @@ infiniStatus_t Descriptor::create(
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *output,
std::vector<const void *> inputs,
void *stream) const {
......
......@@ -10,7 +10,7 @@ typedef struct SwiGLUOp {
private:
template <typename T>
T sigmoid(const T &x) const {
return 1 / (1 + std::exp(-x));
return T(1) / (T(1) + std::exp(-x));
}
public:
......
......@@ -21,9 +21,7 @@ infiniStatus_t Descriptor::create(
const auto &gate_shape = gate_desc->shape();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64);
if (!SAME_VEC(out_shape, up_shape, gate_shape)) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
CHECK_SAME_SHAPE(out_shape, up_shape, gate_shape);
// create CUDA elementwise descriptor
CREATE_ELEMENTWISE_CUDA_DESCRIPTOR
......@@ -32,17 +30,23 @@ infiniStatus_t Descriptor::create(
}
infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *output,
std::vector<const void *> inputs,
void *stream) const {
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
switch (_dtype) {
case INFINI_DTYPE_F16:
return _device_info->calculate<256, SwiGLUOp, half>(_info, output, inputs, stream);
return _device_info->calculate<256, SwiGLUOp, half>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F32:
return _device_info->calculate<256, SwiGLUOp, float>(_info, output, inputs, stream);
return _device_info->calculate<256, SwiGLUOp, float>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F64:
return _device_info->calculate<256, SwiGLUOp, double>(_info, output, inputs, stream);
return _device_info->calculate<256, SwiGLUOp, double>(_info, workspace, output, inputs, stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
......
......@@ -66,8 +66,49 @@ __C infiniStatus_t infiniopCreateSwiGLUDescriptor(
#undef CREATE
}
__C infiniStatus_t infiniopGetSwiGLUWorkspaceSize(infiniopSwiGLUDescriptor_t desc, size_t *size) {
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<op::swiglu::NAMESPACE::Descriptor *>(desc)->workspaceSize(); \
return INFINI_STATUS_SUCCESS;
switch (desc->device_type) {
#ifdef ENABLE_CPU_API
GET(INFINI_DEVICE_CPU, cpu)
#endif
#ifdef ENABLE_CUDA_API
GET(INFINI_DEVICE_NVIDIA, cuda)
#endif
#ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: {
return bangGetSwiGLUWorkspaceSize((SwiGLUBangDescriptor_t)desc, size);
}
#endif
#ifdef ENABLE_ASCEND_API
GET(INFINI_DEVICE_ASCEND, ascend)
#endif
#ifdef ENABLE_METAX_GPU
case DevMetaxGpu: {
return macaGetSwiGLUWorkspaceSize((SwiGLUMacaDescriptor_t)desc, size);
}
#endif
#ifdef ENABLE_MTHREADS_GPU
case DevMthreadsGpu: {
return musaGetSwiGLUWorkspaceSize((SwiGLUMusaDescriptor_t)desc, size);
}
#endif
}
#undef GET
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
__C infiniStatus_t infiniopSwiGLU(
infiniopSwiGLUDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *c,
const void *a,
const void *b,
......@@ -76,7 +117,7 @@ __C infiniStatus_t infiniopSwiGLU(
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<const op::swiglu::NAMESPACE::Descriptor *>(desc) \
->calculate(c, {a, b}, stream)
->calculate(workspace, workspace_size, c, {a, b}, stream)
switch (desc->device_type) {
......
import torch
import ctypes
from ctypes import POINTER, Structure, c_int32, c_void_p
from ctypes import POINTER, Structure, c_int32, c_void_p, c_uint64
from libinfiniop import (
infiniopHandle_t,
infiniopTensorDescriptor_t,
......@@ -14,6 +14,7 @@ from libinfiniop import (
debug,
get_tolerance,
profile_operation,
create_workspace
)
from enum import Enum, auto
......@@ -160,10 +161,19 @@ def test(
for tensor in [a_tensor, b_tensor, c_tensor]:
tensor.destroyDesc(lib)
workspace_size = c_uint64(0)
check_error(
lib.infiniopGetSwiGLUWorkspaceSize(descriptor, ctypes.byref(workspace_size))
)
workspace = create_workspace(workspace_size.value, c.device)
def lib_swiglu():
check_error(
lib.infiniopSwiGLU(
descriptor, c_tensor.data, a_tensor.data, b_tensor.data, None
descriptor,
workspace.data_ptr() if workspace is not None else None,
workspace_size.value,
c_tensor.data, a_tensor.data, b_tensor.data, None
)
)
......@@ -196,10 +206,18 @@ if __name__ == "__main__":
infiniopTensorDescriptor_t,
]
lib.infiniopGetSwiGLUWorkspaceSize.restype = c_int32
lib.infiniopGetSwiGLUWorkspaceSize.argtypes = [
infiniopSwiGLUDescriptor_t,
POINTER(c_uint64),
]
lib.infiniopSwiGLU.restype = c_int32
lib.infiniopSwiGLU.argtypes = [
infiniopSwiGLUDescriptor_t,
c_void_p,
c_uint64,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment