Commit ea8aa63f authored by Chao Liu's avatar Chao Liu
Browse files

adding bwd data v4r1 (multiple kernel launch)

parent ea484457
...@@ -205,7 +205,8 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_deprecated(InDesc, ...@@ -205,7 +205,8 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_deprecated(InDesc,
for(index_t i = 0; i < nrepeat; ++i) for(index_t i = 0; i < nrepeat; ++i)
{ {
float time = launch_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>, float time =
launch_and_time_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
dim3(GridSize), dim3(GridSize),
dim3(BlockSize), dim3(BlockSize),
0, 0,
......
...@@ -178,7 +178,7 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc, ...@@ -178,7 +178,7 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
for(index_t i = 0; i < nrepeat; ++i) for(index_t i = 0; i < nrepeat; ++i)
{ {
float time = launch_kernel( float time = launch_and_time_kernel(
gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw<TInWei, gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw<TInWei,
TOut, TOut,
accum_t, accum_t,
......
...@@ -17,18 +17,19 @@ ...@@ -17,18 +17,19 @@
#include "device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw.hpp" #include "device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw.hpp"
#include "device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp" #include "device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp" #include "device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
using namespace ck; using namespace ck;
#if 0 #if 1
// 3x3 filter, 2x2 stride, 35x35 input // 3x3 filter, 2x2 stride, 35x35 input
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 1024; constexpr index_t C = 1024;
constexpr index_t HI = 35; constexpr index_t HI = 35;
constexpr index_t WI = 35; constexpr index_t WI = 35;
constexpr index_t K = 1024; constexpr index_t K = 128;
constexpr index_t Y = 3; constexpr index_t Y = 3;
constexpr index_t X = 3; constexpr index_t X = 3;
...@@ -251,8 +252,10 @@ int main(int argc, char* argv[]) ...@@ -251,8 +252,10 @@ int main(int argc, char* argv[])
device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw
#elif 0 #elif 0
device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw
#elif 0
device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw
#else #else
device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw
#endif #endif
(in_nchw_desc, (in_nchw_desc,
in_nchw_device, in_nchw_device,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment