"configs/git@developer.sourcefind.cn:OpenDAS/opencompass.git" did not exist on "b3f5d9e421ba5021f607c67a48a599221148e6a7"
Commit c4f15eb9 authored by Adam Osewski's avatar Adam Osewski
Browse files

Use of PadTensorDescriptor for grid desc creation.

parent 7d9fff97
...@@ -256,10 +256,7 @@ struct DeviceGroupedGemmXdlSplitKDirectCWriteOut ...@@ -256,10 +256,7 @@ struct DeviceGroupedGemmXdlSplitKDirectCWriteOut
const index_t stride_b = gemm_descs[i].stride_B_; const index_t stride_b = gemm_descs[i].stride_B_;
const index_t stride_c = gemm_descs[i].stride_C_; const index_t stride_c = gemm_descs[i].stride_C_;
const index_t m_padded = GridwiseGemm::CalculateMPadded(M); const index_t k0 = GridwiseGemm::CalculateK0(K, K_BATCH);
const index_t n_padded = GridwiseGemm::CalculateNPadded(N);
const index_t k_padded = GridwiseGemm::CalculateKPadded(K, K_BATCH);
const index_t k0 = GridwiseGemm::CalculateK0(K, K_BATCH);
const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(M, N, stride_c); const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(M, N, stride_c);
...@@ -285,9 +282,6 @@ struct DeviceGroupedGemmXdlSplitKDirectCWriteOut ...@@ -285,9 +282,6 @@ struct DeviceGroupedGemmXdlSplitKDirectCWriteOut
stride_a, stride_a,
stride_b, stride_b,
stride_c, stride_c,
m_padded,
n_padded,
k_padded,
k0, k0,
K_BATCH}; K_BATCH};
...@@ -311,8 +305,7 @@ struct DeviceGroupedGemmXdlSplitKDirectCWriteOut ...@@ -311,8 +305,7 @@ struct DeviceGroupedGemmXdlSplitKDirectCWriteOut
auto& karg = gemm_kernel_args_[i].karg_; auto& karg = gemm_kernel_args_[i].karg_;
const index_t k_padded = GridwiseGemm::CalculateKPadded(karg.K, K_BATCH); const index_t k0 = GridwiseGemm::CalculateK0(karg.K, K_BATCH);
const index_t k0 = GridwiseGemm::CalculateK0(karg.K, K_BATCH);
const auto c_grid_desc_m_n = const auto c_grid_desc_m_n =
GridwiseGemm::MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC); GridwiseGemm::MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC);
...@@ -330,7 +323,6 @@ struct DeviceGroupedGemmXdlSplitKDirectCWriteOut ...@@ -330,7 +323,6 @@ struct DeviceGroupedGemmXdlSplitKDirectCWriteOut
auto grouped_block_2_ctile_map = auto grouped_block_2_ctile_map =
GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start); GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
karg.KPadded = k_padded;
karg.K0 = k0; karg.K0 = k0;
karg.k_batch = K_BATCH; karg.k_batch = K_BATCH;
gemm_kernel_args_[i].block_2_ctile_map_ = grouped_block_2_ctile_map; gemm_kernel_args_[i].block_2_ctile_map_ = grouped_block_2_ctile_map;
......
...@@ -97,10 +97,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out ...@@ -97,10 +97,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out
static constexpr auto K1 = Number<K1Value>{}; static constexpr auto K1 = Number<K1Value>{};
static constexpr auto KPerBlock = K1Value * K0PerBlock; static constexpr auto KPerBlock = K1Value * K0PerBlock;
static constexpr auto gemm_padder =
tensor_operation::device::GemmPadder<GemmSpec, index_t, index_t, index_t>{
MPerBlock, NPerBlock, K1* K0PerBlock};
using ThisThreadBlock = ThisThreadBlock<BlockSize>; using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<decltype( using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>; GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
...@@ -116,9 +112,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out ...@@ -116,9 +112,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out
index_t StrideA; index_t StrideA;
index_t StrideB; index_t StrideB;
index_t StrideC; index_t StrideC;
index_t MPadded;
index_t NPadded;
index_t KPadded;
index_t K0; index_t K0;
index_t k_batch; index_t k_batch;
...@@ -131,9 +124,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out ...@@ -131,9 +124,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out
index_t StrideA_, index_t StrideA_,
index_t StrideB_, index_t StrideB_,
index_t StrideC_, index_t StrideC_,
index_t MPadded_,
index_t NPadded_,
index_t KPadded_,
index_t K0_, index_t K0_,
index_t k_batch_) index_t k_batch_)
: p_a_grid(p_a_grid_), : p_a_grid(p_a_grid_),
...@@ -145,9 +135,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out ...@@ -145,9 +135,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out
StrideA(StrideA_), StrideA(StrideA_),
StrideB(StrideB_), StrideB(StrideB_),
StrideC(StrideC_), StrideC(StrideC_),
MPadded(MPadded_),
NPadded(NPadded_),
KPadded(KPadded_),
K0(K0_), K0(K0_),
k_batch(k_batch_) k_batch(k_batch_)
{ {
...@@ -162,9 +149,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out ...@@ -162,9 +149,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out
<< "SA:" << StrideA << ", " << "SA:" << StrideA << ", "
<< "SB:" << StrideB << ", " << "SB:" << StrideB << ", "
<< "SC:" << StrideC << ", " << "SC:" << StrideC << ", "
<< "MP:" << MPadded << ", "
<< "NP:" << NPadded << ", "
<< "KP:" << KPadded << ", "
<< "K0:" << K0 << ", " << "K0:" << K0 << ", "
<< "KB:" << k_batch << "}" << std::endl; << "KB:" << k_batch << "}" << std::endl;
} }
...@@ -300,13 +284,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out ...@@ -300,13 +284,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out
} }
} }
__host__ __device__ static auto MakeAGridDescriptor_KBatch_K0_M_K1(index_t M, __host__ __device__ static auto MakeAGridDescriptor_KBatch_K0_M_K1(
index_t MPad, index_t M, index_t K, index_t StrideA, index_t KBatch, index_t K0)
index_t K,
index_t StrideA,
index_t KBatch,
index_t K0,
index_t KPad)
{ {
const auto a_grid_desc_m_k = [&]() { const auto a_grid_desc_m_k = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
...@@ -319,43 +298,20 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out ...@@ -319,43 +298,20 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out
} }
}(); }();
const auto a_grid_desc_m_kpad = transform_tensor_descriptor( using DoPads = Sequence<tensor_operation::device::GemmPadM<GemmSpec>::PadM, true>;
a_grid_desc_m_k, const auto a_grid_desc_mpad_kpad = tensor_operation::device::PadTensorDescriptor(
make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), a_grid_desc_m_k, make_tuple(MPerBlock, K0 * K1), DoPads{});
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || return transform_tensor_descriptor(
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || a_grid_desc_mpad_kpad,
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)),
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) make_pass_through_transform(a_grid_desc_mpad_kpad.GetLength(I0))),
{ make_tuple(Sequence<1>{}, Sequence<0>{}),
// const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
return transform_tensor_descriptor(
a_grid_desc_m_kpad,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)),
make_right_pad_transform(M, MPad - M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
else
{
return transform_tensor_descriptor(
a_grid_desc_m_kpad,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
} }
__host__ __device__ static auto MakeBGridDescriptor_KBatch_K0_N_K1(index_t K, __host__ __device__ static auto MakeBGridDescriptor_KBatch_K0_N_K1(
index_t NPad, index_t K, index_t N, index_t StrideB, index_t KBatch, index_t K0)
index_t N,
index_t StrideB,
index_t KBatch,
index_t K0,
index_t KPad)
{ {
const auto b_grid_desc_k_n = [&]() { const auto b_grid_desc_k_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
...@@ -368,34 +324,16 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out ...@@ -368,34 +324,16 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out
} }
}(); }();
const auto b_grid_desc_kpad_n = transform_tensor_descriptor( using DoPads = Sequence<true, tensor_operation::device::GemmPadN<GemmSpec>::PadN>;
b_grid_desc_k_n, const auto b_grid_desc_kpad_npad = tensor_operation::device::PadTensorDescriptor(
make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)), b_grid_desc_k_n, make_tuple(K0 * K1, NPerBlock), DoPads{});
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || return transform_tensor_descriptor(
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || b_grid_desc_kpad_npad,
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)),
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) make_pass_through_transform(b_grid_desc_kpad_npad.GetLength(I1))),
{ make_tuple(Sequence<0>{}, Sequence<1>{}),
// const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
return transform_tensor_descriptor(
b_grid_desc_kpad_n,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)),
make_right_pad_transform(N, NPad - N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
else
{
return transform_tensor_descriptor(
b_grid_desc_kpad_n,
make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)),
make_pass_through_transform(N)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
}
} }
__host__ __device__ static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) __host__ __device__ static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
...@@ -411,7 +349,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out ...@@ -411,7 +349,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out
} }
}(); }();
return gemm_padder.PadCDescriptor_M_N(c_grid_desc_m_n); using DoPads = Sequence<tensor_operation::device::GemmPadM<GemmSpec>::PadM,
tensor_operation::device::GemmPadN<GemmSpec>::PadN>;
return tensor_operation::device::PadTensorDescriptor(
c_grid_desc_m_n, make_tuple(MPerBlock, NPerBlock), DoPads{});
} }
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
...@@ -612,13 +553,13 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out ...@@ -612,13 +553,13 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_splitk_direct_c_write_out
void* __restrict__ p_shared_block, void* __restrict__ p_shared_block,
const Block2CTileMap& block_2_ctile_map) const Block2CTileMap& block_2_ctile_map)
{ {
const FloatAB* p_a_grid = karg.p_a_grid; const FloatAB* p_a_grid = karg.p_a_grid;
const FloatAB* p_b_grid = karg.p_b_grid; const FloatAB* p_b_grid = karg.p_b_grid;
FloatC* p_c_grid = karg.p_c_grid; FloatC* p_c_grid = karg.p_c_grid;
const auto a_b_k0_m_k1_grid_desc = MakeAGridDescriptor_KBatch_K0_M_K1( const auto a_b_k0_m_k1_grid_desc =
karg.M, karg.MPadded, karg.K, karg.StrideA, karg.k_batch, karg.K0, karg.KPadded); MakeAGridDescriptor_KBatch_K0_M_K1(karg.M, karg.K, karg.StrideA, karg.k_batch, karg.K0);
const auto b_b_k0_n_k1_grid_desc = MakeBGridDescriptor_KBatch_K0_N_K1( const auto b_b_k0_n_k1_grid_desc =
karg.K, karg.NPadded, karg.N, karg.StrideB, karg.k_batch, karg.K0, karg.KPadded); MakeBGridDescriptor_KBatch_K0_N_K1(karg.K, karg.N, karg.StrideB, karg.k_batch, karg.K0);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC); const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC);
const auto c_grid_desc_m0_n0_m1_n1_m2_n2_n3_n4 = const auto c_grid_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(c_grid_desc_m_n); MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(c_grid_desc_m_n);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment