Commit b3e80872 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent 6c37035f
......@@ -20,13 +20,13 @@ template <index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerThread,
index_t NPerThread,
index_t M1PerThread,
index_t N1PerThread,
index_t KPerThread,
index_t MLevel0Cluster,
index_t NLevel0Cluster,
index_t MLevel1Cluster,
index_t NLevel1Cluster,
index_t M1N1ThreadClusterM10,
index_t M1N1ThreadClusterN10,
index_t M1N1ThreadClusterM11,
index_t M1N1ThreadClusterN11,
typename ABlockTransferThreadSliceLengths_K_M,
typename ABlockTransferThreadClusterLengths_K_M,
typename ABlockTransferThreadClusterArrangeOrder,
......@@ -80,8 +80,8 @@ __host__ float launch_kernel_dynamic_gemm_v1r1(const FloatAB* p_a_global,
throw std::runtime_error("wrong! GEMM size no divisible");
}
constexpr auto M1 = Number<MPerThread * MLevel0Cluster * MLevel1Cluster>{};
constexpr auto N1 = Number<NPerThread * NLevel0Cluster * NLevel1Cluster>{};
constexpr auto M1 = Number<M1PerThread * M1N1ThreadClusterM11 * M1N1ThreadClusterM10>{};
constexpr auto N1 = Number<N1PerThread * M1N1ThreadClusterN11 * M1N1ThreadClusterN10>{};
if(!(MPerBlock % M1 == 0 && NPerBlock % N1 == 0))
{
......@@ -102,13 +102,13 @@ __host__ float launch_kernel_dynamic_gemm_v1r1(const FloatAB* p_a_global,
MPerBlock,
NPerBlock,
KPerBlock,
MPerThread,
NPerThread,
M1PerThread,
N1PerThread,
KPerThread,
MLevel0Cluster,
NLevel0Cluster,
MLevel1Cluster,
NLevel1Cluster,
M1N1ThreadClusterM10,
M1N1ThreadClusterN10,
M1N1ThreadClusterM11,
M1N1ThreadClusterN11,
ABlockTransferThreadSliceLengths_K_M,
ABlockTransferThreadClusterLengths_K_M,
ABlockTransferThreadClusterArrangeOrder,
......
......@@ -29,10 +29,10 @@ template <index_t BlockSize,
index_t M1PerThread,
index_t N1PerThread,
index_t KPerThread,
index_t MLevel0ThreadCluster,
index_t NLevel0ThreadCluster,
index_t MLevel1ThreadCluster,
index_t NLevel1ThreadCluster,
index_t M1N1ThreadClusterM10,
index_t M1N1ThreadClusterN10,
index_t M1N1ThreadClusterM11,
index_t M1N1ThreadClusterN11,
index_t AThreadCopyScalarPerVector_M1,
index_t BThreadCopyScalarPerVector_N1,
typename std::enable_if<ABlockDesc::IsKnownAtCompileTime() &&
......@@ -62,8 +62,8 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
CThreadDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(BlockSize == MLevel0ThreadCluster * MLevel1ThreadCluster *
NLevel0ThreadCluster * NLevel1ThreadCluster,
static_assert(BlockSize == M1N1ThreadClusterM11 * M1N1ThreadClusterM10 *
M1N1ThreadClusterN11 * M1N1ThreadClusterN10,
"wrong! blocksize and cluster size not consistent");
static_assert(ABlockDesc{}.GetLength(I0) == BBlockDesc{}.GetLength(I0),
......@@ -78,6 +78,8 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
constexpr index_t N1 = BBlockDesc{}.GetLength(I2);
// 4-d data space into 4-d thread space
// upper: {1, M1N1ThreadClusterM10 * M1N1ThreadClusterM11, 1, M1N1ThreadClusterN10 *
// M1N1ThreadClusterN11} lower: {M0, M1, N0, N1}
constexpr auto adaptor0 = make_single_stage_tensor_adaptor(
make_tuple(make_vectorize_transform(M0, 1),
make_vectorize_transform(M1PerThread, M1 / M1PerThread),
......@@ -87,21 +89,27 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
// thread position 4-d thread space
// upper: {M1N1ThreadClusterM10, M1N1ThreadClusterM11, M1N1ThreadClusterN10,
// M1N1ThreadClusterN11} lower: {1, M1N1ThreadClusterM10 * M1N1ThreadClusterM11, 1,
// M1N1ThreadClusterN10 * M1N1ThreadClusterN11}
constexpr auto adaptor1 = make_single_stage_tensor_adaptor(
make_tuple(
make_freeze_transform(make_multi_index(0)),
make_unmerge_transform(make_tuple(MLevel1ThreadCluster, MLevel0ThreadCluster)),
make_unmerge_transform(make_tuple(M1N1ThreadClusterM10, M1N1ThreadClusterM11)),
make_freeze_transform(make_multi_index(0)),
make_unmerge_transform(make_tuple(NLevel1ThreadCluster, NLevel0ThreadCluster))),
make_unmerge_transform(make_tuple(M1N1ThreadClusterN10, M1N1ThreadClusterN11))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<>{}, Sequence<0, 1>{}, Sequence<>{}, Sequence<2, 3>{}));
// 4-d thread space to 1-d thread space
// upper: {BlockSize}
// lower: {M1N1ThreadClusterM10, M1N1ThreadClusterM11, M1N1ThreadClusterN10,
// M1N1ThreadClusterN11}
constexpr auto adaptor2 = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(MLevel1ThreadCluster,
NLevel1ThreadCluster,
MLevel0ThreadCluster,
NLevel0ThreadCluster))),
make_tuple(make_merge_transform(make_tuple(M1N1ThreadClusterM10,
M1N1ThreadClusterN10,
M1N1ThreadClusterM11,
M1N1ThreadClusterN11))),
make_tuple(Sequence<0, 2, 1, 3>{}),
make_tuple(Sequence<0>{}));
......@@ -221,10 +229,10 @@ template <index_t BlockSize,
index_t M1PerThread,
index_t N1PerThread,
index_t KPerThread,
index_t MLevel0ThreadCluster,
index_t NLevel0ThreadCluster,
index_t MLevel1ThreadCluster,
index_t NLevel1ThreadCluster,
index_t M1N1ThreadClusterM11,
index_t M1N1ThreadClusterN11,
index_t M1N1ThreadClusterM10,
index_t M1N1ThreadClusterN10,
index_t AThreadCopyScalarPerVector_M1,
index_t BThreadCopyScalarPerVector_N1,
typename std::enable_if<ABlockDesc::IsKnownAtCompileTime() &&
......@@ -254,8 +262,8 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2
CThreadDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(BlockSize == MLevel0ThreadCluster * MLevel1ThreadCluster *
NLevel0ThreadCluster * NLevel1ThreadCluster,
static_assert(BlockSize == M1N1ThreadClusterM11 * M1N1ThreadClusterM10 *
M1N1ThreadClusterN11 * M1N1ThreadClusterN10,
"wrong! blocksize and cluster size not consistent");
static_assert(ABlockDesc{}.GetLength(I0) == BBlockDesc{}.GetLength(I0),
......@@ -287,18 +295,18 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2
constexpr auto adaptor1 = make_single_stage_tensor_adaptor(
make_tuple(
make_freeze_transform(make_multi_index(0)),
make_unmerge_transform(make_tuple(MLevel1ThreadCluster, MLevel0ThreadCluster)),
make_unmerge_transform(make_tuple(M1N1ThreadClusterM10, M1N1ThreadClusterM11)),
make_freeze_transform(make_multi_index(0)),
make_unmerge_transform(make_tuple(NLevel1ThreadCluster, NLevel0ThreadCluster))),
make_unmerge_transform(make_tuple(M1N1ThreadClusterN10, M1N1ThreadClusterN11))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<>{}, Sequence<0, 1>{}, Sequence<>{}, Sequence<2, 3>{}));
// 4-d thread space to 1-d thread space
constexpr auto adaptor2 = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(MLevel1ThreadCluster,
NLevel1ThreadCluster,
MLevel0ThreadCluster,
NLevel0ThreadCluster))),
make_tuple(make_merge_transform(make_tuple(M1N1ThreadClusterM10,
M1N1ThreadClusterN10,
M1N1ThreadClusterM11,
M1N1ThreadClusterN11))),
make_tuple(Sequence<0, 2, 1, 3>{}),
make_tuple(Sequence<0>{}));
......
......@@ -108,13 +108,13 @@ template <index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerThread,
index_t NPerThread,
index_t M1PerThread,
index_t N1PerThread,
index_t KPerThread,
index_t MLevel0Cluster,
index_t NLevel0Cluster,
index_t MLevel1Cluster,
index_t NLevel1Cluster,
index_t M1N1ThreadClusterM10,
index_t M1N1ThreadClusterN10,
index_t M1N1ThreadClusterM11,
index_t M1N1ThreadClusterN11,
typename ABlockTransferThreadSliceLengths_K_M,
typename ABlockTransferThreadClusterLengths_K_M,
typename ABlockTransferThreadClusterArrangeOrder,
......@@ -145,8 +145,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r1
{
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M>{},
Number<BBlockTransferDstScalarPerVector_N>{},
Number<MPerThread>{},
Number<NPerThread>{});
Number<M1PerThread>{},
Number<N1PerThread>{});
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
......@@ -210,8 +210,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r1
// lds max alignment
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M>{},
Number<BBlockTransferDstScalarPerVector_N>{},
Number<MPerThread>{},
Number<NPerThread>{});
Number<M1PerThread>{},
Number<N1PerThread>{});
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
......@@ -284,34 +284,39 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r1
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
static_assert(MPerBlock % (MPerThread * MLevel0Cluster * MLevel1Cluster) == 0 &&
NPerBlock % (NPerThread * NLevel0Cluster * NLevel1Cluster) == 0,
"wrong!");
static_assert(
MPerBlock % (M1PerThread * M1N1ThreadClusterM11 * M1N1ThreadClusterM10) == 0 &&
NPerBlock % (N1PerThread * M1N1ThreadClusterN11 * M1N1ThreadClusterN10) == 0,
"wrong!");
constexpr index_t MRepeat = MPerBlock / (MPerThread * MLevel0Cluster * MLevel1Cluster);
constexpr index_t NRepeat = NPerBlock / (NPerThread * NLevel0Cluster * NLevel1Cluster);
constexpr index_t M0PerThread =
MPerBlock / (M1PerThread * M1N1ThreadClusterM11 * M1N1ThreadClusterM10);
constexpr index_t N0PerThread =
NPerBlock / (N1PerThread * M1N1ThreadClusterN11 * M1N1ThreadClusterN10);
constexpr auto a_k_m0_m1_block_desc = transform_dynamic_tensor_descriptor(
a_k_m_block_desc,
make_tuple(
make_pass_through_transform(Number<KPerBlock>{}),
make_unmerge_transform(make_tuple(
Number<MRepeat>{}, Number<MPerThread * MLevel0Cluster * MLevel1Cluster>{}))),
make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
make_unmerge_transform(make_tuple(
Number<M0PerThread>{},
Number<M1PerThread * M1N1ThreadClusterM11 * M1N1ThreadClusterM10>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
constexpr auto b_k_n0_n1_block_desc = transform_dynamic_tensor_descriptor(
b_k_n_block_desc,
make_tuple(
make_pass_through_transform(Number<KPerBlock>{}),
make_unmerge_transform(make_tuple(
Number<NRepeat>{}, Number<NPerThread * NLevel0Cluster * NLevel1Cluster>{}))),
make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
make_unmerge_transform(make_tuple(
Number<N0PerThread>{},
Number<N1PerThread * M1N1ThreadClusterN11 * M1N1ThreadClusterN10>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
constexpr auto c_m0_m1_n0_n1_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
Number<MRepeat>{}, Number<MPerThread>{}, Number<NRepeat>{}, Number<NPerThread>{}));
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(Number<M0PerThread>{},
Number<M1PerThread>{},
Number<N0PerThread>{},
Number<N1PerThread>{}));
const auto blockwise_gemm =
BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2<BlockSize,
......@@ -321,15 +326,15 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r1
decltype(a_k_m0_m1_block_desc),
decltype(b_k_n0_n1_block_desc),
decltype(c_m0_m1_n0_n1_thread_desc),
MPerThread,
NPerThread,
M1PerThread,
N1PerThread,
KPerThread,
MLevel0Cluster,
NLevel0Cluster,
MLevel1Cluster,
NLevel1Cluster,
MPerThread,
NPerThread>{};
M1N1ThreadClusterM10,
M1N1ThreadClusterN10,
M1N1ThreadClusterM11,
M1N1ThreadClusterN11,
M1PerThread,
N1PerThread>{};
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size =
......@@ -345,9 +350,10 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r1
auto c_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatAcc>(
c_m0_m1_n0_n1_thread_desc.GetElementSpaceSize());
ThreadwiseDynamicTensorSliceSet_v1<FloatAcc,
decltype(c_m0_m1_n0_n1_thread_desc),
Sequence<MRepeat, MPerThread, NRepeat, NPerThread>>{}
ThreadwiseDynamicTensorSliceSet_v1<
FloatAcc,
decltype(c_m0_m1_n0_n1_thread_desc),
Sequence<M0PerThread, M1PerThread, N0PerThread, N1PerThread>>{}
.Run(c_m0_m1_n0_n1_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0});
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0);
......@@ -479,8 +485,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r1
// output: register to global memory
{
constexpr auto M1 = Number<MPerThread * MLevel0Cluster * MLevel1Cluster>{};
constexpr auto N1 = Number<NPerThread * NLevel0Cluster * NLevel1Cluster>{};
constexpr auto M1 = Number<M1PerThread * M1N1ThreadClusterM10 * M1N1ThreadClusterM11>{};
constexpr auto N1 = Number<N1PerThread * M1N1ThreadClusterN10 * M1N1ThreadClusterN11>{};
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
constexpr auto c_m0_m1_n0_n1_global_tensor_iterator_hacks = CGlobalIteratorHacks{};
......@@ -493,7 +499,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r1
FloatC,
decltype(c_m0_m1_n0_n1_thread_desc),
decltype(c_m0_m1_n0_n1_global_desc),
Sequence<MRepeat, MPerThread, NRepeat, NPerThread>,
Sequence<M0PerThread, M1PerThread, N0PerThread, N1PerThread>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
......
......@@ -125,14 +125,14 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
constexpr index_t GemmNPerBlock = 64;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerThread = 2;
constexpr index_t GemmNPerThread = 2;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmM1PerThread = 2;
constexpr index_t GemmN1PerThread = 2;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 8;
constexpr index_t GemmM1N1ThreadClusterM11 = 2;
constexpr index_t GemmM1N1ThreadClusterN11 = 2;
constexpr index_t GemmM1N1ThreadClusterM10 = 2;
constexpr index_t GemmM1N1ThreadClusterN10 = 8;
constexpr index_t ThreadGemmDataPerReadM = 2;
constexpr index_t ThreadGemmDataPerReadN = 2;
......@@ -149,7 +149,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 2;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM11 = 2;
#elif 0
// cdata = 32, BlockSize = 64, 16x128x4
constexpr index_t BlockSize = 64;
......@@ -158,14 +158,14 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerThread = 2;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmM1PerThread = 2;
constexpr index_t GemmN1PerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 8;
constexpr index_t GemmM1N1ThreadClusterM11 = 2;
constexpr index_t GemmM1N1ThreadClusterN11 = 2;
constexpr index_t GemmM1N1ThreadClusterM10 = 2;
constexpr index_t GemmM1N1ThreadClusterN10 = 8;
constexpr index_t ThreadGemmDataPerReadM = 2;
constexpr index_t ThreadGemmDataPerReadN = 4;
......@@ -182,7 +182,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 2;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM11 = 2;
#elif 0
// cdata = 64, BlockSize = 64, 16x256x2
constexpr index_t BlockSize = 64;
......@@ -191,14 +191,14 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmKPerBlock = 2;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmM1PerThread = 4;
constexpr index_t GemmN1PerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 1;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 16;
constexpr index_t GemmM1N1ThreadClusterM11 = 1;
constexpr index_t GemmM1N1ThreadClusterN11 = 2;
constexpr index_t GemmM1N1ThreadClusterM10 = 2;
constexpr index_t GemmM1N1ThreadClusterN10 = 16;
constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t ThreadGemmDataPerReadN = 4;
......@@ -215,7 +215,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 2;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM11 = 4;
#elif 0
// cdata = 64, BlockSize = 64, 16x256x4
constexpr index_t BlockSize = 64;
......@@ -224,14 +224,14 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmM1PerThread = 4;
constexpr index_t GemmN1PerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 1;
constexpr index_t GemmNLevel1Cluster = 16;
constexpr index_t GemmM1N1ThreadClusterM11 = 2;
constexpr index_t GemmM1N1ThreadClusterN11 = 2;
constexpr index_t GemmM1N1ThreadClusterM10 = 1;
constexpr index_t GemmM1N1ThreadClusterN10 = 16;
constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t ThreadGemmDataPerReadN = 4;
......@@ -248,7 +248,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM11 = 4;
#elif 0
// cdata = 64, BlockSize = 128, 32x256x4
constexpr index_t BlockSize = 128;
......@@ -257,14 +257,14 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmM1PerThread = 4;
constexpr index_t GemmN1PerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 16;
constexpr index_t GemmM1N1ThreadClusterM11 = 2;
constexpr index_t GemmM1N1ThreadClusterN11 = 2;
constexpr index_t GemmM1N1ThreadClusterM10 = 2;
constexpr index_t GemmM1N1ThreadClusterN10 = 16;
constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t ThreadGemmDataPerReadN = 4;
......@@ -281,7 +281,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM11 = 4;
#elif 0
// cdata = 64, BlockSize = 128, 32x256x8
constexpr index_t BlockSize = 128;
......@@ -290,14 +290,14 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmM1PerThread = 4;
constexpr index_t GemmN1PerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 16;
constexpr index_t GemmM1N1ThreadClusterM11 = 2;
constexpr index_t GemmM1N1ThreadClusterN11 = 2;
constexpr index_t GemmM1N1ThreadClusterM10 = 2;
constexpr index_t GemmM1N1ThreadClusterN10 = 16;
constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t ThreadGemmDataPerReadN = 4;
......@@ -314,7 +314,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 8;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM11 = 4;
#elif 1
// cdata = 64, BlockSize = 256, 128x128x8
constexpr index_t BlockSize = 256;
......@@ -323,14 +323,14 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmM1PerThread = 4;
constexpr index_t GemmN1PerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 8;
constexpr index_t GemmNLevel1Cluster = 8;
constexpr index_t GemmM1N1ThreadClusterM11 = 2;
constexpr index_t GemmM1N1ThreadClusterN11 = 2;
constexpr index_t GemmM1N1ThreadClusterM10 = 8;
constexpr index_t GemmM1N1ThreadClusterN10 = 8;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
......@@ -344,7 +344,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM11 = 4;
#elif 1
// cdata = 64, BlockSize = 256, 128x128x16
constexpr index_t BlockSize = 256;
......@@ -353,14 +353,14 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 16;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmM1PerThread = 4;
constexpr index_t GemmN1PerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 8;
constexpr index_t GemmNLevel1Cluster = 8;
constexpr index_t GemmM1N1ThreadClusterM11 = 2;
constexpr index_t GemmM1N1ThreadClusterN11 = 2;
constexpr index_t GemmM1N1ThreadClusterM10 = 8;
constexpr index_t GemmM1N1ThreadClusterN10 = 8;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 64>;
......@@ -374,11 +374,13 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 8;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM11 = 4;
#endif
constexpr index_t GemmM1 = GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster;
constexpr index_t GemmN1 = GemmNPerThread * GemmNLevel0Cluster * GemmNLevel1Cluster;
constexpr index_t GemmM1 =
GemmM1PerThread * GemmM1N1ThreadClusterM11 * GemmM1N1ThreadClusterM10;
constexpr index_t GemmN1 =
GemmN1PerThread * GemmM1N1ThreadClusterN11 * GemmM1N1ThreadClusterN10;
const auto descs =
#if 1
......@@ -409,13 +411,13 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThread,
GemmNPerThread,
GemmM1PerThread,
GemmN1PerThread,
GemmKPerThread,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmM1N1ThreadClusterM10,
GemmM1N1ThreadClusterN10,
GemmM1N1ThreadClusterM11,
GemmM1N1ThreadClusterN11,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
Sequence<1, 0>,
......@@ -435,7 +437,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
// MoveSrcSliceWindow() to save addr computation
Sequence<2, 3, 0, 1>,
1,
GemmCThreadTransferDstScalarPerVector_GemmM1,
GemmCThreadTransferDstScalarPerVector_GemmM11,
decltype(descs[I4]),
decltype(descs[I5]),
decltype(descs[I6]),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment