Unverified Commit 68dbf40a authored by Haocong WANG's avatar Haocong WANG Committed by GitHub
Browse files

[Navi3x Bug Fix] fix typo to accept MNKPadding flag correctly. (#597)

* fix a bug blocking wmma_gemm_multipleD

* Utilize matrix padder in device_wmma_op

* cosmetic change for gemmpadding format

* clang format

* Change gridwise gemm from FIFO to KMN loop fashion
parent 8f455615
...@@ -19,7 +19,7 @@ using AElementOp = PassThrough; ...@@ -19,7 +19,7 @@ using AElementOp = PassThrough;
using BElementOp = PassThrough; using BElementOp = PassThrough;
using CElementOp = PassThrough; using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// clang-format off // clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
...@@ -27,7 +27,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle ...@@ -27,7 +27,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|MWmmaPerWave|NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| // ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|MWmmaPerWave|NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector|
// ######| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| // ######| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 256, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, 1>; < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmMNKPadding, 256, 128, 256, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, 1>;
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance = ck::tensor_operation::host::
......
...@@ -86,38 +86,30 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -86,38 +86,30 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
// K1 = Max Vector Access Pixels // K1 = Max Vector Access Pixels
static constexpr auto K1Number = Number<K1>{}; static constexpr auto K1Number = Number<K1>{};
static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA) static constexpr auto matrix_padder =
{ MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, K0PerBlock* K1};
assert(K % K1 == 0);
const index_t K0 = K / K1;
const auto a_grid_desc_m_k = [&]() { static auto MakeAGridDescriptor_K0_M_K1(index_t MRaw, index_t KRaw, index_t StrideA)
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value) {
const auto a_grid_desc_mraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{ {
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(StrideA, I1));
} }
#ifdef ENABLE_COLMAJOR else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(I1, StrideA));
} }
#endif
}(); }();
if constexpr(GemmSpec == GemmSpecialization::MNPadding) const auto a_grid_desc_m_k = matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
{ const auto M = a_grid_desc_m_k.GetLength(I0);
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; const auto K = a_grid_desc_m_k.GetLength(I1);
assert(K % K1 == 0);
const index_t K0 = K / K1;
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_right_pad_transform(M, PadM)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor( return transform_tensor_descriptor(
a_grid_desc_m_k, a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
...@@ -125,81 +117,53 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -125,81 +117,53 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
} }
}
static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB) static auto MakeBGridDescriptor_K0_N_K1(index_t KRaw, index_t NRaw, index_t StrideB)
{ {
assert(K % K1 == 0); const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
const index_t K0 = K / K1;
const auto b_grid_desc_k_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(StrideB, I1));
} }
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{ {
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(I1, StrideB));
} }
}(); }();
if constexpr(GemmSpec == GemmSpecialization::MNPadding) const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
{ const auto N = b_grid_desc_n_k.GetLength(I0);
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; const auto K = b_grid_desc_n_k.GetLength(I1);
assert(K % K1 == 0);
const index_t K0 = K / K1;
return transform_tensor_descriptor( return transform_tensor_descriptor(
b_grid_desc_k_n, b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(N)), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
} }
}
template <typename ELayout_> template <typename ELayout_>
static auto MakeEGridDescriptor_M_N(index_t M, index_t N, index_t StrideE) static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
{ {
const auto e_grid_desc_m_n = [&]() { const auto e_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELayout_>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, ELayout_>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideE, I1)); return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(StrideE, I1));
} }
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ELayout_>::value) else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ELayout_>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideE)); return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(I1, StrideE));
} }
}(); }();
if constexpr(GemmSpec == GemmSpecialization::MNPadding) return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
{
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor(
e_grid_desc_m_n,
make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
e_grid_desc_m_n,
make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
} }
static auto MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& Ms, static auto MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& Ms,
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp" #include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -78,38 +79,30 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -78,38 +79,30 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
// K1 = Max Vector Access Pixels // K1 = Max Vector Access Pixels
static constexpr auto K1Number = Number<K1>{}; static constexpr auto K1Number = Number<K1>{};
static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA) static constexpr auto matrix_padder =
{ MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, K0PerBlock* K1};
assert(K % K1 == 0);
const index_t K0 = K / K1;
const auto a_grid_desc_m_k = [&]() { static auto MakeAGridDescriptor_K0_M_K1(index_t MRaw, index_t KRaw, index_t StrideA)
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value) {
const auto a_grid_desc_mraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{ {
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(StrideA, I1));
} }
#ifdef ENABLE_COLMAJOR else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(I1, StrideA));
} }
#endif
}(); }();
if constexpr(GemmSpec == GemmSpecialization::MNPadding) const auto a_grid_desc_m_k = matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
{ const auto M = a_grid_desc_m_k.GetLength(I0);
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; const auto K = a_grid_desc_m_k.GetLength(I1);
assert(K % K1 == 0);
const index_t K0 = K / K1;
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_right_pad_transform(M, PadM)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor( return transform_tensor_descriptor(
a_grid_desc_m_k, a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
...@@ -117,80 +110,52 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -117,80 +110,52 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
} }
}
static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB) static auto MakeBGridDescriptor_K0_N_K1(index_t KRaw, index_t NRaw, index_t StrideB)
{ {
assert(K % K1 == 0); const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
const index_t K0 = K / K1;
const auto b_grid_desc_k_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(StrideB, I1));
} }
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{ {
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(I1, StrideB));
} }
}(); }();
if constexpr(GemmSpec == GemmSpecialization::MNPadding) const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
{ const auto N = b_grid_desc_n_k.GetLength(I0);
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; const auto K = b_grid_desc_n_k.GetLength(I1);
assert(K % K1 == 0);
const index_t K0 = K / K1;
return transform_tensor_descriptor( return transform_tensor_descriptor(
b_grid_desc_k_n, b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(N)), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
} }
}
static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC)
{ {
const auto c_grid_desc_m_n = [&]() { const auto c_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(StrideC, I1));
} }
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value) else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(I1, StrideC));
} }
}(); }();
if constexpr(GemmSpec == GemmSpecialization::MNPadding) return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw);
{
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
} }
// Gridwise descriptor, mapping to whole given provblem. // Gridwise descriptor, mapping to whole given provblem.
......
...@@ -414,7 +414,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -414,7 +414,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
constexpr auto KPack = math::integer_least_multiple(K1, WmmaK); constexpr auto KPack = math::integer_least_multiple(K1, WmmaK);
auto blockwise_gemm = auto blockwise_gemm =
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO<BlockSize, BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle<BlockSize,
FloatA, FloatA,
FloatB, FloatB,
FloatAcc, FloatAcc,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment