Commit 4809badf authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Seperate template parameters

parent 9a06e83e
......@@ -15,14 +15,13 @@
namespace ck {
namespace detail {
template <typename TileDims, typename GridDescriptor>
template <index_t HPerBlock, index_t WPerBlock, typename GridDesc>
struct Block2TileMap
{
static constexpr auto I0 = Number<0>{};
static constexpr auto NumDim = Number<GridDesc::GetNumOfDimension()>{};
static_assert(2 <= NumDim);
static constexpr index_t NumDim = TileDims::Size();
static_assert(NumDim == 2);
static_assert(NumDim <= GridDescriptor::GetNumOfDimension());
static constexpr auto I0 = Number<0>{};
Block2TileMap() = delete;
Block2TileMap(const Block2TileMap&) = default;
......@@ -33,22 +32,16 @@ struct Block2TileMap
Block2TileMap& operator=(const Block2TileMap&) = delete;
Block2TileMap& operator=(Block2TileMap&&) = delete;
explicit Block2TileMap(const GridDescriptor& desc) : desc_(desc) {}
explicit Block2TileMap(const GridDesc& desc) : desc_(desc) {}
__host__ constexpr index_t CalculateGridSize(const GridDescriptor& desc) const
__host__ constexpr index_t CalculateGridSize(const GridDesc& desc) const
{
return [&]() {
std::array<index_t, 2> num_tiles_per_axis;
static_for<NumDim - 2, NumDim, 1>{}([&](auto I) {
num_tiles_per_axis[I - (NumDim - 2)] =
math::integer_divide_ceil(desc.GetLength(I), TileDims::At(I - (NumDim - 2)));
});
return std::accumulate(begin(num_tiles_per_axis),
end(num_tiles_per_axis),
index_t{1},
std::multiplies<index_t>{});
}();
const auto H0 = math::integer_divide_ceil(desc.GetLength(NumDim - Number<2>{}), HPerBlock);
const auto W0 = math::integer_divide_ceil(desc.GetLength(NumDim - Number<1>{}), WPerBlock);
const index_t grid_size = H0 * W0;
return grid_size;
}
template <typename TopIdx>
......@@ -58,34 +51,17 @@ struct Block2TileMap
auto block_1d_id = idx_top[I0];
std::array<index_t, 2> num_tiles_per_axis;
static_for<NumDim - 2, NumDim, 1>{}([&](auto I) {
num_tiles_per_axis[I - (NumDim - 2)] =
math::integer_divide_ceil(desc_.GetLength(I), TileDims::At(I - (NumDim - 2)));
});
std::array<index_t, 2> divisors;
index_t product = 1;
auto divisor = rbegin(divisors);
for(auto num_tiles = rbegin(num_tiles_per_axis); num_tiles != rend(num_tiles_per_axis);
++num_tiles)
{
product *= (*num_tiles);
*(divisor++) = product;
}
const index_t grid_size = divisors.front();
block_1d_id = block_1d_id % grid_size; // swallow batch index
return generate_tuple(
[&](auto I) {
return (block_1d_id % divisors[I]) / (divisors[I] / num_tiles_per_axis[I]);
},
Number<2>{});
const auto H0 = math::integer_divide_ceil(desc_.GetLength(NumDim - Number<2>{}), HPerBlock);
const auto W0 = math::integer_divide_ceil(desc_.GetLength(NumDim - Number<1>{}), WPerBlock);
index_t idx_H0 = block_1d_id / W0;
index_t idx_W0 = block_1d_id % W0;
return make_tuple(idx_H0, idx_W0);
}
private:
const GridDescriptor desc_;
const GridDesc desc_;
};
} // namespace detail
......@@ -137,7 +113,7 @@ struct GridwisePermute
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using DefaultBlock2TileMap = detail::Block2TileMap<Sequence<HPerBlock, WPerBlock>, InGridDesc>;
using DefaultBlock2TileMap = detail::Block2TileMap<HPerBlock, WPerBlock, InGridDesc>;
__host__ __device__ static constexpr auto GetInBlockDescriptor()
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment