#include template struct DeviceTensor { public: inline __device__ __host__ DeviceTensor(DType *p, const int *size) : dptr_(p) { for (int i = 0; i < Dim; ++i) { size_[i] = size ? size[i] : 0; } } inline __device__ __host__ unsigned getSize(const int i) const { assert(i < Dim); return size_[i]; } inline __device__ __host__ int numElements() const { int n = 1; for (int i = 0; i < Dim; ++i) { n *= size_[i]; } return n; } inline __device__ __host__ DeviceTensor select(const size_t x) const { assert(Dim > 1); int offset = x; for (int i = 1; i < Dim; ++i) { offset *= size_[i]; } DeviceTensor tensor(dptr_ + offset, nullptr); for (int i = 0; i < Dim - 1; ++i) { tensor.size_[i] = this->size_[i+1]; } return tensor; } inline __device__ __host__ DeviceTensor operator[](const size_t x) const { assert(Dim > 1); int offset = x; for (int i = 1; i < Dim; ++i) { offset *= size_[i]; } DeviceTensor tensor(dptr_ + offset, nullptr); for (int i = 0; i < Dim - 1; ++i) { tensor.size_[i] = this->size_[i+1]; } return tensor; } inline __device__ __host__ size_t InnerSize() const { assert(Dim >= 3); size_t sz = 1; for (size_t i = 2; i < Dim; ++i) { sz *= size_[i]; } return sz; } inline __device__ __host__ size_t ChannelCount() const { assert(Dim >= 3); return size_[1]; } inline __device__ __host__ DType* data_ptr() const { return dptr_; } DType *dptr_; int size_[Dim]; }; template struct DeviceTensor { inline __device__ __host__ DeviceTensor(DType *p, const int *size) : dptr_(p) { size_[0] = size ? size[0] : 0; } inline __device__ __host__ unsigned getSize(const int i) const { assert(i == 0); return size_[0]; } inline __device__ __host__ int numElements() const { return size_[0]; } inline __device__ __host__ DType &operator[](const size_t x) const { return *(dptr_ + x); } inline __device__ __host__ DType* data_ptr() const { return dptr_; } DType *dptr_; int size_[1]; }; template static DeviceTensor devicetensor(const at::Tensor &blob) { DType *data = blob.data_ptr(); DeviceTensor tensor(data, nullptr); for (int i = 0; i < Dim; ++i) { tensor.size_[i] = blob.size(i); } return tensor; }