Commit 510d6464 authored by Jing Zhang's avatar Jing Zhang
Browse files

format

parent fc985a76
...@@ -112,7 +112,10 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -112,7 +112,10 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(I1, AK0PerBlock, Number<MPerBlock>{}, AK1), make_tuple(I1, AK0PerBlock, Number<MPerBlock>{}, AK1),
make_tuple(AK0PerBlock * Number<MPerBlock + ABlockLdsExtraM>{} * AK1, Number<MPerBlock + ABlockLdsExtraM>{} * AK1, AK1, I1)); make_tuple(AK0PerBlock * Number<MPerBlock + ABlockLdsExtraM>{} * AK1,
Number<MPerBlock + ABlockLdsExtraM>{} * AK1,
AK1,
I1));
} }
__host__ __device__ static constexpr auto GetBBlockDescriptor_KBatch_BK0PerBlock_NPerBlock_BK1() __host__ __device__ static constexpr auto GetBBlockDescriptor_KBatch_BK0PerBlock_NPerBlock_BK1()
...@@ -120,7 +123,10 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -120,7 +123,10 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(I1, BK0PerBlock, Number<NPerBlock>{}, BK1), make_tuple(I1, BK0PerBlock, Number<NPerBlock>{}, BK1),
make_tuple(BK0PerBlock * Number<NPerBlock + BBlockLdsExtraN>{} * BK1, Number<NPerBlock + BBlockLdsExtraN>{} * BK1, BK1, I1)); make_tuple(BK0PerBlock * Number<NPerBlock + BBlockLdsExtraN>{} * BK1,
Number<NPerBlock + BBlockLdsExtraN>{} * BK1,
BK1,
I1));
} }
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
...@@ -245,10 +251,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -245,10 +251,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
} }
template <typename ALayout, GemmSpecialization GemmSpec> template <typename ALayout, GemmSpecialization GemmSpec>
__host__ __device__ static auto MakeAGridDescriptor_KBatch_AK0_M_AK1(index_t M, __host__ __device__ static auto
index_t K, MakeAGridDescriptor_KBatch_AK0_M_AK1(index_t M, index_t K, index_t StrideA, index_t KBatch)
index_t StrideA,
index_t KBatch)
{ {
const auto a_grid_desc_m_k = [&]() { const auto a_grid_desc_m_k = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
...@@ -297,10 +301,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -297,10 +301,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
} }
template <typename BLayout, GemmSpecialization GemmSpec> template <typename BLayout, GemmSpecialization GemmSpec>
__host__ __device__ static auto MakeBGridDescriptor_KBatch_BK0_N_BK1(index_t K, __host__ __device__ static auto
index_t N, MakeBGridDescriptor_KBatch_BK0_N_BK1(index_t K, index_t N, index_t StrideB, index_t KBatch)
index_t StrideB,
index_t KBatch)
{ {
const auto b_grid_desc_k_n = [&]() { const auto b_grid_desc_k_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
...@@ -348,7 +350,6 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -348,7 +350,6 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
} }
} }
// E desc for destination in blockwise copy // E desc for destination in blockwise copy
template <typename EGridDesc_M_N> template <typename EGridDesc_M_N>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
...@@ -624,10 +625,12 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -624,10 +625,12 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
constexpr auto max_lds_align = math::lcm(AK1, BK1); constexpr auto max_lds_align = math::lcm(AK1, BK1);
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_kbatch_ak0_m_ak1 = GetABlockDescriptor_KBatch_AK0PerBlock_MPerBlock_AK1(); constexpr auto a_block_desc_kbatch_ak0_m_ak1 =
GetABlockDescriptor_KBatch_AK0PerBlock_MPerBlock_AK1();
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_kbatch_bk0_n_bk1 = GetBBlockDescriptor_KBatch_BK0PerBlock_NPerBlock_BK1(); constexpr auto b_block_desc_kbatch_bk0_n_bk1 =
GetBBlockDescriptor_KBatch_BK0PerBlock_NPerBlock_BK1();
const index_t kbatch_id = 0; const index_t kbatch_id = 0;
...@@ -693,7 +696,6 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -693,7 +696,6 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
make_multi_index(0, 0, 0, 0), make_multi_index(0, 0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{}); ck::tensor_operation::element_wise::PassThrough{});
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
...@@ -745,8 +747,9 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -745,8 +747,9 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
const auto gridwise_gemm_pipeline = const auto gridwise_gemm_pipeline =
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>(); GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>();
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( const index_t num_k_block_main_loop =
(a_grid_desc_kbatch_ak0_m_ak1.GetLength(I0) * a_grid_desc_kbatch_ak0_m_ak1.GetLength(I2)) / __builtin_amdgcn_readfirstlane((a_grid_desc_kbatch_ak0_m_ak1.GetLength(I0) *
a_grid_desc_kbatch_ak0_m_ak1.GetLength(I2)) /
KPerBlock); KPerBlock);
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_kbatch_ak0_m_ak1, gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_kbatch_ak0_m_ak1,
...@@ -1030,8 +1033,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -1030,8 +1033,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
const auto p_e_grid = reinterpret_cast<EDataType*>(p_e_grid_); const auto p_e_grid = reinterpret_cast<EDataType*>(p_e_grid_);
// tensor descriptors for problem definiton // tensor descriptors for problem definiton
//const auto a_grid_desc_m_k = MakeAGridDescriptor_M_K<ALayout, GemmSpec>(M, K, StrideA); // const auto a_grid_desc_m_k = MakeAGridDescriptor_M_K<ALayout, GemmSpec>(M, K, StrideA);
//const auto b_grid_desc_n_k = MakeBGridDescriptor_N_K<BLayout, GemmSpec>(K, N, StrideB); // const auto b_grid_desc_n_k = MakeBGridDescriptor_N_K<BLayout, GemmSpec>(K, N, StrideB);
using DsGridDesc_M_N = using DsGridDesc_M_N =
remove_cvref_t<decltype(MakeDsGridDescriptor_M_N<DsLayout, GemmSpec>({}, {}, {}))>; remove_cvref_t<decltype(MakeDsGridDescriptor_M_N<DsLayout, GemmSpec>({}, {}, {}))>;
...@@ -1047,9 +1050,11 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -1047,9 +1050,11 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N<ELayout, GemmSpec>(M, N, StrideE); const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N<ELayout, GemmSpec>(M, N, StrideE);
// tensor descriptors for block/thread-wise copy // tensor descriptors for block/thread-wise copy
const auto a_grid_desc_kbatch_ak0_m_ak1 = MakeAGridDescriptor_KBatch_AK0_M_AK1<ALayout, GemmSpec>(M, K, StrideA, 1); const auto a_grid_desc_kbatch_ak0_m_ak1 =
MakeAGridDescriptor_KBatch_AK0_M_AK1<ALayout, GemmSpec>(M, K, StrideA, 1);
const auto b_grid_desc_kbatch_bk0_n_bk1 = MakeBGridDescriptor_KBatch_BK0_N_BK1<BLayout, GemmSpec>(K, N, StrideB, 1); const auto b_grid_desc_kbatch_bk0_n_bk1 =
MakeBGridDescriptor_KBatch_BK0_N_BK1<BLayout, GemmSpec>(K, N, StrideB, 1);
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>; MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment