Commit d8c89b68 authored by Chao Liu's avatar Chao Liu
Browse files

refactor driver for conv

parent fd160c63
This diff is collapsed.
...@@ -184,6 +184,27 @@ struct TensorAdaptor ...@@ -184,6 +184,27 @@ struct TensorAdaptor
return get_container_subset(idx_hidden, BottomDimensionHiddenIds{}); return get_container_subset(idx_hidden, BottomDimensionHiddenIds{});
} }
__host__ __device__ void Print() const
{
printf("{");
printf("TensorAdaptor, ");
static_for<0, ntransform_, 1>{}([&](auto i) {
printf("transforms: ");
transforms_[i].Print();
printf("LowerDimensionHiddenIds:");
LowerDimensionHiddenIdss{}.At(i).Print();
printf("UpperDimensionHiddenIds:");
UpperDimensionHiddenIdss{}.At(i).Print();
});
printf("BottomDimensionHiddenIds:");
BottomDimensionHiddenIds::Print();
printf("TopDimensionHiddenIds:");
TopDimensionHiddenIds::Print();
printf("}");
}
private: private:
Transforms transforms_; Transforms transforms_;
ElementSize element_size_; ElementSize element_size_;
......
...@@ -12,7 +12,36 @@ ...@@ -12,7 +12,36 @@
namespace ck { namespace ck {
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER #if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
template <typename GridwiseGemm,
typename AGlobalDesc,
typename FloatA,
typename BGlobalDesc,
typename FloatB,
typename CGlobalDesc,
typename FloatC,
typename CBlockClusterDesc,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__global__ void kernel_dynamic_gemm_v1(const AGlobalDesc a_k_m_global_desc,
const FloatA* __restrict__ p_a_global,
const BGlobalDesc b_k_n_global_desc,
const FloatB* __restrict__ p_b_global,
const CGlobalDesc c_m0_m1_n0_n1_global_desc,
FloatC* __restrict__ p_c_global,
const CBlockClusterDesc c_block_cluster_desc)
{
GridwiseGemm{}.Run(a_k_m_global_desc,
p_a_global,
b_k_n_global_desc,
p_b_global,
c_m0_m1_n0_n1_global_desc,
p_c_global,
c_block_cluster_desc,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
// pass tensor descriptor by __CONSTANT__ void pointer // pass tensor descriptor by __CONSTANT__ void pointer
// __CONSTANT__ is needed to inform compiler void pointers in the kernel signature are pointing to // __CONSTANT__ is needed to inform compiler void pointers in the kernel signature are pointing to
// non-modifiable parameter address space, so compiler can enable corresponding optimization // non-modifiable parameter address space, so compiler can enable corresponding optimization
...@@ -26,13 +55,13 @@ template <typename GridwiseGemm, ...@@ -26,13 +55,13 @@ template <typename GridwiseGemm,
typename CBlockClusterDesc, typename CBlockClusterDesc,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop> bool HasDoubleTailKBlockLoop>
__global__ void run_gridwise_dynamic_gemm_v1(const void __CONSTANT__* p_a_k_m_global_desc, __global__ void kernel_dynamic_gemm_v1(const void __CONSTANT__* p_a_k_m_global_desc,
const FloatA* __restrict__ p_a_global, const FloatA* __restrict__ p_a_global,
const void __CONSTANT__* p_b_k_n_global_desc, const void __CONSTANT__* p_b_k_n_global_desc,
const FloatB* __restrict__ p_b_global, const FloatB* __restrict__ p_b_global,
const void __CONSTANT__* p_c_m0_m1_n0_n1_global_desc, const void __CONSTANT__* p_c_m0_m1_n0_n1_global_desc,
FloatC* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
const void __CONSTANT__* p_c_block_cluster_desc) const void __CONSTANT__* p_c_block_cluster_desc)
{ {
// first cast void __CONSTANT__ void* to void* // first cast void __CONSTANT__ void* to void*
// second cast void* to Desc* // second cast void* to Desc*
......
...@@ -46,6 +46,7 @@ void launch_kernel(F kernel, ...@@ -46,6 +46,7 @@ void launch_kernel(F kernel,
template <typename... Args, typename F> template <typename... Args, typename F>
float launch_and_time_kernel(F kernel, float launch_and_time_kernel(F kernel,
int nrepeat,
dim3 grid_dim, dim3 grid_dim,
dim3 block_dim, dim3 block_dim,
std::size_t lds_byte, std::size_t lds_byte,
...@@ -54,15 +55,32 @@ float launch_and_time_kernel(F kernel, ...@@ -54,15 +55,32 @@ float launch_and_time_kernel(F kernel,
{ {
KernelTimer timer; KernelTimer timer;
timer.Start(); printf("%s: block_dim {%d, %d, %d}, grid_dim {%d, %d, %d} \n",
__func__,
grid_dim.x,
grid_dim.y,
grid_dim.z,
block_dim.x,
block_dim.y,
block_dim.z);
printf("Warm up\n");
// warm up
hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...); hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...);
timer.End(); printf("Start running %d times...\n", nrepeat);
timer.Start();
hipGetLastError(); for(int i = 0; i < nrepeat; ++i)
{
hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...);
}
return timer.GetElapsedTime(); timer.End();
return timer.GetElapsedTime() / nrepeat;
} }
#elif CK_DEVICE_BACKEND_NVIDIA #elif CK_DEVICE_BACKEND_NVIDIA
......
...@@ -29,8 +29,17 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw( ...@@ -29,8 +29,17 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(
{ {
using namespace ck; using namespace ck;
std::cout << "device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw" std::cout << __func__ << std::endl;
<< std::endl;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{};
constexpr auto I6 = Number<6>{};
constexpr auto I7 = Number<7>{};
constexpr auto I8 = Number<8>{};
DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace()); DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace());
DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace()); DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace());
...@@ -459,50 +468,91 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw( ...@@ -459,50 +468,91 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4; constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4;
#endif #endif
constexpr auto conv_driver = constexpr index_t GemmM1 = GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster;
constexpr index_t GemmN1 = GemmNPerThread * GemmNLevel0Cluster * GemmNLevel1Cluster;
const auto descs =
#if 1 #if 1
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad
#elif 0 #elif 0
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_no_pad
#elif 1 #else
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1 transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_1x1
#endif #endif
<BlockSize, <GemmMPerBlock, GemmNPerBlock, GemmM1, GemmN1>(wei_k_c_y_x_desc,
typename vector_type<TInWei, InWeiVectorSize>::type, in_n_c_hi_wi_desc,
TAcc, out_n_k_ho_wo_desc,
TOut, conv_strides,
GemmMPerBlock, conv_dilations,
GemmNPerBlock, in_left_pads,
GemmKPerBlock, in_right_pads);
GemmMPerThread,
GemmNPerThread, float ave_time = launch_kernel_dynamic_gemm_v1<
GemmKPerThread, BlockSize,
GemmMLevel0Cluster, typename vector_type<TInWei, InWeiVectorSize>::type,
GemmNLevel0Cluster, TAcc,
GemmMLevel1Cluster, TOut,
GemmNLevel1Cluster, InMemoryDataOperation::Set,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM, decltype(descs[I0]),
GemmABlockTransferThreadClusterLengths_GemmK_GemmM, decltype(descs[I1]),
GemmABlockTransferSrcScalarPerVector_GemmK, decltype(descs[I2]),
GemmABlockTransferDstScalarPerVector_GemmM, decltype(descs[I3]),
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN, GemmMPerBlock,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN, GemmNPerBlock,
GemmBBlockTransferSrcScalarPerVector_GemmN, GemmKPerBlock,
GemmBBlockTransferDstScalarPerVector_GemmN, GemmMPerThread,
GemmCThreadTransferDstScalarPerVector_GemmN1>{}; GemmNPerThread,
GemmKPerThread,
conv_driver.Run(wei_k_c_y_x_desc, GemmMLevel0Cluster,
in_n_c_hi_wi_desc, GemmNLevel0Cluster,
out_n_k_ho_wo_desc, GemmMLevel1Cluster,
conv_strides, GemmNLevel1Cluster,
conv_dilations, GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
in_left_pads, GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
in_right_pads, Sequence<1, 0>,
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>( Sequence<1, 0>,
wei_k_c_y_x_device_buf.GetDeviceBuffer()), 0,
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>( GemmABlockTransferSrcScalarPerVector_GemmK,
in_n_c_hi_wi_device_buf.GetDeviceBuffer()), GemmABlockTransferDstScalarPerVector_GemmM,
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer())); false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN,
Sequence<0, 1>,
Sequence<0, 1>,
1,
GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmN,
false, // don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation
Sequence<2, 3, 0, 1>,
3,
GemmCThreadTransferDstScalarPerVector_GemmN1,
decltype(descs[I4]),
decltype(descs[I5]),
decltype(descs[I6]),
decltype(descs[I7]),
decltype(descs[I8])>(static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
wei_k_c_y_x_device_buf.GetDeviceBuffer()),
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
descs[I0],
descs[I1],
descs[I2],
descs[I3],
descs[I4],
descs[I5],
descs[I6],
descs[I7],
descs[I8],
nrepeat);
float perf = (float)calculate_convolution_flops(
in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
// copy result back to host
out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data()); out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data());
} }
...@@ -29,13 +29,17 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk( ...@@ -29,13 +29,17 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
{ {
using namespace ck; using namespace ck;
std::cout << "device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk" std::cout << __func__ << std::endl;
<< std::endl;
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{};
constexpr auto I6 = Number<6>{};
constexpr auto I7 = Number<7>{};
constexpr auto I8 = Number<8>{};
constexpr auto N = OutDesc::GetLengths()[I0]; constexpr auto N = OutDesc::GetLengths()[I0];
constexpr auto K = OutDesc::GetLengths()[I1]; constexpr auto K = OutDesc::GetLengths()[I1];
...@@ -372,51 +376,89 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk( ...@@ -372,51 +376,89 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4; constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4;
#endif #endif
constexpr auto conv_driver = constexpr index_t GemmM1 = GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster;
constexpr index_t GemmN1 = GemmNPerThread * GemmNLevel0Cluster * GemmNLevel1Cluster;
const auto descs =
#if 1 #if 1
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_pad
#elif 0 #else
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_no_pad transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_1x1
#elif 1
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1
#endif #endif
<BlockSize, <GemmMPerBlock, GemmNPerBlock, GemmM1, GemmN1>(wei_k_y_x_c0_desc,
typename vector_type<TInWei, InWeiVectorSize>::type, in_n_hi_wi_c0_desc,
TAcc, out_n_ho_wo_k_desc,
TOut, conv_strides,
GemmMPerBlock, conv_dilations,
GemmNPerBlock, in_left_pads,
GemmKPerBlock, in_right_pads);
GemmMPerThread,
GemmNPerThread, float ave_time = launch_kernel_dynamic_gemm_v1<
GemmKPerThread, BlockSize,
GemmMLevel0Cluster, typename vector_type<TInWei, InWeiVectorSize>::type,
GemmNLevel0Cluster, TAcc,
GemmMLevel1Cluster, TOut,
GemmNLevel1Cluster, InMemoryDataOperation::Set,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM, decltype(descs[I0]),
GemmABlockTransferThreadClusterLengths_GemmK_GemmM, decltype(descs[I1]),
GemmABlockTransferSrcScalarPerVector_GemmK, decltype(descs[I2]),
GemmABlockTransferDstScalarPerVector_GemmM, decltype(descs[I3]),
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN, GemmMPerBlock,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN, GemmNPerBlock,
GemmBBlockTransferSrcScalarPerVector_GemmK, GemmKPerBlock,
GemmBBlockTransferDstScalarPerVector_GemmN, GemmMPerThread,
GemmCThreadTransferDstScalarPerVector_GemmM1>{}; GemmNPerThread,
GemmKPerThread,
conv_driver.Run(wei_k_y_x_c0_desc, GemmMLevel0Cluster,
in_n_hi_wi_c0_desc, GemmNLevel0Cluster,
out_n_ho_wo_k_desc, GemmMLevel1Cluster,
conv_strides, GemmNLevel1Cluster,
conv_dilations, GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
in_left_pads, GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
in_right_pads, Sequence<1, 0>,
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>( Sequence<1, 0>,
wei_k_y_x_c_device_buf.GetDeviceBuffer()), 0,
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>( GemmABlockTransferSrcScalarPerVector_GemmK,
in_n_hi_wi_c_device_buf.GetDeviceBuffer()), GemmABlockTransferDstScalarPerVector_GemmM,
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer())); false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN,
Sequence<1, 0>,
Sequence<1, 0>,
0,
GemmBBlockTransferSrcScalarPerVector_GemmK,
GemmBBlockTransferDstScalarPerVector_GemmN,
false, // don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation
Sequence<2, 3, 0, 1>,
1,
GemmCThreadTransferDstScalarPerVector_GemmM1,
decltype(descs[I4]),
decltype(descs[I5]),
decltype(descs[I6]),
decltype(descs[I7]),
decltype(descs[I8])>(static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
wei_k_y_x_c_device_buf.GetDeviceBuffer()),
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
descs[I0],
descs[I1],
descs[I2],
descs[I3],
descs[I4],
descs[I5],
descs[I6],
descs[I7],
descs[I8],
nrepeat);
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
// copy result back to host
out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data()); out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data());
auto f_nhwk2nkhw = [&](auto n, auto k, auto ho, auto wo) { auto f_nhwk2nkhw = [&](auto n, auto k, auto ho, auto wo) {
......
...@@ -210,7 +210,7 @@ int main(int argc, char* argv[]) ...@@ -210,7 +210,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<1, 1>; using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>; using RightPads = Sequence<1, 1>;
#elif 0 #elif 1
// 3x3, 71x71 // 3x3, 71x71
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 192; constexpr index_t C = 192;
...@@ -225,7 +225,7 @@ int main(int argc, char* argv[]) ...@@ -225,7 +225,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<1, 1>; using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>; using RightPads = Sequence<1, 1>;
#elif 1 #elif 0
// 7x1, 17x17 // 7x1, 17x17
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 128; constexpr index_t C = 128;
...@@ -724,7 +724,7 @@ int main(int argc, char* argv[]) ...@@ -724,7 +724,7 @@ int main(int argc, char* argv[])
LeftPads{}, LeftPads{},
RightPads{}, RightPads{},
nrepeat); nrepeat);
#elif 0 #elif 1
device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw<in_data_t, device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw<in_data_t,
in_vector_size, in_vector_size,
acc_data_t, acc_data_t,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment