Commit 31a440b9 authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed; tuning

parent b505370e
...@@ -19,8 +19,8 @@ template <typename GridwiseGemm, ...@@ -19,8 +19,8 @@ template <typename GridwiseGemm,
typename BGridDesc_E0_E1_N_Ho_Wo_E2, typename BGridDesc_E0_E1_N_Ho_Wo_E2,
typename CGridDesc_K_N_Ho_Wo, typename CGridDesc_K_N_Ho_Wo,
typename CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo, typename CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo,
bool HasMainKBlockLoop, bool HasMainE1BlockLoop,
bool HasDoubleTailKBlockLoop> bool HasDoubleTailE1BlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
...@@ -46,8 +46,8 @@ __global__ void ...@@ -46,8 +46,8 @@ __global__ void
a_e0_e1_k_e2_grid_desc, a_e0_e1_k_e2_grid_desc,
b_e0_e1_n_ho_wo_e2_grid_desc, b_e0_e1_n_ho_wo_e2_grid_desc,
c_k_n_ho_wo_grid_desc, c_k_n_ho_wo_grid_desc,
integral_constant<bool, HasMainKBlockLoop>{}, integral_constant<bool, HasMainE1BlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{}); integral_constant<bool, HasDoubleTailE1BlockLoop>{});
} }
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER #elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
// pass tensor descriptor by CONSTANT void pointer // pass tensor descriptor by CONSTANT void pointer
...@@ -60,8 +60,8 @@ template <typename GridwiseGemm, ...@@ -60,8 +60,8 @@ template <typename GridwiseGemm,
typename BGridDesc_E0_E1_N_Ho_Wo_E2, typename BGridDesc_E0_E1_N_Ho_Wo_E2,
typename CGridDesc_K_N_Ho_Wo, typename CGridDesc_K_N_Ho_Wo,
typename CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo, typename CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo,
bool HasMainKBlockLoop, bool HasMainE1BlockLoop,
bool HasDoubleTailKBlockLoop> bool HasDoubleTailE1BlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
...@@ -96,8 +96,8 @@ __global__ void ...@@ -96,8 +96,8 @@ __global__ void
a_e0_e1_k_e2_grid_desc, a_e0_e1_k_e2_grid_desc,
b_e0_e1_n_ho_wo_e2_grid_desc, b_e0_e1_n_ho_wo_e2_grid_desc,
c_k_n_ho_wo_grid_desc, c_k_n_ho_wo_grid_desc,
integral_constant<bool, HasMainKBlockLoop>{}, integral_constant<bool, HasMainE1BlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{}); integral_constant<bool, HasDoubleTailE1BlockLoop>{});
} }
#endif #endif
...@@ -109,8 +109,8 @@ template <index_t BlockSize, ...@@ -109,8 +109,8 @@ template <index_t BlockSize,
typename AGlobalDesc_E0_E1_K_E2, typename AGlobalDesc_E0_E1_K_E2,
typename BGlobalDesc_E0_E1_N_Ho_Wo_E2, typename BGlobalDesc_E0_E1_N_Ho_Wo_E2,
typename CGlobalDesc_K_N_Ho_Wo, typename CGlobalDesc_K_N_Ho_Wo,
index_t E1, index_t E1_,
index_t E2, index_t E2_,
index_t KPerBlock, index_t KPerBlock,
index_t HoPerBlock, index_t HoPerBlock,
index_t WoPerBlock, index_t WoPerBlock,
...@@ -148,6 +148,9 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -148,6 +148,9 @@ struct GridwiseGemmDlops_km_kn_mn_v3
static constexpr auto I4 = Number<4>{}; static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{}; static constexpr auto I5 = Number<5>{};
static constexpr auto E1 = Number<E1_>{};
static constexpr auto E2 = Number<E2_>{};
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{ {
constexpr auto max_lds_align = Number<ABlockTransferDstScalarPerVector_E2>{}; constexpr auto max_lds_align = Number<ABlockTransferDstScalarPerVector_E2>{};
...@@ -164,7 +167,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -164,7 +167,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
return a_block_space_size * sizeof(FloatAB); return a_block_space_size * sizeof(FloatAB);
} }
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop> template <bool HasMainE1BlockLoop, bool HasDoubleTailE1BlockLoop>
__device__ static void Run(const FloatAB* __restrict__ p_a_global, __device__ static void Run(const FloatAB* __restrict__ p_a_global,
const FloatAB* __restrict__ p_b_global, const FloatAB* __restrict__ p_b_global,
FloatC* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
...@@ -172,8 +175,8 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -172,8 +175,8 @@ struct GridwiseGemmDlops_km_kn_mn_v3
const AGlobalDesc_E0_E1_K_E2& a_e0_e1_k_e2_global_desc, const AGlobalDesc_E0_E1_K_E2& a_e0_e1_k_e2_global_desc,
const BGlobalDesc_E0_E1_N_Ho_Wo_E2& b_e0_e1_n_ho_wo_e2_global_desc, const BGlobalDesc_E0_E1_N_Ho_Wo_E2& b_e0_e1_n_ho_wo_e2_global_desc,
const CGlobalDesc_K_N_Ho_Wo& c_k_n_ho_wo_global_desc, const CGlobalDesc_K_N_Ho_Wo& c_k_n_ho_wo_global_desc,
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainE1BlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) integral_constant<bool, HasDoubleTailE1BlockLoop>)
{ {
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_a_global, a_e0_e1_k_e2_global_desc.GetElementSpaceSize()); p_a_global, a_e0_e1_k_e2_global_desc.GetElementSpaceSize());
...@@ -192,7 +195,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -192,7 +195,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3
const auto Wo = b_e0_e1_n_ho_wo_e2_global_desc.GetLength(I4); const auto Wo = b_e0_e1_n_ho_wo_e2_global_desc.GetLength(I4);
// divide block work by [M, N] // divide block work by [M, N]
#if 1 #if 0
const auto ho_block_work_num = Ho / Number<HoPerBlock>{}; const auto ho_block_work_num = Ho / Number<HoPerBlock>{};
const auto wo_block_work_num = Wo / Number<WoPerBlock>{}; const auto wo_block_work_num = Wo / Number<WoPerBlock>{};
const auto hwo_block_work_num = ho_block_work_num * wo_block_work_num; const auto hwo_block_work_num = ho_block_work_num * wo_block_work_num;
...@@ -356,110 +359,205 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -356,110 +359,205 @@ struct GridwiseGemmDlops_km_kn_mn_v3
const auto E0 = b_e0_e1_n_ho_wo_e2_global_desc.GetLength(I0); const auto E0 = b_e0_e1_n_ho_wo_e2_global_desc.GetLength(I0);
index_t e0_block_data_begin = 0; constexpr auto HasMainE0BlockLoop = false;
// do if constexpr(HasMainE0BlockLoop)
//{
// LDS double buffer: preload data
{ {
a_blockwise_copy.RunRead( index_t e0_block_data_begin = 0;
a_e0_e1_k_e2_global_desc, a_global_buf, a_e0_e1_k_e2_global_step_hacks);
b_threadwise_transfer.Run(b_e0_e1_n_ho_wo_e2_global_desc, do
b_global_buf, {
b_e0_e1_n_ho_wo_e2_thread_desc, // LDS double buffer: preload data
make_tuple(I0, I0, I0, I0, I0, I0), {
b_thread_even_buf, a_blockwise_copy.RunRead(
b_e0_e1_n_ho_wo_e2_global_step_hacks); a_e0_e1_k_e2_global_desc, a_global_buf, a_e0_e1_k_e2_global_step_hacks);
b_threadwise_transfer.Run(b_e0_e1_n_ho_wo_e2_global_desc,
b_global_buf,
b_e0_e1_n_ho_wo_e2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0),
b_thread_even_buf,
b_e0_e1_n_ho_wo_e2_global_step_hacks);
a_blockwise_copy.RunWrite(a_e0_e1_k_e2_block_desc, a_block_buf);
}
__syncthreads();
if constexpr(HasMainE1BlockLoop)
{
index_t e1_block_data_begin = 0;
// LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow
do
{
// even iteration
b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_ho_wo_e2_global_desc,
b_thread_slice_copy_step);
b_threadwise_transfer.Run(b_e0_e1_n_ho_wo_e2_global_desc,
b_global_buf,
b_e0_e1_n_ho_wo_e2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0),
b_thread_odd_buf,
b_e0_e1_n_ho_wo_e2_global_step_hacks);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow(make_tuple(EPerBlock, 0, 0));
b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_ho_wo_e2_global_desc,
b_thread_slice_copy_step);
b_threadwise_transfer.Run(b_e0_e1_n_ho_wo_e2_global_desc,
b_global_buf,
b_e0_e1_n_ho_wo_e2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0),
b_thread_even_buf,
b_e0_e1_n_ho_wo_e2_global_step_hacks);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow(make_tuple(EPerBlock, 0, 0));
e1_block_data_begin += 2 * EPerBlock;
} while(e1_block_data_begin < E1 - 2 * EPerBlock);
}
// LDS double buffer: tail
if constexpr(HasDoubleTailE1BlockLoop) // if has 2 iteration left
{
b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_ho_wo_e2_global_desc,
b_thread_slice_copy_step);
b_threadwise_transfer.Run(b_e0_e1_n_ho_wo_e2_global_desc,
b_global_buf,
b_e0_e1_n_ho_wo_e2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0),
b_thread_odd_buf,
b_e0_e1_n_ho_wo_e2_global_step_hacks);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow(make_tuple(EPerBlock, 0, 0));
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf);
}
else // if has 1 iteration left
{
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
}
a_blockwise_copy.MoveSrcSliceWindow(a_e0_e1_k_e2_global_desc,
a_block_slice_copy_step,
AGlobalMoveSliceWindowStepHacks{});
blockwise_gemm.MoveABlockSliceWindow(make_tuple(-(E1 - EPerBlock), 0, 0));
a_blockwise_copy.RunWrite(a_e0_e1_k_e2_block_desc, a_block_buf); b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_ho_wo_e2_global_desc,
} b_thread_slice_copy_step);
__syncthreads(); e0_block_data_begin += 1;
if constexpr(HasMainKBlockLoop) } while(e0_block_data_begin < E0);
}
else
{ {
index_t e1_block_data_begin = 0; // LDS double buffer: preload data
// LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow
do
{ {
// even iteration a_blockwise_copy.RunRead(
b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_ho_wo_e2_global_desc, a_e0_e1_k_e2_global_desc, a_global_buf, a_e0_e1_k_e2_global_step_hacks);
b_thread_slice_copy_step);
b_threadwise_transfer.Run(b_e0_e1_n_ho_wo_e2_global_desc, b_threadwise_transfer.Run(b_e0_e1_n_ho_wo_e2_global_desc,
b_global_buf, b_global_buf,
b_e0_e1_n_ho_wo_e2_thread_desc, b_e0_e1_n_ho_wo_e2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0),
b_thread_odd_buf, b_thread_even_buf,
b_e0_e1_n_ho_wo_e2_global_step_hacks); b_e0_e1_n_ho_wo_e2_global_step_hacks);
// LDS double buffer: GEMM on current data a_blockwise_copy.RunWrite(a_e0_e1_k_e2_block_desc, a_block_buf);
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); }
blockwise_gemm.MoveABlockSliceWindow(make_tuple(EPerBlock, 0, 0)); __syncthreads();
b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_ho_wo_e2_global_desc, if constexpr(HasMainE1BlockLoop)
b_thread_slice_copy_step); {
index_t e1_block_data_begin = 0;
b_threadwise_transfer.Run(b_e0_e1_n_ho_wo_e2_global_desc, // LDS double buffer: main body
b_global_buf, // use Do-While loop instead of For loop to simplify control flow
b_e0_e1_n_ho_wo_e2_thread_desc, do
make_tuple(I0, I0, I0, I0, I0, I0), {
b_thread_even_buf, // even iteration
b_e0_e1_n_ho_wo_e2_global_step_hacks); b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_ho_wo_e2_global_desc,
b_thread_slice_copy_step);
// LDS double buffer: GEMM on current data b_threadwise_transfer.Run(b_e0_e1_n_ho_wo_e2_global_desc,
blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf); b_global_buf,
b_e0_e1_n_ho_wo_e2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0),
b_thread_odd_buf,
b_e0_e1_n_ho_wo_e2_global_step_hacks);
blockwise_gemm.MoveABlockSliceWindow(make_tuple(EPerBlock, 0, 0)); // LDS double buffer: GEMM on current data
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
e1_block_data_begin += 2 * EPerBlock; blockwise_gemm.MoveABlockSliceWindow(make_tuple(EPerBlock, 0, 0));
} while(e1_block_data_begin < E1 - 2 * EPerBlock); b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_ho_wo_e2_global_desc,
} b_thread_slice_copy_step);
// LDS double buffer: tail b_threadwise_transfer.Run(b_e0_e1_n_ho_wo_e2_global_desc,
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left b_global_buf,
{ b_e0_e1_n_ho_wo_e2_thread_desc,
b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_ho_wo_e2_global_desc, make_tuple(I0, I0, I0, I0, I0, I0),
b_thread_slice_copy_step); b_thread_even_buf,
b_e0_e1_n_ho_wo_e2_global_step_hacks);
b_threadwise_transfer.Run(b_e0_e1_n_ho_wo_e2_global_desc, // LDS double buffer: GEMM on current data
b_global_buf, blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf);
b_e0_e1_n_ho_wo_e2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0),
b_thread_odd_buf,
b_e0_e1_n_ho_wo_e2_global_step_hacks);
// LDS double buffer: GEMM on 2nd-last data blockwise_gemm.MoveABlockSliceWindow(make_tuple(EPerBlock, 0, 0));
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
blockwise_gemm.MoveABlockSliceWindow(make_tuple(EPerBlock, 0, 0)); e1_block_data_begin += 2 * EPerBlock;
// LDS double buffer: GEMM on last data } while(e1_block_data_begin < E1 - 2 * EPerBlock);
blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf); }
}
else // if has 1 iteration left
{
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
}
// a_blockwise_copy.MoveSrcSliceWindow(a_e0_e1_k_e2_global_desc, // LDS double buffer: tail
// a_block_slice_copy_step, if constexpr(HasDoubleTailE1BlockLoop) // if has 2 iteration left
// AGlobalMoveSliceWindowStepHacks{}); {
b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_ho_wo_e2_global_desc,
b_thread_slice_copy_step);
// blockwise_gemm.MoveABlockSliceWindow(make_tuple(-(E1 - EPerBlock), 0, 0)); b_threadwise_transfer.Run(b_e0_e1_n_ho_wo_e2_global_desc,
b_global_buf,
b_e0_e1_n_ho_wo_e2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0),
b_thread_odd_buf,
b_e0_e1_n_ho_wo_e2_global_step_hacks);
// b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_ho_wo_e2_global_desc, // LDS double buffer: GEMM on 2nd-last data
// b_thread_slice_copy_step); blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
// e0_block_data_begin += 1; blockwise_gemm.MoveABlockSliceWindow(make_tuple(EPerBlock, 0, 0));
//} while(e0_block_data_begin < E0); // LDS double buffer: GEMM on last data
blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf);
}
else // if has 1 iteration left
{
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
}
}
// output: register to global memory // output: register to global memory
{ {
......
...@@ -93,17 +93,17 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nhwc_kyxc_nhwk( ...@@ -93,17 +93,17 @@ void device_convolution_forward_implicit_gemm_v5r1_dlops_nhwc_kyxc_nhwk(
constexpr index_t HoPerBlock = 8; constexpr index_t HoPerBlock = 8;
constexpr index_t WoPerBlock = 32; constexpr index_t WoPerBlock = 32;
constexpr index_t E1 = 2 * 9; constexpr index_t E1 = 4 * 9;
constexpr index_t E2 = 8; constexpr index_t E2 = 4;
constexpr index_t EPerBlock = 2; constexpr index_t EPerBlock = 4;
constexpr index_t KPerThread = KPerBlock; constexpr index_t KPerThread = KPerBlock;
constexpr index_t HoPerThread = 2; constexpr index_t HoPerThread = 2;
constexpr index_t WoPerThread = 2; constexpr index_t WoPerThread = 2;
constexpr index_t EPerThread = 1; constexpr index_t EPerThread = EPerBlock;
using ABlockTransferThreadSliceLengths_E0_E1_K_E2 = Sequence<1, 9, 1, E2>; using ABlockTransferThreadSliceLengths_E0_E1_K_E2 = Sequence<1, 9, 1, E2>;
using ABlockTransferThreadClusterLengths_E0_E1_K_E2 = Sequence<1, 2, 16, 1>; using ABlockTransferThreadClusterLengths_E0_E1_K_E2 = Sequence<1, EPerBlock, 16, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_E2 = E2; constexpr index_t ABlockTransferSrcScalarPerVector_E2 = E2;
constexpr index_t ABlockTransferDstScalarPerVector_E2 = E2; constexpr index_t ABlockTransferDstScalarPerVector_E2 = E2;
......
...@@ -95,7 +95,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nhwc_kyxc_nhwk_outp ...@@ -95,7 +95,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nhwc_kyxc_nhwk_outp
const auto C0 = C / E2; const auto C0 = C / E2;
const auto E = Y * X * C0; const auto E = Y * X * C0;
const auto E0 = E / E1; const auto E0 = E / E1;
// weight tensor // weight tensor
...@@ -198,13 +197,13 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nhwc_kyxc_nhwk_outp ...@@ -198,13 +197,13 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nhwc_kyxc_nhwk_outp
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
constexpr auto b_e0_e1_n_ho_wo_e2_global_move_slice_window_step_hack = constexpr auto b_e0_e1_n_ho_wo_e2_global_move_slice_window_step_hack =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}; Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{};
...@@ -243,18 +242,18 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nhwc_kyxc_nhwk_outp ...@@ -243,18 +242,18 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nhwc_kyxc_nhwk_outp
EPerThread, EPerThread,
ABlockTransferThreadSliceLengths_E0_E1_K_E2, ABlockTransferThreadSliceLengths_E0_E1_K_E2,
ABlockTransferThreadClusterLengths_E0_E1_K_E2, ABlockTransferThreadClusterLengths_E0_E1_K_E2,
Sequence<2, 0, 1, 3>, Sequence<0, 1, 2, 3>,
Sequence<2, 0, 1, 3>, Sequence<0, 1, 2, 3>,
3, 3,
ABlockTransferSrcScalarPerVector_E2, ABlockTransferSrcScalarPerVector_E2,
ABlockTransferDstScalarPerVector_E2, ABlockTransferDstScalarPerVector_E2,
false, // don't move back src coordinate after threadwise copy false, // don't move back src coordinate after threadwise copy
Sequence<0, 2, 3, 4, 1, 5>, Sequence<2, 0, 1, 3, 4, 5>,
5, 5,
BThreadTransferSrcScalarPerVector_E2, BThreadTransferSrcScalarPerVector_E2,
false, // don't move back src coordinate after threadwise copy, which will be fused with false, // don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation // MoveSrcSliceWindow() to save addr computation
Sequence<2, 3, 1, 0>, Sequence<1, 2, 3, 0>,
0, 0,
CThreadTransferDstScalarPerVector_K, CThreadTransferDstScalarPerVector_K,
decltype(a_e0_e1_k_e2_global_step_hacks), decltype(a_e0_e1_k_e2_global_step_hacks),
...@@ -269,17 +268,17 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nhwc_kyxc_nhwk_outp ...@@ -269,17 +268,17 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nhwc_kyxc_nhwk_outp
const auto grid_size = (K / KPerBlock) * (Hop / HoPerBlock) * (Wop / WoPerBlock) * N; const auto grid_size = (K / KPerBlock) * (Hop / HoPerBlock) * (Wop / WoPerBlock) * N;
const bool has_main_k_block_loop = (E1 + E1PerBlock) / (2 * E1PerBlock) > 1; constexpr bool has_main_k_block_loop = (E1 + E1PerBlock) / (2 * E1PerBlock) > 1;
const bool has_double_tail_k_block_loop = (E1 / E1PerBlock) % 2 == 0; constexpr bool has_double_tail_k_block_loop = (E1 / E1PerBlock) % 2 == 0;
std::cerr << "has_main_k_block_loop = " << has_main_k_block_loop std::cerr << "has_main_k_block_loop = " << has_main_k_block_loop
<< " has_double_tail_k_block_loop = " << has_double_tail_k_block_loop << " has_double_tail_k_block_loop = " << has_double_tail_k_block_loop
<< std::endl; << std::endl;
const auto c_blockid_to_k_n_ho_wo_block_cluster_adaptor = const auto c_blockid_to_k_n_ho_wo_block_cluster_adaptor =
make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(I0, I0))), make_single_stage_tensor_adaptor(make_tuple(make_pass_through_transform(I0)),
make_tuple(Sequence<0, 1>{}), make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
using CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo = using CBlockIdToBlockClusterAdaptor_K_N_Ho_Wo =
...@@ -288,7 +287,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nhwc_kyxc_nhwk_outp ...@@ -288,7 +287,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nhwc_kyxc_nhwk_outp
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE #if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
float ave_time = 0; float ave_time = 0;
if(has_main_k_block_loop && has_double_tail_k_block_loop) if constexpr(has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
const auto kernel = const auto kernel =
kernel_gemm_dlops_v2<GridwiseGemm, kernel_gemm_dlops_v2<GridwiseGemm,
...@@ -314,7 +313,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nhwc_kyxc_nhwk_outp ...@@ -314,7 +313,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nhwc_kyxc_nhwk_outp
c_k_n_hop_wop_grid_desc, c_k_n_hop_wop_grid_desc,
c_blockid_to_k_n_ho_wo_block_cluster_adaptor); c_blockid_to_k_n_ho_wo_block_cluster_adaptor);
} }
else if(has_main_k_block_loop && !has_double_tail_k_block_loop) else if constexpr(has_main_k_block_loop && !has_double_tail_k_block_loop)
{ {
const auto kernel = const auto kernel =
kernel_gemm_dlops_v2<GridwiseGemm, kernel_gemm_dlops_v2<GridwiseGemm,
...@@ -340,7 +339,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nhwc_kyxc_nhwk_outp ...@@ -340,7 +339,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nhwc_kyxc_nhwk_outp
c_k_n_hop_wop_grid_desc, c_k_n_hop_wop_grid_desc,
c_blockid_to_k_n_ho_wo_block_cluster_adaptor); c_blockid_to_k_n_ho_wo_block_cluster_adaptor);
} }
else if(!has_main_k_block_loop && has_double_tail_k_block_loop) else if constexpr(!has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
const auto kernel = const auto kernel =
kernel_gemm_dlops_v2<GridwiseGemm, kernel_gemm_dlops_v2<GridwiseGemm,
...@@ -409,7 +408,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nhwc_kyxc_nhwk_outp ...@@ -409,7 +408,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nhwc_kyxc_nhwk_outp
float ave_time = 0; float ave_time = 0;
if(has_main_k_block_loop && has_double_tail_k_block_loop) if constexpr(has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
const auto kernel = const auto kernel =
kernel_gemm_dlops_v2<GridwiseGemm, kernel_gemm_dlops_v2<GridwiseGemm,
...@@ -440,7 +439,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nhwc_kyxc_nhwk_outp ...@@ -440,7 +439,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nhwc_kyxc_nhwk_outp
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
c_blockid_to_k_n_ho_wo_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); c_blockid_to_k_n_ho_wo_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
} }
else if(has_main_k_block_loop && !has_double_tail_k_block_loop) else if constexpr(has_main_k_block_loop && !has_double_tail_k_block_loop)
{ {
const auto kernel = const auto kernel =
kernel_gemm_dlops_v2<GridwiseGemm, kernel_gemm_dlops_v2<GridwiseGemm,
...@@ -471,7 +470,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nhwc_kyxc_nhwk_outp ...@@ -471,7 +470,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nhwc_kyxc_nhwk_outp
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
c_blockid_to_k_n_ho_wo_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); c_blockid_to_k_n_ho_wo_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
} }
else if(!has_main_k_block_loop && has_double_tail_k_block_loop) else if constexpr(!has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
const auto kernel = const auto kernel =
kernel_gemm_dlops_v2<GridwiseGemm, kernel_gemm_dlops_v2<GridwiseGemm,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment