Commit 65018876 authored by aska-0096's avatar aska-0096
Browse files

Utilize matrix padder in device_wmma_op

parent e5fb9fa6
...@@ -19,7 +19,7 @@ using AElementOp = PassThrough; ...@@ -19,7 +19,7 @@ using AElementOp = PassThrough;
using BElementOp = PassThrough; using BElementOp = PassThrough;
using CElementOp = PassThrough; using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// clang-format off // clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
......
...@@ -86,38 +86,30 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -86,38 +86,30 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
// K1 = Max Vector Access Pixels // K1 = Max Vector Access Pixels
static constexpr auto K1Number = Number<K1>{}; static constexpr auto K1Number = Number<K1>{};
static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA) static constexpr auto matrix_padder =
{ MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, K0PerBlock* K1};
assert(K % K1 == 0);
const index_t K0 = K / K1;
const auto a_grid_desc_m_k = [&]() { static auto MakeAGridDescriptor_K0_M_K1(index_t MRaw, index_t KRaw, index_t StrideA)
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value) {
const auto a_grid_desc_mraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{ {
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(StrideA, I1));
} }
#ifdef ENABLE_COLMAJOR else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(I1, StrideA));
} }
#endif
}(); }();
if constexpr(GemmSpec == GemmSpecialization::MNKPadding) const auto a_grid_desc_m_k = matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
{ const auto M = a_grid_desc_m_k.GetLength(I0);
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; const auto K = a_grid_desc_m_k.GetLength(I1);
assert(K % K1 == 0);
const index_t K0 = K / K1;
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_right_pad_transform(M, PadM)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor( return transform_tensor_descriptor(
a_grid_desc_m_k, a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
...@@ -125,81 +117,53 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -125,81 +117,53 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
} }
}
static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB) static auto MakeBGridDescriptor_K0_N_K1(index_t KRaw, index_t NRaw, index_t StrideB)
{ {
assert(K % K1 == 0); const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
const index_t K0 = K / K1;
const auto b_grid_desc_k_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(StrideB, I1));
} }
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{ {
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(I1, StrideB));
} }
}(); }();
if constexpr(GemmSpec == GemmSpecialization::MNKPadding) const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
{ const auto N = b_grid_desc_n_k.GetLength(I0);
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; const auto K = b_grid_desc_n_k.GetLength(I1);
assert(K % K1 == 0);
const index_t K0 = K / K1;
return transform_tensor_descriptor( return transform_tensor_descriptor(
b_grid_desc_k_n, b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(N)), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
} }
}
template <typename ELayout_> template <typename ELayout_>
static auto MakeEGridDescriptor_M_N(index_t M, index_t N, index_t StrideE) static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
{ {
const auto e_grid_desc_m_n = [&]() { const auto e_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELayout_>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, ELayout_>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideE, I1)); return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(StrideE, I1));
} }
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ELayout_>::value) else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ELayout_>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideE)); return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(I1, StrideE));
} }
}(); }();
if constexpr(GemmSpec == GemmSpecialization::MNKPadding) return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
{
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor(
e_grid_desc_m_n,
make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
e_grid_desc_m_n,
make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
} }
static auto MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& Ms, static auto MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& Ms,
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp" #include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -78,38 +79,30 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -78,38 +79,30 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
// K1 = Max Vector Access Pixels // K1 = Max Vector Access Pixels
static constexpr auto K1Number = Number<K1>{}; static constexpr auto K1Number = Number<K1>{};
static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA) static constexpr auto matrix_padder =
{ MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, K0PerBlock* K1};
assert(K % K1 == 0);
const index_t K0 = K / K1;
const auto a_grid_desc_m_k = [&]() { static auto MakeAGridDescriptor_K0_M_K1(index_t MRaw, index_t KRaw, index_t StrideA)
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value) {
const auto a_grid_desc_mraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{ {
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(StrideA, I1));
} }
#ifdef ENABLE_COLMAJOR else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(I1, StrideA));
} }
#endif
}(); }();
if constexpr(GemmSpec == GemmSpecialization::MNPadding) const auto a_grid_desc_m_k = matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
{ const auto M = a_grid_desc_m_k.GetLength(I0);
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; const auto K = a_grid_desc_m_k.GetLength(I1);
assert(K % K1 == 0);
const index_t K0 = K / K1;
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_right_pad_transform(M, PadM)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor( return transform_tensor_descriptor(
a_grid_desc_m_k, a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
...@@ -117,80 +110,52 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout, ...@@ -117,80 +110,52 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
} }
}
static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB) static auto MakeBGridDescriptor_K0_N_K1(index_t KRaw, index_t NRaw, index_t StrideB)
{ {
assert(K % K1 == 0); const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
const index_t K0 = K / K1;
const auto b_grid_desc_k_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(StrideB, I1));
} }
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{ {
return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(I1, StrideB));
} }
}(); }();
if constexpr(GemmSpec == GemmSpecialization::MNPadding) const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
{ const auto N = b_grid_desc_n_k.GetLength(I0);
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; const auto K = b_grid_desc_n_k.GetLength(I1);
assert(K % K1 == 0);
const index_t K0 = K / K1;
return transform_tensor_descriptor( return transform_tensor_descriptor(
b_grid_desc_k_n, b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)),
make_pass_through_transform(N)), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
} }
}
static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC)
{ {
const auto c_grid_desc_m_n = [&]() { const auto c_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(StrideC, I1));
} }
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value) else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
{ {
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(I1, StrideC));
} }
}(); }();
if constexpr(GemmSpec == GemmSpecialization::MNPadding) return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw);
{
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
else
{
return transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
} }
// Gridwise descriptor, mapping to whole given provblem. // Gridwise descriptor, mapping to whole given provblem.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment