"...git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "e5a87f35ad65c2a2c31bf6888293bfef1b963bed"
Commit e4d2fc6f authored by Jing Zhang's avatar Jing Zhang
Browse files

add support of dilation

parent b530abc4
...@@ -5,13 +5,14 @@ ...@@ -5,13 +5,14 @@
#include "gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hip.hpp" #include "gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hip.hpp"
#include "gridwise_convolution_implicit_gemm_v4_lds_double_buffer_nchw_kcyx_nkhw.hip.hpp" #include "gridwise_convolution_implicit_gemm_v4_lds_double_buffer_nchw_kcyx_nkhw.hip.hpp"
template <class T, class InDesc, class WeiDesc, class OutDesc, class Strides> template <class T, class InDesc, class WeiDesc, class OutDesc, class Strides, class Dilations>
void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc, void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
const Tensor<T>& in_nchw, const Tensor<T>& in_nchw,
WeiDesc, WeiDesc,
const Tensor<T>& wei_kcyx, const Tensor<T>& wei_kcyx,
OutDesc, OutDesc,
Strides, Strides,
Dilations,
Tensor<T>& out_nkhw, Tensor<T>& out_nkhw,
index_t nrepeat) index_t nrepeat)
{ {
...@@ -102,6 +103,7 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc, ...@@ -102,6 +103,7 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
<GridSize, <GridSize,
BlockSize, BlockSize,
Strides, Strides,
Dilations,
T, T,
decltype(in_nchw_desc), decltype(in_nchw_desc),
decltype(wei_kcyx_desc), decltype(wei_kcyx_desc),
......
...@@ -103,13 +103,20 @@ auto make_TensorDescriptor(TConstTensorDesc) ...@@ -103,13 +103,20 @@ auto make_TensorDescriptor(TConstTensorDesc)
return TensorDescriptor(lengths, strides); return TensorDescriptor(lengths, strides);
} }
template <class TIn, class TWei, class TOut, class LowerPads, class UpperPads, class Strides> template <class TIn,
class TWei,
class TOut,
class LowerPads,
class UpperPads,
class Strides,
class Dilations>
void host_direct_convolution(const Tensor<TIn>& in_nchw, void host_direct_convolution(const Tensor<TIn>& in_nchw,
const Tensor<TWei>& wei_kcyx, const Tensor<TWei>& wei_kcyx,
Tensor<TOut>& out_nkhw, Tensor<TOut>& out_nkhw,
LowerPads, LowerPads,
UpperPads, UpperPads,
Strides) Strides,
Dilations)
{ {
index_t h_pad_low = LowerPads{}.Get(Number<0>{}); index_t h_pad_low = LowerPads{}.Get(Number<0>{});
index_t w_pad_low = LowerPads{}.Get(Number<1>{}); index_t w_pad_low = LowerPads{}.Get(Number<1>{});
...@@ -120,16 +127,19 @@ void host_direct_convolution(const Tensor<TIn>& in_nchw, ...@@ -120,16 +127,19 @@ void host_direct_convolution(const Tensor<TIn>& in_nchw,
index_t stride_h = Strides{}.Get(Number<0>{}); index_t stride_h = Strides{}.Get(Number<0>{});
index_t stride_w = Strides{}.Get(Number<1>{}); index_t stride_w = Strides{}.Get(Number<1>{});
index_t dilation_h = Dilations{}.Get(Number<0>{});
index_t dilation_w = Dilations{}.Get(Number<1>{});
auto f = [&](auto n, auto k, auto ho, auto wo) { auto f = [&](auto n, auto k, auto ho, auto wo) {
double v = 0; double v = 0;
for(int c = 0; c < wei_kcyx.mDesc.GetLengths()[1]; ++c) for(int c = 0; c < wei_kcyx.mDesc.GetLengths()[1]; ++c)
{ {
for(int y = 0; y < wei_kcyx.mDesc.GetLengths()[2]; ++y) for(int y = 0; y < wei_kcyx.mDesc.GetLengths()[2]; ++y)
{ {
int hi = ho * stride_h + y - h_pad_low; int hi = ho * stride_h + y * dilation_h - h_pad_low;
for(int x = 0; x < wei_kcyx.mDesc.GetLengths()[3]; ++x) for(int x = 0; x < wei_kcyx.mDesc.GetLengths()[3]; ++x)
{ {
int wi = wo * stride_w + x - w_pad_low; int wi = wo * stride_w + x * dilation_w - w_pad_low;
if(hi >= 0 && hi < in_nchw.mDesc.GetLengths()[2] && wi >= 0 && if(hi >= 0 && hi < in_nchw.mDesc.GetLengths()[2] && wi >= 0 &&
wi < in_nchw.mDesc.GetLengths()[3]) wi < in_nchw.mDesc.GetLengths()[3])
{ {
...@@ -412,13 +422,16 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result) ...@@ -412,13 +422,16 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result)
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
constexpr index_t U = 2; constexpr index_t U = 1;
constexpr index_t V = 2; constexpr index_t V = 1;
constexpr index_t Dh = 2;
constexpr index_t Dw = 2;
#if 0 #if 0
constexpr index_t N = 8; constexpr index_t N = 8;
constexpr index_t C = 16; constexpr index_t C = 16;
constexpr index_t HI = 16; constexpr index_t HI = 20;
constexpr index_t WI = 16; constexpr index_t WI = 20;
constexpr index_t K = 128; constexpr index_t K = 128;
constexpr index_t Y = 1; constexpr index_t Y = 1;
constexpr index_t X = 1; constexpr index_t X = 1;
...@@ -449,7 +462,7 @@ int main(int argc, char* argv[]) ...@@ -449,7 +462,7 @@ int main(int argc, char* argv[])
constexpr index_t HPad = 0; constexpr index_t HPad = 0;
constexpr index_t WPad = 0; constexpr index_t WPad = 0;
#elif 0 #elif 1
// 3x3 filter, 28x28 image // 3x3 filter, 28x28 image
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 256; constexpr index_t C = 256;
...@@ -461,7 +474,7 @@ int main(int argc, char* argv[]) ...@@ -461,7 +474,7 @@ int main(int argc, char* argv[])
constexpr index_t HPad = 0; constexpr index_t HPad = 0;
constexpr index_t WPad = 0; constexpr index_t WPad = 0;
#elif 1 #elif 0
// 1x1 filter, 28x28 image // 1x1 filter, 28x28 image
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 512; constexpr index_t C = 512;
...@@ -586,12 +599,13 @@ int main(int argc, char* argv[]) ...@@ -586,12 +599,13 @@ int main(int argc, char* argv[])
auto lower_pads = Sequence<HPad, WPad>{}; auto lower_pads = Sequence<HPad, WPad>{};
auto upper_pads = Sequence<HPad, WPad>{}; auto upper_pads = Sequence<HPad, WPad>{};
auto strides = Sequence<U, V>{}; auto strides = Sequence<U, V>{};
auto dilations = Sequence<Dh, Dw>{};
auto in_nchw_desc = make_ConstantTensorDescriptor_packed(Sequence<N, C, HI, WI>{}); auto in_nchw_desc = make_ConstantTensorDescriptor_packed(Sequence<N, C, HI, WI>{});
auto wei_kcyx_desc = make_ConstantTensorDescriptor_packed(Sequence<K, C, Y, X>{}); auto wei_kcyx_desc = make_ConstantTensorDescriptor_packed(Sequence<K, C, Y, X>{});
auto out_nkhw_desc = auto out_nkhw_desc = get_convolution_output_default_4d_tensor_descriptor(
get_convolution_output_default_4d_tensor_descriptor(in_nchw_desc, wei_kcyx_desc, strides); in_nchw_desc, wei_kcyx_desc, strides, dilations);
ostream_ConstantTensorDescriptor(in_nchw_desc, std::cout << "in_nchw_desc: "); ostream_ConstantTensorDescriptor(in_nchw_desc, std::cout << "in_nchw_desc: ");
ostream_ConstantTensorDescriptor(wei_kcyx_desc, std::cout << "wei_kcyx_desc: "); ostream_ConstantTensorDescriptor(wei_kcyx_desc, std::cout << "wei_kcyx_desc: ");
...@@ -665,6 +679,7 @@ int main(int argc, char* argv[]) ...@@ -665,6 +679,7 @@ int main(int argc, char* argv[])
wei_kcyx, wei_kcyx,
out_nkhw_desc, out_nkhw_desc,
strides, strides,
dilations,
out_nkhw_device, out_nkhw_device,
nrepeat); nrepeat);
...@@ -691,7 +706,7 @@ int main(int argc, char* argv[]) ...@@ -691,7 +706,7 @@ int main(int argc, char* argv[])
#endif #endif
{ {
host_direct_convolution( host_direct_convolution(
in_nchw, wei_kcyx, out_nkhw_host, lower_pads, upper_pads, strides); in_nchw, wei_kcyx, out_nkhw_host, lower_pads, upper_pads, strides, dilations);
} }
check_error(out_nkhw_host, out_nkhw_device); check_error(out_nkhw_host, out_nkhw_device);
......
...@@ -2,8 +2,9 @@ ...@@ -2,8 +2,9 @@
#include "ConstantTensorDescriptor.hip.hpp" #include "ConstantTensorDescriptor.hip.hpp"
// this is ugly, only for 4d // this is ugly, only for 4d
template <class InDesc, class WeiDesc, class Strides> template <class InDesc, class WeiDesc, class Strides, class Dilations>
constexpr auto get_convolution_output_default_4d_tensor_descriptor(InDesc, WeiDesc, Strides) constexpr auto
get_convolution_output_default_4d_tensor_descriptor(InDesc, WeiDesc, Strides, Dilations)
{ {
constexpr auto in_desc = InDesc{}; constexpr auto in_desc = InDesc{};
constexpr auto wei_desc = WeiDesc{}; constexpr auto wei_desc = WeiDesc{};
...@@ -26,11 +27,14 @@ constexpr auto get_convolution_output_default_4d_tensor_descriptor(InDesc, WeiDe ...@@ -26,11 +27,14 @@ constexpr auto get_convolution_output_default_4d_tensor_descriptor(InDesc, WeiDe
constexpr auto Y = wei_desc.GetLength(I2); constexpr auto Y = wei_desc.GetLength(I2);
constexpr auto X = wei_desc.GetLength(I3); constexpr auto X = wei_desc.GetLength(I3);
constexpr index_t U = Strides{}.Get(I0); constexpr index_t stride_h = Strides{}.Get(I0);
constexpr index_t V = Strides{}.Get(I1); constexpr index_t stride_w = Strides{}.Get(I1);
constexpr auto HO = (HI - Y) / U + 1; constexpr index_t dilation_h = Dilations{}.Get(I0);
constexpr auto WO = (WI - X) / V + 1; constexpr index_t dilation_w = Dilations{}.Get(I1);
constexpr auto HO = (HI - Y - (Y - 1) * (dilation_h - 1)) / stride_h + 1;
constexpr auto WO = (WI - X - (X - 1) * (dilation_w - 1)) / stride_w + 1;
return make_ConstantTensorDescriptor_packed(Sequence<N, K, HO, WO>{}); return make_ConstantTensorDescriptor_packed(Sequence<N, K, HO, WO>{});
} }
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
template <index_t GridSize, template <index_t GridSize,
index_t BlockSize, index_t BlockSize,
class Strides, class Strides,
class Dilations,
class Float, class Float,
class InGlobalDesc, class InGlobalDesc,
class WeiGlobalDesc, class WeiGlobalDesc,
...@@ -121,9 +122,9 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw ...@@ -121,9 +122,9 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
in_n_c_h_w_global_desc.Fold(I0, Number<N1>{}, Number<N2>{}) in_n_c_h_w_global_desc.Fold(I0, Number<N1>{}, Number<N2>{})
.Extract(Sequence<0, 1, 2, 4, 5>{}); .Extract(Sequence<0, 1, 2, 4, 5>{});
constexpr auto new_lengths = Sequence<N0, N1, N2, Ho, Wo>{}; constexpr auto in_lengths_new = Sequence<N0, N1, N2, Ho, Wo>{};
constexpr auto new_strides = constexpr auto in_strides_new =
Sequence<in_n0_n1_n2_h_w_global_desc.GetStride(I0), Sequence<in_n0_n1_n2_h_w_global_desc.GetStride(I0),
in_n0_n1_n2_h_w_global_desc.GetStride(I1), in_n0_n1_n2_h_w_global_desc.GetStride(I1),
in_n0_n1_n2_h_w_global_desc.GetStride(I2), in_n0_n1_n2_h_w_global_desc.GetStride(I2),
...@@ -131,17 +132,28 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw ...@@ -131,17 +132,28 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
in_n0_n1_n2_h_w_global_desc.GetStride(I4) * Strides{}.Get(I1)>{}; in_n0_n1_n2_h_w_global_desc.GetStride(I4) * Strides{}.Get(I1)>{};
constexpr auto in_n0_n1_n2_h_w_new_global_desc = constexpr auto in_n0_n1_n2_h_w_new_global_desc =
make_ConstantTensorDescriptor(new_lengths, new_strides); make_ConstantTensorDescriptor(in_lengths_new, in_strides_new);
// batch descritpor for device memory // batch descritpor for device memory
// to-do: add dilation: keep lengths, modify strides // to-do: add dilation: keep lengths, modify strides
constexpr auto in_c_y_x_global_desc = in_n_c_h_w_global_desc.Slice(I2, Number<Y>{}) constexpr auto in_c_y_x_global_desc = in_n_c_h_w_global_desc.Slice(I2, Number<Y>{})
.Slice(I3, Number<X>{}) .Slice(I3, Number<X>{})
.Extract(Sequence<1, 2, 3>{}); .Extract(Sequence<1, 2, 3>{});
constexpr auto in_win_lengths_new = Sequence<in_c_y_x_global_desc.GetLength(I0),
in_c_y_x_global_desc.GetLength(I1),
in_c_y_x_global_desc.GetLength(I2)>{};
constexpr auto in_win_strides_new =
Sequence<in_c_y_x_global_desc.GetStride(I0),
in_c_y_x_global_desc.GetStride(I1) * Dilations{}.Get(I0),
in_c_y_x_global_desc.GetStride(I2) * Dilations{}.Get(I1)>{};
constexpr auto in_c_y_x_new_global_desc =
make_ConstantTensorDescriptor(in_win_lengths_new, in_win_strides_new);
// merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy // merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy
constexpr auto in_e_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor( constexpr auto in_e_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor(
in_c_y_x_global_desc.Embed(in_n0_n1_n2_h_w_new_global_desc), in_c_y_x_new_global_desc.Embed(in_n0_n1_n2_h_w_new_global_desc),
Sequence<0, 1, 2>{}, Sequence<0, 1, 2>{},
Sequence<4>{}, Sequence<4>{},
Sequence<3, 6, 7>{}, Sequence<3, 6, 7>{},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment