Commit 9cc0c416 authored by Zimin Li's avatar Zimin Li
Browse files

issue/127: Refactor ElementwiseInfo, refactor elementwise to use workspace for...

issue/127: Refactor ElementwiseInfo, refactor elementwise to use workspace for storing meta, fix misc. issues
parent 40fdded5
...@@ -11,7 +11,11 @@ __C __export infiniStatus_t infiniopCreateSwiGLUDescriptor(infiniopHandle_t hand ...@@ -11,7 +11,11 @@ __C __export infiniStatus_t infiniopCreateSwiGLUDescriptor(infiniopHandle_t hand
infiniopTensorDescriptor_t a_desc, infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc); infiniopTensorDescriptor_t b_desc);
__C __export infiniStatus_t infiniopGetSwiGLUWorkspaceSize(infiniopSwiGLUDescriptor_t desc, size_t *size);
__C __export infiniStatus_t infiniopSwiGLU(infiniopSwiGLUDescriptor_t desc, __C __export infiniStatus_t infiniopSwiGLU(infiniopSwiGLUDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *c, void *c,
void const *a, void const *a,
void const *b, void const *b,
......
...@@ -9,6 +9,10 @@ ...@@ -9,6 +9,10 @@
#define CUDA_BLOCK_SIZE_1024 1024 #define CUDA_BLOCK_SIZE_1024 1024
#define CUDA_BLOCK_SIZE_512 512 #define CUDA_BLOCK_SIZE_512 512
#define CHECK_CUDA(API) CHECK_INTERNAL(API, cudaSuccess)
namespace device::cuda {
// return the memory offset of original tensor, given the flattened index of broadcasted tensor // return the memory offset of original tensor, given the flattened index of broadcasted tensor
__forceinline__ __device__ __host__ size_t __forceinline__ __device__ __host__ size_t
indexToReducedOffset( indexToReducedOffset(
...@@ -38,6 +42,7 @@ indexToOffset( ...@@ -38,6 +42,7 @@ indexToOffset(
} }
return res; return res;
} }
} // namespace device::cuda
#ifdef ENABLE_CUDA_API #ifdef ENABLE_CUDA_API
#include <cuda_fp16.h> #include <cuda_fp16.h>
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
dtype, \ dtype, \
info_result.take(), \ info_result.take(), \
nullptr, \ nullptr, \
0, \
handle->device, \ handle->device, \
handle->device_id); handle->device_id);
...@@ -103,24 +104,34 @@ infiniStatus_t DeviceImpl::create(DeviceImpl **device_info, Args &&...args) { ...@@ -103,24 +104,34 @@ infiniStatus_t DeviceImpl::create(DeviceImpl **device_info, Args &&...args) {
} }
// Perform elementwise operation for different input types // Perform elementwise operation for different input types
template <typename Op, typename Tout, typename... Tin, size_t... Is, typename... Args, std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int> = 0> template <typename Op, typename Tout, typename... Tin, size_t... Is, typename... Args,
void calculate_impl(const op::elementwise::ElementwiseInfo &info, void *output, const std::vector<const void *> &inputs, std::index_sequence<Is...>, Args &&...args) { std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int> = 0>
void calculate_impl(const op::elementwise::ElementwiseInfo &info,
void *output,
const std::vector<const void *> &inputs,
std::index_sequence<Is...>,
Args &&...args) {
Tout *out = reinterpret_cast<Tout *>(output); Tout *out = reinterpret_cast<Tout *>(output);
std::tuple<const Tin *...> input_ptrs = {reinterpret_cast<const Tin *>(inputs[Is])...}; std::tuple<const Tin *...> input_ptrs = {reinterpret_cast<const Tin *>(inputs[Is])...};
ptrdiff_t output_size = info.output_size; ptrdiff_t output_size = info.getOutputSize();
#pragma omp parallel for #pragma omp parallel for
for (ptrdiff_t i = 0; i < output_size; ++i) { for (ptrdiff_t i = 0; i < output_size; ++i) {
size_t out_idx = info.output_contiguous ? i : op::common_cpu::indexToOffset(i, info.ndim, info.output_shape, info.output_strides); size_t out_idx = info.isOutputContiguous()
? i
: op::common_cpu::indexToOffset(i, info.getNdim(), info.getOutputShape(), info.getOutputStrides());
auto get_input_idx = [&](size_t input_id) { auto get_input_idx = [&](size_t input_id) {
return info.input_contiguous[input_id] ? i return info.getInputContiguous()[input_id]
: (info.input_broadcasted[input_id] ? i
? op::common_cpu::indexToReducedOffset(i, info.ndim, info.output_strides, info.input_strides[input_id]) : (info.getInputBroadcasted()[input_id]
: op::common_cpu::indexToOffset(i, info.ndim, info.input_shapes[input_id], info.input_strides[input_id])); ? op::common_cpu::indexToReducedOffset(i, info.getNdim(), info.getOutputStrides(), info.getInputStrides(input_id))
: op::common_cpu::indexToOffset(i, info.getNdim(), info.getInputShape(input_id), info.getInputStrides(input_id)));
}; };
out[out_idx] = utils::cast<Tout>(Op{}.template operator()<Tout, Tin...>(std::get<Is>(input_ptrs)[get_input_idx(Is)]..., std::forward<Args>(args)...)); out[out_idx] = utils::cast<Tout>(
Op{}.template operator()<Tout, Tin...>(std::get<Is>(input_ptrs)[get_input_idx(Is)]..., std::forward<Args>(args)...));
} }
} }
...@@ -147,17 +158,20 @@ void calculate_impl(const op::elementwise::ElementwiseInfo &info, ...@@ -147,17 +158,20 @@ void calculate_impl(const op::elementwise::ElementwiseInfo &info,
Tdata *out = reinterpret_cast<Tdata *>(output); Tdata *out = reinterpret_cast<Tdata *>(output);
std::array<const Tdata *, sizeof...(Is)> ins = {reinterpret_cast<const Tdata *>(inputs[Is])...}; std::array<const Tdata *, sizeof...(Is)> ins = {reinterpret_cast<const Tdata *>(inputs[Is])...};
const ptrdiff_t output_size = info.output_size; const ptrdiff_t output_size = info.getOutputSize();
#pragma omp parallel for #pragma omp parallel for
for (ptrdiff_t i = 0; i < output_size; ++i) { for (ptrdiff_t i = 0; i < output_size; ++i) {
size_t out_idx = info.output_contiguous ? i : op::common_cpu::indexToOffset(i, info.ndim, info.output_shape, info.output_strides); size_t out_idx = info.isOutputContiguous()
? i
: op::common_cpu::indexToOffset(i, info.getNdim(), info.getOutputShape(), info.getOutputStrides());
auto get_input_idx = [&](size_t input_id) { auto get_input_idx = [&](size_t input_id) {
return info.input_contiguous[input_id] ? i return info.getInputContiguous()[input_id]
: (info.input_broadcasted[input_id] ? i
? op::common_cpu::indexToReducedOffset(i, info.ndim, info.output_strides, info.input_strides[input_id]) : (info.getInputBroadcasted()[input_id]
: op::common_cpu::indexToOffset(i, info.ndim, info.input_shapes[input_id], info.input_strides[input_id])); ? op::common_cpu::indexToReducedOffset(i, info.getNdim(), info.getOutputStrides(), info.getInputStrides(input_id))
: op::common_cpu::indexToOffset(i, info.getNdim(), info.getInputShape(input_id), info.getInputStrides(input_id)));
}; };
if constexpr (std::is_same_v<Tdata, fp16_t>) { if constexpr (std::is_same_v<Tdata, fp16_t>) {
...@@ -170,7 +184,11 @@ void calculate_impl(const op::elementwise::ElementwiseInfo &info, ...@@ -170,7 +184,11 @@ void calculate_impl(const op::elementwise::ElementwiseInfo &info,
// Invoke elementwise operation when all inputs have the same type // Invoke elementwise operation when all inputs have the same type
template <typename Op, typename Tdata, typename... Args> template <typename Op, typename Tdata, typename... Args>
infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info, void *output, const std::vector<const void *> &inputs, void *stream, Args &&...args) { infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info,
void *output,
const std::vector<const void *> &inputs,
void *stream,
Args &&...args) {
constexpr size_t N = Op::num_inputs; constexpr size_t N = Op::num_inputs;
calculate_impl<Op, Tdata>(info, output, inputs, std::make_index_sequence<N>{}, std::forward<Args>(args)...); calculate_impl<Op, Tdata>(info, output, inputs, std::make_index_sequence<N>{}, std::forward<Args>(args)...);
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include "../../../utils.h" #include "../../../utils.h"
#include "../../devices/cuda/cuda_common.cuh" #include "../../devices/cuda/cuda_common.cuh"
#include "../../devices/cuda/cuda_kernel_common.cuh"
#include "elementwise_cuda_api.cuh" #include "elementwise_cuda_api.cuh"
namespace op::elementwise::cuda { namespace op::elementwise::cuda {
...@@ -16,18 +17,17 @@ namespace op::elementwise::cuda { ...@@ -16,18 +17,17 @@ namespace op::elementwise::cuda {
* @param lambda Lambda to be called with std::integral_constant<size_t, Is>... as arguments. * @param lambda Lambda to be called with std::integral_constant<size_t, Is>... as arguments.
*/ */
template <typename Lambda, size_t... Is> template <typename Lambda, size_t... Is>
__device__ __forceinline__ void call_expand(Lambda lambda, std::index_sequence<Is...>) { __device__ __forceinline__ void callExpand(Lambda lambda, std::index_sequence<Is...>) {
lambda(std::integral_constant<size_t, Is>{}...); lambda(std::integral_constant<size_t, Is>{}...);
} }
/** /**
* @brief CUDA kernel for performing elementwise operations on tensors where all inputs share the same data type. * @brief CUDA kernel for performing elementwise operations on tensors where all inputs share the same data type.
* *
* @tparam N Number of input tensors.
* @tparam Op Operator type implementing operator()(Tdata...). * @tparam Op Operator type implementing operator()(Tdata...).
* @tparam Tdata Common data type for inputs and output. * @tparam Tdata Common data type for inputs and output.
* @tparam N Number of input tensors.
* @tparam Args Additional arguments to pass to the operator. * @tparam Args Additional arguments to pass to the operator.
*
* @param output_size Total number of output elements. * @param output_size Total number of output elements.
* @param ndim Number of dimensions in tensors. * @param ndim Number of dimensions in tensors.
* @param output_contiguous Whether the output tensor is contiguous in memory. * @param output_contiguous Whether the output tensor is contiguous in memory.
...@@ -37,24 +37,22 @@ __device__ __forceinline__ void call_expand(Lambda lambda, std::index_sequence<I ...@@ -37,24 +37,22 @@ __device__ __forceinline__ void call_expand(Lambda lambda, std::index_sequence<I
* @param input_shapes Shapes of the input tensors. * @param input_shapes Shapes of the input tensors.
* @param output_strides Strides for the output tensor. * @param output_strides Strides for the output tensor.
* @param input_strides Strides for each input tensor. * @param input_strides Strides for each input tensor.
* @param input_size Total number of input elements (optional, may be unused).
* @param output Output buffer. * @param output Output buffer.
* @param inputs Array of input pointers, all of type Tdata. * @param inputs Array of input pointers, all of type Tdata.
* @param offset Linear offset to support partitioned execution. * @param offset Linear offset to support partitioned execution.
* @param args Additional arguments passed to the operator. * @param args Additional arguments passed to the operator.
*/ */
template <size_t N, typename Op, typename Tdata, typename... Args> template <size_t N, typename Op, typename Tdata, typename... Args>
INFINIOP_CUDA_KERNEL elementwise_kernel( INFINIOP_CUDA_KERNEL elementwiseKernel(
size_t output_size, size_t output_size,
size_t ndim, size_t ndim,
bool output_contiguous, bool output_contiguous,
const bool *__restrict__ input_contiguous, const bool *__restrict__ input_contiguous,
const bool *__restrict__ input_broadcasted, const bool *__restrict__ input_broadcasted,
const size_t *__restrict__ output_shape, const size_t *__restrict__ output_shape,
const size_t *__restrict__ *__restrict__ input_shapes, const size_t *__restrict__ input_shapes,
const ptrdiff_t *__restrict__ output_strides, const ptrdiff_t *__restrict__ output_strides,
const ptrdiff_t *__restrict__ *__restrict__ input_strides, const ptrdiff_t *__restrict__ input_strides,
size_t input_size,
Tdata *output, Tdata *output,
const Tdata *const *inputs, const Tdata *const *inputs,
size_t offset, size_t offset,
...@@ -68,8 +66,8 @@ INFINIOP_CUDA_KERNEL elementwise_kernel( ...@@ -68,8 +66,8 @@ INFINIOP_CUDA_KERNEL elementwise_kernel(
auto get_input_idx = [&] __device__(size_t input_id) { auto get_input_idx = [&] __device__(size_t input_id) {
return input_contiguous[input_id] ? idx return input_contiguous[input_id] ? idx
: (input_broadcasted[input_id] : (input_broadcasted[input_id]
? device::cuda::indexToReducedOffset(idx, ndim, output_strides, input_strides[input_id]) ? device::cuda::indexToReducedOffset(idx, ndim, output_strides, input_strides + input_id * ndim)
: device::cuda::indexToOffset(idx, ndim, input_shapes[input_id], input_strides[input_id])); : device::cuda::indexToOffset(idx, ndim, input_shapes + input_id * ndim, input_strides + input_id * ndim));
}; };
// Use a helper to expand the index sequence into individual compile-time constants // Use a helper to expand the index sequence into individual compile-time constants
...@@ -85,7 +83,7 @@ INFINIOP_CUDA_KERNEL elementwise_kernel( ...@@ -85,7 +83,7 @@ INFINIOP_CUDA_KERNEL elementwise_kernel(
} }
}; };
call_expand(expand_inputs, std::make_index_sequence<N>{}); callExpand(expand_inputs, std::make_index_sequence<N>{});
} }
} }
...@@ -97,38 +95,38 @@ INFINIOP_CUDA_KERNEL elementwise_kernel( ...@@ -97,38 +95,38 @@ INFINIOP_CUDA_KERNEL elementwise_kernel(
* @return Pointer of type const T*. * @return Pointer of type const T*.
*/ */
template <typename T> template <typename T>
__device__ inline const T *typed_input_ptr(const void *ptr) { __device__ inline const T *typedInputPtr(const void *ptr) {
return reinterpret_cast<const T *>(ptr); return reinterpret_cast<const T *>(ptr);
} }
/** /**
* @brief Launches a type-safe elementwise operation on a single output element. * @brief Launches elementwise operation at a specific output index.
*
* @tparam Op Operator type implementing a templated operator() for (Tout, Tin...).
* @tparam Tout Output data type.
* @tparam Tin Variadic input data types.
* @tparam Is Index sequence corresponding to each input.
* *
* @param idx Linear index in the flattened output space. * @tparam Op Functor representing the elementwise operation.
* @param out_idx Actual output index (may be non-contiguous). * @tparam Tout Output data type.
* @param ndim Number of dimensions in the tensors. * @tparam Tin... Input data types.
* @param input_contiguous Array indicating whether each input is contiguous. * @tparam Is... Index sequence for unpacking variadic inputs.
* @param input_broadcasted Array indicating whether each input is broadcasted. * @param idx Global linear index into the output tensor.
* @param input_shapes Shapes of the input tensors. * @param out_idx Offset into the output array.
* @param input_strides Strides of the input tensors. * @param ndim Number of dimensions in the tensors.
* @param inputs Raw pointers to input data. * @param input_contiguous Flags indicating whether each input is contiguous.
* @param output Pointer to output data. * @param input_broadcasted Flags indicating whether each input is broadcasted.
* @param ... Index sequence used for unpacking variadic inputs. * @param input_shapes Flattened input shapes (N * ndim).
* @param input_strides Flattened input strides (N * ndim).
* @param output_strides Output tensor strides.
* @param inputs Array of pointers to input tensors.
* @param output Pointer to output tensor.
* @param ...Is Index sequence for iterating over input tensors.
*/ */
template <typename Op, typename Tout, typename... Tin, size_t... Is> template <typename Op, typename Tout, typename... Tin, size_t... Is>
__device__ void launch_op( __device__ void launchOp(
size_t idx, size_t idx,
size_t out_idx, size_t out_idx,
size_t ndim, size_t ndim,
const bool *__restrict__ input_contiguous, const bool *__restrict__ input_contiguous,
const bool *__restrict__ input_broadcasted, const bool *__restrict__ input_broadcasted,
const size_t *__restrict__ const *__restrict__ input_shapes, const size_t *__restrict__ input_shapes,
const ptrdiff_t *__restrict__ const *__restrict__ input_strides, const ptrdiff_t *__restrict__ input_strides,
const ptrdiff_t *__restrict__ output_strides, const ptrdiff_t *__restrict__ output_strides,
const void *const *__restrict__ inputs, const void *const *__restrict__ inputs,
Tout *output, Tout *output,
...@@ -138,12 +136,12 @@ __device__ void launch_op( ...@@ -138,12 +136,12 @@ __device__ void launch_op(
return input_contiguous[input_id] return input_contiguous[input_id]
? idx ? idx
: (input_broadcasted[input_id] : (input_broadcasted[input_id]
? device::cuda::indexToReducedOffset(idx, ndim, output_strides, input_strides[input_id]) ? device::cuda::indexToReducedOffset(idx, ndim, output_strides, input_strides + input_id * ndim)
: device::cuda::indexToOffset(idx, ndim, input_shapes[input_id], input_strides[input_id])); : device::cuda::indexToOffset(idx, ndim, input_shapes + input_id * ndim, input_strides + input_id * ndim));
}; };
output[out_idx] = Op{}.template operator()<Tout, Tin...>( output[out_idx] = Op{}.template operator()<Tout, Tin...>(
(typed_input_ptr<Tin>(inputs[Is])[get_input_idx(Is)])...); (typedInputPtr<Tin>(inputs[Is])[get_input_idx(Is)])...);
} }
/** /**
...@@ -153,7 +151,6 @@ __device__ void launch_op( ...@@ -153,7 +151,6 @@ __device__ void launch_op(
* @tparam Op Operator type implementing a templated operator() for (Tout, Tin...). * @tparam Op Operator type implementing a templated operator() for (Tout, Tin...).
* @tparam Tout Output data type. * @tparam Tout Output data type.
* @tparam Tin Variadic input data types. * @tparam Tin Variadic input data types.
*
* @param output_size Total number of output elements. * @param output_size Total number of output elements.
* @param ndim Number of dimensions in the tensors. * @param ndim Number of dimensions in the tensors.
* @param output_contiguous Whether the output tensor is contiguous. * @param output_contiguous Whether the output tensor is contiguous.
...@@ -163,23 +160,21 @@ __device__ void launch_op( ...@@ -163,23 +160,21 @@ __device__ void launch_op(
* @param input_shapes Shapes of the input tensors. * @param input_shapes Shapes of the input tensors.
* @param output_strides Strides of the output tensor. * @param output_strides Strides of the output tensor.
* @param input_strides Strides of the input tensors. * @param input_strides Strides of the input tensors.
* @param input_size Total number of input elements (unused here, but may be used for validation).
* @param output Pointer to the output buffer. * @param output Pointer to the output buffer.
* @param inputs Array of untyped input pointers. * @param inputs Array of untyped input pointers.
* @param offset Linear offset into the output for partitioned execution. * @param offset Linear offset into the output for partitioned execution.
*/ */
template <typename Op, typename Tout, typename... Tin> template <typename Op, typename Tout, typename... Tin>
INFINIOP_CUDA_KERNEL elementwise_kernel( INFINIOP_CUDA_KERNEL elementwiseKernel(
size_t output_size, size_t output_size,
size_t ndim, size_t ndim,
bool output_contiguous, bool output_contiguous,
const bool *__restrict__ input_contiguous, const bool *__restrict__ input_contiguous,
const bool *__restrict__ input_broadcasted, const bool *__restrict__ input_broadcasted,
const size_t *__restrict__ output_shape, const size_t *__restrict__ output_shape,
const size_t *__restrict__ const *__restrict__ input_shapes, const size_t *__restrict__ input_shapes,
const ptrdiff_t *__restrict__ output_strides, const ptrdiff_t *__restrict__ output_strides,
const ptrdiff_t *__restrict__ const *__restrict__ input_strides, const ptrdiff_t *__restrict__ input_strides,
size_t input_size,
Tout *output, Tout *output,
const void *const *__restrict__ inputs, const void *const *__restrict__ inputs,
size_t offset) { size_t offset) {
...@@ -193,7 +188,7 @@ INFINIOP_CUDA_KERNEL elementwise_kernel( ...@@ -193,7 +188,7 @@ INFINIOP_CUDA_KERNEL elementwise_kernel(
? idx ? idx
: device::cuda::indexToOffset(idx, ndim, output_shape, output_strides); : device::cuda::indexToOffset(idx, ndim, output_shape, output_strides);
launch_op<Op, Tout, Tin...>( launchOp<Op, Tout, Tin...>(
idx, idx,
out_idx, out_idx,
ndim, ndim,
...@@ -214,252 +209,185 @@ struct DeviceImpl::Opaque { ...@@ -214,252 +209,185 @@ struct DeviceImpl::Opaque {
: internal(internal) {} : internal(internal) {}
/** /**
* @brief Performs elementwise operations when all inputs and the output share the same data type. * @brief Executes an elementwise operation where all inputs and the output share the same data type.
* *
* @tparam BLOCK_SIZE The block size for the kernel launch. * @tparam BLOCK_SIZE CUDA block size used for kernel launch.
* @tparam N The number of input tensors. * @tparam N Number of input tensors.
* @tparam Op The operation to perform (e.g., addition, multiplication). * @tparam Op Functor representing the elementwise operation.
* @tparam Tdata The data type of the input and output tensors. * @tparam Tdata Data type of both input and output tensors.
* @tparam Args Additional arguments to be passed to the operation. * @tparam Args Optional additional arguments passed to the operation.
* @param info Structure containing elementwise operation information (size, shape, etc.). * @param info Metadata about the operation including shape, size, and dimensionality.
* @param output Pointer to the output memory where results will be stored. * @param workspace Temporary workspace used for storing metadata on device.
* @param inputs Vector of pointers to input tensors. * @param output Pointer to the output buffer.
* @param stream CUDA stream used for asynchronous execution. * @param inputs Vector of pointers to input buffers.
* @param args Additional arguments for the operation. * @param stream CUDA stream for asynchronous execution.
* @return infiniStatus_t Status indicating success or failure. * @param args Additional arguments forwarded to the operation.
* @return infiniStatus_t Returns success or failure status.
*/ */
template <size_t BLOCK_SIZE, size_t N, typename Op, typename Tdata, typename... Args, size_t... Is> template <size_t BLOCK_SIZE, size_t N, typename Op, typename Tdata, typename... Args>
infiniStatus_t calculateImpl(const op::elementwise::ElementwiseInfo &info, infiniStatus_t calculateImpl(const op::elementwise::ElementwiseInfo &info,
void *workspace,
void *output, void *output,
const std::vector<const void *> &inputs, const std::vector<const void *> &inputs,
std::index_sequence<Is...>,
cudaStream_t stream, cudaStream_t stream,
Args &&...args) { Args &&...args) {
if (info.output_size == 0) { auto output_size = info.getOutputSize();
if (output_size == 0) {
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
// casting the output and the inputs to Tdata pointers // casting the output and the inputs to Tdata pointers
Tdata *out = reinterpret_cast<Tdata *>(output); Tdata *out = reinterpret_cast<Tdata *>(output);
const Tdata *inputs_arr[N]; const void **d_inputs_arr = nullptr;
const Tdata **d_inputs_arr = nullptr;
for (size_t i = 0; i < N; ++i) {
inputs_arr[i] = reinterpret_cast<const Tdata *>(inputs[i]);
}
CHECK_CUDA(cudaMallocAsync(&d_inputs_arr, N * sizeof(*d_inputs_arr), stream));
CHECK_CUDA(cudaMemcpyAsync(d_inputs_arr, inputs_arr, N * sizeof(*d_inputs_arr), cudaMemcpyHostToDevice, stream));
// create and send the info to device // create and send the info to device
const bool *d_bools = nullptr;
const bool *d_input_contiguous = nullptr; const bool *d_input_contiguous = nullptr;
const bool *d_input_broadcasted = nullptr; const bool *d_input_broadcasted = nullptr;
const int8_t *d_output_shape_strides = nullptr;
const size_t *d_output_shape = nullptr; const size_t *d_output_shape = nullptr;
const ptrdiff_t *d_output_strides = nullptr; const ptrdiff_t *d_output_strides = nullptr;
const size_t **d_input_shapes = nullptr; const size_t *d_input_shapes = nullptr;
const ptrdiff_t **d_input_strides = nullptr; const ptrdiff_t *d_input_strides = nullptr;
std::vector<const size_t *> tmp_device_ptrs(info.input_size);
std::vector<const ptrdiff_t *> tmp_device_ptrs_strides(info.input_size);
CHECK_STATUS(infoToDevice<N>(info, d_bools, d_input_contiguous, CHECK_STATUS(infoToDevice<N>(info, workspace, inputs.data(), d_inputs_arr, d_input_contiguous, d_input_broadcasted,
d_input_broadcasted, d_output_shape_strides, d_output_shape, d_output_shape, d_output_strides, d_input_shapes, d_input_strides, stream));
d_output_strides, tmp_device_ptrs, d_input_shapes, tmp_device_ptrs_strides,
d_input_strides, stream));
dim3 blockDims(std::min(BLOCK_SIZE, static_cast<size_t>(internal->maxThreadsPerBlock()))); dim3 blockDims(std::min(BLOCK_SIZE, static_cast<size_t>(internal->maxThreadsPerBlock())));
dim3 gridDims(std::min(CEIL_DIV(info.output_size, blockDims.x), static_cast<size_t>(internal->gridSizeX()))); dim3 gridDims(std::min(CEIL_DIV(output_size, blockDims.x), static_cast<size_t>(internal->gridSizeX())));
size_t step = gridDims.x * blockDims.x; size_t step = gridDims.x * blockDims.x;
for (size_t i = 0; i < info.output_size; i += step) { for (size_t i = 0; i < output_size; i += step) {
elementwise_kernel<N, Op, Tdata, Args...><<<gridDims, blockDims, 0, stream>>>( elementwiseKernel<N, Op, Tdata, Args...><<<gridDims, blockDims, 0, stream>>>(
info.output_size, output_size,
info.ndim, info.getNdim(),
info.output_contiguous, info.isOutputContiguous(),
d_input_contiguous, d_input_contiguous,
d_input_broadcasted, d_input_broadcasted,
d_output_shape, d_output_shape,
d_input_shapes, d_input_shapes,
d_output_strides, d_output_strides,
d_input_strides, d_input_strides,
info.input_size, out, d_inputs_arr, i, std::forward<Args>(args)...); out, reinterpret_cast<const Tdata **>(d_inputs_arr), i, std::forward<Args>(args)...);
} }
CHECK_STATUS(freeAllDevice((const void **)d_inputs_arr, d_bools, d_output_shape_strides,
info.input_size, d_input_shapes, d_input_strides, stream));
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
/** /**
* @brief Performs elementwise operations when inputs and the outputs have mixed data types (i.e., different dtypes). * @brief Executes an elementwise operation with mixed input and output data types.
* *
* @tparam BLOCK_SIZE The block size for the kernel launch. * @tparam BLOCK_SIZE CUDA block size used for kernel launch.
* @tparam N The number of input tensors. * @tparam N Number of input tensors.
* @tparam Op The operation to perform (e.g., addition, multiplication). * @tparam Op Functor representing the elementwise operation.
* @tparam Tout The output data type. * @tparam Tout Data type of the output tensor.
* @tparam Tin The input data types. * @tparam Tin... Data types of the input tensors.
* @tparam Args Additional arguments to be passed to the operation. * @tparam Args Optional additional arguments passed to the operation.(UNUSED)
* @param info Structure containing elementwise operation information (size, shape, etc.). * @param info Metadata about the operation including shape, size, and dimensionality.
* @param output Pointer to the output memory where results will be stored. * @param workspace Temporary workspace used for storing metadata on device.
* @param inputs Vector of pointers to input tensors. * @param output Pointer to the output buffer.
* @param stream CUDA stream used for asynchronous execution. * @param inputs Vector of pointers to input buffers.
* @param args Additional arguments for the operation. * @param stream CUDA stream for asynchronous execution.
* @return infiniStatus_t Status indicating success or failure. * @param args Additional arguments forwarded to the operation.
* @return infiniStatus_t Returns success or failure status.
*/ */
template <size_t BLOCK_SIZE, size_t N, typename Op, typename Tout, typename... Tin, typename... Args, size_t... Is, template <size_t BLOCK_SIZE, size_t N, typename Op, typename Tout, typename... Tin, typename... Args,
std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int> = 0> std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int> = 0>
infiniStatus_t calculateImpl(const op::elementwise::ElementwiseInfo &info, infiniStatus_t calculateImpl(const op::elementwise::ElementwiseInfo &info,
void *workspace,
void *output, void *output,
const std::vector<const void *> &inputs, const std::vector<const void *> &inputs,
std::index_sequence<Is...>,
cudaStream_t stream, cudaStream_t stream,
Args &&...args) { Args &&...args) {
if (info.output_size == 0) { auto output_size = info.getOutputSize();
if (output_size == 0) {
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
Tout *out = reinterpret_cast<Tout *>(output); Tout *out = reinterpret_cast<Tout *>(output);
// Store input pointers with the correct types
const std::tuple<const Tin *...> inputs_arr{reinterpret_cast<const Tin *>(inputs[Is])...};
const void **d_inputs_arr = nullptr; const void **d_inputs_arr = nullptr;
// Create array of input pointers on host (void*) to copy to device
const void *host_input_ptrs[] = {reinterpret_cast<const void *>(std::get<Is>(inputs_arr))...};
CHECK_CUDA(cudaMallocAsync(&d_inputs_arr, N * sizeof(void *), stream));
CHECK_CUDA(cudaMemcpyAsync(d_inputs_arr, host_input_ptrs, N * sizeof(void *), cudaMemcpyHostToDevice, stream));
// Device pointers // Device pointers
const bool *d_bools = nullptr;
const bool *d_input_contiguous = nullptr; const bool *d_input_contiguous = nullptr;
const bool *d_input_broadcasted = nullptr; const bool *d_input_broadcasted = nullptr;
const int8_t *d_output_shape_strides = nullptr;
const size_t *d_output_shape = nullptr; const size_t *d_output_shape = nullptr;
const ptrdiff_t *d_output_strides = nullptr; const ptrdiff_t *d_output_strides = nullptr;
const size_t **d_input_shapes = nullptr; const size_t *d_input_shapes = nullptr;
const ptrdiff_t **d_input_strides = nullptr; const ptrdiff_t *d_input_strides = nullptr;
std::vector<const size_t *> tmp_device_ptrs(info.input_size);
std::vector<const ptrdiff_t *> tmp_device_ptrs_strides(info.input_size);
CHECK_STATUS(infoToDevice<N>(info, d_bools, d_input_contiguous, CHECK_STATUS(infoToDevice<N>(info, workspace, inputs.data(), d_inputs_arr, d_input_contiguous, d_input_broadcasted,
d_input_broadcasted, d_output_shape_strides, d_output_shape, d_output_shape, d_output_strides, d_input_shapes, d_input_strides, stream));
d_output_strides, tmp_device_ptrs, d_input_shapes, tmp_device_ptrs_strides,
d_input_strides, stream));
dim3 blockDims(std::min(BLOCK_SIZE, static_cast<size_t>(internal->maxThreadsPerBlock()))); dim3 blockDims(std::min(BLOCK_SIZE, static_cast<size_t>(internal->maxThreadsPerBlock())));
dim3 gridDims(std::min(CEIL_DIV(info.output_size, blockDims.x), static_cast<size_t>(internal->gridSizeX()))); dim3 gridDims(std::min(CEIL_DIV(output_size, blockDims.x), static_cast<size_t>(internal->gridSizeX())));
size_t step = gridDims.x * blockDims.x; size_t step = gridDims.x * blockDims.x;
for (size_t i = 0; i < info.output_size; i += step) { for (size_t i = 0; i < output_size; i += step) {
elementwise_kernel<Op, Tout, Tin...><<<gridDims, blockDims, 0, stream>>>( elementwiseKernel<Op, Tout, Tin...><<<gridDims, blockDims, 0, stream>>>(
info.output_size, output_size,
info.ndim, info.getNdim(),
info.output_contiguous, info.isOutputContiguous(),
d_input_contiguous, d_input_contiguous,
d_input_broadcasted, d_input_broadcasted,
d_output_shape, d_output_shape,
d_input_shapes, d_input_shapes,
d_output_strides, d_output_strides,
d_input_strides, d_input_strides,
info.input_size, out, reinterpret_cast<const void **>(d_inputs_arr), i); out, reinterpret_cast<const void **>(d_inputs_arr), i);
} }
CHECK_STATUS(freeAllDevice(d_inputs_arr, d_bools, d_output_shape_strides, info.input_size, d_input_shapes, d_input_strides, stream));
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
private: private:
/** /**
* @brief Transfers elementwise kernel metadata (shapes, strides, flags) from host to device. * @brief Transfers elementwise operation metadata and input pointers from host to device memory.
* *
* @tparam N Number of inputs. * @tparam N Number of input tensors.
* @param info Structure containing input/output metadata. * @param info Elementwise operation metadata (shapes, strides, flags, etc.).
* @param d_bools Device pointer for input_contiguous and input_broadcasted flags. * @param workspace Pointer to device workspace memory for storing metadata and input pointers.
* @param d_input_contiguous Device pointer to input contiguity flags. * @param h_inputs_arr Host array of input tensor pointers.
* @param d_input_broadcasted Device pointer to input broadcasting flags. * @param d_inputs_arr Output reference to device array of input tensor pointers.
* @param d_output_shape_strides Device buffer containing both output shape and strides. * @param d_input_contiguous Output reference to device array indicating whether each input is contiguous.
* @param d_output_shape Device pointer to output shape. * @param d_input_broadcasted Output reference to device array indicating whether each input is broadcasted.
* @param d_output_strides Device pointer to output strides. * @param d_output_shape Output reference to device array holding the output tensor shape.
* @param tmp_device_ptrs Temporary device pointers for input shapes. * @param d_output_strides Output reference to device array holding output tensor strides.
* @param d_input_shapes Device array of pointers to input shapes. * @param d_input_shapes Output reference to flattened input tensor shapes (N * ndim).
* @param tmp_device_ptrs_strides Temporary device pointers for input strides. * @param d_input_strides Output reference to flattened input tensor strides (N * ndim).
* @param d_input_strides Device array of pointers to input strides. * @param stream CUDA stream used for asynchronous memory transfer.
* @param stream CUDA stream for async allocation and transfers. * @return infiniStatus_t Status indicating success or failure of the memory transfer and setup.
* @return infiniStatus_t Status indicating success or failure.
*/ */
template <size_t N> template <size_t N>
infiniStatus_t infoToDevice( infiniStatus_t infoToDevice(
const op::elementwise::ElementwiseInfo &info, const op::elementwise::ElementwiseInfo &info,
const bool *&d_bools, void *workspace,
const void *const *h_inputs_arr,
const void **&d_inputs_arr,
const bool *&d_input_contiguous, const bool *&d_input_contiguous,
const bool *&d_input_broadcasted, const bool *&d_input_broadcasted,
const int8_t *&d_output_shape_strides,
const size_t *&d_output_shape, const size_t *&d_output_shape,
const ptrdiff_t *&d_output_strides, const ptrdiff_t *&d_output_strides,
std::vector<const size_t *> &tmp_device_ptrs, const size_t *&d_input_shapes,
const size_t **&d_input_shapes, const ptrdiff_t *&d_input_strides,
std::vector<const ptrdiff_t *> &tmp_device_ptrs_strides,
const ptrdiff_t **&d_input_strides,
cudaStream_t stream) const { cudaStream_t stream) const {
CHECK_CUDA(cudaMallocAsync(&d_bools, 2 * info.input_size * sizeof(*d_bools), stream)); constexpr auto input_size = N;
CHECK_CUDA(cudaMemcpyAsync((void *)d_bools, info.input_contiguous, info.input_size * sizeof(*d_bools), cudaMemcpyHostToDevice, stream)); const auto ndim = info.getNdim();
CHECK_CUDA(cudaMemcpyAsync((void *)(d_bools + info.input_size), info.input_broadcasted, info.input_size * sizeof(*d_bools), cudaMemcpyHostToDevice, stream)); constexpr auto input_arr_size = N * sizeof(*h_inputs_arr);
const int8_t *info_meta_start = info.getMetaStart();
CHECK_CUDA(cudaMallocAsync(&d_output_shape_strides, info.ndim * (sizeof(*d_output_shape) + sizeof(*d_output_strides)), stream)); const int8_t *d_meta_start = reinterpret_cast<int8_t *>(workspace) + input_arr_size;
CHECK_CUDA(cudaMemcpyAsync((void *)d_output_shape_strides, info.output_shape, info.ndim * sizeof(*d_output_shape), cudaMemcpyHostToDevice, stream));
CHECK_CUDA(cudaMemcpyAsync((void *)(d_output_shape_strides + info.ndim * sizeof(*d_output_shape)), info.output_strides, info.ndim * sizeof(*d_output_strides), cudaMemcpyHostToDevice, stream)); // copy the input pointer array and meta to device
CHECK_CUDA(cudaMemcpyAsync(workspace, h_inputs_arr, input_arr_size, cudaMemcpyHostToDevice, stream));
CHECK_CUDA(cudaMallocAsync(&d_input_shapes, info.input_size * sizeof(*d_input_shapes), stream)); CHECK_CUDA(cudaMemcpyAsync((void *)d_meta_start, info_meta_start, info.getMetaMemSize(), cudaMemcpyHostToDevice, stream));
for (size_t i = 0; i < info.input_size; ++i) {
CHECK_CUDA(cudaMallocAsync(&tmp_device_ptrs[i], info.ndim * sizeof(*&tmp_device_ptrs[i]), stream)); // offset/assign the pointers
CHECK_CUDA(cudaMemcpyAsync((void *)tmp_device_ptrs[i], info.input_shapes[i], d_inputs_arr = reinterpret_cast<const void **>(workspace);
info.ndim * sizeof(*tmp_device_ptrs[i]), cudaMemcpyHostToDevice, stream)); d_output_shape = reinterpret_cast<const size_t *>(d_meta_start);
} d_output_strides = reinterpret_cast<const ptrdiff_t *>(d_output_shape + ndim);
CHECK_CUDA(cudaMemcpyAsync((void *)d_input_shapes, tmp_device_ptrs.data(), d_input_shapes = reinterpret_cast<const size_t *>(d_output_strides + ndim);
info.input_size * sizeof(*d_input_shapes), cudaMemcpyHostToDevice, stream)); d_input_strides = reinterpret_cast<const ptrdiff_t *>(d_input_shapes + input_size * ndim);
d_input_contiguous = reinterpret_cast<const bool *>(d_input_strides + input_size * ndim);
CHECK_CUDA(cudaMallocAsync(&d_input_strides, info.input_size * sizeof(*d_input_strides), stream)); d_input_broadcasted = reinterpret_cast<const bool *>(d_input_contiguous + input_size);
for (size_t i = 0; i < info.input_size; ++i) {
CHECK_CUDA(cudaMallocAsync(&tmp_device_ptrs_strides[i], info.ndim * sizeof(*tmp_device_ptrs_strides[i]), stream));
CHECK_CUDA(cudaMemcpyAsync((void *)tmp_device_ptrs_strides[i], info.input_strides[i],
info.ndim * sizeof(*tmp_device_ptrs_strides[i]), cudaMemcpyHostToDevice, stream));
}
CHECK_CUDA(cudaMemcpyAsync((void *)d_input_strides, tmp_device_ptrs_strides.data(),
info.input_size * sizeof(*d_input_strides), cudaMemcpyHostToDevice, stream));
d_input_contiguous = d_bools;
d_input_broadcasted = d_bools + info.input_size;
d_output_shape = reinterpret_cast<const size_t *>(d_output_shape_strides);
d_output_strides = reinterpret_cast<const ptrdiff_t *>(d_output_shape_strides + info.ndim * sizeof(*d_output_shape));
return INFINI_STATUS_SUCCESS;
}
/**
* @brief Frees all device-allocated memory used for metadata in elementwise kernel execution.
*
* @param d_inputs_arr Device array of input pointers.
* @param d_bools Device memory holding input flags.
* @param d_output_shape_strides Device buffer holding output shape and strides.
* @param input_size Number of input tensors.
* @param d_input_shapes Device array of input shape pointers.
* @param d_input_strides Device array of input stride pointers.
* @param stream CUDA stream for async deallocation.
* @return infiniStatus_t Status indicating success or failure.
*/
inline infiniStatus_t freeAllDevice(const void **d_inputs_arr,
const bool *d_bools,
const int8_t *d_output_shape_strides,
const size_t input_size,
const size_t **d_input_shapes,
const ptrdiff_t **d_input_strides,
cudaStream_t stream) const {
CHECK_CUDA(cudaFreeAsync((void *)d_inputs_arr, stream));
CHECK_CUDA(cudaFreeAsync((void *)d_bools, stream));
CHECK_CUDA(cudaFreeAsync((void *)d_output_shape_strides, stream));
CHECK_CUDA(cudaFreeAsync((void *)d_input_shapes, stream));
CHECK_CUDA(cudaFreeAsync((void *)d_input_strides, stream));
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
}; };
...@@ -476,6 +404,7 @@ infiniStatus_t DeviceImpl::create(DeviceImpl **device_info, ...@@ -476,6 +404,7 @@ infiniStatus_t DeviceImpl::create(DeviceImpl **device_info,
template <unsigned int BLOCK_SIZE, typename Op, typename Tout, typename... Tin, typename... Args, template <unsigned int BLOCK_SIZE, typename Op, typename Tout, typename... Tin, typename... Args,
std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int>> std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int>>
infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info, infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info,
void *workspace,
void *output, void *output,
const std::vector<const void *> &inputs, const std::vector<const void *> &inputs,
void *stream, void *stream,
...@@ -483,8 +412,7 @@ infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &inf ...@@ -483,8 +412,7 @@ infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &inf
constexpr size_t N = Op::num_inputs; constexpr size_t N = Op::num_inputs;
static_assert(sizeof...(Tin) == N, "Input type count mismatch"); static_assert(sizeof...(Tin) == N, "Input type count mismatch");
return _opaque->calculateImpl<BLOCK_SIZE, N, Op, Tout, Tin...>( return _opaque->calculateImpl<BLOCK_SIZE, N, Op, Tout, Tin...>(
info, output, inputs, info, workspace, output, inputs,
std::make_index_sequence<N>{},
reinterpret_cast<cudaStream_t>(stream), reinterpret_cast<cudaStream_t>(stream),
std::forward<Args>(args)...); std::forward<Args>(args)...);
} }
...@@ -492,14 +420,14 @@ infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &inf ...@@ -492,14 +420,14 @@ infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &inf
/* Invoke elementwise operation when all inputs have the same dtype */ /* Invoke elementwise operation when all inputs have the same dtype */
template <unsigned int BLOCK_SIZE, typename Op, typename Tdata, typename... Args> template <unsigned int BLOCK_SIZE, typename Op, typename Tdata, typename... Args>
infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info, infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info,
void *workspace,
void *output, void *output,
const std::vector<const void *> &inputs, const std::vector<const void *> &inputs,
void *stream, void *stream,
Args &&...args) { Args &&...args) {
constexpr size_t N = Op::num_inputs; constexpr size_t N = Op::num_inputs;
return _opaque->calculateImpl<BLOCK_SIZE, N, Op, Tdata>( return _opaque->calculateImpl<BLOCK_SIZE, N, Op, Tdata>(
info, output, inputs, info, workspace, output, inputs,
std::make_index_sequence<N>{},
reinterpret_cast<cudaStream_t>(stream), reinterpret_cast<cudaStream_t>(stream),
std::forward<Args>(args)...); std::forward<Args>(args)...);
} }
......
...@@ -31,6 +31,7 @@ public: ...@@ -31,6 +31,7 @@ public:
* @tparam Args... Additional arguments passed to the operation. * @tparam Args... Additional arguments passed to the operation.
* *
* @param info Metadata describing tensor shapes, strides, etc. * @param info Metadata describing tensor shapes, strides, etc.
* @param workspace Pointer to workspace buffer on device.
* @param output Pointer to output buffer on device. * @param output Pointer to output buffer on device.
* @param inputs Vector of input pointers (device memory). * @param inputs Vector of input pointers (device memory).
* @param stream CUDA stream (opaque void*). * @param stream CUDA stream (opaque void*).
...@@ -40,6 +41,7 @@ public: ...@@ -40,6 +41,7 @@ public:
template <unsigned int BLOCK_SIZE, typename Op, typename Tdata, typename... Args> template <unsigned int BLOCK_SIZE, typename Op, typename Tdata, typename... Args>
infiniStatus_t calculate( infiniStatus_t calculate(
const op::elementwise::ElementwiseInfo &info, const op::elementwise::ElementwiseInfo &info,
void *workspace,
void *output, void *output,
const std::vector<const void *> &inputs, const std::vector<const void *> &inputs,
void *stream, void *stream,
...@@ -56,6 +58,7 @@ public: ...@@ -56,6 +58,7 @@ public:
* @tparam Tin... Input data types (must match Op::num_inputs). * @tparam Tin... Input data types (must match Op::num_inputs).
* @tparam Args... Additional arguments passed to the operation. * @tparam Args... Additional arguments passed to the operation.
* @param info Metadata describing tensor shapes, strides, etc. * @param info Metadata describing tensor shapes, strides, etc.
* @param workspace Pointer to workspace buffer on device.
* @param output Pointer to output buffer on device. * @param output Pointer to output buffer on device.
* @param inputs Vector of input pointers (device memory). * @param inputs Vector of input pointers (device memory).
* @param stream CUDA stream (opaque void*). * @param stream CUDA stream (opaque void*).
...@@ -67,6 +70,7 @@ public: ...@@ -67,6 +70,7 @@ public:
std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int> = 0> std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int> = 0>
infiniStatus_t calculate( infiniStatus_t calculate(
const op::elementwise::ElementwiseInfo &info, const op::elementwise::ElementwiseInfo &info,
void *workspace,
void *output, void *output,
const std::vector<const void *> &inputs, const std::vector<const void *> &inputs,
void *stream, void *stream,
...@@ -82,14 +86,17 @@ public: ...@@ -82,14 +86,17 @@ public:
\ \
auto info_result = op::elementwise::ElementwiseInfo::create(out_desc, input_desc); \ auto info_result = op::elementwise::ElementwiseInfo::create(out_desc, input_desc); \
CHECK_RESULT(info_result); \ CHECK_RESULT(info_result); \
auto info = info_result.take(); \
auto workspace_size = info.getMetaMemSize() + info.getInputSize() * sizeof(void *); \
\ \
op::elementwise::cuda::DeviceImpl *device_impl; \ op::elementwise::cuda::DeviceImpl *device_impl; \
CHECK_STATUS(op::elementwise::cuda::DeviceImpl::create(&device_impl, handle->internal())); \ CHECK_STATUS(op::elementwise::cuda::DeviceImpl::create(&device_impl, handle->internal())); \
\ \
*desc_ptr = new Descriptor( \ *desc_ptr = new Descriptor( \
dtype, \ dtype, \
std::move(info_result.take()), \ std::move(info), \
device_impl, \ device_impl, \
workspace_size, \
handle->device, \ handle->device, \
handle->device_id); handle->device_id);
......
...@@ -19,21 +19,26 @@ ...@@ -19,21 +19,26 @@
infiniDtype_t _dtype; \ infiniDtype_t _dtype; \
op::elementwise::ElementwiseInfo _info; \ op::elementwise::ElementwiseInfo _info; \
std::unique_ptr<op::elementwise::NAMESPACE::DeviceImpl> _device_info; \ std::unique_ptr<op::elementwise::NAMESPACE::DeviceImpl> _device_info; \
size_t _workspace_size; \
\ \
Descriptor( \ Descriptor( \
infiniDtype_t dtype, \ infiniDtype_t dtype, \
op::elementwise::ElementwiseInfo info, \ op::elementwise::ElementwiseInfo info, \
op::elementwise::NAMESPACE::DeviceImpl *device_info, \ op::elementwise::NAMESPACE::DeviceImpl *device_info, \
size_t workspace_size, \
infiniDevice_t device_type, \ infiniDevice_t device_type, \
int device_id) \ int device_id) \
: InfiniopDescriptor{device_type, device_id}, \ : InfiniopDescriptor{device_type, device_id}, \
_dtype(dtype), \ _dtype(dtype), \
_info(std::move(info)), \ _info(std::move(info)), \
_device_info(device_info) {} \ _device_info(device_info), \
_workspace_size(workspace_size) {} \
\ \
public: \ public: \
~Descriptor(); \ ~Descriptor(); \
\ \
size_t workspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \ static infiniStatus_t create( \
infiniopHandle_t handle, \ infiniopHandle_t handle, \
Descriptor **desc_ptr, \ Descriptor **desc_ptr, \
...@@ -41,6 +46,7 @@ ...@@ -41,6 +46,7 @@
std::vector<infiniopTensorDescriptor_t> input_descs); \ std::vector<infiniopTensorDescriptor_t> input_descs); \
\ \
infiniStatus_t calculate( \ infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
void *output, \ void *output, \
std::vector<const void *> inputs, \ std::vector<const void *> inputs, \
void *stream) const; \ void *stream) const; \
...@@ -62,57 +68,70 @@ namespace op::elementwise { ...@@ -62,57 +68,70 @@ namespace op::elementwise {
*/ */
struct ElementwiseInfo { struct ElementwiseInfo {
private: private:
ElementwiseInfo() = default; std::vector<int8_t> _meta;
size_t _output_size;
size_t _input_size;
size_t _ndim;
bool _output_contiguous;
ElementwiseInfo(std::vector<int8_t> meta,
size_t output_size,
size_t input_size,
size_t ndim,
bool output_contiguous)
: _meta(std::move(meta)), _output_size(output_size),
_input_size(input_size), _ndim(ndim),
_output_contiguous(output_contiguous) {}
public: public:
size_t output_size; inline size_t getMetaMemSize() const {
size_t ndim; return _meta.size();
bool output_contiguous; }
bool *input_contiguous; inline const int8_t *getMetaStart() const {
bool *input_broadcasted; return _meta.data();
size_t *output_shape; }
size_t **input_shapes; inline size_t getOutputSize() const {
ptrdiff_t *output_strides; return _output_size;
ptrdiff_t **input_strides; }
size_t input_size; inline size_t getInputSize() const {
return _input_size;
~ElementwiseInfo() { }
delete[] input_contiguous; inline size_t getNdim() const {
delete[] input_broadcasted; return _ndim;
delete[] output_shape; }
delete[] output_strides; inline bool isOutputContiguous() const {
return _output_contiguous;
for (size_t i = 0; i < input_size; ++i) { }
delete[] input_shapes[i]; inline const size_t *getOutputShape() const {
delete[] input_strides[i]; return reinterpret_cast<const size_t *>(_meta.data());
}
inline const ptrdiff_t *getOutputStrides() const {
return reinterpret_cast<const ptrdiff_t *>(getOutputShape() + _ndim);
}
inline const size_t *getAllInputShapes() const {
return reinterpret_cast<const size_t *>(getOutputStrides() + _ndim);
}
inline const size_t *getInputShape(const size_t &index) const {
if (index < _input_size) {
return reinterpret_cast<const size_t *>(getAllInputShapes() + index * _ndim);
}
return nullptr;
}
inline const ptrdiff_t *getAllInputStrides() const {
return reinterpret_cast<const ptrdiff_t *>(getAllInputShapes() + _input_size * _ndim);
}
inline const ptrdiff_t *getInputStrides(const size_t &index) const {
if (index < _input_size) {
return reinterpret_cast<const ptrdiff_t *>(getAllInputStrides() + index * _ndim);
} }
delete[] input_shapes; return nullptr;
delete[] input_strides; }
} inline const bool *getInputContiguous() const {
return reinterpret_cast<const bool *>(getAllInputStrides() + _input_size * _ndim);
ElementwiseInfo(ElementwiseInfo &&other) noexcept }
: output_size(other.output_size), inline const bool *getInputBroadcasted() const {
ndim(other.ndim), return reinterpret_cast<const bool *>(getInputContiguous() + _input_size);
output_contiguous(other.output_contiguous), }
input_contiguous(other.input_contiguous),
input_broadcasted(other.input_broadcasted),
output_shape(other.output_shape),
input_shapes(other.input_shapes),
output_strides(other.output_strides),
input_strides(other.input_strides),
input_size(other.input_size) {
other.input_contiguous = nullptr;
other.input_broadcasted = nullptr;
other.output_shape = nullptr;
other.input_shapes = nullptr;
other.output_strides = nullptr;
other.input_strides = nullptr;
other.input_size = 0;
}
ElementwiseInfo(const ElementwiseInfo &other) = delete;
ElementwiseInfo &operator=(const ElementwiseInfo &other) = delete;
ElementwiseInfo &operator=(ElementwiseInfo &&other) = delete;
using ResultType = utils::Result<ElementwiseInfo>; using ResultType = utils::Result<ElementwiseInfo>;
...@@ -136,40 +155,48 @@ public: ...@@ -136,40 +155,48 @@ public:
return INFINI_STATUS_BAD_TENSOR_STRIDES; return INFINI_STATUS_BAD_TENSOR_STRIDES;
} }
ElementwiseInfo info; auto input_size = input_descs.size();
info.input_size = input_descs.size(); auto ndim = output_desc->ndim();
info.ndim = output_desc->ndim(); auto output_size = output_desc->numel();
info.output_size = output_desc->numel(); auto output_contiguous = output_desc->isContiguous();
info.output_contiguous = output_desc->isContiguous();
// Allocate memory for meta
// Allocate memory for arrays auto shape_unit = output_desc->dim(0);
info.input_contiguous = new bool[info.input_size]; auto stride_unit = output_desc->stride(0);
info.input_broadcasted = new bool[info.input_size]; size_t meta_mem_size = ndim * (sizeof(shape_unit) + sizeof(stride_unit))
info.output_shape = new size_t[info.ndim]; + input_size * ndim * sizeof(shape_unit)
info.output_strides = new ptrdiff_t[info.ndim]; + input_size * ndim * sizeof(stride_unit)
info.input_shapes = new size_t *[info.input_size]; + 2 * input_size * sizeof(bool);
info.input_strides = new ptrdiff_t *[info.input_size]; std::vector<int8_t> meta(meta_mem_size);
int8_t *meta_ptr = meta.data();
// Fill arrays
const auto output_shape = output_desc->shape(); const auto output_shape = output_desc->shape();
const auto output_strides = output_desc->strides(); const auto output_strides = output_desc->strides();
std::memcpy(info.output_shape, output_shape.data(), info.ndim * sizeof(*info.output_shape));
std::memcpy(info.output_strides, output_strides.data(), info.ndim * sizeof(*info.output_strides));
for (size_t i = 0; i < info.input_size; ++i) { // Pointers to the sections within _meta
auto &desc = input_descs[i]; size_t *output_shape_p = reinterpret_cast<size_t *>(meta_ptr);
info.input_contiguous[i] = desc->isContiguous(); ptrdiff_t *output_strides_p = reinterpret_cast<ptrdiff_t *>(output_shape_p + ndim);
info.input_broadcasted[i] = !info.input_contiguous[i] && (desc->ndim() != info.ndim || desc->hasBroadcastDim()); size_t *input_shapes = reinterpret_cast<size_t *>(output_strides_p + ndim);
ptrdiff_t *input_strides = reinterpret_cast<ptrdiff_t *>(input_shapes + input_size * ndim);
bool *input_contiguous = reinterpret_cast<bool *>(input_strides + input_size * ndim);
bool *input_broadcasted = input_contiguous + input_size;
info.input_shapes[i] = new size_t[desc->ndim()]; // Copy output shape and strides
const auto &in_shape = desc->shape(); std::memcpy(output_shape_p, output_shape.data(), ndim * sizeof(*output_shape_p));
std::memcpy(info.input_shapes[i], in_shape.data(), desc->ndim() * sizeof(*info.input_shapes[i])); std::memcpy(output_strides_p, output_strides.data(), ndim * sizeof(*output_strides_p));
info.input_strides[i] = new ptrdiff_t[desc->ndim()]; // Copy input shapes, strides, contiguous, and broadcasted flags
const auto &in_strides = desc->strides(); for (size_t i = 0; i < input_size; ++i) {
std::memcpy(info.input_strides[i], in_strides.data(), desc->ndim() * sizeof(*info.input_strides[i])); auto &desc = input_descs[i];
const auto in_shape = desc->shape();
const auto in_strides = desc->strides();
std::memcpy(input_shapes + i * ndim, in_shape.data(), ndim * sizeof(*input_shapes));
std::memcpy(input_strides + i * ndim, in_strides.data(), ndim * sizeof(*input_strides));
input_contiguous[i] = desc->isContiguous();
input_broadcasted[i] = !input_contiguous[i] && (desc->ndim() != ndim || desc->hasBroadcastDim());
} }
ElementwiseInfo info(std::move(meta), output_size, input_size, ndim, output_contiguous);
return ResultType(std::move(info)); return ResultType(std::move(info));
} }
}; };
......
...@@ -30,6 +30,8 @@ infiniStatus_t Descriptor::create( ...@@ -30,6 +30,8 @@ infiniStatus_t Descriptor::create(
} }
infiniStatus_t Descriptor::calculate( infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *output, void *output,
std::vector<const void *> inputs, std::vector<const void *> inputs,
void *stream) const { void *stream) const {
......
...@@ -10,7 +10,7 @@ typedef struct SwiGLUOp { ...@@ -10,7 +10,7 @@ typedef struct SwiGLUOp {
private: private:
template <typename T> template <typename T>
T sigmoid(const T &x) const { T sigmoid(const T &x) const {
return 1 / (1 + std::exp(-x)); return T(1) / (T(1) + std::exp(-x));
} }
public: public:
......
...@@ -21,9 +21,7 @@ infiniStatus_t Descriptor::create( ...@@ -21,9 +21,7 @@ infiniStatus_t Descriptor::create(
const auto &gate_shape = gate_desc->shape(); const auto &gate_shape = gate_desc->shape();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64); CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64);
if (!SAME_VEC(out_shape, up_shape, gate_shape)) { CHECK_SAME_SHAPE(out_shape, up_shape, gate_shape);
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
// create CUDA elementwise descriptor // create CUDA elementwise descriptor
CREATE_ELEMENTWISE_CUDA_DESCRIPTOR CREATE_ELEMENTWISE_CUDA_DESCRIPTOR
...@@ -32,17 +30,23 @@ infiniStatus_t Descriptor::create( ...@@ -32,17 +30,23 @@ infiniStatus_t Descriptor::create(
} }
infiniStatus_t Descriptor::calculate( infiniStatus_t Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *output, void *output,
std::vector<const void *> inputs, std::vector<const void *> inputs,
void *stream) const { void *stream) const {
if (workspace_size < _workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
switch (_dtype) { switch (_dtype) {
case INFINI_DTYPE_F16: case INFINI_DTYPE_F16:
return _device_info->calculate<256, SwiGLUOp, half>(_info, output, inputs, stream); return _device_info->calculate<256, SwiGLUOp, half>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F32: case INFINI_DTYPE_F32:
return _device_info->calculate<256, SwiGLUOp, float>(_info, output, inputs, stream); return _device_info->calculate<256, SwiGLUOp, float>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F64: case INFINI_DTYPE_F64:
return _device_info->calculate<256, SwiGLUOp, double>(_info, output, inputs, stream); return _device_info->calculate<256, SwiGLUOp, double>(_info, workspace, output, inputs, stream);
default: default:
return INFINI_STATUS_BAD_TENSOR_DTYPE; return INFINI_STATUS_BAD_TENSOR_DTYPE;
} }
......
...@@ -66,8 +66,49 @@ __C infiniStatus_t infiniopCreateSwiGLUDescriptor( ...@@ -66,8 +66,49 @@ __C infiniStatus_t infiniopCreateSwiGLUDescriptor(
#undef CREATE #undef CREATE
} }
__C infiniStatus_t infiniopGetSwiGLUWorkspaceSize(infiniopSwiGLUDescriptor_t desc, size_t *size) {
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<op::swiglu::NAMESPACE::Descriptor *>(desc)->workspaceSize(); \
return INFINI_STATUS_SUCCESS;
switch (desc->device_type) {
#ifdef ENABLE_CPU_API
GET(INFINI_DEVICE_CPU, cpu)
#endif
#ifdef ENABLE_CUDA_API
GET(INFINI_DEVICE_NVIDIA, cuda)
#endif
#ifdef ENABLE_CAMBRICON_MLU
case DevCambriconMlu: {
return bangGetSwiGLUWorkspaceSize((SwiGLUBangDescriptor_t)desc, size);
}
#endif
#ifdef ENABLE_ASCEND_API
GET(INFINI_DEVICE_ASCEND, ascend)
#endif
#ifdef ENABLE_METAX_GPU
case DevMetaxGpu: {
return macaGetSwiGLUWorkspaceSize((SwiGLUMacaDescriptor_t)desc, size);
}
#endif
#ifdef ENABLE_MTHREADS_GPU
case DevMthreadsGpu: {
return musaGetSwiGLUWorkspaceSize((SwiGLUMusaDescriptor_t)desc, size);
}
#endif
}
#undef GET
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
__C infiniStatus_t infiniopSwiGLU( __C infiniStatus_t infiniopSwiGLU(
infiniopSwiGLUDescriptor_t desc, infiniopSwiGLUDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *c, void *c,
const void *a, const void *a,
const void *b, const void *b,
...@@ -76,7 +117,7 @@ __C infiniStatus_t infiniopSwiGLU( ...@@ -76,7 +117,7 @@ __C infiniStatus_t infiniopSwiGLU(
#define CALCULATE(CASE, NAMESPACE) \ #define CALCULATE(CASE, NAMESPACE) \
case CASE: \ case CASE: \
return reinterpret_cast<const op::swiglu::NAMESPACE::Descriptor *>(desc) \ return reinterpret_cast<const op::swiglu::NAMESPACE::Descriptor *>(desc) \
->calculate(c, {a, b}, stream) ->calculate(workspace, workspace_size, c, {a, b}, stream)
switch (desc->device_type) { switch (desc->device_type) {
......
import torch import torch
import ctypes import ctypes
from ctypes import POINTER, Structure, c_int32, c_void_p from ctypes import POINTER, Structure, c_int32, c_void_p, c_uint64
from libinfiniop import ( from libinfiniop import (
infiniopHandle_t, infiniopHandle_t,
infiniopTensorDescriptor_t, infiniopTensorDescriptor_t,
...@@ -14,6 +14,7 @@ from libinfiniop import ( ...@@ -14,6 +14,7 @@ from libinfiniop import (
debug, debug,
get_tolerance, get_tolerance,
profile_operation, profile_operation,
create_workspace
) )
from enum import Enum, auto from enum import Enum, auto
...@@ -160,10 +161,19 @@ def test( ...@@ -160,10 +161,19 @@ def test(
for tensor in [a_tensor, b_tensor, c_tensor]: for tensor in [a_tensor, b_tensor, c_tensor]:
tensor.destroyDesc(lib) tensor.destroyDesc(lib)
workspace_size = c_uint64(0)
check_error(
lib.infiniopGetSwiGLUWorkspaceSize(descriptor, ctypes.byref(workspace_size))
)
workspace = create_workspace(workspace_size.value, c.device)
def lib_swiglu(): def lib_swiglu():
check_error( check_error(
lib.infiniopSwiGLU( lib.infiniopSwiGLU(
descriptor, c_tensor.data, a_tensor.data, b_tensor.data, None descriptor,
workspace.data_ptr() if workspace is not None else None,
workspace_size.value,
c_tensor.data, a_tensor.data, b_tensor.data, None
) )
) )
...@@ -196,10 +206,18 @@ if __name__ == "__main__": ...@@ -196,10 +206,18 @@ if __name__ == "__main__":
infiniopTensorDescriptor_t, infiniopTensorDescriptor_t,
] ]
lib.infiniopGetSwiGLUWorkspaceSize.restype = c_int32
lib.infiniopGetSwiGLUWorkspaceSize.argtypes = [
infiniopSwiGLUDescriptor_t,
POINTER(c_uint64),
]
lib.infiniopSwiGLU.restype = c_int32 lib.infiniopSwiGLU.restype = c_int32
lib.infiniopSwiGLU.argtypes = [ lib.infiniopSwiGLU.argtypes = [
infiniopSwiGLUDescriptor_t, infiniopSwiGLUDescriptor_t,
c_void_p, c_void_p,
c_uint64,
c_void_p,
c_void_p, c_void_p,
c_void_p, c_void_p,
c_void_p, c_void_p,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment