Commit af695bee authored by danyao12's avatar danyao12
Browse files

Merge branch 'mha-train-develop-bwdopt-bias' into mha-train-develop-dropout8bit

parents 8ced5c4f 9e527364
...@@ -262,29 +262,29 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -262,29 +262,29 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0, (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
"Invalid tuning param!"); "Invalid tuning param!");
const auto M = q_grid_desc_k0_m_k1.GetLength(I1); const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1); const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
const auto K = q_grid_desc_k0_m_k1.GetLength(I0) * q_grid_desc_k0_m_k1.GetLength(I2); const auto K = q_grid_desc_k0_m_k1.GetLength(I0) * q_grid_desc_k0_m_k1.GetLength(I2);
const auto Gemm1N = v_grid_desc_o0_n_o1.GetLength(I0) * v_grid_desc_o0_n_o1.GetLength(I2); const auto O = v_grid_desc_o0_n_o1.GetLength(I0) * v_grid_desc_o0_n_o1.GetLength(I2);
// This assumption reduces implemention complexity by categorizing 6 separate GEMMs into 3 // This assumption reduces implemention complexity by categorizing 6 separate GEMMs into 3
// types of GEMM operations, therefore some code body can be reused accordingly // types of GEMM operations, therefore some code body can be reused accordingly
// P_MNK / dP_MNO Gemm (Gemm0 rcr) // P_MNK / dP_MNO Gemm (Gemm0 rcr)
// Y_MON / dQ_MKN Gemm (Gemm1 rrr) // Y_MON / dQ_MKN Gemm (Gemm1 rrr)
// dV_NOM / dK_NKM Gemm (Gemm2 crr) // dV_NOM / dK_NKM Gemm (Gemm2 crr)
if(Gemm1N != K) if(O != K)
{ {
std::cerr << "SizeK must be equal to SizeO (equal attention head size)" << '\n'; std::cerr << "SizeK must be equal to SizeO (equal attention head size)" << '\n';
return false; return false;
} }
if(!(M == y_grid_desc_m_o.GetLength(I0) && Gemm1N == y_grid_desc_m_o.GetLength(I1))) if(!(M == y_grid_desc_m_o.GetLength(I0) && O == y_grid_desc_m_o.GetLength(I1)))
{ {
return false; return false;
} }
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0 && if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0 &&
Gemm1N % Gemm1NPerBlock == 0)) O % Gemm1NPerBlock == 0))
{ {
return false; return false;
} }
......
...@@ -113,11 +113,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -113,11 +113,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
static constexpr auto WaveSize = 64; static constexpr auto WaveSize = 64;
// K1 should be Number<...> // K1 should be Number<...>
// Gemm0 // Gemm0
static constexpr auto AK0 = Number<KPerBlock / AK1Value>{}; static constexpr auto AK0 = Number<KPerBlock / AK1Value>{};
static constexpr auto BK0 = Number<KPerBlock / BK1Value>{}; static constexpr auto BK0 = Number<KPerBlock / BK1Value>{};
static constexpr auto K_K0 = Number<Gemm1NPerBlock / BK1Value>{}; static constexpr auto AK1 = Number<AK1Value>{};
static constexpr auto AK1 = Number<AK1Value>{}; static constexpr auto BK1 = Number<BK1Value>{};
static constexpr auto BK1 = Number<BK1Value>{};
static constexpr auto Gemm0MWaves = MPerBlock / (MPerXdl * MXdlPerWave); static constexpr auto Gemm0MWaves = MPerBlock / (MPerXdl * MXdlPerWave);
static constexpr auto Gemm0NWaves = NPerBlock / (NPerXdl * NXdlPerWave); static constexpr auto Gemm0NWaves = NPerBlock / (NPerXdl * NXdlPerWave);
...@@ -127,6 +126,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -127,6 +126,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
static constexpr auto B1K1 = Number<B1K1Value>{}; static constexpr auto B1K1 = Number<B1K1Value>{};
static constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma; static constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma;
static constexpr auto K_K0 = Number<Gemm1NPerBlock / BK1Value>{};
static constexpr auto V_K3 = BK1; static constexpr auto V_K3 = BK1;
static constexpr auto V_K2 = mfma.num_input_blks; static constexpr auto V_K2 = mfma.num_input_blks;
static constexpr auto V_K1 = KPerBlock / V_K2 / V_K3; static constexpr auto V_K1 = KPerBlock / V_K2 / V_K3;
...@@ -307,29 +307,29 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -307,29 +307,29 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0, (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
"Invalid tuning param!"); "Invalid tuning param!");
const auto M = q_grid_desc_k0_m_k1.GetLength(I1); const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1); const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
const auto K = q_grid_desc_k0_m_k1.GetLength(I0) * q_grid_desc_k0_m_k1.GetLength(I2); const auto K = q_grid_desc_k0_m_k1.GetLength(I0) * q_grid_desc_k0_m_k1.GetLength(I2);
const auto Gemm1N = v_grid_desc_o0_n_o1.GetLength(I0) * v_grid_desc_o0_n_o1.GetLength(I2); const auto O = v_grid_desc_o0_n_o1.GetLength(I0) * v_grid_desc_o0_n_o1.GetLength(I2);
// This assumption reduces implemention complexity by categorizing 6 separate GEMMs into 3 // This assumption reduces implemention complexity by categorizing 6 separate GEMMs into 3
// types of GEMM operations, therefore some code body can be reused accordingly // types of GEMM operations, therefore some code body can be reused accordingly
// P_MNK / dP_MNO Gemm (Gemm0 rcr) // P_MNK / dP_MNO Gemm (Gemm0 rcr)
// Y_MON / dQ_MKN Gemm (Gemm1 rrr) // Y_MON / dQ_MKN Gemm (Gemm1 rrr)
// dV_NOM / dK_NKM Gemm (Gemm2 crr) // dV_NOM / dK_NKM Gemm (Gemm2 crr)
if(Gemm1N != K) if(O != K)
{ {
std::cerr << "SizeK must be equal to SizeO (equal attention head size)" << '\n'; std::cerr << "SizeK must be equal to SizeO (equal attention head size)" << '\n';
return false; return false;
} }
if(!(M == y_grid_desc_m_o.GetLength(I0) && Gemm1N == y_grid_desc_m_o.GetLength(I1))) if(!(M == y_grid_desc_m_o.GetLength(I0) && O == y_grid_desc_m_o.GetLength(I1)))
{ {
return false; return false;
} }
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0 && if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0 &&
Gemm1N % Gemm1NPerBlock == 0)) O % Gemm1NPerBlock == 0))
{ {
return false; return false;
} }
......
...@@ -261,29 +261,29 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -261,29 +261,29 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0, (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
"Invalid tuning param!"); "Invalid tuning param!");
const auto M = q_grid_desc_k0_m_k1.GetLength(I1); const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1); const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
const auto K = q_grid_desc_k0_m_k1.GetLength(I0) * q_grid_desc_k0_m_k1.GetLength(I2); const auto K = q_grid_desc_k0_m_k1.GetLength(I0) * q_grid_desc_k0_m_k1.GetLength(I2);
const auto Gemm1N = v_grid_desc_o0_n_o1.GetLength(I0) * v_grid_desc_o0_n_o1.GetLength(I2); const auto O = v_grid_desc_o0_n_o1.GetLength(I0) * v_grid_desc_o0_n_o1.GetLength(I2);
// This assumption reduces implemention complexity by categorizing 6 separate GEMMs into 3 // This assumption reduces implemention complexity by categorizing 6 separate GEMMs into 3
// types of GEMM operations, therefore some code body can be reused accordingly // types of GEMM operations, therefore some code body can be reused accordingly
// P_MNK / dP_MNO Gemm (Gemm0 rcr) // P_MNK / dP_MNO Gemm (Gemm0 rcr)
// Y_MON / dQ_MKN Gemm (Gemm1 rrr) // Y_MON / dQ_MKN Gemm (Gemm1 rrr)
// dV_NOM / dK_NKM Gemm (Gemm2 crr) // dV_NOM / dK_NKM Gemm (Gemm2 crr)
if(Gemm1N != K) if(O != K)
{ {
std::cerr << "SizeK must be equal to SizeO (equal attention head size)" << '\n'; std::cerr << "SizeK must be equal to SizeO (equal attention head size)" << '\n';
return false; return false;
} }
if(!(M == y_grid_desc_m_o.GetLength(I0) && Gemm1N == y_grid_desc_m_o.GetLength(I1))) if(!(M == y_grid_desc_m_o.GetLength(I0) && O == y_grid_desc_m_o.GetLength(I1)))
{ {
return false; return false;
} }
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0 && if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0 &&
Gemm1N % Gemm1NPerBlock == 0)) O % Gemm1NPerBlock == 0))
{ {
return false; return false;
} }
......
...@@ -112,11 +112,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -112,11 +112,10 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static constexpr auto WaveSize = 64; static constexpr auto WaveSize = 64;
// K1 should be Number<...> // K1 should be Number<...>
// Gemm0 // Gemm0
static constexpr auto AK0 = Number<KPerBlock / AK1Value>{}; static constexpr auto AK0 = Number<KPerBlock / AK1Value>{};
static constexpr auto BK0 = Number<KPerBlock / BK1Value>{}; static constexpr auto BK0 = Number<KPerBlock / BK1Value>{};
static constexpr auto K_K0 = Number<Gemm1NPerBlock / BK1Value>{}; static constexpr auto AK1 = Number<AK1Value>{};
static constexpr auto AK1 = Number<AK1Value>{}; static constexpr auto BK1 = Number<BK1Value>{};
static constexpr auto BK1 = Number<BK1Value>{};
static constexpr auto Gemm0MWaves = MPerBlock / (MPerXdl * MXdlPerWave); static constexpr auto Gemm0MWaves = MPerBlock / (MPerXdl * MXdlPerWave);
static constexpr auto Gemm0NWaves = NPerBlock / (NPerXdl * NXdlPerWave); static constexpr auto Gemm0NWaves = NPerBlock / (NPerXdl * NXdlPerWave);
...@@ -126,6 +125,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -126,6 +125,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static constexpr auto B1K1 = Number<B1K1Value>{}; static constexpr auto B1K1 = Number<B1K1Value>{};
static constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma; static constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma;
static constexpr auto K_K0 = Number<Gemm1NPerBlock / BK1Value>{};
static constexpr auto V_K3 = BK1; static constexpr auto V_K3 = BK1;
static constexpr auto V_K2 = mfma.num_input_blks; static constexpr auto V_K2 = mfma.num_input_blks;
static constexpr auto V_K1 = KPerBlock / V_K2 / V_K3; static constexpr auto V_K1 = KPerBlock / V_K2 / V_K3;
...@@ -306,29 +306,29 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -306,29 +306,29 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0, (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
"Invalid tuning param!"); "Invalid tuning param!");
const auto M = q_grid_desc_k0_m_k1.GetLength(I1); const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1); const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
const auto K = q_grid_desc_k0_m_k1.GetLength(I0) * q_grid_desc_k0_m_k1.GetLength(I2); const auto K = q_grid_desc_k0_m_k1.GetLength(I0) * q_grid_desc_k0_m_k1.GetLength(I2);
const auto Gemm1N = v_grid_desc_o0_n_o1.GetLength(I0) * v_grid_desc_o0_n_o1.GetLength(I2); const auto O = v_grid_desc_o0_n_o1.GetLength(I0) * v_grid_desc_o0_n_o1.GetLength(I2);
// This assumption reduces implemention complexity by categorizing 6 separate GEMMs into 3 // This assumption reduces implemention complexity by categorizing 6 separate GEMMs into 3
// types of GEMM operations, therefore some code body can be reused accordingly // types of GEMM operations, therefore some code body can be reused accordingly
// P_MNK / dP_MNO Gemm (Gemm0 rcr) // P_MNK / dP_MNO Gemm (Gemm0 rcr)
// Y_MON / dQ_MKN Gemm (Gemm1 rrr) // Y_MON / dQ_MKN Gemm (Gemm1 rrr)
// dV_NOM / dK_NKM Gemm (Gemm2 crr) // dV_NOM / dK_NKM Gemm (Gemm2 crr)
if(Gemm1N != K) if(O != K)
{ {
std::cerr << "SizeK must be equal to SizeO (equal attention head size)" << '\n'; std::cerr << "SizeK must be equal to SizeO (equal attention head size)" << '\n';
return false; return false;
} }
if(!(M == y_grid_desc_m_o.GetLength(I0) && Gemm1N == y_grid_desc_m_o.GetLength(I1))) if(!(M == y_grid_desc_m_o.GetLength(I0) && O == y_grid_desc_m_o.GetLength(I1)))
{ {
return false; return false;
} }
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0 && if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0 &&
Gemm1N % Gemm1NPerBlock == 0)) O % Gemm1NPerBlock == 0))
{ {
return false; return false;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment