"vscode:/vscode.git/clone" did not exist on "4125756e88e82370c197fecf28e9f0b4d7eee6c3"
Commit 52c79ace authored by Adam Osewski's avatar Adam Osewski
Browse files

Change Run API to accept user provided workspace buffer.

parent 21fbf2ce
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include <tuple>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
...@@ -431,7 +432,161 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -431,7 +432,161 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
// Assume we want to have at most 2 waves per SIMD // Assume we want to have at most 2 waves per SIMD
static constexpr int CU_BLOCKS = math::integer_divide_floor(2 * CU_SIMDS, BLOCK_WAVES); static constexpr int CU_BLOCKS = math::integer_divide_floor(2 * CU_SIMDS, BLOCK_WAVES);
//
// @brief Launch Grouped Gemm kernel.
//
// @note This function overload is using user provided device workspace buffer for
// kernel arguments.
//
// @param[in] arg The structure containing kernel arguments (in host memory).
// @param[in] dev_gemm_args The point to device memory with kernel arguments.
// @param[in] stream_config The device stream configuration.
//
// @return The average kernel execution time (if time measurement is enabled.)
//
float Run(const Argument& arg,
const void* dev_gemm_args,
const StreamConfig& stream_config = StreamConfig{})
{
auto [all_have_kbatch_gt_one, all_have_main_k0_block_loop] =
CheckArgument(arg, stream_config);
if(dev_gemm_args != nullptr)
{
arg.p_workspace_ = dev_gemm_args;
}
else
{
std::ostringstream err;
err << "The gemm arguments workspace buffer is not allocated!"
<< " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
if(all_have_kbatch_gt_one)
{
for(const auto& gemm_arg : arg.gemm_kernel_args_)
{
hip_check_error(hipMemset(
gemm_arg.p_c_grid, 0, gemm_arg.M * gemm_arg.N * sizeof(EDataType)));
}
}
float ave_time = 0;
if(all_have_main_k0_block_loop)
{
if(all_have_kbatch_gt_one)
{
ave_time = DispatchKernel<InMemoryDataOperationEnum::AtomicAdd, true>(
arg, stream_config);
}
else
{
ave_time =
DispatchKernel<InMemoryDataOperationEnum::Set, true>(arg, stream_config);
}
}
else
{
if(all_have_kbatch_gt_one)
{
ave_time = DispatchKernel<InMemoryDataOperationEnum::AtomicAdd, false>(
arg, stream_config);
}
else
{
ave_time =
DispatchKernel<InMemoryDataOperationEnum::Set, false>(arg, stream_config);
}
}
return ave_time;
}
//
// @brief Launch Grouped Gemm kernel.
//
// @note This function overload is using device workspace buffer for kernel arguments.
// The user should call @see GetWorkSpaceSize and @see SetWorkSpacePointer on
// arg parameter to properly allocate this buffer.
//
// @param[in] arg The structure containing kernel arguments (in host memory).
// @param[in] stream_config The device stream configuration.
//
// @return The average kernel execution time (if time measurement is enabled.)
//
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
auto [all_have_kbatch_gt_one, all_have_main_k0_block_loop] =
CheckArgument(arg, stream_config);
if(arg.p_workspace_ != nullptr)
{
hip_check_error(
hipMemcpyWithStream(arg.p_workspace_,
arg.gemm_kernel_args_.data(),
arg.gemm_kernel_args_.size() * sizeof(KernelArguments),
hipMemcpyHostToDevice,
stream_config.stream_id_));
}
else
{
std::ostringstream err;
err << "The gemm arguments workspace buffer is not allocated!"
<< " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
if(all_have_kbatch_gt_one)
{
for(const auto& gemm_arg : arg.gemm_kernel_args_)
{
hip_check_error(hipMemset(
gemm_arg.p_c_grid, 0, gemm_arg.M * gemm_arg.N * sizeof(EDataType)));
}
}
float ave_time = 0;
if(all_have_main_k0_block_loop)
{
if(all_have_kbatch_gt_one)
{
ave_time = DispatchKernel<InMemoryDataOperationEnum::AtomicAdd, true>(
arg, stream_config);
}
else
{
ave_time =
DispatchKernel<InMemoryDataOperationEnum::Set, true>(arg, stream_config);
}
}
else
{
if(all_have_kbatch_gt_one)
{
ave_time = DispatchKernel<InMemoryDataOperationEnum::AtomicAdd, false>(
arg, stream_config);
}
else
{
ave_time =
DispatchKernel<InMemoryDataOperationEnum::Set, false>(arg, stream_config);
}
}
return ave_time;
}
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
private:
auto CheckArgument(const Argument& arg, const StreamConfig& stream_config) const
{ {
index_t K0 = GridwiseGemm::CalculateK0(arg.gemm_kernel_args_[0].K, arg.K_BATCH); index_t K0 = GridwiseGemm::CalculateK0(arg.gemm_kernel_args_[0].K, arg.K_BATCH);
bool all_have_kbatch_gt_one = arg.K_BATCH > 1; bool all_have_kbatch_gt_one = arg.K_BATCH > 1;
...@@ -492,131 +647,53 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -492,131 +647,53 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
throw std::runtime_error(err.str()); throw std::runtime_error(err.str());
} }
} }
return std::make_tuple(all_have_kbatch_gt_one, all_have_main_k0_block_loop);
}
if(arg.p_workspace_ != nullptr) template <InMemoryDataOperationEnum CGlobalMemoryDataOperation, bool HasMainKBlockLoop>
{ float DispatchKernel(const Argument& arg, const StreamConfig& stream_config) const
hip_check_error( {
hipMemcpyWithStream(arg.p_workspace_, const auto kernel = kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
arg.gemm_kernel_args_.data(), KernelArguments,
arg.gemm_kernel_args_.size() * sizeof(KernelArguments), ADataType,
hipMemcpyHostToDevice, BDataType,
stream_config.stream_id_)); EDataType,
} HasMainKBlockLoop,
else CGlobalMemoryDataOperation>;
{ return LaunchKernel(kernel, arg, stream_config);
std::ostringstream err; }
err << "The argument workspace buffer is not allocated!"
<< " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
float ave_time = 0;
const auto Run = [&](const auto& kernel) {
if(all_have_kbatch_gt_one)
{
for(const auto& gemm_arg : arg.gemm_kernel_args_)
{
hip_check_error(hipMemset(
gemm_arg.p_c_grid, 0, gemm_arg.M * gemm_arg.N * sizeof(EDataType)));
}
}
// Calculate max number of workgroups that can simultaneously reside on the CU.
int num_blocks = 0;
size_t dyn_shared_mem_per_blk = 0;
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
&num_blocks, kernel, BlockSize, dyn_shared_mem_per_blk));
int cu_count = getAvailableComputeUnitCount(stream_config);
if(stream_config.log_level_ > 0) template <typename KernelFunction>
{ float LaunchKernel(const KernelFunction& kernel,
std::cout << "MaxActiveBlocksPerCU: " << num_blocks const Argument& arg,
<< ", available CUs count: " << cu_count << ", grid size: " const StreamConfig& stream_config) const
<< ck::math::min(num_blocks, CU_BLOCKS) * cu_count * {
BLOCK_SUBSCRIPTION_FACTOR // Calculate max number of workgroups that can simultaneously reside on the CU.
<< std::endl; int num_blocks = 0;
} size_t dyn_shared_mem_per_blk = 0;
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
&num_blocks, kernel, BlockSize, dyn_shared_mem_per_blk));
ave_time = int cu_count = getAvailableComputeUnitCount(stream_config);
launch_and_time_kernel(stream_config,
kernel,
dim3(cu_count * ck::math::min(num_blocks, CU_BLOCKS) *
BLOCK_SUBSCRIPTION_FACTOR),
dim3(BlockSize),
0,
arg.p_workspace_,
arg.grid_size_,
arg.K_BATCH);
};
if(all_have_main_k0_block_loop) if(stream_config.log_level_ > 0)
{
if(all_have_kbatch_gt_one)
{
const auto kernel =
kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
KernelArguments,
ADataType,
BDataType,
EDataType,
true,
InMemoryDataOperationEnum::AtomicAdd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
KernelArguments,
ADataType,
BDataType,
EDataType,
true,
InMemoryDataOperationEnum::Set>;
Run(kernel);
}
}
else
{ {
if(all_have_kbatch_gt_one) std::cout << "MaxActiveBlocksPerCU: " << num_blocks
{ << ", available CUs count: " << cu_count << ", grid size: "
const auto kernel = << ck::math::min(num_blocks, CU_BLOCKS) * cu_count *
kernel_grouped_gemm_xdl_splitk<GridwiseGemm, BLOCK_SUBSCRIPTION_FACTOR
KernelArguments, << std::endl;
ADataType,
BDataType,
EDataType,
false,
InMemoryDataOperationEnum::AtomicAdd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
KernelArguments,
ADataType,
BDataType,
EDataType,
false,
InMemoryDataOperationEnum::Set>;
Run(kernel);
}
} }
return ave_time; return launch_and_time_kernel(
} stream_config,
kernel,
float Run(const BaseArgument* p_arg, dim3(cu_count * ck::math::min(num_blocks, CU_BLOCKS) * BLOCK_SUBSCRIPTION_FACTOR),
const StreamConfig& stream_config = StreamConfig{}) override dim3(BlockSize),
{ 0,
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config); arg.p_workspace_,
arg.grid_size_,
arg.K_BATCH);
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment