Commit d99d4d56 authored by aska-0096's avatar aska-0096
Browse files

remove xor usage in q, do and ds

parent 545eec16
......@@ -1011,9 +1011,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t KPack = GetSmemKPackQ<Problem>();
constexpr index_t kKPack = GetSmemKPackQ<Problem>();
return MakeXLdsBlockDescriptor<kMPerBlock, kKPerBlock, KPack, false>();
return MakeXLdsBlockDescriptor<kMPerBlock, kKPerBlock, kKPack, false>();
}
template <typename Problem>
......@@ -1077,7 +1077,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPack = GetSmemKPackQ<Problem>();
constexpr index_t kKPack = GetAlignmentQ<Problem>();
constexpr index_t kKPackT = GetSmemKPackQT<Problem>();
return MakeXTLdsBlockDescriptor<Problem, kNPerBlock, kKPerBlock, kKPack, kKPackT>();
......@@ -1218,9 +1218,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr index_t KPack = GetSmemKPackOGrad<Problem>();
constexpr index_t kKPack = GetSmemKPackOGrad<Problem>();
return MakeXLdsBlockDescriptor<kMPerBlock, kKPerBlock, KPack, false>();
return MakeXLdsBlockDescriptor<kMPerBlock, kKPerBlock, kKPack, false>();
}
template <typename Problem>
......@@ -1284,7 +1284,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPack = GetSmemKPackOGrad<Problem>();
constexpr index_t kKPack = GetAlignmentOGrad<Problem>();
constexpr index_t kKPackT = GetSmemKPackOGradT<Problem>();
return MakeXTLdsBlockDescriptor<Problem, kNPerBlock, kKPerBlock, kKPack, kKPackT>();
......@@ -1377,7 +1377,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPack = GetSmemKPackSGrad<Problem>();
return MakeXLdsBlockDescriptor<kMPerBlock, kKPerBlock, kKPack>();
return MakeXLdsBlockDescriptor<kMPerBlock, kKPerBlock, kKPack, false>();
}
template <typename Problem>
......@@ -1924,20 +1924,18 @@ struct BlockFmhaBwdPipelineDefaultPolicy
kM0 * kQKHeaddim / get_warp_size() / GetTransposedAlignmentQ<Problem>();
// 16 * 32 / 64 / 8 = 1
static constexpr index_t SGradT_LDS_READ_P1 =
// kM0 * kK4 / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad<Problem>();
kM0 * kK4 / (get_warp_size() * Gemm4MWarp) / 2;
kM0 * kK4 / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad<Problem>();
// 16 * 128 / 64 / 8 = 4
static constexpr index_t Q_LDS_READ =
kM0 * kK0 / (get_warp_size() * Gemm0MWarp) / GetAlignmentQ<Problem>();
kM0 * kK0 / (get_warp_size() * Gemm0MWarp) / GetSmemKPackQ<Problem>();
// 1
static constexpr index_t LSE_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4);
// 16 * 96 / 64 / 8 = 3
static constexpr index_t SGradT_LDS_READ_P2 =
// kM0 * (kN0 - kK4) / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad<Problem>();
kM0 * (kN0 - kK4) / (get_warp_size() * Gemm4MWarp) / 2;
kM0 * (kN0 - kK4) / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad<Problem>();
// 16 * 128 / 64 / 8 = 4
static constexpr index_t OGrad_LDS_READ =
kM0 * kK2 / (get_warp_size() * Gemm2MWarp) / GetAlignmentOGrad<Problem>();
kM0 * kK2 / (get_warp_size() * Gemm2MWarp) / GetSmemKPackOGrad<Problem>();
// 1
static constexpr index_t D_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment