#pragma once #include "constant_tensor_descriptor.cuh" #include "blockwise_winograd_transform.cuh" #include "threadwise_winograd_transform.cuh" template __global__ void gridwise_winograd_convolution(InGlobalDesc, TFloat* const __restrict__ p_in_global, WeiGlobalDesc, TFloat* const __restrict__ p_wei_global, OutGlobalDesc, TFloat* __restrict__ p_out_global) { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; constexpr auto in_global_desc = InGlobalDesc{}; constexpr auto wei_global_desc = WeiGlobalDesc{}; constexpr auto out_global_desc = OutGlobalDesc{}; constexpr unsigned S = wei_global_desc.GetLength(I2); constexpr unsigned R = wei_global_desc.GetLength(I3); constexpr unsigned HoPerBlock = OutTileSizeH * YPerBlock; constexpr unsigned WoPerBlock = OutTileSizeW * XPerBlock; constexpr unsigned HiPerBlock = HoPerBlock + S - 1; constexpr unsigned WiPerBlock = WoPerBlock + R - 1; constexpr unsigned InTileSizeH = OutTileSizeH + S - 1; constexpr unsigned InTileSizeW = OutTileSizeW + R - 1; // divide block work constexpr unsigned NBlockWork = (out_global_desc.GetLength(I0) + NPerBlock - 1) / NPerBlock; constexpr unsigned KBlockWork = (out_global_desc.GetLength(I1) + KPerBlock - 1) / KPerBlock; constexpr unsigned YBlockWork = (out_global_desc.GetLength(I2) + HoPerBlock - 1) / HoPerBlock; constexpr unsigned XBlockWork = (out_global_desc.GetLength(I3) + WoPerBlock - 1) / WoPerBlock; const unsigned block_id = blockIdx.x; unsigned itmp = block_id; const unsigned n_block_work_id = itmp / (KBlockWork * YBlockWork * XBlockWork); itmp -= n_block_work_id * (KBlockWork * YBlockWork * XBlockWork); const unsigned k_block_work_id = itmp / (YBlockWork * XBlockWork); itmp -= k_block_work_id * (YBlockWork * XBlockWork); const unsigned y_block_work_id = itmp / XBlockWork; const unsigned x_block_work_id = itmp - y_block_work_id * XBlockWork; const unsigned n_block_data_begin = n_block_work_id * NPerBlock; const unsigned k_block_data_begin = k_block_work_id * KPerBlock; const unsigned y_block_data_begin = y_block_work_id * YPerBlock; const unsigned x_block_data_begin = x_block_work_id * XPerBlock; const unsigned ho_block_data_begin = y_block_data_begin * OutTileSizeH; const unsigned wo_block_data_begin = x_block_data_begin * OutTileSizeW; const unsigned hi_block_data_begin = ho_block_data_begin; // minus padding const unsigned wi_block_data_begin = wo_block_data_begin; // minus padding // divide thread work constexpr unsigned NThreadWork = (NPerBlock + NPerThread - 1) / NPerThread; constexpr unsigned KThreadWork = (KPerBlock + KPerThread - 1) / KPerThread; constexpr unsigned YThreadWork = YPerBlock; constexpr unsigned XThreadWork = XPerBlock; const unsigned thread_id = threadIdx.x; itmp = thread_id; const unsigned n_thread_work_id = itmp / (KThreadWork * YThreadWork * XThreadWork); itmp -= n_thread_work_id * (KThreadWork * YThreadWork * XThreadWork); const unsigned k_thread_work_id = itmp / (YThreadWork * XThreadWork); itmp -= k_thread_work_id * (YThreadWork * XThreadWork); const unsigned y_thread_work_id = itmp / XThreadWork; const unsigned x_thread_work_id = itmp - y_thread_work_id * XThreadWork; const unsigned n_thread_data_begin = n_thread_work_id * NPerThread; const unsigned k_thread_data_begin = k_thread_work_id * KPerThread; const unsigned y_thread_data_begin = y_thread_work_id; const unsigned x_thread_data_begin = x_thread_work_id; // block data constexpr auto in_transform_block_desc = make_ConstantTensorDescriptor( Sequence{}); constexpr auto wei_transform_block_desc = make_ConstantTensorDescriptor(Sequence{}); __shared__ TFloat p_in_transform_block[in_transform_block_desc.GetElementSpace()]; __shared__ TFloat p_wei_transform_block[wei_transform_block_desc.GetElementSpace()]; // thread data constexpr auto in_transform_thread_block_desc = make_ConstantTensorDescriptor(Sequence{}, in_transform_block_desc.GetStrides()); constexpr auto wei_transform_thread_block_desc = make_ConstantTensorDescriptor(Sequence{}, wei_transform_block_desc.GetStrides()); constexpr auto out_transform_thread_desc = make_ConstantTensorDescriptor(Sequence{}); constexpr auto out_thread_desc = make_ConstantTensorDescriptor( Sequence{}); constexpr auto out_thread_global_desc = make_ConstantTensorDescriptor(out_thread_desc.GetLengths(), out_global_desc.GetStrides()); TFloat p_out_transform_thread[out_transform_thread_desc.GetElementSpace()]; TFloat p_out_thread[out_thread_desc.GetElementSpace()]; #if 0 if(blockIdx.x == 0 && threadIdx.x == 0) { printf("in_transform_block_size %u, wei_transform_block_size %u, out_transform_thread_size " "%u, out_thread_size %u \n", in_transform_block_size, wei_transform_block_size, out_transform_thread_size, out_thread_size); } #endif // set threadwise output transform tensor to 0 threadwise_4d_tensor_set_zero(out_transform_thread_desc, p_out_transform_thread); for(unsigned c_block_data_begin = 0; c_block_data_begin < in_global_desc.GetLength(I1); c_block_data_begin += CPerBlock, __syncthreads()) { #if 0 // blockwise transform input blockwise_winograd_transform_input( p_in_global + in_global_desc.Get1dIndex(n_block_data_begin, c_block_data_begin, hi_block_data_begin, wi_block_data_begin), p_in_transform_block); #endif // blockwise transform weights blockwise_winograd_transform_weight( p_wei_global + wei_global_desc.Get1dIndex(k_block_data_begin, c_block_data_begin, 0, 0), p_wei_transform_block); for(unsigned c_thread_data = 0; c_thread_data < CPerBlock; c_thread_data += CPerThread) { // threadwise point multiplication threadwise_winograd_calculate_transformed_output< TFloat, decltype(in_transform_thread_block_desc), decltype(wei_transform_thread_block_desc), decltype(out_transform_thread_desc), InTileSizeH, InTileSizeW, S, R, OutTileSizeH, OutTileSizeW>(in_transform_thread_block_desc, p_in_transform_block + in_transform_block_desc.Get1dIndex( n_thread_data_begin, c_thread_data, y_thread_data_begin * InTileSizeH, x_thread_data_begin * InTileSizeW), wei_transform_thread_block_desc, p_wei_transform_block + wei_transform_block_desc.Get1dIndex( k_thread_data_begin, c_thread_data, 0, 0), out_transform_thread_desc, p_out_transform_thread); } }; // transform back threadwise_winograd_reverse_transform_output( out_transform_thread_desc, p_out_transform_thread, out_thread_desc, p_out_thread); // copy output tensor from register to global mem threadwise_4d_tensor_copy( out_thread_desc, p_out_thread, out_thread_global_desc, p_out_global + out_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin, k_block_data_begin + k_thread_data_begin, ho_block_data_begin + y_thread_data_begin * OutTileSizeH, wo_block_data_begin + x_thread_data_begin * OutTileSizeW)); }