"...composable_kernel.git" did not exist on "d626dccc952b90397baa60fd3633a3504e93f92a"
Commit d51f4e52 authored by dummycoderfe's avatar dummycoderfe
Browse files

use 32x32x8 ok, fix scratch store

parent bc4366d4
...@@ -72,7 +72,7 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -72,7 +72,7 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>; ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using CodegenPipelineProblem = ck_tile:: using CodegenPipelineProblem = ck_tile::
GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>; GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>;
using CodegenGemmPolicy = ck_tile::UniversalGemmPipelineAgBgCrPolicy; using CodegenGemmPolicy = ck_tile::GemmPipelineAGmemBGmemCRegV1DefaultPolicy;
using CodegenGemmPipeline = using CodegenGemmPipeline =
ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem, CodegenGemmPolicy>; ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem, CodegenGemmPolicy>;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM. // ToDo: Will add the codegen part to test different pipeline policies in GEMM.
......
...@@ -73,7 +73,7 @@ auto create_args(int argc, char* argv[]) ...@@ -73,7 +73,7 @@ auto create_args(int argc, char* argv[])
.insert("n", "4096", "n dimension") .insert("n", "4096", "n dimension")
.insert("k", "2048", "k dimension") .insert("k", "2048", "k dimension")
.insert("a_layout", "R", "A tensor data layout - Row by default") .insert("a_layout", "R", "A tensor data layout - Row by default")
.insert("b_layout", "R", "B tensor data layout - Row by default") .insert("b_layout", "C", "B tensor data layout - Row by default")
.insert("c_layout", "R", "C tensor data layout - Row by default") .insert("c_layout", "R", "C tensor data layout - Row by default")
.insert("stride_a", "0", "Tensor A stride") .insert("stride_a", "0", "Tensor A stride")
.insert("stride_b", "0", "Tensor B stride") .insert("stride_b", "0", "Tensor B stride")
......
...@@ -194,22 +194,23 @@ int run_gemm_example(int argc, char* argv[]) ...@@ -194,22 +194,23 @@ int run_gemm_example(int argc, char* argv[])
std::string a_layout = arg_parser.get_str("a_layout"); std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout"); std::string b_layout = arg_parser.get_str("b_layout");
if(a_layout == "R" && b_layout == "R") // if(a_layout == "R" && b_layout == "R")
{ // {
return run_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{}); // return run_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{});
} // }
else if(a_layout == "R" && b_layout == "C") // else
if(a_layout == "R" && b_layout == "C")
{ {
return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{});
} }
else if(a_layout == "C" && b_layout == "C") // else if(a_layout == "C" && b_layout == "C")
{ // {
return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); // return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{});
} // }
else if(a_layout == "C" && b_layout == "R") // else if(a_layout == "C" && b_layout == "R")
{ // {
return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{}); // return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{});
} // }
else else
{ {
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!"); throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
......
...@@ -42,6 +42,9 @@ struct BlockGemmASmemBSmemCRegV1 ...@@ -42,6 +42,9 @@ struct BlockGemmASmemBSmemCRegV1
KPerBlock == BlockGemmShape::kK, KPerBlock == BlockGemmShape::kK,
"wrong!"); "wrong!");
// if(threadIdx.x == 0 && blockIdx.x==0) {
// printf("MPerBlock %d NPerBlock %d KPerBlock %d \n", MPerBlock, NPerBlock, KPerBlock);
// }
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>(); constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>; using WG = remove_cvref_t<decltype(config.template at<0>())>;
...@@ -60,6 +63,12 @@ struct BlockGemmASmemBSmemCRegV1 ...@@ -60,6 +63,12 @@ struct BlockGemmASmemBSmemCRegV1
const index_t iMWarp = get_warp_id() / NWarp; const index_t iMWarp = get_warp_id() / NWarp;
const index_t iNWarp = get_warp_id() % NWarp; const index_t iNWarp = get_warp_id() % NWarp;
// if(threadIdx.x == 0 && blockIdx.x==0) {
// printf("MWarp %d NWarp %d MIterPerWarp %d NIterPerWarp %d KIterPerWarp %d MPerBlockPerIter %d NPerBlockPerIter %d KPerBlockPerIter %d \n", MWarp, NWarp, MIterPerWarp, NIterPerWarp, KIterPerWarp, MPerBlockPerIter, NPerBlockPerIter, KPerBlockPerIter);
// }
// MWarp 2 NWarp 2 MIterPerWarp 4 NIterPerWarp 4 KIterPerWarp 4 MPerBlockPerIter 64 NPerBlockPerIter 64 KPerBlockPerIter 8
// construct A-warp-window // construct A-warp-window
auto a_warp_window_tmp = make_tile_window( auto a_warp_window_tmp = make_tile_window(
a_block_window.get_bottom_tensor_view(), a_block_window.get_bottom_tensor_view(),
...@@ -136,7 +145,6 @@ struct BlockGemmASmemBSmemCRegV1 ...@@ -136,7 +145,6 @@ struct BlockGemmASmemBSmemCRegV1
constexpr auto c_warp_y_lengths = constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{}; constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// hot loop: // hot loop:
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
......
...@@ -40,7 +40,8 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy ...@@ -40,7 +40,8 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16{}, 2, 2); return make_tuple(WarpGemmMfmaF16F16F32M32N32K16{}, 2, 2);
} }
#else #else
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1); return make_tuple(WarpGemmMfmaF16F16F32M32N32K8{}, 2, 2);
// return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1);
#endif #endif
} }
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> && else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
......
...@@ -112,10 +112,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -112,10 +112,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
{ {
constexpr index_t smem_size_a = GetSmemSizeA<Problem>(); constexpr index_t smem_size_a = GetSmemSizeA<Problem>();
constexpr index_t smem_size_b = GetSmemSizeB<Problem>(); constexpr index_t smem_size_b = GetSmemSizeB<Problem>();
index_t smem_size = 0; return smem_size_a + smem_size_b;
smem_size += smem_size_a + smem_size_b;
return smem_size;
} }
template <typename Problem> template <typename Problem>
...@@ -259,7 +256,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -259,7 +256,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
constexpr index_t M2 = get_warp_size() / K0; constexpr index_t M2 = get_warp_size() / K0;
// coalesce reading for each blocks // coalesce reading for each blocks
if constexpr(get_warp_size() % (M2 * K0) == 0) if constexpr(get_warp_size() % (M2 * K0) == 0)
{ {//Number<MNXdlPerWave>{}, Number<MNWaves>{}, Number<MNPerXdl>{}))),
constexpr index_t M1 = BlockSize / get_warp_size(); constexpr index_t M1 = BlockSize / get_warp_size();
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error."); static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error."); static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment