"...composable_kernel_rocm.git" did not exist on "500fa9951297c033a9c4c1d300b03895a46528d2"
Commit 4ad62d7f authored by Jing Zhang's avatar Jing Zhang
Browse files

use CK_CONSTANT_ADDRESS_SPACE instead of global constant

parent 69add6ff
......@@ -17,10 +17,6 @@ namespace ck {
namespace tensor_operation {
namespace device {
#define CK_GEMM_DESCS_CONSTANT_BUFF_SIZE 1048576 // 1MB for 1000 gemm_descs
__constant__ static char gemm_descs_const_[CK_GEMM_DESCS_CONSTANT_BUFF_SIZE];
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
......@@ -33,7 +29,8 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_gemm_xdlops_v2r3(const index_t group_count,
kernel_grouped_gemm_xdlops_v2r3(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
const index_t group_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op)
......@@ -43,7 +40,8 @@ __global__ void
const index_t block_id = get_block_1d_id();
const auto gemm_desc_ptr = reinterpret_cast<const GemmDesc*>(gemm_descs_const_);
const auto gemm_desc_ptr =
reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const));
index_t group_id = 0;
for(index_t i = 0; i < group_count; i++)
......@@ -465,18 +463,22 @@ struct DeviceGroupedGemmXdl
}
}
if(sizeof(GemmDescKernelArg) * arg.gemm_desc_kernel_arg_.size() >
CK_GEMM_DESCS_CONSTANT_BUFF_SIZE)
{
throw std::runtime_error("wrong! too many gemms");
}
KernelTimer timer;
timer.Start();
void* gemm_descs_const_;
hipGetErrorString(hipMalloc(
&gemm_descs_const_, arg.gemm_desc_kernel_arg_.size() * sizeof(GemmDescKernelArg)));
hipGetErrorString(
hipMemcpyToSymbol(HIP_SYMBOL(gemm_descs_const_),
arg.gemm_desc_kernel_arg_.data(),
arg.gemm_desc_kernel_arg_.size() * sizeof(GemmDescKernelArg),
0,
hipMemcpyHostToDevice));
hipMemcpy(gemm_descs_const_,
arg.gemm_desc_kernel_arg_.data(),
arg.gemm_desc_kernel_arg_.size() * sizeof(GemmDescKernelArg),
hipMemcpyHostToDevice));
timer.End();
std::cout << "HipMemCpy time: " << timer.GetElapsedTime() << std::endl;
float ave_time = 0;
......@@ -492,15 +494,17 @@ struct DeviceGroupedGemmXdl
CElementwiseOperation,
true>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(arg.grid_size_),
dim3(BlockSize),
0,
arg.gemm_desc_kernel_arg_.size(),
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(arg.grid_size_),
dim3(BlockSize),
0,
cast_pointer_to_constant_address_space(gemm_descs_const_),
arg.gemm_desc_kernel_arg_.size(),
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
}
else
{
......@@ -514,15 +518,17 @@ struct DeviceGroupedGemmXdl
CElementwiseOperation,
false>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(arg.grid_size_),
dim3(BlockSize),
0,
arg.gemm_desc_kernel_arg_.size(),
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(arg.grid_size_),
dim3(BlockSize),
0,
cast_pointer_to_constant_address_space(gemm_descs_const_),
arg.gemm_desc_kernel_arg_.size(),
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
}
return ave_time;
......@@ -546,10 +552,6 @@ struct DeviceGroupedGemmXdl
if(ck::type_convert<ck::index_t>(arg.gemm_desc_kernel_arg_.size()) != arg.group_count_)
return false;
if(sizeof(GemmDescKernelArg) * arg.gemm_desc_kernel_arg_.size() >
CK_GEMM_DESCS_CONSTANT_BUFF_SIZE)
return false;
return true;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment