Commit 23377f7b authored by Adam Osewski's avatar Adam Osewski
Browse files

Formatting & fix IsTranspose

parent 70e0b55d
...@@ -81,7 +81,10 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem> ...@@ -81,7 +81,10 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
return Policy::template GetSmemSize<Problem>(); return Policy::template GetSmemSize<Problem>();
} }
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return Policy::IsTransposeC(); } CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC()
{
return Policy::template IsTransposeC<Problem>();
}
template <GemmPipelineScheduler Scheduler> template <GemmPipelineScheduler Scheduler>
struct PipelineImpl : public PipelineImplBase struct PipelineImpl : public PipelineImplBase
...@@ -250,21 +253,19 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem> ...@@ -250,21 +253,19 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
constexpr bool is_a_col_major = constexpr bool is_a_col_major =
std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>; std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>; constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
static_assert(is_a_col_major ? static_assert(is_a_col_major
(KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && ? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]) MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}])
: : (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
(MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]),
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]), "A block window has incorrect lengths for defined ALayout!");
"A block window has incorrect lengths for defined ALayout!"); static_assert(is_b_row_major
static_assert(is_b_row_major ? ? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
(KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}])
NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]) : (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
: KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
(NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && "B block window has incorrect lengths for defined BLayout!");
KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"B block window has incorrect lengths for defined BLayout!");
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
// Definitions of all needed tiles // Definitions of all needed tiles
...@@ -302,7 +303,6 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem> ...@@ -302,7 +303,6 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
constexpr ADramTileWindowStep a_dram_tile_window_step = constexpr ADramTileWindowStep a_dram_tile_window_step =
is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
constexpr BDramTileWindowStep b_dram_tile_window_step = constexpr BDramTileWindowStep b_dram_tile_window_step =
......
...@@ -454,15 +454,15 @@ struct UniversalGemmPipelineAgBgCrPolicy ...@@ -454,15 +454,15 @@ struct UniversalGemmPipelineAgBgCrPolicy
{ {
using ALayout = remove_cvref_t<typename Problem::ALayout>; using ALayout = remove_cvref_t<typename Problem::ALayout>;
constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t VecLoadSize = GetVectorSizeA<Problem>(); constexpr index_t VecLoadSize = GetVectorSizeA<Problem>();
// Tile: MPerBlock X KPerBlock // Tile: MPerBlock X KPerBlock
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>)
{ {
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize, using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
MPerBlock, MPerBlock,
KPerBlock, KPerBlock,
VecLoadSize, VecLoadSize,
...@@ -472,7 +472,7 @@ struct UniversalGemmPipelineAgBgCrPolicy ...@@ -472,7 +472,7 @@ struct UniversalGemmPipelineAgBgCrPolicy
// Tile: KPerBlock X MPerBlock // Tile: KPerBlock X MPerBlock
else else
{ {
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize, using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
KPerBlock, KPerBlock,
MPerBlock, MPerBlock,
VecLoadSize, VecLoadSize,
...@@ -486,15 +486,15 @@ struct UniversalGemmPipelineAgBgCrPolicy ...@@ -486,15 +486,15 @@ struct UniversalGemmPipelineAgBgCrPolicy
{ {
using BLayout = remove_cvref_t<typename Problem::BLayout>; using BLayout = remove_cvref_t<typename Problem::BLayout>;
constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t VecLoadSize = GetVectorSizeB<Problem>(); constexpr index_t VecLoadSize = GetVectorSizeB<Problem>();
// Tile: KPerBlock X NPerBlock // Tile: KPerBlock X NPerBlock
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{ {
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize, using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
KPerBlock, KPerBlock,
NPerBlock, NPerBlock,
VecLoadSize, VecLoadSize,
...@@ -504,7 +504,7 @@ struct UniversalGemmPipelineAgBgCrPolicy ...@@ -504,7 +504,7 @@ struct UniversalGemmPipelineAgBgCrPolicy
// Tile: NPerBlock X KPerBlock // Tile: NPerBlock X KPerBlock
else else
{ {
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize, using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
NPerBlock, NPerBlock,
KPerBlock, KPerBlock,
VecLoadSize, VecLoadSize,
...@@ -550,7 +550,11 @@ struct UniversalGemmPipelineAgBgCrPolicy ...@@ -550,7 +550,11 @@ struct UniversalGemmPipelineAgBgCrPolicy
} }
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return TransposeC; } template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC()
{
return Problem::TransposeC;
}
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment