Commit c564a605 authored by ltqin's avatar ltqin
Browse files

fix some tuning code

parent 573e1b64
...@@ -48,7 +48,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_ ...@@ -48,7 +48,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_
const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths); const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths);
#if 1 #if 1
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16 // [M, N, K0, K1] = [128, 128, 4, 8] for fp32
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128; constexpr index_t GemmMPerBlock = 128;
...@@ -77,34 +77,6 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_ ...@@ -77,34 +77,6 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
constexpr index_t KBatch = 64; constexpr index_t KBatch = 64;
#elif 1
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
// using vector load 4, so config's wo*ho must be a multiple of 4
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#endif #endif
const auto descs = transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad( const auto descs = transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment