Commit 7105d13d authored by Zimin Li's avatar Zimin Li
Browse files

issue/127: refactor elementwise infra to support Opaque input when calculate

parent b985bc5e
#ifndef __INFINIOP_ELEMENTWISE_CPU_H__
#define __INFINIOP_ELEMENTWISE_CPU_H__
#include "../../devices/cpu/common_cpu.h"
#include "../elementwise.h"
#include <utility>
/**
* @brief Define the process for initializing a Descriptor of an elementwise operation
* for its CPU implementation
*/
#define CREATE_ELEMENTWISE_CPU_DESCRIPTOR \
\
op::elementwise::ElementwiseInfo elementwise_info; \
CHECK_STATUS(op::elementwise::createElementwiseInfo(elementwise_info, out_desc, input_desc)); \
\
*desc_ptr = new Descriptor( \
dtype, \
std::move(elementwise_info), \
nullptr, \
handle->device, \
handle->device_id);
DEVICE_IMPL(cpu)
namespace op::elementwise::cpu {
struct DeviceImpl::Opaque {};
template <typename... Args>
infiniStatus_t DeviceImpl::create(DeviceImpl **device_info, Args &&...args) {
*device_info = new DeviceImpl(nullptr);
return INFINI_STATUS_SUCCESS;
}
// Perform elementwise operation for different input types
template <typename Op, typename Tout, typename... Tin, size_t... Is, typename... Args, std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int> = 0>
void calculate_impl(const op::elementwise::ElementwiseInfo &info, void *output, const std::vector<const void *> &inputs, std::index_sequence<Is...>, Args &&...args) {
Tout *out = reinterpret_cast<Tout *>(output);
std::tuple<const Tin *...> input_ptrs = {reinterpret_cast<const Tin *>(inputs[Is])...};
ptrdiff_t output_size = info.output_size;
#pragma omp parallel for
for (ptrdiff_t i = 0; i < output_size; ++i) {
size_t out_idx = info.output_contiguous ? i : op::common_cpu::indexToOffset(i, info.ndim, info.output_shape.data(), info.output_strides.data());
auto get_input_idx = [&](size_t input_id) {
return info.input_contiguous[input_id] ? i
: (info.input_broadcasted[input_id]
? op::common_cpu::indexToReducedOffset(i, info.ndim, info.output_strides.data(), info.input_strides[input_id].data())
: op::common_cpu::indexToOffset(i, info.ndim, info.input_shapes[input_id].data(), info.input_strides[input_id].data()));
};
out[out_idx] = utils::cast<Tout>(Op{}(std::get<Is>(input_ptrs)[get_input_idx(Is)]..., std::forward<Args>(args)...));
}
}
// Invoke elementwise operation for different input types
template <typename Op, typename Tout, typename... Tin, typename... Args, std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int> = 0>
void DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info,
void *output,
const std::vector<const void *> &inputs,
Args &&...args) {
static_assert(sizeof...(Tin) == Op::num_inputs, "Input type count mismatch");
calculate_impl<Op, Tout, Tin...>(info, output, inputs, std::make_index_sequence<sizeof...(Tin)>{}, std::forward<Args>(args)...);
}
// Perform elementwise operation when all inputs have the same type
template <typename Op, typename Tdata, size_t... Is, typename... Args>
void calculate_impl(const op::elementwise::ElementwiseInfo &info,
void *output,
const std::vector<const void *> &inputs,
std::index_sequence<Is...>,
Args &&...args) {
Tdata *out = reinterpret_cast<Tdata *>(output);
std::array<const Tdata *, sizeof...(Is)> ins = {reinterpret_cast<const Tdata *>(inputs[Is])...};
const ptrdiff_t output_size = info.output_size;
#pragma omp parallel for
for (ptrdiff_t i = 0; i < output_size; ++i) {
size_t out_idx = info.output_contiguous ? i : op::common_cpu::indexToOffset(i, info.ndim, info.output_shape.data(), info.output_strides.data());
auto get_input_idx = [&](size_t input_id) {
return info.input_contiguous[input_id] ? i
: (info.input_broadcasted[input_id]
? op::common_cpu::indexToReducedOffset(i, info.ndim, info.output_strides.data(), info.input_strides[input_id].data())
: op::common_cpu::indexToOffset(i, info.ndim, info.input_shapes[input_id].data(), info.input_strides[input_id].data()));
};
if constexpr (std::is_same_v<Tdata, fp16_t>) {
out[out_idx] = utils::cast<fp16_t>(Op{}(utils::cast<float>(ins[Is][get_input_idx(Is)])..., std::forward<Args>(args)...));
} else {
out[out_idx] = Op{}(ins[Is][get_input_idx(Is)]..., std::forward<Args>(args)...);
}
}
}
// Invoke elementwise operation when all inputs have the same type
template <typename Op, typename Tdata, typename... Args>
void DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info, void *output, const std::vector<const void *> &inputs, Args &&...args) {
constexpr size_t N = Op::num_inputs;
calculate_impl<Op, Tdata>(info, output, inputs, std::make_index_sequence<N>{}, std::forward<Args>(args)...);
}
} // namespace op::elementwise::cpu
#endif // __INFINIOP_ELEMENTWISE_CPU_H__
// #ifndef __INFINIOP_ELEMENTWISE_CUDA_H__
// #define __INFINIOP_ELEMENTWISE_CUDA_H__
// #include "../../devices/cuda/cuda_common.cuh"
// #include "../elementwise.h"
// #define ELEMENTWISE_CUDA_OPAQUE(OP) \
// \
// namespace op::OP::cuda { \
// struct Descriptor::Opaque { \
// std::shared_ptr<device::cuda::Handle::Internal> internal; \
// }; \
// \
// Descriptor::~Descriptor() { \
// delete _opaque; \
// } \
// } // namespace op::elementwise::cuda
// namespace op::common_cuda::elementwise_op {
// // Perform elementwise operation when all inputs have the same type
// template <size_t BLOCK_SIZE, typename Op, typename Tdata, size_t... Is, typename... Args>
// void _calculate_impl(const op::elementwise::ElementwiseInfo &info,
// void *output,
// const std::vector<const void *> &inputs,
// std::index_sequence<Is...>,
// Args &&...args) {
// Tdata *out = reinterpret_cast<Tdata *>(output);
// std::array<const Tdata *, sizeof...(Is)> ins = {reinterpret_cast<const Tdata *>(inputs[Is])...};
// const ptrdiff_t output_size = info.output_size;
// #pragma omp parallel for
// for (ptrdiff_t i = 0; i < output_size; ++i) {
// size_t out_idx = info.output_contiguous ? i : op::common_cpu::indexToOffset(i, info.ndim, info.output_shape.data(), info.output_strides.data());
// auto get_input_idx = [&](size_t input_id) {
// return info.input_contiguous[input_id] ? i
// : (info.input_broadcasted[input_id]
// ? op::common_cpu::indexToReducedOffset(i, info.ndim, info.output_strides.data(), info.input_strides[input_id].data())
// : op::common_cpu::indexToOffset(i, info.ndim, info.input_shapes[input_id].data(), info.input_strides[input_id].data()));
// };
// if constexpr (std::is_same_v<Tdata, fp16_t>) {
// out[out_idx] = utils::cast<fp16_t>(Op{}(utils::cast<float>(ins[Is][get_input_idx(Is)])..., std::forward<Args>(args)...));
// } else {
// out[out_idx] = Op{}(ins[Is][get_input_idx(Is)]..., std::forward<Args>(args)...);
// }
// }
// }
// template <size_t BLOCK_SIZE, typename Op, typename Tdata, size_t... Is, typename... Args>
// void calculate_impl(const op::elementwise::ElementwiseInfo &info,
// void *output,
// const std::vector<const void *> &inputs,
// std::index_sequence<Is...>,
// Args &&...args) {
// if (info.output_size == 0) {
// return;
// }
// Tdata *out = reinterpret_cast<Tdata *>(output);
// std::array<const Tdata *, sizeof...(Is)> inputs_vec = {reinterpret_cast<const Tdata *>(inputs[Is])...};
// dim3 blockDims = dim3(std::min(static_cast<uint64_t>(BLOCK_SIZE), info.output_size));
// dim3 gridDims = dim3(std::min(ROUND_UP_DIV(info.output_size, blockDims.x), desc->max_grid_size));
// uint64_t step = gridDims.x * blockDims.x;
// _calculate_impl<BLOCK_SIZE, Op, Tdata, TIdata>(info, out, inputs_vec, Is, std::forward<Args>(args)...);
// }
// // Invoke elementwise operation when all inputs have the same type
// template <size_t BLOCK_SIZE, typename Op, typename Tdata, typename... Args>
// void calculate(const op::elementwise::ElementwiseInfo &info, void *output, const std::vector<const void *> &inputs, Args &&...args) {
// constexpr size_t N = Op::num_inputs;
// calculate_impl<BLOCK_SIZE, Op, Tdata>(info, output, inputs, std::make_index_sequence<N>{}, std::forward<Args>(args)...);
// }
// } // namespace op::common_cuda::elementwise_op
// #endif // __INFINIOP_ELEMENTWISE_CUDA_H__
\ No newline at end of file
#ifndef __INFINIOP_ELEMENTWISE_H__
#define __INFINIOP_ELEMENTWISE_H__
#include "../operator.h"
#include "../tensor.h"
#include <algorithm>
#include <memory>
#include <numeric>
#include <vector>
#define DEVICE_IMPL(NAMESPACE) \
\
namespace op::elementwise::NAMESPACE { \
class DeviceImpl final { \
struct Opaque; \
std::unique_ptr<Opaque> _opaque; \
\
DeviceImpl(Opaque *opaque) : _opaque(opaque) {} \
\
public: \
~DeviceImpl() = default; \
\
template <typename... Args> \
static infiniStatus_t create( \
DeviceImpl **device_info, \
Args &&...args); \
\
/* Invoke elementwise operation when all inputs have the same type */ \
template <typename Op, typename Tdata, typename... Args> \
void calculate( \
const op::elementwise::ElementwiseInfo &info, \
void *output, \
const std::vector<const void *> &inputs, \
Args &&...args); \
\
/* Invoke elementwise operation for different input types */ \
template <typename Op, typename Tout, typename... Tin, \
typename... Args, \
std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int> = 0> \
void calculate( \
const op::elementwise::ElementwiseInfo &info, \
void *output, \
const std::vector<const void *> &inputs, \
Args &&...args); \
}; \
}
#define ELEMENTWISE_DESCRIPTOR(OP, NAMESPACE) \
\
namespace op::OP::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
infiniDtype_t _dtype; \
op::elementwise::ElementwiseInfo _info; \
std::unique_ptr<op::elementwise::NAMESPACE::DeviceImpl> _device_info; \
\
Descriptor( \
infiniDtype_t dtype, \
op::elementwise::ElementwiseInfo info, \
op::elementwise::NAMESPACE::DeviceImpl *device_info, \
infiniDevice_t device_type, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_dtype(dtype), \
_info(info), \
_device_info(device_info) {} \
\
public: \
~Descriptor(); \
\
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t output_desc, \
std::vector<infiniopTensorDescriptor_t> input_descs); \
\
infiniStatus_t calculate( \
void *output, \
std::vector<const void *> inputs, \
void *stream) const; \
}; \
}
namespace op::elementwise {
// struct that stores data needed for elementwise operation
struct ElementwiseInfo {
size_t output_size;
size_t ndim;
bool output_contiguous;
std::vector<bool> input_contiguous;
std::vector<bool> input_broadcasted;
std::vector<size_t> output_shape;
std::vector<std::vector<size_t>> input_shapes;
std::vector<ptrdiff_t> output_strides;
std::vector<std::vector<ptrdiff_t>> input_strides;
};
inline infiniStatus_t createElementwiseInfo(
ElementwiseInfo &info,
infiniopTensorDescriptor_t output_desc,
std::vector<infiniopTensorDescriptor_t> input_descs) {
if (!output_desc || input_descs.empty()) {
return INFINI_STATUS_BAD_PARAM;
}
// Destination cannot have broadcast setup
if (output_desc->hasBroadcastDim()) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
const size_t input_size = input_descs.size();
const size_t out_ndim = output_desc->ndim();
// Intializing the ElementwiseInfo struct
info.output_size = output_desc->numel();
info.ndim = out_ndim;
info.output_contiguous = output_desc->isContiguous();
for (const auto &desc : input_descs) {
info.input_contiguous.emplace_back(desc->isContiguous());
}
for (size_t i = 0; i < input_size; ++i) {
const auto &desc = input_descs[i];
info.input_broadcasted.emplace_back(!info.input_contiguous[i] && (desc->ndim() != out_ndim || desc->hasBroadcastDim()));
}
info.output_shape = std::move(output_desc->shape());
info.output_strides = std::move(output_desc->strides());
for (const auto &desc : input_descs) {
info.input_shapes.emplace_back(desc->shape());
info.input_strides.emplace_back(desc->strides());
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::elementwise
#endif // __INFINIOP_ELEMENTWISE_H__
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment