Commit 1d182fba authored by Zimin Li's avatar Zimin Li
Browse files

issue/127: Optimize elementwise CUDA code by removing redundancy,...

issue/127: Optimize elementwise CUDA code by removing redundancy, change/correct kernel logic when all inputs have the same dtype
parent 9cc0c416
...@@ -9,16 +9,74 @@ ...@@ -9,16 +9,74 @@
namespace op::elementwise::cuda { namespace op::elementwise::cuda {
/** /**
* @brief Helper device function to expand a compile-time index sequence into individual constants * @brief Casts an untyped device pointer to a typed pointer of type T.
* and pass them to a lambda.
* *
* @tparam Lambda Type of the lambda function to invoke. * @tparam T Desired pointer type.
* @tparam Is Index sequence values (automatically deduced). * @param ptr Untyped pointer.
* @param lambda Lambda to be called with std::integral_constant<size_t, Is>... as arguments. * @return Pointer of type const T*.
*/ */
template <typename Lambda, size_t... Is> template <typename T>
__device__ __forceinline__ void callExpand(Lambda lambda, std::index_sequence<Is...>) { __device__ __forceinline__ const T *typedInputPtr(const void *ptr) {
lambda(std::integral_constant<size_t, Is>{}...); return reinterpret_cast<const T *>(ptr);
}
/**
* @brief Computes the output index in memory, accounting for strides if non-contiguous.
*
* @param idx Linear index.
* @param is_contiguous Whether the output tensor is contiguous.
* @param ndim Number of dimensions.
* @param shape Shape of the output tensor.
* @param strides Strides of the output tensor.
* @return Memory offset index.
*/
__device__ __forceinline__ size_t getOutputIndex(size_t idx, bool is_contiguous, size_t ndim,
const size_t *shape, const ptrdiff_t *strides) {
return is_contiguous ? idx : device::cuda::indexToOffset(idx, ndim, shape, strides);
}
/**
* @brief Computes input element offset for broadcasting and strided access.
*
* Used to map a linear output index to the corresponding index in an input tensor,
* considering contiguity and broadcasting.
*/
struct InputIndexer {
size_t idx;
size_t ndim;
const bool *input_contiguous;
const bool *input_broadcasted;
const size_t *input_shapes;
const ptrdiff_t *input_strides;
const ptrdiff_t *output_strides;
/**
* @brief Computes the memory offset for a given input tensor at current index.
*
* @param input_id ID of the input tensor.
* @return Offset into the input tensor.
*/
__device__ __forceinline__ size_t operator()(size_t input_id) const {
return input_contiguous[input_id]
? idx
: (input_broadcasted[input_id]
? device::cuda::indexToReducedOffset(idx, ndim, output_strides, input_strides + input_id * ndim)
: device::cuda::indexToOffset(idx, ndim, input_shapes + input_id * ndim, input_strides + input_id * ndim));
}
};
/**
* @brief Invokes a callable with compile-time index constants.
*
* Used to unpack index sequence for variadic template processing of inputs.
*
* @tparam F Callable type.
* @tparam Is Compile-time index sequence.
* @param f Callable to invoke with index constants.
*/
template <typename F, size_t... Is>
__device__ __forceinline__ void unpackInputsAndApply(F &&f, std::index_sequence<Is...>) {
f(std::integral_constant<size_t, Is>{}...);
} }
/** /**
...@@ -54,96 +112,25 @@ INFINIOP_CUDA_KERNEL elementwiseKernel( ...@@ -54,96 +112,25 @@ INFINIOP_CUDA_KERNEL elementwiseKernel(
const ptrdiff_t *__restrict__ output_strides, const ptrdiff_t *__restrict__ output_strides,
const ptrdiff_t *__restrict__ input_strides, const ptrdiff_t *__restrict__ input_strides,
Tdata *output, Tdata *output,
const Tdata *const *inputs, const void *const *inputs,
size_t offset, size_t offset,
Args... args) { Args... args) {
size_t idx = blockIdx.x * blockDim.x + threadIdx.x + offset; size_t idx = blockIdx.x * blockDim.x + threadIdx.x + offset;
if (idx < output_size) { if (idx < output_size) {
size_t out_idx = output_contiguous ? idx const Tdata *const *typed_inputs = reinterpret_cast<const Tdata *const *>(inputs);
: device::cuda::indexToOffset(idx, ndim, output_shape, output_strides); size_t out_idx = getOutputIndex(idx, output_contiguous, ndim, output_shape, output_strides);
InputIndexer indexer{idx, ndim, input_contiguous, input_broadcasted, input_shapes, input_strides, output_strides};
auto get_input_idx = [&] __device__(size_t input_id) {
return input_contiguous[input_id] ? idx unpackInputsAndApply(
: (input_broadcasted[input_id] [&](auto... Is) {
? device::cuda::indexToReducedOffset(idx, ndim, output_strides, input_strides + input_id * ndim) output[out_idx] = Op{}(typed_inputs[Is.value][indexer(Is.value)]..., std::forward<Args>(args)...);
: device::cuda::indexToOffset(idx, ndim, input_shapes + input_id * ndim, input_strides + input_id * ndim)); },
}; std::make_index_sequence<N>{});
// Use a helper to expand the index sequence into individual compile-time constants
auto expand_inputs = [&] __device__(auto... idxs) {
if constexpr (std::is_same_v<Tdata, fp16_t>) {
output[out_idx] = utils::cast<fp16_t>(
Op{}(utils::cast<float>(inputs[idxs.value][get_input_idx(idxs.value)])...,
std::forward<Args>(args)...));
} else {
output[out_idx] = Op{}(
inputs[idxs.value][get_input_idx(idxs.value)]...,
std::forward<Args>(args)...);
}
};
callExpand(expand_inputs, std::make_index_sequence<N>{});
} }
} }
/**
* @brief Casts an untyped device pointer to a typed pointer of type T.
*
* @tparam T Desired pointer type.
* @param ptr Untyped pointer.
* @return Pointer of type const T*.
*/
template <typename T>
__device__ inline const T *typedInputPtr(const void *ptr) {
return reinterpret_cast<const T *>(ptr);
}
/**
* @brief Launches elementwise operation at a specific output index.
*
* @tparam Op Functor representing the elementwise operation.
* @tparam Tout Output data type.
* @tparam Tin... Input data types.
* @tparam Is... Index sequence for unpacking variadic inputs.
* @param idx Global linear index into the output tensor.
* @param out_idx Offset into the output array.
* @param ndim Number of dimensions in the tensors.
* @param input_contiguous Flags indicating whether each input is contiguous.
* @param input_broadcasted Flags indicating whether each input is broadcasted.
* @param input_shapes Flattened input shapes (N * ndim).
* @param input_strides Flattened input strides (N * ndim).
* @param output_strides Output tensor strides.
* @param inputs Array of pointers to input tensors.
* @param output Pointer to output tensor.
* @param ...Is Index sequence for iterating over input tensors.
*/
template <typename Op, typename Tout, typename... Tin, size_t... Is>
__device__ void launchOp(
size_t idx,
size_t out_idx,
size_t ndim,
const bool *__restrict__ input_contiguous,
const bool *__restrict__ input_broadcasted,
const size_t *__restrict__ input_shapes,
const ptrdiff_t *__restrict__ input_strides,
const ptrdiff_t *__restrict__ output_strides,
const void *const *__restrict__ inputs,
Tout *output,
std::index_sequence<Is...>) {
auto get_input_idx = [&] __device__(size_t input_id) {
return input_contiguous[input_id]
? idx
: (input_broadcasted[input_id]
? device::cuda::indexToReducedOffset(idx, ndim, output_strides, input_strides + input_id * ndim)
: device::cuda::indexToOffset(idx, ndim, input_shapes + input_id * ndim, input_strides + input_id * ndim));
};
output[out_idx] = Op{}.template operator()<Tout, Tin...>(
(typedInputPtr<Tin>(inputs[Is])[get_input_idx(Is)])...);
}
/** /**
* @brief CUDA kernel for performing an elementwise operation on tensors with support * @brief CUDA kernel for performing an elementwise operation on tensors with support
* for broadcasting and mixed data types. * for broadcasting and mixed data types.
...@@ -180,26 +167,18 @@ INFINIOP_CUDA_KERNEL elementwiseKernel( ...@@ -180,26 +167,18 @@ INFINIOP_CUDA_KERNEL elementwiseKernel(
size_t offset) { size_t offset) {
size_t idx = blockIdx.x * blockDim.x + threadIdx.x + offset; size_t idx = blockIdx.x * blockDim.x + threadIdx.x + offset;
if (idx >= output_size) {
return;
}
size_t out_idx = output_contiguous if (idx < output_size) {
? idx size_t out_idx = getOutputIndex(idx, output_contiguous, ndim, output_shape, output_strides);
: device::cuda::indexToOffset(idx, ndim, output_shape, output_strides); InputIndexer indexer{idx, ndim, input_contiguous, input_broadcasted, input_shapes, input_strides, output_strides};
launchOp<Op, Tout, Tin...>( unpackInputsAndApply(
idx, [&](auto... Is) {
out_idx, output[out_idx] = Op{}.template operator()<Tout, Tin...>(
ndim, (typedInputPtr<Tin>(inputs[Is.value])[indexer(Is.value)])...);
input_contiguous, },
input_broadcasted, std::index_sequence_for<Tin...>{});
input_shapes, }
input_strides,
output_strides,
inputs,
output,
std::index_sequence_for<Tin...>{});
} }
struct DeviceImpl::Opaque { struct DeviceImpl::Opaque {
...@@ -231,45 +210,12 @@ struct DeviceImpl::Opaque { ...@@ -231,45 +210,12 @@ struct DeviceImpl::Opaque {
const std::vector<const void *> &inputs, const std::vector<const void *> &inputs,
cudaStream_t stream, cudaStream_t stream,
Args &&...args) { Args &&...args) {
auto output_size = info.getOutputSize(); return launchElementwiseKernel<BLOCK_SIZE, N>(
if (output_size == 0) { info, workspace,
return INFINI_STATUS_SUCCESS; reinterpret_cast<Tdata *>(output), inputs,
} elementwiseKernel<N, Op, Tdata, Args...>,
stream,
// casting the output and the inputs to Tdata pointers std::forward<Args>(args)...);
Tdata *out = reinterpret_cast<Tdata *>(output);
const void **d_inputs_arr = nullptr;
// create and send the info to device
const bool *d_input_contiguous = nullptr;
const bool *d_input_broadcasted = nullptr;
const size_t *d_output_shape = nullptr;
const ptrdiff_t *d_output_strides = nullptr;
const size_t *d_input_shapes = nullptr;
const ptrdiff_t *d_input_strides = nullptr;
CHECK_STATUS(infoToDevice<N>(info, workspace, inputs.data(), d_inputs_arr, d_input_contiguous, d_input_broadcasted,
d_output_shape, d_output_strides, d_input_shapes, d_input_strides, stream));
dim3 blockDims(std::min(BLOCK_SIZE, static_cast<size_t>(internal->maxThreadsPerBlock())));
dim3 gridDims(std::min(CEIL_DIV(output_size, blockDims.x), static_cast<size_t>(internal->gridSizeX())));
size_t step = gridDims.x * blockDims.x;
for (size_t i = 0; i < output_size; i += step) {
elementwiseKernel<N, Op, Tdata, Args...><<<gridDims, blockDims, 0, stream>>>(
output_size,
info.getNdim(),
info.isOutputContiguous(),
d_input_contiguous,
d_input_broadcasted,
d_output_shape,
d_input_shapes,
d_output_strides,
d_input_strides,
out, reinterpret_cast<const Tdata **>(d_inputs_arr), i, std::forward<Args>(args)...);
}
return INFINI_STATUS_SUCCESS;
} }
/** /**
...@@ -297,44 +243,12 @@ struct DeviceImpl::Opaque { ...@@ -297,44 +243,12 @@ struct DeviceImpl::Opaque {
const std::vector<const void *> &inputs, const std::vector<const void *> &inputs,
cudaStream_t stream, cudaStream_t stream,
Args &&...args) { Args &&...args) {
auto output_size = info.getOutputSize(); return launchElementwiseKernel<BLOCK_SIZE, N>(
if (output_size == 0) { info, workspace,
return INFINI_STATUS_SUCCESS; reinterpret_cast<Tout *>(output), inputs,
} elementwiseKernel<Op, Tout, Tin...>,
stream,
Tout *out = reinterpret_cast<Tout *>(output); std::forward<Args>(args)...);
const void **d_inputs_arr = nullptr;
// Device pointers
const bool *d_input_contiguous = nullptr;
const bool *d_input_broadcasted = nullptr;
const size_t *d_output_shape = nullptr;
const ptrdiff_t *d_output_strides = nullptr;
const size_t *d_input_shapes = nullptr;
const ptrdiff_t *d_input_strides = nullptr;
CHECK_STATUS(infoToDevice<N>(info, workspace, inputs.data(), d_inputs_arr, d_input_contiguous, d_input_broadcasted,
d_output_shape, d_output_strides, d_input_shapes, d_input_strides, stream));
dim3 blockDims(std::min(BLOCK_SIZE, static_cast<size_t>(internal->maxThreadsPerBlock())));
dim3 gridDims(std::min(CEIL_DIV(output_size, blockDims.x), static_cast<size_t>(internal->gridSizeX())));
size_t step = gridDims.x * blockDims.x;
for (size_t i = 0; i < output_size; i += step) {
elementwiseKernel<Op, Tout, Tin...><<<gridDims, blockDims, 0, stream>>>(
output_size,
info.getNdim(),
info.isOutputContiguous(),
d_input_contiguous,
d_input_broadcasted,
d_output_shape,
d_input_shapes,
d_output_strides,
d_input_strides,
out, reinterpret_cast<const void **>(d_inputs_arr), i);
}
return INFINI_STATUS_SUCCESS;
} }
private: private:
...@@ -390,6 +304,70 @@ private: ...@@ -390,6 +304,70 @@ private:
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
/**
* @brief Launches the elementwise kernel for the specified operation.
*
* @tparam BLOCK_SIZE Number of threads per block.
* @tparam N Number of input tensors.
* @tparam KernelFunc Type of the kernel function pointer.
* @tparam Tout Output data type.
* @tparam Args Additional arguments to be forwarded to the kernel.
*
* @param info Metadata about the elementwise operation (shapes, strides, etc.).
* @param workspace CUDA memory used for storing metadata.
* @param output Pointer to output buffer on device.
* @param inputs Vector of device pointers to input tensors.
* @param kernel_func Kernel function to launch.
* @param stream CUDA stream for asynchronous execution.
* @param args Additional arguments passed to the kernel.
* @return infiniStatus_t Status code indicating success or failure.
*/
template <size_t BLOCK_SIZE, size_t N, typename KernelFunc, typename Tout, typename... Args>
infiniStatus_t launchElementwiseKernel(
const op::elementwise::ElementwiseInfo &info,
void *workspace,
Tout *output,
const std::vector<const void *> &inputs,
KernelFunc kernel_func,
cudaStream_t stream,
Args &&...args) {
auto output_size = info.getOutputSize();
if (output_size == 0) {
return INFINI_STATUS_SUCCESS;
}
// Device pointers
const void **d_inputs_arr = nullptr;
const bool *d_input_contiguous = nullptr;
const bool *d_input_broadcasted = nullptr;
const size_t *d_output_shape = nullptr;
const ptrdiff_t *d_output_strides = nullptr;
const size_t *d_input_shapes = nullptr;
const ptrdiff_t *d_input_strides = nullptr;
CHECK_STATUS(infoToDevice<N>(info, workspace, inputs.data(), d_inputs_arr,
d_input_contiguous, d_input_broadcasted,
d_output_shape, d_output_strides,
d_input_shapes, d_input_strides, stream));
dim3 blockDims(std::min(BLOCK_SIZE, static_cast<size_t>(internal->maxThreadsPerBlock())));
dim3 gridDims(std::min(CEIL_DIV(output_size, blockDims.x), static_cast<size_t>(internal->gridSizeX())));
size_t step = gridDims.x * blockDims.x;
for (size_t i = 0; i < output_size; i += step) {
kernel_func<<<gridDims, blockDims, 0, stream>>>(
output_size, info.getNdim(), info.isOutputContiguous(),
d_input_contiguous, d_input_broadcasted,
d_output_shape, d_input_shapes,
d_output_strides, d_input_strides,
output, reinterpret_cast<const void **>(d_inputs_arr),
i, std::forward<Args>(args)...);
}
return INFINI_STATUS_SUCCESS;
}
}; };
template <typename... Args> template <typename... Args>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment