"...composable_kernel.git" did not exist on "ad09ebdb531285c35f7c45be68db7fd52b5dc082"
Commit ef5afc55 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Reserve kernel arg as whole object in interfaces

parent 613dcc6b
...@@ -18,7 +18,7 @@ struct BaseArgument ...@@ -18,7 +18,7 @@ struct BaseArgument
BaseArgument(const BaseArgument&) = default; BaseArgument(const BaseArgument&) = default;
BaseArgument& operator=(const BaseArgument&) = default; BaseArgument& operator=(const BaseArgument&) = default;
virtual ~BaseArgument() {} __host__ __device__ virtual ~BaseArgument() {}
void* p_workspace_ = nullptr; void* p_workspace_ = nullptr;
}; };
......
...@@ -348,54 +348,58 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -348,54 +348,58 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument(const ADataType* p_a_grid, __host__ Argument(const ADataType* p_a_grid_,
const BDataType* p_b_grid, const BDataType* p_b_grid_,
CDataType* p_c_grid, CDataType* p_c_grid_,
index_t M, index_t M_,
index_t N, index_t N_,
index_t K, index_t K_,
index_t StrideA, index_t StrideA_,
index_t StrideB, index_t StrideB_,
index_t StrideC) index_t StrideC_)
: p_a_grid_{p_a_grid}, : p_a_grid{p_a_grid_},
p_b_grid_{p_b_grid}, p_b_grid{p_b_grid_},
p_c_grid_{p_c_grid}, p_c_grid{p_c_grid_},
M_{M}, M{M_},
N_{N}, N{N_},
K_{K}, K{K_},
a_grid_desc_ak0_m_ak1_{ a_grid_desc_ak0_m_ak1{
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(M, DeviceOp::MakeAGridDescriptor_AK0_M_AK1(M_,
GridwiseGemm::CalculateMPadded(M), GridwiseGemm::CalculateMPadded(M_),
K, K_,
GridwiseGemm::CalculateKPadded(K), GridwiseGemm::CalculateKPadded(K_),
StrideA, StrideA_,
GridwiseGemm::CalculateAK0(K))}, GridwiseGemm::CalculateAK0(K_))},
b_grid_desc_bk0_n_bk1_{ b_grid_desc_bk0_n_bk1{
DeviceOp::MakeBGridDescriptor_BK0_N_BK1(K, DeviceOp::MakeBGridDescriptor_BK0_N_BK1(K_,
GridwiseGemm::CalculateKPadded(K), GridwiseGemm::CalculateKPadded(K_),
N, N_,
GridwiseGemm::CalculateNPadded(N), GridwiseGemm::CalculateNPadded(N_),
StrideB, StrideB_,
GridwiseGemm::CalculateBK0(K))}, GridwiseGemm::CalculateBK0(K_))},
c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(M, c_grid_desc_m_n{DeviceOp::MakeCGridDescriptor_M_N(M_,
GridwiseGemm::CalculateMPadded(M), GridwiseGemm::CalculateMPadded(M_),
N, N_,
GridwiseGemm::CalculateNPadded(N), GridwiseGemm::CalculateNPadded(N_),
StrideC)}, StrideC_)},
kraw_{K} kraw_{K_}
{ {
} }
__host__ __device__ Argument(const Argument&) = default;
__host__ __device__ ~Argument() override {}
// private: // private:
const ADataType* p_a_grid_; const ADataType* p_a_grid;
const BDataType* p_b_grid_; const BDataType* p_b_grid;
CDataType* p_c_grid_; CDataType* p_c_grid;
index_t M_; index_t M;
index_t N_; index_t N;
index_t K_; index_t K;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1;
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n;
index_t kraw_; index_t kraw_;
}; };
...@@ -408,78 +412,47 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -408,78 +412,47 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
{ {
#if DEBUG_LOG #if DEBUG_LOG
{ {
std::cout << "arg.a_grid_desc_ak0_m_ak1_{" // std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
<< karg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", " // << karg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
<< karg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", " // << karg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", "
<< karg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl; // << karg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_bk0_n_bk1_{" // std::cout << "arg.b_grid_desc_bk0_n_bk1_{"
<< karg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", " // << karg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", "
<< karg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", " // << karg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", "
<< karg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl; // << karg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " // std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ",
<< karg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; // "
// << karg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
} }
#endif #endif
if(!GridwiseGemm::CheckValidity( if(!GridwiseGemm::CheckValidity(
karg.a_grid_desc_ak0_m_ak1_, karg.b_grid_desc_bk0_n_bk1_, karg.c_grid_desc_m_n_)) karg.a_grid_desc_ak0_m_ak1, karg.b_grid_desc_bk0_n_bk1, karg.c_grid_desc_m_n))
{ {
throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
} }
index_t gdx, gdy, gdz; index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(karg.M_, karg.N_); std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(karg.M, karg.N);
const auto K = GridwiseGemm::CalculateAK0(karg.K_) * AK1; const auto K = GridwiseGemm::CalculateAK0(karg.K) * AK1;
float ave_time = 0; float ave_time = 0;
if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{ {
const auto kernel = const auto kernel = kernel_gemm_xdl_cshuffle_v1<GridwiseGemm, Argument, true>;
kernel_gemm_xdl_cshuffle_v1<GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ave_time = launch_and_time_kernel(
CDataType, stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
DeviceOp::CGridDesc_M_N,
true>;
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
karg.p_a_grid_,
karg.p_b_grid_,
karg.p_c_grid_,
karg.a_grid_desc_ak0_m_ak1_,
karg.b_grid_desc_bk0_n_bk1_,
karg.c_grid_desc_m_n_);
} }
else else
{ {
const auto kernel = const auto kernel = kernel_gemm_xdl_cshuffle_v1<GridwiseGemm, Argument, false>;
kernel_gemm_xdl_cshuffle_v1<GridwiseGemm, ave_time = launch_and_time_kernel(
ADataType, // TODO: distiguish A/B datatype stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
CDataType,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
DeviceOp::CGridDesc_M_N,
false>;
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
karg.p_a_grid_,
karg.p_b_grid_,
karg.p_c_grid_,
karg.a_grid_desc_ak0_m_ak1_,
karg.b_grid_desc_bk0_n_bk1_,
karg.c_grid_desc_m_n_);
} }
return ave_time; return ave_time;
...@@ -516,7 +489,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -516,7 +489,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
} }
return GridwiseGemm::CheckValidity( return GridwiseGemm::CheckValidity(
arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.c_grid_desc_m_n_); arg.a_grid_desc_ak0_m_ak1, arg.b_grid_desc_bk0_n_bk1, arg.c_grid_desc_m_n);
} }
// polymorphic // polymorphic
......
...@@ -17,41 +17,19 @@ ...@@ -17,41 +17,19 @@
namespace ck { namespace ck {
template <typename GridwiseGemm, template <typename GridwiseGemm, typename Argument, bool HasMainKBlockLoop>
typename FloatAB,
typename FloatC,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename CGridDesc_M_N,
bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_gemm_xdl_cshuffle_v1(const FloatAB* __restrict__ p_a_grid, kernel_gemm_xdl_cshuffle_v1(const Argument karg)
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const CGridDesc_M_N c_grid_desc_m_n)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, GridwiseGemm::template Run<HasMainKBlockLoop>(karg, p_shared);
p_b_grid,
p_c_grid,
p_shared,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_m_n);
#else #else
ignore = p_a_grid; ignore = karg;
ignore = p_b_grid;
ignore = p_c_grid;
ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1;
ignore = c_grid_desc_m_n;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
} }
...@@ -322,15 +300,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -322,15 +300,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using Block2CTileMap = remove_cvref_t<decltype(MakeBlock2CTileMap(CGridDesc_M_N{}))>; using Block2CTileMap = remove_cvref_t<decltype(MakeBlock2CTileMap(CGridDesc_M_N{}))>;
template <bool HasMainKBlockLoop> template <bool HasMainKBlockLoop, typename Argument>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid, __device__ static void Run(const Argument karg, void* __restrict__ p_shared)
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
void* __restrict__ p_shared,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const CGridDesc_M_N c_grid_desc_m_n)
{ {
const FloatAB* p_a_grid = karg.p_a_grid;
const FloatAB* p_b_grid = karg.p_b_grid;
FloatC* p_c_grid = karg.p_c_grid;
const auto& a_grid_desc_ak0_m_ak1 = karg.a_grid_desc_ak0_m_ak1;
const auto& b_grid_desc_bk0_n_bk1 = karg.b_grid_desc_bk0_n_bk1;
const auto& c_grid_desc_m_n = karg.c_grid_desc_m_n;
const auto c_grid_desc_mblock_mperblock_nblock_nperblock = const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n); MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment