Unverified Commit 21ef37b4 authored by Dan Yao's avatar Dan Yao Committed by GitHub
Browse files

Merge pull request #889 from ROCmSoftwarePlatform/mha-train-develop-bwdopt-bias

Mha train develop bwdopt bias
parents 1f04cd2b db579ac9
......@@ -880,6 +880,18 @@ struct BlockwiseGemmXdlops_v2
b_thread_copy_.SetSrcCoord(b_origin);
}
template <typename SrcSliceMoveStepIdx>
__device__ void MoveABlockSrcSliceWindow(const SrcSliceMoveStepIdx& src_slice_move_step_idx)
{
a_thread_copy_.MoveSrcSliceWindow(a_block_desc_m0_m1_m2_k, src_slice_move_step_idx);
}
template <typename SrcSliceMoveStepIdx>
__device__ void MoveBBlockSrcSliceWindow(const SrcSliceMoveStepIdx& src_slice_move_step_idx)
{
b_thread_copy_.MoveSrcSliceWindow(b_block_desc_n0_n1_n2_k, src_slice_move_step_idx);
}
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
__host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
{
......
......@@ -273,6 +273,7 @@ template <index_t NumDimG,
index_t KPerBlock, // Gemm0KPerBlock
index_t Gemm1NPerBlock,
index_t Gemm1KPerBlock,
index_t Gemm2KPerBlock,
index_t AK1,
index_t BK1,
index_t B1K1,
......@@ -322,9 +323,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr index_t V_O1 = 8;
static constexpr index_t Y_O1 = 8;
static constexpr index_t Y_M1 = 2;
static constexpr index_t V_O1 = BK1;
static constexpr index_t Y_O1 = AK1;
static constexpr index_t Y_M1 = B1K1;
static constexpr auto padder = GemmGemmPadder<GemmSpec,
Number<MPerBlock>,
......@@ -565,8 +566,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
return lse_grid_desc_mraw;
}
}
// D in Gemm0 C position
static auto MakeDGridDescriptor_M_N(const std::vector<index_t>& d_gs_ms_ns_lengths,
// D0 in Gemm0 C position
static auto MakeD0GridDescriptor_M_N(const std::vector<index_t>& d_gs_ms_ns_lengths,
const std::vector<index_t>& d_gs_ms_ns_strides)
{
return Transform::MakeCGridDescriptor_M_N(d_gs_ms_ns_lengths, d_gs_ms_ns_strides);
......@@ -584,7 +585,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using D0GridDesc_M_N = decltype(MakeDGridDescriptor_M_N({}, {}));
using D0GridDesc_M_N = decltype(MakeD0GridDescriptor_M_N({}, {}));
using KGridDesc_N_K = decltype(Transform::MakeB0GridDescriptor_N_K({}, {}));
using YGradGridDesc_O0_M_O1 = decltype(MakeYGradGridDescriptor_O0_M_O1({}, {}));
using ZGridDesc_M_N = decltype(MakeZGridDescriptor_M_N({}, {}));
......@@ -703,6 +704,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
KPerBlock,
Gemm1NPerBlock,
Gemm1KPerBlock,
Gemm2KPerBlock,
AK1,
BK1,
B1K1,
......@@ -855,8 +857,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
}
if constexpr(!is_same<D0DataType, void>::value)
{
const auto d0_grid_desc_m_n =
MakeDGridDescriptor_M_N(acc0_bias_gs_ms_ns_lengths, acc0_bias_gs_ms_ns_strides);
const auto d0_grid_desc_m_n = MakeD0GridDescriptor_M_N(acc0_bias_gs_ms_ns_lengths,
acc0_bias_gs_ms_ns_strides);
d0_grid_desc_m0_n0_m1_m2_n1_m3_ =
GridwiseGemm::MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(d0_grid_desc_m_n);
......@@ -1361,6 +1363,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
<< MPerBlock << ", "
<< Gemm1NPerBlock << ", "
<< Gemm1KPerBlock << ", "
<< Gemm2KPerBlock << ", "
<< B1K1 << ", "
<< getGemmSpecializationString(GemmSpec) << ", "
<< "ASpec" << getTensorSpecializationString(ASpec) << ", "
......
......@@ -274,6 +274,7 @@ template <index_t NumDimG,
index_t KPerBlock, // Gemm0KPerBlock
index_t Gemm1NPerBlock,
index_t Gemm1KPerBlock,
index_t Gemm2KPerBlock,
index_t AK1,
index_t BK1,
index_t B1K1,
......@@ -330,9 +331,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr index_t V_O1 = 8;
static constexpr index_t Y_O1 = 8;
static constexpr index_t Y_M1 = 2;
static constexpr index_t V_O1 = BK1;
static constexpr index_t Y_O1 = AK1;
static constexpr index_t Y_M1 = B1K1;
static constexpr auto padder = GemmGemmPadder<GemmSpec,
Number<MPerBlock>,
......@@ -517,8 +518,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
return Transform::MakeB0GridDescriptor_BK0_N_BK1(v_grid_desc_n_o, Number<V_O1>{});
}
// D in Gemm0 C position
static auto MakeDGridDescriptor_M_N(const std::vector<index_t>& d_gs_ms_ns_lengths,
// D0 in Gemm0 C position
static auto MakeD0GridDescriptor_M_N(const std::vector<index_t>& d_gs_ms_ns_lengths,
const std::vector<index_t>& d_gs_ms_ns_strides)
{
return Transform::MakeCGridDescriptor_M_N(d_gs_ms_ns_lengths, d_gs_ms_ns_strides);
......@@ -584,7 +585,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using D0GridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {}));
using B1GridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using YGridDesc_M_O = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
using LSEGridDesc_M = decltype(MakeLSEGridDescriptor_M(1));
using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {}));
......@@ -593,7 +594,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using D0GridDesc_M_N = decltype(MakeDGridDescriptor_M_N({}, {}));
using D0GridDesc_M_N = decltype(MakeD0GridDescriptor_M_N({}, {}));
using KGridDesc_N_K = decltype(Transform::MakeB0GridDescriptor_N_K({}, {}));
using YGradGridDesc_M0_O_M1 = decltype(MakeYGradGridDescriptor_M0_O_M1(YGridDesc_M_O{}));
using ZGridDesc_M_N = decltype(MakeZGridDescriptor_M_N({}, {}));
......@@ -711,6 +712,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
KPerBlock,
Gemm1NPerBlock,
Gemm1KPerBlock,
Gemm2KPerBlock,
AK1,
BK1,
B1K1,
......@@ -809,7 +811,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
b_grid_desc_bk0_n_bk1_{
DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)},
z_grid_desc_m_n_{MakeZGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)},
b1_grid_desc_bk0_n_bk1_{DeviceOp::MakeB1GridDescriptor_BK0_N_BK1(
b1_grid_desc_bk0_n_bk1_{DeviceOp::MakeVGridDescriptor_O0_N_O1(
b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
y_grid_desc_m_o_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)},
......@@ -868,8 +870,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
if constexpr(!is_same<D0DataType, void>::value)
{
const auto d0_grid_desc_m_n =
MakeDGridDescriptor_M_N(acc0_bias_gs_ms_ns_lengths, acc0_bias_gs_ms_ns_strides);
const auto d0_grid_desc_m_n = MakeD0GridDescriptor_M_N(acc0_bias_gs_ms_ns_lengths,
acc0_bias_gs_ms_ns_strides);
d0_grid_desc_m0_n0_m1_m2_n1_m3_ =
GridwiseGemm::MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(d0_grid_desc_m_n);
......@@ -1148,7 +1150,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const index_t c_m = arg.y_grid_desc_m_o_.GetLength(I0);
const index_t c_gemm1n = arg.y_grid_desc_m_o_.GetLength(I1);
const index_t a_m = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
const index_t b1_gemm1n = arg.b1_grid_desc_bk0_n_bk1_.GetLength(I1);
const index_t b1_gemm1n =
arg.b1_grid_desc_bk0_n_bk1_.GetLength(I0) * arg.b1_grid_desc_bk0_n_bk1_.GetLength(I2);
if(!(c_g == arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n))
{
......@@ -1393,6 +1396,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
<< MPerBlock << ", "
<< Gemm1NPerBlock << ", "
<< Gemm1KPerBlock << ", "
<< Gemm2KPerBlock << ", "
<< B1K1 << ", "
<< getGemmSpecializationString(GemmSpec) << ", "
<< "ASpec" << getTensorSpecializationString(ASpec) << ", "
......
......@@ -236,6 +236,7 @@ template <index_t NumDimG,
index_t KPerBlock, // Gemm0KPerBlock
index_t Gemm1NPerBlock,
index_t Gemm1KPerBlock,
index_t Gemm2KPerBlock,
index_t AK1,
index_t BK1,
index_t B1K1,
......@@ -310,9 +311,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr index_t V_O1 = 8;
static constexpr index_t Y_O1 = 8;
static constexpr index_t Y_M1 = 2;
static constexpr index_t V_O1 = BK1;
static constexpr index_t Y_O1 = AK1;
static constexpr index_t Y_M1 = B1K1;
static constexpr auto padder = GemmGemmPadder<GemmSpec,
Number<MPerBlock>,
......@@ -497,7 +498,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
return lse_grid_desc_mraw;
}
}
// D in Gemm0 C position
// D0 in Gemm0 C position
static auto MakeD0GridDescriptor_M_N(const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides)
{
......@@ -644,6 +645,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
KPerBlock,
Gemm1NPerBlock,
Gemm1KPerBlock,
Gemm2KPerBlock,
AK1,
BK1,
B1K1,
......@@ -1314,6 +1316,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
<< MPerBlock << ", "
<< Gemm1NPerBlock << ", "
<< Gemm1KPerBlock << ", "
<< Gemm2KPerBlock << ", "
<< B1K1 << ", "
<< getGemmSpecializationString(GemmSpec) << ", "
<< "ASpec" << getTensorSpecializationString(ASpec) << ", "
......
......@@ -236,6 +236,7 @@ template <index_t NumDimG,
index_t KPerBlock, // Gemm0KPerBlock
index_t Gemm1NPerBlock,
index_t Gemm1KPerBlock,
index_t Gemm2KPerBlock,
index_t AK1,
index_t BK1,
index_t B1K1,
......@@ -317,9 +318,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr index_t V_O1 = 8;
static constexpr index_t Y_O1 = 8;
static constexpr index_t Y_M1 = 2;
static constexpr index_t V_O1 = BK1;
static constexpr index_t Y_O1 = AK1;
static constexpr index_t Y_M1 = B1K1;
static constexpr auto padder = GemmGemmPadder<GemmSpec,
Number<MPerBlock>,
......@@ -441,6 +442,69 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
//
// dP = dY * V^T
//
// YGrad in Gemm A position
static auto MakeYGradGridDescriptor_O0_M_O1(const std::vector<index_t>& y_gs_ms_os_lengths,
const std::vector<index_t>& y_gs_ms_os_strides)
{
return Transform::MakeAGridDescriptor_AK0_M_AK1(
Transform::MakeAGridDescriptor_M_K(y_gs_ms_os_lengths, y_gs_ms_os_strides),
Number<Y_O1>{});
}
// V in Gemm B position
static auto MakeVGridDescriptor_O0_N_O1(const std::vector<index_t>& v_gs_os_ns_lengths,
const std::vector<index_t>& v_gs_os_ns_strides)
{
// v_gs_os_ns -> vgrad_gs_ns_os. O dims last because output is row-major.
// Here directly rearrange lengths/strides before constructing tensor descriptor to reduce
// transformation overhead
// TODO: This will be much easier when inputs are Gs, Ms, Ns, Os. So there's no need to
// extract subsequence and shuffle them.
const index_t num_dims = NumDimG + NumDimN + NumDimO;
// 0, 1, .. NumDimG - 1
std::vector<index_t> gs_ids(NumDimG);
std::iota(gs_ids.begin(), gs_ids.end(), 0);
// NumDimG, NumDimG + 1, ... NumDimG + NumDimO - 1
std::vector<index_t> os_ids(NumDimO);
std::iota(os_ids.begin(), os_ids.end(), NumDimG);
// NumDimG + NumDimO, NumDimG + NumDimO + 1, ... NumDimG + NumDimO + NumDimN - 1
std::vector<index_t> ns_ids(NumDimN);
std::iota(ns_ids.begin(), ns_ids.end(), NumDimG + NumDimO);
std::vector<index_t> ids_old2new;
ids_old2new.insert(ids_old2new.end(), gs_ids.begin(), gs_ids.end());
ids_old2new.insert(ids_old2new.end(), ns_ids.begin(), ns_ids.end());
ids_old2new.insert(ids_old2new.end(), os_ids.begin(), os_ids.end());
std::vector<index_t> v_gs_ns_os_lengths(num_dims), v_gs_ns_os_strides(num_dims);
for(int i = 0; i < num_dims; i++)
{
index_t id_new = ids_old2new[i];
v_gs_ns_os_lengths[i] = v_gs_os_ns_lengths[id_new];
v_gs_ns_os_strides[i] = v_gs_os_ns_strides[id_new];
}
const auto v_grid_desc_nraw_oraw =
MakeGridDescriptorPair<NumDimG, NumDimN, NumDimO, TensorSpecialization::Default>(
v_gs_ns_os_lengths, v_gs_ns_os_strides)
.second;
const auto v_grid_desc_n_o = PadTensorDescriptor(v_grid_desc_nraw_oraw,
make_tuple(NPerBlock, Gemm1NPerBlock),
Sequence<padder.PadN, padder.PadO>{});
// N_O to O0_N_O1; to refactor
return Transform::MakeB0GridDescriptor_BK0_N_BK1(v_grid_desc_n_o, Number<V_O1>{});
}
//
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
//
......@@ -497,7 +561,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
return lse_grid_desc_mraw;
}
}
// D in Gemm0 C position
// D0 in Gemm0 C position
static auto MakeD0GridDescriptor_M_N(const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides)
{
......@@ -517,7 +581,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {}));
using B1GridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using YGridDesc_M_O = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
using LSEGridDesc_M = decltype(MakeLSEGridDescriptor_M(1));
using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {}));
......@@ -644,6 +708,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
KPerBlock,
Gemm1NPerBlock,
Gemm1KPerBlock,
Gemm2KPerBlock,
AK1,
BK1,
B1K1,
......@@ -849,7 +914,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
GridwiseGemm::MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(d0_grid_desc_m_n);
const auto z_grid_desc_m_n = DeviceOp::MakeZGridDescriptor_M_N(
problem_desc.z_gs_ms_ns_lengths, problem_desc.z_gs_ms_ns_strides);
const auto b1_grid_desc_bk0_n_bk1 = DeviceOp::MakeB1GridDescriptor_BK0_N_BK1(
const auto b1_grid_desc_bk0_n_bk1 = DeviceOp::MakeVGridDescriptor_O0_N_O1(
problem_desc.b1_gs_gemm1ns_gemm1ks_lengths,
problem_desc.b1_gs_gemm1ns_gemm1ks_strides);
const auto y_grid_desc_m_o = Transform::MakeCGridDescriptor_M_N(
......@@ -1129,7 +1194,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const index_t c_m = kernel_arg.y_grid_desc_m_o_.GetLength(I0);
const index_t c_gemm1n = kernel_arg.y_grid_desc_m_o_.GetLength(I1);
const index_t a_m = kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
const index_t b1_gemm1n = kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I1);
const index_t b1_gemm1n = kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I0) *
kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I2);
if(!(c_g == device_arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n))
{
......@@ -1326,6 +1392,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
<< MPerBlock << ", "
<< Gemm1NPerBlock << ", "
<< Gemm1KPerBlock << ", "
<< Gemm2KPerBlock << ", "
<< B1K1 << ", "
<< getGemmSpecializationString(GemmSpec) << ", "
<< "ASpec" << getTensorSpecializationString(ASpec) << ", "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment