Commit 1b648f2f authored by Chao Liu's avatar Chao Liu
Browse files

use constant tensor descriptor

parent 9657baec
...@@ -3,8 +3,14 @@ ...@@ -3,8 +3,14 @@
#include <initializer_list> #include <initializer_list>
#include "nvToolsExt.h" #include "nvToolsExt.h"
#include "tensor.hpp" #include "tensor.hpp"
#include "device_tensor.cuh" #include "constant_tensor_descriptor.cuh"
#include "device_tensor_descriptor.cuh"
#if 0
#include "direct_convolution.cuh" #include "direct_convolution.cuh"
#else
#include "constant_direct_convolution.cuh"
#endif
template <class T> template <class T>
struct GeneratorConstant struct GeneratorConstant
...@@ -38,11 +44,46 @@ struct GeneratorTensor ...@@ -38,11 +44,46 @@ struct GeneratorTensor
} }
}; };
template <typename T> // this is ugly, only for 4d
void host_convolution(const Tensor<T>& in, template <class TConstTensorDesc>
const Tensor<T>& wei, void ostream_ConstantTensorDescriptor(TConstTensorDesc, std::ostream& os = std::cout)
Tensor<T>& out, {
std::size_t num_thread) static_assert(TConstTensorDesc::nDim == 4, "nDim is not 4");
constexpr auto I0 = Index<0>{};
constexpr auto I1 = Index<1>{};
constexpr auto I2 = Index<2>{};
constexpr auto I3 = Index<3>{};
constexpr auto desc = TConstTensorDesc{};
os << "Lengths: {" << desc.GetLength(I0) << ", " << desc.GetLength(I1) << ", "
<< desc.GetLength(I2) << ", " << desc.GetLength(I3) << "}, "
<< "Strides: {" << desc.GetStride(I0) << ", " << desc.GetStride(I1) << ", "
<< desc.GetStride(I2) << ", " << desc.GetStride(I3) << "}" << std::endl;
}
// this is ugly, only for 4d
template <class TConstTensorDesc>
auto make_TensorDescriptor(TConstTensorDesc)
{
static_assert(TConstTensorDesc::nDim == 4, "nDim is not 4");
constexpr auto I0 = Index<0>{};
constexpr auto I1 = Index<1>{};
constexpr auto I2 = Index<2>{};
constexpr auto I3 = Index<3>{};
constexpr auto desc = TConstTensorDesc{};
std::initializer_list<unsigned> lengths = {
desc.GetLength(I0), desc.GetLength(I1), desc.GetLength(I2), desc.GetLength(I3)};
std::initializer_list<unsigned> strides = {
desc.GetStride(I0), desc.GetStride(I1), desc.GetStride(I2), desc.GetStride(I3)};
return TensorDescriptor(lengths, strides);
}
template <class T>
void host_convolution(const Tensor<T>& in, const Tensor<T>& wei, Tensor<T>& out)
{ {
auto f = [&](auto n, auto k, auto ho, auto wo) { auto f = [&](auto n, auto k, auto ho, auto wo) {
double v = 0; double v = 0;
...@@ -67,12 +108,12 @@ void host_convolution(const Tensor<T>& in, ...@@ -67,12 +108,12 @@ void host_convolution(const Tensor<T>& in,
out.mDesc.GetLengths()[2], out.mDesc.GetLengths()[2],
out.mDesc.GetLengths()[3]); out.mDesc.GetLengths()[3]);
f_par(num_thread); f_par(std::thread::hardware_concurrency());
} }
template <class T> template <class T, class InDesc, class WeiDesc, class OutDesc>
void device_convolution(const Tensor<T>& in, const Tensor<T>& wei, Tensor<T>& out) void device_convolution(
InDesc, const Tensor<T>& in, WeiDesc, const Tensor<T>& wei, OutDesc, Tensor<T>& out)
{ {
DeviceTensorDescriptor<4> in_desc_device(in.mDesc); DeviceTensorDescriptor<4> in_desc_device(in.mDesc);
DeviceTensorDescriptor<4> wei_desc_device(wei.mDesc); DeviceTensorDescriptor<4> wei_desc_device(wei.mDesc);
...@@ -103,6 +144,7 @@ void device_convolution(const Tensor<T>& in, const Tensor<T>& wei, Tensor<T>& ou ...@@ -103,6 +144,7 @@ void device_convolution(const Tensor<T>& in, const Tensor<T>& wei, Tensor<T>& ou
dim3 block_dim(64, 1, 1); dim3 block_dim(64, 1, 1);
dim3 grid_dim(1, 1, 1); dim3 grid_dim(1, 1, 1);
#if 0
gridwise_convolution<T, 3, 3, 4, 4, 2, 2, 1, 1, 8, 8, 1> gridwise_convolution<T, 3, 3, 4, 4, 2, 2, 1, 1, 8, 8, 1>
<<<grid_dim, block_dim>>>(in_desc_device, <<<grid_dim, block_dim>>>(in_desc_device,
static_cast<T*>(in_device_buf.GetDeviceBuffer()), static_cast<T*>(in_device_buf.GetDeviceBuffer()),
...@@ -110,6 +152,15 @@ void device_convolution(const Tensor<T>& in, const Tensor<T>& wei, Tensor<T>& ou ...@@ -110,6 +152,15 @@ void device_convolution(const Tensor<T>& in, const Tensor<T>& wei, Tensor<T>& ou
static_cast<T*>(wei_device_buf.GetDeviceBuffer()), static_cast<T*>(wei_device_buf.GetDeviceBuffer()),
out_desc_device, out_desc_device,
static_cast<T*>(out_device_buf.GetDeviceBuffer())); static_cast<T*>(out_device_buf.GetDeviceBuffer()));
#else
gridwise_convolution<T, InDesc, WeiDesc, OutDesc, 4, 4, 2, 2, 1, 1, 8, 8, 1>
<<<grid_dim, block_dim>>>(InDesc{},
static_cast<T*>(in_device_buf.GetDeviceBuffer()),
WeiDesc{},
static_cast<T*>(wei_device_buf.GetDeviceBuffer()),
OutDesc{},
static_cast<T*>(out_device_buf.GetDeviceBuffer()));
#endif
checkCudaErrors(cudaGetLastError()); checkCudaErrors(cudaGetLastError());
out_device_buf.FromDevice(out.mData.data()); out_device_buf.FromDevice(out.mData.data());
...@@ -117,34 +168,53 @@ void device_convolution(const Tensor<T>& in, const Tensor<T>& wei, Tensor<T>& ou ...@@ -117,34 +168,53 @@ void device_convolution(const Tensor<T>& in, const Tensor<T>& wei, Tensor<T>& ou
int main() int main()
{ {
#if 0 #if 1
Tensor<float> in({3, 16, 130, 130}); constexpr unsigned N = 1;
Tensor<float> wei({4, 16, 3, 3}); constexpr unsigned C = 1;
Tensor<float> out_host({3, 4, 128, 128}); constexpr unsigned HI = 18;
constexpr unsigned WI = 18;
constexpr unsigned K = 1;
constexpr unsigned S = 3;
constexpr unsigned R = 3;
#elif 0 #elif 0
Tensor<float> in({1, 1, 130, 130}); constexpr unsigned N = 1;
Tensor<float> wei({1, 1, 3, 3}); constexpr unsigned C = 1;
Tensor<float> out_host({1, 1, 128, 128}); constexpr unsigned HI = 130;
#elif 1 constexpr unsigned WI = 130;
Tensor<float> in({1, 1, 18, 18}); constexpr unsigned K = 1;
Tensor<float> wei({1, 1, 3, 3}); constexpr unsigned S = 3;
Tensor<float> out_host({1, 1, 16, 16}); constexpr unsigned R = 3;
#else #elif 0
Tensor<float> in({1, 1, 4, 4}); constexpr unsigned N = 3;
Tensor<float> wei({1, 1, 3, 3}); constexpr unsigned C = 16;
Tensor<float> out_host({1, 1, 2, 2}); constexpr unsigned HI = 130;
constexpr unsigned WI = 130;
constexpr unsigned K = 4;
constexpr unsigned S = 3;
constexpr unsigned R = 3;
#endif #endif
auto in_desc = make_ConstantTensorDescriptor(Sequence<N, C, HI, WI>{});
auto wei_desc = make_ConstantTensorDescriptor(Sequence<K, C, S, R>{});
auto out_desc = get_output_4d_tensor_descriptor(in_desc, wei_desc);
ostream_ConstantTensorDescriptor(in_desc, std::cout << "in_desc: ");
ostream_ConstantTensorDescriptor(wei_desc, std::cout << "wei_desc: ");
ostream_ConstantTensorDescriptor(out_desc, std::cout << "out_desc: ");
Tensor<float> in(make_TensorDescriptor(in_desc));
Tensor<float> wei(make_TensorDescriptor(wei_desc));
Tensor<float> out_host(make_TensorDescriptor(out_desc));
Tensor<float> out_device = out_host; Tensor<float> out_device = out_host;
int num_thread = std::thread::hardware_concurrency(); int num_thread = std::thread::hardware_concurrency();
std::cout << __func__ << ": num_thread " << num_thread << std::endl;
in.GenerateTensorValue(GeneratorTensor<float>{}, num_thread); in.GenerateTensorValue(GeneratorTensor<float>{}, num_thread);
wei.GenerateTensorValue(GeneratorTensor<float>{}, num_thread); wei.GenerateTensorValue(GeneratorTensor<float>{}, num_thread);
host_convolution(in, wei, out_host, num_thread); host_convolution(in, wei, out_host);
device_convolution(in, wei, out_device); device_convolution(in_desc, in, wei_desc, wei, out_desc, out_device);
std::cout << __func__ << ": done" << std::endl; std::cout << __func__ << ": done" << std::endl;
......
This diff is collapsed.
#pragma once
#include "helper_cuda.h"
template <class T, T N>
struct Constant
{
const T mValue = N;
};
template <unsigned I>
using Index = Constant<unsigned, I>;
template <unsigned... Is>
struct Sequence
{
static constexpr unsigned nDim = sizeof...(Is);
const unsigned mData[nDim] = {Is...};
template <unsigned I>
__host__ __device__ constexpr unsigned Get(Index<I>) const
{
return mData[I];
}
};
#if 0
template<class F, class T, T... Is>
void for_each(F f, std::integer_sequence<T, Is...>)
{
f(Is)...;
}
template<class F, class T, T N>
void for_n_time(F f, Constant<T, N>)
{
for_each(f, std::make_integer_sequence<T, N>{});
}
#endif
template <class Lengths, class Strides>
struct ConstantTensorDescriptor
{
static constexpr unsigned nDim = Lengths::nDim;
using NDimConstant = Index<nDim>;
__host__ __device__ constexpr ConstantTensorDescriptor()
{
static_assert(Lengths::nDim == Strides::nDim, "nDim not consistent");
}
__host__ __device__ constexpr unsigned GetDimension() const { return nDim; }
__host__ __device__ constexpr Lengths GetLengths() const { return Lengths{}; }
__host__ __device__ constexpr Strides GetStrides() const { return Strides{}; }
template <unsigned I>
__host__ __device__ constexpr unsigned GetLength(Index<I>) const
{
return Lengths{}.Get(Index<I>{});
}
template <unsigned I>
__host__ __device__ constexpr unsigned GetStride(Index<I>) const
{
return Strides{}.Get(Index<I>{});
}
#if 0
template <class... Is>
__host__ __device__ unsigned Get1dIndex(Is... is) const
{
static_assert(nDim == sizeof...(Is), "nDim not consistent");
const unsigned iss[nDim] = {static_cast<unsigned>(is)...};
unsigned idx = 0;
for_n_time([&](auto iDim) { idx += iss[iDim] * GetStride<iDim>(); }, NDimConstant{});
return idx;
}
#elif 1
// this is ugly, only for 4d
__host__ __device__ unsigned Get1dIndex(unsigned n, unsigned c, unsigned h, unsigned w) const
{
constexpr auto I0 = Index<0>{};
constexpr auto I1 = Index<1>{};
constexpr auto I2 = Index<2>{};
constexpr auto I3 = Index<3>{};
static_assert(nDim == 4, "nDim not consistent");
return n * GetStride(I0) + c * GetStride(I1) + h * GetStride(I2) + w * GetStride(I3);
}
#endif
};
// this is ugly, only for 4d
template <unsigned N, unsigned C, unsigned H, unsigned W>
__host__ __device__ constexpr auto calculate_default_strides(Sequence<N, C, H, W>)
{
return Sequence<C * H * W, H * W, W, 1>{};
}
template <class Lengths>
__host__ __device__ constexpr auto make_ConstantTensorDescriptor(Lengths)
{
using Strides = decltype(calculate_default_strides(Lengths{}));
return ConstantTensorDescriptor<Lengths, Strides>{};
}
template <class Lengths, class Strides>
__host__ __device__ constexpr auto make_ConstantTensorDescriptor(Lengths, Strides)
{
return ConstantTensorDescriptor<Lengths, Strides>{};
}
// this is ugly, only for 4d
template <class InDesc, class WeiDesc>
__host__ __device__ constexpr auto get_output_4d_tensor_descriptor(InDesc, WeiDesc)
{
constexpr auto in_desc = InDesc{};
constexpr auto wei_desc = WeiDesc{};
constexpr auto I0 = Index<0>{};
constexpr auto I1 = Index<1>{};
constexpr auto I2 = Index<2>{};
constexpr auto I3 = Index<3>{};
static_assert(in_desc.GetDimension() == 4, "input nDim is not 4");
static_assert(wei_desc.GetDimension() == 4, "weight nDim is not 4");
static_assert(in_desc.GetLength(I1) == wei_desc.GetLength(I1),
"input & weight dimension not consistent");
constexpr auto N = in_desc.GetLength(I0);
constexpr auto HI = in_desc.GetLength(I2);
constexpr auto WI = in_desc.GetLength(I3);
constexpr auto K = wei_desc.GetLength(I0);
constexpr auto S = wei_desc.GetLength(I2);
constexpr auto R = wei_desc.GetLength(I3);
constexpr auto HO = HI - S + 1;
constexpr auto WO = WI - R + 1;
return make_ConstantTensorDescriptor(Sequence<N, K, HO, WO>{});
}
// this is ugly, only for 4d
template <class TDesc>
__host__ __device__ void print_ConstantTensorDescriptor(TDesc)
{
constexpr auto desc = TDesc{};
constexpr auto I0 = Index<0>{};
constexpr auto I1 = Index<1>{};
constexpr auto I2 = Index<2>{};
constexpr auto I3 = Index<3>{};
static_assert(desc.GetDimension() == 4, "dim is not 4");
printf("dim %u, lengths {%u %u %u %u}, strides {%u %u %u %u}\n",
desc.GetDimension(),
desc.GetLength(I0),
desc.GetLength(I1),
desc.GetLength(I2),
desc.GetLength(I3),
desc.GetStride(I0),
desc.GetStride(I1),
desc.GetStride(I2),
desc.GetStride(I3));
}
#pragma once #pragma once
#include <algorithm> #include <algorithm>
#include "constant_tensor_descriptor.cuh"
#include "helper_cuda.h" #include "helper_cuda.h"
#include "tensor.hpp" #include "tensor.hpp"
...@@ -19,7 +20,7 @@ struct DeviceTensorDescriptor ...@@ -19,7 +20,7 @@ struct DeviceTensorDescriptor
__host__ __device__ unsigned GetStride(unsigned i) const { return mpStrides[i]; } __host__ __device__ unsigned GetStride(unsigned i) const { return mpStrides[i]; }
// this is ugly // this is ugly, only for 4d
__host__ __device__ unsigned Get1dIndex(unsigned n, unsigned c, unsigned h, unsigned w) const __host__ __device__ unsigned Get1dIndex(unsigned n, unsigned c, unsigned h, unsigned w) const
{ {
return n * mpStrides[0] + c * mpStrides[1] + h * mpStrides[2] + w * mpStrides[3]; return n * mpStrides[0] + c * mpStrides[1] + h * mpStrides[2] + w * mpStrides[3];
...@@ -28,3 +29,32 @@ struct DeviceTensorDescriptor ...@@ -28,3 +29,32 @@ struct DeviceTensorDescriptor
unsigned mpLengths[NDim]; unsigned mpLengths[NDim];
unsigned mpStrides[NDim]; unsigned mpStrides[NDim];
}; };
// this is ugly, only for 4d
template <class TConstTensorDesc>
__host__ __device__ auto make_DeviceTensorDescriptor(TConstTensorDesc)
{
static_assert(TConstTensorDesc::nDim == 4, "nDim is not 4");
constexpr auto I0 = Index<0>{};
constexpr auto I1 = Index<1>{};
constexpr auto I2 = Index<2>{};
constexpr auto I3 = Index<3>{};
constexpr auto const_desc = TConstTensorDesc{};
constexpr auto ndim = const_desc.GetDimension();
auto desc = DeviceTensorDescriptor<ndim>{};
desc.mpLengths[0] = const_desc.GetLength(I0);
desc.mpLengths[1] = const_desc.GetLength(I1);
desc.mpLengths[2] = const_desc.GetLength(I2);
desc.mpLengths[3] = const_desc.GetLength(I3);
desc.mpStrides[0] = const_desc.GetStride(I0);
desc.mpStrides[1] = const_desc.GetStride(I1);
desc.mpStrides[2] = const_desc.GetStride(I2);
desc.mpStrides[3] = const_desc.GetStride(I3);
return desc;
}
#pragma once #pragma once
#include "device_tensor.cuh" #include "device_tensor_descriptor.cuh"
template <class TFloat, template <class TFloat,
unsigned NWorkLen0, unsigned NWorkLen0,
...@@ -13,7 +13,7 @@ __device__ void blockwise_4d_tensor_op(const DeviceTensorDescriptor<4>& src_desc ...@@ -13,7 +13,7 @@ __device__ void blockwise_4d_tensor_op(const DeviceTensorDescriptor<4>& src_desc
TFloat* __restrict__ p_dst, TFloat* __restrict__ p_dst,
F f) F f)
{ {
#if 1 #if 0
if(threadIdx.x == 0) if(threadIdx.x == 0)
{ {
printf("blockwise_4d_tensor_op: 0: \t" printf("blockwise_4d_tensor_op: 0: \t"
...@@ -80,7 +80,7 @@ __device__ void blockwise_4d_tensor_op(const DeviceTensorDescriptor<4>& src_desc ...@@ -80,7 +80,7 @@ __device__ void blockwise_4d_tensor_op(const DeviceTensorDescriptor<4>& src_desc
f(p_src[dindex], p_dst[sindex]); f(p_src[dindex], p_dst[sindex]);
#if 1 #if 0
// if(threadIdx.x == 0) // if(threadIdx.x == 0)
{ {
printf("blockwise_4d_tensor_op: 1: thread id %u, \t" printf("blockwise_4d_tensor_op: 1: thread id %u, \t"
...@@ -106,7 +106,7 @@ __device__ void threadwise_4d_tensor_op(const DeviceTensorDescriptor<4>& src_des ...@@ -106,7 +106,7 @@ __device__ void threadwise_4d_tensor_op(const DeviceTensorDescriptor<4>& src_des
TFloat* __restrict__ p_dst, TFloat* __restrict__ p_dst,
F f) F f)
{ {
#if 1 #if 0
if(threadIdx.x == 0) if(threadIdx.x == 0)
{ {
printf("threadwise_4d_tensor_op: 0: \t" printf("threadwise_4d_tensor_op: 0: \t"
...@@ -151,7 +151,7 @@ __device__ void threadwise_4d_tensor_op(const DeviceTensorDescriptor<4>& src_des ...@@ -151,7 +151,7 @@ __device__ void threadwise_4d_tensor_op(const DeviceTensorDescriptor<4>& src_des
f(p_src[sindex], p_dst[dindex]); f(p_src[sindex], p_dst[dindex]);
#if 1 #if 0
if(threadIdx.x == 0) if(threadIdx.x == 0)
{ {
printf("threadwise_4d_tensor_op: 1: thread id %u, \t" printf("threadwise_4d_tensor_op: 1: thread id %u, \t"
...@@ -178,7 +178,7 @@ __device__ void threadwise_direct_convolution(const DeviceTensorDescriptor<4>& i ...@@ -178,7 +178,7 @@ __device__ void threadwise_direct_convolution(const DeviceTensorDescriptor<4>& i
const DeviceTensorDescriptor<4>& out_desc, const DeviceTensorDescriptor<4>& out_desc,
TFloat* __restrict__ p_out) TFloat* __restrict__ p_out)
{ {
#if 1 #if 0
if(threadIdx.x == 0) if(threadIdx.x == 0)
{ {
printf("threadwise_direct_convolution: 0: \t" printf("threadwise_direct_convolution: 0: \t"
...@@ -212,7 +212,7 @@ __device__ void threadwise_direct_convolution(const DeviceTensorDescriptor<4>& i ...@@ -212,7 +212,7 @@ __device__ void threadwise_direct_convolution(const DeviceTensorDescriptor<4>& i
out_desc.GetStride(2), out_desc.GetStride(2),
out_desc.GetStride(3)); out_desc.GetStride(3));
} }
#elif 1 #elif 0
{ {
printf("threadwise_direct_convolution: 0: \t" printf("threadwise_direct_convolution: 0: \t"
"threadIdx.x %u \t" "threadIdx.x %u \t"
...@@ -275,7 +275,7 @@ __device__ void threadwise_direct_convolution(const DeviceTensorDescriptor<4>& i ...@@ -275,7 +275,7 @@ __device__ void threadwise_direct_convolution(const DeviceTensorDescriptor<4>& i
p_out[out_index] += p_wei[wei_index] * p_in[in_index]; p_out[out_index] += p_wei[wei_index] * p_in[in_index];
#if 1 #if 0
if(threadIdx.x == 0) if(threadIdx.x == 0)
{ {
printf("threadwise_direct_convolution: 1: \t" printf("threadwise_direct_convolution: 1: \t"
...@@ -320,7 +320,7 @@ __device__ void blockwise_convolution(const DeviceTensorDescriptor<4>& in_desc, ...@@ -320,7 +320,7 @@ __device__ void blockwise_convolution(const DeviceTensorDescriptor<4>& in_desc,
const DeviceTensorDescriptor<4>& out_desc, const DeviceTensorDescriptor<4>& out_desc,
TFloat* __restrict__ p_out) TFloat* __restrict__ p_out)
{ {
#if 1 #if 0
if(threadIdx.x == 0) if(threadIdx.x == 0)
{ {
printf("blockwise_convolution: 0: \t" printf("blockwise_convolution: 0: \t"
...@@ -501,7 +501,7 @@ __global__ void gridwise_convolution(const DeviceTensorDescriptor<4> in_desc, ...@@ -501,7 +501,7 @@ __global__ void gridwise_convolution(const DeviceTensorDescriptor<4> in_desc,
const DeviceTensorDescriptor<4> out_desc, const DeviceTensorDescriptor<4> out_desc,
TFloat* __restrict__ p_out) TFloat* __restrict__ p_out)
{ {
#if 1 #if 0
if(threadIdx.x == 0) if(threadIdx.x == 0)
{ {
printf("gridwise_convolution: 0: \t" printf("gridwise_convolution: 0: \t"
......
...@@ -69,28 +69,25 @@ auto construct_f_unpack_args(F, T args) ...@@ -69,28 +69,25 @@ auto construct_f_unpack_args(F, T args)
struct TensorDescriptor struct TensorDescriptor
{ {
TensorDescriptor() = delete; TensorDescriptor() = delete;
TensorDescriptor(DataType_t t, std::initializer_list<std::size_t> lens); TensorDescriptor(std::initializer_list<std::size_t> lens);
TensorDescriptor(DataType_t t, TensorDescriptor(std::initializer_list<std::size_t> lens,
std::initializer_list<std::size_t> lens,
std::initializer_list<std::size_t> strides); std::initializer_list<std::size_t> strides);
TensorDescriptor(DataType_t t, std::vector<std::size_t> lens, std::vector<std::size_t> strides); TensorDescriptor(std::vector<std::size_t> lens, std::vector<std::size_t> strides);
void CalculateStrides(); void CalculateStrides();
template <class Range> template <class Range>
TensorDescriptor(DataType_t t, const Range& lens) TensorDescriptor(const Range& lens) : mLens(lens.begin(), lens.end())
: mLens(lens.begin(), lens.end()), mDataType(t)
{ {
this->CalculateStrides(); this->CalculateStrides();
} }
template <class Range1, class Range2> template <class Range1, class Range2>
TensorDescriptor(DataType_t t, const Range1& lens, const Range2& strides) TensorDescriptor(const Range1& lens, const Range2& strides)
: mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end()), mDataType(t) : mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end())
{ {
} }
DataType_t GetDataType() const;
std::size_t GetDimension() const; std::size_t GetDimension() const;
std::size_t GetElementSize() const; std::size_t GetElementSize() const;
std::size_t GetElementSpace() const; std::size_t GetElementSpace() const;
...@@ -107,7 +104,6 @@ struct TensorDescriptor ...@@ -107,7 +104,6 @@ struct TensorDescriptor
} }
private: private:
DataType_t mDataType;
std::vector<std::size_t> mLens; std::vector<std::size_t> mLens;
std::vector<std::size_t> mStrides; std::vector<std::size_t> mStrides;
}; };
...@@ -220,22 +216,23 @@ template <class T> ...@@ -220,22 +216,23 @@ template <class T>
struct Tensor struct Tensor
{ {
template <class X> template <class X>
Tensor(std::initializer_list<X> lens) Tensor(std::initializer_list<X> lens) : mDesc(lens), mData(mDesc.GetElementSpace())
: mDesc(DataType<T>{}, lens), mData(mDesc.GetElementSpace())
{ {
} }
template <class X> template <class X>
Tensor(std::vector<X> lens) : mDesc(DataType<T>{}, lens), mData(mDesc.GetElementSpace()) Tensor(std::vector<X> lens) : mDesc(lens), mData(mDesc.GetElementSpace())
{ {
} }
template <class X, class Y> template <class X, class Y>
Tensor(std::vector<X> lens, std::vector<Y> strides) Tensor(std::vector<X> lens, std::vector<Y> strides)
: mDesc(DataType<T>{}, lens, strides), mData(mDesc.GetElementSpace()) : mDesc(lens, strides), mData(mDesc.GetElementSpace())
{ {
} }
Tensor(const TensorDescriptor& desc) : mDesc(desc), mData(mDesc.GetElementSpace()) {}
template <class G> template <class G>
void GenerateTensorValue(G g, std::size_t num_thread = 1) void GenerateTensorValue(G g, std::size_t num_thread = 1)
{ {
......
...@@ -3,16 +3,13 @@ ...@@ -3,16 +3,13 @@
#include "tensor.hpp" #include "tensor.hpp"
TensorDescriptor::TensorDescriptor(DataType_t t, std::initializer_list<std::size_t> lens) TensorDescriptor::TensorDescriptor(std::initializer_list<std::size_t> lens) : mLens(lens)
: mLens(lens), mDataType(t)
{ {
this->CalculateStrides(); this->CalculateStrides();
} }
TensorDescriptor::TensorDescriptor(DataType_t t, TensorDescriptor::TensorDescriptor(std::vector<std::size_t> lens, std::vector<std::size_t> strides)
std::vector<std::size_t> lens, : mLens(lens), mStrides(strides)
std::vector<std::size_t> strides)
: mLens(lens), mStrides(strides), mDataType(t)
{ {
} }
...@@ -28,8 +25,6 @@ void TensorDescriptor::CalculateStrides() ...@@ -28,8 +25,6 @@ void TensorDescriptor::CalculateStrides()
mLens.rbegin(), mLens.rend() - 1, mStrides.rbegin() + 1, std::multiplies<std::size_t>()); mLens.rbegin(), mLens.rend() - 1, mStrides.rbegin() + 1, std::multiplies<std::size_t>());
} }
DataType_t TensorDescriptor::GetDataType() const { return mDataType; }
std::size_t TensorDescriptor::GetDimension() const { return mLens.size(); } std::size_t TensorDescriptor::GetDimension() const { return mLens.size(); }
std::size_t TensorDescriptor::GetElementSize() const std::size_t TensorDescriptor::GetElementSize() const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment