Commit 405c05c0 authored by dummycoderfe's avatar dummycoderfe
Browse files

add prefetch and fix output err

parent 6c270303
...@@ -205,8 +205,8 @@ struct BlockGemmARegBRegCRegV2 ...@@ -205,8 +205,8 @@ struct BlockGemmARegBRegCRegV2
} }
// Prefetch lds // Prefetch lds
template <typename BlockWindowTmp, typename BlockTensor> template <typename BlockWindow, typename BlockTensor>
CK_TILE_DEVICE static auto PrefetchLds(const BlockWindowTmp& block_window, BlockTensor& block_tensor) CK_TILE_DEVICE static auto PrefetchLds(const BlockWindow& block_window, BlockTensor& block_tensor)
{ {
auto tileDist = BlockTensor::get_tile_distribution();//.get_static_tile_distribution_encoding() auto tileDist = BlockTensor::get_tile_distribution();//.get_static_tile_distribution_encoding()
return load_tile(block_tensor, make_tile_window(block_window, tileDist)); return load_tile(block_tensor, make_tile_window(block_window, tileDist));
......
...@@ -37,17 +37,17 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -37,17 +37,17 @@ struct GemmPipelineAGmemBGmemCRegV1
static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kPadK = Problem::kPadK; static constexpr bool kPadK = Problem::kPadK;
CK_TILE_HOST_DEVICE static constexpr index_t GetStaticLdsSize() // CK_TILE_HOST_DEVICE static constexpr index_t GetStaticLdsSize()
{ // {
return integer_least_multiple( // return integer_least_multiple(
sizeof(ADataType) * // sizeof(ADataType) *
Policy::template MakeALdsBlockDescriptor<Problem>().get_element_space_size(), // Policy::template MakeALdsBlockDescriptor<Problem>().get_element_space_size(),
16) * 2 + // 16) * 2 +
integer_least_multiple( // integer_least_multiple(
sizeof(BDataType) * // sizeof(BDataType) *
Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size(), // Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size(),
16) * 2; // 16) * 2;
} // }
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{ {
...@@ -91,46 +91,78 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -91,46 +91,78 @@ struct GemmPipelineAGmemBGmemCRegV1
kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}], kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!"); "wrong!");
////////////// global window & register /////////////////
// A DRAM tile window for load
auto a_copy_dram_window =
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
a_dram_block_window_tmp.get_window_origin(),
Policy::template MakeADramTileDistribution<Problem>());
// B DRAM tile window for load
auto b_copy_dram_window =
make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
b_dram_block_window_tmp.get_window_origin(),
Policy::template MakeBDramTileDistribution<Problem>());
// A register tile for global load
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
using ABlockTile = decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
using BBlockTile = decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
ABlockTile a_global_load_tile;
BBlockTile b_global_load_tile;
// A tile in LDS // global prefetch 0
// global read 0
GlobalPrefetch(a_global_load_tile, a_copy_dram_window);
GlobalPrefetch(b_global_load_tile, b_copy_dram_window);
// if (threadIdx.x == 0) {
// constexpr auto span_2d = decltype(a_global_load_tile)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// printf("%f,", type_convert<float>(a_global_load_tile(i_j_idx)));
// });
// printf("\n");
// });
// printf("bbbbb\n");
// constexpr auto span_2d2 = decltype(b_global_load_tile)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// printf("%f,", type_convert<float>(b_global_load_tile(i_j_idx)));
// });
// printf("\n");
// });
// }
////////////// LDS desc, window & register /////////////////
// AB LDS desc
constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor<Problem>(); constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor<Problem>();
constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>(); constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
constexpr index_t a_lds_block_space_size_aligned = constexpr index_t a_lds_block_space_size_aligned =
integer_least_multiple(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16); integer_least_multiple(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16);
constexpr index_t b_lds_block_space_size_aligned = constexpr index_t b_lds_block_space_size_aligned =
integer_least_multiple(sizeof(BDataType) * b_lds_block_desc.get_element_space_size(), 16); integer_least_multiple(sizeof(BDataType) * b_lds_block_desc.get_element_space_size(), 16);
// A tile in LDS view
ADataType* p_a_lds0 = reinterpret_cast<ADataType*>(p_smem); ADataType* p_a_lds0 = reinterpret_cast<ADataType*>(p_smem);
ADataType* p_a_lds1 = reinterpret_cast<ADataType*>(reinterpret_cast<char*>(p_smem) + a_lds_block_space_size_aligned); ADataType* p_a_lds1 = reinterpret_cast<ADataType*>(reinterpret_cast<char*>(p_a_lds0) + a_lds_block_space_size_aligned);
// B tile in LDS
BDataType* p_b_lds0 = reinterpret_cast<BDataType*>(reinterpret_cast<char*>(p_smem) + a_lds_block_space_size_aligned * 2);
BDataType* p_b_lds1 = reinterpret_cast<BDataType*>(reinterpret_cast<char*>(p_b_lds0) + b_lds_block_space_size_aligned);
auto a_lds_block0 = make_tensor_view<address_space_enum::lds>(p_a_lds0, a_lds_block_desc); auto a_lds_block0 = make_tensor_view<address_space_enum::lds>(p_a_lds0, a_lds_block_desc);
auto b_lds_block0 = make_tensor_view<address_space_enum::lds>(p_b_lds0, b_lds_block_desc);
auto a_lds_block1 = make_tensor_view<address_space_enum::lds>(p_a_lds1, a_lds_block_desc); auto a_lds_block1 = make_tensor_view<address_space_enum::lds>(p_a_lds1, a_lds_block_desc);
// B tile in LDS view
BDataType* p_b_lds0 = reinterpret_cast<BDataType*>(reinterpret_cast<char*>(p_a_lds1) + a_lds_block_space_size_aligned);
BDataType* p_b_lds1 = reinterpret_cast<BDataType*>(reinterpret_cast<char*>(p_b_lds0) + b_lds_block_space_size_aligned);
auto b_lds_block0 = make_tensor_view<address_space_enum::lds>(p_b_lds0, b_lds_block_desc);
auto b_lds_block1 = make_tensor_view<address_space_enum::lds>(p_b_lds1, b_lds_block_desc); auto b_lds_block1 = make_tensor_view<address_space_enum::lds>(p_b_lds1, b_lds_block_desc);
// A DRAM tile window for load
auto a_copy_dram_window =
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
a_dram_block_window_tmp.get_window_origin(),
Policy::template MakeADramTileDistribution<Problem>());
// A LDS tile window for store // A LDS tile window for store
auto a_store_lds_window0 = make_tile_window( auto a_store_lds_window0 = make_tile_window(
a_lds_block0, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0}); a_lds_block0, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
auto a_store_lds_window1 = make_tile_window( auto a_store_lds_window1 = make_tile_window(
a_lds_block1, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0}); a_lds_block1, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
// B DRAM tile window for load
auto b_copy_dram_window =
make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
b_dram_block_window_tmp.get_window_origin(),
Policy::template MakeBDramTileDistribution<Problem>());
// B LDS tile window for store // B LDS tile window for store
auto b_store_lds_window0 = make_tile_window( auto b_store_lds_window0 = make_tile_window(
b_lds_block0, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0}); b_lds_block0, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
...@@ -154,41 +186,46 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -154,41 +186,46 @@ struct GemmPipelineAGmemBGmemCRegV1
// Acc register tile // Acc register tile
auto c_block_tile = Policy::template BlockGemm<Problem>::MakeCBlockTile(); auto c_block_tile = Policy::template BlockGemm<Problem>::MakeCBlockTile();
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
// a b register tile // a b register tile for lds prefetch & mfma
auto a_prefetch_tile0 = make_static_distributed_tensor<ADataType>(Policy::template BlockGemm<Problem>::MakeABlockDistribution()); auto a_block_tile0 = make_static_distributed_tensor<ADataType>(Policy::template BlockGemm<Problem>::MakeABlockDistribution());
auto a_prefetch_tile1 = make_static_distributed_tensor<ADataType>(Policy::template BlockGemm<Problem>::MakeABlockDistribution()); auto a_block_tile1 = make_static_distributed_tensor<ADataType>(Policy::template BlockGemm<Problem>::MakeABlockDistribution());
auto b_prefetch_tile0 = make_static_distributed_tensor<BDataType>(Policy::template BlockGemm<Problem>::MakeBBlockDistribution()); auto b_block_tile0 = make_static_distributed_tensor<BDataType>(Policy::template BlockGemm<Problem>::MakeBBlockDistribution());
auto b_prefetch_tile1 = make_static_distributed_tensor<BDataType>(Policy::template BlockGemm<Problem>::MakeBBlockDistribution()); auto b_block_tile1 = make_static_distributed_tensor<BDataType>(Policy::template BlockGemm<Problem>::MakeBBlockDistribution());
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
using ABlockTile =
decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
using BBlockTile =
decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
ABlockTile a_global_load_tile;
BBlockTile b_global_load_tile;
// prefetch
// global read 0
GlobalPrefetch(a_global_load_tile, a_copy_dram_window);
GlobalPrefetch(b_global_load_tile, b_copy_dram_window);
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
// LDS write 0 // LDS write 0
LocalPrefill(a_store_lds_window0, a_global_load_tile, a_element_func); LocalPrefill(a_store_lds_window0, a_global_load_tile, a_element_func);
LocalPrefill(b_store_lds_window0, b_global_load_tile, b_element_func); LocalPrefill(b_store_lds_window0, b_global_load_tile, b_element_func);
block_sync_lds();
// global read 1 // global read 1
GlobalPrefetch(a_global_load_tile, a_copy_dram_window); GlobalPrefetch(a_global_load_tile, a_copy_dram_window);
GlobalPrefetch(b_global_load_tile, b_copy_dram_window); GlobalPrefetch(b_global_load_tile, b_copy_dram_window);
block_sync_lds();
// local prefetch 0 // local prefetch 0
Policy::template BlockGemm<Problem>::PrefetchLds(a_load_lds_window0, a_prefetch_tile0); Policy::template BlockGemm<Problem>::PrefetchLds(a_load_lds_window0, a_block_tile0);
Policy::template BlockGemm<Problem>::PrefetchLds(b_load_lds_window0, b_prefetch_tile0); Policy::template BlockGemm<Problem>::PrefetchLds(b_load_lds_window0, b_block_tile0);
// if (threadIdx.x == 0) {
// printf("aalds\n");
// constexpr auto span_2d = decltype(a_block_tile0)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// printf("%f,", type_convert<float>(a_block_tile0(i_j_idx)));
// });
// printf("\n");
// });
// printf("bbbbblds\n");
// constexpr auto span_2d2 = decltype(b_block_tile0)::get_distributed_spans();
// sweep_tile_span(span_2d2[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d2[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// printf("%f,", type_convert<float>(b_block_tile0(i_j_idx)));
// });
// printf("\n");
// });
// }
// LDS write 1 // LDS write 1
LocalPrefill(a_store_lds_window1, a_global_load_tile, a_element_func); LocalPrefill(a_store_lds_window1, a_global_load_tile, a_element_func);
LocalPrefill(b_store_lds_window1, b_global_load_tile, b_element_func); LocalPrefill(b_store_lds_window1, b_global_load_tile, b_element_func);
...@@ -197,37 +234,31 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -197,37 +234,31 @@ struct GemmPipelineAGmemBGmemCRegV1
GlobalPrefetch(a_global_load_tile, a_copy_dram_window); GlobalPrefetch(a_global_load_tile, a_copy_dram_window);
GlobalPrefetch(b_global_load_tile, b_copy_dram_window); GlobalPrefetch(b_global_load_tile, b_copy_dram_window);
index_t iCounter = num_loop - 1; index_t iCounter = num_loop - 2;
while(iCounter > 2) while(iCounter > 1)
{ {
// ping // ping
{ {
block_sync_lds(); block_sync_lds();
Policy::template BlockGemm<Problem>::PrefetchLds(a_load_lds_window1, a_block_tile1);
Policy::template BlockGemm<Problem>::PrefetchLds(a_load_lds_window1, a_prefetch_tile1); Policy::template BlockGemm<Problem>::PrefetchLds(b_load_lds_window1, b_block_tile1);
Policy::template BlockGemm<Problem>::PrefetchLds(b_load_lds_window1, b_prefetch_tile1);
LocalPrefill(a_store_lds_window0, a_global_load_tile, a_element_func); LocalPrefill(a_store_lds_window0, a_global_load_tile, a_element_func);
LocalPrefill(b_store_lds_window0, b_global_load_tile, b_element_func); LocalPrefill(b_store_lds_window0, b_global_load_tile, b_element_func);
GlobalPrefetch(a_global_load_tile, a_copy_dram_window); GlobalPrefetch(a_global_load_tile, a_copy_dram_window);
GlobalPrefetch(b_global_load_tile, b_copy_dram_window); GlobalPrefetch(b_global_load_tile, b_copy_dram_window);
block_gemm(c_block_tile, a_prefetch_tile0, b_prefetch_tile0); block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
} }
__builtin_amdgcn_sched_barrier(0);
// pong // pong
{ {
block_sync_lds(); block_sync_lds();
Policy::template BlockGemm<Problem>::PrefetchLds(a_load_lds_window0, a_prefetch_tile0); Policy::template BlockGemm<Problem>::PrefetchLds(a_load_lds_window0, a_block_tile0);
Policy::template BlockGemm<Problem>::PrefetchLds(b_load_lds_window0, b_prefetch_tile0); Policy::template BlockGemm<Problem>::PrefetchLds(b_load_lds_window0, b_block_tile0);
LocalPrefill(a_store_lds_window1, a_global_load_tile, a_element_func); LocalPrefill(a_store_lds_window1, a_global_load_tile, a_element_func);
LocalPrefill(b_store_lds_window1, b_global_load_tile, b_element_func); LocalPrefill(b_store_lds_window1, b_global_load_tile, b_element_func);
GlobalPrefetch(a_global_load_tile, a_copy_dram_window); GlobalPrefetch(a_global_load_tile, a_copy_dram_window);
GlobalPrefetch(b_global_load_tile, b_copy_dram_window); GlobalPrefetch(b_global_load_tile, b_copy_dram_window);
block_gemm(c_block_tile, a_prefetch_tile1, b_prefetch_tile1); block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
} }
iCounter -= 2; iCounter -= 2;
} }
...@@ -236,38 +267,34 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -236,38 +267,34 @@ struct GemmPipelineAGmemBGmemCRegV1
// 3 // 3
{ {
block_sync_lds(); block_sync_lds();
Policy::template BlockGemm<Problem>::PrefetchLds(a_load_lds_window1, a_block_tile1);
Policy::template BlockGemm<Problem>::PrefetchLds(a_load_lds_window1, a_prefetch_tile1); Policy::template BlockGemm<Problem>::PrefetchLds(b_load_lds_window1, b_block_tile1);
Policy::template BlockGemm<Problem>::PrefetchLds(b_load_lds_window1, b_prefetch_tile1);
LocalPrefill(a_store_lds_window0, a_global_load_tile, a_element_func); LocalPrefill(a_store_lds_window0, a_global_load_tile, a_element_func);
LocalPrefill(b_store_lds_window0, b_global_load_tile, b_element_func); LocalPrefill(b_store_lds_window0, b_global_load_tile, b_element_func);
block_gemm(c_block_tile, a_prefetch_tile0, b_prefetch_tile0); block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
__builtin_amdgcn_sched_barrier(0);
} }
// 2 // 2
{ {
block_sync_lds(); block_sync_lds();
Policy::template BlockGemm<Problem>::PrefetchLds(a_load_lds_window0, a_prefetch_tile0); Policy::template BlockGemm<Problem>::PrefetchLds(a_load_lds_window0, a_block_tile0);
Policy::template BlockGemm<Problem>::PrefetchLds(b_load_lds_window0, b_prefetch_tile0); Policy::template BlockGemm<Problem>::PrefetchLds(b_load_lds_window0, b_block_tile0);
block_gemm(c_block_tile, a_prefetch_tile1, b_prefetch_tile1); block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
__builtin_amdgcn_sched_barrier(0);
} }
//1 //1
{ {
block_gemm(c_block_tile, a_prefetch_tile0, b_prefetch_tile0); block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
} }
//tail 2 //tail 2
} else { } else {
{ {
block_sync_lds(); block_sync_lds();
Policy::template BlockGemm<Problem>::PrefetchLds(a_load_lds_window1, a_prefetch_tile1); Policy::template BlockGemm<Problem>::PrefetchLds(a_load_lds_window1, a_block_tile1);
Policy::template BlockGemm<Problem>::PrefetchLds(b_load_lds_window1, b_prefetch_tile1); Policy::template BlockGemm<Problem>::PrefetchLds(b_load_lds_window1, b_block_tile1);
block_gemm(c_block_tile, a_prefetch_tile0, b_prefetch_tile0); block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
__builtin_amdgcn_sched_barrier(0);
} }
// 2 // 2
{ {
block_gemm(c_block_tile, a_prefetch_tile1, b_prefetch_tile1); block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
} }
} }
return c_block_tile; return c_block_tile;
......
...@@ -97,16 +97,18 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -97,16 +97,18 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA() CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA()
{ {
constexpr index_t smem_size_a = sizeof(typename Problem::ADataType) * constexpr index_t smem_size_a = integer_least_multiple(sizeof(typename Problem::ADataType) *
MakeALdsBlockDescriptor<Problem>().get_element_space_size(); MakeALdsBlockDescriptor<Problem>().get_element_space_size(), 16)
* 2;
return smem_size_a; return smem_size_a;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB() CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB()
{ {
constexpr index_t smem_size_b = sizeof(typename Problem::BDataType) * constexpr index_t smem_size_b = integer_least_multiple(sizeof(typename Problem::BDataType) *
MakeBLdsBlockDescriptor<Problem>().get_element_space_size(); MakeBLdsBlockDescriptor<Problem>().get_element_space_size(), 16)
* 2;
return smem_size_b; return smem_size_b;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment