Commit 4809badf authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Seperate template parameters

parent 9a06e83e
...@@ -15,14 +15,13 @@ ...@@ -15,14 +15,13 @@
namespace ck { namespace ck {
namespace detail { namespace detail {
template <typename TileDims, typename GridDescriptor> template <index_t HPerBlock, index_t WPerBlock, typename GridDesc>
struct Block2TileMap struct Block2TileMap
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto NumDim = Number<GridDesc::GetNumOfDimension()>{};
static_assert(2 <= NumDim);
static constexpr index_t NumDim = TileDims::Size(); static constexpr auto I0 = Number<0>{};
static_assert(NumDim == 2);
static_assert(NumDim <= GridDescriptor::GetNumOfDimension());
Block2TileMap() = delete; Block2TileMap() = delete;
Block2TileMap(const Block2TileMap&) = default; Block2TileMap(const Block2TileMap&) = default;
...@@ -33,22 +32,16 @@ struct Block2TileMap ...@@ -33,22 +32,16 @@ struct Block2TileMap
Block2TileMap& operator=(const Block2TileMap&) = delete; Block2TileMap& operator=(const Block2TileMap&) = delete;
Block2TileMap& operator=(Block2TileMap&&) = delete; Block2TileMap& operator=(Block2TileMap&&) = delete;
explicit Block2TileMap(const GridDescriptor& desc) : desc_(desc) {} explicit Block2TileMap(const GridDesc& desc) : desc_(desc) {}
__host__ constexpr index_t CalculateGridSize(const GridDescriptor& desc) const __host__ constexpr index_t CalculateGridSize(const GridDesc& desc) const
{ {
return [&]() { const auto H0 = math::integer_divide_ceil(desc.GetLength(NumDim - Number<2>{}), HPerBlock);
std::array<index_t, 2> num_tiles_per_axis; const auto W0 = math::integer_divide_ceil(desc.GetLength(NumDim - Number<1>{}), WPerBlock);
static_for<NumDim - 2, NumDim, 1>{}([&](auto I) {
num_tiles_per_axis[I - (NumDim - 2)] = const index_t grid_size = H0 * W0;
math::integer_divide_ceil(desc.GetLength(I), TileDims::At(I - (NumDim - 2)));
}); return grid_size;
return std::accumulate(begin(num_tiles_per_axis),
end(num_tiles_per_axis),
index_t{1},
std::multiplies<index_t>{});
}();
} }
template <typename TopIdx> template <typename TopIdx>
...@@ -58,34 +51,17 @@ struct Block2TileMap ...@@ -58,34 +51,17 @@ struct Block2TileMap
auto block_1d_id = idx_top[I0]; auto block_1d_id = idx_top[I0];
std::array<index_t, 2> num_tiles_per_axis; const auto H0 = math::integer_divide_ceil(desc_.GetLength(NumDim - Number<2>{}), HPerBlock);
static_for<NumDim - 2, NumDim, 1>{}([&](auto I) { const auto W0 = math::integer_divide_ceil(desc_.GetLength(NumDim - Number<1>{}), WPerBlock);
num_tiles_per_axis[I - (NumDim - 2)] =
math::integer_divide_ceil(desc_.GetLength(I), TileDims::At(I - (NumDim - 2)));
});
std::array<index_t, 2> divisors;
index_t product = 1;
auto divisor = rbegin(divisors);
for(auto num_tiles = rbegin(num_tiles_per_axis); num_tiles != rend(num_tiles_per_axis);
++num_tiles)
{
product *= (*num_tiles);
*(divisor++) = product;
}
const index_t grid_size = divisors.front(); index_t idx_H0 = block_1d_id / W0;
block_1d_id = block_1d_id % grid_size; // swallow batch index index_t idx_W0 = block_1d_id % W0;
return generate_tuple( return make_tuple(idx_H0, idx_W0);
[&](auto I) {
return (block_1d_id % divisors[I]) / (divisors[I] / num_tiles_per_axis[I]);
},
Number<2>{});
} }
private: private:
const GridDescriptor desc_; const GridDesc desc_;
}; };
} // namespace detail } // namespace detail
...@@ -137,7 +113,7 @@ struct GridwisePermute ...@@ -137,7 +113,7 @@ struct GridwisePermute
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using DefaultBlock2TileMap = detail::Block2TileMap<Sequence<HPerBlock, WPerBlock>, InGridDesc>; using DefaultBlock2TileMap = detail::Block2TileMap<HPerBlock, WPerBlock, InGridDesc>;
__host__ __device__ static constexpr auto GetInBlockDescriptor() __host__ __device__ static constexpr auto GetInBlockDescriptor()
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment