Commit 4f827c76 authored by root's avatar root
Browse files

change vector dim of input to GemmN

parent 8a906f5f
...@@ -31,7 +31,7 @@ template <index_t BlockSize, ...@@ -31,7 +31,7 @@ template <index_t BlockSize,
index_t GemmABlockTransferDstScalarPerVector_GemmM, index_t GemmABlockTransferDstScalarPerVector_GemmM,
typename GemmBBlockTransferThreadSliceLengths_GemmK_GemmN, typename GemmBBlockTransferThreadSliceLengths_GemmK_GemmN,
typename GemmBBlockTransferThreadClusterLengths_GemmK_GemmN, typename GemmBBlockTransferThreadClusterLengths_GemmK_GemmN,
index_t GemmBBlockTransferSrcScalarPerVector_GemmK, index_t GemmBBlockTransferSrcScalarPerVector_GemmN,
index_t GemmBBlockTransferDstScalarPerVector_GemmN, index_t GemmBBlockTransferDstScalarPerVector_GemmN,
index_t GemmCThreadTransferDstScalarPerVector_GemmM1> index_t GemmCThreadTransferDstScalarPerVector_GemmM1>
struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_chwn_cyxk_khwn_pad struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_chwn_cyxk_khwn_pad
...@@ -209,8 +209,8 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_chwn_cyxk_khwn_pad ...@@ -209,8 +209,8 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_chwn_cyxk_khwn_pad
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN, GemmBBlockTransferThreadClusterLengths_GemmK_GemmN,
Sequence<1, 0>, Sequence<1, 0>,
Sequence<1, 0>, Sequence<1, 0>,
0, 1,
GemmBBlockTransferSrcScalarPerVector_GemmK, GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmN, GemmBBlockTransferDstScalarPerVector_GemmN,
false, // don't move back src coordinate after threadwise copy, which will be fused with false, // don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation // MoveSrcSliceWindow() to save addr computation
...@@ -701,7 +701,7 @@ template <index_t BlockSize, ...@@ -701,7 +701,7 @@ template <index_t BlockSize,
index_t GemmABlockTransferDstScalarPerVector_GemmM, index_t GemmABlockTransferDstScalarPerVector_GemmM,
typename GemmBBlockTransferThreadSliceLengths_GemmK_GemmN, typename GemmBBlockTransferThreadSliceLengths_GemmK_GemmN,
typename GemmBBlockTransferThreadClusterLengths_GemmK_GemmN, typename GemmBBlockTransferThreadClusterLengths_GemmK_GemmN,
index_t GemmBBlockTransferSrcScalarPerVector_GemmK, index_t GemmBBlockTransferSrcScalarPerVector_GemmN,
index_t GemmBBlockTransferDstScalarPerVector_GemmN, index_t GemmBBlockTransferDstScalarPerVector_GemmN,
index_t GemmCThreadTransferDstScalarPerVector_GemmM1> index_t GemmCThreadTransferDstScalarPerVector_GemmM1>
struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_chwn_cyxk_khwn_1x1 struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_chwn_cyxk_khwn_1x1
...@@ -862,8 +862,8 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_chwn_cyxk_khwn_1x1 ...@@ -862,8 +862,8 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_chwn_cyxk_khwn_1x1
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN, GemmBBlockTransferThreadClusterLengths_GemmK_GemmN,
Sequence<1, 0>, Sequence<1, 0>,
Sequence<1, 0>, Sequence<1, 0>,
0, 1,
GemmBBlockTransferSrcScalarPerVector_GemmK, GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmN, GemmBBlockTransferDstScalarPerVector_GemmN,
false, // don't move back src coordinate after threadwise copy, which will be fused with false, // don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation // MoveSrcSliceWindow() to save addr computation
......
...@@ -135,10 +135,10 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_chwn_cyxk_khwn(InDesc ...@@ -135,10 +135,10 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_chwn_cyxk_khwn(InDesc
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1; constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1; constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>; using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<1, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>; using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<4, 16>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 1; constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 1; constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 1;
...@@ -171,7 +171,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_chwn_cyxk_khwn(InDesc ...@@ -171,7 +171,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_chwn_cyxk_khwn(InDesc
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 2>; using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 2>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>; using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 4; constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 2; constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 2;
...@@ -204,7 +204,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_chwn_cyxk_khwn(InDesc ...@@ -204,7 +204,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_chwn_cyxk_khwn(InDesc
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>; using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>; using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 2; constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4; constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4;
...@@ -237,7 +237,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_chwn_cyxk_khwn(InDesc ...@@ -237,7 +237,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_chwn_cyxk_khwn(InDesc
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 4>; using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>; using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 4; constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4; constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4;
...@@ -270,7 +270,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_chwn_cyxk_khwn(InDesc ...@@ -270,7 +270,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_chwn_cyxk_khwn(InDesc
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 2>; using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 2>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 128>; using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 128>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 4; constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4; constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4;
...@@ -303,7 +303,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_chwn_cyxk_khwn(InDesc ...@@ -303,7 +303,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_chwn_cyxk_khwn(InDesc
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<8, 2>; using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<8, 2>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 128>; using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 128>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 8; constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4; constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4;
...@@ -333,7 +333,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_chwn_cyxk_khwn(InDesc ...@@ -333,7 +333,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_chwn_cyxk_khwn(InDesc
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>; using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>; using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 4; constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4; constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4;
...@@ -363,7 +363,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_chwn_cyxk_khwn(InDesc ...@@ -363,7 +363,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_chwn_cyxk_khwn(InDesc
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<8, 1>; using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<8, 1>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>; using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 8; constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4; constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4;
...@@ -396,7 +396,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_chwn_cyxk_khwn(InDesc ...@@ -396,7 +396,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_chwn_cyxk_khwn(InDesc
GemmABlockTransferDstScalarPerVector_GemmM, GemmABlockTransferDstScalarPerVector_GemmM,
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN, GemmBBlockTransferThreadSliceLengths_GemmK_GemmN,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN, GemmBBlockTransferThreadClusterLengths_GemmK_GemmN,
GemmBBlockTransferSrcScalarPerVector_GemmK, GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmN, GemmBBlockTransferDstScalarPerVector_GemmN,
GemmCThreadTransferDstScalarPerVector_GemmM1>{}; GemmCThreadTransferDstScalarPerVector_GemmM1>{};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment