Commit f63f1636 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent 51cdcee6
......@@ -69,44 +69,8 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const auto M = a_k_m_grid_desc.GetLength(I1);
const auto N = b_k_n_grid_desc.GetLength(I1);
const auto K = a_k_m_grid_desc.GetLength(I0);
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
{
throw std::runtime_error("wrong! GEMM size no divisible");
}
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
constexpr auto M11 = Number<M1PerThread * M1N1ThreadClusterM11 * M1N1ThreadClusterM10>{};
constexpr auto N11 = Number<N1PerThread * M1N1ThreadClusterN11 * M1N1ThreadClusterN10>{};
constexpr auto M10 = M1 / M11;
constexpr auto N10 = N1 / N11;
const auto c_m0_m10_m11_n0_n10_n11_grid_desc = transform_dynamic_tensor_descriptor(
c_m_n_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)),
make_unmerge_transform(make_tuple(N0, N10, N11))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
using CM0M10M11N0N10N11GridDesc = decltype(c_m0_m10_m11_n0_n10_n11_grid_desc);
// out_gemm_block_cluster_desc
const auto c_block_cluster_desc =
make_cluster_descriptor_v2(make_tuple(M / Number<MPerBlock>{}, N / Number<NPerBlock>{}));
using CBlockClusterDesc = decltype(c_block_cluster_desc);
// GEMM
using gridwise_gemm =
using GridwiseGemm =
GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2<BlockSize,
FloatAB,
FloatAcc,
......@@ -114,8 +78,7 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
CGlobalMemoryDataOperation,
AKMGridDesc,
BKNGridDesc,
CM0M10M11N0N10N11GridDesc,
CBlockClusterDesc,
CMNGridDesc,
MPerBlock,
NPerBlock,
KPerBlock,
......@@ -151,6 +114,26 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
AGridMoveSliceWindowIteratorHacks,
BGridMoveSliceWindowIteratorHacks>;
const auto M = a_k_m_grid_desc.GetLength(I1);
const auto N = b_k_n_grid_desc.GetLength(I1);
const auto K = a_k_m_grid_desc.GetLength(I0);
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
{
throw std::runtime_error("wrong! GEMM size no divisible");
}
// c_m0_m10_m11_n0_n10_n11_grid_desc
const auto c_m0_m10_m11_n0_n10_n11_grid_desc =
GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc);
using CM0M10M11N0N10N11GridDesc = decltype(c_m0_m10_m11_n0_n10_n11_grid_desc);
// c_block_cluster_adaptor
const auto c_block_cluster_adaptor = GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor);
const auto GridSize = (M / MPerBlock) * (N / NPerBlock);
const bool has_main_k_block_loop = (K + KPerBlock) / (2 * KPerBlock) > 1;
......@@ -161,13 +144,13 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_gemm_v1r2<gridwise_gemm,
const auto kernel = kernel_dynamic_gemm_v1r2<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AKMGridDesc>,
remove_reference_t<BKNGridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockClusterDesc>,
remove_reference_t<CBlockClusterAdaptor>,
true,
true>;
......@@ -183,17 +166,17 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
a_k_m_grid_desc,
b_k_n_grid_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc,
c_block_cluster_desc);
c_block_cluster_adaptor);
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_gemm_v1r2<gridwise_gemm,
const auto kernel = kernel_dynamic_gemm_v1r2<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AKMGridDesc>,
remove_reference_t<BKNGridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockClusterDesc>,
remove_reference_t<CBlockClusterAdaptor>,
true,
false>;
......@@ -209,17 +192,17 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
a_k_m_grid_desc,
b_k_n_grid_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc,
c_block_cluster_desc);
c_block_cluster_adaptor);
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_gemm_v1r2<gridwise_gemm,
const auto kernel = kernel_dynamic_gemm_v1r2<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AKMGridDesc>,
remove_reference_t<BKNGridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockClusterDesc>,
remove_reference_t<CBlockClusterAdaptor>,
false,
true>;
......@@ -235,17 +218,17 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
a_k_m_grid_desc,
b_k_n_grid_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc,
c_block_cluster_desc);
c_block_cluster_adaptor);
}
else
{
const auto kernel = kernel_dynamic_gemm_v1r2<gridwise_gemm,
const auto kernel = kernel_dynamic_gemm_v1r2<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AKMGridDesc>,
remove_reference_t<BKNGridDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockClusterDesc>,
remove_reference_t<CBlockClusterAdaptor>,
false,
false>;
......@@ -261,7 +244,7 @@ __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
a_k_m_grid_desc,
b_k_n_grid_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc,
c_block_cluster_desc);
c_block_cluster_adaptor);
}
return ave_time;
......
......@@ -18,7 +18,7 @@ template <typename GridwiseGemm,
typename AKMGridDesc,
typename BKNGridDesc,
typename CM0M10M11N0N10N11GridDesc,
typename CBlockClusterDesc,
typename CBlockClusterAdaptor,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__global__ void
......@@ -31,7 +31,7 @@ __global__ void
const AKMGridDesc a_k_m_grid_desc,
const BKNGridDesc b_k_n_grid_desc,
const CM0M10M11N0N10N11GridDesc c_m0_m10_m11_n0_n10_n11_grid_desc,
const CBlockClusterDesc c_block_cluster_desc)
const CBlockClusterAdaptor c_block_cluster_desc)
{
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
......@@ -57,8 +57,7 @@ template <index_t BlockSize,
InMemoryDataOperation CGlobalMemoryDataOperation,
typename AKMGridDesc,
typename BKNGridDesc,
typename CM0M10M11N0N10N11GridDesc,
typename CBlockClusterDesc,
typename CMNGridDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
......@@ -163,15 +162,55 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
return b_k_n0_n1_block_clusterized_grid_desc;
}
#if 0
__host__ __device__ static constexpr auto
MakeCM0M10M11N0N10N11GridDescriptor(const BKNGridDesc& b_k_n_grid_desc)
MakeCM0M10M11N0N10N11GridDescriptor(const CMNGridDesc& c_m_n_grid_desc)
{
const auto M = c_m_n_grid_desc.GetLength(I0);
const auto N = c_m_n_grid_desc.GetLength(I1);
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
constexpr auto M11 = Number<M1N1ThreadClusterM100 * M1N1ThreadClusterM101 * M1PerThread>{};
constexpr auto N11 = Number<M1N1ThreadClusterN100 * M1N1ThreadClusterN101 * N1PerThread>{};
constexpr auto M10 = M1 / M11;
constexpr auto N10 = N1 / N11;
const auto c_m0_m10_m11_n0_n10_n11_grid_desc = transform_dynamic_tensor_descriptor(
c_m_n_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)),
make_unmerge_transform(make_tuple(N0, N10, N11))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
return c_m0_m10_m11_n0_n10_n11_grid_desc;
}
__host__ __device__ static constexpr auto
MakeCBlockClusterAdaptor(const CMNGridDesc& c_m_n_grid_desc)
{
const auto M = c_m_n_grid_desc.GetLength(I0);
const auto N = c_m_n_grid_desc.GetLength(I1);
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
const auto c_block_cluster_adaptor = make_cluster_descriptor_v2(make_tuple(M0, N0));
return c_block_cluster_adaptor;
}
#endif
using AKM0M1GridDesc = decltype(MakeAKM0M1GridDescriptor(AKMGridDesc{}));
using BKN0N1GridDesc = decltype(MakeBKN0N1GridDescriptor(BKNGridDesc{}));
using AKM0M1GridDesc = decltype(MakeAKM0M1GridDescriptor(AKMGridDesc{}));
using BKN0N1GridDesc = decltype(MakeBKN0N1GridDescriptor(BKNGridDesc{}));
using CM0M10M11N0N10N11GridDesc = decltype(MakeCM0M10M11N0N10N11GridDescriptor(CMNGridDesc{}));
using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}));
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
......@@ -181,7 +220,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
const AKMGridDesc& a_k_m_grid_desc,
const BKNGridDesc& b_k_n_grid_desc,
const CM0M10M11N0N10N11GridDesc& c_m0_m10_m11_n0_n10_n11_grid_desc,
const CBlockClusterDesc& c_block_cluster_desc,
const CBlockClusterAdaptor& c_block_cluster_desc,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>)
{
......@@ -506,8 +545,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
1,
c_m10_n10_m11_n11_thread_tensor_lengths[I2],
c_m10_n10_m11_n11_thread_tensor_lengths[I3]>,
Sequence<3, 4, 5, 0, 1, 2>, // TODO: CThreadTransferSrcDstAccessOrder
5, // TODO: CThreadTransferSrcDstVectorDim
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
......
......@@ -551,8 +551,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
GemmBBlockTransferDstScalarPerVector_GemmN,
false, // don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation
Sequence<2, 3, 0, 1>,
3,
Sequence<3, 4, 5, 0, 1, 2>, // CThreadTransferSrcDstAccessOrder
5, // CThreadTransferSrcDstVectorDim
GemmCThreadTransferDstScalarPerVector_GemmN1,
decltype(wei_gemmk_gemmm_grid_iterator_hacks),
decltype(in_gemmk_gemmn_grid_iterator_hacks),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment