Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
d1715c0f
"magic_pdf/git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "92f9df5dd7f1d7cbcd089406e297a59a1c29b4f9"
Commit
d1715c0f
authored
Feb 05, 2025
by
ThomasNing
Browse files
Fix the gtest compilation error
parent
987cc54d
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
49 additions
and
16 deletions
+49
-16
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp
...tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp
+3
-1
test/ck_tile/gemm/test_gemm_pipeline.cpp
test/ck_tile/gemm/test_gemm_pipeline.cpp
+6
-5
test/ck_tile/gemm/test_gemm_pipeline_util.hpp
test/ck_tile/gemm/test_gemm_pipeline_util.hpp
+40
-10
No files found.
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp
View file @
d1715c0f
...
...
@@ -42,7 +42,9 @@ struct BaseGemmPipelineAgBgCrCompV4
// the ping-pong buffer to grab memory from the global memory. While one LDS is grabbing the data
// from global memory, the other will call the warps on running the MFMA matrix multiplication. When
// the matrix is in bigger shape, it will keep the Warp always busy and cover the memory loading
// time.
// time. It will have better performance comparing to the Compute Version 3 when they have the same
// block tile and better performance when you have M, N, K all > 8K even when the compute V3 block
// size is 2 times of the compute V4.
template
<
typename
Problem
,
typename
Policy
=
GemmPipelineAGmemBGmemCregComputeV4DefaultPolicy
>
struct
GemmPipelineAgBgCrCompV4
:
public
BaseGemmPipelineAgBgCrCompV4
<
Problem
>
{
...
...
test/ck_tile/gemm/test_gemm_pipeline.cpp
View file @
d1715c0f
...
...
@@ -17,7 +17,8 @@ using Intrawave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
// using Interwave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
// ck_tile::GemmPipelineScheduler::Interwave>;
// using Mem = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Mem>;
using
Comp
=
ck_tile
::
integral_constant
<
GemmPipelineType
,
GemmPipelineType
::
Comp
>
;
// using CompV4 = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::CompV4>;
using
CompV3
=
ck_tile
::
integral_constant
<
GemmPipelineType
,
GemmPipelineType
::
CompV3
>
;
// TODO: Enable Memory pipeline, when it would be updated for vector loads on non-K major tensors.
...
...
@@ -25,16 +26,16 @@ using Comp = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Comp
using
KernelTypes
=
::
testing
::
Types
<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, GemmPipelineScheduler, PipelineType
// std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Mem>,
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
V3
>
,
// std::tuple< Row, Row, Row, F16, F16, F32, F16, Interwave, Mem>,
// std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, Mem>,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
V3
>
,
// std::tuple< Row, Col, Row, F16, F16, F32, F16, Interwave, Mem>,
// std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, Mem>,
std
::
tuple
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
,
std
::
tuple
<
Col
,
Row
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
V3
>
,
// std::tuple< Col, Row, Row, F16, F16, F32, F16, Interwave, Mem>,
// std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Mem>,
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
>
std
::
tuple
<
Col
,
Col
,
Row
,
F16
,
F16
,
F32
,
F16
,
Intrawave
,
Comp
V3
>
// std::tuple< Col, Col, Row, F16, F16, F32, F16, Interwave, Mem>
>
;
// clang-format on
...
...
test/ck_tile/gemm/test_gemm_pipeline_util.hpp
View file @
d1715c0f
...
...
@@ -14,7 +14,8 @@
enum
struct
GemmPipelineType
{
Mem
,
Comp
CompV3
,
CompV4
};
template
<
typename
Tuple
>
...
...
@@ -52,6 +53,8 @@ class TestCkTileGemmPipeline : public ::testing::Test
constexpr
bool
kPadN
=
PadN
;
constexpr
bool
kPadK
=
PadK
;
constexpr
bool
DoubleSmemBuffer
=
(
PipelineType
==
GemmPipelineType
::
CompV4
)
?
true
:
false
;
// TODO: For now - but this should also be a test parameter
constexpr
bool
TransposeC
=
false
;
...
...
@@ -69,16 +72,24 @@ class TestCkTileGemmPipeline : public ::testing::Test
GemmSpatiallyLocalTilePartitioner
<
GemmShape
,
TileParitionerGroupNum
,
TileParitionerM01
>
;
using
Traits
=
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
>
;
using
GemmUniversalTraits
=
ck_tile
::
TileGemmUniversalTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
,
TransposeC
>
;
using
GemmUniversalTraits
=
ck_tile
::
TileGemmUniversalTraits
<
kPadM
,
kPadN
,
kPadK
,
DoubleSmemBuffer
,
ALayout
,
BLayout
,
CLayout
,
TransposeC
>
;
using
GemmPipelineProblem
=
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
>
;
using
BaseGemmPipeline
=
std
::
conditional_t
<
PipelineType
==
GemmPipelineType
::
Mem
,
ck_tile
::
BaseGemmPipelineAgBgCrMem
<
GemmPipelineProblem
>
,
ck_tile
::
BaseGemmPipelineAgBgCrCompV3
<
GemmPipelineProblem
>>
;
using
BaseGemmPipeline
=
std
::
conditional_t
<
PipelineType
==
GemmPipelineType
::
Mem
,
ck_tile
::
BaseGemmPipelineAgBgCrMem
<
GemmPipelineProblem
>
,
std
::
conditional_t
<
PipelineType
==
GemmPipelineType
::
CompV3
,
ck_tile
::
BaseGemmPipelineAgBgCrCompV3
<
GemmPipelineProblem
>
,
ck_tile
::
BaseGemmPipelineAgBgCrCompV4
<
GemmPipelineProblem
>>>
;
const
ck_tile
::
index_t
k_grain
=
args
.
k_batch
*
K_Tile
;
const
ck_tile
::
index_t
K_split
=
(
args
.
K
+
k_grain
-
1
)
/
k_grain
*
K_Tile
;
...
...
@@ -103,8 +114,11 @@ class TestCkTileGemmPipeline : public ::testing::Test
PipelineType
==
GemmPipelineType
::
Mem
,
ck_tile
::
GemmPipelineAgBgCrMem
<
UniversalGemmProblem
,
ck_tile
::
UniversalGemmPipelineAgBgCrPolicy
>
,
ck_tile
::
GemmPipelineAgBgCrCompV3
<
UniversalGemmProblem
,
ck_tile
::
UniversalGemmPipelineAgBgCrPolicy
>>
;
std
::
conditional_t
<
PipelineType
==
GemmPipelineType
::
CompV3
,
ck_tile
::
GemmPipelineAgBgCrCompV3
<
UniversalGemmProblem
,
ck_tile
::
UniversalGemmPipelineAgBgCrPolicy
>
,
ck_tile
::
GemmPipelineAgBgCrCompV4
<
UniversalGemmProblem
>>>
;
using
GemmEpilogue
=
ck_tile
::
CShuffleEpilogue
<
ck_tile
::
CShuffleEpilogueProblem
<
AccDataType
,
...
...
@@ -145,7 +159,7 @@ class TestCkTileGemmPipeline : public ::testing::Test
if
(
has_hot_loop
)
{
if
constexpr
(
PipelineType
==
GemmPipelineType
::
Comp
)
if
constexpr
(
PipelineType
==
GemmPipelineType
::
Comp
V3
)
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
Full
)
{
...
...
@@ -235,6 +249,22 @@ class TestCkTileGemmPipeline : public ::testing::Test
}
}
}
if
constexpr
(
PipelineType
==
GemmPipelineType
::
CompV4
)
{
if
(
tail_num
==
ck_tile
::
TailNumber
::
Three
)
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Three
>
{});
}
else
{
Run
(
ck_tile
::
bool_constant
<
true
>
{},
ck_tile
::
integral_constant
<
ck_tile
::
TailNumber
,
ck_tile
::
TailNumber
::
Two
>
{});
}
}
}
else
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment