"...composable_kernel.git" did not exist on "ce87bcc759aa8c8889d217c348d80024219610f1"
Commit 7d2efb23 authored by Chao Liu's avatar Chao Liu
Browse files

update embed usage

parent 039b5e7e
......@@ -111,8 +111,12 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
in_n_c_hip_wip_global_desc,
make_tuple(PassThrough<N>{},
PassThrough<C>{},
Embed<Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
Embed<Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
Embed<Hi + InLeftPads::At(0) + InRightPads::At(0),
Sequence<Y, Ho>,
Sequence<ConvDilationH, ConvStrideH, 0>>{},
Embed<Wi + InLeftPads::At(1) + InRightPads::At(1),
Sequence<X, Wo>,
Sequence<ConvDilationW, ConvStrideW, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
......
......@@ -378,8 +378,12 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl
in_n_c_hip_wip_global_desc,
make_tuple(UnMerge<Sequence<N0, N1>>{},
UnMerge<Sequence<C0, C1>>{},
Embed<Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
Embed<Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
Embed<Hi + LeftPads::At(0) + RightPads::At(0),
Sequence<Y, Ho>,
Sequence<ConvDilationH, ConvStrideH, 0>>{},
Embed<Wi + LeftPads::At(1) + RightPads::At(1),
Sequence<X, Wo>,
Sequence<ConvDilationW, ConvStrideW, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}, Sequence<6, 7>{}));
......
......@@ -13,8 +13,8 @@
#include "device_tensor.hpp"
#include "conv_common.hpp"
#include "host_conv_bwd_data.hpp"
//#include "device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp"
//#include "device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw.hpp"
#include "device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw.hpp"
#include "device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp"
int main(int argc, char* argv[])
......@@ -65,7 +65,22 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 3>;
using RightPads = Sequence<0, 3>;
#elif 0
#elif 1
// 3x3 filter, 2x2 stride
constexpr index_t N = 128;
constexpr index_t C = 1024;
constexpr index_t HI = 35;
constexpr index_t WI = 35;
constexpr index_t K = 1024;
constexpr index_t Y = 3;
constexpr index_t X = 3;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 1
// 3x3, 34x34
constexpr index_t N = 64;
constexpr index_t C = 256;
......@@ -365,7 +380,7 @@ int main(int argc, char* argv[])
#if 0
device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw
#elif 0
#elif 1
device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw
#else
device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment