Commit 6c37035f authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent 3a7fd7d6
......@@ -4,8 +4,7 @@
#include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "gridwise_dynamic_gemm.hpp"
#include "gridwise_operation_wrapper.hpp"
#include "gridwise_dynamic_gemm_v1r1.hpp"
namespace ck {
......@@ -52,7 +51,7 @@ template <index_t BlockSize,
typename CGlobalIteratorHacks,
typename AGlobalMoveSliceWindowIteratorHacks,
typename BGlobalMoveSliceWindowIteratorHacks>
__host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
__host__ float launch_kernel_dynamic_gemm_v1r1(const FloatAB* p_a_global,
const FloatAB* p_b_global,
FloatC* p_c_global,
const AGlobalDesc& a_k_m_global_desc,
......@@ -91,7 +90,7 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
// GEMM
using gridwise_gemm =
GridwiseDynamicGemm_km_kn_m0m1n0n1_v1<BlockSize,
GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r1<BlockSize,
FloatAB,
FloatAcc,
FloatC,
......@@ -146,7 +145,7 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm,
const auto kernel = kernel_dynamic_gemm_v1r1<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
......@@ -173,7 +172,7 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm,
const auto kernel = kernel_dynamic_gemm_v1r1<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
......@@ -200,7 +199,7 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm,
const auto kernel = kernel_dynamic_gemm_v1r1<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
......@@ -227,7 +226,7 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
}
else
{
const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm,
const auto kernel = kernel_dynamic_gemm_v1r1<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
......@@ -269,7 +268,7 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm,
const auto kernel = kernel_dynamic_gemm_v1r1<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
......@@ -297,7 +296,7 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm,
const auto kernel = kernel_dynamic_gemm_v1r1<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
......@@ -325,7 +324,7 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm,
const auto kernel = kernel_dynamic_gemm_v1r1<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
......@@ -353,7 +352,7 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
}
else
{
const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm,
const auto kernel = kernel_dynamic_gemm_v1r1<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
......
......@@ -27,7 +27,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_dynamic_gemm_v1(const FloatA* __restrict__ p_a_global,
kernel_dynamic_gemm_v1r1(const FloatA* __restrict__ p_a_global,
const FloatB* __restrict__ p_b_global,
FloatC* __restrict__ p_c_global,
const AGlobalDesc a_k_m_global_desc,
......@@ -63,7 +63,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_dynamic_gemm_v1(const FloatA* __restrict__ p_a_global,
kernel_dynamic_gemm_v1r1(const FloatA* __restrict__ p_a_global,
const FloatB* __restrict__ p_b_global,
FloatC* __restrict__ p_c_global,
const void __CONSTANT__* p_a_k_m_global_desc,
......@@ -139,7 +139,7 @@ template <index_t BlockSize,
typename CGlobalIteratorHacks,
typename AGlobalMoveSliceWindowIteratorHacks,
typename BGlobalMoveSliceWindowIteratorHacks>
struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r1
{
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
......
......@@ -2,7 +2,7 @@
#include "device.hpp"
#include "host_tensor.hpp"
#include "transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "driver_dynamic_gemm_v1.hpp"
#include "driver_dynamic_gemm_v1r1.hpp"
template <class TInWei,
ck::index_t InWeiVectorSize,
......@@ -490,7 +490,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(
for(index_t i = 0; i < 5; ++i)
{
float ave_time = launch_kernel_dynamic_gemm_v1<
float ave_time = launch_kernel_dynamic_gemm_v1r1<
BlockSize,
typename vector_type<TInWei, InWeiVectorSize>::type,
TAcc,
......
......@@ -2,7 +2,7 @@
#include "device.hpp"
#include "host_tensor.hpp"
#include "transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp"
#include "driver_dynamic_gemm_v1.hpp"
#include "driver_dynamic_gemm_v1r1.hpp"
template <class TInWei,
ck::index_t InWeiVectorSize,
......@@ -396,7 +396,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
for(index_t i = 0; i < 5; ++i)
{
float ave_time = launch_kernel_dynamic_gemm_v1<
float ave_time = launch_kernel_dynamic_gemm_v1r1<
BlockSize,
typename vector_type<TInWei, InWeiVectorSize>::type,
TAcc,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment