Commit a3c14b5f authored by guangzlu's avatar guangzlu
Browse files

removed global shuffle apis

parent 9e16e38e
......@@ -39,7 +39,6 @@ template <typename GridwiseGemm,
typename B1GridDesc_BK0_N_BK1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
typename ZShuffleGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5,
typename LSEGridDescriptor_M,
typename Block2CTileMap,
typename ComputeBasePtrOfStridedBatch,
......@@ -71,8 +70,6 @@ __global__ void
c_grid_desc_mblock_mperblock_nblock_nperblock,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
const ZShuffleGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5
z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5,
const LSEGridDescriptor_M lse_grid_desc_m,
const Block2CTileMap block_2_ctile_map,
const index_t batch_count,
......@@ -130,7 +127,6 @@ __global__ void
b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5,
lse_grid_desc_m,
block_2_ctile_map,
c0_matrix_mask,
......@@ -163,7 +159,6 @@ __global__ void
b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5,
lse_grid_desc_m,
block_2_ctile_map,
c0_matrix_mask,
......@@ -653,10 +648,6 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(z_grid_desc_m_n_);
z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_ =
GridwiseGemm::MakeCShuffleGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5(
z_grid_desc_m_n_);
if(p_lse_grid == nullptr)
{
is_lse_storing_ = false;
......@@ -706,9 +697,6 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_;
typename GridwiseGemm::ZShuffleGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5
z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_;
// block-to-c-tile map
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
......@@ -765,9 +753,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_,
auto is_dropout_,
auto is_lse_storing_) {
auto launch_kernel =
[&](auto has_main_k_block_loop_, auto is_dropout_, auto is_lse_storing_) {
const auto kernel = kernel_batched_multiheadattention_forward_xdl_cshuffle<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
......@@ -785,7 +772,6 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
DeviceOp::B1GridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
typename GridwiseGemm::ZShuffleGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5,
DeviceOp::LSEGridDesc_M,
typename GridwiseGemm::DefaultBlock2CTileMap,
ComputeBasePtrOfStridedBatch,
......@@ -817,7 +803,6 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
arg.b1_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg.z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_,
arg.lse_grid_desc_m_,
arg.block_2_ctile_map_,
arg.batch_count_,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment