Commit f63f1636 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent 51cdcee6
...@@ -69,44 +69,8 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -69,44 +69,8 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
const auto M = a_k_m_grid_desc.GetLength(I1);
const auto N = b_k_n_grid_desc.GetLength(I1);
const auto K = a_k_m_grid_desc.GetLength(I0);
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
{
throw std::runtime_error("wrong! GEMM size no divisible");
}
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
constexpr auto M11 = Number<M1PerThread * M1N1ThreadClusterM11 * M1N1ThreadClusterM10>{};
constexpr auto N11 = Number<N1PerThread * M1N1ThreadClusterN11 * M1N1ThreadClusterN10>{};
constexpr auto M10 = M1 / M11;
constexpr auto N10 = N1 / N11;
const auto c_m0_m10_m11_n0_n10_n11_grid_desc = transform_dynamic_tensor_descriptor(
c_m_n_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)),
make_unmerge_transform(make_tuple(N0, N10, N11))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
using CM0M10M11N0N10N11GridDesc = decltype(c_m0_m10_m11_n0_n10_n11_grid_desc);
// out_gemm_block_cluster_desc
const auto c_block_cluster_desc =
make_cluster_descriptor_v2(make_tuple(M / Number<MPerBlock>{}, N / Number<NPerBlock>{}));
using CBlockClusterDesc = decltype(c_block_cluster_desc);
// GEMM // GEMM
using gridwise_gemm = using GridwiseGemm =
GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2<BlockSize, GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2<BlockSize,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
...@@ -114,8 +78,7 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -114,8 +78,7 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
CGlobalMemoryDataOperation, CGlobalMemoryDataOperation,
AKMGridDesc, AKMGridDesc,
BKNGridDesc, BKNGridDesc,
CM0M10M11N0N10N11GridDesc, CMNGridDesc,
CBlockClusterDesc,
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
KPerBlock, KPerBlock,
...@@ -151,6 +114,26 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -151,6 +114,26 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
AGridMoveSliceWindowIteratorHacks, AGridMoveSliceWindowIteratorHacks,
BGridMoveSliceWindowIteratorHacks>; BGridMoveSliceWindowIteratorHacks>;
const auto M = a_k_m_grid_desc.GetLength(I1);
const auto N = b_k_n_grid_desc.GetLength(I1);
const auto K = a_k_m_grid_desc.GetLength(I0);
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
{
throw std::runtime_error("wrong! GEMM size no divisible");
}
// c_m0_m10_m11_n0_n10_n11_grid_desc
const auto c_m0_m10_m11_n0_n10_n11_grid_desc =
GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc);
using CM0M10M11N0N10N11GridDesc = decltype(c_m0_m10_m11_n0_n10_n11_grid_desc);
// c_block_cluster_adaptor
const auto c_block_cluster_adaptor = GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor);
const auto GridSize = (M / MPerBlock) * (N / NPerBlock); const auto GridSize = (M / MPerBlock) * (N / NPerBlock);
const bool has_main_k_block_loop = (K + KPerBlock) / (2 * KPerBlock) > 1; const bool has_main_k_block_loop = (K + KPerBlock) / (2 * KPerBlock) > 1;
...@@ -161,13 +144,13 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -161,13 +144,13 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
if(has_main_k_block_loop && has_double_tail_k_block_loop) if(has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
const auto kernel = kernel_dynamic_gemm_v1r2<gridwise_gemm, const auto kernel = kernel_dynamic_gemm_v1r2<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AKMGridDesc>, remove_reference_t<AKMGridDesc>,
remove_reference_t<BKNGridDesc>, remove_reference_t<BKNGridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>, remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockClusterDesc>, remove_reference_t<CBlockClusterAdaptor>,
true, true,
true>; true>;
...@@ -183,17 +166,17 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -183,17 +166,17 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
a_k_m_grid_desc, a_k_m_grid_desc,
b_k_n_grid_desc, b_k_n_grid_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc, c_m0_m10_m11_n0_n10_n11_grid_desc,
c_block_cluster_desc); c_block_cluster_adaptor);
} }
else if(has_main_k_block_loop && !has_double_tail_k_block_loop) else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{ {
const auto kernel = kernel_dynamic_gemm_v1r2<gridwise_gemm, const auto kernel = kernel_dynamic_gemm_v1r2<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AKMGridDesc>, remove_reference_t<AKMGridDesc>,
remove_reference_t<BKNGridDesc>, remove_reference_t<BKNGridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>, remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockClusterDesc>, remove_reference_t<CBlockClusterAdaptor>,
true, true,
false>; false>;
...@@ -209,17 +192,17 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -209,17 +192,17 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
a_k_m_grid_desc, a_k_m_grid_desc,
b_k_n_grid_desc, b_k_n_grid_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc, c_m0_m10_m11_n0_n10_n11_grid_desc,
c_block_cluster_desc); c_block_cluster_adaptor);
} }
else if(!has_main_k_block_loop && has_double_tail_k_block_loop) else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
const auto kernel = kernel_dynamic_gemm_v1r2<gridwise_gemm, const auto kernel = kernel_dynamic_gemm_v1r2<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AKMGridDesc>, remove_reference_t<AKMGridDesc>,
remove_reference_t<BKNGridDesc>, remove_reference_t<BKNGridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>, remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockClusterDesc>, remove_reference_t<CBlockClusterAdaptor>,
false, false,
true>; true>;
...@@ -235,17 +218,17 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -235,17 +218,17 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
a_k_m_grid_desc, a_k_m_grid_desc,
b_k_n_grid_desc, b_k_n_grid_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc, c_m0_m10_m11_n0_n10_n11_grid_desc,
c_block_cluster_desc); c_block_cluster_adaptor);
} }
else else
{ {
const auto kernel = kernel_dynamic_gemm_v1r2<gridwise_gemm, const auto kernel = kernel_dynamic_gemm_v1r2<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AKMGridDesc>, remove_reference_t<AKMGridDesc>,
remove_reference_t<BKNGridDesc>, remove_reference_t<BKNGridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>, remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockClusterDesc>, remove_reference_t<CBlockClusterAdaptor>,
false, false,
false>; false>;
...@@ -261,7 +244,7 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -261,7 +244,7 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
a_k_m_grid_desc, a_k_m_grid_desc,
b_k_n_grid_desc, b_k_n_grid_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc, c_m0_m10_m11_n0_n10_n11_grid_desc,
c_block_cluster_desc); c_block_cluster_adaptor);
} }
return ave_time; return ave_time;
......
...@@ -18,7 +18,7 @@ template <typename GridwiseGemm, ...@@ -18,7 +18,7 @@ template <typename GridwiseGemm,
typename AKMGridDesc, typename AKMGridDesc,
typename BKNGridDesc, typename BKNGridDesc,
typename CM0M10M11N0N10N11GridDesc, typename CM0M10M11N0N10N11GridDesc,
typename CBlockClusterDesc, typename CBlockClusterAdaptor,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop> bool HasDoubleTailKBlockLoop>
__global__ void __global__ void
...@@ -31,7 +31,7 @@ __global__ void ...@@ -31,7 +31,7 @@ __global__ void
const AKMGridDesc a_k_m_grid_desc, const AKMGridDesc a_k_m_grid_desc,
const BKNGridDesc b_k_n_grid_desc, const BKNGridDesc b_k_n_grid_desc,
const CM0M10M11N0N10N11GridDesc c_m0_m10_m11_n0_n10_n11_grid_desc, const CM0M10M11N0N10N11GridDesc c_m0_m10_m11_n0_n10_n11_grid_desc,
const CBlockClusterDesc c_block_cluster_desc) const CBlockClusterAdaptor c_block_cluster_desc)
{ {
constexpr index_t shared_block_size = constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
...@@ -57,8 +57,7 @@ template <index_t BlockSize, ...@@ -57,8 +57,7 @@ template <index_t BlockSize,
InMemoryDataOperation CGlobalMemoryDataOperation, InMemoryDataOperation CGlobalMemoryDataOperation,
typename AKMGridDesc, typename AKMGridDesc,
typename BKNGridDesc, typename BKNGridDesc,
typename CM0M10M11N0N10N11GridDesc, typename CMNGridDesc,
typename CBlockClusterDesc,
index_t MPerBlock, index_t MPerBlock,
index_t NPerBlock, index_t NPerBlock,
index_t KPerBlock, index_t KPerBlock,
...@@ -163,15 +162,55 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 ...@@ -163,15 +162,55 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
return b_k_n0_n1_block_clusterized_grid_desc; return b_k_n0_n1_block_clusterized_grid_desc;
} }
#if 0
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeCM0M10M11N0N10N11GridDescriptor(const BKNGridDesc& b_k_n_grid_desc) MakeCM0M10M11N0N10N11GridDescriptor(const CMNGridDesc& c_m_n_grid_desc)
{ {
const auto M = c_m_n_grid_desc.GetLength(I0);
const auto N = c_m_n_grid_desc.GetLength(I1);
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
constexpr auto M11 = Number<M1N1ThreadClusterM100 * M1N1ThreadClusterM101 * M1PerThread>{};
constexpr auto N11 = Number<M1N1ThreadClusterN100 * M1N1ThreadClusterN101 * N1PerThread>{};
constexpr auto M10 = M1 / M11;
constexpr auto N10 = N1 / N11;
const auto c_m0_m10_m11_n0_n10_n11_grid_desc = transform_dynamic_tensor_descriptor(
c_m_n_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)),
make_unmerge_transform(make_tuple(N0, N10, N11))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
return c_m0_m10_m11_n0_n10_n11_grid_desc;
}
__host__ __device__ static constexpr auto
MakeCBlockClusterAdaptor(const CMNGridDesc& c_m_n_grid_desc)
{
const auto M = c_m_n_grid_desc.GetLength(I0);
const auto N = c_m_n_grid_desc.GetLength(I1);
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
const auto c_block_cluster_adaptor = make_cluster_descriptor_v2(make_tuple(M0, N0));
return c_block_cluster_adaptor;
} }
#endif
using AKM0M1GridDesc = decltype(MakeAKM0M1GridDescriptor(AKMGridDesc{})); using AKM0M1GridDesc = decltype(MakeAKM0M1GridDescriptor(AKMGridDesc{}));
using BKN0N1GridDesc = decltype(MakeBKN0N1GridDescriptor(BKNGridDesc{})); using BKN0N1GridDesc = decltype(MakeBKN0N1GridDescriptor(BKNGridDesc{}));
using CM0M10M11N0N10N11GridDesc = decltype(MakeCM0M10M11N0N10N11GridDescriptor(CMNGridDesc{}));
using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}));
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop> template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid, __device__ static void Run(const FloatAB* __restrict__ p_a_grid,
...@@ -181,7 +220,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 ...@@ -181,7 +220,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
const AKMGridDesc& a_k_m_grid_desc, const AKMGridDesc& a_k_m_grid_desc,
const BKNGridDesc& b_k_n_grid_desc, const BKNGridDesc& b_k_n_grid_desc,
const CM0M10M11N0N10N11GridDesc& c_m0_m10_m11_n0_n10_n11_grid_desc, const CM0M10M11N0N10N11GridDesc& c_m0_m10_m11_n0_n10_n11_grid_desc,
const CBlockClusterDesc& c_block_cluster_desc, const CBlockClusterAdaptor& c_block_cluster_desc,
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) integral_constant<bool, HasDoubleTailKBlockLoop>)
{ {
...@@ -506,8 +545,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 ...@@ -506,8 +545,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
1, 1,
c_m10_n10_m11_n11_thread_tensor_lengths[I2], c_m10_n10_m11_n11_thread_tensor_lengths[I2],
c_m10_n10_m11_n11_thread_tensor_lengths[I3]>, c_m10_n10_m11_n11_thread_tensor_lengths[I3]>,
Sequence<3, 4, 5, 0, 1, 2>, // TODO: CThreadTransferSrcDstAccessOrder CThreadTransferSrcDstAccessOrder,
5, // TODO: CThreadTransferSrcDstVectorDim CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector, CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation, CGlobalMemoryDataOperation,
1, 1,
......
...@@ -551,8 +551,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw( ...@@ -551,8 +551,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
GemmBBlockTransferDstScalarPerVector_GemmN, GemmBBlockTransferDstScalarPerVector_GemmN,
false, // don't move back src coordinate after threadwise copy, which will be fused with false, // don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation // MoveSrcSliceWindow() to save addr computation
Sequence<2, 3, 0, 1>, Sequence<3, 4, 5, 0, 1, 2>, // CThreadTransferSrcDstAccessOrder
3, 5, // CThreadTransferSrcDstVectorDim
GemmCThreadTransferDstScalarPerVector_GemmN1, GemmCThreadTransferDstScalarPerVector_GemmN1,
decltype(wei_gemmk_gemmm_grid_iterator_hacks), decltype(wei_gemmk_gemmm_grid_iterator_hacks),
decltype(in_gemmk_gemmn_grid_iterator_hacks), decltype(in_gemmk_gemmn_grid_iterator_hacks),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment