Commit 3d345953 authored by Adam Osewski's avatar Adam Osewski
Browse files

Update API.

parent 0e33fbdf
...@@ -12,57 +12,6 @@ namespace ck { ...@@ -12,57 +12,6 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
/**
* @brief Structure representing single GEMM problem arguments.
*
* The pointer to the vector of those structures is passed
* to the GroupedGEMM entry point kernel.
*/
struct GemmKernelArguments
{
__host__ __device__ GemmKernelArguments(const void* p_a_grid_,
const void* p_b_grid_,
void* p_c_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_)
: p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_c_grid{p_c_grid_},
M{M_},
N{N_},
K{K_},
StrideA{StrideA_},
StrideB{StrideB_},
StrideC{StrideC_}
{
}
const void* p_a_grid;
const void* p_b_grid;
void* p_c_grid;
index_t M;
index_t N;
index_t K;
index_t StrideA;
index_t StrideB;
index_t StrideC;
void Print() const
{
std::cout << "arg {"
<< "M:" << M << ", "
<< "N:" << N << ", "
<< "K:" << K << ", "
<< "SA:" << StrideA << ", "
<< "SB:" << StrideB << ", "
<< "SC:" << StrideC << "}" << std::endl;
}
};
struct GemmDesc struct GemmDesc
{ {
ck::index_t M_, N_, K_; ck::index_t M_, N_, K_;
......
...@@ -8,6 +8,57 @@ namespace ck { ...@@ -8,6 +8,57 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
/**
* @brief Structure representing single GEMM problem arguments.
*
* The pointer to the vector of those structures is passed
* to the GroupedGEMM entry point kernel.
*/
struct GroupedGemmKernelArguments
{
__host__ __device__ GroupedGemmKernelArguments(const void* p_a_grid_,
const void* p_b_grid_,
void* p_c_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_)
: p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_c_grid{p_c_grid_},
M{M_},
N{N_},
K{K_},
StrideA{StrideA_},
StrideB{StrideB_},
StrideC{StrideC_}
{
}
const void* p_a_grid;
const void* p_b_grid;
void* p_c_grid;
index_t M;
index_t N;
index_t K;
index_t StrideA;
index_t StrideB;
index_t StrideC;
void Print() const
{
std::cout << "arg {"
<< "M:" << M << ", "
<< "N:" << N << ", "
<< "K:" << K << ", "
<< "SA:" << StrideA << ", "
<< "SB:" << StrideB << ", "
<< "SC:" << StrideC << "}" << std::endl;
}
};
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
typename DsLayout, typename DsLayout,
...@@ -31,7 +82,23 @@ struct DeviceGroupedGemmSplitK : public DeviceGroupedGemm<ALayout, ...@@ -31,7 +82,23 @@ struct DeviceGroupedGemmSplitK : public DeviceGroupedGemm<ALayout,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation> CElementwiseOperation>
{ {
//------------------------------------------------------------------------//
// @brief Sets the k batch size.
//
// @param p_arg Pointer to the Argument we're going to change.
// @param[in] kbatch The kbatch value.
//
virtual void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const = 0; virtual void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const = 0;
//------------------------------------------------------------------------//
//
// @brief Sets the device kernel arguments pointer.
//
// @param p_arg The pointer to the Argument we're going to update.
// @param[in] p_dev_kernel_args The pointer to the device memory which contains kernel
// arguments.
//
virtual void SetDeviceKernelArgs(BaseArgument* p_arg, const void* p_dev_kernel_args) const = 0;
}; };
} // namespace device } // namespace device
......
...@@ -265,7 +265,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -265,7 +265,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
using CGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N; using CGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N;
using GridwiseGemmArg = typename GridwiseGemm::Argument; using GridwiseGemmArg = typename GridwiseGemm::Argument;
using KernelArguments = GemmKernelArguments; using KernelArguments = GroupedGemmKernelArguments;
using Block2ETileMapKSplit = using Block2ETileMapKSplit =
BlockToCTileMap_KSplit_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>; BlockToCTileMap_KSplit_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>;
// Block2CTileMap configuration parameter. // Block2CTileMap configuration parameter.
...@@ -366,6 +366,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -366,6 +366,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
index_t skipped_group_count_; index_t skipped_group_count_;
// The overall number of output tiles to be processed. // The overall number of output tiles to be processed.
index_t grid_size_; index_t grid_size_;
const void* p_dev_gemm_args_;
std::vector<KernelArguments> gemm_kernel_args_; std::vector<KernelArguments> gemm_kernel_args_;
}; };
...@@ -384,8 +385,8 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -384,8 +385,8 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
// //
// @brief Launch Grouped Gemm kernel. // @brief Launch Grouped Gemm kernel.
// //
// @note This function overload is using user provided device workspace buffer for // @note This function overload is using user provided device buffer for kernel
// kernel arguments. // arguments.
// //
// @param[in] arg The structure containing kernel arguments (in host memory). // @param[in] arg The structure containing kernel arguments (in host memory).
// @param[in] dev_gemm_args The point to device memory with kernel arguments. // @param[in] dev_gemm_args The point to device memory with kernel arguments.
...@@ -400,11 +401,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -400,11 +401,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
auto [all_have_kbatch_gt_one, all_have_main_k0_block_loop] = auto [all_have_kbatch_gt_one, all_have_main_k0_block_loop] =
CheckArgument(arg, stream_config); CheckArgument(arg, stream_config);
if(dev_gemm_args != nullptr) if(dev_gemm_args == nullptr)
{
arg.p_workspace_ = dev_gemm_args;
}
else
{ {
std::ostringstream err; std::ostringstream err;
err << "The gemm arguments workspace buffer is not allocated!" err << "The gemm arguments workspace buffer is not allocated!"
...@@ -428,12 +425,12 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -428,12 +425,12 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
if(all_have_kbatch_gt_one) if(all_have_kbatch_gt_one)
{ {
ave_time = DispatchKernel<InMemoryDataOperationEnum::AtomicAdd, true>( ave_time = DispatchKernel<InMemoryDataOperationEnum::AtomicAdd, true>(
arg, stream_config); arg, dev_gemm_args, stream_config);
} }
else else
{ {
ave_time = ave_time = DispatchKernel<InMemoryDataOperationEnum::Set, true>(
DispatchKernel<InMemoryDataOperationEnum::Set, true>(arg, stream_config); arg, dev_gemm_args, stream_config);
} }
} }
else else
...@@ -441,12 +438,12 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -441,12 +438,12 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
if(all_have_kbatch_gt_one) if(all_have_kbatch_gt_one)
{ {
ave_time = DispatchKernel<InMemoryDataOperationEnum::AtomicAdd, false>( ave_time = DispatchKernel<InMemoryDataOperationEnum::AtomicAdd, false>(
arg, stream_config); arg, dev_gemm_args, stream_config);
} }
else else
{ {
ave_time = ave_time = DispatchKernel<InMemoryDataOperationEnum::Set, false>(
DispatchKernel<InMemoryDataOperationEnum::Set, false>(arg, stream_config); arg, dev_gemm_args, stream_config);
} }
} }
...@@ -467,9 +464,6 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -467,9 +464,6 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
// //
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
auto [all_have_kbatch_gt_one, all_have_main_k0_block_loop] =
CheckArgument(arg, stream_config);
if(arg.p_workspace_ != nullptr) if(arg.p_workspace_ != nullptr)
{ {
hip_check_error( hip_check_error(
...@@ -487,45 +481,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -487,45 +481,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
throw std::runtime_error(err.str()); throw std::runtime_error(err.str());
} }
if(all_have_kbatch_gt_one) return Run(arg, arg.p_workspace_, stream_config);
{
for(const auto& gemm_arg : arg.gemm_kernel_args_)
{
hip_check_error(hipMemset(
gemm_arg.p_c_grid, 0, gemm_arg.M * gemm_arg.N * sizeof(EDataType)));
}
}
float ave_time = 0;
if(all_have_main_k0_block_loop)
{
if(all_have_kbatch_gt_one)
{
ave_time = DispatchKernel<InMemoryDataOperationEnum::AtomicAdd, true>(
arg, stream_config);
}
else
{
ave_time =
DispatchKernel<InMemoryDataOperationEnum::Set, true>(arg, stream_config);
}
}
else
{
if(all_have_kbatch_gt_one)
{
ave_time = DispatchKernel<InMemoryDataOperationEnum::AtomicAdd, false>(
arg, stream_config);
}
else
{
ave_time =
DispatchKernel<InMemoryDataOperationEnum::Set, false>(arg, stream_config);
}
}
return ave_time;
} }
float Run(const BaseArgument* p_arg, float Run(const BaseArgument* p_arg,
...@@ -600,7 +556,9 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -600,7 +556,9 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
} }
template <InMemoryDataOperationEnum CGlobalMemoryDataOperation, bool HasMainKBlockLoop> template <InMemoryDataOperationEnum CGlobalMemoryDataOperation, bool HasMainKBlockLoop>
float DispatchKernel(const Argument& arg, const StreamConfig& stream_config) const float DispatchKernel(const Argument& arg,
const void* dev_gemm_args,
const StreamConfig& stream_config) const
{ {
const auto kernel = kernel_grouped_gemm_xdl_splitk<GridwiseGemm, const auto kernel = kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
KernelArguments, KernelArguments,
...@@ -772,11 +730,20 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -772,11 +730,20 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
} }
static void SetKBatchSize(Argument& arg, index_t kbatch) { arg.UpdateKBatch(kbatch); } static void SetKBatchSize(Argument& arg, index_t kbatch) { arg.UpdateKBatch(kbatch); }
static void SetDeviceKernelArgs(Argument& arg, const void* p_dev_kernel_args)
{
arg.p_dev_gemm_args_ = p_dev_kernel_args;
}
void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const override void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const override
{ {
return SetKBatchSize(*dynamic_cast<Argument*>(p_arg), kbatch); return SetKBatchSize(*dynamic_cast<Argument*>(p_arg), kbatch);
} }
void SetDeviceKernelArgs(BaseArgument* p_arg, const void* p_dev_kernel_args) const override
{
return SetDeviceKernelArgs(*dynamic_cast<Argument*>(p_arg), p_dev_kernel_args);
}
}; };
} // namespace device } // namespace device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment