"git@developer.sourcefind.cn:cnjsdfcy/simbricks.git" did not exist on "787267c4c53b25efa139b65a82d6939c69243792"
Commit b3512749 authored by Adam Osewski's avatar Adam Osewski
Browse files

Move Gemm KernelArguments to device op interface.

parent 61862fb4
...@@ -12,6 +12,57 @@ namespace ck { ...@@ -12,6 +12,57 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
/**
* @brief Structure representing single GEMM problem arguments.
*
* The pointer to the vector of those structures is passed
* to the GroupedGEMM entry point kernel.
*/
struct GemmKernelArguments
{
__host__ __device__ GemmKernelArguments(const void* p_a_grid_,
const void* p_b_grid_,
void* p_c_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_)
: p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_c_grid{p_c_grid_},
M{M_},
N{N_},
K{K_},
StrideA{StrideA_},
StrideB{StrideB_},
StrideC{StrideC_}
{
}
const void* p_a_grid;
const void* p_b_grid;
void* p_c_grid;
index_t M;
index_t N;
index_t K;
index_t StrideA;
index_t StrideB;
index_t StrideC;
void Print() const
{
std::cout << "arg {"
<< "M:" << M << ", "
<< "N:" << N << ", "
<< "K:" << K << ", "
<< "SA:" << StrideA << ", "
<< "SB:" << StrideB << ", "
<< "SC:" << StrideC << "}" << std::endl;
}
};
struct GemmDesc struct GemmDesc
{ {
ck::index_t M_, N_, K_; ck::index_t M_, N_, K_;
......
...@@ -265,62 +265,11 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -265,62 +265,11 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
using CGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N; using CGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N;
using GridwiseGemmArg = typename GridwiseGemm::Argument; using GridwiseGemmArg = typename GridwiseGemm::Argument;
using KernelArguments = GemmKernelArguments;
using Block2ETileMapKSplit = using Block2ETileMapKSplit =
BlockToCTileMap_KSplit_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>; BlockToCTileMap_KSplit_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>;
// Block2CTileMap configuration parameter. // Block2CTileMap configuration parameter.
static constexpr index_t B2E_M01 = 8; static constexpr index_t B2E_M01 = 8;
/**
* @brief Structure representing single GEMM problem arguments.
*
* The pointer to the vector of those structures is passed
* to the GroupedGEMM entry point kernel.
*/
struct KernelArguments
{
__host__ __device__ KernelArguments(const void* p_a_grid_,
const void* p_b_grid_,
void* p_c_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_)
: p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_c_grid{p_c_grid_},
M{M_},
N{N_},
K{K_},
StrideA{StrideA_},
StrideB{StrideB_},
StrideC{StrideC_}
{
}
const void* p_a_grid;
const void* p_b_grid;
void* p_c_grid;
index_t M;
index_t N;
index_t K;
index_t StrideA;
index_t StrideB;
index_t StrideC;
void Print() const
{
std::cout << "arg {"
<< "M:" << M << ", "
<< "N:" << N << ", "
<< "K:" << K << ", "
<< "SA:" << StrideA << ", "
<< "SB:" << StrideB << ", "
<< "SC:" << StrideC << "}" << std::endl;
}
};
static constexpr index_t DefaultKBatch = 1; static constexpr index_t DefaultKBatch = 1;
// Argument // Argument
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment