Commit 405c05c0 authored by dummycoderfe's avatar dummycoderfe
Browse files

add prefetch and fix output err

parent 6c270303
......@@ -205,8 +205,8 @@ struct BlockGemmARegBRegCRegV2
}
// Prefetch lds
template <typename BlockWindowTmp, typename BlockTensor>
CK_TILE_DEVICE static auto PrefetchLds(const BlockWindowTmp& block_window, BlockTensor& block_tensor)
template <typename BlockWindow, typename BlockTensor>
CK_TILE_DEVICE static auto PrefetchLds(const BlockWindow& block_window, BlockTensor& block_tensor)
{
auto tileDist = BlockTensor::get_tile_distribution();//.get_static_tile_distribution_encoding()
return load_tile(block_tensor, make_tile_window(block_window, tileDist));
......
......@@ -37,17 +37,17 @@ struct GemmPipelineAGmemBGmemCRegV1
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kPadK = Problem::kPadK;
CK_TILE_HOST_DEVICE static constexpr index_t GetStaticLdsSize()
{
return integer_least_multiple(
sizeof(ADataType) *
Policy::template MakeALdsBlockDescriptor<Problem>().get_element_space_size(),
16) * 2 +
integer_least_multiple(
sizeof(BDataType) *
Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size(),
16) * 2;
}
// CK_TILE_HOST_DEVICE static constexpr index_t GetStaticLdsSize()
// {
// return integer_least_multiple(
// sizeof(ADataType) *
// Policy::template MakeALdsBlockDescriptor<Problem>().get_element_space_size(),
// 16) * 2 +
// integer_least_multiple(
// sizeof(BDataType) *
// Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size(),
// 16) * 2;
// }
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
......@@ -91,46 +91,78 @@ struct GemmPipelineAGmemBGmemCRegV1
kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
////////////// global window & register /////////////////
// A DRAM tile window for load
auto a_copy_dram_window =
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
a_dram_block_window_tmp.get_window_origin(),
Policy::template MakeADramTileDistribution<Problem>());
// B DRAM tile window for load
auto b_copy_dram_window =
make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
b_dram_block_window_tmp.get_window_origin(),
Policy::template MakeBDramTileDistribution<Problem>());
// A register tile for global load
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
using ABlockTile = decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
using BBlockTile = decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
ABlockTile a_global_load_tile;
BBlockTile b_global_load_tile;
// A tile in LDS
// global prefetch 0
// global read 0
GlobalPrefetch(a_global_load_tile, a_copy_dram_window);
GlobalPrefetch(b_global_load_tile, b_copy_dram_window);
// if (threadIdx.x == 0) {
// constexpr auto span_2d = decltype(a_global_load_tile)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// printf("%f,", type_convert<float>(a_global_load_tile(i_j_idx)));
// });
// printf("\n");
// });
// printf("bbbbb\n");
// constexpr auto span_2d2 = decltype(b_global_load_tile)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// printf("%f,", type_convert<float>(b_global_load_tile(i_j_idx)));
// });
// printf("\n");
// });
// }
////////////// LDS desc, window & register /////////////////
// AB LDS desc
constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor<Problem>();
constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
constexpr index_t a_lds_block_space_size_aligned =
integer_least_multiple(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16);
constexpr index_t b_lds_block_space_size_aligned =
integer_least_multiple(sizeof(BDataType) * b_lds_block_desc.get_element_space_size(), 16);
// A tile in LDS view
ADataType* p_a_lds0 = reinterpret_cast<ADataType*>(p_smem);
ADataType* p_a_lds1 = reinterpret_cast<ADataType*>(reinterpret_cast<char*>(p_smem) + a_lds_block_space_size_aligned);
// B tile in LDS
BDataType* p_b_lds0 = reinterpret_cast<BDataType*>(reinterpret_cast<char*>(p_smem) + a_lds_block_space_size_aligned * 2);
BDataType* p_b_lds1 = reinterpret_cast<BDataType*>(reinterpret_cast<char*>(p_b_lds0) + b_lds_block_space_size_aligned);
ADataType* p_a_lds1 = reinterpret_cast<ADataType*>(reinterpret_cast<char*>(p_a_lds0) + a_lds_block_space_size_aligned);
auto a_lds_block0 = make_tensor_view<address_space_enum::lds>(p_a_lds0, a_lds_block_desc);
auto b_lds_block0 = make_tensor_view<address_space_enum::lds>(p_b_lds0, b_lds_block_desc);
auto a_lds_block1 = make_tensor_view<address_space_enum::lds>(p_a_lds1, a_lds_block_desc);
// B tile in LDS view
BDataType* p_b_lds0 = reinterpret_cast<BDataType*>(reinterpret_cast<char*>(p_a_lds1) + a_lds_block_space_size_aligned);
BDataType* p_b_lds1 = reinterpret_cast<BDataType*>(reinterpret_cast<char*>(p_b_lds0) + b_lds_block_space_size_aligned);
auto b_lds_block0 = make_tensor_view<address_space_enum::lds>(p_b_lds0, b_lds_block_desc);
auto b_lds_block1 = make_tensor_view<address_space_enum::lds>(p_b_lds1, b_lds_block_desc);
// A DRAM tile window for load
auto a_copy_dram_window =
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
a_dram_block_window_tmp.get_window_origin(),
Policy::template MakeADramTileDistribution<Problem>());
// A LDS tile window for store
auto a_store_lds_window0 = make_tile_window(
a_lds_block0, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
auto a_store_lds_window1 = make_tile_window(
a_lds_block1, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
// B DRAM tile window for load
auto b_copy_dram_window =
make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
b_dram_block_window_tmp.get_window_origin(),
Policy::template MakeBDramTileDistribution<Problem>());
// B LDS tile window for store
auto b_store_lds_window0 = make_tile_window(
b_lds_block0, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
......@@ -154,41 +186,46 @@ struct GemmPipelineAGmemBGmemCRegV1
// Acc register tile
auto c_block_tile = Policy::template BlockGemm<Problem>::MakeCBlockTile();
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
// a b register tile
auto a_prefetch_tile0 = make_static_distributed_tensor<ADataType>(Policy::template BlockGemm<Problem>::MakeABlockDistribution());
auto a_prefetch_tile1 = make_static_distributed_tensor<ADataType>(Policy::template BlockGemm<Problem>::MakeABlockDistribution());
auto b_prefetch_tile0 = make_static_distributed_tensor<BDataType>(Policy::template BlockGemm<Problem>::MakeBBlockDistribution());
auto b_prefetch_tile1 = make_static_distributed_tensor<BDataType>(Policy::template BlockGemm<Problem>::MakeBBlockDistribution());
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
using ABlockTile =
decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
using BBlockTile =
decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
// a b register tile for lds prefetch & mfma
auto a_block_tile0 = make_static_distributed_tensor<ADataType>(Policy::template BlockGemm<Problem>::MakeABlockDistribution());
auto a_block_tile1 = make_static_distributed_tensor<ADataType>(Policy::template BlockGemm<Problem>::MakeABlockDistribution());
auto b_block_tile0 = make_static_distributed_tensor<BDataType>(Policy::template BlockGemm<Problem>::MakeBBlockDistribution());
auto b_block_tile1 = make_static_distributed_tensor<BDataType>(Policy::template BlockGemm<Problem>::MakeBBlockDistribution());
ABlockTile a_global_load_tile;
BBlockTile b_global_load_tile;
// prefetch
// global read 0
GlobalPrefetch(a_global_load_tile, a_copy_dram_window);
GlobalPrefetch(b_global_load_tile, b_copy_dram_window);
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
// LDS write 0
LocalPrefill(a_store_lds_window0, a_global_load_tile, a_element_func);
LocalPrefill(b_store_lds_window0, b_global_load_tile, b_element_func);
block_sync_lds();
// global read 1
GlobalPrefetch(a_global_load_tile, a_copy_dram_window);
GlobalPrefetch(b_global_load_tile, b_copy_dram_window);
block_sync_lds();
// local prefetch 0
Policy::template BlockGemm<Problem>::PrefetchLds(a_load_lds_window0, a_prefetch_tile0);
Policy::template BlockGemm<Problem>::PrefetchLds(b_load_lds_window0, b_prefetch_tile0);
Policy::template BlockGemm<Problem>::PrefetchLds(a_load_lds_window0, a_block_tile0);
Policy::template BlockGemm<Problem>::PrefetchLds(b_load_lds_window0, b_block_tile0);
// if (threadIdx.x == 0) {
// printf("aalds\n");
// constexpr auto span_2d = decltype(a_block_tile0)::get_distributed_spans();
// sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// printf("%f,", type_convert<float>(a_block_tile0(i_j_idx)));
// });
// printf("\n");
// });
// printf("bbbbblds\n");
// constexpr auto span_2d2 = decltype(b_block_tile0)::get_distributed_spans();
// sweep_tile_span(span_2d2[number<0>{}], [&](auto idx0) {
// sweep_tile_span(span_2d2[number<1>{}], [&](auto idx1) {
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
// printf("%f,", type_convert<float>(b_block_tile0(i_j_idx)));
// });
// printf("\n");
// });
// }
// LDS write 1
LocalPrefill(a_store_lds_window1, a_global_load_tile, a_element_func);
LocalPrefill(b_store_lds_window1, b_global_load_tile, b_element_func);
......@@ -197,37 +234,31 @@ struct GemmPipelineAGmemBGmemCRegV1
GlobalPrefetch(a_global_load_tile, a_copy_dram_window);
GlobalPrefetch(b_global_load_tile, b_copy_dram_window);
index_t iCounter = num_loop - 1;
while(iCounter > 2)
index_t iCounter = num_loop - 2;
while(iCounter > 1)
{
// ping
{
block_sync_lds();
Policy::template BlockGemm<Problem>::PrefetchLds(a_load_lds_window1, a_prefetch_tile1);
Policy::template BlockGemm<Problem>::PrefetchLds(b_load_lds_window1, b_prefetch_tile1);
Policy::template BlockGemm<Problem>::PrefetchLds(a_load_lds_window1, a_block_tile1);
Policy::template BlockGemm<Problem>::PrefetchLds(b_load_lds_window1, b_block_tile1);
LocalPrefill(a_store_lds_window0, a_global_load_tile, a_element_func);
LocalPrefill(b_store_lds_window0, b_global_load_tile, b_element_func);
GlobalPrefetch(a_global_load_tile, a_copy_dram_window);
GlobalPrefetch(b_global_load_tile, b_copy_dram_window);
block_gemm(c_block_tile, a_prefetch_tile0, b_prefetch_tile0);
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
}
__builtin_amdgcn_sched_barrier(0);
// pong
{
block_sync_lds();
Policy::template BlockGemm<Problem>::PrefetchLds(a_load_lds_window0, a_prefetch_tile0);
Policy::template BlockGemm<Problem>::PrefetchLds(b_load_lds_window0, b_prefetch_tile0);
Policy::template BlockGemm<Problem>::PrefetchLds(a_load_lds_window0, a_block_tile0);
Policy::template BlockGemm<Problem>::PrefetchLds(b_load_lds_window0, b_block_tile0);
LocalPrefill(a_store_lds_window1, a_global_load_tile, a_element_func);
LocalPrefill(b_store_lds_window1, b_global_load_tile, b_element_func);
GlobalPrefetch(a_global_load_tile, a_copy_dram_window);
GlobalPrefetch(b_global_load_tile, b_copy_dram_window);
block_gemm(c_block_tile, a_prefetch_tile1, b_prefetch_tile1);
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
}
iCounter -= 2;
}
......@@ -236,38 +267,34 @@ struct GemmPipelineAGmemBGmemCRegV1
// 3
{
block_sync_lds();
Policy::template BlockGemm<Problem>::PrefetchLds(a_load_lds_window1, a_prefetch_tile1);
Policy::template BlockGemm<Problem>::PrefetchLds(b_load_lds_window1, b_prefetch_tile1);
Policy::template BlockGemm<Problem>::PrefetchLds(a_load_lds_window1, a_block_tile1);
Policy::template BlockGemm<Problem>::PrefetchLds(b_load_lds_window1, b_block_tile1);
LocalPrefill(a_store_lds_window0, a_global_load_tile, a_element_func);
LocalPrefill(b_store_lds_window0, b_global_load_tile, b_element_func);
block_gemm(c_block_tile, a_prefetch_tile0, b_prefetch_tile0);
__builtin_amdgcn_sched_barrier(0);
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
}
// 2
{
block_sync_lds();
Policy::template BlockGemm<Problem>::PrefetchLds(a_load_lds_window0, a_prefetch_tile0);
Policy::template BlockGemm<Problem>::PrefetchLds(b_load_lds_window0, b_prefetch_tile0);
block_gemm(c_block_tile, a_prefetch_tile1, b_prefetch_tile1);
__builtin_amdgcn_sched_barrier(0);
Policy::template BlockGemm<Problem>::PrefetchLds(a_load_lds_window0, a_block_tile0);
Policy::template BlockGemm<Problem>::PrefetchLds(b_load_lds_window0, b_block_tile0);
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
}
//1
{
block_gemm(c_block_tile, a_prefetch_tile0, b_prefetch_tile0);
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
}
//tail 2
} else {
{
block_sync_lds();
Policy::template BlockGemm<Problem>::PrefetchLds(a_load_lds_window1, a_prefetch_tile1);
Policy::template BlockGemm<Problem>::PrefetchLds(b_load_lds_window1, b_prefetch_tile1);
block_gemm(c_block_tile, a_prefetch_tile0, b_prefetch_tile0);
__builtin_amdgcn_sched_barrier(0);
Policy::template BlockGemm<Problem>::PrefetchLds(a_load_lds_window1, a_block_tile1);
Policy::template BlockGemm<Problem>::PrefetchLds(b_load_lds_window1, b_block_tile1);
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
}
// 2
{
block_gemm(c_block_tile, a_prefetch_tile1, b_prefetch_tile1);
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
}
}
return c_block_tile;
......
......@@ -97,16 +97,18 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA()
{
constexpr index_t smem_size_a = sizeof(typename Problem::ADataType) *
MakeALdsBlockDescriptor<Problem>().get_element_space_size();
constexpr index_t smem_size_a = integer_least_multiple(sizeof(typename Problem::ADataType) *
MakeALdsBlockDescriptor<Problem>().get_element_space_size(), 16)
* 2;
return smem_size_a;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB()
{
constexpr index_t smem_size_b = sizeof(typename Problem::BDataType) *
MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
constexpr index_t smem_size_b = integer_least_multiple(sizeof(typename Problem::BDataType) *
MakeBLdsBlockDescriptor<Problem>().get_element_space_size(), 16)
* 2;
return smem_size_b;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment