Commit 81ec25b4 authored by Jing Zhang's avatar Jing Zhang
Browse files

add backward without exchange in-out

parent ecaad8c0
......@@ -13,7 +13,7 @@ template <class T,
class Dilations,
index_t Direction>
void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
const Tensor<T>& in_nchw,
Tensor<T>& in_nchw,
WeiDesc,
const Tensor<T>& wei_kcyx,
OutDesc,
......@@ -56,8 +56,9 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
constexpr index_t N1 = 2;
constexpr index_t N2 = 4;
constexpr index_t B = N * mod_conv::integer_divide_ceil(Ho, Strides::Get(I0)) *
mod_conv::integer_divide_ceil(Wo, Strides::Get(I1)) / (N1 * N2);
//constexpr index_t B = N * mod_conv::integer_divide_ceil(Ho, Strides::Get(I0)) *
//mod_conv::integer_divide_ceil(Wo, Strides::Get(I1)) / (N1 * N2);
constexpr index_t B = (N * Ho * Wo) / (N1 * N2);
#if 1
constexpr index_t BlockSize = 256;
......@@ -160,5 +161,8 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
usleep(std::min(time * 1000, float(10000)));
}
out_nkhw_device_buf.FromDevice(out_nkhw.mData.data());
if(Direction == 1)
out_nkhw_device_buf.FromDevice(out_nkhw.mData.data());
else
in_nchw_device_buf.FromDevice(in_nchw.mData.data());
}
......@@ -754,15 +754,15 @@ int main(int argc, char* argv[])
#elif 1
device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw
#endif
(out_nkhw_desc,
out_nkhw,
(in_nchw_desc,
in_nchw_device,
wei_kcyx_desc,
wei_kcyx,
in_nchw_desc,
out_nkhw_desc,
strides,
dilations,
Number<Direction>{},
in_nchw_device,
out_nkhw,
nrepeat);
#elif 1
......
......@@ -46,10 +46,12 @@ template <index_t GridSize,
index_t WeiBlockCopyDstDataPerWrite_K>
struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
{
__device__ void Run(const Float* const __restrict__ p_in_global,
__device__ void Run(Float* const __restrict__ p_conv_in_global,
const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) const
Float* const __restrict__ p_conv_out_global) const
{
auto p_in_global = p_conv_out_global;
auto p_out_global = p_conv_in_global;
// this is a mess
// TODO: find more elegent way of specifying (or calculating) performance parameters
static_assert(N2 == GemmNPerThreadSubC, "wrong!");
......@@ -69,10 +71,10 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
constexpr auto True = integral_constant<bool, true>{};
constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{};
constexpr auto in_n_c_h_w_global_desc = OutGlobalDesc{};
// to-do: backward data: 1) ckyx: yx unfold, 2) merge cyx = e, 3 out = ek
constexpr auto wei_k_c_1_1_global_desc = WeiGlobalDesc{};
constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{};
constexpr auto out_n_k_h_w_global_desc = InGlobalDesc{};
constexpr index_t N = in_n_c_h_w_global_desc.GetLength(I0);
constexpr index_t C = in_n_c_h_w_global_desc.GetLength(I1);
......@@ -125,15 +127,6 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
// constexpr auto in_lengths_new = Sequence<N0, N1, N2, Ho, Wo>{};
// constexpr auto in_strides_new =
// Sequence<in_n0_n1_n2_h_w_global_desc.GetStride(I0),
// in_n0_n1_n2_h_w_global_desc.GetStride(I1),
// in_n0_n1_n2_h_w_global_desc.GetStride(I2),
// in_n0_n1_n2_h_w_global_desc.GetStride(I3),
// in_n0_n1_n2_h_w_global_desc.GetStride(I4)>{};
// constexpr auto in_n0_n1_n2_h_w_new_global_desc =
// make_ConstantTensorDescriptor(in_lengths_new, in_strides_new);
constexpr auto in_n0_n1_n2_h_w_new_global_desc = in_n0_n1_n2_h_w_global_desc;
// batch descritpor for device memory
......@@ -141,17 +134,8 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
constexpr auto in_c_y_x_global_desc = in_n_c_h_w_global_desc.Slice(I2, Number<Y>{})
.Slice(I3, Number<X>{})
.Extract(Sequence<1, 2, 3>{});
// constexpr auto in_win_lengths_new = Sequence<in_c_y_x_global_desc.GetLength(I0),
// in_c_y_x_global_desc.GetLength(I1),
// in_c_y_x_global_desc.GetLength(I2)>{};
// constexpr auto in_win_strides_new =
// Sequence<in_c_y_x_global_desc.GetStride(I0),
// in_c_y_x_global_desc.GetStride(I1),
// in_c_y_x_global_desc.GetStride(I2)>{};
// constexpr auto in_c_y_x_new_global_desc =
// make_ConstantTensorDescriptor(in_win_lengths_new, in_win_strides_new);
// merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy
constexpr auto in_e_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor(
......@@ -192,12 +176,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
// weight tensor
// tensor descriptor in device memory, src of blockwise copy
#if 1 // backward
constexpr auto wei_e_k_global_desc = wei_k_c_1_1_global_desc.Unfold(I1, I3);
#else
constexpr auto wei_e_k_global_desc =
wei_k_c_y_x_global_desc.Unfold(I1, I3).ReorderGivenNew2Old(Sequence<1, 0>{})
#endif
// tensor descriptor in LDS, dst of blockwise copy
// be careful of LDS alignment
......@@ -267,7 +246,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
// choose GEMM implementation here
const auto run_blockwise_gemm = [&](auto... Xs) {
#if 1
#if 0
return blockwise_gemm.Run(Xs...);
#else
return blockwise_gemm.Run_asm(Xs...);
......
#pragma once
template <class GridwiseConvolution, class T>
__global__ void run_gridwise_convolution(const T* const __restrict__ p_in_global,
__global__ void run_gridwise_convolution(T* const __restrict__ p_in_global,
const T* const __restrict__ p_wei_global,
T* const __restrict__ p_out_global)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment