Commit f3eb5a18 authored by Jakub Piasecki's avatar Jakub Piasecki
Browse files

tmp save

parent 26a4993d
...@@ -103,11 +103,11 @@ float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& args, const ck_til ...@@ -103,11 +103,11 @@ float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& args, const ck_til
using Row = ck_tile::tensor_layout::gemm::RowMajor; using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
if(t.is_a_rowmajor && t.is_b_rowmajor && t.is_c_rowmajor) // if(t.is_a_rowmajor && t.is_b_rowmajor && t.is_c_rowmajor)
{ // {
return gemm_<Row, Row, Row>(args, s); // return gemm_<Row, Row, Row>(args, s);
} // }
else if(t.is_a_rowmajor && !t.is_b_rowmajor && t.is_c_rowmajor) if(t.is_a_rowmajor && !t.is_b_rowmajor && t.is_c_rowmajor)
{ {
return gemm_<Row, Col, Row>(args, s); return gemm_<Row, Col, Row>(args, s);
} }
......
...@@ -21,6 +21,9 @@ float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) ...@@ -21,6 +21,9 @@ float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
typename Traits_::CDataType, typename Traits_::CDataType,
Traits_::kPadM, Traits_::kPadM,
Traits_::kPadN>>; Traits_::kPadN>>;
constexpr bool TransposeC = false;
using GemmUniversalTraits = ck_tile::
TileGemmUniversalTraits<Traits_::kPadM, Traits_::kPadN, Traits_::kPadK, Traits_::ALayout, Traits_::BLayout, Traits_::CLayout, TransposeC>;
using GemmTraits = ck_tile::TileGemmTraits<Traits_::kPadM, using GemmTraits = ck_tile::TileGemmTraits<Traits_::kPadM,
Traits_::kPadN, Traits_::kPadN,
Traits_::kPadK, Traits_::kPadK,
...@@ -53,10 +56,10 @@ float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) ...@@ -53,10 +56,10 @@ float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
typename Traits_::BDataType, typename Traits_::BDataType,
typename Traits_::AccDataType, typename Traits_::AccDataType,
GemmShape, GemmShape,
GemmTraits, GemmUniversalTraits,
ck_tile::GemmPipelineScheduler::Intrawave, ck_tile::GemmPipelineScheduler::Intrawave,
has_hot_loop_v, has_hot_loop_v,
tail_number_v>>; tail_number_v>, ck_tile::UniversalGemmPipelineAgBgCrPolicy>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>; using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args); auto kargs = Kernel::MakeKernelArgs(args);
......
...@@ -26,7 +26,7 @@ float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) ...@@ -26,7 +26,7 @@ float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
Traits_::kPadK, Traits_::kPadK,
typename Traits_::ALayout, typename Traits_::ALayout,
typename Traits_::BLayout, typename Traits_::BLayout,
typename Traits_::CLayout>; typename Traits_::CLayout>;
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem< using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<
ck_tile::GemmPipelineProblem<typename Traits_::ADataType, ck_tile::GemmPipelineProblem<typename Traits_::ADataType,
typename Traits_::BDataType, typename Traits_::BDataType,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment