Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
9eed0992
Commit
9eed0992
authored
May 10, 2023
by
Adam Osewski
Browse files
Unit tests for multiple KBatch values.
parent
045bf6b6
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
184 additions
and
87 deletions
+184
-87
test/grouped_gemm/test_grouped_gemm_splitk.cpp
test/grouped_gemm/test_grouped_gemm_splitk.cpp
+12
-4
test/grouped_gemm/test_grouped_gemm_ut_cases.inc
test/grouped_gemm/test_grouped_gemm_ut_cases.inc
+172
-74
test/grouped_gemm/test_grouped_gemm_util.hpp
test/grouped_gemm/test_grouped_gemm_util.hpp
+0
-9
No files found.
test/grouped_gemm/test_grouped_gemm_splitk.cpp
View file @
9eed0992
...
...
@@ -14,13 +14,21 @@ using F16 = ck::half_t;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
//
using RRR_F16_F16_F16 = ck::test::TestGroupedGemm<std::tuple<Row, Row, Row, F16, F16, F16>>;
using
RRR_F16_F16_F16
=
ck
::
test
::
TestGroupedGemm
<
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
>>
;
using
RCR_F16_F16_F16
=
ck
::
test
::
TestGroupedGemm
<
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
>>
;
// const
st
d
::
vector<int> KBATCH{1, 2, 8, 32}
;
const
st
d
::
vector
<
int
>
KBATCH
{
4
}
;
using
RRR_F16_F16_F16_LargeK
=
ck
::
te
st
::
TestGroupedGemm
<
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
>>
;
using
RCR_F16_F16_F16_LargeK
=
ck
::
te
st
::
TestGroupedGemm
<
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
>>
;
// INSTANTIATE_TEST_SUITE_P(TestGroupedGemm_splitk_MK_KN, RRR_F16_F16_F16, testing::ValuesIn(KBATCH));
const
std
::
vector
<
int
>
KBATCH
{
1
,
2
,
3
,
5
,
8
};
INSTANTIATE_TEST_SUITE_P
(
TestGroupedGemm_splitk_MK_KN
,
RRR_F16_F16_F16
,
testing
::
ValuesIn
(
KBATCH
));
INSTANTIATE_TEST_SUITE_P
(
TestGroupedGemm_splitk_MK_NK
,
RCR_F16_F16_F16
,
testing
::
ValuesIn
(
KBATCH
));
INSTANTIATE_TEST_SUITE_P
(
TestGroupedGemm_splitk_LargeK_MK_KN
,
RRR_F16_F16_F16_LargeK
,
testing
::
Values
(
32
,
64
));
INSTANTIATE_TEST_SUITE_P
(
TestGroupedGemm_splitk_LargeK_MK_NK
,
RCR_F16_F16_F16_LargeK
,
testing
::
Values
(
32
,
64
));
#include "test_grouped_gemm_ut_cases.inc"
test/grouped_gemm/test_grouped_gemm_ut_cases.inc
View file @
9eed0992
#pragma once
// TEST_P(RRR_F16_F16_F16, TinyCases)
// {
// const std::vector<int> Ms{0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0};
// const std::vector<int> Ns(Ms.size(), 4608);
// const std::vector<int> Ks(Ms.size(), 384);
// const std::vector<int> StrideAs(Ms.size(), 384);
// const std::vector<int> StrideBs(Ms.size(), 4608);
// const std::vector<int> StrideCs(Ms.size(), 4608);
// this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
// }
// TEST_P(RRR_F16_F16_F16, SmallCases)
// {
// const std::vector<int> Ms{2, 1, 1, 1, 1, 1, 3, 4, 3, 5, 2, 4, 2, 1, 0, 1};
// const std::vector<int> Ns(Ms.size(), 4608);
// const std::vector<int> Ks(Ms.size(), 384);
// const std::vector<int> StrideAs(Ms.size(), 384);
// const std::vector<int> StrideBs(Ms.size(), 4608);
// const std::vector<int> StrideCs(Ms.size(), 4608);
// this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
// }
// TEST_P(RRR_F16_F16_F16, MidCases)
// {
// const std::vector<int> Ms{
// 167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148};
// const std::vector<int> Ns(Ms.size(), 4608);
// const std::vector<int> Ks(Ms.size(), 384);
// const std::vector<int> StrideAs(Ms.size(), 384);
// const std::vector<int> StrideBs(Ms.size(), 4608);
// const std::vector<int> StrideCs(Ms.size(), 4608);
// this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
// }
// TEST_P(RCR_F16_F16_F16, TinyCases)
// {
// const std::vector<int> Ms{0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0};
// const std::vector<int> Ns(Ms.size(), 768);
// const std::vector<int> Ks(Ms.size(), 4608);
// const std::vector<int> StrideAs(Ms.size(), 4608);
// const std::vector<int> StrideBs(Ms.size(), 4608);
// const std::vector<int> StrideCs(Ms.size(), 768);
// this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
// }
// TEST_P(RCR_F16_F16_F16, SmallCases)
// {
// const std::vector<int> Ms{2, 1, 1, 1, 1, 1, 3, 4, 3, 5, 2, 4, 2, 1, 0, 1};
// const std::vector<int> Ns(Ms.size(), 768);
// const std::vector<int> Ks(Ms.size(), 4608);
// const std::vector<int> StrideAs(Ms.size(), 4608);
// const std::vector<int> StrideBs(Ms.size(), 4608);
// const std::vector<int> StrideCs(Ms.size(), 768);
// this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
// }
TEST_P
(
RRR_F16_F16_F16
,
TinyCases
)
{
const
std
::
vector
<
int
>
Ms
{
0
,
1
};
const
int
N
=
768
;
const
int
K
=
544
;
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideAs
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideBs
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
StrideCs
(
Ms
.
size
(),
N
);
this
->
Run
(
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideCs
,
this
->
GetParam
());
}
TEST_P
(
RRR_F16_F16_F16
,
SmallCases
)
{
const
std
::
vector
<
int
>
Ms
{
2
,
1
,
3
,
4
,
5
,
0
};
const
int
N
=
768
;
const
int
K
=
544
;
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideAs
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideBs
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
StrideCs
(
Ms
.
size
(),
N
);
this
->
Run
(
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideCs
,
this
->
GetParam
());
}
TEST_P
(
RRR_F16_F16_F16
,
MidCases
)
{
const
std
::
vector
<
int
>
Ms
{
167
,
183
,
177
,
153
,
139
,
204
};
const
int
N
=
768
;
const
int
K
=
544
;
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideAs
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideBs
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
StrideCs
(
Ms
.
size
(),
N
);
this
->
Run
(
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideCs
,
this
->
GetParam
());
}
TEST_P
(
RRR_F16_F16_F16
,
Regular
)
{
const
std
::
vector
<
int
>
Ms
{
64
,
128
,
256
};
const
int
N
=
768
;
const
int
K
=
320
;
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideAs
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideBs
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
StrideCs
(
Ms
.
size
(),
N
);
this
->
Run
(
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideCs
,
this
->
GetParam
());
}
TEST_P
(
RRR_F16_F16_F16
,
MNKPadded
)
{
const
std
::
vector
<
int
>
Ms
{
127
,
150
,
188
,
210
};
const
int
N
=
136
;
const
int
K
=
280
;
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideAs
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideBs
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
StrideCs
(
Ms
.
size
(),
N
);
this
->
Run
(
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideCs
,
this
->
GetParam
());
}
TEST_P
(
RCR_F16_F16_F16
,
TinyCases
)
{
const
std
::
vector
<
int
>
Ms
{
0
,
1
};
const
int
N
=
768
;
const
int
K
=
544
;
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideAs
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideBs
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideCs
(
Ms
.
size
(),
N
);
this
->
Run
(
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideCs
,
this
->
GetParam
());
}
TEST_P
(
RCR_F16_F16_F16
,
SmallCases
)
{
const
std
::
vector
<
int
>
Ms
{
2
,
1
,
3
,
4
,
5
,
0
};
const
int
N
=
768
;
const
int
K
=
544
;
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideAs
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideBs
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideCs
(
Ms
.
size
(),
N
);
this
->
Run
(
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideCs
,
this
->
GetParam
());
}
TEST_P
(
RCR_F16_F16_F16
,
MidCases
)
{
const
std
::
vector
<
int
>
Ms
{
// 167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148};
167
};
// const std::vector<int> Ns(Ms.size(), 768);
// const std::vector<int> Ks(Ms.size(), 4608);
// const std::vector<int> StrideAs(Ms.size(), 4608);
// const std::vector<int> StrideBs(Ms.size(), 4608);
// const std::vector<int> StrideCs(Ms.size(), 768);
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
256
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
128
);
const
std
::
vector
<
int
>
StrideAs
(
Ms
.
size
(),
128
);
const
std
::
vector
<
int
>
StrideBs
(
Ms
.
size
(),
128
);
const
std
::
vector
<
int
>
StrideCs
(
Ms
.
size
(),
256
);
const
std
::
vector
<
int
>
Ms
{
167
,
183
,
177
,
153
,
139
,
204
};
const
int
N
=
768
;
const
int
K
=
544
;
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideAs
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideBs
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideCs
(
Ms
.
size
(),
N
);
this
->
Run
(
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideCs
,
this
->
GetParam
());
}
TEST_P
(
RCR_F16_F16_F16
,
Regular
)
{
const
std
::
vector
<
int
>
Ms
{
32
,
64
,
128
,
256
};
const
int
N
=
768
;
const
int
K
=
320
;
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideAs
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideBs
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideCs
(
Ms
.
size
(),
N
);
this
->
Run
(
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideCs
,
this
->
GetParam
());
}
TEST_P
(
RCR_F16_F16_F16
,
MNKPadded
)
{
const
std
::
vector
<
int
>
Ms
{
127
,
150
,
188
,
210
};
const
int
N
=
136
;
const
int
K
=
280
;
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideAs
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideBs
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideCs
(
Ms
.
size
(),
N
);
this
->
Run
(
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideCs
,
this
->
GetParam
());
}
TEST_P
(
RRR_F16_F16_F16_LargeK
,
TestLargeKBatch
)
{
const
std
::
vector
<
int
>
Ms
{
188
,
210
};
const
int
N
=
768
;
const
int
K
=
4096
;
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideAs
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideBs
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
StrideCs
(
Ms
.
size
(),
N
);
this
->
Run
(
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideCs
,
this
->
GetParam
());
}
TEST_P
(
RCR_F16_F16_F16_LargeK
,
TestLargeKBatch
)
{
const
std
::
vector
<
int
>
Ms
{
188
,
210
};
const
int
N
=
768
;
const
int
K
=
4096
;
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
N
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideAs
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideBs
(
Ms
.
size
(),
K
);
const
std
::
vector
<
int
>
StrideCs
(
Ms
.
size
(),
N
);
this
->
Run
(
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideCs
,
this
->
GetParam
());
}
test/grouped_gemm/test_grouped_gemm_util.hpp
View file @
9eed0992
...
...
@@ -56,15 +56,6 @@ class TestGroupedGemm : public testing::TestWithParam<int>
const
std
::
vector
<
int
>&
StrideCs
,
int
kbatch
=
1
)
{
std
::
cout
<<
"Ms: ["
<<
serialize_range
(
Ms
)
<<
"] "
<<
"Ns: ["
<<
serialize_range
(
Ns
)
<<
"] "
<<
"Ks: ["
<<
serialize_range
(
Ks
)
<<
"] "
<<
"StrideAs: ["
<<
serialize_range
(
StrideAs
)
<<
"] "
<<
"StrideBs: ["
<<
serialize_range
(
StrideBs
)
<<
"] "
<<
"StrideCs: ["
<<
serialize_range
(
StrideCs
)
<<
"] "
<<
"kbatch: "
<<
kbatch
<<
std
::
endl
;
bool
pass
=
ck
::
profiler
::
profile_grouped_gemm_impl
<
ADataType
,
BDataType
,
EDataType
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment