Commit b97af4ec authored by Jehandad Khan's avatar Jehandad Khan
Browse files

merge forward and wrw code in kernel

parent 59252249
......@@ -11,6 +11,37 @@
namespace ck {
enum struct ConvolutionDirection { Forward, BackwardWeights};
template<ConvolutionDirection conv_dir, typename WeiDesc>
struct make_WeiDesc
{
};
template<typename WeiDesc>
struct make_WeiDesc <ConvolutionDirection::Forward, WeiDesc>
{
__device__ constexpr
auto
get(WeiDesc& desc)
{
constexpr auto I1 = Number<1>{};
constexpr auto I3 = Number<3>{};
return WeiDesc{}.Unfold(I1, I3).ReorderGivenNew2Old(Sequence<1, 0>{});
}
};
template<typename WeiDesc>
struct make_WeiDesc<ConvolutionDirection::BackwardWeights,WeiDesc>
{
__device__ constexpr
auto
get(WeiDesc& desc)
{
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
return make_ConstantMergedTensorDescriptor(
desc.Unfold(I2, I3), Sequence<1, 2>{}, Sequence<0>{});
}
};
// define B = merge(N0, Ho, Wo)
template <index_t GridSize,
index_t BlockSize,
......@@ -48,7 +79,8 @@ template <index_t GridSize,
class WeiBlockCopyDstAccessOrder,
index_t WeiBlockCopySrcDataPerRead_E,
index_t WeiBlockCopyDstDataPerWrite_K,
index_t OutThreadCopyDataPerAccess_W>
index_t OutThreadCopyDataPerAccess_W,
ConvolutionDirection conv_dir>
struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
{
__device__ void Run(const Float* const __restrict__ p_in_global,
......@@ -196,8 +228,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
// weight tensor
// tensor descriptor in device memory, src of blockwise copy
constexpr auto wei_e_k_global_merged_desc = make_ConstantMergedTensorDescriptor(
wei_k_c_y_x_global_desc.Unfold(I2, I3), Sequence<1, 2>{}, Sequence<0>{});
constexpr auto wei_e_k_global_desc = make_WeiDesc<conv_dir, decltype(wei_k_c_y_x_global_desc)>{}.get(wei_k_c_y_x_global_desc);
// tensor descriptor in LDS, dst of blockwise copy
// be careful of LDS alignment
......@@ -211,7 +242,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
// this copy operator already have blockwise offset built-in
auto blockwise_wei_copy =
BlockwiseGenericTensorSliceCopy_v1<BlockSize,
decltype(wei_e_k_global_merged_desc),
decltype(wei_e_k_global_desc),
decltype(wei_e_k_block_desc),
decltype(wei_e_k_block_desc.GetLengths()),
WeiBlockCopySubLengths_E_K,
......@@ -336,7 +367,13 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
blockwise_in_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0, 0, 0>{}, True);
blockwise_wei_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0>{}, True);
static_if<conv_dir == ConvolutionDirection::BackwardWeights>{}(
[&] (auto fwd){
fwd(blockwise_wei_copy).MoveSrcSlicingWindow(Sequence<EPerBlock, 0>{}, True);
}).Else( [&](auto fwd) {
p_wei_block_on_global += EPerBlock * fwd(wei_e_k_global_desc).GetStride(I0);
});
__syncthreads();
......@@ -356,12 +393,16 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
// LDS double buffer: tail
{
// even iteration
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
// even iteration
blockwise_in_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0, 0, 0>{}, True);
blockwise_wei_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0>{}, True);
static_if<conv_dir == ConvolutionDirection::BackwardWeights>{}(
[&] (auto){
blockwise_wei_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0>{}, True);
}).Else( [&] (auto fwd){
p_wei_block_on_global += EPerBlock * fwd(wei_e_k_global_desc).GetStride(I0);
});
__syncthreads();
......
......@@ -217,7 +217,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
WeiBlockCopyDstAccessOrder,
WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K,
OutThreadCopyDataPerAccess_W>{};
OutThreadCopyDataPerAccess_W, ConvolutionDirection::BackwardWeights>{};
float time = launch_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
dim3(GridSize),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment