Commit 9be8900f authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Push-down class 'GridwiseGemm::Argument' fields

parent 7bae1691
......@@ -451,26 +451,18 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_,
index_t MPadded_,
index_t NPadded_,
index_t KPadded_,
index_t AK0_,
index_t BK0_)
: Parent(nullptr,
nullptr,
nullptr,
M_,
index_t StrideC_)
: Parent(M_,
N_,
K_,
StrideA_,
StrideB_,
StrideC_,
MPadded_,
NPadded_,
KPadded_,
AK0_,
BK0_),
GridwiseGemm::CalculateMPadded(M_),
GridwiseGemm::CalculateNPadded(N_),
GridwiseGemm::CalculateKPadded(K_),
GridwiseGemm::CalculateAK0(K_),
GridwiseGemm::CalculateBK0(K_)),
p_a_grid_real_{p_a_grid_real},
p_a_grid_imag_{p_a_grid_imag},
p_b_grid_real_{p_b_grid_real},
......@@ -510,15 +502,13 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
// Invoker
struct Invoker : public BaseInvoker
{
// void Print(const Argument& karg) { karg.Print(); }
void Print(const Argument& karg) { karg.Print(); }
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{})
{
Argument karg = arg;
if(stream_config.log_level_ > 0)
{
// Print(karg);
Print(karg);
}
if(!GridwiseGemm::CheckValidity(karg))
......@@ -575,17 +565,25 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
{
const auto kernel = kernel_gemm_xdl_cshuffle_v1_simplified<GridwiseGemm, true>;
karg.p_a_grid = karg.p_a_grid_real_;
karg.p_b_grid = karg.p_b_grid_real_;
karg.p_c_grid = karg.p_aux_grid_;
ave_time += launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
karg.p_a_grid = karg.p_a_grid_imag_;
karg.p_b_grid = karg.p_b_grid_imag_;
karg.p_c_grid = karg.p_aux_2_grid_;
ave_time += launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
ave_time += launch_and_time_kernel(stream_config,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
karg.p_a_grid_real_,
karg.p_b_grid_real_,
karg.p_aux_grid_,
karg);
ave_time += launch_and_time_kernel(stream_config,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
karg.p_a_grid_imag_,
karg.p_b_grid_imag_,
karg.p_aux_2_grid_,
karg);
// c_real = aux - aux_2
ave_time += launch_and_time_kernel(
......@@ -601,17 +599,25 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
make_tuple(karg.p_c_grid_real_),
Subtract{});
karg.p_a_grid = karg.p_a_grid_real_;
karg.p_b_grid = karg.p_b_grid_imag_;
karg.p_c_grid = karg.p_aux_grid_;
ave_time += launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
karg.p_a_grid = karg.p_a_grid_imag_;
karg.p_b_grid = karg.p_b_grid_real_;
karg.p_c_grid = karg.p_aux_2_grid_;
ave_time += launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
ave_time += launch_and_time_kernel(stream_config,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
karg.p_a_grid_real_,
karg.p_b_grid_imag_,
karg.p_aux_grid_,
karg);
ave_time += launch_and_time_kernel(stream_config,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
karg.p_a_grid_imag_,
karg.p_b_grid_real_,
karg.p_aux_2_grid_,
karg);
// c_imag = aux + aux_2
ave_time += launch_and_time_kernel(
......@@ -631,17 +637,25 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
{
const auto kernel = kernel_gemm_xdl_cshuffle_v1_simplified<GridwiseGemm, false>;
karg.p_a_grid = karg.p_a_grid_real_;
karg.p_b_grid = karg.p_b_grid_real_;
karg.p_c_grid = karg.p_aux_grid_;
ave_time += launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
karg.p_a_grid = karg.p_a_grid_imag_;
karg.p_b_grid = karg.p_b_grid_imag_;
karg.p_c_grid = karg.p_aux_2_grid_;
ave_time += launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
ave_time += launch_and_time_kernel(stream_config,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
karg.p_a_grid_real_,
karg.p_b_grid_real_,
karg.p_aux_grid_,
karg);
ave_time += launch_and_time_kernel(stream_config,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
karg.p_a_grid_imag_,
karg.p_b_grid_imag_,
karg.p_aux_2_grid_,
karg);
// c_real = aux - aux_2
ave_time += launch_and_time_kernel(
......@@ -657,17 +671,25 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
make_tuple(karg.p_c_grid_real_),
Subtract{});
karg.p_a_grid = karg.p_a_grid_real_;
karg.p_b_grid = karg.p_b_grid_imag_;
karg.p_c_grid = karg.p_aux_grid_;
ave_time += launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
karg.p_a_grid = karg.p_a_grid_imag_;
karg.p_b_grid = karg.p_b_grid_real_;
karg.p_c_grid = karg.p_aux_2_grid_;
ave_time += launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
ave_time += launch_and_time_kernel(stream_config,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
karg.p_a_grid_real_,
karg.p_b_grid_imag_,
karg.p_aux_grid_,
karg);
ave_time += launch_and_time_kernel(stream_config,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
karg.p_a_grid_imag_,
karg.p_b_grid_real_,
karg.p_aux_2_grid_,
karg);
// c_imag = aux + aux_2
ave_time += launch_and_time_kernel(
......@@ -741,12 +763,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
K,
StrideA,
StrideB,
StrideC,
GridwiseGemm::CalculateMPadded(M),
GridwiseGemm::CalculateNPadded(N),
GridwiseGemm::CalculateKPadded(K),
GridwiseGemm::CalculateAK0(K),
GridwiseGemm::CalculateBK0(K)};
StrideC};
}
static auto MakeInvoker() { return Invoker{}; }
......@@ -782,12 +799,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
K,
StrideA,
StrideB,
StrideC,
GridwiseGemm::CalculateMPadded(M),
GridwiseGemm::CalculateNPadded(N),
GridwiseGemm::CalculateKPadded(K),
GridwiseGemm::CalculateAK0(K),
GridwiseGemm::CalculateBK0(K));
StrideC);
}
// polymorphic
......
......@@ -130,7 +130,40 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
LoopSched,
PipelineVer>;
using Argument = typename GridwiseGemm::Argument;
struct Argument : public GridwiseGemm::Argument
{
using Parent = typename GridwiseGemm::Argument;
Argument(const ADataType* p_a_grid_,
const BDataType* p_b_grid_,
CDataType* p_c_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_)
: Parent(M_,
N_,
K_,
StrideA_,
StrideB_,
StrideC_,
GridwiseGemm::CalculateMPadded(M_),
GridwiseGemm::CalculateNPadded(N_),
GridwiseGemm::CalculateKPadded(K_),
GridwiseGemm::CalculateAK0(K_),
GridwiseGemm::CalculateBK0(K_)),
p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_c_grid{p_c_grid_}
{
}
const ADataType* p_a_grid;
const BDataType* p_b_grid;
CDataType* p_c_grid;
};
// Invoker
struct Invoker : public BaseInvoker
......@@ -160,15 +193,29 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
{
const auto kernel = kernel_gemm_xdl_cshuffle_v1_simplified<GridwiseGemm, true>;
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
karg.p_a_grid,
karg.p_b_grid,
karg.p_c_grid,
karg);
}
else
{
const auto kernel = kernel_gemm_xdl_cshuffle_v1_simplified<GridwiseGemm, false>;
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
karg.p_a_grid,
karg.p_b_grid,
karg.p_c_grid,
karg);
}
return ave_time;
......@@ -212,20 +259,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
BElementwiseOperation,
CElementwiseOperation)
{
return Argument{p_a,
p_b,
p_c,
M,
N,
K,
StrideA,
StrideB,
StrideC,
GridwiseGemm::CalculateMPadded(M),
GridwiseGemm::CalculateNPadded(N),
GridwiseGemm::CalculateKPadded(K),
GridwiseGemm::CalculateAK0(K),
GridwiseGemm::CalculateBK0(K)};
return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC};
}
static auto MakeInvoker() { return Invoker{}; }
......@@ -252,12 +286,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
K,
StrideA,
StrideB,
StrideC,
GridwiseGemm::CalculateMPadded(M),
GridwiseGemm::CalculateNPadded(N),
GridwiseGemm::CalculateKPadded(K),
GridwiseGemm::CalculateAK0(K),
GridwiseGemm::CalculateBK0(K));
StrideC);
}
// polymorphic
......
......@@ -22,13 +22,17 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_xdl_cshuffle_v1_simplified(typename GridwiseGemm::Argument karg)
kernel_gemm_xdl_cshuffle_v1_simplified(
const typename GridwiseGemm::FloatAB* __restrict__ p_a_grid,
const typename GridwiseGemm::FloatAB* __restrict__ p_b_grid,
typename GridwiseGemm::FloatC* __restrict__ p_c_grid,
typename GridwiseGemm::Argument karg)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(karg, p_shared);
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, p_b_grid, p_c_grid, p_shared, karg);
#else
ignore = karg;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
......@@ -37,10 +41,10 @@ __global__ void
template <typename ALayout,
typename BLayout,
typename CLayout,
typename FloatAB,
typename FloatAB_,
typename FloatGemmAcc,
typename FloatCShuffle,
typename FloatC,
typename FloatC_,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
......@@ -96,6 +100,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
static constexpr auto AK1_c = Number<AK1Value>{};
static constexpr auto BK1_c = Number<BK1Value>{};
using FloatAB = FloatAB_;
using FloatC = FloatC_;
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
#if defined(INTEGER_DIVIDE_CEIL)
......@@ -390,10 +397,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// Argument
struct Argument : public tensor_operation::device::BaseArgument
{
__host__ Argument(const FloatAB* p_a_grid_,
const FloatAB* p_b_grid_,
FloatC* p_c_grid_,
index_t M_,
__host__ Argument(index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
......@@ -404,10 +408,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
index_t KPadded_,
index_t AK0_,
index_t BK0_)
: p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_c_grid{p_c_grid_},
M{M_},
: M{M_},
N{N_},
K{K_},
StrideA{StrideA_},
......@@ -446,10 +447,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
__host__ __device__ ~Argument() override {}
// private:
const FloatAB* p_a_grid;
const FloatAB* p_b_grid;
FloatC* p_c_grid;
index_t M;
index_t N;
index_t K;
......@@ -673,12 +670,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
}
template <bool HasMainKBlockLoop>
__device__ static void Run(const Argument& karg, void* __restrict__ p_shared)
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
void* __restrict__ p_shared,
const Argument& karg)
{
const FloatAB* p_a_grid = karg.p_a_grid;
const FloatAB* p_b_grid = karg.p_b_grid;
FloatC* p_c_grid = karg.p_c_grid;
#define CREATE_DESCS_ON_HOST 1
#if CREATE_DESCS_ON_HOST
const auto a_grid_desc_ak0_m_ak1 = karg.a_grid_desc_ak0_m_ak1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment