Commit a8d88d8d authored by coderfeli's avatar coderfeli
Browse files

tmp before merge

parent c7d08b7c
......@@ -58,7 +58,7 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
using GemmEpilogue = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>;
using CodegenGemmTraits =
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout, true, 3>;
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout, true, 2>;
using CodegenPipelineProblem = ck_tile::
GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>;
using CodegenGemmPolicy = ck_tile::GemmPipelineAGmemBGmemCRegV1DefaultPolicy;
......
......@@ -189,9 +189,16 @@ struct BlockGemmARegBRegCRegV2
CK_TILE_DEVICE static constexpr auto MakeABlockDistribution()
{
// M->N Warp
// using AWarpDstrEncoding = tile_distribution_encoding<
// sequence<>,
// tuple<sequence<Impl::kAMLane>, sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>, //<32>, <2, 8>
// tuple<sequence<2, 1>>,
// tuple<sequence<0, 0>>,
// sequence<2>,
// sequence<1>>;
constexpr auto a_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>, // <4, 2>, <2>
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
......
......@@ -254,13 +254,17 @@ struct GemmPipelineAGmemBGmemCRegV1
// local prefetch 0
// a b register tile for lds prefetch & mfma
using ALdsTileDistr = decltype(Policy::template BlockGemm<Problem>::MakeABlockDistribution());
using BLdsTileDistr = decltype(Policy::template BlockGemm<Problem>::MakeBBlockDistribution());
using ALdsTile = decltype(make_static_distributed_tensor<ADataType>(ALdsTileDistr{}));
using BLdsTile = decltype(make_static_distributed_tensor<BDataType>(BLdsTileDistr{}));
constexpr auto ALdsTileDistr = decltype(Policy::template BlockGemm<Problem>::MakeABlockDistribution()){};
constexpr auto BLdsTileDistr = decltype(Policy::template BlockGemm<Problem>::MakeBBlockDistribution()){};
using ALdsTile = decltype(make_static_distributed_tensor<ADataType>(ALdsTileDistr));
using BLdsTile = decltype(make_static_distributed_tensor<BDataType>(BLdsTileDistr));
ALdsTile a_block_tile0;
BLdsTile b_block_tile0;
Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window0, a_block_tile0);
auto a_lds_ld_window0 = make_tile_window_linear(a_lds_block0, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0}, ALdsTileDistr);
auto a_lds_ld_window1 = make_tile_window_linear(a_lds_block1, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0}, ALdsTileDistr);
// Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window0, a_block_tile0);
load_tile(a_block_tile0, a_lds_ld_window0);
Policy::template BlockGemm<Problem>::PrefetchLds(b_lds_window0, b_block_tile0);
// LDS write 1
......@@ -281,7 +285,8 @@ struct GemmPipelineAGmemBGmemCRegV1
// ping
{
block_sync_lds();
Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window1, a_block_tile1);
// Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window1, a_block_tile1);
load_tile(a_block_tile1, a_lds_ld_window1);
Policy::template BlockGemm<Problem>::PrefetchLds(b_lds_window1, b_block_tile1);
LocalPrefill(a_lds_window0, a_global_load_tile, a_element_func);
LocalPrefill(b_lds_window0, b_global_load_tile, b_element_func);
......@@ -293,7 +298,8 @@ struct GemmPipelineAGmemBGmemCRegV1
// pong
{
block_sync_lds();
Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window0, a_block_tile0);
// Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window0, a_block_tile0);
load_tile(a_block_tile0, a_lds_ld_window0);
Policy::template BlockGemm<Problem>::PrefetchLds(b_lds_window0, b_block_tile0);
LocalPrefill(a_lds_window1, a_global_load_tile, a_element_func);
LocalPrefill(b_lds_window1, b_global_load_tile, b_element_func);
......@@ -311,7 +317,8 @@ struct GemmPipelineAGmemBGmemCRegV1
// 3
{
block_sync_lds();
Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window1, a_block_tile1);
// Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window1, a_block_tile1);
load_tile(a_block_tile1, a_lds_ld_window1);
Policy::template BlockGemm<Problem>::PrefetchLds(b_lds_window1, b_block_tile1);
LocalPrefill(a_lds_window0, a_global_load_tile, a_element_func);
LocalPrefill(b_lds_window0, b_global_load_tile, b_element_func);
......@@ -320,7 +327,8 @@ struct GemmPipelineAGmemBGmemCRegV1
// 2
{
block_sync_lds();
Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window0, a_block_tile0);
// Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window0, a_block_tile0);
load_tile(a_block_tile0, a_lds_ld_window0);
Policy::template BlockGemm<Problem>::PrefetchLds(b_lds_window0, b_block_tile0);
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
}
......@@ -334,7 +342,8 @@ struct GemmPipelineAGmemBGmemCRegV1
// //tail 2
{
block_sync_lds();
Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window1, a_block_tile1);
// Policy::template BlockGemm<Problem>::PrefetchLds(a_lds_window1, a_block_tile1);
load_tile(a_block_tile1, a_lds_ld_window1);
Policy::template BlockGemm<Problem>::PrefetchLds(b_lds_window1, b_block_tile1);
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
}
......
......@@ -70,21 +70,6 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
return a_lds_block_desc;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockLinearDescriptor()
{
using namespace ck_tile;
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / 8>{}, number<kMPerBlock>{}, number<8>{}),
make_tuple(number<(kMPerBlock) * 8>{}, number<8>{}, number<1>{}),
number<8>{},
number<1>{});
return a_lds_block_desc_0;
}
// 3d + padding
template <typename Problem>
......
......@@ -170,7 +170,7 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution
using AWarpDstrEncoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>, // <32>, <2, 4>
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2>,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment