Commit a8d88d8d authored by coderfeli's avatar coderfeli
Browse files

tmp before merge

parent c7d08b7c
...@@ -58,7 +58,7 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -58,7 +58,7 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
using GemmEpilogue = ck_tile::Default2DEpilogue< using GemmEpilogue = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>; ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>;
using CodegenGemmTraits = using CodegenGemmTraits =
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout, true, 3>; ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout, true, 2>;
using CodegenPipelineProblem = ck_tile:: using CodegenPipelineProblem = ck_tile::
GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>; GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>;
using CodegenGemmPolicy = ck_tile::GemmPipelineAGmemBGmemCRegV1DefaultPolicy; using CodegenGemmPolicy = ck_tile::GemmPipelineAGmemBGmemCRegV1DefaultPolicy;
......
...@@ -189,9 +189,16 @@ struct BlockGemmARegBRegCRegV2 ...@@ -189,9 +189,16 @@ struct BlockGemmARegBRegCRegV2
CK_TILE_DEVICE static constexpr auto MakeABlockDistribution() CK_TILE_DEVICE static constexpr auto MakeABlockDistribution()
{ {
// M->N Warp // M->N Warp
// using AWarpDstrEncoding = tile_distribution_encoding<
// sequence<>,
// tuple<sequence<Impl::kAMLane>, sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>, //<32>, <2, 8>
// tuple<sequence<2, 1>>,
// tuple<sequence<0, 0>>,
// sequence<2>,
// sequence<1>>;
constexpr auto a_block_outer_dstr_encoding = constexpr auto a_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>, tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>, tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>, // <4, 2>, <2>
tuple<sequence<1, 0>>, tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>, tuple<sequence<1, 0>>,
sequence<1, 2>, sequence<1, 2>,
......
...@@ -254,13 +254,17 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -254,13 +254,17 @@ struct GemmPipelineAGmemBGmemCRegV1
// local prefetch 0 // local prefetch 0
// a b register tile for lds prefetch & mfma // a b register tile for lds prefetch & mfma
using ALdsTileDistr = decltype(Policy::template BlockGemm<Problem>::MakeABlockDistribution()); constexpr auto ALdsTileDistr = decltype(Policy::template BlockGemm<Problem>::MakeABlockDistribution()){};
using BLdsTileDistr = decltype(Policy::template BlockGemm<Problem>::MakeBBlockDistribution()); constexpr auto BLdsTileDistr = decltype(Policy::template BlockGemm<Problem>::MakeBBlockDistribution()){};
using ALdsTile = decltype(make_static_distributed_tensor<ADataType>(ALdsTileDistr{})); using ALdsTile = decltype(make_static_distributed_tensor<ADataType>(ALdsTileDistr));
using BLdsTile = decltype(make_static_distributed_tensor<BDataType>(BLdsTileDistr{})); using BLdsTile = decltype(make_static_distributed_tensor<BDataType>(BLdsTileDistr));
ALdsTile a_block_tile0; ALdsTile a_block_tile0;
BLdsTile b_block_tile0; BLdsTile b_block_tile0;
Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window0, a_block_tile0); auto a_lds_ld_window0 = make_tile_window_linear(a_lds_block0, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0}, ALdsTileDistr);
auto a_lds_ld_window1 = make_tile_window_linear(a_lds_block1, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0}, ALdsTileDistr);
// Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window0, a_block_tile0);
load_tile(a_block_tile0, a_lds_ld_window0);
Policy::template BlockGemm<Problem>::PrefetchLds(b_lds_window0, b_block_tile0); Policy::template BlockGemm<Problem>::PrefetchLds(b_lds_window0, b_block_tile0);
// LDS write 1 // LDS write 1
...@@ -281,7 +285,8 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -281,7 +285,8 @@ struct GemmPipelineAGmemBGmemCRegV1
// ping // ping
{ {
block_sync_lds(); block_sync_lds();
Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window1, a_block_tile1); // Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window1, a_block_tile1);
load_tile(a_block_tile1, a_lds_ld_window1);
Policy::template BlockGemm<Problem>::PrefetchLds(b_lds_window1, b_block_tile1); Policy::template BlockGemm<Problem>::PrefetchLds(b_lds_window1, b_block_tile1);
LocalPrefill(a_lds_window0, a_global_load_tile, a_element_func); LocalPrefill(a_lds_window0, a_global_load_tile, a_element_func);
LocalPrefill(b_lds_window0, b_global_load_tile, b_element_func); LocalPrefill(b_lds_window0, b_global_load_tile, b_element_func);
...@@ -293,7 +298,8 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -293,7 +298,8 @@ struct GemmPipelineAGmemBGmemCRegV1
// pong // pong
{ {
block_sync_lds(); block_sync_lds();
Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window0, a_block_tile0); // Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window0, a_block_tile0);
load_tile(a_block_tile0, a_lds_ld_window0);
Policy::template BlockGemm<Problem>::PrefetchLds(b_lds_window0, b_block_tile0); Policy::template BlockGemm<Problem>::PrefetchLds(b_lds_window0, b_block_tile0);
LocalPrefill(a_lds_window1, a_global_load_tile, a_element_func); LocalPrefill(a_lds_window1, a_global_load_tile, a_element_func);
LocalPrefill(b_lds_window1, b_global_load_tile, b_element_func); LocalPrefill(b_lds_window1, b_global_load_tile, b_element_func);
...@@ -311,7 +317,8 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -311,7 +317,8 @@ struct GemmPipelineAGmemBGmemCRegV1
// 3 // 3
{ {
block_sync_lds(); block_sync_lds();
Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window1, a_block_tile1); // Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window1, a_block_tile1);
load_tile(a_block_tile1, a_lds_ld_window1);
Policy::template BlockGemm<Problem>::PrefetchLds(b_lds_window1, b_block_tile1); Policy::template BlockGemm<Problem>::PrefetchLds(b_lds_window1, b_block_tile1);
LocalPrefill(a_lds_window0, a_global_load_tile, a_element_func); LocalPrefill(a_lds_window0, a_global_load_tile, a_element_func);
LocalPrefill(b_lds_window0, b_global_load_tile, b_element_func); LocalPrefill(b_lds_window0, b_global_load_tile, b_element_func);
...@@ -320,7 +327,8 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -320,7 +327,8 @@ struct GemmPipelineAGmemBGmemCRegV1
// 2 // 2
{ {
block_sync_lds(); block_sync_lds();
Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window0, a_block_tile0); // Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window0, a_block_tile0);
load_tile(a_block_tile0, a_lds_ld_window0);
Policy::template BlockGemm<Problem>::PrefetchLds(b_lds_window0, b_block_tile0); Policy::template BlockGemm<Problem>::PrefetchLds(b_lds_window0, b_block_tile0);
block_gemm(c_block_tile, a_block_tile1, b_block_tile1); block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
} }
...@@ -334,7 +342,8 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -334,7 +342,8 @@ struct GemmPipelineAGmemBGmemCRegV1
// //tail 2 // //tail 2
{ {
block_sync_lds(); block_sync_lds();
Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window1, a_block_tile1); // Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window1, a_block_tile1);
load_tile(a_block_tile1, a_lds_ld_window1);
Policy::template BlockGemm<Problem>::PrefetchLds(b_lds_window1, b_block_tile1); Policy::template BlockGemm<Problem>::PrefetchLds(b_lds_window1, b_block_tile1);
block_gemm(c_block_tile, a_block_tile0, b_block_tile0); block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
} }
......
...@@ -70,21 +70,6 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -70,21 +70,6 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
return a_lds_block_desc; return a_lds_block_desc;
} }
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockLinearDescriptor()
{
using namespace ck_tile;
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / 8>{}, number<kMPerBlock>{}, number<8>{}),
make_tuple(number<(kMPerBlock) * 8>{}, number<8>{}, number<1>{}),
number<8>{},
number<1>{});
return a_lds_block_desc_0;
}
// 3d + padding // 3d + padding
template <typename Problem> template <typename Problem>
......
...@@ -170,7 +170,7 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution ...@@ -170,7 +170,7 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution
using AWarpDstrEncoding = tile_distribution_encoding< using AWarpDstrEncoding = tile_distribution_encoding<
sequence<>, sequence<>,
tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>, tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>, // <32>, <2, 4>
tuple<sequence<2, 1>>, tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>, tuple<sequence<0, 0>>,
sequence<2>, sequence<2>,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment