Commit d23a7617 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Extract desciptor creation logic from original gridwise gemm

parent 677778ee
...@@ -76,7 +76,7 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout, ...@@ -76,7 +76,7 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
static constexpr auto K1Number = Number<K1>{}; static constexpr auto K1Number = Number<K1>{};
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext<
BlockSize, BlockSize,
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
AccDataType, AccDataType,
......
...@@ -56,13 +56,9 @@ template <index_t BlockSize, ...@@ -56,13 +56,9 @@ template <index_t BlockSize,
typename FloatAcc, typename FloatAcc,
typename FloatC, typename FloatC,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename ALayout,
typename BLayout,
typename CLayout,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
tensor_operation::device::GemmSpecialization GemmSpec,
index_t MPerBlock, index_t MPerBlock,
index_t NPerBlock, index_t NPerBlock,
index_t K0PerBlock, index_t K0PerBlock,
...@@ -131,107 +127,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -131,107 +127,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
return math::integer_divide_floor(CalculateK0(K), K0PerBlock); return math::integer_divide_floor(CalculateK0(K), K0PerBlock);
} }
__device__ static auto
MakeAGridDescriptor_K0_M_K1(index_t M, index_t MPad, index_t K, index_t K0, index_t StrideA)
{
const auto a_grid_desc_m_k = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
}
}();
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding)
{
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Value)),
make_right_pad_transform(M, MPad - M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Value)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
}
__device__ static auto
MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t NPad, index_t K0, index_t StrideB)
{
const auto b_grid_desc_k_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
}
}();
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding)
{
return transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Value)),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Value)),
make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
}
__device__ static auto
MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
{
const auto c_grid_desc_m_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
}
}();
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding)
{
return transform_tensor_descriptor(c_grid_desc_m_n,
make_tuple(make_right_pad_transform(M, MPad - M),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}
// Argument // Argument
struct Argument : public tensor_operation::device::BaseArgument struct Argument : public tensor_operation::device::BaseArgument
{ {
...@@ -365,70 +260,57 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -365,70 +260,57 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
return (a_block_space_size_aligned + b_block_space_size_aligned) * sizeof(FloatAB); return (a_block_space_size_aligned + b_block_space_size_aligned) * sizeof(FloatAB);
} }
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} template <typename AGridDesc_K0_M_K1,
__host__ static constexpr bool CheckValidity(const Argument& karg) typename BGridDesc_K0_N_K1,
typename CGridDesc_M_N,
typename Block2CTileMap>
__host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
const CGridDesc_M_N& c_grid_desc_m_n,
const Block2CTileMap& block_2_ctile_map)
{ {
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value, const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
"wrong! K1 need to be known at compile-time"); const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
static_assert((MPerBlock % (MPerXDL * MXdlPerWave) == 0) && if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) &&
(NPerBlock % (NXdlPerWave * NPerXDL)) == 0, K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) &&
"Invalid tuning param!"); K1 == b_grid_desc_k0_n_k1.GetLength(I2)))
return false;
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || return false;
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
{
if(!(karg.M % MPerBlock == 0))
{
return false;
}
}
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || // check gridwise gemm pipeline
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || const auto num_k_loop = K0 / K0PerBlock;
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
{
if(!(karg.N % NPerBlock == 0))
{
return false;
}
}
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value) if(!GridwiseGemmPipe::IsSupported(num_k_loop))
{
if(karg.K % ABlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else
{ {
if(karg.M % ABlockTransferSrcScalarPerVector != 0) return false;
{
return false;
}
} }
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value) if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n))
{ {
if(karg.N % BBlockTransferSrcScalarPerVector != 0) return false;
{
return false;
}
}
else
{
if(karg.K % BBlockTransferSrcScalarPerVector != 0)
{
return false;
}
} }
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true;
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
__host__ static constexpr bool CheckValidity(const Argument& karg)
{
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
"wrong! K1 need to be known at compile-time");
static_assert((MPerBlock % (MPerXDL * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXDL)) == 0,
"Invalid tuning param!");
// check gridwise gemm pipeline // check gridwise gemm pipeline
const index_t K0 = karg.K / K1; const index_t K0 = karg.K / K1Value;
const auto num_k_loop = K0 / K0PerBlock; const auto num_k_loop = K0 / K0PerBlock;
if(!GridwiseGemmPipe::IsSupported(num_k_loop)) if(!GridwiseGemmPipe::IsSupported(num_k_loop))
...@@ -786,4 +668,310 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -786,4 +668,310 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
} }
}; };
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename ALayout,
typename BLayout,
typename CLayout,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
tensor_operation::device::GemmSpecialization GemmSpec,
index_t MPerBlock,
index_t NPerBlock,
index_t K0PerBlock,
index_t MPerXDL,
index_t NPerXDL,
index_t K1Value,
index_t MXdlPerWave,
index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_K1,
bool AThreadTransferSrcResetCoordinateAfterRun,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_K0_N_K1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_K1,
bool BThreadTransferSrcResetCoordinateAfterRun,
bool BBlockLdsExtraN,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
index_t NumGemmKPrefetchStage = 1,
LoopScheduler LoopSched = make_default_loop_scheduler(),
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext
: GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
FloatAB,
FloatAcc,
FloatC,
CGlobalMemoryDataOperation,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
MPerBlock,
NPerBlock,
K0PerBlock,
MPerXDL,
NPerXDL,
K1Value,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsExtraN,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
NumGemmKPrefetchStage,
LoopSched,
PipelineVer>
{
using Parent =
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
FloatAB,
FloatAcc,
FloatC,
CGlobalMemoryDataOperation,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
MPerBlock,
NPerBlock,
K0PerBlock,
MPerXDL,
NPerXDL,
K1Value,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsExtraN,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
NumGemmKPrefetchStage,
LoopSched,
PipelineVer>;
using typename Parent::Argument;
using typename Parent::GridwiseGemmPipe;
using Parent::I1;
using Parent::K1;
__device__ static auto
MakeAGridDescriptor_K0_M_K1(index_t M, index_t MPad, index_t K, index_t K0, index_t StrideA)
{
const auto a_grid_desc_m_k = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
}
}();
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding)
{
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Value)),
make_right_pad_transform(M, MPad - M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Value)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
}
__device__ static auto
MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t NPad, index_t K0, index_t StrideB)
{
const auto b_grid_desc_k_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
}
}();
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding)
{
return transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Value)),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Value)),
make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
}
__device__ static auto
MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
{
const auto c_grid_desc_m_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
}
}();
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding)
{
return transform_tensor_descriptor(c_grid_desc_m_n,
make_tuple(make_right_pad_transform(M, MPad - M),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
__host__ static constexpr bool CheckValidity(const Argument& karg)
{
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
"wrong! K1 need to be known at compile-time");
static_assert((MPerBlock % (MPerXDL * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXDL)) == 0,
"Invalid tuning param!");
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
{
if(!(karg.M % MPerBlock == 0))
{
return false;
}
}
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
{
if(!(karg.N % NPerBlock == 0))
{
return false;
}
}
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
if(karg.K % ABlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else
{
if(karg.M % ABlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
if(karg.N % BBlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else
{
if(karg.K % BBlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
// check gridwise gemm pipeline
const index_t K0 = karg.K / K1;
const auto num_k_loop = K0 / K0PerBlock;
if(!GridwiseGemmPipe::IsSupported(num_k_loop))
{
return false;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true;
}
};
} // namespace ck } // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment