Commit b2510c05 authored by danyao12's avatar danyao12
Browse files

fix dq alignment

parent da2dce18
...@@ -67,8 +67,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR ...@@ -67,8 +67,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>(); kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
static constexpr index_t kAlignmentOGrad = static constexpr index_t kAlignmentOGrad =
kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad<Problem>(); kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad<Problem>();
static constexpr index_t kAlignmentQGrad = static constexpr index_t kAlignmentQGrad = 1;
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQGrad<Problem>();
static constexpr index_t kAlignmentKGrad = static constexpr index_t kAlignmentKGrad =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad<Problem>(); kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad<Problem>();
static constexpr index_t kAlignmentVGrad = static constexpr index_t kAlignmentVGrad =
......
...@@ -276,18 +276,6 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -276,18 +276,6 @@ struct BlockFmhaBwdPipelineDefaultPolicy
return kVecLoad; return kVecLoad;
} }
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQGrad()
{
using BlockGemm = remove_cvref_t<decltype(GetSGradKTBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
using CWarpDstr = typename WG::CWarpDstr;
constexpr auto vec =
CWarpDstr{}.get_ys_to_d_descriptor().get_lengths().at(number<CWarpDstr::NDimY - 1>{});
return vec;
}
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentKGrad() CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentKGrad()
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment