Commit ec381569 authored by Jing Zhang's avatar Jing Zhang
Browse files

add maxpool fusion

parent 0f276ac2
......@@ -875,7 +875,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
});
}
// bias
// Bias
{
constexpr auto bias_k0_k1_thread_desc =
make_naive_tensor_descriptor_packed(make_tuple(I1, Number<KPerThread>{}));
......@@ -976,7 +976,9 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
}
#endif
#if 0
// Resize_Add
if constexpr(add_type == 0)
{
constexpr auto HoPerThreadx2 = HoPerThread * 2;
constexpr auto WoPerThreadx2 = WoPerThread * 2;
......@@ -1069,7 +1071,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
InMemoryDataOperationEnum_t::Add,
1,
true>(d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
make_multi_index(k_block_work_id,
......@@ -1088,6 +1090,97 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
d_global_buf,
d_k_n_h0_h1_hx_w0_w1_wx_global_tensor_step_hacks);
}
// MaxPool
else if constexpr(add_type == 1)
{
static_assert(HoPerThread % 2 == 0 && WoPerThread % 2 == 0, "");
constexpr auto HoPerThread_2 = HoPerThread / 2;
constexpr auto WoPerThread_2 = WoPerThread / 2;
constexpr auto d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc =
make_naive_tensor_descriptor_packed(make_tuple(I1,
Number<KPerThread>{},
I1,
I1,
I1,
Number<HoPerThread_2>{},
I1,
I1,
Number<WoPerThread_2>{}));
StaticBuffer<AddressSpaceEnum_t::Vgpr,
FloatC,
d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc.GetElementSpaceSize(),
true>
d_thread_buf;
#if 1
static_for<0, KPerThread, 1>{}([&](auto ki) {
static_for<0, HoPerThread_2, 1>{}([&](auto hi) {
static_for<0, WoPerThread_2, 1>{}([&](auto wi) {
constexpr index_t d_offset =
d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc.CalculateOffset(
make_tuple(0, ki, 0, 0, 0, hi, 0, 0, wi));
constexpr index_t c_offset_0 =
c_k1_n_h2_w2_thread_gemm_desc.CalculateOffset(
make_tuple(ki, 0, hi * 2, wi * 2));
constexpr index_t c_offset_1 =
c_k1_n_h2_w2_thread_gemm_desc.CalculateOffset(
make_tuple(ki, 0, hi * 2, wi * 2 + 1));
constexpr index_t c_offset_2 =
c_k1_n_h2_w2_thread_gemm_desc.CalculateOffset(
make_tuple(ki, 0, hi * 2 + 1, wi * 2));
constexpr index_t c_offset_3 =
c_k1_n_h2_w2_thread_gemm_desc.CalculateOffset(
make_tuple(ki, 0, hi * 2 + 1, wi * 2 + 1));
d_thread_buf(Number<d_offset>{}) = c_thread_buf[Number<c_offset_0>{}];
d_thread_buf(Number<d_offset>{}) = max(c_thread_buf[Number<c_offset_1>{}],
d_thread_buf(Number<d_offset>{}));
d_thread_buf(Number<d_offset>{}) = max(c_thread_buf[Number<c_offset_2>{}],
d_thread_buf(Number<d_offset>{}));
d_thread_buf(Number<d_offset>{}) = max(c_thread_buf[Number<c_offset_3>{}],
d_thread_buf(Number<d_offset>{}));
});
});
});
#endif
const index_t k_thread_data_on_global = k_thread_id * KPerThread;
constexpr auto d_k_n_h0_h1_hx_w0_w1_wx_global_tensor_step_hacks = DGlobalStepHacks{};
ThreadwiseTensorSliceTransfer_v1r3<
FloatC,
FloatC,
decltype(d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc),
decltype(d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc),
Sequence<I1, KPerThread, I1, I1, I1, HoPerThread_2, I1, I1, WoPerThread_2>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
InMemoryDataOperationEnum_t::Set,
1,
true>(d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
make_multi_index(k_block_work_id,
k_thread_data_on_global,
n_block_work_id,
ho_block_work_id,
ho_thread_id,
0,
wo_block_work_id,
wo_thread_id,
0))
.Run(d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0),
d_thread_buf,
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
d_global_buf,
d_k_n_h0_h1_hx_w0_w1_wx_global_tensor_step_hacks);
}
#endif
}
};
......
......@@ -14,6 +14,7 @@ include_directories(BEFORE
set(CONV_FWD_DRIVER_OFFLINE_SOURCE src/conv_fwd_driver_offline.cpp)
set(CONV_FWD_DRIVER_OFFLINE_NCHWC_SOURCE src/conv_fwd_driver_offline_nchwc.cpp)
set(CONV_ADD_FWD_DRIVER_OFFLINE_NCHWC_SOURCE src/conv_add_fwd_driver_offline_nchwc.cpp)
set(CONV_MAXPOOL_FWD_DRIVER_OFFLINE_NCHWC_SOURCE src/conv_maxpool_fwd_driver_offline_nchwc.cpp)
set(CONV_BWD_DRIVER_OFFLINE_SOURCE src/conv_bwd_driver_offline.cpp)
set(CONV_WRW_DRIVER_OFFLINE_SOURCE src/conv_wrw_driver_offline.cpp)
set(GEMM_DRIVER_OFFLINE_SOURCE src/gemm_driver_offline.cpp)
......@@ -21,6 +22,7 @@ set(GEMM_DRIVER_OFFLINE_SOURCE src/gemm_driver_offline.cpp)
add_executable(conv_fwd_driver_offline ${CONV_FWD_DRIVER_OFFLINE_SOURCE})
add_executable(conv_fwd_driver_offline_nchwc ${CONV_FWD_DRIVER_OFFLINE_NCHWC_SOURCE})
add_executable(conv_add_fwd_driver_offline_nchwc ${CONV_ADD_FWD_DRIVER_OFFLINE_NCHWC_SOURCE})
add_executable(conv_maxpool_fwd_driver_offline_nchwc ${CONV_MAXPOOL_FWD_DRIVER_OFFLINE_NCHWC_SOURCE})
add_executable(conv_bwd_driver_offline ${CONV_BWD_DRIVER_OFFLINE_SOURCE})
add_executable(conv_wrw_driver_offline ${CONV_WRW_DRIVER_OFFLINE_SOURCE})
add_executable(gemm_driver_offline ${GEMM_DRIVER_OFFLINE_SOURCE})
......@@ -28,6 +30,7 @@ add_executable(gemm_driver_offline ${GEMM_DRIVER_OFFLINE_SOURCE})
target_link_libraries(conv_fwd_driver_offline PRIVATE host_tensor)
target_link_libraries(conv_fwd_driver_offline_nchwc PRIVATE host_tensor)
target_link_libraries(conv_add_fwd_driver_offline_nchwc PRIVATE host_tensor)
target_link_libraries(conv_maxpool_fwd_driver_offline_nchwc PRIVATE host_tensor)
target_link_libraries(conv_bwd_driver_offline PRIVATE host_tensor)
target_link_libraries(conv_wrw_driver_offline PRIVATE host_tensor)
target_link_libraries(gemm_driver_offline PRIVATE host_tensor)
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "driver_convolution_maxpool_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp"
template <typename TInWei,
typename TAcc,
typename TOut,
ck::index_t activ_type,
typename InLengths,
typename WeiLengths,
typename MaxLengths,
typename OutLengths,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
void device_convolution_maxpool_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1(
const InLengths& in_n_c0_hi_wi_c1_lengths,
const WeiLengths& wei_k_c0_y_x_c1_lengths,
const MaxLengths& max_n_k0_hx_wx_k1_lengths,
const OutLengths& out_n_k0_ho_wo_k1_lengths,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
const Tensor<TInWei>& in_n_c0_hi_wi_c1,
const Tensor<TInWei>& wei_k_c0_y_x_c1,
const Tensor<TOut>& bias_k0_k1,
Tensor<TOut>& out_n_k0_ho_wo_k1,
Tensor<TOut>& max_n_k0_hx_wx_k1,
ck::index_t nrepeat)
{
using namespace ck;
std::cout << __func__ << std::endl;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
const auto N = out_n_k0_ho_wo_k1_lengths[I0];
const auto K0 = out_n_k0_ho_wo_k1_lengths[I1];
const auto Ho = out_n_k0_ho_wo_k1_lengths[I2];
const auto Wo = out_n_k0_ho_wo_k1_lengths[I3];
const auto K1 = out_n_k0_ho_wo_k1_lengths[I4];
const auto C0 = in_n_c0_hi_wi_c1_lengths[I1];
const auto Hi = in_n_c0_hi_wi_c1_lengths[I2];
const auto Wi = in_n_c0_hi_wi_c1_lengths[I3];
const auto C1 = in_n_c0_hi_wi_c1_lengths[I4];
const auto K = wei_k_c0_y_x_c1_lengths[I0];
const auto Y = wei_k_c0_y_x_c1_lengths[I2];
const auto X = wei_k_c0_y_x_c1_lengths[I3];
const auto Hx = max_n_k0_hx_wx_k1_lengths[I2];
const auto Wx = max_n_k0_hx_wx_k1_lengths[I3];
DeviceMem in_n_c0_hi_wi_c1_device_buf(sizeof(TInWei) *
in_n_c0_hi_wi_c1.mDesc.GetElementSpace());
DeviceMem wei_k_c0_y_x_c1_device_buf(sizeof(TInWei) * wei_k_c0_y_x_c1.mDesc.GetElementSpace());
DeviceMem bias_k0_k1_device_buf(sizeof(TOut) * bias_k0_k1.mDesc.GetElementSpace());
DeviceMem out_n_k0_ho_wo_k1_device_buf(sizeof(TOut) *
out_n_k0_ho_wo_k1.mDesc.GetElementSpace());
DeviceMem max_n_k0_hx_wx_k1_device_buf(sizeof(TOut) *
max_n_k0_hx_wx_k1.mDesc.GetElementSpace());
in_n_c0_hi_wi_c1_device_buf.ToDevice(in_n_c0_hi_wi_c1.mData.data());
wei_k_c0_y_x_c1_device_buf.ToDevice(wei_k_c0_y_x_c1.mData.data());
bias_k0_k1_device_buf.ToDevice(bias_k0_k1.mData.data());
max_n_k0_hx_wx_k1_device_buf.ToDevice(max_n_k0_hx_wx_k1.mData.data());
constexpr index_t InWeiVectorSize = 8;
if(C1 % InWeiVectorSize != 0)
{
throw std::runtime_error("wrong! C1 cannot be divided by InWeiVectorSize");
}
#if 0
constexpr index_t BlockSize = 256;
constexpr index_t KPerBlock = 32;
constexpr index_t HoPerBlock = 8;
constexpr index_t WoPerBlock = 64;
constexpr index_t E1 = C0 * 9;
constexpr index_t E2 = 1;
constexpr index_t E1PerBlock = C0;
constexpr index_t KPerThread = 16;
constexpr index_t HoPerThread = 2;
constexpr index_t WoPerThread = 2;
constexpr index_t EPerThread = 1;
using ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2 = Sequence<1, 9, 1, E2>;
using ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2 = Sequence<1, E1PerBlock, KPerBlock, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_E2 = E2;
constexpr index_t ABlockTransferDstScalarPerVector_E2 = E2;
constexpr index_t BThreadTransferSrcScalarPerVector_E2 = E2;
constexpr index_t CThreadTransferDstScalarPerVector_K = K1;
#elif 1
constexpr auto BlockSize = 64;
constexpr auto KPerBlock = 16;
constexpr auto HoPerBlock = 8;
constexpr auto WoPerBlock = 32;
constexpr auto E1 = 2 * 9;
constexpr auto E2 = 1;
constexpr auto K2 = 2;
constexpr auto E1PerBlock = 2;
constexpr auto KPerThread = 16;
constexpr auto HoPerThread = 2;
constexpr auto WoPerThread = 2;
constexpr auto EPerThread = 1;
using ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2 = Sequence<1, 9, 1, 1, E2>;
using ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2 =
Sequence<1, E1PerBlock, 1, KPerBlock, 1>;
constexpr auto ABlockTransferSrcScalarPerVector_E2 = E2;
constexpr auto ABlockTransferDstScalarPerVector_E2 = E2;
constexpr auto BThreadTransferSrcScalarPerVector_E2 = E2;
constexpr auto CThreadTransferDstScalarPerVector_K = 8;
#endif
const auto in_n_c0_hi_wi_c1_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, C0, Hi, Wi, C1));
const auto wei_k_c0_y_x_c1_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, C0, Y, X, C1));
const auto max_n_k0_hx_wx_k1_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, K0, Hx, Wx, K1));
const auto out_n_k0_ho_wo_k1_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1));
static_assert(in_n_c0_hi_wi_c1_desc.IsKnownAtCompileTime(), "");
static_assert(wei_k_c0_y_x_c1_desc.IsKnownAtCompileTime(), "");
static_assert(max_n_k0_hx_wx_k1_desc.IsKnownAtCompileTime(), "");
static_assert(out_n_k0_ho_wo_k1_desc.IsKnownAtCompileTime(), "");
constexpr auto conv_driver =
DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0hwk1_maxpool<
BlockSize,
typename vector_type<TInWei, InWeiVectorSize>::type,
TAcc,
TOut,
E1,
E2,
K2,
KPerBlock,
HoPerBlock,
WoPerBlock,
E1PerBlock,
KPerThread,
HoPerThread,
WoPerThread,
EPerThread,
ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2,
ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2,
ABlockTransferSrcScalarPerVector_E2,
ABlockTransferDstScalarPerVector_E2,
BThreadTransferSrcScalarPerVector_E2,
CThreadTransferDstScalarPerVector_K,
activ_type>{};
for(int i = 0; i < 5; i++)
{
const auto ave_time =
conv_driver.Run(wei_k_c0_y_x_c1_desc,
in_n_c0_hi_wi_c1_desc,
out_n_k0_ho_wo_k1_desc,
max_n_k0_hx_wx_k1_desc,
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads,
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
wei_k_c0_y_x_c1_device_buf.GetDeviceBuffer()),
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
in_n_c0_hi_wi_c1_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(bias_k0_k1_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_k0_ho_wo_k1_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(max_n_k0_hx_wx_k1_device_buf.GetDeviceBuffer()),
nrepeat);
{
float perf = static_cast<float>(std::size_t(2) * N * K * Ho * Wo * C0 * C1 * Y * X) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl;
}
}
out_n_k0_ho_wo_k1_device_buf.FromDevice(out_n_k0_ho_wo_k1.mData.data());
max_n_k0_hx_wx_k1_device_buf.FromDevice(max_n_k0_hx_wx_k1.mData.data());
}
......@@ -300,7 +300,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
FloatAB,
FloatAcc,
FloatC,
InMemoryDataOperationEnum_t::Add,
InMemoryDataOperationEnum_t::Set,
decltype(a_e0_e1_k_e2_grid_desc),
decltype(b_e0_e1_n_ho_wo_e2_grid_desc),
decltype(c_k_n_hop_wop_grid_desc),
......
#ifndef DRIVER_CONVOLUTION_MAXPOOL_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NC0HWc1_KC0YXC1_NK0HWK1_HPP
#define DRIVER_CONVOLUTION_MAXPOOL_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NC0HWc1_KC0YXC1_NK0HWK1_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_dlops_v2_add.hpp"
template <ck::index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
ck::index_t E1_,
ck::index_t E2_,
ck::index_t K2_,
ck::index_t KPerBlock,
ck::index_t HoPerBlock,
ck::index_t WoPerBlock,
ck::index_t E1PerBlock,
ck::index_t KPerThread,
ck::index_t HoPerThread,
ck::index_t WoPerThread,
ck::index_t EPerThread,
typename ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2,
typename ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2,
ck::index_t ABlockTransferSrcScalarPerVector_E2,
ck::index_t ABlockTransferDstScalarPerVector_E2,
ck::index_t BThreadTransferSrcScalarPerVector_E2,
ck::index_t CThreadTransferDstScalarPerVector_K,
ck::index_t activ_type>
struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0hwk1_maxpool
{
template <typename... Wei,
typename... In,
typename... Add,
typename... Out,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
__host__ float Run(const ck::TensorDescriptor<Wei...>& wei_k_c0_y_x_c1_global_desc,
const ck::TensorDescriptor<In...>& in_n_c0_hi_wi_c1_global_desc,
const ck::TensorDescriptor<Out...>& out_n_k0_ho_wo_k1_global_desc,
const ck::TensorDescriptor<Add...>& max_n_k0_hx_wx_k1_global_desc,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
const FloatC* __restrict__ p_bias_grid,
FloatC* __restrict__ p_c_grid,
FloatC* __restrict__ p_d_grid,
const int nrepeat) const
{
using namespace ck;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
const auto N = in_n_c0_hi_wi_c1_global_desc.GetLength(I0);
const auto C0 = in_n_c0_hi_wi_c1_global_desc.GetLength(I1);
const auto Hi = in_n_c0_hi_wi_c1_global_desc.GetLength(I2);
const auto Wi = in_n_c0_hi_wi_c1_global_desc.GetLength(I3);
// const auto C1 = in_n_c0_hi_wi_c1_global_desc.GetLength(I4);
const auto K0 = out_n_k0_ho_wo_k1_global_desc.GetLength(I1);
const auto Ho = out_n_k0_ho_wo_k1_global_desc.GetLength(I2);
const auto Wo = out_n_k0_ho_wo_k1_global_desc.GetLength(I3);
const auto K1 = out_n_k0_ho_wo_k1_global_desc.GetLength(I4);
const auto Hx = max_n_k0_hx_wx_k1_global_desc.GetLength(I2);
const auto Wx = max_n_k0_hx_wx_k1_global_desc.GetLength(I3);
const auto K = wei_k_c0_y_x_c1_global_desc.GetLength(I0);
const auto Y = wei_k_c0_y_x_c1_global_desc.GetLength(I2);
const auto X = wei_k_c0_y_x_c1_global_desc.GetLength(I3);
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1];
const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
#if 1
const auto Hop = Number<(Ho + HoPerBlock - 1) / HoPerBlock * HoPerBlock>{};
const auto Wop = Number<(Wo + WoPerBlock - 1) / WoPerBlock * WoPerBlock>{};
#else
const auto Hop = (Ho + HoPerBlock - 1) / HoPerBlock * HoPerBlock;
const auto Wop = (Wo + WoPerBlock - 1) / WoPerBlock * WoPerBlock;
#endif
const auto OutRightPadH = Hop - Ho;
const auto OutRightPadW = Wop - Wo;
const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1];
const auto InRightPadH = in_right_pads[I0] + OutRightPadH * ConvStrideH;
const auto InRightPadW = in_right_pads[I1] + OutRightPadW * ConvStrideW;
const auto E = C0 * Y * X;
constexpr auto E1 = Number<E1_>{};
constexpr auto E2 = Number<E2_>{};
constexpr auto K2 = Number<K2_>{};
const auto E0 = E / E1;
// weight tensor
const auto a_e_k_e2_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(K, C0 * Y * X, E2)),
make_tuple(make_pass_through_transform(K),
make_pass_through_transform(C0 * Y * X),
make_pass_through_transform(E2)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}, Sequence<2>{}));
const auto a_e0_e1_k_e2_grid_desc =
transform_tensor_descriptor(a_e_k_e2_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(E0, E1)),
make_pass_through_transform(K),
make_pass_through_transform(E2)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}));
// input tensor
const auto in_n_c0_hip_wip_e2_global_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(K, C0, Hi, Wi, E2)),
make_tuple(make_pass_through_transform(N),
make_pass_through_transform(C0),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(E2)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_n_c0_y_ho_x_wo_e2_global_desc = transform_tensor_descriptor(
in_n_c0_hip_wip_e2_global_desc,
make_tuple(
make_pass_through_transform(N),
make_pass_through_transform(C0),
make_embed_transform(make_tuple(Y, Hop), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wop), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(E2)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}, Sequence<6>{}));
const auto b_e_n_ho_wo_e2_grid_desc = transform_tensor_descriptor(
in_n_c0_y_ho_x_wo_e2_global_desc,
make_tuple(make_merge_transform(make_tuple(C0, Y, X)),
make_pass_through_transform(N),
make_pass_through_transform(Hop),
make_pass_through_transform(Wop),
make_pass_through_transform(E2)),
make_tuple(
Sequence<1, 2, 4>{}, Sequence<0>{}, Sequence<3>{}, Sequence<5>{}, Sequence<6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto b_e0_e1_n_ho_wo_e2_grid_desc = transform_tensor_descriptor(
b_e_n_ho_wo_e2_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(E0, E1)),
make_pass_through_transform(N),
make_pass_through_transform(Hop),
make_pass_through_transform(Wop),
make_pass_through_transform(E2)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}, Sequence<5>{}));
// output tensor
const auto c_k_n_hop_wop_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1)),
make_tuple(make_merge_transform(make_tuple(K0, K1)),
make_pass_through_transform(N),
make_pad_transform(Ho, I0, OutRightPadH),
make_pad_transform(Wo, I0, OutRightPadW)),
make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
// max tensor
const auto d_k_n_hopx2_wopx2_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N, K0, Hx, Wx, K1)),
make_tuple(make_merge_transform(make_tuple(K0, K1)),
make_pass_through_transform(N),
make_pad_transform(Hx, I0, Number<OutRightPadH/2>{}),
make_pad_transform(Wx, I0, Number<OutRightPadW/2>{})),
make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
std::cerr << "Hop = " << Hop << " Wop = " << Wop << std::endl;
if(!((K % KPerBlock) == 0 && (Hop % HoPerBlock) == 0 && (Wop % WoPerBlock) == 0 &&
(E1 % E1PerBlock) == 0))
{
throw std::runtime_error("wrong! GEMM size no divisible");
}
// clang-format off
// hack to control index calculation when iterating over a_e0_e1_k_e2_global tensor
constexpr auto a_e0_e1_k_e2_global_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
constexpr auto a_e0_e1_k_e2_global_move_slice_window_step_hack =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{};
// hack to control index calculation when iterating over b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global tensor
constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks =
make_tuple(
make_tuple(
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}),
make_tuple(
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})
);
constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{};
constexpr auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks =
make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 2, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 2, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
constexpr auto d_k0_k1_n_h0_h1_hx_w0_w1_wx_global_tensor_step_hacks =
make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 1, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 2, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 2, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
// clang-format on
static_assert(a_e0_e1_k_e2_grid_desc.IsKnownAtCompileTime(), "");
static_assert(b_e0_e1_n_ho_wo_e2_grid_desc.IsKnownAtCompileTime(), "");
static_assert(d_k_n_hopx2_wopx2_grid_desc.IsKnownAtCompileTime(), "");
static_assert(c_k_n_hop_wop_grid_desc.IsKnownAtCompileTime(), "");
// GEMM
using GridwiseGemm = GridwiseGemmDlops_km_kn_mn_v3_add<
BlockSize,
FloatAB,
FloatAcc,
FloatC,
InMemoryDataOperationEnum_t::Set,
decltype(a_e0_e1_k_e2_grid_desc),
decltype(b_e0_e1_n_ho_wo_e2_grid_desc),
decltype(c_k_n_hop_wop_grid_desc),
decltype(d_k_n_hopx2_wopx2_grid_desc),
E1,
E2,
K2,
KPerBlock,
HoPerBlock,
WoPerBlock,
E1PerBlock,
KPerThread,
HoPerThread,
WoPerThread,
EPerThread,
ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2,
ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2,
Sequence<2, 3, 0, 1, 4>,
Sequence<0, 1, 2, 3, 4>,
4,
ABlockTransferSrcScalarPerVector_E2,
ABlockTransferDstScalarPerVector_E2,
false, // don't move back src coordinate after threadwise copy
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, // E0, E1, N, H0, H1, H2, W0, W1, W2, E2
9,
BThreadTransferSrcScalarPerVector_E2,
false, // don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8>, // K0, K1, N, H0, H1, I2, H2, W0, W1, I2, W2
1,
CThreadTransferDstScalarPerVector_K,
decltype(a_e0_e1_k_e2_global_step_hacks),
decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks),
decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks),
decltype(d_k0_k1_n_h0_h1_hx_w0_w1_wx_global_tensor_step_hacks),
decltype(a_e0_e1_k_e2_global_move_slice_window_step_hack),
decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_move_slice_window_step_hack),
activ_type,
1>;
const auto a_e0_e1_k0_k1_e2_grid_desc =
GridwiseGemm::MakeAE0E1K0K1E2GridDescriptor(a_e0_e1_k_e2_grid_desc);
const auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc =
GridwiseGemm::MakeBE0E1NH0H1H2W0W1W2E2GridDescriptor(b_e0_e1_n_ho_wo_e2_grid_desc);
const auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc =
GridwiseGemm::MakeCK0K1NH0H1H2W0W1W2GridDescriptor(c_k_n_hop_wop_grid_desc);
const auto d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc =
GridwiseGemm::MakeDK0K1NH0H1HxW0W1WxGridDescriptor(d_k_n_hopx2_wopx2_grid_desc);
using AGridDesc_E0_E1_K0_K1_E2 = decltype(a_e0_e1_k0_k1_e2_grid_desc);
using BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 =
decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc);
using CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2 = decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc);
using DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx =
decltype(d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc);
const auto grid_size = (K / KPerBlock) * (Hop / HoPerBlock) * (Wop / WoPerBlock) * N;
const bool has_main_e0_block_loop = E0 > 1;
std::cerr << "has_main_e0_block_loop = " << has_main_e0_block_loop << std::endl;
const auto c_blockid_to_k_n_h_w_block_cluster_adaptor =
GridwiseGemm::MakeCBlockIdToKNHoWoBlockClusterAdaptor(c_k_n_hop_wop_grid_desc);
using CBlockIdToBlockClusterAdaptor_K_N_H_W =
decltype(c_blockid_to_k_n_h_w_block_cluster_adaptor);
float ave_time = 0;
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
if(has_main_e0_block_loop)
{
const auto kernel = kernel_gemm_dlops_v2_add<
GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AGridDesc_E0_E1_K0_K1_E2>,
remove_reference_t<BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2>,
remove_reference_t<CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2>,
remove_reference_t<DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_H_W>,
true>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
p_bias_grid,
p_c_grid,
p_d_grid,
a_e0_e1_k0_k1_e2_grid_desc,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
c_blockid_to_k_n_h_w_block_cluster_adaptor);
}
else
{
const auto kernel = kernel_gemm_dlops_v2_add<
GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AGridDesc_E0_E1_K0_K1_E2>,
remove_reference_t<BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2>,
remove_reference_t<CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2>,
remove_reference_t<DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_H_W>,
false>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
p_bias_grid,
p_c_grid,
p_d_grid,
a_e0_e1_k0_k1_e2_grid_desc,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
c_blockid_to_k_n_h_w_block_cluster_adaptor);
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem a_e0_e1_k0_k1_e2_grid_desc_dev_buf(sizeof(AGridDesc_E0_E1_K0_K1_E2));
DeviceMem b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc_dev_buf(
sizeof(BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2));
DeviceMem c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc_dev_buf(
sizeof(CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2));
DeviceMem d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc_dev_buf(
sizeof(DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx));
DeviceMem c_blockid_to_k_n_h_w_block_cluster_adaptor_dev_buf(
sizeof(CBlockIdToBlockClusterAdaptor_K_N_H_W));
a_e0_e1_k0_k1_e2_grid_desc_dev_buf.ToDevice(&a_e0_e1_k0_k1_e2_grid_desc);
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc_dev_buf.ToDevice(
&b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc);
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc_dev_buf.ToDevice(
&c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc);
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc_dev_buf.ToDevice(
&d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc);
c_blockid_to_k_n_h_w_block_cluster_adaptor_dev_buf.ToDevice(
&c_blockid_to_k_n_h_w_block_cluster_adaptor);
if(has_main_e0_block_loop)
{
const auto kernel = kernel_gemm_dlops_v2_add<
GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AGridDesc_E0_E1_K0_K1_E2>,
remove_reference_t<BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2>,
remove_reference_t<CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2>,
remove_reference_t<DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_H_W>,
true>;
ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
p_bias_grid,
p_c_grid,
p_d_grid,
cast_pointer_to_constant_address_space(
a_e0_e1_k0_k1_e2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
c_blockid_to_k_n_h_w_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
}
else
{
const auto kernel = kernel_gemm_dlops_v2_add<
GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AGridDesc_E0_E1_K0_K1_E2>,
remove_reference_t<BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2>,
remove_reference_t<CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2>,
remove_reference_t<DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx>,
remove_reference_t<CBlockIdToBlockClusterAdaptor_K_N_H_W>,
false>;
ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
p_bias_grid,
p_c_grid,
p_d_grid,
cast_pointer_to_constant_address_space(
a_e0_e1_k0_k1_e2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
c_blockid_to_k_n_h_w_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
}
#endif
return ave_time;
}
};
#endif
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
//#include <half.hpp>
#include "config.hpp"
#include "debug.hpp"
#include "print.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "conv_common.hpp"
#include "host_conv.hpp"
#include "device_tensor.hpp"
#include "device_convolution_maxpool_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp"
#define USE_DYNAMIC_MODE 0
#define USE_CONV_FWD_V5R1_NCHWC 1
enum ConvForwardAlgo
{
V5R1NCHWC // 0
};
int main(int argc, char* argv[])
{
using namespace ck;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{};
constexpr auto I6 = Number<6>{};
constexpr auto I7 = Number<7>{};
#if USE_DYNAMIC_MODE
// dynamic mode
if(argc != 23)
{
printf("arg1 to 5: algo, do_verification, init_method, do_log, nrepeat\n");
printf("rest: N, K0, K1, C0, C1, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, "
"RightPx\n");
exit(1);
}
constexpr index_t activ_type = 0;
const ConvForwardAlgo algo = static_cast<ConvForwardAlgo>(std::stoi(argv[1]));
const bool do_verification = std::stoi(argv[2]);
const int init_method = std::stoi(argv[3]);
const bool do_log = std::stoi(argv[4]);
const int nrepeat = std::stoi(argv[5]);
const index_t N = std::stoi(argv[6]);
const index_t K0 = std::stoi(argv[7]);
const index_t K1 = std::stoi(argv[8]);
const index_t C0 = std::stoi(argv[9]);
const index_t C1 = std::stoi(argv[10]);
const index_t Y = std::stoi(argv[11]);
const index_t X = std::stoi(argv[12]);
const index_t Hi = std::stoi(argv[13]);
const index_t Wi = std::stoi(argv[14]);
const index_t conv_stride_h = std::stoi(argv[15]);
const index_t conv_stride_w = std::stoi(argv[16]);
const index_t conv_dilation_h = std::stoi(argv[17]);
const index_t conv_dilation_w = std::stoi(argv[18]);
const index_t in_left_pad_h = std::stoi(argv[19]);
const index_t in_left_pad_w = std::stoi(argv[20]);
const index_t in_right_pad_h = std::stoi(argv[21]);
const index_t in_right_pad_w = std::stoi(argv[22]);
const index_t YEff = (Y - 1) * conv_dilation_h + 1;
const index_t XEff = (X - 1) * conv_dilation_w + 1;
const index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1;
const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
#else
// static mode
if(argc < 6)
{
printf("arg1 to 5: algo, do_verification, init_method, do_log, nrepeat\n");
exit(1);
}
const ConvForwardAlgo algo = static_cast<ConvForwardAlgo>(std::stoi(argv[1]));
const bool do_verification = std::stoi(argv[2]);
const int init_method = std::stoi(argv[3]);
const bool do_log = std::stoi(argv[4]);
const int nrepeat = std::stoi(argv[5]);
constexpr index_t activ_type = 0;
#if 0
constexpr auto N = Number<1>{};
constexpr auto Hi = Number<1080>{};
constexpr auto Wi = Number<1920>{};
constexpr auto Y = Number<3>{};
constexpr auto X = Number<3>{};
constexpr auto C0 = Number<2>{};
constexpr auto C1 = Number<8>{};
constexpr auto K1 = Number<8>{};
constexpr auto K0 = Number<8>{};
#elif 1
constexpr auto N = Number<1>{};
constexpr auto Hi = Number<540>{};
constexpr auto Wi = Number<960>{};
constexpr auto Y = Number<3>{};
constexpr auto X = Number<3>{};
constexpr auto C0 = Number<2>{};
constexpr auto C1 = Number<8>{};
constexpr auto K1 = Number<8>{};
constexpr auto K0 = Number<8>{};
#elif 0
constexpr auto N = Number<1>{};
constexpr auto Hi = Number<270>{};
constexpr auto Wi = Number<480>{};
constexpr auto Y = Number<3>{};
constexpr auto X = Number<3>{};
constexpr auto C0 = Number<2>{};
constexpr auto C1 = Number<8>{};
constexpr auto K1 = Number<8>{};
constexpr auto K0 = Number<8>{};
#elif 0
constexpr auto N = Number<1>{};
constexpr auto Hi = Number<135>{};
constexpr auto Wi = Number<240>{};
constexpr auto Y = Number<3>{};
constexpr auto X = Number<3>{};
constexpr auto C0 = Number<2>{};
constexpr auto C1 = Number<8>{};
constexpr auto K1 = Number<8>{};
constexpr auto K0 = Number<8>{};
#elif 1
constexpr auto N = Number<1>{};
constexpr auto Hi = Number<32>{};
constexpr auto Wi = Number<32>{};
constexpr auto Y = Number<3>{};
constexpr auto X = Number<3>{};
constexpr auto C0 = Number<2>{};
constexpr auto C1 = Number<8>{};
constexpr auto K1 = Number<8>{};
constexpr auto K0 = Number<8>{};
#endif
constexpr auto conv_stride_h = I1;
constexpr auto conv_stride_w = I1;
constexpr auto conv_dilation_h = I1;
constexpr auto conv_dilation_w = I1;
constexpr auto in_left_pad_h = I1;
constexpr auto in_left_pad_w = I1;
constexpr auto in_right_pad_h = I1;
constexpr auto in_right_pad_w = I1;
constexpr auto YEff = (Y - I1) * conv_dilation_h + I1;
constexpr auto XEff = (X - I1) * conv_dilation_w + I1;
constexpr auto Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + I1;
constexpr auto Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + I1;
constexpr auto Ho_2 = Number<Ho / 2>{};
constexpr auto Wo_2 = Number<Wo / 2>{};
#endif
#if 0
using in_data_t = float;
using acc_data_t = float;
using out_data_t = float;
#elif 1
using in_data_t = half_t;
using acc_data_t = float;
using out_data_t = half_t;
#elif 1
using in_data_t = int8_t;
using acc_data_t = int32_t;
using out_data_t = int8_t;
#endif
std::vector<std::size_t> in_lengths_host(5), wei_lengths_host(5), out_lengths_host(5),
max_lengths_host(5), bias_lengths_host(2);
in_lengths_host[0] = static_cast<std::size_t>(N);
in_lengths_host[1] = static_cast<std::size_t>(C0);
in_lengths_host[2] = static_cast<std::size_t>(Hi);
in_lengths_host[3] = static_cast<std::size_t>(Wi);
in_lengths_host[4] = static_cast<std::size_t>(C1);
wei_lengths_host[0] = static_cast<std::size_t>(K0 * K1);
wei_lengths_host[1] = static_cast<std::size_t>(C0);
wei_lengths_host[2] = static_cast<std::size_t>(Y);
wei_lengths_host[3] = static_cast<std::size_t>(X);
wei_lengths_host[4] = static_cast<std::size_t>(C1);
out_lengths_host[0] = static_cast<std::size_t>(N);
out_lengths_host[1] = static_cast<std::size_t>(K0);
out_lengths_host[2] = static_cast<std::size_t>(Ho);
out_lengths_host[3] = static_cast<std::size_t>(Wo);
out_lengths_host[4] = static_cast<std::size_t>(K1);
max_lengths_host[0] = static_cast<std::size_t>(N);
max_lengths_host[1] = static_cast<std::size_t>(K0);
max_lengths_host[2] = static_cast<std::size_t>(Ho_2);
max_lengths_host[3] = static_cast<std::size_t>(Wo_2);
max_lengths_host[4] = static_cast<std::size_t>(K1);
bias_lengths_host[0] = static_cast<std::size_t>(K0);
bias_lengths_host[1] = static_cast<std::size_t>(K1);
Tensor<in_data_t> in(in_lengths_host);
Tensor<in_data_t> wei(wei_lengths_host);
Tensor<out_data_t> bias(bias_lengths_host);
Tensor<out_data_t> out_device(out_lengths_host);
Tensor<out_data_t> out_host(out_lengths_host);
Tensor<in_data_t> max_device(max_lengths_host);
Tensor<in_data_t> max_host(max_lengths_host);
ostream_HostTensorDescriptor(in.mDesc, std::cout << "in: ");
ostream_HostTensorDescriptor(wei.mDesc, std::cout << "wei: ");
print_array("InLeftPads", make_tuple(in_left_pad_h, in_left_pad_w));
print_array("InRightPads", make_tuple(in_right_pad_h, in_right_pad_w));
print_array("ConvStrides", make_tuple(conv_stride_h, conv_stride_w));
print_array("ConvDilations", make_tuple(conv_dilation_h, conv_dilation_w));
std::size_t num_thread = std::thread::hardware_concurrency();
switch(init_method)
{
case 0:
// no initialization
break;
case 1:
in.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
break;
case 2:
in.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
break;
case 3:
in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
break;
case 4:
in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
break;
case 5:
in.GenerateTensorValue(GeneratorTensor_3<float>{0.0, 1.0}, num_thread);
wei.GenerateTensorValue(GeneratorTensor_3<float>{-0.5, 0.5}, num_thread);
break;
default:
in.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread);
auto gen_wei = [](auto... is) {
return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...);
};
wei.GenerateTensorValue(gen_wei, num_thread);
}
bias.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
auto f_make_for_device_nchwc = [&]() {
const auto in_lengths_dev = make_tuple(N, C0, Hi, Wi, C1);
const auto wei_lengths_dev = make_tuple(K0 * K1, C0, Y, X, C1);
const auto max_lengths_dev = make_tuple(N, K0, Ho_2, Wo_2, K1);
const auto out_lengths_dev = make_tuple(N, K0, Ho, Wo, K1);
const auto conv_strides_dev = make_tuple(conv_stride_h, conv_stride_w);
const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w);
const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w);
const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w);
return make_tuple(in_lengths_dev,
wei_lengths_dev,
max_lengths_dev,
out_lengths_dev,
conv_strides_dev,
conv_dilations_dev,
in_left_pads_dev,
in_right_pads_dev);
};
#if USE_CONV_FWD_V5R1_NCHWC
if(algo == ConvForwardAlgo::V5R1NCHWC)
{
const auto tmp = f_make_for_device_nchwc();
device_convolution_maxpool_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1<in_data_t,
acc_data_t,
out_data_t,
activ_type>(
tmp[I0], // in_lengths_dev
tmp[I1], // wei_lengths_dev
tmp[I2], // max_lengths_dev
tmp[I3], // out_lengths_dev
tmp[I4], // conv_strides_dev
tmp[I5], // conv_dilations_dev
tmp[I6], // in_left_pads_dev
tmp[I7], // in_right_pads_dev
in,
wei,
bias,
out_device,
max_device,
nrepeat);
}
#endif
if(do_verification)
{
host_direct_convolution_maxpool_nchwc(in,
wei,
bias,
out_host,
max_host,
make_tuple(conv_stride_h, conv_stride_w),
make_tuple(conv_dilation_h, conv_dilation_w),
make_tuple(in_left_pad_h, in_left_pad_w),
make_tuple(in_right_pad_h, in_right_pad_w),
activ_type);
check_error(out_host, out_device);
check_error(max_host, max_device);
if(do_log)
{
// LogRangeAsType<float>(std::cout << "in : ", in.mData, ",") << std::endl;
// LogRangeAsType<float>(std::cout << "wei: ", wei.mData, ",") << std::endl;
// LogRangeAsType<float>(std::cout << "out_device: ", out_device.mData, ",") <<
// std::endl;
LogRangeAsType<float>(std::cout << "max_host: ", max_host.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "max_device: ", max_device.mData, ",") << std::endl;
}
}
}
......@@ -226,6 +226,66 @@ void host_direct_convolution_add_nchwc(const Tensor<TIn>& in,
out_host.mDesc.GetLengths()[4])(std::thread::hardware_concurrency());
}
template <typename TIn,
typename TWei,
typename TOut,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
void host_direct_convolution_maxpool_nchwc(const Tensor<TIn>& in,
const Tensor<TWei>& wei,
const Tensor<TOut>& bias,
Tensor<TOut>& out_host,
Tensor<TOut>& max_host,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads&,
const ck::index_t activ_type = 0)
{
using namespace ck;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
auto f_nchw = [&](auto n, auto k0, auto ho, auto wo, auto k1) {
double v = 0;
for(int c0 = 0; c0 < wei.mDesc.GetLengths()[1]; ++c0)
{
for(int c1 = 0; c1 < wei.mDesc.GetLengths()[4]; ++c1)
{
for(int y = 0; y < wei.mDesc.GetLengths()[2]; ++y)
{
int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0];
for(int x = 0; x < wei.mDesc.GetLengths()[3]; ++x)
{
int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1];
if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 &&
wi < in.mDesc.GetLengths()[3])
{
v += static_cast<const double>(in(n, c0, hi, wi, c1)) *
static_cast<const double>(
wei(k0 * out_host.mDesc.GetLengths()[4] + k1, c0, y, x, c1));
}
}
}
}
}
v = activ(v, activ_type) + bias(k0, k1);
out_host(n, k0, ho, wo, k1) = v;
};
make_ParallelTensorFunctor(f_nchw,
out_host.mDesc.GetLengths()[0],
out_host.mDesc.GetLengths()[1],
out_host.mDesc.GetLengths()[2],
out_host.mDesc.GetLengths()[3],
out_host.mDesc.GetLengths()[4])(std::thread::hardware_concurrency());
}
template <typename TIn, typename TWei, typename TOut, typename InLeftPads, typename InRightPads>
void host_winograd_3x3_convolution(const Tensor<TIn>& in_nchw,
const Tensor<TWei>& wei_kcyx,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment