Commit 71352c44 authored by ThomasNing's avatar ThomasNing
Browse files

Solve the compiler issue on SHMEM conflict

parent 49316982
...@@ -216,14 +216,14 @@ int run_gemm_example(int argc, char* argv[]) ...@@ -216,14 +216,14 @@ int run_gemm_example(int argc, char* argv[])
std::string a_layout = arg_parser.get_str("a_layout"); std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout"); std::string b_layout = arg_parser.get_str("b_layout");
if(a_layout == "R" && b_layout == "R") // if(a_layout == "R" && b_layout == "R")
{ // {
return run_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); // return run_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{});
} // }
else if(a_layout == "R" && b_layout == "C") // else if(a_layout == "R" && b_layout == "C")
{ // {
return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{});
} // }
// TODO: Fixme: with latest changes to GemmPipelineAGmemBGmemCRegV1DefaultPolicy below do not // TODO: Fixme: with latest changes to GemmPipelineAGmemBGmemCRegV1DefaultPolicy below do not
// work. // work.
// else if(a_layout == "C" && b_layout == "C") // else if(a_layout == "C" && b_layout == "C")
...@@ -234,8 +234,8 @@ int run_gemm_example(int argc, char* argv[]) ...@@ -234,8 +234,8 @@ int run_gemm_example(int argc, char* argv[])
// { // {
// return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{}); // return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{});
// } // }
else // else
{ // {
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!"); // throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
} // }
} }
...@@ -33,7 +33,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& ...@@ -33,7 +33,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
// Compute friendly for Intrawave scheduler // Compute friendly for Intrawave scheduler
constexpr ck_tile::index_t M_Tile = 256; constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 256; constexpr ck_tile::index_t N_Tile = 256;
constexpr ck_tile::index_t K_Tile = 32; constexpr ck_tile::index_t K_Tile = 64;
constexpr ck_tile::index_t M_Warp = 2; constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2; constexpr ck_tile::index_t N_Warp = 2;
......
...@@ -414,10 +414,10 @@ struct GemmKernel ...@@ -414,10 +414,10 @@ struct GemmKernel
* @tparam DstInMemOp Destination memory operation (default: set). * @tparam DstInMemOp Destination memory operation (default: set).
*/ */
template <memory_operation_enum DstInMemOp = memory_operation_enum::set> template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr, CK_TILE_DEVICE static void RunGemmSinglePointer(const ADataType* a_ptr,
const BDataType* b_ptr, const BDataType* b_ptr,
CDataType* c_ptr, CDataType* c_ptr,
void* smem_ptr, void* smem_ptr_0,
const GemmKernelArgs& kargs, const GemmKernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset, const SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m, const index_t block_idx_m,
...@@ -436,20 +436,63 @@ struct GemmKernel ...@@ -436,20 +436,63 @@ struct GemmKernel
const auto& a_block_window = gemm_tile_windows.at(I0); const auto& a_block_window = gemm_tile_windows.at(I0);
const auto& b_block_window = gemm_tile_windows.at(I1); const auto& b_block_window = gemm_tile_windows.at(I1);
const auto& c_block_tile = const auto& c_block_tile = GemmPipeline{}.template operator()(
[&]() { a_block_window, b_block_window, num_loop, smem_ptr_0);
if constexpr(GemmPipeline::isDoubleSmemBuffer == true)
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I2);
constexpr bool is_output_c_reg_transposed =
EpiloguePipeline::IsOutputTransposed() != GemmPipeline::IsTransposeC();
if constexpr((DstInMemOp == memory_operation_enum::set) || (sizeof(CDataType) > 2) ||
(GemmPipeline::VectorSizeC % 2 == 0 &&
std::is_same_v<CLayout, tensor_layout::gemm::RowMajor> &&
is_output_c_reg_transposed))
{ {
__shared__ char smem_ptr_1[GetSmemSize()]; EpiloguePipeline{}
return GemmPipeline{}.template operator()( .template operator()<decltype(c_block_window), decltype(c_block_tile), DstInMemOp>(
a_block_window, b_block_window, num_loop, smem_ptr, smem_ptr_1); c_block_window, c_block_tile);
} }
else
{
return GemmPipeline{}.template operator()(
a_block_window, b_block_window, num_loop, smem_ptr);
} }
}();
/**
* @brief Runs single GEMM problem cooperatively by whole workgroup.
*
* @param a_ptr input A pointer
* @param b_ptr input B pointer
* @param c_ptr output C pointer
* @param kargs GEMM kernel arguments
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
*
* @tparam DstInMemOp Destination memory operation (default: set).
*/
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
CK_TILE_DEVICE static void RunGemmDoublePointer(const ADataType* a_ptr,
const BDataType* b_ptr,
CDataType* c_ptr,
void* smem_ptr_0,
void* smem_ptr_1,
const GemmKernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m,
const index_t block_idx_n)
{
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple =
MakeGemmTensorViews<DstInMemOp>(a_ptr, b_ptr, c_ptr, kargs, splitk_batch_offset);
;
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
// Run GEMM cooperatively by whole workgroup.
const auto& a_block_window = gemm_tile_windows.at(I0);
const auto& b_block_window = gemm_tile_windows.at(I1);
const auto& c_block_tile = GemmPipeline{}.template operator()(
a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1);
// Run Epilogue Pipeline // Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I2); auto& c_block_window = gemm_tile_windows.at(I2);
...@@ -479,16 +522,48 @@ struct GemmKernel ...@@ -479,16 +522,48 @@ struct GemmKernel
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr); CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
// allocate LDS // allocate LDS
__shared__ char smem_ptr[GetSmemSize()]; __shared__ char smem_ptr_0[GetSmemSize()];
__shared__ char smem_ptr_1[GetSmemSize()];
if(kargs.KBatch == 1) if(kargs.KBatch == 1)
{ {
RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); if constexpr(GemmPipeline::isDoubleSmemBuffer == true)
{
RunGemmDoublePointer(a_ptr,
b_ptr,
c_ptr,
smem_ptr_0,
smem_ptr_1,
kargs,
splitk_batch_offset,
i_m,
i_n);
}
else
{
RunGemmSinglePointer(
a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
}
} }
else else
{ {
RunGemm<memory_operation_enum::atomic_add>( if constexpr(GemmPipeline::isDoubleSmemBuffer == true)
a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); {
RunGemmDoublePointer<memory_operation_enum::atomic_add>(a_ptr,
b_ptr,
c_ptr,
smem_ptr_0,
smem_ptr_1,
kargs,
splitk_batch_offset,
i_m,
i_n);
}
else
{
RunGemmSinglePointer<memory_operation_enum::atomic_add>(
a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
}
} }
} }
}; };
......
...@@ -268,9 +268,6 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem> ...@@ -268,9 +268,6 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
Base::GlobalPrefetch(a_global_load_tile, a_copy_dram_window); Base::GlobalPrefetch(a_global_load_tile, a_copy_dram_window);
Base::GlobalPrefetch(b_global_load_tile, b_copy_dram_window); Base::GlobalPrefetch(b_global_load_tile, b_copy_dram_window);
printf("Tail Num: =====================================\n");
printf("%d \n", static_cast<int>(TailNum));
if(HasHotLoop) if(HasHotLoop)
{ {
// minus 2 because we have ping-pong double buffer. // minus 2 because we have ping-pong double buffer.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment