Commit fa479ce4 authored by Chao Liu's avatar Chao Liu
Browse files

modify gridwise dynamic gemm looping

parent 8e35a579
...@@ -211,9 +211,369 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw ...@@ -211,9 +211,369 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw
const index_t GridSize = (GemmM / GemmMPerBlock) * (GemmN / GemmNPerBlock); const index_t GridSize = (GemmM / GemmMPerBlock) * (GemmN / GemmNPerBlock);
const bool is_even_number_k_block_loop = (GemmK / GemmKPerBlock) % 2 == 0; const bool has_main_k_block_loop = (GemmK + GemmKPerBlock) / (2 * GemmKPerBlock) > 1;
if(is_even_number_k_block_loop) const bool has_double_tail_k_block_loop = (GemmK / GemmKPerBlock) % 2 == 0;
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const Float*,
decltype(in_gemmk_gemmn_global_desc),
const Float*,
decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*,
integral_constant<bool, true>,
integral_constant<bool, true>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
integral_constant<bool, true>{},
integral_constant<bool, true>{});
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const Float*,
decltype(in_gemmk_gemmn_global_desc),
const Float*,
decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*,
integral_constant<bool, true>,
integral_constant<bool, false>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const Float*,
decltype(in_gemmk_gemmn_global_desc),
const Float*,
decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*,
integral_constant<bool, false>,
integral_constant<bool, true>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const Float*,
decltype(in_gemmk_gemmn_global_desc),
const Float*,
decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*,
integral_constant<bool, false>,
integral_constant<bool, false>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
}
};
// GemmM = K
// GemmN = N * Ho * Wo
// GemmK = C * Y * X
template <index_t BlockSize,
typename Float,
typename AccFloat,
index_t GemmMPerBlock,
index_t GemmNPerBlock,
index_t GemmKPerBlock,
index_t GemmMPerThread,
index_t GemmNPerThread,
index_t GemmKPerThread,
index_t GemmMLevel0Cluster,
index_t GemmNLevel0Cluster,
index_t GemmMLevel1Cluster,
index_t GemmNLevel1Cluster,
typename GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
typename GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
index_t GemmABlockTransferSrcScalarPerVector_GemmK,
index_t GemmABlockTransferDstScalarPerVector_GemmM,
typename GemmBBlockTransferThreadSliceLengths_GemmK_GemmN,
typename GemmBBlockTransferThreadClusterLengths_GemmK_GemmN,
index_t GemmBBlockTransferSrcScalarPerVector_GemmN,
index_t GemmBBlockTransferDstScalarPerVector_GemmN,
index_t GemmCThreadTransferDstScalarPerVector_GemmN1>
struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
{
template <typename... Wei, typename... In, typename... Out>
__host__ void Run(const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
const DynamicTensorDescriptor<Out...>& out_n_k_ho_wo_global_desc,
const MultiIndex<2> conv_strides,
const MultiIndex<2> conv_dilations,
const MultiIndex<2> in_left_pads,
const MultiIndex<2> in_right_pads,
const Float* __restrict__ p_wei_global,
const Float* __restrict__ p_in_global,
Float* __restrict__ p_out_global) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const index_t N = in_n_c_hi_wi_global_desc.GetLength(I0);
const index_t C = in_n_c_hi_wi_global_desc.GetLength(I1);
const index_t K = out_n_k_ho_wo_global_desc.GetLength(I1);
const index_t Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
const index_t Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
const index_t Ho = out_n_k_ho_wo_global_desc.GetLength(I2);
const index_t Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
const index_t Y = wei_k_c_y_x_global_desc.GetLength(I2);
const index_t X = wei_k_c_y_x_global_desc.GetLength(I3);
const index_t ConvStrideH = conv_strides[I0];
const index_t ConvStrideW = conv_strides[I1];
const index_t ConvDilationH = conv_dilations[I0];
const index_t ConvDilationW = conv_dilations[I1];
const index_t InLeftPadH = in_left_pads[I0];
const index_t InLeftPadW = in_left_pads[I1];
const index_t InRightPadH = in_right_pads[I0];
const index_t InRightPadW = in_right_pads[I1];
// weight tensor
#if 0
// TODO implement graph optimization of tensor descriptor transformation
const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor(
wei_k_c_y_x_global_desc,
make_tuple(DynamicPassThrough{K}, DynamicMerge<3>{make_multi_index(C, Y, X)}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
#else
const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed<2>(make_multi_index(K, C * Y * X)),
make_tuple(DynamicPassThrough{K}, DynamicPassThrough{C * Y * X}),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
#endif
// input tensor
// debug: don't do padding
const auto in_n_c_hip_wip_global_desc = in_n_c_hi_wi_global_desc;
const index_t Hip = in_n_c_hip_wip_global_desc.GetLength(I2);
const index_t Wip = in_n_c_hip_wip_global_desc.GetLength(I3);
const auto in_n_c_y_ho_x_wo_global_desc = transform_dynamic_tensor_descriptor(
in_n_c_hip_wip_global_desc,
make_tuple(DynamicPassThrough{N},
DynamicPassThrough{C},
DynamicEmbed<2>{make_multi_index(Y, Ho),
make_multi_index(ConvDilationH, ConvStrideH)},
DynamicEmbed<2>{make_multi_index(X, Wo),
make_multi_index(ConvDilationW, ConvStrideW)}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
const auto in_gemmk_gemmn_global_desc = transform_dynamic_tensor_descriptor(
in_n_c_y_ho_x_wo_global_desc,
make_tuple(DynamicMerge<3>{make_multi_index(C, Y, X)},
DynamicMerge<3>{make_multi_index(N, Ho, Wo)}),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// output tensor
#if 0
//TODO: implement graph optimization of tensor descriptor transformation
const auto out_gemmm_gemmn_global_desc =
transform_dynamic_tensor_descriptor(out_n_k_ho_wo_global_desc,
make_tuple(DynamicPassThrough{K}, DynamicMerge<3>{make_mult_index(N, Ho, Wo)}),
make_tuple(Sequence<1>{}, Sequence<0, 2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
#else
const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed<3>(make_multi_index(N, K, Ho * Wo)),
make_tuple(DynamicPassThrough{K}, DynamicMerge<2>{make_multi_index(N, Ho * Wo)}),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
#endif
const index_t GemmM = out_gemmm_gemmn_global_desc.GetLength(I0);
const index_t GemmN = out_gemmm_gemmn_global_desc.GetLength(I1);
const index_t GemmK = wei_gemmk_gemmm_global_desc.GetLength(I0);
if(!(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 &&
GemmK % GemmKPerBlock == 0))
{
throw std::runtime_error("wrong! GEMM size no divisible");
}
constexpr index_t GemmM1 = GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster;
constexpr index_t GemmN1 = GemmNPerThread * GemmNLevel0Cluster * GemmNLevel1Cluster;
const index_t GemmM0 = GemmM / GemmM1;
const index_t GemmN0 = GemmN / GemmN1;
const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc =
transform_dynamic_tensor_descriptor(
out_gemmm_gemmn_global_desc,
make_tuple(DynamicUnMerge<2>{make_multi_index(GemmM0, GemmM1)},
DynamicUnMerge<2>{make_multi_index(GemmN0, GemmN1)}),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
// GEMM
using gridwise_gemm = GridwiseDynamicGemm_km_kn_mn_v1<
BlockSize,
Float,
AccFloat,
InMemoryDataOperation::Set,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThread,
GemmNPerThread,
GemmKPerThread,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
Sequence<1, 0>,
Sequence<1, 0>,
0,
GemmABlockTransferSrcScalarPerVector_GemmK,
GemmABlockTransferDstScalarPerVector_GemmM,
true, // move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN,
Sequence<0, 1>,
Sequence<0, 1>,
1,
GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmN,
false, // don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation
Sequence<2, 3, 0, 1>,
3,
GemmCThreadTransferDstScalarPerVector_GemmN1>;
const index_t GridSize = (GemmM / GemmMPerBlock) * (GemmN / GemmNPerBlock);
const bool has_main_k_block_loop = (GemmK + GemmKPerBlock) / (2 * GemmKPerBlock) > 1;
const bool has_double_tail_k_block_loop = (GemmK / GemmKPerBlock) % 2 == 0;
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const Float*,
decltype(in_gemmk_gemmn_global_desc),
const Float*,
decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*,
integral_constant<bool, true>,
integral_constant<bool, true>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
integral_constant<bool, true>{},
integral_constant<bool, true>{});
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_gemmk_gemmm_global_desc),
const Float*,
decltype(in_gemmk_gemmn_global_desc),
const Float*,
decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*,
integral_constant<bool, true>,
integral_constant<bool, false>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc,
p_wei_global,
in_gemmk_gemmn_global_desc,
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global,
integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
const auto kernel = const auto kernel =
run_gridwise_operation<gridwise_gemm, run_gridwise_operation<gridwise_gemm,
...@@ -223,6 +583,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw ...@@ -223,6 +583,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw
const Float*, const Float*,
decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc), decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*, Float*,
integral_constant<bool, false>,
integral_constant<bool, true>>; integral_constant<bool, true>>;
launch_kernel(kernel, launch_kernel(kernel,
...@@ -236,6 +597,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw ...@@ -236,6 +597,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw
p_in_global, p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc, out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global, p_out_global,
integral_constant<bool, false>{},
integral_constant<bool, true>{}); integral_constant<bool, true>{});
} }
else else
...@@ -248,6 +610,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw ...@@ -248,6 +610,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw
const Float*, const Float*,
decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc), decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc),
Float*, Float*,
integral_constant<bool, false>,
integral_constant<bool, false>>; integral_constant<bool, false>>;
launch_kernel(kernel, launch_kernel(kernel,
...@@ -261,6 +624,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw ...@@ -261,6 +624,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw
p_in_global, p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc, out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
p_out_global, p_out_global,
integral_constant<bool, false>{},
integral_constant<bool, false>{}); integral_constant<bool, false>{});
} }
} }
......
...@@ -73,7 +73,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -73,7 +73,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
return 2 * (a_block_space_size + b_block_space_size) * sizeof(Float); return 2 * (a_block_space_size + b_block_space_size) * sizeof(Float);
} }
template <typename... ADesc, typename... BDesc, typename... CDesc, bool IsEvenNumberKBlockLoop> template <typename... ADesc,
typename... BDesc,
typename... CDesc,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__device__ void Run(const DynamicTensorDescriptor<ADesc...>& a_k_m_global_desc, __device__ void Run(const DynamicTensorDescriptor<ADesc...>& a_k_m_global_desc,
const Float* __restrict__ p_a_global, const Float* __restrict__ p_a_global,
const DynamicTensorDescriptor<BDesc...>& b_k_n_global_desc, const DynamicTensorDescriptor<BDesc...>& b_k_n_global_desc,
...@@ -81,7 +85,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -81,7 +85,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
const DynamicTensorDescriptor<CDesc...>& c_m0_m1_n0_n1_global_desc, const DynamicTensorDescriptor<CDesc...>& c_m0_m1_n0_n1_global_desc,
Float* __restrict__ p_c_global, Float* __restrict__ p_c_global,
Float* __restrict__ p_shared_block, Float* __restrict__ p_shared_block,
integral_constant<bool, IsEvenNumberKBlockLoop>) const integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -264,88 +269,91 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -264,88 +269,91 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
} }
#endif #endif
#if 1 if constexpr(HasMainKBlockLoop)
Float* p_a_block_even = p_a_block_double; {
Float* p_b_block_even = p_b_block_double; Float* p_a_block_even = p_a_block_double;
Float* p_b_block_even = p_b_block_double;
Float* p_a_block_odd = p_a_block_double + a_block_space_size; Float* p_a_block_odd = p_a_block_double + a_block_space_size;
Float* p_b_block_odd = p_b_block_double + b_block_space_size; Float* p_b_block_odd = p_b_block_double + b_block_space_size;
// LDS double buffer: main body index_t k_block_data_begin = 0;
for(index_t k_block_data_begin = 0; k_block_data_begin < K - 2 * KPerBlock;
k_block_data_begin += 2 * KPerBlock)
{
// even iteration
a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc, a_block_slice_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_k_n_global_desc, b_block_slice_copy_step);
__syncthreads(); // LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow
do
{
// even iteration
a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc, a_block_slice_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_k_n_global_desc, b_block_slice_copy_step);
// LDS doubel buffer: load next data from device mem __syncthreads();
a_blockwise_copy.RunRead(a_k_m_global_desc, p_a_global);
b_blockwise_copy.RunRead(b_k_n_global_desc, p_b_global);
// LDS double buffer: GEMM on current data // LDS doubel buffer: load next data from device mem
blockwise_gemm.Run(p_a_block_even, p_b_block_even, p_c_thread); a_blockwise_copy.RunRead(a_k_m_global_desc, p_a_global);
b_blockwise_copy.RunRead(b_k_n_global_desc, p_b_global);
// LDS double buffer: store next data to LDS // LDS double buffer: GEMM on current data
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_odd); blockwise_gemm.Run(p_a_block_even, p_b_block_even, p_c_thread);
b_blockwise_copy.RunWrite(b_k_n_block_desc, p_b_block_odd);
// odd iteration // LDS double buffer: store next data to LDS
a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc, a_block_slice_copy_step); a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_odd);
b_blockwise_copy.MoveSrcSliceWindow(b_k_n_global_desc, b_block_slice_copy_step); b_blockwise_copy.RunWrite(b_k_n_block_desc, p_b_block_odd);
__syncthreads(); // odd iteration
a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc, a_block_slice_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_k_n_global_desc, b_block_slice_copy_step);
// LDS doubel buffer: load next data from device mem __syncthreads();
a_blockwise_copy.RunRead(a_k_m_global_desc, p_a_global);
b_blockwise_copy.RunRead(b_k_n_global_desc, p_b_global);
// LDS double buffer: GEMM on current data // LDS doubel buffer: load next data from device mem
blockwise_gemm.Run(p_a_block_odd, p_b_block_odd, p_c_thread); a_blockwise_copy.RunRead(a_k_m_global_desc, p_a_global);
b_blockwise_copy.RunRead(b_k_n_global_desc, p_b_global);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_a_block_odd, p_b_block_odd, p_c_thread);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_even);
b_blockwise_copy.RunWrite(b_k_n_block_desc, p_b_block_even);
// LDS double buffer: store next data to LDS k_block_data_begin += 2 * KPerBlock;
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_even); } while(k_block_data_begin < K - 2 * KPerBlock);
b_blockwise_copy.RunWrite(b_k_n_block_desc, p_b_block_even);
} }
#endif
#if 1 #if 1
// LDS double buffer: tail // LDS double buffer: tail
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{ {
if constexpr(IsEvenNumberKBlockLoop) // if has 2 iteration left a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc, a_block_slice_copy_step);
{ b_blockwise_copy.MoveSrcSliceWindow(b_k_n_global_desc, b_block_slice_copy_step);
a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc, a_block_slice_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_k_n_global_desc, b_block_slice_copy_step);
__syncthreads(); __syncthreads();
// LDS double buffer: load last data from device mem // LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(a_k_m_global_desc, p_a_global); a_blockwise_copy.RunRead(a_k_m_global_desc, p_a_global);
b_blockwise_copy.RunRead(b_k_n_global_desc, p_b_global); b_blockwise_copy.RunRead(b_k_n_global_desc, p_b_global);
// LDS double buffer: GEMM on 2nd-last data // LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread); blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
// LDS double buffer: store last data to LDS // LDS double buffer: store last data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_double + a_block_space_size); a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_double + a_block_space_size);
b_blockwise_copy.RunWrite(b_k_n_block_desc, p_b_block_double + b_block_space_size); b_blockwise_copy.RunWrite(b_k_n_block_desc, p_b_block_double + b_block_space_size);
__syncthreads(); __syncthreads();
// LDS double buffer: GEMM on last data // LDS double buffer: GEMM on last data
blockwise_gemm.Run(p_a_block_double + a_block_space_size, blockwise_gemm.Run(p_a_block_double + a_block_space_size,
p_b_block_double + b_block_space_size, p_b_block_double + b_block_space_size,
p_c_thread); p_c_thread);
} }
else // if has 1 iteration left else // if has 1 iteration left
{ {
__syncthreads(); __syncthreads();
// LDS double buffer: GEMM on last data // LDS double buffer: GEMM on last data
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread); blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
}
} }
#endif #endif
...@@ -398,14 +406,19 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -398,14 +406,19 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
} }
} }
template <typename... ADesc, typename... BDesc, typename... CDesc, bool IsEvenNumberKBlockLoop> template <typename... ADesc,
typename... BDesc,
typename... CDesc,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__device__ void Run(const DynamicTensorDescriptor<ADesc...>& a_k_m_global_desc, __device__ void Run(const DynamicTensorDescriptor<ADesc...>& a_k_m_global_desc,
const Float* __restrict__ p_a_global, const Float* __restrict__ p_a_global,
const DynamicTensorDescriptor<BDesc...>& b_k_n_global_desc, const DynamicTensorDescriptor<BDesc...>& b_k_n_global_desc,
const Float* __restrict__ p_b_global, const Float* __restrict__ p_b_global,
const DynamicTensorDescriptor<CDesc...>& c_m0_m1_n0_n1_global_desc, const DynamicTensorDescriptor<CDesc...>& c_m0_m1_n0_n1_global_desc,
Float* __restrict__ p_c_global, Float* __restrict__ p_c_global,
integral_constant<bool, IsEvenNumberKBlockLoop>) const integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const
{ {
constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(Float); constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(Float);
...@@ -418,7 +431,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -418,7 +431,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
c_m0_m1_n0_n1_global_desc, c_m0_m1_n0_n1_global_desc,
p_c_global, p_c_global,
p_shared_block, p_shared_block,
integral_constant<bool, IsEvenNumberKBlockLoop>{}); integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
} }
}; };
} // namespace ck } // namespace ck
......
...@@ -87,7 +87,7 @@ void device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -87,7 +87,7 @@ void device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 0 #elif 1
// cdata = 64, BlockSize = 256, 128x128x4 // cdata = 64, BlockSize = 256, 128x128x4
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
...@@ -99,10 +99,10 @@ void device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -99,10 +99,10 @@ void device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmNPerThread = 4; constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1; constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 4; constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 4; constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 4; constexpr index_t GemmMLevel1Cluster = 8;
constexpr index_t GemmNLevel1Cluster = 4; constexpr index_t GemmNLevel1Cluster = 8;
constexpr index_t ThreadGemmDataPerReadM = 4; constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t ThreadGemmDataPerReadN = 4; constexpr index_t ThreadGemmDataPerReadN = 4;
......
...@@ -54,6 +54,36 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc ...@@ -54,6 +54,36 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc
const auto in_right_pads = to_multi_index(InRightPads{}); const auto in_right_pads = to_multi_index(InRightPads{});
#if 1 #if 1
// cdata = 64, BlockSize = 256, 128x128x4
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 8;
constexpr index_t GemmNLevel1Cluster = 8;
using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<2, 1>;
using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1;
using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<2, 1>;
using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
#elif 1
// cdata = 64, BlockSize = 256, 128x128x8 // cdata = 64, BlockSize = 256, 128x128x8
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
...@@ -107,29 +137,34 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc ...@@ -107,29 +137,34 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
constexpr auto conv_driver = DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw< constexpr auto conv_driver =
BlockSize, #if 0 // debug
TDevice, DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw
TDevice, #else
GemmMPerBlock, DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
GemmNPerBlock, #endif
GemmKPerBlock, <BlockSize,
GemmMPerThread, TDevice,
GemmNPerThread, TDevice,
GemmKPerThread, GemmMPerBlock,
GemmMLevel0Cluster, GemmNPerBlock,
GemmNLevel0Cluster, GemmKPerBlock,
GemmMLevel1Cluster, GemmMPerThread,
GemmNLevel1Cluster, GemmNPerThread,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM, GemmKPerThread,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM, GemmMLevel0Cluster,
GemmABlockTransferSrcScalarPerVector_GemmK, GemmNLevel0Cluster,
GemmABlockTransferDstScalarPerVector_GemmM, GemmMLevel1Cluster,
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN, GemmNLevel1Cluster,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN, GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
GemmBBlockTransferSrcScalarPerVector_GemmN, GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
GemmBBlockTransferDstScalarPerVector_GemmN, GemmABlockTransferSrcScalarPerVector_GemmK,
GemmCThreadTransferDstScalarPerVector_GemmN1>{}; GemmABlockTransferDstScalarPerVector_GemmM,
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN,
GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmN,
GemmCThreadTransferDstScalarPerVector_GemmN1>{};
for(index_t i = 0; i < 5; ++i) for(index_t i = 0; i < 5; ++i)
{ {
......
...@@ -22,22 +22,22 @@ int main(int argc, char* argv[]) ...@@ -22,22 +22,22 @@ int main(int argc, char* argv[])
{ {
using namespace ck; using namespace ck;
#if 0 #if 1
// 1x1, 8x8 // 3x3, 35x35, stride 2
constexpr index_t N = 2; constexpr index_t N = 128;
constexpr index_t C = 24; constexpr index_t C = 192;
constexpr index_t HI = 8; constexpr index_t HI = 35;
constexpr index_t WI = 8; constexpr index_t WI = 35;
constexpr index_t K = 128; constexpr index_t K = 384;
constexpr index_t Y = 1; constexpr index_t Y = 3;
constexpr index_t X = 1; constexpr index_t X = 3;
using ConvStrides = Sequence<1, 1>; using ConvStrides = Sequence<2, 2>;
using ConvDilations = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 1
// 3x3, 71x71 // 3x3, 71x71
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 192; constexpr index_t C = 192;
...@@ -127,7 +127,7 @@ int main(int argc, char* argv[]) ...@@ -127,7 +127,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<3, 0>; using LeftPads = Sequence<3, 0>;
using RightPads = Sequence<3, 0>; using RightPads = Sequence<3, 0>;
#elif 1 #elif 0
// 1x7, 17x17 // 1x7, 17x17
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 128; constexpr index_t C = 128;
...@@ -217,7 +217,7 @@ int main(int argc, char* argv[]) ...@@ -217,7 +217,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 1
// 3x3, 35x35, stride 2 // 3x3, 35x35, stride 2
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 288; constexpr index_t C = 288;
...@@ -352,7 +352,7 @@ int main(int argc, char* argv[]) ...@@ -352,7 +352,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 1 #elif 0
// 3x3, 28x28 // 3x3, 28x28
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 128; constexpr index_t C = 128;
...@@ -367,7 +367,7 @@ int main(int argc, char* argv[]) ...@@ -367,7 +367,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<1, 1>; using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>; using RightPads = Sequence<1, 1>;
#elif 0 #elif 1
// 3x3, 14x14 // 3x3, 14x14
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 256; constexpr index_t C = 256;
...@@ -567,17 +567,17 @@ int main(int argc, char* argv[]) ...@@ -567,17 +567,17 @@ int main(int argc, char* argv[])
#if 0 #if 0
device_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc, device_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc,
in_nchw, in_nchw,
wei_kcyx_desc, wei_kcyx_desc,
wei_kcyx, wei_kcyx,
out_nkhw_desc, out_nkhw_desc,
out_nkhw_device, out_nkhw_device,
ConvStrides{}, ConvStrides{},
ConvDilations{}, ConvDilations{},
LeftPads{}, LeftPads{},
RightPads{}, RightPads{},
nrepeat); nrepeat);
#elif 1 #elif 0
device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(in_nchw_desc, device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(in_nchw_desc,
in_nchw, in_nchw,
wei_kcyx_desc, wei_kcyx_desc,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment