Commit d5de0968 authored by Jing Zhang's avatar Jing Zhang
Browse files

add dynamic support

parent 5ce317cb
...@@ -112,15 +112,15 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1 ...@@ -112,15 +112,15 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1
constexpr index_t BThreadTransferSrcScalarPerVector_E2 = E2; constexpr index_t BThreadTransferSrcScalarPerVector_E2 = E2;
constexpr index_t CThreadTransferDstScalarPerVector_K = K1; constexpr index_t CThreadTransferDstScalarPerVector_K = 8;
#endif #endif
constexpr index_t InWeiVectorSize = C1; constexpr index_t InWeiVectorSize = 8;
const auto in_n_c0_hi_wi_c1_desc = make_naive_tensor_descriptor_packed( const auto in_n_c0_hi_wi_c1_desc =
make_tuple(N, C0, Hi, Wi, Number<C1 / InWeiVectorSize>{})); make_naive_tensor_descriptor_packed(make_tuple(N, C0, Hi, Wi, C1 / InWeiVectorSize));
const auto wei_k_c0_y_x_c1_desc = make_naive_tensor_descriptor_packed( const auto wei_k_c0_y_x_c1_desc =
make_tuple(K, C0, Y, X, Number<C1 / InWeiVectorSize>{})); make_naive_tensor_descriptor_packed(make_tuple(K, C0, Y, X, C1 / InWeiVectorSize));
const auto out_n_k0_ho_wo_k1_desc = const auto out_n_k0_ho_wo_k1_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1)); make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1));
......
...@@ -61,7 +61,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0 ...@@ -61,7 +61,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
const auto C0 = in_n_c0_hi_wi_c1_global_desc.GetLength(I1); const auto C0 = in_n_c0_hi_wi_c1_global_desc.GetLength(I1);
const auto Hi = in_n_c0_hi_wi_c1_global_desc.GetLength(I2); const auto Hi = in_n_c0_hi_wi_c1_global_desc.GetLength(I2);
const auto Wi = in_n_c0_hi_wi_c1_global_desc.GetLength(I3); const auto Wi = in_n_c0_hi_wi_c1_global_desc.GetLength(I3);
const auto C1 = in_n_c0_hi_wi_c1_global_desc.GetLength(I4); // const auto C1 = in_n_c0_hi_wi_c1_global_desc.GetLength(I4);
const auto K0 = out_n_k0_ho_wo_k1_global_desc.GetLength(I1); const auto K0 = out_n_k0_ho_wo_k1_global_desc.GetLength(I1);
const auto Ho = out_n_k0_ho_wo_k1_global_desc.GetLength(I2); const auto Ho = out_n_k0_ho_wo_k1_global_desc.GetLength(I2);
...@@ -78,11 +78,11 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0 ...@@ -78,11 +78,11 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
const auto ConvDilationH = conv_dilations[I0]; const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1]; const auto ConvDilationW = conv_dilations[I1];
const auto Hop = Number<(Ho + HoPerBlock - 1) / HoPerBlock * HoPerBlock>{}; // const auto Hop = Number<(Ho + HoPerBlock - 1) / HoPerBlock * HoPerBlock>{};
const auto Wop = Number<(Wo + WoPerBlock - 1) / WoPerBlock * WoPerBlock>{}; // const auto Wop = Number<(Wo + WoPerBlock - 1) / WoPerBlock * WoPerBlock>{};
// const auto Hop = (Ho + HoPerBlock - 1) / HoPerBlock * HoPerBlock; const auto Hop = (Ho + HoPerBlock - 1) / HoPerBlock * HoPerBlock;
// const auto Wop = (Wo + WoPerBlock - 1) / WoPerBlock * WoPerBlock; const auto Wop = (Wo + WoPerBlock - 1) / WoPerBlock * WoPerBlock;
const auto OutRightPadH = Hop - Ho; const auto OutRightPadH = Hop - Ho;
const auto OutRightPadW = Wop - Wo; const auto OutRightPadW = Wop - Wo;
...@@ -99,8 +99,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0 ...@@ -99,8 +99,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
constexpr auto E2 = Number<E2_>{}; constexpr auto E2 = Number<E2_>{};
constexpr auto K2 = Number<K2_>{}; constexpr auto K2 = Number<K2_>{};
static_assert(E2 == C1, "");
const auto E0 = E / E1; const auto E0 = E / E1;
// weight tensor // weight tensor
...@@ -253,9 +251,9 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0 ...@@ -253,9 +251,9 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
// clang-format on // clang-format on
static_assert(a_e0_e1_k_e2_grid_desc.IsKnownAtCompileTime(), ""); // static_assert(a_e0_e1_k_e2_grid_desc.IsKnownAtCompileTime(), "");
static_assert(b_e0_e1_n_ho_wo_e2_grid_desc.IsKnownAtCompileTime(), ""); // static_assert(b_e0_e1_n_ho_wo_e2_grid_desc.IsKnownAtCompileTime(), "");
static_assert(c_k_n_hop_wop_grid_desc.IsKnownAtCompileTime(), ""); // static_assert(c_k_n_hop_wop_grid_desc.IsKnownAtCompileTime(), "");
// GEMM // GEMM
using GridwiseGemm = GridwiseGemmDlops_km_kn_mn_v3< using GridwiseGemm = GridwiseGemmDlops_km_kn_mn_v3<
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include <initializer_list> #include <initializer_list>
#include <cstdlib> #include <cstdlib>
#include <stdlib.h> #include <stdlib.h>
#include <half.hpp> //#include <half.hpp>
#include "config.hpp" #include "config.hpp"
#include "debug.hpp" #include "debug.hpp"
#include "print.hpp" #include "print.hpp"
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "device_convolution_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp" #include "device_convolution_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp"
#define USE_DYNAMIC_MODE 0 #define USE_DYNAMIC_MODE 1
#define USE_CONV_FWD_V5R1_NCHWC 1 #define USE_CONV_FWD_V5R1_NCHWC 1
enum ConvForwardAlgo enum ConvForwardAlgo
...@@ -37,35 +37,38 @@ int main(int argc, char* argv[]) ...@@ -37,35 +37,38 @@ int main(int argc, char* argv[])
#if USE_DYNAMIC_MODE #if USE_DYNAMIC_MODE
// dynamic mode // dynamic mode
if(argc != 21) if(argc != 23)
{ {
printf("arg1 to 5: algo, do_verification, init_method, do_log, nrepeat\n"); printf("arg1 to 5: algo, do_verification, init_method, do_log, nrepeat\n");
printf("rest: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx\n"); printf("rest: N, K0, K1, C0, C1, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, "
"RightPx\n");
exit(1); exit(1);
} }
const ConvForwardAlgo algo = static_cast<ConvForwardAlgo>(std::stoi(argv[2])); const ConvForwardAlgo algo = static_cast<ConvForwardAlgo>(std::stoi(argv[1]));
const bool do_verification = std::stoi(argv[3]); const bool do_verification = std::stoi(argv[2]);
const int init_method = std::stoi(argv[4]); const int init_method = std::stoi(argv[3]);
const bool do_log = std::stoi(argv[5]); const bool do_log = std::stoi(argv[4]);
const int nrepeat = std::stoi(argv[6]); const int nrepeat = std::stoi(argv[5]);
const index_t N = std::stoi(argv[7]); const index_t N = std::stoi(argv[6]);
const index_t K = std::stoi(argv[8]); const index_t K0 = std::stoi(argv[7]);
const index_t C = std::stoi(argv[9]); const index_t K1 = std::stoi(argv[8]);
const index_t Y = std::stoi(argv[10]); const index_t C0 = std::stoi(argv[9]);
const index_t X = std::stoi(argv[11]); const index_t C1 = std::stoi(argv[10]);
const index_t Hi = std::stoi(argv[12]); const index_t Y = std::stoi(argv[11]);
const index_t Wi = std::stoi(argv[13]); const index_t X = std::stoi(argv[12]);
const index_t Hi = std::stoi(argv[13]);
const index_t conv_stride_h = std::stoi(argv[14]); const index_t Wi = std::stoi(argv[14]);
const index_t conv_stride_w = std::stoi(argv[15]);
const index_t conv_dilation_h = std::stoi(argv[16]); const index_t conv_stride_h = std::stoi(argv[15]);
const index_t conv_dilation_w = std::stoi(argv[17]); const index_t conv_stride_w = std::stoi(argv[16]);
const index_t in_left_pad_h = std::stoi(argv[18]); const index_t conv_dilation_h = std::stoi(argv[17]);
const index_t in_left_pad_w = std::stoi(argv[19]); const index_t conv_dilation_w = std::stoi(argv[18]);
const index_t in_right_pad_h = std::stoi(argv[20]); const index_t in_left_pad_h = std::stoi(argv[19]);
const index_t in_right_pad_w = std::stoi(argv[21]); const index_t in_left_pad_w = std::stoi(argv[20]);
const index_t in_right_pad_h = std::stoi(argv[21]);
const index_t in_right_pad_w = std::stoi(argv[22]);
const index_t YEff = (Y - 1) * conv_dilation_h + 1; const index_t YEff = (Y - 1) * conv_dilation_h + 1;
const index_t XEff = (X - 1) * conv_dilation_w + 1; const index_t XEff = (X - 1) * conv_dilation_w + 1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment