Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
41fc6a24
Commit
41fc6a24
authored
Oct 09, 2024
by
Adam Osewski
Browse files
Few small changes & formatting.
parent
7ffb0921
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
17 additions
and
22 deletions
+17
-22
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
+11
-7
include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp
...ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp
+1
-1
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
.../ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
+5
-14
No files found.
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
View file @
41fc6a24
...
...
@@ -26,7 +26,7 @@ struct GemmKernel
using
ADataType
=
remove_cvref_t
<
typename
GemmPipeline
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
GemmPipeline
::
BDataType
>
;
// using CAccDataType = remove_cvref_t<typename GemmPipeline::CDataType>;
using
CDataType
=
remove_cvref_t
<
typename
EpiloguePipeline
::
C
DataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
EpiloguePipeline
::
O
DataType
>
;
__host__
static
constexpr
auto
GridSize
(
index_t
M
,
index_t
N
,
index_t
KBatch
)
{
...
...
@@ -118,8 +118,10 @@ struct GemmKernel
auto
a_pad_view
=
pad_tensor_view
(
a_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
sequence
<
false
,
GemmPipeline
::
kPadA
?
true
:
false
>
{});
// somehow clang-format is splitting below line into multiple.
// clang-format off
sequence
<
false
,
GemmPipeline
::
kPadA
?
true
:
false
>
{});
// clang-format on
auto
a_block_window
=
make_tile_window
(
a_pad_view
,
...
...
@@ -129,8 +131,9 @@ struct GemmKernel
auto
b_pad_view
=
pad_tensor_view
(
b_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kN
>
{},
number
<
TilePartitioner
::
kK
>
{}),
sequence
<
false
,
GemmPipeline
::
kPadB
?
true
:
false
>
{});
// clang-format off
sequence
<
false
,
GemmPipeline
::
kPadB
?
true
:
false
>
{});
// clang-format on
auto
b_block_window
=
make_tile_window
(
b_pad_view
,
...
...
@@ -171,8 +174,9 @@ struct GemmKernel
auto
c_pad_view
=
pad_tensor_view
(
c_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
sequence
<
false
,
GemmPipeline
::
kPadC
?
true
:
false
>
{});
// clang-format off
sequence
<
false
,
GemmPipeline
::
kPadC
?
true
:
false
>
{});
// clang-format on
auto
CBlockWindow
=
make_tile_window
(
c_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
...
...
include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp
View file @
41fc6a24
...
...
@@ -51,7 +51,7 @@ template <typename ADataType_,
bool
kPadB_
=
false
,
bool
kPadC_
=
false
,
GemmPipelineScheduler
Scheduler_
=
GemmPipelineScheduler
::
Intrawave
,
bool
HasHotLoop_
=
fals
e
,
bool
HasHotLoop_
=
tru
e
,
TailNumber
TailNum_
=
TailNumber
::
Full
>
struct
UniversalGemmPipelineProblem
{
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
View file @
41fc6a24
...
...
@@ -27,6 +27,7 @@ struct BaseGemmPipelineAgBgCrMem
static
constexpr
index_t
WgpPerCU
=
(
4
*
get_warp_size
()
/
BlockSize
)
>=
1
?
4
*
get_warp_size
()
/
BlockSize
:
1
;
// TODO: Is this 32K value gfx9 arch specific?
static
constexpr
index_t
FullMemBandPrefetchStages
=
integer_divide_ceil
(
32768
/
WgpPerCU
,
(
MPerBlock
*
sizeof
(
ADataType
)
+
NPerBlock
*
sizeof
(
BDataType
))
*
KPerBlock
);
...
...
@@ -206,6 +207,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
integer_divide_ceil
(
sizeof
(
ADataType
)
*
a_lds_block_desc
.
get_element_space_size
(),
16
)
*
16
;
// B tile in LDS
BDataType
*
p_b_lds
=
static_cast
<
BDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
p_smem
)
+
a_lds_block_space_size_aligned
));
...
...
@@ -251,20 +253,6 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
// Block GEMM
constexpr
auto
block_gemm
=
BlockGemm
();
auto
c_block_tile
=
block_gemm
.
MakeCBlockTile
();
// -----------------------------------------------------------------------------------------
// Gemm pipeline start
if
(
blockIdx
.
x
==
0
&&
blockIdx
.
y
==
0
&&
threadIdx
.
x
==
0
)
{
printf
(
"Pipeline >>> bid: %d, tid: %d: HasHotLoop: %s, TailNumber: %d,"
" PrefetchStages: %d, num_loop: %d
\n
"
,
blockIdx
.
x
,
threadIdx
.
x
,
(
HasHotLoop
?
"True"
:
"False"
),
static_cast
<
index_t
>
(
TailNum
),
PrefetchStages
,
num_loop
);
}
using
ABlockTileDistr
=
decltype
(
a_copy_dram_window
.
get_tile_distribution
());
using
BBlockTileDistr
=
decltype
(
b_copy_dram_window
.
get_tile_distribution
());
...
...
@@ -277,6 +265,9 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
tuple_array
<
ABlockTile
,
PrefetchStages
>
a_block_tiles
;
tuple_array
<
BBlockTile
,
PrefetchStages
>
b_block_tiles
;
// -----------------------------------------------------------------------------------------
// Gemm pipeline start
// prefetch
// global read 0
GlobalPrefetch
(
a_block_tiles
.
get
(
I0
{}),
a_copy_dram_window
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment