Commit 52c79ace authored by Adam Osewski's avatar Adam Osewski
Browse files

Change Run API to accept user provided workspace buffer.

parent 21fbf2ce
......@@ -5,6 +5,7 @@
#include <iostream>
#include <sstream>
#include <tuple>
#include "ck/ck.hpp"
#include "ck/host_utility/device_prop.hpp"
......@@ -431,7 +432,161 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
// Assume we want to have at most 2 waves per SIMD
static constexpr int CU_BLOCKS = math::integer_divide_floor(2 * CU_SIMDS, BLOCK_WAVES);
//
// @brief Launch Grouped Gemm kernel.
//
// @note This function overload is using user provided device workspace buffer for
// kernel arguments.
//
// @param[in] arg The structure containing kernel arguments (in host memory).
// @param[in] dev_gemm_args The point to device memory with kernel arguments.
// @param[in] stream_config The device stream configuration.
//
// @return The average kernel execution time (if time measurement is enabled.)
//
float Run(const Argument& arg,
const void* dev_gemm_args,
const StreamConfig& stream_config = StreamConfig{})
{
auto [all_have_kbatch_gt_one, all_have_main_k0_block_loop] =
CheckArgument(arg, stream_config);
if(dev_gemm_args != nullptr)
{
arg.p_workspace_ = dev_gemm_args;
}
else
{
std::ostringstream err;
err << "The gemm arguments workspace buffer is not allocated!"
<< " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
if(all_have_kbatch_gt_one)
{
for(const auto& gemm_arg : arg.gemm_kernel_args_)
{
hip_check_error(hipMemset(
gemm_arg.p_c_grid, 0, gemm_arg.M * gemm_arg.N * sizeof(EDataType)));
}
}
float ave_time = 0;
if(all_have_main_k0_block_loop)
{
if(all_have_kbatch_gt_one)
{
ave_time = DispatchKernel<InMemoryDataOperationEnum::AtomicAdd, true>(
arg, stream_config);
}
else
{
ave_time =
DispatchKernel<InMemoryDataOperationEnum::Set, true>(arg, stream_config);
}
}
else
{
if(all_have_kbatch_gt_one)
{
ave_time = DispatchKernel<InMemoryDataOperationEnum::AtomicAdd, false>(
arg, stream_config);
}
else
{
ave_time =
DispatchKernel<InMemoryDataOperationEnum::Set, false>(arg, stream_config);
}
}
return ave_time;
}
//
// @brief Launch Grouped Gemm kernel.
//
// @note This function overload is using device workspace buffer for kernel arguments.
// The user should call @see GetWorkSpaceSize and @see SetWorkSpacePointer on
// arg parameter to properly allocate this buffer.
//
// @param[in] arg The structure containing kernel arguments (in host memory).
// @param[in] stream_config The device stream configuration.
//
// @return The average kernel execution time (if time measurement is enabled.)
//
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
auto [all_have_kbatch_gt_one, all_have_main_k0_block_loop] =
CheckArgument(arg, stream_config);
if(arg.p_workspace_ != nullptr)
{
hip_check_error(
hipMemcpyWithStream(arg.p_workspace_,
arg.gemm_kernel_args_.data(),
arg.gemm_kernel_args_.size() * sizeof(KernelArguments),
hipMemcpyHostToDevice,
stream_config.stream_id_));
}
else
{
std::ostringstream err;
err << "The gemm arguments workspace buffer is not allocated!"
<< " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
if(all_have_kbatch_gt_one)
{
for(const auto& gemm_arg : arg.gemm_kernel_args_)
{
hip_check_error(hipMemset(
gemm_arg.p_c_grid, 0, gemm_arg.M * gemm_arg.N * sizeof(EDataType)));
}
}
float ave_time = 0;
if(all_have_main_k0_block_loop)
{
if(all_have_kbatch_gt_one)
{
ave_time = DispatchKernel<InMemoryDataOperationEnum::AtomicAdd, true>(
arg, stream_config);
}
else
{
ave_time =
DispatchKernel<InMemoryDataOperationEnum::Set, true>(arg, stream_config);
}
}
else
{
if(all_have_kbatch_gt_one)
{
ave_time = DispatchKernel<InMemoryDataOperationEnum::AtomicAdd, false>(
arg, stream_config);
}
else
{
ave_time =
DispatchKernel<InMemoryDataOperationEnum::Set, false>(arg, stream_config);
}
}
return ave_time;
}
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
private:
auto CheckArgument(const Argument& arg, const StreamConfig& stream_config) const
{
index_t K0 = GridwiseGemm::CalculateK0(arg.gemm_kernel_args_[0].K, arg.K_BATCH);
bool all_have_kbatch_gt_one = arg.K_BATCH > 1;
......@@ -492,131 +647,53 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
throw std::runtime_error(err.str());
}
}
return std::make_tuple(all_have_kbatch_gt_one, all_have_main_k0_block_loop);
}
if(arg.p_workspace_ != nullptr)
{
hip_check_error(
hipMemcpyWithStream(arg.p_workspace_,
arg.gemm_kernel_args_.data(),
arg.gemm_kernel_args_.size() * sizeof(KernelArguments),
hipMemcpyHostToDevice,
stream_config.stream_id_));
}
else
{
std::ostringstream err;
err << "The argument workspace buffer is not allocated!"
<< " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
float ave_time = 0;
const auto Run = [&](const auto& kernel) {
if(all_have_kbatch_gt_one)
{
for(const auto& gemm_arg : arg.gemm_kernel_args_)
{
hip_check_error(hipMemset(
gemm_arg.p_c_grid, 0, gemm_arg.M * gemm_arg.N * sizeof(EDataType)));
}
}
// Calculate max number of workgroups that can simultaneously reside on the CU.
int num_blocks = 0;
size_t dyn_shared_mem_per_blk = 0;
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
&num_blocks, kernel, BlockSize, dyn_shared_mem_per_blk));
int cu_count = getAvailableComputeUnitCount(stream_config);
template <InMemoryDataOperationEnum CGlobalMemoryDataOperation, bool HasMainKBlockLoop>
float DispatchKernel(const Argument& arg, const StreamConfig& stream_config) const
{
const auto kernel = kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
KernelArguments,
ADataType,
BDataType,
EDataType,
HasMainKBlockLoop,
CGlobalMemoryDataOperation>;
return LaunchKernel(kernel, arg, stream_config);
}
if(stream_config.log_level_ > 0)
{
std::cout << "MaxActiveBlocksPerCU: " << num_blocks
<< ", available CUs count: " << cu_count << ", grid size: "
<< ck::math::min(num_blocks, CU_BLOCKS) * cu_count *
BLOCK_SUBSCRIPTION_FACTOR
<< std::endl;
}
template <typename KernelFunction>
float LaunchKernel(const KernelFunction& kernel,
const Argument& arg,
const StreamConfig& stream_config) const
{
// Calculate max number of workgroups that can simultaneously reside on the CU.
int num_blocks = 0;
size_t dyn_shared_mem_per_blk = 0;
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
&num_blocks, kernel, BlockSize, dyn_shared_mem_per_blk));
ave_time =
launch_and_time_kernel(stream_config,
kernel,
dim3(cu_count * ck::math::min(num_blocks, CU_BLOCKS) *
BLOCK_SUBSCRIPTION_FACTOR),
dim3(BlockSize),
0,
arg.p_workspace_,
arg.grid_size_,
arg.K_BATCH);
};
int cu_count = getAvailableComputeUnitCount(stream_config);
if(all_have_main_k0_block_loop)
{
if(all_have_kbatch_gt_one)
{
const auto kernel =
kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
KernelArguments,
ADataType,
BDataType,
EDataType,
true,
InMemoryDataOperationEnum::AtomicAdd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
KernelArguments,
ADataType,
BDataType,
EDataType,
true,
InMemoryDataOperationEnum::Set>;
Run(kernel);
}
}
else
if(stream_config.log_level_ > 0)
{
if(all_have_kbatch_gt_one)
{
const auto kernel =
kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
KernelArguments,
ADataType,
BDataType,
EDataType,
false,
InMemoryDataOperationEnum::AtomicAdd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
KernelArguments,
ADataType,
BDataType,
EDataType,
false,
InMemoryDataOperationEnum::Set>;
Run(kernel);
}
std::cout << "MaxActiveBlocksPerCU: " << num_blocks
<< ", available CUs count: " << cu_count << ", grid size: "
<< ck::math::min(num_blocks, CU_BLOCKS) * cu_count *
BLOCK_SUBSCRIPTION_FACTOR
<< std::endl;
}
return ave_time;
}
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
return launch_and_time_kernel(
stream_config,
kernel,
dim3(cu_count * ck::math::min(num_blocks, CU_BLOCKS) * BLOCK_SUBSCRIPTION_FACTOR),
dim3(BlockSize),
0,
arg.p_workspace_,
arg.grid_size_,
arg.K_BATCH);
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment