Commit 1b4614be authored by Jing Zhang's avatar Jing Zhang
Browse files

static mode

parent d5de0968
......@@ -61,6 +61,13 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1
in_n_c0_hi_wi_c1_device_buf.ToDevice(in_n_c0_hi_wi_c1.mData.data());
wei_k_c0_y_x_c1_device_buf.ToDevice(wei_k_c0_y_x_c1.mData.data());
constexpr index_t InWeiVectorSize = 8;
if(C1 % InWeiVectorSize != 0)
{
throw std::runtime_error("wrong! C1 cannot be divided by InWeiVectorSize");
}
#if 0
constexpr index_t BlockSize = 256;
......@@ -115,15 +122,17 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1
constexpr index_t CThreadTransferDstScalarPerVector_K = 8;
#endif
constexpr index_t InWeiVectorSize = 8;
const auto in_n_c0_hi_wi_c1_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, C0, Hi, Wi, C1 / InWeiVectorSize));
make_naive_tensor_descriptor_packed(make_tuple(N, C0, Hi, Wi, C1));
const auto wei_k_c0_y_x_c1_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, C0, Y, X, C1 / InWeiVectorSize));
make_naive_tensor_descriptor_packed(make_tuple(K, C0, Y, X, C1));
const auto out_n_k0_ho_wo_k1_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1));
static_assert(in_n_c0_hi_wi_c1_desc.IsKnownAtCompileTime(), "");
static_assert(wei_k_c0_y_x_c1_desc.IsKnownAtCompileTime(), "");
static_assert(out_n_k0_ho_wo_k1_desc.IsKnownAtCompileTime(), "");
constexpr auto conv_driver =
DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0hwk1_outpad<
BlockSize,
......
......@@ -61,7 +61,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
const auto C0 = in_n_c0_hi_wi_c1_global_desc.GetLength(I1);
const auto Hi = in_n_c0_hi_wi_c1_global_desc.GetLength(I2);
const auto Wi = in_n_c0_hi_wi_c1_global_desc.GetLength(I3);
// const auto C1 = in_n_c0_hi_wi_c1_global_desc.GetLength(I4);
//const auto C1 = in_n_c0_hi_wi_c1_global_desc.GetLength(I4);
const auto K0 = out_n_k0_ho_wo_k1_global_desc.GetLength(I1);
const auto Ho = out_n_k0_ho_wo_k1_global_desc.GetLength(I2);
......@@ -78,11 +78,13 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
// const auto Hop = Number<(Ho + HoPerBlock - 1) / HoPerBlock * HoPerBlock>{};
// const auto Wop = Number<(Wo + WoPerBlock - 1) / WoPerBlock * WoPerBlock>{};
#if 1
const auto Hop = Number<(Ho + HoPerBlock - 1) / HoPerBlock * HoPerBlock>{};
const auto Wop = Number<(Wo + WoPerBlock - 1) / WoPerBlock * WoPerBlock>{};
#else
const auto Hop = (Ho + HoPerBlock - 1) / HoPerBlock * HoPerBlock;
const auto Wop = (Wo + WoPerBlock - 1) / WoPerBlock * WoPerBlock;
#endif
const auto OutRightPadH = Hop - Ho;
const auto OutRightPadW = Wop - Wo;
......@@ -119,8 +121,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}));
// input tensor
const auto in_n_c0_hip_wip_c1_global_desc = transform_tensor_descriptor(
in_n_c0_hi_wi_c1_global_desc,
const auto in_n_c0_hip_wip_e2_global_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(K, C0, Hi, Wi, E2)),
make_tuple(make_pass_through_transform(N),
make_pass_through_transform(C0),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
......@@ -129,8 +131,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_n_c0_y_ho_x_wo_c1_global_desc = transform_tensor_descriptor(
in_n_c0_hip_wip_c1_global_desc,
const auto in_n_c0_y_ho_x_wo_e2_global_desc = transform_tensor_descriptor(
in_n_c0_hip_wip_e2_global_desc,
make_tuple(
make_pass_through_transform(N),
make_pass_through_transform(C0),
......@@ -142,7 +144,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}, Sequence<6>{}));
const auto b_e_n_ho_wo_e2_grid_desc = transform_tensor_descriptor(
in_n_c0_y_ho_x_wo_c1_global_desc,
in_n_c0_y_ho_x_wo_e2_global_desc,
make_tuple(make_merge_transform(make_tuple(C0, Y, X)),
make_pass_through_transform(N),
make_pass_through_transform(Hop),
......@@ -251,9 +253,9 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
// clang-format on
// static_assert(a_e0_e1_k_e2_grid_desc.IsKnownAtCompileTime(), "");
// static_assert(b_e0_e1_n_ho_wo_e2_grid_desc.IsKnownAtCompileTime(), "");
// static_assert(c_k_n_hop_wop_grid_desc.IsKnownAtCompileTime(), "");
static_assert(a_e0_e1_k_e2_grid_desc.IsKnownAtCompileTime(), "");
static_assert(b_e0_e1_n_ho_wo_e2_grid_desc.IsKnownAtCompileTime(), "");
static_assert(c_k_n_hop_wop_grid_desc.IsKnownAtCompileTime(), "");
// GEMM
using GridwiseGemm = GridwiseGemmDlops_km_kn_mn_v3<
......
......@@ -15,7 +15,7 @@
#include "device_tensor.hpp"
#include "device_convolution_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp"
#define USE_DYNAMIC_MODE 1
#define USE_DYNAMIC_MODE 0
#define USE_CONV_FWD_V5R1_NCHWC 1
enum ConvForwardAlgo
......@@ -90,7 +90,7 @@ int main(int argc, char* argv[])
const bool do_log = std::stoi(argv[4]);
const int nrepeat = std::stoi(argv[5]);
#if 1
#if 0
constexpr auto N = Number<1>{};
constexpr auto C0 = Number<2>{};
constexpr auto Hi = Number<1080>{};
......@@ -100,14 +100,16 @@ int main(int argc, char* argv[])
constexpr auto Y = Number<3>{};
constexpr auto X = Number<3>{};
constexpr auto K1 = Number<8>{};
#elif 0
#elif 1
constexpr auto N = Number<1>{};
constexpr auto C = Number<16>{};
constexpr auto C0 = Number<2>{};
constexpr auto Hi = Number<540>{};
constexpr auto Wi = Number<960>{};
constexpr auto K = Number<64>{};
constexpr auto C1 = Number<8>{};
constexpr auto K0 = Number<8>{};
constexpr auto Y = Number<3>{};
constexpr auto X = Number<3>{};
constexpr auto K1 = Number<8>{};
#elif 0
constexpr auto N = Number<1>{};
constexpr auto C = Number<16>{};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment