Commit a3c14b5f authored by guangzlu's avatar guangzlu
Browse files

removed global shuffle apis

parent 9e16e38e
...@@ -39,7 +39,6 @@ template <typename GridwiseGemm, ...@@ -39,7 +39,6 @@ template <typename GridwiseGemm,
typename B1GridDesc_BK0_N_BK1, typename B1GridDesc_BK0_N_BK1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5, typename ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
typename ZShuffleGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5,
typename LSEGridDescriptor_M, typename LSEGridDescriptor_M,
typename Block2CTileMap, typename Block2CTileMap,
typename ComputeBasePtrOfStridedBatch, typename ComputeBasePtrOfStridedBatch,
...@@ -71,8 +70,6 @@ __global__ void ...@@ -71,8 +70,6 @@ __global__ void
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 const ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
const ZShuffleGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5
z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5,
const LSEGridDescriptor_M lse_grid_desc_m, const LSEGridDescriptor_M lse_grid_desc_m,
const Block2CTileMap block_2_ctile_map, const Block2CTileMap block_2_ctile_map,
const index_t batch_count, const index_t batch_count,
...@@ -130,7 +127,6 @@ __global__ void ...@@ -130,7 +127,6 @@ __global__ void
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5,
lse_grid_desc_m, lse_grid_desc_m,
block_2_ctile_map, block_2_ctile_map,
c0_matrix_mask, c0_matrix_mask,
...@@ -163,7 +159,6 @@ __global__ void ...@@ -163,7 +159,6 @@ __global__ void
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5,
lse_grid_desc_m, lse_grid_desc_m,
block_2_ctile_map, block_2_ctile_map,
c0_matrix_mask, c0_matrix_mask,
...@@ -653,10 +648,6 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -653,10 +648,6 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_ = z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(z_grid_desc_m_n_); GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(z_grid_desc_m_n_);
z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_ =
GridwiseGemm::MakeCShuffleGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5(
z_grid_desc_m_n_);
if(p_lse_grid == nullptr) if(p_lse_grid == nullptr)
{ {
is_lse_storing_ = false; is_lse_storing_ = false;
...@@ -706,9 +697,6 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -706,9 +697,6 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_; z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_;
typename GridwiseGemm::ZShuffleGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5
z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_;
// block-to-c-tile map // block-to-c-tile map
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
...@@ -765,72 +753,69 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle ...@@ -765,72 +753,69 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
float ave_time = 0; float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_, auto launch_kernel =
auto is_dropout_, [&](auto has_main_k_block_loop_, auto is_dropout_, auto is_lse_storing_) {
auto is_lse_storing_) { const auto kernel = kernel_batched_multiheadattention_forward_xdl_cshuffle<
const auto kernel = kernel_batched_multiheadattention_forward_xdl_cshuffle< GridwiseGemm,
GridwiseGemm, ADataType, // TODO: distiguish A/B datatype
ADataType, // TODO: distiguish A/B datatype CDataType,
CDataType, ZDataType,
ZDataType, LSEDataType,
LSEDataType, GemmAccDataType,
GemmAccDataType, AElementwiseOperation,
AElementwiseOperation, BElementwiseOperation,
BElementwiseOperation, AccElementwiseOperation,
AccElementwiseOperation, B1ElementwiseOperation,
B1ElementwiseOperation, CElementwiseOperation,
CElementwiseOperation, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::BGridDesc_BK0_N_BK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::B1GridDesc_BK0_N_BK1,
DeviceOp::B1GridDesc_BK0_N_BK1, typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5, DeviceOp::LSEGridDesc_M,
typename GridwiseGemm::ZShuffleGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5, typename GridwiseGemm::DefaultBlock2CTileMap,
DeviceOp::LSEGridDesc_M, ComputeBasePtrOfStridedBatch,
typename GridwiseGemm::DefaultBlock2CTileMap, C0MatrixMask,
ComputeBasePtrOfStridedBatch, has_main_k_block_loop_,
C0MatrixMask, is_dropout_,
has_main_k_block_loop_, is_lse_storing_,
is_dropout_, Deterministic>;
is_lse_storing_,
Deterministic>; return launch_and_time_kernel(
stream_config,
return launch_and_time_kernel( kernel,
stream_config, dim3(grid_size),
kernel, dim3(BlockSize),
dim3(grid_size), 0,
dim3(BlockSize), arg.p_a_grid_,
0, arg.p_b_grid_,
arg.p_a_grid_, arg.p_b1_grid_,
arg.p_b_grid_, arg.p_c_grid_,
arg.p_b1_grid_, arg.p_z_grid_,
arg.p_c_grid_, arg.p_lse_grid_,
arg.p_z_grid_, arg.a_element_op_,
arg.p_lse_grid_, arg.b_element_op_,
arg.a_element_op_, arg.acc_element_op_,
arg.b_element_op_, arg.b1_element_op_,
arg.acc_element_op_, arg.c_element_op_,
arg.b1_element_op_, arg.a_grid_desc_ak0_m_ak1_,
arg.c_element_op_, arg.b_grid_desc_bk0_n_bk1_,
arg.a_grid_desc_ak0_m_ak1_, arg.b1_grid_desc_bk0_n_bk1_,
arg.b_grid_desc_bk0_n_bk1_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.b1_grid_desc_bk0_n_bk1_, arg.z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.lse_grid_desc_m_,
arg.z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_, arg.block_2_ctile_map_,
arg.z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_, arg.batch_count_,
arg.lse_grid_desc_m_, arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_),
arg.block_2_ctile_map_, arg.compute_base_ptr_of_batch_,
arg.batch_count_, arg.c0_matrix_mask_,
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_), arg.p_dropout_in_16bits_,
arg.compute_base_ptr_of_batch_, arg.p_dropout_rescale_,
arg.c0_matrix_mask_, arg.seed_,
arg.p_dropout_in_16bits_, arg.offset_,
arg.p_dropout_rescale_, arg.raw_lengths_mz_nz_kz_gemm1nz_[0],
arg.seed_, arg.raw_lengths_mz_nz_kz_gemm1nz_[1]);
arg.offset_, };
arg.raw_lengths_mz_nz_kz_gemm1nz_[0],
arg.raw_lengths_mz_nz_kz_gemm1nz_[1]);
};
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need // Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
// to concern Gemm0's loop // to concern Gemm0's loop
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment