Commit 987cc54d authored by ThomasNing's avatar ThomasNing
Browse files

Finish the integration to develop and have the correct result

parent 3b301468
...@@ -114,8 +114,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& ...@@ -114,8 +114,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
has_hot_loop_v, has_hot_loop_v,
tail_number_v>; tail_number_v>;
using GemmPipeline = using GemmPipeline = GEMM_PIPELINE<UniversalGemmProblem>;
GEMM_PIPELINE<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue< using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<AccDataType, ck_tile::CShuffleEpilogueProblem<AccDataType,
CDataType, CDataType,
...@@ -241,8 +240,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& ...@@ -241,8 +240,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
Run(ck_tile::bool_constant<true>{}, Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{}); ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
} }
}
#endif #endif
}
else else
{ {
// Tail number always Full - #PrefetchStages // Tail number always Full - #PrefetchStages
...@@ -262,12 +261,12 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& ...@@ -262,12 +261,12 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
} }
return ave_time; return ave_time;
} }
#include "run_gemm_example.inc" #include "run_gemm_example.inc"
int run_gemm_example(int argc, char* argv[]) int run_gemm_example(int argc, char* argv[])
{ {
auto [result, arg_parser] = create_args(argc, argv); auto [result, arg_parser] = create_args(argc, argv);
if(!result) if(!result)
return -1; return -1;
...@@ -296,9 +295,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& ...@@ -296,9 +295,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
} }
else else
{ {
throw std::runtime_error( throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
"Unsupported data layout configuration for A,B and C tensors!");
}
} }
}
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
...@@ -490,8 +490,6 @@ struct GemmKernel ...@@ -490,8 +490,6 @@ struct GemmKernel
const auto& c_block_tile = GemmPipeline{}.template operator()( const auto& c_block_tile = GemmPipeline{}.template operator()(
a_block_window, b_block_window, num_loop, smem_ptr_0); a_block_window, b_block_window, num_loop, smem_ptr_0);
// Run Epilogue Pipeline // Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I2); auto& c_block_window = gemm_tile_windows.at(I2);
...@@ -548,7 +546,7 @@ struct GemmKernel ...@@ -548,7 +546,7 @@ struct GemmKernel
EpiloguePipeline{} EpiloguePipeline{}
.template operator()<decltype(c_block_window), decltype(c_block_tile), DstInMemOp>( .template operator()<decltype(c_block_window), decltype(c_block_tile), DstInMemOp>(
c_block_window, c_block_tile, smem_ptr_0, smem_ptr_1); c_block_window, c_block_tile, smem_ptr_0);
} }
CK_TILE_DEVICE void operator()(GemmKernelArgs kargs) const CK_TILE_DEVICE void operator()(GemmKernelArgs kargs) const
......
...@@ -69,9 +69,9 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem> ...@@ -69,9 +69,9 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK; static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t VectorSizeA = Policy::template GetVectorSizeA<Problem>(); static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
static constexpr index_t VectorSizeB = Policy::template GetVectorSizeB<Problem>(); static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
static constexpr index_t VectorSizeC = Policy::template GetVectorSizeC<Problem>(); static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
static constexpr bool kPadM = Problem::kPadM; static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadN = Problem::kPadN;
...@@ -117,9 +117,9 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem> ...@@ -117,9 +117,9 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
constexpr index_t B_LDS_Read_Width = KPerXDL; constexpr index_t B_LDS_Read_Width = KPerXDL;
constexpr index_t A_Buffer_Load_Inst_Num = constexpr index_t A_Buffer_Load_Inst_Num =
MPerBlock * KPerBlock / (BlockSize * VectorSizeA); MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA());
constexpr index_t B_Buffer_Load_Inst_Num = constexpr index_t B_Buffer_Load_Inst_Num =
NPerBlock * KPerBlock / (BlockSize * VectorSizeB); NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB());
constexpr index_t A_LDS_Write_Inst_Num = MPerBlock * KPerBlock / (BlockSize * KPerXDL); constexpr index_t A_LDS_Write_Inst_Num = MPerBlock * KPerBlock / (BlockSize * KPerXDL);
constexpr index_t B_LDS_Write_Inst_Num = NPerBlock * KPerBlock / (BlockSize * KPerXDL); constexpr index_t B_LDS_Write_Inst_Num = NPerBlock * KPerBlock / (BlockSize * KPerXDL);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment