Commit 666bdad1 authored by root's avatar root
Browse files

clean

parent f744524e
...@@ -13,13 +13,13 @@ template <index_t BlockSize, ...@@ -13,13 +13,13 @@ template <index_t BlockSize,
typename Float, typename Float,
typename AccFloat, typename AccFloat,
index_t KPerBlock, index_t KPerBlock,
index_t HPerBlock, index_t HoPerBlock,
index_t WPerBlock, index_t WoPerBlock,
index_t CYXPerBlock, index_t EPerBlock,
index_t KPerThread, index_t KPerThread,
index_t HPerThread, index_t HPerThread,
index_t WPerThread, index_t WPerThread,
index_t CYXPerThread, index_t EPerThread,
typename GemmABlockTransferThreadSliceLengths_GemmK_GemmM, typename GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
typename GemmABlockTransferThreadClusterLengths_GemmK_GemmM, typename GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
index_t GemmABlockTransferSrcScalarPerVector_GemmK, index_t GemmABlockTransferSrcScalarPerVector_GemmK,
...@@ -123,10 +123,10 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -123,10 +123,10 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
make_tuple(Sequence<1>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<1>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto CYX = C * Y * X; const auto E = C * Y * X;
if(!(K % KPerBlock == 0 && Ho % HPerBlock == 0 && Wo % WPerBlock == 0 && if(!(K % KPerBlock == 0 && Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0 &&
CYX % CYXPerBlock == 0)) E % EPerBlock == 0))
{ {
throw std::runtime_error("wrong! GEMM size no divisible"); throw std::runtime_error("wrong! GEMM size no divisible");
} }
...@@ -165,7 +165,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -165,7 +165,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
#if 1 #if 1
// GEMM // GEMM
using gridwise_gemm = GridwiseDynamicGemm_km_kn_mn_v3< using gridwise_gemm = GridwiseDynamicGemm_km_kn_mn_v2<
BlockSize, BlockSize,
Float, Float,
AccFloat, AccFloat,
...@@ -174,13 +174,13 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -174,13 +174,13 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
decltype(in_gemmk_n_ho_wo_global_desc), decltype(in_gemmk_n_ho_wo_global_desc),
decltype(out_gemmm_n_ho_wo_global_desc), decltype(out_gemmm_n_ho_wo_global_desc),
KPerBlock, KPerBlock,
HPerBlock, HoPerBlock,
WPerBlock, WoPerBlock,
CYXPerBlock, EPerBlock,
KPerThread, KPerThread,
HPerThread, HPerThread,
WPerThread, WPerThread,
CYXPerThread, EPerThread,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM, GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM, GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
Sequence<1, 0>, Sequence<1, 0>,
...@@ -205,11 +205,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -205,11 +205,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
decltype(a_k_m_global_move_slice_window_iterator_hack), decltype(a_k_m_global_move_slice_window_iterator_hack),
decltype(b_k_n_global_move_slice_window_iterator_hack)>; decltype(b_k_n_global_move_slice_window_iterator_hack)>;
const auto GridSize = (K / KPerBlock) * (Ho / HPerBlock) * (Wo / WPerBlock) * N; const auto GridSize = (K / KPerBlock) * (Ho / HoPerBlock) * (Wo / WoPerBlock) * N;
const bool has_main_k_block_loop = (CYX + CYXPerBlock) / (2 * CYXPerBlock) > 1; const bool has_main_k_block_loop = (E + EPerBlock) / (2 * EPerBlock) > 1;
const bool has_double_tail_k_block_loop = (CYX / CYXPerBlock) % 2 == 0; const bool has_double_tail_k_block_loop = (E / EPerBlock) % 2 == 0;
index_t nrepeat = 100; index_t nrepeat = 100;
...@@ -225,6 +225,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -225,6 +225,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
for(index_t j = 0; j < nrepeat; ++j) for(index_t j = 0; j < nrepeat; ++j)
{ {
#if 0
{ {
const auto kernel = const auto kernel =
run_gridwise_operation<gridwise_gemm, run_gridwise_operation<gridwise_gemm,
...@@ -251,7 +252,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad ...@@ -251,7 +252,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
integral_constant<bool, true>{}, integral_constant<bool, true>{},
integral_constant<bool, false>{}); integral_constant<bool, false>{});
} }
#if 0 #else
if(has_main_k_block_loop && has_double_tail_k_block_loop) if(has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
const auto kernel = const auto kernel =
......
...@@ -57,10 +57,10 @@ struct ThreadwiseGemm_km_kn_mn_v3 ...@@ -57,10 +57,10 @@ struct ThreadwiseGemm_km_kn_mn_v3
constexpr auto H = BDesc{}.GetLength(I2); constexpr auto H = BDesc{}.GetLength(I2);
constexpr auto W = BDesc{}.GetLength(I3); constexpr auto W = BDesc{}.GetLength(I3);
constexpr auto CYX = ADesc{}.GetLength(I0); constexpr auto E = ADesc{}.GetLength(I0);
constexpr auto K = ADesc{}.GetLength(I1); constexpr auto K = ADesc{}.GetLength(I1);
static_for<0, CYX, 1>{}([&](auto e) { static_for<0, E, 1>{}([&](auto e) {
static_for<0, K, 1>{}([&](auto k) { static_for<0, K, 1>{}([&](auto k) {
static_for<0, H, 1>{}([&](auto h) { static_for<0, H, 1>{}([&](auto h) {
static_for<0, W, 1>{}([&](auto w) { static_for<0, W, 1>{}([&](auto w) {
......
...@@ -73,15 +73,15 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc ...@@ -73,15 +73,15 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc
constexpr index_t KPerBlock = 16; constexpr index_t KPerBlock = 16;
constexpr index_t HPerBlock = 16; constexpr index_t HPerBlock = 16;
constexpr index_t WPerBlock = 16; constexpr index_t WPerBlock = 16;
constexpr index_t CYXPerBlock = 2 * 3 * 3; constexpr index_t CYXPerBlock = 4;
constexpr index_t KPerThread = 4; constexpr index_t KPerThread = 4;
constexpr index_t HPerThread = 2; constexpr index_t HPerThread = 2;
constexpr index_t WPerThread = 2; constexpr index_t WPerThread = 2;
constexpr index_t CYXPerThread = 2; constexpr index_t CYXPerThread = 4;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>; using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<9, 16>; using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1; constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1; constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment