Commit ecaad8c0 authored by Jing Zhang's avatar Jing Zhang
Browse files

initialize forw/back merge

parent 11b848da
......@@ -5,7 +5,13 @@
#include "gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hip.hpp"
#include "gridwise_convolution_implicit_gemm_v4_lds_double_buffer_nchw_kcyx_nkhw.hip.hpp"
template <class T, class InDesc, class WeiDesc, class OutDesc, class Strides, class Dilations>
template <class T,
class InDesc,
class WeiDesc,
class OutDesc,
class Strides,
class Dilations,
index_t Direction>
void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
const Tensor<T>& in_nchw,
WeiDesc,
......@@ -13,6 +19,7 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
OutDesc,
Strides,
Dilations,
Number<Direction>,
Tensor<T>& out_nkhw,
index_t nrepeat)
{
......
......@@ -499,7 +499,7 @@ int main(int argc, char* argv[])
constexpr index_t HDilation = 1;
constexpr index_t WDilation = 1;
constexpr index_t Direction = 1; // 1: Forward; 2:Backward
constexpr index_t Direction = 2; // 1: Forward; 2:Backward
#if 0
constexpr index_t N = 32;
constexpr index_t C = 128;
......@@ -680,7 +680,7 @@ int main(int argc, char* argv[])
auto out_nkhw_desc = get_convolution_output_default_4d_tensor_descriptor(
in_nchw_desc, wei_kcyx_desc, strides, dilations);
auto wei_ckyx_back_desc = wei_kcyx_desc.ReorderGivenNew2Old(Sequence<1, 0, 2, 3>{});
// auto wei_ckyx_back_desc = wei_kcyx_desc.ReorderGivenNew2Old(Sequence<1, 0, 2, 3>{});
ostream_ConstantTensorDescriptor(in_nchw_desc, std::cout << "in_nchw_desc: ");
ostream_ConstantTensorDescriptor(wei_kcyx_desc, std::cout << "wei_kcyx_desc: ");
......@@ -756,11 +756,12 @@ int main(int argc, char* argv[])
#endif
(out_nkhw_desc,
out_nkhw,
wei_ckyx_back_desc,
wei_kcyx_desc,
wei_kcyx,
in_nchw_desc,
strides,
dilations,
Number<Direction>{},
in_nchw_device,
nrepeat);
......
......@@ -71,7 +71,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{};
// to-do: backward data: 1) ckyx: yx unfold, 2) merge cyx = e, 3 out = ek
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
constexpr auto wei_k_c_1_1_global_desc = WeiGlobalDesc{};
constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{};
constexpr index_t N = in_n_c_h_w_global_desc.GetLength(I0);
......@@ -83,8 +83,8 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
constexpr index_t Ho = out_n_k_h_w_global_desc.GetLength(I2);
constexpr index_t Wo = out_n_k_h_w_global_desc.GetLength(I3);
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2);
constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3);
constexpr index_t Y = wei_k_c_1_1_global_desc.GetLength(I2);
constexpr index_t X = wei_k_c_1_1_global_desc.GetLength(I3);
static_assert(N % (N1 * N2) == 0, "wrong! cannot divice N evenly among thread");
......@@ -192,12 +192,11 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
// weight tensor
// tensor descriptor in device memory, src of blockwise copy
#if 0
constexpr auto wei_e_k_global_desc =
wei_k_c_y_x_global_desc.Unfold(I1, I3).ReorderGivenNew2Old(Sequence<1, 0>{});
#if 1 // backward
constexpr auto wei_e_k_global_desc = wei_k_c_1_1_global_desc.Unfold(I1, I3);
#else
constexpr auto wei_e_k_global_desc = make_ConstantMergedTensorDescriptor(
wei_k_c_y_x_global_desc, Sequence<1, 2, 3>{}, Sequence<0>{});
constexpr auto wei_e_k_global_desc =
wei_k_c_y_x_global_desc.Unfold(I1, I3).ReorderGivenNew2Old(Sequence<1, 0>{})
#endif
// tensor descriptor in LDS, dst of blockwise copy
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment