Commit ecaad8c0 authored by Jing Zhang's avatar Jing Zhang
Browse files

initialize forw/back merge

parent 11b848da
...@@ -5,7 +5,13 @@ ...@@ -5,7 +5,13 @@
#include "gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hip.hpp" #include "gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hip.hpp"
#include "gridwise_convolution_implicit_gemm_v4_lds_double_buffer_nchw_kcyx_nkhw.hip.hpp" #include "gridwise_convolution_implicit_gemm_v4_lds_double_buffer_nchw_kcyx_nkhw.hip.hpp"
template <class T, class InDesc, class WeiDesc, class OutDesc, class Strides, class Dilations> template <class T,
class InDesc,
class WeiDesc,
class OutDesc,
class Strides,
class Dilations,
index_t Direction>
void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc, void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
const Tensor<T>& in_nchw, const Tensor<T>& in_nchw,
WeiDesc, WeiDesc,
...@@ -13,6 +19,7 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc, ...@@ -13,6 +19,7 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
OutDesc, OutDesc,
Strides, Strides,
Dilations, Dilations,
Number<Direction>,
Tensor<T>& out_nkhw, Tensor<T>& out_nkhw,
index_t nrepeat) index_t nrepeat)
{ {
......
...@@ -499,7 +499,7 @@ int main(int argc, char* argv[]) ...@@ -499,7 +499,7 @@ int main(int argc, char* argv[])
constexpr index_t HDilation = 1; constexpr index_t HDilation = 1;
constexpr index_t WDilation = 1; constexpr index_t WDilation = 1;
constexpr index_t Direction = 1; // 1: Forward; 2:Backward constexpr index_t Direction = 2; // 1: Forward; 2:Backward
#if 0 #if 0
constexpr index_t N = 32; constexpr index_t N = 32;
constexpr index_t C = 128; constexpr index_t C = 128;
...@@ -680,7 +680,7 @@ int main(int argc, char* argv[]) ...@@ -680,7 +680,7 @@ int main(int argc, char* argv[])
auto out_nkhw_desc = get_convolution_output_default_4d_tensor_descriptor( auto out_nkhw_desc = get_convolution_output_default_4d_tensor_descriptor(
in_nchw_desc, wei_kcyx_desc, strides, dilations); in_nchw_desc, wei_kcyx_desc, strides, dilations);
auto wei_ckyx_back_desc = wei_kcyx_desc.ReorderGivenNew2Old(Sequence<1, 0, 2, 3>{}); // auto wei_ckyx_back_desc = wei_kcyx_desc.ReorderGivenNew2Old(Sequence<1, 0, 2, 3>{});
ostream_ConstantTensorDescriptor(in_nchw_desc, std::cout << "in_nchw_desc: "); ostream_ConstantTensorDescriptor(in_nchw_desc, std::cout << "in_nchw_desc: ");
ostream_ConstantTensorDescriptor(wei_kcyx_desc, std::cout << "wei_kcyx_desc: "); ostream_ConstantTensorDescriptor(wei_kcyx_desc, std::cout << "wei_kcyx_desc: ");
...@@ -756,11 +756,12 @@ int main(int argc, char* argv[]) ...@@ -756,11 +756,12 @@ int main(int argc, char* argv[])
#endif #endif
(out_nkhw_desc, (out_nkhw_desc,
out_nkhw, out_nkhw,
wei_ckyx_back_desc, wei_kcyx_desc,
wei_kcyx, wei_kcyx,
in_nchw_desc, in_nchw_desc,
strides, strides,
dilations, dilations,
Number<Direction>{},
in_nchw_device, in_nchw_device,
nrepeat); nrepeat);
......
...@@ -71,7 +71,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw ...@@ -71,7 +71,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{}; constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{};
// to-do: backward data: 1) ckyx: yx unfold, 2) merge cyx = e, 3 out = ek // to-do: backward data: 1) ckyx: yx unfold, 2) merge cyx = e, 3 out = ek
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{}; constexpr auto wei_k_c_1_1_global_desc = WeiGlobalDesc{};
constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{}; constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{};
constexpr index_t N = in_n_c_h_w_global_desc.GetLength(I0); constexpr index_t N = in_n_c_h_w_global_desc.GetLength(I0);
...@@ -83,8 +83,8 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw ...@@ -83,8 +83,8 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
constexpr index_t Ho = out_n_k_h_w_global_desc.GetLength(I2); constexpr index_t Ho = out_n_k_h_w_global_desc.GetLength(I2);
constexpr index_t Wo = out_n_k_h_w_global_desc.GetLength(I3); constexpr index_t Wo = out_n_k_h_w_global_desc.GetLength(I3);
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2); constexpr index_t Y = wei_k_c_1_1_global_desc.GetLength(I2);
constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3); constexpr index_t X = wei_k_c_1_1_global_desc.GetLength(I3);
static_assert(N % (N1 * N2) == 0, "wrong! cannot divice N evenly among thread"); static_assert(N % (N1 * N2) == 0, "wrong! cannot divice N evenly among thread");
...@@ -192,12 +192,11 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw ...@@ -192,12 +192,11 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
// weight tensor // weight tensor
// tensor descriptor in device memory, src of blockwise copy // tensor descriptor in device memory, src of blockwise copy
#if 0 #if 1 // backward
constexpr auto wei_e_k_global_desc = constexpr auto wei_e_k_global_desc = wei_k_c_1_1_global_desc.Unfold(I1, I3);
wei_k_c_y_x_global_desc.Unfold(I1, I3).ReorderGivenNew2Old(Sequence<1, 0>{});
#else #else
constexpr auto wei_e_k_global_desc = make_ConstantMergedTensorDescriptor( constexpr auto wei_e_k_global_desc =
wei_k_c_y_x_global_desc, Sequence<1, 2, 3>{}, Sequence<0>{}); wei_k_c_y_x_global_desc.Unfold(I1, I3).ReorderGivenNew2Old(Sequence<1, 0>{})
#endif #endif
// tensor descriptor in LDS, dst of blockwise copy // tensor descriptor in LDS, dst of blockwise copy
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment