Commit c881136b authored by Po Yen Chen's avatar Po Yen Chen
Browse files

Merge branch 'develop' into ck_tile/support-vllm-kcache-layout

parents c5e8e14f 4e076909
...@@ -29,6 +29,7 @@ struct static_distributed_tensor ...@@ -29,6 +29,7 @@ struct static_distributed_tensor
remove_cvref_t<decltype(StaticTileDistribution{}.get_ys_to_d_descriptor())>; remove_cvref_t<decltype(StaticTileDistribution{}.get_ys_to_d_descriptor())>;
static constexpr index_t kThreadElementSpaceSize = ThreadTensorDesc{}.get_element_space_size(); static constexpr index_t kThreadElementSpaceSize = ThreadTensorDesc{}.get_element_space_size();
static_assert(0 < kThreadElementSpaceSize, "Make sure tile distribution is valid");
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_dimension() CK_TILE_HOST_DEVICE static constexpr auto get_num_of_dimension()
{ {
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
...@@ -15,11 +15,14 @@ ...@@ -15,11 +15,14 @@
namespace ck_tile { namespace ck_tile {
/* /*
* a host side utility, arg parser for * a host side utility, arg parser for, either
* -[key0]=[value0] -[key1]=[value1] ... * -[key0] = [value0, value1, value2]
* or
* -[key0]=[value0] -[key1]=[value1] ...
*/ */
class ArgParser class ArgParser
{ {
public: public:
class Arg class Arg
{ {
...@@ -187,6 +190,45 @@ class ArgParser ...@@ -187,6 +190,45 @@ class ArgParser
return value; return value;
} }
std::vector<std::string> get_string_vec(const std::string& name,
const std::string& delimiter = ",") const
{
if(get_str(name).empty())
{
return {};
}
std::string s = get_str(name);
std::vector<std::string> tokens;
size_t pos = 0;
std::string token;
while((pos = s.find(delimiter)) != std::string::npos)
{
token = s.substr(0, pos);
tokens.push_back(token);
s.erase(0, pos + delimiter.length());
}
tokens.push_back(s);
return tokens;
}
std::vector<int> get_int_vec(const std::string& name, const std::string& delimiter = ",") const
{
if(get_str(name).empty())
{
return {};
}
const std::vector<std::string> args = get_string_vec(name, delimiter);
std::vector<int> tokens;
tokens.reserve(static_cast<int>(args.size()));
for(const std::string& token : args)
{
int value = atoi(token.c_str());
tokens.push_back(value);
}
return tokens;
}
private: private:
std::unordered_map<std::string, Arg> input_map; std::unordered_map<std::string, Arg> input_map;
std::vector<std::string> keys; std::vector<std::string> keys;
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -56,6 +56,13 @@ struct CShuffleEpilogue ...@@ -56,6 +56,13 @@ struct CShuffleEpilogue
// No additional shared memory needed // No additional shared memory needed
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; } CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; }
CK_TILE_HOST_DEVICE static constexpr bool IsOutputTransposed()
{
// TODO: At now CShuffle doesn't allow to vector store after permute.
// It should be fixed and this function should return true.
return false;
}
template <typename OAccTile> template <typename OAccTile>
CK_TILE_DEVICE void permute_tile_data(OAccTile& o_acc_tile) CK_TILE_DEVICE void permute_tile_data(OAccTile& o_acc_tile)
{ {
...@@ -111,7 +118,9 @@ struct CShuffleEpilogue ...@@ -111,7 +118,9 @@ struct CShuffleEpilogue
} }
} }
template <typename ODramWindowTmp, typename OAccTile> template <typename ODramWindowTmp,
typename OAccTile,
memory_operation_enum out_memory_data_op = memory_operation_enum::set>
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, OAccTile& o_acc_tile) CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, OAccTile& o_acc_tile)
{ {
const auto& current_window_origin = o_dram_window_tmp.get_window_origin(); const auto& current_window_origin = o_dram_window_tmp.get_window_origin();
...@@ -158,12 +167,26 @@ struct CShuffleEpilogue ...@@ -158,12 +167,26 @@ struct CShuffleEpilogue
// Store the tile data to the permuted location // Store the tile data to the permuted location
if constexpr(kPadM || kPadN) if constexpr(kPadM || kPadN)
{ {
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile)); if constexpr(out_memory_data_op == memory_operation_enum::set)
{
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
}
else
{
update_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
}
buffer_store_fence(); buffer_store_fence();
} }
else else
{ {
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile)); if constexpr(out_memory_data_op == memory_operation_enum::set)
{
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
}
else
{
update_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
}
} }
} }
}; };
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -35,21 +35,39 @@ struct Default2DEpilogue ...@@ -35,21 +35,39 @@ struct Default2DEpilogue
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; } CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; }
CK_TILE_HOST_DEVICE static constexpr bool IsOutputTransposed() { return false; }
// TODO: this function assume store out vector size is the same as OAccTile last dimension size // TODO: this function assume store out vector size is the same as OAccTile last dimension size
// how do we fix this ? // how do we fix this ?
template <typename ODramWindowTmp, typename OAccTile> template <typename ODramWindowTmp,
typename OAccTile,
memory_operation_enum out_memory_data_op = memory_operation_enum::set>
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile) CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile)
{ {
// TODO: this is ugly // TODO: this is ugly
if constexpr(UseRawStore && (kPadM || kPadN)) if constexpr(UseRawStore && (kPadM || kPadN))
{ {
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile)); if constexpr(out_memory_data_op == memory_operation_enum::set)
{
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
}
else
{
update_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
}
buffer_store_fence(); buffer_store_fence();
} }
else else
{ {
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile)); if constexpr(out_memory_data_op == memory_operation_enum::set)
{
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
}
else
{
update_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
}
} }
} }
}; };
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -14,9 +14,7 @@ ...@@ -14,9 +14,7 @@
#include "ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp"
...@@ -28,6 +26,8 @@ ...@@ -28,6 +26,8 @@
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async_default_policy.hpp"
......
...@@ -10,10 +10,9 @@ ...@@ -10,10 +10,9 @@
namespace ck_tile { namespace ck_tile {
template <typename TilePartitioner_, typename FmhaPipeline_> template <typename FmhaPipeline_>
struct FmhaFwdAppendKVKernel struct FmhaFwdAppendKVKernel
{ {
using TilePartitioner = ck_tile::remove_cvref_t<TilePartitioner_>;
using FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_>; using FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_>;
static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize; static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu; static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
...@@ -234,12 +233,25 @@ struct FmhaFwdAppendKVKernel ...@@ -234,12 +233,25 @@ struct FmhaFwdAppendKVKernel
return kargs; return kargs;
} }
__host__ static constexpr auto GridSize(ck_tile::index_t batch_size, CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size,
ck_tile::index_t nhead, ck_tile::index_t nhead,
ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_knew) ck_tile::index_t seqlen_knew)
{ {
return TilePartitioner::GridSize(batch_size, nhead, seqlen_q, seqlen_knew); // TODO: this may need tuning
return dim3(std::max(ck_tile::integer_divide_ceil(seqlen_q, FmhaPipeline::kM0),
ck_tile::integer_divide_ceil(seqlen_knew, FmhaPipeline::kN0)),
nhead,
batch_size);
}
CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& /* kargs */)
{
const index_t i_tile = blockIdx.x;
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
return ck_tile::make_tuple(i_tile, i_nhead, i_batch);
} }
__host__ static constexpr auto BlockSize() { return dim3(kBlockSize); } __host__ static constexpr auto BlockSize() { return dim3(kBlockSize); }
...@@ -247,7 +259,7 @@ struct FmhaFwdAppendKVKernel ...@@ -247,7 +259,7 @@ struct FmhaFwdAppendKVKernel
CK_TILE_DEVICE void operator()(Kargs kargs) const CK_TILE_DEVICE void operator()(Kargs kargs) const
{ {
// divide problem // divide problem
const auto [i_tile, i_nhead, i_batch] = TilePartitioner{}(); const auto [i_tile, i_nhead, i_batch] = GetTileIndex(kargs);
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile * FmhaPipeline::kM0); const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile * FmhaPipeline::kM0);
const index_t i_n0 = __builtin_amdgcn_readfirstlane(i_tile * FmhaPipeline::kN0); const index_t i_n0 = __builtin_amdgcn_readfirstlane(i_tile * FmhaPipeline::kN0);
......
...@@ -20,10 +20,9 @@ ...@@ -20,10 +20,9 @@
namespace ck_tile { namespace ck_tile {
template <typename TilePartitioner_, typename FmhaPipeline_, typename EpiloguePipeline_> template <typename FmhaPipeline_, typename EpiloguePipeline_>
struct FmhaFwdKernel struct FmhaFwdKernel
{ {
using TilePartitioner = ck_tile::remove_cvref_t<TilePartitioner_>;
using FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_>; using FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_>;
using EpiloguePipeline = ck_tile::remove_cvref_t<EpiloguePipeline_>; using EpiloguePipeline = ck_tile::remove_cvref_t<EpiloguePipeline_>;
static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize; static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
...@@ -71,7 +70,8 @@ struct FmhaFwdKernel ...@@ -71,7 +70,8 @@ struct FmhaFwdKernel
using bfs = typename FmhaPipeline::BlockFmhaShape; using bfs = typename FmhaPipeline::BlockFmhaShape;
using g0br = typename bfs::Gemm0BlockWarps; using g0br = typename bfs::Gemm0BlockWarps;
using g1br = typename bfs::Gemm1BlockWarps; using g1br = typename bfs::Gemm1BlockWarps;
using gwt = typename bfs::Gemm0WarpTile; using g0wt = typename bfs::Gemm0WarpTile;
using g1wt = typename bfs::Gemm1WarpTile;
#define _SS_ std::string #define _SS_ std::string
#define _TS_ std::to_string #define _TS_ std::to_string
auto pn = [&] () { auto pn = [&] () {
...@@ -83,12 +83,13 @@ struct FmhaFwdKernel ...@@ -83,12 +83,13 @@ struct FmhaFwdKernel
return n.empty() ? n : std::string("p") + n; }(); return n.empty() ? n : std::string("p") + n; }();
return return
_SS_("fmha_fwd_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s<QDataType>::name) + _SS_("fmha_fwd_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s<QDataType>::name) +
"_" + (kIsGroupMode ? "group" : "batch") + "_" + _SS_(TilePartitioner::name) + "_" "_" + (kIsGroupMode ? "group" : "batch") + "_"
"b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" + "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" +
_TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" + _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" +
"r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" + "r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" +
"r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" + "r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" +
"w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" + "w" + _TS_(g0wt::at(ck_tile::number<0>{})) + "x" + _TS_(g0wt::at(ck_tile::number<1>{})) + "x" + _TS_(g0wt::at(ck_tile::number<2>{})) + "_" +
"w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" +
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) + "v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) +
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
...@@ -865,9 +866,75 @@ struct FmhaFwdKernel ...@@ -865,9 +866,75 @@ struct FmhaFwdKernel
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
ck_tile::index_t nhead_, ck_tile::index_t nhead_,
ck_tile::index_t seqlen_q_, ck_tile::index_t seqlen_q_,
ck_tile::index_t hdim_v_) ck_tile::index_t hdim_v_,
bool has_padded_seqlen_k = false)
{ {
return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_q_, hdim_v_); // has_padded_seqlen_k is determined by checking (seqlen_k_ptr != nullptr)
if(has_padded_seqlen_k)
{
// TODO: this may need tuning
return dim3(nhead_,
batch_size_,
ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1));
}
else
{
// TODO: this may need tuning
return dim3(ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1),
nhead_,
batch_size_);
}
}
CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs)
{
bool has_padded_seqlen_k = false;
if constexpr(kIsGroupMode)
has_padded_seqlen_k = (kargs.seqlen_k_ptr != nullptr);
if(has_padded_seqlen_k)
{
// const index_t num_tile_m0 = seqlen_q / kM0;
const index_t num_tile_n1 =
ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
const index_t i_block = blockIdx.z;
const index_t i_nhead = blockIdx.x;
const index_t i_batch = blockIdx.y;
const auto f = [](index_t dividend, index_t divisor) {
index_t quotient = dividend / divisor;
index_t modulus = dividend - quotient * divisor;
return ck_tile::make_tuple(quotient, modulus);
};
const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
}
else
{
// const index_t num_tile_m0 = seqlen_q / kM0;
const index_t num_tile_n1 =
ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
const index_t i_block = blockIdx.x;
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
const auto f = [](index_t dividend, index_t divisor) {
index_t quotient = dividend / divisor;
index_t modulus = dividend - quotient * divisor;
return ck_tile::make_tuple(quotient, modulus);
};
const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
}
} }
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
...@@ -883,8 +950,7 @@ struct FmhaFwdKernel ...@@ -883,8 +950,7 @@ struct FmhaFwdKernel
__shared__ char smem_ptr[GetSmemSize()]; __shared__ char smem_ptr[GetSmemSize()];
// divide problem // divide problem
const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);
TilePartitioner{}(kargs.seqlen_q, kargs.hdim_v);
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0); const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0);
const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
......
...@@ -5,12 +5,13 @@ ...@@ -5,12 +5,13 @@
namespace ck_tile { namespace ck_tile {
template <typename TilePartitioner_, typename FmhaPipeline_, typename EpiloguePipeline_> template <typename FmhaPipeline_, typename EpiloguePipeline_>
struct FmhaFwdSplitKVCombineKernel struct FmhaFwdSplitKVCombineKernel
{ {
using TilePartitioner = remove_cvref_t<TilePartitioner_>; using FmhaPipeline = remove_cvref_t<FmhaPipeline_>;
using FmhaPipeline = remove_cvref_t<FmhaPipeline_>; using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
static constexpr index_t kNumWarps = FmhaPipeline::kNumWarps;
static constexpr index_t kBlockSize = FmhaPipeline::kBlockSize; static constexpr index_t kBlockSize = FmhaPipeline::kBlockSize;
static constexpr index_t kBlockPerCu = FmhaPipeline::kBlockPerCu; static constexpr index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
static_assert(kBlockPerCu > 0); static_assert(kBlockPerCu > 0);
...@@ -50,8 +51,7 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -50,8 +51,7 @@ struct FmhaFwdSplitKVCombineKernel
return return
_SS_("fmha_fwd_splitkv_combine_d") + _TS_(FmhaPipeline::kHeadDimV) + "_" + _SS_(t2s<ODataType>::name) + _SS_("fmha_fwd_splitkv_combine_d") + _TS_(FmhaPipeline::kHeadDimV) + "_" + _SS_(t2s<ODataType>::name) +
"_" + (kIsGroupMode ? "group" : "batch") + "_" "_" + (kIsGroupMode ? "group" : "batch") + "_"
"b" + _TS_(FmhaPipeline::kM0) + "x" + "b" + _TS_(FmhaPipeline::kN1) + "_" +
_TS_(FmhaPipeline::kN1) + "_" +
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) +
_SS_(FmhaPipeline::name) + _SS_(FmhaPipeline::name) +
(pn.empty() ? "" : "_" + pn) + (pn.empty() ? "" : "_" + pn) +
...@@ -234,12 +234,35 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -234,12 +234,35 @@ struct FmhaFwdSplitKVCombineKernel
return kargs; return kargs;
} }
__host__ static constexpr auto GridSize(ck_tile::index_t batch_size, CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size,
ck_tile::index_t nhead, ck_tile::index_t nhead,
ck_tile::index_t max_seqlen_q, ck_tile::index_t max_seqlen_q,
ck_tile::index_t hdim_v) ck_tile::index_t hdim_v)
{
// TODO: this may need tuning
return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, FmhaPipeline::kM0) *
ck_tile::integer_divide_ceil(hdim_v, FmhaPipeline::kN1),
nhead,
batch_size);
}
CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs)
{ {
return TilePartitioner::GridSize(batch_size, nhead, max_seqlen_q, hdim_v); const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
const index_t i_block = blockIdx.x;
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
const auto f = [](index_t dividend, index_t divisor) {
index_t quotient = dividend / divisor;
index_t modulus = dividend - quotient * divisor;
return ck_tile::make_tuple(quotient, modulus);
};
const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
} }
__host__ static constexpr auto BlockSize() { return dim3(kBlockSize); } __host__ static constexpr auto BlockSize() { return dim3(kBlockSize); }
...@@ -255,8 +278,7 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -255,8 +278,7 @@ struct FmhaFwdSplitKVCombineKernel
__shared__ char smem_ptr[GetSmemSize()]; __shared__ char smem_ptr[GetSmemSize()];
// divide problem // divide problem
const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);
TilePartitioner{}(kargs.seqlen_q, kargs.hdim_v);
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0); const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0);
const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
...@@ -339,37 +361,56 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -339,37 +361,56 @@ struct FmhaFwdSplitKVCombineKernel
number<FmhaPipeline::kAlignmentOacc>{}, number<FmhaPipeline::kAlignmentOacc>{},
number<1>{}); number<1>{});
// read 4 * (kM0, kN1) o_acc tiles simultaneously by 4 warps
const auto o_acc_dram_view = pad_tensor_view( const auto o_acc_dram_view = pad_tensor_view(
o_acc_dram_naive, o_acc_dram_naive,
make_tuple(number<1>{}, number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}), make_tuple(
sequence<false, kPadSeqLenQ, kPadHeadDimV>{}); number<kNumWarps>{}, number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
sequence<true, kPadSeqLenQ, kPadHeadDimV>{});
const index_t padded_num_splits =
o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<0>{}];
const index_t padded_seqlen_q = const index_t padded_seqlen_q =
o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<1>{}]; o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<1>{}];
const index_t padded_hdim_v = const index_t padded_hdim_v =
o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<2>{}]; o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<2>{}];
return transform_tensor_view( const index_t num_m_tiles = integer_divide_floor(padded_seqlen_q, FmhaPipeline::kM0);
// transform tensor view by following steps, given shape: (padded_num_splits,
// padded_seqlen_q, padded_hdim_v)
// 1. unmerge to (padded_num_splits, num_m_tiles, kM0, padded_hdim_v)
// 2. transpose to (num_m_tiles, padded_num_splits, kM0, padded_hdim_v)
// 3. merge to (num_m_tiles * padded_num_splits * kM0, padded_hdim_v)
auto transposed = transform_tensor_view(
o_acc_dram_view, o_acc_dram_view,
make_tuple(make_merge_transform(make_tuple(kargs.num_splits, padded_seqlen_q)), make_tuple(make_pass_through_transform(padded_num_splits),
make_unmerge_transform(make_tuple(num_m_tiles, FmhaPipeline::kM0)),
make_pass_through_transform(padded_hdim_v)), make_pass_through_transform(padded_hdim_v)),
make_tuple(sequence<0, 1>{}, sequence<2>{}), make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<1>{}, sequence<0, 2>{}, sequence<3>{}));
return transform_tensor_view(
transposed,
make_tuple(make_merge_transform(
make_tuple(num_m_tiles, padded_num_splits, FmhaPipeline::kM0)),
make_pass_through_transform(padded_hdim_v)),
make_tuple(sequence<0, 1, 2>{}, sequence<3>{}),
make_tuple(sequence<0>{}, sequence<1>{})); make_tuple(sequence<0>{}, sequence<1>{}));
}(); }();
auto lse_acc_dram_window = make_tile_window( auto lse_acc_dram_window = make_tile_window(
lse_acc_dram, lse_acc_dram,
[&]() { make_tuple(number<FmhaPipeline::kMaxSplits>{}, number<FmhaPipeline::kM0>{}),
return make_tuple(number<FmhaPipeline::kMaxSplits>{}, number<FmhaPipeline::kM0>{});
}(),
{0, i_m0}); {0, i_m0});
const index_t padded_num_splits =
integer_divide_ceil(kargs.num_splits, kNumWarps) * kNumWarps;
auto o_acc_dram_window = make_tile_window( auto o_acc_dram_window = make_tile_window(
o_acc_dram, o_acc_dram,
[&]() { make_tuple(number<kNumWarps * FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
return make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}); {i_tile_m * padded_num_splits * FmhaPipeline::kM0, i_n1});
}(),
{i_m0, i_n1});
// LSE DRAM window // LSE DRAM window
auto lse_dram_window = [&, i_nhead_ = i_nhead]() { auto lse_dram_window = [&, i_nhead_ = i_nhead]() {
...@@ -410,7 +451,6 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -410,7 +451,6 @@ struct FmhaFwdSplitKVCombineKernel
identity{}, // lse_element_func identity{}, // lse_element_func
composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func
kargs.num_splits, kargs.num_splits,
kargs.seqlen_q,
smem_ptr); smem_ptr);
} }
else else
...@@ -419,7 +459,6 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -419,7 +459,6 @@ struct FmhaFwdSplitKVCombineKernel
o_acc_dram_window, o_acc_dram_window,
lse_dram_window, lse_dram_window,
kargs.num_splits, kargs.num_splits,
kargs.seqlen_q,
smem_ptr); smem_ptr);
} }
}(); }();
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <index_t kM0_, index_t kN1_>
struct FmhaFwdSplitKVCombineTilePartitioner
{
static constexpr ck_tile::index_t kM0 = kM0_;
static constexpr ck_tile::index_t kN1 = kN1_;
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size,
ck_tile::index_t nhead,
ck_tile::index_t max_seqlen_q,
ck_tile::index_t hdim_v)
{
// TODO: this may need tuning
return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, kM0) *
ck_tile::integer_divide_ceil(hdim_v, kN1),
nhead,
batch_size);
}
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t hdim_v)
{
const index_t num_tile_n1 = ck_tile::integer_divide_ceil(hdim_v, kN1);
const index_t i_block = blockIdx.x;
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
const auto f = [](index_t dividend, index_t divisor) {
index_t quotient = dividend / divisor;
index_t modulus = dividend - quotient * divisor;
return ck_tile::make_tuple(quotient, modulus);
};
const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
}
};
} // namespace ck_tile
...@@ -44,6 +44,7 @@ struct FmhaFwdSplitKVKernel ...@@ -44,6 +44,7 @@ struct FmhaFwdSplitKVKernel
static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ; static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
static constexpr auto BiasEnum = FmhaPipeline::BiasEnum; static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant; static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV; static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV;
...@@ -66,7 +67,8 @@ struct FmhaFwdSplitKVKernel ...@@ -66,7 +67,8 @@ struct FmhaFwdSplitKVKernel
using bfs = typename FmhaPipeline::BlockFmhaShape; using bfs = typename FmhaPipeline::BlockFmhaShape;
using g0br = typename bfs::Gemm0BlockWarps; using g0br = typename bfs::Gemm0BlockWarps;
using g1br = typename bfs::Gemm1BlockWarps; using g1br = typename bfs::Gemm1BlockWarps;
using gwt = typename bfs::Gemm0WarpTile; using g0wt = typename bfs::Gemm0WarpTile;
using g1wt = typename bfs::Gemm1WarpTile;
#define _SS_ std::string #define _SS_ std::string
#define _TS_ std::to_string #define _TS_ std::to_string
auto pn = [&] () { auto pn = [&] () {
...@@ -83,11 +85,12 @@ struct FmhaFwdSplitKVKernel ...@@ -83,11 +85,12 @@ struct FmhaFwdSplitKVKernel
_TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" + _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" +
"r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" + "r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" +
"r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" + "r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" +
"w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" + "w" + _TS_(g0wt::at(ck_tile::number<0>{})) + "x" + _TS_(g0wt::at(ck_tile::number<1>{})) + "x" + _TS_(g0wt::at(ck_tile::number<2>{})) + "_" +
"w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" +
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) + "v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) +
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
(kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kDoFp8StaticQuant ? "_squant" : "") + (kIsPagedKV ? "_pagedkv" : "" ); (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kStoreLSE ? "_lse" : "" ) + (kDoFp8StaticQuant ? "_squant" : "") + (kIsPagedKV ? "_pagedkv" : "" );
#undef _SS_ #undef _SS_
#undef _TS_ #undef _TS_
// clang-format on // clang-format on
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename BlockFmhaShape_>
struct FmhaFwdTilePartitioner
{
using BlockFmhaShape = ck_tile::remove_cvref_t<BlockFmhaShape_>;
static constexpr ck_tile::index_t kM0 = BlockFmhaShape::kM0;
static constexpr ck_tile::index_t kN0 = BlockFmhaShape::kN0;
static constexpr ck_tile::index_t kK0 = BlockFmhaShape::kK0;
static constexpr ck_tile::index_t kN1 = BlockFmhaShape::kN1;
static constexpr ck_tile::index_t kK1 = BlockFmhaShape::kK1;
static constexpr const char* name = "shb";
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
ck_tile::index_t nhead_,
ck_tile::index_t seqlen_q_,
ck_tile::index_t hdim_v_)
{
// TODO: this may need tuning
return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0) *
ck_tile::integer_divide_ceil(hdim_v_, kN1),
nhead_,
batch_size_);
}
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t hdim_v)
{
// const index_t num_tile_m0 = seqlen_q / kM0;
const index_t num_tile_n1 = ck_tile::integer_divide_ceil(hdim_v, kN1);
const index_t i_block = blockIdx.x;
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
const auto f = [](index_t dividend, index_t divisor) {
index_t quotient = dividend / divisor;
index_t modulus = dividend - quotient * divisor;
return ck_tile::make_tuple(quotient, modulus);
};
const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
}
};
template <typename BlockFmhaShape_>
using FmhaFwdTilePartitioner_SHB = FmhaFwdTilePartitioner<BlockFmhaShape_>;
template <typename BlockFmhaShape_>
struct FmhaFwdTilePartitioner_HBS
{
using BlockFmhaShape = ck_tile::remove_cvref_t<BlockFmhaShape_>;
static constexpr ck_tile::index_t kM0 = BlockFmhaShape::kM0;
static constexpr ck_tile::index_t kN0 = BlockFmhaShape::kN0;
static constexpr ck_tile::index_t kK0 = BlockFmhaShape::kK0;
static constexpr ck_tile::index_t kN1 = BlockFmhaShape::kN1;
static constexpr ck_tile::index_t kK1 = BlockFmhaShape::kK1;
static constexpr const char* name = "hbs";
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
ck_tile::index_t nhead_,
ck_tile::index_t seqlen_q_,
ck_tile::index_t hdim_v_)
{
// TODO: this may need tuning
return dim3(nhead_,
batch_size_,
ck_tile::integer_divide_ceil(seqlen_q_, kM0) *
ck_tile::integer_divide_ceil(hdim_v_, kN1));
}
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t hdim_v)
{
// const index_t num_tile_m0 = seqlen_q / kM0;
const index_t num_tile_n1 = ck_tile::integer_divide_ceil(hdim_v, kN1);
const index_t i_block = blockIdx.z;
const index_t i_nhead = blockIdx.x;
const index_t i_batch = blockIdx.y;
const auto f = [](index_t dividend, index_t divisor) {
index_t quotient = dividend / divisor;
index_t modulus = dividend - quotient * divisor;
return ck_tile::make_tuple(quotient, modulus);
};
const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
}
};
} // namespace ck_tile
...@@ -53,6 +53,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline ...@@ -53,6 +53,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>; using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>; using ODataType = remove_cvref_t<typename Problem::ODataType>;
static constexpr index_t kNumWarps = Problem::kNumWarps;
static constexpr index_t kBlockSize = Problem::kBlockSize; static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kHeadDimV = Problem::kHeadDimV; static constexpr index_t kHeadDimV = Problem::kHeadDimV;
...@@ -117,7 +118,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline ...@@ -117,7 +118,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline
const LSEElementFunction& lse_element_func, const LSEElementFunction& lse_element_func,
const OaccElementFunction& o_acc_element_func, const OaccElementFunction& o_acc_element_func,
index_t num_splits, index_t num_splits,
index_t seqlen_q,
void* smem_ptr) const void* smem_ptr) const
{ {
// lse_acc tile in LDS // lse_acc tile in LDS
...@@ -143,11 +143,12 @@ struct BlockFmhaFwdSplitKVCombinePipeline ...@@ -143,11 +143,12 @@ struct BlockFmhaFwdSplitKVCombinePipeline
// copy lse_acc tile (shape=[kMaxSplits, kM0]) to LDS (shape=[kMaxSplits, kM0]). // copy lse_acc tile (shape=[kMaxSplits, kM0]) to LDS (shape=[kMaxSplits, kM0]).
auto lse_acc_tile = load_tile(lse_acc_dram_window); auto lse_acc_tile = load_tile(lse_acc_dram_window);
store_tile(lse_acc_lds_write_window, lse_acc_tile); store_tile(lse_acc_lds_write_window, lse_acc_tile);
block_sync_lds();
auto lse_accum = make_static_distributed_tensor<LSEDataType>( auto lse_accum = make_static_distributed_tensor<LSEDataType>(
Policy::template MakeLSEaccRegTileDistribution<Problem>()); Policy::template MakeLSEaccRegTileDistribution<Problem>());
__builtin_amdgcn_sched_barrier(0);
block_sync_lds();
// copy LDS (shape=[kM0, kMaxSplits]) to lse_accum (shape=[kM0, kMaxSplits]) // copy LDS (shape=[kM0, kMaxSplits]) to lse_accum (shape=[kM0, kMaxSplits])
// and fill up -INF values outside the [kM0, num_splits] region. // and fill up -INF values outside the [kM0, num_splits] region.
{ {
...@@ -264,46 +265,94 @@ struct BlockFmhaFwdSplitKVCombinePipeline ...@@ -264,46 +265,94 @@ struct BlockFmhaFwdSplitKVCombinePipeline
} }
}); });
} }
block_sync_lds();
if constexpr(kStoreLSE) if constexpr(kStoreLSE)
{ {
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse_logsum)); store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse_logsum));
} }
auto o_acc_dist = Policy::template MakeOaccDramTileDistribution<Problem>(); auto o_acc_4_dist = Policy::template MakeOacc4DramTileDistribution<Problem>();
auto o_acc_dram_window = auto o_acc_4_dram_window =
make_tile_window(o_acc_dram_block_window_tmp.get_bottom_tensor_view(), make_tile_window(o_acc_dram_block_window_tmp.get_bottom_tensor_view(),
o_acc_dram_block_window_tmp.get_window_lengths(), o_acc_dram_block_window_tmp.get_window_lengths(),
o_acc_dram_block_window_tmp.get_window_origin(), o_acc_dram_block_window_tmp.get_window_origin(),
o_acc_dist); o_acc_4_dist);
auto o_acc = make_static_distributed_tensor<OaccDataType>(o_acc_dist);
clear_tile(o_acc);
const index_t padded_seqlen_q = integer_divide_ceil(seqlen_q, kM0) * kM0; // shape=[4 * KM0, kN1]
auto o_acc_4 = make_static_distributed_tensor<OaccDataType>(o_acc_4_dist);
clear_tile(o_acc_4);
for(index_t i_split = 0; i_split < num_splits; ++i_split) const index_t padded_num_splits = integer_divide_ceil(num_splits, kNumWarps) * kNumWarps;
__builtin_amdgcn_sched_barrier(0);
block_sync_lds();
// each warp handles a [KM0, kN1] tile
for(index_t split_start = 0; split_start < padded_num_splits; split_start += kNumWarps)
{ {
auto o_tile = load_tile(o_acc_dram_window); auto o_tile = load_tile(o_acc_4_dram_window);
const index_t i_split = split_start + get_warp_id();
const index_t row_start = kM0 * get_warp_id();
{ {
constexpr auto spans = decltype(o_acc)::get_distributed_spans(); constexpr auto spans = decltype(o_acc_4)::get_distributed_spans();
sweep_tile_span(spans[number<0>{}], [&](auto idx0) { sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(spans[number<1>{}], [&](auto idx1) { sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1); constexpr auto i_j_idx = make_tuple(idx0, idx1);
const auto x_indices = get_x_indices_from_distributed_indices( const auto x_indices = get_x_indices_from_distributed_indices(
o_acc.get_tile_distribution(), i_j_idx); o_acc_4.get_tile_distribution(), i_j_idx);
const auto row = x_indices.at(number<0>{}); const auto row = x_indices.at(number<0>{});
const LSEDataType lse_scale = lse_acc_lds(row, i_split); const LSEDataType lse_scale = lse_acc_lds(row - row_start, i_split);
o_acc(i_j_idx) += lse_scale * o_tile(i_j_idx); o_acc_4(i_j_idx) += lse_scale * o_tile(i_j_idx);
}); });
}); });
} }
move_tile_window(o_acc_dram_window, {padded_seqlen_q, 0}); move_tile_window(o_acc_4_dram_window, {kNumWarps * kM0, 0});
}
// 4 o_acc tiles in LDS. shape=[4 * kM0, kN1]
OaccDataType* o_acc_4_lds_ptr = static_cast<OaccDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeLSEacc<Problem>()));
{
auto o_acc_4_lds_window = [&]() {
auto desc = Policy::template MakeOacc4LdsBlockDescriptor<Problem>();
auto view = make_tensor_view<address_space_enum::lds>(o_acc_4_lds_ptr, desc);
return make_tile_window(view, desc.get_lengths(), {0, 0});
}();
store_tile(o_acc_4_lds_window, o_acc_4);
} }
auto o_acc_dist = Policy::template MakeOaccDramTileDistribution<Problem>();
auto o_acc_4_lds_window = [&]() {
auto desc = Policy::template MakeOacc4LdsBlockDescriptor<Problem>();
auto view = make_tensor_view<address_space_enum::lds>(o_acc_4_lds_ptr, desc);
return make_tile_window(view, desc.get_lengths(), {0, 0}, o_acc_dist);
}();
auto o_acc = make_static_distributed_tensor<OaccDataType>(o_acc_dist);
clear_tile(o_acc);
__builtin_amdgcn_sched_barrier(0);
block_sync_lds();
static_for<0, kNumWarps, 1>{}([&](auto) {
auto o_acc_in = load_tile(o_acc_4_lds_window);
{
constexpr auto spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
o_acc(i_j_idx) += o_acc_in(i_j_idx);
});
});
}
move_tile_window(o_acc_4_lds_window, {kM0, 0});
});
o_acc = tile_elementwise_in(o_acc_element_func, o_acc); o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
return o_acc; return o_acc;
...@@ -316,7 +365,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline ...@@ -316,7 +365,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline
const OaccDramBlockWindow& o_acc_dram_block_window, const OaccDramBlockWindow& o_acc_dram_block_window,
LSEDramBlockWindow& lse_dram_block_window, LSEDramBlockWindow& lse_dram_block_window,
index_t num_splits, index_t num_splits,
index_t seqlen_q,
void* smem_ptr) const void* smem_ptr) const
{ {
return operator()(lse_acc_dram_block_window, return operator()(lse_acc_dram_block_window,
...@@ -325,7 +373,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline ...@@ -325,7 +373,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline
identity{}, identity{},
identity{}, identity{},
num_splits, num_splits,
seqlen_q,
smem_ptr); smem_ptr);
} }
}; };
......
...@@ -10,23 +10,38 @@ namespace ck_tile { ...@@ -10,23 +10,38 @@ namespace ck_tile {
struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
{ {
template <index_t BlockSize, index_t M, index_t N, typename DataType> template <index_t NumWarps, index_t M, index_t N, typename DataType>
CK_TILE_HOST_DEVICE static constexpr auto GetMaxNumWarpsForTile()
{
static_assert(NumWarps == 1 || NumWarps == 2 || NumWarps == 4);
constexpr index_t ElemPerThread = (M * N) / (NumWarps * get_warp_size());
if constexpr(0 < ElemPerThread)
{
return NumWarps;
}
else
{ // try dividing tile by smaller # of warps
return GetMaxNumWarpsForTile<NumWarps / 2, M, N, DataType>();
}
}
template <index_t NumWarps, index_t M, index_t N, typename DataType>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeForTile() CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeForTile()
{ {
constexpr index_t PixelsPerThread = (M * N) / BlockSize; constexpr index_t MaxNumWarps = GetMaxNumWarpsForTile<NumWarps, M, N, DataType>();
static_assert(0 < PixelsPerThread);
constexpr index_t MaxNPerThread = 16 / sizeof(DataType); constexpr index_t ElemPerThread = (M * N) / (MaxNumWarps * get_warp_size());
constexpr index_t NPerThread = min(MaxNPerThread, PixelsPerThread);
return NPerThread; constexpr index_t MaxNPerThread = 16 / sizeof(DataType);
return min(MaxNPerThread, ElemPerThread);
} }
// alignment for dram lse tile (shape=[kMaxSplits, kM0]) // alignment for dram lse tile (shape=[kMaxSplits, kM0])
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentLSE() CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentLSE()
{ {
return GetVectorSizeForTile<Problem::kBlockSize, return GetVectorSizeForTile<Problem::kNumWarps,
Problem::kMaxSplits, Problem::kMaxSplits,
Problem::kM0, Problem::kM0,
typename Problem::LSEDataType>(); typename Problem::LSEDataType>();
...@@ -56,40 +71,54 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy ...@@ -56,40 +71,54 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeLSEacc()
{ {
return sizeof(typename Problem::LSEDataType) * return sizeof(typename Problem::LSEDataType) *
MakeLSEaccLdsBlockDescriptor<Problem>().get_element_space_size(); MakeLSEaccLdsBlockDescriptor<Problem>().get_element_space_size();
} }
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeOacc4()
{
return sizeof(typename Problem::OaccDataType) *
MakeOacc4LdsBlockDescriptor<Problem>().get_element_space_size();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return GetSmemSizeLSEacc<Problem>() + GetSmemSizeOacc4<Problem>();
}
// shape=[kMaxSplits, kM0] // shape=[kMaxSplits, kM0]
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccDramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccDramTileDistribution()
{ {
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>; using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNumWarps = Problem::kNumWarps;
constexpr index_t kNPerBlock = Problem::kM0;
constexpr index_t kMPerBlock = Problem::kMaxSplits; constexpr index_t kMPerBlock = Problem::kMaxSplits;
constexpr index_t kNPerBlock = Problem::kM0;
constexpr index_t MaxNumWarps =
GetMaxNumWarpsForTile<Problem::kNumWarps, kNPerBlock, kMPerBlock, LSEDataType>();
constexpr index_t Replicate = Problem::kNumWarps / MaxNumWarps;
constexpr index_t NPerThread = constexpr index_t NPerThread =
GetVectorSizeForTile<kBlockSize, kMPerBlock, kNPerBlock, LSEDataType>(); GetVectorSizeForTile<MaxNumWarps, kMPerBlock, kNPerBlock, LSEDataType>();
constexpr index_t NThreads = kNPerBlock / NPerThread; constexpr index_t NThreads = kNPerBlock / NPerThread;
constexpr index_t MThreadsPerWarp = get_warp_size() / NThreads; constexpr index_t MThreadsPerWarp = get_warp_size() / NThreads;
constexpr index_t MPerThread = kMPerBlock / (kNumWarps * MThreadsPerWarp); constexpr index_t MPerThread = kMPerBlock / (MaxNumWarps * MThreadsPerWarp);
static_assert(MPerThread * MaxNumWarps * MThreadsPerWarp == kMPerBlock);
static_assert(NThreads * NPerThread == kNPerBlock); static_assert(NThreads * NPerThread == kNPerBlock);
static_assert(MPerThread * kNumWarps * MThreadsPerWarp == kMPerBlock);
return make_static_tile_distribution( return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>, tile_distribution_encoding<sequence<Replicate>,
tuple<sequence<MPerThread, kNumWarps, MThreadsPerWarp>, tuple<sequence<MPerThread, MaxNumWarps, MThreadsPerWarp>,
sequence<NThreads, NPerThread>>, sequence<NThreads, NPerThread>>,
tuple<sequence<1>, sequence<1, 2>>, tuple<sequence<0, 1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>, tuple<sequence<0, 1>, sequence<2, 0>>,
sequence<1, 2>, sequence<1, 2>,
sequence<0, 1>>{}); sequence<0, 1>>{});
} }
...@@ -100,17 +129,15 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy ...@@ -100,17 +129,15 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
{ {
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>; using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kMPerBlock = Problem::kM0;
constexpr index_t kNPerBlock = Problem::kMaxSplits;
constexpr index_t kMPerBlock = Problem::kMaxSplits;
constexpr index_t kNPerBlock = Problem::kM0;
constexpr index_t NPack = constexpr index_t NPack =
GetVectorSizeForTile<kBlockSize, kMPerBlock, kNPerBlock, LSEDataType>(); GetVectorSizeForTile<Problem::kNumWarps, kMPerBlock, kNPerBlock, LSEDataType>();
constexpr auto lse_acc_lds_block_desc_0 = make_naive_tensor_descriptor( constexpr auto lse_acc_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kNPerBlock / NPack>{}, number<kMPerBlock>{}, number<NPack>{}), make_tuple(number<kNPerBlock / NPack>{}, number<kMPerBlock>{}, number<NPack>{}),
make_tuple(number<(kMPerBlock + 1) * NPack>{}, number<NPack>{}, number<1>{}), make_tuple(number<(kMPerBlock + 1) * NPack>{}, number<NPack>{}, number<1>{}),
number<8>{}, number<NPack>{},
number<1>{}); number<1>{});
constexpr auto lse_acc_lds_block_desc = transform_tensor_descriptor( constexpr auto lse_acc_lds_block_desc = transform_tensor_descriptor(
...@@ -129,17 +156,15 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy ...@@ -129,17 +156,15 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
{ {
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>; using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kMPerBlock = Problem::kM0;
constexpr index_t kNPerBlock = Problem::kMaxSplits;
constexpr index_t kMPerBlock = Problem::kMaxSplits;
constexpr index_t kNPerBlock = Problem::kM0;
constexpr index_t NPack = constexpr index_t NPack =
GetVectorSizeForTile<kBlockSize, kMPerBlock, kNPerBlock, LSEDataType>(); GetVectorSizeForTile<Problem::kNumWarps, kMPerBlock, kNPerBlock, LSEDataType>();
constexpr auto lse_acc_lds_block_desc_0 = make_naive_tensor_descriptor( constexpr auto lse_acc_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kNPerBlock / NPack>{}, number<kMPerBlock>{}, number<NPack>{}), make_tuple(number<kNPerBlock / NPack>{}, number<kMPerBlock>{}, number<NPack>{}),
make_tuple(number<(kMPerBlock + 1) * NPack>{}, number<NPack>{}, number<1>{}), make_tuple(number<(kMPerBlock + 1) * NPack>{}, number<NPack>{}, number<1>{}),
number<8>{}, number<NPack>{},
number<1>{}); number<1>{});
constexpr auto lse_acc_t_lds_block_desc = transform_tensor_descriptor( constexpr auto lse_acc_t_lds_block_desc = transform_tensor_descriptor(
...@@ -152,33 +177,86 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy ...@@ -152,33 +177,86 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
return lse_acc_t_lds_block_desc; return lse_acc_t_lds_block_desc;
} }
// 3d + padding, shape=[4 * kM0, kN1]
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccRegTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeOacc4LdsBlockDescriptor()
{ {
constexpr index_t kBlockSize = Problem::kBlockSize; using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
constexpr index_t kNPerBlock = Problem::kMaxSplits; constexpr index_t kMPerBlock = 4 * Problem::kM0;
constexpr index_t kNPerBlock = Problem::kN1;
constexpr index_t NPack =
GetVectorSizeForTile<Problem::kNumWarps, kMPerBlock, kNPerBlock, LSEDataType>();
constexpr auto o_acc_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kNPerBlock / NPack>{}, number<kMPerBlock>{}, number<NPack>{}),
make_tuple(number<(kMPerBlock + 1) * NPack>{}, number<NPack>{}, number<1>{}),
number<8>{},
number<1>{});
constexpr auto o_acc_t_lds_block_desc = transform_tensor_descriptor(
o_acc_lds_block_desc_0,
make_tuple(make_pass_through_transform(kMPerBlock),
make_merge_transform(make_tuple(kNPerBlock / NPack, NPack))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
return o_acc_t_lds_block_desc;
}
// shape=[kM0, kMaxSplits]
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccRegTileDistribution()
{
constexpr index_t kMPerBlock = Problem::kM0; constexpr index_t kMPerBlock = Problem::kM0;
constexpr index_t kNPerBlock = Problem::kMaxSplits;
constexpr index_t NThreads = 4; constexpr index_t MaxNThreads = 8;
constexpr index_t NPerThread = kNPerBlock / NThreads; constexpr index_t NThreads = min(kNPerBlock, MaxNThreads);
constexpr index_t NPerThread = kNPerBlock / NThreads;
constexpr index_t MThreads = kBlockSize / NThreads; constexpr index_t MPerThread = 1;
constexpr index_t MPerThread = kMPerBlock / MThreads; constexpr index_t MThreads = kMPerBlock / MPerThread;
constexpr index_t MWarps = kBlockSize / get_warp_size();
constexpr index_t MThreadPerWarp = get_warp_size() / NThreads; constexpr index_t MThreadPerWarp = get_warp_size() / NThreads;
constexpr index_t MaxNumWarps = (MThreads * NThreads) / get_warp_size();
constexpr index_t Replicate = Problem::kNumWarps / MaxNumWarps;
static_assert(MaxNumWarps * MThreadPerWarp * MPerThread == kMPerBlock);
static_assert(NThreads * NPerThread == kNPerBlock); static_assert(NThreads * NPerThread == kNPerBlock);
static_assert(MWarps * MThreadPerWarp * MPerThread == kMPerBlock);
return make_static_tile_distribution( return make_static_tile_distribution(
tile_distribution_encoding< tile_distribution_encoding<sequence<Replicate>,
sequence<1>, tuple<sequence<MaxNumWarps, MThreadPerWarp, MPerThread>,
tuple<sequence<MWarps, MThreadPerWarp, MPerThread>, sequence<NThreads, NPerThread>>, sequence<NThreads, NPerThread>>,
tuple<sequence<1>, sequence<2, 1>>, tuple<sequence<0, 1>, sequence<2, 1>>,
tuple<sequence<0>, sequence<0, 1>>, tuple<sequence<0, 0>, sequence<0, 1>>,
sequence<1, 2>, sequence<1, 2>,
sequence<2, 1>>{}); sequence<2, 1>>{});
}
// similar to MakeOaccDramTileDistribution(), but duplicate same 1-warp encoding 4 times on M
// direction
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeOacc4DramTileDistribution()
{
constexpr index_t kMPerBlock = Problem::kM0; // real kMPerBlock we want is (4 * kM0)
constexpr index_t kNPerBlock = Problem::kN1;
static_assert(get_warp_size() <= kMPerBlock * kNPerBlock);
constexpr index_t M1 = 1; // compose encoding base on 1 warp
constexpr index_t M2 = min(kMPerBlock / M1, get_warp_size());
constexpr index_t N0 = get_warp_size() / M2;
constexpr index_t N1 = kNPerBlock / N0;
constexpr index_t M0 = kMPerBlock / (M2 * M1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<4, M0, M1, M2>, sequence<N0, N1>>,
tuple<sequence<1, 1>, sequence<1, 2>>,
tuple<sequence<0, 2>, sequence<3, 0>>,
sequence<1, 2>,
sequence<1, 1>>{});
} }
template <typename Problem> template <typename Problem>
...@@ -187,6 +265,7 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy ...@@ -187,6 +265,7 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::kM0; constexpr index_t kMPerBlock = Problem::kM0;
constexpr index_t kNPerBlock = Problem::kN1; constexpr index_t kNPerBlock = Problem::kN1;
static_assert(kBlockSize <= kMPerBlock * kNPerBlock);
constexpr index_t M1 = kBlockSize / get_warp_size(); constexpr index_t M1 = kBlockSize / get_warp_size();
constexpr index_t M2 = min(kMPerBlock / M1, get_warp_size()); constexpr index_t M2 = min(kMPerBlock / M1, get_warp_size());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment