Commit 81ec25b4 authored by Jing Zhang's avatar Jing Zhang
Browse files

add backward without exchange in-out

parent ecaad8c0
...@@ -13,7 +13,7 @@ template <class T, ...@@ -13,7 +13,7 @@ template <class T,
class Dilations, class Dilations,
index_t Direction> index_t Direction>
void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc, void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
const Tensor<T>& in_nchw, Tensor<T>& in_nchw,
WeiDesc, WeiDesc,
const Tensor<T>& wei_kcyx, const Tensor<T>& wei_kcyx,
OutDesc, OutDesc,
...@@ -56,8 +56,9 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc, ...@@ -56,8 +56,9 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
constexpr index_t N1 = 2; constexpr index_t N1 = 2;
constexpr index_t N2 = 4; constexpr index_t N2 = 4;
constexpr index_t B = N * mod_conv::integer_divide_ceil(Ho, Strides::Get(I0)) * //constexpr index_t B = N * mod_conv::integer_divide_ceil(Ho, Strides::Get(I0)) *
mod_conv::integer_divide_ceil(Wo, Strides::Get(I1)) / (N1 * N2); //mod_conv::integer_divide_ceil(Wo, Strides::Get(I1)) / (N1 * N2);
constexpr index_t B = (N * Ho * Wo) / (N1 * N2);
#if 1 #if 1
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
...@@ -160,5 +161,8 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc, ...@@ -160,5 +161,8 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
usleep(std::min(time * 1000, float(10000))); usleep(std::min(time * 1000, float(10000)));
} }
out_nkhw_device_buf.FromDevice(out_nkhw.mData.data()); if(Direction == 1)
out_nkhw_device_buf.FromDevice(out_nkhw.mData.data());
else
in_nchw_device_buf.FromDevice(in_nchw.mData.data());
} }
...@@ -754,15 +754,15 @@ int main(int argc, char* argv[]) ...@@ -754,15 +754,15 @@ int main(int argc, char* argv[])
#elif 1 #elif 1
device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw
#endif #endif
(out_nkhw_desc, (in_nchw_desc,
out_nkhw, in_nchw_device,
wei_kcyx_desc, wei_kcyx_desc,
wei_kcyx, wei_kcyx,
in_nchw_desc, out_nkhw_desc,
strides, strides,
dilations, dilations,
Number<Direction>{}, Number<Direction>{},
in_nchw_device, out_nkhw,
nrepeat); nrepeat);
#elif 1 #elif 1
......
...@@ -46,10 +46,12 @@ template <index_t GridSize, ...@@ -46,10 +46,12 @@ template <index_t GridSize,
index_t WeiBlockCopyDstDataPerWrite_K> index_t WeiBlockCopyDstDataPerWrite_K>
struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
{ {
__device__ void Run(const Float* const __restrict__ p_in_global, __device__ void Run(Float* const __restrict__ p_conv_in_global,
const Float* const __restrict__ p_wei_global, const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) const Float* const __restrict__ p_conv_out_global) const
{ {
auto p_in_global = p_conv_out_global;
auto p_out_global = p_conv_in_global;
// this is a mess // this is a mess
// TODO: find more elegent way of specifying (or calculating) performance parameters // TODO: find more elegent way of specifying (or calculating) performance parameters
static_assert(N2 == GemmNPerThreadSubC, "wrong!"); static_assert(N2 == GemmNPerThreadSubC, "wrong!");
...@@ -69,10 +71,10 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw ...@@ -69,10 +71,10 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
constexpr auto True = integral_constant<bool, true>{}; constexpr auto True = integral_constant<bool, true>{};
constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{}; constexpr auto in_n_c_h_w_global_desc = OutGlobalDesc{};
// to-do: backward data: 1) ckyx: yx unfold, 2) merge cyx = e, 3 out = ek // to-do: backward data: 1) ckyx: yx unfold, 2) merge cyx = e, 3 out = ek
constexpr auto wei_k_c_1_1_global_desc = WeiGlobalDesc{}; constexpr auto wei_k_c_1_1_global_desc = WeiGlobalDesc{};
constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{}; constexpr auto out_n_k_h_w_global_desc = InGlobalDesc{};
constexpr index_t N = in_n_c_h_w_global_desc.GetLength(I0); constexpr index_t N = in_n_c_h_w_global_desc.GetLength(I0);
constexpr index_t C = in_n_c_h_w_global_desc.GetLength(I1); constexpr index_t C = in_n_c_h_w_global_desc.GetLength(I1);
...@@ -125,15 +127,6 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw ...@@ -125,15 +127,6 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
// constexpr auto in_lengths_new = Sequence<N0, N1, N2, Ho, Wo>{}; // constexpr auto in_lengths_new = Sequence<N0, N1, N2, Ho, Wo>{};
// constexpr auto in_strides_new =
// Sequence<in_n0_n1_n2_h_w_global_desc.GetStride(I0),
// in_n0_n1_n2_h_w_global_desc.GetStride(I1),
// in_n0_n1_n2_h_w_global_desc.GetStride(I2),
// in_n0_n1_n2_h_w_global_desc.GetStride(I3),
// in_n0_n1_n2_h_w_global_desc.GetStride(I4)>{};
// constexpr auto in_n0_n1_n2_h_w_new_global_desc =
// make_ConstantTensorDescriptor(in_lengths_new, in_strides_new);
constexpr auto in_n0_n1_n2_h_w_new_global_desc = in_n0_n1_n2_h_w_global_desc; constexpr auto in_n0_n1_n2_h_w_new_global_desc = in_n0_n1_n2_h_w_global_desc;
// batch descritpor for device memory // batch descritpor for device memory
...@@ -141,17 +134,8 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw ...@@ -141,17 +134,8 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
constexpr auto in_c_y_x_global_desc = in_n_c_h_w_global_desc.Slice(I2, Number<Y>{}) constexpr auto in_c_y_x_global_desc = in_n_c_h_w_global_desc.Slice(I2, Number<Y>{})
.Slice(I3, Number<X>{}) .Slice(I3, Number<X>{})
.Extract(Sequence<1, 2, 3>{}); .Extract(Sequence<1, 2, 3>{});
// constexpr auto in_win_lengths_new = Sequence<in_c_y_x_global_desc.GetLength(I0),
// in_c_y_x_global_desc.GetLength(I1),
// in_c_y_x_global_desc.GetLength(I2)>{};
// constexpr auto in_win_strides_new =
// Sequence<in_c_y_x_global_desc.GetStride(I0),
// in_c_y_x_global_desc.GetStride(I1),
// in_c_y_x_global_desc.GetStride(I2)>{};
// constexpr auto in_c_y_x_new_global_desc =
// make_ConstantTensorDescriptor(in_win_lengths_new, in_win_strides_new);
// merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy // merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy
constexpr auto in_e_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor( constexpr auto in_e_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor(
...@@ -192,12 +176,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw ...@@ -192,12 +176,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
// weight tensor // weight tensor
// tensor descriptor in device memory, src of blockwise copy // tensor descriptor in device memory, src of blockwise copy
#if 1 // backward
constexpr auto wei_e_k_global_desc = wei_k_c_1_1_global_desc.Unfold(I1, I3); constexpr auto wei_e_k_global_desc = wei_k_c_1_1_global_desc.Unfold(I1, I3);
#else
constexpr auto wei_e_k_global_desc =
wei_k_c_y_x_global_desc.Unfold(I1, I3).ReorderGivenNew2Old(Sequence<1, 0>{})
#endif
// tensor descriptor in LDS, dst of blockwise copy // tensor descriptor in LDS, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
...@@ -267,7 +246,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw ...@@ -267,7 +246,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
// choose GEMM implementation here // choose GEMM implementation here
const auto run_blockwise_gemm = [&](auto... Xs) { const auto run_blockwise_gemm = [&](auto... Xs) {
#if 1 #if 0
return blockwise_gemm.Run(Xs...); return blockwise_gemm.Run(Xs...);
#else #else
return blockwise_gemm.Run_asm(Xs...); return blockwise_gemm.Run_asm(Xs...);
......
#pragma once #pragma once
template <class GridwiseConvolution, class T> template <class GridwiseConvolution, class T>
__global__ void run_gridwise_convolution(const T* const __restrict__ p_in_global, __global__ void run_gridwise_convolution(T* const __restrict__ p_in_global,
const T* const __restrict__ p_wei_global, const T* const __restrict__ p_wei_global,
T* const __restrict__ p_out_global) T* const __restrict__ p_out_global)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment