Commit 1311e039 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent 49b926b6
......@@ -59,31 +59,31 @@ template <index_t BlockSize,
typename AKMGridDesc,
typename BKNGridDesc,
typename CMNGridDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t MPerBlockM1,
index_t NPerBlockN1,
index_t KPerBlock,
index_t M1PerThread,
index_t N1PerThread,
index_t M1PerThreadM111,
index_t N1PerThreadN111,
index_t KPerThread,
index_t M1N1ThreadClusterM100,
index_t M1N1ThreadClusterN100,
index_t M1N1ThreadClusterM101,
index_t M1N1ThreadClusterN101,
typename ABlockTransferThreadSliceLengths_K_M,
typename ABlockTransferThreadClusterLengths_K_M,
index_t M11N11ThreadClusterM1100,
index_t M11N11ThreadClusterN1100,
index_t M11N11ThreadClusterM1101,
index_t M11N11ThreadClusterN1101,
typename ABlockTransferThreadSliceLengths_K_M0_M1,
typename ABlockTransferThreadClusterLengths_K_M0_M1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_M,
index_t ABlockTransferDstScalarPerVector_M1,
bool AThreadTransferSrcResetCoordinateAfterRun,
typename BBlockTransferThreadSliceLengths_K_N,
typename BBlockTransferThreadClusterLengths_K_N,
typename BBlockTransferThreadSliceLengths_K_N0_N1,
typename BBlockTransferThreadClusterLengths_K_N0_N1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_N,
index_t BBlockTransferDstScalarPerVector_N1,
bool BThreadTransferSrcResetCoordinateAfterRun,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
......@@ -102,20 +102,20 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M>{},
Number<BBlockTransferDstScalarPerVector_N>{},
Number<M1PerThread>{},
Number<N1PerThread>{});
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M1>{},
Number<BBlockTransferDstScalarPerVector_N1>{},
Number<M1PerThreadM111>{},
Number<N1PerThreadN111>{});
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_k_m_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}), max_lds_align);
make_tuple(Number<KPerBlock>{}, Number<MPerBlockM1>{}), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}), max_lds_align);
make_tuple(Number<KPerBlock>{}, Number<NPerBlockN1>{}), max_lds_align);
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_aligned_space_size =
......@@ -139,12 +139,12 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) &&
K == b_k_n_grid_desc.GetLength(I0)) &&
(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0);
(M % MPerBlockM1 == 0 && N % NPerBlockN1 == 0 && K % KPerBlock == 0);
}
__host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
{
const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
const index_t grid_size = (M / MPerBlockM1) * (N / NPerBlockN1);
return grid_size;
}
......@@ -169,7 +169,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
const auto K = a_k_m_grid_desc.GetLength(I0);
const auto M = a_k_m_grid_desc.GetLength(I1);
const auto M1 = Number<MPerBlock>{};
const auto M1 = Number<MPerBlockM1>{};
const auto M0 = M / M1;
const auto a_k_m0_m1_grid_desc = transform_dynamic_tensor_descriptor(
......@@ -187,7 +187,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
const auto K = b_k_n_grid_desc.GetLength(I0);
const auto N = b_k_n_grid_desc.GetLength(I1);
const auto N1 = Number<NPerBlock>{};
const auto N1 = Number<NPerBlockN1>{};
const auto N0 = N / N1;
const auto b_k_n0_n1_grid_desc = transform_dynamic_tensor_descriptor(
......@@ -205,14 +205,16 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
const auto M = c_m_n_grid_desc.GetLength(I0);
const auto N = c_m_n_grid_desc.GetLength(I1);
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
constexpr auto M1 = Number<MPerBlockM1>{};
constexpr auto N1 = Number<NPerBlockN1>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
constexpr auto M11 = Number<M1N1ThreadClusterM100 * M1N1ThreadClusterM101 * M1PerThread>{};
constexpr auto N11 = Number<M1N1ThreadClusterN100 * M1N1ThreadClusterN101 * N1PerThread>{};
constexpr auto M11 =
Number<M11N11ThreadClusterM1100 * M11N11ThreadClusterM1101 * M1PerThreadM111>{};
constexpr auto N11 =
Number<M11N11ThreadClusterN1100 * M11N11ThreadClusterN1101 * N1PerThreadN111>{};
constexpr auto M10 = M1 / M11;
constexpr auto N10 = N1 / N11;
......@@ -233,8 +235,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
const auto M = c_m_n_grid_desc.GetLength(I0);
const auto N = c_m_n_grid_desc.GetLength(I1);
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
constexpr auto M1 = Number<MPerBlockM1>{};
constexpr auto N1 = Number<NPerBlockN1>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
......@@ -285,38 +287,38 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
const index_t n0_idx = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]);
// lds max alignment
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M>{},
Number<BBlockTransferDstScalarPerVector_N>{},
Number<M1PerThread>{},
Number<N1PerThread>{});
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M1>{},
Number<BBlockTransferDstScalarPerVector_N1>{},
Number<M1PerThreadM111>{},
Number<N1PerThreadN111>{});
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_k_m_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}), max_lds_align);
make_tuple(Number<KPerBlock>{}, Number<MPerBlockM1>{}), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}), max_lds_align);
make_tuple(Number<KPerBlock>{}, Number<NPerBlockN1>{}), max_lds_align);
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_k_m0_m1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, I1, Number<MPerBlock>{}), max_lds_align);
make_tuple(Number<KPerBlock>{}, I1, Number<MPerBlockM1>{}), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_k_n0_n1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, I1, Number<NPerBlock>{}), max_lds_align);
make_tuple(Number<KPerBlock>{}, I1, Number<NPerBlockN1>{}), max_lds_align);
// A matrix blockwise copy
auto a_blockwise_copy =
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperation::Set,
Sequence<KPerBlock, 1, MPerBlock>,
ABlockTransferThreadSliceLengths_K_M,
ABlockTransferThreadClusterLengths_K_M,
Sequence<KPerBlock, 1, MPerBlockM1>,
ABlockTransferThreadSliceLengths_K_M0_M1,
ABlockTransferThreadClusterLengths_K_M0_M1,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
......@@ -327,7 +329,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_M,
ABlockTransferDstScalarPerVector_M1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
......@@ -340,9 +342,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
auto b_blockwise_copy =
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperation::Set,
Sequence<KPerBlock, 1, NPerBlock>,
BBlockTransferThreadSliceLengths_K_N,
BBlockTransferThreadClusterLengths_K_N,
Sequence<KPerBlock, 1, NPerBlockN1>,
BBlockTransferThreadSliceLengths_K_N0_N1,
BBlockTransferThreadClusterLengths_K_N0_N1,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
......@@ -353,7 +355,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_N,
BBlockTransferDstScalarPerVector_N1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
......@@ -364,9 +366,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// a_mtx[KPerBlock, MPerBlockM1] is in LDS
// b_mtx[KPerBlocl, NPerBlockN1] is in LDS
// c_mtx[MPerBlockM1, NPerBlockN1] is distributed among threads, and saved in
// register
const auto blockwise_gemm =
BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2<BlockSize,
......@@ -375,15 +377,15 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
FloatAcc,
decltype(a_k_m_block_desc),
decltype(b_k_n_block_desc),
M1PerThread,
N1PerThread,
M1PerThreadM111,
N1PerThreadN111,
KPerThread,
M1N1ThreadClusterM100,
M1N1ThreadClusterN100,
M1N1ThreadClusterM101,
M1N1ThreadClusterN101,
M1PerThread,
N1PerThread>{};
M11N11ThreadClusterM1100,
M11N11ThreadClusterN1100,
M11N11ThreadClusterM1101,
M11N11ThreadClusterN1101,
M1PerThreadM111,
N1PerThreadN111>{};
constexpr auto c_m10_n10_m11_n11_thread_tensor_lengths =
decltype(blockwise_gemm)::GetCM0M1N0N1ThreadTensorLengths();
......@@ -559,14 +561,16 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
// output: register to global memory
{
constexpr index_t M11 = M1PerThread * M1N1ThreadClusterM100 * M1N1ThreadClusterM101;
constexpr index_t N11 = N1PerThread * M1N1ThreadClusterN100 * M1N1ThreadClusterN101;
constexpr index_t M11 =
M1PerThreadM111 * M11N11ThreadClusterM1100 * M11N11ThreadClusterM1101;
constexpr index_t N11 =
N1PerThreadN111 * M11N11ThreadClusterN1100 * M11N11ThreadClusterN1101;
constexpr index_t M10 = MPerBlock / M11;
constexpr index_t N10 = NPerBlock / N11;
constexpr index_t M10 = MPerBlockM1 / M11;
constexpr index_t N10 = NPerBlockN1 / N11;
constexpr index_t M111 = M1PerThread;
constexpr index_t N111 = N1PerThread;
constexpr index_t M111 = M1PerThreadM111;
constexpr index_t N111 = N1PerThreadN111;
constexpr auto c_m0_m10_m11_n0_n10_n11_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(
......
......@@ -83,32 +83,32 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
// b thread copy 4x1
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerBlockM1 = 128;
constexpr index_t GemmNPerBlockN1 = 128;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmM1PerThreadM111 = 4;
constexpr index_t GemmN1PerThreadN111 = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 8;
constexpr index_t GemmNLevel1Cluster = 8;
constexpr index_t GemmM11N11ThreadClusterM1101 = 2;
constexpr index_t GemmM11N11ThreadClusterN1101 = 2;
constexpr index_t GemmM11N11ThreadClusterM1100 = 8;
constexpr index_t GemmM11N11ThreadClusterN1100 = 8;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM0_GemmM1 = Sequence<4, 1, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM0_GemmM1 = Sequence<2, 1, 128>;
using GemmABlockTransferThreadSliceLengths_K_M0_M1 = Sequence<4, 1, 1>;
using GemmABlockTransferThreadClusterLengths_K_M0_M1 = Sequence<2, 1, 128>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM1 = 1;
constexpr index_t GemmABlockTransferSrcScalarPerVector_K = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_M1 = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN0_GemmN1 = Sequence<4, 1, 1>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN0_GemmN1 = Sequence<2, 1, 128>;
using GemmBBlockTransferThreadSliceLengths_K_N0_N1 = Sequence<4, 1, 1>;
using GemmBBlockTransferThreadClusterLengths_K_N0_N1 = Sequence<2, 1, 128>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN1 = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN1 = 1;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_N1 = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_N1 = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN11 = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_N11 = 1;
#endif
const auto descs =
......@@ -174,35 +174,35 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
decltype(wei_gemmk_gemmm_grid_desc),
decltype(in_gemmk_gemmn_grid_desc),
decltype(out_gemmm_gemmn_grid_desc),
GemmMPerBlock,
GemmNPerBlock,
GemmMPerBlockM1,
GemmNPerBlockN1,
GemmKPerBlock,
GemmMPerThread,
GemmNPerThread,
GemmM1PerThreadM111,
GemmN1PerThreadN111,
GemmKPerThread,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM0_GemmM1,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM0_GemmM1,
GemmM11N11ThreadClusterM1100,
GemmM11N11ThreadClusterN1100,
GemmM11N11ThreadClusterM1101,
GemmM11N11ThreadClusterN1101,
GemmABlockTransferThreadSliceLengths_K_M0_M1,
GemmABlockTransferThreadClusterLengths_K_M0_M1,
Sequence<2, 1, 0>, // ABlockTransferThreadClusterArrangeOrder
Sequence<2, 1, 0>, // ABlockTransferSrcAccessOrder
0, // ABlockTransferSrcVectorDim
GemmABlockTransferSrcScalarPerVector_GemmK,
GemmABlockTransferDstScalarPerVector_GemmM1,
GemmABlockTransferSrcScalarPerVector_K,
GemmABlockTransferDstScalarPerVector_M1,
false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN0_GemmN1,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN0_GemmN1,
GemmBBlockTransferThreadSliceLengths_K_N0_N1,
GemmBBlockTransferThreadClusterLengths_K_N0_N1,
Sequence<0, 1, 2>, // BBlockTransferThreadClusterArrangeOrder
Sequence<0, 1, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
GemmBBlockTransferSrcScalarPerVector_GemmN1,
GemmBBlockTransferDstScalarPerVector_GemmN1,
GemmBBlockTransferSrcScalarPerVector_N1,
GemmBBlockTransferDstScalarPerVector_N1,
false, // don't move back src coordinate after threadwise copy
Sequence<3, 4, 5, 0, 1, 2>, // CThreadTransferSrcDstAccessOrder
5, // CThreadTransferSrcDstVectorDim
GemmCThreadTransferDstScalarPerVector_GemmN11,
GemmCThreadTransferDstScalarPerVector_N11,
decltype(wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks),
decltype(in_gemmk_gemmn0_gemmn1_grid_iterator_hacks),
decltype(out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment