Commit b2510c05 authored by danyao12's avatar danyao12
Browse files

fix dq alignment

parent da2dce18
......@@ -67,8 +67,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
static constexpr index_t kAlignmentOGrad =
kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad<Problem>();
static constexpr index_t kAlignmentQGrad =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQGrad<Problem>();
static constexpr index_t kAlignmentQGrad = 1;
static constexpr index_t kAlignmentKGrad =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad<Problem>();
static constexpr index_t kAlignmentVGrad =
......
......@@ -276,18 +276,6 @@ struct BlockFmhaBwdPipelineDefaultPolicy
return kVecLoad;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQGrad()
{
using BlockGemm = remove_cvref_t<decltype(GetSGradKTBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
using CWarpDstr = typename WG::CWarpDstr;
constexpr auto vec =
CWarpDstr{}.get_ys_to_d_descriptor().get_lengths().at(number<CWarpDstr::NDimY - 1>{});
return vec;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentKGrad()
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment