Commit 97648ccd authored by Adam Osewski's avatar Adam Osewski
Browse files

Use of PadTensorDescriptor for grid desc creation.

parent 0eff71a4
......@@ -256,9 +256,6 @@ struct DeviceGroupedGemmXdlSplitKDirectCWriteOut
const index_t stride_b = gemm_descs[i].stride_B_;
const index_t stride_c = gemm_descs[i].stride_C_;
const index_t m_padded = GridwiseGemm::CalculateMPadded(M);
const index_t n_padded = GridwiseGemm::CalculateNPadded(N);
const index_t k_padded = GridwiseGemm::CalculateKPadded(K, K_BATCH);
const index_t k0 = GridwiseGemm::CalculateK0(K, K_BATCH);
const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(M, N, stride_c);
......@@ -285,9 +282,6 @@ struct DeviceGroupedGemmXdlSplitKDirectCWriteOut
stride_a,
stride_b,
stride_c,
m_padded,
n_padded,
k_padded,
k0,
K_BATCH};
......@@ -311,7 +305,6 @@ struct DeviceGroupedGemmXdlSplitKDirectCWriteOut
auto& karg = gemm_kernel_args_[i].karg_;
const index_t k_padded = GridwiseGemm::CalculateKPadded(karg.K, K_BATCH);
const index_t k0 = GridwiseGemm::CalculateK0(karg.K, K_BATCH);
const auto c_grid_desc_m_n =
......@@ -330,7 +323,6 @@ struct DeviceGroupedGemmXdlSplitKDirectCWriteOut
auto grouped_block_2_ctile_map =
GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
karg.KPadded = k_padded;
karg.K0 = k0;
karg.k_batch = K_BATCH;
gemm_kernel_args_[i].block_2_ctile_map_ = grouped_block_2_ctile_map;
......
......@@ -97,10 +97,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out
static constexpr auto K1 = Number<K1Value>{};
static constexpr auto KPerBlock = K1Value * K0PerBlock;
static constexpr auto gemm_padder =
tensor_operation::device::GemmPadder<GemmSpec, index_t, index_t, index_t>{
MPerBlock, NPerBlock, K1* K0PerBlock};
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
......@@ -116,9 +112,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out
index_t StrideA;
index_t StrideB;
index_t StrideC;
index_t MPadded;
index_t NPadded;
index_t KPadded;
index_t K0;
index_t k_batch;
......@@ -131,9 +124,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out
index_t StrideA_,
index_t StrideB_,
index_t StrideC_,
index_t MPadded_,
index_t NPadded_,
index_t KPadded_,
index_t K0_,
index_t k_batch_)
: p_a_grid(p_a_grid_),
......@@ -145,9 +135,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out
StrideA(StrideA_),
StrideB(StrideB_),
StrideC(StrideC_),
MPadded(MPadded_),
NPadded(NPadded_),
KPadded(KPadded_),
K0(K0_),
k_batch(k_batch_)
{
......@@ -162,9 +149,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out
<< "SA:" << StrideA << ", "
<< "SB:" << StrideB << ", "
<< "SC:" << StrideC << ", "
<< "MP:" << MPadded << ", "
<< "NP:" << NPadded << ", "
<< "KP:" << KPadded << ", "
<< "K0:" << K0 << ", "
<< "KB:" << k_batch << "}" << std::endl;
}
......@@ -300,13 +284,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out
}
}
__host__ __device__ static auto MakeAGridDescriptor_KBatch_K0_M_K1(index_t M,
index_t MPad,
index_t K,
index_t StrideA,
index_t KBatch,
index_t K0,
index_t KPad)
__host__ __device__ static auto MakeAGridDescriptor_KBatch_K0_M_K1(
index_t M, index_t K, index_t StrideA, index_t KBatch, index_t K0)
{
const auto a_grid_desc_m_k = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
......@@ -319,43 +298,20 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out
}
}();
const auto a_grid_desc_m_kpad = transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
using DoPads = Sequence<tensor_operation::device::GemmPadM<GemmSpec>::PadM, true>;
const auto a_grid_desc_mpad_kpad = tensor_operation::device::PadTensorDescriptor(
a_grid_desc_m_k, make_tuple(MPerBlock, K0 * K1), DoPads{});
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)
{
// const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
return transform_tensor_descriptor(
a_grid_desc_m_kpad,
a_grid_desc_mpad_kpad,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)),
make_right_pad_transform(M, MPad - M)),
make_pass_through_transform(a_grid_desc_mpad_kpad.GetLength(I0))),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
else
{
return transform_tensor_descriptor(
a_grid_desc_m_kpad,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
}
__host__ __device__ static auto MakeBGridDescriptor_KBatch_K0_N_K1(index_t K,
index_t NPad,
index_t N,
index_t StrideB,
index_t KBatch,
index_t K0,
index_t KPad)
__host__ __device__ static auto MakeBGridDescriptor_KBatch_K0_N_K1(
index_t K, index_t N, index_t StrideB, index_t KBatch, index_t K0)
{
const auto b_grid_desc_k_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
......@@ -368,35 +324,17 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out
}
}();
const auto b_grid_desc_kpad_n = transform_tensor_descriptor(
b_grid_desc_k_n,
make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
using DoPads = Sequence<true, tensor_operation::device::GemmPadN<GemmSpec>::PadN>;
const auto b_grid_desc_kpad_npad = tensor_operation::device::PadTensorDescriptor(
b_grid_desc_k_n, make_tuple(K0 * K1, NPerBlock), DoPads{});
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)
{
// const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return transform_tensor_descriptor(
b_grid_desc_kpad_n,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
else
{
return transform_tensor_descriptor(
b_grid_desc_kpad_n,
b_grid_desc_kpad_npad,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)),
make_pass_through_transform(N)),
make_pass_through_transform(b_grid_desc_kpad_npad.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
}
__host__ __device__ static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
{
......@@ -411,7 +349,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out
}
}();
return gemm_padder.PadCDescriptor_M_N(c_grid_desc_m_n);
using DoPads = Sequence<tensor_operation::device::GemmPadM<GemmSpec>::PadM,
tensor_operation::device::GemmPadN<GemmSpec>::PadN>;
return tensor_operation::device::PadTensorDescriptor(
c_grid_desc_m_n, make_tuple(MPerBlock, NPerBlock), DoPads{});
}
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
......@@ -615,10 +556,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out
const FloatAB* p_a_grid = karg.p_a_grid;
const FloatAB* p_b_grid = karg.p_b_grid;
FloatC* p_c_grid = karg.p_c_grid;
const auto a_b_k0_m_k1_grid_desc = MakeAGridDescriptor_KBatch_K0_M_K1(
karg.M, karg.MPadded, karg.K, karg.StrideA, karg.k_batch, karg.K0, karg.KPadded);
const auto b_b_k0_n_k1_grid_desc = MakeBGridDescriptor_KBatch_K0_N_K1(
karg.K, karg.NPadded, karg.N, karg.StrideB, karg.k_batch, karg.K0, karg.KPadded);
const auto a_b_k0_m_k1_grid_desc =
MakeAGridDescriptor_KBatch_K0_M_K1(karg.M, karg.K, karg.StrideA, karg.k_batch, karg.K0);
const auto b_b_k0_n_k1_grid_desc =
MakeBGridDescriptor_KBatch_K0_N_K1(karg.K, karg.N, karg.StrideB, karg.k_batch, karg.K0);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC);
const auto c_grid_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(c_grid_desc_m_n);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment