"...composable_kernel.git" did not exist on "055acaceddc1cbe796386a8ff39efc876a913a07"
Commit b3100b6f authored by danyao12's avatar danyao12
Browse files

remove FmhaBwdTilePartitioner

parent 9d78a6c5
...@@ -104,8 +104,7 @@ using fmha_bwd_dv_epilogue_{F_idx} = ck_tile::Default2DEpilogue< ...@@ -104,8 +104,7 @@ using fmha_bwd_dv_epilogue_{F_idx} = ck_tile::Default2DEpilogue<
false>>; false>>;
using fmha_bwd_dq_dk_dv_kernel_{F_idx} = using fmha_bwd_dq_dk_dv_kernel_{F_idx} =
ck_tile::FmhaBwdDQDKDVKernel<ck_tile::FmhaBwdKTilePartitioner<{F_bn0}>, ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_{F_idx},
fmha_bwd_pipeline_{F_idx},
fmha_bwd_dk_epilogue_{F_idx}, fmha_bwd_dk_epilogue_{F_idx},
fmha_bwd_dv_epilogue_{F_idx}>; fmha_bwd_dv_epilogue_{F_idx}>;
...@@ -517,8 +516,7 @@ using fmha_bwd_dot_do_o_{F_idx} = ...@@ -517,8 +516,7 @@ using fmha_bwd_dot_do_o_{F_idx} =
typename ck_tile::BlockFmhaBwdOGradDotO<fmha_bwd_dot_do_o_pipeline_problem_{F_idx}>; typename ck_tile::BlockFmhaBwdOGradDotO<fmha_bwd_dot_do_o_pipeline_problem_{F_idx}>;
using fmha_bwd_dot_do_o_kernel_{F_idx} = using fmha_bwd_dot_do_o_kernel_{F_idx} =
ck_tile::FmhaBwdOGradDotOKernel<ck_tile::FmhaBwdQTilePartitioner</* BlockSize = */ 64>, ck_tile::FmhaBwdOGradDotOKernel<fmha_bwd_dot_do_o_{F_idx}>;
fmha_bwd_dot_do_o_{F_idx}>;
using dot_do_o_trait_{F_idx} = using dot_do_o_trait_{F_idx} =
fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad}, {F_dvpad}>; fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad}, {F_dvpad}>;
...@@ -641,8 +639,7 @@ using fmha_bwd_convert_dq_{F_idx} = ...@@ -641,8 +639,7 @@ using fmha_bwd_convert_dq_{F_idx} =
typename ck_tile::BlockFmhaBwdConvertQGrad<fmha_bwd_convert_dq_pipeline_problem_{F_idx}>; typename ck_tile::BlockFmhaBwdConvertQGrad<fmha_bwd_convert_dq_pipeline_problem_{F_idx}>;
using fmha_bwd_convert_dq_kernel_{F_idx} = using fmha_bwd_convert_dq_kernel_{F_idx} =
ck_tile::FmhaBwdConvertQGradKernel<ck_tile::FmhaBwdQTilePartitioner<{F_bm0}>, ck_tile::FmhaBwdConvertQGradKernel<fmha_bwd_convert_dq_{F_idx}>;
fmha_bwd_convert_dq_{F_idx}>;
using convert_dq_trait_{F_idx} = fmha_bwd_convert_dq_traits_<{F_hdim}, using convert_dq_trait_{F_idx} = fmha_bwd_convert_dq_traits_<{F_hdim},
{F_dtype}, {F_dtype},
......
...@@ -8,7 +8,6 @@ ...@@ -8,7 +8,6 @@
#include "ck_tile/ops/fmha/block/block_masking.hpp" #include "ck_tile/ops/fmha/block/block_masking.hpp"
#include "ck_tile/ops/fmha/block/block_position_encoding.hpp" #include "ck_tile/ops/fmha/block/block_position_encoding.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp"
......
...@@ -23,13 +23,9 @@ ...@@ -23,13 +23,9 @@
namespace ck_tile { namespace ck_tile {
template <typename TilePartitioner_, template <typename FmhaPipeline_, typename KGradEpiloguePipeline_, typename VGradEpiloguePipeline_>
typename FmhaPipeline_,
typename KGradEpiloguePipeline_,
typename VGradEpiloguePipeline_>
struct FmhaBwdDQDKDVKernel struct FmhaBwdDQDKDVKernel
{ {
using TilePartitioner = ck_tile::remove_cvref_t<TilePartitioner_>;
using FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_>; using FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_>;
using KGradEpiloguePipeline = ck_tile::remove_cvref_t<KGradEpiloguePipeline_>; using KGradEpiloguePipeline = ck_tile::remove_cvref_t<KGradEpiloguePipeline_>;
using VGradEpiloguePipeline = ck_tile::remove_cvref_t<VGradEpiloguePipeline_>; using VGradEpiloguePipeline = ck_tile::remove_cvref_t<VGradEpiloguePipeline_>;
...@@ -536,7 +532,17 @@ struct FmhaBwdDQDKDVKernel ...@@ -536,7 +532,17 @@ struct FmhaBwdDQDKDVKernel
CK_TILE_HOST static constexpr auto CK_TILE_HOST static constexpr auto
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_k_) GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_k_)
{ {
return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_k_); return dim3(
batch_size_, nhead_, ck_tile::integer_divide_ceil(seqlen_k_, FmhaPipeline::kN0));
}
CK_TILE_DEVICE static constexpr auto GetTileIndex()
{
const index_t i_block = blockIdx.z;
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.x;
return ck_tile::make_tuple(i_block, i_nhead, i_batch);
} }
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
...@@ -554,7 +560,7 @@ struct FmhaBwdDQDKDVKernel ...@@ -554,7 +560,7 @@ struct FmhaBwdDQDKDVKernel
__shared__ char smem_ptr[GetSmemSize()]; __shared__ char smem_ptr[GetSmemSize()];
// divide problem // divide problem
const auto [i_tile_n, i_nhead, i_batch] = TilePartitioner{}(kargs.seqlen_k); const auto [i_tile_n, i_nhead, i_batch] = GetTileIndex();
const index_t i_n0 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN0); const index_t i_n0 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN0);
...@@ -1037,10 +1043,9 @@ struct FmhaBwdDQDKDVKernel ...@@ -1037,10 +1043,9 @@ struct FmhaBwdDQDKDVKernel
} }
}; };
template <typename TilePartitioner_, typename FmhaBwdOGradDotO_> template <typename FmhaBwdOGradDotO_>
struct FmhaBwdOGradDotOKernel struct FmhaBwdOGradDotOKernel
{ {
using TilePartitioner = ck_tile::remove_cvref_t<TilePartitioner_>;
using FmhaBwdOGradDotO = ck_tile::remove_cvref_t<FmhaBwdOGradDotO_>; using FmhaBwdOGradDotO = ck_tile::remove_cvref_t<FmhaBwdOGradDotO_>;
static constexpr ck_tile::index_t kBlockSize = FmhaBwdOGradDotO::kBlockSize; static constexpr ck_tile::index_t kBlockSize = FmhaBwdOGradDotO::kBlockSize;
static constexpr ck_tile::index_t kBlockPerCu = FmhaBwdOGradDotO::kBlockPerCu; static constexpr ck_tile::index_t kBlockPerCu = FmhaBwdOGradDotO::kBlockPerCu;
...@@ -1189,7 +1194,16 @@ struct FmhaBwdOGradDotOKernel ...@@ -1189,7 +1194,16 @@ struct FmhaBwdOGradDotOKernel
CK_TILE_HOST static constexpr auto CK_TILE_HOST static constexpr auto
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_) GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_)
{ {
return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_q_); return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0), nhead_, batch_size_);
}
CK_TILE_DEVICE static constexpr auto GetTileIndex()
{
const index_t i_block = blockIdx.x;
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
return ck_tile::make_tuple(i_block, i_nhead, i_batch);
} }
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
...@@ -1199,7 +1213,7 @@ struct FmhaBwdOGradDotOKernel ...@@ -1199,7 +1213,7 @@ struct FmhaBwdOGradDotOKernel
CK_TILE_DEVICE void operator()(Kargs kargs) const CK_TILE_DEVICE void operator()(Kargs kargs) const
{ {
// divide problem // divide problem
const auto [i_tile_m, i_nhead, i_batch] = TilePartitioner{}(kargs.seqlen_q); const auto [i_tile_m, i_nhead, i_batch] = GetTileIndex();
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * kM0); const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * kM0);
...@@ -1286,10 +1300,9 @@ struct FmhaBwdOGradDotOKernel ...@@ -1286,10 +1300,9 @@ struct FmhaBwdOGradDotOKernel
} }
}; };
template <typename TilePartitioner_, typename FmhaBwdConvertQGrad_> template <typename FmhaBwdConvertQGrad_>
struct FmhaBwdConvertQGradKernel struct FmhaBwdConvertQGradKernel
{ {
using TilePartitioner = ck_tile::remove_cvref_t<TilePartitioner_>;
using FmhaBwdConvertQGrad = ck_tile::remove_cvref_t<FmhaBwdConvertQGrad_>; using FmhaBwdConvertQGrad = ck_tile::remove_cvref_t<FmhaBwdConvertQGrad_>;
static constexpr ck_tile::index_t kBlockSize = FmhaBwdConvertQGrad::kBlockSize; static constexpr ck_tile::index_t kBlockSize = FmhaBwdConvertQGrad::kBlockSize;
static constexpr ck_tile::index_t kBlockPerCu = FmhaBwdConvertQGrad::kBlockPerCu; static constexpr ck_tile::index_t kBlockPerCu = FmhaBwdConvertQGrad::kBlockPerCu;
...@@ -1439,7 +1452,16 @@ struct FmhaBwdConvertQGradKernel ...@@ -1439,7 +1452,16 @@ struct FmhaBwdConvertQGradKernel
CK_TILE_HOST static constexpr auto CK_TILE_HOST static constexpr auto
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_) GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_)
{ {
return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_q_); return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0), nhead_, batch_size_);
}
CK_TILE_DEVICE static constexpr auto GetTileIndex()
{
const index_t i_block = blockIdx.x;
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
return ck_tile::make_tuple(i_block, i_nhead, i_batch);
} }
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
...@@ -1449,7 +1471,7 @@ struct FmhaBwdConvertQGradKernel ...@@ -1449,7 +1471,7 @@ struct FmhaBwdConvertQGradKernel
CK_TILE_DEVICE void operator()(Kargs kargs) const CK_TILE_DEVICE void operator()(Kargs kargs) const
{ {
// divide problem // divide problem
const auto [i_tile_m, i_nhead, i_batch] = TilePartitioner{}(kargs.seqlen_q); const auto [i_tile_m, i_nhead, i_batch] = GetTileIndex();
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * kM0); const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * kM0);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <ck_tile::index_t kN0>
struct FmhaBwdKTilePartitioner
{
CK_TILE_HOST static constexpr auto
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_k_)
{
// TODO: this may need tuning
return dim3(batch_size_, nhead_, ck_tile::integer_divide_ceil(seqlen_k_, kN0));
}
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_k*/)
{
const index_t i_block = blockIdx.z;
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.x;
return ck_tile::make_tuple(i_block, i_nhead, i_batch);
}
};
template <ck_tile::index_t kM0>
struct FmhaBwdQTilePartitioner
{
CK_TILE_HOST static constexpr auto
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_)
{
// TODO: this may need tuning
return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0), nhead_, batch_size_);
}
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/)
{
const index_t i_block = blockIdx.x;
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
return ck_tile::make_tuple(i_block, i_nhead, i_batch);
}
};
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment