"...git@developer.sourcefind.cn:cnjsdfcy/simbricks.git" did not exist on "cc179404dde0c2787c6e108a12745d7c8f1a1dcc"
Commit 5decb4a7 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Merge branch 'feature/integrage-karg-simplification-pr' into...

Merge branch 'feature/integrage-karg-simplification-pr' into feature/simplify-karg-for-device-gemm-xdl
parents 7c2b82ca 64b9b6a0
...@@ -168,9 +168,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -168,9 +168,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
using CGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1)); using CGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1));
// Argument // Argument
struct Argument : public GridwiseGemm::Argument struct Argument : public tensor_operation::device::BaseArgument, public GridwiseGemm::Problem
{ {
using Parent = typename GridwiseGemm::Argument; using Problem = typename GridwiseGemm::Problem;
Argument(const ADataType* p_a_grid_real_, Argument(const ADataType* p_a_grid_real_,
const ADataType* p_a_grid_imag_, const ADataType* p_a_grid_imag_,
...@@ -185,7 +185,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -185,7 +185,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
index_t StrideA_, index_t StrideA_,
index_t StrideB_, index_t StrideB_,
index_t StrideC_) index_t StrideC_)
: Parent(M_, N_, K_, StrideA_, StrideB_, StrideC_), : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_},
p_a_grid_real{p_a_grid_real_}, p_a_grid_real{p_a_grid_real_},
p_a_grid_imag{p_a_grid_imag_}, p_a_grid_imag{p_a_grid_imag_},
p_b_grid_real{p_b_grid_real_}, p_b_grid_real{p_b_grid_real_},
...@@ -225,22 +225,22 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -225,22 +225,22 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
// Invoker // Invoker
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
if(stream_config.log_level_ > 0) if(stream_config.log_level_ > 0)
{ {
karg.Print(); arg.Print();
} }
if(!GridwiseGemm::CheckValidity(karg)) if(!GridwiseGemm::CheckValidity(arg))
{ {
throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
} }
index_t gdx, gdy, gdz; index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(karg.M, karg.N); std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N);
const auto K = GridwiseGemm::CalculateAK0(karg.K) * AK1; const auto K = GridwiseGemm::CalculateAK0(arg.K) * AK1;
float ave_time = 0; float ave_time = 0;
...@@ -284,27 +284,28 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -284,27 +284,28 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{ {
const auto kernel = kernel_gemm_xdl_cshuffle_v1_simplified<GridwiseGemm, true>; const auto kernel =
kernel_gemm_xdl_cshuffle_v2<GridwiseGemm, ADataType, CDataType, true>;
ave_time += launch_and_time_kernel(stream_config, ave_time += launch_and_time_kernel(stream_config,
kernel, kernel,
dim3(gdx, gdy, gdz), dim3(gdx, gdy, gdz),
dim3(BlockSize), dim3(BlockSize),
0, 0,
karg.p_a_grid_real, arg.p_a_grid_real,
karg.p_b_grid_real, arg.p_b_grid_real,
karg.p_aux_grid, arg.p_aux_grid,
karg); arg);
ave_time += launch_and_time_kernel(stream_config, ave_time += launch_and_time_kernel(stream_config,
kernel, kernel,
dim3(gdx, gdy, gdz), dim3(gdx, gdy, gdz),
dim3(BlockSize), dim3(BlockSize),
0, 0,
karg.p_a_grid_imag, arg.p_a_grid_imag,
karg.p_b_grid_imag, arg.p_b_grid_imag,
karg.p_aux_2_grid, arg.p_aux_2_grid,
karg); arg);
// c_real = aux - aux_2 // c_real = aux - aux_2
ave_time += launch_and_time_kernel( ave_time += launch_and_time_kernel(
...@@ -313,11 +314,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -313,11 +314,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
dim3(gdx, gdy, gdz), dim3(gdx, gdy, gdz),
dim3(BlockSize), dim3(BlockSize),
0, 0,
make_tuple(karg.c_grid_desc_m, karg.c_grid_desc_m), make_tuple(arg.c_grid_desc_m, arg.c_grid_desc_m),
make_tuple(karg.c_grid_desc_m), make_tuple(arg.c_grid_desc_m),
make_tuple(const_cast<const CDataType*>(karg.p_aux_grid), make_tuple(const_cast<const CDataType*>(arg.p_aux_grid),
const_cast<const CDataType*>(karg.p_aux_2_grid)), const_cast<const CDataType*>(arg.p_aux_2_grid)),
make_tuple(karg.p_c_grid_real), make_tuple(arg.p_c_grid_real),
Subtract{}); Subtract{});
ave_time += launch_and_time_kernel(stream_config, ave_time += launch_and_time_kernel(stream_config,
...@@ -325,20 +326,20 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -325,20 +326,20 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
dim3(gdx, gdy, gdz), dim3(gdx, gdy, gdz),
dim3(BlockSize), dim3(BlockSize),
0, 0,
karg.p_a_grid_real, arg.p_a_grid_real,
karg.p_b_grid_imag, arg.p_b_grid_imag,
karg.p_aux_grid, arg.p_aux_grid,
karg); arg);
ave_time += launch_and_time_kernel(stream_config, ave_time += launch_and_time_kernel(stream_config,
kernel, kernel,
dim3(gdx, gdy, gdz), dim3(gdx, gdy, gdz),
dim3(BlockSize), dim3(BlockSize),
0, 0,
karg.p_a_grid_imag, arg.p_a_grid_imag,
karg.p_b_grid_real, arg.p_b_grid_real,
karg.p_aux_2_grid, arg.p_aux_2_grid,
karg); arg);
// c_imag = aux + aux_2 // c_imag = aux + aux_2
ave_time += launch_and_time_kernel( ave_time += launch_and_time_kernel(
...@@ -347,36 +348,37 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -347,36 +348,37 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
dim3(gdx, gdy, gdz), dim3(gdx, gdy, gdz),
dim3(BlockSize), dim3(BlockSize),
0, 0,
make_tuple(karg.c_grid_desc_m, karg.c_grid_desc_m), make_tuple(arg.c_grid_desc_m, arg.c_grid_desc_m),
make_tuple(karg.c_grid_desc_m), make_tuple(arg.c_grid_desc_m),
make_tuple(const_cast<const CDataType*>(karg.p_aux_grid), make_tuple(const_cast<const CDataType*>(arg.p_aux_grid),
const_cast<const CDataType*>(karg.p_aux_2_grid)), const_cast<const CDataType*>(arg.p_aux_2_grid)),
make_tuple(karg.p_c_grid_imag), make_tuple(arg.p_c_grid_imag),
Add{}); Add{});
} }
else else
{ {
const auto kernel = kernel_gemm_xdl_cshuffle_v1_simplified<GridwiseGemm, false>; const auto kernel =
kernel_gemm_xdl_cshuffle_v2<GridwiseGemm, ADataType, CDataType, false>;
ave_time += launch_and_time_kernel(stream_config, ave_time += launch_and_time_kernel(stream_config,
kernel, kernel,
dim3(gdx, gdy, gdz), dim3(gdx, gdy, gdz),
dim3(BlockSize), dim3(BlockSize),
0, 0,
karg.p_a_grid_real, arg.p_a_grid_real,
karg.p_b_grid_real, arg.p_b_grid_real,
karg.p_aux_grid, arg.p_aux_grid,
karg); arg);
ave_time += launch_and_time_kernel(stream_config, ave_time += launch_and_time_kernel(stream_config,
kernel, kernel,
dim3(gdx, gdy, gdz), dim3(gdx, gdy, gdz),
dim3(BlockSize), dim3(BlockSize),
0, 0,
karg.p_a_grid_imag, arg.p_a_grid_imag,
karg.p_b_grid_imag, arg.p_b_grid_imag,
karg.p_aux_2_grid, arg.p_aux_2_grid,
karg); arg);
// c_real = aux - aux_2 // c_real = aux - aux_2
ave_time += launch_and_time_kernel( ave_time += launch_and_time_kernel(
...@@ -385,11 +387,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -385,11 +387,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
dim3(gdx, gdy, gdz), dim3(gdx, gdy, gdz),
dim3(BlockSize), dim3(BlockSize),
0, 0,
make_tuple(karg.c_grid_desc_m, karg.c_grid_desc_m), make_tuple(arg.c_grid_desc_m, arg.c_grid_desc_m),
make_tuple(karg.c_grid_desc_m), make_tuple(arg.c_grid_desc_m),
make_tuple(const_cast<const CDataType*>(karg.p_aux_grid), make_tuple(const_cast<const CDataType*>(arg.p_aux_grid),
const_cast<const CDataType*>(karg.p_aux_2_grid)), const_cast<const CDataType*>(arg.p_aux_2_grid)),
make_tuple(karg.p_c_grid_real), make_tuple(arg.p_c_grid_real),
Subtract{}); Subtract{});
ave_time += launch_and_time_kernel(stream_config, ave_time += launch_and_time_kernel(stream_config,
...@@ -397,20 +399,20 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -397,20 +399,20 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
dim3(gdx, gdy, gdz), dim3(gdx, gdy, gdz),
dim3(BlockSize), dim3(BlockSize),
0, 0,
karg.p_a_grid_real, arg.p_a_grid_real,
karg.p_b_grid_imag, arg.p_b_grid_imag,
karg.p_aux_grid, arg.p_aux_grid,
karg); arg);
ave_time += launch_and_time_kernel(stream_config, ave_time += launch_and_time_kernel(stream_config,
kernel, kernel,
dim3(gdx, gdy, gdz), dim3(gdx, gdy, gdz),
dim3(BlockSize), dim3(BlockSize),
0, 0,
karg.p_a_grid_imag, arg.p_a_grid_imag,
karg.p_b_grid_real, arg.p_b_grid_real,
karg.p_aux_2_grid, arg.p_aux_2_grid,
karg); arg);
// c_imag = aux + aux_2 // c_imag = aux + aux_2
ave_time += launch_and_time_kernel( ave_time += launch_and_time_kernel(
...@@ -419,11 +421,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -419,11 +421,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
dim3(gdx, gdy, gdz), dim3(gdx, gdy, gdz),
dim3(BlockSize), dim3(BlockSize),
0, 0,
make_tuple(karg.c_grid_desc_m, karg.c_grid_desc_m), make_tuple(arg.c_grid_desc_m, arg.c_grid_desc_m),
make_tuple(karg.c_grid_desc_m), make_tuple(arg.c_grid_desc_m),
make_tuple(const_cast<const CDataType*>(karg.p_aux_grid), make_tuple(const_cast<const CDataType*>(arg.p_aux_grid),
const_cast<const CDataType*>(karg.p_aux_2_grid)), const_cast<const CDataType*>(arg.p_aux_2_grid)),
make_tuple(karg.p_c_grid_imag), make_tuple(arg.p_c_grid_imag),
Add{}); Add{});
} }
...@@ -444,9 +446,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -444,9 +446,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
return true; return true;
} }
static bool IsSupportedArgument(const Argument& karg) static bool IsSupportedArgument(const Argument& arg)
{ {
return GridwiseGemm::CheckValidity(karg); return GridwiseGemm::CheckValidity(arg);
} }
// polymorphic // polymorphic
......
...@@ -130,80 +130,43 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -130,80 +130,43 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
LoopSched, LoopSched,
PipelineVer>; PipelineVer>;
struct Argument : public GridwiseGemm::Argument using Argument = typename GridwiseGemm::Argument;
{
using Parent = typename GridwiseGemm::Argument;
Argument(const ADataType* p_a_grid_,
const BDataType* p_b_grid_,
CDataType* p_c_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_)
: Parent(M_, N_, K_, StrideA_, StrideB_, StrideC_),
p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_c_grid{p_c_grid_}
{
}
const ADataType* p_a_grid;
const BDataType* p_b_grid;
CDataType* p_c_grid;
};
// Invoker // Invoker
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
if(stream_config.log_level_ > 0) if(stream_config.log_level_ > 0)
{ {
karg.Print(); arg.Print();
} }
if(!GridwiseGemm::CheckValidity(karg)) if(!GridwiseGemm::CheckValidity(arg))
{ {
throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
} }
index_t gdx, gdy, gdz; index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(karg.M, karg.N); std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N);
const auto K = GridwiseGemm::CalculateAK0(karg.K) * AK1; const auto K = GridwiseGemm::CalculateAK0(arg.K) * AK1;
float ave_time = 0; float ave_time = 0;
if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{ {
const auto kernel = kernel_gemm_xdl_cshuffle_v1_simplified<GridwiseGemm, true>; const auto kernel = kernel_gemm_xdl_cshuffle_v1<GridwiseGemm, true>;
ave_time = launch_and_time_kernel(stream_config, ave_time = launch_and_time_kernel(
kernel, stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
karg.p_a_grid,
karg.p_b_grid,
karg.p_c_grid,
karg);
} }
else else
{ {
const auto kernel = kernel_gemm_xdl_cshuffle_v1_simplified<GridwiseGemm, false>; const auto kernel = kernel_gemm_xdl_cshuffle_v1<GridwiseGemm, false>;
ave_time = launch_and_time_kernel(stream_config, ave_time = launch_and_time_kernel(
kernel, stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
karg.p_a_grid,
karg.p_b_grid,
karg.p_c_grid,
karg);
} }
return ave_time; return ave_time;
...@@ -223,9 +186,9 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -223,9 +186,9 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
return true; return true;
} }
static bool IsSupportedArgument(const Argument& karg) static bool IsSupportedArgument(const Argument& arg)
{ {
return GridwiseGemm::CheckValidity(karg); return GridwiseGemm::CheckValidity(arg);
} }
// polymorphic // polymorphic
......
...@@ -25,32 +25,49 @@ __global__ void ...@@ -25,32 +25,49 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_gemm_xdl_cshuffle_v1_simplified( kernel_gemm_xdl_cshuffle_v1(typename GridwiseGemm::Argument karg)
const typename GridwiseGemm::FloatAB* __restrict__ p_a_grid,
const typename GridwiseGemm::FloatAB* __restrict__ p_b_grid,
typename GridwiseGemm::FloatC* __restrict__ p_c_grid,
typename GridwiseGemm::Argument karg)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__)) defined(__gfx940__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, p_b_grid, p_c_grid, p_shared, karg); GridwiseGemm::template Run<HasMainKBlockLoop>(
karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared, karg);
#else
ignore = karg;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template <typename GridwiseGemm, typename FloatAB, typename FloatC, bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_xdl_cshuffle_v2(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
typename GridwiseGemm::Problem problem)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, p_b_grid, p_c_grid, p_shared, problem);
#else #else
ignore = p_a_grid; ignore = p_a_grid;
ignore = p_b_grid; ignore = p_b_grid;
ignore = p_c_grid; ignore = p_c_grid;
ignore = karg; ignore = problem;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
} }
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
typename CLayout, typename CLayout,
typename FloatAB_, typename FloatAB,
typename FloatGemmAcc, typename FloatGemmAcc,
typename FloatCShuffle, typename FloatCShuffle,
typename FloatC_, typename FloatC,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
...@@ -106,9 +123,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -106,9 +123,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
static constexpr auto AK1Number = Number<AK1Value>{}; static constexpr auto AK1Number = Number<AK1Value>{};
static constexpr auto BK1Number = Number<BK1Value>{}; static constexpr auto BK1Number = Number<BK1Value>{};
using FloatAB = FloatAB_;
using FloatC = FloatC_;
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
__host__ static auto CalculateGridSize(index_t M, index_t N) __host__ static auto CalculateGridSize(index_t M, index_t N)
...@@ -392,10 +406,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -392,10 +406,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
} }
} }
// Argument struct Problem
struct Argument : public tensor_operation::device::BaseArgument
{ {
__host__ Argument(index_t M_, __host__ Problem(index_t M_,
index_t N_, index_t N_,
index_t K_, index_t K_,
index_t StrideA_, index_t StrideA_,
...@@ -419,7 +432,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -419,7 +432,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
__host__ void Print() const __host__ void Print() const
{ {
std::cout << "arg {" std::cout << "problem {"
<< "M:" << M << ", " << "M:" << M << ", "
<< "N:" << N << ", " << "N:" << N << ", "
<< "K:" << K << ", " << "K:" << K << ", "
...@@ -450,6 +463,30 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -450,6 +463,30 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
index_t NBlock; index_t NBlock;
}; };
// Argument
struct Argument : public tensor_operation::device::BaseArgument, public Problem
{
__host__ Argument(const FloatAB* p_a_grid_,
const FloatAB* p_b_grid_,
FloatC* p_c_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_)
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_},
p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_c_grid{p_c_grid_}
{
}
const FloatAB* p_a_grid;
const FloatAB* p_b_grid;
FloatC* p_c_grid;
};
// FIXME: pass GridwiseGemmPipe as a template arguement into GridwiseGemm // FIXME: pass GridwiseGemmPipe as a template arguement into GridwiseGemm
using GridwiseGemmPipe = remove_cvref_t<decltype( using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>; GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
...@@ -513,7 +550,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -513,7 +550,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
} }
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
__host__ static constexpr bool CheckValidity(const Argument& karg) __host__ static constexpr bool CheckValidity(const Problem& problem)
{ {
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0, (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
...@@ -524,7 +561,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -524,7 +561,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
{ {
if(!(karg.M % MPerBlock == 0)) if(!(problem.M % MPerBlock == 0))
{ {
return false; return false;
} }
...@@ -535,7 +572,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -535,7 +572,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
{ {
if(!(karg.N % NPerBlock == 0)) if(!(problem.N % NPerBlock == 0))
{ {
return false; return false;
} }
...@@ -546,15 +583,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -546,15 +583,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding) GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding)
{ {
if(!(CalculateKPadded(karg.K) % AK1Value == 0) || if(!(CalculateKPadded(problem.K) % AK1Value == 0) ||
!(CalculateKPadded(karg.K) % BK1Value == 0)) !(CalculateKPadded(problem.K) % BK1Value == 0))
{ {
return false; return false;
} }
} }
else else
{ {
if(!(karg.K % AK1Value == 0) || !(karg.K % BK1Value == 0)) if(!(problem.K % AK1Value == 0) || !(problem.K % BK1Value == 0))
{ {
return false; return false;
} }
...@@ -562,14 +599,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -562,14 +599,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{ {
if(karg.K % ABlockTransferSrcScalarPerVector != 0) if(problem.K % ABlockTransferSrcScalarPerVector != 0)
{ {
return false; return false;
} }
} }
else else
{ {
if(karg.M % ABlockTransferSrcScalarPerVector != 0) if(problem.M % ABlockTransferSrcScalarPerVector != 0)
{ {
return false; return false;
} }
...@@ -577,14 +614,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -577,14 +614,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{ {
if(karg.N % BBlockTransferSrcScalarPerVector != 0) if(problem.N % BBlockTransferSrcScalarPerVector != 0)
{ {
return false; return false;
} }
} }
else else
{ {
if(karg.K % BBlockTransferSrcScalarPerVector != 0) if(problem.K % BBlockTransferSrcScalarPerVector != 0)
{ {
return false; return false;
} }
...@@ -592,21 +629,21 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -592,21 +629,21 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{ {
if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) if(problem.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
{ {
return false; return false;
} }
} }
else else
{ {
if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) if(problem.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
{ {
return false; return false;
} }
} }
// check gridwise gemm pipeline // check gridwise gemm pipeline
const auto num_k_loop = (CalculateAK0(karg.K) * AK1Value) / KPerBlock; const auto num_k_loop = (CalculateAK0(problem.K) * AK1Value) / KPerBlock;
if(!GridwiseGemmPipe::IsSupported(num_k_loop)) if(!GridwiseGemmPipe::IsSupported(num_k_loop))
{ {
...@@ -646,7 +683,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -646,7 +683,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
void* __restrict__ p_shared, void* __restrict__ p_shared,
const Argument& karg) const Problem& problem)
{ {
#if ENABLE_DUMP_CLOCK #if ENABLE_DUMP_CLOCK
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
...@@ -656,15 +693,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -656,15 +693,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
#endif #endif
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
karg.M, karg.MPadded, karg.K, karg.KPadded, karg.StrideA, karg.AK0); problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
karg.K, karg.KPadded, karg.N, karg.NPadded, karg.StrideB, karg.BK0); problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
const auto c_grid_desc_m_n = const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
MakeCGridDescriptor_M_N(karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideC); problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock = const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n, karg.MBlock, karg.NBlock); c_grid_desc_m_n, problem.MBlock, problem.NBlock);
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
...@@ -678,7 +715,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -678,7 +715,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
const CElementwiseOperation c_element_op{}; const CElementwiseOperation c_element_op{};
// divide block work by [M, N] // divide block work by [M, N]
const auto block_2_ctile_map = Block2CTileMap{karg.M, karg.N}; const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N};
const auto block_work_idx = const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment