"vscode:/vscode.git/clone" did not exist on "7d5c47aa41e33e4aa35e5f2847e0c53571f42528"
Commit 6c37035f authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent 3a7fd7d6
...@@ -4,8 +4,7 @@ ...@@ -4,8 +4,7 @@
#include "common_header.hpp" #include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp" #include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp" #include "dynamic_tensor_descriptor_helper.hpp"
#include "gridwise_dynamic_gemm.hpp" #include "gridwise_dynamic_gemm_v1r1.hpp"
#include "gridwise_operation_wrapper.hpp"
namespace ck { namespace ck {
...@@ -52,7 +51,7 @@ template <index_t BlockSize, ...@@ -52,7 +51,7 @@ template <index_t BlockSize,
typename CGlobalIteratorHacks, typename CGlobalIteratorHacks,
typename AGlobalMoveSliceWindowIteratorHacks, typename AGlobalMoveSliceWindowIteratorHacks,
typename BGlobalMoveSliceWindowIteratorHacks> typename BGlobalMoveSliceWindowIteratorHacks>
__host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global, __host__ float launch_kernel_dynamic_gemm_v1r1(const FloatAB* p_a_global,
const FloatAB* p_b_global, const FloatAB* p_b_global,
FloatC* p_c_global, FloatC* p_c_global,
const AGlobalDesc& a_k_m_global_desc, const AGlobalDesc& a_k_m_global_desc,
...@@ -91,7 +90,7 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global, ...@@ -91,7 +90,7 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
// GEMM // GEMM
using gridwise_gemm = using gridwise_gemm =
GridwiseDynamicGemm_km_kn_m0m1n0n1_v1<BlockSize, GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r1<BlockSize,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
FloatC, FloatC,
...@@ -146,7 +145,7 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global, ...@@ -146,7 +145,7 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
if(has_main_k_block_loop && has_double_tail_k_block_loop) if(has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm, const auto kernel = kernel_dynamic_gemm_v1r1<gridwise_gemm,
FloatAB, FloatAB,
FloatAB, FloatAB,
FloatC, FloatC,
...@@ -173,7 +172,7 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global, ...@@ -173,7 +172,7 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
} }
else if(has_main_k_block_loop && !has_double_tail_k_block_loop) else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{ {
const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm, const auto kernel = kernel_dynamic_gemm_v1r1<gridwise_gemm,
FloatAB, FloatAB,
FloatAB, FloatAB,
FloatC, FloatC,
...@@ -200,7 +199,7 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global, ...@@ -200,7 +199,7 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
} }
else if(!has_main_k_block_loop && has_double_tail_k_block_loop) else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm, const auto kernel = kernel_dynamic_gemm_v1r1<gridwise_gemm,
FloatAB, FloatAB,
FloatAB, FloatAB,
FloatC, FloatC,
...@@ -227,7 +226,7 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global, ...@@ -227,7 +226,7 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
} }
else else
{ {
const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm, const auto kernel = kernel_dynamic_gemm_v1r1<gridwise_gemm,
FloatAB, FloatAB,
FloatAB, FloatAB,
FloatC, FloatC,
...@@ -269,7 +268,7 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global, ...@@ -269,7 +268,7 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
if(has_main_k_block_loop && has_double_tail_k_block_loop) if(has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm, const auto kernel = kernel_dynamic_gemm_v1r1<gridwise_gemm,
FloatAB, FloatAB,
FloatAB, FloatAB,
FloatC, FloatC,
...@@ -297,7 +296,7 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global, ...@@ -297,7 +296,7 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
} }
else if(has_main_k_block_loop && !has_double_tail_k_block_loop) else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{ {
const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm, const auto kernel = kernel_dynamic_gemm_v1r1<gridwise_gemm,
FloatAB, FloatAB,
FloatAB, FloatAB,
FloatC, FloatC,
...@@ -325,7 +324,7 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global, ...@@ -325,7 +324,7 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
} }
else if(!has_main_k_block_loop && has_double_tail_k_block_loop) else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm, const auto kernel = kernel_dynamic_gemm_v1r1<gridwise_gemm,
FloatAB, FloatAB,
FloatAB, FloatAB,
FloatC, FloatC,
...@@ -353,7 +352,7 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global, ...@@ -353,7 +352,7 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
} }
else else
{ {
const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm, const auto kernel = kernel_dynamic_gemm_v1r1<gridwise_gemm,
FloatAB, FloatAB,
FloatAB, FloatAB,
FloatC, FloatC,
......
...@@ -27,7 +27,7 @@ __global__ void ...@@ -27,7 +27,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_dynamic_gemm_v1(const FloatA* __restrict__ p_a_global, kernel_dynamic_gemm_v1r1(const FloatA* __restrict__ p_a_global,
const FloatB* __restrict__ p_b_global, const FloatB* __restrict__ p_b_global,
FloatC* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
const AGlobalDesc a_k_m_global_desc, const AGlobalDesc a_k_m_global_desc,
...@@ -63,7 +63,7 @@ __global__ void ...@@ -63,7 +63,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_dynamic_gemm_v1(const FloatA* __restrict__ p_a_global, kernel_dynamic_gemm_v1r1(const FloatA* __restrict__ p_a_global,
const FloatB* __restrict__ p_b_global, const FloatB* __restrict__ p_b_global,
FloatC* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
const void __CONSTANT__* p_a_k_m_global_desc, const void __CONSTANT__* p_a_k_m_global_desc,
...@@ -139,7 +139,7 @@ template <index_t BlockSize, ...@@ -139,7 +139,7 @@ template <index_t BlockSize,
typename CGlobalIteratorHacks, typename CGlobalIteratorHacks,
typename AGlobalMoveSliceWindowIteratorHacks, typename AGlobalMoveSliceWindowIteratorHacks,
typename BGlobalMoveSliceWindowIteratorHacks> typename BGlobalMoveSliceWindowIteratorHacks>
struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r1
{ {
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{ {
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "device.hpp" #include "device.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp" #include "transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "driver_dynamic_gemm_v1.hpp" #include "driver_dynamic_gemm_v1r1.hpp"
template <class TInWei, template <class TInWei,
ck::index_t InWeiVectorSize, ck::index_t InWeiVectorSize,
...@@ -490,7 +490,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw( ...@@ -490,7 +490,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(
for(index_t i = 0; i < 5; ++i) for(index_t i = 0; i < 5; ++i)
{ {
float ave_time = launch_kernel_dynamic_gemm_v1< float ave_time = launch_kernel_dynamic_gemm_v1r1<
BlockSize, BlockSize,
typename vector_type<TInWei, InWeiVectorSize>::type, typename vector_type<TInWei, InWeiVectorSize>::type,
TAcc, TAcc,
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "device.hpp" #include "device.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp" #include "transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp"
#include "driver_dynamic_gemm_v1.hpp" #include "driver_dynamic_gemm_v1r1.hpp"
template <class TInWei, template <class TInWei,
ck::index_t InWeiVectorSize, ck::index_t InWeiVectorSize,
...@@ -396,7 +396,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk( ...@@ -396,7 +396,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
for(index_t i = 0; i < 5; ++i) for(index_t i = 0; i < 5; ++i)
{ {
float ave_time = launch_kernel_dynamic_gemm_v1< float ave_time = launch_kernel_dynamic_gemm_v1r1<
BlockSize, BlockSize,
typename vector_type<TInWei, InWeiVectorSize>::type, typename vector_type<TInWei, InWeiVectorSize>::type,
TAcc, TAcc,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment