"git@developer.sourcefind.cn:cnjsdfcy/simbricks.git" did not exist on "a4de8c62f89271f5618d09e99a3193c171514d92"
Commit 510d6464 authored by Jing Zhang's avatar Jing Zhang
Browse files

format

parent fc985a76
......@@ -112,7 +112,10 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
// A matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(I1, AK0PerBlock, Number<MPerBlock>{}, AK1),
make_tuple(AK0PerBlock * Number<MPerBlock + ABlockLdsExtraM>{} * AK1, Number<MPerBlock + ABlockLdsExtraM>{} * AK1, AK1, I1));
make_tuple(AK0PerBlock * Number<MPerBlock + ABlockLdsExtraM>{} * AK1,
Number<MPerBlock + ABlockLdsExtraM>{} * AK1,
AK1,
I1));
}
__host__ __device__ static constexpr auto GetBBlockDescriptor_KBatch_BK0PerBlock_NPerBlock_BK1()
......@@ -120,7 +123,10 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
// B matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor(
make_tuple(I1, BK0PerBlock, Number<NPerBlock>{}, BK1),
make_tuple(BK0PerBlock * Number<NPerBlock + BBlockLdsExtraN>{} * BK1, Number<NPerBlock + BBlockLdsExtraN>{} * BK1, BK1, I1));
make_tuple(BK0PerBlock * Number<NPerBlock + BBlockLdsExtraN>{} * BK1,
Number<NPerBlock + BBlockLdsExtraN>{} * BK1,
BK1,
I1));
}
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
......@@ -245,10 +251,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
}
template <typename ALayout, GemmSpecialization GemmSpec>
__host__ __device__ static auto MakeAGridDescriptor_KBatch_AK0_M_AK1(index_t M,
index_t K,
index_t StrideA,
index_t KBatch)
__host__ __device__ static auto
MakeAGridDescriptor_KBatch_AK0_M_AK1(index_t M, index_t K, index_t StrideA, index_t KBatch)
{
const auto a_grid_desc_m_k = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
......@@ -297,10 +301,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
}
template <typename BLayout, GemmSpecialization GemmSpec>
__host__ __device__ static auto MakeBGridDescriptor_KBatch_BK0_N_BK1(index_t K,
index_t N,
index_t StrideB,
index_t KBatch)
__host__ __device__ static auto
MakeBGridDescriptor_KBatch_BK0_N_BK1(index_t K, index_t N, index_t StrideB, index_t KBatch)
{
const auto b_grid_desc_k_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
......@@ -348,7 +350,6 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
}
}
// E desc for destination in blockwise copy
template <typename EGridDesc_M_N>
__host__ __device__ static constexpr auto
......@@ -624,10 +625,12 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
constexpr auto max_lds_align = math::lcm(AK1, BK1);
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_kbatch_ak0_m_ak1 = GetABlockDescriptor_KBatch_AK0PerBlock_MPerBlock_AK1();
constexpr auto a_block_desc_kbatch_ak0_m_ak1 =
GetABlockDescriptor_KBatch_AK0PerBlock_MPerBlock_AK1();
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_kbatch_bk0_n_bk1 = GetBBlockDescriptor_KBatch_BK0PerBlock_NPerBlock_BK1();
constexpr auto b_block_desc_kbatch_bk0_n_bk1 =
GetBBlockDescriptor_KBatch_BK0PerBlock_NPerBlock_BK1();
const index_t kbatch_id = 0;
......@@ -693,7 +696,6 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
make_multi_index(0, 0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
......@@ -745,9 +747,10 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
const auto gridwise_gemm_pipeline =
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>();
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_kbatch_ak0_m_ak1.GetLength(I0) * a_grid_desc_kbatch_ak0_m_ak1.GetLength(I2)) /
KPerBlock);
const index_t num_k_block_main_loop =
__builtin_amdgcn_readfirstlane((a_grid_desc_kbatch_ak0_m_ak1.GetLength(I0) *
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I2)) /
KPerBlock);
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_kbatch_ak0_m_ak1,
a_block_desc_kbatch_ak0_m_ak1,
......@@ -1030,8 +1033,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
const auto p_e_grid = reinterpret_cast<EDataType*>(p_e_grid_);
// tensor descriptors for problem definiton
//const auto a_grid_desc_m_k = MakeAGridDescriptor_M_K<ALayout, GemmSpec>(M, K, StrideA);
//const auto b_grid_desc_n_k = MakeBGridDescriptor_N_K<BLayout, GemmSpec>(K, N, StrideB);
// const auto a_grid_desc_m_k = MakeAGridDescriptor_M_K<ALayout, GemmSpec>(M, K, StrideA);
// const auto b_grid_desc_n_k = MakeBGridDescriptor_N_K<BLayout, GemmSpec>(K, N, StrideB);
using DsGridDesc_M_N =
remove_cvref_t<decltype(MakeDsGridDescriptor_M_N<DsLayout, GemmSpec>({}, {}, {}))>;
......@@ -1047,9 +1050,11 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N<ELayout, GemmSpec>(M, N, StrideE);
// tensor descriptors for block/thread-wise copy
const auto a_grid_desc_kbatch_ak0_m_ak1 = MakeAGridDescriptor_KBatch_AK0_M_AK1<ALayout, GemmSpec>(M, K, StrideA, 1);
const auto a_grid_desc_kbatch_ak0_m_ak1 =
MakeAGridDescriptor_KBatch_AK0_M_AK1<ALayout, GemmSpec>(M, K, StrideA, 1);
const auto b_grid_desc_kbatch_bk0_n_bk1 = MakeBGridDescriptor_KBatch_BK0_N_BK1<BLayout, GemmSpec>(K, N, StrideB, 1);
const auto b_grid_desc_kbatch_bk0_n_bk1 =
MakeBGridDescriptor_KBatch_BK0_N_BK1<BLayout, GemmSpec>(K, N, StrideB, 1);
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment