Commit e790467d authored by ltqin's avatar ltqin
Browse files

fix tuning prometer for fp32

parent e4ed3740
...@@ -76,7 +76,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh ...@@ -76,7 +76,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 0 #elif 1
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32 // [M, N, K0, K1] = [128, 128, 4, 4] for fp32
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
...@@ -94,19 +94,19 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh ...@@ -94,19 +94,19 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 2>; using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 2>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 32, 2>; using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 32, 2>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 2; constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 2; constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 2>; using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 2>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>; using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2; constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 2; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 1 #elif 0
// [M, N, K0, K1] = [128, 128, 4, 8] // [M, N, K0, K1] = [128, 128, 4, 8] for fp16
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128; constexpr index_t GemmMPerBlock = 128;
...@@ -120,17 +120,17 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh ...@@ -120,17 +120,17 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
constexpr index_t MRepeat = 2; constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2; constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>; using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 32, 2>; using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4; constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 4; constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>; using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>; using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 4; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#endif #endif
...@@ -215,7 +215,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh ...@@ -215,7 +215,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
Sequence<0, 2, 1>, Sequence<0, 2, 1>,
1, 1,
GemmABlockTransferSrcScalarPerVector_GemmM, GemmABlockTransferSrcScalarPerVector_GemmM,
GemmABlockTransferDstScalarPerVector_GemmM, GemmABlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
...@@ -223,7 +223,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh ...@@ -223,7 +223,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
Sequence<0, 2, 1>, Sequence<0, 2, 1>,
1, 1,
GemmBBlockTransferSrcScalarPerVector_GemmN, GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmN, GemmBBlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy false, // don't move back src coordinate after threadwise copy
Sequence<2, 3, 0, 1, 7, 5, 4, 6>, Sequence<2, 3, 0, 1, 7, 5, 4, 6>,
7, 7,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment