Commit 97648ccd authored by Adam Osewski's avatar Adam Osewski
Browse files

Use of PadTensorDescriptor for grid desc creation.

parent 0eff71a4
...@@ -256,9 +256,6 @@ struct DeviceGroupedGemmXdlSplitKDirectCWriteOut ...@@ -256,9 +256,6 @@ struct DeviceGroupedGemmXdlSplitKDirectCWriteOut
const index_t stride_b = gemm_descs[i].stride_B_; const index_t stride_b = gemm_descs[i].stride_B_;
const index_t stride_c = gemm_descs[i].stride_C_; const index_t stride_c = gemm_descs[i].stride_C_;
const index_t m_padded = GridwiseGemm::CalculateMPadded(M);
const index_t n_padded = GridwiseGemm::CalculateNPadded(N);
const index_t k_padded = GridwiseGemm::CalculateKPadded(K, K_BATCH);
const index_t k0 = GridwiseGemm::CalculateK0(K, K_BATCH); const index_t k0 = GridwiseGemm::CalculateK0(K, K_BATCH);
const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(M, N, stride_c); const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(M, N, stride_c);
...@@ -285,9 +282,6 @@ struct DeviceGroupedGemmXdlSplitKDirectCWriteOut ...@@ -285,9 +282,6 @@ struct DeviceGroupedGemmXdlSplitKDirectCWriteOut
stride_a, stride_a,
stride_b, stride_b,
stride_c, stride_c,
m_padded,
n_padded,
k_padded,
k0, k0,
K_BATCH}; K_BATCH};
...@@ -311,7 +305,6 @@ struct DeviceGroupedGemmXdlSplitKDirectCWriteOut ...@@ -311,7 +305,6 @@ struct DeviceGroupedGemmXdlSplitKDirectCWriteOut
auto& karg = gemm_kernel_args_[i].karg_; auto& karg = gemm_kernel_args_[i].karg_;
const index_t k_padded = GridwiseGemm::CalculateKPadded(karg.K, K_BATCH);
const index_t k0 = GridwiseGemm::CalculateK0(karg.K, K_BATCH); const index_t k0 = GridwiseGemm::CalculateK0(karg.K, K_BATCH);
const auto c_grid_desc_m_n = const auto c_grid_desc_m_n =
...@@ -330,7 +323,6 @@ struct DeviceGroupedGemmXdlSplitKDirectCWriteOut ...@@ -330,7 +323,6 @@ struct DeviceGroupedGemmXdlSplitKDirectCWriteOut
auto grouped_block_2_ctile_map = auto grouped_block_2_ctile_map =
GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start); GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
karg.KPadded = k_padded;
karg.K0 = k0; karg.K0 = k0;
karg.k_batch = K_BATCH; karg.k_batch = K_BATCH;
gemm_kernel_args_[i].block_2_ctile_map_ = grouped_block_2_ctile_map; gemm_kernel_args_[i].block_2_ctile_map_ = grouped_block_2_ctile_map;
......
...@@ -97,10 +97,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out ...@@ -97,10 +97,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out
static constexpr auto K1 = Number<K1Value>{}; static constexpr auto K1 = Number<K1Value>{};
static constexpr auto KPerBlock = K1Value * K0PerBlock; static constexpr auto KPerBlock = K1Value * K0PerBlock;
static constexpr auto gemm_padder =
tensor_operation::device::GemmPadder<GemmSpec, index_t, index_t, index_t>{
MPerBlock, NPerBlock, K1* K0PerBlock};
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<decltype( using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>; GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
...@@ -116,9 +112,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out ...@@ -116,9 +112,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out
index_t StrideA; index_t StrideA;
index_t StrideB; index_t StrideB;
index_t StrideC; index_t StrideC;
index_t MPadded;
index_t NPadded;
index_t KPadded;
index_t K0; index_t K0;
index_t k_batch; index_t k_batch;
...@@ -131,9 +124,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out ...@@ -131,9 +124,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out
index_t StrideA_, index_t StrideA_,
index_t StrideB_, index_t StrideB_,
index_t StrideC_, index_t StrideC_,
index_t MPadded_,
index_t NPadded_,
index_t KPadded_,
index_t K0_, index_t K0_,
index_t k_batch_) index_t k_batch_)
: p_a_grid(p_a_grid_), : p_a_grid(p_a_grid_),
...@@ -145,9 +135,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out ...@@ -145,9 +135,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out
StrideA(StrideA_), StrideA(StrideA_),
StrideB(StrideB_), StrideB(StrideB_),
StrideC(StrideC_), StrideC(StrideC_),
MPadded(MPadded_),
NPadded(NPadded_),
KPadded(KPadded_),
K0(K0_), K0(K0_),
k_batch(k_batch_) k_batch(k_batch_)
{ {
...@@ -162,9 +149,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out ...@@ -162,9 +149,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out
<< "SA:" << StrideA << ", " << "SA:" << StrideA << ", "
<< "SB:" << StrideB << ", " << "SB:" << StrideB << ", "
<< "SC:" << StrideC << ", " << "SC:" << StrideC << ", "
<< "MP:" << MPadded << ", "
<< "NP:" << NPadded << ", "
<< "KP:" << KPadded << ", "
<< "K0:" << K0 << ", " << "K0:" << K0 << ", "
<< "KB:" << k_batch << "}" << std::endl; << "KB:" << k_batch << "}" << std::endl;
} }
...@@ -300,13 +284,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out ...@@ -300,13 +284,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out
} }
} }
__host__ __device__ static auto MakeAGridDescriptor_KBatch_K0_M_K1(index_t M, __host__ __device__ static auto MakeAGridDescriptor_KBatch_K0_M_K1(
index_t MPad, index_t M, index_t K, index_t StrideA, index_t KBatch, index_t K0)
index_t K,
index_t StrideA,
index_t KBatch,
index_t K0,
index_t KPad)
{ {
const auto a_grid_desc_m_k = [&]() { const auto a_grid_desc_m_k = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
...@@ -319,43 +298,20 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out ...@@ -319,43 +298,20 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out
} }
}(); }();
const auto a_grid_desc_m_kpad = transform_tensor_descriptor( using DoPads = Sequence<tensor_operation::device::GemmPadM<GemmSpec>::PadM, true>;
a_grid_desc_m_k, const auto a_grid_desc_mpad_kpad = tensor_operation::device::PadTensorDescriptor(
make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), a_grid_desc_m_k, make_tuple(MPerBlock, K0 * K1), DoPads{});
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)
{
// const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
return transform_tensor_descriptor( return transform_tensor_descriptor(
a_grid_desc_m_kpad, a_grid_desc_mpad_kpad,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)), make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)),
make_right_pad_transform(M, MPad - M)), make_pass_through_transform(a_grid_desc_mpad_kpad.GetLength(I0))),
make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
} }
else
{
return transform_tensor_descriptor(
a_grid_desc_m_kpad,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
}
__host__ __device__ static auto MakeBGridDescriptor_KBatch_K0_N_K1(index_t K, __host__ __device__ static auto MakeBGridDescriptor_KBatch_K0_N_K1(
index_t NPad, index_t K, index_t N, index_t StrideB, index_t KBatch, index_t K0)
index_t N,
index_t StrideB,
index_t KBatch,
index_t K0,
index_t KPad)
{ {
const auto b_grid_desc_k_n = [&]() { const auto b_grid_desc_k_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
...@@ -368,35 +324,17 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out ...@@ -368,35 +324,17 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out
} }
}(); }();
const auto b_grid_desc_kpad_n = transform_tensor_descriptor( using DoPads = Sequence<true, tensor_operation::device::GemmPadN<GemmSpec>::PadN>;
b_grid_desc_k_n, const auto b_grid_desc_kpad_npad = tensor_operation::device::PadTensorDescriptor(
make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)), b_grid_desc_k_n, make_tuple(K0 * K1, NPerBlock), DoPads{});
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)
{
// const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor(
b_grid_desc_kpad_n,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
else
{
return transform_tensor_descriptor( return transform_tensor_descriptor(
b_grid_desc_kpad_n, b_grid_desc_kpad_npad,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)), make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)),
make_pass_through_transform(N)), make_pass_through_transform(b_grid_desc_kpad_npad.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
} }
}
__host__ __device__ static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) __host__ __device__ static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
{ {
...@@ -411,7 +349,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out ...@@ -411,7 +349,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out
} }
}(); }();
return gemm_padder.PadCDescriptor_M_N(c_grid_desc_m_n); using DoPads = Sequence<tensor_operation::device::GemmPadM<GemmSpec>::PadM,
tensor_operation::device::GemmPadN<GemmSpec>::PadN>;
return tensor_operation::device::PadTensorDescriptor(
c_grid_desc_m_n, make_tuple(MPerBlock, NPerBlock), DoPads{});
} }
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
...@@ -615,10 +556,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out ...@@ -615,10 +556,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out
const FloatAB* p_a_grid = karg.p_a_grid; const FloatAB* p_a_grid = karg.p_a_grid;
const FloatAB* p_b_grid = karg.p_b_grid; const FloatAB* p_b_grid = karg.p_b_grid;
FloatC* p_c_grid = karg.p_c_grid; FloatC* p_c_grid = karg.p_c_grid;
const auto a_b_k0_m_k1_grid_desc = MakeAGridDescriptor_KBatch_K0_M_K1( const auto a_b_k0_m_k1_grid_desc =
karg.M, karg.MPadded, karg.K, karg.StrideA, karg.k_batch, karg.K0, karg.KPadded); MakeAGridDescriptor_KBatch_K0_M_K1(karg.M, karg.K, karg.StrideA, karg.k_batch, karg.K0);
const auto b_b_k0_n_k1_grid_desc = MakeBGridDescriptor_KBatch_K0_N_K1( const auto b_b_k0_n_k1_grid_desc =
karg.K, karg.NPadded, karg.N, karg.StrideB, karg.k_batch, karg.K0, karg.KPadded); MakeBGridDescriptor_KBatch_K0_N_K1(karg.K, karg.N, karg.StrideB, karg.k_batch, karg.K0);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC); const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC);
const auto c_grid_desc_m0_n0_m1_n1_m2_n2_n3_n4 = const auto c_grid_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(c_grid_desc_m_n); MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(c_grid_desc_m_n);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment