Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
0c5f9438
Commit
0c5f9438
authored
May 22, 2023
by
Adam Osewski
Browse files
Constepxr everything!
parent
01f0831b
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
115 additions
and
115 deletions
+115
-115
test/gemm_split_k/test_gemm_splitk_ut_cases.inc
test/gemm_split_k/test_gemm_splitk_ut_cases.inc
+72
-72
test/gemm_split_k/test_gemm_splitk_util.hpp
test/gemm_split_k/test_gemm_splitk_util.hpp
+4
-4
test/grouped_gemm/test_grouped_gemm_interface.cpp
test/grouped_gemm/test_grouped_gemm_interface.cpp
+11
-11
test/grouped_gemm/test_grouped_gemm_ut_cases.inc
test/grouped_gemm/test_grouped_gemm_ut_cases.inc
+24
-24
test/grouped_gemm/test_grouped_gemm_util.hpp
test/grouped_gemm/test_grouped_gemm_util.hpp
+4
-4
No files found.
test/gemm_split_k/test_gemm_splitk_ut_cases.inc
View file @
0c5f9438
...
@@ -3,12 +3,12 @@
...
@@ -3,12 +3,12 @@
TYPED_TEST
(
TestGemmSplitK_MK_KN
,
SmallM
)
TYPED_TEST
(
TestGemmSplitK_MK_KN
,
SmallM
)
{
{
std
::
vector
<
int
>
Ms
{
0
,
1
,
2
,
3
,
4
,
5
,
6
};
std
::
vector
<
int
>
Ms
{
0
,
1
,
2
,
3
,
4
,
5
,
6
};
int
N
=
512
;
constexpr
int
N
=
512
;
int
K
=
320
;
constexpr
int
K
=
320
;
int
StrideA
=
K
;
constexpr
int
StrideA
=
K
;
int
StrideB
=
N
;
constexpr
int
StrideB
=
N
;
int
StrideC
=
N
;
constexpr
int
StrideC
=
N
;
for
(
int
M
:
Ms
)
for
(
int
M
:
Ms
)
this
->
Run
(
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
);
this
->
Run
(
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
);
...
@@ -17,12 +17,12 @@ TYPED_TEST(TestGemmSplitK_MK_KN, SmallM)
...
@@ -17,12 +17,12 @@ TYPED_TEST(TestGemmSplitK_MK_KN, SmallM)
TYPED_TEST
(
TestGemmSplitK_MK_NK
,
SmallM
)
TYPED_TEST
(
TestGemmSplitK_MK_NK
,
SmallM
)
{
{
std
::
vector
<
int
>
Ms
{
0
,
1
,
2
,
3
,
4
,
5
,
6
};
std
::
vector
<
int
>
Ms
{
0
,
1
,
2
,
3
,
4
,
5
,
6
};
int
N
=
512
;
constexpr
int
N
=
512
;
int
K
=
320
;
constexpr
int
K
=
320
;
int
StrideA
=
K
;
constexpr
int
StrideA
=
K
;
int
StrideB
=
K
;
constexpr
int
StrideB
=
K
;
int
StrideC
=
N
;
constexpr
int
StrideC
=
N
;
for
(
int
M
:
Ms
)
for
(
int
M
:
Ms
)
this
->
Run
(
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
);
this
->
Run
(
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
);
...
@@ -31,11 +31,11 @@ TYPED_TEST(TestGemmSplitK_MK_NK, SmallM)
...
@@ -31,11 +31,11 @@ TYPED_TEST(TestGemmSplitK_MK_NK, SmallM)
TYPED_TEST
(
TestGemmSplitK_KM_KN
,
SmallM
)
TYPED_TEST
(
TestGemmSplitK_KM_KN
,
SmallM
)
{
{
std
::
vector
<
int
>
Ms
{
0
,
1
,
2
,
3
,
4
,
5
,
6
};
std
::
vector
<
int
>
Ms
{
0
,
1
,
2
,
3
,
4
,
5
,
6
};
int
N
=
512
;
constexpr
int
N
=
512
;
int
K
=
320
;
constexpr
int
K
=
320
;
int
StrideB
=
N
;
constexpr
int
StrideB
=
N
;
int
StrideC
=
N
;
constexpr
int
StrideC
=
N
;
for
(
int
M
:
Ms
)
for
(
int
M
:
Ms
)
this
->
Run
(
M
,
N
,
K
,
M
,
StrideB
,
StrideC
);
this
->
Run
(
M
,
N
,
K
,
M
,
StrideB
,
StrideC
);
...
@@ -44,11 +44,11 @@ TYPED_TEST(TestGemmSplitK_KM_KN, SmallM)
...
@@ -44,11 +44,11 @@ TYPED_TEST(TestGemmSplitK_KM_KN, SmallM)
TYPED_TEST
(
TestGemmSplitK_KM_NK
,
SmallM
)
TYPED_TEST
(
TestGemmSplitK_KM_NK
,
SmallM
)
{
{
std
::
vector
<
int
>
Ms
{
0
,
1
,
2
,
3
,
4
,
5
,
6
};
std
::
vector
<
int
>
Ms
{
0
,
1
,
2
,
3
,
4
,
5
,
6
};
int
N
=
512
;
constexpr
int
N
=
512
;
int
K
=
320
;
constexpr
int
K
=
320
;
int
StrideB
=
K
;
constexpr
int
StrideB
=
K
;
int
StrideC
=
N
;
constexpr
int
StrideC
=
N
;
for
(
int
M
:
Ms
)
for
(
int
M
:
Ms
)
this
->
Run
(
M
,
N
,
K
,
M
,
StrideB
,
StrideC
);
this
->
Run
(
M
,
N
,
K
,
M
,
StrideB
,
StrideC
);
...
@@ -57,12 +57,12 @@ TYPED_TEST(TestGemmSplitK_KM_NK, SmallM)
...
@@ -57,12 +57,12 @@ TYPED_TEST(TestGemmSplitK_KM_NK, SmallM)
TYPED_TEST
(
TestGemmSplitK_MK_KN
,
MidLargeM
)
TYPED_TEST
(
TestGemmSplitK_MK_KN
,
MidLargeM
)
{
{
std
::
vector
<
int
>
Ms
{
127
,
255
,
312
,
799
,
1573
};
std
::
vector
<
int
>
Ms
{
127
,
255
,
312
,
799
,
1573
};
int
N
=
512
;
constexpr
int
N
=
512
;
int
K
=
320
;
constexpr
int
K
=
320
;
int
StrideA
=
K
;
constexpr
int
StrideA
=
K
;
int
StrideB
=
N
;
constexpr
int
StrideB
=
N
;
int
StrideC
=
N
;
constexpr
int
StrideC
=
N
;
for
(
int
M
:
Ms
)
for
(
int
M
:
Ms
)
this
->
Run
(
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
);
this
->
Run
(
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
);
...
@@ -71,12 +71,12 @@ TYPED_TEST(TestGemmSplitK_MK_KN, MidLargeM)
...
@@ -71,12 +71,12 @@ TYPED_TEST(TestGemmSplitK_MK_KN, MidLargeM)
TYPED_TEST
(
TestGemmSplitK_MK_NK
,
MidLargeM
)
TYPED_TEST
(
TestGemmSplitK_MK_NK
,
MidLargeM
)
{
{
std
::
vector
<
int
>
Ms
{
127
,
255
,
312
,
799
,
1573
};
std
::
vector
<
int
>
Ms
{
127
,
255
,
312
,
799
,
1573
};
int
N
=
512
;
constexpr
int
N
=
512
;
int
K
=
320
;
constexpr
int
K
=
320
;
int
StrideA
=
K
;
constexpr
int
StrideA
=
K
;
int
StrideB
=
K
;
constexpr
int
StrideB
=
K
;
int
StrideC
=
N
;
constexpr
int
StrideC
=
N
;
for
(
int
M
:
Ms
)
for
(
int
M
:
Ms
)
this
->
Run
(
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
);
this
->
Run
(
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
);
...
@@ -85,11 +85,11 @@ TYPED_TEST(TestGemmSplitK_MK_NK, MidLargeM)
...
@@ -85,11 +85,11 @@ TYPED_TEST(TestGemmSplitK_MK_NK, MidLargeM)
TYPED_TEST
(
TestGemmSplitK_KM_KN
,
MidLargeM
)
TYPED_TEST
(
TestGemmSplitK_KM_KN
,
MidLargeM
)
{
{
std
::
vector
<
int
>
Ms
{
127
,
255
,
312
,
799
,
1573
};
std
::
vector
<
int
>
Ms
{
127
,
255
,
312
,
799
,
1573
};
int
N
=
512
;
constexpr
int
N
=
512
;
int
K
=
320
;
constexpr
int
K
=
320
;
int
StrideB
=
N
;
constexpr
int
StrideB
=
N
;
int
StrideC
=
N
;
constexpr
int
StrideC
=
N
;
for
(
int
M
:
Ms
)
for
(
int
M
:
Ms
)
this
->
Run
(
M
,
N
,
K
,
M
,
StrideB
,
StrideC
);
this
->
Run
(
M
,
N
,
K
,
M
,
StrideB
,
StrideC
);
...
@@ -98,11 +98,11 @@ TYPED_TEST(TestGemmSplitK_KM_KN, MidLargeM)
...
@@ -98,11 +98,11 @@ TYPED_TEST(TestGemmSplitK_KM_KN, MidLargeM)
TYPED_TEST
(
TestGemmSplitK_KM_NK
,
MidLargeM
)
TYPED_TEST
(
TestGemmSplitK_KM_NK
,
MidLargeM
)
{
{
std
::
vector
<
int
>
Ms
{
127
,
255
,
312
,
799
,
1573
};
std
::
vector
<
int
>
Ms
{
127
,
255
,
312
,
799
,
1573
};
int
N
=
512
;
constexpr
int
N
=
512
;
int
K
=
320
;
constexpr
int
K
=
320
;
int
StrideB
=
K
;
constexpr
int
StrideB
=
K
;
int
StrideC
=
N
;
constexpr
int
StrideC
=
N
;
for
(
int
M
:
Ms
)
for
(
int
M
:
Ms
)
this
->
Run
(
M
,
N
,
K
,
M
,
StrideB
,
StrideC
);
this
->
Run
(
M
,
N
,
K
,
M
,
StrideB
,
StrideC
);
...
@@ -111,12 +111,12 @@ TYPED_TEST(TestGemmSplitK_KM_NK, MidLargeM)
...
@@ -111,12 +111,12 @@ TYPED_TEST(TestGemmSplitK_KM_NK, MidLargeM)
TYPED_TEST
(
TestGemmSplitK_MK_KN
,
PaddK
)
TYPED_TEST
(
TestGemmSplitK_MK_KN
,
PaddK
)
{
{
std
::
vector
<
int
>
Ms
{
127
};
std
::
vector
<
int
>
Ms
{
127
};
int
N
=
512
;
constexpr
int
N
=
512
;
int
K
=
437
;
constexpr
int
K
=
437
;
int
StrideA
=
K
;
constexpr
int
StrideA
=
K
;
int
StrideB
=
N
;
constexpr
int
StrideB
=
N
;
int
StrideC
=
N
;
constexpr
int
StrideC
=
N
;
for
(
int
M
:
Ms
)
for
(
int
M
:
Ms
)
this
->
Run
(
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
);
this
->
Run
(
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
);
...
@@ -125,12 +125,12 @@ TYPED_TEST(TestGemmSplitK_MK_KN, PaddK)
...
@@ -125,12 +125,12 @@ TYPED_TEST(TestGemmSplitK_MK_KN, PaddK)
TYPED_TEST
(
TestGemmSplitK_MK_NK
,
PaddK
)
TYPED_TEST
(
TestGemmSplitK_MK_NK
,
PaddK
)
{
{
std
::
vector
<
int
>
Ms
{
127
};
std
::
vector
<
int
>
Ms
{
127
};
int
N
=
512
;
constexpr
int
N
=
512
;
int
K
=
437
;
constexpr
int
K
=
437
;
int
StrideA
=
K
;
constexpr
int
StrideA
=
K
;
int
StrideB
=
K
;
constexpr
int
StrideB
=
K
;
int
StrideC
=
N
;
constexpr
int
StrideC
=
N
;
for
(
int
M
:
Ms
)
for
(
int
M
:
Ms
)
this
->
Run
(
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
);
this
->
Run
(
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
);
...
@@ -139,11 +139,11 @@ TYPED_TEST(TestGemmSplitK_MK_NK, PaddK)
...
@@ -139,11 +139,11 @@ TYPED_TEST(TestGemmSplitK_MK_NK, PaddK)
TYPED_TEST
(
TestGemmSplitK_KM_KN
,
PaddK
)
TYPED_TEST
(
TestGemmSplitK_KM_KN
,
PaddK
)
{
{
std
::
vector
<
int
>
Ms
{
127
};
std
::
vector
<
int
>
Ms
{
127
};
int
N
=
512
;
constexpr
int
N
=
512
;
int
K
=
437
;
constexpr
int
K
=
437
;
int
StrideB
=
N
;
constexpr
int
StrideB
=
N
;
int
StrideC
=
N
;
constexpr
int
StrideC
=
N
;
for
(
int
M
:
Ms
)
for
(
int
M
:
Ms
)
this
->
Run
(
M
,
N
,
K
,
M
,
StrideB
,
StrideC
);
this
->
Run
(
M
,
N
,
K
,
M
,
StrideB
,
StrideC
);
...
@@ -152,11 +152,11 @@ TYPED_TEST(TestGemmSplitK_KM_KN, PaddK)
...
@@ -152,11 +152,11 @@ TYPED_TEST(TestGemmSplitK_KM_KN, PaddK)
TYPED_TEST
(
TestGemmSplitK_KM_NK
,
PaddK
)
TYPED_TEST
(
TestGemmSplitK_KM_NK
,
PaddK
)
{
{
std
::
vector
<
int
>
Ms
{
127
};
std
::
vector
<
int
>
Ms
{
127
};
int
N
=
512
;
constexpr
int
N
=
512
;
int
K
=
437
;
constexpr
int
K
=
437
;
int
StrideB
=
K
;
constexpr
int
StrideB
=
K
;
int
StrideC
=
N
;
constexpr
int
StrideC
=
N
;
for
(
int
M
:
Ms
)
for
(
int
M
:
Ms
)
this
->
Run
(
M
,
N
,
K
,
M
,
StrideB
,
StrideC
);
this
->
Run
(
M
,
N
,
K
,
M
,
StrideB
,
StrideC
);
...
@@ -165,12 +165,12 @@ TYPED_TEST(TestGemmSplitK_KM_NK, PaddK)
...
@@ -165,12 +165,12 @@ TYPED_TEST(TestGemmSplitK_KM_NK, PaddK)
TYPED_TEST
(
TestGemmSplitK_MK_KN
,
Regular
)
TYPED_TEST
(
TestGemmSplitK_MK_KN
,
Regular
)
{
{
std
::
vector
<
int
>
Ms
{
512
};
std
::
vector
<
int
>
Ms
{
512
};
int
N
=
512
;
constexpr
int
N
=
512
;
int
K
=
512
;
constexpr
int
K
=
512
;
int
StrideA
=
K
;
constexpr
int
StrideA
=
K
;
int
StrideB
=
N
;
constexpr
int
StrideB
=
N
;
int
StrideC
=
N
;
constexpr
int
StrideC
=
N
;
for
(
int
M
:
Ms
)
for
(
int
M
:
Ms
)
this
->
Run
(
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
);
this
->
Run
(
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
);
...
@@ -179,12 +179,12 @@ TYPED_TEST(TestGemmSplitK_MK_KN, Regular)
...
@@ -179,12 +179,12 @@ TYPED_TEST(TestGemmSplitK_MK_KN, Regular)
TYPED_TEST
(
TestGemmSplitK_MK_NK
,
Regular
)
TYPED_TEST
(
TestGemmSplitK_MK_NK
,
Regular
)
{
{
std
::
vector
<
int
>
Ms
{
512
};
std
::
vector
<
int
>
Ms
{
512
};
int
N
=
512
;
constexpr
int
N
=
512
;
int
K
=
512
;
constexpr
int
K
=
512
;
int
StrideA
=
K
;
constexpr
int
StrideA
=
K
;
int
StrideB
=
K
;
constexpr
int
StrideB
=
K
;
int
StrideC
=
N
;
constexpr
int
StrideC
=
N
;
for
(
int
M
:
Ms
)
for
(
int
M
:
Ms
)
this
->
Run
(
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
);
this
->
Run
(
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
);
...
@@ -193,11 +193,11 @@ TYPED_TEST(TestGemmSplitK_MK_NK, Regular)
...
@@ -193,11 +193,11 @@ TYPED_TEST(TestGemmSplitK_MK_NK, Regular)
TYPED_TEST
(
TestGemmSplitK_KM_KN
,
Regular
)
TYPED_TEST
(
TestGemmSplitK_KM_KN
,
Regular
)
{
{
std
::
vector
<
int
>
Ms
{
512
};
std
::
vector
<
int
>
Ms
{
512
};
int
N
=
512
;
constexpr
int
N
=
512
;
int
K
=
512
;
constexpr
int
K
=
512
;
int
StrideB
=
N
;
constexpr
int
StrideB
=
N
;
int
StrideC
=
N
;
constexpr
int
StrideC
=
N
;
for
(
int
M
:
Ms
)
for
(
int
M
:
Ms
)
this
->
Run
(
M
,
N
,
K
,
M
,
StrideB
,
StrideC
);
this
->
Run
(
M
,
N
,
K
,
M
,
StrideB
,
StrideC
);
...
@@ -206,11 +206,11 @@ TYPED_TEST(TestGemmSplitK_KM_KN, Regular)
...
@@ -206,11 +206,11 @@ TYPED_TEST(TestGemmSplitK_KM_KN, Regular)
TYPED_TEST
(
TestGemmSplitK_KM_NK
,
Regular
)
TYPED_TEST
(
TestGemmSplitK_KM_NK
,
Regular
)
{
{
std
::
vector
<
int
>
Ms
{
512
};
std
::
vector
<
int
>
Ms
{
512
};
int
N
=
512
;
constexpr
int
N
=
512
;
int
K
=
512
;
constexpr
int
K
=
512
;
int
StrideB
=
K
;
constexpr
int
StrideB
=
K
;
int
StrideC
=
N
;
constexpr
int
StrideC
=
N
;
for
(
int
M
:
Ms
)
for
(
int
M
:
Ms
)
this
->
Run
(
M
,
N
,
K
,
M
,
StrideB
,
StrideC
);
this
->
Run
(
M
,
N
,
K
,
M
,
StrideB
,
StrideC
);
...
...
test/gemm_split_k/test_gemm_splitk_util.hpp
View file @
0c5f9438
...
@@ -33,10 +33,10 @@ class TestGemmSplitK : public testing::Test
...
@@ -33,10 +33,10 @@ class TestGemmSplitK : public testing::Test
using
CDataType
=
std
::
tuple_element_t
<
4
,
Tuple
>
;
using
CDataType
=
std
::
tuple_element_t
<
4
,
Tuple
>
;
public:
public:
bool
verify_
=
true
;
static
constexpr
bool
verify_
=
true
;
int
init_method_
=
1
;
// decimal value initialization
static
constexpr
int
init_method_
=
1
;
// decimal value initialization
bool
log_
=
false
;
static
constexpr
bool
log_
=
false
;
bool
bench_
=
false
;
// measure kernel performance
static
constexpr
bool
bench_
=
false
;
// measure kernel performance
std
::
vector
<
int
>
k_batches_
;
std
::
vector
<
int
>
k_batches_
;
void
SetUp
()
override
{
k_batches_
=
{
1
,
2
,
3
,
5
,
8
};
}
void
SetUp
()
override
{
k_batches_
=
{
1
,
2
,
3
,
5
,
8
};
}
...
...
test/grouped_gemm/test_grouped_gemm_interface.cpp
View file @
0c5f9438
...
@@ -45,8 +45,8 @@ class TestGGemmSplitKInterface_MKNKMN : public ::testing::Test
...
@@ -45,8 +45,8 @@ class TestGGemmSplitKInterface_MKNKMN : public ::testing::Test
TEST_F
(
TestGGemmSplitKInterface_MKNKMN
,
TileSize
)
TEST_F
(
TestGGemmSplitKInterface_MKNKMN
,
TileSize
)
{
{
std
::
vector
<
int
>
Ms
{
128
,
256
,
188
,
512
};
std
::
vector
<
int
>
Ms
{
128
,
256
,
188
,
512
};
int
N
=
256
;
constexpr
int
N
=
256
;
int
K
=
128
;
constexpr
int
K
=
128
;
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
...
@@ -70,8 +70,8 @@ TEST_F(TestGGemmSplitKInterface_MKNKMN, VectorLoadWidth)
...
@@ -70,8 +70,8 @@ TEST_F(TestGGemmSplitKInterface_MKNKMN, VectorLoadWidth)
using
PaddedGGemmInstance
=
GGemmInstance
<
GemmMNKPadding
,
32
,
8
,
4
,
8
,
8
>
;
using
PaddedGGemmInstance
=
GGemmInstance
<
GemmMNKPadding
,
32
,
8
,
4
,
8
,
8
>
;
std
::
vector
<
int
>
Ms
{
128
,
256
,
256
,
512
};
std
::
vector
<
int
>
Ms
{
128
,
256
,
256
,
512
};
int
N
=
256
;
constexpr
int
N
=
256
;
int
K
=
512
;
constexpr
int
K
=
512
;
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
...
@@ -96,9 +96,9 @@ TEST_F(TestGGemmSplitKInterface_MKNKMN, VectorLoadWidth)
...
@@ -96,9 +96,9 @@ TEST_F(TestGGemmSplitKInterface_MKNKMN, VectorLoadWidth)
TEST_F
(
TestGGemmSplitKInterface_MKNKMN
,
KLoops
)
TEST_F
(
TestGGemmSplitKInterface_MKNKMN
,
KLoops
)
{
{
std
::
vector
<
int
>
Ms
{
128
,
256
,
256
,
512
};
std
::
vector
<
int
>
Ms
{
128
,
256
,
256
,
512
};
int
N
=
256
;
constexpr
int
N
=
256
;
int
K
=
128
;
constexpr
int
K
=
128
;
int
kbatch
=
4
;
constexpr
int
kbatch
=
4
;
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
...
@@ -152,8 +152,8 @@ class TestGGemmSplitKInterface_KMKNNM : public ::testing::Test
...
@@ -152,8 +152,8 @@ class TestGGemmSplitKInterface_KMKNNM : public ::testing::Test
TEST_F
(
TestGGemmSplitKInterface_KMKNNM
,
TileSize
)
TEST_F
(
TestGGemmSplitKInterface_KMKNNM
,
TileSize
)
{
{
std
::
vector
<
int
>
Ms
{
128
,
256
,
188
,
512
};
std
::
vector
<
int
>
Ms
{
128
,
256
,
188
,
512
};
int
N
=
256
;
constexpr
int
N
=
256
;
int
K
=
128
;
constexpr
int
K
=
128
;
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
...
@@ -177,8 +177,8 @@ TEST_F(TestGGemmSplitKInterface_KMKNNM, VectorLoadWidth)
...
@@ -177,8 +177,8 @@ TEST_F(TestGGemmSplitKInterface_KMKNNM, VectorLoadWidth)
using
PaddedGGemmInstance
=
GGemmInstance
<
GemmMNKPadding
,
32
,
8
,
2
,
8
,
4
>
;
using
PaddedGGemmInstance
=
GGemmInstance
<
GemmMNKPadding
,
32
,
8
,
2
,
8
,
4
>
;
std
::
vector
<
int
>
Ms
{
128
,
256
,
256
,
512
};
std
::
vector
<
int
>
Ms
{
128
,
256
,
256
,
512
};
int
N
=
256
;
constexpr
int
N
=
256
;
int
K
=
512
;
constexpr
int
K
=
512
;
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
...
...
test/grouped_gemm/test_grouped_gemm_ut_cases.inc
View file @
0c5f9438
...
@@ -3,8 +3,8 @@
...
@@ -3,8 +3,8 @@
TEST_P
(
RRR_F16_F16_F16
,
TinyCases
)
TEST_P
(
RRR_F16_F16_F16
,
TinyCases
)
{
{
const
std
::
vector
<
int
>
Ms
{
0
,
1
};
const
std
::
vector
<
int
>
Ms
{
0
,
1
};
const
int
N
=
768
;
const
expr
int
N
=
768
;
const
int
K
=
544
;
const
expr
int
K
=
544
;
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
...
@@ -18,8 +18,8 @@ TEST_P(RRR_F16_F16_F16, TinyCases)
...
@@ -18,8 +18,8 @@ TEST_P(RRR_F16_F16_F16, TinyCases)
TEST_P
(
RRR_F16_F16_F16
,
SmallCases
)
TEST_P
(
RRR_F16_F16_F16
,
SmallCases
)
{
{
const
std
::
vector
<
int
>
Ms
{
2
,
1
,
3
,
4
,
5
,
0
};
const
std
::
vector
<
int
>
Ms
{
2
,
1
,
3
,
4
,
5
,
0
};
const
int
N
=
768
;
const
expr
int
N
=
768
;
const
int
K
=
544
;
const
expr
int
K
=
544
;
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
...
@@ -33,8 +33,8 @@ TEST_P(RRR_F16_F16_F16, SmallCases)
...
@@ -33,8 +33,8 @@ TEST_P(RRR_F16_F16_F16, SmallCases)
TEST_P
(
RRR_F16_F16_F16
,
MidCases
)
TEST_P
(
RRR_F16_F16_F16
,
MidCases
)
{
{
const
std
::
vector
<
int
>
Ms
{
167
,
183
,
177
,
153
,
139
,
204
};
const
std
::
vector
<
int
>
Ms
{
167
,
183
,
177
,
153
,
139
,
204
};
const
int
N
=
768
;
const
expr
int
N
=
768
;
const
int
K
=
544
;
const
expr
int
K
=
544
;
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
...
@@ -48,8 +48,8 @@ TEST_P(RRR_F16_F16_F16, MidCases)
...
@@ -48,8 +48,8 @@ TEST_P(RRR_F16_F16_F16, MidCases)
TEST_P
(
RRR_F16_F16_F16
,
Regular
)
TEST_P
(
RRR_F16_F16_F16
,
Regular
)
{
{
const
std
::
vector
<
int
>
Ms
{
64
,
128
,
256
};
const
std
::
vector
<
int
>
Ms
{
64
,
128
,
256
};
const
int
N
=
768
;
const
expr
int
N
=
768
;
const
int
K
=
320
;
const
expr
int
K
=
320
;
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
...
@@ -63,8 +63,8 @@ TEST_P(RRR_F16_F16_F16, Regular)
...
@@ -63,8 +63,8 @@ TEST_P(RRR_F16_F16_F16, Regular)
TEST_P
(
RRR_F16_F16_F16
,
MNKPadded
)
TEST_P
(
RRR_F16_F16_F16
,
MNKPadded
)
{
{
const
std
::
vector
<
int
>
Ms
{
127
,
150
,
188
,
210
};
const
std
::
vector
<
int
>
Ms
{
127
,
150
,
188
,
210
};
const
int
N
=
136
;
const
expr
int
N
=
136
;
const
int
K
=
280
;
const
expr
int
K
=
280
;
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
...
@@ -78,8 +78,8 @@ TEST_P(RRR_F16_F16_F16, MNKPadded)
...
@@ -78,8 +78,8 @@ TEST_P(RRR_F16_F16_F16, MNKPadded)
TEST_P
(
RCR_F16_F16_F16
,
TinyCases
)
TEST_P
(
RCR_F16_F16_F16
,
TinyCases
)
{
{
const
std
::
vector
<
int
>
Ms
{
0
,
1
};
const
std
::
vector
<
int
>
Ms
{
0
,
1
};
const
int
N
=
768
;
const
expr
int
N
=
768
;
const
int
K
=
544
;
const
expr
int
K
=
544
;
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
...
@@ -92,8 +92,8 @@ TEST_P(RCR_F16_F16_F16, TinyCases)
...
@@ -92,8 +92,8 @@ TEST_P(RCR_F16_F16_F16, TinyCases)
TEST_P
(
RCR_F16_F16_F16
,
SmallCases
)
TEST_P
(
RCR_F16_F16_F16
,
SmallCases
)
{
{
const
std
::
vector
<
int
>
Ms
{
2
,
1
,
3
,
4
,
5
,
0
};
const
std
::
vector
<
int
>
Ms
{
2
,
1
,
3
,
4
,
5
,
0
};
const
int
N
=
768
;
const
expr
int
N
=
768
;
const
int
K
=
544
;
const
expr
int
K
=
544
;
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
...
@@ -107,8 +107,8 @@ TEST_P(RCR_F16_F16_F16, SmallCases)
...
@@ -107,8 +107,8 @@ TEST_P(RCR_F16_F16_F16, SmallCases)
TEST_P
(
RCR_F16_F16_F16
,
MidCases
)
TEST_P
(
RCR_F16_F16_F16
,
MidCases
)
{
{
const
std
::
vector
<
int
>
Ms
{
167
,
183
,
177
,
153
,
139
,
204
};
const
std
::
vector
<
int
>
Ms
{
167
,
183
,
177
,
153
,
139
,
204
};
const
int
N
=
768
;
const
expr
int
N
=
768
;
const
int
K
=
544
;
const
expr
int
K
=
544
;
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
...
@@ -122,8 +122,8 @@ TEST_P(RCR_F16_F16_F16, MidCases)
...
@@ -122,8 +122,8 @@ TEST_P(RCR_F16_F16_F16, MidCases)
TEST_P
(
RCR_F16_F16_F16
,
Regular
)
TEST_P
(
RCR_F16_F16_F16
,
Regular
)
{
{
const
std
::
vector
<
int
>
Ms
{
32
,
64
,
128
,
256
};
const
std
::
vector
<
int
>
Ms
{
32
,
64
,
128
,
256
};
const
int
N
=
768
;
const
expr
int
N
=
768
;
const
int
K
=
320
;
const
expr
int
K
=
320
;
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
...
@@ -137,8 +137,8 @@ TEST_P(RCR_F16_F16_F16, Regular)
...
@@ -137,8 +137,8 @@ TEST_P(RCR_F16_F16_F16, Regular)
TEST_P
(
RCR_F16_F16_F16
,
MNKPadded
)
TEST_P
(
RCR_F16_F16_F16
,
MNKPadded
)
{
{
const
std
::
vector
<
int
>
Ms
{
127
,
150
,
188
,
210
};
const
std
::
vector
<
int
>
Ms
{
127
,
150
,
188
,
210
};
const
int
N
=
136
;
const
expr
int
N
=
136
;
const
int
K
=
280
;
const
expr
int
K
=
280
;
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
...
@@ -152,8 +152,8 @@ TEST_P(RCR_F16_F16_F16, MNKPadded)
...
@@ -152,8 +152,8 @@ TEST_P(RCR_F16_F16_F16, MNKPadded)
TEST_P
(
RRR_F16_F16_F16_LargeK
,
TestLargeKBatch
)
TEST_P
(
RRR_F16_F16_F16_LargeK
,
TestLargeKBatch
)
{
{
const
std
::
vector
<
int
>
Ms
{
188
,
210
};
const
std
::
vector
<
int
>
Ms
{
188
,
210
};
const
int
N
=
768
;
const
expr
int
N
=
768
;
const
int
K
=
4096
;
const
expr
int
K
=
4096
;
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
...
@@ -167,8 +167,8 @@ TEST_P(RRR_F16_F16_F16_LargeK, TestLargeKBatch)
...
@@ -167,8 +167,8 @@ TEST_P(RRR_F16_F16_F16_LargeK, TestLargeKBatch)
TEST_P
(
RCR_F16_F16_F16_LargeK
,
TestLargeKBatch
)
TEST_P
(
RCR_F16_F16_F16_LargeK
,
TestLargeKBatch
)
{
{
const
std
::
vector
<
int
>
Ms
{
188
,
210
};
const
std
::
vector
<
int
>
Ms
{
188
,
210
};
const
int
N
=
768
;
const
expr
int
N
=
768
;
const
int
K
=
4096
;
const
expr
int
K
=
4096
;
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
...
...
test/grouped_gemm/test_grouped_gemm_util.hpp
View file @
0c5f9438
...
@@ -50,10 +50,10 @@ class TestGroupedGemm : public testing::TestWithParam<int>
...
@@ -50,10 +50,10 @@ class TestGroupedGemm : public testing::TestWithParam<int>
using
EDataType
=
std
::
tuple_element_t
<
5
,
Tuple
>
;
using
EDataType
=
std
::
tuple_element_t
<
5
,
Tuple
>
;
public:
public:
bool
verify_
=
true
;
static
constexpr
bool
verify_
=
true
;
int
init_method_
=
0
;
// decimal value initialization
static
constexpr
int
init_method_
=
1
;
// decimal value initialization
bool
log_
=
false
;
static
constexpr
bool
log_
=
false
;
bool
bench_
=
false
;
// measure kernel performance
static
constexpr
bool
bench_
=
false
;
// measure kernel performance
void
SetUp
()
override
{}
void
SetUp
()
override
{}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment