You need to sign in or sign up before continuing.
Commit b1cb48a0 authored by Chao Liu's avatar Chao Liu
Browse files

added strides and dilations suppport to implicit gemm v4

parent 1566b317
...@@ -22,6 +22,8 @@ template <index_t GridSize, ...@@ -22,6 +22,8 @@ template <index_t GridSize,
class InGlobalDesc, class InGlobalDesc,
class WeiGlobalDesc, class WeiGlobalDesc,
class OutGlobalDesc, class OutGlobalDesc,
class ConvStrides,
class ConvDilations,
index_t BPerBlock, index_t BPerBlock,
index_t KPerBlock, index_t KPerBlock,
index_t EPerBlock, index_t EPerBlock,
...@@ -117,15 +119,17 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -117,15 +119,17 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
// input tensor // input tensor
// tensor descriptor in device memory [N0, N1, N2, Ho, Wo] // tensor descriptor in device memory [N0, N1, N2, Ho, Wo]
constexpr auto in_n0_n1_n2_h_w_global_desc = in_n_c_h_w_global_desc.Slice(I2, Number<Ho>{}) constexpr auto in_n0_n1_n2_h_w_global_desc =
.Slice(I3, Number<Wo>{}) in_n_c_h_w_global_desc.StridedSlice(I2, Number<Ho>{}, Number<ConvStrides::Get(I0)>{})
.Fold(I0, Number<N1>{}, Number<N2>{}) .StridedSlice(I3, Number<Wo>{}, Number<ConvStrides::Get(I1)>{})
.Extract(Sequence<0, 1, 2, 4, 5>{}); .Fold(I0, Number<N1>{}, Number<N2>{})
.Extract(Sequence<0, 1, 2, 4, 5>{});
// batch descritpor for device memory // batch descritpor for device memory
constexpr auto in_c_y_x_global_desc = in_n_c_h_w_global_desc.Slice(I2, Number<Y>{}) constexpr auto in_c_y_x_global_desc =
.Slice(I3, Number<X>{}) in_n_c_h_w_global_desc.StridedSlice(I2, Number<Y>{}, Number<ConvDilations::Get(I0)>{})
.Extract(Sequence<1, 2, 3>{}); .StridedSlice(I3, Number<X>{}, Number<ConvDilations::Get(I1)>{})
.Extract(Sequence<1, 2, 3>{});
// merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy // merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy
constexpr auto in_e_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor( constexpr auto in_e_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor(
......
...@@ -320,6 +320,18 @@ struct ConstantTensorDescriptor ...@@ -320,6 +320,18 @@ struct ConstantTensorDescriptor
return ConstantTensorDescriptor<slice_lengths, Strides>{}; return ConstantTensorDescriptor<slice_lengths, Strides>{};
} }
template <index_t IDim, index_t SliceLength, index_t SliceStride>
__host__ __device__ static constexpr auto
StridedSlice(Number<IDim>, Number<SliceLength>, Number<SliceStride>)
{
constexpr index_t new_stride = Strides::Get(Number<IDim>{}) * SliceStride;
using new_lengths = decltype(Lengths::Modify(Number<IDim>{}, Number<SliceLength>{}));
using new_strides = decltype(Strides::Modify(Number<IDim>{}, Number<new_stride>{}));
return ConstantTensorDescriptor<new_lengths, new_strides>{};
}
template <index_t IDim, index_t... FoldIntervals> template <index_t IDim, index_t... FoldIntervals>
__host__ __device__ static constexpr auto Fold(Number<IDim>, Number<FoldIntervals>...) __host__ __device__ static constexpr auto Fold(Number<IDim>, Number<FoldIntervals>...)
{ {
......
...@@ -36,11 +36,14 @@ constexpr auto get_convolution_output_default_4d_tensor_descriptor(InDesc, WeiDe ...@@ -36,11 +36,14 @@ constexpr auto get_convolution_output_default_4d_tensor_descriptor(InDesc, WeiDe
return make_ConstantTensorDescriptor_packed(Sequence<N, K, HO, WO>{}); return make_ConstantTensorDescriptor_packed(Sequence<N, K, HO, WO>{});
} }
template <class InDesc, class WeiDesc, class LowerPads, class UpperPads> template <class InDesc,
constexpr auto get_convolution_with_padding_output_default_4d_tensor_descriptor(InDesc, class WeiDesc,
WeiDesc, class ConvStrides,
LowerPads, class ConvDilations,
UpperPads) class LowerPads,
class UpperPads>
constexpr auto get_convolution_with_padding_output_default_4d_tensor_descriptor(
InDesc, WeiDesc, ConvStrides, ConvDilations, LowerPads, UpperPads)
{ {
constexpr auto in_desc = InDesc{}; constexpr auto in_desc = InDesc{};
constexpr auto wei_desc = WeiDesc{}; constexpr auto wei_desc = WeiDesc{};
...@@ -55,24 +58,27 @@ constexpr auto get_convolution_with_padding_output_default_4d_tensor_descriptor( ...@@ -55,24 +58,27 @@ constexpr auto get_convolution_with_padding_output_default_4d_tensor_descriptor(
static_assert(in_desc.GetLength(I1) == wei_desc.GetLength(I1), static_assert(in_desc.GetLength(I1) == wei_desc.GetLength(I1),
"input & weight dimension not consistent"); "input & weight dimension not consistent");
constexpr auto N = in_desc.GetLength(I0); constexpr index_t N = in_desc.GetLength(I0);
constexpr auto HI = in_desc.GetLength(I2); constexpr index_t Hi = in_desc.GetLength(I2);
constexpr auto WI = in_desc.GetLength(I3); constexpr index_t Wi = in_desc.GetLength(I3);
constexpr auto K = wei_desc.GetLength(I0); constexpr index_t K = wei_desc.GetLength(I0);
constexpr auto Y = wei_desc.GetLength(I2); constexpr index_t Y = wei_desc.GetLength(I2);
constexpr auto X = wei_desc.GetLength(I3); constexpr index_t X = wei_desc.GetLength(I3);
constexpr auto HPadLow = LowerPads{}.Get(I0); constexpr index_t HPadLow = LowerPads{}.Get(I0);
constexpr auto WPadLow = LowerPads{}.Get(I1); constexpr index_t WPadLow = LowerPads{}.Get(I1);
constexpr auto HPadUp = UpperPads{}.Get(I0); constexpr index_t HPadUp = UpperPads{}.Get(I0);
constexpr auto WPadUp = UpperPads{}.Get(I1); constexpr index_t WPadUp = UpperPads{}.Get(I1);
constexpr auto HO = HI + HPadLow + HPadUp + 1 - Y; constexpr index_t YEff = (Y - 1) * ConvDilations{}[0] + 1;
constexpr auto WO = WI + WPadLow + WPadUp + 1 - X; constexpr index_t XEff = (X - 1) * ConvDilations{}[1] + 1;
return make_ConstantTensorDescriptor_packed(Sequence<N, K, HO, WO>{}); constexpr index_t Ho = (Hi + HPadLow + HPadUp - YEff) / ConvStrides{}[0] + 1;
constexpr index_t Wo = (Wi + WPadLow + WPadUp - XEff) / ConvStrides{}[1] + 1;
return make_ConstantTensorDescriptor_packed(Sequence<N, K, Ho, Wo>{});
} }
template <class InDesc, class WeiDesc, class OutDesc> template <class InDesc, class WeiDesc, class OutDesc>
......
...@@ -8,13 +8,20 @@ ...@@ -8,13 +8,20 @@
using namespace ck; using namespace ck;
template <class T, class InDesc, class WeiDesc, class OutDesc> template <class T,
class InDesc,
class WeiDesc,
class OutDesc,
class ConvStrides,
class ConvDilations>
void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc, void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
const Tensor<T>& in_nchw, const Tensor<T>& in_nchw,
WeiDesc, WeiDesc,
const Tensor<T>& wei_kcyx, const Tensor<T>& wei_kcyx,
OutDesc, OutDesc,
Tensor<T>& out_nkhw, Tensor<T>& out_nkhw,
ConvStrides,
ConvDilations,
index_t nrepeat) index_t nrepeat)
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
...@@ -107,6 +114,8 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc, ...@@ -107,6 +114,8 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
decltype(in_nchw_desc), decltype(in_nchw_desc),
decltype(wei_kcyx_desc), decltype(wei_kcyx_desc),
decltype(out_nkhw_desc), decltype(out_nkhw_desc),
ConvStrides,
ConvDilations,
BPerBlock, BPerBlock,
KPerBlock, KPerBlock,
CPerBlock, CPerBlock,
......
...@@ -103,10 +103,18 @@ auto make_TensorDescriptor(TConstTensorDesc) ...@@ -103,10 +103,18 @@ auto make_TensorDescriptor(TConstTensorDesc)
return TensorDescriptor(lengths, strides); return TensorDescriptor(lengths, strides);
} }
template <class TIn, class TWei, class TOut, class LowerPads, class UpperPads> template <class TIn,
class TWei,
class TOut,
class ConvStrides,
class ConvDilations,
class LowerPads,
class UpperPads>
void host_direct_convolution(const Tensor<TIn>& in_nchw, void host_direct_convolution(const Tensor<TIn>& in_nchw,
const Tensor<TWei>& wei_kcyx, const Tensor<TWei>& wei_kcyx,
Tensor<TOut>& out_nkhw, Tensor<TOut>& out_nkhw,
ConvStrides,
ConvDilations,
LowerPads, LowerPads,
UpperPads) UpperPads)
{ {
...@@ -122,10 +130,10 @@ void host_direct_convolution(const Tensor<TIn>& in_nchw, ...@@ -122,10 +130,10 @@ void host_direct_convolution(const Tensor<TIn>& in_nchw,
{ {
for(int y = 0; y < wei_kcyx.mDesc.GetLengths()[2]; ++y) for(int y = 0; y < wei_kcyx.mDesc.GetLengths()[2]; ++y)
{ {
int hi = ho + y - h_pad_low; int hi = ho * ConvStrides{}[0] + y * ConvDilations{}[0] - h_pad_low;
for(int x = 0; x < wei_kcyx.mDesc.GetLengths()[3]; ++x) for(int x = 0; x < wei_kcyx.mDesc.GetLengths()[3]; ++x)
{ {
int wi = wo + x - w_pad_low; int wi = wo * ConvStrides{}[1] + x * ConvDilations{}[1] - w_pad_low;
if(hi >= 0 && hi < in_nchw.mDesc.GetLengths()[2] && wi >= 0 && if(hi >= 0 && hi < in_nchw.mDesc.GetLengths()[2] && wi >= 0 &&
wi < in_nchw.mDesc.GetLengths()[3]) wi < in_nchw.mDesc.GetLengths()[3])
{ {
...@@ -419,9 +427,9 @@ int main(int argc, char* argv[]) ...@@ -419,9 +427,9 @@ int main(int argc, char* argv[])
constexpr index_t HPad = 0; constexpr index_t HPad = 0;
constexpr index_t WPad = 0; constexpr index_t WPad = 0;
#elif 1 #elif 0
// 3x3, 34x34 // 3x3, 34x34
constexpr index_t N = 64; constexpr index_t N = 128;
constexpr index_t C = 256; constexpr index_t C = 256;
constexpr index_t HI = 34; constexpr index_t HI = 34;
constexpr index_t WI = 34; constexpr index_t WI = 34;
...@@ -429,6 +437,9 @@ int main(int argc, char* argv[]) ...@@ -429,6 +437,9 @@ int main(int argc, char* argv[])
constexpr index_t Y = 3; constexpr index_t Y = 3;
constexpr index_t X = 3; constexpr index_t X = 3;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
constexpr index_t HPad = 0; constexpr index_t HPad = 0;
constexpr index_t WPad = 0; constexpr index_t WPad = 0;
#elif 0 #elif 0
...@@ -453,6 +464,9 @@ int main(int argc, char* argv[]) ...@@ -453,6 +464,9 @@ int main(int argc, char* argv[])
constexpr index_t Y = 3; constexpr index_t Y = 3;
constexpr index_t X = 3; constexpr index_t X = 3;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
constexpr index_t HPad = 0; constexpr index_t HPad = 0;
constexpr index_t WPad = 0; constexpr index_t WPad = 0;
#elif 0 #elif 0
...@@ -583,7 +597,7 @@ int main(int argc, char* argv[]) ...@@ -583,7 +597,7 @@ int main(int argc, char* argv[])
auto in_nchw_desc = make_ConstantTensorDescriptor_packed(Sequence<N, C, HI, WI>{}); auto in_nchw_desc = make_ConstantTensorDescriptor_packed(Sequence<N, C, HI, WI>{});
auto wei_kcyx_desc = make_ConstantTensorDescriptor_packed(Sequence<K, C, Y, X>{}); auto wei_kcyx_desc = make_ConstantTensorDescriptor_packed(Sequence<K, C, Y, X>{});
auto out_nkhw_desc = get_convolution_with_padding_output_default_4d_tensor_descriptor( auto out_nkhw_desc = get_convolution_with_padding_output_default_4d_tensor_descriptor(
in_nchw_desc, wei_kcyx_desc, lower_pads, upper_pads); in_nchw_desc, wei_kcyx_desc, ConvStrides{}, ConvDilations{}, lower_pads, upper_pads);
ostream_ConstantTensorDescriptor(in_nchw_desc, std::cout << "in_nchw_desc: "); ostream_ConstantTensorDescriptor(in_nchw_desc, std::cout << "in_nchw_desc: ");
ostream_ConstantTensorDescriptor(wei_kcyx_desc, std::cout << "wei_kcyx_desc: "); ostream_ConstantTensorDescriptor(wei_kcyx_desc, std::cout << "wei_kcyx_desc: ");
...@@ -645,9 +659,17 @@ int main(int argc, char* argv[]) ...@@ -645,9 +659,17 @@ int main(int argc, char* argv[])
#elif 1 #elif 1
device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw
#endif #endif
(in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat); (in_nchw_desc,
in_nchw,
wei_kcyx_desc,
wei_kcyx,
out_nkhw_desc,
out_nkhw_device,
ConvStrides{},
ConvDilations{},
nrepeat);
#elif 1 #elif 0
device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(in_nchw_desc, device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(in_nchw_desc,
in_nchw, in_nchw,
wei_kcyx_desc, wei_kcyx_desc,
...@@ -662,14 +684,21 @@ int main(int argc, char* argv[]) ...@@ -662,14 +684,21 @@ int main(int argc, char* argv[])
if(do_verification) if(do_verification)
{ {
#if 1 #if 1
if(Y == 3 && X == 3) if(Y == 3 && X == 3 && ConvStrides{}[0] == 1 && ConvStrides{}[1] == 1 &&
ConvDilations{}[0] == 1 && ConvDilations{}[1] == 1)
{ {
host_winograd_3x3_convolution(in_nchw, wei_kcyx, out_nkhw_host, lower_pads, upper_pads); host_winograd_3x3_convolution(in_nchw, wei_kcyx, out_nkhw_host, lower_pads, upper_pads);
} }
else else
#endif #endif
{ {
host_direct_convolution(in_nchw, wei_kcyx, out_nkhw_host, lower_pads, upper_pads); host_direct_convolution(in_nchw,
wei_kcyx,
out_nkhw_host,
ConvStrides{},
ConvDilations{},
lower_pads,
upper_pads);
} }
check_error(out_nkhw_host, out_nkhw_device); check_error(out_nkhw_host, out_nkhw_device);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment