Commit 937ad6c4 authored by root's avatar root
Browse files

inline, tuned

parent 5b242405
......@@ -165,7 +165,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
#if 1
// GEMM
using gridwise_gemm = GridwiseDynamicGemm_km_kn_mn_v2<
using gridwise_gemm = GridwiseDynamicGemm_km_kn_mn_v3<
BlockSize,
Float,
AccFloat,
......@@ -189,13 +189,13 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
GemmABlockTransferSrcScalarPerVector_GemmK,
GemmABlockTransferDstScalarPerVector_GemmM,
false, // don't move back src coordinate after threadwise copy
Sequence<3, 2, 1, 0>,
Sequence<0, 2, 3, 1>,
3,
GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmN,
false, // don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation
Sequence<3, 2, 1, 0>,
Sequence<0, 2, 3, 1>,
3,
GemmCThreadTransferDstScalarPerVector_GemmN1,
decltype(a_k_m_global_iterator_hacks),
......@@ -224,34 +224,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
for(index_t j = 0; j < nrepeat; ++j)
{
#if 0
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const Float*,
decltype(in_gemmk_n_ho_wo_global_desc),
const Float*,
decltype(out_gemmm_n_ho_wo_global_desc),
Float*,
integral_constant<bool, true>,
integral_constant<bool, false>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_n_ho_wo_global_desc,
p_in_global,
out_gemmm_n_ho_wo_global_desc,
p_out_global,
integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
#else
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
......@@ -360,7 +332,6 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
#endif
}
timer.End();
......
......@@ -62,16 +62,68 @@ struct ThreadwiseGemm_km_kn_mn_v3
static_for<0, E, 1>{}([&](auto e) {
static_for<0, K, 1>{}([&](auto k) {
static_for<0, H, 1>{}([&](auto h) {
static_for<0, W, 1>{}([&](auto w) {
constexpr auto a_offset = ADesc{}.CalculateOffset(make_tuple(e, k));
constexpr auto b_offset = BDesc{}.CalculateOffset(make_tuple(e, 0, h, w));
constexpr auto c_offset = CDesc{}.CalculateOffset(make_tuple(k, 0, h, w));
p_c[c_offset] +=
inner_product_with_conversion<FloatC>{}(p_a[a_offset], p_b[b_offset]);
constexpr auto a_offset = ADesc{}.CalculateOffset(make_tuple(e, k));
if constexpr(H == 2 && W == 2)
{
constexpr auto b_offset_0 = BDesc{}.CalculateOffset(make_tuple(e, 0, 0, 0));
constexpr auto b_offset_1 = BDesc{}.CalculateOffset(make_tuple(e, 0, 0, 1));
constexpr auto b_offset_2 = BDesc{}.CalculateOffset(make_tuple(e, 0, 1, 0));
constexpr auto b_offset_3 = BDesc{}.CalculateOffset(make_tuple(e, 0, 1, 1));
constexpr auto c_offset_0 = CDesc{}.CalculateOffset(make_tuple(k, 0, 0, 0));
constexpr auto c_offset_1 = CDesc{}.CalculateOffset(make_tuple(k, 0, 0, 1));
constexpr auto c_offset_2 = CDesc{}.CalculateOffset(make_tuple(k, 0, 1, 0));
constexpr auto c_offset_3 = CDesc{}.CalculateOffset(make_tuple(k, 0, 1, 1));
amd_assembly_outer_product_1x4(p_a[a_offset],
p_b[b_offset_0],
p_b[b_offset_1],
p_b[b_offset_2],
p_b[b_offset_3],
p_c[c_offset_0],
p_c[c_offset_1],
p_c[c_offset_2],
p_c[c_offset_3]);
}
else if constexpr(H == 4 && W == 1)
{
constexpr auto b_offset_0 = BDesc{}.CalculateOffset(make_tuple(e, 0, 0, 0));
constexpr auto b_offset_1 = BDesc{}.CalculateOffset(make_tuple(e, 0, 1, 0));
constexpr auto b_offset_2 = BDesc{}.CalculateOffset(make_tuple(e, 0, 2, 0));
constexpr auto b_offset_3 = BDesc{}.CalculateOffset(make_tuple(e, 0, 3, 0));
constexpr auto c_offset_0 = CDesc{}.CalculateOffset(make_tuple(k, 0, 0, 0));
constexpr auto c_offset_1 = CDesc{}.CalculateOffset(make_tuple(k, 0, 1, 0));
constexpr auto c_offset_2 = CDesc{}.CalculateOffset(make_tuple(k, 0, 2, 0));
constexpr auto c_offset_3 = CDesc{}.CalculateOffset(make_tuple(k, 0, 3, 0));
amd_assembly_outer_product_1x4(p_a[a_offset],
p_b[b_offset_0],
p_b[b_offset_1],
p_b[b_offset_2],
p_b[b_offset_3],
p_c[c_offset_0],
p_c[c_offset_1],
p_c[c_offset_2],
p_c[c_offset_3]);
}
else
{
static_for<0, H, 1>{}([&](auto h) {
static_for<0, W, 1>{}([&](auto w) {
constexpr auto b_offset =
BDesc{}.CalculateOffset(make_tuple(e, 0, h, w));
constexpr auto c_offset =
CDesc{}.CalculateOffset(make_tuple(k, 0, h, w));
p_c[c_offset] += inner_product_with_conversion<FloatC>{}(p_a[a_offset],
p_b[b_offset]);
});
});
});
}
});
});
}
......
......@@ -85,7 +85,7 @@
// experimental implementation
#ifndef CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 0
#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1
#endif
#ifndef CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
......
......@@ -68,19 +68,19 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc
#endif
// cdata = 16, BlockSize = 64, 16x64x4
constexpr index_t BlockSize = 256;
constexpr index_t BlockSize = 64;
constexpr index_t KPerBlock = 16;
constexpr index_t HoPerBlock = 16;
constexpr index_t WoPerBlock = 16;
constexpr index_t EPerBlock = 4;
constexpr index_t EPerBlock = 2;
constexpr index_t KPerThread = 4;
constexpr index_t HoPerThread = 2;
constexpr index_t WoPerThread = 2;
constexpr index_t EPerThread = 4;
constexpr index_t KPerThread = 16;
constexpr index_t HoPerThread = 4;
constexpr index_t WoPerThread = 1;
constexpr index_t EPerThread = 2;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<9, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
......@@ -89,7 +89,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(InDesc
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 2;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
constexpr auto conv_driver =
DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad<
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment