Commit d99e020d authored by Chao Liu's avatar Chao Liu
Browse files

overhauling fwd-v4r4

parent 4b21c0fd
......@@ -121,44 +121,10 @@ transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(
const auto out_gemm_block_cluster_desc = make_cluster_descriptor_v2(
make_tuple(GemmM / Number<GemmMPerBlock>{}, GemmN / Number<GemmNPerBlock>{}));
// hack to control index calculation when iterating over wei_gemmk_gemmm_global tensor
constexpr auto wei_gemmk_gemmm_global_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}));
constexpr auto wei_gemmk_gemmm_global_move_slice_window_iterator_hacks = Sequence<0, 0, 0>{};
// hack to control index calculation when iterating over in_gemmk_gemmn_global tensor
constexpr auto in_gemmk_gemmn_global_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2>{}));
constexpr auto in_gemmk_gemmn_global_move_slice_window_iterator_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2>{};
// hack to control index calculation when iterating over out_gemmm0_gemmm1_gemmn0_gemmn1_global
// tensor hack for NKHW format
constexpr auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}));
return make_tuple(wei_gemmk_gemmm_global_desc,
in_gemmk_gemmn_global_desc,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc,
out_gemm_block_cluster_desc,
wei_gemmk_gemmm_global_iterator_hacks,
in_gemmk_gemmn_global_iterator_hacks,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks,
wei_gemmk_gemmm_global_move_slice_window_iterator_hacks,
in_gemmk_gemmn_global_move_slice_window_iterator_hacks);
out_gemm_block_cluster_desc);
}
} // namespace ck
......
......@@ -74,7 +74,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2
}
__host__ __device__ static constexpr auto
MakeBKN0N1BlockDescriptor(const BKNBlockDesc& n_k_n_block_desc)
MakeBKN0N1BlockDescriptor(const BKNBlockDesc& b_k_n_block_desc)
{
const auto b_k_n0_n1_block_desc = transform_dynamic_tensor_descriptor(
BKNBlockDesc{},
......
......@@ -17,9 +17,9 @@ template <typename GridwiseGemm,
typename FloatA,
typename FloatB,
typename FloatC,
typename AGlobalDesc,
typename BGlobalDesc,
typename CGlobalDesc,
typename AKMGridDesc,
typename BKNGridDesc,
typename CM0M1N0N1GridDesc,
typename CBlockClusterDesc,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
......@@ -27,20 +27,20 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_dynamic_gemm_v1r2(const FloatA* __restrict__ p_a_global,
const FloatB* __restrict__ p_b_global,
FloatC* __restrict__ p_c_global,
const AGlobalDesc a_k_m_global_desc,
const BGlobalDesc b_k_n_global_desc,
const CGlobalDesc c_m0_m1_n0_n1_global_desc,
kernel_dynamic_gemm_v1r2(const FloatA* __restrict__ p_a_grid,
const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const AKMGridDesc a_k_m_grid_desc,
const BKNGridDesc b_k_n_grid_desc,
const CM0M1N0N1GridDesc c_m0_m1_n0_n1_grid_desc,
const CBlockClusterDesc c_block_cluster_desc)
{
GridwiseGemm::Run(p_a_global,
p_b_global,
p_c_global,
a_k_m_global_desc,
b_k_n_global_desc,
c_m0_m1_n0_n1_global_desc,
GridwiseGemm::Run(p_a_grid,
p_b_grid,
p_c_grid,
a_k_m_grid_desc,
b_k_n_grid_desc,
c_m0_m1_n0_n1_grid_desc,
c_block_cluster_desc,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
......@@ -53,9 +53,9 @@ template <typename GridwiseGemm,
typename FloatA,
typename FloatB,
typename FloatC,
typename AGlobalDesc,
typename BGlobalDesc,
typename CGlobalDesc,
typename AKMGridDesc,
typename BKNGridDesc,
typename CM0M1N0N1GridDesc,
typename CBlockClusterDesc,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
......@@ -63,33 +63,33 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_dynamic_gemm_v1r2(const FloatA* __restrict__ p_a_global,
const FloatB* __restrict__ p_b_global,
FloatC* __restrict__ p_c_global,
const void __CONSTANT__* p_a_k_m_global_desc,
const void __CONSTANT__* p_b_k_n_global_desc,
const void __CONSTANT__* p_c_m0_m1_n0_n1_global_desc,
kernel_dynamic_gemm_v1r2(const FloatA* __restrict__ p_a_grid,
const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const void __CONSTANT__* p_a_k_m_grid_desc,
const void __CONSTANT__* p_b_k_n_grid_desc,
const void __CONSTANT__* p_c_m0_m1_n0_n1_grid_desc,
const void __CONSTANT__* p_c_block_cluster_desc)
{
// first cast void __CONSTANT__ void* to void*
// second cast void* to Desc*
// the copy constructor of tensor descriptor doesn't take address_space(4)
const auto a_k_m_global_desc =
*reinterpret_cast<const AGlobalDesc*>((const void*)p_a_k_m_global_desc);
const auto b_k_n_global_desc =
*reinterpret_cast<const BGlobalDesc*>((const void*)p_b_k_n_global_desc);
const auto c_m0_m1_n0_n1_global_desc =
*reinterpret_cast<const CGlobalDesc*>((const void*)p_c_m0_m1_n0_n1_global_desc);
const auto a_k_m_grid_desc =
*reinterpret_cast<const AKMGridDesc*>((const void*)p_a_k_m_grid_desc);
const auto b_k_n_grid_desc =
*reinterpret_cast<const BKNGridDesc*>((const void*)p_b_k_n_grid_desc);
const auto c_m0_m1_n0_n1_grid_desc =
*reinterpret_cast<const CM0M1N0N1GridDesc*>((const void*)p_c_m0_m1_n0_n1_grid_desc);
const auto c_block_cluster_desc =
*reinterpret_cast<const CBlockClusterDesc*>((const void*)p_c_block_cluster_desc);
GridwiseGemm::Run(p_a_global,
p_b_global,
p_c_global,
a_k_m_global_desc,
b_k_n_global_desc,
c_m0_m1_n0_n1_global_desc,
GridwiseGemm::Run(p_a_grid,
p_b_grid,
p_c_grid,
a_k_m_grid_desc,
b_k_n_grid_desc,
c_m0_m1_n0_n1_grid_desc,
c_block_cluster_desc,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
......@@ -101,9 +101,9 @@ template <index_t BlockSize,
typename FloatAcc,
typename FloatC,
InMemoryDataOperation CGlobalMemoryDataOperation,
typename AGlobalDesc,
typename BGlobalDesc,
typename CGlobalDesc,
typename AKMGridDesc,
typename BKNGridDesc,
typename CM0M1N0N1GridDesc,
typename CBlockClusterDesc,
index_t MPerBlock,
index_t NPerBlock,
......@@ -134,13 +134,18 @@ template <index_t BlockSize,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
typename AGlobalIteratorHacks,
typename BGlobalIteratorHacks,
typename CGlobalIteratorHacks,
typename AGlobalMoveSliceWindowIteratorHacks,
typename BGlobalMoveSliceWindowIteratorHacks>
typename AGridIteratorHacks,
typename BGridIteratorHacks,
typename CGridIteratorHacks,
typename AGridMoveSliceWindowIteratorHacks,
typename BGridMoveSliceWindowIteratorHacks>
struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M>{},
......@@ -168,33 +173,71 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
return 2 * (a_block_space_size + b_block_space_size) * sizeof(FloatAB);
}
__host__ __device__ static constexpr auto
MakeAKM0M1BlockClusterizedGridDescriptor(const AKMGridDesc& a_k_m_grid_desc)
{
const auto K = a_k_m_grid_desc.GetLength(I0);
const auto M = a_k_m_grid_desc.GetLength(I1);
const auto M1 = Number<MPerBlock>{};
const auto M0 = M / M1;
const auto a_k_m0_m1_block_clusterized_grid_desc = transform_dynamic_tensor_descriptor(
a_k_m_grid_desc,
make_tuple(make_pass_through_transform(K), make_unmerge_transform(make_tuple(M0, M1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
return a_k_m0_m1_block_clusterized_grid_desc;
}
__host__ __device__ static constexpr auto
MakeBKN0N1BlockClusterizedGridDescriptor(const BKNGridDesc& b_k_n_grid_desc)
{
const auto K = b_k_n_grid_desc.GetLength(I0);
const auto N = b_k_n_grid_desc.GetLength(I1);
const auto N1 = Number<NPerBlock>{};
const auto N0 = N / N1;
const auto b_k_n0_n1_block_clusterized_grid_desc = transform_dynamic_tensor_descriptor(
b_k_n_grid_desc,
make_tuple(make_pass_through_transform(K), make_unmerge_transform(make_tuple(N0, N1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
return b_k_n0_n1_block_clusterized_grid_desc;
}
using AKM0M1GridDesc = decltype(MakeAKM0M1BlockClusterizedGridDescriptor(AKMGridDesc{}));
using BKN0N1GridDesc = decltype(MakeBKN0N1BlockClusterizedGridDescriptor(BKNGridDesc{}));
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ static void Run(const FloatAB* __restrict__ p_a_global,
const FloatAB* __restrict__ p_b_global,
FloatC* __restrict__ p_c_global,
const AGlobalDesc& a_k_m_global_desc,
const BGlobalDesc& b_k_n_global_desc,
const CGlobalDesc& c_m0_m1_n0_n1_global_desc,
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const AKMGridDesc& a_k_m_grid_desc,
const BKNGridDesc& b_k_n_grid_desc,
const CM0M1N0N1GridDesc& c_m0_m1_n0_n1_grid_desc,
const CBlockClusterDesc& c_block_cluster_desc,
#if 0
const AKM0M1GridDesc& a_k_m0_m1_grid_desc,
const BKN0N1GridDesc& b_k_n0_n1_grid_desc,
#endif
FloatAB* __restrict__ p_shared_block,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const auto a_global_buf = make_dynamic_buffer<AddressSpace::Global>(
p_a_global, a_k_m_global_desc.GetElementSpaceSize());
p_a_grid, a_k_m_grid_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpace::Global>(
p_b_global, b_k_n_global_desc.GetElementSpaceSize());
auto c_global_buf = make_dynamic_buffer<AddressSpace::Global>(
p_c_global, c_m0_m1_n0_n1_global_desc.GetElementSpaceSize());
p_b_grid, b_k_n_grid_desc.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpace::Global>(
p_c_grid, c_m0_m1_n0_n1_grid_desc.GetElementSpaceSize());
const auto K = a_k_m_global_desc.GetLength(I0);
const auto M = a_k_m_global_desc.GetLength(I1);
const auto N = b_k_n_global_desc.GetLength(I1);
const auto K = a_k_m_grid_desc.GetLength(I0);
const auto M = a_k_m_grid_desc.GetLength(I1);
const auto N = b_k_n_grid_desc.GetLength(I1);
// divide block work by [M, N]
const auto block_work_idx =
......@@ -233,7 +276,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(a_k_m_global_desc),
decltype(a_k_m_grid_desc),
decltype(a_k_m_block_desc),
ABlockTransferSrcAccessOrder,
Sequence<0, 1>,
......@@ -245,7 +288,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
1,
AThreadTransferSrcResetCoordinateAfterRun,
true>(
a_k_m_global_desc,
a_k_m_grid_desc,
make_multi_index(0, m_block_data_idx_on_global),
a_k_m_block_desc,
make_multi_index(0, 0));
......@@ -260,7 +303,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(b_k_n_global_desc),
decltype(b_k_n_grid_desc),
decltype(b_k_n_block_desc),
BBlockTransferSrcAccessOrder,
Sequence<0, 1>,
......@@ -272,7 +315,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
1,
BThreadTransferSrcResetCoordinateAfterRun,
true>(
b_k_n_global_desc,
b_k_n_grid_desc,
make_multi_index(0, n_block_data_idx_on_global),
b_k_n_block_desc,
make_multi_index(0, 0));
......@@ -328,15 +371,15 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_k_m_global_iterator_hacks = AGlobalIteratorHacks{};
constexpr auto b_k_n_global_iterator_hacks = BGlobalIteratorHacks{};
constexpr auto a_k_m_global_iterator_hacks = AGridIteratorHacks{};
constexpr auto b_k_n_global_iterator_hacks = BGridIteratorHacks{};
// hack to control index calculation when move slice window for A and B matrix for
// threadwise copy
constexpr auto a_k_m_global_move_slice_window_iterator_hack =
AGlobalMoveSliceWindowIteratorHacks{};
AGridMoveSliceWindowIteratorHacks{};
constexpr auto b_k_n_global_move_slice_window_iterator_hack =
BGlobalMoveSliceWindowIteratorHacks{};
BGridMoveSliceWindowIteratorHacks{};
auto a_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>(
p_a_block_double, a_k_m_block_desc.GetElementSpaceSize());
......@@ -350,8 +393,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
// LDS double buffer: preload data into LDS
{
a_blockwise_copy.RunRead(a_k_m_global_desc, a_global_buf, a_k_m_global_iterator_hacks);
b_blockwise_copy.RunRead(b_k_n_global_desc, b_global_buf, b_k_n_global_iterator_hacks);
a_blockwise_copy.RunRead(a_k_m_grid_desc, a_global_buf, a_k_m_global_iterator_hacks);
b_blockwise_copy.RunRead(b_k_n_grid_desc, b_global_buf, b_k_n_global_iterator_hacks);
a_blockwise_copy.RunWrite(a_k_m_block_desc, a_block_even_buf);
b_blockwise_copy.RunWrite(b_k_n_block_desc, b_block_even_buf);
......@@ -366,10 +409,10 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
do
{
// even iteration
a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc,
a_blockwise_copy.MoveSrcSliceWindow(a_k_m_grid_desc,
a_block_slice_copy_step,
a_k_m_global_move_slice_window_iterator_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_k_n_global_desc,
b_blockwise_copy.MoveSrcSliceWindow(b_k_n_grid_desc,
b_block_slice_copy_step,
b_k_n_global_move_slice_window_iterator_hack);
......@@ -377,9 +420,9 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(
a_k_m_global_desc, a_global_buf, a_k_m_global_iterator_hacks);
a_k_m_grid_desc, a_global_buf, a_k_m_global_iterator_hacks);
b_blockwise_copy.RunRead(
b_k_n_global_desc, b_global_buf, b_k_n_global_iterator_hacks);
b_k_n_grid_desc, b_global_buf, b_k_n_global_iterator_hacks);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(
......@@ -390,10 +433,10 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
b_blockwise_copy.RunWrite(b_k_n_block_desc, b_block_odd_buf);
// odd iteration
a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc,
a_blockwise_copy.MoveSrcSliceWindow(a_k_m_grid_desc,
a_block_slice_copy_step,
a_k_m_global_move_slice_window_iterator_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_k_n_global_desc,
b_blockwise_copy.MoveSrcSliceWindow(b_k_n_grid_desc,
b_block_slice_copy_step,
b_k_n_global_move_slice_window_iterator_hack);
......@@ -401,9 +444,9 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(
a_k_m_global_desc, a_global_buf, a_k_m_global_iterator_hacks);
a_k_m_grid_desc, a_global_buf, a_k_m_global_iterator_hacks);
b_blockwise_copy.RunRead(
b_k_n_global_desc, b_global_buf, b_k_n_global_iterator_hacks);
b_k_n_grid_desc, b_global_buf, b_k_n_global_iterator_hacks);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(
......@@ -420,18 +463,18 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
// LDS double buffer: tail
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{
a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc,
a_blockwise_copy.MoveSrcSliceWindow(a_k_m_grid_desc,
a_block_slice_copy_step,
a_k_m_global_move_slice_window_iterator_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_k_n_global_desc,
b_blockwise_copy.MoveSrcSliceWindow(b_k_n_grid_desc,
b_block_slice_copy_step,
b_k_n_global_move_slice_window_iterator_hack);
__syncthreads();
// LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(a_k_m_global_desc, a_global_buf, a_k_m_global_iterator_hacks);
b_blockwise_copy.RunRead(b_k_n_global_desc, b_global_buf, b_k_n_global_iterator_hacks);
a_blockwise_copy.RunRead(a_k_m_grid_desc, a_global_buf, a_k_m_global_iterator_hacks);
b_blockwise_copy.RunRead(b_k_n_grid_desc, b_global_buf, b_k_n_global_iterator_hacks);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(
......@@ -462,7 +505,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
constexpr auto N1 = Number<N1PerThread * M1N1ThreadClusterN10 * M1N1ThreadClusterN11>{};
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
constexpr auto c_m0_m1_n0_n1_global_tensor_iterator_hacks = CGlobalIteratorHacks{};
constexpr auto c_m0_m1_n0_n1_global_tensor_iterator_hacks = CGridIteratorHacks{};
const auto c_thread_data_idx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(get_thread_local_1d_id());
......@@ -470,7 +513,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
ThreadwiseDynamicTensorSliceTransfer_v1r3<FloatAcc,
FloatC,
decltype(c_m0_m1_n0_n1_thread_desc),
decltype(c_m0_m1_n0_n1_global_desc),
decltype(c_m0_m1_n0_n1_grid_desc),
decltype(c_m0_m1_n0_n1_thread_tensor_lengths),
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
......@@ -478,7 +521,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
CGlobalMemoryDataOperation,
1,
true>{
c_m0_m1_n0_n1_global_desc,
c_m0_m1_n0_n1_grid_desc,
make_multi_index(m_block_data_idx_on_global / M1 + c_thread_data_idx_on_block[I0],
c_thread_data_idx_on_block[I1],
n_block_data_idx_on_global / N1 + c_thread_data_idx_on_block[I2],
......@@ -486,19 +529,19 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
.Run(c_m0_m1_n0_n1_thread_desc,
make_tuple(I0, I0, I0, I0),
c_thread_buf,
c_m0_m1_n0_n1_global_desc,
c_global_buf,
c_m0_m1_n0_n1_grid_desc,
c_grid_buf,
c_m0_m1_n0_n1_global_tensor_iterator_hacks);
}
}
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ static void Run(const FloatAB* __restrict__ p_a_global,
const FloatAB* __restrict__ p_b_global,
FloatC* __restrict__ p_c_global,
const AGlobalDesc& a_k_m_global_desc,
const BGlobalDesc& b_k_n_global_desc,
const CGlobalDesc& c_m0_m1_n0_n1_global_desc,
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const AKMGridDesc& a_k_m_grid_desc,
const BKNGridDesc& b_k_n_grid_desc,
const CM0M1N0N1GridDesc& c_m0_m1_n0_n1_grid_desc,
const CBlockClusterDesc& c_block_cluster_desc,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>)
......@@ -507,12 +550,12 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
__shared__ FloatAB p_shared_block[shared_block_size];
Run(p_a_global,
p_b_global,
p_c_global,
a_k_m_global_desc,
b_k_n_global_desc,
c_m0_m1_n0_n1_global_desc,
Run(p_a_grid,
p_b_grid,
p_c_grid,
a_k_m_grid_desc,
b_k_n_grid_desc,
c_m0_m1_n0_n1_grid_desc,
c_block_cluster_desc,
p_shared_block,
integral_constant<bool, HasMainKBlockLoop>{},
......
......@@ -485,6 +485,35 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
in_left_pads,
in_right_pads);
// hack to control index calculation when iterating over wei_gemmk_gemmm_global tensor
constexpr auto wei_gemmk_gemmm_global_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}));
constexpr auto wei_gemmk_gemmm_global_move_slice_window_iterator_hacks = Sequence<0, 0, 0>{};
// hack to control index calculation when iterating over in_gemmk_gemmn_global tensor
constexpr auto in_gemmk_gemmn_global_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2>{}));
constexpr auto in_gemmk_gemmn_global_move_slice_window_iterator_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2>{};
// hack to control index calculation when iterating over out_gemmm0_gemmm1_gemmn0_gemmn1_global
// tensor hack for NKHW format
constexpr auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}));
for(index_t i = 0; i < 5; ++i)
{
float ave_time = launch_kernel_dynamic_gemm_v1r2<
......@@ -527,25 +556,26 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
Sequence<2, 3, 0, 1>,
3,
GemmCThreadTransferDstScalarPerVector_GemmN1,
decltype(descs[I4]),
decltype(descs[I5]),
decltype(descs[I6]),
decltype(descs[I7]),
decltype(descs[I8])>(static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
wei_k_c_y_x_device_buf.GetDeviceBuffer()),
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
descs[I0],
descs[I1],
descs[I2],
descs[I3],
descs[I4],
descs[I5],
descs[I6],
descs[I7],
descs[I8],
nrepeat);
decltype(wei_gemmk_gemmm_global_iterator_hacks),
decltype(in_gemmk_gemmn_global_iterator_hacks),
decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks),
decltype(wei_gemmk_gemmm_global_move_slice_window_iterator_hacks),
decltype(in_gemmk_gemmn_global_move_slice_window_iterator_hacks)>(
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
wei_k_c_y_x_device_buf.GetDeviceBuffer()),
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
descs[I0],
descs[I1],
descs[I2],
descs[I3],
wei_gemmk_gemmm_global_iterator_hacks,
in_gemmk_gemmn_global_iterator_hacks,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks,
wei_gemmk_gemmm_global_move_slice_window_iterator_hacks,
in_gemmk_gemmn_global_move_slice_window_iterator_hacks,
nrepeat);
float perf = (float)calculate_convolution_flops(
in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc) /
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment