Commit d99d4d56 authored by aska-0096's avatar aska-0096
Browse files

remove xor usage in q, do and ds

parent 545eec16
...@@ -1011,9 +1011,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1011,9 +1011,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t KPack = GetSmemKPackQ<Problem>(); constexpr index_t kKPack = GetSmemKPackQ<Problem>();
return MakeXLdsBlockDescriptor<kMPerBlock, kKPerBlock, KPack, false>(); return MakeXLdsBlockDescriptor<kMPerBlock, kKPerBlock, kKPack, false>();
} }
template <typename Problem> template <typename Problem>
...@@ -1077,7 +1077,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1077,7 +1077,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kM0; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPack = GetSmemKPackQ<Problem>(); constexpr index_t kKPack = GetAlignmentQ<Problem>();
constexpr index_t kKPackT = GetSmemKPackQT<Problem>(); constexpr index_t kKPackT = GetSmemKPackQT<Problem>();
return MakeXTLdsBlockDescriptor<Problem, kNPerBlock, kKPerBlock, kKPack, kKPackT>(); return MakeXTLdsBlockDescriptor<Problem, kNPerBlock, kKPerBlock, kKPack, kKPackT>();
...@@ -1218,9 +1218,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1218,9 +1218,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr index_t KPack = GetSmemKPackOGrad<Problem>(); constexpr index_t kKPack = GetSmemKPackOGrad<Problem>();
return MakeXLdsBlockDescriptor<kMPerBlock, kKPerBlock, KPack, false>(); return MakeXLdsBlockDescriptor<kMPerBlock, kKPerBlock, kKPack, false>();
} }
template <typename Problem> template <typename Problem>
...@@ -1284,7 +1284,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1284,7 +1284,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kM0; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPack = GetSmemKPackOGrad<Problem>(); constexpr index_t kKPack = GetAlignmentOGrad<Problem>();
constexpr index_t kKPackT = GetSmemKPackOGradT<Problem>(); constexpr index_t kKPackT = GetSmemKPackOGradT<Problem>();
return MakeXTLdsBlockDescriptor<Problem, kNPerBlock, kKPerBlock, kKPack, kKPackT>(); return MakeXTLdsBlockDescriptor<Problem, kNPerBlock, kKPerBlock, kKPack, kKPackT>();
...@@ -1377,7 +1377,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1377,7 +1377,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPack = GetSmemKPackSGrad<Problem>(); constexpr index_t kKPack = GetSmemKPackSGrad<Problem>();
return MakeXLdsBlockDescriptor<kMPerBlock, kKPerBlock, kKPack>(); return MakeXLdsBlockDescriptor<kMPerBlock, kKPerBlock, kKPack, false>();
} }
template <typename Problem> template <typename Problem>
...@@ -1924,20 +1924,18 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1924,20 +1924,18 @@ struct BlockFmhaBwdPipelineDefaultPolicy
kM0 * kQKHeaddim / get_warp_size() / GetTransposedAlignmentQ<Problem>(); kM0 * kQKHeaddim / get_warp_size() / GetTransposedAlignmentQ<Problem>();
// 16 * 32 / 64 / 8 = 1 // 16 * 32 / 64 / 8 = 1
static constexpr index_t SGradT_LDS_READ_P1 = static constexpr index_t SGradT_LDS_READ_P1 =
// kM0 * kK4 / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad<Problem>(); kM0 * kK4 / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad<Problem>();
kM0 * kK4 / (get_warp_size() * Gemm4MWarp) / 2;
// 16 * 128 / 64 / 8 = 4 // 16 * 128 / 64 / 8 = 4
static constexpr index_t Q_LDS_READ = static constexpr index_t Q_LDS_READ =
kM0 * kK0 / (get_warp_size() * Gemm0MWarp) / GetAlignmentQ<Problem>(); kM0 * kK0 / (get_warp_size() * Gemm0MWarp) / GetSmemKPackQ<Problem>();
// 1 // 1
static constexpr index_t LSE_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4); static constexpr index_t LSE_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4);
// 16 * 96 / 64 / 8 = 3 // 16 * 96 / 64 / 8 = 3
static constexpr index_t SGradT_LDS_READ_P2 = static constexpr index_t SGradT_LDS_READ_P2 =
// kM0 * (kN0 - kK4) / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad<Problem>(); kM0 * (kN0 - kK4) / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad<Problem>();
kM0 * (kN0 - kK4) / (get_warp_size() * Gemm4MWarp) / 2;
// 16 * 128 / 64 / 8 = 4 // 16 * 128 / 64 / 8 = 4
static constexpr index_t OGrad_LDS_READ = static constexpr index_t OGrad_LDS_READ =
kM0 * kK2 / (get_warp_size() * Gemm2MWarp) / GetAlignmentOGrad<Problem>(); kM0 * kK2 / (get_warp_size() * Gemm2MWarp) / GetSmemKPackOGrad<Problem>();
// 1 // 1
static constexpr index_t D_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4); static constexpr index_t D_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment