Commit 55cd22e3 authored by Jiacheng Huang's avatar Jiacheng Huang Committed by wooway777
Browse files

issue/402 - convenient ninetoothed util

对 `NineToothedTensor` 进行 C++ 层封装

加入使用数组作为 `shape` 和 `strides` 创建 `ninetoothed::Tensor` 的方式

使用 `ninetoothed::Tensor` 接入九齿的 ReLU 算子

Add an include guard to `ninetoothed/utils.h`
parent 7c5aa160
#ifndef __NINETOOTHED_UTILS__
#define __NINETOOTHED_UTILS__
#include <initializer_list>
#include <limits>
#include <type_traits>
#include <vector>
namespace ninetoothed {
template <typename T = float>
class Tensor {
public:
using Data = decltype(NineToothedTensor::data);
using Size = std::remove_pointer_t<decltype(NineToothedTensor::shape)>;
using Stride = std::remove_pointer_t<decltype(NineToothedTensor::strides)>;
template <typename Shape, typename Strides>
Tensor(const void *data, Shape shape, Strides strides) : data_{data}, shape_{shape}, strides_{strides}, ndim_{shape_.size()} {}
Tensor(const void *data, std::initializer_list<Size> shape, std::initializer_list<Stride> strides) : Tensor{data, decltype(shape_){shape}, decltype(strides_){strides}} {}
Tensor(const void *data, const Size *shape, const Stride *strides, Size ndim) : data_{data}, shape_{shape, shape + ndim}, strides_{strides, strides + ndim}, ndim_{shape_.size()} {}
Tensor(const T value) : value_{value}, data_{&value_}, ndim_{0} {}
operator NineToothedTensor() { return {const_cast<Data>(data_), shape_.data(), strides_.data()}; }
template <typename Shape>
Tensor expand(const Shape &sizes) const {
auto new_ndim{sizes.size()};
decltype(shape_) shape(new_ndim, 1);
decltype(strides_) strides(new_ndim, 0);
auto num_new_dims{new_ndim - ndim_};
for (auto dim{decltype(ndim_){0}}; dim < ndim_; ++dim) {
shape[dim + num_new_dims] = shape_[dim];
strides[dim + num_new_dims] = strides_[dim];
}
for (auto dim{decltype(new_ndim){0}}; dim < new_ndim; ++dim) {
if (sizes[dim] == std::numeric_limits<std::remove_reference_t<decltype(sizes[dim])>>::max() || shape[dim] != 1) {
continue;
}
shape[dim] = sizes[dim];
strides[dim] = 0;
}
return {data_, shape, strides};
}
Tensor expand_as(const Tensor &other) const {
return expand(other.shape_);
}
private:
const void *data_{nullptr};
std::vector<Size> shape_;
std::vector<Stride> strides_;
Size ndim_{0};
T value_{0};
};
} // namespace ninetoothed
#endif
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include "../../../../../build/ninetoothed/relu.h" #include "../../../../../build/ninetoothed/relu.h"
#include "../../../devices/metax/metax_common.h" #include "../../../devices/metax/metax_common.h"
#include "../../../ninetoothed/utils.h"
#include "relu_metax.h" #include "relu_metax.h"
namespace op::relu::metax { namespace op::relu::metax {
...@@ -42,22 +43,10 @@ infiniStatus_t Descriptor::calculate( ...@@ -42,22 +43,10 @@ infiniStatus_t Descriptor::calculate(
} }
const auto &ndim{_info.getNdim()}; const auto &ndim{_info.getNdim()};
const auto &x_shape_{_info.getInputShape(0)};
const auto &x_strides_{_info.getInputStrides(0)}; auto x{ninetoothed::Tensor{inputs[0], _info.getInputShape(0), _info.getInputStrides(0), ndim}};
std::vector<uint64_t> x_shape_vec{x_shape_, x_shape_ + ndim}; auto y{ninetoothed::Tensor{output, _info.getOutputShape(), _info.getOutputStrides(), ndim}};
std::vector<int64_t> x_strides_vec{x_strides_, x_strides_ + ndim};
auto x_data{const_cast<void *>(inputs[0])};
auto x_shape{x_shape_vec.data()};
auto x_strides{x_strides_vec.data()};
const NineToothedTensor x{x_data, x_shape, x_strides};
const auto &y_shape_{_info.getOutputShape()};
const auto &y_strides_{_info.getOutputStrides()};
std::vector<uint64_t> y_shape_vec{y_shape_, y_shape_ + ndim};
std::vector<int64_t> y_strides_vec{y_strides_, y_strides_ + ndim};
auto y_data{output};
auto y_shape{y_shape_vec.data()};
auto y_strides{y_strides_vec.data()};
const NineToothedTensor y{y_data, y_shape, y_strides};
constexpr auto block_size{1024}; constexpr auto block_size{1024};
switch (_dtype) { switch (_dtype) {
......
#ifdef ENABLE_NINETOOTHED #ifdef ENABLE_NINETOOTHED
#include "../../../../../build/ninetoothed/relu.h" #include "../../../../../build/ninetoothed/relu.h"
#include "../../../ninetoothed/utils.h"
#endif #endif
#include "../../../devices/nvidia/nvidia_common.cuh" #include "../../../devices/nvidia/nvidia_common.cuh"
#include "../../../elementwise/nvidia/elementwise_nvidia.cuh" #include "../../../elementwise/nvidia/elementwise_nvidia.cuh"
...@@ -45,22 +46,10 @@ infiniStatus_t Descriptor::calculate( ...@@ -45,22 +46,10 @@ infiniStatus_t Descriptor::calculate(
} }
#ifdef ENABLE_NINETOOTHED #ifdef ENABLE_NINETOOTHED
const auto &ndim{_info.getNdim()}; const auto &ndim{_info.getNdim()};
const auto &x_shape_{_info.getInputShape(0)};
const auto &x_strides_{_info.getInputStrides(0)}; auto x{ninetoothed::Tensor{inputs[0], _info.getInputShape(0), _info.getInputStrides(0), ndim}};
std::vector<uint64_t> x_shape_vec{x_shape_, x_shape_ + ndim}; auto y{ninetoothed::Tensor{output, _info.getOutputShape(), _info.getOutputStrides(), ndim}};
std::vector<int64_t> x_strides_vec{x_strides_, x_strides_ + ndim};
auto x_data{const_cast<void *>(inputs[0])};
auto x_shape{x_shape_vec.data()};
auto x_strides{x_strides_vec.data()};
const NineToothedTensor x{x_data, x_shape, x_strides};
const auto &y_shape_{_info.getOutputShape()};
const auto &y_strides_{_info.getOutputStrides()};
std::vector<uint64_t> y_shape_vec{y_shape_, y_shape_ + ndim};
std::vector<int64_t> y_strides_vec{y_strides_, y_strides_ + ndim};
auto y_data{output};
auto y_shape{y_shape_vec.data()};
auto y_strides{y_strides_vec.data()};
const NineToothedTensor y{y_data, y_shape, y_strides};
constexpr auto block_size{1024}; constexpr auto block_size{1024};
switch (_dtype) { switch (_dtype) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment