Commit 72a345c6 authored by letaoqin's avatar letaoqin
Browse files

add d desc

parent e0d6326b
......@@ -498,6 +498,13 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
return Transform::MakeB0GridDescriptor_BK0_N_BK1(v_grid_desc_n_o, Number<V_O1>{});
}
// D in Gemm0 C position
static auto MakeDGridDescriptor_M_N(const std::vector<index_t>& d_gs_ms_ns_lengths_vec,
const std::vector<index_t>& d_gs_ms_ns_strides_vec)
{
return Transform::MakeCGridDescriptor_M_N(d_gs_ms_ns_lengths_vec, d_gs_ms_ns_strides_vec);
}
// Z in Gemm0 C position
static auto MakeZGridDescriptor_M_N(const std::vector<index_t>& z_gs_ms_ns_lengths_vec,
const std::vector<index_t>& z_gs_ms_ns_strides_vec)
......@@ -557,6 +564,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using D0GridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {}));
using YGridDesc_M_O = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
using LSEGridDesc_M = decltype(MakeLSEGridDescriptor_M(1));
......@@ -566,6 +574,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using D0GridDesc_M_N = decltype(MakeDGridDescriptor_M_N({}, {}));
using KGridDesc_N_K = decltype(Transform::MakeB0GridDescriptor_N_K({}, {}));
using YGradGridDesc_M0_O_M1 = decltype(MakeYGradGridDescriptor_M0_O_M1(YGridDesc_M_O{}));
using ZGridDesc_M_N = decltype(MakeZGridDescriptor_M_N({}, {}));
......@@ -756,6 +765,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
std::tuple<unsigned long long, unsigned long long> seeds)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_d0_grid_{p_acc0_biases},
p_z_grid_{p_z_grid},
p_b1_grid_{p_b1_grid},
p_c_grid_{p_c_grid},
......@@ -875,6 +885,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// pointers
const InputDataType* p_a_grid_;
const InputDataType* p_b_grid_;
const D0DataType* p_d0_grid_;
ZDataType* p_z_grid_;
const InputDataType* p_b1_grid_;
const InputDataType* p_c_grid_;
......@@ -887,6 +898,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// tensor descriptor
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
D0GridDesc_M_N d0_grid_desc_m_n_;
ZGridDesc_M_N z_grid_desc_m_n_;
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
YGridDesc_M_O y_grid_desc_m_o_;
......@@ -897,6 +909,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
// batch offsets
AGridDesc_G_M_K a_grid_desc_g_m_k_;
BGridDesc_G_N_K b_grid_desc_g_n_k_;
D0GridDesc_G_M_N d0_grid_desc_g_m_n_;
B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_;
ZGridDesc_G_M_N z_grid_desc_g_m_n_;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment