"vscode:/vscode.git/clone" did not exist on "fb30b7c7a2f3ac97aa8257f2b56e5e46e21fa8d3"
Commit 36da8a88 authored by Po Yen, Chen's avatar Po Yen, Chen
Browse files

Async copy K tile

parent 54d2e0a7
...@@ -60,8 +60,7 @@ struct BlockFmhaPipelineQRKSVS2Wave ...@@ -60,8 +60,7 @@ struct BlockFmhaPipelineQRKSVS2Wave
// ... together with tensor distribution. tensor dist should able to overwrite this // ... together with tensor distribution. tensor dist should able to overwrite this
static constexpr index_t kAlignmentQ = static constexpr index_t kAlignmentQ =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>(); kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
static constexpr index_t kAlignmentK = static constexpr index_t kAlignmentK = Policy::template GetAlignmentK<Problem>();
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentV = []() { static constexpr index_t kAlignmentV = []() {
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
return kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>(); return kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
...@@ -129,7 +128,7 @@ struct BlockFmhaPipelineQRKSVS2Wave ...@@ -129,7 +128,7 @@ struct BlockFmhaPipelineQRKSVS2Wave
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const QElementFunction& q_element_func, const QElementFunction& q_element_func,
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const KElementFunction& k_element_func, [[maybe_unused]] const KElementFunction& k_element_func,
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const VElementFunction& v_element_func, const VElementFunction& v_element_func,
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
...@@ -164,11 +163,19 @@ struct BlockFmhaPipelineQRKSVS2Wave ...@@ -164,11 +163,19 @@ struct BlockFmhaPipelineQRKSVS2Wave
"wrong!"); "wrong!");
// K tile in LDS // K tile in LDS
auto k_lds = make_tensor_view<address_space_enum::lds>( auto k_lds_ptr = reinterpret_cast<KDataType*>(smem_ptr);
reinterpret_cast<KDataType*>(smem_ptr), auto k_lds_store = make_tile_window(
Policy::template MakeKLdsStoreBlockDescriptor<Problem>()); make_tensor_view<address_space_enum::lds>(
auto k_lds_window_for_store = k_lds_ptr, Policy::template MakeKLdsStoreBlockDescriptor<Problem>(number<0>{})),
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kK0BlockLength>{}), {0, 0}); Policy::template MakeKLdsStoreBlockDescriptor<Problem>(number<0>{}).get_lengths(),
{0, 0, 0});
auto k_lds_Load_view = make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeKLdsLoadBlockDescriptor<Problem>());
auto k_lds_load =
make_tile_window(k_lds_Load_view,
Policy::template MakeKLdsLoadBlockDescriptor<Problem>().get_lengths(),
{0, 0});
// V tile in LDS // V tile in LDS
auto v_lds = make_tensor_view<address_space_enum::lds>( auto v_lds = make_tensor_view<address_space_enum::lds>(
...@@ -279,15 +286,16 @@ struct BlockFmhaPipelineQRKSVS2Wave ...@@ -279,15 +286,16 @@ struct BlockFmhaPipelineQRKSVS2Wave
k_dram_block_window.get_window_origin(), k_dram_block_window.get_window_origin(),
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
// load // load
k_dram_window.init_raw(); // this is necessary for async_load_tile_raw()
auto k_block_tile = load_tile(k_dram_window);
{ {
constexpr auto k_oob_ck = bool_constant<true>{};
constexpr auto k_pre_np = bool_constant<true>{};
async_load_tile_raw(k_lds_store, k_dram_window, k_oob_ck, k_pre_np);
clear_tile(s_acc); clear_tile(s_acc);
store_tile(k_lds_window_for_store, async_load_fence();
tile_elementwise_in(k_element_func, k_block_tile)); __builtin_amdgcn_sched_barrier(0);
} }
auto k_lds_window_for_load =
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0});
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{ {
...@@ -302,14 +310,15 @@ struct BlockFmhaPipelineQRKSVS2Wave ...@@ -302,14 +310,15 @@ struct BlockFmhaPipelineQRKSVS2Wave
} }
{ {
block_sync_lds(); __builtin_amdgcn_s_barrier();
static_for<0, k0_loops, 1>{}([&](auto i_k0) { static_for<0, k0_loops, 1>{}([&](auto i_k0) {
gemm_0(s_acc, gemm_0(s_acc,
get_slice_tile(q_tile, get_slice_tile(q_tile,
sequence<0, i_k0 * kK0>{}, sequence<0, i_k0 * kK0>{},
sequence<kM0, (i_k0 + 1) * kK0>{}), sequence<kM0, (i_k0 + 1) * kK0>{}),
k_lds_window_for_load); get_slice_tile(k_lds_load,
move_tile_window(k_lds_window_for_load, {0, kK0}); sequence<0, i_k0 * kK0>{},
sequence<kN0, (i_k0 + 1) * kK0>{}));
}); });
} }
......
...@@ -107,7 +107,7 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy ...@@ -107,7 +107,7 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy
// end copy from BlockFmhaPipelineQXCustomPolicy<true> // end copy from BlockFmhaPipelineQXCustomPolicy<true>
// start copy from BlockFmhaPipelineQXKSVSCustomPolicy // start copy from BlockFmhaPipelineQXKSVSCustomPolicy
static constexpr bool AsyncCopyK = false; static constexpr bool AsyncCopyK = true;
static constexpr bool AsyncCopyV = false; // TODO: this not supported yet static constexpr bool AsyncCopyV = false; // TODO: this not supported yet
static constexpr index_t NumPrefetchK = 1; static constexpr index_t NumPrefetchK = 1;
...@@ -255,7 +255,10 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy ...@@ -255,7 +255,10 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy
else else
{ {
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; // constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; // [POYENC] old tile
// size
constexpr index_t kKPerBlock =
Problem::BlockFmhaShape::kK0BlockLength; // [POYENC] updated tile size
constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps; constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
constexpr index_t warpSize = ck_tile::get_warp_size(); constexpr index_t warpSize = ck_tile::get_warp_size();
...@@ -326,8 +329,7 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy ...@@ -326,8 +329,7 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy
return q_block_dstr; return q_block_dstr;
} }
#if 0 // [POYENC] disabled since we are using #if 0
// MakeKLdsStoreBlockDescriptor/MakeVLdsStoreBlockDescriptor now
// TODO: this is used for non async copy desc. unify in the future // TODO: this is used for non async copy desc. unify in the future
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor()
...@@ -352,6 +354,7 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy ...@@ -352,6 +354,7 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy
return k_lds_block_desc; return k_lds_block_desc;
} }
#endif
template <typename Problem, index_t IBuf = 0> template <typename Problem, index_t IBuf = 0>
CK_TILE_HOST_DEVICE static constexpr auto CK_TILE_HOST_DEVICE static constexpr auto
...@@ -359,7 +362,9 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy ...@@ -359,7 +362,9 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy
{ {
// K is always k-major, we use async-copy to load into LDS // K is always k-major, we use async-copy to load into LDS
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; // constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; // [POYENC] old tile size
constexpr index_t kKPerBlock =
Problem::BlockFmhaShape::kK0BlockLength; // [POYENC] updated tile size
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps; constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
constexpr index_t warpSize = ck_tile::get_warp_size(); constexpr index_t warpSize = ck_tile::get_warp_size();
...@@ -407,61 +412,14 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy ...@@ -407,61 +412,14 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy
return k_lds_block_desc_issues_warps_lanes; return k_lds_block_desc_issues_warps_lanes;
} }
#if K_LDS_LOAD_USE_OFFSET_TRANSFORM
template <typename Problem, index_t IBuf = 0>
CK_TILE_HOST_DEVICE static constexpr auto
MakeKLdsLoadBlockDescriptor(number<IBuf> = number<0>{})
{
// K is always k-major, we use async-copy to load into LDS
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
constexpr index_t warpSize = ck_tile::get_warp_size();
constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
constexpr index_t kPad = KPack; // for async-copy, this pad is between warps
static_assert(warpSize * KVector >= kKPerBlock && warpSize * KVector % kKPerBlock == 0);
constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave
constexpr index_t LaneGroups = warpSize / LanesPerK; // within a wave
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor_with_offset(
make_tuple(number<NumIssues>{}, // n0
number<NumWarps>{}, // n2
number<LaneGroups>{}, // n1
number<kKPerBlock / KPack>{}, // k0
number<KPack>{}), // k1
make_tuple(number<NumWarps*(warpSize * KVector + kPad)>{},
number<warpSize * KVector + kPad>{},
number<kKPerBlock>{},
number<KPack>{},
number<1>{}),
number<IBuf * GetSingleSmemElementSpaceSize<Problem>()>{},
number<KPack>{},
number<1>{});
constexpr auto k_lds_block_desc = transform_tensor_descriptor(
k_lds_block_desc_0,
make_tuple(
make_merge_transform(
make_tuple(number<NumIssues>{}, number<LaneGroups>{}, number<NumWarps>{})),
make_merge_transform(make_tuple(number<kKPerBlock / KPack>{}, number<KPack>{}))),
make_tuple(sequence<0, 2, 1>{}, sequence<3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return k_lds_block_desc;
}
#else
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsLoadBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsLoadBlockDescriptor()
{ {
// K is always k-major, we use async-copy to load into LDS // K is always k-major, we use async-copy to load into LDS
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; // constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; // [POYENC] old tile size
constexpr index_t kKPerBlock =
Problem::BlockFmhaShape::kK0BlockLength; // [POYENC] updated tile size
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps; constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
constexpr index_t warpSize = ck_tile::get_warp_size(); constexpr index_t warpSize = ck_tile::get_warp_size();
...@@ -510,8 +468,9 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy ...@@ -510,8 +468,9 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy
return k_lds_block_desc; return k_lds_block_desc;
} }
#endif
#if 0 // [POYENC] disabled since we are using
// MakeKLdsStoreBlockDescriptor/MakeVLdsStoreBlockDescriptor now
// 3d + padding // 3d + padding
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsBlockDescriptor()
...@@ -640,7 +599,10 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy ...@@ -640,7 +599,10 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy
else else
{ {
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; // constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; // [POYENC] old tile
// size
constexpr index_t kKPerBlock =
Problem::BlockFmhaShape::kK0BlockLength; // [POYENC] updated tile size
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps; constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
constexpr index_t warpSize = ck_tile::get_warp_size(); constexpr index_t warpSize = ck_tile::get_warp_size();
...@@ -868,6 +830,7 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy ...@@ -868,6 +830,7 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy
} }
// end copy from BlockFmhaPipelineQXKSVSCustomPolicy // end copy from BlockFmhaPipelineQXKSVSCustomPolicy
#if 0
// TODO: this is used for non async copy desc. unify in the future // TODO: this is used for non async copy desc. unify in the future
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsStoreBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsStoreBlockDescriptor()
...@@ -893,6 +856,7 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy ...@@ -893,6 +856,7 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy
return k_lds_block_desc; return k_lds_block_desc;
} }
#endif
// 3d + padding // 3d + padding
template <typename Problem> template <typename Problem>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment