Commit 760a234f authored by Chao Liu's avatar Chao Liu
Browse files

use StaticallyIndexedArray for buffer in threadwise copy, in order to get rid of alloca in IR

parent 70d06fa9
...@@ -173,9 +173,8 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw ...@@ -173,9 +173,8 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
// GEMM // GEMM
#if 1 using gridwise_gemm = GridwiseDynamicGemm_km_kn_mn_v1<
using gridwise_gemm = BlockSize,
GridwiseDynamicGemm_km_kn_mn_v1<BlockSize,
Float, Float,
AccFloat, AccFloat,
InMemoryDataOperation::Set, InMemoryDataOperation::Set,
...@@ -196,6 +195,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw ...@@ -196,6 +195,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw
0, 0,
GemmABlockTransferSrcScalarPerVector_GemmK, GemmABlockTransferSrcScalarPerVector_GemmK,
GemmABlockTransferDstScalarPerVector_GemmM, GemmABlockTransferDstScalarPerVector_GemmM,
true, // move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN, GemmBBlockTransferThreadSliceLengths_GemmK_GemmN,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN, GemmBBlockTransferThreadClusterLengths_GemmK_GemmN,
Sequence<0, 1>, Sequence<0, 1>,
...@@ -203,6 +203,8 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw ...@@ -203,6 +203,8 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw
1, 1,
GemmBBlockTransferSrcScalarPerVector_GemmN, GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmN, GemmBBlockTransferDstScalarPerVector_GemmN,
false, // don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation
Sequence<2, 3, 0, 1>, Sequence<2, 3, 0, 1>,
3, 3,
GemmCThreadTransferDstScalarPerVector_GemmN1>; GemmCThreadTransferDstScalarPerVector_GemmN1>;
...@@ -261,63 +263,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw ...@@ -261,63 +263,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw
p_out_global, p_out_global,
integral_constant<bool, false>{}); integral_constant<bool, false>{});
} }
#else
using gridwise_gemm =
GridwiseDynamicGemm_km_kn_mn_v2<BlockSize,
Float,
AccFloat,
InMemoryDataOperation::Set,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThread,
GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
Sequence<1, 0>,
Sequence<1, 0>,
0,
GemmABlockTransferSrcScalarPerVector_GemmK,
GemmABlockTransferDstScalarPerVector_GemmM,
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN,
Sequence<0, 1>,
Sequence<0, 1>,
1,
GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmN,
Sequence<2, 3, 0, 1>,
3,
GemmCThreadTransferDstScalarPerVector_GemmN1>;
const index_t GridSize = (GemmM / GemmMPerBlock) * (GemmN / GemmNPerBlock);
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const Float*,
decltype(in_gemmk_gemmn_global_desc),
const Float*,
decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global);
#endif
} }
}; };
......
...@@ -51,26 +51,24 @@ constexpr auto get_convolution_output_default_4d_tensor_descriptor( ...@@ -51,26 +51,24 @@ constexpr auto get_convolution_output_default_4d_tensor_descriptor(
} }
template <class InDesc, class WeiDesc, class OutDesc> template <class InDesc, class WeiDesc, class OutDesc>
constexpr std::size_t calculate_convolution_flops(InDesc, WeiDesc, OutDesc) constexpr std::size_t
calculate_convolution_flops(const InDesc& in_desc, const WeiDesc& wei_desc, const OutDesc& out_desc)
{ {
using namespace ck; using namespace ck;
constexpr auto wei_desc = WeiDesc{};
constexpr auto out_desc = OutDesc{};
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
constexpr index_t N = out_desc.GetLength(I0); const index_t N = out_desc.GetLength(I0);
constexpr index_t K = out_desc.GetLength(I1); const index_t K = out_desc.GetLength(I1);
constexpr index_t Ho = out_desc.GetLength(I2); const index_t Ho = out_desc.GetLength(I2);
constexpr index_t Wo = out_desc.GetLength(I3); const index_t Wo = out_desc.GetLength(I3);
constexpr index_t C = wei_desc.GetLength(I1); const index_t C = wei_desc.GetLength(I1);
constexpr index_t Y = wei_desc.GetLength(I2); const index_t Y = wei_desc.GetLength(I2);
constexpr index_t X = wei_desc.GetLength(I3); const index_t X = wei_desc.GetLength(I3);
return std::size_t(2) * N * K * Ho * Wo * C * Y * X; return std::size_t(2) * N * K * Ho * Wo * C * Y * X;
} }
......
...@@ -577,7 +577,7 @@ int main(int argc, char* argv[]) ...@@ -577,7 +577,7 @@ int main(int argc, char* argv[])
LeftPads{}, LeftPads{},
RightPads{}, RightPads{},
nrepeat); nrepeat);
#elif 0 #elif 1
device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(in_nchw_desc, device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(in_nchw_desc,
in_nchw, in_nchw,
wei_kcyx_desc, wei_kcyx_desc,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment