Commit 666bdad1 authored by root's avatar root
Browse files

clean

parent f744524e
......@@ -13,13 +13,13 @@ template <index_t BlockSize,
typename Float,
typename AccFloat,
index_t KPerBlock,
index_t HPerBlock,
index_t WPerBlock,
index_t CYXPerBlock,
index_t HoPerBlock,
index_t WoPerBlock,
index_t EPerBlock,
index_t KPerThread,
index_t HPerThread,
index_t WPerThread,
index_t CYXPerThread,
index_t EPerThread,
typename GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
typename GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
index_t GemmABlockTransferSrcScalarPerVector_GemmK,
......@@ -123,10 +123,10 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
make_tuple(Sequence<1>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto CYX = C * Y * X;
const auto E = C * Y * X;
if(!(K % KPerBlock == 0 && Ho % HPerBlock == 0 && Wo % WPerBlock == 0 &&
CYX % CYXPerBlock == 0))
if(!(K % KPerBlock == 0 && Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0 &&
E % EPerBlock == 0))
{
throw std::runtime_error("wrong! GEMM size no divisible");
}
......@@ -165,7 +165,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
#if 1
// GEMM
using gridwise_gemm = GridwiseDynamicGemm_km_kn_mn_v3<
using gridwise_gemm = GridwiseDynamicGemm_km_kn_mn_v2<
BlockSize,
Float,
AccFloat,
......@@ -174,13 +174,13 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
decltype(in_gemmk_n_ho_wo_global_desc),
decltype(out_gemmm_n_ho_wo_global_desc),
KPerBlock,
HPerBlock,
WPerBlock,
CYXPerBlock,
HoPerBlock,
WoPerBlock,
EPerBlock,
KPerThread,
HPerThread,
WPerThread,
CYXPerThread,
EPerThread,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
Sequence<1, 0>,
......@@ -205,11 +205,11 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
decltype(a_k_m_global_move_slice_window_iterator_hack),
decltype(b_k_n_global_move_slice_window_iterator_hack)>;
const auto GridSize = (K / KPerBlock) * (Ho / HPerBlock) * (Wo / WPerBlock) * N;
const auto GridSize = (K / KPerBlock) * (Ho / HoPerBlock) * (Wo / WoPerBlock) * N;
const bool has_main_k_block_loop = (CYX + CYXPerBlock) / (2 * CYXPerBlock) > 1;
const bool has_main_k_block_loop = (E + EPerBlock) / (2 * EPerBlock) > 1;
const bool has_double_tail_k_block_loop = (CYX / CYXPerBlock) % 2 == 0;
const bool has_double_tail_k_block_loop = (E / EPerBlock) % 2 == 0;
index_t nrepeat = 100;
......@@ -225,6 +225,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
for(index_t j = 0; j < nrepeat; ++j)
{
#if 0
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
......@@ -251,7 +252,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
#if 0
#else
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
......
......@@ -57,10 +57,10 @@ struct ThreadwiseGemm_km_kn_mn_v3
constexpr auto H = BDesc{}.GetLength(I2);
constexpr auto W = BDesc{}.GetLength(I3);
constexpr auto CYX = ADesc{}.GetLength(I0);
constexpr auto K = ADesc{}.GetLength(I1);
constexpr auto E = ADesc{}.GetLength(I0);
constexpr auto K = ADesc{}.GetLength(I1);
static_for<0, CYX, 1>{}([&](auto e) {
static_for<0, E, 1>{}([&](auto e) {
static_for<0, K, 1>{}([&](auto k) {
static_for<0, H, 1>{}([&](auto h) {
static_for<0, W, 1>{}([&](auto w) {
......
......@@ -73,15 +73,15 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc
constexpr index_t KPerBlock = 16;
constexpr index_t HPerBlock = 16;
constexpr index_t WPerBlock = 16;
constexpr index_t CYXPerBlock = 2 * 3 * 3;
constexpr index_t CYXPerBlock = 4;
constexpr index_t KPerThread = 4;
constexpr index_t HPerThread = 2;
constexpr index_t WPerThread = 2;
constexpr index_t CYXPerThread = 2;
constexpr index_t CYXPerThread = 4;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<9, 16>;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment