Commit dfbe7e20 authored by Jing Zhang's avatar Jing Zhang
Browse files

added tuning params

parent b3a4d179
...@@ -13,8 +13,6 @@ namespace ck { ...@@ -13,8 +13,6 @@ namespace ck {
// GemmK = C * Y * X // GemmK = C * Y * X
template <index_t GemmMPerBlock, template <index_t GemmMPerBlock,
index_t GemmNPerBlock, index_t GemmNPerBlock,
index_t GemmM1,
index_t GemmN1,
typename... Wei, typename... Wei,
typename... In, typename... In,
typename... Out, typename... Out,
...@@ -108,9 +106,6 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad( ...@@ -108,9 +106,6 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && GemmK % GemmKPerBlock == 0); assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && GemmK % GemmKPerBlock == 0);
const auto GemmM0 = GemmM / Number<GemmM1>{};
const auto GemmN0 = GemmN / Number<GemmN1>{};
const auto out_m0_m1_m2_n_global_desc = transform_dynamic_tensor_descriptor( const auto out_m0_m1_m2_n_global_desc = transform_dynamic_tensor_descriptor(
out_gemmm_gemmn_global_desc, out_gemmm_gemmn_global_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmM / 8, 2, 4)), make_tuple(make_unmerge_transform(make_tuple(GemmM / 8, 2, 4)),
......
...@@ -21,13 +21,9 @@ template <index_t BlockSize, ...@@ -21,13 +21,9 @@ template <index_t BlockSize,
index_t MPerBlock, index_t MPerBlock,
index_t NPerBlock, index_t NPerBlock,
index_t KPerBlock, index_t KPerBlock,
index_t MPerThread, index_t MPerWave,
index_t NPerThread, index_t NPerWave,
index_t KPerThread, index_t KPerWave,
index_t MLevel0Cluster,
index_t NLevel0Cluster,
index_t MLevel1Cluster,
index_t NLevel1Cluster,
typename ABlockTransferThreadSliceLengths_K_M, typename ABlockTransferThreadSliceLengths_K_M,
typename ABlockTransferThreadClusterLengths_K_M, typename ABlockTransferThreadClusterLengths_K_M,
typename ABlockTransferThreadClusterArrangeOrder, typename ABlockTransferThreadClusterArrangeOrder,
...@@ -81,10 +77,7 @@ __host__ float launch_kernel_dynamic_gemm_xdlops_v1(const FloatAB* p_a_global, ...@@ -81,10 +77,7 @@ __host__ float launch_kernel_dynamic_gemm_xdlops_v1(const FloatAB* p_a_global,
throw std::runtime_error("wrong! GEMM size no divisible"); throw std::runtime_error("wrong! GEMM size no divisible");
} }
constexpr auto M1 = Number<MPerThread * MLevel0Cluster * MLevel1Cluster>{}; if(!(MPerBlock % MPerWave == 0 && NPerBlock % NPerWave == 0))
constexpr auto N1 = Number<NPerThread * NLevel0Cluster * NLevel1Cluster>{};
if(!(MPerBlock % M1 == 0 && NPerBlock % N1 == 0))
{ {
throw std::runtime_error("wrong! GEMM size no divisible"); throw std::runtime_error("wrong! GEMM size no divisible");
} }
...@@ -103,13 +96,9 @@ __host__ float launch_kernel_dynamic_gemm_xdlops_v1(const FloatAB* p_a_global, ...@@ -103,13 +96,9 @@ __host__ float launch_kernel_dynamic_gemm_xdlops_v1(const FloatAB* p_a_global,
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
KPerBlock, KPerBlock,
MPerThread, MPerWave,
NPerThread, NPerWave,
KPerThread, KPerWave,
MLevel0Cluster,
NLevel0Cluster,
MLevel1Cluster,
NLevel1Cluster,
ABlockTransferThreadSliceLengths_K_M, ABlockTransferThreadSliceLengths_K_M,
ABlockTransferThreadClusterLengths_K_M, ABlockTransferThreadClusterLengths_K_M,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
...@@ -141,6 +130,9 @@ __host__ float launch_kernel_dynamic_gemm_xdlops_v1(const FloatAB* p_a_global, ...@@ -141,6 +130,9 @@ __host__ float launch_kernel_dynamic_gemm_xdlops_v1(const FloatAB* p_a_global,
const bool has_double_tail_k_block_loop = (K / KPerBlock) % 2 == 0; const bool has_double_tail_k_block_loop = (K / KPerBlock) % 2 == 0;
std::cerr << "has_main_k_block_loop = " << has_main_k_block_loop
<< " has_double_tail_k_block_loop = " << has_double_tail_k_block_loop << std::endl;
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE #if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
float ave_time = 0; float ave_time = 0;
......
...@@ -15,9 +15,7 @@ template <index_t BlockSize, ...@@ -15,9 +15,7 @@ template <index_t BlockSize,
class BBlockDesc, class BBlockDesc,
index_t MPerWave, index_t MPerWave,
index_t NPerWave, index_t NPerWave,
index_t KPerWave, index_t KPerWave>
index_t MWaves,
index_t NWaves>
struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
{ {
...@@ -32,6 +30,11 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 ...@@ -32,6 +30,11 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
static constexpr index_t WaveSize = 64; static constexpr index_t WaveSize = 64;
static constexpr index_t MPerBlock = ABlockDesc{}.GetLength(I1); // A is transposed
static constexpr index_t NPerBlock = BBlockDesc{}.GetLength(I1);
static constexpr index_t MWaves = MPerBlock / MPerWave;
static constexpr index_t NWaves = NPerBlock / NPerWave;
__device__ constexpr auto GetOutputLayout() const { return XdlopsGemm.GetOutputLayout(); } __device__ constexpr auto GetOutputLayout() const { return XdlopsGemm.GetOutputLayout(); }
__device__ constexpr auto GetNumBlks() const __device__ constexpr auto GetNumBlks() const
...@@ -90,11 +93,8 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 ...@@ -90,11 +93,8 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
static_assert(ABlockDesc{}.GetLength(I0) == BBlockDesc{}.GetLength(I0), static_assert(ABlockDesc{}.GetLength(I0) == BBlockDesc{}.GetLength(I0),
"wrong! K dimension not consistent"); "wrong! K dimension not consistent");
constexpr index_t M = ABlockDesc{}.GetLength(I1); // A is transposed static_assert(MPerWave * MWaves == MPerBlock, "GemmMWaves * MPerWave != M");
constexpr index_t N = BBlockDesc{}.GetLength(I1); static_assert(NPerWave * NWaves == NPerBlock, "GemmNWaves * NPerWave != N");
static_assert(MPerWave * MWaves == M, "GemmMWaves * MPerWave != M");
static_assert(NPerWave * NWaves == N, "GemmNWaves * NPerWave != N");
static_assert(BlockSize == MWaves * NWaves * WaveSize, static_assert(BlockSize == MWaves * NWaves * WaveSize,
"BlockSize != MWaves * NWaves * WaveSize\n"); "BlockSize != MWaves * NWaves * WaveSize\n");
......
...@@ -108,13 +108,9 @@ template <index_t BlockSize, ...@@ -108,13 +108,9 @@ template <index_t BlockSize,
index_t MPerBlock, index_t MPerBlock,
index_t NPerBlock, index_t NPerBlock,
index_t KPerBlock, index_t KPerBlock,
index_t MPerThread, index_t MPerWave,
index_t NPerThread, index_t NPerWave,
index_t KPerThread, index_t KPerWave,
index_t MLevel0Cluster,
index_t NLevel0Cluster,
index_t MLevel1Cluster,
index_t NLevel1Cluster,
typename ABlockTransferThreadSliceLengths_K_M, typename ABlockTransferThreadSliceLengths_K_M,
typename ABlockTransferThreadClusterLengths_K_M, typename ABlockTransferThreadClusterLengths_K_M,
typename ABlockTransferThreadClusterArrangeOrder, typename ABlockTransferThreadClusterArrangeOrder,
...@@ -144,9 +140,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 ...@@ -144,9 +140,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{ {
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M>{}, constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M>{},
Number<BBlockTransferDstScalarPerVector_N>{}, Number<BBlockTransferDstScalarPerVector_N>{});
Number<MPerThread>{},
Number<NPerThread>{});
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
...@@ -209,9 +203,9 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 ...@@ -209,9 +203,9 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
// lds max alignment // lds max alignment
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M>{}, constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M>{},
Number<BBlockTransferDstScalarPerVector_N>{}, Number<BBlockTransferDstScalarPerVector_N>{});
Number<MPerThread>{}, // Number<MPerThread>{},
Number<NPerThread>{}); // Number<NPerThread>{});
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
...@@ -284,30 +278,28 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 ...@@ -284,30 +278,28 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register // register
// sanity check // sanity check
static_assert(MPerBlock % (MPerThread * MLevel0Cluster * MLevel1Cluster) == 0 && static_assert(MPerBlock % MPerWave == 0 && NPerBlock % NPerWave == 0, "wrong!");
NPerBlock % (NPerThread * NLevel0Cluster * NLevel1Cluster) == 0,
"wrong!"); // constexpr index_t MRepeat = MPerBlock / (MPerThread * MLevel0Cluster * MLevel1Cluster);
// constexpr index_t NRepeat = NPerBlock / (NPerThread * NLevel0Cluster * NLevel1Cluster);
constexpr index_t MRepeat = MPerBlock / (MPerThread * MLevel0Cluster * MLevel1Cluster);
constexpr index_t NRepeat = NPerBlock / (NPerThread * NLevel0Cluster * NLevel1Cluster); // constexpr auto a_k_m0_m1_block_desc = transform_dynamic_tensor_descriptor(
// a_k_m_block_desc,
constexpr auto a_k_m0_m1_block_desc = transform_dynamic_tensor_descriptor( // make_tuple(
a_k_m_block_desc, // make_pass_through_transform(Number<KPerBlock>{}),
make_tuple( // make_unmerge_transform(make_tuple(
make_pass_through_transform(Number<KPerBlock>{}), // Number<MRepeat>{}, Number<MPerThread * MLevel0Cluster * MLevel1Cluster>{}))),
make_unmerge_transform(make_tuple( // make_tuple(Sequence<0>{}, Sequence<1>{}),
Number<MRepeat>{}, Number<MPerThread * MLevel0Cluster * MLevel1Cluster>{}))), // make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{})); // constexpr auto b_k_n0_n1_block_desc = transform_dynamic_tensor_descriptor(
// b_k_n_block_desc,
constexpr auto b_k_n0_n1_block_desc = transform_dynamic_tensor_descriptor( // make_tuple(
b_k_n_block_desc, // make_pass_through_transform(Number<KPerBlock>{}),
make_tuple( // make_unmerge_transform(make_tuple(
make_pass_through_transform(Number<KPerBlock>{}), // Number<NRepeat>{}, Number<NPerThread * NLevel0Cluster * NLevel1Cluster>{}))),
make_unmerge_transform(make_tuple( // make_tuple(Sequence<0>{}, Sequence<1>{}),
Number<NRepeat>{}, Number<NPerThread * NLevel0Cluster * NLevel1Cluster>{}))), // make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
// constexpr auto c_m0_m1_n0_n1_thread_desc = // constexpr auto c_m0_m1_n0_n1_thread_desc =
// make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( // make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
...@@ -318,12 +310,9 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 ...@@ -318,12 +310,9 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
FloatAB, FloatAB,
decltype(a_k_m_block_desc), decltype(a_k_m_block_desc),
decltype(b_k_n_block_desc), decltype(b_k_n_block_desc),
64, // MPerWave, MPerWave,
64, // NPerWave, NPerWave,
1, // KPerWave, KPerWave>{};
1, // MWaves,
1 // NWaves,
>{};
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size = constexpr auto a_block_space_size =
...@@ -481,7 +470,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 ...@@ -481,7 +470,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
constexpr index_t K1 = OutputLayout.N1(); constexpr index_t K1 = OutputLayout.N1();
constexpr index_t K2 = OutputLayout.M0(); constexpr index_t K2 = OutputLayout.M0();
static_assert(K0 == 4 && K1 == 2 && K2 == 4, ""); // static_assert(K0 == 4 && K1 == 2 && K2 == 4, "");
constexpr auto c_m0_m1_m2_n_thread_desc = constexpr auto c_m0_m1_m2_n_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2( make_dynamic_naive_tensor_descriptor_packed_v2(
...@@ -490,7 +479,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 ...@@ -490,7 +479,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
constexpr index_t BlkSize = OutputLayout.GetBlkSize(); constexpr index_t BlkSize = OutputLayout.GetBlkSize();
constexpr index_t NumBlks = OutputLayout.GetNumBlks(); constexpr index_t NumBlks = OutputLayout.GetNumBlks();
static_assert(BlkSize == 16 && NumBlks == 4, ""); // static_assert(BlkSize == 16 && NumBlks == 4, "");
// force unrolling the output loop to get ride of scratches // force unrolling the output loop to get ride of scratches
static_for<0, NumBlks, 1>{}([&](auto i) { static_for<0, NumBlks, 1>{}([&](auto i) {
......
...@@ -84,14 +84,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw ...@@ -84,14 +84,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
constexpr index_t GemmNPerBlock = 64; constexpr index_t GemmNPerBlock = 64;
constexpr index_t GemmKPerBlock = 8; constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThread = 4; constexpr index_t GemmMPerWave = 64;
constexpr index_t GemmNPerThread = 4; constexpr index_t GemmNPerWave = 64;
constexpr index_t GemmKPerThread = 1; constexpr index_t GemmKPerWave = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 2;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>; using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 32>; using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 32>;
...@@ -107,14 +102,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw ...@@ -107,14 +102,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1; constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
constexpr index_t GemmM1 = GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster; constexpr index_t GemmM1 = GemmMPerWave;
constexpr index_t GemmN1 = GemmNPerThread * GemmNLevel0Cluster * GemmNLevel1Cluster; constexpr index_t GemmN1 = GemmNPerWave;
const auto descs = const auto descs =
transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad<GemmMPerBlock, transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad<GemmMPerBlock,
GemmNPerBlock, GemmNPerBlock>(
GemmM1,
GemmN1>(
wei_k_c_y_x_desc, wei_k_c_y_x_desc,
in_n_c_hi_wi_desc, in_n_c_hi_wi_desc,
out_n_k_ho_wo_desc, out_n_k_ho_wo_desc,
...@@ -138,13 +131,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw ...@@ -138,13 +131,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
GemmMPerBlock, GemmMPerBlock,
GemmNPerBlock, GemmNPerBlock,
GemmKPerBlock, GemmKPerBlock,
GemmMPerThread, GemmMPerWave,
GemmNPerThread, GemmNPerWave,
GemmKPerThread, GemmKPerWave,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM, GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM, GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
Sequence<1, 0>, Sequence<1, 0>,
......
...@@ -25,11 +25,11 @@ int main(int argc, char* argv[]) ...@@ -25,11 +25,11 @@ int main(int argc, char* argv[])
using namespace ck; using namespace ck;
#if 1 #if 1
constexpr index_t N = 4; constexpr index_t N = 256;
constexpr index_t C = 16; constexpr index_t C = 256;
constexpr index_t HI = 4; constexpr index_t HI = 16;
constexpr index_t WI = 4; constexpr index_t WI = 16;
constexpr index_t K = 64; constexpr index_t K = 256;
constexpr index_t Y = 1; constexpr index_t Y = 1;
constexpr index_t X = 1; constexpr index_t X = 1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment