Commit 13129772 authored by letaoqin's avatar letaoqin
Browse files

update to grouped gemm

parent 5cc0fd88
...@@ -120,7 +120,7 @@ using DeviceGemmInstance = ...@@ -120,7 +120,7 @@ using DeviceGemmInstance =
8, 8,
8, 8,
true, true,
1, 8,
S<16, 16, 1>, // B1BlockTransfer S<16, 16, 1>, // B1BlockTransfer
S<0, 2, 1>, S<0, 2, 1>,
S<0, 2, 1>, S<0, 2, 1>,
......
...@@ -202,7 +202,7 @@ template <index_t NumDimG, ...@@ -202,7 +202,7 @@ template <index_t NumDimG,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
MaskingSpecialization MaskingSpec, MaskingSpecialization MaskingSpec,
LoopScheduler LoopSched = LoopScheduler::Default> LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedMultiheadAttentionInfer_Xdl_CShuffle struct DeviceBatchedMultiheadAttentionInfer_Xdl_CShuffle
: public DeviceBatchedMultiheadAttentionInfer<NumDimG, : public DeviceBatchedMultiheadAttentionInfer<NumDimG,
NumDimM, NumDimM,
......
...@@ -109,7 +109,7 @@ __global__ void ...@@ -109,7 +109,7 @@ __global__ void
c_element_op, c_element_op,
arg_ptr[group_id].a_grid_desc_ak0_m_ak1_, arg_ptr[group_id].a_grid_desc_ak0_m_ak1_,
arg_ptr[group_id].b_grid_desc_bk0_n_bk1_, arg_ptr[group_id].b_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_, arg_ptr[group_id].d0_grid_desc_m0_n0_m1_m2_n1_m3_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_, arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].c_grid_desc_mblock_mperblock_nblock_nperblock_, arg_ptr[group_id].c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg_ptr[group_id].block_2_ctile_map_, arg_ptr[group_id].block_2_ctile_map_,
...@@ -461,8 +461,7 @@ struct DeviceGroupedMultiheadAttentionInfer_Xdl_CShuffle ...@@ -461,8 +461,7 @@ struct DeviceGroupedMultiheadAttentionInfer_Xdl_CShuffle
// tensor descriptors for block/thread-wise copy // tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
typename GridwiseGemm::D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 typename GridwiseGemm::D0GridDescriptor_M0_N0_N1_N2_M1_N3 d0_grid_desc_m0_n0_m1_m2_n1_m3_;
d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_;
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_; B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
...@@ -567,9 +566,8 @@ struct DeviceGroupedMultiheadAttentionInfer_Xdl_CShuffle ...@@ -567,9 +566,8 @@ struct DeviceGroupedMultiheadAttentionInfer_Xdl_CShuffle
const D0GridDesc_M_N d0_grid_desc_m_n{DeviceOp::MakeD0GridDescriptor_M_N( const D0GridDesc_M_N d0_grid_desc_m_n{DeviceOp::MakeD0GridDescriptor_M_N(
tmp_d0_gs_ms_ns_lengths, tmp_d0_gs_ms_ns_strides)}; tmp_d0_gs_ms_ns_lengths, tmp_d0_gs_ms_ns_strides)};
const auto d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 = const auto d0_grid_desc_m0_n0_m1_m2_n1_m3_ =
GridwiseGemm::MakeD0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5( GridwiseGemm::MakeD0GridDescriptor_M0_N0_N1_N2_M1_N3(d0_grid_desc_m_n);
d0_grid_desc_m_n);
const auto b1_grid_desc_bk0_n_bk1 = MakeB1GridDescriptor_BK0_N_BK1( const auto b1_grid_desc_bk0_n_bk1 = MakeB1GridDescriptor_BK0_N_BK1(
problem_desc.b1_gs_os_ns_lengths, problem_desc.b1_gs_os_ns_strides); problem_desc.b1_gs_os_ns_lengths, problem_desc.b1_gs_os_ns_strides);
...@@ -619,7 +617,7 @@ struct DeviceGroupedMultiheadAttentionInfer_Xdl_CShuffle ...@@ -619,7 +617,7 @@ struct DeviceGroupedMultiheadAttentionInfer_Xdl_CShuffle
p_c_grid, p_c_grid,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, d0_grid_desc_m0_n0_m1_m2_n1_m3_,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n), block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment