Commit 774c9209 authored by letaoqin's avatar letaoqin
Browse files

change make c grid desc to c0 grid desc, because c is mxo, c0 is mxn

parent 522d8b2f
......@@ -591,7 +591,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides)
{
return Transform::MakeCGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
return Transform::MakeC0GridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
}
//
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
......@@ -647,7 +647,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
static auto MakeD0GridDescriptor_M_N(const std::vector<index_t>& d_gs_ms_ns_lengths,
const std::vector<index_t>& d_gs_ms_ns_strides)
{
return Transform::MakeCGridDescriptor_M_N(d_gs_ms_ns_lengths, d_gs_ms_ns_strides);
return Transform::MakeC0GridDescriptor_M_N(d_gs_ms_ns_lengths, d_gs_ms_ns_strides);
}
static auto MakeDGridDescriptor_M(index_t MRaw)
......@@ -677,7 +677,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using D0GridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using D0GridDesc_G_M_N = decltype(Transform::MakeC0GridDescriptor_G_M_N({}, {}));
using B1GridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using YGridDesc_M_O = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
using LSEGridDesc_M = decltype(MakeLSEGridDescriptor_M(1));
......@@ -685,7 +685,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {}));
using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeC0GridDescriptor_G_M_N({}, {}));
using DYGridDesc_M_O = decltype(DTransform::MakeCGridDescriptor_M_N({}, {}));
using DGridDesc_M = decltype(MakeDGridDescriptor_M(1));
......@@ -936,7 +936,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
c_grid_desc_g_m_n_{Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)},
z_grid_desc_g_m_n_{
Transform::MakeCGridDescriptor_G_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)},
Transform::MakeC0GridDescriptor_G_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(k_grid_desc_n_k_)},
d_block_2_ctile_map_{
GridwiseYDotYGrad::MakeDefaultBlock2CTileMap(d_y_grid_desc_m_o_)},
......@@ -975,7 +975,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
d0_grid_desc_m0_n0_m1_m2_n1_m3_ =
GridwiseGemm::MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(d0_grid_desc_m_n);
d0_grid_desc_g_m_n_ = Transform::MakeCGridDescriptor_G_M_N(
d0_grid_desc_g_m_n_ = Transform::MakeC0GridDescriptor_G_M_N(
acc0_bias_gs_ms_ns_lengths, acc0_bias_gs_ms_ns_strides);
d0_n_length_stride_.push_back(acc0_bias_gs_ms_ns_lengths[NumDimG + NumDimM]);
......
......@@ -599,14 +599,14 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
static auto MakeD0GridDescriptor_M_N(const std::vector<index_t>& d_gs_ms_ns_lengths,
const std::vector<index_t>& d_gs_ms_ns_strides)
{
return Transform::MakeCGridDescriptor_M_N(d_gs_ms_ns_lengths, d_gs_ms_ns_strides);
return Transform::MakeC0GridDescriptor_M_N(d_gs_ms_ns_lengths, d_gs_ms_ns_strides);
}
// Z in Gemm0 C position
static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides)
{
return Transform::MakeCGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
return Transform::MakeC0GridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
}
//
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
......@@ -686,7 +686,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using D0GridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using D0GridDesc_G_M_N = decltype(Transform::MakeC0GridDescriptor_G_M_N({}, {}));
using B1GridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using YGridDesc_M_O = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
using LSEGridDesc_M = decltype(MakeLSEGridDescriptor_M(1));
......@@ -694,7 +694,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {}));
using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeC0GridDescriptor_G_M_N({}, {}));
using DYGridDesc_M_O = decltype(DTransform::MakeCGridDescriptor_M_N({}, {}));
using DGridDesc_M = decltype(MakeDGridDescriptor_M(1));
......@@ -951,7 +951,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
c_grid_desc_g_m_n_{Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)},
z_grid_desc_g_m_n_{
Transform::MakeCGridDescriptor_G_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)},
Transform::MakeC0GridDescriptor_G_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(k_grid_desc_n_k_)},
d_block_2_ctile_map_{
GridwiseYDotYGrad::MakeDefaultBlock2CTileMap(d_y_grid_desc_m_o_)},
......@@ -990,7 +990,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
d0_grid_desc_m0_n0_m1_m2_n1_m3_ =
GridwiseGemm::MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(d0_grid_desc_m_n);
d0_grid_desc_g_m_n_ = Transform::MakeCGridDescriptor_G_M_N(
d0_grid_desc_g_m_n_ = Transform::MakeC0GridDescriptor_G_M_N(
acc0_bias_gs_ms_ns_lengths, acc0_bias_gs_ms_ns_strides);
d0_n_length_stride_.push_back(acc0_bias_gs_ms_ns_lengths[NumDimG + NumDimM]);
......
......@@ -526,7 +526,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides)
{
return Transform::MakeCGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
return Transform::MakeC0GridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
}
//
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
......@@ -582,12 +582,12 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static auto MakeD0GridDescriptor_M_N(const std::vector<index_t>& d_gs_ms_ns_lengths,
const std::vector<index_t>& d_gs_ms_ns_strides)
{
return Transform::MakeCGridDescriptor_M_N(d_gs_ms_ns_lengths, d_gs_ms_ns_strides);
return Transform::MakeC0GridDescriptor_M_N(d_gs_ms_ns_lengths, d_gs_ms_ns_strides);
}
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using D0GridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using D0GridDesc_G_M_N = decltype(Transform::MakeC0GridDescriptor_G_M_N({}, {}));
using B1GridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using YGridDesc_M_O = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
using LSEGridDesc_M = decltype(MakeLSEGridDescriptor_M(1));
......@@ -595,7 +595,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {}));
using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeC0GridDescriptor_G_M_N({}, {}));
using D0GridDesc_M_N = decltype(MakeD0GridDescriptor_M_N({}, {}));
using KGridDesc_N_K = decltype(Transform::MakeB0GridDescriptor_N_K({}, {}));
......@@ -829,7 +829,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
c_grid_desc_g_m_n_{Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)},
z_grid_desc_g_m_n_{
Transform::MakeCGridDescriptor_G_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)},
Transform::MakeC0GridDescriptor_G_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)},
y_grid_desc_mblock_mperblock_oblock_operblock_{},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(k_grid_desc_n_k_)},
a_element_op_{a_element_op},
......@@ -875,7 +875,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
d0_grid_desc_m0_n0_m1_m2_n1_m3_ =
GridwiseGemm::MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(d0_grid_desc_m_n);
d0_grid_desc_g_m_n_ = Transform::MakeCGridDescriptor_G_M_N(
d0_grid_desc_g_m_n_ = Transform::MakeC0GridDescriptor_G_M_N(
acc0_bias_gs_ms_ns_lengths, acc0_bias_gs_ms_ns_strides);
d0_n_length_stride_.push_back(acc0_bias_gs_ms_ns_lengths[NumDimG + NumDimM]);
......
......@@ -534,14 +534,14 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static auto MakeD0GridDescriptor_M_N(const std::vector<index_t>& d_gs_ms_ns_lengths,
const std::vector<index_t>& d_gs_ms_ns_strides)
{
return Transform::MakeCGridDescriptor_M_N(d_gs_ms_ns_lengths, d_gs_ms_ns_strides);
return Transform::MakeC0GridDescriptor_M_N(d_gs_ms_ns_lengths, d_gs_ms_ns_strides);
}
// Z in Gemm0 C position
static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides)
{
return Transform::MakeCGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
return Transform::MakeC0GridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
}
//
// dS_i_j = P_i_j .* (dP_i_j - dY_i dot Y_i)
......@@ -596,7 +596,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using D0GridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using D0GridDesc_G_M_N = decltype(Transform::MakeC0GridDescriptor_G_M_N({}, {}));
using B1GridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using YGridDesc_M_O = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
using LSEGridDesc_M = decltype(MakeLSEGridDescriptor_M(1));
......@@ -604,7 +604,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {}));
using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeC0GridDescriptor_G_M_N({}, {}));
using D0GridDesc_M_N = decltype(MakeD0GridDescriptor_M_N({}, {}));
using KGridDesc_N_K = decltype(Transform::MakeB0GridDescriptor_N_K({}, {}));
......@@ -844,7 +844,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
c_grid_desc_g_m_n_{Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)},
z_grid_desc_g_m_n_{
Transform::MakeCGridDescriptor_G_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)},
Transform::MakeC0GridDescriptor_G_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides)},
y_grid_desc_mblock_mperblock_oblock_operblock_{},
block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(k_grid_desc_n_k_)},
a_element_op_{a_element_op},
......@@ -891,7 +891,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
d0_grid_desc_m0_n0_m1_m2_n1_m3_ =
GridwiseGemm::MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(d0_grid_desc_m_n);
d0_grid_desc_g_m_n_ = Transform::MakeCGridDescriptor_G_M_N(
d0_grid_desc_g_m_n_ = Transform::MakeC0GridDescriptor_G_M_N(
acc0_bias_gs_ms_ns_lengths, acc0_bias_gs_ms_ns_strides);
d0_n_length_stride_.push_back(acc0_bias_gs_ms_ns_lengths[NumDimG + NumDimM]);
......@@ -941,6 +941,10 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
std::cout << "ygrad_grid_desc_m0_o_m1_: " << ygrad_grid_desc_m0_o_m1_.GetLength(I0)
<< ", " << ygrad_grid_desc_m0_o_m1_.GetLength(I1) << ", "
<< ygrad_grid_desc_m0_o_m1_.GetLength(I2) << '\n';
std::cout << "d0_grid_desc_g_m_n_: " << d0_grid_desc_g_m_n_.GetLength(I0) << ", "
<< d0_grid_desc_g_m_n_.GetLength(I1) << ", "
<< d0_grid_desc_g_m_n_.GetLength(I2) << '\n';
}
// pointers
......
......@@ -545,7 +545,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides)
{
return Transform::MakeCGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
return Transform::MakeC0GridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
}
static auto MakeLSEGridDescriptor_M(index_t MRaw)
......@@ -577,7 +577,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides)
{
return Transform::MakeCGridDescriptor_M_N(acc0_bias_gs_ms_ns_lengths,
return Transform::MakeC0GridDescriptor_M_N(acc0_bias_gs_ms_ns_lengths,
acc0_bias_gs_ms_ns_strides);
}
......@@ -586,7 +586,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides)
{
return Transform::MakeCGridDescriptor_G_M_N(acc0_bias_gs_ms_ns_lengths,
return Transform::MakeC0GridDescriptor_G_M_N(acc0_bias_gs_ms_ns_lengths,
acc0_bias_gs_ms_ns_strides);
}
......@@ -998,7 +998,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
problem_desc.b_gs_ns_ks_lengths, problem_desc.b_gs_ns_ks_strides);
const auto d0_grid_desc_g_m_n = DeviceOp::MakeD0GridDescriptor_G_M_N(
tmp_d0_gs_ms_ns_lengths, tmp_d0_gs_ms_ns_strides);
const auto z_grid_desc_g_m_n = Transform::MakeCGridDescriptor_G_M_N(
const auto z_grid_desc_g_m_n = Transform::MakeC0GridDescriptor_G_M_N(
problem_desc.z_gs_ms_ns_lengths, problem_desc.z_gs_ms_ns_strides);
const auto b1_grid_desc_g_n_k = Transform::MakeB1GridDescriptor_G_N_K(
problem_desc.b1_gs_gemm1ns_gemm1ks_lengths,
......
......@@ -608,7 +608,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides)
{
return Transform::MakeCGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
return Transform::MakeC0GridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
}
static auto MakeLSEGridDescriptor_M(index_t MRaw)
......@@ -640,7 +640,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides)
{
return Transform::MakeCGridDescriptor_M_N(acc0_bias_gs_ms_ns_lengths,
return Transform::MakeC0GridDescriptor_M_N(acc0_bias_gs_ms_ns_lengths,
acc0_bias_gs_ms_ns_strides);
}
......@@ -649,7 +649,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides)
{
return Transform::MakeCGridDescriptor_G_M_N(acc0_bias_gs_ms_ns_lengths,
return Transform::MakeC0GridDescriptor_G_M_N(acc0_bias_gs_ms_ns_lengths,
acc0_bias_gs_ms_ns_strides);
}
......@@ -688,7 +688,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
using D0GridDesc_G_M_N = decltype(MakeD0GridDescriptor_G_M_N({}, {}));
using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeC0GridDescriptor_G_M_N({}, {}));
using KGridDesc_N_K = decltype(Transform::MakeB0GridDescriptor_N_K({}, {}));
using D0GridDesc_M_N = decltype(MakeD0GridDescriptor_M_N({}, {}));
......@@ -1069,7 +1069,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
problem_desc.b_gs_ns_ks_lengths, problem_desc.b_gs_ns_ks_strides);
const auto d0_grid_desc_g_m_n = DeviceOp::MakeD0GridDescriptor_G_M_N(
tmp_d0_gs_ms_ns_lengths, tmp_d0_gs_ms_ns_strides);
const auto z_grid_desc_g_m_n = Transform::MakeCGridDescriptor_G_M_N(
const auto z_grid_desc_g_m_n = Transform::MakeC0GridDescriptor_G_M_N(
problem_desc.z_gs_ms_ns_lengths, problem_desc.z_gs_ms_ns_strides);
const auto b1_grid_desc_g_n_k = Transform::MakeB1GridDescriptor_G_N_K(
problem_desc.b1_gs_gemm1ns_gemm1ks_lengths,
......
......@@ -477,7 +477,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides)
{
return Transform::MakeCGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
return Transform::MakeC0GridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
}
static auto MakeLSEGridDescriptor_M(index_t MRaw)
......@@ -509,7 +509,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides)
{
return Transform::MakeCGridDescriptor_M_N(acc0_bias_gs_ms_ns_lengths,
return Transform::MakeC0GridDescriptor_M_N(acc0_bias_gs_ms_ns_lengths,
acc0_bias_gs_ms_ns_strides);
}
......@@ -518,7 +518,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides)
{
return Transform::MakeCGridDescriptor_G_M_N(acc0_bias_gs_ms_ns_lengths,
return Transform::MakeC0GridDescriptor_G_M_N(acc0_bias_gs_ms_ns_lengths,
acc0_bias_gs_ms_ns_strides);
}
......@@ -532,7 +532,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
using D0GridDesc_G_M_N = decltype(MakeD0GridDescriptor_G_M_N({}, {}));
using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeC0GridDescriptor_G_M_N({}, {}));
using KGridDesc_N_K = decltype(Transform::MakeB0GridDescriptor_N_K({}, {}));
using D0GridDesc_M_N = decltype(MakeD0GridDescriptor_M_N({}, {}));
......@@ -878,7 +878,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
problem_desc.b_gs_ns_ks_lengths, problem_desc.b_gs_ns_ks_strides);
const auto d0_grid_desc_g_m_n = DeviceOp::MakeD0GridDescriptor_G_M_N(
tmp_d0_gs_ms_ns_lengths, tmp_d0_gs_ms_ns_strides);
const auto z_grid_desc_g_m_n = Transform::MakeCGridDescriptor_G_M_N(
const auto z_grid_desc_g_m_n = Transform::MakeC0GridDescriptor_G_M_N(
problem_desc.z_gs_ms_ns_lengths, problem_desc.z_gs_ms_ns_strides);
const auto b1_grid_desc_g_n_k = Transform::MakeB1GridDescriptor_G_N_K(
problem_desc.b1_gs_gemm1ns_gemm1ks_lengths,
......
......@@ -539,7 +539,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides)
{
return Transform::MakeCGridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
return Transform::MakeC0GridDescriptor_M_N(z_gs_ms_ns_lengths, z_gs_ms_ns_strides);
}
static auto MakeLSEGridDescriptor_M(index_t MRaw)
......@@ -571,7 +571,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides)
{
return Transform::MakeCGridDescriptor_M_N(acc0_bias_gs_ms_ns_lengths,
return Transform::MakeC0GridDescriptor_M_N(acc0_bias_gs_ms_ns_lengths,
acc0_bias_gs_ms_ns_strides);
}
......@@ -580,7 +580,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const std::vector<ck::index_t>& acc0_bias_gs_ms_ns_strides)
{
return Transform::MakeCGridDescriptor_G_M_N(acc0_bias_gs_ms_ns_lengths,
return Transform::MakeC0GridDescriptor_G_M_N(acc0_bias_gs_ms_ns_lengths,
acc0_bias_gs_ms_ns_strides);
}
......@@ -594,7 +594,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
using D0GridDesc_G_M_N = decltype(MakeD0GridDescriptor_G_M_N({}, {}));
using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeC0GridDescriptor_G_M_N({}, {}));
using KGridDesc_N_K = decltype(Transform::MakeB0GridDescriptor_N_K({}, {}));
using D0GridDesc_M_N = decltype(MakeD0GridDescriptor_M_N({}, {}));
......@@ -948,7 +948,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
problem_desc.b_gs_ns_ks_lengths, problem_desc.b_gs_ns_ks_strides);
const auto d0_grid_desc_g_m_n = DeviceOp::MakeD0GridDescriptor_G_M_N(
tmp_d0_gs_ms_ns_lengths, tmp_d0_gs_ms_ns_strides);
const auto z_grid_desc_g_m_n = Transform::MakeCGridDescriptor_G_M_N(
const auto z_grid_desc_g_m_n = Transform::MakeC0GridDescriptor_G_M_N(
problem_desc.z_gs_ms_ns_lengths, problem_desc.z_gs_ms_ns_strides);
const auto b1_grid_desc_g_n_k = Transform::MakeB1GridDescriptor_G_N_K(
problem_desc.b1_gs_gemm1ns_gemm1ks_lengths,
......
......@@ -119,6 +119,15 @@ struct GemmGemmPadder
c_desc_mraw_nraw, make_tuple(MPerTile_, OPerTile_), Sequence<PadM, PadO>{});
}
// C[M, Gemm1N] = C[M, N]
template <typename C0Desc_MRaw_NRaw>
__host__ __device__ constexpr auto
PadC0Descriptor_M_N(const C0Desc_MRaw_NRaw& c_desc_mraw_nraw) const
{
return PadTensorDescriptor(
c_desc_mraw_nraw, make_tuple(MPerTile_, NPerTile_), Sequence<PadM, PadN>{});
}
MPerTileType MPerTile_;
NPerTileType NPerTile_;
KPerTileType KPerTile_;
......
......@@ -282,6 +282,29 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
return matrix_padder.PadCDescriptor_M_N(
MakeCGridDescriptorPair(c_gs_ms_os_lengths_vec, c_gs_ms_os_strides_vec).second);
}
//
// C0
//
static auto MakeC0GridDescriptorPair(const std::vector<index_t>& c_gs_ms_ns_lengths_vec,
const std::vector<index_t>& c_gs_ms_ns_strides_vec)
{
return MakeGridDescriptorPair<NumDimG, NumDimM, NumDimN, CSpec>(c_gs_ms_ns_lengths_vec,
c_gs_ms_ns_strides_vec);
}
// TODO: rename to G_MRaw_NRaw
static auto MakeC0GridDescriptor_G_M_N(const std::vector<index_t>& c_gs_ms_ns_lengths_vec,
const std::vector<index_t>& c_gs_ms_ns_strides_vec)
{
return MakeC0GridDescriptorPair(c_gs_ms_ns_lengths_vec, c_gs_ms_ns_strides_vec).first;
}
static auto MakeC0GridDescriptor_M_N(const std::vector<index_t>& c_gs_ms_ns_lengths_vec,
const std::vector<index_t>& c_gs_ms_ns_strides_vec)
{
return matrix_padder.PadC0Descriptor_M_N(
MakeC0GridDescriptorPair(c_gs_ms_ns_lengths_vec, c_gs_ms_ns_strides_vec).second);
}
};
} // namespace tensor_operation
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment