"test/git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "d4fe76e7e44e0887c8d0e4e0ed0b6f0a876adfd9"
Commit b97af4ec authored by Jehandad Khan's avatar Jehandad Khan
Browse files

merge forward and wrw code in kernel

parent 59252249
...@@ -11,6 +11,37 @@ ...@@ -11,6 +11,37 @@
namespace ck { namespace ck {
enum struct ConvolutionDirection { Forward, BackwardWeights};
template<ConvolutionDirection conv_dir, typename WeiDesc>
struct make_WeiDesc
{
};
template<typename WeiDesc>
struct make_WeiDesc <ConvolutionDirection::Forward, WeiDesc>
{
__device__ constexpr
auto
get(WeiDesc& desc)
{
constexpr auto I1 = Number<1>{};
constexpr auto I3 = Number<3>{};
return WeiDesc{}.Unfold(I1, I3).ReorderGivenNew2Old(Sequence<1, 0>{});
}
};
template<typename WeiDesc>
struct make_WeiDesc<ConvolutionDirection::BackwardWeights,WeiDesc>
{
__device__ constexpr
auto
get(WeiDesc& desc)
{
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
return make_ConstantMergedTensorDescriptor(
desc.Unfold(I2, I3), Sequence<1, 2>{}, Sequence<0>{});
}
};
// define B = merge(N0, Ho, Wo) // define B = merge(N0, Ho, Wo)
template <index_t GridSize, template <index_t GridSize,
index_t BlockSize, index_t BlockSize,
...@@ -48,7 +79,8 @@ template <index_t GridSize, ...@@ -48,7 +79,8 @@ template <index_t GridSize,
class WeiBlockCopyDstAccessOrder, class WeiBlockCopyDstAccessOrder,
index_t WeiBlockCopySrcDataPerRead_E, index_t WeiBlockCopySrcDataPerRead_E,
index_t WeiBlockCopyDstDataPerWrite_K, index_t WeiBlockCopyDstDataPerWrite_K,
index_t OutThreadCopyDataPerAccess_W> index_t OutThreadCopyDataPerAccess_W,
ConvolutionDirection conv_dir>
struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
{ {
__device__ void Run(const Float* const __restrict__ p_in_global, __device__ void Run(const Float* const __restrict__ p_in_global,
...@@ -196,8 +228,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer ...@@ -196,8 +228,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
// weight tensor // weight tensor
// tensor descriptor in device memory, src of blockwise copy // tensor descriptor in device memory, src of blockwise copy
constexpr auto wei_e_k_global_merged_desc = make_ConstantMergedTensorDescriptor( constexpr auto wei_e_k_global_desc = make_WeiDesc<conv_dir, decltype(wei_k_c_y_x_global_desc)>{}.get(wei_k_c_y_x_global_desc);
wei_k_c_y_x_global_desc.Unfold(I2, I3), Sequence<1, 2>{}, Sequence<0>{});
// tensor descriptor in LDS, dst of blockwise copy // tensor descriptor in LDS, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
...@@ -211,7 +242,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer ...@@ -211,7 +242,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
// this copy operator already have blockwise offset built-in // this copy operator already have blockwise offset built-in
auto blockwise_wei_copy = auto blockwise_wei_copy =
BlockwiseGenericTensorSliceCopy_v1<BlockSize, BlockwiseGenericTensorSliceCopy_v1<BlockSize,
decltype(wei_e_k_global_merged_desc), decltype(wei_e_k_global_desc),
decltype(wei_e_k_block_desc), decltype(wei_e_k_block_desc),
decltype(wei_e_k_block_desc.GetLengths()), decltype(wei_e_k_block_desc.GetLengths()),
WeiBlockCopySubLengths_E_K, WeiBlockCopySubLengths_E_K,
...@@ -336,7 +367,13 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer ...@@ -336,7 +367,13 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()]; Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
blockwise_in_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0, 0, 0>{}, True); blockwise_in_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0, 0, 0>{}, True);
blockwise_wei_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0>{}, True); static_if<conv_dir == ConvolutionDirection::BackwardWeights>{}(
[&] (auto fwd){
fwd(blockwise_wei_copy).MoveSrcSlicingWindow(Sequence<EPerBlock, 0>{}, True);
}).Else( [&](auto fwd) {
p_wei_block_on_global += EPerBlock * fwd(wei_e_k_global_desc).GetStride(I0);
});
__syncthreads(); __syncthreads();
...@@ -356,12 +393,16 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer ...@@ -356,12 +393,16 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
// LDS double buffer: tail // LDS double buffer: tail
{ {
// even iteration
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()]; Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()]; Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
// even iteration
blockwise_in_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0, 0, 0>{}, True); blockwise_in_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0, 0, 0>{}, True);
blockwise_wei_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0>{}, True); static_if<conv_dir == ConvolutionDirection::BackwardWeights>{}(
[&] (auto){
blockwise_wei_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0>{}, True);
}).Else( [&] (auto fwd){
p_wei_block_on_global += EPerBlock * fwd(wei_e_k_global_desc).GetStride(I0);
});
__syncthreads(); __syncthreads();
......
...@@ -217,7 +217,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, ...@@ -217,7 +217,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
WeiBlockCopyDstAccessOrder, WeiBlockCopyDstAccessOrder,
WeiBlockCopySrcDataPerRead_E, WeiBlockCopySrcDataPerRead_E,
WeiBlockCopyDstDataPerWrite_K, WeiBlockCopyDstDataPerWrite_K,
OutThreadCopyDataPerAccess_W>{}; OutThreadCopyDataPerAccess_W, ConvolutionDirection::BackwardWeights>{};
float time = launch_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>, float time = launch_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
dim3(GridSize), dim3(GridSize),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment