Commit 6c37035f authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent 3a7fd7d6
...@@ -4,8 +4,7 @@ ...@@ -4,8 +4,7 @@
#include "common_header.hpp" #include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp" #include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp" #include "dynamic_tensor_descriptor_helper.hpp"
#include "gridwise_dynamic_gemm.hpp" #include "gridwise_dynamic_gemm_v1r1.hpp"
#include "gridwise_operation_wrapper.hpp"
namespace ck { namespace ck {
...@@ -52,19 +51,19 @@ template <index_t BlockSize, ...@@ -52,19 +51,19 @@ template <index_t BlockSize,
typename CGlobalIteratorHacks, typename CGlobalIteratorHacks,
typename AGlobalMoveSliceWindowIteratorHacks, typename AGlobalMoveSliceWindowIteratorHacks,
typename BGlobalMoveSliceWindowIteratorHacks> typename BGlobalMoveSliceWindowIteratorHacks>
__host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global, __host__ float launch_kernel_dynamic_gemm_v1r1(const FloatAB* p_a_global,
const FloatAB* p_b_global, const FloatAB* p_b_global,
FloatC* p_c_global, FloatC* p_c_global,
const AGlobalDesc& a_k_m_global_desc, const AGlobalDesc& a_k_m_global_desc,
const BGlobalDesc& b_k_n_global_desc, const BGlobalDesc& b_k_n_global_desc,
const CGlobalDesc& c_m0_m1_n0_n1_global_desc, const CGlobalDesc& c_m0_m1_n0_n1_global_desc,
const CBlockClusterDesc& c_block_cluster_desc, const CBlockClusterDesc& c_block_cluster_desc,
AGlobalIteratorHacks, AGlobalIteratorHacks,
BGlobalIteratorHacks, BGlobalIteratorHacks,
CGlobalIteratorHacks, CGlobalIteratorHacks,
AGlobalMoveSliceWindowIteratorHacks, AGlobalMoveSliceWindowIteratorHacks,
BGlobalMoveSliceWindowIteratorHacks, BGlobalMoveSliceWindowIteratorHacks,
index_t nrepeat) index_t nrepeat)
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
...@@ -91,49 +90,49 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global, ...@@ -91,49 +90,49 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
// GEMM // GEMM
using gridwise_gemm = using gridwise_gemm =
GridwiseDynamicGemm_km_kn_m0m1n0n1_v1<BlockSize, GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r1<BlockSize,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
FloatC, FloatC,
CGlobalMemoryDataOperation, CGlobalMemoryDataOperation,
AGlobalDesc, AGlobalDesc,
BGlobalDesc, BGlobalDesc,
CGlobalDesc, CGlobalDesc,
CBlockClusterDesc, CBlockClusterDesc,
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
KPerBlock, KPerBlock,
MPerThread, MPerThread,
NPerThread, NPerThread,
KPerThread, KPerThread,
MLevel0Cluster, MLevel0Cluster,
NLevel0Cluster, NLevel0Cluster,
MLevel1Cluster, MLevel1Cluster,
NLevel1Cluster, NLevel1Cluster,
ABlockTransferThreadSliceLengths_K_M, ABlockTransferThreadSliceLengths_K_M,
ABlockTransferThreadClusterLengths_K_M, ABlockTransferThreadClusterLengths_K_M,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim, ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_M, ABlockTransferDstScalarPerVector_M,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
BBlockTransferThreadSliceLengths_K_N, BBlockTransferThreadSliceLengths_K_N,
BBlockTransferThreadClusterLengths_K_N, BBlockTransferThreadClusterLengths_K_N,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim, BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_N, BBlockTransferDstScalarPerVector_N,
BThreadTransferSrcResetCoordinateAfterRun, BThreadTransferSrcResetCoordinateAfterRun,
CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector, CThreadTransferDstScalarPerVector,
AGlobalIteratorHacks, AGlobalIteratorHacks,
BGlobalIteratorHacks, BGlobalIteratorHacks,
CGlobalIteratorHacks, CGlobalIteratorHacks,
AGlobalMoveSliceWindowIteratorHacks, AGlobalMoveSliceWindowIteratorHacks,
BGlobalMoveSliceWindowIteratorHacks>; BGlobalMoveSliceWindowIteratorHacks>;
const auto GridSize = (M / MPerBlock) * (N / NPerBlock); const auto GridSize = (M / MPerBlock) * (N / NPerBlock);
...@@ -146,16 +145,16 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global, ...@@ -146,16 +145,16 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
if(has_main_k_block_loop && has_double_tail_k_block_loop) if(has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm, const auto kernel = kernel_dynamic_gemm_v1r1<gridwise_gemm,
FloatAB, FloatAB,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AGlobalDesc>, remove_reference_t<AGlobalDesc>,
remove_reference_t<BGlobalDesc>, remove_reference_t<BGlobalDesc>,
remove_reference_t<CGlobalDesc>, remove_reference_t<CGlobalDesc>,
remove_reference_t<CBlockClusterDesc>, remove_reference_t<CBlockClusterDesc>,
true, true,
true>; true>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(kernel,
nrepeat, nrepeat,
...@@ -173,16 +172,16 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global, ...@@ -173,16 +172,16 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
} }
else if(has_main_k_block_loop && !has_double_tail_k_block_loop) else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{ {
const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm, const auto kernel = kernel_dynamic_gemm_v1r1<gridwise_gemm,
FloatAB, FloatAB,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AGlobalDesc>, remove_reference_t<AGlobalDesc>,
remove_reference_t<BGlobalDesc>, remove_reference_t<BGlobalDesc>,
remove_reference_t<CGlobalDesc>, remove_reference_t<CGlobalDesc>,
remove_reference_t<CBlockClusterDesc>, remove_reference_t<CBlockClusterDesc>,
true, true,
false>; false>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(kernel,
nrepeat, nrepeat,
...@@ -200,16 +199,16 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global, ...@@ -200,16 +199,16 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
} }
else if(!has_main_k_block_loop && has_double_tail_k_block_loop) else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm, const auto kernel = kernel_dynamic_gemm_v1r1<gridwise_gemm,
FloatAB, FloatAB,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AGlobalDesc>, remove_reference_t<AGlobalDesc>,
remove_reference_t<BGlobalDesc>, remove_reference_t<BGlobalDesc>,
remove_reference_t<CGlobalDesc>, remove_reference_t<CGlobalDesc>,
remove_reference_t<CBlockClusterDesc>, remove_reference_t<CBlockClusterDesc>,
false, false,
true>; true>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(kernel,
nrepeat, nrepeat,
...@@ -227,16 +226,16 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global, ...@@ -227,16 +226,16 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
} }
else else
{ {
const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm, const auto kernel = kernel_dynamic_gemm_v1r1<gridwise_gemm,
FloatAB, FloatAB,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AGlobalDesc>, remove_reference_t<AGlobalDesc>,
remove_reference_t<BGlobalDesc>, remove_reference_t<BGlobalDesc>,
remove_reference_t<CGlobalDesc>, remove_reference_t<CGlobalDesc>,
remove_reference_t<CBlockClusterDesc>, remove_reference_t<CBlockClusterDesc>,
false, false,
false>; false>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(kernel,
nrepeat, nrepeat,
...@@ -269,16 +268,16 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global, ...@@ -269,16 +268,16 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
if(has_main_k_block_loop && has_double_tail_k_block_loop) if(has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm, const auto kernel = kernel_dynamic_gemm_v1r1<gridwise_gemm,
FloatAB, FloatAB,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AGlobalDesc>, remove_reference_t<AGlobalDesc>,
remove_reference_t<BGlobalDesc>, remove_reference_t<BGlobalDesc>,
remove_reference_t<CGlobalDesc>, remove_reference_t<CGlobalDesc>,
remove_reference_t<CBlockClusterDesc>, remove_reference_t<CBlockClusterDesc>,
true, true,
true>; true>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
kernel, kernel,
...@@ -297,16 +296,16 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global, ...@@ -297,16 +296,16 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
} }
else if(has_main_k_block_loop && !has_double_tail_k_block_loop) else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{ {
const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm, const auto kernel = kernel_dynamic_gemm_v1r1<gridwise_gemm,
FloatAB, FloatAB,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AGlobalDesc>, remove_reference_t<AGlobalDesc>,
remove_reference_t<BGlobalDesc>, remove_reference_t<BGlobalDesc>,
remove_reference_t<CGlobalDesc>, remove_reference_t<CGlobalDesc>,
remove_reference_t<CBlockClusterDesc>, remove_reference_t<CBlockClusterDesc>,
true, true,
false>; false>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
kernel, kernel,
...@@ -325,16 +324,16 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global, ...@@ -325,16 +324,16 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
} }
else if(!has_main_k_block_loop && has_double_tail_k_block_loop) else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm, const auto kernel = kernel_dynamic_gemm_v1r1<gridwise_gemm,
FloatAB, FloatAB,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AGlobalDesc>, remove_reference_t<AGlobalDesc>,
remove_reference_t<BGlobalDesc>, remove_reference_t<BGlobalDesc>,
remove_reference_t<CGlobalDesc>, remove_reference_t<CGlobalDesc>,
remove_reference_t<CBlockClusterDesc>, remove_reference_t<CBlockClusterDesc>,
false, false,
true>; true>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
kernel, kernel,
...@@ -353,16 +352,16 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global, ...@@ -353,16 +352,16 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
} }
else else
{ {
const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm, const auto kernel = kernel_dynamic_gemm_v1r1<gridwise_gemm,
FloatAB, FloatAB,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AGlobalDesc>, remove_reference_t<AGlobalDesc>,
remove_reference_t<BGlobalDesc>, remove_reference_t<BGlobalDesc>,
remove_reference_t<CGlobalDesc>, remove_reference_t<CGlobalDesc>,
remove_reference_t<CBlockClusterDesc>, remove_reference_t<CBlockClusterDesc>,
false, false,
false>; false>;
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
kernel, kernel,
......
...@@ -27,13 +27,13 @@ __global__ void ...@@ -27,13 +27,13 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_dynamic_gemm_v1(const FloatA* __restrict__ p_a_global, kernel_dynamic_gemm_v1r1(const FloatA* __restrict__ p_a_global,
const FloatB* __restrict__ p_b_global, const FloatB* __restrict__ p_b_global,
FloatC* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
const AGlobalDesc a_k_m_global_desc, const AGlobalDesc a_k_m_global_desc,
const BGlobalDesc b_k_n_global_desc, const BGlobalDesc b_k_n_global_desc,
const CGlobalDesc c_m0_m1_n0_n1_global_desc, const CGlobalDesc c_m0_m1_n0_n1_global_desc,
const CBlockClusterDesc c_block_cluster_desc) const CBlockClusterDesc c_block_cluster_desc)
{ {
GridwiseGemm::Run(p_a_global, GridwiseGemm::Run(p_a_global,
p_b_global, p_b_global,
...@@ -63,13 +63,13 @@ __global__ void ...@@ -63,13 +63,13 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_dynamic_gemm_v1(const FloatA* __restrict__ p_a_global, kernel_dynamic_gemm_v1r1(const FloatA* __restrict__ p_a_global,
const FloatB* __restrict__ p_b_global, const FloatB* __restrict__ p_b_global,
FloatC* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
const void __CONSTANT__* p_a_k_m_global_desc, const void __CONSTANT__* p_a_k_m_global_desc,
const void __CONSTANT__* p_b_k_n_global_desc, const void __CONSTANT__* p_b_k_n_global_desc,
const void __CONSTANT__* p_c_m0_m1_n0_n1_global_desc, const void __CONSTANT__* p_c_m0_m1_n0_n1_global_desc,
const void __CONSTANT__* p_c_block_cluster_desc) const void __CONSTANT__* p_c_block_cluster_desc)
{ {
// first cast void __CONSTANT__ void* to void* // first cast void __CONSTANT__ void* to void*
// second cast void* to Desc* // second cast void* to Desc*
...@@ -139,7 +139,7 @@ template <index_t BlockSize, ...@@ -139,7 +139,7 @@ template <index_t BlockSize,
typename CGlobalIteratorHacks, typename CGlobalIteratorHacks,
typename AGlobalMoveSliceWindowIteratorHacks, typename AGlobalMoveSliceWindowIteratorHacks,
typename BGlobalMoveSliceWindowIteratorHacks> typename BGlobalMoveSliceWindowIteratorHacks>
struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r1
{ {
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{ {
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "device.hpp" #include "device.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp" #include "transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "driver_dynamic_gemm_v1.hpp" #include "driver_dynamic_gemm_v1r1.hpp"
template <class TInWei, template <class TInWei,
ck::index_t InWeiVectorSize, ck::index_t InWeiVectorSize,
...@@ -490,7 +490,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw( ...@@ -490,7 +490,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(
for(index_t i = 0; i < 5; ++i) for(index_t i = 0; i < 5; ++i)
{ {
float ave_time = launch_kernel_dynamic_gemm_v1< float ave_time = launch_kernel_dynamic_gemm_v1r1<
BlockSize, BlockSize,
typename vector_type<TInWei, InWeiVectorSize>::type, typename vector_type<TInWei, InWeiVectorSize>::type,
TAcc, TAcc,
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "device.hpp" #include "device.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp" #include "transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp"
#include "driver_dynamic_gemm_v1.hpp" #include "driver_dynamic_gemm_v1r1.hpp"
template <class TInWei, template <class TInWei,
ck::index_t InWeiVectorSize, ck::index_t InWeiVectorSize,
...@@ -396,7 +396,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk( ...@@ -396,7 +396,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
for(index_t i = 0; i < 5; ++i) for(index_t i = 0; i < 5; ++i)
{ {
float ave_time = launch_kernel_dynamic_gemm_v1< float ave_time = launch_kernel_dynamic_gemm_v1r1<
BlockSize, BlockSize,
typename vector_type<TInWei, InWeiVectorSize>::type, typename vector_type<TInWei, InWeiVectorSize>::type,
TAcc, TAcc,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment