Unverified Commit 2a30cfdd authored by arai713's avatar arai713 Committed by GitHub
Browse files

Merge branch 'develop' into codegen-enable-hiprtc

parents 9533a172 78195ccc
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename BlockFmhaShape_>
struct FmhaFwdTilePartitioner
{
using BlockFmhaShape = ck_tile::remove_cvref_t<BlockFmhaShape_>;
static constexpr ck_tile::index_t kM0 = BlockFmhaShape::kM0;
static constexpr ck_tile::index_t kN0 = BlockFmhaShape::kN0;
static constexpr ck_tile::index_t kK0 = BlockFmhaShape::kK0;
static constexpr ck_tile::index_t kN1 = BlockFmhaShape::kN1;
static constexpr ck_tile::index_t kK1 = BlockFmhaShape::kK1;
static constexpr const char* name = "shb";
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
ck_tile::index_t nhead_,
ck_tile::index_t seqlen_q_,
ck_tile::index_t hdim_v_)
{
// TODO: this may need tuning
return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0) *
ck_tile::integer_divide_ceil(hdim_v_, kN1),
nhead_,
batch_size_);
}
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t hdim_v)
{
// const index_t num_tile_m0 = seqlen_q / kM0;
const index_t num_tile_n1 = ck_tile::integer_divide_ceil(hdim_v, kN1);
const index_t i_block = blockIdx.x;
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
const auto f = [](index_t dividend, index_t divisor) {
index_t quotient = dividend / divisor;
index_t modulus = dividend - quotient * divisor;
return ck_tile::make_tuple(quotient, modulus);
};
const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
}
};
template <typename BlockFmhaShape_>
using FmhaFwdTilePartitioner_SHB = FmhaFwdTilePartitioner<BlockFmhaShape_>;
template <typename BlockFmhaShape_>
struct FmhaFwdTilePartitioner_HBS
{
using BlockFmhaShape = ck_tile::remove_cvref_t<BlockFmhaShape_>;
static constexpr ck_tile::index_t kM0 = BlockFmhaShape::kM0;
static constexpr ck_tile::index_t kN0 = BlockFmhaShape::kN0;
static constexpr ck_tile::index_t kK0 = BlockFmhaShape::kK0;
static constexpr ck_tile::index_t kN1 = BlockFmhaShape::kN1;
static constexpr ck_tile::index_t kK1 = BlockFmhaShape::kK1;
static constexpr const char* name = "hbs";
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
ck_tile::index_t nhead_,
ck_tile::index_t seqlen_q_,
ck_tile::index_t hdim_v_)
{
// TODO: this may need tuning
return dim3(nhead_,
batch_size_,
ck_tile::integer_divide_ceil(seqlen_q_, kM0) *
ck_tile::integer_divide_ceil(hdim_v_, kN1));
}
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t hdim_v)
{
// const index_t num_tile_m0 = seqlen_q / kM0;
const index_t num_tile_n1 = ck_tile::integer_divide_ceil(hdim_v, kN1);
const index_t i_block = blockIdx.z;
const index_t i_nhead = blockIdx.x;
const index_t i_batch = blockIdx.y;
const auto f = [](index_t dividend, index_t divisor) {
index_t quotient = dividend / divisor;
index_t modulus = dividend - quotient * divisor;
return ck_tile::make_tuple(quotient, modulus);
};
const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
}
};
} // namespace ck_tile
...@@ -53,6 +53,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline ...@@ -53,6 +53,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>; using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>; using ODataType = remove_cvref_t<typename Problem::ODataType>;
static constexpr index_t kNumWarps = Problem::kNumWarps;
static constexpr index_t kBlockSize = Problem::kBlockSize; static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kHeadDimV = Problem::kHeadDimV; static constexpr index_t kHeadDimV = Problem::kHeadDimV;
...@@ -117,7 +118,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline ...@@ -117,7 +118,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline
const LSEElementFunction& lse_element_func, const LSEElementFunction& lse_element_func,
const OaccElementFunction& o_acc_element_func, const OaccElementFunction& o_acc_element_func,
index_t num_splits, index_t num_splits,
index_t seqlen_q,
void* smem_ptr) const void* smem_ptr) const
{ {
// lse_acc tile in LDS // lse_acc tile in LDS
...@@ -143,11 +143,12 @@ struct BlockFmhaFwdSplitKVCombinePipeline ...@@ -143,11 +143,12 @@ struct BlockFmhaFwdSplitKVCombinePipeline
// copy lse_acc tile (shape=[kMaxSplits, kM0]) to LDS (shape=[kMaxSplits, kM0]). // copy lse_acc tile (shape=[kMaxSplits, kM0]) to LDS (shape=[kMaxSplits, kM0]).
auto lse_acc_tile = load_tile(lse_acc_dram_window); auto lse_acc_tile = load_tile(lse_acc_dram_window);
store_tile(lse_acc_lds_write_window, lse_acc_tile); store_tile(lse_acc_lds_write_window, lse_acc_tile);
block_sync_lds();
auto lse_accum = make_static_distributed_tensor<LSEDataType>( auto lse_accum = make_static_distributed_tensor<LSEDataType>(
Policy::template MakeLSEaccRegTileDistribution<Problem>()); Policy::template MakeLSEaccRegTileDistribution<Problem>());
__builtin_amdgcn_sched_barrier(0);
block_sync_lds();
// copy LDS (shape=[kM0, kMaxSplits]) to lse_accum (shape=[kM0, kMaxSplits]) // copy LDS (shape=[kM0, kMaxSplits]) to lse_accum (shape=[kM0, kMaxSplits])
// and fill up -INF values outside the [kM0, num_splits] region. // and fill up -INF values outside the [kM0, num_splits] region.
{ {
...@@ -264,46 +265,94 @@ struct BlockFmhaFwdSplitKVCombinePipeline ...@@ -264,46 +265,94 @@ struct BlockFmhaFwdSplitKVCombinePipeline
} }
}); });
} }
block_sync_lds();
if constexpr(kStoreLSE) if constexpr(kStoreLSE)
{ {
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse_logsum)); store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse_logsum));
} }
auto o_acc_dist = Policy::template MakeOaccDramTileDistribution<Problem>(); auto o_acc_4_dist = Policy::template MakeOacc4DramTileDistribution<Problem>();
auto o_acc_dram_window = auto o_acc_4_dram_window =
make_tile_window(o_acc_dram_block_window_tmp.get_bottom_tensor_view(), make_tile_window(o_acc_dram_block_window_tmp.get_bottom_tensor_view(),
o_acc_dram_block_window_tmp.get_window_lengths(), o_acc_dram_block_window_tmp.get_window_lengths(),
o_acc_dram_block_window_tmp.get_window_origin(), o_acc_dram_block_window_tmp.get_window_origin(),
o_acc_dist); o_acc_4_dist);
auto o_acc = make_static_distributed_tensor<OaccDataType>(o_acc_dist);
clear_tile(o_acc);
const index_t padded_seqlen_q = integer_divide_ceil(seqlen_q, kM0) * kM0; // shape=[4 * KM0, kN1]
auto o_acc_4 = make_static_distributed_tensor<OaccDataType>(o_acc_4_dist);
clear_tile(o_acc_4);
for(index_t i_split = 0; i_split < num_splits; ++i_split) const index_t padded_num_splits = integer_divide_ceil(num_splits, kNumWarps) * kNumWarps;
__builtin_amdgcn_sched_barrier(0);
block_sync_lds();
// each warp handles a [KM0, kN1] tile
for(index_t split_start = 0; split_start < padded_num_splits; split_start += kNumWarps)
{ {
auto o_tile = load_tile(o_acc_dram_window); auto o_tile = load_tile(o_acc_4_dram_window);
const index_t i_split = split_start + get_warp_id();
const index_t row_start = kM0 * get_warp_id();
{ {
constexpr auto spans = decltype(o_acc)::get_distributed_spans(); constexpr auto spans = decltype(o_acc_4)::get_distributed_spans();
sweep_tile_span(spans[number<0>{}], [&](auto idx0) { sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(spans[number<1>{}], [&](auto idx1) { sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1); constexpr auto i_j_idx = make_tuple(idx0, idx1);
const auto x_indices = get_x_indices_from_distributed_indices( const auto x_indices = get_x_indices_from_distributed_indices(
o_acc.get_tile_distribution(), i_j_idx); o_acc_4.get_tile_distribution(), i_j_idx);
const auto row = x_indices.at(number<0>{}); const auto row = x_indices.at(number<0>{});
const LSEDataType lse_scale = lse_acc_lds(row, i_split); const LSEDataType lse_scale = lse_acc_lds(row - row_start, i_split);
o_acc(i_j_idx) += lse_scale * o_tile(i_j_idx); o_acc_4(i_j_idx) += lse_scale * o_tile(i_j_idx);
}); });
}); });
} }
move_tile_window(o_acc_dram_window, {padded_seqlen_q, 0}); move_tile_window(o_acc_4_dram_window, {kNumWarps * kM0, 0});
}
// 4 o_acc tiles in LDS. shape=[4 * kM0, kN1]
OaccDataType* o_acc_4_lds_ptr = static_cast<OaccDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeLSEacc<Problem>()));
{
auto o_acc_4_lds_window = [&]() {
auto desc = Policy::template MakeOacc4LdsBlockDescriptor<Problem>();
auto view = make_tensor_view<address_space_enum::lds>(o_acc_4_lds_ptr, desc);
return make_tile_window(view, desc.get_lengths(), {0, 0});
}();
store_tile(o_acc_4_lds_window, o_acc_4);
} }
auto o_acc_dist = Policy::template MakeOaccDramTileDistribution<Problem>();
auto o_acc_4_lds_window = [&]() {
auto desc = Policy::template MakeOacc4LdsBlockDescriptor<Problem>();
auto view = make_tensor_view<address_space_enum::lds>(o_acc_4_lds_ptr, desc);
return make_tile_window(view, desc.get_lengths(), {0, 0}, o_acc_dist);
}();
auto o_acc = make_static_distributed_tensor<OaccDataType>(o_acc_dist);
clear_tile(o_acc);
__builtin_amdgcn_sched_barrier(0);
block_sync_lds();
static_for<0, kNumWarps, 1>{}([&](auto) {
auto o_acc_in = load_tile(o_acc_4_lds_window);
{
constexpr auto spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
o_acc(i_j_idx) += o_acc_in(i_j_idx);
});
});
}
move_tile_window(o_acc_4_lds_window, {kM0, 0});
});
o_acc = tile_elementwise_in(o_acc_element_func, o_acc); o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
return o_acc; return o_acc;
...@@ -316,7 +365,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline ...@@ -316,7 +365,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline
const OaccDramBlockWindow& o_acc_dram_block_window, const OaccDramBlockWindow& o_acc_dram_block_window,
LSEDramBlockWindow& lse_dram_block_window, LSEDramBlockWindow& lse_dram_block_window,
index_t num_splits, index_t num_splits,
index_t seqlen_q,
void* smem_ptr) const void* smem_ptr) const
{ {
return operator()(lse_acc_dram_block_window, return operator()(lse_acc_dram_block_window,
...@@ -325,7 +373,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline ...@@ -325,7 +373,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline
identity{}, identity{},
identity{}, identity{},
num_splits, num_splits,
seqlen_q,
smem_ptr); smem_ptr);
} }
}; };
......
...@@ -10,23 +10,38 @@ namespace ck_tile { ...@@ -10,23 +10,38 @@ namespace ck_tile {
struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
{ {
template <index_t BlockSize, index_t M, index_t N, typename DataType> template <index_t NumWarps, index_t M, index_t N, typename DataType>
CK_TILE_HOST_DEVICE static constexpr auto GetMaxNumWarpsForTile()
{
static_assert(NumWarps == 1 || NumWarps == 2 || NumWarps == 4);
constexpr index_t ElemPerThread = (M * N) / (NumWarps * get_warp_size());
if constexpr(0 < ElemPerThread)
{
return NumWarps;
}
else
{ // try dividing tile by smaller # of warps
return GetMaxNumWarpsForTile<NumWarps / 2, M, N, DataType>();
}
}
template <index_t NumWarps, index_t M, index_t N, typename DataType>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeForTile() CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeForTile()
{ {
constexpr index_t PixelsPerThread = (M * N) / BlockSize; constexpr index_t MaxNumWarps = GetMaxNumWarpsForTile<NumWarps, M, N, DataType>();
static_assert(0 < PixelsPerThread);
constexpr index_t MaxNPerThread = 16 / sizeof(DataType); constexpr index_t ElemPerThread = (M * N) / (MaxNumWarps * get_warp_size());
constexpr index_t NPerThread = min(MaxNPerThread, PixelsPerThread);
return NPerThread; constexpr index_t MaxNPerThread = 16 / sizeof(DataType);
return min(MaxNPerThread, ElemPerThread);
} }
// alignment for dram lse tile (shape=[kMaxSplits, kM0]) // alignment for dram lse tile (shape=[kMaxSplits, kM0])
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentLSE() CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentLSE()
{ {
return GetVectorSizeForTile<Problem::kBlockSize, return GetVectorSizeForTile<Problem::kNumWarps,
Problem::kMaxSplits, Problem::kMaxSplits,
Problem::kM0, Problem::kM0,
typename Problem::LSEDataType>(); typename Problem::LSEDataType>();
...@@ -56,40 +71,54 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy ...@@ -56,40 +71,54 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeLSEacc()
{ {
return sizeof(typename Problem::LSEDataType) * return sizeof(typename Problem::LSEDataType) *
MakeLSEaccLdsBlockDescriptor<Problem>().get_element_space_size(); MakeLSEaccLdsBlockDescriptor<Problem>().get_element_space_size();
} }
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeOacc4()
{
return sizeof(typename Problem::OaccDataType) *
MakeOacc4LdsBlockDescriptor<Problem>().get_element_space_size();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return GetSmemSizeLSEacc<Problem>() + GetSmemSizeOacc4<Problem>();
}
// shape=[kMaxSplits, kM0] // shape=[kMaxSplits, kM0]
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccDramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccDramTileDistribution()
{ {
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>; using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNumWarps = Problem::kNumWarps;
constexpr index_t kNPerBlock = Problem::kM0;
constexpr index_t kMPerBlock = Problem::kMaxSplits; constexpr index_t kMPerBlock = Problem::kMaxSplits;
constexpr index_t kNPerBlock = Problem::kM0;
constexpr index_t MaxNumWarps =
GetMaxNumWarpsForTile<Problem::kNumWarps, kNPerBlock, kMPerBlock, LSEDataType>();
constexpr index_t Replicate = Problem::kNumWarps / MaxNumWarps;
constexpr index_t NPerThread = constexpr index_t NPerThread =
GetVectorSizeForTile<kBlockSize, kMPerBlock, kNPerBlock, LSEDataType>(); GetVectorSizeForTile<MaxNumWarps, kMPerBlock, kNPerBlock, LSEDataType>();
constexpr index_t NThreads = kNPerBlock / NPerThread; constexpr index_t NThreads = kNPerBlock / NPerThread;
constexpr index_t MThreadsPerWarp = get_warp_size() / NThreads; constexpr index_t MThreadsPerWarp = get_warp_size() / NThreads;
constexpr index_t MPerThread = kMPerBlock / (kNumWarps * MThreadsPerWarp); constexpr index_t MPerThread = kMPerBlock / (MaxNumWarps * MThreadsPerWarp);
static_assert(MPerThread * MaxNumWarps * MThreadsPerWarp == kMPerBlock);
static_assert(NThreads * NPerThread == kNPerBlock); static_assert(NThreads * NPerThread == kNPerBlock);
static_assert(MPerThread * kNumWarps * MThreadsPerWarp == kMPerBlock);
return make_static_tile_distribution( return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>, tile_distribution_encoding<sequence<Replicate>,
tuple<sequence<MPerThread, kNumWarps, MThreadsPerWarp>, tuple<sequence<MPerThread, MaxNumWarps, MThreadsPerWarp>,
sequence<NThreads, NPerThread>>, sequence<NThreads, NPerThread>>,
tuple<sequence<1>, sequence<1, 2>>, tuple<sequence<0, 1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>, tuple<sequence<0, 1>, sequence<2, 0>>,
sequence<1, 2>, sequence<1, 2>,
sequence<0, 1>>{}); sequence<0, 1>>{});
} }
...@@ -100,17 +129,15 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy ...@@ -100,17 +129,15 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
{ {
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>; using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kMPerBlock = Problem::kM0;
constexpr index_t kNPerBlock = Problem::kMaxSplits;
constexpr index_t kMPerBlock = Problem::kMaxSplits;
constexpr index_t kNPerBlock = Problem::kM0;
constexpr index_t NPack = constexpr index_t NPack =
GetVectorSizeForTile<kBlockSize, kMPerBlock, kNPerBlock, LSEDataType>(); GetVectorSizeForTile<Problem::kNumWarps, kMPerBlock, kNPerBlock, LSEDataType>();
constexpr auto lse_acc_lds_block_desc_0 = make_naive_tensor_descriptor( constexpr auto lse_acc_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kNPerBlock / NPack>{}, number<kMPerBlock>{}, number<NPack>{}), make_tuple(number<kNPerBlock / NPack>{}, number<kMPerBlock>{}, number<NPack>{}),
make_tuple(number<(kMPerBlock + 1) * NPack>{}, number<NPack>{}, number<1>{}), make_tuple(number<(kMPerBlock + 1) * NPack>{}, number<NPack>{}, number<1>{}),
number<8>{}, number<NPack>{},
number<1>{}); number<1>{});
constexpr auto lse_acc_lds_block_desc = transform_tensor_descriptor( constexpr auto lse_acc_lds_block_desc = transform_tensor_descriptor(
...@@ -129,17 +156,15 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy ...@@ -129,17 +156,15 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
{ {
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>; using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kMPerBlock = Problem::kM0;
constexpr index_t kNPerBlock = Problem::kMaxSplits;
constexpr index_t kMPerBlock = Problem::kMaxSplits;
constexpr index_t kNPerBlock = Problem::kM0;
constexpr index_t NPack = constexpr index_t NPack =
GetVectorSizeForTile<kBlockSize, kMPerBlock, kNPerBlock, LSEDataType>(); GetVectorSizeForTile<Problem::kNumWarps, kMPerBlock, kNPerBlock, LSEDataType>();
constexpr auto lse_acc_lds_block_desc_0 = make_naive_tensor_descriptor( constexpr auto lse_acc_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kNPerBlock / NPack>{}, number<kMPerBlock>{}, number<NPack>{}), make_tuple(number<kNPerBlock / NPack>{}, number<kMPerBlock>{}, number<NPack>{}),
make_tuple(number<(kMPerBlock + 1) * NPack>{}, number<NPack>{}, number<1>{}), make_tuple(number<(kMPerBlock + 1) * NPack>{}, number<NPack>{}, number<1>{}),
number<8>{}, number<NPack>{},
number<1>{}); number<1>{});
constexpr auto lse_acc_t_lds_block_desc = transform_tensor_descriptor( constexpr auto lse_acc_t_lds_block_desc = transform_tensor_descriptor(
...@@ -152,33 +177,86 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy ...@@ -152,33 +177,86 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
return lse_acc_t_lds_block_desc; return lse_acc_t_lds_block_desc;
} }
// 3d + padding, shape=[4 * kM0, kN1]
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccRegTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeOacc4LdsBlockDescriptor()
{ {
constexpr index_t kBlockSize = Problem::kBlockSize; using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
constexpr index_t kNPerBlock = Problem::kMaxSplits; constexpr index_t kMPerBlock = 4 * Problem::kM0;
constexpr index_t kNPerBlock = Problem::kN1;
constexpr index_t NPack =
GetVectorSizeForTile<Problem::kNumWarps, kMPerBlock, kNPerBlock, LSEDataType>();
constexpr auto o_acc_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kNPerBlock / NPack>{}, number<kMPerBlock>{}, number<NPack>{}),
make_tuple(number<(kMPerBlock + 1) * NPack>{}, number<NPack>{}, number<1>{}),
number<8>{},
number<1>{});
constexpr auto o_acc_t_lds_block_desc = transform_tensor_descriptor(
o_acc_lds_block_desc_0,
make_tuple(make_pass_through_transform(kMPerBlock),
make_merge_transform(make_tuple(kNPerBlock / NPack, NPack))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
return o_acc_t_lds_block_desc;
}
// shape=[kM0, kMaxSplits]
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccRegTileDistribution()
{
constexpr index_t kMPerBlock = Problem::kM0; constexpr index_t kMPerBlock = Problem::kM0;
constexpr index_t kNPerBlock = Problem::kMaxSplits;
constexpr index_t NThreads = 4; constexpr index_t MaxNThreads = 8;
constexpr index_t NPerThread = kNPerBlock / NThreads; constexpr index_t NThreads = min(kNPerBlock, MaxNThreads);
constexpr index_t NPerThread = kNPerBlock / NThreads;
constexpr index_t MThreads = kBlockSize / NThreads; constexpr index_t MPerThread = 1;
constexpr index_t MPerThread = kMPerBlock / MThreads; constexpr index_t MThreads = kMPerBlock / MPerThread;
constexpr index_t MWarps = kBlockSize / get_warp_size();
constexpr index_t MThreadPerWarp = get_warp_size() / NThreads; constexpr index_t MThreadPerWarp = get_warp_size() / NThreads;
constexpr index_t MaxNumWarps = (MThreads * NThreads) / get_warp_size();
constexpr index_t Replicate = Problem::kNumWarps / MaxNumWarps;
static_assert(MaxNumWarps * MThreadPerWarp * MPerThread == kMPerBlock);
static_assert(NThreads * NPerThread == kNPerBlock); static_assert(NThreads * NPerThread == kNPerBlock);
static_assert(MWarps * MThreadPerWarp * MPerThread == kMPerBlock);
return make_static_tile_distribution( return make_static_tile_distribution(
tile_distribution_encoding< tile_distribution_encoding<sequence<Replicate>,
sequence<1>, tuple<sequence<MaxNumWarps, MThreadPerWarp, MPerThread>,
tuple<sequence<MWarps, MThreadPerWarp, MPerThread>, sequence<NThreads, NPerThread>>, sequence<NThreads, NPerThread>>,
tuple<sequence<1>, sequence<2, 1>>, tuple<sequence<0, 1>, sequence<2, 1>>,
tuple<sequence<0>, sequence<0, 1>>, tuple<sequence<0, 0>, sequence<0, 1>>,
sequence<1, 2>, sequence<1, 2>,
sequence<2, 1>>{}); sequence<2, 1>>{});
}
// similar to MakeOaccDramTileDistribution(), but duplicate same 1-warp encoding 4 times on M
// direction
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeOacc4DramTileDistribution()
{
constexpr index_t kMPerBlock = Problem::kM0; // real kMPerBlock we want is (4 * kM0)
constexpr index_t kNPerBlock = Problem::kN1;
static_assert(get_warp_size() <= kMPerBlock * kNPerBlock);
constexpr index_t M1 = 1; // compose encoding base on 1 warp
constexpr index_t M2 = min(kMPerBlock / M1, get_warp_size());
constexpr index_t N0 = get_warp_size() / M2;
constexpr index_t N1 = kNPerBlock / N0;
constexpr index_t M0 = kMPerBlock / (M2 * M1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<4, M0, M1, M2>, sequence<N0, N1>>,
tuple<sequence<1, 1>, sequence<1, 2>>,
tuple<sequence<0, 2>, sequence<3, 0>>,
sequence<1, 2>,
sequence<1, 1>>{});
} }
template <typename Problem> template <typename Problem>
...@@ -187,6 +265,7 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy ...@@ -187,6 +265,7 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::kM0; constexpr index_t kMPerBlock = Problem::kM0;
constexpr index_t kNPerBlock = Problem::kN1; constexpr index_t kNPerBlock = Problem::kN1;
static_assert(kBlockSize <= kMPerBlock * kNPerBlock);
constexpr index_t M1 = kBlockSize / get_warp_size(); constexpr index_t M1 = kBlockSize / get_warp_size();
constexpr index_t M2 = min(kMPerBlock / M1, get_warp_size()); constexpr index_t M2 = min(kMPerBlock / M1, get_warp_size());
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile {
// This pipeline is qkv all located in LDS
template <typename Problem_,
typename Policy_ = BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVSDefaultPolicy>
struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using QDataType = remove_cvref_t<typename Problem::QDataType>;
using KDataType = remove_cvref_t<typename Problem::KDataType>;
using VDataType = remove_cvref_t<typename Problem::VDataType>;
using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>;
using SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>;
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
using PDataType = remove_cvref_t<typename Problem::PDataType>;
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once
static_assert(kQLoadOnce == Policy::QLoadOnce);
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = BlockFmhaShape::kM0;
static constexpr index_t kN0 = BlockFmhaShape::kN0;
static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kN1 = BlockFmhaShape::kN1;
static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kIsPagedKV = Problem::kIsPagedKV;
static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static constexpr index_t kAlignmentQ =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
static constexpr index_t kAlignmentK =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentV = []() {
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
return kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
else
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
}();
static constexpr index_t kAlignmentOacc =
kPadHeadDimV ? 1 : Policy::template GetAlignmentOacc<Problem>();
static constexpr index_t kAlignmentBias =
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
static constexpr index_t kBlockPerCu = []() {
if constexpr(Problem::kBlockPerCu != -1)
return Problem::kBlockPerCu;
else
{
if constexpr(kQKHeaddim <= 32)
{
return 2;
}
else if constexpr(kQKHeaddim <= 64)
{
return 3;
}
else if constexpr(kQKHeaddim <= 128)
{
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
return 1;
else
return 2;
}
else if constexpr(kQKHeaddim <= 256)
{
return 1;
}
}
}();
static constexpr const char* name = "qr_nwarp_sshuffle";
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowLengths,
typename KPageBlockNavigator,
typename VDramBlockWindowLengths,
typename VPageBlockNavigator,
typename BiasDramBlockWindowTmp,
typename LSEaccDramBlockWindowTmp,
typename QElementFunction,
typename KElementFunction,
typename VElementFunction,
typename BiasElementFunction,
typename LSEaccElementFunction,
typename SAccElementFunction,
typename PComputeElementFunction,
typename OAccElementFunction,
typename PositionEncoding>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const QElementFunction& q_element_func,
const KDramBlockWindowLengths& k_dram_block_window_lengths, // N0*K0 tile
const KPageBlockNavigator& k_page_block_navigator,
const KElementFunction& k_element_func,
const VDramBlockWindowLengths& v_dram_block_window_lengths, // N1*K1 tile
const VPageBlockNavigator& v_page_block_navigator,
const VElementFunction& v_element_func,
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
const BiasElementFunction& bias_element_func,
LSEaccDramBlockWindowTmp& lse_acc_dram_window_tmp, // M0*1 tile
const LSEaccElementFunction& lse_acc_element_func,
const SAccElementFunction& s_acc_element_func,
const PComputeElementFunction& p_compute_element_func,
const OAccElementFunction& o_acc_element_func,
index_t num_splits,
index_t i_split,
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate
void* smem_ptr) const
{
static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
std::is_same_v<KDataType, remove_cvref_t<typename KPageBlockNavigator::DataType>> &&
std::is_same_v<VDataType, remove_cvref_t<typename VPageBlockNavigator::DataType>>,
"wrong!");
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kSubQKHeaddim ==
QDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kN0 == KDramBlockWindowLengths{}[number<0>{}] &&
kK0 == KDramBlockWindowLengths{}[number<1>{}] &&
kN1 == VDramBlockWindowLengths{}[number<0>{}] &&
kK1 == VDramBlockWindowLengths{}[number<1>{}] &&
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
// Q tile in LDS
QDataType* q_lds_ptr =
static_cast<QDataType*>(static_cast<void*>(static_cast<char*>(smem_ptr)));
auto q_lds = make_tensor_view<address_space_enum::lds>(
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
// K tile in LDS
KDataType* k_lds_ptr =
static_cast<KDataType*>(static_cast<void*>(static_cast<char*>(smem_ptr)));
auto k_lds = make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeKLdsBlockDescriptor<Problem>());
auto k_lds_window =
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0});
// V tile in LDS
auto v_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<VDataType*>(static_cast<char*>(smem_ptr) +
max(Policy::template GetSmemSizeQ<Problem>(),
Policy::template GetSmemSizeK<Problem>())),
Policy::template MakeVLdsBlockDescriptor<Problem>());
auto v_lds_window = make_tile_window(
v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
// S tile in LDS
auto s_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<SaccDataType*>(reinterpret_cast<char*>(smem_ptr) +
max(Policy::template GetSmemSizeQ<Problem>(),
Policy::template GetSmemSizeK<Problem>())),
Policy::template MakeSLdsBlockDescriptor<Problem>());
auto s_write_lds_window = make_tile_window(
s_lds, Policy::template MakeSLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
auto s_read_lds_window =
make_tile_window(s_lds,
Policy::template MakeSLdsBlockDescriptor<Problem>().get_lengths(),
{0, 0},
Policy::template MakeSRegTileDistribution<Problem>());
// Block GEMM
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
auto q_dram_window =
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
q_dram_block_window_tmp.get_window_lengths(),
q_dram_block_window_tmp.get_window_origin(),
Policy::template MakeQDramTileDistribution<Problem>());
// load Q here, will store Q into LDS to maximize throughput
auto origin_q = load_tile(q_dram_window);
using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
auto s_acc = SaccBlockTileType{};
// reduction function for softmax
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
auto o_acc = OaccBlockTileType{};
// infer Sacc, S, P, M, L, Oacc type
using SBlockTileType = decltype(cast_tile<SMPLComputeDataType>(o_acc));
using MLBlockTileType = decltype(block_tile_reduce<SMPLComputeDataType>(
SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0}));
// init M, L
auto m = MLBlockTileType{};
auto l = MLBlockTileType{};
clear_tile(o_acc);
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
clear_tile(l);
const auto q_origin = q_dram_window.get_window_origin();
const auto [logical_seqlen_k_start, logical_seqlen_k_end] = mask.GetTileRangeAlongX(
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
// check early exit if no work to do
if constexpr(FmhaMask::IsMasking || kPadSeqLenK || kHasUnevenSplits)
{
const index_t logical_num_total_loop =
integer_divide_ceil(logical_seqlen_k_end - logical_seqlen_k_start, kN0);
if(logical_num_total_loop <= 0)
{
if constexpr(kStoreLSE)
{
auto lse_acc =
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
set_tile(lse_acc, -numeric<SMPLComputeDataType>::infinity());
if(get_thread_local_1d_id() < kM0)
{
store_tile(lse_acc_dram_window_tmp,
tile_elementwise_in(lse_acc_element_func, lse_acc));
}
}
// Note: here occ are all cleard, return it
// Note: q loaded but no fence, ignore it.
return o_acc;
}
}
const index_t physical_seqlen_k_start = logical_seqlen_k_start + kv_l2p_offset;
const index_t physical_seqlen_k_end = logical_seqlen_k_end + kv_l2p_offset;
// make sure the first tile is completely located in page-block (page-block size should be
// divisible by kN0)
// relationship between each *_start variables: aligned_physical_seqlen_k_start <=
// physical_seqlen_k_start, logical_seqlen_k_start <= physical_seqlen_k_start
const index_t aligned_physical_seqlen_k_start =
[&, physical_seqlen_k_start_ = physical_seqlen_k_start] {
if constexpr(kIsPagedKV)
{
return kN0 * integer_divide_floor(physical_seqlen_k_start_, kN0);
}
else
{
return physical_seqlen_k_start_;
}
}();
const index_t num_total_loop =
integer_divide_ceil(physical_seqlen_k_end - aligned_physical_seqlen_k_start, kN0);
auto [i_page_block_k, k_dram_block_window] = k_page_block_navigator.make_tile_window(
k_dram_block_window_lengths, {aligned_physical_seqlen_k_start, 0});
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
auto bias_dram_window =
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(),
{bias_origin.at(number<0>{}),
logical_seqlen_k_start - (physical_seqlen_k_start -
aligned_physical_seqlen_k_start)}, // M/N
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window(
v_dram_block_window_lengths,
{0, aligned_physical_seqlen_k_start}, // TODO: hdim split?
Policy::template MakeVDramTileDistribution<Problem>());
// store Q into LDS
__builtin_amdgcn_sched_barrier(0);
auto q_lds_window_for_store = make_tile_window(
q_lds, Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
store_tile(q_lds_window_for_store, origin_q);
__builtin_amdgcn_sched_barrier(0);
// load Q from LDS
__builtin_amdgcn_sched_barrier(0);
auto q_lds_window_for_load = make_tile_window(
q_lds,
Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(),
{0, 0},
Policy::template MakeQRegTileDistribution<Problem, decltype(gemm_0)>());
block_sync_lds();
auto q = load_tile(q_lds_window_for_load);
__builtin_amdgcn_sched_barrier(0);
auto q_tile = tile_elementwise_in(q_element_func, q);
// prefetch K tile
index_t i_total_loops = 0;
constexpr index_t k0_loops = kQKHeaddim / kK0;
constexpr index_t k1_loops = kN0 / kK1;
static_assert(2 <= k0_loops);
static_assert(1 <= k1_loops);
auto k_dram_window = make_tile_window(
k_dram_block_window,
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
// load the first tile of the first iteration and store to LDS
auto k_block_tile = load_tile(k_dram_window);
// moving k_dram_window is an in-page-block operation, so there is
// no need to invoke k_page_block_navigator.move_tile_window() here.
move_tile_window(k_dram_window, {0, kK0});
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
do
{
// STAGE 1, QK gemm
clear_tile(s_acc); // initialize C
// load the second tile of the first iteration
k_block_tile = load_tile(k_dram_window);
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
__builtin_amdgcn_sched_barrier(
0); // prevent from messing up the order of global loads
}
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
__builtin_amdgcn_sched_barrier(
0); // prevent from messing up the order of global loads
}
if constexpr(k0_loops > 2)
{
static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) {
block_sync_lds();
gemm_0(s_acc,
get_slice_tile(q_tile,
sequence<0, i_k0 * kK0>{},
sequence<kM0, (i_k0 + 1) * kK0>{}),
k_lds_window);
block_sync_lds();
move_tile_window(k_dram_window, {0, kK0});
store_tile(
k_lds_window,
tile_elementwise_in(k_element_func, k_block_tile)); // LDS write i + 1
k_block_tile = load_tile(k_dram_window); // global read i + 2
});
}
const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile
{ // tail
block_sync_lds();
gemm_0(s_acc,
get_slice_tile(q_tile,
sequence<0, (k0_loops - 2) * kK0>{},
sequence<kM0, (k0_loops - 1) * kK0>{}),
k_lds_window);
block_sync_lds();
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
block_sync_lds();
gemm_0(s_acc,
get_slice_tile(q_tile,
sequence<0, (k0_loops - 1) * kK0>{},
sequence<kM0, k0_loops * kK0>{}),
k_lds_window);
}
// STAGE 2, scale_s, add bias, mask, softmax
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
tile_elementwise_inout(
[&](auto& x, const auto& y) {
#if !CK_TILE_FMHA_FWD_FAST_EXP2
x += type_convert<SaccDataType>(bias_element_func(y));
#else
x += log2e_v<SaccDataType> *
type_convert<SaccDataType>(bias_element_func(y));
#endif
},
s_acc,
bias_tile);
}
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
const auto k_origin = k_page_block_navigator.to_global_window_origin(
i_page_block_k, k_dram_block_window.get_window_origin());
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
s_acc(i_j_idx) *= scale_s;
// position_encoding accept only logical coordinates, do conversion here
position_encoding.update(s_acc(i_j_idx), row, col - kv_l2p_offset);
});
});
}
else
{
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
#endif
}
move_tile_window(bias_dram_window, {0, kN0});
/// TODO: only check in first/last iteration without increasing code size
if constexpr(kHasUnevenSplits)
{
const auto k_origin = k_page_block_navigator.to_global_window_origin(
i_page_block_k, k_dram_block_window.get_window_origin());
set_tile_if(
s_acc,
-numeric<SMPLComputeDataType>::infinity(),
[&,
physical_seqlen_k_start_ = physical_seqlen_k_start,
physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) {
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
if constexpr(kIsPagedKV)
{
return col < physical_seqlen_k_start_ || physical_seqlen_k_end_ <= col;
}
else
{
return physical_seqlen_k_end_ <= col;
}
});
}
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
{
const auto k_origin = k_page_block_navigator.to_global_window_origin(
i_page_block_k, k_dram_block_window.get_window_origin());
// mask accept only logical coordinates, do conversion here
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
k_origin.at(number<0>{}) - kv_l2p_offset,
number<kM0>{},
number<kN0>{});
if(need_perpixel_check)
{
set_tile_if(
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col - kv_l2p_offset);
});
}
}
__builtin_amdgcn_sched_barrier(0);
// load the first tile for next iteration
if(i_total_loops < num_total_loop - 1)
{
// move K tile windows
i_page_block_k = k_page_block_navigator.move_tile_window(
i_page_block_k, k_dram_block_window, {kN0, 0});
k_dram_window = make_tile_window(
k_dram_block_window,
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window
// laod the first tile of the first iteration and store to LDS
k_block_tile = load_tile(k_dram_window);
}
__builtin_amdgcn_sched_barrier(0);
const auto s = cast_tile<SMPLComputeDataType>(s_acc); // S{j}
// shuffle through LDS so that the tile layout is consistent with required by Gemm1
store_tile(s_write_lds_window, s);
block_sync_lds();
auto s_new = load_tile(s_read_lds_window);
auto m_local = block_tile_reduce<SMPLComputeDataType>(
s_new,
sequence<1>{},
f_max,
-numeric<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
const auto m_old = m; // m{j-1}
tile_elementwise_inout(
[](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j}
auto p_compute = make_static_distributed_tensor<SMPLComputeDataType>(
s_new.get_tile_distribution()); // Pcompute{j}
static const auto get_validated_m = [](SMPLComputeDataType raw_m) {
/// NOTICE: bias might be materialized mask including -inf values, need
/// consideration
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
FmhaMask::IsMasking)
{
return raw_m == -numeric<SMPLComputeDataType>::infinity()
? type_convert<SMPLComputeDataType>(0.f)
: raw_m;
}
else
{
return raw_m;
}
};
constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2
auto row_max = scale_s * get_validated_m(m[i_idx]);
#endif
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
p_compute(i_j_idx) = exp2(s_new[i_j_idx] - get_validated_m(m[i_idx]));
}
else
{
p_compute(i_j_idx) = exp2(scale_s * s_new[i_j_idx] - row_max);
}
#else
p_compute(i_j_idx) = exp(s_new[i_j_idx] - get_validated_m(m[i_idx]));
#endif
});
});
auto rowsum_p = block_tile_reduce<SMPLComputeDataType>(
p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j})
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
const auto p =
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
// l{j}, Oacc{j}
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2
const auto tmp = [&]() {
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
}
else
{
auto row_max = scale_s * get_validated_m(m[i_idx]);
return exp2(scale_s * m_old[i_idx] - row_max);
}
}();
#else
const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx]));
#endif
l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
// FIXME: this use different equation from FA v2 paper,
// but produce correc result.
// Is the equation wrong?
o_acc(i_j_idx) *= tmp;
});
});
block_sync_lds();
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
shuffle_tile(v_shuffle_tmp, v_prefetch);
store_tile(
v_lds_window,
tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch
}
else
{
store_tile(v_lds_window,
tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch
}
i_page_block_v =
v_page_block_navigator.move_tile_window(i_page_block_v, v_dram_window, {0, kK1});
// STAGE 3, KV gemm
if constexpr(k1_loops > 1)
{
static_for<0, k1_loops - 1, 1>{}([&,
&i_page_block_v_ = i_page_block_v,
&v_dram_window_ = v_dram_window](auto i_k1) {
const auto v = load_tile(v_dram_window_); // load next v
block_sync_lds();
gemm_1(o_acc,
get_slice_tile(
p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
v_lds_window);
block_sync_lds();
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
shuffle_tile(v_shuffle_tmp, v);
store_tile(v_lds_window,
tile_elementwise_in(v_element_func,
v_shuffle_tmp)); // store the prefetch
}
else
{
store_tile(v_lds_window,
tile_elementwise_in(v_element_func, v)); // store next v
}
i_page_block_v_ = v_page_block_navigator.move_tile_window(
i_page_block_v_, v_dram_window_, {0, kK1});
});
}
// tail
{
block_sync_lds();
gemm_1(o_acc,
get_slice_tile(
p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, k1_loops * kK1>{}),
v_lds_window);
block_sync_lds();
}
__builtin_amdgcn_sched_barrier(0);
// load the first tile for next iteration
if(i_total_loops < num_total_loop - 1)
{
// store the first tile for next iteration to LDS
// moving k_dram_window is an in-page-block operation, so there is
// no need to invoke k_page_block_navigator.move_tile_window() here.
move_tile_window(k_dram_window, {0, kK0});
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
}
} while(++i_total_loops < num_total_loop);
if constexpr(kStoreLSE)
{
// store lse acc
auto lse_acc = make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
constexpr auto lse_acc_spans = decltype(lse_acc)::get_distributed_spans();
sweep_tile_span(lse_acc_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
lse_acc(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]);
}
else
{
lse_acc(i_idx) = m_[i_idx] * scale_s / C_LOG2E + log(l_[i_idx]);
}
#else
lse_acc(i_idx) = m_[i_idx] + log(l_[i_idx]);
#endif
});
if(get_thread_local_1d_id() < kM0)
{
store_tile(lse_acc_dram_window_tmp,
tile_elementwise_in(lse_acc_element_func, lse_acc));
}
}
// finally, O
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
const auto tmp = [&]() {
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
FmhaMask::IsMasking)
{
return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx];
}
else
return 1 / l[i_idx];
}();
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
o_acc(i_j_idx) *= tmp;
});
});
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
return o_acc;
}
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowLengths,
typename KPageBlockNavigator,
typename VDramBlockWindowLengths,
typename VPageBlockNavigator,
typename BiasDramBlockWindowTmp,
typename LSEaccDramBlockWindowTmp,
typename PositionEncoding>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowLengths& k_dram_block_window_lengths, // N0*K0 tile
const KPageBlockNavigator& k_page_block_navigator,
const VDramBlockWindowLengths& v_dram_block_window_lengths, // N1*K1 tile
const VPageBlockNavigator& v_page_block_navigator,
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
LSEaccDramBlockWindowTmp& lse_acc_dram_block_window_tmp, // M0*1 tile
index_t num_splits,
index_t i_split,
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate
void* smem_ptr) const
{
return operator()(q_dram_block_window_tmp,
identity{},
k_dram_block_window_lengths,
k_page_block_navigator,
identity{},
v_dram_block_window_lengths,
v_page_block_navigator,
identity{},
bias_dram_block_window_tmp,
identity{},
lse_acc_dram_block_window_tmp,
identity{},
identity{},
identity{},
identity{},
num_splits,
i_split,
mask,
position_encoding,
scale_s,
kv_l2p_offset,
smem_ptr);
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp"
namespace ck_tile {
// This pipeline is qkv all located in LDS
struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVSDefaultPolicy
: BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
/* AsyncCopyK = */ false,
/* AsyncCopyV = */ false,
/* NumPrefetchK = */ 1,
/* NumPrefetchV = */ 1>
{
using BasePolicy = BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
/* AsyncCopyK = */ false,
/* AsyncCopyV = */ false,
/* NumPrefetchK = */ 1,
/* NumPrefetchV = */ 1>;
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ()
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType);
// this should align with MakeQDramTileDistribution()
constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
static_assert(0 < ElemPerThread);
return min(ElemPerThread, MaxVectorSize);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentOacc()
{
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
return static_cast<index_t>(16 / sizeof(OaccDataType));
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution()
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType);
constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
static_assert(0 < ElemPerThread);
constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize);
constexpr index_t KPerThread = kMaxVecLoad;
constexpr index_t KThreads = kKPerBlock / KPerThread;
constexpr index_t MThreadPerWarp = get_warp_size() / KThreads;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<MPerThread, NumWarps, MThreadPerWarp>,
sequence<KThreads, KPerThread>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
template <typename Problem, typename BlockGemm>
CK_TILE_HOST_DEVICE static constexpr auto MakeQRegTileDistribution()
{
return BasePolicy::template MakeQDramTileDistribution<Problem, BlockGemm>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackQ()
{
// TODO: this is for 3d layout
using QDataType = remove_cvref_t<typename Problem::QDataType>;
return static_cast<index_t>(16 / sizeof(QDataType));
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor()
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
static_assert(0 < ElemPerThread);
constexpr index_t kKPack = min(ElemPerThread, GetSmemKPackQ<Problem>());
constexpr auto q_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / kKPack>{}, number<kMPerBlock>{}, number<kKPack>{}),
make_tuple(number<(kMPerBlock + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto q_lds_block_desc = transform_tensor_descriptor(
q_lds_block_desc_0,
make_tuple(
make_pass_through_transform(number<kMPerBlock>{}),
make_merge_transform(make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return q_lds_block_desc;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemNPackS()
{
using SDataType = remove_cvref_t<typename Problem::SaccDataType>;
return static_cast<index_t>(16 / sizeof(SDataType));
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeSLdsBlockDescriptor()
{
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kNPack = GetSmemNPackS<Problem>();
constexpr auto s_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kNPerBlock / kNPack>{}, number<kMPerBlock>{}, number<kNPack>{}),
make_tuple(number<(kMPerBlock + 1) * kNPack>{}, number<kNPack>{}, number<1>{}),
number<kNPack>{},
number<1>{});
constexpr auto s_lds_block_desc = transform_tensor_descriptor(
s_lds_block_desc_0,
make_tuple(
make_pass_through_transform(number<kMPerBlock>{}),
make_merge_transform(make_tuple(number<kNPerBlock / kNPack>{}, number<kNPack>{}))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return s_lds_block_desc;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeSRegTileDistribution()
{
using BlockGemm = remove_cvref_t<decltype(GetKVBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
static_assert(MWarp == 1, "Check failed!");
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t kTileK = Problem::BlockFmhaShape::kN0;
// K2 is equal to Impl::kABKPerLane * kKIterPerWarpGemm
constexpr index_t K3 = WG::kK / WG::WarpGemmAttribute::Impl::kABKLane;
constexpr index_t K2 = WG::WarpGemmAttribute::Impl::kABKLane;
constexpr index_t K1 = kKPerBlock / (K2 * K3);
constexpr index_t K0 = kTileK / kKPerBlock;
constexpr index_t M2 = WG::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t M1 = MWarp;
constexpr index_t M0 = kMPerBlock / (M2 * M1);
constexpr auto s2_block_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1, K2, K3>>,
tuple<sequence<1, 0>, sequence<2, 1>>,
tuple<sequence<1, 0>, sequence<2, 2>>,
sequence<1, 2, 2, 2>,
sequence<0, 0, 1, 3>>{};
constexpr auto s2_block_dstr = make_static_tile_distribution(s2_block_dstr_encoding);
return s2_block_dstr;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeQ()
{
return MakeQLdsBlockDescriptor<Problem>().get_element_space_size() *
sizeof(typename Problem::QDataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeK()
{
return MakeKLdsBlockDescriptor<Problem>().get_element_space_size() *
sizeof(typename Problem::KDataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeV()
{
return MakeVLdsBlockDescriptor<Problem>().get_element_space_size() *
sizeof(typename Problem::VDataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeS()
{
return MakeSLdsBlockDescriptor<Problem>().get_element_space_size() *
sizeof(typename Problem::SaccDataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return max(GetSmemSizeQ<Problem>(), GetSmemSizeK<Problem>()) +
max(GetSmemSizeV<Problem>(), GetSmemSizeS<Problem>());
}
};
} // namespace ck_tile
...@@ -94,40 +94,56 @@ struct BlockFmhaFwdSplitKVPipelineProblem ...@@ -94,40 +94,56 @@ struct BlockFmhaFwdSplitKVPipelineProblem
static constexpr bool kIsGroupMode = kIsGroupMode_; static constexpr bool kIsGroupMode = kIsGroupMode_;
// attributes from traits // attributes from traits
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK; static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ; static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
static constexpr auto BiasEnum = Traits::BiasEnum; static constexpr auto BiasEnum = Traits::BiasEnum;
static constexpr bool kStoreLSE = Traits::kStoreLSE; static constexpr bool kStoreLSE = Traits::kStoreLSE;
static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant; static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
static constexpr bool kIsPagedKV = Traits::kIsPagedKV; static constexpr bool kIsPagedKV = Traits::kIsPagedKV;
static constexpr bool kHasUnevenSplits = kIsGroupMode || Traits::kHasUnevenSplits; static constexpr bool kHasUnevenSplits = kIsGroupMode || Traits::kHasUnevenSplits;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; static constexpr bool kMergeNumHeadGroupsSeqLenQ = Traits::kMergeNumHeadGroupsSeqLenQ;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
};
// extract tile size attributes to remove dependency on traits
template <typename OaccDataType_, ck_tile::index_t kN1_>
struct BlockFmhaSplitKVCombinePipelineTileSizes
{
static constexpr index_t MaxVectorSize = 16 / sizeof(OaccDataType_);
static constexpr index_t kN1 = kN1_;
static constexpr index_t NThreads = kN1 / MaxVectorSize;
static constexpr index_t kM0 = get_warp_size() / NThreads; // MThreadPerWarp
}; };
template <typename LSEDataType_, template <typename LSEDataType_,
typename OaccDataType_, typename OaccDataType_,
typename ODataType_, typename ODataType_,
index_t HeadDimV_, index_t HeadDimV_,
index_t kM0_,
index_t kN1_,
bool kIsGroupMode_, bool kIsGroupMode_,
ck_tile::index_t kN1_,
typename Traits_> typename Traits_>
struct BlockFmhaSplitKVCombinePipelineProblem struct BlockFmhaSplitKVCombinePipelineProblem
: BlockFmhaSplitKVCombinePipelineTileSizes<OaccDataType_, kN1_>
{ {
using BaseType = BlockFmhaSplitKVCombinePipelineTileSizes<OaccDataType_, kN1_>;
using LSEDataType = remove_cvref_t<LSEDataType_>; using LSEDataType = remove_cvref_t<LSEDataType_>;
using OaccDataType = remove_cvref_t<OaccDataType_>; using OaccDataType = remove_cvref_t<OaccDataType_>;
using ODataType = remove_cvref_t<ODataType_>; using ODataType = remove_cvref_t<ODataType_>;
using Traits = remove_cvref_t<Traits_>; using Traits = remove_cvref_t<Traits_>;
static constexpr index_t kNumWarps = kM0_ / (get_warp_size() / 4); static_assert(std::is_same_v<LSEDataType, OaccDataType>);
static constexpr index_t kBlockSize = kNumWarps * get_warp_size();
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr index_t kHeadDimV = HeadDimV_; static constexpr index_t kHeadDimV = HeadDimV_;
static constexpr index_t kM0 = kM0_; static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr index_t kN1 = kN1_;
using BaseType::kM0;
using BaseType::kN1;
static_assert(kN1 <= kHeadDimV && kHeadDimV % kN1 == 0);
// attributes from traits // attributes from traits
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
...@@ -136,6 +152,13 @@ struct BlockFmhaSplitKVCombinePipelineProblem ...@@ -136,6 +152,13 @@ struct BlockFmhaSplitKVCombinePipelineProblem
static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant; static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
static constexpr index_t kMaxSplits = Traits::kMaxSplits; static constexpr index_t kMaxSplits = Traits::kMaxSplits;
static_assert(8 <= kMaxSplits);
static constexpr index_t kNumWarps = 4; // always use 4 warps for each workgroup
static constexpr index_t kBlockSize = kNumWarps * get_warp_size();
static_assert(get_warp_size() <= (kM0 * kMaxSplits) &&
(kM0 * kMaxSplits) % get_warp_size() == 0);
}; };
template <typename QDataType_, template <typename QDataType_,
......
...@@ -5,14 +5,14 @@ ...@@ -5,14 +5,14 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp"
namespace ck_tile { namespace ck_tile {
/// NOTICE: we no-longer use this pipeline.
// This pipeline is qkv all located in LDS // This pipeline is qkv all located in LDS
template <typename Problem_, typename Policy_ = BlockFmhaPipelineQSKSVSDefaultPolicy> template <typename Problem_, typename Policy_ = BlockFmhaPipelineQSKSVSDefaultPolicy>
struct [[deprecated]] BlockFmhaPipelineQSKSVS struct BlockFmhaPipelineQSKSVS
{ {
using Problem = remove_cvref_t<Problem_>; using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>; using Policy = remove_cvref_t<Policy_>;
...@@ -51,6 +51,24 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS ...@@ -51,6 +51,24 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kHasDropout = Problem::kHasDropout;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static constexpr index_t kAlignmentQ =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
static constexpr index_t kAlignmentK =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentV = []() {
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
return kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
else
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
}();
static constexpr index_t kAlignmentO =
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
static constexpr index_t kAlignmentBias =
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
static constexpr index_t kBlockPerCu = []() { static constexpr index_t kBlockPerCu = []() {
if constexpr(Problem::kBlockPerCu != -1) if constexpr(Problem::kBlockPerCu != -1)
...@@ -81,20 +99,18 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS ...@@ -81,20 +99,18 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
static constexpr const char* name = "qs"; static constexpr const char* name = "qs";
using DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout>;
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{ {
return Policy::template GetSmemSize<Problem>(); return Policy::template GetSmemSize<Problem>();
} }
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeQ()
{
return Policy::template GetSmemSizeQ<Problem>();
}
template <typename QDramBlockWindowTmp, template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp, typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp, typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp, typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp,
typename LSEDramBlockWindowTmp, typename LSEDramBlockWindowTmp,
typename QElementFunction, typename QElementFunction,
typename KElementFunction, typename KElementFunction,
...@@ -114,6 +130,7 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS ...@@ -114,6 +130,7 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
const VElementFunction& v_element_func, const VElementFunction& v_element_func,
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
const BiasElementFunction& bias_element_func, const BiasElementFunction& bias_element_func,
RandValDramBlockWindowTmp& /* unused_randval_dram_block_window_tmp */,
LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile
const LSEElementFunction& lse_element_func, const LSEElementFunction& lse_element_func,
const SAccElementFunction& s_acc_element_func, const SAccElementFunction& s_acc_element_func,
...@@ -122,7 +139,8 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS ...@@ -122,7 +139,8 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
FmhaMask mask, FmhaMask mask,
PositionEncoding position_encoding, PositionEncoding position_encoding,
float scale_s, float scale_s,
void* smem_ptr) const void* smem_ptr,
DropoutType& /* unused_dropout */) const
{ {
static_assert( static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> && std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
...@@ -222,11 +240,11 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS ...@@ -222,11 +240,11 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
{seqlen_k_start, 0}); {seqlen_k_start, 0});
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
auto bias_dram_window = make_tile_window( auto bias_dram_window =
bias_dram_block_window_tmp.get_bottom_tensor_view(), make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(), bias_dram_block_window_tmp.get_window_lengths(),
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
Policy::template MakeBiasDramTileDistribution<Problem, decltype(gemm_0)>()); Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
auto v_dram_window = auto v_dram_window =
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
...@@ -305,8 +323,7 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS ...@@ -305,8 +323,7 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
}); });
} }
const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile { // tail
{ // tail
block_sync_lds(); block_sync_lds();
gemm_0(s_acc, q_lds_window, k_lds_window); gemm_0(s_acc, q_lds_window, k_lds_window);
block_sync_lds(); block_sync_lds();
...@@ -318,6 +335,10 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS ...@@ -318,6 +335,10 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
gemm_0(s_acc, q_lds_window, k_lds_window); gemm_0(s_acc, q_lds_window, k_lds_window);
} }
__builtin_amdgcn_sched_barrier(0);
const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile
__builtin_amdgcn_sched_barrier(0);
// STAGE 2, scale_s, add bias, mask, softmax // STAGE 2, scale_s, add bias, mask, softmax
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{ {
...@@ -439,6 +460,12 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS ...@@ -439,6 +460,12 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j}) p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j})
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{}); block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
const auto p =
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
__builtin_amdgcn_sched_barrier(0);
// l{j}, Oacc{j} // l{j}, Oacc{j}
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
...@@ -486,9 +513,6 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS ...@@ -486,9 +513,6 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
} }
move_tile_window(v_dram_window, {0, kK1}); move_tile_window(v_dram_window, {0, kK1});
const auto p =
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
// STAGE 3, KV gemm // STAGE 3, KV gemm
if constexpr(k1_loops > 1) if constexpr(k1_loops > 1)
{ {
...@@ -583,6 +607,7 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS ...@@ -583,6 +607,7 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
typename KDramBlockWindowTmp, typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp, typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp, typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp,
typename LSEDramBlockWindowTmp, typename LSEDramBlockWindowTmp,
typename PositionEncoding> typename PositionEncoding>
CK_TILE_HOST_DEVICE auto CK_TILE_HOST_DEVICE auto
...@@ -590,11 +615,13 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS ...@@ -590,11 +615,13 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
FmhaMask mask, FmhaMask mask,
PositionEncoding position_encoding, PositionEncoding position_encoding,
float scale_s, float scale_s,
void* smem_ptr) const void* smem_ptr,
DropoutType& dropout) const
{ {
return operator()(q_dram_block_window_tmp, return operator()(q_dram_block_window_tmp,
identity{}, identity{},
...@@ -604,6 +631,7 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS ...@@ -604,6 +631,7 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
identity{}, identity{},
bias_dram_block_window_tmp, bias_dram_block_window_tmp,
identity{}, identity{},
randval_dram_block_window_tmp,
lse_dram_block_window_tmp, lse_dram_block_window_tmp,
identity{}, identity{},
identity{}, identity{},
...@@ -612,7 +640,8 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS ...@@ -612,7 +640,8 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
mask, mask,
position_encoding, position_encoding,
scale_s, scale_s,
smem_ptr); smem_ptr,
dropout);
} }
}; };
......
...@@ -9,11 +9,33 @@ ...@@ -9,11 +9,33 @@
namespace ck_tile { namespace ck_tile {
// This pipeline is qkv all located in LDS // This pipeline is qkv all located in LDS
using BlockFmhaPipelineQSKSVSDefaultPolicy = struct BlockFmhaPipelineQSKSVSDefaultPolicy
BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ false, : BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ false,
/* AsyncCopyK = */ false, /* AsyncCopyK = */ false,
/* AsyncCopyV = */ false, /* AsyncCopyV = */ false,
/* NumPrefetchK = */ 1, /* NumPrefetchK = */ 1,
/* NumPrefetchV = */ 1>; /* NumPrefetchV = */ 1>
{
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeK()
{
return MakeKLdsBlockDescriptor<Problem>().get_element_space_size() *
sizeof(typename Problem::KDataType);
} // namespace ck_tile
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeV()
{
return MakeVLdsBlockDescriptor<Problem>().get_element_space_size() *
sizeof(typename Problem::VDataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return max(GetSmemSizeQ<Problem>() + GetSmemSizeK<Problem>(), GetSmemSizeV<Problem>()) +
GetSmemSizeDropout<Problem>();
}
};
} // namespace ck_tile } // namespace ck_tile
...@@ -41,52 +41,21 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true> ...@@ -41,52 +41,21 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ() CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ()
{ {
constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType);
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>; using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>(); constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>; using WG = remove_cvref_t<decltype(config.template at<0>())>;
return WG::kK / WG::WarpGemmAttribute::Impl::kABKLane;
return min(MaxVectorSize, WG::kK / WG::WarpGemmAttribute::Impl::kABKLane);
} }
template <typename Problem, typename BlockGemm> template <typename Problem, typename BlockGemm>
CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution()
{ {
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>(); return BlockGemm::template MakeABlockTileDistribution<
using WG = remove_cvref_t<decltype(config.template at<0>())>; Problem::BlockFmhaShape::kM0,
constexpr index_t MWarp = config.template at<1>(); Problem::BlockFmhaShape::kSubQKHeaddim>();
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
constexpr index_t K2 = WG::kK / WG::WarpGemmAttribute::Impl::kABKLane;
constexpr index_t K1 = WG::WarpGemmAttribute::Impl::kABKLane;
constexpr index_t K0 = kKPerBlock / (K1 * K2);
constexpr index_t M2 = WG::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t M1 = MWarp;
constexpr index_t M0 = kMPerBlock / (M2 * M1);
if constexpr(1 < Problem::kNumGemm0Warps)
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1, K2>>,
tuple<sequence<1>, sequence<2, 1>>,
tuple<sequence<1>, sequence<1, 2>>,
sequence<1, 2, 2>,
sequence<0, 0, 2>>{});
}
else
{
static_assert(MWarp == 1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1, K2>>,
tuple<sequence<2, 1>>,
tuple<sequence<1, 2>>,
sequence<1, 2, 2>,
sequence<0, 0, 2>>{});
}
} }
template <typename Problem> template <typename Problem>
...@@ -105,7 +74,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true> ...@@ -105,7 +74,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
constexpr auto warp_gemm = []() { constexpr auto warp_gemm = []() {
constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}); constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{});
static_assert(WarpGemmM == 16 || WarpGemmM == 32); static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32);
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> && if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
std::is_same_v<typename Problem::KDataType, half_t> && std::is_same_v<typename Problem::KDataType, half_t> &&
...@@ -113,8 +82,10 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true> ...@@ -113,8 +82,10 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
{ {
if constexpr(WarpGemmM == 32) if constexpr(WarpGemmM == 32)
return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{}; return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{};
else // WarpGemmM == 16 else if constexpr(WarpGemmM == 16)
return WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}; return WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{};
else // WarpGemmM == 4
return WarpGemmMfmaF16F16F32M4N64K16{};
} }
else if constexpr(std::is_same_v<typename Problem::QDataType, bf16_t> && else if constexpr(std::is_same_v<typename Problem::QDataType, bf16_t> &&
std::is_same_v<typename Problem::KDataType, bf16_t> && std::is_same_v<typename Problem::KDataType, bf16_t> &&
...@@ -122,8 +93,10 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true> ...@@ -122,8 +93,10 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
{ {
if constexpr(WarpGemmM == 32) if constexpr(WarpGemmM == 32)
return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{}; return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{};
else // WarpGemmM == 16 else if constexpr(WarpGemmM == 16)
return WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{}; return WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{};
else // WarpGemmM == 4
return WarpGemmMfmaBf16Bf16F32M4N64K16{};
} }
else if constexpr(std::is_same_v<typename Problem::QDataType, fp8_t> && else if constexpr(std::is_same_v<typename Problem::QDataType, fp8_t> &&
std::is_same_v<typename Problem::KDataType, fp8_t> && std::is_same_v<typename Problem::KDataType, fp8_t> &&
...@@ -152,9 +125,8 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true> ...@@ -152,9 +125,8 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
} }
}; };
/// NOTICE: we no-longer use this policy.
template <> template <>
struct [[deprecated]] BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false> struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
{ {
static constexpr bool QLoadOnce = false; static constexpr bool QLoadOnce = false;
...@@ -174,8 +146,16 @@ struct [[deprecated]] BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false> ...@@ -174,8 +146,16 @@ struct [[deprecated]] BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ() CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ()
{ {
using QDataType = remove_cvref_t<typename Problem::QDataType>; constexpr index_t kBlockSize = Problem::kBlockSize;
return 16 / sizeof(QDataType); constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType);
// this should align with MakeQDramTileDistribution()
constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
static_assert(0 < ElemPerThread);
return min(ElemPerThread, MaxVectorSize);
} }
template <typename Problem> template <typename Problem>
...@@ -184,19 +164,25 @@ struct [[deprecated]] BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false> ...@@ -184,19 +164,25 @@ struct [[deprecated]] BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
using QDataType = remove_cvref_t<typename Problem::QDataType>; using QDataType = remove_cvref_t<typename Problem::QDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t K1 = 16 / sizeof(QDataType); // use dwordx4. TODO: change this constexpr index_t MaxVectorSize = 16 / sizeof(QDataType);
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t M2 = get_warp_size() / K0; constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
constexpr index_t M1 = kBlockSize / get_warp_size(); static_assert(0 < ElemPerThread);
constexpr index_t M0 = kMPerBlock / (M2 * M1); constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize);
constexpr index_t KPerThread = kMaxVecLoad;
constexpr index_t KThreads = kKPerBlock / KPerThread;
constexpr index_t MThreadPerWarp = get_warp_size() / KThreads;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
return make_static_tile_distribution( return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>, tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>, tuple<sequence<MPerThread, NumWarps, MThreadPerWarp>,
sequence<KThreads, KPerThread>>,
tuple<sequence<1>, sequence<1, 2>>, tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>, tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>, sequence<1, 2>,
...@@ -243,18 +229,31 @@ struct [[deprecated]] BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false> ...@@ -243,18 +229,31 @@ struct [[deprecated]] BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
typename Problem::BlockFmhaShape::Gemm0BlockWarps, typename Problem::BlockFmhaShape::Gemm0BlockWarps,
typename Problem::BlockFmhaShape::Gemm0WarpTile>>; typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{});
static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32);
constexpr auto warp_gemm = []() { constexpr auto warp_gemm = []() {
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> && if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
std::is_same_v<typename Problem::KDataType, half_t> && std::is_same_v<typename Problem::KDataType, half_t> &&
std::is_same_v<typename Problem::SaccDataType, float>) std::is_same_v<typename Problem::SaccDataType, float>)
{ {
return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{}; if constexpr(WarpGemmM == 32)
return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{};
else if constexpr(WarpGemmM == 16)
return WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{};
else // WarpGemmM == 4
return WarpGemmMfmaF16F16F32M4N64K16{};
} }
else if constexpr(std::is_same_v<typename Problem::QDataType, bf16_t> && else if constexpr(std::is_same_v<typename Problem::QDataType, bf16_t> &&
std::is_same_v<typename Problem::KDataType, bf16_t> && std::is_same_v<typename Problem::KDataType, bf16_t> &&
std::is_same_v<typename Problem::SaccDataType, float>) std::is_same_v<typename Problem::SaccDataType, float>)
{ {
return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{}; if constexpr(WarpGemmM == 32)
return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{};
else if constexpr(WarpGemmM == 16)
return WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{};
else // WarpGemmM == 4
return WarpGemmMfmaBf16Bf16F32M4N64K16{};
} }
else if constexpr(std::is_same_v<typename Problem::QDataType, fp8_t> && else if constexpr(std::is_same_v<typename Problem::QDataType, fp8_t> &&
std::is_same_v<typename Problem::KDataType, fp8_t> && std::is_same_v<typename Problem::KDataType, fp8_t> &&
......
...@@ -43,8 +43,6 @@ struct TileFmhaShape ...@@ -43,8 +43,6 @@ struct TileFmhaShape
static constexpr index_t NumWarps = max(NumGemm0Warps, NumGemm1Warps); static constexpr index_t NumWarps = max(NumGemm0Warps, NumGemm1Warps);
static_assert(std::is_same_v<Gemm0WarpTile, Gemm1WarpTile>);
static constexpr index_t kM0 = BlockTile::at(number<0>{}); // tile size along q seqlen static constexpr index_t kM0 = BlockTile::at(number<0>{}); // tile size along q seqlen
static constexpr index_t kN0 = BlockTile::at(number<1>{}); // tile size along k seqlen static constexpr index_t kN0 = BlockTile::at(number<1>{}); // tile size along k seqlen
static constexpr index_t kK0 = BlockTile::at(number<2>{}); // tile size along qk gemm unroll static constexpr index_t kK0 = BlockTile::at(number<2>{}); // tile size along qk gemm unroll
......
...@@ -43,7 +43,8 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */, ...@@ -43,7 +43,8 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kDoFp8StaticQuant_, bool kDoFp8StaticQuant_,
bool kIsPagedKV_, bool kIsPagedKV_,
bool kHasUnevenSplits_, bool kHasUnevenSplits_,
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */> bool kMergeNumHeadGroupsSeqLenQ_ = false,
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
struct TileFmhaFwdSplitKVTraits struct TileFmhaFwdSplitKVTraits
{ {
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_; static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
...@@ -56,8 +57,9 @@ struct TileFmhaFwdSplitKVTraits ...@@ -56,8 +57,9 @@ struct TileFmhaFwdSplitKVTraits
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
static constexpr bool kIsPagedKV = kIsPagedKV_; static constexpr bool kIsPagedKV = kIsPagedKV_;
// determine if some split (length) is not divisible by tile size // determine if some split (length) is not divisible by tile size
static constexpr bool kHasUnevenSplits = kHasUnevenSplits_; static constexpr bool kHasUnevenSplits = kHasUnevenSplits_;
static constexpr index_t kBlockPerCu = kBlockPerCu_; static constexpr bool kMergeNumHeadGroupsSeqLenQ = kMergeNumHeadGroupsSeqLenQ_;
static constexpr index_t kBlockPerCu = kBlockPerCu_;
}; };
template <bool kPadSeqLenQ_ /* padding for seqlen_q */, template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp" #include "ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp"
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_tile_partitioner.hpp" #include "ck_tile/ops/fused_moe/kernel/fused_moegemm_tile_partitioner.hpp"
#include "ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp" #include "ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp"
#include "ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_ex.hpp" #include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_ex.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp" #include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp" #include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp"
...@@ -14,6 +15,6 @@ ...@@ -14,6 +15,6 @@
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp" #include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp" #include "ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp" #include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5 // (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]] // weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
// //
// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1) // max_num_tokens_padded : topk * input_tokens + num_experts * M_a - topk (updated)
// * this could be larger than actual, since actual tokens are on GPU // * this could be larger than actual, since actual tokens are on GPU
// //
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5] // sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
...@@ -111,7 +111,7 @@ struct FusedMoeGemmHostArgs ...@@ -111,7 +111,7 @@ struct FusedMoeGemmHostArgs
const void* num_sorted_tiles_ptr; // [1] const void* num_sorted_tiles_ptr; // [1]
index_t hidden_size; // k index_t hidden_size; // k
index_t intermediate_size; // n / TP, for Gate. if Gate+Up, Down need divide by 2 index_t intermediate_size; // n / TP, for Gate/UP/Down
index_t num_tokens; // input number of tokens for current iteration index_t num_tokens; // input number of tokens for current iteration
index_t num_experts; // number of groups index_t num_experts; // number of groups
index_t topk; // need this? index_t topk; // need this?
...@@ -178,7 +178,7 @@ struct FusedMoeGemmKernel ...@@ -178,7 +178,7 @@ struct FusedMoeGemmKernel
return base_str; return base_str;
}(); }();
return _SS_("fused_moe_") + _SS_(prec_str) + "_" + return _SS_("fused_moe_") + _SS_(prec_str) + "_" + (IsGateOnly ? "g1u0_":"g1u1_") +
_TS_(S_::Block_M0) + "x" + _TS_(S_::Block_N0) + "x" + _TS_(S_::Block_K0) + "x" + _TS_(S_::Block_N1) + "_" + _TS_(S_::Block_M0) + "x" + _TS_(S_::Block_N0) + "x" + _TS_(S_::Block_K0) + "x" + _TS_(S_::Block_N1) + "_" +
_TS_(S_::WarpPerBlock_M0) + "x" + _TS_(S_::WarpPerBlock_N0) + "x" + _TS_(S_::WarpPerBlock_K0) + "_" + _TS_(S_::WarpPerBlock_M0) + "x" + _TS_(S_::WarpPerBlock_N0) + "x" + _TS_(S_::WarpPerBlock_K0) + "_" +
_TS_(S_::Warp_M0) + "x" + _TS_(S_::Warp_N0) + "x" + _TS_(S_::Warp_K0) + "_" + _SS_(Pipeline::name); _TS_(S_::Warp_M0) + "x" + _TS_(S_::Warp_N0) + "x" + _TS_(S_::Warp_K0) + "_" + _SS_(Pipeline::name);
...@@ -204,7 +204,7 @@ struct FusedMoeGemmKernel ...@@ -204,7 +204,7 @@ struct FusedMoeGemmKernel
const void* num_sorted_tiles_ptr; const void* num_sorted_tiles_ptr;
index_t hidden_size; // k index_t hidden_size; // k
index_t intermediate_size; // n / TP, for Gate. if Gate+Up, Down need divide by 2 index_t intermediate_size; // n / TP, for Gate/Up/Down
index_t num_tokens; // input number of tokens for current iteration index_t num_tokens; // input number of tokens for current iteration
index_t num_experts; // number of groups index_t num_experts; // number of groups
index_t topk; // need this? index_t topk; // need this?
...@@ -239,7 +239,7 @@ struct FusedMoeGemmKernel ...@@ -239,7 +239,7 @@ struct FusedMoeGemmKernel
{ {
if constexpr(UseUK) if constexpr(UseUK)
{ {
__shared__ CK_TILE_LDS_ADDR ADataType smem[GetSmemSize()]; __shared__ CK_TILE_LDS_ADDR char smem[GetSmemSize()];
IndexDataType num_sorted_tiles = __builtin_amdgcn_readfirstlane( IndexDataType num_sorted_tiles = __builtin_amdgcn_readfirstlane(
*reinterpret_cast<const IndexDataType*>(kargs.num_sorted_tiles_ptr)); *reinterpret_cast<const IndexDataType*>(kargs.num_sorted_tiles_ptr));
...@@ -298,6 +298,9 @@ struct FusedMoeGemmKernel ...@@ -298,6 +298,9 @@ struct FusedMoeGemmKernel
index_t token_id = index_t token_id =
reinterpret_cast<const index_t*>(kargs.sorted_token_ids_ptr)[sorted_token_id]; reinterpret_cast<const index_t*>(kargs.sorted_token_ids_ptr)[sorted_token_id];
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
token_id &= 0xffffff;
#endif
auto topk_weight = reinterpret_cast<const TopkWeightDataType*>( auto topk_weight = reinterpret_cast<const TopkWeightDataType*>(
kargs.sorted_weight_ptr)[sorted_token_id]; kargs.sorted_weight_ptr)[sorted_token_id];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment