Commit 71352c44 authored by ThomasNing's avatar ThomasNing
Browse files

Solve the compiler issue on SHMEM conflict

parent 49316982
......@@ -216,14 +216,14 @@ int run_gemm_example(int argc, char* argv[])
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
if(a_layout == "R" && b_layout == "R")
{
return run_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{});
}
else if(a_layout == "R" && b_layout == "C")
{
// if(a_layout == "R" && b_layout == "R")
// {
// return run_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{});
// }
// else if(a_layout == "R" && b_layout == "C")
// {
return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{});
}
// }
// TODO: Fixme: with latest changes to GemmPipelineAGmemBGmemCRegV1DefaultPolicy below do not
// work.
// else if(a_layout == "C" && b_layout == "C")
......@@ -234,8 +234,8 @@ int run_gemm_example(int argc, char* argv[])
// {
// return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{});
// }
else
{
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
}
// else
// {
// throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
// }
}
......@@ -33,7 +33,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
// Compute friendly for Intrawave scheduler
constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 256;
constexpr ck_tile::index_t K_Tile = 32;
constexpr ck_tile::index_t K_Tile = 64;
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
......
......@@ -414,10 +414,10 @@ struct GemmKernel
* @tparam DstInMemOp Destination memory operation (default: set).
*/
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr,
CK_TILE_DEVICE static void RunGemmSinglePointer(const ADataType* a_ptr,
const BDataType* b_ptr,
CDataType* c_ptr,
void* smem_ptr,
void* smem_ptr_0,
const GemmKernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m,
......@@ -436,20 +436,63 @@ struct GemmKernel
const auto& a_block_window = gemm_tile_windows.at(I0);
const auto& b_block_window = gemm_tile_windows.at(I1);
const auto& c_block_tile =
[&]() {
if constexpr(GemmPipeline::isDoubleSmemBuffer == true)
const auto& c_block_tile = GemmPipeline{}.template operator()(
a_block_window, b_block_window, num_loop, smem_ptr_0);
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I2);
constexpr bool is_output_c_reg_transposed =
EpiloguePipeline::IsOutputTransposed() != GemmPipeline::IsTransposeC();
if constexpr((DstInMemOp == memory_operation_enum::set) || (sizeof(CDataType) > 2) ||
(GemmPipeline::VectorSizeC % 2 == 0 &&
std::is_same_v<CLayout, tensor_layout::gemm::RowMajor> &&
is_output_c_reg_transposed))
{
__shared__ char smem_ptr_1[GetSmemSize()];
return GemmPipeline{}.template operator()(
a_block_window, b_block_window, num_loop, smem_ptr, smem_ptr_1);
EpiloguePipeline{}
.template operator()<decltype(c_block_window), decltype(c_block_tile), DstInMemOp>(
c_block_window, c_block_tile);
}
else
{
return GemmPipeline{}.template operator()(
a_block_window, b_block_window, num_loop, smem_ptr);
}
}();
/**
* @brief Runs single GEMM problem cooperatively by whole workgroup.
*
* @param a_ptr input A pointer
* @param b_ptr input B pointer
* @param c_ptr output C pointer
* @param kargs GEMM kernel arguments
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
*
* @tparam DstInMemOp Destination memory operation (default: set).
*/
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
CK_TILE_DEVICE static void RunGemmDoublePointer(const ADataType* a_ptr,
const BDataType* b_ptr,
CDataType* c_ptr,
void* smem_ptr_0,
void* smem_ptr_1,
const GemmKernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m,
const index_t block_idx_n)
{
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple =
MakeGemmTensorViews<DstInMemOp>(a_ptr, b_ptr, c_ptr, kargs, splitk_batch_offset);
;
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
// Run GEMM cooperatively by whole workgroup.
const auto& a_block_window = gemm_tile_windows.at(I0);
const auto& b_block_window = gemm_tile_windows.at(I1);
const auto& c_block_tile = GemmPipeline{}.template operator()(
a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1);
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I2);
......@@ -479,16 +522,48 @@ struct GemmKernel
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
// allocate LDS
__shared__ char smem_ptr[GetSmemSize()];
__shared__ char smem_ptr_0[GetSmemSize()];
__shared__ char smem_ptr_1[GetSmemSize()];
if(kargs.KBatch == 1)
{
RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
if constexpr(GemmPipeline::isDoubleSmemBuffer == true)
{
RunGemmDoublePointer(a_ptr,
b_ptr,
c_ptr,
smem_ptr_0,
smem_ptr_1,
kargs,
splitk_batch_offset,
i_m,
i_n);
}
else
{
RunGemmSinglePointer(
a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
}
}
else
{
RunGemm<memory_operation_enum::atomic_add>(
a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
if constexpr(GemmPipeline::isDoubleSmemBuffer == true)
{
RunGemmDoublePointer<memory_operation_enum::atomic_add>(a_ptr,
b_ptr,
c_ptr,
smem_ptr_0,
smem_ptr_1,
kargs,
splitk_batch_offset,
i_m,
i_n);
}
else
{
RunGemmSinglePointer<memory_operation_enum::atomic_add>(
a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
}
}
}
};
......
......@@ -268,9 +268,6 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
Base::GlobalPrefetch(a_global_load_tile, a_copy_dram_window);
Base::GlobalPrefetch(b_global_load_tile, b_copy_dram_window);
printf("Tail Num: =====================================\n");
printf("%d \n", static_cast<int>(TailNum));
if(HasHotLoop)
{
// minus 2 because we have ping-pong double buffer.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment