Commit 9acad4f4 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Merge branch 'feature/integrage-karg-simplification-pr' into...

Merge branch 'feature/integrage-karg-simplification-pr' into feature/simplify-karg-for-device-gemm-xdl
parents 312c581b 6570ef7a
......@@ -18,7 +18,7 @@ struct BaseArgument
BaseArgument(const BaseArgument&) = default;
BaseArgument& operator=(const BaseArgument&) = default;
__host__ __device__ virtual ~BaseArgument() {}
virtual ~BaseArgument() {}
void* p_workspace_ = nullptr;
};
......
......@@ -172,12 +172,12 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
{
using Parent = typename GridwiseGemm::Argument;
Argument(const ADataType* p_a_grid_real,
const ADataType* p_a_grid_imag,
const BDataType* p_b_grid_real,
const BDataType* p_b_grid_imag,
CDataType* p_c_grid_real,
CDataType* p_c_grid_imag,
Argument(const ADataType* p_a_grid_real_,
const ADataType* p_a_grid_imag_,
const BDataType* p_b_grid_real_,
const BDataType* p_b_grid_imag_,
CDataType* p_c_grid_real_,
CDataType* p_c_grid_imag_,
CDataType* p_workspace,
index_t M_,
index_t N_,
......@@ -185,63 +185,51 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
index_t StrideA_,
index_t StrideB_,
index_t StrideC_)
: Parent(M_,
N_,
K_,
StrideA_,
StrideB_,
StrideC_,
GridwiseGemm::CalculateMPadded(M_),
GridwiseGemm::CalculateNPadded(N_),
GridwiseGemm::CalculateKPadded(K_),
GridwiseGemm::CalculateAK0(K_),
GridwiseGemm::CalculateBK0(K_)),
p_a_grid_real_{p_a_grid_real},
p_a_grid_imag_{p_a_grid_imag},
p_b_grid_real_{p_b_grid_real},
p_b_grid_imag_{p_b_grid_imag},
p_c_grid_real_{p_c_grid_real},
p_c_grid_imag_{p_c_grid_imag},
p_aux_grid_{p_workspace}
: Parent(M_, N_, K_, StrideA_, StrideB_, StrideC_),
p_a_grid_real{p_a_grid_real_},
p_a_grid_imag{p_a_grid_imag_},
p_b_grid_real{p_b_grid_real_},
p_b_grid_imag{p_b_grid_imag_},
p_c_grid_real{p_c_grid_real_},
p_c_grid_imag{p_c_grid_imag_},
p_aux_grid{p_workspace}
{
const index_t grid_size = std::get<1>(GridwiseGemm::CalculateGridSize(M_, N_));
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{
c_grid_desc_m_ =
c_grid_desc_m =
DeviceOp::MakeDescriptor_M({M_, N_}, {StrideC_, I1}, grid_size, BlockSize);
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
{
c_grid_desc_m_ =
c_grid_desc_m =
DeviceOp::MakeDescriptor_M({M_, N_}, {I1, StrideC_}, grid_size, BlockSize);
}
p_aux_2_grid_ = p_workspace + Parent::c_grid_desc_m_n.GetElementSpaceSize();
p_aux_2_grid = p_workspace + GetCElementSpaceSize(M_, N_, StrideC_);
}
// private:
const ADataType* p_a_grid_real_;
const ADataType* p_a_grid_imag_;
const BDataType* p_b_grid_real_;
const BDataType* p_b_grid_imag_;
CDataType* p_c_grid_real_;
CDataType* p_c_grid_imag_;
CDataType* p_aux_grid_;
CDataType* p_aux_2_grid_;
CGridDesc_M c_grid_desc_m_;
const ADataType* p_a_grid_real;
const ADataType* p_a_grid_imag;
const BDataType* p_b_grid_real;
const BDataType* p_b_grid_imag;
CDataType* p_c_grid_real;
CDataType* p_c_grid_imag;
CDataType* p_aux_grid;
CDataType* p_aux_2_grid;
CGridDesc_M c_grid_desc_m;
};
// Invoker
struct Invoker : public BaseInvoker
{
void Print(const Argument& karg) { karg.Print(); }
float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{})
{
if(stream_config.log_level_ > 0)
{
Print(karg);
karg.Print();
}
if(!GridwiseGemm::CheckValidity(karg))
......@@ -303,9 +291,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
karg.p_a_grid_real_,
karg.p_b_grid_real_,
karg.p_aux_grid_,
karg.p_a_grid_real,
karg.p_b_grid_real,
karg.p_aux_grid,
karg);
ave_time += launch_and_time_kernel(stream_config,
......@@ -313,9 +301,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
karg.p_a_grid_imag_,
karg.p_b_grid_imag_,
karg.p_aux_2_grid_,
karg.p_a_grid_imag,
karg.p_b_grid_imag,
karg.p_aux_2_grid,
karg);
// c_real = aux - aux_2
......@@ -325,11 +313,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
make_tuple(karg.c_grid_desc_m_, karg.c_grid_desc_m_),
make_tuple(karg.c_grid_desc_m_),
make_tuple(const_cast<const CDataType*>(karg.p_aux_grid_),
const_cast<const CDataType*>(karg.p_aux_2_grid_)),
make_tuple(karg.p_c_grid_real_),
make_tuple(karg.c_grid_desc_m, karg.c_grid_desc_m),
make_tuple(karg.c_grid_desc_m),
make_tuple(const_cast<const CDataType*>(karg.p_aux_grid),
const_cast<const CDataType*>(karg.p_aux_2_grid)),
make_tuple(karg.p_c_grid_real),
Subtract{});
ave_time += launch_and_time_kernel(stream_config,
......@@ -337,9 +325,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
karg.p_a_grid_real_,
karg.p_b_grid_imag_,
karg.p_aux_grid_,
karg.p_a_grid_real,
karg.p_b_grid_imag,
karg.p_aux_grid,
karg);
ave_time += launch_and_time_kernel(stream_config,
......@@ -347,9 +335,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
karg.p_a_grid_imag_,
karg.p_b_grid_real_,
karg.p_aux_2_grid_,
karg.p_a_grid_imag,
karg.p_b_grid_real,
karg.p_aux_2_grid,
karg);
// c_imag = aux + aux_2
......@@ -359,11 +347,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
make_tuple(karg.c_grid_desc_m_, karg.c_grid_desc_m_),
make_tuple(karg.c_grid_desc_m_),
make_tuple(const_cast<const CDataType*>(karg.p_aux_grid_),
const_cast<const CDataType*>(karg.p_aux_2_grid_)),
make_tuple(karg.p_c_grid_imag_),
make_tuple(karg.c_grid_desc_m, karg.c_grid_desc_m),
make_tuple(karg.c_grid_desc_m),
make_tuple(const_cast<const CDataType*>(karg.p_aux_grid),
const_cast<const CDataType*>(karg.p_aux_2_grid)),
make_tuple(karg.p_c_grid_imag),
Add{});
}
else
......@@ -375,9 +363,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
karg.p_a_grid_real_,
karg.p_b_grid_real_,
karg.p_aux_grid_,
karg.p_a_grid_real,
karg.p_b_grid_real,
karg.p_aux_grid,
karg);
ave_time += launch_and_time_kernel(stream_config,
......@@ -385,9 +373,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
karg.p_a_grid_imag_,
karg.p_b_grid_imag_,
karg.p_aux_2_grid_,
karg.p_a_grid_imag,
karg.p_b_grid_imag,
karg.p_aux_2_grid,
karg);
// c_real = aux - aux_2
......@@ -397,11 +385,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
make_tuple(karg.c_grid_desc_m_, karg.c_grid_desc_m_),
make_tuple(karg.c_grid_desc_m_),
make_tuple(const_cast<const CDataType*>(karg.p_aux_grid_),
const_cast<const CDataType*>(karg.p_aux_2_grid_)),
make_tuple(karg.p_c_grid_real_),
make_tuple(karg.c_grid_desc_m, karg.c_grid_desc_m),
make_tuple(karg.c_grid_desc_m),
make_tuple(const_cast<const CDataType*>(karg.p_aux_grid),
const_cast<const CDataType*>(karg.p_aux_2_grid)),
make_tuple(karg.p_c_grid_real),
Subtract{});
ave_time += launch_and_time_kernel(stream_config,
......@@ -409,9 +397,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
karg.p_a_grid_real_,
karg.p_b_grid_imag_,
karg.p_aux_grid_,
karg.p_a_grid_real,
karg.p_b_grid_imag,
karg.p_aux_grid,
karg);
ave_time += launch_and_time_kernel(stream_config,
......@@ -419,9 +407,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
karg.p_a_grid_imag_,
karg.p_b_grid_real_,
karg.p_aux_2_grid_,
karg.p_a_grid_imag,
karg.p_b_grid_real,
karg.p_aux_2_grid,
karg);
// c_imag = aux + aux_2
......@@ -431,11 +419,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
make_tuple(karg.c_grid_desc_m_, karg.c_grid_desc_m_),
make_tuple(karg.c_grid_desc_m_),
make_tuple(const_cast<const CDataType*>(karg.p_aux_grid_),
const_cast<const CDataType*>(karg.p_aux_2_grid_)),
make_tuple(karg.p_c_grid_imag_),
make_tuple(karg.c_grid_desc_m, karg.c_grid_desc_m),
make_tuple(karg.c_grid_desc_m),
make_tuple(const_cast<const CDataType*>(karg.p_aux_grid),
const_cast<const CDataType*>(karg.p_aux_2_grid)),
make_tuple(karg.p_c_grid_imag),
Add{});
}
......@@ -561,6 +549,14 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
return str.str();
}
static std::size_t GetCElementSpaceSize(index_t M, index_t N, index_t StrideC)
{
const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(
M, GridwiseGemm::CalculateMPadded(M), N, GridwiseGemm::CalculateNPadded(N), StrideC);
return c_grid_desc_m_n.GetElementSpaceSize();
}
std::size_t GetWorkspaceSize(index_t M,
index_t N,
[[maybe_unused]] index_t K,
......@@ -568,10 +564,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
[[maybe_unused]] index_t StrideB,
index_t StrideC) override
{
const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(
M, GridwiseGemm::CalculateMPadded(M), N, GridwiseGemm::CalculateNPadded(N), StrideC);
return 2 * sizeof(CDataType) * c_grid_desc_m_n.GetElementSpaceSize();
return 2 * sizeof(CDataType) * GetCElementSpaceSize(M, N, StrideC);
}
};
......
......@@ -143,17 +143,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_)
: Parent(M_,
N_,
K_,
StrideA_,
StrideB_,
StrideC_,
GridwiseGemm::CalculateMPadded(M_),
GridwiseGemm::CalculateNPadded(N_),
GridwiseGemm::CalculateKPadded(K_),
GridwiseGemm::CalculateAK0(K_),
GridwiseGemm::CalculateBK0(K_)),
: Parent(M_, N_, K_, StrideA_, StrideB_, StrideC_),
p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_c_grid{p_c_grid_}
......@@ -168,13 +158,11 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
// Invoker
struct Invoker : public BaseInvoker
{
void Print(const Argument& karg) { karg.Print(); }
float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{})
{
if(stream_config.log_level_ > 0)
{
Print(karg);
karg.Print();
}
if(!GridwiseGemm::CheckValidity(karg))
......
......@@ -98,21 +98,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
static constexpr auto I7 = Number<7>{};
// K1 should be Number<...>
static constexpr auto AK0_c = Number<KPerBlock / AK1Value>{};
static constexpr auto BK0_c = Number<KPerBlock / BK1Value>{};
static constexpr auto AK1_c = Number<AK1Value>{};
static constexpr auto BK1_c = Number<BK1Value>{};
static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
static constexpr auto AK1Number = Number<AK1Value>{};
static constexpr auto BK1Number = Number<BK1Value>{};
using FloatAB = FloatAB_;
using FloatC = FloatC_;
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
#if defined(INTEGER_DIVIDE_CEIL)
#error "macro INTEGER_DIVIDE_CEIL() was already defined somewhere else"
#endif
#define INTEGER_DIVIDE_CEIL(x, y) (((x) + (y)-1) / (y))
__host__ static auto CalculateGridSize(index_t M, index_t N)
{
return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, 1);
......@@ -120,19 +115,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
__host__ static auto CalculateMPadded(index_t M)
{
return INTEGER_DIVIDE_CEIL(M, MPerBlock) * MPerBlock;
return math::integer_divide_ceil(M, MPerBlock) * MPerBlock;
}
__host__ static auto CalculateNPadded(index_t N)
{
return INTEGER_DIVIDE_CEIL(N, NPerBlock) * NPerBlock;
return math::integer_divide_ceil(N, NPerBlock) * NPerBlock;
}
__host__ static auto CalculateKPadded(index_t K)
{
return INTEGER_DIVIDE_CEIL(K, KPerBlock) * KPerBlock;
return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
}
#undef INTEGER_DIVIDE_CEIL
__host__ static auto CalculateAK0(index_t K)
{
......@@ -143,14 +137,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
assert(CalculateKPadded(K) % AK1Value == 0);
return CalculateKPadded(K) / AK1Value;
}
else
{
assert(K % AK1Value == 0);
return K / AK1Value;
}
}
......@@ -164,19 +154,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
assert(CalculateKPadded(K) % BK1Value == 0);
return CalculateKPadded(K) / BK1Value;
}
else
{
assert(K % BK1Value == 0);
return K / BK1Value;
}
}
__host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
__device__ static auto MakeAGridDescriptor_AK0_M_AK1(
index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
{
const auto a_grid_desc_mraw_kraw = [&]() {
......@@ -258,7 +244,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
}
}
__host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
__device__ static auto MakeBGridDescriptor_BK0_N_BK1(
index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
{
const auto b_grid_desc_nraw_kraw = [&]() {
......@@ -393,10 +379,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
}
}
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1, 1, 1, 1));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1, 1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1, 1, 1));
// Argument
struct Argument : public tensor_operation::device::BaseArgument
{
......@@ -405,32 +387,22 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_,
index_t MPadded_,
index_t NPadded_,
index_t KPadded_,
index_t AK0_,
index_t BK0_)
index_t StrideC_)
: M{M_},
N{N_},
K{K_},
StrideA{StrideA_},
StrideB{StrideB_},
StrideC{StrideC_},
MPadded{MPadded_},
NPadded{NPadded_},
KPadded{KPadded_},
AK0{AK0_},
BK0{BK0_},
a_grid_desc_ak0_m_ak1{
MakeAGridDescriptor_AK0_M_AK1(M_, MPadded_, K_, KPadded_, StrideA_, AK0_)},
b_grid_desc_bk0_n_bk1{
MakeBGridDescriptor_BK0_N_BK1(K_, KPadded_, N_, NPadded_, StrideB_, BK0_)},
c_grid_desc_m_n{MakeCGridDescriptor_M_N(M_, MPadded_, N_, NPadded_, StrideC_)}
MPadded{CalculateMPadded(M_)},
NPadded{CalculateNPadded(N_)},
KPadded{CalculateKPadded(K_)},
AK0{CalculateAK0(K_)},
BK0{CalculateBK0(K_)}
{
}
__host__ __device__ void Print() const
__host__ void Print() const
{
std::cout << "arg {"
<< "M:" << M << ", "
......@@ -446,10 +418,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
<< "BK0:" << BK0 << "}" << std::endl;
}
__host__ __device__ Argument(const Argument&) = default;
__host__ __device__ ~Argument() override {}
index_t M;
index_t N;
index_t K;
......@@ -461,33 +429,29 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
index_t KPadded;
index_t AK0;
index_t BK0;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1;
CGridDesc_M_N c_grid_desc_m_n;
};
// FIXME: pass GridwiseGemmPipe as a template arguement into GridwiseGemm
using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
__device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
// A matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(AK0_c, Number<MPerBlock>{}, AK1_c),
make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1_c, AK1_c, I1));
make_tuple(AK0Number, Number<MPerBlock>{}, AK1Number),
make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1Number, AK1Number, I1));
}
__host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
__device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{
// B matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(BK0_c, Number<NPerBlock>{}, BK1_c),
make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1_c, BK1_c, I1));
make_tuple(BK0Number, Number<NPerBlock>{}, BK1Number),
make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1Number, BK1Number, I1));
}
__host__ __device__ static constexpr auto
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
__device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
{
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
......@@ -502,14 +466,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
}
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
__device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// lds max alignment
constexpr auto max_lds_align = math::lcm(AK1_c, BK1_c);
constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
......@@ -530,7 +494,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
__host__ __device__ static constexpr bool CheckValidity(const Argument& karg)
__host__ static constexpr bool CheckValidity(const Argument& karg)
{
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
......@@ -558,6 +522,25 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
}
}
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding)
{
if(!(CalculateKPadded(karg.K) % AK1Value == 0) ||
!(CalculateKPadded(karg.K) % BK1Value == 0))
{
return false;
}
}
else
{
if(!(karg.K % AK1Value == 0) || !(karg.K % BK1Value == 0))
{
return false;
}
}
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
if(karg.K % ABlockTransferSrcScalarPerVector != 0)
......@@ -615,7 +598,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
return true;
}
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
__host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{
const index_t num_loop = K / KPerBlock;
......@@ -623,7 +606,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
}
template <typename CGridDesc>
__host__ __device__ static constexpr auto
__device__ static constexpr auto
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
......@@ -645,33 +628,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// return block_id to C matrix tile idx (m0, n0) mapping
using Block2CTileMap = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>;
__host__ __device__ static void print_bytes(const uint8_t* memory, std::size_t size)
{
(void)memory;
(void)size;
for(std::size_t idx = 0; idx < size; ++idx)
{
if(idx % 10 == 0)
{
printf("\n");
}
printf("0x%02X ", static_cast<unsigned>(memory[idx]));
}
printf("\n");
}
template <typename T>
__host__ __device__ static void print_bytes(const T& obj)
{
uint8_t memory[sizeof(T)];
memcpy(memory, &obj, sizeof(T));
print_bytes(memory, sizeof(T));
}
template <bool HasMainKBlockLoop>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
......@@ -684,22 +640,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
asm volatile("; [POYENC] kernel start" ::);
__builtin_amdgcn_sched_barrier(0);
#define CREATE_DESCS_ON_HOST 1
#if CREATE_DESCS_ON_HOST
const auto a_grid_desc_ak0_m_ak1 = karg.a_grid_desc_ak0_m_ak1;
const auto b_grid_desc_bk0_n_bk1 = karg.b_grid_desc_bk0_n_bk1;
const auto c_grid_desc_m_n = karg.c_grid_desc_m_n;
#else
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
karg.M, karg.MPadded, karg.K, karg.KPadded, karg.StrideA, karg.AK0);
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
karg.K, karg.KPadded, karg.N, karg.NPadded, karg.StrideB, karg.BK0);
const auto c_grid_desc_m_n =
MakeCGridDescriptor_M_N(karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideC);
#endif
// if (blockIdx.x == 0 && threadIdx.x == 0) {
// print_bytes(a_grid_desc_ak0_m_ak1);
// }
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n);
......@@ -737,7 +683,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
// lds max alignment
constexpr auto max_lds_align = math::lcm(AK1_c, BK1_c);
constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
......@@ -751,7 +697,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0_c, MPerBlock, AK1_c>,
Sequence<AK0Number, MPerBlock, AK1Number>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
......@@ -782,7 +728,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<BK0_c, NPerBlock, BK1_c>,
Sequence<BK0Number, NPerBlock, BK1Number>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
......@@ -815,7 +761,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// register
// sanity check
constexpr index_t KPack =
math::max(math::lcm(AK1_c, BK1_c),
math::max(math::lcm(AK1Number, BK1Number),
MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
......@@ -844,8 +790,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned,
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1_c, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1_c, 0, 0);
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
// gridwise GEMM pipeline
static_assert(std::is_default_constructible_v<GridwiseGemmPipe>);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment