Commit af13f822 authored by Chao Liu's avatar Chao Liu
Browse files

use address_spacé¥e(4) in kernel signature to fix performance issue when...

use address_spacé¥e(4) in kernel signature to fix performance issue when passing tensor descriptor from host to kernel by (void) pointers
parent fcbb9788
......@@ -397,131 +397,115 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
{
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
const decltype(wei_gemmk_gemmm_global_desc)*,
const FloatAB*,
const decltype(in_gemmk_gemmn_global_desc)*,
const FloatAB*,
const decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc)*,
FloatC*,
integral_constant<bool, true>,
integral_constant<bool, true>>;
const auto kernel = run_gridwise_dynamic_gemm_v1<gridwise_gemm,
ADesc,
FloatAB,
BDesc,
FloatAB,
CDesc,
FloatC,
true,
true>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
reinterpret_cast<const ADesc*>(
(const ADesc __CONSTANT__*)reinterpret_cast<const ADesc*>(
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer()),
p_wei_global,
reinterpret_cast<const BDesc*>(
(const BDesc __CONSTANT__*)reinterpret_cast<const BDesc*>(
in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer()),
p_in_global,
reinterpret_cast<const CDesc*>(
(const CDesc __CONSTANT__*)reinterpret_cast<const CDesc*>(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer()),
p_out_global,
integral_constant<bool, true>{},
integral_constant<bool, true>{});
p_out_global);
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
const decltype(wei_gemmk_gemmm_global_desc)*,
const FloatAB*,
const decltype(in_gemmk_gemmn_global_desc)*,
const FloatAB*,
const decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc)*,
FloatC*,
integral_constant<bool, true>,
integral_constant<bool, false>>;
const auto kernel = run_gridwise_dynamic_gemm_v1<gridwise_gemm,
ADesc,
FloatAB,
BDesc,
FloatAB,
CDesc,
FloatC,
true,
false>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
reinterpret_cast<const ADesc*>(
(const ADesc __CONSTANT__*)reinterpret_cast<const ADesc*>(
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer()),
p_wei_global,
reinterpret_cast<const BDesc*>(
(const BDesc __CONSTANT__*)reinterpret_cast<const BDesc*>(
in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer()),
p_in_global,
reinterpret_cast<const CDesc*>(
(const CDesc __CONSTANT__*)reinterpret_cast<const CDesc*>(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer()),
p_out_global,
integral_constant<bool, true>{},
integral_constant<bool, false>{});
p_out_global);
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
const decltype(wei_gemmk_gemmm_global_desc)*,
const FloatAB*,
const decltype(in_gemmk_gemmn_global_desc)*,
const FloatAB*,
const decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc)*,
FloatC*,
integral_constant<bool, false>,
integral_constant<bool, true>>;
const auto kernel = run_gridwise_dynamic_gemm_v1<gridwise_gemm,
ADesc,
FloatAB,
BDesc,
FloatAB,
CDesc,
FloatC,
false,
true>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
reinterpret_cast<const ADesc*>(
(const ADesc __CONSTANT__*)reinterpret_cast<const ADesc*>(
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer()),
p_wei_global,
reinterpret_cast<const BDesc*>(
(const BDesc __CONSTANT__*)reinterpret_cast<const BDesc*>(
in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer()),
p_in_global,
reinterpret_cast<const CDesc*>(
(const CDesc __CONSTANT__*)reinterpret_cast<const CDesc*>(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer()),
p_out_global,
integral_constant<bool, false>{},
integral_constant<bool, true>{});
p_out_global);
}
else
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
const decltype(wei_gemmk_gemmm_global_desc)*,
const FloatAB*,
const decltype(in_gemmk_gemmn_global_desc)*,
const FloatAB*,
const decltype(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc)*,
FloatC*,
integral_constant<bool, false>,
integral_constant<bool, false>>;
const auto kernel = run_gridwise_dynamic_gemm_v1<gridwise_gemm,
ADesc,
FloatAB,
BDesc,
FloatAB,
CDesc,
FloatC,
false,
false>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
reinterpret_cast<const ADesc*>(
(const ADesc __CONSTANT__*)reinterpret_cast<const ADesc*>(
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer()),
p_wei_global,
reinterpret_cast<const BDesc*>(
(const BDesc __CONSTANT__*)reinterpret_cast<const BDesc*>(
in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer()),
p_in_global,
reinterpret_cast<const CDesc*>(
(const CDesc __CONSTANT__*)reinterpret_cast<const CDesc*>(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer()),
p_out_global,
integral_constant<bool, false>{},
integral_constant<bool, false>{});
p_out_global);
}
}
......@@ -564,111 +548,115 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
{
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = run_gridwise_operation<gridwise_gemm,
const void*,
const FloatAB*,
const void*,
const FloatAB*,
const void*,
FloatC*,
integral_constant<bool, true>,
integral_constant<bool, true>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
p_wei_global,
in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer(),
p_out_global,
integral_constant<bool, true>{},
integral_constant<bool, true>{});
const auto kernel = run_gridwise_dynamic_gemm_v1<gridwise_gemm,
ADesc,
FloatAB,
BDesc,
FloatAB,
CDesc,
FloatC,
true,
true>;
launch_kernel(
kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
(void __CONSTANT__*)
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
p_wei_global,
(void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global,
(void __CONSTANT__*)
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer(),
p_out_global);
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel = run_gridwise_operation<gridwise_gemm,
const void*,
const FloatAB*,
const void*,
const FloatAB*,
const void*,
FloatC*,
integral_constant<bool, true>,
integral_constant<bool, false>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
p_wei_global,
in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer(),
p_out_global,
integral_constant<bool, true>{},
integral_constant<bool, false>{});
const auto kernel = run_gridwise_dynamic_gemm_v1<gridwise_gemm,
ADesc,
FloatAB,
BDesc,
FloatAB,
CDesc,
FloatC,
true,
false>;
launch_kernel(
kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
(void __CONSTANT__*)
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
p_wei_global,
(void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global,
(void __CONSTANT__*)
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer(),
p_out_global);
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = run_gridwise_operation<gridwise_gemm,
const void*,
const FloatAB*,
const void*,
const FloatAB*,
const void*,
FloatC*,
integral_constant<bool, false>,
integral_constant<bool, true>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
p_wei_global,
in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer(),
p_out_global,
integral_constant<bool, false>{},
integral_constant<bool, true>{});
const auto kernel = run_gridwise_dynamic_gemm_v1<gridwise_gemm,
ADesc,
FloatAB,
BDesc,
FloatAB,
CDesc,
FloatC,
false,
true>;
launch_kernel(
kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
(void __CONSTANT__*)
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
p_wei_global,
(void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global,
(void __CONSTANT__*)
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer(),
p_out_global);
}
else
{
const auto kernel = run_gridwise_operation<gridwise_gemm,
const void*,
const FloatAB*,
const void*,
const FloatAB*,
const void*,
FloatC*,
integral_constant<bool, false>,
integral_constant<bool, false>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
p_wei_global,
in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer(),
p_out_global,
integral_constant<bool, false>{},
integral_constant<bool, false>{});
const auto kernel = run_gridwise_dynamic_gemm_v1<gridwise_gemm,
ADesc,
FloatAB,
BDesc,
FloatAB,
CDesc,
FloatC,
false,
false>;
launch_kernel(
kernel,
dim3(GridSize),
dim3(BlockSize),
0,
0,
(void __CONSTANT__*)
wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(),
p_wei_global,
(void __CONSTANT__*)in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(),
p_in_global,
(void __CONSTANT__*)
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf
.GetDeviceBuffer(),
p_out_global);
}
}
......
......@@ -11,6 +11,110 @@
namespace ck {
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
// pass tensor descriptor by value
template <typename GridwiseGemm,
typename AGlobalDesc,
typename FloatA,
typename BGlobalDesc,
typename FloatB,
typename CGlobalDesc,
typename FloatC,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__global__ void run_gridwise_dynamic_gemm_v1(const AGlobalDesc a_k_m_global_desc,
const FloatA* __restrict__ p_a_global,
const BGlobalDesc b_k_n_global_desc,
const FloatB* __restrict__ p_b_global,
const CGlobalDesc c_m0_m1_n0_n1_global_desc,
FloatC* __restrict__ p_c_global)
{
GridwiseGemm{}.Run(a_k_m_global_desc,
p_a_global,
b_k_n_global_desc,
p_b_global,
c_m0_m1_n0_n1_global_desc,
p_c_global,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_POINTER
// pass tensor descriptor by __CONSTANT__ pointer
// __CONSTANT__ is needed to inform compiler pointers in the kernel signature are pointing to
// non-modifiable parameter address space, so compiler can enable corresponding optimization
template <typename GridwiseGemm,
typename AGlobalDesc,
typename FloatA,
typename BGlobalDesc,
typename FloatB,
typename CGlobalDesc,
typename FloatC,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__global__ void
run_gridwise_dynamic_gemm_v1(const AGlobalDesc __CONSTANT__* p_a_k_m_global_desc,
const FloatA* __restrict__ p_a_global,
const BGlobalDesc __CONSTANT__* p_b_k_n_global_desc,
const FloatB* __restrict__ p_b_global,
const CGlobalDesc __CONSTANT__* p_c_m0_m1_n0_n1_global_desc,
FloatC* __restrict__ p_c_global)
{
// cast pointer to address_space(1), because the copy constructor of tensor descriptor is for
// address_space(1)
const auto a_k_m_global_desc = *(const AGlobalDesc*)p_a_k_m_global_desc;
const auto b_k_n_global_desc = *(const BGlobalDesc*)p_b_k_n_global_desc;
const auto c_m0_m1_n0_n1_global_desc = *(const CGlobalDesc*)p_c_m0_m1_n0_n1_global_desc;
GridwiseGemm{}.Run(a_k_m_global_desc,
p_a_global,
b_k_n_global_desc,
p_b_global,
c_m0_m1_n0_n1_global_desc,
p_c_global,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
// pass tensor descriptor by __CONSTANT__ void pointer
// __CONSTANT__ is needed to inform compiler void pointers in the kernel signature are pointing to
// non-modifiable parameter address space, so compiler can enable corresponding optimization
template <typename GridwiseGemm,
typename AGlobalDesc,
typename FloatA,
typename BGlobalDesc,
typename FloatB,
typename CGlobalDesc,
typename FloatC,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__global__ void run_gridwise_dynamic_gemm_v1(const void __CONSTANT__* p_a_k_m_global_desc,
const FloatA* __restrict__ p_a_global,
const void __CONSTANT__* p_b_k_n_global_desc,
const FloatB* __restrict__ p_b_global,
const void __CONSTANT__* p_c_m0_m1_n0_n1_global_desc,
FloatC* __restrict__ p_c_global)
{
// first cast void __CONSTANT__* to void*
// second cast void* to Desc*
// the copy constructor of tensor descriptor doesn't take address_space(4)
const auto a_k_m_global_desc =
*reinterpret_cast<const AGlobalDesc*>((const void*)p_a_k_m_global_desc);
const auto b_k_n_global_desc =
*reinterpret_cast<const BGlobalDesc*>((const void*)p_b_k_n_global_desc);
const auto c_m0_m1_n0_n1_global_desc =
*reinterpret_cast<const CGlobalDesc*>((const void*)p_c_m0_m1_n0_n1_global_desc);
GridwiseGemm{}.Run(a_k_m_global_desc,
p_a_global,
b_k_n_global_desc,
p_b_global,
c_m0_m1_n0_n1_global_desc,
p_c_global,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
#endif
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
......@@ -427,7 +531,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
}
}
// pass tensor descriptor by reference
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc& a_k_m_global_desc,
const FloatAB* __restrict__ p_a_global,
......@@ -452,57 +555,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
// pass tensor descriptors by pointers
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc* p_a_k_m_global_desc,
const FloatAB* __restrict__ p_a_global,
const BGlobalDesc* p_b_k_n_global_desc,
const FloatAB* __restrict__ p_b_global,
const CGlobalDesc* p_c_m0_m1_n0_n1_global_desc,
FloatC* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const
{
const auto a_k_m_global_desc = *p_a_k_m_global_desc;
const auto b_k_n_global_desc = *p_b_k_n_global_desc;
const auto c_m0_m1_n0_n1_global_desc = *p_c_m0_m1_n0_n1_global_desc;
Run(a_k_m_global_desc,
p_a_global,
b_k_n_global_desc,
p_b_global,
c_m0_m1_n0_n1_global_desc,
p_c_global,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
// pass tensor descriptors by void*
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const void* p_a_k_m_global_desc,
const FloatAB* __restrict__ p_a_global,
const void* p_b_k_n_global_desc,
const FloatAB* __restrict__ p_b_global,
const void* p_c_m0_m1_n0_n1_global_desc,
FloatC* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const
{
const auto a_k_m_global_desc = *reinterpret_cast<const AGlobalDesc*>(p_a_k_m_global_desc);
const auto b_k_n_global_desc = *reinterpret_cast<const BGlobalDesc*>(p_b_k_n_global_desc);
const auto c_m0_m1_n0_n1_global_desc =
*reinterpret_cast<const CGlobalDesc*>(p_c_m0_m1_n0_n1_global_desc);
Run(a_k_m_global_desc,
p_a_global,
b_k_n_global_desc,
p_b_global,
c_m0_m1_n0_n1_global_desc,
p_c_global,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
};
} // namespace ck
......
......@@ -7,6 +7,9 @@
#endif
#include "bfloat16_dev.hpp"
// address space for kernel parameter
#define __CONSTANT__ __attribute__((address_space(4)))
// device backend
#define CK_DEVICE_BACKEND_AMD 1
......@@ -105,9 +108,9 @@
#endif
// pass tensor descriptor by value, pointer or void*
#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE 1
#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE 0
#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_POINTER 0
#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER 0
#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER 1
// hack: have underlying assumption that need to be satsified, otherwise it's a bug
// hack for forcing register to keep idx_diff_low_const in SGPR. idx_diff_low_const must be
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment