Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
0c8b4bbf
Commit
0c8b4bbf
authored
Apr 26, 2023
by
Adam Osewski
Browse files
Add functional tests for grouped_gemm with different kbatch value.
parent
1945c26b
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
168 additions
and
2 deletions
+168
-2
test/grouped_gemm/CMakeLists.txt
test/grouped_gemm/CMakeLists.txt
+8
-2
test/grouped_gemm/test_grouped_gemm_splitk.cpp
test/grouped_gemm/test_grouped_gemm_splitk.cpp
+25
-0
test/grouped_gemm/test_grouped_gemm_ut_cases.inc
test/grouped_gemm/test_grouped_gemm_ut_cases.inc
+75
-0
test/grouped_gemm/test_grouped_gemm_util.hpp
test/grouped_gemm/test_grouped_gemm_util.hpp
+60
-0
No files found.
test/grouped_gemm/CMakeLists.txt
View file @
0c8b4bbf
add_custom_target
(
test_grouped_gemm
)
add_test_executable
(
test_grouped_gemm_fp16 grouped_gemm_fp16.cpp
)
target_link_libraries
(
test_grouped_gemm_fp16 PRIVATE utility
)
target_link_libraries
(
test_grouped_gemm_fp16 PRIVATE device_grouped_gemm_instance
)
add_gtest_executable
(
test_grouped_gemm_splitk test_grouped_gemm_splitk.cpp
)
target_link_libraries
(
test_grouped_gemm_fp16 PRIVATE utility device_grouped_gemm_instance
)
target_link_libraries
(
test_grouped_gemm_splitk PRIVATE utility device_grouped_gemm_instance
)
add_dependencies
(
test_grouped_gemm test_grouped_gemm_fp16 test_grouped_gemm_splitk
)
test/grouped_gemm/test_grouped_gemm_splitk.cpp
0 → 100644
View file @
0c8b4bbf
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple>
#include <vector>
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/utility/data_type.hpp"
#include "gtest/gtest.h"
#include "test_grouped_gemm_util.hpp"
using
F16
=
ck
::
half_t
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
RRR_F16_F16_F16
=
ck
::
test
::
TestGroupedGemm
<
std
::
tuple
<
Row
,
Row
,
Row
,
F16
,
F16
,
F16
>>
;
using
RCR_F16_F16_F16
=
ck
::
test
::
TestGroupedGemm
<
std
::
tuple
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
>>
;
const
std
::
vector
<
int
>
KBATCH
{
1
,
2
,
4
,
6
,
8
,
10
,
12
,
14
,
16
,
32
,
64
,
128
};
INSTANTIATE_TEST_SUITE_P
(
TestGroupedGemm_splitk_MK_KN
,
RRR_F16_F16_F16
,
testing
::
ValuesIn
(
KBATCH
));
INSTANTIATE_TEST_SUITE_P
(
TestGroupedGemm_splitk_MK_NK
,
RCR_F16_F16_F16
,
testing
::
ValuesIn
(
KBATCH
));
#include "test_grouped_gemm_ut_cases.inc"
test/grouped_gemm/test_grouped_gemm_ut_cases.inc
0 → 100644
View file @
0c8b4bbf
#pragma once
TEST_P
(
RCR_F16_F16_F16
,
TinyCases
)
{
const
std
::
vector
<
int
>
Ms
{
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
};
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
768
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
4068
);
const
std
::
vector
<
int
>
StrideAs
(
Ms
.
size
(),
4608
);
const
std
::
vector
<
int
>
StrideBs
(
Ms
.
size
(),
4608
);
const
std
::
vector
<
int
>
StrideCs
(
Ms
.
size
(),
768
);
this
->
Run
(
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideCs
,
this
->
GetParam
());
}
TEST_P
(
RRR_F16_F16_F16
,
TinyCases
)
{
const
std
::
vector
<
int
>
Ms
{
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
};
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
4608
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
384
);
const
std
::
vector
<
int
>
StrideAs
(
Ms
.
size
(),
384
);
const
std
::
vector
<
int
>
StrideBs
(
Ms
.
size
(),
4608
);
const
std
::
vector
<
int
>
StrideCs
(
Ms
.
size
(),
4608
);
this
->
Run
(
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideCs
,
this
->
GetParam
());
}
TEST_P
(
RCR_F16_F16_F16
,
SmallCases
)
{
const
std
::
vector
<
int
>
Ms
{
2
,
1
,
1
,
1
,
1
,
1
,
3
,
4
,
3
,
5
,
2
,
4
,
2
,
1
,
0
,
1
};
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
768
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
4068
);
const
std
::
vector
<
int
>
StrideAs
(
Ms
.
size
(),
4608
);
const
std
::
vector
<
int
>
StrideBs
(
Ms
.
size
(),
4608
);
const
std
::
vector
<
int
>
StrideCs
(
Ms
.
size
(),
768
);
this
->
Run
(
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideCs
,
this
->
GetParam
());
}
TEST_P
(
RRR_F16_F16_F16
,
SmallCases
)
{
const
std
::
vector
<
int
>
Ms
{
2
,
1
,
1
,
1
,
1
,
1
,
3
,
4
,
3
,
5
,
2
,
4
,
2
,
1
,
0
,
1
};
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
4608
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
384
);
const
std
::
vector
<
int
>
StrideAs
(
Ms
.
size
(),
384
);
const
std
::
vector
<
int
>
StrideBs
(
Ms
.
size
(),
4608
);
const
std
::
vector
<
int
>
StrideCs
(
Ms
.
size
(),
4608
);
this
->
Run
(
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideCs
,
this
->
GetParam
());
}
TEST_P
(
RCR_F16_F16_F16
,
MidCases
)
{
const
std
::
vector
<
int
>
Ms
{
167
,
183
,
177
,
181
,
153
,
139
,
156
,
173
,
163
,
150
,
204
,
184
,
168
,
156
,
168
,
148
};
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
768
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
4068
);
const
std
::
vector
<
int
>
StrideAs
(
Ms
.
size
(),
4608
);
const
std
::
vector
<
int
>
StrideBs
(
Ms
.
size
(),
4608
);
const
std
::
vector
<
int
>
StrideCs
(
Ms
.
size
(),
768
);
this
->
Run
(
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideCs
,
this
->
GetParam
());
}
TEST_P
(
RRR_F16_F16_F16
,
MidCases
)
{
const
std
::
vector
<
int
>
Ms
{
167
,
183
,
177
,
181
,
153
,
139
,
156
,
173
,
163
,
150
,
204
,
184
,
168
,
156
,
168
,
148
};
const
std
::
vector
<
int
>
Ns
(
Ms
.
size
(),
4608
);
const
std
::
vector
<
int
>
Ks
(
Ms
.
size
(),
384
);
const
std
::
vector
<
int
>
StrideAs
(
Ms
.
size
(),
384
);
const
std
::
vector
<
int
>
StrideBs
(
Ms
.
size
(),
4608
);
const
std
::
vector
<
int
>
StrideCs
(
Ms
.
size
(),
4608
);
this
->
Run
(
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideCs
,
this
->
GetParam
());
}
test/grouped_gemm/test_grouped_gemm_util.hpp
0 → 100644
View file @
0c8b4bbf
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include <sstream>
#include <tuple>
#include <vector>
#include <gtest/gtest.h>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "include/ck/utility/data_type.hpp"
#include "profiler/profile_grouped_gemm_impl.hpp"
namespace
ck
{
namespace
test
{
template
<
typename
Tuple
>
class
TestGroupedGemm
:
public
testing
::
TestWithParam
<
int
>
{
protected:
using
ALayout
=
std
::
tuple_element_t
<
0
,
Tuple
>
;
using
BLayout
=
std
::
tuple_element_t
<
1
,
Tuple
>
;
using
ELayout
=
std
::
tuple_element_t
<
2
,
Tuple
>
;
using
ADataType
=
std
::
tuple_element_t
<
3
,
Tuple
>
;
using
BDataType
=
std
::
tuple_element_t
<
4
,
Tuple
>
;
using
EDataType
=
std
::
tuple_element_t
<
5
,
Tuple
>
;
public:
bool
verify_
=
true
;
int
init_method_
=
2
;
// decimal value initialization
bool
log_
=
false
;
bool
bench_
=
false
;
// measure kernel performance
void
SetUp
()
override
{}
void
Run
(
const
std
::
vector
<
int
>&
Ms
,
const
std
::
vector
<
int
>&
Ns
,
const
std
::
vector
<
int
>&
Ks
,
const
std
::
vector
<
int
>&
StrideAs
,
const
std
::
vector
<
int
>&
StrideBs
,
const
std
::
vector
<
int
>&
StrideCs
,
int
kbatch
=
1
)
{
bool
pass
=
ck
::
profiler
::
profile_grouped_gemm_impl
<
ADataType
,
BDataType
,
EDataType
,
float
,
ALayout
,
BLayout
,
ELayout
>
(
verify_
,
init_method_
,
log_
,
bench_
,
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideCs
,
kbatch
);
EXPECT_TRUE
(
pass
);
}
};
}
// namespace test
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment