Commit ef5afc55 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Reserve kernel arg as whole object in interfaces

parent 613dcc6b
......@@ -18,7 +18,7 @@ struct BaseArgument
BaseArgument(const BaseArgument&) = default;
BaseArgument& operator=(const BaseArgument&) = default;
virtual ~BaseArgument() {}
__host__ __device__ virtual ~BaseArgument() {}
void* p_workspace_ = nullptr;
};
......
......@@ -348,54 +348,58 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
// Argument
struct Argument : public BaseArgument
{
Argument(const ADataType* p_a_grid,
const BDataType* p_b_grid,
CDataType* p_c_grid,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid},
M_{M},
N_{N},
K_{K},
a_grid_desc_ak0_m_ak1_{
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(M,
GridwiseGemm::CalculateMPadded(M),
K,
GridwiseGemm::CalculateKPadded(K),
StrideA,
GridwiseGemm::CalculateAK0(K))},
b_grid_desc_bk0_n_bk1_{
DeviceOp::MakeBGridDescriptor_BK0_N_BK1(K,
GridwiseGemm::CalculateKPadded(K),
N,
GridwiseGemm::CalculateNPadded(N),
StrideB,
GridwiseGemm::CalculateBK0(K))},
c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(M,
GridwiseGemm::CalculateMPadded(M),
N,
GridwiseGemm::CalculateNPadded(N),
StrideC)},
kraw_{K}
__host__ Argument(const ADataType* p_a_grid_,
const BDataType* p_b_grid_,
CDataType* p_c_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_)
: p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_c_grid{p_c_grid_},
M{M_},
N{N_},
K{K_},
a_grid_desc_ak0_m_ak1{
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(M_,
GridwiseGemm::CalculateMPadded(M_),
K_,
GridwiseGemm::CalculateKPadded(K_),
StrideA_,
GridwiseGemm::CalculateAK0(K_))},
b_grid_desc_bk0_n_bk1{
DeviceOp::MakeBGridDescriptor_BK0_N_BK1(K_,
GridwiseGemm::CalculateKPadded(K_),
N_,
GridwiseGemm::CalculateNPadded(N_),
StrideB_,
GridwiseGemm::CalculateBK0(K_))},
c_grid_desc_m_n{DeviceOp::MakeCGridDescriptor_M_N(M_,
GridwiseGemm::CalculateMPadded(M_),
N_,
GridwiseGemm::CalculateNPadded(N_),
StrideC_)},
kraw_{K_}
{
}
__host__ __device__ Argument(const Argument&) = default;
__host__ __device__ ~Argument() override {}
// private:
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
CDataType* p_c_grid_;
index_t M_;
index_t N_;
index_t K_;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
CGridDesc_M_N c_grid_desc_m_n_;
const ADataType* p_a_grid;
const BDataType* p_b_grid;
CDataType* p_c_grid;
index_t M;
index_t N;
index_t K;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1;
CGridDesc_M_N c_grid_desc_m_n;
index_t kraw_;
};
......@@ -408,78 +412,47 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
{
#if DEBUG_LOG
{
std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
<< karg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
<< karg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", "
<< karg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_bk0_n_bk1_{"
<< karg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", "
<< karg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", "
<< karg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< karg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
// std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
// << karg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
// << karg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", "
// << karg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl;
// std::cout << "arg.b_grid_desc_bk0_n_bk1_{"
// << karg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", "
// << karg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", "
// << karg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl;
// std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ",
// "
// << karg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
#endif
if(!GridwiseGemm::CheckValidity(
karg.a_grid_desc_ak0_m_ak1_, karg.b_grid_desc_bk0_n_bk1_, karg.c_grid_desc_m_n_))
karg.a_grid_desc_ak0_m_ak1, karg.b_grid_desc_bk0_n_bk1, karg.c_grid_desc_m_n))
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(karg.M_, karg.N_);
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(karg.M, karg.N);
const auto K = GridwiseGemm::CalculateAK0(karg.K_) * AK1;
const auto K = GridwiseGemm::CalculateAK0(karg.K) * AK1;
float ave_time = 0;
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v1<GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
DeviceOp::CGridDesc_M_N,
true>;
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
karg.p_a_grid_,
karg.p_b_grid_,
karg.p_c_grid_,
karg.a_grid_desc_ak0_m_ak1_,
karg.b_grid_desc_bk0_n_bk1_,
karg.c_grid_desc_m_n_);
const auto kernel = kernel_gemm_xdl_cshuffle_v1<GridwiseGemm, Argument, true>;
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v1<GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
DeviceOp::CGridDesc_M_N,
false>;
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
karg.p_a_grid_,
karg.p_b_grid_,
karg.p_c_grid_,
karg.a_grid_desc_ak0_m_ak1_,
karg.b_grid_desc_bk0_n_bk1_,
karg.c_grid_desc_m_n_);
const auto kernel = kernel_gemm_xdl_cshuffle_v1<GridwiseGemm, Argument, false>;
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
}
return ave_time;
......@@ -516,7 +489,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
}
return GridwiseGemm::CheckValidity(
arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.c_grid_desc_m_n_);
arg.a_grid_desc_ak0_m_ak1, arg.b_grid_desc_bk0_n_bk1, arg.c_grid_desc_m_n);
}
// polymorphic
......
......@@ -17,41 +17,19 @@
namespace ck {
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename CGridDesc_M_N,
bool HasMainKBlockLoop>
template <typename GridwiseGemm, typename Argument, bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_xdl_cshuffle_v1(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const CGridDesc_M_N c_grid_desc_m_n)
kernel_gemm_xdl_cshuffle_v1(const Argument karg)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
p_c_grid,
p_shared,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_m_n);
GridwiseGemm::template Run<HasMainKBlockLoop>(karg, p_shared);
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_c_grid;
ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_m_n;
ignore = karg;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
......@@ -322,15 +300,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using Block2CTileMap = remove_cvref_t<decltype(MakeBlock2CTileMap(CGridDesc_M_N{}))>;
template <bool HasMainKBlockLoop>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
void* __restrict__ p_shared,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const CGridDesc_M_N c_grid_desc_m_n)
template <bool HasMainKBlockLoop, typename Argument>
__device__ static void Run(const Argument karg, void* __restrict__ p_shared)
{
const FloatAB* p_a_grid = karg.p_a_grid;
const FloatAB* p_b_grid = karg.p_b_grid;
FloatC* p_c_grid = karg.p_c_grid;
const auto& a_grid_desc_ak0_m_ak1 = karg.a_grid_desc_ak0_m_ak1;
const auto& b_grid_desc_bk0_n_bk1 = karg.b_grid_desc_bk0_n_bk1;
const auto& c_grid_desc_m_n = karg.c_grid_desc_m_n;
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment