Commit 4ad62d7f authored by Jing Zhang's avatar Jing Zhang
Browse files

use CK_CONSTANT_ADDRESS_SPACE instead of global constant

parent 69add6ff
...@@ -17,10 +17,6 @@ namespace ck { ...@@ -17,10 +17,6 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
#define CK_GEMM_DESCS_CONSTANT_BUFF_SIZE 1048576 // 1MB for 1000 gemm_descs
__constant__ static char gemm_descs_const_[CK_GEMM_DESCS_CONSTANT_BUFF_SIZE];
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatAB, typename FloatAB,
typename FloatC, typename FloatC,
...@@ -33,7 +29,8 @@ __global__ void ...@@ -33,7 +29,8 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_grouped_gemm_xdlops_v2r3(const index_t group_count, kernel_grouped_gemm_xdlops_v2r3(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
const index_t group_count,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op) const CElementwiseOperation c_element_op)
...@@ -43,7 +40,8 @@ __global__ void ...@@ -43,7 +40,8 @@ __global__ void
const index_t block_id = get_block_1d_id(); const index_t block_id = get_block_1d_id();
const auto gemm_desc_ptr = reinterpret_cast<const GemmDesc*>(gemm_descs_const_); const auto gemm_desc_ptr =
reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const));
index_t group_id = 0; index_t group_id = 0;
for(index_t i = 0; i < group_count; i++) for(index_t i = 0; i < group_count; i++)
...@@ -465,18 +463,22 @@ struct DeviceGroupedGemmXdl ...@@ -465,18 +463,22 @@ struct DeviceGroupedGemmXdl
} }
} }
if(sizeof(GemmDescKernelArg) * arg.gemm_desc_kernel_arg_.size() > KernelTimer timer;
CK_GEMM_DESCS_CONSTANT_BUFF_SIZE)
{ timer.Start();
throw std::runtime_error("wrong! too many gemms");
} void* gemm_descs_const_;
hipGetErrorString(hipMalloc(
&gemm_descs_const_, arg.gemm_desc_kernel_arg_.size() * sizeof(GemmDescKernelArg)));
hipGetErrorString( hipGetErrorString(
hipMemcpyToSymbol(HIP_SYMBOL(gemm_descs_const_), hipMemcpy(gemm_descs_const_,
arg.gemm_desc_kernel_arg_.data(), arg.gemm_desc_kernel_arg_.data(),
arg.gemm_desc_kernel_arg_.size() * sizeof(GemmDescKernelArg), arg.gemm_desc_kernel_arg_.size() * sizeof(GemmDescKernelArg),
0, hipMemcpyHostToDevice));
hipMemcpyHostToDevice)); timer.End();
std::cout << "HipMemCpy time: " << timer.GetElapsedTime() << std::endl;
float ave_time = 0; float ave_time = 0;
...@@ -492,15 +494,17 @@ struct DeviceGroupedGemmXdl ...@@ -492,15 +494,17 @@ struct DeviceGroupedGemmXdl
CElementwiseOperation, CElementwiseOperation,
true>; true>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(
nrepeat, kernel,
dim3(arg.grid_size_), nrepeat,
dim3(BlockSize), dim3(arg.grid_size_),
0, dim3(BlockSize),
arg.gemm_desc_kernel_arg_.size(), 0,
arg.a_element_op_, cast_pointer_to_constant_address_space(gemm_descs_const_),
arg.b_element_op_, arg.gemm_desc_kernel_arg_.size(),
arg.c_element_op_); arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
} }
else else
{ {
...@@ -514,15 +518,17 @@ struct DeviceGroupedGemmXdl ...@@ -514,15 +518,17 @@ struct DeviceGroupedGemmXdl
CElementwiseOperation, CElementwiseOperation,
false>; false>;
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(
nrepeat, kernel,
dim3(arg.grid_size_), nrepeat,
dim3(BlockSize), dim3(arg.grid_size_),
0, dim3(BlockSize),
arg.gemm_desc_kernel_arg_.size(), 0,
arg.a_element_op_, cast_pointer_to_constant_address_space(gemm_descs_const_),
arg.b_element_op_, arg.gemm_desc_kernel_arg_.size(),
arg.c_element_op_); arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
} }
return ave_time; return ave_time;
...@@ -546,10 +552,6 @@ struct DeviceGroupedGemmXdl ...@@ -546,10 +552,6 @@ struct DeviceGroupedGemmXdl
if(ck::type_convert<ck::index_t>(arg.gemm_desc_kernel_arg_.size()) != arg.group_count_) if(ck::type_convert<ck::index_t>(arg.gemm_desc_kernel_arg_.size()) != arg.group_count_)
return false; return false;
if(sizeof(GemmDescKernelArg) * arg.gemm_desc_kernel_arg_.size() >
CK_GEMM_DESCS_CONSTANT_BUFF_SIZE)
return false;
return true; return true;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment