Commit bc0c18ad authored by Jing Zhang's avatar Jing Zhang
Browse files

clean code for 1x1

parent fccf044a
......@@ -12,24 +12,26 @@ template <class T,
class Strides,
class Dilations,
index_t Direction>
void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
Tensor<T>& in_nchw,
WeiDesc,
const Tensor<T>& wei_kcyx,
OutDesc,
Strides,
Dilations,
Number<Direction>,
Tensor<T>& out_nkhw,
index_t nrepeat)
void device_convolution_implicit_gemm_v4_nchw_kc1x1_nkhw(InDesc,
Tensor<T>& in_nchw,
WeiDesc,
const Tensor<T>& wei_kc,
OutDesc,
Strides,
Dilations,
Number<Direction>,
Tensor<T>& out_nkhw,
index_t nrepeat)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_nchw_desc = InDesc{};
constexpr auto wei_kcyx_desc = WeiDesc{};
constexpr auto in_nchw_desc = InDesc{};
static_assert(WeiDesc{}.GetLength(I2) == 1, "1x1 filter only");
static_assert(WeiDesc{}.GetLength(I3) == 1, "1x1 filter only");
constexpr auto wei_kc_desc = WeiDesc{}.Extract(Sequence<0, 1>{});
constexpr auto out_nkhw_desc = OutDesc{};
constexpr index_t Hi = in_nchw_desc.GetLength(I2);
......@@ -39,18 +41,16 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
constexpr index_t Ho = out_nkhw_desc.GetLength(I2);
constexpr index_t Wo = out_nkhw_desc.GetLength(I3);
constexpr index_t K = wei_kcyx_desc.GetLength(I0);
constexpr index_t C = wei_kcyx_desc.GetLength(I1);
constexpr index_t Y = wei_kcyx_desc.GetLength(I2);
constexpr index_t X = wei_kcyx_desc.GetLength(I3);
constexpr index_t K = wei_kc_desc.GetLength(I0);
constexpr index_t C = wei_kc_desc.GetLength(I1);
std::size_t data_sz = sizeof(T);
DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace());
DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace());
DeviceMem wei_kc_device_buf(data_sz * wei_kc.mDesc.GetElementSpace());
DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace());
in_nchw_device_buf.ToDevice(in_nchw.mData.data());
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
wei_kc_device_buf.ToDevice(wei_kc.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
constexpr index_t N1 = 2;
......@@ -104,55 +104,51 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
for(index_t i = 0; i < nrepeat; ++i)
{
constexpr auto gridwise_conv =
#if 0
GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw
#else
GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
#endif
<GridSize,
BlockSize,
Strides,
Dilations,
Direction,
T,
decltype(in_nchw_desc),
decltype(wei_kcyx_desc),
decltype(out_nkhw_desc),
BPerBlock,
KPerBlock,
CPerBlock,
N1,
N2,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmDataPerReadA,
GemmDataPerReadB,
InBlockCopySubLengths_E_N1_B_N2,
InBlockCopyClusterLengths_E_N1_B_N2,
InBlockCopyThreadClusterArrangeOrder,
InBlockCopySrcAccessOrder,
InBlockCopyDstAccessOrder,
InBlockCopySrcDataPerRead_B,
InBlockCopyDstDataPerWrite_N2,
WeiBlockCopySubLengths_E_K,
WeiBlockCopyClusterLengths_E_K,
WeiBlockCopyThreadClusterArrangeOrder,
WeiBlockCopySrcAccessOrder,
WeiBlockCopyDstAccessOrder,
WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K>{};
GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kc1x1_nkhw<
GridSize,
BlockSize,
Strides,
Dilations,
Direction,
T,
decltype(in_nchw_desc),
decltype(wei_kc_desc),
decltype(out_nkhw_desc),
BPerBlock,
KPerBlock,
CPerBlock,
N1,
N2,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmDataPerReadA,
GemmDataPerReadB,
InBlockCopySubLengths_E_N1_B_N2,
InBlockCopyClusterLengths_E_N1_B_N2,
InBlockCopyThreadClusterArrangeOrder,
InBlockCopySrcAccessOrder,
InBlockCopyDstAccessOrder,
InBlockCopySrcDataPerRead_B,
InBlockCopyDstDataPerWrite_N2,
WeiBlockCopySubLengths_E_K,
WeiBlockCopyClusterLengths_E_K,
WeiBlockCopyThreadClusterArrangeOrder,
WeiBlockCopySrcAccessOrder,
WeiBlockCopyDstAccessOrder,
WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K>{};
float time = launch_kernel(run_gridwise_convolution<decltype(gridwise_conv), T>,
dim3(GridSize),
dim3(BlockSize),
0,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_kc_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
printf("Elapsed time : %f ms, %f TFlop/s\n",
......
......@@ -752,7 +752,7 @@ int main(int argc, char* argv[])
#elif 0
device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw
#elif 1
device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw
device_convolution_implicit_gemm_v4_nchw_kc1x1_nkhw
#endif
(in_nchw_desc,
in_nchw_device,
......
......@@ -67,18 +67,18 @@ struct GetInGlobalMergeDesc<true, InType, N1, N2, Ho, Wo, Strides, Dilations>
constexpr auto in_n0_n1_n2_h_w_new_global_desc =
make_ConstantTensorDescriptor(in_lengths_new, in_strides_new);
constexpr auto in_c_1_1_global_desc = in_n_c_h_w_global_desc.Slice(I2, Number<1>{})
.Slice(I3, Number<1>{})
.Extract(Sequence<1, 2, 3>{});
constexpr auto in_c_global_desc = in_n_c_h_w_global_desc.Slice(I2, Number<1>{})
.Slice(I3, Number<1>{})
.Extract(Sequence<1, 2, 3>{});
constexpr auto in_win_lengths_new = Sequence<in_c_1_1_global_desc.GetLength(I0),
in_c_1_1_global_desc.GetLength(I1),
in_c_1_1_global_desc.GetLength(I2)>{};
constexpr auto in_win_lengths_new = Sequence<in_c_global_desc.GetLength(I0),
in_c_global_desc.GetLength(I1),
in_c_global_desc.GetLength(I2)>{};
constexpr auto in_win_strides_new =
Sequence<in_c_1_1_global_desc.GetStride(I0),
in_c_1_1_global_desc.GetStride(I1) * Dilations{}.Get(I0),
in_c_1_1_global_desc.GetStride(I2) * Dilations{}.Get(I1)>{};
Sequence<in_c_global_desc.GetStride(I0),
in_c_global_desc.GetStride(I1) * Dilations{}.Get(I0),
in_c_global_desc.GetStride(I2) * Dilations{}.Get(I1)>{};
constexpr auto in_c_1_1_new_global_desc =
make_ConstantTensorDescriptor(in_win_lengths_new, in_win_strides_new);
......@@ -122,11 +122,11 @@ struct GetInGlobalMergeDesc<false, InType, N1, N2, Ho, Wo, Strides, Dilations>
constexpr auto in_n0_n1_n2_h_w_new_global_desc = in_n0_n1_n2_h_w_global_desc;
constexpr auto in_c_1_1_global_desc = in_n_c_h_w_global_desc.Slice(I2, Number<1>{})
.Slice(I3, Number<1>{})
.Extract(Sequence<1, 2, 3>{});
constexpr auto in_c_global_desc = in_n_c_h_w_global_desc.Slice(I2, Number<1>{})
.Slice(I3, Number<1>{})
.Extract(Sequence<1, 2, 3>{});
constexpr auto in_c_1_1_new_global_desc = in_c_1_1_global_desc;
constexpr auto in_c_1_1_new_global_desc = in_c_global_desc;
// merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy
constexpr auto in_e_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor(
......@@ -233,7 +233,7 @@ template <index_t GridSize,
class WeiBlockCopyDstAccessOrder,
index_t WeiBlockCopySrcDataPerRead_E,
index_t WeiBlockCopyDstDataPerWrite_K>
struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kc1x1_nkhw
{
__device__ void Run(Float* const __restrict__ p_conv_in_global,
const Float* const __restrict__ p_wei_global,
......@@ -266,7 +266,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
OutGlobalDescType<Direction == 1, InGlobalDesc, OutGlobalDesc>{}.Type;
// to-do: backward data: 1) ckyx: yx unfold, 2) merge cyx = e, 3 out = ek
constexpr auto wei_k_c_1_1_global_desc = WeiGlobalDesc{};
constexpr auto wei_k_c_global_desc = WeiGlobalDesc{};
constexpr index_t N = in_n_c_h_w_global_desc.GetLength(I0);
constexpr index_t C = in_n_c_h_w_global_desc.GetLength(I1);
......@@ -277,8 +277,8 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
constexpr index_t Ho = out_n_k_h_w_global_desc.GetLength(I2);
constexpr index_t Wo = out_n_k_h_w_global_desc.GetLength(I3);
// constexpr index_t Y = wei_k_c_1_1_global_desc.GetLength(I2);
// constexpr index_t X = wei_k_c_1_1_global_desc.GetLength(I3);
// constexpr index_t Y = wei_k_c_global_desc.GetLength(I2);
// constexpr index_t X = wei_k_c_global_desc.GetLength(I3);
static_assert(N % (N1 * N2) == 0, "wrong! cannot divice N evenly among thread");
......@@ -306,9 +306,9 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
// batch descritpor for device memory
// to-do: add dilation: keep lengths, modify strides
constexpr auto in_c_1_1_global_desc = in_n_c_h_w_global_desc.Slice(I2, Number<1>{})
.Slice(I3, Number<1>{})
.Extract(Sequence<1, 2, 3>{});
constexpr auto in_c_global_desc = in_n_c_h_w_global_desc.Slice(I2, Number<1>{})
.Slice(I3, Number<1>{})
.Extract(Sequence<1, 2, 3>{});
constexpr bool fwd = Direction == 1;
constexpr auto in_e_n1_b_n2_global_merged_desc =
GetInGlobalMergeDesc<fwd,
......@@ -352,7 +352,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
// weight tensor
// tensor descriptor in device memory, src of blockwise copy
constexpr auto wei_e_k_global_desc = wei_k_c_1_1_global_desc.Unfold(I1, I3);
constexpr auto wei_e_k_global_desc = wei_k_c_global_desc;
// tensor descriptor in LDS, dst of blockwise copy
// be careful of LDS alignment
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment