Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
75ada2a6
Commit
75ada2a6
authored
May 10, 2023
by
Po-Yen, Chen
Browse files
Remove integer divisions in device function
parent
9acad4f4
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
32 additions
and
21 deletions
+32
-21
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
+32
-21
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
View file @
75ada2a6
...
...
@@ -116,11 +116,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
return
math
::
integer_divide_ceil
(
N
,
NPerBlock
)
*
NPerBlock
;
}
__device__
static
auto
MakeAGridDescriptor_K0_M_K1
(
index_t
M
,
index_t
MPad
,
index_t
K
,
index_t
StrideA
)
__host__
static
auto
CalculateK0
(
index_t
K
)
{
return
math
::
integer_divide_floor
(
K
,
K1Value
);
}
__host__
static
auto
CalculateNumKBlockLoop
(
index_t
K
)
{
const
index_t
K0
=
K
/
K1
;
return
math
::
integer_divide_floor
(
CalculateK0
(
K
),
K0PerBlock
);
}
__device__
static
auto
MakeAGridDescriptor_K0_M_K1
(
index_t
M
,
index_t
MPad
,
index_t
K
,
index_t
K0
,
index_t
StrideA
)
{
const
auto
a_grid_desc_m_k
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
{
...
...
@@ -153,10 +158,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
}
__device__
static
auto
MakeBGridDescriptor_K0_N_K1
(
index_t
K
,
index_t
N
,
index_t
NPad
,
index_t
StrideB
)
MakeBGridDescriptor_K0_N_K1
(
index_t
K
,
index_t
N
,
index_t
NPad
,
index_t
K0
,
index_t
StrideB
)
{
const
index_t
K0
=
K
/
K1
;
const
auto
b_grid_desc_k_n
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
...
...
@@ -243,7 +246,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
StrideB
{
StrideB_
},
StrideC
{
StrideC_
},
MPadded
{
CalculateMPadded
(
M_
)},
NPadded
{
CalculateNPadded
(
N_
)}
NPadded
{
CalculateNPadded
(
N_
)},
K0
{
CalculateK0
(
K
)},
NumKBlockLoop
{
CalculateNumKBlockLoop
(
K
)}
{
}
...
...
@@ -257,7 +262,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
<<
"SB:"
<<
StrideB
<<
", "
<<
"SC:"
<<
StrideC
<<
", "
<<
"MP:"
<<
MPadded
<<
", "
<<
"NP:"
<<
NPadded
<<
"}"
<<
std
::
endl
;
<<
"NP:"
<<
NPadded
<<
", "
<<
"K0:"
<<
K0
<<
", "
<<
"NumKBlockLoop: "
<<
NumKBlockLoop
<<
"}"
<<
std
::
endl
;
}
const
FloatAB
*
p_a_grid
;
...
...
@@ -271,6 +278,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
index_t
StrideC
;
index_t
MPadded
;
index_t
NPadded
;
index_t
K0
;
index_t
NumKBlockLoop
;
};
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
...
...
@@ -349,8 +358,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template
<
typename
Argument
>
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
Argument
&
karg
)
__host__
static
constexpr
bool
CheckValidity
(
const
Argument
&
karg
)
{
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
decltype
(
K1
)
>>::
value
,
"wrong! K1 need to be known at compile-time"
);
...
...
@@ -424,7 +432,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
return
true
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
__host__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
{
const
index_t
num_loop
=
K
/
(
K0PerBlock
*
K1
);
...
...
@@ -485,7 +493,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// return block_id to C matrix tile idx (m0, n0) mapping
using
Block2CTileMap
=
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
>
;
template
<
bool
HasMainKBlockLoop
,
typename
Argument
>
template
<
bool
HasMainKBlockLoop
>
__device__
static
void
Run
(
const
FloatAB
*
p_a_grid
,
const
FloatAB
*
p_b_grid
,
FloatC
*
p_c_grid
,
...
...
@@ -498,9 +506,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
__builtin_amdgcn_sched_barrier
(
0
);
const
auto
a_grid_desc_k0_m_k1
=
MakeAGridDescriptor_K0_M_K1
(
karg
.
M
,
karg
.
MPadded
,
karg
.
K
,
karg
.
StrideA
);
MakeAGridDescriptor_K0_M_K1
(
karg
.
M
,
karg
.
MPadded
,
karg
.
K
,
karg
.
K0
,
karg
.
StrideA
);
const
auto
b_grid_desc_k0_n_k1
=
MakeBGridDescriptor_K0_N_K1
(
karg
.
K
,
karg
.
N
,
karg
.
NPadded
,
karg
.
StrideB
);
MakeBGridDescriptor_K0_N_K1
(
karg
.
K
,
karg
.
N
,
karg
.
NPadded
,
karg
.
K0
,
karg
.
StrideB
);
const
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
(
karg
.
M
,
karg
.
MPadded
,
karg
.
N
,
karg
.
NPadded
,
karg
.
StrideC
);
...
...
@@ -518,8 +526,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const
BElementwiseOperation
b_element_op
{};
const
CElementwiseOperation
c_element_op
{};
const
index_t
K0
=
karg
.
K
/
K1
;
const
auto
block_2_ctile_map
=
Block2CTileMap
{
karg
.
M
,
karg
.
N
};
// divide block work by [M, N]
...
...
@@ -649,7 +655,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
// gridwise GEMM pipeline
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
K0
/
K0Per
Block
);
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
karg
.
NumK
Block
Loop
);
long
loop_start
=
0
,
loop_end
=
0
;
GridwiseGemmPipe
::
template
Run
<
HasMainKBlockLoop
>(
a_grid_desc_k0_m_k1
,
...
...
@@ -666,7 +672,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
b_block_slice_copy_step
,
blockwise_gemm
,
c_thread_buf
,
num_k_block_main_loop
,
loop_start
,
loop_end
);
num_k_block_main_loop
,
loop_start
,
loop_end
);
// output: register to global memory
{
...
...
@@ -750,9 +758,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
asm
volatile
(
"; [POYENC] kernel end"
::
);
__builtin_amdgcn_sched_barrier
(
0
);
if
(
blockIdx
.
x
==
0
&&
threadIdx
.
x
==
0
)
{
printf
(
"[POYENC] prolog: %ld, hot-loop: %ld, epilog: %ld
\n
"
,
loop_start
-
kernel_start
,
loop_end
-
loop_start
,
kernel_end
-
loop_end
);
if
(
blockIdx
.
x
==
0
&&
threadIdx
.
x
==
0
)
{
printf
(
"[POYENC] prolog: %ld, hot-loop: %ld, epilog: %ld
\n
"
,
loop_start
-
kernel_start
,
loop_end
-
loop_start
,
kernel_end
-
loop_end
);
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment