"docs/source/en/optimization/open_vino.md" did not exist on "a5d2ee9d474e35c874fcc2a3b1085012202c6b47"
Commit b3100b6f authored by danyao12's avatar danyao12
Browse files

remove FmhaBwdTilePartitioner

parent 9d78a6c5
......@@ -104,8 +104,7 @@ using fmha_bwd_dv_epilogue_{F_idx} = ck_tile::Default2DEpilogue<
false>>;
using fmha_bwd_dq_dk_dv_kernel_{F_idx} =
ck_tile::FmhaBwdDQDKDVKernel<ck_tile::FmhaBwdKTilePartitioner<{F_bn0}>,
fmha_bwd_pipeline_{F_idx},
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_{F_idx},
fmha_bwd_dk_epilogue_{F_idx},
fmha_bwd_dv_epilogue_{F_idx}>;
......@@ -517,8 +516,7 @@ using fmha_bwd_dot_do_o_{F_idx} =
typename ck_tile::BlockFmhaBwdOGradDotO<fmha_bwd_dot_do_o_pipeline_problem_{F_idx}>;
using fmha_bwd_dot_do_o_kernel_{F_idx} =
ck_tile::FmhaBwdOGradDotOKernel<ck_tile::FmhaBwdQTilePartitioner</* BlockSize = */ 64>,
fmha_bwd_dot_do_o_{F_idx}>;
ck_tile::FmhaBwdOGradDotOKernel<fmha_bwd_dot_do_o_{F_idx}>;
using dot_do_o_trait_{F_idx} =
fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad}, {F_dvpad}>;
......@@ -641,8 +639,7 @@ using fmha_bwd_convert_dq_{F_idx} =
typename ck_tile::BlockFmhaBwdConvertQGrad<fmha_bwd_convert_dq_pipeline_problem_{F_idx}>;
using fmha_bwd_convert_dq_kernel_{F_idx} =
ck_tile::FmhaBwdConvertQGradKernel<ck_tile::FmhaBwdQTilePartitioner<{F_bm0}>,
fmha_bwd_convert_dq_{F_idx}>;
ck_tile::FmhaBwdConvertQGradKernel<fmha_bwd_convert_dq_{F_idx}>;
using convert_dq_trait_{F_idx} = fmha_bwd_convert_dq_traits_<{F_hdim},
{F_dtype},
......
......@@ -8,7 +8,6 @@
#include "ck_tile/ops/fmha/block/block_masking.hpp"
#include "ck_tile/ops/fmha/block/block_position_encoding.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp"
......
......@@ -23,13 +23,9 @@
namespace ck_tile {
template <typename TilePartitioner_,
typename FmhaPipeline_,
typename KGradEpiloguePipeline_,
typename VGradEpiloguePipeline_>
template <typename FmhaPipeline_, typename KGradEpiloguePipeline_, typename VGradEpiloguePipeline_>
struct FmhaBwdDQDKDVKernel
{
using TilePartitioner = ck_tile::remove_cvref_t<TilePartitioner_>;
using FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_>;
using KGradEpiloguePipeline = ck_tile::remove_cvref_t<KGradEpiloguePipeline_>;
using VGradEpiloguePipeline = ck_tile::remove_cvref_t<VGradEpiloguePipeline_>;
......@@ -536,7 +532,17 @@ struct FmhaBwdDQDKDVKernel
CK_TILE_HOST static constexpr auto
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_k_)
{
return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_k_);
return dim3(
batch_size_, nhead_, ck_tile::integer_divide_ceil(seqlen_k_, FmhaPipeline::kN0));
}
CK_TILE_DEVICE static constexpr auto GetTileIndex()
{
const index_t i_block = blockIdx.z;
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.x;
return ck_tile::make_tuple(i_block, i_nhead, i_batch);
}
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
......@@ -554,7 +560,7 @@ struct FmhaBwdDQDKDVKernel
__shared__ char smem_ptr[GetSmemSize()];
// divide problem
const auto [i_tile_n, i_nhead, i_batch] = TilePartitioner{}(kargs.seqlen_k);
const auto [i_tile_n, i_nhead, i_batch] = GetTileIndex();
const index_t i_n0 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN0);
......@@ -1037,10 +1043,9 @@ struct FmhaBwdDQDKDVKernel
}
};
template <typename TilePartitioner_, typename FmhaBwdOGradDotO_>
template <typename FmhaBwdOGradDotO_>
struct FmhaBwdOGradDotOKernel
{
using TilePartitioner = ck_tile::remove_cvref_t<TilePartitioner_>;
using FmhaBwdOGradDotO = ck_tile::remove_cvref_t<FmhaBwdOGradDotO_>;
static constexpr ck_tile::index_t kBlockSize = FmhaBwdOGradDotO::kBlockSize;
static constexpr ck_tile::index_t kBlockPerCu = FmhaBwdOGradDotO::kBlockPerCu;
......@@ -1189,7 +1194,16 @@ struct FmhaBwdOGradDotOKernel
CK_TILE_HOST static constexpr auto
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_)
{
return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_q_);
return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0), nhead_, batch_size_);
}
CK_TILE_DEVICE static constexpr auto GetTileIndex()
{
const index_t i_block = blockIdx.x;
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
return ck_tile::make_tuple(i_block, i_nhead, i_batch);
}
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
......@@ -1199,7 +1213,7 @@ struct FmhaBwdOGradDotOKernel
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
// divide problem
const auto [i_tile_m, i_nhead, i_batch] = TilePartitioner{}(kargs.seqlen_q);
const auto [i_tile_m, i_nhead, i_batch] = GetTileIndex();
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * kM0);
......@@ -1286,10 +1300,9 @@ struct FmhaBwdOGradDotOKernel
}
};
template <typename TilePartitioner_, typename FmhaBwdConvertQGrad_>
template <typename FmhaBwdConvertQGrad_>
struct FmhaBwdConvertQGradKernel
{
using TilePartitioner = ck_tile::remove_cvref_t<TilePartitioner_>;
using FmhaBwdConvertQGrad = ck_tile::remove_cvref_t<FmhaBwdConvertQGrad_>;
static constexpr ck_tile::index_t kBlockSize = FmhaBwdConvertQGrad::kBlockSize;
static constexpr ck_tile::index_t kBlockPerCu = FmhaBwdConvertQGrad::kBlockPerCu;
......@@ -1439,7 +1452,16 @@ struct FmhaBwdConvertQGradKernel
CK_TILE_HOST static constexpr auto
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_)
{
return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_q_);
return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0), nhead_, batch_size_);
}
CK_TILE_DEVICE static constexpr auto GetTileIndex()
{
const index_t i_block = blockIdx.x;
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
return ck_tile::make_tuple(i_block, i_nhead, i_batch);
}
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
......@@ -1449,7 +1471,7 @@ struct FmhaBwdConvertQGradKernel
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
// divide problem
const auto [i_tile_m, i_nhead, i_batch] = TilePartitioner{}(kargs.seqlen_q);
const auto [i_tile_m, i_nhead, i_batch] = GetTileIndex();
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * kM0);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <ck_tile::index_t kN0>
struct FmhaBwdKTilePartitioner
{
CK_TILE_HOST static constexpr auto
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_k_)
{
// TODO: this may need tuning
return dim3(batch_size_, nhead_, ck_tile::integer_divide_ceil(seqlen_k_, kN0));
}
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_k*/)
{
const index_t i_block = blockIdx.z;
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.x;
return ck_tile::make_tuple(i_block, i_nhead, i_batch);
}
};
template <ck_tile::index_t kM0>
struct FmhaBwdQTilePartitioner
{
CK_TILE_HOST static constexpr auto
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_)
{
// TODO: this may need tuning
return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0), nhead_, batch_size_);
}
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/)
{
const index_t i_block = blockIdx.x;
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
return ck_tile::make_tuple(i_block, i_nhead, i_batch);
}
};
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment