Commit bbea596d authored by ThomasNing's avatar ThomasNing
Browse files

debug the code

parent 3935554a
add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp) add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp)
add_executable(tile_example_universal_gemm EXCLUDE_FROM_ALL universal_gemm.cpp) add_executable(tile_example_universal_gemm EXCLUDE_FROM_ALL universal_gemm.cpp)
target_compile_options(tile_example_gemm_basic PRIVATE
-mllvm -enable-noalias-to-md-conversion=0
)
...@@ -69,12 +69,11 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -69,12 +69,11 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>>; ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>>;
using CodegenGemmTraits = using CodegenGemmTraits =
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>; ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout, true, 3>;
using CodegenPipelineProblem = ck_tile:: using CodegenPipelineProblem = ck_tile::
GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>; GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>;
using CodegenGemmPolicy = ck_tile::UniversalGemmPipelineAgBgCrPolicy;
using CodegenGemmPipeline = using CodegenGemmPipeline =
ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem, CodegenGemmPolicy>; ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM. // ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using Kernel = ck_tile::GemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>; using Kernel = ck_tile::GemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
...@@ -92,11 +91,6 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -92,11 +91,6 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch); const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch);
constexpr dim3 blocks = Kernel::BlockSize(); constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0) if(s.log_level_ > 0)
{ {
std::cout << "Launching kernel with args:" std::cout << "Launching kernel with args:"
......
...@@ -40,6 +40,37 @@ struct BlockGemmARegBRegCRegV2 ...@@ -40,6 +40,37 @@ struct BlockGemmARegBRegCRegV2
static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
static constexpr index_t KIterPerWarp = KPerBlock / WG::kK; static constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode()
{
constexpr auto a_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
return a_block_dstr_encode;
}
CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode()
{
constexpr auto b_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
b_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{});
return b_block_dstr_encode;
}
// C += A * B // C += A * B
template <typename CBlockTensor, typename ABlockTensor, typename BBlockTensor> template <typename CBlockTensor, typename ABlockTensor, typename BBlockTensor>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
...@@ -51,8 +82,8 @@ struct BlockGemmARegBRegCRegV2 ...@@ -51,8 +82,8 @@ struct BlockGemmARegBRegCRegV2
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>, std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
"wrong!"); "wrong!");
constexpr auto a_block_dstr_encode = MakeABlockDistribution(); constexpr auto a_block_dstr_encode = MakeABlockDistributionEncode();
constexpr auto b_block_dstr_encode = MakeBBlockDistribution(); constexpr auto b_block_dstr_encode = MakeBBlockDistributionEncode();
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>, sequence<>,
...@@ -173,38 +204,21 @@ struct BlockGemmARegBRegCRegV2 ...@@ -173,38 +204,21 @@ struct BlockGemmARegBRegCRegV2
// auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr); // auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
// return c_block_tensor; // return c_block_tensor;
// } // }
CK_TILE_DEVICE static constexpr auto MakeCBlockDistributionEncode()
CK_TILE_DEVICE static constexpr auto MakeABlockDistribution()
{ {
constexpr auto a_block_outer_dstr_encoding = constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
tile_distribution_encoding<sequence<NWarp>, sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>, tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 0>>, tuple<sequence<1, 2>>,
tuple<sequence<1, 0>>, tuple<sequence<1, 1>>,
sequence<1, 2>, sequence<1, 2>,
sequence<0, 0>>{}; sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( return c_block_dstr_encode;
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode);
return a_block_dstr;
} }
CK_TILE_DEVICE static constexpr auto MakeBBlockDistribution()
{
constexpr auto b_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
b_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{});
constexpr auto b_block_dstr = make_static_tile_distribution(b_block_dstr_encode);
return b_block_dstr;
}
// C = A * B // C = A * B
template <typename ABlockTensor, typename BBlockTensor> template <typename ABlockTensor, typename BBlockTensor>
......
...@@ -149,8 +149,8 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -149,8 +149,8 @@ struct GemmPipelineAGmemBGmemCRegV1
const BDramBlockWindowTmp& b_dram_block_window_tmp, const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func, const BElementFunction& b_element_func,
index_t num_loop, index_t num_loop,
void* __restrict__ p_smem_0, void* p_smem_0,
void* __restrict__ p_smem_1) void* p_smem_1)
{ {
static_assert( static_assert(
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> && std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
...@@ -238,8 +238,12 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -238,8 +238,12 @@ struct GemmPipelineAGmemBGmemCRegV1
block_sync_lds(); block_sync_lds();
constexpr auto ALdsTileDistr = decltype(Policy::template BlockGemm<Problem>::MakeABlockDistribution()){}; // constexpr auto ALdsTileDistr = decltype(Policy::template BlockGemm<Problem>::MakeABlockDistribution()){};
constexpr auto BLdsTileDistr = decltype(Policy::template BlockGemm<Problem>::MakeBBlockDistribution()){}; // constexpr auto BLdsTileDistr = decltype(Policy::template BlockGemm<Problem>::MakeBBlockDistribution()){};
constexpr auto ALdsTileDistr = decltype(make_static_tile_distribution(
Policy::template BlockGemm<Problem>::MakeABlockDistributionEncode())){};
constexpr auto BLdsTileDistr = decltype(make_static_tile_distribution(
Policy::template BlockGemm<Problem>::MakeBBlockDistributionEncode())){};
using ALdsTile = decltype(make_static_distributed_tensor<ADataType>(ALdsTileDistr)); using ALdsTile = decltype(make_static_distributed_tensor<ADataType>(ALdsTileDistr));
using BLdsTile = decltype(make_static_distributed_tensor<BDataType>(BLdsTileDistr)); using BLdsTile = decltype(make_static_distributed_tensor<BDataType>(BLdsTileDistr));
ALdsTile a_block_tile0; ALdsTile a_block_tile0;
...@@ -359,8 +363,8 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -359,8 +363,8 @@ struct GemmPipelineAGmemBGmemCRegV1
CK_TILE_DEVICE static auto run(const ADramBlockWindowTmp& a_dram_block_window_tmp, CK_TILE_DEVICE static auto run(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp, const BDramBlockWindowTmp& b_dram_block_window_tmp,
index_t num_loop, index_t num_loop,
void*__restrict__ p_smem_0, void* p_smem_0,
void*__restrict__ p_smem_1) void* p_smem_1)
{ {
return run( return run(
a_dram_block_window_tmp, a_dram_block_window_tmp,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment