Commit bc0c18ad authored by Jing Zhang's avatar Jing Zhang
Browse files

clean code for 1x1

parent fccf044a
...@@ -12,24 +12,26 @@ template <class T, ...@@ -12,24 +12,26 @@ template <class T,
class Strides, class Strides,
class Dilations, class Dilations,
index_t Direction> index_t Direction>
void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc, void device_convolution_implicit_gemm_v4_nchw_kc1x1_nkhw(InDesc,
Tensor<T>& in_nchw, Tensor<T>& in_nchw,
WeiDesc, WeiDesc,
const Tensor<T>& wei_kcyx, const Tensor<T>& wei_kc,
OutDesc, OutDesc,
Strides, Strides,
Dilations, Dilations,
Number<Direction>, Number<Direction>,
Tensor<T>& out_nkhw, Tensor<T>& out_nkhw,
index_t nrepeat) index_t nrepeat)
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
constexpr auto in_nchw_desc = InDesc{}; constexpr auto in_nchw_desc = InDesc{};
constexpr auto wei_kcyx_desc = WeiDesc{}; static_assert(WeiDesc{}.GetLength(I2) == 1, "1x1 filter only");
static_assert(WeiDesc{}.GetLength(I3) == 1, "1x1 filter only");
constexpr auto wei_kc_desc = WeiDesc{}.Extract(Sequence<0, 1>{});
constexpr auto out_nkhw_desc = OutDesc{}; constexpr auto out_nkhw_desc = OutDesc{};
constexpr index_t Hi = in_nchw_desc.GetLength(I2); constexpr index_t Hi = in_nchw_desc.GetLength(I2);
...@@ -39,18 +41,16 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc, ...@@ -39,18 +41,16 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
constexpr index_t Ho = out_nkhw_desc.GetLength(I2); constexpr index_t Ho = out_nkhw_desc.GetLength(I2);
constexpr index_t Wo = out_nkhw_desc.GetLength(I3); constexpr index_t Wo = out_nkhw_desc.GetLength(I3);
constexpr index_t K = wei_kcyx_desc.GetLength(I0); constexpr index_t K = wei_kc_desc.GetLength(I0);
constexpr index_t C = wei_kcyx_desc.GetLength(I1); constexpr index_t C = wei_kc_desc.GetLength(I1);
constexpr index_t Y = wei_kcyx_desc.GetLength(I2);
constexpr index_t X = wei_kcyx_desc.GetLength(I3);
std::size_t data_sz = sizeof(T); std::size_t data_sz = sizeof(T);
DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace()); DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace());
DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace()); DeviceMem wei_kc_device_buf(data_sz * wei_kc.mDesc.GetElementSpace());
DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace()); DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace());
in_nchw_device_buf.ToDevice(in_nchw.mData.data()); in_nchw_device_buf.ToDevice(in_nchw.mData.data());
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data()); wei_kc_device_buf.ToDevice(wei_kc.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
constexpr index_t N1 = 2; constexpr index_t N1 = 2;
...@@ -104,55 +104,51 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc, ...@@ -104,55 +104,51 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
for(index_t i = 0; i < nrepeat; ++i) for(index_t i = 0; i < nrepeat; ++i)
{ {
constexpr auto gridwise_conv = constexpr auto gridwise_conv =
#if 0 GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kc1x1_nkhw<
GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw GridSize,
#else BlockSize,
GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw Strides,
#endif Dilations,
<GridSize, Direction,
BlockSize, T,
Strides, decltype(in_nchw_desc),
Dilations, decltype(wei_kc_desc),
Direction, decltype(out_nkhw_desc),
T, BPerBlock,
decltype(in_nchw_desc), KPerBlock,
decltype(wei_kcyx_desc), CPerBlock,
decltype(out_nkhw_desc), N1,
BPerBlock, N2,
KPerBlock, GemmMPerThreadSubC,
CPerBlock, GemmNPerThreadSubC,
N1, GemmMLevel0Cluster,
N2, GemmNLevel0Cluster,
GemmMPerThreadSubC, GemmMLevel1Cluster,
GemmNPerThreadSubC, GemmNLevel1Cluster,
GemmMLevel0Cluster, GemmKPerThreadLoop,
GemmNLevel0Cluster, GemmDataPerReadA,
GemmMLevel1Cluster, GemmDataPerReadB,
GemmNLevel1Cluster, InBlockCopySubLengths_E_N1_B_N2,
GemmKPerThreadLoop, InBlockCopyClusterLengths_E_N1_B_N2,
GemmDataPerReadA, InBlockCopyThreadClusterArrangeOrder,
GemmDataPerReadB, InBlockCopySrcAccessOrder,
InBlockCopySubLengths_E_N1_B_N2, InBlockCopyDstAccessOrder,
InBlockCopyClusterLengths_E_N1_B_N2, InBlockCopySrcDataPerRead_B,
InBlockCopyThreadClusterArrangeOrder, InBlockCopyDstDataPerWrite_N2,
InBlockCopySrcAccessOrder, WeiBlockCopySubLengths_E_K,
InBlockCopyDstAccessOrder, WeiBlockCopyClusterLengths_E_K,
InBlockCopySrcDataPerRead_B, WeiBlockCopyThreadClusterArrangeOrder,
InBlockCopyDstDataPerWrite_N2, WeiBlockCopySrcAccessOrder,
WeiBlockCopySubLengths_E_K, WeiBlockCopyDstAccessOrder,
WeiBlockCopyClusterLengths_E_K, WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyThreadClusterArrangeOrder, WeiBlockCopyDstDataPerWrite_K>{};
WeiBlockCopySrcAccessOrder,
WeiBlockCopyDstAccessOrder,
WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K>{};
float time = launch_kernel(run_gridwise_convolution<decltype(gridwise_conv), T>, float time = launch_kernel(run_gridwise_convolution<decltype(gridwise_conv), T>,
dim3(GridSize), dim3(GridSize),
dim3(BlockSize), dim3(BlockSize),
0, 0,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()), static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()), static_cast<T*>(wei_kc_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer())); static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
printf("Elapsed time : %f ms, %f TFlop/s\n", printf("Elapsed time : %f ms, %f TFlop/s\n",
......
...@@ -752,7 +752,7 @@ int main(int argc, char* argv[]) ...@@ -752,7 +752,7 @@ int main(int argc, char* argv[])
#elif 0 #elif 0
device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw
#elif 1 #elif 1
device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw device_convolution_implicit_gemm_v4_nchw_kc1x1_nkhw
#endif #endif
(in_nchw_desc, (in_nchw_desc,
in_nchw_device, in_nchw_device,
......
...@@ -67,18 +67,18 @@ struct GetInGlobalMergeDesc<true, InType, N1, N2, Ho, Wo, Strides, Dilations> ...@@ -67,18 +67,18 @@ struct GetInGlobalMergeDesc<true, InType, N1, N2, Ho, Wo, Strides, Dilations>
constexpr auto in_n0_n1_n2_h_w_new_global_desc = constexpr auto in_n0_n1_n2_h_w_new_global_desc =
make_ConstantTensorDescriptor(in_lengths_new, in_strides_new); make_ConstantTensorDescriptor(in_lengths_new, in_strides_new);
constexpr auto in_c_1_1_global_desc = in_n_c_h_w_global_desc.Slice(I2, Number<1>{}) constexpr auto in_c_global_desc = in_n_c_h_w_global_desc.Slice(I2, Number<1>{})
.Slice(I3, Number<1>{}) .Slice(I3, Number<1>{})
.Extract(Sequence<1, 2, 3>{}); .Extract(Sequence<1, 2, 3>{});
constexpr auto in_win_lengths_new = Sequence<in_c_1_1_global_desc.GetLength(I0), constexpr auto in_win_lengths_new = Sequence<in_c_global_desc.GetLength(I0),
in_c_1_1_global_desc.GetLength(I1), in_c_global_desc.GetLength(I1),
in_c_1_1_global_desc.GetLength(I2)>{}; in_c_global_desc.GetLength(I2)>{};
constexpr auto in_win_strides_new = constexpr auto in_win_strides_new =
Sequence<in_c_1_1_global_desc.GetStride(I0), Sequence<in_c_global_desc.GetStride(I0),
in_c_1_1_global_desc.GetStride(I1) * Dilations{}.Get(I0), in_c_global_desc.GetStride(I1) * Dilations{}.Get(I0),
in_c_1_1_global_desc.GetStride(I2) * Dilations{}.Get(I1)>{}; in_c_global_desc.GetStride(I2) * Dilations{}.Get(I1)>{};
constexpr auto in_c_1_1_new_global_desc = constexpr auto in_c_1_1_new_global_desc =
make_ConstantTensorDescriptor(in_win_lengths_new, in_win_strides_new); make_ConstantTensorDescriptor(in_win_lengths_new, in_win_strides_new);
...@@ -122,11 +122,11 @@ struct GetInGlobalMergeDesc<false, InType, N1, N2, Ho, Wo, Strides, Dilations> ...@@ -122,11 +122,11 @@ struct GetInGlobalMergeDesc<false, InType, N1, N2, Ho, Wo, Strides, Dilations>
constexpr auto in_n0_n1_n2_h_w_new_global_desc = in_n0_n1_n2_h_w_global_desc; constexpr auto in_n0_n1_n2_h_w_new_global_desc = in_n0_n1_n2_h_w_global_desc;
constexpr auto in_c_1_1_global_desc = in_n_c_h_w_global_desc.Slice(I2, Number<1>{}) constexpr auto in_c_global_desc = in_n_c_h_w_global_desc.Slice(I2, Number<1>{})
.Slice(I3, Number<1>{}) .Slice(I3, Number<1>{})
.Extract(Sequence<1, 2, 3>{}); .Extract(Sequence<1, 2, 3>{});
constexpr auto in_c_1_1_new_global_desc = in_c_1_1_global_desc; constexpr auto in_c_1_1_new_global_desc = in_c_global_desc;
// merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy // merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy
constexpr auto in_e_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor( constexpr auto in_e_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor(
...@@ -233,7 +233,7 @@ template <index_t GridSize, ...@@ -233,7 +233,7 @@ template <index_t GridSize,
class WeiBlockCopyDstAccessOrder, class WeiBlockCopyDstAccessOrder,
index_t WeiBlockCopySrcDataPerRead_E, index_t WeiBlockCopySrcDataPerRead_E,
index_t WeiBlockCopyDstDataPerWrite_K> index_t WeiBlockCopyDstDataPerWrite_K>
struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kc1x1_nkhw
{ {
__device__ void Run(Float* const __restrict__ p_conv_in_global, __device__ void Run(Float* const __restrict__ p_conv_in_global,
const Float* const __restrict__ p_wei_global, const Float* const __restrict__ p_wei_global,
...@@ -266,7 +266,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw ...@@ -266,7 +266,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
OutGlobalDescType<Direction == 1, InGlobalDesc, OutGlobalDesc>{}.Type; OutGlobalDescType<Direction == 1, InGlobalDesc, OutGlobalDesc>{}.Type;
// to-do: backward data: 1) ckyx: yx unfold, 2) merge cyx = e, 3 out = ek // to-do: backward data: 1) ckyx: yx unfold, 2) merge cyx = e, 3 out = ek
constexpr auto wei_k_c_1_1_global_desc = WeiGlobalDesc{}; constexpr auto wei_k_c_global_desc = WeiGlobalDesc{};
constexpr index_t N = in_n_c_h_w_global_desc.GetLength(I0); constexpr index_t N = in_n_c_h_w_global_desc.GetLength(I0);
constexpr index_t C = in_n_c_h_w_global_desc.GetLength(I1); constexpr index_t C = in_n_c_h_w_global_desc.GetLength(I1);
...@@ -277,8 +277,8 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw ...@@ -277,8 +277,8 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
constexpr index_t Ho = out_n_k_h_w_global_desc.GetLength(I2); constexpr index_t Ho = out_n_k_h_w_global_desc.GetLength(I2);
constexpr index_t Wo = out_n_k_h_w_global_desc.GetLength(I3); constexpr index_t Wo = out_n_k_h_w_global_desc.GetLength(I3);
// constexpr index_t Y = wei_k_c_1_1_global_desc.GetLength(I2); // constexpr index_t Y = wei_k_c_global_desc.GetLength(I2);
// constexpr index_t X = wei_k_c_1_1_global_desc.GetLength(I3); // constexpr index_t X = wei_k_c_global_desc.GetLength(I3);
static_assert(N % (N1 * N2) == 0, "wrong! cannot divice N evenly among thread"); static_assert(N % (N1 * N2) == 0, "wrong! cannot divice N evenly among thread");
...@@ -306,9 +306,9 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw ...@@ -306,9 +306,9 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
// batch descritpor for device memory // batch descritpor for device memory
// to-do: add dilation: keep lengths, modify strides // to-do: add dilation: keep lengths, modify strides
constexpr auto in_c_1_1_global_desc = in_n_c_h_w_global_desc.Slice(I2, Number<1>{}) constexpr auto in_c_global_desc = in_n_c_h_w_global_desc.Slice(I2, Number<1>{})
.Slice(I3, Number<1>{}) .Slice(I3, Number<1>{})
.Extract(Sequence<1, 2, 3>{}); .Extract(Sequence<1, 2, 3>{});
constexpr bool fwd = Direction == 1; constexpr bool fwd = Direction == 1;
constexpr auto in_e_n1_b_n2_global_merged_desc = constexpr auto in_e_n1_b_n2_global_merged_desc =
GetInGlobalMergeDesc<fwd, GetInGlobalMergeDesc<fwd,
...@@ -352,7 +352,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw ...@@ -352,7 +352,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
// weight tensor // weight tensor
// tensor descriptor in device memory, src of blockwise copy // tensor descriptor in device memory, src of blockwise copy
constexpr auto wei_e_k_global_desc = wei_k_c_1_1_global_desc.Unfold(I1, I3); constexpr auto wei_e_k_global_desc = wei_k_c_global_desc;
// tensor descriptor in LDS, dst of blockwise copy // tensor descriptor in LDS, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment