"...resnet50_tensorflow.git" did not exist on "31ca3b97ebc1ca37b1d4db6ff3bf062fcbf16b5d"
Commit b3e80872 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent 6c37035f
...@@ -20,13 +20,13 @@ template <index_t BlockSize, ...@@ -20,13 +20,13 @@ template <index_t BlockSize,
index_t MPerBlock, index_t MPerBlock,
index_t NPerBlock, index_t NPerBlock,
index_t KPerBlock, index_t KPerBlock,
index_t MPerThread, index_t M1PerThread,
index_t NPerThread, index_t N1PerThread,
index_t KPerThread, index_t KPerThread,
index_t MLevel0Cluster, index_t M1N1ThreadClusterM10,
index_t NLevel0Cluster, index_t M1N1ThreadClusterN10,
index_t MLevel1Cluster, index_t M1N1ThreadClusterM11,
index_t NLevel1Cluster, index_t M1N1ThreadClusterN11,
typename ABlockTransferThreadSliceLengths_K_M, typename ABlockTransferThreadSliceLengths_K_M,
typename ABlockTransferThreadClusterLengths_K_M, typename ABlockTransferThreadClusterLengths_K_M,
typename ABlockTransferThreadClusterArrangeOrder, typename ABlockTransferThreadClusterArrangeOrder,
...@@ -80,8 +80,8 @@ __host__ float launch_kernel_dynamic_gemm_v1r1(const FloatAB* p_a_global, ...@@ -80,8 +80,8 @@ __host__ float launch_kernel_dynamic_gemm_v1r1(const FloatAB* p_a_global,
throw std::runtime_error("wrong! GEMM size no divisible"); throw std::runtime_error("wrong! GEMM size no divisible");
} }
constexpr auto M1 = Number<MPerThread * MLevel0Cluster * MLevel1Cluster>{}; constexpr auto M1 = Number<M1PerThread * M1N1ThreadClusterM11 * M1N1ThreadClusterM10>{};
constexpr auto N1 = Number<NPerThread * NLevel0Cluster * NLevel1Cluster>{}; constexpr auto N1 = Number<N1PerThread * M1N1ThreadClusterN11 * M1N1ThreadClusterN10>{};
if(!(MPerBlock % M1 == 0 && NPerBlock % N1 == 0)) if(!(MPerBlock % M1 == 0 && NPerBlock % N1 == 0))
{ {
...@@ -102,13 +102,13 @@ __host__ float launch_kernel_dynamic_gemm_v1r1(const FloatAB* p_a_global, ...@@ -102,13 +102,13 @@ __host__ float launch_kernel_dynamic_gemm_v1r1(const FloatAB* p_a_global,
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
KPerBlock, KPerBlock,
MPerThread, M1PerThread,
NPerThread, N1PerThread,
KPerThread, KPerThread,
MLevel0Cluster, M1N1ThreadClusterM10,
NLevel0Cluster, M1N1ThreadClusterN10,
MLevel1Cluster, M1N1ThreadClusterM11,
NLevel1Cluster, M1N1ThreadClusterN11,
ABlockTransferThreadSliceLengths_K_M, ABlockTransferThreadSliceLengths_K_M,
ABlockTransferThreadClusterLengths_K_M, ABlockTransferThreadClusterLengths_K_M,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
......
...@@ -29,10 +29,10 @@ template <index_t BlockSize, ...@@ -29,10 +29,10 @@ template <index_t BlockSize,
index_t M1PerThread, index_t M1PerThread,
index_t N1PerThread, index_t N1PerThread,
index_t KPerThread, index_t KPerThread,
index_t MLevel0ThreadCluster, index_t M1N1ThreadClusterM10,
index_t NLevel0ThreadCluster, index_t M1N1ThreadClusterN10,
index_t MLevel1ThreadCluster, index_t M1N1ThreadClusterM11,
index_t NLevel1ThreadCluster, index_t M1N1ThreadClusterN11,
index_t AThreadCopyScalarPerVector_M1, index_t AThreadCopyScalarPerVector_M1,
index_t BThreadCopyScalarPerVector_N1, index_t BThreadCopyScalarPerVector_N1,
typename std::enable_if<ABlockDesc::IsKnownAtCompileTime() && typename std::enable_if<ABlockDesc::IsKnownAtCompileTime() &&
...@@ -62,8 +62,8 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1 ...@@ -62,8 +62,8 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
CThreadDesc::IsKnownAtCompileTime(), CThreadDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time"); "wrong! Desc should be known at compile-time");
static_assert(BlockSize == MLevel0ThreadCluster * MLevel1ThreadCluster * static_assert(BlockSize == M1N1ThreadClusterM11 * M1N1ThreadClusterM10 *
NLevel0ThreadCluster * NLevel1ThreadCluster, M1N1ThreadClusterN11 * M1N1ThreadClusterN10,
"wrong! blocksize and cluster size not consistent"); "wrong! blocksize and cluster size not consistent");
static_assert(ABlockDesc{}.GetLength(I0) == BBlockDesc{}.GetLength(I0), static_assert(ABlockDesc{}.GetLength(I0) == BBlockDesc{}.GetLength(I0),
...@@ -78,6 +78,8 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1 ...@@ -78,6 +78,8 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
constexpr index_t N1 = BBlockDesc{}.GetLength(I2); constexpr index_t N1 = BBlockDesc{}.GetLength(I2);
// 4-d data space into 4-d thread space // 4-d data space into 4-d thread space
// upper: {1, M1N1ThreadClusterM10 * M1N1ThreadClusterM11, 1, M1N1ThreadClusterN10 *
// M1N1ThreadClusterN11} lower: {M0, M1, N0, N1}
constexpr auto adaptor0 = make_single_stage_tensor_adaptor( constexpr auto adaptor0 = make_single_stage_tensor_adaptor(
make_tuple(make_vectorize_transform(M0, 1), make_tuple(make_vectorize_transform(M0, 1),
make_vectorize_transform(M1PerThread, M1 / M1PerThread), make_vectorize_transform(M1PerThread, M1 / M1PerThread),
...@@ -87,21 +89,27 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1 ...@@ -87,21 +89,27 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
// thread position 4-d thread space // thread position 4-d thread space
// upper: {M1N1ThreadClusterM10, M1N1ThreadClusterM11, M1N1ThreadClusterN10,
// M1N1ThreadClusterN11} lower: {1, M1N1ThreadClusterM10 * M1N1ThreadClusterM11, 1,
// M1N1ThreadClusterN10 * M1N1ThreadClusterN11}
constexpr auto adaptor1 = make_single_stage_tensor_adaptor( constexpr auto adaptor1 = make_single_stage_tensor_adaptor(
make_tuple( make_tuple(
make_freeze_transform(make_multi_index(0)), make_freeze_transform(make_multi_index(0)),
make_unmerge_transform(make_tuple(MLevel1ThreadCluster, MLevel0ThreadCluster)), make_unmerge_transform(make_tuple(M1N1ThreadClusterM10, M1N1ThreadClusterM11)),
make_freeze_transform(make_multi_index(0)), make_freeze_transform(make_multi_index(0)),
make_unmerge_transform(make_tuple(NLevel1ThreadCluster, NLevel0ThreadCluster))), make_unmerge_transform(make_tuple(M1N1ThreadClusterN10, M1N1ThreadClusterN11))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<>{}, Sequence<0, 1>{}, Sequence<>{}, Sequence<2, 3>{})); make_tuple(Sequence<>{}, Sequence<0, 1>{}, Sequence<>{}, Sequence<2, 3>{}));
// 4-d thread space to 1-d thread space // 4-d thread space to 1-d thread space
// upper: {BlockSize}
// lower: {M1N1ThreadClusterM10, M1N1ThreadClusterM11, M1N1ThreadClusterN10,
// M1N1ThreadClusterN11}
constexpr auto adaptor2 = make_single_stage_tensor_adaptor( constexpr auto adaptor2 = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(MLevel1ThreadCluster, make_tuple(make_merge_transform(make_tuple(M1N1ThreadClusterM10,
NLevel1ThreadCluster, M1N1ThreadClusterN10,
MLevel0ThreadCluster, M1N1ThreadClusterM11,
NLevel0ThreadCluster))), M1N1ThreadClusterN11))),
make_tuple(Sequence<0, 2, 1, 3>{}), make_tuple(Sequence<0, 2, 1, 3>{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
...@@ -221,10 +229,10 @@ template <index_t BlockSize, ...@@ -221,10 +229,10 @@ template <index_t BlockSize,
index_t M1PerThread, index_t M1PerThread,
index_t N1PerThread, index_t N1PerThread,
index_t KPerThread, index_t KPerThread,
index_t MLevel0ThreadCluster, index_t M1N1ThreadClusterM11,
index_t NLevel0ThreadCluster, index_t M1N1ThreadClusterN11,
index_t MLevel1ThreadCluster, index_t M1N1ThreadClusterM10,
index_t NLevel1ThreadCluster, index_t M1N1ThreadClusterN10,
index_t AThreadCopyScalarPerVector_M1, index_t AThreadCopyScalarPerVector_M1,
index_t BThreadCopyScalarPerVector_N1, index_t BThreadCopyScalarPerVector_N1,
typename std::enable_if<ABlockDesc::IsKnownAtCompileTime() && typename std::enable_if<ABlockDesc::IsKnownAtCompileTime() &&
...@@ -254,8 +262,8 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2 ...@@ -254,8 +262,8 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2
CThreadDesc::IsKnownAtCompileTime(), CThreadDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time"); "wrong! Desc should be known at compile-time");
static_assert(BlockSize == MLevel0ThreadCluster * MLevel1ThreadCluster * static_assert(BlockSize == M1N1ThreadClusterM11 * M1N1ThreadClusterM10 *
NLevel0ThreadCluster * NLevel1ThreadCluster, M1N1ThreadClusterN11 * M1N1ThreadClusterN10,
"wrong! blocksize and cluster size not consistent"); "wrong! blocksize and cluster size not consistent");
static_assert(ABlockDesc{}.GetLength(I0) == BBlockDesc{}.GetLength(I0), static_assert(ABlockDesc{}.GetLength(I0) == BBlockDesc{}.GetLength(I0),
...@@ -287,18 +295,18 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2 ...@@ -287,18 +295,18 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2
constexpr auto adaptor1 = make_single_stage_tensor_adaptor( constexpr auto adaptor1 = make_single_stage_tensor_adaptor(
make_tuple( make_tuple(
make_freeze_transform(make_multi_index(0)), make_freeze_transform(make_multi_index(0)),
make_unmerge_transform(make_tuple(MLevel1ThreadCluster, MLevel0ThreadCluster)), make_unmerge_transform(make_tuple(M1N1ThreadClusterM10, M1N1ThreadClusterM11)),
make_freeze_transform(make_multi_index(0)), make_freeze_transform(make_multi_index(0)),
make_unmerge_transform(make_tuple(NLevel1ThreadCluster, NLevel0ThreadCluster))), make_unmerge_transform(make_tuple(M1N1ThreadClusterN10, M1N1ThreadClusterN11))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<>{}, Sequence<0, 1>{}, Sequence<>{}, Sequence<2, 3>{})); make_tuple(Sequence<>{}, Sequence<0, 1>{}, Sequence<>{}, Sequence<2, 3>{}));
// 4-d thread space to 1-d thread space // 4-d thread space to 1-d thread space
constexpr auto adaptor2 = make_single_stage_tensor_adaptor( constexpr auto adaptor2 = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(MLevel1ThreadCluster, make_tuple(make_merge_transform(make_tuple(M1N1ThreadClusterM10,
NLevel1ThreadCluster, M1N1ThreadClusterN10,
MLevel0ThreadCluster, M1N1ThreadClusterM11,
NLevel0ThreadCluster))), M1N1ThreadClusterN11))),
make_tuple(Sequence<0, 2, 1, 3>{}), make_tuple(Sequence<0, 2, 1, 3>{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
......
...@@ -108,13 +108,13 @@ template <index_t BlockSize, ...@@ -108,13 +108,13 @@ template <index_t BlockSize,
index_t MPerBlock, index_t MPerBlock,
index_t NPerBlock, index_t NPerBlock,
index_t KPerBlock, index_t KPerBlock,
index_t MPerThread, index_t M1PerThread,
index_t NPerThread, index_t N1PerThread,
index_t KPerThread, index_t KPerThread,
index_t MLevel0Cluster, index_t M1N1ThreadClusterM10,
index_t NLevel0Cluster, index_t M1N1ThreadClusterN10,
index_t MLevel1Cluster, index_t M1N1ThreadClusterM11,
index_t NLevel1Cluster, index_t M1N1ThreadClusterN11,
typename ABlockTransferThreadSliceLengths_K_M, typename ABlockTransferThreadSliceLengths_K_M,
typename ABlockTransferThreadClusterLengths_K_M, typename ABlockTransferThreadClusterLengths_K_M,
typename ABlockTransferThreadClusterArrangeOrder, typename ABlockTransferThreadClusterArrangeOrder,
...@@ -145,8 +145,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r1 ...@@ -145,8 +145,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r1
{ {
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M>{}, constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M>{},
Number<BBlockTransferDstScalarPerVector_N>{}, Number<BBlockTransferDstScalarPerVector_N>{},
Number<MPerThread>{}, Number<M1PerThread>{},
Number<NPerThread>{}); Number<N1PerThread>{});
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
...@@ -210,8 +210,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r1 ...@@ -210,8 +210,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r1
// lds max alignment // lds max alignment
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M>{}, constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M>{},
Number<BBlockTransferDstScalarPerVector_N>{}, Number<BBlockTransferDstScalarPerVector_N>{},
Number<MPerThread>{}, Number<M1PerThread>{},
Number<NPerThread>{}); Number<N1PerThread>{});
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
...@@ -284,34 +284,39 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r1 ...@@ -284,34 +284,39 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r1
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register // register
// sanity check // sanity check
static_assert(MPerBlock % (MPerThread * MLevel0Cluster * MLevel1Cluster) == 0 && static_assert(
NPerBlock % (NPerThread * NLevel0Cluster * NLevel1Cluster) == 0, MPerBlock % (M1PerThread * M1N1ThreadClusterM11 * M1N1ThreadClusterM10) == 0 &&
"wrong!"); NPerBlock % (N1PerThread * M1N1ThreadClusterN11 * M1N1ThreadClusterN10) == 0,
"wrong!");
constexpr index_t MRepeat = MPerBlock / (MPerThread * MLevel0Cluster * MLevel1Cluster); constexpr index_t M0PerThread =
constexpr index_t NRepeat = NPerBlock / (NPerThread * NLevel0Cluster * NLevel1Cluster); MPerBlock / (M1PerThread * M1N1ThreadClusterM11 * M1N1ThreadClusterM10);
constexpr index_t N0PerThread =
NPerBlock / (N1PerThread * M1N1ThreadClusterN11 * M1N1ThreadClusterN10);
constexpr auto a_k_m0_m1_block_desc = transform_dynamic_tensor_descriptor( constexpr auto a_k_m0_m1_block_desc = transform_dynamic_tensor_descriptor(
a_k_m_block_desc, a_k_m_block_desc,
make_tuple( make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
make_pass_through_transform(Number<KPerBlock>{}), make_unmerge_transform(make_tuple(
make_unmerge_transform(make_tuple( Number<M0PerThread>{},
Number<MRepeat>{}, Number<MPerThread * MLevel0Cluster * MLevel1Cluster>{}))), Number<M1PerThread * M1N1ThreadClusterM11 * M1N1ThreadClusterM10>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
constexpr auto b_k_n0_n1_block_desc = transform_dynamic_tensor_descriptor( constexpr auto b_k_n0_n1_block_desc = transform_dynamic_tensor_descriptor(
b_k_n_block_desc, b_k_n_block_desc,
make_tuple( make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
make_pass_through_transform(Number<KPerBlock>{}), make_unmerge_transform(make_tuple(
make_unmerge_transform(make_tuple( Number<N0PerThread>{},
Number<NRepeat>{}, Number<NPerThread * NLevel0Cluster * NLevel1Cluster>{}))), Number<N1PerThread * M1N1ThreadClusterN11 * M1N1ThreadClusterN10>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
constexpr auto c_m0_m1_n0_n1_thread_desc = constexpr auto c_m0_m1_n0_n1_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(Number<M0PerThread>{},
Number<MRepeat>{}, Number<MPerThread>{}, Number<NRepeat>{}, Number<NPerThread>{})); Number<M1PerThread>{},
Number<N0PerThread>{},
Number<N1PerThread>{}));
const auto blockwise_gemm = const auto blockwise_gemm =
BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2<BlockSize, BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2<BlockSize,
...@@ -321,15 +326,15 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r1 ...@@ -321,15 +326,15 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r1
decltype(a_k_m0_m1_block_desc), decltype(a_k_m0_m1_block_desc),
decltype(b_k_n0_n1_block_desc), decltype(b_k_n0_n1_block_desc),
decltype(c_m0_m1_n0_n1_thread_desc), decltype(c_m0_m1_n0_n1_thread_desc),
MPerThread, M1PerThread,
NPerThread, N1PerThread,
KPerThread, KPerThread,
MLevel0Cluster, M1N1ThreadClusterM10,
NLevel0Cluster, M1N1ThreadClusterN10,
MLevel1Cluster, M1N1ThreadClusterM11,
NLevel1Cluster, M1N1ThreadClusterN11,
MPerThread, M1PerThread,
NPerThread>{}; N1PerThread>{};
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size = constexpr auto a_block_space_size =
...@@ -345,9 +350,10 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r1 ...@@ -345,9 +350,10 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r1
auto c_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatAcc>( auto c_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatAcc>(
c_m0_m1_n0_n1_thread_desc.GetElementSpaceSize()); c_m0_m1_n0_n1_thread_desc.GetElementSpaceSize());
ThreadwiseDynamicTensorSliceSet_v1<FloatAcc, ThreadwiseDynamicTensorSliceSet_v1<
decltype(c_m0_m1_n0_n1_thread_desc), FloatAcc,
Sequence<MRepeat, MPerThread, NRepeat, NPerThread>>{} decltype(c_m0_m1_n0_n1_thread_desc),
Sequence<M0PerThread, M1PerThread, N0PerThread, N1PerThread>>{}
.Run(c_m0_m1_n0_n1_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0}); .Run(c_m0_m1_n0_n1_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0});
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0);
...@@ -479,8 +485,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r1 ...@@ -479,8 +485,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r1
// output: register to global memory // output: register to global memory
{ {
constexpr auto M1 = Number<MPerThread * MLevel0Cluster * MLevel1Cluster>{}; constexpr auto M1 = Number<M1PerThread * M1N1ThreadClusterM10 * M1N1ThreadClusterM11>{};
constexpr auto N1 = Number<NPerThread * NLevel0Cluster * NLevel1Cluster>{}; constexpr auto N1 = Number<N1PerThread * M1N1ThreadClusterN10 * M1N1ThreadClusterN11>{};
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor // hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
constexpr auto c_m0_m1_n0_n1_global_tensor_iterator_hacks = CGlobalIteratorHacks{}; constexpr auto c_m0_m1_n0_n1_global_tensor_iterator_hacks = CGlobalIteratorHacks{};
...@@ -493,7 +499,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r1 ...@@ -493,7 +499,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r1
FloatC, FloatC,
decltype(c_m0_m1_n0_n1_thread_desc), decltype(c_m0_m1_n0_n1_thread_desc),
decltype(c_m0_m1_n0_n1_global_desc), decltype(c_m0_m1_n0_n1_global_desc),
Sequence<MRepeat, MPerThread, NRepeat, NPerThread>, Sequence<M0PerThread, M1PerThread, N0PerThread, N1PerThread>,
CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector, CThreadTransferDstScalarPerVector,
......
...@@ -125,14 +125,14 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk( ...@@ -125,14 +125,14 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
constexpr index_t GemmNPerBlock = 64; constexpr index_t GemmNPerBlock = 64;
constexpr index_t GemmKPerBlock = 4; constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerThread = 2; constexpr index_t GemmM1PerThread = 2;
constexpr index_t GemmNPerThread = 2; constexpr index_t GemmN1PerThread = 2;
constexpr index_t GemmKPerThread = 1; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2; constexpr index_t GemmM1N1ThreadClusterM11 = 2;
constexpr index_t GemmNLevel0Cluster = 2; constexpr index_t GemmM1N1ThreadClusterN11 = 2;
constexpr index_t GemmMLevel1Cluster = 2; constexpr index_t GemmM1N1ThreadClusterM10 = 2;
constexpr index_t GemmNLevel1Cluster = 8; constexpr index_t GemmM1N1ThreadClusterN10 = 8;
constexpr index_t ThreadGemmDataPerReadM = 2; constexpr index_t ThreadGemmDataPerReadM = 2;
constexpr index_t ThreadGemmDataPerReadN = 2; constexpr index_t ThreadGemmDataPerReadN = 2;
...@@ -149,7 +149,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk( ...@@ -149,7 +149,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 4; constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 2; constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM11 = 2;
#elif 0 #elif 0
// cdata = 32, BlockSize = 64, 16x128x4 // cdata = 32, BlockSize = 64, 16x128x4
constexpr index_t BlockSize = 64; constexpr index_t BlockSize = 64;
...@@ -158,14 +158,14 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk( ...@@ -158,14 +158,14 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4; constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerThread = 2; constexpr index_t GemmM1PerThread = 2;
constexpr index_t GemmNPerThread = 4; constexpr index_t GemmN1PerThread = 4;
constexpr index_t GemmKPerThread = 1; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2; constexpr index_t GemmM1N1ThreadClusterM11 = 2;
constexpr index_t GemmNLevel0Cluster = 2; constexpr index_t GemmM1N1ThreadClusterN11 = 2;
constexpr index_t GemmMLevel1Cluster = 2; constexpr index_t GemmM1N1ThreadClusterM10 = 2;
constexpr index_t GemmNLevel1Cluster = 8; constexpr index_t GemmM1N1ThreadClusterN10 = 8;
constexpr index_t ThreadGemmDataPerReadM = 2; constexpr index_t ThreadGemmDataPerReadM = 2;
constexpr index_t ThreadGemmDataPerReadN = 4; constexpr index_t ThreadGemmDataPerReadN = 4;
...@@ -182,7 +182,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk( ...@@ -182,7 +182,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 4; constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 2; constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM11 = 2;
#elif 0 #elif 0
// cdata = 64, BlockSize = 64, 16x256x2 // cdata = 64, BlockSize = 64, 16x256x2
constexpr index_t BlockSize = 64; constexpr index_t BlockSize = 64;
...@@ -191,14 +191,14 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk( ...@@ -191,14 +191,14 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
constexpr index_t GemmNPerBlock = 256; constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmKPerBlock = 2; constexpr index_t GemmKPerBlock = 2;
constexpr index_t GemmMPerThread = 4; constexpr index_t GemmM1PerThread = 4;
constexpr index_t GemmNPerThread = 4; constexpr index_t GemmN1PerThread = 4;
constexpr index_t GemmKPerThread = 1; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 1; constexpr index_t GemmM1N1ThreadClusterM11 = 1;
constexpr index_t GemmNLevel0Cluster = 2; constexpr index_t GemmM1N1ThreadClusterN11 = 2;
constexpr index_t GemmMLevel1Cluster = 2; constexpr index_t GemmM1N1ThreadClusterM10 = 2;
constexpr index_t GemmNLevel1Cluster = 16; constexpr index_t GemmM1N1ThreadClusterN10 = 16;
constexpr index_t ThreadGemmDataPerReadM = 4; constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t ThreadGemmDataPerReadN = 4; constexpr index_t ThreadGemmDataPerReadN = 4;
...@@ -215,7 +215,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk( ...@@ -215,7 +215,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 2; constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 2;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4; constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM11 = 4;
#elif 0 #elif 0
// cdata = 64, BlockSize = 64, 16x256x4 // cdata = 64, BlockSize = 64, 16x256x4
constexpr index_t BlockSize = 64; constexpr index_t BlockSize = 64;
...@@ -224,14 +224,14 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk( ...@@ -224,14 +224,14 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
constexpr index_t GemmNPerBlock = 256; constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmKPerBlock = 4; constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerThread = 4; constexpr index_t GemmM1PerThread = 4;
constexpr index_t GemmNPerThread = 4; constexpr index_t GemmN1PerThread = 4;
constexpr index_t GemmKPerThread = 1; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2; constexpr index_t GemmM1N1ThreadClusterM11 = 2;
constexpr index_t GemmNLevel0Cluster = 2; constexpr index_t GemmM1N1ThreadClusterN11 = 2;
constexpr index_t GemmMLevel1Cluster = 1; constexpr index_t GemmM1N1ThreadClusterM10 = 1;
constexpr index_t GemmNLevel1Cluster = 16; constexpr index_t GemmM1N1ThreadClusterN10 = 16;
constexpr index_t ThreadGemmDataPerReadM = 4; constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t ThreadGemmDataPerReadN = 4; constexpr index_t ThreadGemmDataPerReadN = 4;
...@@ -248,7 +248,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk( ...@@ -248,7 +248,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 4; constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4; constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM11 = 4;
#elif 0 #elif 0
// cdata = 64, BlockSize = 128, 32x256x4 // cdata = 64, BlockSize = 128, 32x256x4
constexpr index_t BlockSize = 128; constexpr index_t BlockSize = 128;
...@@ -257,14 +257,14 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk( ...@@ -257,14 +257,14 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
constexpr index_t GemmNPerBlock = 256; constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmKPerBlock = 4; constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerThread = 4; constexpr index_t GemmM1PerThread = 4;
constexpr index_t GemmNPerThread = 4; constexpr index_t GemmN1PerThread = 4;
constexpr index_t GemmKPerThread = 1; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2; constexpr index_t GemmM1N1ThreadClusterM11 = 2;
constexpr index_t GemmNLevel0Cluster = 2; constexpr index_t GemmM1N1ThreadClusterN11 = 2;
constexpr index_t GemmMLevel1Cluster = 2; constexpr index_t GemmM1N1ThreadClusterM10 = 2;
constexpr index_t GemmNLevel1Cluster = 16; constexpr index_t GemmM1N1ThreadClusterN10 = 16;
constexpr index_t ThreadGemmDataPerReadM = 4; constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t ThreadGemmDataPerReadN = 4; constexpr index_t ThreadGemmDataPerReadN = 4;
...@@ -281,7 +281,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk( ...@@ -281,7 +281,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 4; constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4; constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM11 = 4;
#elif 0 #elif 0
// cdata = 64, BlockSize = 128, 32x256x8 // cdata = 64, BlockSize = 128, 32x256x8
constexpr index_t BlockSize = 128; constexpr index_t BlockSize = 128;
...@@ -290,14 +290,14 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk( ...@@ -290,14 +290,14 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
constexpr index_t GemmNPerBlock = 256; constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmKPerBlock = 8; constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThread = 4; constexpr index_t GemmM1PerThread = 4;
constexpr index_t GemmNPerThread = 4; constexpr index_t GemmN1PerThread = 4;
constexpr index_t GemmKPerThread = 1; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2; constexpr index_t GemmM1N1ThreadClusterM11 = 2;
constexpr index_t GemmNLevel0Cluster = 2; constexpr index_t GemmM1N1ThreadClusterN11 = 2;
constexpr index_t GemmMLevel1Cluster = 2; constexpr index_t GemmM1N1ThreadClusterM10 = 2;
constexpr index_t GemmNLevel1Cluster = 16; constexpr index_t GemmM1N1ThreadClusterN10 = 16;
constexpr index_t ThreadGemmDataPerReadM = 4; constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t ThreadGemmDataPerReadN = 4; constexpr index_t ThreadGemmDataPerReadN = 4;
...@@ -314,7 +314,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk( ...@@ -314,7 +314,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 8; constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 8;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4; constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM11 = 4;
#elif 1 #elif 1
// cdata = 64, BlockSize = 256, 128x128x8 // cdata = 64, BlockSize = 256, 128x128x8
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
...@@ -323,14 +323,14 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk( ...@@ -323,14 +323,14 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 8; constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThread = 4; constexpr index_t GemmM1PerThread = 4;
constexpr index_t GemmNPerThread = 4; constexpr index_t GemmN1PerThread = 4;
constexpr index_t GemmKPerThread = 1; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2; constexpr index_t GemmM1N1ThreadClusterM11 = 2;
constexpr index_t GemmNLevel0Cluster = 2; constexpr index_t GemmM1N1ThreadClusterN11 = 2;
constexpr index_t GemmMLevel1Cluster = 8; constexpr index_t GemmM1N1ThreadClusterM10 = 8;
constexpr index_t GemmNLevel1Cluster = 8; constexpr index_t GemmM1N1ThreadClusterN10 = 8;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>; using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>; using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
...@@ -344,7 +344,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk( ...@@ -344,7 +344,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 4; constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4; constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM11 = 4;
#elif 1 #elif 1
// cdata = 64, BlockSize = 256, 128x128x16 // cdata = 64, BlockSize = 256, 128x128x16
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
...@@ -353,14 +353,14 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk( ...@@ -353,14 +353,14 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 16; constexpr index_t GemmKPerBlock = 16;
constexpr index_t GemmMPerThread = 4; constexpr index_t GemmM1PerThread = 4;
constexpr index_t GemmNPerThread = 4; constexpr index_t GemmN1PerThread = 4;
constexpr index_t GemmKPerThread = 1; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2; constexpr index_t GemmM1N1ThreadClusterM11 = 2;
constexpr index_t GemmNLevel0Cluster = 2; constexpr index_t GemmM1N1ThreadClusterN11 = 2;
constexpr index_t GemmMLevel1Cluster = 8; constexpr index_t GemmM1N1ThreadClusterM10 = 8;
constexpr index_t GemmNLevel1Cluster = 8; constexpr index_t GemmM1N1ThreadClusterN10 = 8;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>; using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 64>; using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 64>;
...@@ -374,11 +374,13 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk( ...@@ -374,11 +374,13 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 8; constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 8;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4; constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM11 = 4;
#endif #endif
constexpr index_t GemmM1 = GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster; constexpr index_t GemmM1 =
constexpr index_t GemmN1 = GemmNPerThread * GemmNLevel0Cluster * GemmNLevel1Cluster; GemmM1PerThread * GemmM1N1ThreadClusterM11 * GemmM1N1ThreadClusterM10;
constexpr index_t GemmN1 =
GemmN1PerThread * GemmM1N1ThreadClusterN11 * GemmM1N1ThreadClusterN10;
const auto descs = const auto descs =
#if 1 #if 1
...@@ -409,13 +411,13 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk( ...@@ -409,13 +411,13 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
GemmMPerBlock, GemmMPerBlock,
GemmNPerBlock, GemmNPerBlock,
GemmKPerBlock, GemmKPerBlock,
GemmMPerThread, GemmM1PerThread,
GemmNPerThread, GemmN1PerThread,
GemmKPerThread, GemmKPerThread,
GemmMLevel0Cluster, GemmM1N1ThreadClusterM10,
GemmNLevel0Cluster, GemmM1N1ThreadClusterN10,
GemmMLevel1Cluster, GemmM1N1ThreadClusterM11,
GemmNLevel1Cluster, GemmM1N1ThreadClusterN11,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM, GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM, GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
Sequence<1, 0>, Sequence<1, 0>,
...@@ -435,7 +437,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk( ...@@ -435,7 +437,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
// MoveSrcSliceWindow() to save addr computation // MoveSrcSliceWindow() to save addr computation
Sequence<2, 3, 0, 1>, Sequence<2, 3, 0, 1>,
1, 1,
GemmCThreadTransferDstScalarPerVector_GemmM1, GemmCThreadTransferDstScalarPerVector_GemmM11,
decltype(descs[I4]), decltype(descs[I4]),
decltype(descs[I5]), decltype(descs[I5]),
decltype(descs[I6]), decltype(descs[I6]),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment