Commit 3e5e4cf7 authored by Jing Zhang's avatar Jing Zhang
Browse files

merge v5r1 nchwc

parent d88c2b2d
...@@ -13,8 +13,7 @@ template <index_t BlockSize, ...@@ -13,8 +13,7 @@ template <index_t BlockSize,
typename ABlockDesc_E1_K_E2, typename ABlockDesc_E1_K_E2,
typename BBlockDesc_E1_N_Ho_Wo_E2, typename BBlockDesc_E1_N_Ho_Wo_E2,
typename CThreadDesc_K_N_Ho_Wo, typename CThreadDesc_K_N_Ho_Wo,
index_t EPerThreadLoop, index_t EPerThreadLoop>
index_t ThreadGemmADataPerRead_E2>
struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -42,7 +41,7 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 ...@@ -42,7 +41,7 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
static constexpr auto HPerThread = CThreadDesc_K_N_Ho_Wo{}.GetLength(I2); static constexpr auto HPerThread = CThreadDesc_K_N_Ho_Wo{}.GetLength(I2);
static constexpr auto WPerThread = CThreadDesc_K_N_Ho_Wo{}.GetLength(I3); static constexpr auto WPerThread = CThreadDesc_K_N_Ho_Wo{}.GetLength(I3);
static constexpr index_t KPerThreadLoop = KPerThread; static constexpr index_t KPerThreadLoop = 2;
static constexpr auto a_thread_mtx_ = make_naive_tensor_descriptor_packed( static constexpr auto a_thread_mtx_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<EPerThreadLoop>{}, Number<KPerThreadLoop>{}, Number<E2>{})); make_tuple(Number<EPerThreadLoop>{}, Number<KPerThreadLoop>{}, Number<E2>{}));
...@@ -162,14 +161,14 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 ...@@ -162,14 +161,14 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
using AThreadCopy = using AThreadCopy =
ThreadwiseTensorSliceTransfer_v4<FloatA, ThreadwiseTensorSliceTransfer_v4<FloatA,
FloatB, FloatA,
ABlockDesc_E1_K_E2, ABlockDesc_E1_K_E2,
decltype(a_thread_mtx_), decltype(a_thread_mtx_),
Sequence<EPerThreadLoop, KPerThreadLoop, E2>, Sequence<EPerThreadLoop, KPerThreadLoop, E2>,
Sequence<0, 1, 2>, Sequence<0, 1, 2>,
2, 2,
ThreadGemmADataPerRead_E2, E2,
ThreadGemmADataPerRead_E2>; E2>;
AThreadCopy a_thread_copy_; AThreadCopy a_thread_copy_;
}; };
......
...@@ -19,6 +19,7 @@ template <typename GridwiseGemm, ...@@ -19,6 +19,7 @@ template <typename GridwiseGemm,
typename BGridDesc_E0_E1_N_Ho_Wo_E2, typename BGridDesc_E0_E1_N_Ho_Wo_E2,
typename CGridDesc_K_N_Ho_Wo, typename CGridDesc_K_N_Ho_Wo,
typename CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo, typename CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo,
bool HasMainE0BlockLoop,
bool HasMainE1BlockLoop, bool HasMainE1BlockLoop,
bool HasDoubleTailE1BlockLoop> bool HasDoubleTailE1BlockLoop>
__global__ void __global__ void
...@@ -46,6 +47,7 @@ __global__ void ...@@ -46,6 +47,7 @@ __global__ void
a_e0_e1_k_e2_grid_desc, a_e0_e1_k_e2_grid_desc,
b_e0_e1_n_ho_wo_e2_grid_desc, b_e0_e1_n_ho_wo_e2_grid_desc,
c_k_n_ho_wo_grid_desc, c_k_n_ho_wo_grid_desc,
integral_constant<bool, HasMainE0BlockLoop>{},
integral_constant<bool, HasMainE1BlockLoop>{}, integral_constant<bool, HasMainE1BlockLoop>{},
integral_constant<bool, HasDoubleTailE1BlockLoop>{}); integral_constant<bool, HasDoubleTailE1BlockLoop>{});
} }
...@@ -60,6 +62,7 @@ template <typename GridwiseGemm, ...@@ -60,6 +62,7 @@ template <typename GridwiseGemm,
typename BGridDesc_E0_E1_N_Ho_Wo_E2, typename BGridDesc_E0_E1_N_Ho_Wo_E2,
typename CGridDesc_K_N_Ho_Wo, typename CGridDesc_K_N_Ho_Wo,
typename CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo, typename CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo,
bool HasMainE0BlockLoop,
bool HasMainE1BlockLoop, bool HasMainE1BlockLoop,
bool HasDoubleTailE1BlockLoop> bool HasDoubleTailE1BlockLoop>
__global__ void __global__ void
...@@ -96,6 +99,7 @@ __global__ void ...@@ -96,6 +99,7 @@ __global__ void
a_e0_e1_k_e2_grid_desc, a_e0_e1_k_e2_grid_desc,
b_e0_e1_n_ho_wo_e2_grid_desc, b_e0_e1_n_ho_wo_e2_grid_desc,
c_k_n_ho_wo_grid_desc, c_k_n_ho_wo_grid_desc,
integral_constant<bool, HasMainE0BlockLoop>{},
integral_constant<bool, HasMainE1BlockLoop>{}, integral_constant<bool, HasMainE1BlockLoop>{},
integral_constant<bool, HasDoubleTailE1BlockLoop>{}); integral_constant<bool, HasDoubleTailE1BlockLoop>{});
} }
...@@ -138,7 +142,8 @@ template <index_t BlockSize, ...@@ -138,7 +142,8 @@ template <index_t BlockSize,
typename BGlobalStepHacks, typename BGlobalStepHacks,
typename CGlobalStepHacks, typename CGlobalStepHacks,
typename AGlobalMoveSliceWindowStepHacks, typename AGlobalMoveSliceWindowStepHacks,
typename BGlobalMoveSliceWindowStepHacks> typename BGlobalMoveSliceWindowStepHacks,
index_t activ_type = 0>
struct GridwiseGemmDlops_km_kn_mn_v3 struct GridwiseGemmDlops_km_kn_mn_v3
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -167,7 +172,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -167,7 +172,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
return a_block_space_size * sizeof(FloatAB); return a_block_space_size * sizeof(FloatAB);
} }
template <bool HasMainE1BlockLoop, bool HasDoubleTailE1BlockLoop> template <bool HasMainE0BlockLoop, bool HasMainE1BlockLoop, bool HasDoubleTailE1BlockLoop>
__device__ static void Run(const FloatAB* __restrict__ p_a_global, __device__ static void Run(const FloatAB* __restrict__ p_a_global,
const FloatAB* __restrict__ p_b_global, const FloatAB* __restrict__ p_b_global,
FloatC* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
...@@ -175,6 +180,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -175,6 +180,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
const AGlobalDesc_E0_E1_K_E2& a_e0_e1_k_e2_global_desc, const AGlobalDesc_E0_E1_K_E2& a_e0_e1_k_e2_global_desc,
const BGlobalDesc_E0_E1_N_Ho_Wo_E2& b_e0_e1_n_ho_wo_e2_global_desc, const BGlobalDesc_E0_E1_N_Ho_Wo_E2& b_e0_e1_n_ho_wo_e2_global_desc,
const CGlobalDesc_K_N_Ho_Wo& c_k_n_ho_wo_global_desc, const CGlobalDesc_K_N_Ho_Wo& c_k_n_ho_wo_global_desc,
integral_constant<bool, HasMainE0BlockLoop>,
integral_constant<bool, HasMainE1BlockLoop>, integral_constant<bool, HasMainE1BlockLoop>,
integral_constant<bool, HasDoubleTailE1BlockLoop>) integral_constant<bool, HasDoubleTailE1BlockLoop>)
{ {
...@@ -253,8 +259,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -253,8 +259,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
decltype(a_e1_k_e2_block_desc), decltype(a_e1_k_e2_block_desc),
decltype(b_e1_n_ho_wo_e2_block_desc), decltype(b_e1_n_ho_wo_e2_block_desc),
decltype(c_k_n_ho_wo_thread_desc), decltype(c_k_n_ho_wo_thread_desc),
EPerThread, EPerThread>{};
ABlockTransferDstScalarPerVector_E2>{};
auto c_thread_mtx_index = auto c_thread_mtx_index =
blockwise_gemm.GetBeginOfCThreadDesc_K_N_Ho_Wo(get_thread_local_1d_id()); blockwise_gemm.GetBeginOfCThreadDesc_K_N_Ho_Wo(get_thread_local_1d_id());
...@@ -357,8 +362,6 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -357,8 +362,6 @@ struct GridwiseGemmDlops_km_kn_mn_v3
true> true>
b_thread_even_buf, b_thread_odd_buf; b_thread_even_buf, b_thread_odd_buf;
constexpr auto HasMainE0BlockLoop = false;
if constexpr(HasMainE0BlockLoop) if constexpr(HasMainE0BlockLoop)
{ {
const auto E0 = b_e0_e1_n_ho_wo_e2_global_desc.GetLength(I0); const auto E0 = b_e0_e1_n_ho_wo_e2_global_desc.GetLength(I0);
...@@ -566,6 +569,27 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -566,6 +569,27 @@ struct GridwiseGemmDlops_km_kn_mn_v3
} }
} }
// activ
{
static_for<0, c_k_n_ho_wo_thread_desc.GetElementSpaceSize(), 1>{}([&](auto i) {
if constexpr(activ_type == 1)
{
c_thread_buf(i) = c_thread_buf[i] >= 0 ? c_thread_buf[i] : 0.0;
}
else if constexpr(activ_type == 2)
{
FloatAcc x = 1.0 + exp(-c_thread_buf[i]);
asm volatile("\n \
v_rcp_f32 %0, %1 \n"
: "=v"(x)
: "0"(x));
c_thread_buf(i) = x;
}
});
}
// output: register to global memory // output: register to global memory
{ {
// hack to control index calculation when iterating over c_k_n_ho_wo_global tensor // hack to control index calculation when iterating over c_k_n_ho_wo_global tensor
......
...@@ -71,95 +71,84 @@ struct ThreadwiseGemmDlops_km_kn_mn_v3 ...@@ -71,95 +71,84 @@ struct ThreadwiseGemmDlops_km_kn_mn_v3
constexpr auto b_origin_idx = to_multi_index(BOriginIdx{}); constexpr auto b_origin_idx = to_multi_index(BOriginIdx{});
constexpr auto c_origin_idx = to_multi_index(COriginIdx{}); constexpr auto c_origin_idx = to_multi_index(COriginIdx{});
constexpr index_t Vec = 2; #if 1
constexpr index_t SubHW = 2;
static_for<0, K, 1>{}([&](auto k) { static_for<0, K, 1>{}([&](auto k) {
static_for<0, E1, 1>{}([&](auto e) { static_for<0, Ho, SubHW>{}([&](auto h) {
static_for<0, Ho, Vec>{}([&](auto h) { static_for<0, Wo, SubHW>{}([&](auto w) {
static_for<0, Wo, Vec>{}([&](auto w) { static_for<0, E1, 1>{}([&](auto e1) {
vector_type<FloatA, E2> a_vec; static_for<0, E2, 1>{}([&](auto e2) {
vector_type<FloatB, E2> b0_vec;
vector_type<FloatB, E2> b1_vec;
vector_type<FloatB, E2> b2_vec;
vector_type<FloatB, E2> b3_vec;
static_for<0, E2, 1>{}([&](auto v) {
constexpr index_t a_offset = AThreadDesc_E1_K_E2{}.CalculateOffset( constexpr index_t a_offset = AThreadDesc_E1_K_E2{}.CalculateOffset(
a_origin_idx + make_tuple(e, k, v)); a_origin_idx + make_tuple(e1, k, e2));
constexpr index_t b0_offset = constexpr index_t b0_offset =
BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset( BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset(
b_origin_idx + make_tuple(e, 0, h, w, v)); b_origin_idx + make_tuple(e1, 0, h, w, e2));
constexpr index_t b1_offset = constexpr index_t b1_offset =
BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset( BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset(
b_origin_idx + make_tuple(e, 0, h, w + 1, v)); b_origin_idx + make_tuple(e1, 0, h, w + 1, e2));
constexpr index_t b2_offset = constexpr index_t b2_offset =
BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset( BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset(
b_origin_idx + make_tuple(e, 0, h + 1, w, v)); b_origin_idx + make_tuple(e1, 0, h + 1, w, e2));
constexpr index_t b3_offset = constexpr index_t b3_offset =
BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset( BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset(
b_origin_idx + make_tuple(e, 0, h + 1, w + 1, v)); b_origin_idx + make_tuple(e1, 0, h + 1, w + 1, e2));
constexpr index_t c0_offset = CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(
c_origin_idx + make_tuple(k, 0, h, w));
constexpr index_t c1_offset = CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(
c_origin_idx + make_tuple(k, 0, h, w + 1));
a_vec.template AsType<FloatA>()(v) = a_buf[Number<a_offset>{}]; constexpr index_t c2_offset = CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(
c_origin_idx + make_tuple(k, 0, h + 1, w));
b0_vec.template AsType<FloatB>()(v) = b_buf[Number<b0_offset>{}]; constexpr index_t c3_offset = CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(
b1_vec.template AsType<FloatB>()(v) = b_buf[Number<b1_offset>{}]; c_origin_idx + make_tuple(k, 0, h + 1, w + 1));
b2_vec.template AsType<FloatB>()(v) = b_buf[Number<b2_offset>{}];
b3_vec.template AsType<FloatB>()(v) = b_buf[Number<b3_offset>{}]; amd_assembly_outer_product_1x4(a_buf[Number<a_offset>{}],
b_buf[Number<b0_offset>{}],
b_buf[Number<b1_offset>{}],
b_buf[Number<b2_offset>{}],
b_buf[Number<b3_offset>{}],
c_buf(Number<c0_offset>{}),
c_buf(Number<c1_offset>{}),
c_buf(Number<c2_offset>{}),
c_buf(Number<c3_offset>{}));
}); });
});
});
});
});
#else
static_for<0, K, 1>{}([&](auto k) {
static_for<0, Ho, 1>{}([&](auto h) {
static_for<0, Wo, 1>{}([&](auto w) {
static_for<0, E1, 1>{}([&](auto e1) {
static_for<0, E2, 1>{}([&](auto e2) {
constexpr index_t a_offset = AThreadDesc_E1_K_E2{}.CalculateOffset(
a_origin_idx + make_tuple(e1, k, e2));
constexpr index_t b_offset =
BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset(
b_origin_idx + make_tuple(e1, 0, h, w, e2));
constexpr index_t c_offset = CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(
c_origin_idx + make_tuple(k, 0, h, w));
using a_vector_t = typename vector_type<FloatA, E2>::type; inner_product<FloatA, FloatB, FloatC>(a_buf[Number<a_offset>{}],
using b_vector_t = typename vector_type<FloatB, E2>::type; b_buf[Number<b_offset>{}],
c_buf(Number<c_offset>{}));
constexpr index_t c0_offset = CThreadDesc_K_N_Ho_Wo{}.CalculateOffset( });
c_origin_idx + make_tuple(k, 0, h, w));
constexpr index_t c1_offset = CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(
c_origin_idx + make_tuple(k, 0, h, w + 1));
constexpr index_t c2_offset = CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(
c_origin_idx + make_tuple(k, 0, h + 1, w));
constexpr index_t c3_offset = CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(
c_origin_idx + make_tuple(k, 0, h + 1, w + 1));
amd_assembly_outer_product_1x4(a_vec.template AsType<a_vector_t>()[I0],
b0_vec.template AsType<b_vector_t>()[I0],
b1_vec.template AsType<b_vector_t>()[I0],
b2_vec.template AsType<b_vector_t>()[I0],
b3_vec.template AsType<b_vector_t>()[I0],
c_buf(Number<c0_offset>{}),
c_buf(Number<c1_offset>{}),
c_buf(Number<c2_offset>{}),
c_buf(Number<c3_offset>{}));
// inner_product<a_vector_t, b_vector_t, FloatC>(
// a_vec.template AsType<a_vector_t>()[I0],
// b0_vec.template AsType<b_vector_t>()[I0],
// c_buf(Number<c0_offset>{}));
// inner_product<a_vector_t, b_vector_t, FloatC>(
// a_vec.template AsType<a_vector_t>()[I0],
// b1_vec.template AsType<b_vector_t>()[I0],
// c_buf(Number<c1_offset>{}));
// inner_product<a_vector_t, b_vector_t, FloatC>(
// a_vec.template AsType<a_vector_t>()[I0],
// b2_vec.template AsType<b_vector_t>()[I0],
// c_buf(Number<c2_offset>{}));
// inner_product<a_vector_t, b_vector_t, FloatC>(
// a_vec.template AsType<a_vector_t>()[I0],
// b3_vec.template AsType<b_vector_t>()[I0],
// c_buf(Number<c3_offset>{}));
}); });
}); });
}); });
}); });
#endif
} }
}; };
......
...@@ -6,6 +6,8 @@ ...@@ -6,6 +6,8 @@
template <typename TInWei, template <typename TInWei,
typename TAcc, typename TAcc,
typename TOut, typename TOut,
ck::index_t InWeiVectorSize,
ck::index_t activ_type,
typename InLengths, typename InLengths,
typename WeiLengths, typename WeiLengths,
typename OutLengths, typename OutLengths,
...@@ -48,21 +50,11 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw( ...@@ -48,21 +50,11 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
const auto Y = wei_k_c_y_x_lengths[I2]; const auto Y = wei_k_c_y_x_lengths[I2];
const auto X = wei_k_c_y_x_lengths[I3]; const auto X = wei_k_c_y_x_lengths[I3];
constexpr auto InWeiVectorSize = 8;
#if 1
const auto C0 = C / Number<InWeiVectorSize>{}; const auto C0 = C / Number<InWeiVectorSize>{};
const auto C1 = Number<InWeiVectorSize>{}; const auto C1 = Number<InWeiVectorSize>{};
const auto K0 = K / Number<InWeiVectorSize>{}; const auto K0 = K / Number<InWeiVectorSize>{};
const auto K1 = Number<InWeiVectorSize>{}; const auto K1 = Number<InWeiVectorSize>{};
#else
const auto C0 = 1;
const auto C1 = C;
const auto K0 = 1;
const auto K1 = K;
#endif
Tensor<TInWei> in_n_c0_hi_wi_c1( Tensor<TInWei> in_n_c0_hi_wi_c1(
HostTensorDescriptor(std::initializer_list<index_t>{N, C0, Hi, Wi, C1})); HostTensorDescriptor(std::initializer_list<index_t>{N, C0, Hi, Wi, C1}));
...@@ -92,31 +84,55 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw( ...@@ -92,31 +84,55 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
wei_k_c0_y_x_c1_device_buf.ToDevice(wei_k_c0_y_x_c1.mData.data()); wei_k_c0_y_x_c1_device_buf.ToDevice(wei_k_c0_y_x_c1.mData.data());
const auto in_n_c0_hi_wi_c1_desc = const auto in_n_c0_hi_wi_c1_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, C0, Hi, Wi, C1)); make_naive_tensor_descriptor_packed(make_tuple(N, C0, Hi, Wi, I1));
const auto wei_k_c0_y_x_c1_desc = const auto wei_k_c0_y_x_c1_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, C0, Y, X, C1)); make_naive_tensor_descriptor_packed(make_tuple(K, C0, Y, X, I1));
const auto out_n_k0_ho_wo_k1_desc = const auto out_n_k0_ho_wo_k1_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1)); make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1));
#if 1 #if 0
// cdata = 64, BlockSize = 64, 16x8x32x4 constexpr index_t BlockSize = 256;
constexpr index_t KPerBlock = 32;
constexpr index_t HoPerBlock = 8;
constexpr index_t WoPerBlock = 64;
constexpr index_t E1 = C0 * 9;
constexpr index_t E2 = 1;
constexpr index_t EPerBlock = C0;
constexpr index_t KPerThread = 16;
constexpr index_t HoPerThread = 2;
constexpr index_t WoPerThread = 2;
constexpr index_t EPerThread = 1;
using ABlockTransferThreadSliceLengths_E0_E1_K_E2 = Sequence<1, 9, 1, E2>;
using ABlockTransferThreadClusterLengths_E0_E1_K_E2 = Sequence<1, EPerBlock, KPerBlock, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_E2 = E2;
constexpr index_t ABlockTransferDstScalarPerVector_E2 = E2;
constexpr index_t BThreadTransferSrcScalarPerVector_E2 = E2;
constexpr index_t CThreadTransferDstScalarPerVector_K = K1;
#elif 1
constexpr index_t BlockSize = 64; constexpr index_t BlockSize = 64;
constexpr index_t KPerBlock = 16; constexpr index_t KPerBlock = 16;
constexpr index_t HoPerBlock = 8; constexpr index_t HoPerBlock = 8;
constexpr index_t WoPerBlock = 32; constexpr index_t WoPerBlock = 32;
constexpr index_t E1 = 2 * 9; constexpr index_t E1 = C0 * 9;
constexpr index_t E2 = C1; constexpr index_t E2 = 1;
constexpr index_t EPerBlock = 2; constexpr index_t EPerBlock = C0;
constexpr index_t KPerThread = KPerBlock; constexpr index_t KPerThread = 16;
constexpr index_t HoPerThread = 2; constexpr index_t HoPerThread = 2;
constexpr index_t WoPerThread = 2; constexpr index_t WoPerThread = 2;
constexpr index_t EPerThread = 1; constexpr index_t EPerThread = 1;
using ABlockTransferThreadSliceLengths_E0_E1_K_E2 = Sequence<1, 9, 1, E2>; using ABlockTransferThreadSliceLengths_E0_E1_K_E2 = Sequence<1, 9, 1, E2>;
using ABlockTransferThreadClusterLengths_E0_E1_K_E2 = Sequence<1, EPerBlock, 16, 1>; using ABlockTransferThreadClusterLengths_E0_E1_K_E2 = Sequence<1, EPerBlock, KPerBlock, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_E2 = E2; constexpr index_t ABlockTransferSrcScalarPerVector_E2 = E2;
constexpr index_t ABlockTransferDstScalarPerVector_E2 = E2; constexpr index_t ABlockTransferDstScalarPerVector_E2 = E2;
...@@ -129,7 +145,7 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw( ...@@ -129,7 +145,7 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
constexpr auto conv_driver = constexpr auto conv_driver =
DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outpad< DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outpad<
BlockSize, BlockSize,
TInWei, typename vector_type<TInWei, InWeiVectorSize>::type,
TAcc, TAcc,
TOut, TOut,
E1, E1,
...@@ -147,7 +163,8 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw( ...@@ -147,7 +163,8 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
ABlockTransferSrcScalarPerVector_E2, ABlockTransferSrcScalarPerVector_E2,
ABlockTransferDstScalarPerVector_E2, ABlockTransferDstScalarPerVector_E2,
BThreadTransferSrcScalarPerVector_E2, BThreadTransferSrcScalarPerVector_E2,
CThreadTransferDstScalarPerVector_K>{}; CThreadTransferDstScalarPerVector_K,
activ_type>{};
for(int i = 0; i < 5; i++) for(int i = 0; i < 5; i++)
{ {
...@@ -160,8 +177,10 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw( ...@@ -160,8 +177,10 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
conv_dilations, conv_dilations,
in_left_pads, in_left_pads,
in_right_pads, in_right_pads,
static_cast<TInWei*>(wei_k_c0_y_x_c1_device_buf.GetDeviceBuffer()), static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
static_cast<TInWei*>(in_n_c0_hi_wi_c1_device_buf.GetDeviceBuffer()), wei_k_c0_y_x_c1_device_buf.GetDeviceBuffer()),
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
in_n_c0_hi_wi_c1_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_k0_ho_wo_k1_device_buf.GetDeviceBuffer()), static_cast<TOut*>(out_n_k0_ho_wo_k1_device_buf.GetDeviceBuffer()),
nrepeat); nrepeat);
......
...@@ -25,7 +25,8 @@ template <ck::index_t BlockSize, ...@@ -25,7 +25,8 @@ template <ck::index_t BlockSize,
ck::index_t ABlockTransferSrcScalarPerVector_E2, ck::index_t ABlockTransferSrcScalarPerVector_E2,
ck::index_t ABlockTransferDstScalarPerVector_E2, ck::index_t ABlockTransferDstScalarPerVector_E2,
ck::index_t BThreadTransferSrcScalarPerVector_E2, ck::index_t BThreadTransferSrcScalarPerVector_E2,
ck::index_t CThreadTransferDstScalarPerVector_K> ck::index_t CThreadTransferDstScalarPerVector_K,
ck::index_t activ_type>
struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outpad struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outpad
{ {
template <typename... Wei, template <typename... Wei,
...@@ -76,8 +77,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -76,8 +77,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
const auto ConvDilationH = conv_dilations[I0]; const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1]; const auto ConvDilationW = conv_dilations[I1];
const auto Hop = (Ho + HoPerBlock - 1) / HoPerBlock * HoPerBlock; const auto Hop = Number<(Ho + HoPerBlock - 1) / HoPerBlock * HoPerBlock>{};
const auto Wop = (Wo + WoPerBlock - 1) / WoPerBlock * WoPerBlock; const auto Wop = Number<(Wo + WoPerBlock - 1) / WoPerBlock * WoPerBlock>{};
const auto OutRightPadH = Hop - Ho; const auto OutRightPadH = Hop - Ho;
const auto OutRightPadW = Wop - Wo; const auto OutRightPadW = Wop - Wo;
...@@ -169,8 +170,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -169,8 +170,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1)), make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1)),
make_tuple(make_merge_transform(make_tuple(K0, K1)), make_tuple(make_merge_transform(make_tuple(K0, K1)),
make_pass_through_transform(N), make_pass_through_transform(N),
make_pad_transform(Ho, 0, OutRightPadH), make_pad_transform(Ho, I0, OutRightPadH),
make_pad_transform(Wo, 0, OutRightPadW)), make_pad_transform(Wo, I0, OutRightPadW)),
make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
...@@ -225,6 +226,10 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -225,6 +226,10 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{})); Sequence<0, 0, 0, 0, 0>{}));
static_assert(a_e0_e1_k_e2_grid_desc.IsKnownAtCompileTime(), "");
static_assert(b_e0_e1_n_ho_wo_e2_grid_desc.IsKnownAtCompileTime(), "");
static_assert(c_k_n_hop_wop_grid_desc.IsKnownAtCompileTime(), "");
// GEMM // GEMM
using GridwiseGemm = GridwiseGemmDlops_km_kn_mn_v3< using GridwiseGemm = GridwiseGemmDlops_km_kn_mn_v3<
BlockSize, BlockSize,
...@@ -265,7 +270,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -265,7 +270,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
decltype(b_e0_e1_n_ho_wo_e2_global_step_hacks), decltype(b_e0_e1_n_ho_wo_e2_global_step_hacks),
decltype(c_k_n_ho_wo_global_tensor_step_hacks), decltype(c_k_n_ho_wo_global_tensor_step_hacks),
decltype(a_e0_e1_k_e2_global_move_slice_window_step_hack), decltype(a_e0_e1_k_e2_global_move_slice_window_step_hack),
decltype(b_e0_e1_n_ho_wo_e2_global_move_slice_window_step_hack)>; decltype(b_e0_e1_n_ho_wo_e2_global_move_slice_window_step_hack),
activ_type>;
using AGridDesc_E0_E1_K_E2 = decltype(a_e0_e1_k_e2_grid_desc); using AGridDesc_E0_E1_K_E2 = decltype(a_e0_e1_k_e2_grid_desc);
using BGridDesc_E0_E1_N_Ho_Wo_E2 = decltype(b_e0_e1_n_ho_wo_e2_grid_desc); using BGridDesc_E0_E1_N_Ho_Wo_E2 = decltype(b_e0_e1_n_ho_wo_e2_grid_desc);
...@@ -273,12 +279,13 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -273,12 +279,13 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
const auto grid_size = (K / KPerBlock) * (Hop / HoPerBlock) * (Wop / WoPerBlock) * N; const auto grid_size = (K / KPerBlock) * (Hop / HoPerBlock) * (Wop / WoPerBlock) * N;
const bool has_main_k_block_loop = (E1 + E1PerBlock) / (2 * E1PerBlock) > 1; const bool has_main_e0_block_loop = E0 > 1;
const bool has_main_e1_block_loop = (E1 + E1PerBlock) / (2 * E1PerBlock) > 1;
const bool has_double_tail_k_block_loop = (E1 / E1PerBlock) % 2 == 0; const bool has_double_tail_e1_block_loop = (E1 / E1PerBlock) % 2 == 0;
std::cerr << "has_main_k_block_loop = " << has_main_k_block_loop std::cerr << "has_main_e0_block_loop = " << has_main_e0_block_loop
<< " has_double_tail_k_block_loop = " << has_double_tail_k_block_loop << "has_main_e1_block_loop = " << has_main_e1_block_loop
<< " has_double_tail_e1_block_loop = " << has_double_tail_e1_block_loop
<< std::endl; << std::endl;
const auto c_blockid_to_k_n_ho_wo_block_cluster_adaptor = const auto c_blockid_to_k_n_ho_wo_block_cluster_adaptor =
...@@ -292,110 +299,31 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -292,110 +299,31 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
float ave_time = 0; float ave_time = 0;
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE #if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
if(has_main_k_block_loop && has_double_tail_k_block_loop) const auto kernel =
{ kernel_gemm_dlops_v2<GridwiseGemm,
const auto kernel = FloatAB,
kernel_gemm_dlops_v2<GridwiseGemm, FloatC,
FloatAB, remove_reference_t<AGridDesc_E0_E1_K_E2>,
FloatC, remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo_E2>,
remove_reference_t<AGridDesc_E0_E1_K_E2>, remove_reference_t<CGridDesc_K_N_Ho_Wo>,
remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo_E2>, remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>,
remove_reference_t<CGridDesc_K_N_Ho_Wo>, has_main_e0_block_loop,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>, has_main_e1_block_loop,
true, has_double_tail_e1_block_loop>;
true>;
ave_time = launch_and_time_kernel(kernel,
ave_time = launch_and_time_kernel(kernel, nrepeat,
nrepeat, dim3(grid_size),
dim3(grid_size), dim3(BlockSize),
dim3(BlockSize), 0,
0, p_a_grid,
p_a_grid, p_b_grid,
p_b_grid, p_c_grid,
p_c_grid, a_e0_e1_k_e2_grid_desc,
a_e0_e1_k_e2_grid_desc, b_e0_e1_n_ho_wo_e2_grid_desc,
b_e0_e1_n_ho_wo_e2_grid_desc, c_k_n_hop_wop_grid_desc,
c_k_n_hop_wop_grid_desc, c_blockid_to_k_n_ho_wo_block_cluster_adaptor);
c_blockid_to_k_n_ho_wo_block_cluster_adaptor);
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel =
kernel_gemm_dlops_v2<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AGridDesc_E0_E1_K_E2>,
remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo_E2>,
remove_reference_t<CGridDesc_K_N_Ho_Wo>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>,
true,
false>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
p_c_grid,
a_e0_e1_k_e2_grid_desc,
b_e0_e1_n_ho_wo_e2_grid_desc,
c_k_n_hop_wop_grid_desc,
c_blockid_to_k_n_ho_wo_block_cluster_adaptor);
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
kernel_gemm_dlops_v2<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AGridDesc_E0_E1_K_E2>,
remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo_E2>,
remove_reference_t<CGridDesc_K_N_Ho_Wo>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>,
false,
true>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
p_c_grid,
a_e0_e1_k_e2_grid_desc,
b_e0_e1_n_ho_wo_e2_grid_desc,
c_k_n_hop_wop_grid_desc,
c_blockid_to_k_n_ho_wo_block_cluster_adaptor);
}
else
{
const auto kernel =
kernel_gemm_dlops_v2<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AGridDesc_E0_E1_K_E2>,
remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo_E2>,
remove_reference_t<CGridDesc_K_N_Ho_Wo>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>,
false,
false>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
p_c_grid,
a_e0_e1_k_e2_grid_desc,
b_e0_e1_n_ho_wo_e2_grid_desc,
c_k_n_hop_wop_grid_desc,
c_blockid_to_k_n_ho_wo_block_cluster_adaptor);
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER #elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem a_e0_e1_k_e2_grid_desc_dev_buf(sizeof(AGridDesc_E0_E1_K_E2)); DeviceMem a_e0_e1_k_e2_grid_desc_dev_buf(sizeof(AGridDesc_E0_E1_K_E2));
DeviceMem b_e0_e1_n_ho_wo_e2_grid_desc_dev_buf(sizeof(BGridDesc_E0_E1_N_Ho_Wo_E2)); DeviceMem b_e0_e1_n_ho_wo_e2_grid_desc_dev_buf(sizeof(BGridDesc_E0_E1_N_Ho_Wo_E2));
...@@ -409,130 +337,35 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -409,130 +337,35 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
c_blockid_to_k_n_ho_wo_block_cluster_adaptor_dev_buf.ToDevice( c_blockid_to_k_n_ho_wo_block_cluster_adaptor_dev_buf.ToDevice(
&c_blockid_to_k_n_ho_wo_block_cluster_adaptor); &c_blockid_to_k_n_ho_wo_block_cluster_adaptor);
if(has_main_k_block_loop && has_double_tail_k_block_loop) const auto kernel =
{ kernel_gemm_dlops_v2<GridwiseGemm,
const auto kernel = FloatAB,
kernel_gemm_dlops_v2<GridwiseGemm, FloatC,
FloatAB, remove_reference_t<AGridDesc_E0_E1_K_E2>,
FloatC, remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo_E2>,
remove_reference_t<AGridDesc_E0_E1_K_E2>, remove_reference_t<CGridDesc_K_N_Ho_Wo>,
remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo_E2>, remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>,
remove_reference_t<CGridDesc_K_N_Ho_Wo>, has_main_e0_block_loop,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>, has_main_e1_block_loop,
true, has_double_tail_e1_block_loop>;
true>;
ave_time = launch_and_time_kernel(
ave_time = launch_and_time_kernel( kernel,
kernel, nrepeat,
nrepeat, dim3(grid_size),
dim3(grid_size), dim3(BlockSize),
dim3(BlockSize), 0,
0, p_a_grid,
p_a_grid, p_b_grid,
p_b_grid, p_c_grid,
p_c_grid, cast_pointer_to_constant_address_space(
cast_pointer_to_constant_address_space( a_e0_e1_k_e2_grid_desc_dev_buf.GetDeviceBuffer()),
a_e0_e1_k_e2_grid_desc_dev_buf.GetDeviceBuffer()), cast_pointer_to_constant_address_space(
cast_pointer_to_constant_address_space( b_e0_e1_n_ho_wo_e2_grid_desc_dev_buf.GetDeviceBuffer()),
b_e0_e1_n_ho_wo_e2_grid_desc_dev_buf.GetDeviceBuffer()), cast_pointer_to_constant_address_space(
cast_pointer_to_constant_address_space( c_k_n_hop_wop_grid_desc_dev_buf.GetDeviceBuffer()),
c_k_n_hop_wop_grid_desc_dev_buf.GetDeviceBuffer()), cast_pointer_to_constant_address_space(
cast_pointer_to_constant_address_space( c_blockid_to_k_n_ho_wo_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
c_blockid_to_k_n_ho_wo_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel =
kernel_gemm_dlops_v2<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AGridDesc_E0_E1_K_E2>,
remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo_E2>,
remove_reference_t<CGridDesc_K_N_Ho_Wo>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>,
true,
false>;
ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
p_c_grid,
cast_pointer_to_constant_address_space(
a_e0_e1_k_e2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
b_e0_e1_n_ho_wo_e2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
c_k_n_hop_wop_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
c_blockid_to_k_n_ho_wo_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
kernel_gemm_dlops_v2<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AGridDesc_E0_E1_K_E2>,
remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo_E2>,
remove_reference_t<CGridDesc_K_N_Ho_Wo>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>,
false,
true>;
ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
p_c_grid,
cast_pointer_to_constant_address_space(
a_e0_e1_k_e2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
b_e0_e1_n_ho_wo_e2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
c_k_n_hop_wop_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
c_blockid_to_k_n_ho_wo_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
}
else
{
const auto kernel =
kernel_gemm_dlops_v2<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AGridDesc_E0_E1_K_E2>,
remove_reference_t<BGridDesc_E0_E1_N_Ho_Wo_E2>,
remove_reference_t<CGridDesc_K_N_Ho_Wo>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo>,
false,
false>;
ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
p_c_grid,
cast_pointer_to_constant_address_space(
a_e0_e1_k_e2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
b_e0_e1_n_ho_wo_e2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
c_k_n_hop_wop_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
c_blockid_to_k_n_ho_wo_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
}
#endif #endif
return ave_time; return ave_time;
} }
......
...@@ -34,7 +34,7 @@ enum ConvForwardAlgo ...@@ -34,7 +34,7 @@ enum ConvForwardAlgo
V4R4NCHW, // 0 V4R4NCHW, // 0
V4R4R2NHWC, // 1 V4R4R2NHWC, // 1
V6R1NCHW, // 2 V6R1NCHW, // 2
V5R1NCHWc, // 3 V5R1NCHWC, // 3
V5R1NHWC, // 4 V5R1NHWC, // 4
V4R4R2XDLNCHW, // 5 V4R4R2XDLNCHW, // 5
V4R4R4XDLNHWC // 6 V4R4R4XDLNHWC // 6
...@@ -105,13 +105,57 @@ int main(int argc, char* argv[]) ...@@ -105,13 +105,57 @@ int main(int argc, char* argv[])
const bool do_log = std::stoi(argv[5]); const bool do_log = std::stoi(argv[5]);
const int nrepeat = std::stoi(argv[6]); const int nrepeat = std::stoi(argv[6]);
constexpr index_t activ_type = 0;
#if 1
constexpr auto N = Number<1>{};
constexpr auto C = Number<16>{};
constexpr auto Hi = Number<1080>{};
constexpr auto Wi = Number<1920>{};
constexpr auto K = Number<64>{};
constexpr auto Y = Number<3>{};
constexpr auto X = Number<3>{};
#elif 0
constexpr auto N = Number<1>{}; constexpr auto N = Number<1>{};
constexpr auto C = Number<16>{}; constexpr auto C = Number<16>{};
constexpr auto Hi = Number<1080>{}; constexpr auto Hi = Number<540>{};
constexpr auto Wi = Number<1920>{}; constexpr auto Wi = Number<960>{};
constexpr auto K = Number<16>{}; constexpr auto K = Number<64>{};
constexpr auto Y = Number<3>{}; constexpr auto Y = Number<3>{};
constexpr auto X = Number<3>{}; constexpr auto X = Number<3>{};
#elif 0
constexpr auto N = Number<1>{};
constexpr auto C = Number<16>{};
constexpr auto Hi = Number<270>{};
constexpr auto Wi = Number<480>{};
constexpr auto K = Number<64>{};
constexpr auto Y = Number<3>{};
constexpr auto X = Number<3>{};
#elif 0
constexpr auto N = Number<1>{};
constexpr auto C = Number<16>{};
constexpr auto Hi = Number<135>{};
constexpr auto Wi = Number<240>{};
constexpr auto K = Number<64>{};
constexpr auto Y = Number<3>{};
constexpr auto X = Number<3>{};
#elif 0
constexpr auto N = Number<1>{};
constexpr auto C = Number<16>{};
constexpr auto Hi = Number<1440>{};
constexpr auto Wi = Number<2560>{};
constexpr auto K = Number<64>{};
constexpr auto Y = Number<3>{};
constexpr auto X = Number<3>{};
#elif 0
constexpr auto N = Number<1>{};
constexpr auto C = Number<16>{};
constexpr auto Hi = Number<2160>{};
constexpr auto Wi = Number<3840>{};
constexpr auto K = Number<64>{};
constexpr auto Y = Number<3>{};
constexpr auto X = Number<3>{};
#endif
constexpr auto conv_stride_h = I1; constexpr auto conv_stride_h = I1;
constexpr auto conv_stride_w = I1; constexpr auto conv_stride_w = I1;
...@@ -345,7 +389,7 @@ int main(int argc, char* argv[]) ...@@ -345,7 +389,7 @@ int main(int argc, char* argv[])
#endif #endif
#if USE_CONV_FWD_V5R1_NCHWC #if USE_CONV_FWD_V5R1_NCHWC
if(algo == ConvForwardAlgo::V5R1NCHWc) if(algo == ConvForwardAlgo::V5R1NCHWC)
{ {
if(layout != ConvTensorLayout::NCHW) if(layout != ConvTensorLayout::NCHW)
{ {
...@@ -356,7 +400,9 @@ int main(int argc, char* argv[]) ...@@ -356,7 +400,9 @@ int main(int argc, char* argv[])
device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw<in_data_t, device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw<in_data_t,
acc_data_t, acc_data_t,
out_data_t>(tmp[I0], out_data_t,
8,
activ_type>(tmp[I0],
tmp[I1], tmp[I1],
tmp[I2], tmp[I2],
tmp[I3], tmp[I3],
...@@ -459,7 +505,8 @@ int main(int argc, char* argv[]) ...@@ -459,7 +505,8 @@ int main(int argc, char* argv[])
make_tuple(conv_dilation_h, conv_dilation_w), make_tuple(conv_dilation_h, conv_dilation_w),
make_tuple(in_left_pad_h, in_left_pad_w), make_tuple(in_left_pad_h, in_left_pad_w),
make_tuple(in_right_pad_h, in_right_pad_w), make_tuple(in_right_pad_h, in_right_pad_w),
layout); layout,
activ_type);
check_error(out_host, out_device); check_error(out_host, out_device);
......
#pragma once #pragma once
#include "host_tensor.hpp" #include "host_tensor.hpp"
template <typename T>
inline auto activ(T v, const ck::index_t activ_type)
{
switch(activ_type)
{
case 0: return v;
case 1: return (v >= 0 ? v : 0);
case 2: return (1 / (1 + exp(-v)));
default: throw std::runtime_error("unsupported activ type"); break;
}
}
template <typename TIn, template <typename TIn,
typename TWei, typename TWei,
typename TOut, typename TOut,
...@@ -15,7 +27,8 @@ void host_direct_convolution(const Tensor<TIn>& in, ...@@ -15,7 +27,8 @@ void host_direct_convolution(const Tensor<TIn>& in,
const ConvDilations& conv_dilations, const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads, const InLeftPads& in_left_pads,
const InRightPads&, const InRightPads&,
const ConvTensorLayout layout = ConvTensorLayout::NCHW) const ConvTensorLayout layout = ConvTensorLayout::NCHW,
const ck::index_t activ_type = 0)
{ {
using namespace ck; using namespace ck;
...@@ -41,7 +54,7 @@ void host_direct_convolution(const Tensor<TIn>& in, ...@@ -41,7 +54,7 @@ void host_direct_convolution(const Tensor<TIn>& in,
} }
} }
} }
out(n, k, ho, wo) = v; out(n, k, ho, wo) = activ(v, activ_type);
}; };
auto f_nhwc = [&](auto n, auto ho, auto wo, auto k) { auto f_nhwc = [&](auto n, auto ho, auto wo, auto k) {
...@@ -63,7 +76,7 @@ void host_direct_convolution(const Tensor<TIn>& in, ...@@ -63,7 +76,7 @@ void host_direct_convolution(const Tensor<TIn>& in,
} }
} }
} }
out(n, ho, wo, k) = v; out(n, ho, wo, k) = activ(v, activ_type);
}; };
if(layout == ConvTensorLayout::NCHW) if(layout == ConvTensorLayout::NCHW)
...@@ -88,6 +101,115 @@ void host_direct_convolution(const Tensor<TIn>& in, ...@@ -88,6 +101,115 @@ void host_direct_convolution(const Tensor<TIn>& in,
} }
} }
template <typename TIn,
typename TWei,
typename TOut,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
void host_direct_convolution_add(const Tensor<TIn>& in,
const Tensor<TWei>& wei,
const Tensor<TOut>& add,
Tensor<TOut>& out,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads&,
const ConvTensorLayout layout = ConvTensorLayout::NCHW,
const ck::index_t activ_type = 0)
{
using namespace ck;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
auto f_nchw = [&](auto n, auto k, auto ho, auto wo) {
double v = 0;
for(int c = 0; c < wei.mDesc.GetLengths()[1]; ++c)
{
for(int y = 0; y < wei.mDesc.GetLengths()[2]; ++y)
{
int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0];
for(int x = 0; x < wei.mDesc.GetLengths()[3]; ++x)
{
int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1];
if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 &&
wi < in.mDesc.GetLengths()[3])
{
v += static_cast<const double>(in(n, c, hi, wi)) *
static_cast<const double>(wei(k, c, y, x));
}
}
}
}
index_t hox2 = ho * 2;
index_t wox2 = wo * 2;
v = activ(v, activ_type);
out(n, k, hox2, wox2) = v + add(n, k, hox2, wox2);
out(n, k, hox2, wox2 + 1) = v + add(n, k, hox2, wox2 + 1);
out(n, k, hox2 + 1, wox2) = v + add(n, k, hox2 + 1, wox2);
out(n, k, hox2 + 1, wox2 + 1) = v + add(n, k, hox2 + 1, wox2 + 1);
};
auto f_nhwc = [&](auto n, auto ho, auto wo, auto k) {
double v = 0;
for(int c = 0; c < wei.mDesc.GetLengths()[3]; ++c)
{
for(int y = 0; y < wei.mDesc.GetLengths()[1]; ++y)
{
int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0];
for(int x = 0; x < wei.mDesc.GetLengths()[2]; ++x)
{
int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1];
if(hi >= 0 && hi < in.mDesc.GetLengths()[1] && wi >= 0 &&
wi < in.mDesc.GetLengths()[2])
{
v += static_cast<const double>(in(n, hi, wi, c)) *
static_cast<const double>(wei(k, y, x, c));
}
}
}
}
index_t hox2 = ho * 2;
index_t wox2 = wo * 2;
v = activ(v, activ_type);
out(n, k, hox2, wox2) = v + add(n, k, hox2, wox2);
out(n, k, hox2, wox2 + 1) = v + add(n, k, hox2, wox2 + 1);
out(n, k, hox2 + 1, wox2) = v + add(n, k, hox2 + 1, wox2);
out(n, k, hox2 + 1, wox2 + 1) = v + add(n, k, hox2 + 1, wox2 + 1);
};
if(layout == ConvTensorLayout::NCHW)
{
make_ParallelTensorFunctor(f_nchw,
out.mDesc.GetLengths()[0],
out.mDesc.GetLengths()[1],
out.mDesc.GetLengths()[2] / 2,
out.mDesc.GetLengths()[3] /
2)(std::thread::hardware_concurrency());
}
else if(layout == ConvTensorLayout::NHWC)
{
make_ParallelTensorFunctor(f_nhwc,
out.mDesc.GetLengths()[0],
out.mDesc.GetLengths()[1],
out.mDesc.GetLengths()[2] / 2,
out.mDesc.GetLengths()[3] /
2)(std::thread::hardware_concurrency());
}
else
{
throw std::runtime_error("wrong! not supported layout");
}
}
template <typename TIn, typename TWei, typename TOut, typename InLeftPads, typename InRightPads> template <typename TIn, typename TWei, typename TOut, typename InLeftPads, typename InRightPads>
void host_winograd_3x3_convolution(const Tensor<TIn>& in_nchw, void host_winograd_3x3_convolution(const Tensor<TIn>& in_nchw,
const Tensor<TWei>& wei_kcyx, const Tensor<TWei>& wei_kcyx,
......
...@@ -29,6 +29,8 @@ INIT=$4 ...@@ -29,6 +29,8 @@ INIT=$4
LOG=$5 LOG=$5
REPEAT=$6 REPEAT=$6
./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT
################################################ layout algo verify init log repeat N__ K___ C___ Y X Hi_ Wi__ Strides Dilations LeftPads RightPads ################################################ layout algo verify init log repeat N__ K___ C___ Y X Hi_ Wi__ Strides Dilations LeftPads RightPads
#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 192 3 3 71 71 2 2 1 1 1 1 1 1 #./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 192 3 3 71 71 2 2 1 1 1 1 1 1
#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 #./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1
...@@ -51,7 +53,7 @@ REPEAT=$6 ...@@ -51,7 +53,7 @@ REPEAT=$6
#./host/driver_online/conv_fwd_driver_online $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 192 3 3 71 71 2 2 1 1 1 1 1 1 #./host/driver_online/conv_fwd_driver_online $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 192 3 3 71 71 2 2 1 1 1 1 1 1
./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 1 16 16 3 3 1080 1920 1 1 1 1 1 1 1 1 #./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 1 16 16 3 3 1080 1920 1 1 1 1 1 1 1 1
#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 1 16 16 1 1 16 16 1 1 1 1 0 0 0 0 #./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 1 16 16 1 1 16 16 1 1 1 1 0 0 0 0
################################################ layout algo verify init log repeat M___ N___ K___ ################################################ layout algo verify init log repeat M___ N___ K___
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment