Commit a283a8fa authored by Zimin Li's avatar Zimin Li
Browse files

issue/127: refactor ElementwiseInfo to use utils::Result, change elementwise...

issue/127: refactor ElementwiseInfo to use utils::Result, change elementwise calcualte and calculateImpl to return infiniStatus_t, add CHECK_CUDA to cuda function calls
parent 7a8f2bca
...@@ -9,20 +9,27 @@ ...@@ -9,20 +9,27 @@
* @brief Define the process for initializing a Descriptor of an elementwise operation * @brief Define the process for initializing a Descriptor of an elementwise operation
* for its CPU implementation * for its CPU implementation
*/ */
#define CREATE_ELEMENTWISE_CPU_DESCRIPTOR \ #define CREATE_ELEMENTWISE_CPU_DESCRIPTOR \
\ \
op::elementwise::ElementwiseInfo elementwise_info; \ auto info_result = op::elementwise::ElementwiseInfo::create(out_desc, input_desc); \
CHECK_STATUS(op::elementwise::createElementwiseInfo(elementwise_info, out_desc, input_desc)); \ CHECK_RESULT(info_result); \
\ \
*desc_ptr = new Descriptor( \ *desc_ptr = new Descriptor( \
dtype, \ dtype, \
std::move(elementwise_info), \ std::move(info_result.take()), \
nullptr, \ nullptr, \
handle->device, \ handle->device, \
handle->device_id); handle->device_id);
namespace op::elementwise::cpu { namespace op::elementwise::cpu {
/**
* @brief CPU-specific device implementation for resource management and
* calculation implementations.
*
* This class encapsulates device-specific behavior and execution logic.
* Use the static create() method to instantiate a DeviceImpl.
*/
class DeviceImpl final { class DeviceImpl final {
struct Opaque; struct Opaque;
std::shared_ptr<struct Opaque> _opaque; std::shared_ptr<struct Opaque> _opaque;
...@@ -37,20 +44,48 @@ public: ...@@ -37,20 +44,48 @@ public:
DeviceImpl **device_info, DeviceImpl **device_info,
Args &&...args); Args &&...args);
/* Invoke elementwise operation when all inputs have the same type */ /**
* @brief Dispatches an elementwise operation with uniform input types.
*
* @tparam Op The elementwise operation to perform.
* @tparam Tdata The common data type of all inputs and output.
* @tparam Args Additional backend-specific arguments.
* @param info Precomputed tensor metadata (shapes, strides, etc.).
* @param output Pointer to the output tensor buffer.
* @param inputs Vector of input tensor data pointers.
* @param stream Device execution stream.
* @param args Additional backend-specific arguments.
* @return infiniStatus_t Status indicating success or failure.
*/
template <typename Op, typename Tdata, typename... Args> template <typename Op, typename Tdata, typename... Args>
void calculate( infiniStatus_t calculate(
const op::elementwise::ElementwiseInfo &info, const op::elementwise::ElementwiseInfo &info,
void *output, void *output,
const std::vector<const void *> &inputs, const std::vector<const void *> &inputs,
void *stream, void *stream,
Args &&...args); Args &&...args);
/* Invoke elementwise operation for different input types */ /**
* @brief Dispatches an elementwise operation with heterogeneous input types.
*
* Supports operations where each input may have a different type, as defined by Op.
* The number of input types must match the operation's expected input count.
*
* @tparam Op The elementwise operation to perform.
* @tparam Tout Output data type.
* @tparam Tin Variadic input data types.
* @tparam Args Additional backend-specific arguments.
* @param info Precomputed tensor metadata (shapes, strides, etc.).
* @param output Pointer to the output tensor buffer.
* @param inputs Vector of input tensor data pointers.
* @param stream Device execution stream.
* @param args Additional backend-specific arguments.
* @return infiniStatus_t Status indicating success or failure.
*/
template <typename Op, typename Tout, typename... Tin, template <typename Op, typename Tout, typename... Tin,
typename... Args, typename... Args,
std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int> = 0> std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int> = 0>
void calculate( infiniStatus_t calculate(
const op::elementwise::ElementwiseInfo &info, const op::elementwise::ElementwiseInfo &info,
void *output, void *output,
const std::vector<const void *> &inputs, const std::vector<const void *> &inputs,
...@@ -58,6 +93,7 @@ public: ...@@ -58,6 +93,7 @@ public:
Args &&...args); Args &&...args);
}; };
// Define the Opaque struct for CPU, which is empty
struct DeviceImpl::Opaque {}; struct DeviceImpl::Opaque {};
template <typename... Args> template <typename... Args>
...@@ -90,14 +126,15 @@ void calculate_impl(const op::elementwise::ElementwiseInfo &info, void *output, ...@@ -90,14 +126,15 @@ void calculate_impl(const op::elementwise::ElementwiseInfo &info, void *output,
// Invoke elementwise operation for different input types // Invoke elementwise operation for different input types
template <typename Op, typename Tout, typename... Tin, typename... Args, std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int> = 0> template <typename Op, typename Tout, typename... Tin, typename... Args, std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int> = 0>
void DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info, infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info,
void *output, void *output,
const std::vector<const void *> &inputs, const std::vector<const void *> &inputs,
void *stream, void *stream,
Args &&...args) { Args &&...args) {
static_assert(sizeof...(Tin) == Op::num_inputs, "Input type count mismatch"); static_assert(sizeof...(Tin) == Op::num_inputs, "Input type count mismatch");
calculate_impl<Op, Tout, Tin...>(info, output, inputs, std::make_index_sequence<sizeof...(Tin)>{}, std::forward<Args>(args)...); calculate_impl<Op, Tout, Tin...>(info, output, inputs, std::make_index_sequence<sizeof...(Tin)>{}, std::forward<Args>(args)...);
return INFINI_STATUS_SUCCESS;
} }
// Perform elementwise operation when all inputs have the same type // Perform elementwise operation when all inputs have the same type
...@@ -133,9 +170,10 @@ void calculate_impl(const op::elementwise::ElementwiseInfo &info, ...@@ -133,9 +170,10 @@ void calculate_impl(const op::elementwise::ElementwiseInfo &info,
// Invoke elementwise operation when all inputs have the same type // Invoke elementwise operation when all inputs have the same type
template <typename Op, typename Tdata, typename... Args> template <typename Op, typename Tdata, typename... Args>
void DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info, void *output, const std::vector<const void *> &inputs, void *stream, Args &&...args) { infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info, void *output, const std::vector<const void *> &inputs, void *stream, Args &&...args) {
constexpr size_t N = Op::num_inputs; constexpr size_t N = Op::num_inputs;
calculate_impl<Op, Tdata>(info, output, inputs, std::make_index_sequence<N>{}, std::forward<Args>(args)...); calculate_impl<Op, Tdata>(info, output, inputs, std::make_index_sequence<N>{}, std::forward<Args>(args)...);
return INFINI_STATUS_SUCCESS;
} }
} // namespace op::elementwise::cpu } // namespace op::elementwise::cpu
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include "../../../utils.h" #include "../../../utils.h"
#include "../../devices/cuda/cuda_common.cuh" #include "../../devices/cuda/cuda_common.cuh"
#include "elementwise_cuda_api.cuh" #include "elementwise_cuda_api.cuh"
namespace op::elementwise::cuda { namespace op::elementwise::cuda {
/** /**
...@@ -223,16 +224,17 @@ struct DeviceImpl::Opaque { ...@@ -223,16 +224,17 @@ struct DeviceImpl::Opaque {
* @param inputs Vector of pointers to input tensors. * @param inputs Vector of pointers to input tensors.
* @param stream CUDA stream used for asynchronous execution. * @param stream CUDA stream used for asynchronous execution.
* @param args Additional arguments for the operation. * @param args Additional arguments for the operation.
* @return infiniStatus_t Status indicating success or failure.
*/ */
template <size_t BLOCK_SIZE, size_t N, typename Op, typename Tdata, typename... Args, size_t... Is> template <size_t BLOCK_SIZE, size_t N, typename Op, typename Tdata, typename... Args, size_t... Is>
void calculateImpl(const op::elementwise::ElementwiseInfo &info, infiniStatus_t calculateImpl(const op::elementwise::ElementwiseInfo &info,
void *output, void *output,
const std::vector<const void *> &inputs, const std::vector<const void *> &inputs,
std::index_sequence<Is...>, std::index_sequence<Is...>,
cudaStream_t stream, cudaStream_t stream,
Args &&...args) { Args &&...args) {
if (info.output_size == 0) { if (info.output_size == 0) {
return; return INFINI_STATUS_SUCCESS;
} }
// casting the output and the inputs to Tdata pointers // casting the output and the inputs to Tdata pointers
...@@ -242,8 +244,8 @@ struct DeviceImpl::Opaque { ...@@ -242,8 +244,8 @@ struct DeviceImpl::Opaque {
for (size_t i = 0; i < N; ++i) { for (size_t i = 0; i < N; ++i) {
inputs_arr[i] = reinterpret_cast<const Tdata *>(inputs[i]); inputs_arr[i] = reinterpret_cast<const Tdata *>(inputs[i]);
} }
cudaMallocAsync(&d_inputs_arr, N * sizeof(*d_inputs_arr), stream); CHECK_CUDA(cudaMallocAsync(&d_inputs_arr, N * sizeof(*d_inputs_arr), stream));
cudaMemcpyAsync(d_inputs_arr, inputs_arr, N * sizeof(*d_inputs_arr), cudaMemcpyHostToDevice, stream); CHECK_CUDA(cudaMemcpyAsync(d_inputs_arr, inputs_arr, N * sizeof(*d_inputs_arr), cudaMemcpyHostToDevice, stream));
// create and send the info to device // create and send the info to device
const bool *d_bools = nullptr; const bool *d_bools = nullptr;
...@@ -257,10 +259,10 @@ struct DeviceImpl::Opaque { ...@@ -257,10 +259,10 @@ struct DeviceImpl::Opaque {
std::vector<const size_t *> tmp_device_ptrs(info.input_size); std::vector<const size_t *> tmp_device_ptrs(info.input_size);
std::vector<const ptrdiff_t *> tmp_device_ptrs_strides(info.input_size); std::vector<const ptrdiff_t *> tmp_device_ptrs_strides(info.input_size);
infoToDevice<N>(info, d_bools, d_input_contiguous, CHECK_STATUS(infoToDevice<N>(info, d_bools, d_input_contiguous,
d_input_broadcasted, d_output_shape_strides, d_output_shape, d_input_broadcasted, d_output_shape_strides, d_output_shape,
d_output_strides, tmp_device_ptrs, d_input_shapes, tmp_device_ptrs_strides, d_output_strides, tmp_device_ptrs, d_input_shapes, tmp_device_ptrs_strides,
d_input_strides, stream); d_input_strides, stream));
dim3 blockDims(std::min(BLOCK_SIZE, static_cast<size_t>(internal->maxThreadsPerBlock()))); dim3 blockDims(std::min(BLOCK_SIZE, static_cast<size_t>(internal->maxThreadsPerBlock())));
dim3 gridDims(std::min(CEIL_DIV(info.output_size, blockDims.x), static_cast<size_t>(internal->gridSizeX()))); dim3 gridDims(std::min(CEIL_DIV(info.output_size, blockDims.x), static_cast<size_t>(internal->gridSizeX())));
...@@ -280,7 +282,9 @@ struct DeviceImpl::Opaque { ...@@ -280,7 +282,9 @@ struct DeviceImpl::Opaque {
info.input_size, out, d_inputs_arr, i, std::forward<Args>(args)...); info.input_size, out, d_inputs_arr, i, std::forward<Args>(args)...);
} }
freeAllDevice((const void **)d_inputs_arr, d_bools, d_output_shape_strides, info.input_size, d_input_shapes, d_input_strides, stream); CHECK_STATUS(freeAllDevice((const void **)d_inputs_arr, d_bools, d_output_shape_strides,
info.input_size, d_input_shapes, d_input_strides, stream));
return INFINI_STATUS_SUCCESS;
} }
/** /**
...@@ -297,17 +301,18 @@ struct DeviceImpl::Opaque { ...@@ -297,17 +301,18 @@ struct DeviceImpl::Opaque {
* @param inputs Vector of pointers to input tensors. * @param inputs Vector of pointers to input tensors.
* @param stream CUDA stream used for asynchronous execution. * @param stream CUDA stream used for asynchronous execution.
* @param args Additional arguments for the operation. * @param args Additional arguments for the operation.
* @return infiniStatus_t Status indicating success or failure.
*/ */
template <size_t BLOCK_SIZE, size_t N, typename Op, typename Tout, typename... Tin, typename... Args, size_t... Is, template <size_t BLOCK_SIZE, size_t N, typename Op, typename Tout, typename... Tin, typename... Args, size_t... Is,
std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int> = 0> std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int> = 0>
void calculateImpl(const op::elementwise::ElementwiseInfo &info, infiniStatus_t calculateImpl(const op::elementwise::ElementwiseInfo &info,
void *output, void *output,
const std::vector<const void *> &inputs, const std::vector<const void *> &inputs,
std::index_sequence<Is...>, std::index_sequence<Is...>,
cudaStream_t stream, cudaStream_t stream,
Args &&...args) { Args &&...args) {
if (info.output_size == 0) { if (info.output_size == 0) {
return; return INFINI_STATUS_SUCCESS;
} }
Tout *out = reinterpret_cast<Tout *>(output); Tout *out = reinterpret_cast<Tout *>(output);
...@@ -318,8 +323,8 @@ struct DeviceImpl::Opaque { ...@@ -318,8 +323,8 @@ struct DeviceImpl::Opaque {
// Create array of input pointers on host (void*) to copy to device // Create array of input pointers on host (void*) to copy to device
const void *host_input_ptrs[] = {reinterpret_cast<const void *>(std::get<Is>(inputs_arr))...}; const void *host_input_ptrs[] = {reinterpret_cast<const void *>(std::get<Is>(inputs_arr))...};
cudaMallocAsync(&d_inputs_arr, N * sizeof(void *), stream); CHECK_CUDA(cudaMallocAsync(&d_inputs_arr, N * sizeof(void *), stream));
cudaMemcpyAsync(d_inputs_arr, host_input_ptrs, N * sizeof(void *), cudaMemcpyHostToDevice, stream); CHECK_CUDA(cudaMemcpyAsync(d_inputs_arr, host_input_ptrs, N * sizeof(void *), cudaMemcpyHostToDevice, stream));
// Device pointers // Device pointers
const bool *d_bools = nullptr; const bool *d_bools = nullptr;
...@@ -333,10 +338,10 @@ struct DeviceImpl::Opaque { ...@@ -333,10 +338,10 @@ struct DeviceImpl::Opaque {
std::vector<const size_t *> tmp_device_ptrs(info.input_size); std::vector<const size_t *> tmp_device_ptrs(info.input_size);
std::vector<const ptrdiff_t *> tmp_device_ptrs_strides(info.input_size); std::vector<const ptrdiff_t *> tmp_device_ptrs_strides(info.input_size);
infoToDevice<N>(info, d_bools, d_input_contiguous, CHECK_STATUS(infoToDevice<N>(info, d_bools, d_input_contiguous,
d_input_broadcasted, d_output_shape_strides, d_output_shape, d_input_broadcasted, d_output_shape_strides, d_output_shape,
d_output_strides, tmp_device_ptrs, d_input_shapes, tmp_device_ptrs_strides, d_output_strides, tmp_device_ptrs, d_input_shapes, tmp_device_ptrs_strides,
d_input_strides, stream); d_input_strides, stream));
dim3 blockDims(std::min(BLOCK_SIZE, static_cast<size_t>(internal->maxThreadsPerBlock()))); dim3 blockDims(std::min(BLOCK_SIZE, static_cast<size_t>(internal->maxThreadsPerBlock())));
dim3 gridDims(std::min(CEIL_DIV(info.output_size, blockDims.x), static_cast<size_t>(internal->gridSizeX()))); dim3 gridDims(std::min(CEIL_DIV(info.output_size, blockDims.x), static_cast<size_t>(internal->gridSizeX())));
...@@ -356,7 +361,8 @@ struct DeviceImpl::Opaque { ...@@ -356,7 +361,8 @@ struct DeviceImpl::Opaque {
info.input_size, out, reinterpret_cast<const void **>(d_inputs_arr), i); info.input_size, out, reinterpret_cast<const void **>(d_inputs_arr), i);
} }
freeAllDevice(d_inputs_arr, d_bools, d_output_shape_strides, info.input_size, d_input_shapes, d_input_strides, stream); CHECK_STATUS(freeAllDevice(d_inputs_arr, d_bools, d_output_shape_strides, info.input_size, d_input_shapes, d_input_strides, stream));
return INFINI_STATUS_SUCCESS;
} }
private: private:
...@@ -393,31 +399,31 @@ private: ...@@ -393,31 +399,31 @@ private:
const ptrdiff_t **&d_input_strides, const ptrdiff_t **&d_input_strides,
cudaStream_t stream) const { cudaStream_t stream) const {
cudaMallocAsync(&d_bools, 2 * info.input_size * sizeof(*d_bools), stream); CHECK_CUDA(cudaMallocAsync(&d_bools, 2 * info.input_size * sizeof(*d_bools), stream));
cudaMemcpyAsync((void *)d_bools, info.input_contiguous, info.input_size * sizeof(bool), cudaMemcpyHostToDevice, stream); CHECK_CUDA(cudaMemcpyAsync((void *)d_bools, info.input_contiguous, info.input_size * sizeof(bool), cudaMemcpyHostToDevice, stream));
cudaMemcpyAsync((void *)(d_bools + info.input_size), info.input_broadcasted, info.input_size * sizeof(bool), cudaMemcpyHostToDevice, stream); CHECK_CUDA(cudaMemcpyAsync((void *)(d_bools + info.input_size), info.input_broadcasted, info.input_size * sizeof(bool), cudaMemcpyHostToDevice, stream));
cudaMallocAsync(&d_output_shape_strides, info.ndim * (sizeof(*d_output_shape) + sizeof(*d_output_strides)), stream); CHECK_CUDA(cudaMallocAsync(&d_output_shape_strides, info.ndim * (sizeof(*d_output_shape) + sizeof(*d_output_strides)), stream));
cudaMemcpyAsync((void *)d_output_shape_strides, info.output_shape, info.ndim * sizeof(*d_output_shape), cudaMemcpyHostToDevice, stream); CHECK_CUDA(cudaMemcpyAsync((void *)d_output_shape_strides, info.output_shape, info.ndim * sizeof(*d_output_shape), cudaMemcpyHostToDevice, stream));
cudaMemcpyAsync((void *)(d_output_shape_strides + info.ndim * sizeof(*d_output_shape)), info.output_strides, info.ndim * sizeof(*d_output_strides), cudaMemcpyHostToDevice, stream); CHECK_CUDA(cudaMemcpyAsync((void *)(d_output_shape_strides + info.ndim * sizeof(*d_output_shape)), info.output_strides, info.ndim * sizeof(*d_output_strides), cudaMemcpyHostToDevice, stream));
cudaMallocAsync(&d_input_shapes, info.input_size * sizeof(*d_input_shapes), stream); CHECK_CUDA(cudaMallocAsync(&d_input_shapes, info.input_size * sizeof(*d_input_shapes), stream));
for (size_t i = 0; i < info.input_size; ++i) { for (size_t i = 0; i < info.input_size; ++i) {
cudaMallocAsync(&tmp_device_ptrs[i], info.ndim * sizeof(*&tmp_device_ptrs[i]), stream); CHECK_CUDA(cudaMallocAsync(&tmp_device_ptrs[i], info.ndim * sizeof(*&tmp_device_ptrs[i]), stream));
cudaMemcpyAsync((void *)tmp_device_ptrs[i], info.input_shapes[i], CHECK_CUDA(cudaMemcpyAsync((void *)tmp_device_ptrs[i], info.input_shapes[i],
info.ndim * sizeof(size_t), cudaMemcpyHostToDevice, stream); info.ndim * sizeof(size_t), cudaMemcpyHostToDevice, stream));
} }
cudaMemcpyAsync((void *)d_input_shapes, tmp_device_ptrs.data(), CHECK_CUDA(cudaMemcpyAsync((void *)d_input_shapes, tmp_device_ptrs.data(),
info.input_size * sizeof(size_t *), cudaMemcpyHostToDevice, stream); info.input_size * sizeof(size_t *), cudaMemcpyHostToDevice, stream));
cudaMallocAsync(&d_input_strides, info.input_size * sizeof(*d_input_strides), stream); CHECK_CUDA(cudaMallocAsync(&d_input_strides, info.input_size * sizeof(*d_input_strides), stream));
for (size_t i = 0; i < info.input_size; ++i) { for (size_t i = 0; i < info.input_size; ++i) {
cudaMallocAsync(&tmp_device_ptrs_strides[i], info.ndim * sizeof(*&tmp_device_ptrs_strides[i]), stream); CHECK_CUDA(cudaMallocAsync(&tmp_device_ptrs_strides[i], info.ndim * sizeof(*&tmp_device_ptrs_strides[i]), stream));
cudaMemcpyAsync((void *)tmp_device_ptrs_strides[i], info.input_strides[i], CHECK_CUDA(cudaMemcpyAsync((void *)tmp_device_ptrs_strides[i], info.input_strides[i],
info.ndim * sizeof(ptrdiff_t), cudaMemcpyHostToDevice, stream); info.ndim * sizeof(ptrdiff_t), cudaMemcpyHostToDevice, stream));
} }
cudaMemcpyAsync((void *)d_input_strides, tmp_device_ptrs_strides.data(), CHECK_CUDA(cudaMemcpyAsync((void *)d_input_strides, tmp_device_ptrs_strides.data(),
info.input_size * sizeof(ptrdiff_t *), cudaMemcpyHostToDevice, stream); info.input_size * sizeof(ptrdiff_t *), cudaMemcpyHostToDevice, stream));
d_input_contiguous = d_bools; d_input_contiguous = d_bools;
d_input_broadcasted = d_bools + info.input_size; d_input_broadcasted = d_bools + info.input_size;
...@@ -447,11 +453,11 @@ private: ...@@ -447,11 +453,11 @@ private:
const ptrdiff_t **d_input_strides, const ptrdiff_t **d_input_strides,
cudaStream_t stream) const { cudaStream_t stream) const {
cudaFreeAsync((void *)d_inputs_arr, stream); CHECK_CUDA(cudaFreeAsync((void *)d_inputs_arr, stream));
cudaFreeAsync((void *)d_bools, stream); CHECK_CUDA(cudaFreeAsync((void *)d_bools, stream));
cudaFreeAsync((void *)d_output_shape_strides, stream); CHECK_CUDA(cudaFreeAsync((void *)d_output_shape_strides, stream));
cudaFreeAsync((void *)d_input_shapes, stream); CHECK_CUDA(cudaFreeAsync((void *)d_input_shapes, stream));
cudaFreeAsync((void *)d_input_strides, stream); CHECK_CUDA(cudaFreeAsync((void *)d_input_strides, stream));
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
}; };
...@@ -479,17 +485,18 @@ infiniStatus_t DeviceImpl::create(DeviceImpl **device_info, ...@@ -479,17 +485,18 @@ infiniStatus_t DeviceImpl::create(DeviceImpl **device_info,
* @param inputs Vector of input pointers (device memory). * @param inputs Vector of input pointers (device memory).
* @param stream CUDA stream (opaque void*). * @param stream CUDA stream (opaque void*).
* @param args (UNUSED) Additional operation-specific arguments. * @param args (UNUSED) Additional operation-specific arguments.
* @return infiniStatus_t Status indicating success or failure.
*/ */
template <unsigned int BLOCK_SIZE, typename Op, typename Tout, typename... Tin, typename... Args, template <unsigned int BLOCK_SIZE, typename Op, typename Tout, typename... Tin, typename... Args,
std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int>> std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int>>
void DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info, infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info,
void *output, void *output,
const std::vector<const void *> &inputs, const std::vector<const void *> &inputs,
void *stream, void *stream,
Args &&...args) { Args &&...args) {
constexpr size_t N = Op::num_inputs; constexpr size_t N = Op::num_inputs;
static_assert(sizeof...(Tin) == N, "Input type count mismatch"); static_assert(sizeof...(Tin) == N, "Input type count mismatch");
_opaque->calculateImpl<BLOCK_SIZE, N, Op, Tout, Tin...>( return _opaque->calculateImpl<BLOCK_SIZE, N, Op, Tout, Tin...>(
info, output, inputs, info, output, inputs,
std::make_index_sequence<N>{}, std::make_index_sequence<N>{},
reinterpret_cast<cudaStream_t>(stream), reinterpret_cast<cudaStream_t>(stream),
...@@ -510,20 +517,20 @@ void DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info, ...@@ -510,20 +517,20 @@ void DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info,
* @param inputs Vector of input pointers (device memory). * @param inputs Vector of input pointers (device memory).
* @param stream CUDA stream (opaque void*). * @param stream CUDA stream (opaque void*).
* @param args Additional operation-specific arguments. * @param args Additional operation-specific arguments.
* @return infiniStatus_t Status indicating success or failure.
*/ */
template <unsigned int BLOCK_SIZE, typename Op, typename Tdata, typename... Args> template <unsigned int BLOCK_SIZE, typename Op, typename Tdata, typename... Args>
void DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info, infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info,
void *output, void *output,
const std::vector<const void *> &inputs, const std::vector<const void *> &inputs,
void *stream, void *stream,
Args &&...args) { Args &&...args) {
constexpr size_t N = Op::num_inputs; constexpr size_t N = Op::num_inputs;
_opaque->calculateImpl<BLOCK_SIZE, N, Op, Tdata>( return _opaque->calculateImpl<BLOCK_SIZE, N, Op, Tdata>(
info, output, inputs, info, output, inputs,
std::make_index_sequence<N>{}, std::make_index_sequence<N>{},
reinterpret_cast<cudaStream_t>(stream), reinterpret_cast<cudaStream_t>(stream),
std::forward<Args>(args)...); std::forward<Args>(args)...);
cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(stream));
} }
} // namespace op::elementwise::cuda } // namespace op::elementwise::cuda
......
...@@ -24,7 +24,7 @@ public: ...@@ -24,7 +24,7 @@ public:
/* Invoke elementwise operation when all inputs have the same dtype */ /* Invoke elementwise operation when all inputs have the same dtype */
template <unsigned int BLOCK_SIZE, typename Op, typename Tdata, typename... Args> template <unsigned int BLOCK_SIZE, typename Op, typename Tdata, typename... Args>
void calculate( infiniStatus_t calculate(
const op::elementwise::ElementwiseInfo &info, const op::elementwise::ElementwiseInfo &info,
void *output, void *output,
const std::vector<const void *> &inputs, const std::vector<const void *> &inputs,
...@@ -35,7 +35,7 @@ public: ...@@ -35,7 +35,7 @@ public:
template <unsigned int BLOCK_SIZE, typename Op, typename Tout, typename... Tin, template <unsigned int BLOCK_SIZE, typename Op, typename Tout, typename... Tin,
typename... Args, typename... Args,
std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int> = 0> std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int> = 0>
void calculate( infiniStatus_t calculate(
const op::elementwise::ElementwiseInfo &info, const op::elementwise::ElementwiseInfo &info,
void *output, void *output,
const std::vector<const void *> &inputs, const std::vector<const void *> &inputs,
...@@ -48,19 +48,19 @@ public: ...@@ -48,19 +48,19 @@ public:
* @brief Define the process for initializing a Descriptor of an elementwise operation * @brief Define the process for initializing a Descriptor of an elementwise operation
* for its CUDA implementation * for its CUDA implementation
*/ */
#define CREATE_ELEMENTWISE_CUDA_DESCRIPTOR \ #define CREATE_ELEMENTWISE_CUDA_DESCRIPTOR \
\ \
op::elementwise::ElementwiseInfo elementwise_info; \ auto info_result = op::elementwise::ElementwiseInfo::create(out_desc, input_desc); \
CHECK_STATUS(op::elementwise::createElementwiseInfo(elementwise_info, out_desc, input_desc)); \ CHECK_RESULT(info_result); \
\ \
op::elementwise::cuda::DeviceImpl *device_impl; \ op::elementwise::cuda::DeviceImpl *device_impl; \
CHECK_STATUS(op::elementwise::cuda::DeviceImpl::create(&device_impl, handle->internal())); \ CHECK_STATUS(op::elementwise::cuda::DeviceImpl::create(&device_impl, handle->internal())); \
\ \
*desc_ptr = new Descriptor( \ *desc_ptr = new Descriptor( \
dtype, \ dtype, \
std::move(elementwise_info), \ std::move(info_result.take()), \
device_impl, \ device_impl, \
handle->device, \ handle->device, \
handle->device_id); handle->device_id);
#endif // __INFINIOP_ELEMENTWISE_CUDA_API_H__ #endif // __INFINIOP_ELEMENTWISE_CUDA_API_H__
#ifndef __INFINIOP_ELEMENTWISE_H__ #ifndef __INFINIOP_ELEMENTWISE_H__
#define __INFINIOP_ELEMENTWISE_H__ #define __INFINIOP_ELEMENTWISE_H__
#include "../../utils.h"
#include "../operator.h" #include "../operator.h"
#include "../tensor.h" #include "../tensor.h"
#include <algorithm> #include <algorithm>
...@@ -47,8 +48,22 @@ ...@@ -47,8 +48,22 @@
namespace op::elementwise { namespace op::elementwise {
// struct that stores data needed for elementwise operation /**
* @brief Stores the metadata required for performing an elementwise operation.
*
* This struct encapsulates shape, stride, and layout information for both
* output and multiple input tensors involved in an elementwise operation.
*
* Memory is manually managed and freed in the destructor.
* Supports move construction but disallows copy construction and copy/move assignment.
*
* Use ElementwiseInfo::create(...) to safely construct an instance from tensor descriptors.
*/
struct ElementwiseInfo { struct ElementwiseInfo {
private:
ElementwiseInfo() = default;
public:
size_t output_size; size_t output_size;
size_t ndim; size_t ndim;
bool output_contiguous; bool output_contiguous;
...@@ -60,9 +75,6 @@ struct ElementwiseInfo { ...@@ -60,9 +75,6 @@ struct ElementwiseInfo {
ptrdiff_t **input_strides; ptrdiff_t **input_strides;
size_t input_size; size_t input_size;
ElementwiseInfo() = default;
// Destructor to free allocated memory
~ElementwiseInfo() { ~ElementwiseInfo() {
delete[] input_contiguous; delete[] input_contiguous;
delete[] input_broadcasted; delete[] input_broadcasted;
...@@ -77,37 +89,6 @@ struct ElementwiseInfo { ...@@ -77,37 +89,6 @@ struct ElementwiseInfo {
delete[] input_strides; delete[] input_strides;
} }
ElementwiseInfo(const ElementwiseInfo &other)
: output_size(other.output_size),
ndim(other.ndim),
output_contiguous(other.output_contiguous),
input_size(other.input_size) {
input_contiguous = new bool[input_size];
std::memcpy(input_contiguous, other.input_contiguous, input_size * sizeof(*input_contiguous));
input_broadcasted = new bool[input_size];
std::memcpy(input_broadcasted, other.input_broadcasted, input_size * sizeof(*input_broadcasted));
output_shape = new size_t[ndim];
std::memcpy(output_shape, other.output_shape, ndim * sizeof(*output_shape));
output_strides = new ptrdiff_t[ndim];
std::memcpy(output_strides, other.output_strides, ndim * sizeof(*output_strides));
input_shapes = new size_t *[input_size];
for (size_t i = 0; i < input_size; ++i) {
input_shapes[i] = new size_t[ndim];
std::memcpy(input_shapes[i], other.input_shapes[i], ndim * sizeof(*input_shapes[i]));
}
input_strides = new ptrdiff_t *[input_size];
for (size_t i = 0; i < input_size; ++i) {
input_strides[i] = new ptrdiff_t[ndim];
std::memcpy(input_strides[i], other.input_strides[i], ndim * sizeof(*input_strides[i]));
}
}
ElementwiseInfo(ElementwiseInfo &&other) noexcept ElementwiseInfo(ElementwiseInfo &&other) noexcept
: output_size(other.output_size), : output_size(other.output_size),
ndim(other.ndim), ndim(other.ndim),
...@@ -128,60 +109,69 @@ struct ElementwiseInfo { ...@@ -128,60 +109,69 @@ struct ElementwiseInfo {
other.input_size = 0; other.input_size = 0;
} }
ElementwiseInfo(const ElementwiseInfo &other) = delete;
ElementwiseInfo &operator=(const ElementwiseInfo &other) = delete; ElementwiseInfo &operator=(const ElementwiseInfo &other) = delete;
ElementwiseInfo &operator=(ElementwiseInfo &&other) = delete; ElementwiseInfo &operator=(ElementwiseInfo &&other) = delete;
};
inline infiniStatus_t createElementwiseInfo( using ResultType = utils::Result<ElementwiseInfo>;
ElementwiseInfo &info,
infiniopTensorDescriptor_t output_desc, /**
std::vector<infiniopTensorDescriptor_t> input_descs) { * @brief Construct ElementwiseInfo from output and input tensor descriptors.
* @param output_desc Descriptor of the output tensor.
* @param input_descs Descriptors of the input tensors.
* @return Result<ElementwiseInfo> with the successfully constructed ElementwiseInfo,
* or the status code.
*/
static ResultType create(
infiniopTensorDescriptor_t output_desc,
std::vector<infiniopTensorDescriptor_t> input_descs) {
if (!output_desc || input_descs.empty()) {
return INFINI_STATUS_BAD_PARAM;
}
if (!output_desc || input_descs.empty()) { // Destination cannot have broadcast setup
return INFINI_STATUS_BAD_PARAM; if (output_desc->hasBroadcastDim()) {
} return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
// Destination cannot have broadcast setup ElementwiseInfo info;
if (output_desc->hasBroadcastDim()) { info.input_size = input_descs.size();
return INFINI_STATUS_BAD_TENSOR_STRIDES; info.ndim = output_desc->ndim();
} info.output_size = output_desc->numel();
info.output_contiguous = output_desc->isContiguous();
// Allocate memory for arrays
info.input_contiguous = new bool[info.input_size];
info.input_broadcasted = new bool[info.input_size];
info.output_shape = new size_t[info.ndim];
info.output_strides = new ptrdiff_t[info.ndim];
info.input_shapes = new size_t *[info.input_size];
info.input_strides = new ptrdiff_t *[info.input_size];
// Fill arrays
const auto output_shape = output_desc->shape();
const auto output_strides = output_desc->strides();
std::memcpy(info.output_shape, output_shape.data(), info.ndim * sizeof(*info.output_shape));
std::memcpy(info.output_strides, output_strides.data(), info.ndim * sizeof(*info.output_strides));
for (size_t i = 0; i < info.input_size; ++i) {
auto &desc = input_descs[i];
info.input_contiguous[i] = desc->isContiguous();
info.input_broadcasted[i] = !info.input_contiguous[i] && (desc->ndim() != info.ndim || desc->hasBroadcastDim());
info.input_shapes[i] = new size_t[desc->ndim()];
const auto &in_shape = desc->shape();
std::memcpy(info.input_shapes[i], in_shape.data(), desc->ndim() * sizeof(*info.input_shapes[i]));
info.input_strides[i] = new ptrdiff_t[desc->ndim()];
const auto &in_strides = desc->strides();
std::memcpy(info.input_strides[i], in_strides.data(), desc->ndim() * sizeof(*info.input_strides[i]));
}
info.input_size = input_descs.size(); return ResultType(std::move(info));
info.ndim = output_desc->ndim();
info.output_size = output_desc->numel();
info.output_contiguous = output_desc->isContiguous();
// Allocate memory for arrays
info.input_contiguous = new bool[info.input_size];
info.input_broadcasted = new bool[info.input_size];
info.output_shape = new size_t[info.ndim];
info.output_strides = new ptrdiff_t[info.ndim];
info.input_shapes = new size_t *[info.input_size];
info.input_strides = new ptrdiff_t *[info.input_size];
// Fill arrays
const auto output_shape = output_desc->shape();
const auto output_strides = output_desc->strides();
std::memcpy(info.output_shape, output_shape.data(), info.ndim * sizeof(*info.output_shape));
std::memcpy(info.output_strides, output_strides.data(), info.ndim * sizeof(*info.output_strides));
for (size_t i = 0; i < info.input_size; ++i) {
auto &desc = input_descs[i];
info.input_contiguous[i] = desc->isContiguous();
info.input_broadcasted[i] = !info.input_contiguous[i] && (desc->ndim() != info.ndim || desc->hasBroadcastDim());
info.input_shapes[i] = new size_t[desc->ndim()];
const auto &in_shape = desc->shape();
std::memcpy(info.input_shapes[i], in_shape.data(), desc->ndim() * sizeof(*info.input_shapes[i]));
info.input_strides[i] = new ptrdiff_t[desc->ndim()];
const auto &in_strides = desc->strides();
std::memcpy(info.input_strides[i], in_strides.data(), desc->ndim() * sizeof(*info.input_strides[i]));
} }
};
return INFINI_STATUS_SUCCESS;
}
} // namespace op::elementwise } // namespace op::elementwise
#endif // __INFINIOP_ELEMENTWISE_H__ #endif // __INFINIOP_ELEMENTWISE_H__
...@@ -36,14 +36,11 @@ infiniStatus_t Descriptor::calculate( ...@@ -36,14 +36,11 @@ infiniStatus_t Descriptor::calculate(
switch (_dtype) { switch (_dtype) {
case INFINI_DTYPE_F16: case INFINI_DTYPE_F16:
_device_info->calculate<SwiGLUOp, fp16_t>(_info, output, inputs, stream); return _device_info->calculate<SwiGLUOp, fp16_t>(_info, output, inputs, stream);
break;
case INFINI_DTYPE_F32: case INFINI_DTYPE_F32:
_device_info->calculate<SwiGLUOp, float>(_info, output, inputs, stream); return _device_info->calculate<SwiGLUOp, float>(_info, output, inputs, stream);
break;
case INFINI_DTYPE_F64: case INFINI_DTYPE_F64:
_device_info->calculate<SwiGLUOp, double>(_info, output, inputs, stream); return _device_info->calculate<SwiGLUOp, double>(_info, output, inputs, stream);
break;
default: default:
return INFINI_STATUS_BAD_TENSOR_DTYPE; return INFINI_STATUS_BAD_TENSOR_DTYPE;
} }
......
...@@ -38,14 +38,11 @@ infiniStatus_t Descriptor::calculate( ...@@ -38,14 +38,11 @@ infiniStatus_t Descriptor::calculate(
switch (_dtype) { switch (_dtype) {
case INFINI_DTYPE_F16: case INFINI_DTYPE_F16:
_device_info->calculate<256, SwiGLUOp, half>(_info, output, inputs, stream); return _device_info->calculate<256, SwiGLUOp, half>(_info, output, inputs, stream);
break;
case INFINI_DTYPE_F32: case INFINI_DTYPE_F32:
_device_info->calculate<256, SwiGLUOp, float>(_info, output, inputs, stream); return _device_info->calculate<256, SwiGLUOp, float>(_info, output, inputs, stream);
break;
case INFINI_DTYPE_F64: case INFINI_DTYPE_F64:
_device_info->calculate<256, SwiGLUOp, double>(_info, output, inputs, stream); return _device_info->calculate<256, SwiGLUOp, double>(_info, output, inputs, stream);
break;
default: default:
return INFINI_STATUS_BAD_TENSOR_DTYPE; return INFINI_STATUS_BAD_TENSOR_DTYPE;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment