Commit d99e020d authored by Chao Liu's avatar Chao Liu
Browse files

overhauling fwd-v4r4

parent 4b21c0fd
...@@ -121,44 +121,10 @@ transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad( ...@@ -121,44 +121,10 @@ transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(
const auto out_gemm_block_cluster_desc = make_cluster_descriptor_v2( const auto out_gemm_block_cluster_desc = make_cluster_descriptor_v2(
make_tuple(GemmM / Number<GemmMPerBlock>{}, GemmN / Number<GemmNPerBlock>{})); make_tuple(GemmM / Number<GemmMPerBlock>{}, GemmN / Number<GemmNPerBlock>{}));
// hack to control index calculation when iterating over wei_gemmk_gemmm_global tensor
constexpr auto wei_gemmk_gemmm_global_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}));
constexpr auto wei_gemmk_gemmm_global_move_slice_window_iterator_hacks = Sequence<0, 0, 0>{};
// hack to control index calculation when iterating over in_gemmk_gemmn_global tensor
constexpr auto in_gemmk_gemmn_global_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2>{}));
constexpr auto in_gemmk_gemmn_global_move_slice_window_iterator_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2>{};
// hack to control index calculation when iterating over out_gemmm0_gemmm1_gemmn0_gemmn1_global
// tensor hack for NKHW format
constexpr auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}));
return make_tuple(wei_gemmk_gemmm_global_desc, return make_tuple(wei_gemmk_gemmm_global_desc,
in_gemmk_gemmn_global_desc, in_gemmk_gemmn_global_desc,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc, out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
out_gemm_block_cluster_desc, out_gemm_block_cluster_desc);
wei_gemmk_gemmm_global_iterator_hacks,
in_gemmk_gemmn_global_iterator_hacks,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks,
wei_gemmk_gemmm_global_move_slice_window_iterator_hacks,
in_gemmk_gemmn_global_move_slice_window_iterator_hacks);
} }
} // namespace ck } // namespace ck
......
...@@ -74,7 +74,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2 ...@@ -74,7 +74,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2
} }
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeBKN0N1BlockDescriptor(const BKNBlockDesc& n_k_n_block_desc) MakeBKN0N1BlockDescriptor(const BKNBlockDesc& b_k_n_block_desc)
{ {
const auto b_k_n0_n1_block_desc = transform_dynamic_tensor_descriptor( const auto b_k_n0_n1_block_desc = transform_dynamic_tensor_descriptor(
BKNBlockDesc{}, BKNBlockDesc{},
......
...@@ -17,9 +17,9 @@ template <typename GridwiseGemm, ...@@ -17,9 +17,9 @@ template <typename GridwiseGemm,
typename FloatA, typename FloatA,
typename FloatB, typename FloatB,
typename FloatC, typename FloatC,
typename AGlobalDesc, typename AKMGridDesc,
typename BGlobalDesc, typename BKNGridDesc,
typename CGlobalDesc, typename CM0M1N0N1GridDesc,
typename CBlockClusterDesc, typename CBlockClusterDesc,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop> bool HasDoubleTailKBlockLoop>
...@@ -27,20 +27,20 @@ __global__ void ...@@ -27,20 +27,20 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_dynamic_gemm_v1r2(const FloatA* __restrict__ p_a_global, kernel_dynamic_gemm_v1r2(const FloatA* __restrict__ p_a_grid,
const FloatB* __restrict__ p_b_global, const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_global, FloatC* __restrict__ p_c_grid,
const AGlobalDesc a_k_m_global_desc, const AKMGridDesc a_k_m_grid_desc,
const BGlobalDesc b_k_n_global_desc, const BKNGridDesc b_k_n_grid_desc,
const CGlobalDesc c_m0_m1_n0_n1_global_desc, const CM0M1N0N1GridDesc c_m0_m1_n0_n1_grid_desc,
const CBlockClusterDesc c_block_cluster_desc) const CBlockClusterDesc c_block_cluster_desc)
{ {
GridwiseGemm::Run(p_a_global, GridwiseGemm::Run(p_a_grid,
p_b_global, p_b_grid,
p_c_global, p_c_grid,
a_k_m_global_desc, a_k_m_grid_desc,
b_k_n_global_desc, b_k_n_grid_desc,
c_m0_m1_n0_n1_global_desc, c_m0_m1_n0_n1_grid_desc,
c_block_cluster_desc, c_block_cluster_desc,
integral_constant<bool, HasMainKBlockLoop>{}, integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{}); integral_constant<bool, HasDoubleTailKBlockLoop>{});
...@@ -53,9 +53,9 @@ template <typename GridwiseGemm, ...@@ -53,9 +53,9 @@ template <typename GridwiseGemm,
typename FloatA, typename FloatA,
typename FloatB, typename FloatB,
typename FloatC, typename FloatC,
typename AGlobalDesc, typename AKMGridDesc,
typename BGlobalDesc, typename BKNGridDesc,
typename CGlobalDesc, typename CM0M1N0N1GridDesc,
typename CBlockClusterDesc, typename CBlockClusterDesc,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop> bool HasDoubleTailKBlockLoop>
...@@ -63,33 +63,33 @@ __global__ void ...@@ -63,33 +63,33 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_dynamic_gemm_v1r2(const FloatA* __restrict__ p_a_global, kernel_dynamic_gemm_v1r2(const FloatA* __restrict__ p_a_grid,
const FloatB* __restrict__ p_b_global, const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_global, FloatC* __restrict__ p_c_grid,
const void __CONSTANT__* p_a_k_m_global_desc, const void __CONSTANT__* p_a_k_m_grid_desc,
const void __CONSTANT__* p_b_k_n_global_desc, const void __CONSTANT__* p_b_k_n_grid_desc,
const void __CONSTANT__* p_c_m0_m1_n0_n1_global_desc, const void __CONSTANT__* p_c_m0_m1_n0_n1_grid_desc,
const void __CONSTANT__* p_c_block_cluster_desc) const void __CONSTANT__* p_c_block_cluster_desc)
{ {
// first cast void __CONSTANT__ void* to void* // first cast void __CONSTANT__ void* to void*
// second cast void* to Desc* // second cast void* to Desc*
// the copy constructor of tensor descriptor doesn't take address_space(4) // the copy constructor of tensor descriptor doesn't take address_space(4)
const auto a_k_m_global_desc = const auto a_k_m_grid_desc =
*reinterpret_cast<const AGlobalDesc*>((const void*)p_a_k_m_global_desc); *reinterpret_cast<const AKMGridDesc*>((const void*)p_a_k_m_grid_desc);
const auto b_k_n_global_desc = const auto b_k_n_grid_desc =
*reinterpret_cast<const BGlobalDesc*>((const void*)p_b_k_n_global_desc); *reinterpret_cast<const BKNGridDesc*>((const void*)p_b_k_n_grid_desc);
const auto c_m0_m1_n0_n1_global_desc = const auto c_m0_m1_n0_n1_grid_desc =
*reinterpret_cast<const CGlobalDesc*>((const void*)p_c_m0_m1_n0_n1_global_desc); *reinterpret_cast<const CM0M1N0N1GridDesc*>((const void*)p_c_m0_m1_n0_n1_grid_desc);
const auto c_block_cluster_desc = const auto c_block_cluster_desc =
*reinterpret_cast<const CBlockClusterDesc*>((const void*)p_c_block_cluster_desc); *reinterpret_cast<const CBlockClusterDesc*>((const void*)p_c_block_cluster_desc);
GridwiseGemm::Run(p_a_global, GridwiseGemm::Run(p_a_grid,
p_b_global, p_b_grid,
p_c_global, p_c_grid,
a_k_m_global_desc, a_k_m_grid_desc,
b_k_n_global_desc, b_k_n_grid_desc,
c_m0_m1_n0_n1_global_desc, c_m0_m1_n0_n1_grid_desc,
c_block_cluster_desc, c_block_cluster_desc,
integral_constant<bool, HasMainKBlockLoop>{}, integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{}); integral_constant<bool, HasDoubleTailKBlockLoop>{});
...@@ -101,9 +101,9 @@ template <index_t BlockSize, ...@@ -101,9 +101,9 @@ template <index_t BlockSize,
typename FloatAcc, typename FloatAcc,
typename FloatC, typename FloatC,
InMemoryDataOperation CGlobalMemoryDataOperation, InMemoryDataOperation CGlobalMemoryDataOperation,
typename AGlobalDesc, typename AKMGridDesc,
typename BGlobalDesc, typename BKNGridDesc,
typename CGlobalDesc, typename CM0M1N0N1GridDesc,
typename CBlockClusterDesc, typename CBlockClusterDesc,
index_t MPerBlock, index_t MPerBlock,
index_t NPerBlock, index_t NPerBlock,
...@@ -134,13 +134,18 @@ template <index_t BlockSize, ...@@ -134,13 +134,18 @@ template <index_t BlockSize,
typename CThreadTransferSrcDstAccessOrder, typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim, index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector, index_t CThreadTransferDstScalarPerVector,
typename AGlobalIteratorHacks, typename AGridIteratorHacks,
typename BGlobalIteratorHacks, typename BGridIteratorHacks,
typename CGlobalIteratorHacks, typename CGridIteratorHacks,
typename AGlobalMoveSliceWindowIteratorHacks, typename AGridMoveSliceWindowIteratorHacks,
typename BGlobalMoveSliceWindowIteratorHacks> typename BGridMoveSliceWindowIteratorHacks>
struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
{ {
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{ {
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M>{}, constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M>{},
...@@ -168,33 +173,71 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 ...@@ -168,33 +173,71 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
return 2 * (a_block_space_size + b_block_space_size) * sizeof(FloatAB); return 2 * (a_block_space_size + b_block_space_size) * sizeof(FloatAB);
} }
__host__ __device__ static constexpr auto
MakeAKM0M1BlockClusterizedGridDescriptor(const AKMGridDesc& a_k_m_grid_desc)
{
const auto K = a_k_m_grid_desc.GetLength(I0);
const auto M = a_k_m_grid_desc.GetLength(I1);
const auto M1 = Number<MPerBlock>{};
const auto M0 = M / M1;
const auto a_k_m0_m1_block_clusterized_grid_desc = transform_dynamic_tensor_descriptor(
a_k_m_grid_desc,
make_tuple(make_pass_through_transform(K), make_unmerge_transform(make_tuple(M0, M1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
return a_k_m0_m1_block_clusterized_grid_desc;
}
__host__ __device__ static constexpr auto
MakeBKN0N1BlockClusterizedGridDescriptor(const BKNGridDesc& b_k_n_grid_desc)
{
const auto K = b_k_n_grid_desc.GetLength(I0);
const auto N = b_k_n_grid_desc.GetLength(I1);
const auto N1 = Number<NPerBlock>{};
const auto N0 = N / N1;
const auto b_k_n0_n1_block_clusterized_grid_desc = transform_dynamic_tensor_descriptor(
b_k_n_grid_desc,
make_tuple(make_pass_through_transform(K), make_unmerge_transform(make_tuple(N0, N1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
return b_k_n0_n1_block_clusterized_grid_desc;
}
using AKM0M1GridDesc = decltype(MakeAKM0M1BlockClusterizedGridDescriptor(AKMGridDesc{}));
using BKN0N1GridDesc = decltype(MakeBKN0N1BlockClusterizedGridDescriptor(BKNGridDesc{}));
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop> template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ static void Run(const FloatAB* __restrict__ p_a_global, __device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_global, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_global, FloatC* __restrict__ p_c_grid,
const AGlobalDesc& a_k_m_global_desc, const AKMGridDesc& a_k_m_grid_desc,
const BGlobalDesc& b_k_n_global_desc, const BKNGridDesc& b_k_n_grid_desc,
const CGlobalDesc& c_m0_m1_n0_n1_global_desc, const CM0M1N0N1GridDesc& c_m0_m1_n0_n1_grid_desc,
const CBlockClusterDesc& c_block_cluster_desc, const CBlockClusterDesc& c_block_cluster_desc,
#if 0
const AKM0M1GridDesc& a_k_m0_m1_grid_desc,
const BKN0N1GridDesc& b_k_n0_n1_grid_desc,
#endif
FloatAB* __restrict__ p_shared_block, FloatAB* __restrict__ p_shared_block,
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) integral_constant<bool, HasDoubleTailKBlockLoop>)
{ {
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const auto a_global_buf = make_dynamic_buffer<AddressSpace::Global>( const auto a_global_buf = make_dynamic_buffer<AddressSpace::Global>(
p_a_global, a_k_m_global_desc.GetElementSpaceSize()); p_a_grid, a_k_m_grid_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpace::Global>( const auto b_global_buf = make_dynamic_buffer<AddressSpace::Global>(
p_b_global, b_k_n_global_desc.GetElementSpaceSize()); p_b_grid, b_k_n_grid_desc.GetElementSpaceSize());
auto c_global_buf = make_dynamic_buffer<AddressSpace::Global>( auto c_grid_buf = make_dynamic_buffer<AddressSpace::Global>(
p_c_global, c_m0_m1_n0_n1_global_desc.GetElementSpaceSize()); p_c_grid, c_m0_m1_n0_n1_grid_desc.GetElementSpaceSize());
const auto K = a_k_m_global_desc.GetLength(I0); const auto K = a_k_m_grid_desc.GetLength(I0);
const auto M = a_k_m_global_desc.GetLength(I1); const auto M = a_k_m_grid_desc.GetLength(I1);
const auto N = b_k_n_global_desc.GetLength(I1); const auto N = b_k_n_grid_desc.GetLength(I1);
// divide block work by [M, N] // divide block work by [M, N]
const auto block_work_idx = const auto block_work_idx =
...@@ -233,7 +276,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 ...@@ -233,7 +276,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
FloatAB, FloatAB,
decltype(a_k_m_global_desc), decltype(a_k_m_grid_desc),
decltype(a_k_m_block_desc), decltype(a_k_m_block_desc),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
Sequence<0, 1>, Sequence<0, 1>,
...@@ -245,7 +288,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 ...@@ -245,7 +288,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
1, 1,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
true>( true>(
a_k_m_global_desc, a_k_m_grid_desc,
make_multi_index(0, m_block_data_idx_on_global), make_multi_index(0, m_block_data_idx_on_global),
a_k_m_block_desc, a_k_m_block_desc,
make_multi_index(0, 0)); make_multi_index(0, 0));
...@@ -260,7 +303,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 ...@@ -260,7 +303,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
FloatAB, FloatAB,
decltype(b_k_n_global_desc), decltype(b_k_n_grid_desc),
decltype(b_k_n_block_desc), decltype(b_k_n_block_desc),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
Sequence<0, 1>, Sequence<0, 1>,
...@@ -272,7 +315,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 ...@@ -272,7 +315,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
1, 1,
BThreadTransferSrcResetCoordinateAfterRun, BThreadTransferSrcResetCoordinateAfterRun,
true>( true>(
b_k_n_global_desc, b_k_n_grid_desc,
make_multi_index(0, n_block_data_idx_on_global), make_multi_index(0, n_block_data_idx_on_global),
b_k_n_block_desc, b_k_n_block_desc,
make_multi_index(0, 0)); make_multi_index(0, 0));
...@@ -328,15 +371,15 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 ...@@ -328,15 +371,15 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0); constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy // hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_k_m_global_iterator_hacks = AGlobalIteratorHacks{}; constexpr auto a_k_m_global_iterator_hacks = AGridIteratorHacks{};
constexpr auto b_k_n_global_iterator_hacks = BGlobalIteratorHacks{}; constexpr auto b_k_n_global_iterator_hacks = BGridIteratorHacks{};
// hack to control index calculation when move slice window for A and B matrix for // hack to control index calculation when move slice window for A and B matrix for
// threadwise copy // threadwise copy
constexpr auto a_k_m_global_move_slice_window_iterator_hack = constexpr auto a_k_m_global_move_slice_window_iterator_hack =
AGlobalMoveSliceWindowIteratorHacks{}; AGridMoveSliceWindowIteratorHacks{};
constexpr auto b_k_n_global_move_slice_window_iterator_hack = constexpr auto b_k_n_global_move_slice_window_iterator_hack =
BGlobalMoveSliceWindowIteratorHacks{}; BGridMoveSliceWindowIteratorHacks{};
auto a_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>( auto a_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>(
p_a_block_double, a_k_m_block_desc.GetElementSpaceSize()); p_a_block_double, a_k_m_block_desc.GetElementSpaceSize());
...@@ -350,8 +393,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 ...@@ -350,8 +393,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
// LDS double buffer: preload data into LDS // LDS double buffer: preload data into LDS
{ {
a_blockwise_copy.RunRead(a_k_m_global_desc, a_global_buf, a_k_m_global_iterator_hacks); a_blockwise_copy.RunRead(a_k_m_grid_desc, a_global_buf, a_k_m_global_iterator_hacks);
b_blockwise_copy.RunRead(b_k_n_global_desc, b_global_buf, b_k_n_global_iterator_hacks); b_blockwise_copy.RunRead(b_k_n_grid_desc, b_global_buf, b_k_n_global_iterator_hacks);
a_blockwise_copy.RunWrite(a_k_m_block_desc, a_block_even_buf); a_blockwise_copy.RunWrite(a_k_m_block_desc, a_block_even_buf);
b_blockwise_copy.RunWrite(b_k_n_block_desc, b_block_even_buf); b_blockwise_copy.RunWrite(b_k_n_block_desc, b_block_even_buf);
...@@ -366,10 +409,10 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 ...@@ -366,10 +409,10 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
do do
{ {
// even iteration // even iteration
a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc, a_blockwise_copy.MoveSrcSliceWindow(a_k_m_grid_desc,
a_block_slice_copy_step, a_block_slice_copy_step,
a_k_m_global_move_slice_window_iterator_hack); a_k_m_global_move_slice_window_iterator_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_k_n_global_desc, b_blockwise_copy.MoveSrcSliceWindow(b_k_n_grid_desc,
b_block_slice_copy_step, b_block_slice_copy_step,
b_k_n_global_move_slice_window_iterator_hack); b_k_n_global_move_slice_window_iterator_hack);
...@@ -377,9 +420,9 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 ...@@ -377,9 +420,9 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(
a_k_m_global_desc, a_global_buf, a_k_m_global_iterator_hacks); a_k_m_grid_desc, a_global_buf, a_k_m_global_iterator_hacks);
b_blockwise_copy.RunRead( b_blockwise_copy.RunRead(
b_k_n_global_desc, b_global_buf, b_k_n_global_iterator_hacks); b_k_n_grid_desc, b_global_buf, b_k_n_global_iterator_hacks);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run( blockwise_gemm.Run(
...@@ -390,10 +433,10 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 ...@@ -390,10 +433,10 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
b_blockwise_copy.RunWrite(b_k_n_block_desc, b_block_odd_buf); b_blockwise_copy.RunWrite(b_k_n_block_desc, b_block_odd_buf);
// odd iteration // odd iteration
a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc, a_blockwise_copy.MoveSrcSliceWindow(a_k_m_grid_desc,
a_block_slice_copy_step, a_block_slice_copy_step,
a_k_m_global_move_slice_window_iterator_hack); a_k_m_global_move_slice_window_iterator_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_k_n_global_desc, b_blockwise_copy.MoveSrcSliceWindow(b_k_n_grid_desc,
b_block_slice_copy_step, b_block_slice_copy_step,
b_k_n_global_move_slice_window_iterator_hack); b_k_n_global_move_slice_window_iterator_hack);
...@@ -401,9 +444,9 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 ...@@ -401,9 +444,9 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead( a_blockwise_copy.RunRead(
a_k_m_global_desc, a_global_buf, a_k_m_global_iterator_hacks); a_k_m_grid_desc, a_global_buf, a_k_m_global_iterator_hacks);
b_blockwise_copy.RunRead( b_blockwise_copy.RunRead(
b_k_n_global_desc, b_global_buf, b_k_n_global_iterator_hacks); b_k_n_grid_desc, b_global_buf, b_k_n_global_iterator_hacks);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run( blockwise_gemm.Run(
...@@ -420,18 +463,18 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 ...@@ -420,18 +463,18 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
// LDS double buffer: tail // LDS double buffer: tail
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{ {
a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc, a_blockwise_copy.MoveSrcSliceWindow(a_k_m_grid_desc,
a_block_slice_copy_step, a_block_slice_copy_step,
a_k_m_global_move_slice_window_iterator_hack); a_k_m_global_move_slice_window_iterator_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_k_n_global_desc, b_blockwise_copy.MoveSrcSliceWindow(b_k_n_grid_desc,
b_block_slice_copy_step, b_block_slice_copy_step,
b_k_n_global_move_slice_window_iterator_hack); b_k_n_global_move_slice_window_iterator_hack);
__syncthreads(); __syncthreads();
// LDS double buffer: load last data from device mem // LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(a_k_m_global_desc, a_global_buf, a_k_m_global_iterator_hacks); a_blockwise_copy.RunRead(a_k_m_grid_desc, a_global_buf, a_k_m_global_iterator_hacks);
b_blockwise_copy.RunRead(b_k_n_global_desc, b_global_buf, b_k_n_global_iterator_hacks); b_blockwise_copy.RunRead(b_k_n_grid_desc, b_global_buf, b_k_n_global_iterator_hacks);
// LDS double buffer: GEMM on 2nd-last data // LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run( blockwise_gemm.Run(
...@@ -462,7 +505,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 ...@@ -462,7 +505,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
constexpr auto N1 = Number<N1PerThread * M1N1ThreadClusterN10 * M1N1ThreadClusterN11>{}; constexpr auto N1 = Number<N1PerThread * M1N1ThreadClusterN10 * M1N1ThreadClusterN11>{};
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor // hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
constexpr auto c_m0_m1_n0_n1_global_tensor_iterator_hacks = CGlobalIteratorHacks{}; constexpr auto c_m0_m1_n0_n1_global_tensor_iterator_hacks = CGridIteratorHacks{};
const auto c_thread_data_idx_on_block = const auto c_thread_data_idx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(get_thread_local_1d_id()); blockwise_gemm.CalculateCThreadOriginDataIndex(get_thread_local_1d_id());
...@@ -470,7 +513,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 ...@@ -470,7 +513,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
ThreadwiseDynamicTensorSliceTransfer_v1r3<FloatAcc, ThreadwiseDynamicTensorSliceTransfer_v1r3<FloatAcc,
FloatC, FloatC,
decltype(c_m0_m1_n0_n1_thread_desc), decltype(c_m0_m1_n0_n1_thread_desc),
decltype(c_m0_m1_n0_n1_global_desc), decltype(c_m0_m1_n0_n1_grid_desc),
decltype(c_m0_m1_n0_n1_thread_tensor_lengths), decltype(c_m0_m1_n0_n1_thread_tensor_lengths),
CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
...@@ -478,7 +521,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 ...@@ -478,7 +521,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
CGlobalMemoryDataOperation, CGlobalMemoryDataOperation,
1, 1,
true>{ true>{
c_m0_m1_n0_n1_global_desc, c_m0_m1_n0_n1_grid_desc,
make_multi_index(m_block_data_idx_on_global / M1 + c_thread_data_idx_on_block[I0], make_multi_index(m_block_data_idx_on_global / M1 + c_thread_data_idx_on_block[I0],
c_thread_data_idx_on_block[I1], c_thread_data_idx_on_block[I1],
n_block_data_idx_on_global / N1 + c_thread_data_idx_on_block[I2], n_block_data_idx_on_global / N1 + c_thread_data_idx_on_block[I2],
...@@ -486,19 +529,19 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 ...@@ -486,19 +529,19 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
.Run(c_m0_m1_n0_n1_thread_desc, .Run(c_m0_m1_n0_n1_thread_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
c_thread_buf, c_thread_buf,
c_m0_m1_n0_n1_global_desc, c_m0_m1_n0_n1_grid_desc,
c_global_buf, c_grid_buf,
c_m0_m1_n0_n1_global_tensor_iterator_hacks); c_m0_m1_n0_n1_global_tensor_iterator_hacks);
} }
} }
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop> template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ static void Run(const FloatAB* __restrict__ p_a_global, __device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_global, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_global, FloatC* __restrict__ p_c_grid,
const AGlobalDesc& a_k_m_global_desc, const AKMGridDesc& a_k_m_grid_desc,
const BGlobalDesc& b_k_n_global_desc, const BKNGridDesc& b_k_n_grid_desc,
const CGlobalDesc& c_m0_m1_n0_n1_global_desc, const CM0M1N0N1GridDesc& c_m0_m1_n0_n1_grid_desc,
const CBlockClusterDesc& c_block_cluster_desc, const CBlockClusterDesc& c_block_cluster_desc,
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) integral_constant<bool, HasDoubleTailKBlockLoop>)
...@@ -507,12 +550,12 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 ...@@ -507,12 +550,12 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
__shared__ FloatAB p_shared_block[shared_block_size]; __shared__ FloatAB p_shared_block[shared_block_size];
Run(p_a_global, Run(p_a_grid,
p_b_global, p_b_grid,
p_c_global, p_c_grid,
a_k_m_global_desc, a_k_m_grid_desc,
b_k_n_global_desc, b_k_n_grid_desc,
c_m0_m1_n0_n1_global_desc, c_m0_m1_n0_n1_grid_desc,
c_block_cluster_desc, c_block_cluster_desc,
p_shared_block, p_shared_block,
integral_constant<bool, HasMainKBlockLoop>{}, integral_constant<bool, HasMainKBlockLoop>{},
......
...@@ -485,6 +485,35 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw( ...@@ -485,6 +485,35 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
in_left_pads, in_left_pads,
in_right_pads); in_right_pads);
// hack to control index calculation when iterating over wei_gemmk_gemmm_global tensor
constexpr auto wei_gemmk_gemmm_global_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}));
constexpr auto wei_gemmk_gemmm_global_move_slice_window_iterator_hacks = Sequence<0, 0, 0>{};
// hack to control index calculation when iterating over in_gemmk_gemmn_global tensor
constexpr auto in_gemmk_gemmn_global_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2>{}));
constexpr auto in_gemmk_gemmn_global_move_slice_window_iterator_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2>{};
// hack to control index calculation when iterating over out_gemmm0_gemmm1_gemmn0_gemmn1_global
// tensor hack for NKHW format
constexpr auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}));
for(index_t i = 0; i < 5; ++i) for(index_t i = 0; i < 5; ++i)
{ {
float ave_time = launch_kernel_dynamic_gemm_v1r2< float ave_time = launch_kernel_dynamic_gemm_v1r2<
...@@ -527,25 +556,26 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw( ...@@ -527,25 +556,26 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
Sequence<2, 3, 0, 1>, Sequence<2, 3, 0, 1>,
3, 3,
GemmCThreadTransferDstScalarPerVector_GemmN1, GemmCThreadTransferDstScalarPerVector_GemmN1,
decltype(descs[I4]), decltype(wei_gemmk_gemmm_global_iterator_hacks),
decltype(descs[I5]), decltype(in_gemmk_gemmn_global_iterator_hacks),
decltype(descs[I6]), decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks),
decltype(descs[I7]), decltype(wei_gemmk_gemmm_global_move_slice_window_iterator_hacks),
decltype(descs[I8])>(static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>( decltype(in_gemmk_gemmn_global_move_slice_window_iterator_hacks)>(
wei_k_c_y_x_device_buf.GetDeviceBuffer()), static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>( wei_k_c_y_x_device_buf.GetDeviceBuffer()),
in_n_c_hi_wi_device_buf.GetDeviceBuffer()), static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
descs[I0], static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
descs[I1], descs[I0],
descs[I2], descs[I1],
descs[I3], descs[I2],
descs[I4], descs[I3],
descs[I5], wei_gemmk_gemmm_global_iterator_hacks,
descs[I6], in_gemmk_gemmn_global_iterator_hacks,
descs[I7], out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks,
descs[I8], wei_gemmk_gemmm_global_move_slice_window_iterator_hacks,
nrepeat); in_gemmk_gemmn_global_move_slice_window_iterator_hacks,
nrepeat);
float perf = (float)calculate_convolution_flops( float perf = (float)calculate_convolution_flops(
in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc) / in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc) /
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment