Commit 987cc54d authored by ThomasNing's avatar ThomasNing
Browse files

Finish the integration to develop and have the correct result

parent 3b301468
...@@ -114,8 +114,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& ...@@ -114,8 +114,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
has_hot_loop_v, has_hot_loop_v,
tail_number_v>; tail_number_v>;
using GemmPipeline = using GemmPipeline = GEMM_PIPELINE<UniversalGemmProblem>;
GEMM_PIPELINE<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue< using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<AccDataType, ck_tile::CShuffleEpilogueProblem<AccDataType,
CDataType, CDataType,
...@@ -241,64 +240,63 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& ...@@ -241,64 +240,63 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
Run(ck_tile::bool_constant<true>{}, Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{}); ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
} }
}
#endif #endif
}
else
{
// Tail number always Full - #PrefetchStages
if(tail_num == ck_tile::TailNumber::Full)
{
Run(ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
else else
{ {
// Tail number always Full - #PrefetchStages std::ostringstream err;
if(tail_num == ck_tile::TailNumber::Full) err << "When there's no hot loop, this tail number \"" << tail_num
{ << "\" is not supported! PrefetchStages: " << BaseGemmPipeline::PrefetchStages
Run(ck_tile::bool_constant<false>{}, << "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{}); throw std::runtime_error(err.str());
}
else
{
std::ostringstream err;
err << "When there's no hot loop, this tail number \"" << tail_num
<< "\" is not supported! PrefetchStages: " << BaseGemmPipeline::PrefetchStages
<< "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
} }
return ave_time;
} }
return ave_time;
}
#include "run_gemm_example.inc" #include "run_gemm_example.inc"
int run_gemm_example(int argc, char* argv[]) int run_gemm_example(int argc, char* argv[])
{ {
auto [result, arg_parser] = create_args(argc, argv); auto [result, arg_parser] = create_args(argc, argv);
if(!result) if(!result)
return -1; return -1;
using Row = ck_tile::tensor_layout::gemm::RowMajor; using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
std::string a_layout = arg_parser.get_str("a_layout"); std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout"); std::string b_layout = arg_parser.get_str("b_layout");
if(a_layout == "R" && b_layout == "R") if(a_layout == "R" && b_layout == "R")
{ {
return run_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); return run_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{});
} }
else if(a_layout == "R" && b_layout == "C") else if(a_layout == "R" && b_layout == "C")
{ {
return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{});
} }
else if(a_layout == "C" && b_layout == "C") else if(a_layout == "C" && b_layout == "C")
{ {
return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{});
} }
else if(a_layout == "C" && b_layout == "R") else if(a_layout == "C" && b_layout == "R")
{ {
return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{}); return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{});
} }
else else
{ {
throw std::runtime_error( throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
"Unsupported data layout configuration for A,B and C tensors!");
}
} }
}
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
...@@ -490,8 +490,6 @@ struct GemmKernel ...@@ -490,8 +490,6 @@ struct GemmKernel
const auto& c_block_tile = GemmPipeline{}.template operator()( const auto& c_block_tile = GemmPipeline{}.template operator()(
a_block_window, b_block_window, num_loop, smem_ptr_0); a_block_window, b_block_window, num_loop, smem_ptr_0);
// Run Epilogue Pipeline // Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I2); auto& c_block_window = gemm_tile_windows.at(I2);
...@@ -548,7 +546,7 @@ struct GemmKernel ...@@ -548,7 +546,7 @@ struct GemmKernel
EpiloguePipeline{} EpiloguePipeline{}
.template operator()<decltype(c_block_window), decltype(c_block_tile), DstInMemOp>( .template operator()<decltype(c_block_window), decltype(c_block_tile), DstInMemOp>(
c_block_window, c_block_tile, smem_ptr_0, smem_ptr_1); c_block_window, c_block_tile, smem_ptr_0);
} }
CK_TILE_DEVICE void operator()(GemmKernelArgs kargs) const CK_TILE_DEVICE void operator()(GemmKernelArgs kargs) const
...@@ -596,14 +594,14 @@ struct GemmKernel ...@@ -596,14 +594,14 @@ struct GemmKernel
if constexpr(GemmPipeline::DoubleSmemBuffer == true) if constexpr(GemmPipeline::DoubleSmemBuffer == true)
{ {
RunGemm2LDS<memory_operation_enum::atomic_add>(a_ptr, RunGemm2LDS<memory_operation_enum::atomic_add>(a_ptr,
b_ptr, b_ptr,
c_ptr, c_ptr,
smem_ptr_0, smem_ptr_0,
smem_ptr_1, smem_ptr_1,
kargs, kargs,
splitk_batch_offset, splitk_batch_offset,
i_m, i_m,
i_n); i_n);
} }
else else
{ {
......
...@@ -69,9 +69,9 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem> ...@@ -69,9 +69,9 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK; static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t VectorSizeA = Policy::template GetVectorSizeA<Problem>(); static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
static constexpr index_t VectorSizeB = Policy::template GetVectorSizeB<Problem>(); static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
static constexpr index_t VectorSizeC = Policy::template GetVectorSizeC<Problem>(); static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
static constexpr bool kPadM = Problem::kPadM; static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadN = Problem::kPadN;
...@@ -117,9 +117,9 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem> ...@@ -117,9 +117,9 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
constexpr index_t B_LDS_Read_Width = KPerXDL; constexpr index_t B_LDS_Read_Width = KPerXDL;
constexpr index_t A_Buffer_Load_Inst_Num = constexpr index_t A_Buffer_Load_Inst_Num =
MPerBlock * KPerBlock / (BlockSize * VectorSizeA); MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA());
constexpr index_t B_Buffer_Load_Inst_Num = constexpr index_t B_Buffer_Load_Inst_Num =
NPerBlock * KPerBlock / (BlockSize * VectorSizeB); NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB());
constexpr index_t A_LDS_Write_Inst_Num = MPerBlock * KPerBlock / (BlockSize * KPerXDL); constexpr index_t A_LDS_Write_Inst_Num = MPerBlock * KPerBlock / (BlockSize * KPerXDL);
constexpr index_t B_LDS_Write_Inst_Num = NPerBlock * KPerBlock / (BlockSize * KPerXDL); constexpr index_t B_LDS_Write_Inst_Num = NPerBlock * KPerBlock / (BlockSize * KPerXDL);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment