Commit 36da8a88 authored by Po Yen, Chen's avatar Po Yen, Chen
Browse files

Async copy K tile

parent 54d2e0a7
......@@ -60,8 +60,7 @@ struct BlockFmhaPipelineQRKSVS2Wave
// ... together with tensor distribution. tensor dist should able to overwrite this
static constexpr index_t kAlignmentQ =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
static constexpr index_t kAlignmentK =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentK = Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentV = []() {
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
return kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
......@@ -129,7 +128,7 @@ struct BlockFmhaPipelineQRKSVS2Wave
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const QElementFunction& q_element_func,
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const KElementFunction& k_element_func,
[[maybe_unused]] const KElementFunction& k_element_func,
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const VElementFunction& v_element_func,
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
......@@ -164,11 +163,19 @@ struct BlockFmhaPipelineQRKSVS2Wave
"wrong!");
// K tile in LDS
auto k_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<KDataType*>(smem_ptr),
Policy::template MakeKLdsStoreBlockDescriptor<Problem>());
auto k_lds_window_for_store =
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kK0BlockLength>{}), {0, 0});
auto k_lds_ptr = reinterpret_cast<KDataType*>(smem_ptr);
auto k_lds_store = make_tile_window(
make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeKLdsStoreBlockDescriptor<Problem>(number<0>{})),
Policy::template MakeKLdsStoreBlockDescriptor<Problem>(number<0>{}).get_lengths(),
{0, 0, 0});
auto k_lds_Load_view = make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeKLdsLoadBlockDescriptor<Problem>());
auto k_lds_load =
make_tile_window(k_lds_Load_view,
Policy::template MakeKLdsLoadBlockDescriptor<Problem>().get_lengths(),
{0, 0});
// V tile in LDS
auto v_lds = make_tensor_view<address_space_enum::lds>(
......@@ -279,15 +286,16 @@ struct BlockFmhaPipelineQRKSVS2Wave
k_dram_block_window.get_window_origin(),
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
// load
auto k_block_tile = load_tile(k_dram_window);
k_dram_window.init_raw(); // this is necessary for async_load_tile_raw()
{
constexpr auto k_oob_ck = bool_constant<true>{};
constexpr auto k_pre_np = bool_constant<true>{};
async_load_tile_raw(k_lds_store, k_dram_window, k_oob_ck, k_pre_np);
clear_tile(s_acc);
store_tile(k_lds_window_for_store,
tile_elementwise_in(k_element_func, k_block_tile));
async_load_fence();
__builtin_amdgcn_sched_barrier(0);
}
auto k_lds_window_for_load =
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0});
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
......@@ -302,14 +310,15 @@ struct BlockFmhaPipelineQRKSVS2Wave
}
{
block_sync_lds();
__builtin_amdgcn_s_barrier();
static_for<0, k0_loops, 1>{}([&](auto i_k0) {
gemm_0(s_acc,
get_slice_tile(q_tile,
sequence<0, i_k0 * kK0>{},
sequence<kM0, (i_k0 + 1) * kK0>{}),
k_lds_window_for_load);
move_tile_window(k_lds_window_for_load, {0, kK0});
get_slice_tile(k_lds_load,
sequence<0, i_k0 * kK0>{},
sequence<kN0, (i_k0 + 1) * kK0>{}));
});
}
......
......@@ -107,7 +107,7 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy
// end copy from BlockFmhaPipelineQXCustomPolicy<true>
// start copy from BlockFmhaPipelineQXKSVSCustomPolicy
static constexpr bool AsyncCopyK = false;
static constexpr bool AsyncCopyK = true;
static constexpr bool AsyncCopyV = false; // TODO: this not supported yet
static constexpr index_t NumPrefetchK = 1;
......@@ -255,7 +255,10 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy
else
{
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
// constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; // [POYENC] old tile
// size
constexpr index_t kKPerBlock =
Problem::BlockFmhaShape::kK0BlockLength; // [POYENC] updated tile size
constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
constexpr index_t warpSize = ck_tile::get_warp_size();
......@@ -326,8 +329,7 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy
return q_block_dstr;
}
#if 0 // [POYENC] disabled since we are using
// MakeKLdsStoreBlockDescriptor/MakeVLdsStoreBlockDescriptor now
#if 0
// TODO: this is used for non async copy desc. unify in the future
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor()
......@@ -352,6 +354,7 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy
return k_lds_block_desc;
}
#endif
template <typename Problem, index_t IBuf = 0>
CK_TILE_HOST_DEVICE static constexpr auto
......@@ -359,7 +362,9 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy
{
// K is always k-major, we use async-copy to load into LDS
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
// constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; // [POYENC] old tile size
constexpr index_t kKPerBlock =
Problem::BlockFmhaShape::kK0BlockLength; // [POYENC] updated tile size
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
constexpr index_t warpSize = ck_tile::get_warp_size();
......@@ -407,61 +412,14 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy
return k_lds_block_desc_issues_warps_lanes;
}
#if K_LDS_LOAD_USE_OFFSET_TRANSFORM
template <typename Problem, index_t IBuf = 0>
CK_TILE_HOST_DEVICE static constexpr auto
MakeKLdsLoadBlockDescriptor(number<IBuf> = number<0>{})
{
// K is always k-major, we use async-copy to load into LDS
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
constexpr index_t warpSize = ck_tile::get_warp_size();
constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
constexpr index_t kPad = KPack; // for async-copy, this pad is between warps
static_assert(warpSize * KVector >= kKPerBlock && warpSize * KVector % kKPerBlock == 0);
constexpr index_t LanesPerK = kKPerBlock / KVector; // within a wave
constexpr index_t LaneGroups = warpSize / LanesPerK; // within a wave
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor_with_offset(
make_tuple(number<NumIssues>{}, // n0
number<NumWarps>{}, // n2
number<LaneGroups>{}, // n1
number<kKPerBlock / KPack>{}, // k0
number<KPack>{}), // k1
make_tuple(number<NumWarps*(warpSize * KVector + kPad)>{},
number<warpSize * KVector + kPad>{},
number<kKPerBlock>{},
number<KPack>{},
number<1>{}),
number<IBuf * GetSingleSmemElementSpaceSize<Problem>()>{},
number<KPack>{},
number<1>{});
constexpr auto k_lds_block_desc = transform_tensor_descriptor(
k_lds_block_desc_0,
make_tuple(
make_merge_transform(
make_tuple(number<NumIssues>{}, number<LaneGroups>{}, number<NumWarps>{})),
make_merge_transform(make_tuple(number<kKPerBlock / KPack>{}, number<KPack>{}))),
make_tuple(sequence<0, 2, 1>{}, sequence<3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return k_lds_block_desc;
}
#else
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsLoadBlockDescriptor()
{
// K is always k-major, we use async-copy to load into LDS
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
// constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; // [POYENC] old tile size
constexpr index_t kKPerBlock =
Problem::BlockFmhaShape::kK0BlockLength; // [POYENC] updated tile size
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
constexpr index_t warpSize = ck_tile::get_warp_size();
......@@ -510,8 +468,9 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy
return k_lds_block_desc;
}
#endif
#if 0 // [POYENC] disabled since we are using
// MakeKLdsStoreBlockDescriptor/MakeVLdsStoreBlockDescriptor now
// 3d + padding
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsBlockDescriptor()
......@@ -640,7 +599,10 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy
else
{
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
// constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; // [POYENC] old tile
// size
constexpr index_t kKPerBlock =
Problem::BlockFmhaShape::kK0BlockLength; // [POYENC] updated tile size
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
constexpr index_t warpSize = ck_tile::get_warp_size();
......@@ -868,6 +830,7 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy
}
// end copy from BlockFmhaPipelineQXKSVSCustomPolicy
#if 0
// TODO: this is used for non async copy desc. unify in the future
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsStoreBlockDescriptor()
......@@ -893,6 +856,7 @@ struct BlockFmhaPipelineQRKSVS2WaveDefaultPolicy
return k_lds_block_desc;
}
#endif
// 3d + padding
template <typename Problem>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment