Commit 9acad4f4 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Merge branch 'feature/integrage-karg-simplification-pr' into...

Merge branch 'feature/integrage-karg-simplification-pr' into feature/simplify-karg-for-device-gemm-xdl
parents 312c581b 6570ef7a
...@@ -18,7 +18,7 @@ struct BaseArgument ...@@ -18,7 +18,7 @@ struct BaseArgument
BaseArgument(const BaseArgument&) = default; BaseArgument(const BaseArgument&) = default;
BaseArgument& operator=(const BaseArgument&) = default; BaseArgument& operator=(const BaseArgument&) = default;
__host__ __device__ virtual ~BaseArgument() {} virtual ~BaseArgument() {}
void* p_workspace_ = nullptr; void* p_workspace_ = nullptr;
}; };
......
...@@ -172,12 +172,12 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -172,12 +172,12 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
{ {
using Parent = typename GridwiseGemm::Argument; using Parent = typename GridwiseGemm::Argument;
Argument(const ADataType* p_a_grid_real, Argument(const ADataType* p_a_grid_real_,
const ADataType* p_a_grid_imag, const ADataType* p_a_grid_imag_,
const BDataType* p_b_grid_real, const BDataType* p_b_grid_real_,
const BDataType* p_b_grid_imag, const BDataType* p_b_grid_imag_,
CDataType* p_c_grid_real, CDataType* p_c_grid_real_,
CDataType* p_c_grid_imag, CDataType* p_c_grid_imag_,
CDataType* p_workspace, CDataType* p_workspace,
index_t M_, index_t M_,
index_t N_, index_t N_,
...@@ -185,63 +185,51 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -185,63 +185,51 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
index_t StrideA_, index_t StrideA_,
index_t StrideB_, index_t StrideB_,
index_t StrideC_) index_t StrideC_)
: Parent(M_, : Parent(M_, N_, K_, StrideA_, StrideB_, StrideC_),
N_, p_a_grid_real{p_a_grid_real_},
K_, p_a_grid_imag{p_a_grid_imag_},
StrideA_, p_b_grid_real{p_b_grid_real_},
StrideB_, p_b_grid_imag{p_b_grid_imag_},
StrideC_, p_c_grid_real{p_c_grid_real_},
GridwiseGemm::CalculateMPadded(M_), p_c_grid_imag{p_c_grid_imag_},
GridwiseGemm::CalculateNPadded(N_), p_aux_grid{p_workspace}
GridwiseGemm::CalculateKPadded(K_),
GridwiseGemm::CalculateAK0(K_),
GridwiseGemm::CalculateBK0(K_)),
p_a_grid_real_{p_a_grid_real},
p_a_grid_imag_{p_a_grid_imag},
p_b_grid_real_{p_b_grid_real},
p_b_grid_imag_{p_b_grid_imag},
p_c_grid_real_{p_c_grid_real},
p_c_grid_imag_{p_c_grid_imag},
p_aux_grid_{p_workspace}
{ {
const index_t grid_size = std::get<1>(GridwiseGemm::CalculateGridSize(M_, N_)); const index_t grid_size = std::get<1>(GridwiseGemm::CalculateGridSize(M_, N_));
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{ {
c_grid_desc_m_ = c_grid_desc_m =
DeviceOp::MakeDescriptor_M({M_, N_}, {StrideC_, I1}, grid_size, BlockSize); DeviceOp::MakeDescriptor_M({M_, N_}, {StrideC_, I1}, grid_size, BlockSize);
} }
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value) else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
{ {
c_grid_desc_m_ = c_grid_desc_m =
DeviceOp::MakeDescriptor_M({M_, N_}, {I1, StrideC_}, grid_size, BlockSize); DeviceOp::MakeDescriptor_M({M_, N_}, {I1, StrideC_}, grid_size, BlockSize);
} }
p_aux_2_grid_ = p_workspace + Parent::c_grid_desc_m_n.GetElementSpaceSize(); p_aux_2_grid = p_workspace + GetCElementSpaceSize(M_, N_, StrideC_);
} }
// private: // private:
const ADataType* p_a_grid_real_; const ADataType* p_a_grid_real;
const ADataType* p_a_grid_imag_; const ADataType* p_a_grid_imag;
const BDataType* p_b_grid_real_; const BDataType* p_b_grid_real;
const BDataType* p_b_grid_imag_; const BDataType* p_b_grid_imag;
CDataType* p_c_grid_real_; CDataType* p_c_grid_real;
CDataType* p_c_grid_imag_; CDataType* p_c_grid_imag;
CDataType* p_aux_grid_; CDataType* p_aux_grid;
CDataType* p_aux_2_grid_; CDataType* p_aux_2_grid;
CGridDesc_M c_grid_desc_m_; CGridDesc_M c_grid_desc_m;
}; };
// Invoker // Invoker
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
void Print(const Argument& karg) { karg.Print(); }
float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{})
{ {
if(stream_config.log_level_ > 0) if(stream_config.log_level_ > 0)
{ {
Print(karg); karg.Print();
} }
if(!GridwiseGemm::CheckValidity(karg)) if(!GridwiseGemm::CheckValidity(karg))
...@@ -303,9 +291,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -303,9 +291,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
dim3(gdx, gdy, gdz), dim3(gdx, gdy, gdz),
dim3(BlockSize), dim3(BlockSize),
0, 0,
karg.p_a_grid_real_, karg.p_a_grid_real,
karg.p_b_grid_real_, karg.p_b_grid_real,
karg.p_aux_grid_, karg.p_aux_grid,
karg); karg);
ave_time += launch_and_time_kernel(stream_config, ave_time += launch_and_time_kernel(stream_config,
...@@ -313,9 +301,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -313,9 +301,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
dim3(gdx, gdy, gdz), dim3(gdx, gdy, gdz),
dim3(BlockSize), dim3(BlockSize),
0, 0,
karg.p_a_grid_imag_, karg.p_a_grid_imag,
karg.p_b_grid_imag_, karg.p_b_grid_imag,
karg.p_aux_2_grid_, karg.p_aux_2_grid,
karg); karg);
// c_real = aux - aux_2 // c_real = aux - aux_2
...@@ -325,11 +313,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -325,11 +313,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
dim3(gdx, gdy, gdz), dim3(gdx, gdy, gdz),
dim3(BlockSize), dim3(BlockSize),
0, 0,
make_tuple(karg.c_grid_desc_m_, karg.c_grid_desc_m_), make_tuple(karg.c_grid_desc_m, karg.c_grid_desc_m),
make_tuple(karg.c_grid_desc_m_), make_tuple(karg.c_grid_desc_m),
make_tuple(const_cast<const CDataType*>(karg.p_aux_grid_), make_tuple(const_cast<const CDataType*>(karg.p_aux_grid),
const_cast<const CDataType*>(karg.p_aux_2_grid_)), const_cast<const CDataType*>(karg.p_aux_2_grid)),
make_tuple(karg.p_c_grid_real_), make_tuple(karg.p_c_grid_real),
Subtract{}); Subtract{});
ave_time += launch_and_time_kernel(stream_config, ave_time += launch_and_time_kernel(stream_config,
...@@ -337,9 +325,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -337,9 +325,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
dim3(gdx, gdy, gdz), dim3(gdx, gdy, gdz),
dim3(BlockSize), dim3(BlockSize),
0, 0,
karg.p_a_grid_real_, karg.p_a_grid_real,
karg.p_b_grid_imag_, karg.p_b_grid_imag,
karg.p_aux_grid_, karg.p_aux_grid,
karg); karg);
ave_time += launch_and_time_kernel(stream_config, ave_time += launch_and_time_kernel(stream_config,
...@@ -347,9 +335,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -347,9 +335,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
dim3(gdx, gdy, gdz), dim3(gdx, gdy, gdz),
dim3(BlockSize), dim3(BlockSize),
0, 0,
karg.p_a_grid_imag_, karg.p_a_grid_imag,
karg.p_b_grid_real_, karg.p_b_grid_real,
karg.p_aux_2_grid_, karg.p_aux_2_grid,
karg); karg);
// c_imag = aux + aux_2 // c_imag = aux + aux_2
...@@ -359,11 +347,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -359,11 +347,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
dim3(gdx, gdy, gdz), dim3(gdx, gdy, gdz),
dim3(BlockSize), dim3(BlockSize),
0, 0,
make_tuple(karg.c_grid_desc_m_, karg.c_grid_desc_m_), make_tuple(karg.c_grid_desc_m, karg.c_grid_desc_m),
make_tuple(karg.c_grid_desc_m_), make_tuple(karg.c_grid_desc_m),
make_tuple(const_cast<const CDataType*>(karg.p_aux_grid_), make_tuple(const_cast<const CDataType*>(karg.p_aux_grid),
const_cast<const CDataType*>(karg.p_aux_2_grid_)), const_cast<const CDataType*>(karg.p_aux_2_grid)),
make_tuple(karg.p_c_grid_imag_), make_tuple(karg.p_c_grid_imag),
Add{}); Add{});
} }
else else
...@@ -375,9 +363,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -375,9 +363,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
dim3(gdx, gdy, gdz), dim3(gdx, gdy, gdz),
dim3(BlockSize), dim3(BlockSize),
0, 0,
karg.p_a_grid_real_, karg.p_a_grid_real,
karg.p_b_grid_real_, karg.p_b_grid_real,
karg.p_aux_grid_, karg.p_aux_grid,
karg); karg);
ave_time += launch_and_time_kernel(stream_config, ave_time += launch_and_time_kernel(stream_config,
...@@ -385,9 +373,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -385,9 +373,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
dim3(gdx, gdy, gdz), dim3(gdx, gdy, gdz),
dim3(BlockSize), dim3(BlockSize),
0, 0,
karg.p_a_grid_imag_, karg.p_a_grid_imag,
karg.p_b_grid_imag_, karg.p_b_grid_imag,
karg.p_aux_2_grid_, karg.p_aux_2_grid,
karg); karg);
// c_real = aux - aux_2 // c_real = aux - aux_2
...@@ -397,11 +385,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -397,11 +385,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
dim3(gdx, gdy, gdz), dim3(gdx, gdy, gdz),
dim3(BlockSize), dim3(BlockSize),
0, 0,
make_tuple(karg.c_grid_desc_m_, karg.c_grid_desc_m_), make_tuple(karg.c_grid_desc_m, karg.c_grid_desc_m),
make_tuple(karg.c_grid_desc_m_), make_tuple(karg.c_grid_desc_m),
make_tuple(const_cast<const CDataType*>(karg.p_aux_grid_), make_tuple(const_cast<const CDataType*>(karg.p_aux_grid),
const_cast<const CDataType*>(karg.p_aux_2_grid_)), const_cast<const CDataType*>(karg.p_aux_2_grid)),
make_tuple(karg.p_c_grid_real_), make_tuple(karg.p_c_grid_real),
Subtract{}); Subtract{});
ave_time += launch_and_time_kernel(stream_config, ave_time += launch_and_time_kernel(stream_config,
...@@ -409,9 +397,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -409,9 +397,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
dim3(gdx, gdy, gdz), dim3(gdx, gdy, gdz),
dim3(BlockSize), dim3(BlockSize),
0, 0,
karg.p_a_grid_real_, karg.p_a_grid_real,
karg.p_b_grid_imag_, karg.p_b_grid_imag,
karg.p_aux_grid_, karg.p_aux_grid,
karg); karg);
ave_time += launch_and_time_kernel(stream_config, ave_time += launch_and_time_kernel(stream_config,
...@@ -419,9 +407,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -419,9 +407,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
dim3(gdx, gdy, gdz), dim3(gdx, gdy, gdz),
dim3(BlockSize), dim3(BlockSize),
0, 0,
karg.p_a_grid_imag_, karg.p_a_grid_imag,
karg.p_b_grid_real_, karg.p_b_grid_real,
karg.p_aux_2_grid_, karg.p_aux_2_grid,
karg); karg);
// c_imag = aux + aux_2 // c_imag = aux + aux_2
...@@ -431,11 +419,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -431,11 +419,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
dim3(gdx, gdy, gdz), dim3(gdx, gdy, gdz),
dim3(BlockSize), dim3(BlockSize),
0, 0,
make_tuple(karg.c_grid_desc_m_, karg.c_grid_desc_m_), make_tuple(karg.c_grid_desc_m, karg.c_grid_desc_m),
make_tuple(karg.c_grid_desc_m_), make_tuple(karg.c_grid_desc_m),
make_tuple(const_cast<const CDataType*>(karg.p_aux_grid_), make_tuple(const_cast<const CDataType*>(karg.p_aux_grid),
const_cast<const CDataType*>(karg.p_aux_2_grid_)), const_cast<const CDataType*>(karg.p_aux_2_grid)),
make_tuple(karg.p_c_grid_imag_), make_tuple(karg.p_c_grid_imag),
Add{}); Add{});
} }
...@@ -561,6 +549,14 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -561,6 +549,14 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
return str.str(); return str.str();
} }
static std::size_t GetCElementSpaceSize(index_t M, index_t N, index_t StrideC)
{
const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(
M, GridwiseGemm::CalculateMPadded(M), N, GridwiseGemm::CalculateNPadded(N), StrideC);
return c_grid_desc_m_n.GetElementSpaceSize();
}
std::size_t GetWorkspaceSize(index_t M, std::size_t GetWorkspaceSize(index_t M,
index_t N, index_t N,
[[maybe_unused]] index_t K, [[maybe_unused]] index_t K,
...@@ -568,10 +564,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -568,10 +564,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
[[maybe_unused]] index_t StrideB, [[maybe_unused]] index_t StrideB,
index_t StrideC) override index_t StrideC) override
{ {
const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N( return 2 * sizeof(CDataType) * GetCElementSpaceSize(M, N, StrideC);
M, GridwiseGemm::CalculateMPadded(M), N, GridwiseGemm::CalculateNPadded(N), StrideC);
return 2 * sizeof(CDataType) * c_grid_desc_m_n.GetElementSpaceSize();
} }
}; };
......
...@@ -143,17 +143,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -143,17 +143,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
index_t StrideA_, index_t StrideA_,
index_t StrideB_, index_t StrideB_,
index_t StrideC_) index_t StrideC_)
: Parent(M_, : Parent(M_, N_, K_, StrideA_, StrideB_, StrideC_),
N_,
K_,
StrideA_,
StrideB_,
StrideC_,
GridwiseGemm::CalculateMPadded(M_),
GridwiseGemm::CalculateNPadded(N_),
GridwiseGemm::CalculateKPadded(K_),
GridwiseGemm::CalculateAK0(K_),
GridwiseGemm::CalculateBK0(K_)),
p_a_grid{p_a_grid_}, p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_}, p_b_grid{p_b_grid_},
p_c_grid{p_c_grid_} p_c_grid{p_c_grid_}
...@@ -168,13 +158,11 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout, ...@@ -168,13 +158,11 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
// Invoker // Invoker
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
void Print(const Argument& karg) { karg.Print(); }
float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{})
{ {
if(stream_config.log_level_ > 0) if(stream_config.log_level_ > 0)
{ {
Print(karg); karg.Print();
} }
if(!GridwiseGemm::CheckValidity(karg)) if(!GridwiseGemm::CheckValidity(karg))
......
...@@ -98,21 +98,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -98,21 +98,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
static constexpr auto I7 = Number<7>{}; static constexpr auto I7 = Number<7>{};
// K1 should be Number<...> // K1 should be Number<...>
static constexpr auto AK0_c = Number<KPerBlock / AK1Value>{}; static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
static constexpr auto BK0_c = Number<KPerBlock / BK1Value>{}; static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
static constexpr auto AK1_c = Number<AK1Value>{}; static constexpr auto AK1Number = Number<AK1Value>{};
static constexpr auto BK1_c = Number<BK1Value>{}; static constexpr auto BK1Number = Number<BK1Value>{};
using FloatAB = FloatAB_; using FloatAB = FloatAB_;
using FloatC = FloatC_; using FloatC = FloatC_;
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
#if defined(INTEGER_DIVIDE_CEIL)
#error "macro INTEGER_DIVIDE_CEIL() was already defined somewhere else"
#endif
#define INTEGER_DIVIDE_CEIL(x, y) (((x) + (y)-1) / (y))
__host__ static auto CalculateGridSize(index_t M, index_t N) __host__ static auto CalculateGridSize(index_t M, index_t N)
{ {
return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, 1); return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, 1);
...@@ -120,19 +115,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -120,19 +115,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
__host__ static auto CalculateMPadded(index_t M) __host__ static auto CalculateMPadded(index_t M)
{ {
return INTEGER_DIVIDE_CEIL(M, MPerBlock) * MPerBlock; return math::integer_divide_ceil(M, MPerBlock) * MPerBlock;
} }
__host__ static auto CalculateNPadded(index_t N) __host__ static auto CalculateNPadded(index_t N)
{ {
return INTEGER_DIVIDE_CEIL(N, NPerBlock) * NPerBlock; return math::integer_divide_ceil(N, NPerBlock) * NPerBlock;
} }
__host__ static auto CalculateKPadded(index_t K) __host__ static auto CalculateKPadded(index_t K)
{ {
return INTEGER_DIVIDE_CEIL(K, KPerBlock) * KPerBlock; return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
} }
#undef INTEGER_DIVIDE_CEIL
__host__ static auto CalculateAK0(index_t K) __host__ static auto CalculateAK0(index_t K)
{ {
...@@ -143,14 +137,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -143,14 +137,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
GemmSpec == GemmSpecialization::KPadding || GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::NKPadding) GemmSpec == GemmSpecialization::NKPadding)
{ {
assert(CalculateKPadded(K) % AK1Value == 0);
return CalculateKPadded(K) / AK1Value; return CalculateKPadded(K) / AK1Value;
} }
else else
{ {
assert(K % AK1Value == 0);
return K / AK1Value; return K / AK1Value;
} }
} }
...@@ -164,19 +154,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -164,19 +154,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
GemmSpec == GemmSpecialization::KPadding || GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::MKPadding) GemmSpec == GemmSpecialization::MKPadding)
{ {
assert(CalculateKPadded(K) % BK1Value == 0);
return CalculateKPadded(K) / BK1Value; return CalculateKPadded(K) / BK1Value;
} }
else else
{ {
assert(K % BK1Value == 0);
return K / BK1Value; return K / BK1Value;
} }
} }
__host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1( __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0) index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
{ {
const auto a_grid_desc_mraw_kraw = [&]() { const auto a_grid_desc_mraw_kraw = [&]() {
...@@ -258,7 +244,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -258,7 +244,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
} }
} }
__host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1( __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0) index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
{ {
const auto b_grid_desc_nraw_kraw = [&]() { const auto b_grid_desc_nraw_kraw = [&]() {
...@@ -393,10 +379,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -393,10 +379,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
} }
} }
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1, 1, 1, 1));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1, 1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1, 1, 1));
// Argument // Argument
struct Argument : public tensor_operation::device::BaseArgument struct Argument : public tensor_operation::device::BaseArgument
{ {
...@@ -405,32 +387,22 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -405,32 +387,22 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
index_t K_, index_t K_,
index_t StrideA_, index_t StrideA_,
index_t StrideB_, index_t StrideB_,
index_t StrideC_, index_t StrideC_)
index_t MPadded_,
index_t NPadded_,
index_t KPadded_,
index_t AK0_,
index_t BK0_)
: M{M_}, : M{M_},
N{N_}, N{N_},
K{K_}, K{K_},
StrideA{StrideA_}, StrideA{StrideA_},
StrideB{StrideB_}, StrideB{StrideB_},
StrideC{StrideC_}, StrideC{StrideC_},
MPadded{MPadded_}, MPadded{CalculateMPadded(M_)},
NPadded{NPadded_}, NPadded{CalculateNPadded(N_)},
KPadded{KPadded_}, KPadded{CalculateKPadded(K_)},
AK0{AK0_}, AK0{CalculateAK0(K_)},
BK0{BK0_}, BK0{CalculateBK0(K_)}
a_grid_desc_ak0_m_ak1{
MakeAGridDescriptor_AK0_M_AK1(M_, MPadded_, K_, KPadded_, StrideA_, AK0_)},
b_grid_desc_bk0_n_bk1{
MakeBGridDescriptor_BK0_N_BK1(K_, KPadded_, N_, NPadded_, StrideB_, BK0_)},
c_grid_desc_m_n{MakeCGridDescriptor_M_N(M_, MPadded_, N_, NPadded_, StrideC_)}
{ {
} }
__host__ __device__ void Print() const __host__ void Print() const
{ {
std::cout << "arg {" std::cout << "arg {"
<< "M:" << M << ", " << "M:" << M << ", "
...@@ -446,10 +418,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -446,10 +418,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
<< "BK0:" << BK0 << "}" << std::endl; << "BK0:" << BK0 << "}" << std::endl;
} }
__host__ __device__ Argument(const Argument&) = default;
__host__ __device__ ~Argument() override {}
index_t M; index_t M;
index_t N; index_t N;
index_t K; index_t K;
...@@ -461,33 +429,29 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -461,33 +429,29 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
index_t KPadded; index_t KPadded;
index_t AK0; index_t AK0;
index_t BK0; index_t BK0;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1;
CGridDesc_M_N c_grid_desc_m_n;
}; };
// FIXME: pass GridwiseGemmPipe as a template arguement into GridwiseGemm // FIXME: pass GridwiseGemmPipe as a template arguement into GridwiseGemm
using GridwiseGemmPipe = remove_cvref_t<decltype( using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>; GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{ {
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(AK0_c, Number<MPerBlock>{}, AK1_c), make_tuple(AK0Number, Number<MPerBlock>{}, AK1Number),
make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1_c, AK1_c, I1)); make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1Number, AK1Number, I1));
} }
__host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{ {
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(BK0_c, Number<NPerBlock>{}, BK1_c), make_tuple(BK0Number, Number<NPerBlock>{}, BK1Number),
make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1_c, BK1_c, I1)); make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1Number, BK1Number, I1));
} }
__host__ __device__ static constexpr auto __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
{ {
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
...@@ -502,14 +466,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -502,14 +466,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
} }
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{ {
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// lds max alignment // lds max alignment
constexpr auto max_lds_align = math::lcm(AK1_c, BK1_c); constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
constexpr auto a_block_space_size_aligned = math::integer_least_multiple( constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
...@@ -530,7 +494,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -530,7 +494,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
} }
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
__host__ __device__ static constexpr bool CheckValidity(const Argument& karg) __host__ static constexpr bool CheckValidity(const Argument& karg)
{ {
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0, (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
...@@ -558,6 +522,25 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -558,6 +522,25 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
} }
} }
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding)
{
if(!(CalculateKPadded(karg.K) % AK1Value == 0) ||
!(CalculateKPadded(karg.K) % BK1Value == 0))
{
return false;
}
}
else
{
if(!(karg.K % AK1Value == 0) || !(karg.K % BK1Value == 0))
{
return false;
}
}
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{ {
if(karg.K % ABlockTransferSrcScalarPerVector != 0) if(karg.K % ABlockTransferSrcScalarPerVector != 0)
...@@ -615,7 +598,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -615,7 +598,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
return true; return true;
} }
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{ {
const index_t num_loop = K / KPerBlock; const index_t num_loop = K / KPerBlock;
...@@ -623,7 +606,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -623,7 +606,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
} }
template <typename CGridDesc> template <typename CGridDesc>
__host__ __device__ static constexpr auto __device__ static constexpr auto
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc& c_grid_desc_m_n) MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc& c_grid_desc_m_n)
{ {
const auto M = c_grid_desc_m_n.GetLength(I0); const auto M = c_grid_desc_m_n.GetLength(I0);
...@@ -645,33 +628,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -645,33 +628,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// return block_id to C matrix tile idx (m0, n0) mapping // return block_id to C matrix tile idx (m0, n0) mapping
using Block2CTileMap = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>; using Block2CTileMap = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>;
__host__ __device__ static void print_bytes(const uint8_t* memory, std::size_t size)
{
(void)memory;
(void)size;
for(std::size_t idx = 0; idx < size; ++idx)
{
if(idx % 10 == 0)
{
printf("\n");
}
printf("0x%02X ", static_cast<unsigned>(memory[idx]));
}
printf("\n");
}
template <typename T>
__host__ __device__ static void print_bytes(const T& obj)
{
uint8_t memory[sizeof(T)];
memcpy(memory, &obj, sizeof(T));
print_bytes(memory, sizeof(T));
}
template <bool HasMainKBlockLoop> template <bool HasMainKBlockLoop>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid, __device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
...@@ -684,22 +640,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -684,22 +640,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
asm volatile("; [POYENC] kernel start" ::); asm volatile("; [POYENC] kernel start" ::);
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
#define CREATE_DESCS_ON_HOST 1
#if CREATE_DESCS_ON_HOST
const auto a_grid_desc_ak0_m_ak1 = karg.a_grid_desc_ak0_m_ak1;
const auto b_grid_desc_bk0_n_bk1 = karg.b_grid_desc_bk0_n_bk1;
const auto c_grid_desc_m_n = karg.c_grid_desc_m_n;
#else
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
karg.M, karg.MPadded, karg.K, karg.KPadded, karg.StrideA, karg.AK0); karg.M, karg.MPadded, karg.K, karg.KPadded, karg.StrideA, karg.AK0);
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
karg.K, karg.KPadded, karg.N, karg.NPadded, karg.StrideB, karg.BK0); karg.K, karg.KPadded, karg.N, karg.NPadded, karg.StrideB, karg.BK0);
const auto c_grid_desc_m_n = const auto c_grid_desc_m_n =
MakeCGridDescriptor_M_N(karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideC); MakeCGridDescriptor_M_N(karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideC);
#endif
// if (blockIdx.x == 0 && threadIdx.x == 0) {
// print_bytes(a_grid_desc_ak0_m_ak1);
// }
const auto c_grid_desc_mblock_mperblock_nblock_nperblock = const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n); MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n);
...@@ -737,7 +683,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -737,7 +683,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
// lds max alignment // lds max alignment
constexpr auto max_lds_align = math::lcm(AK1_c, BK1_c); constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
...@@ -751,7 +697,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -751,7 +697,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
AElementwiseOperation, AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
Sequence<AK0_c, MPerBlock, AK1_c>, Sequence<AK0Number, MPerBlock, AK1Number>,
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
...@@ -782,7 +728,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -782,7 +728,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
BElementwiseOperation, BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
Sequence<BK0_c, NPerBlock, BK1_c>, Sequence<BK0Number, NPerBlock, BK1Number>,
BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
...@@ -815,7 +761,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -815,7 +761,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// register // register
// sanity check // sanity check
constexpr index_t KPack = constexpr index_t KPack =
math::max(math::lcm(AK1_c, BK1_c), math::max(math::lcm(AK1Number, BK1Number),
MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
...@@ -844,8 +790,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -844,8 +790,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned, static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned,
b_block_desc_bk0_n_bk1.GetElementSpaceSize()); b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1_c, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1_c, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
// gridwise GEMM pipeline // gridwise GEMM pipeline
static_assert(std::is_default_constructible_v<GridwiseGemmPipe>); static_assert(std::is_default_constructible_v<GridwiseGemmPipe>);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment