"...composable_kernel_rocm.git" did not exist on "f6934e0bf4460c7ad97c57d5f4a645e426048b1d"
Commit 148d9e57 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Move kernel arg type definition into GridwiseGemm

parent affdca9d
......@@ -130,106 +130,11 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
LoopSched,
PipelineVer>;
using AGridDesc_AK0_M_AK1 =
decltype(GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(1, 1, 1, 1, 1, 1));
using BGridDesc_BK0_N_BK1 =
decltype(GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(1, 1, 1, 1, 1, 1));
using CGridDesc_M_N = decltype(GridwiseGemm::MakeCGridDescriptor_M_N(1, 1, 1, 1, 1));
// Argument
struct Argument : public BaseArgument
{
__host__ Argument(const ADataType* p_a_grid_,
const BDataType* p_b_grid_,
CDataType* p_c_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_)
: p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_c_grid{p_c_grid_},
M{M_},
N{N_},
K{K_},
StrideA{StrideA_},
StrideB{StrideB_},
StrideC{StrideC_},
MPadded{GridwiseGemm::CalculateMPadded(M_)},
NPadded{GridwiseGemm::CalculateNPadded(N_)},
KPadded{GridwiseGemm::CalculateKPadded(K_)},
AK0{GridwiseGemm::CalculateAK0(K_)},
BK0{GridwiseGemm::CalculateBK0(K_)},
a_grid_desc_ak0_m_ak1{
GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(M_,
GridwiseGemm::CalculateMPadded(M_),
K_,
GridwiseGemm::CalculateKPadded(K_),
StrideA_,
GridwiseGemm::CalculateAK0(K_))},
b_grid_desc_bk0_n_bk1{
GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(K_,
GridwiseGemm::CalculateKPadded(K_),
N_,
GridwiseGemm::CalculateNPadded(N_),
StrideB_,
GridwiseGemm::CalculateBK0(K_))},
c_grid_desc_m_n{
GridwiseGemm::MakeCGridDescriptor_M_N(M_,
GridwiseGemm::CalculateMPadded(M_),
N_,
GridwiseGemm::CalculateNPadded(N_),
StrideC_)}
{
}
__host__ __device__ void Print() const
{
std::cout << "arg {"
<< "M:" << M << ", "
<< "N:" << N << ", "
<< "K:" << K << ", "
<< "SA:" << StrideA << ", "
<< "SB:" << StrideB << ", "
<< "SC:" << StrideC << ", "
<< "MP:" << MPadded << ", "
<< "NP:" << NPadded << ", "
<< "KP:" << KPadded << ", "
<< "AK0:" << AK0 << ", "
<< "BK0:" << BK0 << "}" << std::endl;
}
__host__ __device__ Argument(const Argument&) = default;
__host__ __device__ ~Argument() override {}
// private:
const ADataType* p_a_grid;
const BDataType* p_b_grid;
CDataType* p_c_grid;
index_t M;
index_t N;
index_t K;
index_t StrideA;
index_t StrideB;
index_t StrideC;
index_t MPadded;
index_t NPadded;
index_t KPadded;
index_t AK0;
index_t BK0;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1;
CGridDesc_M_N c_grid_desc_m_n;
};
using Argument = typename GridwiseGemm::Argument;
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceOp::Argument;
void Print(const Argument& karg) { karg.Print(); }
float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{})
......@@ -253,16 +158,15 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v1_simplified<GridwiseGemm, Argument, true>;
const auto kernel = kernel_gemm_xdl_cshuffle_v1_simplified<GridwiseGemm, true>;
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v1_simplified<GridwiseGemm, Argument, false>;
const auto kernel = kernel_gemm_xdl_cshuffle_v1_simplified<GridwiseGemm, false>;
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
}
......
......@@ -17,12 +17,12 @@
namespace ck {
template <typename GridwiseGemm, typename Argument, bool HasMainKBlockLoop>
template <typename GridwiseGemm, bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_xdl_cshuffle_v1_simplified(Argument karg)
kernel_gemm_xdl_cshuffle_v1_simplified(typename GridwiseGemm::Argument karg)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
......@@ -383,6 +383,85 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
}
}
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1, 1, 1, 1));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1, 1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1, 1, 1));
// Argument
struct Argument : public tensor_operation::device::BaseArgument
{
__host__ Argument(const FloatAB* p_a_grid_,
const FloatAB* p_b_grid_,
FloatC* p_c_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_)
: p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_c_grid{p_c_grid_},
M{M_},
N{N_},
K{K_},
StrideA{StrideA_},
StrideB{StrideB_},
StrideC{StrideC_},
MPadded{CalculateMPadded(M_)},
NPadded{CalculateNPadded(N_)},
KPadded{CalculateKPadded(K_)},
AK0{CalculateAK0(K_)},
BK0{CalculateBK0(K_)},
a_grid_desc_ak0_m_ak1{MakeAGridDescriptor_AK0_M_AK1(
M_, CalculateMPadded(M_), K_, CalculateKPadded(K_), StrideA_, CalculateAK0(K_))},
b_grid_desc_bk0_n_bk1{MakeBGridDescriptor_BK0_N_BK1(
K_, CalculateKPadded(K_), N_, CalculateNPadded(N_), StrideB_, CalculateBK0(K_))},
c_grid_desc_m_n{MakeCGridDescriptor_M_N(
M_, CalculateMPadded(M_), N_, CalculateNPadded(N_), StrideC_)}
{
}
__host__ __device__ void Print() const
{
std::cout << "arg {"
<< "M:" << M << ", "
<< "N:" << N << ", "
<< "K:" << K << ", "
<< "SA:" << StrideA << ", "
<< "SB:" << StrideB << ", "
<< "SC:" << StrideC << ", "
<< "MP:" << MPadded << ", "
<< "NP:" << NPadded << ", "
<< "KP:" << KPadded << ", "
<< "AK0:" << AK0 << ", "
<< "BK0:" << BK0 << "}" << std::endl;
}
__host__ __device__ Argument(const Argument&) = default;
__host__ __device__ ~Argument() override {}
// private:
const FloatAB* p_a_grid;
const FloatAB* p_b_grid;
FloatC* p_c_grid;
index_t M;
index_t N;
index_t K;
index_t StrideA;
index_t StrideB;
index_t StrideC;
index_t MPadded;
index_t NPadded;
index_t KPadded;
index_t AK0;
index_t BK0;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1;
CGridDesc_M_N c_grid_desc_m_n;
};
// FIXME: pass GridwiseGemmPipe as a template arguement into GridwiseGemm
using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
......@@ -447,7 +526,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Argument>
__host__ __device__ static constexpr bool CheckValidity(const Argument& karg)
{
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
......@@ -590,7 +668,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
print_bytes(memory, sizeof(T));
}
template <bool HasMainKBlockLoop, typename Argument>
template <bool HasMainKBlockLoop>
__device__ static void Run(const Argument& karg, void* __restrict__ p_shared)
{
const FloatAB* p_a_grid = karg.p_a_grid;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment