Commit 7d2efb23 authored by Chao Liu's avatar Chao Liu
Browse files

update embed usage

parent 039b5e7e
...@@ -111,8 +111,12 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw ...@@ -111,8 +111,12 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
in_n_c_hip_wip_global_desc, in_n_c_hip_wip_global_desc,
make_tuple(PassThrough<N>{}, make_tuple(PassThrough<N>{},
PassThrough<C>{}, PassThrough<C>{},
Embed<Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{}, Embed<Hi + InLeftPads::At(0) + InRightPads::At(0),
Embed<Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}), Sequence<Y, Ho>,
Sequence<ConvDilationH, ConvStrideH, 0>>{},
Embed<Wi + InLeftPads::At(1) + InRightPads::At(1),
Sequence<X, Wo>,
Sequence<ConvDilationW, ConvStrideW, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
......
...@@ -378,8 +378,12 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl ...@@ -378,8 +378,12 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl
in_n_c_hip_wip_global_desc, in_n_c_hip_wip_global_desc,
make_tuple(UnMerge<Sequence<N0, N1>>{}, make_tuple(UnMerge<Sequence<N0, N1>>{},
UnMerge<Sequence<C0, C1>>{}, UnMerge<Sequence<C0, C1>>{},
Embed<Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{}, Embed<Hi + LeftPads::At(0) + RightPads::At(0),
Embed<Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}), Sequence<Y, Ho>,
Sequence<ConvDilationH, ConvStrideH, 0>>{},
Embed<Wi + LeftPads::At(1) + RightPads::At(1),
Sequence<X, Wo>,
Sequence<ConvDilationW, ConvStrideW, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}, Sequence<6, 7>{})); make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}, Sequence<6, 7>{}));
......
...@@ -13,8 +13,8 @@ ...@@ -13,8 +13,8 @@
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "conv_common.hpp" #include "conv_common.hpp"
#include "host_conv_bwd_data.hpp" #include "host_conv_bwd_data.hpp"
//#include "device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp" #include "device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp"
//#include "device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw.hpp" #include "device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw.hpp"
#include "device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp" #include "device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp"
int main(int argc, char* argv[]) int main(int argc, char* argv[])
...@@ -65,7 +65,22 @@ int main(int argc, char* argv[]) ...@@ -65,7 +65,22 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 3>; using LeftPads = Sequence<0, 3>;
using RightPads = Sequence<0, 3>; using RightPads = Sequence<0, 3>;
#elif 0 #elif 1
// 3x3 filter, 2x2 stride
constexpr index_t N = 128;
constexpr index_t C = 1024;
constexpr index_t HI = 35;
constexpr index_t WI = 35;
constexpr index_t K = 1024;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 1
// 3x3, 34x34 // 3x3, 34x34
constexpr index_t N = 64; constexpr index_t N = 64;
constexpr index_t C = 256; constexpr index_t C = 256;
...@@ -365,7 +380,7 @@ int main(int argc, char* argv[]) ...@@ -365,7 +380,7 @@ int main(int argc, char* argv[])
#if 0 #if 0
device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw
#elif 0 #elif 1
device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw
#else #else
device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment