Commit 84419a2b authored by Adam Osewski's avatar Adam Osewski
Browse files

Few fixes to B2C map and new functionality.

parent 80622468
......@@ -1214,6 +1214,12 @@ struct BlockToCTileMap_LinearKSplit
return make_tuple(M0_idx_, N0_idx_, K0_idx_);
}
__host__ __device__ index_t GetOutputTileIdx() const
{
const auto N0 = math::integer_divide_ceil(N_, NPerBlock);
return M0_idx_ * N0 + N0_idx_;
}
template <typename CTileIdx, typename CTileDim>
__host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */,
const CTileDim& /* c_tile_dim */) const
......@@ -1222,9 +1228,14 @@ struct BlockToCTileMap_LinearKSplit
}
__host__ __device__ bool GetNextKTileIdx()
{
if(K0_idx_ + 1 < KSplit_)
{
K0_idx_++;
return K0_idx_ < KSplit_;
return true;
}
else
return false;
}
///
......@@ -1236,7 +1247,7 @@ struct BlockToCTileMap_LinearKSplit
///
__host__ __device__ bool IsFirstKSplitBlock(index_t tiles_per_block) const
{
return (K0_idx_ - tiles_per_block) <= 0;
return (K0_idx_ + 1 - tiles_per_block) <= 0;
}
__host__ __device__ index_t GetTileMIdx() const { return M0_idx_; }
......
......@@ -369,6 +369,7 @@ TEST(BlockToCTileMap, BlockToCTileMap_LinearKSplit_NextKTile)
const index_t MPerBlock = 128;
const index_t NPerBlock = 64;
const index_t KSplit = 3;
const index_t tiles_per_block = 1;
auto c_grid_desc_m_n = make_naive_tensor_descriptor_packed(make_tuple(M, N));
......@@ -377,6 +378,7 @@ TEST(BlockToCTileMap, BlockToCTileMap_LinearKSplit_NextKTile)
auto m0n0k0_idx = tile_map.CalculateBottomIndex(3);
EXPECT_EQ((std::vector<int>{m0n0k0_idx[I0], m0n0k0_idx[I1], m0n0k0_idx[I2]}),
(std::vector<int>{0, 1, 0}));
EXPECT_TRUE(tile_map.IsFirstKSplitBlock(tiles_per_block));
for(index_t i = 0; i < KSplit - 1; i++)
{
......@@ -384,9 +386,11 @@ TEST(BlockToCTileMap, BlockToCTileMap_LinearKSplit_NextKTile)
m0n0k0_idx = tile_map.GetBottomIndex();
EXPECT_EQ((std::vector<int>{m0n0k0_idx[I0], m0n0k0_idx[I1], m0n0k0_idx[I2]}),
(std::vector<int>{0, 1, i + 1}));
EXPECT_FALSE(tile_map.IsFirstKSplitBlock(tiles_per_block));
}
EXPECT_FALSE(tile_map.GetNextKTileIdx());
m0n0k0_idx = tile_map.GetBottomIndex();
EXPECT_EQ((std::vector<int>{m0n0k0_idx[I0], m0n0k0_idx[I1], m0n0k0_idx[I2]}),
(std::vector<int>{0, 1, 3}));
(std::vector<int>{0, 1, 2}));
EXPECT_FALSE(tile_map.IsFirstKSplitBlock(tiles_per_block));
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment