Commit aac345ab authored by Chao Liu's avatar Chao Liu
Browse files

refactor gridwise gemm

parent 9d8b39a7
...@@ -146,16 +146,16 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global, ...@@ -146,16 +146,16 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
if(has_main_k_block_loop && has_double_tail_k_block_loop) if(has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
const auto kernel = run_gridwise_operation<gridwise_gemm, const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AGlobalDesc>, remove_reference_t<AGlobalDesc>,
const FloatAB*,
remove_reference_t<BGlobalDesc>, remove_reference_t<BGlobalDesc>,
const FloatAB*,
remove_reference_t<CGlobalDesc>, remove_reference_t<CGlobalDesc>,
FloatC*,
remove_reference_t<CBlockClusterDesc>, remove_reference_t<CBlockClusterDesc>,
integral_constant<bool, true>, true,
integral_constant<bool, true>>; true>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(kernel,
nrepeat, nrepeat,
...@@ -163,28 +163,26 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global, ...@@ -163,28 +163,26 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
dim3(BlockSize), dim3(BlockSize),
0, 0,
0, 0,
a_k_m_global_desc,
p_a_global, p_a_global,
b_k_n_global_desc,
p_b_global, p_b_global,
c_m0_m1_n0_n1_global_desc,
p_c_global, p_c_global,
c_block_cluster_desc, a_k_m_global_desc,
integral_constant<bool, true>{}, b_k_n_global_desc,
integral_constant<bool, true>{}); c_m0_m1_n0_n1_global_desc,
c_block_cluster_desc);
} }
else if(has_main_k_block_loop && !has_double_tail_k_block_loop) else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{ {
const auto kernel = run_gridwise_operation<gridwise_gemm, const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AGlobalDesc>, remove_reference_t<AGlobalDesc>,
const FloatAB*,
remove_reference_t<BGlobalDesc>, remove_reference_t<BGlobalDesc>,
const FloatAB*,
remove_reference_t<CGlobalDesc>, remove_reference_t<CGlobalDesc>,
FloatC*,
remove_reference_t<CBlockClusterDesc>, remove_reference_t<CBlockClusterDesc>,
integral_constant<bool, true>, true,
integral_constant<bool, false>>; false>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(kernel,
nrepeat, nrepeat,
...@@ -192,28 +190,26 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global, ...@@ -192,28 +190,26 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
dim3(BlockSize), dim3(BlockSize),
0, 0,
0, 0,
a_k_m_global_desc,
p_a_global, p_a_global,
b_k_n_global_desc,
p_b_global, p_b_global,
c_m0_m1_n0_n1_global_desc,
p_c_global, p_c_global,
c_block_cluster_desc, a_k_m_global_desc,
integral_constant<bool, true>{}, b_k_n_global_desc,
integral_constant<bool, false>{}); c_m0_m1_n0_n1_global_desc,
c_block_cluster_desc);
} }
else if(!has_main_k_block_loop && has_double_tail_k_block_loop) else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
const auto kernel = run_gridwise_operation<gridwise_gemm, const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AGlobalDesc>, remove_reference_t<AGlobalDesc>,
const FloatAB*,
remove_reference_t<BGlobalDesc>, remove_reference_t<BGlobalDesc>,
const FloatAB*,
remove_reference_t<CGlobalDesc>, remove_reference_t<CGlobalDesc>,
FloatC*,
remove_reference_t<CBlockClusterDesc>, remove_reference_t<CBlockClusterDesc>,
integral_constant<bool, false>, false,
integral_constant<bool, true>>; true>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(kernel,
nrepeat, nrepeat,
...@@ -221,28 +217,26 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global, ...@@ -221,28 +217,26 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
dim3(BlockSize), dim3(BlockSize),
0, 0,
0, 0,
a_k_m_global_desc,
p_a_global, p_a_global,
b_k_n_global_desc,
p_b_global, p_b_global,
c_m0_m1_n0_n1_global_desc,
p_c_global, p_c_global,
c_block_cluster_desc, a_k_m_global_desc,
integral_constant<bool, false>{}, b_k_n_global_desc,
integral_constant<bool, true>{}); c_m0_m1_n0_n1_global_desc,
c_block_cluster_desc);
} }
else else
{ {
const auto kernel = run_gridwise_operation<gridwise_gemm, const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AGlobalDesc>, remove_reference_t<AGlobalDesc>,
const FloatAB*,
remove_reference_t<BGlobalDesc>, remove_reference_t<BGlobalDesc>,
const FloatAB*,
remove_reference_t<CGlobalDesc>, remove_reference_t<CGlobalDesc>,
FloatC*,
remove_reference_t<CBlockClusterDesc>, remove_reference_t<CBlockClusterDesc>,
integral_constant<bool, false>, false,
integral_constant<bool, false>>; false>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(kernel,
nrepeat, nrepeat,
...@@ -250,15 +244,13 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global, ...@@ -250,15 +244,13 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
dim3(BlockSize), dim3(BlockSize),
0, 0,
0, 0,
a_k_m_global_desc,
p_a_global, p_a_global,
b_k_n_global_desc,
p_b_global, p_b_global,
c_m0_m1_n0_n1_global_desc,
p_c_global, p_c_global,
c_block_cluster_desc, a_k_m_global_desc,
integral_constant<bool, false>{}, b_k_n_global_desc,
integral_constant<bool, false>{}); c_m0_m1_n0_n1_global_desc,
c_block_cluster_desc);
} }
return ave_time; return ave_time;
...@@ -277,13 +269,13 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global, ...@@ -277,13 +269,13 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
if(has_main_k_block_loop && has_double_tail_k_block_loop) if(has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
const auto kernel = run_gridwise_operation<gridwise_gemm, const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AGlobalDesc>, remove_reference_t<AGlobalDesc>,
const FloatAB*,
remove_reference_t<BGlobalDesc>, remove_reference_t<BGlobalDesc>,
const FloatAB*,
remove_reference_t<CGlobalDesc>, remove_reference_t<CGlobalDesc>,
FloatC*,
remove_reference_t<CBlockClusterDesc>, remove_reference_t<CBlockClusterDesc>,
true, true,
true>; true>;
...@@ -295,23 +287,23 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global, ...@@ -295,23 +287,23 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
dim3(BlockSize), dim3(BlockSize),
0, 0,
0, 0,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
p_a_global, p_a_global,
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
p_b_global, p_b_global,
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
p_c_global, p_c_global,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer()); (void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
} }
else if(has_main_k_block_loop && !has_double_tail_k_block_loop) else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{ {
const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm, const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm,
remove_reference_t<AGlobalDesc>,
FloatAB, FloatAB,
remove_reference_t<BGlobalDesc>,
FloatAB, FloatAB,
remove_reference_t<CGlobalDesc>,
FloatC, FloatC,
remove_reference_t<AGlobalDesc>,
remove_reference_t<BGlobalDesc>,
remove_reference_t<CGlobalDesc>,
remove_reference_t<CBlockClusterDesc>, remove_reference_t<CBlockClusterDesc>,
true, true,
false>; false>;
...@@ -323,23 +315,23 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global, ...@@ -323,23 +315,23 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
dim3(BlockSize), dim3(BlockSize),
0, 0,
0, 0,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
p_a_global, p_a_global,
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
p_b_global, p_b_global,
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
p_c_global, p_c_global,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer()); (void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
} }
else if(!has_main_k_block_loop && has_double_tail_k_block_loop) else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm, const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm,
remove_reference_t<AGlobalDesc>,
FloatAB, FloatAB,
remove_reference_t<BGlobalDesc>,
FloatAB, FloatAB,
remove_reference_t<CGlobalDesc>,
FloatC, FloatC,
remove_reference_t<AGlobalDesc>,
remove_reference_t<BGlobalDesc>,
remove_reference_t<CGlobalDesc>,
remove_reference_t<CBlockClusterDesc>, remove_reference_t<CBlockClusterDesc>,
false, false,
true>; true>;
...@@ -351,23 +343,23 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global, ...@@ -351,23 +343,23 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
dim3(BlockSize), dim3(BlockSize),
0, 0,
0, 0,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
p_a_global, p_a_global,
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
p_b_global, p_b_global,
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
p_c_global, p_c_global,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer()); (void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
} }
else else
{ {
const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm, const auto kernel = kernel_dynamic_gemm_v1<gridwise_gemm,
remove_reference_t<AGlobalDesc>,
FloatAB, FloatAB,
remove_reference_t<BGlobalDesc>,
FloatAB, FloatAB,
remove_reference_t<CGlobalDesc>,
FloatC, FloatC,
remove_reference_t<AGlobalDesc>,
remove_reference_t<BGlobalDesc>,
remove_reference_t<CGlobalDesc>,
remove_reference_t<CBlockClusterDesc>, remove_reference_t<CBlockClusterDesc>,
false, false,
false>; false>;
...@@ -379,12 +371,12 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global, ...@@ -379,12 +371,12 @@ __host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global,
dim3(BlockSize), dim3(BlockSize),
0, 0,
0, 0,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
p_a_global, p_a_global,
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
p_b_global, p_b_global,
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
p_c_global, p_c_global,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer()); (void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
} }
......
...@@ -14,12 +14,12 @@ namespace ck { ...@@ -14,12 +14,12 @@ namespace ck {
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE #if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename AGlobalDesc,
typename FloatA, typename FloatA,
typename BGlobalDesc,
typename FloatB, typename FloatB,
typename CGlobalDesc,
typename FloatC, typename FloatC,
typename AGlobalDesc,
typename BGlobalDesc,
typename CGlobalDesc,
typename CBlockClusterDesc, typename CBlockClusterDesc,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop> bool HasDoubleTailKBlockLoop>
...@@ -27,35 +27,36 @@ __global__ void ...@@ -27,35 +27,36 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_dynamic_gemm_v1(const AGlobalDesc a_k_m_global_desc, kernel_dynamic_gemm_v1(const FloatA* __restrict__ p_a_global,
const FloatA* __restrict__ p_a_global,
const BGlobalDesc b_k_n_global_desc,
const FloatB* __restrict__ p_b_global, const FloatB* __restrict__ p_b_global,
const CGlobalDesc c_m0_m1_n0_n1_global_desc,
FloatC* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
const AGlobalDesc a_k_m_global_desc,
const BGlobalDesc b_k_n_global_desc,
const CGlobalDesc c_m0_m1_n0_n1_global_desc,
const CBlockClusterDesc c_block_cluster_desc) const CBlockClusterDesc c_block_cluster_desc)
{ {
GridwiseGemm{}.Run(a_k_m_global_desc, GridwiseGemm::Run(
p_a_global, p_a_global,
b_k_n_global_desc, p_b_global,
p_b_global, p_c_global,
c_m0_m1_n0_n1_global_desc, a_k_m_global_desc,
p_c_global, b_k_n_global_desc,
c_block_cluster_desc, c_m0_m1_n0_n1_global_desc,
integral_constant<bool, HasMainKBlockLoop>{}, c_block_cluster_desc,
integral_constant<bool, HasDoubleTailKBlockLoop>{}); integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
} }
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER #elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
// pass tensor descriptor by __CONSTANT__ void pointer // pass tensor descriptor by __CONSTANT__ void pointer
// __CONSTANT__ is needed to inform compiler void pointers in the kernel signature are pointing to // __CONSTANT__ is needed to inform compiler void pointers in the kernel signature are pointing to
// non-modifiable parameter address space, so compiler can enable corresponding optimization // non-modifiable parameter address space, so compiler can enable corresponding optimization
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename AGlobalDesc,
typename FloatA, typename FloatA,
typename BGlobalDesc,
typename FloatB, typename FloatB,
typename CGlobalDesc,
typename FloatC, typename FloatC,
typename AGlobalDesc,
typename BGlobalDesc,
typename CGlobalDesc,
typename CBlockClusterDesc, typename CBlockClusterDesc,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop> bool HasDoubleTailKBlockLoop>
...@@ -63,12 +64,12 @@ __global__ void ...@@ -63,12 +64,12 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_dynamic_gemm_v1(const void __CONSTANT__* p_a_k_m_global_desc, kernel_dynamic_gemm_v1(const FloatA* __restrict__ p_a_global,
const FloatA* __restrict__ p_a_global,
const void __CONSTANT__* p_b_k_n_global_desc,
const FloatB* __restrict__ p_b_global, const FloatB* __restrict__ p_b_global,
const void __CONSTANT__* p_c_m0_m1_n0_n1_global_desc,
FloatC* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
const void __CONSTANT__* p_a_k_m_global_desc,
const void __CONSTANT__* p_b_k_n_global_desc,
const void __CONSTANT__* p_c_m0_m1_n0_n1_global_desc,
const void __CONSTANT__* p_c_block_cluster_desc) const void __CONSTANT__* p_c_block_cluster_desc)
{ {
// first cast void __CONSTANT__ void* to void* // first cast void __CONSTANT__ void* to void*
...@@ -84,15 +85,16 @@ __global__ void ...@@ -84,15 +85,16 @@ __global__ void
const auto c_block_cluster_desc = const auto c_block_cluster_desc =
*reinterpret_cast<const CBlockClusterDesc*>((const void*)p_c_block_cluster_desc); *reinterpret_cast<const CBlockClusterDesc*>((const void*)p_c_block_cluster_desc);
GridwiseGemm{}.Run(a_k_m_global_desc, GridwiseGemm::Run(
p_a_global, p_a_global,
b_k_n_global_desc, p_b_global,
p_b_global, p_c_global,
c_m0_m1_n0_n1_global_desc, a_k_m_global_desc,
p_c_global, b_k_n_global_desc,
c_block_cluster_desc, c_m0_m1_n0_n1_global_desc,
integral_constant<bool, HasMainKBlockLoop>{}, c_block_cluster_desc,
integral_constant<bool, HasDoubleTailKBlockLoop>{}); integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
} }
#endif #endif
...@@ -169,16 +171,17 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -169,16 +171,17 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
} }
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop> template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc& a_k_m_global_desc, __device__ static void Run(
const FloatAB* __restrict__ p_a_global, const FloatAB* __restrict__ p_a_global,
const BGlobalDesc& b_k_n_global_desc, const FloatAB* __restrict__ p_b_global,
const FloatAB* __restrict__ p_b_global, FloatC* __restrict__ p_c_global,
const CGlobalDesc& c_m0_m1_n0_n1_global_desc, const AGlobalDesc& a_k_m_global_desc,
FloatC* __restrict__ p_c_global, const BGlobalDesc& b_k_n_global_desc,
const CBlockClusterDesc& c_block_cluster_desc, const CGlobalDesc& c_m0_m1_n0_n1_global_desc,
FloatAB* __restrict__ p_shared_block, const CBlockClusterDesc& c_block_cluster_desc,
integral_constant<bool, HasMainKBlockLoop>, FloatAB* __restrict__ p_shared_block,
integral_constant<bool, HasDoubleTailKBlockLoop>) const integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>)
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -514,26 +517,28 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 ...@@ -514,26 +517,28 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
} }
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop> template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc& a_k_m_global_desc, __device__ static void Run(
const FloatAB* __restrict__ p_a_global, const FloatAB* __restrict__ p_a_global,
const BGlobalDesc& b_k_n_global_desc, const FloatAB* __restrict__ p_b_global,
const FloatAB* __restrict__ p_b_global, FloatC* __restrict__ p_c_global,
const CGlobalDesc& c_m0_m1_n0_n1_global_desc, const AGlobalDesc& a_k_m_global_desc,
FloatC* __restrict__ p_c_global, const BGlobalDesc& b_k_n_global_desc,
const CBlockClusterDesc& c_block_cluster_desc, const CGlobalDesc& c_m0_m1_n0_n1_global_desc,
integral_constant<bool, HasMainKBlockLoop>, const CBlockClusterDesc& c_block_cluster_desc,
integral_constant<bool, HasDoubleTailKBlockLoop>) const integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>)
{ {
constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(FloatAB); constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size]; __shared__ FloatAB p_shared_block[shared_block_size];
Run(a_k_m_global_desc, Run(
p_a_global, p_a_global,
b_k_n_global_desc,
p_b_global, p_b_global,
c_m0_m1_n0_n1_global_desc,
p_c_global, p_c_global,
a_k_m_global_desc,
b_k_n_global_desc,
c_m0_m1_n0_n1_global_desc,
c_block_cluster_desc, c_block_cluster_desc,
p_shared_block, p_shared_block,
integral_constant<bool, HasMainKBlockLoop>{}, integral_constant<bool, HasMainKBlockLoop>{},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment