#pragma once #include "common.cuh" template struct ConstantMatrixDescriptor { __host__ __device__ ConstantMatrixDescriptor() { static_assert(NCol_ <= RowStride_, "wrong! NCol > RowStride!"); } __host__ __device__ constexpr unsigned NRow() const { return NRow_; } __host__ __device__ constexpr unsigned NCol() const { return NCol_; } __host__ __device__ constexpr unsigned RowStride() const { return RowStride_; } __host__ __device__ constexpr auto GetLengths() const { return Sequence{}; } __host__ __device__ constexpr unsigned GetElementSize() const { return NRow_ * NCol_; } __host__ __device__ constexpr unsigned GetElementSpace() const { return NRow_ * RowStride_; } __host__ __device__ unsigned Get1dIndex(unsigned irow, unsigned icol) const { return irow * RowStride_ + icol; } template __host__ __device__ constexpr auto MakeSubMatrixDescriptor(Number, Number) const { return ConstantMatrixDescriptor{}; } }; template __host__ __device__ constexpr auto make_ConstantMatrixDescriptor(Number, Number) { return ConstantMatrixDescriptor{}; } template __host__ __device__ constexpr auto make_ConstantMatrixDescriptor(Number, Number, Number) { return ConstantMatrixDescriptor{}; } template __host__ __device__ void print_ConstantMatrixDescriptor(TDesc, const char* s) { const auto desc = TDesc{}; constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; printf("%s NRow %u NCol %u RowStride %u\n", s, desc.NRow(), desc.NCol(), desc.RowStride()); }