Commit b530abc4 authored by Jing Zhang's avatar Jing Zhang
Browse files

add support of stride

parent eafdabba
......@@ -5,12 +5,13 @@
#include "gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hip.hpp"
#include "gridwise_convolution_implicit_gemm_v4_lds_double_buffer_nchw_kcyx_nkhw.hip.hpp"
template <class T, class InDesc, class WeiDesc, class OutDesc>
template <class T, class InDesc, class WeiDesc, class OutDesc, class Strides>
void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
const Tensor<T>& in_nchw,
WeiDesc,
const Tensor<T>& wei_kcyx,
OutDesc,
Strides,
Tensor<T>& out_nkhw,
index_t nrepeat)
{
......@@ -100,6 +101,7 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
#endif
<GridSize,
BlockSize,
Strides,
T,
decltype(in_nchw_desc),
decltype(wei_kcyx_desc),
......
......@@ -103,12 +103,13 @@ auto make_TensorDescriptor(TConstTensorDesc)
return TensorDescriptor(lengths, strides);
}
template <class TIn, class TWei, class TOut, class LowerPads, class UpperPads>
template <class TIn, class TWei, class TOut, class LowerPads, class UpperPads, class Strides>
void host_direct_convolution(const Tensor<TIn>& in_nchw,
const Tensor<TWei>& wei_kcyx,
Tensor<TOut>& out_nkhw,
LowerPads,
UpperPads)
UpperPads,
Strides)
{
index_t h_pad_low = LowerPads{}.Get(Number<0>{});
index_t w_pad_low = LowerPads{}.Get(Number<1>{});
......@@ -116,16 +117,19 @@ void host_direct_convolution(const Tensor<TIn>& in_nchw,
index_t h_pad_up = UpperPads{}.Get(Number<0>{});
index_t w_pad_up = UpperPads{}.Get(Number<1>{});
index_t stride_h = Strides{}.Get(Number<0>{});
index_t stride_w = Strides{}.Get(Number<1>{});
auto f = [&](auto n, auto k, auto ho, auto wo) {
double v = 0;
for(int c = 0; c < wei_kcyx.mDesc.GetLengths()[1]; ++c)
{
for(int y = 0; y < wei_kcyx.mDesc.GetLengths()[2]; ++y)
{
int hi = ho + y - h_pad_low;
int hi = ho * stride_h + y - h_pad_low;
for(int x = 0; x < wei_kcyx.mDesc.GetLengths()[3]; ++x)
{
int wi = wo + x - w_pad_low;
int wi = wo * stride_w + x - w_pad_low;
if(hi >= 0 && hi < in_nchw.mDesc.GetLengths()[2] && wi >= 0 &&
wi < in_nchw.mDesc.GetLengths()[3])
{
......@@ -408,14 +412,16 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result)
int main(int argc, char* argv[])
{
constexpr index_t U = 2;
constexpr index_t V = 2;
#if 0
constexpr index_t N = 8;
constexpr index_t C = 16;
constexpr index_t HI = 3;
constexpr index_t WI = 18;
constexpr index_t HI = 16;
constexpr index_t WI = 16;
constexpr index_t K = 128;
constexpr index_t Y = 3;
constexpr index_t X = 3;
constexpr index_t Y = 1;
constexpr index_t X = 1;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
......@@ -443,7 +449,7 @@ int main(int argc, char* argv[])
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 1
#elif 0
// 3x3 filter, 28x28 image
constexpr index_t N = 128;
constexpr index_t C = 256;
......@@ -455,7 +461,7 @@ int main(int argc, char* argv[])
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 0
#elif 1
// 1x1 filter, 28x28 image
constexpr index_t N = 128;
constexpr index_t C = 512;
......@@ -580,10 +586,12 @@ int main(int argc, char* argv[])
auto lower_pads = Sequence<HPad, WPad>{};
auto upper_pads = Sequence<HPad, WPad>{};
auto strides = Sequence<U, V>{};
auto in_nchw_desc = make_ConstantTensorDescriptor_packed(Sequence<N, C, HI, WI>{});
auto wei_kcyx_desc = make_ConstantTensorDescriptor_packed(Sequence<K, C, Y, X>{});
auto out_nkhw_desc = get_convolution_with_padding_output_default_4d_tensor_descriptor(
in_nchw_desc, wei_kcyx_desc, lower_pads, upper_pads);
auto out_nkhw_desc =
get_convolution_output_default_4d_tensor_descriptor(in_nchw_desc, wei_kcyx_desc, strides);
ostream_ConstantTensorDescriptor(in_nchw_desc, std::cout << "in_nchw_desc: ");
ostream_ConstantTensorDescriptor(wei_kcyx_desc, std::cout << "wei_kcyx_desc: ");
......@@ -651,7 +659,14 @@ int main(int argc, char* argv[])
#elif 1
device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw
#endif
(in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat);
(in_nchw_desc,
in_nchw,
wei_kcyx_desc,
wei_kcyx,
out_nkhw_desc,
strides,
out_nkhw_device,
nrepeat);
#elif 1
device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(in_nchw_desc,
......@@ -667,7 +682,7 @@ int main(int argc, char* argv[])
if(do_verification)
{
#if 1
#if 0
if(Y == 3 && X == 3)
{
host_winograd_3x3_convolution(in_nchw, wei_kcyx, out_nkhw_host, lower_pads, upper_pads);
......@@ -675,7 +690,8 @@ int main(int argc, char* argv[])
else
#endif
{
host_direct_convolution(in_nchw, wei_kcyx, out_nkhw_host, lower_pads, upper_pads);
host_direct_convolution(
in_nchw, wei_kcyx, out_nkhw_host, lower_pads, upper_pads, strides);
}
check_error(out_nkhw_host, out_nkhw_device);
......
......@@ -2,8 +2,8 @@
#include "ConstantTensorDescriptor.hip.hpp"
// this is ugly, only for 4d
template <class InDesc, class WeiDesc>
constexpr auto get_convolution_output_default_4d_tensor_descriptor(InDesc, WeiDesc)
template <class InDesc, class WeiDesc, class Strides>
constexpr auto get_convolution_output_default_4d_tensor_descriptor(InDesc, WeiDesc, Strides)
{
constexpr auto in_desc = InDesc{};
constexpr auto wei_desc = WeiDesc{};
......@@ -26,8 +26,11 @@ constexpr auto get_convolution_output_default_4d_tensor_descriptor(InDesc, WeiDe
constexpr auto Y = wei_desc.GetLength(I2);
constexpr auto X = wei_desc.GetLength(I3);
constexpr auto HO = HI + 1 - Y;
constexpr auto WO = WI + 1 - X;
constexpr index_t U = Strides{}.Get(I0);
constexpr index_t V = Strides{}.Get(I1);
constexpr auto HO = (HI - Y) / U + 1;
constexpr auto WO = (WI - X) / V + 1;
return make_ConstantTensorDescriptor_packed(Sequence<N, K, HO, WO>{});
}
......
......@@ -10,6 +10,7 @@
// define B = merge(N, Ho, Wo)
template <index_t GridSize,
index_t BlockSize,
class Strides,
class Float,
class InGlobalDesc,
class WeiGlobalDesc,
......@@ -67,7 +68,8 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
constexpr auto True = integral_constant<bool, true>{};
constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{};
constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{};
// to-do: backward data: 1) ckyx: yx unfold, 2) merge cyx = e, 3 out = ek
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{};
......@@ -109,19 +111,37 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
// input tensor
// tensor descriptor in device memory [N0, N1, N2, Ho, Wo]
constexpr auto in_n0_n1_n2_h_w_global_desc = in_n_c_h_w_global_desc.Slice(I2, Number<Ho>{})
.Slice(I3, Number<Wo>{})
.Fold(I0, Number<N1>{}, Number<N2>{})
.Extract(Sequence<0, 1, 2, 4, 5>{});
// constexpr auto in_n0_n1_n2_h_w_global_desc = in_n_c_h_w_global_desc.Slice(I2,
// Number<Ho>{})
//.Slice(I3, Number<Wo>{})
//.Fold(I0, Number<N1>{}, Number<N2>{})
//.Extract(Sequence<0, 1, 2, 4, 5>{});
constexpr auto in_n0_n1_n2_h_w_global_desc =
in_n_c_h_w_global_desc.Fold(I0, Number<N1>{}, Number<N2>{})
.Extract(Sequence<0, 1, 2, 4, 5>{});
constexpr auto new_lengths = Sequence<N0, N1, N2, Ho, Wo>{};
constexpr auto new_strides =
Sequence<in_n0_n1_n2_h_w_global_desc.GetStride(I0),
in_n0_n1_n2_h_w_global_desc.GetStride(I1),
in_n0_n1_n2_h_w_global_desc.GetStride(I2),
in_n0_n1_n2_h_w_global_desc.GetStride(I3) * Strides{}.Get(I0),
in_n0_n1_n2_h_w_global_desc.GetStride(I4) * Strides{}.Get(I1)>{};
constexpr auto in_n0_n1_n2_h_w_new_global_desc =
make_ConstantTensorDescriptor(new_lengths, new_strides);
// batch descritpor for device memory
// to-do: add dilation: keep lengths, modify strides
constexpr auto in_c_y_x_global_desc = in_n_c_h_w_global_desc.Slice(I2, Number<Y>{})
.Slice(I3, Number<X>{})
.Extract(Sequence<1, 2, 3>{});
// merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy
constexpr auto in_e_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor(
in_c_y_x_global_desc.Embed(in_n0_n1_n2_h_w_global_desc),
in_c_y_x_global_desc.Embed(in_n0_n1_n2_h_w_new_global_desc),
Sequence<0, 1, 2>{},
Sequence<4>{},
Sequence<3, 6, 7>{},
......@@ -246,7 +266,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
// choose GEMM implementation here
const auto run_blockwise_gemm = [&](auto... Xs) {
#if 1
#if 0
return blockwise_gemm.Run(Xs...);
#else
return blockwise_gemm.Run_asm(Xs...);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment