Commit 3d345953 authored by Adam Osewski's avatar Adam Osewski
Browse files

Update API.

parent 0e33fbdf
......@@ -12,57 +12,6 @@ namespace ck {
namespace tensor_operation {
namespace device {
/**
* @brief Structure representing single GEMM problem arguments.
*
* The pointer to the vector of those structures is passed
* to the GroupedGEMM entry point kernel.
*/
struct GemmKernelArguments
{
__host__ __device__ GemmKernelArguments(const void* p_a_grid_,
const void* p_b_grid_,
void* p_c_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_)
: p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_c_grid{p_c_grid_},
M{M_},
N{N_},
K{K_},
StrideA{StrideA_},
StrideB{StrideB_},
StrideC{StrideC_}
{
}
const void* p_a_grid;
const void* p_b_grid;
void* p_c_grid;
index_t M;
index_t N;
index_t K;
index_t StrideA;
index_t StrideB;
index_t StrideC;
void Print() const
{
std::cout << "arg {"
<< "M:" << M << ", "
<< "N:" << N << ", "
<< "K:" << K << ", "
<< "SA:" << StrideA << ", "
<< "SB:" << StrideB << ", "
<< "SC:" << StrideC << "}" << std::endl;
}
};
struct GemmDesc
{
ck::index_t M_, N_, K_;
......
......@@ -8,6 +8,57 @@ namespace ck {
namespace tensor_operation {
namespace device {
/**
* @brief Structure representing single GEMM problem arguments.
*
* The pointer to the vector of those structures is passed
* to the GroupedGEMM entry point kernel.
*/
struct GroupedGemmKernelArguments
{
__host__ __device__ GroupedGemmKernelArguments(const void* p_a_grid_,
const void* p_b_grid_,
void* p_c_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_)
: p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_c_grid{p_c_grid_},
M{M_},
N{N_},
K{K_},
StrideA{StrideA_},
StrideB{StrideB_},
StrideC{StrideC_}
{
}
const void* p_a_grid;
const void* p_b_grid;
void* p_c_grid;
index_t M;
index_t N;
index_t K;
index_t StrideA;
index_t StrideB;
index_t StrideC;
void Print() const
{
std::cout << "arg {"
<< "M:" << M << ", "
<< "N:" << N << ", "
<< "K:" << K << ", "
<< "SA:" << StrideA << ", "
<< "SB:" << StrideB << ", "
<< "SC:" << StrideC << "}" << std::endl;
}
};
template <typename ALayout,
typename BLayout,
typename DsLayout,
......@@ -31,7 +82,23 @@ struct DeviceGroupedGemmSplitK : public DeviceGroupedGemm<ALayout,
BElementwiseOperation,
CElementwiseOperation>
{
//------------------------------------------------------------------------//
// @brief Sets the k batch size.
//
// @param p_arg Pointer to the Argument we're going to change.
// @param[in] kbatch The kbatch value.
//
virtual void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const = 0;
//------------------------------------------------------------------------//
//
// @brief Sets the device kernel arguments pointer.
//
// @param p_arg The pointer to the Argument we're going to update.
// @param[in] p_dev_kernel_args The pointer to the device memory which contains kernel
// arguments.
//
virtual void SetDeviceKernelArgs(BaseArgument* p_arg, const void* p_dev_kernel_args) const = 0;
};
} // namespace device
......
......@@ -265,7 +265,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
using CGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N;
using GridwiseGemmArg = typename GridwiseGemm::Argument;
using KernelArguments = GemmKernelArguments;
using KernelArguments = GroupedGemmKernelArguments;
using Block2ETileMapKSplit =
BlockToCTileMap_KSplit_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>;
// Block2CTileMap configuration parameter.
......@@ -366,6 +366,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
index_t skipped_group_count_;
// The overall number of output tiles to be processed.
index_t grid_size_;
const void* p_dev_gemm_args_;
std::vector<KernelArguments> gemm_kernel_args_;
};
......@@ -384,8 +385,8 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
//
// @brief Launch Grouped Gemm kernel.
//
// @note This function overload is using user provided device workspace buffer for
// kernel arguments.
// @note This function overload is using user provided device buffer for kernel
// arguments.
//
// @param[in] arg The structure containing kernel arguments (in host memory).
// @param[in] dev_gemm_args The point to device memory with kernel arguments.
......@@ -400,11 +401,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
auto [all_have_kbatch_gt_one, all_have_main_k0_block_loop] =
CheckArgument(arg, stream_config);
if(dev_gemm_args != nullptr)
{
arg.p_workspace_ = dev_gemm_args;
}
else
if(dev_gemm_args == nullptr)
{
std::ostringstream err;
err << "The gemm arguments workspace buffer is not allocated!"
......@@ -428,12 +425,12 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
if(all_have_kbatch_gt_one)
{
ave_time = DispatchKernel<InMemoryDataOperationEnum::AtomicAdd, true>(
arg, stream_config);
arg, dev_gemm_args, stream_config);
}
else
{
ave_time =
DispatchKernel<InMemoryDataOperationEnum::Set, true>(arg, stream_config);
ave_time = DispatchKernel<InMemoryDataOperationEnum::Set, true>(
arg, dev_gemm_args, stream_config);
}
}
else
......@@ -441,12 +438,12 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
if(all_have_kbatch_gt_one)
{
ave_time = DispatchKernel<InMemoryDataOperationEnum::AtomicAdd, false>(
arg, stream_config);
arg, dev_gemm_args, stream_config);
}
else
{
ave_time =
DispatchKernel<InMemoryDataOperationEnum::Set, false>(arg, stream_config);
ave_time = DispatchKernel<InMemoryDataOperationEnum::Set, false>(
arg, dev_gemm_args, stream_config);
}
}
......@@ -467,9 +464,6 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
//
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
auto [all_have_kbatch_gt_one, all_have_main_k0_block_loop] =
CheckArgument(arg, stream_config);
if(arg.p_workspace_ != nullptr)
{
hip_check_error(
......@@ -487,45 +481,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
throw std::runtime_error(err.str());
}
if(all_have_kbatch_gt_one)
{
for(const auto& gemm_arg : arg.gemm_kernel_args_)
{
hip_check_error(hipMemset(
gemm_arg.p_c_grid, 0, gemm_arg.M * gemm_arg.N * sizeof(EDataType)));
}
}
float ave_time = 0;
if(all_have_main_k0_block_loop)
{
if(all_have_kbatch_gt_one)
{
ave_time = DispatchKernel<InMemoryDataOperationEnum::AtomicAdd, true>(
arg, stream_config);
}
else
{
ave_time =
DispatchKernel<InMemoryDataOperationEnum::Set, true>(arg, stream_config);
}
}
else
{
if(all_have_kbatch_gt_one)
{
ave_time = DispatchKernel<InMemoryDataOperationEnum::AtomicAdd, false>(
arg, stream_config);
}
else
{
ave_time =
DispatchKernel<InMemoryDataOperationEnum::Set, false>(arg, stream_config);
}
}
return ave_time;
return Run(arg, arg.p_workspace_, stream_config);
}
float Run(const BaseArgument* p_arg,
......@@ -600,7 +556,9 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
}
template <InMemoryDataOperationEnum CGlobalMemoryDataOperation, bool HasMainKBlockLoop>
float DispatchKernel(const Argument& arg, const StreamConfig& stream_config) const
float DispatchKernel(const Argument& arg,
const void* dev_gemm_args,
const StreamConfig& stream_config) const
{
const auto kernel = kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
KernelArguments,
......@@ -772,11 +730,20 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
}
static void SetKBatchSize(Argument& arg, index_t kbatch) { arg.UpdateKBatch(kbatch); }
static void SetDeviceKernelArgs(Argument& arg, const void* p_dev_kernel_args)
{
arg.p_dev_gemm_args_ = p_dev_kernel_args;
}
void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const override
{
return SetKBatchSize(*dynamic_cast<Argument*>(p_arg), kbatch);
}
void SetDeviceKernelArgs(BaseArgument* p_arg, const void* p_dev_kernel_args) const override
{
return SetDeviceKernelArgs(*dynamic_cast<Argument*>(p_arg), p_dev_kernel_args);
}
};
} // namespace device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment