"src/git@developer.sourcefind.cn:guobj/qwen_lmdeploy.git" did not exist on "e357c71fae0ac571d86e394743a07d71c163222b"
Commit b1cb48a0 authored by Chao Liu's avatar Chao Liu
Browse files

added strides and dilations suppport to implicit gemm v4

parent 1566b317
...@@ -22,6 +22,8 @@ template <index_t GridSize, ...@@ -22,6 +22,8 @@ template <index_t GridSize,
class InGlobalDesc, class InGlobalDesc,
class WeiGlobalDesc, class WeiGlobalDesc,
class OutGlobalDesc, class OutGlobalDesc,
class ConvStrides,
class ConvDilations,
index_t BPerBlock, index_t BPerBlock,
index_t KPerBlock, index_t KPerBlock,
index_t EPerBlock, index_t EPerBlock,
...@@ -117,15 +119,17 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer ...@@ -117,15 +119,17 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
// input tensor // input tensor
// tensor descriptor in device memory [N0, N1, N2, Ho, Wo] // tensor descriptor in device memory [N0, N1, N2, Ho, Wo]
constexpr auto in_n0_n1_n2_h_w_global_desc = in_n_c_h_w_global_desc.Slice(I2, Number<Ho>{}) constexpr auto in_n0_n1_n2_h_w_global_desc =
.Slice(I3, Number<Wo>{}) in_n_c_h_w_global_desc.StridedSlice(I2, Number<Ho>{}, Number<ConvStrides::Get(I0)>{})
.Fold(I0, Number<N1>{}, Number<N2>{}) .StridedSlice(I3, Number<Wo>{}, Number<ConvStrides::Get(I1)>{})
.Extract(Sequence<0, 1, 2, 4, 5>{}); .Fold(I0, Number<N1>{}, Number<N2>{})
.Extract(Sequence<0, 1, 2, 4, 5>{});
// batch descritpor for device memory // batch descritpor for device memory
constexpr auto in_c_y_x_global_desc = in_n_c_h_w_global_desc.Slice(I2, Number<Y>{}) constexpr auto in_c_y_x_global_desc =
.Slice(I3, Number<X>{}) in_n_c_h_w_global_desc.StridedSlice(I2, Number<Y>{}, Number<ConvDilations::Get(I0)>{})
.Extract(Sequence<1, 2, 3>{}); .StridedSlice(I3, Number<X>{}, Number<ConvDilations::Get(I1)>{})
.Extract(Sequence<1, 2, 3>{});
// merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy // merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy
constexpr auto in_e_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor( constexpr auto in_e_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor(
......
...@@ -320,6 +320,18 @@ struct ConstantTensorDescriptor ...@@ -320,6 +320,18 @@ struct ConstantTensorDescriptor
return ConstantTensorDescriptor<slice_lengths, Strides>{}; return ConstantTensorDescriptor<slice_lengths, Strides>{};
} }
template <index_t IDim, index_t SliceLength, index_t SliceStride>
__host__ __device__ static constexpr auto
StridedSlice(Number<IDim>, Number<SliceLength>, Number<SliceStride>)
{
constexpr index_t new_stride = Strides::Get(Number<IDim>{}) * SliceStride;
using new_lengths = decltype(Lengths::Modify(Number<IDim>{}, Number<SliceLength>{}));
using new_strides = decltype(Strides::Modify(Number<IDim>{}, Number<new_stride>{}));
return ConstantTensorDescriptor<new_lengths, new_strides>{};
}
template <index_t IDim, index_t... FoldIntervals> template <index_t IDim, index_t... FoldIntervals>
__host__ __device__ static constexpr auto Fold(Number<IDim>, Number<FoldIntervals>...) __host__ __device__ static constexpr auto Fold(Number<IDim>, Number<FoldIntervals>...)
{ {
......
...@@ -36,11 +36,14 @@ constexpr auto get_convolution_output_default_4d_tensor_descriptor(InDesc, WeiDe ...@@ -36,11 +36,14 @@ constexpr auto get_convolution_output_default_4d_tensor_descriptor(InDesc, WeiDe
return make_ConstantTensorDescriptor_packed(Sequence<N, K, HO, WO>{}); return make_ConstantTensorDescriptor_packed(Sequence<N, K, HO, WO>{});
} }
template <class InDesc, class WeiDesc, class LowerPads, class UpperPads> template <class InDesc,
constexpr auto get_convolution_with_padding_output_default_4d_tensor_descriptor(InDesc, class WeiDesc,
WeiDesc, class ConvStrides,
LowerPads, class ConvDilations,
UpperPads) class LowerPads,
class UpperPads>
constexpr auto get_convolution_with_padding_output_default_4d_tensor_descriptor(
InDesc, WeiDesc, ConvStrides, ConvDilations, LowerPads, UpperPads)
{ {
constexpr auto in_desc = InDesc{}; constexpr auto in_desc = InDesc{};
constexpr auto wei_desc = WeiDesc{}; constexpr auto wei_desc = WeiDesc{};
...@@ -55,24 +58,27 @@ constexpr auto get_convolution_with_padding_output_default_4d_tensor_descriptor( ...@@ -55,24 +58,27 @@ constexpr auto get_convolution_with_padding_output_default_4d_tensor_descriptor(
static_assert(in_desc.GetLength(I1) == wei_desc.GetLength(I1), static_assert(in_desc.GetLength(I1) == wei_desc.GetLength(I1),
"input & weight dimension not consistent"); "input & weight dimension not consistent");
constexpr auto N = in_desc.GetLength(I0); constexpr index_t N = in_desc.GetLength(I0);
constexpr auto HI = in_desc.GetLength(I2); constexpr index_t Hi = in_desc.GetLength(I2);
constexpr auto WI = in_desc.GetLength(I3); constexpr index_t Wi = in_desc.GetLength(I3);
constexpr auto K = wei_desc.GetLength(I0); constexpr index_t K = wei_desc.GetLength(I0);
constexpr auto Y = wei_desc.GetLength(I2); constexpr index_t Y = wei_desc.GetLength(I2);
constexpr auto X = wei_desc.GetLength(I3); constexpr index_t X = wei_desc.GetLength(I3);
constexpr auto HPadLow = LowerPads{}.Get(I0); constexpr index_t HPadLow = LowerPads{}.Get(I0);
constexpr auto WPadLow = LowerPads{}.Get(I1); constexpr index_t WPadLow = LowerPads{}.Get(I1);
constexpr auto HPadUp = UpperPads{}.Get(I0); constexpr index_t HPadUp = UpperPads{}.Get(I0);
constexpr auto WPadUp = UpperPads{}.Get(I1); constexpr index_t WPadUp = UpperPads{}.Get(I1);
constexpr auto HO = HI + HPadLow + HPadUp + 1 - Y; constexpr index_t YEff = (Y - 1) * ConvDilations{}[0] + 1;
constexpr auto WO = WI + WPadLow + WPadUp + 1 - X; constexpr index_t XEff = (X - 1) * ConvDilations{}[1] + 1;
return make_ConstantTensorDescriptor_packed(Sequence<N, K, HO, WO>{}); constexpr index_t Ho = (Hi + HPadLow + HPadUp - YEff) / ConvStrides{}[0] + 1;
constexpr index_t Wo = (Wi + WPadLow + WPadUp - XEff) / ConvStrides{}[1] + 1;
return make_ConstantTensorDescriptor_packed(Sequence<N, K, Ho, Wo>{});
} }
template <class InDesc, class WeiDesc, class OutDesc> template <class InDesc, class WeiDesc, class OutDesc>
......
...@@ -8,13 +8,20 @@ ...@@ -8,13 +8,20 @@
using namespace ck; using namespace ck;
template <class T, class InDesc, class WeiDesc, class OutDesc> template <class T,
class InDesc,
class WeiDesc,
class OutDesc,
class ConvStrides,
class ConvDilations>
void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc, void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
const Tensor<T>& in_nchw, const Tensor<T>& in_nchw,
WeiDesc, WeiDesc,
const Tensor<T>& wei_kcyx, const Tensor<T>& wei_kcyx,
OutDesc, OutDesc,
Tensor<T>& out_nkhw, Tensor<T>& out_nkhw,
ConvStrides,
ConvDilations,
index_t nrepeat) index_t nrepeat)
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
...@@ -107,6 +114,8 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc, ...@@ -107,6 +114,8 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
decltype(in_nchw_desc), decltype(in_nchw_desc),
decltype(wei_kcyx_desc), decltype(wei_kcyx_desc),
decltype(out_nkhw_desc), decltype(out_nkhw_desc),
ConvStrides,
ConvDilations,
BPerBlock, BPerBlock,
KPerBlock, KPerBlock,
CPerBlock, CPerBlock,
......
...@@ -103,10 +103,18 @@ auto make_TensorDescriptor(TConstTensorDesc) ...@@ -103,10 +103,18 @@ auto make_TensorDescriptor(TConstTensorDesc)
return TensorDescriptor(lengths, strides); return TensorDescriptor(lengths, strides);
} }
template <class TIn, class TWei, class TOut, class LowerPads, class UpperPads> template <class TIn,
class TWei,
class TOut,
class ConvStrides,
class ConvDilations,
class LowerPads,
class UpperPads>
void host_direct_convolution(const Tensor<TIn>& in_nchw, void host_direct_convolution(const Tensor<TIn>& in_nchw,
const Tensor<TWei>& wei_kcyx, const Tensor<TWei>& wei_kcyx,
Tensor<TOut>& out_nkhw, Tensor<TOut>& out_nkhw,
ConvStrides,
ConvDilations,
LowerPads, LowerPads,
UpperPads) UpperPads)
{ {
...@@ -122,10 +130,10 @@ void host_direct_convolution(const Tensor<TIn>& in_nchw, ...@@ -122,10 +130,10 @@ void host_direct_convolution(const Tensor<TIn>& in_nchw,
{ {
for(int y = 0; y < wei_kcyx.mDesc.GetLengths()[2]; ++y) for(int y = 0; y < wei_kcyx.mDesc.GetLengths()[2]; ++y)
{ {
int hi = ho + y - h_pad_low; int hi = ho * ConvStrides{}[0] + y * ConvDilations{}[0] - h_pad_low;
for(int x = 0; x < wei_kcyx.mDesc.GetLengths()[3]; ++x) for(int x = 0; x < wei_kcyx.mDesc.GetLengths()[3]; ++x)
{ {
int wi = wo + x - w_pad_low; int wi = wo * ConvStrides{}[1] + x * ConvDilations{}[1] - w_pad_low;
if(hi >= 0 && hi < in_nchw.mDesc.GetLengths()[2] && wi >= 0 && if(hi >= 0 && hi < in_nchw.mDesc.GetLengths()[2] && wi >= 0 &&
wi < in_nchw.mDesc.GetLengths()[3]) wi < in_nchw.mDesc.GetLengths()[3])
{ {
...@@ -419,9 +427,9 @@ int main(int argc, char* argv[]) ...@@ -419,9 +427,9 @@ int main(int argc, char* argv[])
constexpr index_t HPad = 0; constexpr index_t HPad = 0;
constexpr index_t WPad = 0; constexpr index_t WPad = 0;
#elif 1 #elif 0
// 3x3, 34x34 // 3x3, 34x34
constexpr index_t N = 64; constexpr index_t N = 128;
constexpr index_t C = 256; constexpr index_t C = 256;
constexpr index_t HI = 34; constexpr index_t HI = 34;
constexpr index_t WI = 34; constexpr index_t WI = 34;
...@@ -429,6 +437,9 @@ int main(int argc, char* argv[]) ...@@ -429,6 +437,9 @@ int main(int argc, char* argv[])
constexpr index_t Y = 3; constexpr index_t Y = 3;
constexpr index_t X = 3; constexpr index_t X = 3;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
constexpr index_t HPad = 0; constexpr index_t HPad = 0;
constexpr index_t WPad = 0; constexpr index_t WPad = 0;
#elif 0 #elif 0
...@@ -453,6 +464,9 @@ int main(int argc, char* argv[]) ...@@ -453,6 +464,9 @@ int main(int argc, char* argv[])
constexpr index_t Y = 3; constexpr index_t Y = 3;
constexpr index_t X = 3; constexpr index_t X = 3;
using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>;
constexpr index_t HPad = 0; constexpr index_t HPad = 0;
constexpr index_t WPad = 0; constexpr index_t WPad = 0;
#elif 0 #elif 0
...@@ -583,7 +597,7 @@ int main(int argc, char* argv[]) ...@@ -583,7 +597,7 @@ int main(int argc, char* argv[])
auto in_nchw_desc = make_ConstantTensorDescriptor_packed(Sequence<N, C, HI, WI>{}); auto in_nchw_desc = make_ConstantTensorDescriptor_packed(Sequence<N, C, HI, WI>{});
auto wei_kcyx_desc = make_ConstantTensorDescriptor_packed(Sequence<K, C, Y, X>{}); auto wei_kcyx_desc = make_ConstantTensorDescriptor_packed(Sequence<K, C, Y, X>{});
auto out_nkhw_desc = get_convolution_with_padding_output_default_4d_tensor_descriptor( auto out_nkhw_desc = get_convolution_with_padding_output_default_4d_tensor_descriptor(
in_nchw_desc, wei_kcyx_desc, lower_pads, upper_pads); in_nchw_desc, wei_kcyx_desc, ConvStrides{}, ConvDilations{}, lower_pads, upper_pads);
ostream_ConstantTensorDescriptor(in_nchw_desc, std::cout << "in_nchw_desc: "); ostream_ConstantTensorDescriptor(in_nchw_desc, std::cout << "in_nchw_desc: ");
ostream_ConstantTensorDescriptor(wei_kcyx_desc, std::cout << "wei_kcyx_desc: "); ostream_ConstantTensorDescriptor(wei_kcyx_desc, std::cout << "wei_kcyx_desc: ");
...@@ -645,9 +659,17 @@ int main(int argc, char* argv[]) ...@@ -645,9 +659,17 @@ int main(int argc, char* argv[])
#elif 1 #elif 1
device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw
#endif #endif
(in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat); (in_nchw_desc,
in_nchw,
wei_kcyx_desc,
wei_kcyx,
out_nkhw_desc,
out_nkhw_device,
ConvStrides{},
ConvDilations{},
nrepeat);
#elif 1 #elif 0
device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(in_nchw_desc, device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(in_nchw_desc,
in_nchw, in_nchw,
wei_kcyx_desc, wei_kcyx_desc,
...@@ -662,14 +684,21 @@ int main(int argc, char* argv[]) ...@@ -662,14 +684,21 @@ int main(int argc, char* argv[])
if(do_verification) if(do_verification)
{ {
#if 1 #if 1
if(Y == 3 && X == 3) if(Y == 3 && X == 3 && ConvStrides{}[0] == 1 && ConvStrides{}[1] == 1 &&
ConvDilations{}[0] == 1 && ConvDilations{}[1] == 1)
{ {
host_winograd_3x3_convolution(in_nchw, wei_kcyx, out_nkhw_host, lower_pads, upper_pads); host_winograd_3x3_convolution(in_nchw, wei_kcyx, out_nkhw_host, lower_pads, upper_pads);
} }
else else
#endif #endif
{ {
host_direct_convolution(in_nchw, wei_kcyx, out_nkhw_host, lower_pads, upper_pads); host_direct_convolution(in_nchw,
wei_kcyx,
out_nkhw_host,
ConvStrides{},
ConvDilations{},
lower_pads,
upper_pads);
} }
check_error(out_nkhw_host, out_nkhw_device); check_error(out_nkhw_host, out_nkhw_device);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment