Commit 0cf90eaf authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Remove unnesscary type parameters

parent 670ce6b9
......@@ -299,6 +299,9 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
// GridwiseGemm
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1<
ALayout,
BLayout,
CLayout,
ADataType, // TODO: distinguish A/B datatype
GemmAccDataType,
CShuffleDataType,
......@@ -308,9 +311,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
CElementwiseOperation,
GemmSpec,
InMemoryDataOperationEnum::Set,
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
CGridDesc_M_N,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
......@@ -363,21 +363,30 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
M{M_},
N{N_},
K{K_},
StrideA{StrideA_},
StrideB{StrideB_},
StrideC{StrideC_},
MPadded{GridwiseGemm::CalculateMPadded(M_)},
NPadded{GridwiseGemm::CalculateNPadded(N_)},
KPadded{GridwiseGemm::CalculateKPadded(K_)},
AK0{GridwiseGemm::CalculateAK0(K_)},
BK0{GridwiseGemm::CalculateBK0(K_)},
a_grid_desc_ak0_m_ak1{
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(M_,
GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(M_,
GridwiseGemm::CalculateMPadded(M_),
K_,
GridwiseGemm::CalculateKPadded(K_),
StrideA_,
GridwiseGemm::CalculateAK0(K_))},
b_grid_desc_bk0_n_bk1{
DeviceOp::MakeBGridDescriptor_BK0_N_BK1(K_,
GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(K_,
GridwiseGemm::CalculateKPadded(K_),
N_,
GridwiseGemm::CalculateNPadded(N_),
StrideB_,
GridwiseGemm::CalculateBK0(K_))},
c_grid_desc_m_n{DeviceOp::MakeCGridDescriptor_M_N(M_,
c_grid_desc_m_n{
GridwiseGemm::MakeCGridDescriptor_M_N(M_,
GridwiseGemm::CalculateMPadded(M_),
N_,
GridwiseGemm::CalculateNPadded(N_),
......@@ -385,7 +394,58 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
{
}
__host__ __device__ Argument(const Argument&) = default;
__host__ __device__ void Print() const
{
printf("arg {M: %d, N: %d, K: %d, "
"SA: %d, SB: %d, SC: %d, "
"MP: %d, NP: %d, KP: %d, "
"AK0: %d, BK0: %d}\n",
M,
N,
K,
StrideA,
StrideB,
StrideC,
MPadded,
NPadded,
KPadded,
AK0,
BK0);
// std::cout << "arg {"
// << "M:" << M << ", "
// << "N:" << N << ", "
// << "K:" << K << ", "
// << "SA:" << StrideA << ", "
// << "SB:" << StrideB << ", "
// << "SC:" << StrideC << ", "
// << "MP:" << MPadded << ", "
// << "NP:" << NPadded << ", "
// << "KP:" << KPadded << ", "
// << "AK0:" << AK0 << ", "
// << "BK0:" << BK0 << "}" << std::endl;
}
__host__ __device__ Argument(const Argument& other)
: p_a_grid{other.p_a_grid},
p_b_grid{other.p_b_grid},
p_c_grid{other.p_c_grid},
M{other.M},
N{other.N},
K{other.K},
StrideA{other.StrideA},
StrideB{other.StrideB},
StrideC{other.StrideC},
MPadded{other.MPadded},
NPadded{other.NPadded},
KPadded{other.KPadded},
AK0{other.AK0},
BK0{other.BK0},
a_grid_desc_ak0_m_ak1{other.a_grid_desc_ak0_m_ak1},
b_grid_desc_bk0_n_bk1{other.b_grid_desc_bk0_n_bk1},
c_grid_desc_m_n{other.c_grid_desc_m_n}
{
}
__host__ __device__ ~Argument() override {}
......@@ -396,6 +456,14 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
index_t M;
index_t N;
index_t K;
index_t StrideA;
index_t StrideB;
index_t StrideC;
index_t MPadded;
index_t NPadded;
index_t KPadded;
index_t AK0;
index_t BK0;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1;
CGridDesc_M_N c_grid_desc_m_n;
......@@ -406,28 +474,16 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
{
using Argument = DeviceOp::Argument;
void Print(const Argument& karg) { karg.Print(); }
float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{})
{
#if DEBUG_LOG
if(stream_config.log_level_ > 0)
{
// std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
// << karg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
// << karg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", "
// << karg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl;
// std::cout << "arg.b_grid_desc_bk0_n_bk1_{"
// << karg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", "
// << karg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", "
// << karg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl;
// std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ",
// "
// << karg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
Print(karg);
}
#endif
if(!GridwiseGemm::CheckValidity(
karg.a_grid_desc_ak0_m_ak1, karg.b_grid_desc_bk0_n_bk1, karg.c_grid_desc_m_n))
if(!GridwiseGemm::CheckValidity(karg))
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
......@@ -441,14 +497,16 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
const auto kernel = kernel_gemm_xdl_cshuffle_v1<GridwiseGemm, Argument, true>;
const auto kernel =
kernel_gemm_xdl_cshuffle_v1_simplified<GridwiseGemm, Argument, true>;
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
}
else
{
const auto kernel = kernel_gemm_xdl_cshuffle_v1<GridwiseGemm, Argument, false>;
const auto kernel =
kernel_gemm_xdl_cshuffle_v1_simplified<GridwiseGemm, Argument, false>;
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
}
......@@ -472,22 +530,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
static bool IsSupportedArgument(const Argument& karg)
{
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
{
return false;
}
if((karg.K % AK1 != 0 || karg.K % BK1 != 0) &&
!(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding ||
GemmSpec == GemmSpecialization::KPadding))
{
return false;
}
return GridwiseGemm::CheckValidity(
karg.a_grid_desc_ak0_m_ak1, karg.b_grid_desc_bk0_n_bk1, karg.c_grid_desc_m_n);
return GridwiseGemm::CheckValidity(karg);
}
// polymorphic
......
......@@ -22,7 +22,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_xdl_cshuffle_v1(const Argument karg)
kernel_gemm_xdl_cshuffle_v1_simplified(Argument karg)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
......@@ -33,7 +33,10 @@ __global__ void
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template <typename FloatAB,
template <typename ALayout,
typename BLayout,
typename CLayout,
typename FloatAB,
typename FloatGemmAcc,
typename FloatCShuffle,
typename FloatC,
......@@ -42,9 +45,6 @@ template <typename FloatAB,
typename CElementwiseOperation,
tensor_operation::device::GemmSpecialization GemmSpec,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename CGridDesc_M_N,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
......@@ -90,10 +90,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
static constexpr auto I7 = Number<7>{};
// K1 should be Number<...>
static constexpr auto AK0 = Number<KPerBlock / AK1Value>{};
static constexpr auto BK0 = Number<KPerBlock / BK1Value>{};
static constexpr auto AK1 = Number<AK1Value>{};
static constexpr auto BK1 = Number<BK1Value>{};
static constexpr auto AK0_ = Number<KPerBlock / AK1Value>{};
static constexpr auto BK0_ = Number<KPerBlock / BK1Value>{};
static constexpr auto AK1_ = Number<AK1Value>{};
static constexpr auto BK1_ = Number<BK1Value>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
......@@ -102,29 +102,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
#endif
#define INTEGER_DIVIDE_CEIL(x, y) (((x) + (y)-1) / (y))
__host__ __device__ static auto CalculateGridSize(index_t M, index_t N)
__host__ static auto CalculateGridSize(index_t M, index_t N)
{
// reference the implementation of class 'BlockToCTileMap_M00_N0_M01Adapt'
return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, 1);
}
__host__ __device__ static auto CalculateMPadded(index_t M)
__host__ static auto CalculateMPadded(index_t M)
{
return INTEGER_DIVIDE_CEIL(M, MPerBlock) * MPerBlock;
}
__host__ __device__ static auto CalculateNPadded(index_t N)
__host__ static auto CalculateNPadded(index_t N)
{
return INTEGER_DIVIDE_CEIL(N, NPerBlock) * NPerBlock;
}
__host__ __device__ static auto CalculateKPadded(index_t K)
__host__ static auto CalculateKPadded(index_t K)
{
return INTEGER_DIVIDE_CEIL(K, KPerBlock) * KPerBlock;
}
#undef INTEGER_DIVIDE_CEIL
__host__ __device__ static auto CalculateAK0(index_t K)
__host__ static auto CalculateAK0(index_t K)
{
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
......@@ -133,19 +132,19 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
assert(CalculateKPadded(K) % AK1 == 0);
assert(CalculateKPadded(K) % AK1Value == 0);
return CalculateKPadded(K) / AK1;
return CalculateKPadded(K) / AK1Value;
}
else
{
assert(K % AK1 == 0);
assert(K % AK1Value == 0);
return K / AK1;
return K / AK1Value;
}
}
__host__ __device__ static auto CalculateBK0(index_t K)
__host__ static auto CalculateBK0(index_t K)
{
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
......@@ -154,15 +153,232 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
assert(CalculateKPadded(K) % BK1 == 0);
assert(CalculateKPadded(K) % BK1Value == 0);
return CalculateKPadded(K) / BK1;
return CalculateKPadded(K) / BK1Value;
}
else
{
assert(K % BK1 == 0);
assert(K % BK1Value == 0);
return K / BK1;
return K / BK1Value;
}
}
__host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
{
const auto a_grid_desc_mraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
}
}();
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad both M and K
const auto a_grid_desc_m_k =
transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_right_pad_transform(M, MPad - M),
make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
make_pass_through_transform(MPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MNPadding)
{
// pad M, but not K
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
make_right_pad_transform(M, MPad - M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad K, but not M
const auto a_grid_desc_m_k = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
else
{
// not pad M or K
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
}
__host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
{
const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
}
}();
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad both N and K
const auto b_grid_desc_n_k =
transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_right_pad_transform(N, NPad - N),
make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
make_pass_through_transform(NPad)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::MNPadding)
{
// pad N, but not K
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad K, but not N
const auto b_grid_desc_n_k = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
else
{
// not pad N or K
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return b_grid_desc_bk0_n_bk1;
}
}
__host__ __device__ static auto
MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
{
const auto c_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
}
}();
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
GemmSpec == GemmSpecialization::MNKPadding)
{
// pad M and N
return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(M, MPad - M),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
GemmSpec == GemmSpecialization::MKPadding)
{
// pad M, but not N
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
make_tuple(make_right_pad_transform(M, MPad - M), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
GemmSpec == GemmSpecialization::NKPadding)
{
// pad N, but not M
return transform_tensor_descriptor(
c_grid_desc_mraw_nraw,
make_tuple(make_pass_through_transform(M), make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
// not pad M or N
return c_grid_desc_mraw_nraw;
}
}
......@@ -174,16 +390,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{
// A matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(AK0, Number<MPerBlock>{}, AK1),
make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1, AK1, I1));
make_tuple(AK0_, Number<MPerBlock>{}, AK1_),
make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1_, AK1_, I1));
}
__host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{
// B matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(BK0, Number<NPerBlock>{}, BK1),
make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1, BK1, I1));
make_tuple(BK0_, Number<NPerBlock>{}, BK1_),
make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1_, BK1_, I1));
}
__host__ __device__ static constexpr auto
......@@ -209,7 +425,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// lds max alignment
constexpr auto max_lds_align = math::lcm(AK1, BK1);
constexpr auto max_lds_align = math::lcm(AK1_, BK1_);
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
......@@ -230,27 +446,82 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
__host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const CGridDesc_M_N& c_grid_desc_m_n)
template <typename Argument>
__host__ __device__ static constexpr bool CheckValidity(const Argument& karg)
{
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
"Invalid tuning param!");
const auto M = a_grid_desc_ak0_m_ak1.GetLength(I1);
const auto N = b_grid_desc_bk0_n_bk1.GetLength(I1);
const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2);
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
{
if(!(karg.M % MPerBlock == 0))
{
return false;
}
}
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
{
if(!(karg.N % NPerBlock == 0))
{
return false;
}
}
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
if(karg.K % ABlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else
{
if(karg.M % ABlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1)))
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
if(karg.N % BBlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else
{
if(karg.K % BBlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{
if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
{
return false;
}
}
else
{
if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
{
return false;
}
}
// check gridwise gemm pipeline
const auto num_k_loop = K / KPerBlock;
const auto num_k_loop = (CalculateAK0(karg.K) * AK1Value) / KPerBlock;
if(!GridwiseGemmPipe::IsSupported(num_k_loop))
{
......@@ -268,8 +539,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
}
template <typename CGridDesc>
__host__ __device__ static constexpr auto
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n)
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
......@@ -288,28 +560,66 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
}
// return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto
MakeBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n)
using Block2CTileMap = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>;
__host__ __device__ static void print_bytes(const uint8_t* memory, std::size_t size)
{
(void)memory;
(void)size;
for(std::size_t idx = 0; idx < size; ++idx)
{
if(idx % 10 == 0)
{
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>(
c_grid_desc_m_n);
printf("\n");
}
printf("0x%02X ", static_cast<unsigned>(memory[idx]));
}
printf("\n");
}
template <typename T>
__host__ __device__ static void print_bytes(const T& obj)
{
uint8_t memory[sizeof(T)];
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>;
memcpy(memory, &obj, sizeof(T));
using Block2CTileMap = remove_cvref_t<decltype(MakeBlock2CTileMap(CGridDesc_M_N{}))>;
print_bytes(memory, sizeof(T));
}
template <bool HasMainKBlockLoop, typename Argument>
__device__ static void Run(const Argument karg, void* __restrict__ p_shared)
__device__ static void Run(const Argument& karg, void* __restrict__ p_shared)
{
const FloatAB* p_a_grid = karg.p_a_grid;
const FloatAB* p_b_grid = karg.p_b_grid;
FloatC* p_c_grid = karg.p_c_grid;
const auto& a_grid_desc_ak0_m_ak1 = karg.a_grid_desc_ak0_m_ak1;
const auto& b_grid_desc_bk0_n_bk1 = karg.b_grid_desc_bk0_n_bk1;
const auto& c_grid_desc_m_n = karg.c_grid_desc_m_n;
const auto a_grid_desc_ak0_m_ak1 = karg.a_grid_desc_ak0_m_ak1;
const auto b_grid_desc_bk0_n_bk1 = karg.b_grid_desc_bk0_n_bk1;
const auto c_grid_desc_m_n = karg.c_grid_desc_m_n;
// const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(karg.M,
// karg.MPadded,
// karg.K,
// karg.KPadded,
// karg.StrideA,
// karg.AK0);
// const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(karg.K,
// karg.KPadded,
// karg.N,
// karg.NPadded,
// karg.StrideB,
// karg.BK0);
// const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(karg.M,
// karg.MPadded,
// karg.N,
// karg.NPadded,
// karg.StrideC);
// if (blockIdx.x == 0 && threadIdx.x == 0) {
// print_bytes(a_grid_desc_ak0_m_ak1);
// }
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n);
......@@ -326,7 +636,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
const CElementwiseOperation c_element_op{};
// divide block work by [M, N]
const auto block_2_ctile_map = MakeBlock2CTileMap(c_grid_desc_m_n);
const auto block_2_ctile_map = Block2CTileMap{karg.M, karg.N};
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
......@@ -347,7 +657,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
// lds max alignment
constexpr auto max_lds_align = math::lcm(AK1, BK1);
constexpr auto max_lds_align = math::lcm(AK1_, BK1_);
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
......@@ -361,7 +671,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0, MPerBlock, AK1>,
Sequence<AK0_, MPerBlock, AK1_>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
......@@ -392,7 +702,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<BK0, NPerBlock, BK1>,
Sequence<BK0_, NPerBlock, BK1_>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
......@@ -424,8 +734,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
constexpr index_t KPack = math::max(
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
constexpr index_t KPack =
math::max(math::lcm(AK1_, BK1_),
MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
......@@ -453,8 +764,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned,
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1_, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1_, 0, 0);
// gridwise GEMM pipeline
static_assert(std::is_default_constructible_v<GridwiseGemmPipe>);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment