Commit f17fa4d7 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Move 'Block2TileMap' definition into 'GridwisePermute'

parent b681fc26
......@@ -14,61 +14,6 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
namespace detail {
template <index_t NPerBlock, index_t HPerBlock, index_t WPerBlock, typename GridDesc>
struct GridwisePermuteBlock2TileMap
{
static constexpr index_t NumDim = GridDesc::GetNumOfDimension();
static_assert(3 <= NumDim);
static constexpr auto I0 = Number<0>{};
GridwisePermuteBlock2TileMap() = delete;
GridwisePermuteBlock2TileMap(const GridwisePermuteBlock2TileMap&) = default;
GridwisePermuteBlock2TileMap(GridwisePermuteBlock2TileMap&&) = delete;
~GridwisePermuteBlock2TileMap() = default;
GridwisePermuteBlock2TileMap& operator=(const GridwisePermuteBlock2TileMap&) = delete;
GridwisePermuteBlock2TileMap& operator=(GridwisePermuteBlock2TileMap&&) = delete;
explicit GridwisePermuteBlock2TileMap(const GridDesc& desc) : desc_(desc) {}
__host__ constexpr index_t CalculateGridSize(const GridDesc& desc) const
{
const auto N0 = math::integer_divide_ceil(desc.GetLength(Number<NumDim - 3>{}), NPerBlock);
const auto H0 = math::integer_divide_ceil(desc.GetLength(Number<NumDim - 2>{}), HPerBlock);
const auto W0 = math::integer_divide_ceil(desc.GetLength(Number<NumDim - 1>{}), WPerBlock);
const index_t grid_size = N0 * H0 * W0;
return grid_size;
}
template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{
static_assert(TopIdx::Size() == 1);
auto block_1d_id = idx_top[I0];
const auto N0 = math::integer_divide_ceil(desc_.GetLength(Number<NumDim - 3>{}), NPerBlock);
const auto H0 = math::integer_divide_ceil(desc_.GetLength(Number<NumDim - 2>{}), HPerBlock);
const auto W0 = math::integer_divide_ceil(desc_.GetLength(Number<NumDim - 1>{}), WPerBlock);
block_1d_id = block_1d_id % (N0 * H0 * W0);
index_t idx_N0 = block_1d_id / (H0 * W0);
index_t idx_H0 = (block_1d_id % (H0 * W0)) / W0;
index_t idx_W0 = block_1d_id % W0;
return make_tuple(idx_N0, idx_H0, idx_W0);
}
private:
const GridDesc desc_;
};
} // namespace detail
template <typename GridwisePermute,
typename InGridDesc,
......@@ -127,8 +72,66 @@ struct GridwisePermute
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using DefaultBlock2TileMap =
detail::GridwisePermuteBlock2TileMap<NPerBlock, HPerBlock, WPerBlock, InGridDesc>;
struct Block2TileMap
{
static constexpr index_t NumDim = InGridDesc::GetNumOfDimension();
static_assert(3 <= NumDim);
static constexpr auto I0 = Number<0>{};
Block2TileMap() = delete;
Block2TileMap(const Block2TileMap&) = default;
Block2TileMap(Block2TileMap&&) = delete;
~Block2TileMap() = default;
Block2TileMap& operator=(const Block2TileMap&) = delete;
Block2TileMap& operator=(Block2TileMap&&) = delete;
explicit Block2TileMap(const InGridDesc& desc) : desc_(desc) {}
__host__ constexpr index_t CalculateGridSize(const InGridDesc& desc) const
{
const auto N0 =
math::integer_divide_ceil(desc.GetLength(Number<NumDim - 3>{}), NPerBlock);
const auto H0 =
math::integer_divide_ceil(desc.GetLength(Number<NumDim - 2>{}), HPerBlock);
const auto W0 =
math::integer_divide_ceil(desc.GetLength(Number<NumDim - 1>{}), WPerBlock);
const index_t grid_size = N0 * H0 * W0;
return grid_size;
}
template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{
static_assert(TopIdx::Size() == 1);
auto block_1d_id = idx_top[I0];
const auto N0 =
math::integer_divide_ceil(desc_.GetLength(Number<NumDim - 3>{}), NPerBlock);
const auto H0 =
math::integer_divide_ceil(desc_.GetLength(Number<NumDim - 2>{}), HPerBlock);
const auto W0 =
math::integer_divide_ceil(desc_.GetLength(Number<NumDim - 1>{}), WPerBlock);
block_1d_id = block_1d_id % (N0 * H0 * W0);
index_t idx_N0 = block_1d_id / (H0 * W0);
index_t idx_H0 = (block_1d_id % (H0 * W0)) / W0;
index_t idx_W0 = block_1d_id % W0;
return make_tuple(idx_N0, idx_H0, idx_W0);
}
private:
const InGridDesc desc_;
};
using DefaultBlock2TileMap = Block2TileMap;
__host__ __device__ static constexpr auto GetInBlockDesc_NPerBlock_HPerBlock_WPerBlock()
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment