Commit af695bee authored by danyao12's avatar danyao12
Browse files

Merge branch 'mha-train-develop-bwdopt-bias' into mha-train-develop-dropout8bit

parents 8ced5c4f 9e527364
...@@ -265,26 +265,26 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -265,26 +265,26 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
const auto M = q_grid_desc_k0_m_k1.GetLength(I1); const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1); const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
const auto K = q_grid_desc_k0_m_k1.GetLength(I0) * q_grid_desc_k0_m_k1.GetLength(I2); const auto K = q_grid_desc_k0_m_k1.GetLength(I0) * q_grid_desc_k0_m_k1.GetLength(I2);
const auto Gemm1N = v_grid_desc_o0_n_o1.GetLength(I0) * v_grid_desc_o0_n_o1.GetLength(I2); const auto O = v_grid_desc_o0_n_o1.GetLength(I0) * v_grid_desc_o0_n_o1.GetLength(I2);
// This assumption reduces implemention complexity by categorizing 6 separate GEMMs into 3 // This assumption reduces implemention complexity by categorizing 6 separate GEMMs into 3
// types of GEMM operations, therefore some code body can be reused accordingly // types of GEMM operations, therefore some code body can be reused accordingly
// P_MNK / dP_MNO Gemm (Gemm0 rcr) // P_MNK / dP_MNO Gemm (Gemm0 rcr)
// Y_MON / dQ_MKN Gemm (Gemm1 rrr) // Y_MON / dQ_MKN Gemm (Gemm1 rrr)
// dV_NOM / dK_NKM Gemm (Gemm2 crr) // dV_NOM / dK_NKM Gemm (Gemm2 crr)
if(Gemm1N != K) if(O != K)
{ {
std::cerr << "SizeK must be equal to SizeO (equal attention head size)" << '\n'; std::cerr << "SizeK must be equal to SizeO (equal attention head size)" << '\n';
return false; return false;
} }
if(!(M == y_grid_desc_m_o.GetLength(I0) && Gemm1N == y_grid_desc_m_o.GetLength(I1))) if(!(M == y_grid_desc_m_o.GetLength(I0) && O == y_grid_desc_m_o.GetLength(I1)))
{ {
return false; return false;
} }
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0 && if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0 &&
Gemm1N % Gemm1NPerBlock == 0)) O % Gemm1NPerBlock == 0))
{ {
return false; return false;
} }
......
...@@ -115,7 +115,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -115,7 +115,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
// Gemm0 // Gemm0
static constexpr auto AK0 = Number<KPerBlock / AK1Value>{}; static constexpr auto AK0 = Number<KPerBlock / AK1Value>{};
static constexpr auto BK0 = Number<KPerBlock / BK1Value>{}; static constexpr auto BK0 = Number<KPerBlock / BK1Value>{};
static constexpr auto K_K0 = Number<Gemm1NPerBlock / BK1Value>{};
static constexpr auto AK1 = Number<AK1Value>{}; static constexpr auto AK1 = Number<AK1Value>{};
static constexpr auto BK1 = Number<BK1Value>{}; static constexpr auto BK1 = Number<BK1Value>{};
...@@ -127,6 +126,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -127,6 +126,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
static constexpr auto B1K1 = Number<B1K1Value>{}; static constexpr auto B1K1 = Number<B1K1Value>{};
static constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma; static constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma;
static constexpr auto K_K0 = Number<Gemm1NPerBlock / BK1Value>{};
static constexpr auto V_K3 = BK1; static constexpr auto V_K3 = BK1;
static constexpr auto V_K2 = mfma.num_input_blks; static constexpr auto V_K2 = mfma.num_input_blks;
static constexpr auto V_K1 = KPerBlock / V_K2 / V_K3; static constexpr auto V_K1 = KPerBlock / V_K2 / V_K3;
...@@ -310,26 +310,26 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -310,26 +310,26 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
const auto M = q_grid_desc_k0_m_k1.GetLength(I1); const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1); const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
const auto K = q_grid_desc_k0_m_k1.GetLength(I0) * q_grid_desc_k0_m_k1.GetLength(I2); const auto K = q_grid_desc_k0_m_k1.GetLength(I0) * q_grid_desc_k0_m_k1.GetLength(I2);
const auto Gemm1N = v_grid_desc_o0_n_o1.GetLength(I0) * v_grid_desc_o0_n_o1.GetLength(I2); const auto O = v_grid_desc_o0_n_o1.GetLength(I0) * v_grid_desc_o0_n_o1.GetLength(I2);
// This assumption reduces implemention complexity by categorizing 6 separate GEMMs into 3 // This assumption reduces implemention complexity by categorizing 6 separate GEMMs into 3
// types of GEMM operations, therefore some code body can be reused accordingly // types of GEMM operations, therefore some code body can be reused accordingly
// P_MNK / dP_MNO Gemm (Gemm0 rcr) // P_MNK / dP_MNO Gemm (Gemm0 rcr)
// Y_MON / dQ_MKN Gemm (Gemm1 rrr) // Y_MON / dQ_MKN Gemm (Gemm1 rrr)
// dV_NOM / dK_NKM Gemm (Gemm2 crr) // dV_NOM / dK_NKM Gemm (Gemm2 crr)
if(Gemm1N != K) if(O != K)
{ {
std::cerr << "SizeK must be equal to SizeO (equal attention head size)" << '\n'; std::cerr << "SizeK must be equal to SizeO (equal attention head size)" << '\n';
return false; return false;
} }
if(!(M == y_grid_desc_m_o.GetLength(I0) && Gemm1N == y_grid_desc_m_o.GetLength(I1))) if(!(M == y_grid_desc_m_o.GetLength(I0) && O == y_grid_desc_m_o.GetLength(I1)))
{ {
return false; return false;
} }
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0 && if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0 &&
Gemm1N % Gemm1NPerBlock == 0)) O % Gemm1NPerBlock == 0))
{ {
return false; return false;
} }
......
...@@ -264,26 +264,26 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -264,26 +264,26 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const auto M = q_grid_desc_k0_m_k1.GetLength(I1); const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1); const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
const auto K = q_grid_desc_k0_m_k1.GetLength(I0) * q_grid_desc_k0_m_k1.GetLength(I2); const auto K = q_grid_desc_k0_m_k1.GetLength(I0) * q_grid_desc_k0_m_k1.GetLength(I2);
const auto Gemm1N = v_grid_desc_o0_n_o1.GetLength(I0) * v_grid_desc_o0_n_o1.GetLength(I2); const auto O = v_grid_desc_o0_n_o1.GetLength(I0) * v_grid_desc_o0_n_o1.GetLength(I2);
// This assumption reduces implemention complexity by categorizing 6 separate GEMMs into 3 // This assumption reduces implemention complexity by categorizing 6 separate GEMMs into 3
// types of GEMM operations, therefore some code body can be reused accordingly // types of GEMM operations, therefore some code body can be reused accordingly
// P_MNK / dP_MNO Gemm (Gemm0 rcr) // P_MNK / dP_MNO Gemm (Gemm0 rcr)
// Y_MON / dQ_MKN Gemm (Gemm1 rrr) // Y_MON / dQ_MKN Gemm (Gemm1 rrr)
// dV_NOM / dK_NKM Gemm (Gemm2 crr) // dV_NOM / dK_NKM Gemm (Gemm2 crr)
if(Gemm1N != K) if(O != K)
{ {
std::cerr << "SizeK must be equal to SizeO (equal attention head size)" << '\n'; std::cerr << "SizeK must be equal to SizeO (equal attention head size)" << '\n';
return false; return false;
} }
if(!(M == y_grid_desc_m_o.GetLength(I0) && Gemm1N == y_grid_desc_m_o.GetLength(I1))) if(!(M == y_grid_desc_m_o.GetLength(I0) && O == y_grid_desc_m_o.GetLength(I1)))
{ {
return false; return false;
} }
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0 && if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0 &&
Gemm1N % Gemm1NPerBlock == 0)) O % Gemm1NPerBlock == 0))
{ {
return false; return false;
} }
......
...@@ -114,7 +114,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -114,7 +114,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// Gemm0 // Gemm0
static constexpr auto AK0 = Number<KPerBlock / AK1Value>{}; static constexpr auto AK0 = Number<KPerBlock / AK1Value>{};
static constexpr auto BK0 = Number<KPerBlock / BK1Value>{}; static constexpr auto BK0 = Number<KPerBlock / BK1Value>{};
static constexpr auto K_K0 = Number<Gemm1NPerBlock / BK1Value>{};
static constexpr auto AK1 = Number<AK1Value>{}; static constexpr auto AK1 = Number<AK1Value>{};
static constexpr auto BK1 = Number<BK1Value>{}; static constexpr auto BK1 = Number<BK1Value>{};
...@@ -126,6 +125,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -126,6 +125,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static constexpr auto B1K1 = Number<B1K1Value>{}; static constexpr auto B1K1 = Number<B1K1Value>{};
static constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma; static constexpr auto mfma = MfmaSelector<GemmDataType, MPerXdl, NPerXdl>::selected_mfma;
static constexpr auto K_K0 = Number<Gemm1NPerBlock / BK1Value>{};
static constexpr auto V_K3 = BK1; static constexpr auto V_K3 = BK1;
static constexpr auto V_K2 = mfma.num_input_blks; static constexpr auto V_K2 = mfma.num_input_blks;
static constexpr auto V_K1 = KPerBlock / V_K2 / V_K3; static constexpr auto V_K1 = KPerBlock / V_K2 / V_K3;
...@@ -309,26 +309,26 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -309,26 +309,26 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const auto M = q_grid_desc_k0_m_k1.GetLength(I1); const auto M = q_grid_desc_k0_m_k1.GetLength(I1);
const auto N = k_grid_desc_k0_n_k1.GetLength(I1); const auto N = k_grid_desc_k0_n_k1.GetLength(I1);
const auto K = q_grid_desc_k0_m_k1.GetLength(I0) * q_grid_desc_k0_m_k1.GetLength(I2); const auto K = q_grid_desc_k0_m_k1.GetLength(I0) * q_grid_desc_k0_m_k1.GetLength(I2);
const auto Gemm1N = v_grid_desc_o0_n_o1.GetLength(I0) * v_grid_desc_o0_n_o1.GetLength(I2); const auto O = v_grid_desc_o0_n_o1.GetLength(I0) * v_grid_desc_o0_n_o1.GetLength(I2);
// This assumption reduces implemention complexity by categorizing 6 separate GEMMs into 3 // This assumption reduces implemention complexity by categorizing 6 separate GEMMs into 3
// types of GEMM operations, therefore some code body can be reused accordingly // types of GEMM operations, therefore some code body can be reused accordingly
// P_MNK / dP_MNO Gemm (Gemm0 rcr) // P_MNK / dP_MNO Gemm (Gemm0 rcr)
// Y_MON / dQ_MKN Gemm (Gemm1 rrr) // Y_MON / dQ_MKN Gemm (Gemm1 rrr)
// dV_NOM / dK_NKM Gemm (Gemm2 crr) // dV_NOM / dK_NKM Gemm (Gemm2 crr)
if(Gemm1N != K) if(O != K)
{ {
std::cerr << "SizeK must be equal to SizeO (equal attention head size)" << '\n'; std::cerr << "SizeK must be equal to SizeO (equal attention head size)" << '\n';
return false; return false;
} }
if(!(M == y_grid_desc_m_o.GetLength(I0) && Gemm1N == y_grid_desc_m_o.GetLength(I1))) if(!(M == y_grid_desc_m_o.GetLength(I0) && O == y_grid_desc_m_o.GetLength(I1)))
{ {
return false; return false;
} }
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0 && if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0 &&
Gemm1N % Gemm1NPerBlock == 0)) O % Gemm1NPerBlock == 0))
{ {
return false; return false;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment