"configs/vscode:/vscode.git/clone" did not exist on "90c07a3dfd99f14bfbc5b43f59b96ce48fc4d0ec"
Commit 760a234f authored by Chao Liu's avatar Chao Liu
Browse files

use StaticallyIndexedArray for buffer in threadwise copy, in order to get rid of alloca in IR

parent 70d06fa9
...@@ -173,39 +173,41 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw ...@@ -173,39 +173,41 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
// GEMM // GEMM
#if 1 using gridwise_gemm = GridwiseDynamicGemm_km_kn_mn_v1<
using gridwise_gemm = BlockSize,
GridwiseDynamicGemm_km_kn_mn_v1<BlockSize, Float,
Float, AccFloat,
AccFloat, InMemoryDataOperation::Set,
InMemoryDataOperation::Set, GemmMPerBlock,
GemmMPerBlock, GemmNPerBlock,
GemmNPerBlock, GemmKPerBlock,
GemmKPerBlock, GemmMPerThread,
GemmMPerThread, GemmNPerThread,
GemmNPerThread, GemmKPerThread,
GemmKPerThread, GemmMLevel0Cluster,
GemmMLevel0Cluster, GemmNLevel0Cluster,
GemmNLevel0Cluster, GemmMLevel1Cluster,
GemmMLevel1Cluster, GemmNLevel1Cluster,
GemmNLevel1Cluster, GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM, GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM, Sequence<1, 0>,
Sequence<1, 0>, Sequence<1, 0>,
Sequence<1, 0>, 0,
0, GemmABlockTransferSrcScalarPerVector_GemmK,
GemmABlockTransferSrcScalarPerVector_GemmK, GemmABlockTransferDstScalarPerVector_GemmM,
GemmABlockTransferDstScalarPerVector_GemmM, true, // move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN, GemmBBlockTransferThreadSliceLengths_GemmK_GemmN,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN, GemmBBlockTransferThreadClusterLengths_GemmK_GemmN,
Sequence<0, 1>, Sequence<0, 1>,
Sequence<0, 1>, Sequence<0, 1>,
1, 1,
GemmBBlockTransferSrcScalarPerVector_GemmN, GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmN, GemmBBlockTransferDstScalarPerVector_GemmN,
Sequence<2, 3, 0, 1>, false, // don't move back src coordinate after threadwise copy, which will be fused with
3, // MoveSrcSliceWindow() to save addr computation
GemmCThreadTransferDstScalarPerVector_GemmN1>; Sequence<2, 3, 0, 1>,
3,
GemmCThreadTransferDstScalarPerVector_GemmN1>;
const index_t GridSize = (GemmM / GemmMPerBlock) * (GemmN / GemmNPerBlock); const index_t GridSize = (GemmM / GemmMPerBlock) * (GemmN / GemmNPerBlock);
...@@ -261,63 +263,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw ...@@ -261,63 +263,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw
p_out_global, p_out_global,
integral_constant<bool, false>{}); integral_constant<bool, false>{});
} }
#else
using gridwise_gemm =
GridwiseDynamicGemm_km_kn_mn_v2<BlockSize,
Float,
AccFloat,
InMemoryDataOperation::Set,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThread,
GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
Sequence<1, 0>,
Sequence<1, 0>,
0,
GemmABlockTransferSrcScalarPerVector_GemmK,
GemmABlockTransferDstScalarPerVector_GemmM,
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN,
Sequence<0, 1>,
Sequence<0, 1>,
1,
GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmN,
Sequence<2, 3, 0, 1>,
3,
GemmCThreadTransferDstScalarPerVector_GemmN1>;
const index_t GridSize = (GemmM / GemmMPerBlock) * (GemmN / GemmNPerBlock);
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const Float*,
decltype(in_gemmk_gemmn_global_desc),
const Float*,
decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global);
#endif
} }
}; };
......
...@@ -51,26 +51,24 @@ constexpr auto get_convolution_output_default_4d_tensor_descriptor( ...@@ -51,26 +51,24 @@ constexpr auto get_convolution_output_default_4d_tensor_descriptor(
} }
template <class InDesc, class WeiDesc, class OutDesc> template <class InDesc, class WeiDesc, class OutDesc>
constexpr std::size_t calculate_convolution_flops(InDesc, WeiDesc, OutDesc) constexpr std::size_t
calculate_convolution_flops(const InDesc& in_desc, const WeiDesc& wei_desc, const OutDesc& out_desc)
{ {
using namespace ck; using namespace ck;
constexpr auto wei_desc = WeiDesc{};
constexpr auto out_desc = OutDesc{};
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
constexpr index_t N = out_desc.GetLength(I0); const index_t N = out_desc.GetLength(I0);
constexpr index_t K = out_desc.GetLength(I1); const index_t K = out_desc.GetLength(I1);
constexpr index_t Ho = out_desc.GetLength(I2); const index_t Ho = out_desc.GetLength(I2);
constexpr index_t Wo = out_desc.GetLength(I3); const index_t Wo = out_desc.GetLength(I3);
constexpr index_t C = wei_desc.GetLength(I1); const index_t C = wei_desc.GetLength(I1);
constexpr index_t Y = wei_desc.GetLength(I2); const index_t Y = wei_desc.GetLength(I2);
constexpr index_t X = wei_desc.GetLength(I3); const index_t X = wei_desc.GetLength(I3);
return std::size_t(2) * N * K * Ho * Wo * C * Y * X; return std::size_t(2) * N * K * Ho * Wo * C * Y * X;
} }
......
...@@ -577,7 +577,7 @@ int main(int argc, char* argv[]) ...@@ -577,7 +577,7 @@ int main(int argc, char* argv[])
LeftPads{}, LeftPads{},
RightPads{}, RightPads{},
nrepeat); nrepeat);
#elif 0 #elif 1
device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(in_nchw_desc, device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(in_nchw_desc,
in_nchw, in_nchw,
wei_kcyx_desc, wei_kcyx_desc,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment