Commit 0b10c0bb authored by Jing Zhang's avatar Jing Zhang
Browse files

add backward

parent e4d2fc6f
......@@ -77,14 +77,14 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
using WeiBlockCopySubLengths_E_K = Sequence<4, 1>;
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>;
using WeiBlockCopySubLengths_E_K = Sequence<1, 4>;
using WeiBlockCopyClusterLengths_E_K = Sequence<8, 32>;
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
constexpr index_t WeiBlockCopySrcDataPerRead_E = 1;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 4;
#endif
constexpr index_t GridSize =
......
......@@ -110,7 +110,7 @@ template <class TIn,
class UpperPads,
class Strides,
class Dilations>
void host_direct_convolution(const Tensor<TIn>& in_nchw,
void host_direct_convolution_forw(const Tensor<TIn>& in_nchw,
const Tensor<TWei>& wei_kcyx,
Tensor<TOut>& out_nkhw,
LowerPads,
......@@ -160,6 +160,66 @@ void host_direct_convolution(const Tensor<TIn>& in_nchw,
f_par(std::thread::hardware_concurrency());
}
template <class TIn,
class TWei,
class TOut,
class LowerPads,
class UpperPads,
class Strides,
class Dilations>
void host_direct_convolution_back(Tensor<TOut>& in_nchw,
const Tensor<TWei>& wei_kcyx,
const Tensor<TIn>& out_nkhw,
LowerPads,
UpperPads,
Strides,
Dilations
)
{
index_t h_pad_low = LowerPads{}.Get(Number<0>{});
index_t w_pad_low = LowerPads{}.Get(Number<1>{});
index_t h_pad_up = UpperPads{}.Get(Number<0>{});
index_t w_pad_up = UpperPads{}.Get(Number<1>{});
index_t stride_h = Strides{}.Get(Number<0>{});
index_t stride_w = Strides{}.Get(Number<1>{});
index_t dilation_h = Dilations{}.Get(Number<0>{});
index_t dilation_w = Dilations{}.Get(Number<1>{});
//loop n,c,hi,wi
auto f = [&](auto n, auto c, auto hi, auto wi) {
double v = 0;
//loop k,y,x
for(int k = 0; k < wei_kcyx.mDesc.GetLengths()[0]; ++k)
{
for(int y = 0; y < wei_kcyx.mDesc.GetLengths()[2]; ++y)
{
int ho = (hi - y * dilation_h + h_pad_low) / stride_h;
for(int x = 0; x < wei_kcyx.mDesc.GetLengths()[3]; ++x)
{
int wo = (wi - x * dilation_w + w_pad_low) / stride_w;
if(ho >= 0 && hi < out_nkhw.mDesc.GetLengths()[2] && wo >= 0 &&
wo < out_nkhw.mDesc.GetLengths()[3] && ho % stride_h == 0 && wo % stride_w == 0)
{
v += double(out_nkhw(n, k, ho, wo)) * double(wei_kcyx(k, c, y, x));
}
}
}
}
in_nchw(n, c, hi, wi) = v;
};
auto f_par = make_ParallelTensorFunctor(f,
in_nchw.mDesc.GetLengths()[0],
in_nchw.mDesc.GetLengths()[1],
in_nchw.mDesc.GetLengths()[2],
in_nchw.mDesc.GetLengths()[3]);
f_par(std::thread::hardware_concurrency());
}
template <class TIn, class TWei, class TOut, class LowerPads, class UpperPads>
void host_winograd_3x3_convolution(const Tensor<TIn>& in_nchw,
const Tensor<TWei>& wei_kcyx,
......@@ -422,16 +482,18 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result)
int main(int argc, char* argv[])
{
constexpr index_t U = 1;
constexpr index_t V = 1;
constexpr index_t HStride = 2;
constexpr index_t WStride = 2;
constexpr index_t Dh = 2;
constexpr index_t Dw = 2;
#if 0
constexpr index_t HDilation = 1;
constexpr index_t WDilation = 1;
constexpr index_t Direction = 2; //1: Forward; 2:Backward
#if 1
constexpr index_t N = 8;
constexpr index_t C = 16;
constexpr index_t HI = 20;
constexpr index_t WI = 20;
constexpr index_t C = 128;
constexpr index_t HI = 16;
constexpr index_t WI = 16;
constexpr index_t K = 128;
constexpr index_t Y = 1;
constexpr index_t X = 1;
......@@ -462,7 +524,7 @@ int main(int argc, char* argv[])
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 1
#elif 0
// 3x3 filter, 28x28 image
constexpr index_t N = 128;
constexpr index_t C = 256;
......@@ -474,7 +536,7 @@ int main(int argc, char* argv[])
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 0
#elif 1
// 1x1 filter, 28x28 image
constexpr index_t N = 128;
constexpr index_t C = 512;
......@@ -599,14 +661,16 @@ int main(int argc, char* argv[])
auto lower_pads = Sequence<HPad, WPad>{};
auto upper_pads = Sequence<HPad, WPad>{};
auto strides = Sequence<U, V>{};
auto dilations = Sequence<Dh, Dw>{};
auto strides = Sequence<HStride, WStride>{};
auto dilations = Sequence<HDilation, WDilation>{};
auto in_nchw_desc = make_ConstantTensorDescriptor_packed(Sequence<N, C, HI, WI>{});
auto wei_kcyx_desc = make_ConstantTensorDescriptor_packed(Sequence<K, C, Y, X>{});
auto out_nkhw_desc = get_convolution_output_default_4d_tensor_descriptor(
in_nchw_desc, wei_kcyx_desc, strides, dilations);
auto wei_ckyx_back_desc = wei_kcyx_desc.ReorderGivenNew2Old(Sequence<1, 0, 2, 3>{});
ostream_ConstantTensorDescriptor(in_nchw_desc, std::cout << "in_nchw_desc: ");
ostream_ConstantTensorDescriptor(wei_kcyx_desc, std::cout << "wei_kcyx_desc: ");
ostream_ConstantTensorDescriptor(out_nkhw_desc, std::cout << "out_nkhw_desc: ");
......@@ -614,9 +678,10 @@ int main(int argc, char* argv[])
using in_data_t = float;
using out_data_t = float;
Tensor<in_data_t> in_nchw(make_TensorDescriptor(in_nchw_desc));
Tensor<in_data_t> wei_kcyx(make_TensorDescriptor(wei_kcyx_desc));
Tensor<out_data_t> out_nkhw_host(make_TensorDescriptor(out_nkhw_desc));
Tensor<in_data_t> out_nkhw(make_TensorDescriptor(out_nkhw_desc));
Tensor<in_data_t> in_nchw_device(make_TensorDescriptor(in_nchw_desc));
Tensor<out_data_t> out_nkhw_device(make_TensorDescriptor(out_nkhw_desc));
Tensor<in_data_t> wei_kcyx(make_TensorDescriptor(wei_kcyx_desc));
std::size_t num_thread = std::thread::hardware_concurrency();
......@@ -642,6 +707,7 @@ int main(int argc, char* argv[])
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
#elif 1
in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
out_nkhw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
#elif 0
in_nchw.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread);
......@@ -673,15 +739,16 @@ int main(int argc, char* argv[])
#elif 1
device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw
#endif
(in_nchw_desc,
in_nchw,
wei_kcyx_desc,
(out_nkhw_desc,
out_nkhw,
wei_ckyx_back_desc,
wei_kcyx,
out_nkhw_desc,
in_nchw_desc,
strides,
dilations,
out_nkhw_device,
nrepeat);
in_nchw_device,
nrepeat
);
#elif 1
device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(in_nchw_desc,
......@@ -704,11 +771,18 @@ int main(int argc, char* argv[])
}
else
#endif
if(Direction == 1)
{
host_direct_convolution_forw(
in_nchw, wei_kcyx, out_nkhw, lower_pads, upper_pads, strides, dilations);
check_error(out_nkhw, out_nkhw_device);
}
else
{
host_direct_convolution(
in_nchw, wei_kcyx, out_nkhw_host, lower_pads, upper_pads, strides, dilations);
host_direct_convolution_back(
in_nchw, wei_kcyx, out_nkhw, lower_pads, upper_pads, strides, dilations);
check_error(in_nchw, in_nchw_device);
}
check_error(out_nkhw_host, out_nkhw_device);
#if 0
LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl;
......
......@@ -112,48 +112,49 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
// input tensor
// tensor descriptor in device memory [N0, N1, N2, Ho, Wo]
// constexpr auto in_n0_n1_n2_h_w_global_desc = in_n_c_h_w_global_desc.Slice(I2,
// Number<Ho>{})
//.Slice(I3, Number<Wo>{})
//.Fold(I0, Number<N1>{}, Number<N2>{})
//.Extract(Sequence<0, 1, 2, 4, 5>{});
constexpr auto in_n0_n1_n2_h_w_global_desc = in_n_c_h_w_global_desc.Slice(I2,
Number<Ho>{})
.Slice(I3, Number<Wo>{})
.Fold(I0, Number<N1>{}, Number<N2>{})
.Extract(Sequence<0, 1, 2, 4, 5>{});
constexpr auto in_n0_n1_n2_h_w_global_desc =
in_n_c_h_w_global_desc.Fold(I0, Number<N1>{}, Number<N2>{})
.Extract(Sequence<0, 1, 2, 4, 5>{});
//constexpr auto in_n0_n1_n2_h_w_global_desc =
//in_n_c_h_w_global_desc.Fold(I0, Number<N1>{}, Number<N2>{})
//.Extract(Sequence<0, 1, 2, 4, 5>{});
constexpr auto in_lengths_new = Sequence<N0, N1, N2, Ho, Wo>{};
//constexpr auto in_lengths_new = Sequence<N0, N1, N2, Ho, Wo>{};
constexpr auto in_strides_new =
Sequence<in_n0_n1_n2_h_w_global_desc.GetStride(I0),
in_n0_n1_n2_h_w_global_desc.GetStride(I1),
in_n0_n1_n2_h_w_global_desc.GetStride(I2),
in_n0_n1_n2_h_w_global_desc.GetStride(I3) * Strides{}.Get(I0),
in_n0_n1_n2_h_w_global_desc.GetStride(I4) * Strides{}.Get(I1)>{};
//constexpr auto in_strides_new =
//Sequence<in_n0_n1_n2_h_w_global_desc.GetStride(I0),
//in_n0_n1_n2_h_w_global_desc.GetStride(I1),
//in_n0_n1_n2_h_w_global_desc.GetStride(I2),
//in_n0_n1_n2_h_w_global_desc.GetStride(I3),
//in_n0_n1_n2_h_w_global_desc.GetStride(I4)>{};
constexpr auto in_n0_n1_n2_h_w_new_global_desc =
make_ConstantTensorDescriptor(in_lengths_new, in_strides_new);
//constexpr auto in_n0_n1_n2_h_w_new_global_desc =
//make_ConstantTensorDescriptor(in_lengths_new, in_strides_new);
constexpr auto in_n0_n1_n2_h_w_new_global_desc = in_n0_n1_n2_h_w_global_desc;
// batch descritpor for device memory
// to-do: add dilation: keep lengths, modify strides
constexpr auto in_c_y_x_global_desc = in_n_c_h_w_global_desc.Slice(I2, Number<Y>{})
.Slice(I3, Number<X>{})
.Extract(Sequence<1, 2, 3>{});
constexpr auto in_win_lengths_new = Sequence<in_c_y_x_global_desc.GetLength(I0),
in_c_y_x_global_desc.GetLength(I1),
in_c_y_x_global_desc.GetLength(I2)>{};
//constexpr auto in_win_lengths_new = Sequence<in_c_y_x_global_desc.GetLength(I0),
//in_c_y_x_global_desc.GetLength(I1),
//in_c_y_x_global_desc.GetLength(I2)>{};
constexpr auto in_win_strides_new =
Sequence<in_c_y_x_global_desc.GetStride(I0),
in_c_y_x_global_desc.GetStride(I1) * Dilations{}.Get(I0),
in_c_y_x_global_desc.GetStride(I2) * Dilations{}.Get(I1)>{};
//constexpr auto in_win_strides_new =
//Sequence<in_c_y_x_global_desc.GetStride(I0),
//in_c_y_x_global_desc.GetStride(I1),
//in_c_y_x_global_desc.GetStride(I2)>{};
constexpr auto in_c_y_x_new_global_desc =
make_ConstantTensorDescriptor(in_win_lengths_new, in_win_strides_new);
//constexpr auto in_c_y_x_new_global_desc =
//make_ConstantTensorDescriptor(in_win_lengths_new, in_win_strides_new);
// merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy
constexpr auto in_e_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor(
in_c_y_x_new_global_desc.Embed(in_n0_n1_n2_h_w_new_global_desc),
in_c_y_x_global_desc.Embed(in_n0_n1_n2_h_w_new_global_desc),
Sequence<0, 1, 2>{},
Sequence<4>{},
Sequence<3, 6, 7>{},
......@@ -190,8 +191,15 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
// weight tensor
// tensor descriptor in device memory, src of blockwise copy
#if 0
constexpr auto wei_e_k_global_desc =
wei_k_c_y_x_global_desc.Unfold(I1, I3).ReorderGivenNew2Old(Sequence<1, 0>{});
#else
constexpr auto wei_e_k_global_desc =
make_ConstantMergedTensorDescriptor(wei_k_c_y_x_global_desc,
Sequence<1, 2, 3>{},
Sequence<0>{});
#endif
// tensor descriptor in LDS, dst of blockwise copy
// be careful of LDS alignment
......@@ -356,7 +364,11 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True);
#if 0
p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0);
#else
blockwise_wei_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True);
#endif
__syncthreads();
......@@ -383,7 +395,11 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
// even iteration
blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True);
#if 0
p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0);
#else
blockwise_wei_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True);
#endif
__syncthreads();
......@@ -432,6 +448,34 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
out_n_k_h_w_global_desc.Fold(I1, Number<K1>{}, Number<K2>{})
.Fold(I0, Number<N1>{}, Number<N2>{});
constexpr auto out_lengths_new = Sequence<
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.GetLength(I0),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.GetLength(I1),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.GetLength(I2),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.GetLength(I3),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.GetLength(I4),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.GetLength(I5),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.GetLength(I6) / Strides{}.Get(I0),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.GetLength(I7) / Strides{}.Get(I1)
>{};
constexpr auto out_strides_new = Sequence<
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.GetStride(I0),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.GetStride(I1),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.GetStride(I2),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.GetStride(I3),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.GetStride(I4),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.GetStride(I5),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.GetStride(I6) * Strides{}.Get(I0),
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.GetStride(I7) * Strides{}.Get(I1)
>{};
constexpr auto out_n0_n1_n2_k0_k1_k2_h_w_new_global_mem_desc = make_ConstantTensorDescriptor(
out_lengths_new, out_strides_new
);
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
......@@ -446,7 +490,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
// output merged global tensor descriptor, for calculating origin of thread tensor
// in global memory
constexpr auto out_k_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor(
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.Unfold(I3, I5),
out_n0_n1_n2_k0_k1_k2_h_w_new_global_mem_desc.Unfold(I3, I5),
Sequence<3>{},
Sequence<1>{},
Sequence<0, 4, 5>{},
......@@ -462,7 +506,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc,
p_out_thread,
{0, 0, 0, 0, 0, 0, 0, 0},
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc,
out_n0_n1_n2_k0_k1_k2_h_w_new_global_mem_desc,
p_out_thread_on_global,
{0, 0, 0, 0, 0, 0, 0, 0},
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment