Commit aac345ab authored by Chao Liu's avatar Chao Liu
Browse files

refactor gridwise gemm

parent 9d8b39a7
......@@ -146,16 +146,16 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = run_gridwise_operation<gridwise_gemm,
const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AGlobalDesc>,
const FloatAB*,
remove_reference_t<BGlobalDesc>,
const FloatAB*,
remove_reference_t<CGlobalDesc>,
FloatC*,
remove_reference_t<CBlockClusterDesc>,
integral_constant<bool, true>,
integral_constant<bool, true>>;
true,
true>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
......@@ -163,28 +163,26 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
dim3(BlockSize),
0,
0,
a_k_m_global_desc,
p_a_global,
b_k_n_global_desc,
p_b_global,
c_m0_m1_n0_n1_global_desc,
p_c_global,
c_block_cluster_desc,
integral_constant<bool, true>{},
integral_constant<bool, true>{});
a_k_m_global_desc,
b_k_n_global_desc,
c_m0_m1_n0_n1_global_desc,
c_block_cluster_desc);
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel = run_gridwise_operation<gridwise_gemm,
const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AGlobalDesc>,
const FloatAB*,
remove_reference_t<BGlobalDesc>,
const FloatAB*,
remove_reference_t<CGlobalDesc>,
FloatC*,
remove_reference_t<CBlockClusterDesc>,
integral_constant<bool, true>,
integral_constant<bool, false>>;
true,
false>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
......@@ -192,28 +190,26 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
dim3(BlockSize),
0,
0,
a_k_m_global_desc,
p_a_global,
b_k_n_global_desc,
p_b_global,
c_m0_m1_n0_n1_global_desc,
p_c_global,
c_block_cluster_desc,
integral_constant<bool, true>{},
integral_constant<bool, false>{});
a_k_m_global_desc,
b_k_n_global_desc,
c_m0_m1_n0_n1_global_desc,
c_block_cluster_desc);
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = run_gridwise_operation<gridwise_gemm,
const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AGlobalDesc>,
const FloatAB*,
remove_reference_t<BGlobalDesc>,
const FloatAB*,
remove_reference_t<CGlobalDesc>,
FloatC*,
remove_reference_t<CBlockClusterDesc>,
integral_constant<bool, false>,
integral_constant<bool, true>>;
false,
true>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
......@@ -221,28 +217,26 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
dim3(BlockSize),
0,
0,
a_k_m_global_desc,
p_a_global,
b_k_n_global_desc,
p_b_global,
c_m0_m1_n0_n1_global_desc,
p_c_global,
c_block_cluster_desc,
integral_constant<bool, false>{},
integral_constant<bool, true>{});
a_k_m_global_desc,
b_k_n_global_desc,
c_m0_m1_n0_n1_global_desc,
c_block_cluster_desc);
}
else
{
const auto kernel = run_gridwise_operation<gridwise_gemm,
const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AGlobalDesc>,
const FloatAB*,
remove_reference_t<BGlobalDesc>,
const FloatAB*,
remove_reference_t<CGlobalDesc>,
FloatC*,
remove_reference_t<CBlockClusterDesc>,
integral_constant<bool, false>,
integral_constant<bool, false>>;
false,
false>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
......@@ -250,15 +244,13 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
dim3(BlockSize),
0,
0,
a_k_m_global_desc,
p_a_global,
b_k_n_global_desc,
p_b_global,
c_m0_m1_n0_n1_global_desc,
p_c_global,
c_block_cluster_desc,
integral_constant<bool, false>{},
integral_constant<bool, false>{});
a_k_m_global_desc,
b_k_n_global_desc,
c_m0_m1_n0_n1_global_desc,
c_block_cluster_desc);
}
return ave_time;
......@@ -277,13 +269,13 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = run_gridwise_operation<gridwise_gemm,
const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AGlobalDesc>,
const FloatAB*,
remove_reference_t<BGlobalDesc>,
const FloatAB*,
remove_reference_t<CGlobalDesc>,
FloatC*,
remove_reference_t<CBlockClusterDesc>,
true,
true>;
......@@ -295,23 +287,23 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
dim3(BlockSize),
0,
0,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
p_a_global,
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
p_b_global,
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
p_c_global,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm,
remove_reference_t<AGlobalDesc>,
FloatAB,
remove_reference_t<BGlobalDesc>,
FloatAB,
remove_reference_t<CGlobalDesc>,
FloatC,
remove_reference_t<AGlobalDesc>,
remove_reference_t<BGlobalDesc>,
remove_reference_t<CGlobalDesc>,
remove_reference_t<CBlockClusterDesc>,
true,
false>;
......@@ -323,23 +315,23 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
dim3(BlockSize),
0,
0,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
p_a_global,
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
p_b_global,
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
p_c_global,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm,
remove_reference_t<AGlobalDesc>,
FloatAB,
remove_reference_t<BGlobalDesc>,
FloatAB,
remove_reference_t<CGlobalDesc>,
FloatC,
remove_reference_t<AGlobalDesc>,
remove_reference_t<BGlobalDesc>,
remove_reference_t<CGlobalDesc>,
remove_reference_t<CBlockClusterDesc>,
false,
true>;
......@@ -351,23 +343,23 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
dim3(BlockSize),
0,
0,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
p_a_global,
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
p_b_global,
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
p_c_global,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
}
else
{
const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm,
remove_reference_t<AGlobalDesc>,
FloatAB,
remove_reference_t<BGlobalDesc>,
FloatAB,
remove_reference_t<CGlobalDesc>,
FloatC,
remove_reference_t<AGlobalDesc>,
remove_reference_t<BGlobalDesc>,
remove_reference_t<CGlobalDesc>,
remove_reference_t<CBlockClusterDesc>,
false,
false>;
......@@ -379,12 +371,12 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
dim3(BlockSize),
0,
0,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
p_a_global,
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
p_b_global,
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
p_c_global,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
}
......
......@@ -14,12 +14,12 @@ namespace ck {
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
template <typename GridwiseGemm,
typename AGlobalDesc,
typename FloatA,
typename BGlobalDesc,
typename FloatB,
typename CGlobalDesc,
typename FloatC,
typename AGlobalDesc,
typename BGlobalDesc,
typename CGlobalDesc,
typename CBlockClusterDesc,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
......@@ -27,35 +27,36 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_dynamic_gemm_v1(const AGlobalDesc a_k_m_global_desc,
const FloatA* __restrict__ p_a_global,
const BGlobalDesc b_k_n_global_desc,
kernel_dynamic_gemm_v1(const FloatA* __restrict__ p_a_global,
const FloatB* __restrict__ p_b_global,
const CGlobalDesc c_m0_m1_n0_n1_global_desc,
FloatC* __restrict__ p_c_global,
const AGlobalDesc a_k_m_global_desc,
const BGlobalDesc b_k_n_global_desc,
const CGlobalDesc c_m0_m1_n0_n1_global_desc,
const CBlockClusterDesc c_block_cluster_desc)
{
GridwiseGemm{}.Run(a_k_m_global_desc,
p_a_global,
b_k_n_global_desc,
p_b_global,
c_m0_m1_n0_n1_global_desc,
p_c_global,
c_block_cluster_desc,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
GridwiseGemm::Run(
p_a_global,
p_b_global,
p_c_global,
a_k_m_global_desc,
b_k_n_global_desc,
c_m0_m1_n0_n1_global_desc,
c_block_cluster_desc,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
// pass tensor descriptor by __CONSTANT__ void pointer
// __CONSTANT__ is needed to inform compiler void pointers in the kernel signature are pointing to
// non-modifiable parameter address space, so compiler can enable corresponding optimization
template <typename GridwiseGemm,
typename AGlobalDesc,
typename FloatA,
typename BGlobalDesc,
typename FloatB,
typename CGlobalDesc,
typename FloatC,
typename AGlobalDesc,
typename BGlobalDesc,
typename CGlobalDesc,
typename CBlockClusterDesc,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
......@@ -63,12 +64,12 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_dynamic_gemm_v1(const void __CONSTANT__* p_a_k_m_global_desc,
const FloatA* __restrict__ p_a_global,
const void __CONSTANT__* p_b_k_n_global_desc,
kernel_dynamic_gemm_v1(const FloatA* __restrict__ p_a_global,
const FloatB* __restrict__ p_b_global,
const void __CONSTANT__* p_c_m0_m1_n0_n1_global_desc,
FloatC* __restrict__ p_c_global,
const void __CONSTANT__* p_a_k_m_global_desc,
const void __CONSTANT__* p_b_k_n_global_desc,
const void __CONSTANT__* p_c_m0_m1_n0_n1_global_desc,
const void __CONSTANT__* p_c_block_cluster_desc)
{
// first cast void __CONSTANT__ void* to void*
......@@ -84,15 +85,16 @@ __global__ void
const auto c_block_cluster_desc =
*reinterpret_cast<const CBlockClusterDesc*>((const void*)p_c_block_cluster_desc);
GridwiseGemm{}.Run(a_k_m_global_desc,
p_a_global,
b_k_n_global_desc,
p_b_global,
c_m0_m1_n0_n1_global_desc,
p_c_global,
c_block_cluster_desc,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
GridwiseGemm::Run(
p_a_global,
p_b_global,
p_c_global,
a_k_m_global_desc,
b_k_n_global_desc,
c_m0_m1_n0_n1_global_desc,
c_block_cluster_desc,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
#endif
......@@ -169,16 +171,17 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
}
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc& a_k_m_global_desc,
const FloatAB* __restrict__ p_a_global,
const BGlobalDesc& b_k_n_global_desc,
const FloatAB* __restrict__ p_b_global,
const CGlobalDesc& c_m0_m1_n0_n1_global_desc,
FloatC* __restrict__ p_c_global,
const CBlockClusterDesc& c_block_cluster_desc,
FloatAB* __restrict__ p_shared_block,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const
__device__ static void Run(
const FloatAB* __restrict__ p_a_global,
const FloatAB* __restrict__ p_b_global,
FloatC* __restrict__ p_c_global,
const AGlobalDesc& a_k_m_global_desc,
const BGlobalDesc& b_k_n_global_desc,
const CGlobalDesc& c_m0_m1_n0_n1_global_desc,
const CBlockClusterDesc& c_block_cluster_desc,
FloatAB* __restrict__ p_shared_block,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
......@@ -514,26 +517,28 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
}
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc& a_k_m_global_desc,
const FloatAB* __restrict__ p_a_global,
const BGlobalDesc& b_k_n_global_desc,
const FloatAB* __restrict__ p_b_global,
const CGlobalDesc& c_m0_m1_n0_n1_global_desc,
FloatC* __restrict__ p_c_global,
const CBlockClusterDesc& c_block_cluster_desc,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const
__device__ static void Run(
const FloatAB* __restrict__ p_a_global,
const FloatAB* __restrict__ p_b_global,
FloatC* __restrict__ p_c_global,
const AGlobalDesc& a_k_m_global_desc,
const BGlobalDesc& b_k_n_global_desc,
const CGlobalDesc& c_m0_m1_n0_n1_global_desc,
const CBlockClusterDesc& c_block_cluster_desc,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>)
{
constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
Run(a_k_m_global_desc,
Run(
p_a_global,
b_k_n_global_desc,
p_b_global,
c_m0_m1_n0_n1_global_desc,
p_c_global,
a_k_m_global_desc,
b_k_n_global_desc,
c_m0_m1_n0_n1_global_desc,
c_block_cluster_desc,
p_shared_block,
integral_constant<bool, HasMainKBlockLoop>{},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment