Commit 1b78ca0d authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Move 'Argument' into GridwiseGemm

parent cbc49dc2
...@@ -120,69 +120,11 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -120,69 +120,11 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
LoopSched, LoopSched,
PipelineVer>; PipelineVer>;
using AGridDesc_K0_M_K1 = decltype(GridwiseGemm::MakeAGridDescriptor_K0_M_K1(1, 1, 1, 1)); using Argument = typename GridwiseGemm::Argument;
using BGridDesc_K0_N_K1 = decltype(GridwiseGemm::MakeBGridDescriptor_K0_N_K1(1, 1, 1, 1));
using CGridDesc_M_N = decltype(GridwiseGemm::MakeCGridDescriptor_M_N(1, 1, 1, 1, 1));
// Argument
struct Argument : public BaseArgument
{
Argument(const ADataType* p_a_grid,
const BDataType* p_b_grid,
CDataType* p_c_grid,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid},
M{M_},
N{N_},
K{K_},
StrideA{StrideA_},
StrideB{StrideB_},
StrideC{StrideC_},
MPadded{GridwiseGemm::CalculateMPadded(M_)},
NPadded{GridwiseGemm::CalculateNPadded(N_)}
{
}
__host__ void Print() const
{
printf("M = %d, N = %d, K = %d, "
"SA = %d, SB = %d, SC = %d, "
"MP = %d, NP = %d\n",
M,
N,
K,
StrideA,
StrideB,
StrideC,
MPadded,
NPadded);
}
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
CDataType* p_c_grid_;
index_t M;
index_t N;
index_t K;
index_t StrideA;
index_t StrideB;
index_t StrideC;
index_t MPadded;
index_t NPadded;
};
// Invoker // Invoker
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
using Argument = DeviceGemmXdl::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
#if DEBUG_LOG #if DEBUG_LOG
...@@ -216,30 +158,30 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -216,30 +158,30 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
if(GridwiseGemm::CalculateHasMainKBlockLoop(arg.K)) if(GridwiseGemm::CalculateHasMainKBlockLoop(arg.K))
{ {
const auto kernel = kernel_gemm_xdlops_v2r3<GridwiseGemm, Argument, true>; const auto kernel = kernel_gemm_xdlops_v2r3<GridwiseGemm, true>;
ave_time = launch_and_time_kernel(stream_config, ave_time = launch_and_time_kernel(stream_config,
kernel, kernel,
dim3(gdx, gdy, gdz), dim3(gdx, gdy, gdz),
dim3(BlockSize), dim3(BlockSize),
0, 0,
arg.p_a_grid_, arg.p_a_grid,
arg.p_b_grid_, arg.p_b_grid,
arg.p_c_grid_, arg.p_c_grid,
arg); arg);
} }
else else
{ {
const auto kernel = kernel_gemm_xdlops_v2r3<GridwiseGemm, Argument, false>; const auto kernel = kernel_gemm_xdlops_v2r3<GridwiseGemm, false>;
ave_time = launch_and_time_kernel(stream_config, ave_time = launch_and_time_kernel(stream_config,
kernel, kernel,
dim3(gdx, gdy, gdz), dim3(gdx, gdy, gdz),
dim3(BlockSize), dim3(BlockSize),
0, 0,
arg.p_a_grid_, arg.p_a_grid,
arg.p_b_grid_, arg.p_b_grid,
arg.p_c_grid_, arg.p_c_grid,
arg); arg);
} }
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
namespace ck { namespace ck {
template <typename GridwiseGemm, typename Argument, bool HasMainKBlockLoop> template <typename GridwiseGemm, bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
...@@ -24,7 +24,7 @@ __global__ void ...@@ -24,7 +24,7 @@ __global__ void
kernel_gemm_xdlops_v2r3(const typename GridwiseGemm::FloatAB* __restrict__ p_a_grid, kernel_gemm_xdlops_v2r3(const typename GridwiseGemm::FloatAB* __restrict__ p_a_grid,
const typename GridwiseGemm::FloatAB* __restrict__ p_b_grid, const typename GridwiseGemm::FloatAB* __restrict__ p_b_grid,
typename GridwiseGemm::FloatC* __restrict__ p_c_grid, typename GridwiseGemm::FloatC* __restrict__ p_c_grid,
const Argument karg) const typename GridwiseGemm::Argument karg)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__)) defined(__gfx940__))
...@@ -220,6 +220,58 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -220,6 +220,58 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
} }
} }
// Argument
struct Argument : public tensor_operation::device::BaseArgument
{
__host__ Argument(const FloatAB* p_a_grid_,
const FloatAB* p_b_grid_,
FloatC* p_c_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_)
: p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_c_grid{p_c_grid_},
M{M_},
N{N_},
K{K_},
StrideA{StrideA_},
StrideB{StrideB_},
StrideC{StrideC_},
MPadded{CalculateMPadded(M_)},
NPadded{CalculateNPadded(N_)}
{
}
__host__ void Print() const
{
std::cout << "arg {"
<< "M:" << M << ", "
<< "N:" << N << ", "
<< "K:" << K << ", "
<< "SA:" << StrideA << ", "
<< "SB:" << StrideB << ", "
<< "SC:" << StrideC << ", "
<< "MP:" << MPadded << ", "
<< "NP:" << NPadded << "}" << std::endl;
}
const FloatAB* p_a_grid;
const FloatAB* p_b_grid;
FloatC* p_c_grid;
index_t M;
index_t N;
index_t K;
index_t StrideA;
index_t StrideB;
index_t StrideC;
index_t MPadded;
index_t NPadded;
};
using GridwiseGemmPipe = remove_cvref_t<decltype( using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>; GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment