Commit 75ada2a6 authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Remove integer divisions in device function

parent 9acad4f4
......@@ -116,11 +116,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
return math::integer_divide_ceil(N, NPerBlock) * NPerBlock;
}
__device__ static auto
MakeAGridDescriptor_K0_M_K1(index_t M, index_t MPad, index_t K, index_t StrideA)
__host__ static auto CalculateK0(index_t K) { return math::integer_divide_floor(K, K1Value); }
__host__ static auto CalculateNumKBlockLoop(index_t K)
{
const index_t K0 = K / K1;
return math::integer_divide_floor(CalculateK0(K), K0PerBlock);
}
__device__ static auto
MakeAGridDescriptor_K0_M_K1(index_t M, index_t MPad, index_t K, index_t K0, index_t StrideA)
{
const auto a_grid_desc_m_k = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
......@@ -153,10 +158,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
}
__device__ static auto
MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t NPad, index_t StrideB)
MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t NPad, index_t K0, index_t StrideB)
{
const index_t K0 = K / K1;
const auto b_grid_desc_k_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
......@@ -243,7 +246,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
StrideB{StrideB_},
StrideC{StrideC_},
MPadded{CalculateMPadded(M_)},
NPadded{CalculateNPadded(N_)}
NPadded{CalculateNPadded(N_)},
K0{CalculateK0(K)},
NumKBlockLoop{CalculateNumKBlockLoop(K)}
{
}
......@@ -257,7 +262,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
<< "SB:" << StrideB << ", "
<< "SC:" << StrideC << ", "
<< "MP:" << MPadded << ", "
<< "NP:" << NPadded << "}" << std::endl;
<< "NP:" << NPadded << ", "
<< "K0:" << K0 << ", "
<< "NumKBlockLoop: " << NumKBlockLoop << "}" << std::endl;
}
const FloatAB* p_a_grid;
......@@ -271,6 +278,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
index_t StrideC;
index_t MPadded;
index_t NPadded;
index_t K0;
index_t NumKBlockLoop;
};
using GridwiseGemmPipe = remove_cvref_t<decltype(
......@@ -349,8 +358,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Argument>
__host__ __device__ static constexpr bool CheckValidity(const Argument& karg)
__host__ static constexpr bool CheckValidity(const Argument& karg)
{
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
"wrong! K1 need to be known at compile-time");
......@@ -424,7 +432,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
return true;
}
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
__host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{
const index_t num_loop = K / (K0PerBlock * K1);
......@@ -485,7 +493,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// return block_id to C matrix tile idx (m0, n0) mapping
using Block2CTileMap = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>;
template <bool HasMainKBlockLoop, typename Argument>
template <bool HasMainKBlockLoop>
__device__ static void Run(const FloatAB* p_a_grid,
const FloatAB* p_b_grid,
FloatC* p_c_grid,
......@@ -498,9 +506,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
__builtin_amdgcn_sched_barrier(0);
const auto a_grid_desc_k0_m_k1 =
MakeAGridDescriptor_K0_M_K1(karg.M, karg.MPadded, karg.K, karg.StrideA);
MakeAGridDescriptor_K0_M_K1(karg.M, karg.MPadded, karg.K, karg.K0, karg.StrideA);
const auto b_grid_desc_k0_n_k1 =
MakeBGridDescriptor_K0_N_K1(karg.K, karg.N, karg.NPadded, karg.StrideB);
MakeBGridDescriptor_K0_N_K1(karg.K, karg.N, karg.NPadded, karg.K0, karg.StrideB);
const auto c_grid_desc_m_n =
MakeCGridDescriptor_M_N(karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideC);
......@@ -518,8 +526,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const BElementwiseOperation b_element_op{};
const CElementwiseOperation c_element_op{};
const index_t K0 = karg.K / K1;
const auto block_2_ctile_map = Block2CTileMap{karg.M, karg.N};
// divide block work by [M, N]
......@@ -649,7 +655,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
// gridwise GEMM pipeline
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock);
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(karg.NumKBlockLoop);
long loop_start = 0, loop_end = 0;
GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc_k0_m_k1,
......@@ -666,7 +672,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
b_block_slice_copy_step,
blockwise_gemm,
c_thread_buf,
num_k_block_main_loop, loop_start, loop_end);
num_k_block_main_loop,
loop_start,
loop_end);
// output: register to global memory
{
......@@ -750,9 +758,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
asm volatile("; [POYENC] kernel end" ::);
__builtin_amdgcn_sched_barrier(0);
if (blockIdx.x == 0 && threadIdx.x == 0) {
printf("[POYENC] prolog: %ld, hot-loop: %ld, epilog: %ld\n",
loop_start - kernel_start, loop_end - loop_start, kernel_end - loop_end);
if(blockIdx.x == 0 && threadIdx.x == 0)
{
printf("[POYENC] prolog: %ld, hot-loop: %ld, epilog: %ld\n",
loop_start - kernel_start,
loop_end - loop_start,
kernel_end - loop_end);
}
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment