Commit 75ada2a6 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Remove integer divisions in device function

parent 9acad4f4
...@@ -116,11 +116,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -116,11 +116,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
return math::integer_divide_ceil(N, NPerBlock) * NPerBlock; return math::integer_divide_ceil(N, NPerBlock) * NPerBlock;
} }
__device__ static auto __host__ static auto CalculateK0(index_t K) { return math::integer_divide_floor(K, K1Value); }
MakeAGridDescriptor_K0_M_K1(index_t M, index_t MPad, index_t K, index_t StrideA)
__host__ static auto CalculateNumKBlockLoop(index_t K)
{ {
const index_t K0 = K / K1; return math::integer_divide_floor(CalculateK0(K), K0PerBlock);
}
__device__ static auto
MakeAGridDescriptor_K0_M_K1(index_t M, index_t MPad, index_t K, index_t K0, index_t StrideA)
{
const auto a_grid_desc_m_k = [&]() { const auto a_grid_desc_m_k = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{ {
...@@ -153,10 +158,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -153,10 +158,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
} }
__device__ static auto __device__ static auto
MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t NPad, index_t StrideB) MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t NPad, index_t K0, index_t StrideB)
{ {
const index_t K0 = K / K1;
const auto b_grid_desc_k_n = [&]() { const auto b_grid_desc_k_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value) if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{ {
...@@ -243,7 +246,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -243,7 +246,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
StrideB{StrideB_}, StrideB{StrideB_},
StrideC{StrideC_}, StrideC{StrideC_},
MPadded{CalculateMPadded(M_)}, MPadded{CalculateMPadded(M_)},
NPadded{CalculateNPadded(N_)} NPadded{CalculateNPadded(N_)},
K0{CalculateK0(K)},
NumKBlockLoop{CalculateNumKBlockLoop(K)}
{ {
} }
...@@ -257,7 +262,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -257,7 +262,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
<< "SB:" << StrideB << ", " << "SB:" << StrideB << ", "
<< "SC:" << StrideC << ", " << "SC:" << StrideC << ", "
<< "MP:" << MPadded << ", " << "MP:" << MPadded << ", "
<< "NP:" << NPadded << "}" << std::endl; << "NP:" << NPadded << ", "
<< "K0:" << K0 << ", "
<< "NumKBlockLoop: " << NumKBlockLoop << "}" << std::endl;
} }
const FloatAB* p_a_grid; const FloatAB* p_a_grid;
...@@ -271,6 +278,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -271,6 +278,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
index_t StrideC; index_t StrideC;
index_t MPadded; index_t MPadded;
index_t NPadded; index_t NPadded;
index_t K0;
index_t NumKBlockLoop;
}; };
using GridwiseGemmPipe = remove_cvref_t<decltype( using GridwiseGemmPipe = remove_cvref_t<decltype(
...@@ -349,8 +358,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -349,8 +358,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
} }
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Argument> __host__ static constexpr bool CheckValidity(const Argument& karg)
__host__ __device__ static constexpr bool CheckValidity(const Argument& karg)
{ {
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value, static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
"wrong! K1 need to be known at compile-time"); "wrong! K1 need to be known at compile-time");
...@@ -424,7 +432,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -424,7 +432,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
return true; return true;
} }
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{ {
const index_t num_loop = K / (K0PerBlock * K1); const index_t num_loop = K / (K0PerBlock * K1);
...@@ -485,7 +493,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -485,7 +493,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// return block_id to C matrix tile idx (m0, n0) mapping // return block_id to C matrix tile idx (m0, n0) mapping
using Block2CTileMap = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>; using Block2CTileMap = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>;
template <bool HasMainKBlockLoop, typename Argument> template <bool HasMainKBlockLoop>
__device__ static void Run(const FloatAB* p_a_grid, __device__ static void Run(const FloatAB* p_a_grid,
const FloatAB* p_b_grid, const FloatAB* p_b_grid,
FloatC* p_c_grid, FloatC* p_c_grid,
...@@ -498,9 +506,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -498,9 +506,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
const auto a_grid_desc_k0_m_k1 = const auto a_grid_desc_k0_m_k1 =
MakeAGridDescriptor_K0_M_K1(karg.M, karg.MPadded, karg.K, karg.StrideA); MakeAGridDescriptor_K0_M_K1(karg.M, karg.MPadded, karg.K, karg.K0, karg.StrideA);
const auto b_grid_desc_k0_n_k1 = const auto b_grid_desc_k0_n_k1 =
MakeBGridDescriptor_K0_N_K1(karg.K, karg.N, karg.NPadded, karg.StrideB); MakeBGridDescriptor_K0_N_K1(karg.K, karg.N, karg.NPadded, karg.K0, karg.StrideB);
const auto c_grid_desc_m_n = const auto c_grid_desc_m_n =
MakeCGridDescriptor_M_N(karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideC); MakeCGridDescriptor_M_N(karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideC);
...@@ -518,8 +526,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -518,8 +526,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const BElementwiseOperation b_element_op{}; const BElementwiseOperation b_element_op{};
const CElementwiseOperation c_element_op{}; const CElementwiseOperation c_element_op{};
const index_t K0 = karg.K / K1;
const auto block_2_ctile_map = Block2CTileMap{karg.M, karg.N}; const auto block_2_ctile_map = Block2CTileMap{karg.M, karg.N};
// divide block work by [M, N] // divide block work by [M, N]
...@@ -649,7 +655,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -649,7 +655,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
// gridwise GEMM pipeline // gridwise GEMM pipeline
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock); const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(karg.NumKBlockLoop);
long loop_start = 0, loop_end = 0; long loop_start = 0, loop_end = 0;
GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc_k0_m_k1, GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc_k0_m_k1,
...@@ -666,7 +672,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -666,7 +672,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
b_block_slice_copy_step, b_block_slice_copy_step,
blockwise_gemm, blockwise_gemm,
c_thread_buf, c_thread_buf,
num_k_block_main_loop, loop_start, loop_end); num_k_block_main_loop,
loop_start,
loop_end);
// output: register to global memory // output: register to global memory
{ {
...@@ -750,9 +758,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -750,9 +758,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
asm volatile("; [POYENC] kernel end" ::); asm volatile("; [POYENC] kernel end" ::);
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
if (blockIdx.x == 0 && threadIdx.x == 0) { if(blockIdx.x == 0 && threadIdx.x == 0)
printf("[POYENC] prolog: %ld, hot-loop: %ld, epilog: %ld\n", {
loop_start - kernel_start, loop_end - loop_start, kernel_end - loop_end); printf("[POYENC] prolog: %ld, hot-loop: %ld, epilog: %ld\n",
loop_start - kernel_start,
loop_end - loop_start,
kernel_end - loop_end);
} }
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment