Commit a283a8fa authored by Zimin Li's avatar Zimin Li
Browse files

issue/127: refactor ElementwiseInfo to use utils::Result, change elementwise...

issue/127: refactor ElementwiseInfo to use utils::Result, change elementwise calcualte and calculateImpl to return infiniStatus_t, add CHECK_CUDA to cuda function calls
parent 7a8f2bca
......@@ -9,20 +9,27 @@
* @brief Define the process for initializing a Descriptor of an elementwise operation
* for its CPU implementation
*/
#define CREATE_ELEMENTWISE_CPU_DESCRIPTOR \
\
op::elementwise::ElementwiseInfo elementwise_info; \
CHECK_STATUS(op::elementwise::createElementwiseInfo(elementwise_info, out_desc, input_desc)); \
\
*desc_ptr = new Descriptor( \
dtype, \
std::move(elementwise_info), \
nullptr, \
handle->device, \
#define CREATE_ELEMENTWISE_CPU_DESCRIPTOR \
\
auto info_result = op::elementwise::ElementwiseInfo::create(out_desc, input_desc); \
CHECK_RESULT(info_result); \
\
*desc_ptr = new Descriptor( \
dtype, \
std::move(info_result.take()), \
nullptr, \
handle->device, \
handle->device_id);
namespace op::elementwise::cpu {
/**
* @brief CPU-specific device implementation for resource management and
* calculation implementations.
*
* This class encapsulates device-specific behavior and execution logic.
* Use the static create() method to instantiate a DeviceImpl.
*/
class DeviceImpl final {
struct Opaque;
std::shared_ptr<struct Opaque> _opaque;
......@@ -37,20 +44,48 @@ public:
DeviceImpl **device_info,
Args &&...args);
/* Invoke elementwise operation when all inputs have the same type */
/**
* @brief Dispatches an elementwise operation with uniform input types.
*
* @tparam Op The elementwise operation to perform.
* @tparam Tdata The common data type of all inputs and output.
* @tparam Args Additional backend-specific arguments.
* @param info Precomputed tensor metadata (shapes, strides, etc.).
* @param output Pointer to the output tensor buffer.
* @param inputs Vector of input tensor data pointers.
* @param stream Device execution stream.
* @param args Additional backend-specific arguments.
* @return infiniStatus_t Status indicating success or failure.
*/
template <typename Op, typename Tdata, typename... Args>
void calculate(
infiniStatus_t calculate(
const op::elementwise::ElementwiseInfo &info,
void *output,
const std::vector<const void *> &inputs,
void *stream,
Args &&...args);
/* Invoke elementwise operation for different input types */
/**
* @brief Dispatches an elementwise operation with heterogeneous input types.
*
* Supports operations where each input may have a different type, as defined by Op.
* The number of input types must match the operation's expected input count.
*
* @tparam Op The elementwise operation to perform.
* @tparam Tout Output data type.
* @tparam Tin Variadic input data types.
* @tparam Args Additional backend-specific arguments.
* @param info Precomputed tensor metadata (shapes, strides, etc.).
* @param output Pointer to the output tensor buffer.
* @param inputs Vector of input tensor data pointers.
* @param stream Device execution stream.
* @param args Additional backend-specific arguments.
* @return infiniStatus_t Status indicating success or failure.
*/
template <typename Op, typename Tout, typename... Tin,
typename... Args,
std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int> = 0>
void calculate(
infiniStatus_t calculate(
const op::elementwise::ElementwiseInfo &info,
void *output,
const std::vector<const void *> &inputs,
......@@ -58,6 +93,7 @@ public:
Args &&...args);
};
// Define the Opaque struct for CPU, which is empty
struct DeviceImpl::Opaque {};
template <typename... Args>
......@@ -90,14 +126,15 @@ void calculate_impl(const op::elementwise::ElementwiseInfo &info, void *output,
// Invoke elementwise operation for different input types
template <typename Op, typename Tout, typename... Tin, typename... Args, std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int> = 0>
void DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info,
void *output,
const std::vector<const void *> &inputs,
void *stream,
Args &&...args) {
infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info,
void *output,
const std::vector<const void *> &inputs,
void *stream,
Args &&...args) {
static_assert(sizeof...(Tin) == Op::num_inputs, "Input type count mismatch");
calculate_impl<Op, Tout, Tin...>(info, output, inputs, std::make_index_sequence<sizeof...(Tin)>{}, std::forward<Args>(args)...);
return INFINI_STATUS_SUCCESS;
}
// Perform elementwise operation when all inputs have the same type
......@@ -133,9 +170,10 @@ void calculate_impl(const op::elementwise::ElementwiseInfo &info,
// Invoke elementwise operation when all inputs have the same type
template <typename Op, typename Tdata, typename... Args>
void DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info, void *output, const std::vector<const void *> &inputs, void *stream, Args &&...args) {
infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info, void *output, const std::vector<const void *> &inputs, void *stream, Args &&...args) {
constexpr size_t N = Op::num_inputs;
calculate_impl<Op, Tdata>(info, output, inputs, std::make_index_sequence<N>{}, std::forward<Args>(args)...);
return INFINI_STATUS_SUCCESS;
}
} // namespace op::elementwise::cpu
......
......@@ -4,6 +4,7 @@
#include "../../../utils.h"
#include "../../devices/cuda/cuda_common.cuh"
#include "elementwise_cuda_api.cuh"
namespace op::elementwise::cuda {
/**
......@@ -223,16 +224,17 @@ struct DeviceImpl::Opaque {
* @param inputs Vector of pointers to input tensors.
* @param stream CUDA stream used for asynchronous execution.
* @param args Additional arguments for the operation.
* @return infiniStatus_t Status indicating success or failure.
*/
template <size_t BLOCK_SIZE, size_t N, typename Op, typename Tdata, typename... Args, size_t... Is>
void calculateImpl(const op::elementwise::ElementwiseInfo &info,
void *output,
const std::vector<const void *> &inputs,
std::index_sequence<Is...>,
cudaStream_t stream,
Args &&...args) {
infiniStatus_t calculateImpl(const op::elementwise::ElementwiseInfo &info,
void *output,
const std::vector<const void *> &inputs,
std::index_sequence<Is...>,
cudaStream_t stream,
Args &&...args) {
if (info.output_size == 0) {
return;
return INFINI_STATUS_SUCCESS;
}
// casting the output and the inputs to Tdata pointers
......@@ -242,8 +244,8 @@ struct DeviceImpl::Opaque {
for (size_t i = 0; i < N; ++i) {
inputs_arr[i] = reinterpret_cast<const Tdata *>(inputs[i]);
}
cudaMallocAsync(&d_inputs_arr, N * sizeof(*d_inputs_arr), stream);
cudaMemcpyAsync(d_inputs_arr, inputs_arr, N * sizeof(*d_inputs_arr), cudaMemcpyHostToDevice, stream);
CHECK_CUDA(cudaMallocAsync(&d_inputs_arr, N * sizeof(*d_inputs_arr), stream));
CHECK_CUDA(cudaMemcpyAsync(d_inputs_arr, inputs_arr, N * sizeof(*d_inputs_arr), cudaMemcpyHostToDevice, stream));
// create and send the info to device
const bool *d_bools = nullptr;
......@@ -257,10 +259,10 @@ struct DeviceImpl::Opaque {
std::vector<const size_t *> tmp_device_ptrs(info.input_size);
std::vector<const ptrdiff_t *> tmp_device_ptrs_strides(info.input_size);
infoToDevice<N>(info, d_bools, d_input_contiguous,
d_input_broadcasted, d_output_shape_strides, d_output_shape,
d_output_strides, tmp_device_ptrs, d_input_shapes, tmp_device_ptrs_strides,
d_input_strides, stream);
CHECK_STATUS(infoToDevice<N>(info, d_bools, d_input_contiguous,
d_input_broadcasted, d_output_shape_strides, d_output_shape,
d_output_strides, tmp_device_ptrs, d_input_shapes, tmp_device_ptrs_strides,
d_input_strides, stream));
dim3 blockDims(std::min(BLOCK_SIZE, static_cast<size_t>(internal->maxThreadsPerBlock())));
dim3 gridDims(std::min(CEIL_DIV(info.output_size, blockDims.x), static_cast<size_t>(internal->gridSizeX())));
......@@ -280,7 +282,9 @@ struct DeviceImpl::Opaque {
info.input_size, out, d_inputs_arr, i, std::forward<Args>(args)...);
}
freeAllDevice((const void **)d_inputs_arr, d_bools, d_output_shape_strides, info.input_size, d_input_shapes, d_input_strides, stream);
CHECK_STATUS(freeAllDevice((const void **)d_inputs_arr, d_bools, d_output_shape_strides,
info.input_size, d_input_shapes, d_input_strides, stream));
return INFINI_STATUS_SUCCESS;
}
/**
......@@ -297,17 +301,18 @@ struct DeviceImpl::Opaque {
* @param inputs Vector of pointers to input tensors.
* @param stream CUDA stream used for asynchronous execution.
* @param args Additional arguments for the operation.
* @return infiniStatus_t Status indicating success or failure.
*/
template <size_t BLOCK_SIZE, size_t N, typename Op, typename Tout, typename... Tin, typename... Args, size_t... Is,
std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int> = 0>
void calculateImpl(const op::elementwise::ElementwiseInfo &info,
void *output,
const std::vector<const void *> &inputs,
std::index_sequence<Is...>,
cudaStream_t stream,
Args &&...args) {
infiniStatus_t calculateImpl(const op::elementwise::ElementwiseInfo &info,
void *output,
const std::vector<const void *> &inputs,
std::index_sequence<Is...>,
cudaStream_t stream,
Args &&...args) {
if (info.output_size == 0) {
return;
return INFINI_STATUS_SUCCESS;
}
Tout *out = reinterpret_cast<Tout *>(output);
......@@ -318,8 +323,8 @@ struct DeviceImpl::Opaque {
// Create array of input pointers on host (void*) to copy to device
const void *host_input_ptrs[] = {reinterpret_cast<const void *>(std::get<Is>(inputs_arr))...};
cudaMallocAsync(&d_inputs_arr, N * sizeof(void *), stream);
cudaMemcpyAsync(d_inputs_arr, host_input_ptrs, N * sizeof(void *), cudaMemcpyHostToDevice, stream);
CHECK_CUDA(cudaMallocAsync(&d_inputs_arr, N * sizeof(void *), stream));
CHECK_CUDA(cudaMemcpyAsync(d_inputs_arr, host_input_ptrs, N * sizeof(void *), cudaMemcpyHostToDevice, stream));
// Device pointers
const bool *d_bools = nullptr;
......@@ -333,10 +338,10 @@ struct DeviceImpl::Opaque {
std::vector<const size_t *> tmp_device_ptrs(info.input_size);
std::vector<const ptrdiff_t *> tmp_device_ptrs_strides(info.input_size);
infoToDevice<N>(info, d_bools, d_input_contiguous,
d_input_broadcasted, d_output_shape_strides, d_output_shape,
d_output_strides, tmp_device_ptrs, d_input_shapes, tmp_device_ptrs_strides,
d_input_strides, stream);
CHECK_STATUS(infoToDevice<N>(info, d_bools, d_input_contiguous,
d_input_broadcasted, d_output_shape_strides, d_output_shape,
d_output_strides, tmp_device_ptrs, d_input_shapes, tmp_device_ptrs_strides,
d_input_strides, stream));
dim3 blockDims(std::min(BLOCK_SIZE, static_cast<size_t>(internal->maxThreadsPerBlock())));
dim3 gridDims(std::min(CEIL_DIV(info.output_size, blockDims.x), static_cast<size_t>(internal->gridSizeX())));
......@@ -356,7 +361,8 @@ struct DeviceImpl::Opaque {
info.input_size, out, reinterpret_cast<const void **>(d_inputs_arr), i);
}
freeAllDevice(d_inputs_arr, d_bools, d_output_shape_strides, info.input_size, d_input_shapes, d_input_strides, stream);
CHECK_STATUS(freeAllDevice(d_inputs_arr, d_bools, d_output_shape_strides, info.input_size, d_input_shapes, d_input_strides, stream));
return INFINI_STATUS_SUCCESS;
}
private:
......@@ -393,31 +399,31 @@ private:
const ptrdiff_t **&d_input_strides,
cudaStream_t stream) const {
cudaMallocAsync(&d_bools, 2 * info.input_size * sizeof(*d_bools), stream);
cudaMemcpyAsync((void *)d_bools, info.input_contiguous, info.input_size * sizeof(bool), cudaMemcpyHostToDevice, stream);
cudaMemcpyAsync((void *)(d_bools + info.input_size), info.input_broadcasted, info.input_size * sizeof(bool), cudaMemcpyHostToDevice, stream);
CHECK_CUDA(cudaMallocAsync(&d_bools, 2 * info.input_size * sizeof(*d_bools), stream));
CHECK_CUDA(cudaMemcpyAsync((void *)d_bools, info.input_contiguous, info.input_size * sizeof(bool), cudaMemcpyHostToDevice, stream));
CHECK_CUDA(cudaMemcpyAsync((void *)(d_bools + info.input_size), info.input_broadcasted, info.input_size * sizeof(bool), cudaMemcpyHostToDevice, stream));
cudaMallocAsync(&d_output_shape_strides, info.ndim * (sizeof(*d_output_shape) + sizeof(*d_output_strides)), stream);
cudaMemcpyAsync((void *)d_output_shape_strides, info.output_shape, info.ndim * sizeof(*d_output_shape), cudaMemcpyHostToDevice, stream);
cudaMemcpyAsync((void *)(d_output_shape_strides + info.ndim * sizeof(*d_output_shape)), info.output_strides, info.ndim * sizeof(*d_output_strides), cudaMemcpyHostToDevice, stream);
CHECK_CUDA(cudaMallocAsync(&d_output_shape_strides, info.ndim * (sizeof(*d_output_shape) + sizeof(*d_output_strides)), stream));
CHECK_CUDA(cudaMemcpyAsync((void *)d_output_shape_strides, info.output_shape, info.ndim * sizeof(*d_output_shape), cudaMemcpyHostToDevice, stream));
CHECK_CUDA(cudaMemcpyAsync((void *)(d_output_shape_strides + info.ndim * sizeof(*d_output_shape)), info.output_strides, info.ndim * sizeof(*d_output_strides), cudaMemcpyHostToDevice, stream));
cudaMallocAsync(&d_input_shapes, info.input_size * sizeof(*d_input_shapes), stream);
CHECK_CUDA(cudaMallocAsync(&d_input_shapes, info.input_size * sizeof(*d_input_shapes), stream));
for (size_t i = 0; i < info.input_size; ++i) {
cudaMallocAsync(&tmp_device_ptrs[i], info.ndim * sizeof(*&tmp_device_ptrs[i]), stream);
cudaMemcpyAsync((void *)tmp_device_ptrs[i], info.input_shapes[i],
info.ndim * sizeof(size_t), cudaMemcpyHostToDevice, stream);
CHECK_CUDA(cudaMallocAsync(&tmp_device_ptrs[i], info.ndim * sizeof(*&tmp_device_ptrs[i]), stream));
CHECK_CUDA(cudaMemcpyAsync((void *)tmp_device_ptrs[i], info.input_shapes[i],
info.ndim * sizeof(size_t), cudaMemcpyHostToDevice, stream));
}
cudaMemcpyAsync((void *)d_input_shapes, tmp_device_ptrs.data(),
info.input_size * sizeof(size_t *), cudaMemcpyHostToDevice, stream);
CHECK_CUDA(cudaMemcpyAsync((void *)d_input_shapes, tmp_device_ptrs.data(),
info.input_size * sizeof(size_t *), cudaMemcpyHostToDevice, stream));
cudaMallocAsync(&d_input_strides, info.input_size * sizeof(*d_input_strides), stream);
CHECK_CUDA(cudaMallocAsync(&d_input_strides, info.input_size * sizeof(*d_input_strides), stream));
for (size_t i = 0; i < info.input_size; ++i) {
cudaMallocAsync(&tmp_device_ptrs_strides[i], info.ndim * sizeof(*&tmp_device_ptrs_strides[i]), stream);
cudaMemcpyAsync((void *)tmp_device_ptrs_strides[i], info.input_strides[i],
info.ndim * sizeof(ptrdiff_t), cudaMemcpyHostToDevice, stream);
CHECK_CUDA(cudaMallocAsync(&tmp_device_ptrs_strides[i], info.ndim * sizeof(*&tmp_device_ptrs_strides[i]), stream));
CHECK_CUDA(cudaMemcpyAsync((void *)tmp_device_ptrs_strides[i], info.input_strides[i],
info.ndim * sizeof(ptrdiff_t), cudaMemcpyHostToDevice, stream));
}
cudaMemcpyAsync((void *)d_input_strides, tmp_device_ptrs_strides.data(),
info.input_size * sizeof(ptrdiff_t *), cudaMemcpyHostToDevice, stream);
CHECK_CUDA(cudaMemcpyAsync((void *)d_input_strides, tmp_device_ptrs_strides.data(),
info.input_size * sizeof(ptrdiff_t *), cudaMemcpyHostToDevice, stream));
d_input_contiguous = d_bools;
d_input_broadcasted = d_bools + info.input_size;
......@@ -447,11 +453,11 @@ private:
const ptrdiff_t **d_input_strides,
cudaStream_t stream) const {
cudaFreeAsync((void *)d_inputs_arr, stream);
cudaFreeAsync((void *)d_bools, stream);
cudaFreeAsync((void *)d_output_shape_strides, stream);
cudaFreeAsync((void *)d_input_shapes, stream);
cudaFreeAsync((void *)d_input_strides, stream);
CHECK_CUDA(cudaFreeAsync((void *)d_inputs_arr, stream));
CHECK_CUDA(cudaFreeAsync((void *)d_bools, stream));
CHECK_CUDA(cudaFreeAsync((void *)d_output_shape_strides, stream));
CHECK_CUDA(cudaFreeAsync((void *)d_input_shapes, stream));
CHECK_CUDA(cudaFreeAsync((void *)d_input_strides, stream));
return INFINI_STATUS_SUCCESS;
}
};
......@@ -479,17 +485,18 @@ infiniStatus_t DeviceImpl::create(DeviceImpl **device_info,
* @param inputs Vector of input pointers (device memory).
* @param stream CUDA stream (opaque void*).
* @param args (UNUSED) Additional operation-specific arguments.
* @return infiniStatus_t Status indicating success or failure.
*/
template <unsigned int BLOCK_SIZE, typename Op, typename Tout, typename... Tin, typename... Args,
std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int>>
void DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info,
void *output,
const std::vector<const void *> &inputs,
void *stream,
Args &&...args) {
infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info,
void *output,
const std::vector<const void *> &inputs,
void *stream,
Args &&...args) {
constexpr size_t N = Op::num_inputs;
static_assert(sizeof...(Tin) == N, "Input type count mismatch");
_opaque->calculateImpl<BLOCK_SIZE, N, Op, Tout, Tin...>(
return _opaque->calculateImpl<BLOCK_SIZE, N, Op, Tout, Tin...>(
info, output, inputs,
std::make_index_sequence<N>{},
reinterpret_cast<cudaStream_t>(stream),
......@@ -510,20 +517,20 @@ void DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info,
* @param inputs Vector of input pointers (device memory).
* @param stream CUDA stream (opaque void*).
* @param args Additional operation-specific arguments.
* @return infiniStatus_t Status indicating success or failure.
*/
template <unsigned int BLOCK_SIZE, typename Op, typename Tdata, typename... Args>
void DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info,
void *output,
const std::vector<const void *> &inputs,
void *stream,
Args &&...args) {
infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info,
void *output,
const std::vector<const void *> &inputs,
void *stream,
Args &&...args) {
constexpr size_t N = Op::num_inputs;
_opaque->calculateImpl<BLOCK_SIZE, N, Op, Tdata>(
return _opaque->calculateImpl<BLOCK_SIZE, N, Op, Tdata>(
info, output, inputs,
std::make_index_sequence<N>{},
reinterpret_cast<cudaStream_t>(stream),
std::forward<Args>(args)...);
cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(stream));
}
} // namespace op::elementwise::cuda
......
......@@ -24,7 +24,7 @@ public:
/* Invoke elementwise operation when all inputs have the same dtype */
template <unsigned int BLOCK_SIZE, typename Op, typename Tdata, typename... Args>
void calculate(
infiniStatus_t calculate(
const op::elementwise::ElementwiseInfo &info,
void *output,
const std::vector<const void *> &inputs,
......@@ -35,7 +35,7 @@ public:
template <unsigned int BLOCK_SIZE, typename Op, typename Tout, typename... Tin,
typename... Args,
std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int> = 0>
void calculate(
infiniStatus_t calculate(
const op::elementwise::ElementwiseInfo &info,
void *output,
const std::vector<const void *> &inputs,
......@@ -48,19 +48,19 @@ public:
* @brief Define the process for initializing a Descriptor of an elementwise operation
* for its CUDA implementation
*/
#define CREATE_ELEMENTWISE_CUDA_DESCRIPTOR \
\
op::elementwise::ElementwiseInfo elementwise_info; \
CHECK_STATUS(op::elementwise::createElementwiseInfo(elementwise_info, out_desc, input_desc)); \
\
op::elementwise::cuda::DeviceImpl *device_impl; \
CHECK_STATUS(op::elementwise::cuda::DeviceImpl::create(&device_impl, handle->internal())); \
\
*desc_ptr = new Descriptor( \
dtype, \
std::move(elementwise_info), \
device_impl, \
handle->device, \
#define CREATE_ELEMENTWISE_CUDA_DESCRIPTOR \
\
auto info_result = op::elementwise::ElementwiseInfo::create(out_desc, input_desc); \
CHECK_RESULT(info_result); \
\
op::elementwise::cuda::DeviceImpl *device_impl; \
CHECK_STATUS(op::elementwise::cuda::DeviceImpl::create(&device_impl, handle->internal())); \
\
*desc_ptr = new Descriptor( \
dtype, \
std::move(info_result.take()), \
device_impl, \
handle->device, \
handle->device_id);
#endif // __INFINIOP_ELEMENTWISE_CUDA_API_H__
#ifndef __INFINIOP_ELEMENTWISE_H__
#define __INFINIOP_ELEMENTWISE_H__
#include "../../utils.h"
#include "../operator.h"
#include "../tensor.h"
#include <algorithm>
......@@ -47,8 +48,22 @@
namespace op::elementwise {
// struct that stores data needed for elementwise operation
/**
* @brief Stores the metadata required for performing an elementwise operation.
*
* This struct encapsulates shape, stride, and layout information for both
* output and multiple input tensors involved in an elementwise operation.
*
* Memory is manually managed and freed in the destructor.
* Supports move construction but disallows copy construction and copy/move assignment.
*
* Use ElementwiseInfo::create(...) to safely construct an instance from tensor descriptors.
*/
struct ElementwiseInfo {
private:
ElementwiseInfo() = default;
public:
size_t output_size;
size_t ndim;
bool output_contiguous;
......@@ -60,9 +75,6 @@ struct ElementwiseInfo {
ptrdiff_t **input_strides;
size_t input_size;
ElementwiseInfo() = default;
// Destructor to free allocated memory
~ElementwiseInfo() {
delete[] input_contiguous;
delete[] input_broadcasted;
......@@ -77,37 +89,6 @@ struct ElementwiseInfo {
delete[] input_strides;
}
ElementwiseInfo(const ElementwiseInfo &other)
: output_size(other.output_size),
ndim(other.ndim),
output_contiguous(other.output_contiguous),
input_size(other.input_size) {
input_contiguous = new bool[input_size];
std::memcpy(input_contiguous, other.input_contiguous, input_size * sizeof(*input_contiguous));
input_broadcasted = new bool[input_size];
std::memcpy(input_broadcasted, other.input_broadcasted, input_size * sizeof(*input_broadcasted));
output_shape = new size_t[ndim];
std::memcpy(output_shape, other.output_shape, ndim * sizeof(*output_shape));
output_strides = new ptrdiff_t[ndim];
std::memcpy(output_strides, other.output_strides, ndim * sizeof(*output_strides));
input_shapes = new size_t *[input_size];
for (size_t i = 0; i < input_size; ++i) {
input_shapes[i] = new size_t[ndim];
std::memcpy(input_shapes[i], other.input_shapes[i], ndim * sizeof(*input_shapes[i]));
}
input_strides = new ptrdiff_t *[input_size];
for (size_t i = 0; i < input_size; ++i) {
input_strides[i] = new ptrdiff_t[ndim];
std::memcpy(input_strides[i], other.input_strides[i], ndim * sizeof(*input_strides[i]));
}
}
ElementwiseInfo(ElementwiseInfo &&other) noexcept
: output_size(other.output_size),
ndim(other.ndim),
......@@ -128,60 +109,69 @@ struct ElementwiseInfo {
other.input_size = 0;
}
ElementwiseInfo(const ElementwiseInfo &other) = delete;
ElementwiseInfo &operator=(const ElementwiseInfo &other) = delete;
ElementwiseInfo &operator=(ElementwiseInfo &&other) = delete;
};
inline infiniStatus_t createElementwiseInfo(
ElementwiseInfo &info,
infiniopTensorDescriptor_t output_desc,
std::vector<infiniopTensorDescriptor_t> input_descs) {
using ResultType = utils::Result<ElementwiseInfo>;
/**
* @brief Construct ElementwiseInfo from output and input tensor descriptors.
* @param output_desc Descriptor of the output tensor.
* @param input_descs Descriptors of the input tensors.
* @return Result<ElementwiseInfo> with the successfully constructed ElementwiseInfo,
* or the status code.
*/
static ResultType create(
infiniopTensorDescriptor_t output_desc,
std::vector<infiniopTensorDescriptor_t> input_descs) {
if (!output_desc || input_descs.empty()) {
return INFINI_STATUS_BAD_PARAM;
}
if (!output_desc || input_descs.empty()) {
return INFINI_STATUS_BAD_PARAM;
}
// Destination cannot have broadcast setup
if (output_desc->hasBroadcastDim()) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
// Destination cannot have broadcast setup
if (output_desc->hasBroadcastDim()) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
ElementwiseInfo info;
info.input_size = input_descs.size();
info.ndim = output_desc->ndim();
info.output_size = output_desc->numel();
info.output_contiguous = output_desc->isContiguous();
// Allocate memory for arrays
info.input_contiguous = new bool[info.input_size];
info.input_broadcasted = new bool[info.input_size];
info.output_shape = new size_t[info.ndim];
info.output_strides = new ptrdiff_t[info.ndim];
info.input_shapes = new size_t *[info.input_size];
info.input_strides = new ptrdiff_t *[info.input_size];
// Fill arrays
const auto output_shape = output_desc->shape();
const auto output_strides = output_desc->strides();
std::memcpy(info.output_shape, output_shape.data(), info.ndim * sizeof(*info.output_shape));
std::memcpy(info.output_strides, output_strides.data(), info.ndim * sizeof(*info.output_strides));
for (size_t i = 0; i < info.input_size; ++i) {
auto &desc = input_descs[i];
info.input_contiguous[i] = desc->isContiguous();
info.input_broadcasted[i] = !info.input_contiguous[i] && (desc->ndim() != info.ndim || desc->hasBroadcastDim());
info.input_shapes[i] = new size_t[desc->ndim()];
const auto &in_shape = desc->shape();
std::memcpy(info.input_shapes[i], in_shape.data(), desc->ndim() * sizeof(*info.input_shapes[i]));
info.input_strides[i] = new ptrdiff_t[desc->ndim()];
const auto &in_strides = desc->strides();
std::memcpy(info.input_strides[i], in_strides.data(), desc->ndim() * sizeof(*info.input_strides[i]));
}
info.input_size = input_descs.size();
info.ndim = output_desc->ndim();
info.output_size = output_desc->numel();
info.output_contiguous = output_desc->isContiguous();
// Allocate memory for arrays
info.input_contiguous = new bool[info.input_size];
info.input_broadcasted = new bool[info.input_size];
info.output_shape = new size_t[info.ndim];
info.output_strides = new ptrdiff_t[info.ndim];
info.input_shapes = new size_t *[info.input_size];
info.input_strides = new ptrdiff_t *[info.input_size];
// Fill arrays
const auto output_shape = output_desc->shape();
const auto output_strides = output_desc->strides();
std::memcpy(info.output_shape, output_shape.data(), info.ndim * sizeof(*info.output_shape));
std::memcpy(info.output_strides, output_strides.data(), info.ndim * sizeof(*info.output_strides));
for (size_t i = 0; i < info.input_size; ++i) {
auto &desc = input_descs[i];
info.input_contiguous[i] = desc->isContiguous();
info.input_broadcasted[i] = !info.input_contiguous[i] && (desc->ndim() != info.ndim || desc->hasBroadcastDim());
info.input_shapes[i] = new size_t[desc->ndim()];
const auto &in_shape = desc->shape();
std::memcpy(info.input_shapes[i], in_shape.data(), desc->ndim() * sizeof(*info.input_shapes[i]));
info.input_strides[i] = new ptrdiff_t[desc->ndim()];
const auto &in_strides = desc->strides();
std::memcpy(info.input_strides[i], in_strides.data(), desc->ndim() * sizeof(*info.input_strides[i]));
return ResultType(std::move(info));
}
return INFINI_STATUS_SUCCESS;
}
};
} // namespace op::elementwise
#endif // __INFINIOP_ELEMENTWISE_H__
......@@ -36,14 +36,11 @@ infiniStatus_t Descriptor::calculate(
switch (_dtype) {
case INFINI_DTYPE_F16:
_device_info->calculate<SwiGLUOp, fp16_t>(_info, output, inputs, stream);
break;
return _device_info->calculate<SwiGLUOp, fp16_t>(_info, output, inputs, stream);
case INFINI_DTYPE_F32:
_device_info->calculate<SwiGLUOp, float>(_info, output, inputs, stream);
break;
return _device_info->calculate<SwiGLUOp, float>(_info, output, inputs, stream);
case INFINI_DTYPE_F64:
_device_info->calculate<SwiGLUOp, double>(_info, output, inputs, stream);
break;
return _device_info->calculate<SwiGLUOp, double>(_info, output, inputs, stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
......
......@@ -38,14 +38,11 @@ infiniStatus_t Descriptor::calculate(
switch (_dtype) {
case INFINI_DTYPE_F16:
_device_info->calculate<256, SwiGLUOp, half>(_info, output, inputs, stream);
break;
return _device_info->calculate<256, SwiGLUOp, half>(_info, output, inputs, stream);
case INFINI_DTYPE_F32:
_device_info->calculate<256, SwiGLUOp, float>(_info, output, inputs, stream);
break;
return _device_info->calculate<256, SwiGLUOp, float>(_info, output, inputs, stream);
case INFINI_DTYPE_F64:
_device_info->calculate<256, SwiGLUOp, double>(_info, output, inputs, stream);
break;
return _device_info->calculate<256, SwiGLUOp, double>(_info, output, inputs, stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment