Commit 51cdcee6 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent 888a0a95
...@@ -50,18 +50,18 @@ template <index_t BlockSize, ...@@ -50,18 +50,18 @@ template <index_t BlockSize,
typename CGridIteratorHacks, typename CGridIteratorHacks,
typename AGridMoveSliceWindowIteratorHacks, typename AGridMoveSliceWindowIteratorHacks,
typename BGridMoveSliceWindowIteratorHacks> typename BGridMoveSliceWindowIteratorHacks>
__host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid, __host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
const FloatAB* p_b_grid, const FloatAB* p_b_grid,
FloatC* p_c_grid, FloatC* p_c_grid,
const AKMGridDesc& a_k_m_grid_desc, const AKMGridDesc& a_k_m_grid_desc,
const BKNGridDesc& b_k_n_grid_desc, const BKNGridDesc& b_k_n_grid_desc,
const CMNGridDesc& c_m_n_grid_desc, const CMNGridDesc& c_m_n_grid_desc,
AGridIteratorHacks, AGridIteratorHacks,
BGridIteratorHacks, BGridIteratorHacks,
CGridIteratorHacks, CGridIteratorHacks,
AGridMoveSliceWindowIteratorHacks, AGridMoveSliceWindowIteratorHacks,
BGridMoveSliceWindowIteratorHacks, BGridMoveSliceWindowIteratorHacks,
index_t nrepeat) index_t nrepeat)
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
...@@ -78,26 +78,6 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -78,26 +78,6 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
throw std::runtime_error("wrong! GEMM size no divisible"); throw std::runtime_error("wrong! GEMM size no divisible");
} }
const auto M1Old = Number<M1PerThread * M1N1ThreadClusterM11 * M1N1ThreadClusterM10>{};
const auto N1Old = Number<N1PerThread * M1N1ThreadClusterN11 * M1N1ThreadClusterN10>{};
if(!(MPerBlock % M1Old == 0 && NPerBlock % N1Old == 0))
{
throw std::runtime_error("wrong! GEMM size no divisible");
}
const auto M0Old = M / M1Old;
const auto N0Old = N / N1Old;
const auto c_m0_m1_n0_n1_grid_desc = transform_dynamic_tensor_descriptor(
c_m_n_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(M0Old, M1Old)),
make_unmerge_transform(make_tuple(N0Old, N1Old))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
using CM0M1N0N1GridDesc = decltype(c_m0_m1_n0_n1_grid_desc);
constexpr auto M1 = Number<MPerBlock>{}; constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{}; constexpr auto N1 = Number<NPerBlock>{};
...@@ -134,9 +114,8 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -134,9 +114,8 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
CGlobalMemoryDataOperation, CGlobalMemoryDataOperation,
AKMGridDesc, AKMGridDesc,
BKNGridDesc, BKNGridDesc,
CM0M1N0N1GridDesc,
CBlockClusterDesc,
CM0M10M11N0N10N11GridDesc, CM0M10M11N0N10N11GridDesc,
CBlockClusterDesc,
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
KPerBlock, KPerBlock,
...@@ -183,14 +162,12 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -183,14 +162,12 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
if(has_main_k_block_loop && has_double_tail_k_block_loop) if(has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
const auto kernel = kernel_dynamic_gemm_v1r2<gridwise_gemm, const auto kernel = kernel_dynamic_gemm_v1r2<gridwise_gemm,
FloatAB,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AKMGridDesc>, remove_reference_t<AKMGridDesc>,
remove_reference_t<BKNGridDesc>, remove_reference_t<BKNGridDesc>,
remove_reference_t<CM0M1N0N1GridDesc>,
remove_reference_t<CBlockClusterDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>, remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockClusterDesc>,
true, true,
true>; true>;
...@@ -205,21 +182,18 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -205,21 +182,18 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
p_c_grid, p_c_grid,
a_k_m_grid_desc, a_k_m_grid_desc,
b_k_n_grid_desc, b_k_n_grid_desc,
c_m0_m1_n0_n1_grid_desc, c_m0_m10_m11_n0_n10_n11_grid_desc,
c_block_cluster_desc, c_block_cluster_desc);
c_m0_m10_m11_n0_n10_n11_grid_desc);
} }
else if(has_main_k_block_loop && !has_double_tail_k_block_loop) else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{ {
const auto kernel = kernel_dynamic_gemm_v1r2<gridwise_gemm, const auto kernel = kernel_dynamic_gemm_v1r2<gridwise_gemm,
FloatAB,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AKMGridDesc>, remove_reference_t<AKMGridDesc>,
remove_reference_t<BKNGridDesc>, remove_reference_t<BKNGridDesc>,
remove_reference_t<CM0M1N0N1GridDesc>,
remove_reference_t<CBlockClusterDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>, remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockClusterDesc>,
true, true,
false>; false>;
...@@ -234,21 +208,18 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -234,21 +208,18 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
p_c_grid, p_c_grid,
a_k_m_grid_desc, a_k_m_grid_desc,
b_k_n_grid_desc, b_k_n_grid_desc,
c_m0_m1_n0_n1_grid_desc, c_m0_m10_m11_n0_n10_n11_grid_desc,
c_block_cluster_desc, c_block_cluster_desc);
c_m0_m10_m11_n0_n10_n11_grid_desc);
} }
else if(!has_main_k_block_loop && has_double_tail_k_block_loop) else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
const auto kernel = kernel_dynamic_gemm_v1r2<gridwise_gemm, const auto kernel = kernel_dynamic_gemm_v1r2<gridwise_gemm,
FloatAB,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AKMGridDesc>, remove_reference_t<AKMGridDesc>,
remove_reference_t<BKNGridDesc>, remove_reference_t<BKNGridDesc>,
remove_reference_t<CM0M1N0N1GridDesc>,
remove_reference_t<CBlockClusterDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>, remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockClusterDesc>,
false, false,
true>; true>;
...@@ -263,21 +234,18 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -263,21 +234,18 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
p_c_grid, p_c_grid,
a_k_m_grid_desc, a_k_m_grid_desc,
b_k_n_grid_desc, b_k_n_grid_desc,
c_m0_m1_n0_n1_grid_desc, c_m0_m10_m11_n0_n10_n11_grid_desc,
c_block_cluster_desc, c_block_cluster_desc);
c_m0_m10_m11_n0_n10_n11_grid_desc);
} }
else else
{ {
const auto kernel = kernel_dynamic_gemm_v1r2<gridwise_gemm, const auto kernel = kernel_dynamic_gemm_v1r2<gridwise_gemm,
FloatAB,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AKMGridDesc>, remove_reference_t<AKMGridDesc>,
remove_reference_t<BKNGridDesc>, remove_reference_t<BKNGridDesc>,
remove_reference_t<CM0M1N0N1GridDesc>,
remove_reference_t<CBlockClusterDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>, remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockClusterDesc>,
false, false,
false>; false>;
...@@ -292,9 +260,8 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid, ...@@ -292,9 +260,8 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
p_c_grid, p_c_grid,
a_k_m_grid_desc, a_k_m_grid_desc,
b_k_n_grid_desc, b_k_n_grid_desc,
c_m0_m1_n0_n1_grid_desc, c_m0_m10_m11_n0_n10_n11_grid_desc,
c_block_cluster_desc, c_block_cluster_desc);
c_m0_m10_m11_n0_n10_n11_grid_desc);
} }
return ave_time; return ave_time;
......
...@@ -13,37 +13,39 @@ ...@@ -13,37 +13,39 @@
namespace ck { namespace ck {
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatA, typename FloatAB,
typename FloatB,
typename FloatC, typename FloatC,
typename AKMGridDesc, typename AKMGridDesc,
typename BKNGridDesc, typename BKNGridDesc,
typename CM10M11N10N11GridDesc,
typename CBlockClusterDesc,
typename CM0M10M11N0N10N11GridDesc, typename CM0M10M11N0N10N11GridDesc,
typename CBlockClusterDesc,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop> bool HasDoubleTailKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_dynamic_gemm_v1r2(const FloatA* __restrict__ p_a_grid, kernel_dynamic_gemm_v1r2(const FloatAB* __restrict__ p_a_grid,
const FloatB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const AKMGridDesc a_k_m_grid_desc, const AKMGridDesc a_k_m_grid_desc,
const BKNGridDesc b_k_n_grid_desc, const BKNGridDesc b_k_n_grid_desc,
const CM10M11N10N11GridDesc c_m10_n10_m11_n11_grid_desc, const CM0M10M11N0N10N11GridDesc c_m0_m10_m11_n0_n10_n11_grid_desc,
const CBlockClusterDesc c_block_cluster_desc, const CBlockClusterDesc c_block_cluster_desc)
const CM0M10M11N0N10N11GridDesc c_m0_m10_m11_n0_n10_n11_grid_desc)
{ {
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
GridwiseGemm::Run(p_a_grid, GridwiseGemm::Run(p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
p_shared_block,
a_k_m_grid_desc, a_k_m_grid_desc,
b_k_n_grid_desc, b_k_n_grid_desc,
c_m10_n10_m11_n11_grid_desc,
c_block_cluster_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc, c_m0_m10_m11_n0_n10_n11_grid_desc,
c_block_cluster_desc,
integral_constant<bool, HasMainKBlockLoop>{}, integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{}); integral_constant<bool, HasDoubleTailKBlockLoop>{});
} }
...@@ -55,9 +57,8 @@ template <index_t BlockSize, ...@@ -55,9 +57,8 @@ template <index_t BlockSize,
InMemoryDataOperation CGlobalMemoryDataOperation, InMemoryDataOperation CGlobalMemoryDataOperation,
typename AKMGridDesc, typename AKMGridDesc,
typename BKNGridDesc, typename BKNGridDesc,
typename CM10M11N10N11GridDesc,
typename CBlockClusterDesc,
typename CM0M10M11N0N10N11GridDesc, typename CM0M10M11N0N10N11GridDesc,
typename CBlockClusterDesc,
index_t MPerBlock, index_t MPerBlock,
index_t NPerBlock, index_t NPerBlock,
index_t KPerBlock, index_t KPerBlock,
...@@ -176,12 +177,11 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 ...@@ -176,12 +177,11 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
__device__ static void Run(const FloatAB* __restrict__ p_a_grid, __device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
FloatAB* __restrict__ p_shared_block,
const AKMGridDesc& a_k_m_grid_desc, const AKMGridDesc& a_k_m_grid_desc,
const BKNGridDesc& b_k_n_grid_desc, const BKNGridDesc& b_k_n_grid_desc,
const CM10M11N10N11GridDesc& c_m10_n10_m11_n11_grid_desc, const CM0M10M11N0N10N11GridDesc& c_m0_m10_m11_n0_n10_n11_grid_desc,
const CBlockClusterDesc& c_block_cluster_desc, const CBlockClusterDesc& c_block_cluster_desc,
const CM0M10M11N0N10N11GridDesc c_m0_m10_m11_n0_n10_n11_grid_desc,
FloatAB* __restrict__ p_shared_block,
integral_constant<bool, HasMainKBlockLoop>, integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) integral_constant<bool, HasDoubleTailKBlockLoop>)
{ {
...@@ -190,7 +190,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 ...@@ -190,7 +190,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
const auto b_global_buf = make_dynamic_buffer<AddressSpace::Global>( const auto b_global_buf = make_dynamic_buffer<AddressSpace::Global>(
p_b_grid, b_k_n_grid_desc.GetElementSpaceSize()); p_b_grid, b_k_n_grid_desc.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpace::Global>( auto c_grid_buf = make_dynamic_buffer<AddressSpace::Global>(
p_c_grid, c_m10_n10_m11_n11_grid_desc.GetElementSpaceSize()); p_c_grid, c_m0_m10_m11_n0_n10_n11_grid_desc.GetElementSpaceSize());
const auto K = a_k_m_grid_desc.GetLength(I0); const auto K = a_k_m_grid_desc.GetLength(I0);
const auto M = a_k_m_grid_desc.GetLength(I1); const auto M = a_k_m_grid_desc.GetLength(I1);
...@@ -526,35 +526,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2 ...@@ -526,35 +526,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
CGridIteratorHacks{}); CGridIteratorHacks{});
} }
} }
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const AKMGridDesc& a_k_m_grid_desc,
const BKNGridDesc& b_k_n_grid_desc,
const CM10M11N10N11GridDesc& c_m10_n10_m11_n11_grid_desc,
const CBlockClusterDesc& c_block_cluster_desc,
const CM0M10M11N0N10N11GridDesc c_m0_m10_m11_n0_n10_n11_grid_desc,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>)
{
constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
Run(p_a_grid,
p_b_grid,
p_c_grid,
a_k_m_grid_desc,
b_k_n_grid_desc,
c_m10_n10_m11_n11_grid_desc,
c_block_cluster_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc,
p_shared_block,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
}; };
} // namespace ck } // namespace ck
......
...@@ -515,7 +515,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw( ...@@ -515,7 +515,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
for(index_t i = 0; i < 5; ++i) for(index_t i = 0; i < 5; ++i)
{ {
float ave_time = launch_kernel_dynamic_gemm_v1r2< float ave_time = driver_dynamic_gemm_v1r2<
BlockSize, BlockSize,
typename vector_type<TInWei, InWeiVectorSize>::type, typename vector_type<TInWei, InWeiVectorSize>::type,
TAcc, TAcc,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment