Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
1b2bf88d
Commit
1b2bf88d
authored
Oct 10, 2024
by
Adam Osewski
Browse files
Refactoring and review comment.s
parent
ba676917
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
27 additions
and
26 deletions
+27
-26
example/ck_tile/03_gemm/CMakeLists.txt
example/ck_tile/03_gemm/CMakeLists.txt
+0
-1
include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp
...ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp
+15
-15
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
.../ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
+3
-2
test/ck_tile/CMakeLists.txt
test/ck_tile/CMakeLists.txt
+1
-0
test/ck_tile/gemm/test_gemm_mem_pipeline.cpp
test/ck_tile/gemm/test_gemm_mem_pipeline.cpp
+4
-4
test/ck_tile/gemm/test_gemm_mem_pipeline_util.hpp
test/ck_tile/gemm/test_gemm_mem_pipeline_util.hpp
+4
-4
No files found.
example/ck_tile/03_gemm/CMakeLists.txt
View file @
1b2bf88d
# set(CMAKE_BUILD_TYPE Debug)
add_executable
(
tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp
)
add_executable
(
tile_example_gemm_basic_mem_pipeline EXCLUDE_FROM_ALL gemm_basic_mem_pipeline.cpp
)
include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp
View file @
1b2bf88d
...
...
@@ -24,19 +24,19 @@ struct BlockGemmASmemBSmemCRegV1
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
// C += A * B
template
<
typename
CBlockTensor
,
typename
ABlockWindow
Tmp
,
typename
BBlockWindow
Tmp
>
template
<
typename
CBlockTensor
,
typename
ABlockWindow
,
typename
BBlockWindow
>
CK_TILE_DEVICE
void
operator
()(
CBlockTensor
&
c_block_tensor
,
const
ABlockWindow
Tmp
&
a_block_window
_tmp
,
const
BBlockWindow
Tmp
&
b_block_window
_tmp
)
const
const
ABlockWindow
&
a_block_window
,
const
BBlockWindow
&
b_block_window
)
const
{
static_assert
(
std
::
is_same_v
<
ADataType
,
typename
ABlockWindow
Tmp
::
DataType
>
&&
std
::
is_same_v
<
BDataType
,
typename
BBlockWindow
Tmp
::
DataType
>
&&
static_assert
(
std
::
is_same_v
<
ADataType
,
typename
ABlockWindow
::
DataType
>
&&
std
::
is_same_v
<
BDataType
,
typename
BBlockWindow
::
DataType
>
&&
std
::
is_same_v
<
AccDataType
,
typename
CBlockTensor
::
DataType
>
,
"wrong!"
);
constexpr
index_t
MPerBlock
=
ABlockWindow
Tmp
{}.
get_window_lengths
()[
number
<
0
>
{}];
constexpr
index_t
NPerBlock
=
BBlockWindow
Tmp
{}.
get_window_lengths
()[
number
<
0
>
{}];
constexpr
index_t
KPerBlock
=
ABlockWindow
Tmp
{}.
get_window_lengths
()[
number
<
1
>
{}];
constexpr
index_t
MPerBlock
=
ABlockWindow
{}.
get_window_lengths
()[
number
<
0
>
{}];
constexpr
index_t
NPerBlock
=
BBlockWindow
{}.
get_window_lengths
()[
number
<
0
>
{}];
constexpr
index_t
KPerBlock
=
ABlockWindow
{}.
get_window_lengths
()[
number
<
1
>
{}];
static_assert
(
MPerBlock
==
BlockGemmShape
::
kM
&&
NPerBlock
==
BlockGemmShape
::
kN
&&
KPerBlock
==
BlockGemmShape
::
kK
,
...
...
@@ -62,9 +62,9 @@ struct BlockGemmASmemBSmemCRegV1
// construct A-warp-window
auto
a_warp_window_tmp
=
make_tile_window
(
a_block_window
_tmp
.
get_bottom_tensor_view
(),
a_block_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
WG
::
kM
>
{},
number
<
WG
::
kK
>
{}),
a_block_window
_tmp
.
get_window_origin
()
+
multi_index
<
2
>
{
iMWarp
*
WG
::
kM
,
0
},
a_block_window
.
get_window_origin
()
+
multi_index
<
2
>
{
iMWarp
*
WG
::
kM
,
0
},
make_static_tile_distribution
(
typename
WG
::
AWarpDstrEncoding
{}));
#if 0 // FIXME: using array will cause register spill
...
...
@@ -97,9 +97,9 @@ struct BlockGemmASmemBSmemCRegV1
// construct B-warp-window
auto
b_warp_window_tmp
=
make_tile_window
(
b_block_window
_tmp
.
get_bottom_tensor_view
(),
b_block_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
WG
::
kN
>
{},
number
<
WG
::
kK
>
{}),
b_block_window
_tmp
.
get_window_origin
()
+
multi_index
<
2
>
{
iNWarp
*
WG
::
kN
,
0
},
b_block_window
.
get_window_origin
()
+
multi_index
<
2
>
{
iNWarp
*
WG
::
kN
,
0
},
make_static_tile_distribution
(
typename
WG
::
BWarpDstrEncoding
{}));
#if 0 // FIXME: using array will cause register spill
...
...
@@ -200,12 +200,12 @@ struct BlockGemmASmemBSmemCRegV1
}
// C = A * B
template
<
typename
ABlockTensorTmp
,
typename
BBlockWindow
Tmp
>
template
<
typename
ABlockTensorTmp
,
typename
BBlockWindow
>
CK_TILE_DEVICE
auto
operator
()(
const
ABlockTensorTmp
&
a_block_tensor_tmp
,
const
BBlockWindow
Tmp
&
b_block_window
_tmp
)
const
const
BBlockWindow
&
b_block_window
)
const
{
auto
c_block_tensor
=
MakeCBlockTile
();
operator
()(
c_block_tensor
,
a_block_tensor_tmp
,
b_block_window
_tmp
);
operator
()(
c_block_tensor
,
a_block_tensor_tmp
,
b_block_window
);
return
c_block_tensor
;
}
};
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
View file @
1b2bf88d
...
...
@@ -12,7 +12,6 @@ namespace ck_tile {
// A Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
template
<
typename
Problem
>
struct
BaseGemmPipelineAgBgCrMem
{
...
...
@@ -25,11 +24,13 @@ struct BaseGemmPipelineAgBgCrMem
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
static
constexpr
index_t
MinMemInFlyBytes
=
32768
;
static
constexpr
index_t
WgpPerCU
=
(
4
*
get_warp_size
()
/
BlockSize
)
>=
1
?
4
*
get_warp_size
()
/
BlockSize
:
1
;
// TODO: Is this 32K value gfx9 arch specific?
static
constexpr
index_t
FullMemBandPrefetchStages
=
integer_divide_ceil
(
32768
/
WgpPerCU
,
MinMemInFlyBytes
/
WgpPerCU
,
(
MPerBlock
*
sizeof
(
ADataType
)
+
NPerBlock
*
sizeof
(
BDataType
))
*
KPerBlock
);
static
constexpr
index_t
PrefetchStages
=
FullMemBandPrefetchStages
>=
2
...
...
test/ck_tile/CMakeLists.txt
View file @
1b2bf88d
add_subdirectory
(
image_to_column
)
add_subdirectory
(
gemm
)
test/ck_tile/gemm/test_gemm_mem_pipeline.cpp
View file @
1b2bf88d
...
...
@@ -17,11 +17,11 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
// clang-format off
using
KernelTypes
=
::
testing
::
Types
<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F
32
>
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F
16
>
// TODO: fixme!
// std::tuple< Col, Row, Row, F16, F16, F32, F
32
>,
// std::tuple< Row, Row, Row, F16, F16, F32, F
32
>,
// std::tuple< Col, Col, Row, F16, F16, F32, F
32
>
// std::tuple< Col, Row, Row, F16, F16, F32, F
16
>,
// std::tuple< Row, Row, Row, F16, F16, F32, F
16
>,
// std::tuple< Col, Col, Row, F16, F16, F32, F
16
>
>
;
// clang-format on
...
...
test/ck_tile/gemm/test_gemm_mem_pipeline_util.hpp
View file @
1b2bf88d
...
...
@@ -14,10 +14,9 @@ template <typename Tuple>
class
TestCkTileGemmMemPipeline
:
public
::
testing
::
Test
{
protected:
using
ALayout
=
std
::
tuple_element_t
<
0
,
Tuple
>
;
using
BLayout
=
std
::
tuple_element_t
<
1
,
Tuple
>
;
using
CLayout
=
std
::
tuple_element_t
<
2
,
Tuple
>
;
;
using
ALayout
=
std
::
tuple_element_t
<
0
,
Tuple
>
;
using
BLayout
=
std
::
tuple_element_t
<
1
,
Tuple
>
;
using
CLayout
=
std
::
tuple_element_t
<
2
,
Tuple
>
;
using
ADataType
=
std
::
tuple_element_t
<
3
,
Tuple
>
;
using
BDataType
=
std
::
tuple_element_t
<
4
,
Tuple
>
;
using
AccDataType
=
std
::
tuple_element_t
<
5
,
Tuple
>
;
...
...
@@ -90,6 +89,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
using
GemmPipeline
=
ck_tile
::
GemmPipelineAgBgCrMem
<
ck_tile
::
UniversalGemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
GemmShape
,
ALayout
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment