Commit 51e102e5 authored by guangzlu's avatar guangzlu
Browse files

moidfied arg names in bwd qloop

parent 7d6a8ec7
......@@ -123,7 +123,7 @@ __global__ void
c_element_op,
arg_ptr[group_id].a_grid_desc_ak0_m_ak1_,
arg_ptr[group_id].b_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].y_grid_desc_mblock_mperblock_oblock_operblock_,
arg_ptr[group_id].lse_grid_desc_m_,
......@@ -159,7 +159,7 @@ __global__ void
c_element_op,
arg_ptr[group_id].a_grid_desc_ak0_m_ak1_,
arg_ptr[group_id].b_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].y_grid_desc_mblock_mperblock_oblock_operblock_,
arg_ptr[group_id].lse_grid_desc_m_,
......@@ -661,7 +661,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
y_grid_desc_mblock_mperblock_oblock_operblock_;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_;
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_;
LSEGridDesc_M lse_grid_desc_m_;
KGridDesc_N_K k_grid_desc_n_k_;
YGradGridDesc_O0_M_O1 ygrad_grid_desc_o0_m_o1_;
......@@ -800,7 +800,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
y_grid_desc_mblock_mperblock_oblock_operblock;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5;
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3;
const index_t BlockStart = grid_size_;
const auto block_2_ctile_map = Block2CTileMap(k_grid_desc_n_k, BlockStart);
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1,
......@@ -813,7 +813,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
y_grid_desc_m_o);
}
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3 =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(
z_grid_desc_m_n);
......@@ -869,7 +869,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
b1_grid_desc_bk0_n_bk1,
y_grid_desc_m_o,
y_grid_desc_mblock_mperblock_oblock_operblock,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
lse_grid_desc_m,
k_grid_desc_n_k,
ygrad_grid_desc_o0_m_o1,
......
......@@ -123,7 +123,7 @@ __global__ void
c_element_op,
arg_ptr[group_id].a_grid_desc_ak0_m_ak1_,
arg_ptr[group_id].b_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].y_grid_desc_mblock_mperblock_oblock_operblock_,
arg_ptr[group_id].lse_grid_desc_m_,
......@@ -159,7 +159,7 @@ __global__ void
c_element_op,
arg_ptr[group_id].a_grid_desc_ak0_m_ak1_,
arg_ptr[group_id].b_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].y_grid_desc_mblock_mperblock_oblock_operblock_,
arg_ptr[group_id].lse_grid_desc_m_,
......@@ -669,7 +669,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
y_grid_desc_mblock_mperblock_oblock_operblock_;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_;
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_;
LSEGridDesc_M lse_grid_desc_m_;
KGridDesc_N_K k_grid_desc_n_k_;
YGradGridDesc_M0_O_M1 ygrad_grid_desc_m0_o_m1_;
......@@ -808,7 +808,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
typename GridwiseGemm::YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
y_grid_desc_mblock_mperblock_oblock_operblock;
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5;
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3;
const index_t BlockStart = grid_size_;
const auto block_2_ctile_map = Block2CTileMap(k_grid_desc_n_k, BlockStart);
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1,
......@@ -821,7 +821,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
y_grid_desc_m_o);
}
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3 =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(
z_grid_desc_m_n);
......@@ -877,7 +877,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
b1_grid_desc_bk0_n_bk1,
y_grid_desc_m_o,
y_grid_desc_mblock_mperblock_oblock_operblock,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3,
lse_grid_desc_m,
k_grid_desc_n_k,
ygrad_grid_desc_m0_o_m1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment