Commit 249c5d6d authored by Chao Liu's avatar Chao Liu
Browse files

nvidia build

parent ea8aa63f
......@@ -76,7 +76,10 @@ void launch_kernel(F kernel,
cudaStream_t stream_id,
Args... args)
{
cudaLaunchKernel(f, grid_dim, block_dim, p_args, lds_byte, stream_id);
const void* f = reinterpret_cast<const void*>(kernel);
void* p_args[] = {&args...};
cudaError_t error = cudaLaunchKernel(f, grid_dim, block_dim, p_args, lds_byte, stream_id);
}
template <typename... Args, typename F>
......
......@@ -5,6 +5,10 @@
#include "gridwise_operation_wrapper.hpp"
#include "gridwise_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp"
namespace launcher {
using namespace ck;
template <typename T,
typename InDesc,
typename WeiDesc,
......@@ -121,22 +125,18 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
for(index_t i = 0; i < nrepeat; ++i)
{
float time =
launch_and_time_kernel(run_gridwise_operation<decltype(gridwise_conv),
T* const __restrict__,
const T* const __restrict__,
const T* const __restrict__>,
dim3(GridSize),
dim3(BlockSize),
0,
0,
gridwise_conv,
const_cast<T* const __restrict__>(
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer())),
const_cast<const T* const __restrict__>(
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer())),
const_cast<const T* const __restrict__>(
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer())));
float time = launch_and_time_kernel(run_gridwise_operation<decltype(gridwise_conv),
T* const __restrict__,
const T* const __restrict__,
const T* const __restrict__>,
dim3(GridSize),
dim3(BlockSize),
0,
0,
gridwise_conv,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
printf("Elapsed time : %f ms, %f TFlop/s\n",
time,
......@@ -147,3 +147,5 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
in_nchw_device_buf.FromDevice(in_nchw.mData.data());
}
} // namespace launcher
......@@ -5,6 +5,10 @@
#include "gridwise_operation_wrapper.hpp"
#include "gridwise_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw_lds_double_buffer.hpp"
namespace launcher {
using namespace ck;
template <typename T,
typename InDesc,
typename WeiDesc,
......@@ -129,21 +133,18 @@ void device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw(InDesc i
for(index_t i = 0; i < nrepeat; ++i)
{
float time =
launch_and_time_kernel(run_gridwise_operation<decltype(gridwise_conv),
T* const __restrict__,
const T* const __restrict__,
const T* const __restrict__>,
dim3(GridSize),
dim3(BlockSize),
0,
gridwise_conv,
const_cast<T* const __restrict__>(
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer())),
const_cast<const T* const __restrict__>(
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer())),
const_cast<const T* const __restrict__>(
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer())));
float time = launch_and_time_kernel(run_gridwise_operation<decltype(gridwise_conv),
T* const __restrict__,
const T* const __restrict__,
const T* const __restrict__>,
dim3(GridSize),
dim3(BlockSize),
0,
0,
gridwise_conv,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
printf("Elapsed time : %f ms, %f TFlop/s\n",
time,
......@@ -154,3 +155,5 @@ void device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw(InDesc i
in_nchw_device_buf.FromDevice(in_nchw.mData.data());
}
} // namespace launcher
......@@ -5,6 +5,10 @@
#include "gridwise_operation_wrapper.hpp"
#include "gridwise_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp"
namespace launcher {
using namespace ck;
template <typename T,
typename InDesc,
typename WeiDesc,
......@@ -217,21 +221,18 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
for(index_t i = 0; i < nrepeat; ++i)
{
float time =
launch_and_time_kernel(run_gridwise_operation<decltype(gridwise_conv),
T* const __restrict__,
const T* const __restrict__,
const T* const __restrict__>,
dim3(GridSize),
dim3(BlockSize),
0,
gridwise_conv,
const_cast<T* const __restrict__>(
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer())),
const_cast<const T* const __restrict__>(
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer())),
const_cast<const T* const __restrict__>(
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer())));
float time = launch_and_time_kernel(run_gridwise_operation<decltype(gridwise_conv),
T* const __restrict__,
const T* const __restrict__,
const T* const __restrict__>,
dim3(GridSize),
dim3(BlockSize),
0,
0,
gridwise_conv,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
printf("Elapsed time : %f ms, %f TFlop/s\n",
time,
......@@ -242,3 +243,5 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
in_nchw_device_buf.FromDevice(in_nchw.mData.data());
}
} // namespace launcher
......@@ -5,6 +5,10 @@
#include "gridwise_operation_wrapper.hpp"
#include "gridwise_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp"
namespace launcher {
using namespace ck;
template <typename T,
typename InDesc,
typename WeiDesc,
......@@ -84,36 +88,6 @@ void device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw(InDesc i
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 1
// BlockSize = 256, each thread hold 64 data
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmThreadGemmDataPerReadM = 4;
constexpr index_t GemmThreadGemmDataPerReadN = 4;
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<2, 1>;
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 1;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<2, 1>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#endif
......@@ -186,21 +160,18 @@ void device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw(InDesc i
for(index_t i = 0; i < nrepeat; ++i)
{
float time =
launch_and_time_kernel(run_gridwise_operation<decltype(gridwise_conv),
T* const __restrict__,
const T* const __restrict__,
const T* const __restrict__>,
dim3(GridSize),
dim3(BlockSize),
0,
gridwise_conv,
const_cast<T* const __restrict__>(
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer())),
const_cast<const T* const __restrict__>(
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer())),
const_cast<const T* const __restrict__>(
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer())));
float time = launch_and_time_kernel(run_gridwise_operation<decltype(gridwise_conv),
T* const __restrict__,
const T* const __restrict__,
const T* const __restrict__>,
dim3(GridSize),
dim3(BlockSize),
0,
0,
gridwise_conv,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
printf("Elapsed time : %f ms, %f TFlop/s\n",
time,
......@@ -211,3 +182,5 @@ void device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw(InDesc i
in_nchw_device_buf.FromDevice(in_nchw.mData.data());
}
} // namespace launcher
......@@ -5,6 +5,10 @@
#include "gridwise_operation_wrapper.hpp"
#include "gridwise_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
namespace launcher {
using namespace ck;
template <typename T,
typename InDesc,
typename WeiDesc,
......@@ -25,8 +29,6 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i
InRightPads,
std::size_t nrepeat)
{
using namespace ck;
constexpr index_t N = out_nkhw_desc.GetLengths()[0];
constexpr index_t K = out_nkhw_desc.GetLengths()[1];
constexpr index_t C = wei_kcyx_desc.GetLengths()[1];
......@@ -207,12 +209,9 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i
0,
0,
gridwise_conv,
const_cast<T* const __restrict__>(
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer())),
const_cast<const T* const __restrict__>(
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer())),
const_cast<const T* const __restrict__>(
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer())));
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
});
});
......@@ -229,3 +228,5 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i
in_nchw_device_buf.FromDevice(in_nchw.mData.data());
}
} // namespace launcher
......@@ -21,12 +21,12 @@
int main(int argc, char* argv[])
{
using namespace ck;
using namespace launcher;
#if 1
// 3x3 filter, 2x2 stride, 35x35 input
constexpr index_t N = 128;
constexpr index_t C = 1024;
constexpr index_t C = 128;
constexpr index_t HI = 35;
constexpr index_t WI = 35;
constexpr index_t K = 128;
......@@ -253,7 +253,7 @@ int main(int argc, char* argv[])
#elif 0
device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw
#elif 0
device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw
device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw
#else
device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw
#endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment