"include/ck/host_utility" did not exist on "d1db6a0c3ea190996bdae37adda191f746bfc34e"
Commit 51cdcee6 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent 888a0a95
......@@ -50,7 +50,7 @@ template <index_t BlockSize,
typename CGridIteratorHacks,
typename AGridMoveSliceWindowIteratorHacks,
typename BGridMoveSliceWindowIteratorHacks>
__host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
__host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
const FloatAB* p_b_grid,
FloatC* p_c_grid,
const AKMGridDesc& a_k_m_grid_desc,
......@@ -78,26 +78,6 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
throw std::runtime_error("wrong! GEMM size no divisible");
}
const auto M1Old = Number<M1PerThread * M1N1ThreadClusterM11 * M1N1ThreadClusterM10>{};
const auto N1Old = Number<N1PerThread * M1N1ThreadClusterN11 * M1N1ThreadClusterN10>{};
if(!(MPerBlock % M1Old == 0 && NPerBlock % N1Old == 0))
{
throw std::runtime_error("wrong! GEMM size no divisible");
}
const auto M0Old = M / M1Old;
const auto N0Old = N / N1Old;
const auto c_m0_m1_n0_n1_grid_desc = transform_dynamic_tensor_descriptor(
c_m_n_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(M0Old, M1Old)),
make_unmerge_transform(make_tuple(N0Old, N1Old))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
using CM0M1N0N1GridDesc = decltype(c_m0_m1_n0_n1_grid_desc);
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
......@@ -134,9 +114,8 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
CGlobalMemoryDataOperation,
AKMGridDesc,
BKNGridDesc,
CM0M1N0N1GridDesc,
CBlockClusterDesc,
CM0M10M11N0N10N11GridDesc,
CBlockClusterDesc,
MPerBlock,
NPerBlock,
KPerBlock,
......@@ -183,14 +162,12 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_gemm_v1r2<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AKMGridDesc>,
remove_reference_t<BKNGridDesc>,
remove_reference_t<CM0M1N0N1GridDesc>,
remove_reference_t<CBlockClusterDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockClusterDesc>,
true,
true>;
......@@ -205,21 +182,18 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
p_c_grid,
a_k_m_grid_desc,
b_k_n_grid_desc,
c_m0_m1_n0_n1_grid_desc,
c_block_cluster_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc);
c_m0_m10_m11_n0_n10_n11_grid_desc,
c_block_cluster_desc);
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_gemm_v1r2<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AKMGridDesc>,
remove_reference_t<BKNGridDesc>,
remove_reference_t<CM0M1N0N1GridDesc>,
remove_reference_t<CBlockClusterDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockClusterDesc>,
true,
false>;
......@@ -234,21 +208,18 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
p_c_grid,
a_k_m_grid_desc,
b_k_n_grid_desc,
c_m0_m1_n0_n1_grid_desc,
c_block_cluster_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc);
c_m0_m10_m11_n0_n10_n11_grid_desc,
c_block_cluster_desc);
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_gemm_v1r2<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AKMGridDesc>,
remove_reference_t<BKNGridDesc>,
remove_reference_t<CM0M1N0N1GridDesc>,
remove_reference_t<CBlockClusterDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockClusterDesc>,
false,
true>;
......@@ -263,21 +234,18 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
p_c_grid,
a_k_m_grid_desc,
b_k_n_grid_desc,
c_m0_m1_n0_n1_grid_desc,
c_block_cluster_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc);
c_m0_m10_m11_n0_n10_n11_grid_desc,
c_block_cluster_desc);
}
else
{
const auto kernel = kernel_dynamic_gemm_v1r2<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AKMGridDesc>,
remove_reference_t<BKNGridDesc>,
remove_reference_t<CM0M1N0N1GridDesc>,
remove_reference_t<CBlockClusterDesc>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
remove_reference_t<CBlockClusterDesc>,
false,
false>;
......@@ -292,9 +260,8 @@ __host__ float launch_kernel_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
p_c_grid,
a_k_m_grid_desc,
b_k_n_grid_desc,
c_m0_m1_n0_n1_grid_desc,
c_block_cluster_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc);
c_m0_m10_m11_n0_n10_n11_grid_desc,
c_block_cluster_desc);
}
return ave_time;
......
......@@ -13,37 +13,39 @@
namespace ck {
template <typename GridwiseGemm,
typename FloatA,
typename FloatB,
typename FloatAB,
typename FloatC,
typename AKMGridDesc,
typename BKNGridDesc,
typename CM10M11N10N11GridDesc,
typename CBlockClusterDesc,
typename CM0M10M11N0N10N11GridDesc,
typename CBlockClusterDesc,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_dynamic_gemm_v1r2(const FloatA* __restrict__ p_a_grid,
const FloatB* __restrict__ p_b_grid,
kernel_dynamic_gemm_v1r2(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const AKMGridDesc a_k_m_grid_desc,
const BKNGridDesc b_k_n_grid_desc,
const CM10M11N10N11GridDesc c_m10_n10_m11_n11_grid_desc,
const CBlockClusterDesc c_block_cluster_desc,
const CM0M10M11N0N10N11GridDesc c_m0_m10_m11_n0_n10_n11_grid_desc)
const CM0M10M11N0N10N11GridDesc c_m0_m10_m11_n0_n10_n11_grid_desc,
const CBlockClusterDesc c_block_cluster_desc)
{
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
GridwiseGemm::Run(p_a_grid,
p_b_grid,
p_c_grid,
p_shared_block,
a_k_m_grid_desc,
b_k_n_grid_desc,
c_m10_n10_m11_n11_grid_desc,
c_block_cluster_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc,
c_block_cluster_desc,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
......@@ -55,9 +57,8 @@ template <index_t BlockSize,
InMemoryDataOperation CGlobalMemoryDataOperation,
typename AKMGridDesc,
typename BKNGridDesc,
typename CM10M11N10N11GridDesc,
typename CBlockClusterDesc,
typename CM0M10M11N0N10N11GridDesc,
typename CBlockClusterDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
......@@ -176,12 +177,11 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
FloatAB* __restrict__ p_shared_block,
const AKMGridDesc& a_k_m_grid_desc,
const BKNGridDesc& b_k_n_grid_desc,
const CM10M11N10N11GridDesc& c_m10_n10_m11_n11_grid_desc,
const CM0M10M11N0N10N11GridDesc& c_m0_m10_m11_n0_n10_n11_grid_desc,
const CBlockClusterDesc& c_block_cluster_desc,
const CM0M10M11N0N10N11GridDesc c_m0_m10_m11_n0_n10_n11_grid_desc,
FloatAB* __restrict__ p_shared_block,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>)
{
......@@ -190,7 +190,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
const auto b_global_buf = make_dynamic_buffer<AddressSpace::Global>(
p_b_grid, b_k_n_grid_desc.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpace::Global>(
p_c_grid, c_m10_n10_m11_n11_grid_desc.GetElementSpaceSize());
p_c_grid, c_m0_m10_m11_n0_n10_n11_grid_desc.GetElementSpaceSize());
const auto K = a_k_m_grid_desc.GetLength(I0);
const auto M = a_k_m_grid_desc.GetLength(I1);
......@@ -526,35 +526,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
CGridIteratorHacks{});
}
}
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const AKMGridDesc& a_k_m_grid_desc,
const BKNGridDesc& b_k_n_grid_desc,
const CM10M11N10N11GridDesc& c_m10_n10_m11_n11_grid_desc,
const CBlockClusterDesc& c_block_cluster_desc,
const CM0M10M11N0N10N11GridDesc c_m0_m10_m11_n0_n10_n11_grid_desc,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>)
{
constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
Run(p_a_grid,
p_b_grid,
p_c_grid,
a_k_m_grid_desc,
b_k_n_grid_desc,
c_m10_n10_m11_n11_grid_desc,
c_block_cluster_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc,
p_shared_block,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
};
} // namespace ck
......
......@@ -515,7 +515,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nchw_kcyx_nkhw(
for(index_t i = 0; i < 5; ++i)
{
float ave_time = launch_kernel_dynamic_gemm_v1r2<
float ave_time = driver_dynamic_gemm_v1r2<
BlockSize,
typename vector_type<TInWei, InWeiVectorSize>::type,
TAcc,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment