Commit d2490b49 authored by Chao Liu's avatar Chao Liu
Browse files

rename

parent d78fe365
#ifndef CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V1_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
#define CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V1_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
#ifndef CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
#define CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
......@@ -41,7 +41,7 @@ template <index_t GridSize,
typename WeiBlockCopyClusterLengths_K_E,
index_t WeiBlockCopyDataPerAccess_E,
index_t InThreadCopyDataPerAccess_B>
struct GridwiseConvolutionBackwardDataImplicitGemm_v1_nchw_kcyx_nkhw_lds_double_buffer
struct GridwiseConvolutionBackwardDataImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
{
__device__ void Run(Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global,
......
......@@ -3,7 +3,7 @@
#include "device.hpp"
#include "tensor.hpp"
#include "gridwise_operation_wrapper.hpp"
#include "gridwise_convolution_backward_data_implicit_gemm_v1_nchw_kcyx_nkhw_lds_double_buffer.hpp"
#include "gridwise_convolution_backward_data_implicit_gemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer.hpp"
template <typename T,
typename InDesc,
......@@ -13,17 +13,17 @@ template <typename T,
typename ConvDilations,
typename LeftPads,
typename RightPads>
void device_convolution_bwd_data_implicit_gemm_v1_nchw_kcyx_nkhw(InDesc in_nchw_desc,
Tensor<T>& in_nchw,
WeiDesc wei_kcyx_desc,
const Tensor<T>& wei_kcyx,
OutDesc out_nkhw_desc,
const Tensor<T>& out_nkhw,
ConvStrides,
ConvDilations,
LeftPads,
RightPads,
std::size_t nrepeat)
void device_convolution_backward_data_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc in_nchw_desc,
Tensor<T>& in_nchw,
WeiDesc wei_kcyx_desc,
const Tensor<T>& wei_kcyx,
OutDesc out_nkhw_desc,
const Tensor<T>& out_nkhw,
ConvStrides,
ConvDilations,
LeftPads,
RightPads,
std::size_t nrepeat)
{
using namespace ck;
......@@ -85,7 +85,7 @@ void device_convolution_bwd_data_implicit_gemm_v1_nchw_kcyx_nkhw(InDesc in_nchw_
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
constexpr auto gridwise_conv =
GridwiseConvolutionBackwardDataImplicitGemm_v1_nchw_kcyx_nkhw_lds_double_buffer<
GridwiseConvolutionBackwardDataImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer<
GridSize,
BlockSize,
T,
......
......@@ -13,7 +13,7 @@
#include "device_tensor.hpp"
#include "conv_common.hpp"
#include "host_conv_bwd_data.hpp"
#include "device_convolution_bwd_data_implicit_gemm_v1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_backward_data_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
int main(int argc, char* argv[])
{
......@@ -96,7 +96,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
#elif 1
// 1x1 filter, 8x8 image
// cudnn@V100 83%, ck@V100 75%, ck@P100 78%, ck@VII 65%
constexpr index_t N = 128;
......@@ -344,19 +344,17 @@ int main(int argc, char* argv[])
#endif
}
#if 1
device_convolution_bwd_data_implicit_gemm_v1_nchw_kcyx_nkhw(in_nchw_desc,
in_nchw_device,
wei_kcyx_desc,
wei_kcyx,
out_nkhw_desc,
out_nkhw,
ConvStrides{},
ConvDilations{},
LeftPads{},
RightPads{},
nrepeat);
#endif
device_convolution_backward_data_implicit_gemm_v4r4_nchw_kcyx_nkhw(in_nchw_desc,
in_nchw_device,
wei_kcyx_desc,
wei_kcyx,
out_nkhw_desc,
out_nkhw,
ConvStrides{},
ConvDilations{},
LeftPads{},
RightPads{},
nrepeat);
if(do_verification)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment