Commit 23377f7b authored by Adam Osewski's avatar Adam Osewski
Browse files

Formatting & fix IsTranspose

parent 70e0b55d
......@@ -81,7 +81,10 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
return Policy::template GetSmemSize<Problem>();
}
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return Policy::IsTransposeC(); }
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC()
{
return Policy::template IsTransposeC<Problem>();
}
template <GemmPipelineScheduler Scheduler>
struct PipelineImpl : public PipelineImplBase
......@@ -251,18 +254,16 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
static_assert(is_a_col_major ?
(KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
static_assert(is_a_col_major
? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}])
:
(MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
: (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"A block window has incorrect lengths for defined ALayout!");
static_assert(is_b_row_major ?
(KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
static_assert(is_b_row_major
? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}])
:
(NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
: (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"B block window has incorrect lengths for defined BLayout!");
......@@ -302,7 +303,6 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
constexpr ADramTileWindowStep a_dram_tile_window_step =
is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
constexpr BDramTileWindowStep b_dram_tile_window_step =
......
......@@ -550,7 +550,11 @@ struct UniversalGemmPipelineAgBgCrPolicy
}
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return TransposeC; }
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC()
{
return Problem::TransposeC;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment