Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
ebe8b7d1
"vscode:/vscode.git/clone" did not exist on "462a79d39ad278090fbe5fc723d5a2c4d22185b9"
Commit
ebe8b7d1
authored
Oct 19, 2022
by
Anthony Chang
Browse files
simplify gemm test
parent
37d83d7d
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
71 additions
and
255 deletions
+71
-255
test/gemm/gemm_bf16.cpp
test/gemm/gemm_bf16.cpp
+6
-51
test/gemm/gemm_fp16.cpp
test/gemm/gemm_fp16.cpp
+6
-51
test/gemm/gemm_fp32.cpp
test/gemm/gemm_fp32.cpp
+6
-51
test/gemm/gemm_fp64.cpp
test/gemm/gemm_fp64.cpp
+6
-51
test/gemm/gemm_int8.cpp
test/gemm/gemm_int8.cpp
+6
-51
test/gemm/run_gemm_test.inc
test/gemm/run_gemm_test.inc
+41
-0
No files found.
test/gemm/gemm_bf16.cpp
View file @
ebe8b7d1
...
@@ -24,56 +24,11 @@
...
@@ -24,56 +24,11 @@
#include "test/gemm/gemm_util.hpp"
#include "test/gemm/gemm_util.hpp"
int
main
()
using
ADataType
=
ck
::
bhalf_t
;
{
using
BDataType
=
ck
::
bhalf_t
;
using
ADataType
=
ck
::
bhalf_t
;
using
CDataType
=
ck
::
bhalf_t
;
using
BDataType
=
ck
::
bhalf_t
;
using
AccDataType
=
float
;
using
CDataType
=
ck
::
bhalf_t
;
using
AccDataType
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
#include "run_gemm_test.inc"
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
int
main
()
{
return
run_gemm_test
();
}
auto
test
=
[
&
](
auto
a_layout
,
auto
b_layout
,
auto
c_layout
)
{
bool
pass
=
true
;
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceGemm
<
decltype
(
a_layout
),
decltype
(
b_layout
),
decltype
(
c_layout
),
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>
;
const
auto
gemmPtrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
for
(
auto
&
gemmPtr
:
gemmPtrs
)
{
pass
&=
ck
::
gemm_util
::
TestGemm
<
std
::
unique_ptr
<
DeviceOp
>
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
decltype
(
a_layout
),
decltype
(
b_layout
),
decltype
(
c_layout
),
PassThrough
,
PassThrough
,
PassThrough
>
{}(
gemmPtr
);
}
return
pass
;
};
bool
pass
=
test
(
Row
{},
Row
{},
Row
{})
&&
test
(
Row
{},
Col
{},
Row
{})
&&
test
(
Col
{},
Row
{},
Row
{})
&&
test
(
Col
{},
Col
{},
Row
{});
std
::
cout
<<
"TestGemm ..... "
<<
(
pass
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
return
pass
?
0
:
1
;
}
test/gemm/gemm_fp16.cpp
View file @
ebe8b7d1
...
@@ -24,56 +24,11 @@
...
@@ -24,56 +24,11 @@
#include "test/gemm/gemm_util.hpp"
#include "test/gemm/gemm_util.hpp"
int
main
()
using
ADataType
=
ck
::
half_t
;
{
using
BDataType
=
ck
::
half_t
;
using
ADataType
=
ck
::
half_t
;
using
CDataType
=
ck
::
half_t
;
using
BDataType
=
ck
::
half_t
;
using
AccDataType
=
float
;
using
CDataType
=
ck
::
half_t
;
using
AccDataType
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
#include "run_gemm_test.inc"
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
int
main
()
{
return
run_gemm_test
();
}
auto
test
=
[
&
](
auto
a_layout
,
auto
b_layout
,
auto
c_layout
)
{
bool
pass
=
true
;
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceGemm
<
decltype
(
a_layout
),
decltype
(
b_layout
),
decltype
(
c_layout
),
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>
;
const
auto
gemmPtrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
for
(
auto
&
gemmPtr
:
gemmPtrs
)
{
pass
&=
ck
::
gemm_util
::
TestGemm
<
std
::
unique_ptr
<
DeviceOp
>
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
decltype
(
a_layout
),
decltype
(
b_layout
),
decltype
(
c_layout
),
PassThrough
,
PassThrough
,
PassThrough
>
{}(
gemmPtr
);
}
return
pass
;
};
bool
pass
=
test
(
Row
{},
Row
{},
Row
{})
&&
test
(
Row
{},
Col
{},
Row
{})
&&
test
(
Col
{},
Row
{},
Row
{})
&&
test
(
Col
{},
Col
{},
Row
{});
std
::
cout
<<
"TestGemm ..... "
<<
(
pass
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
return
pass
?
0
:
1
;
}
test/gemm/gemm_fp32.cpp
View file @
ebe8b7d1
...
@@ -24,56 +24,11 @@
...
@@ -24,56 +24,11 @@
#include "test/gemm/gemm_util.hpp"
#include "test/gemm/gemm_util.hpp"
int
main
()
using
ADataType
=
float
;
{
using
BDataType
=
float
;
using
ADataType
=
float
;
using
CDataType
=
float
;
using
BDataType
=
float
;
using
AccDataType
=
float
;
using
CDataType
=
float
;
using
AccDataType
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
#include "run_gemm_test.inc"
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
int
main
()
{
return
run_gemm_test
();
}
auto
test
=
[
&
](
auto
a_layout
,
auto
b_layout
,
auto
c_layout
)
{
bool
pass
=
true
;
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceGemm
<
decltype
(
a_layout
),
decltype
(
b_layout
),
decltype
(
c_layout
),
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>
;
const
auto
gemmPtrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
for
(
auto
&
gemmPtr
:
gemmPtrs
)
{
pass
&=
ck
::
gemm_util
::
TestGemm
<
std
::
unique_ptr
<
DeviceOp
>
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
decltype
(
a_layout
),
decltype
(
b_layout
),
decltype
(
c_layout
),
PassThrough
,
PassThrough
,
PassThrough
>
{}(
gemmPtr
);
}
return
pass
;
};
bool
pass
=
test
(
Row
{},
Row
{},
Row
{})
&&
test
(
Row
{},
Col
{},
Row
{})
&&
test
(
Col
{},
Row
{},
Row
{})
&&
test
(
Col
{},
Col
{},
Row
{});
std
::
cout
<<
"TestGemm ..... "
<<
(
pass
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
return
pass
?
0
:
1
;
}
test/gemm/gemm_fp64.cpp
View file @
ebe8b7d1
...
@@ -24,56 +24,11 @@
...
@@ -24,56 +24,11 @@
#include "test/gemm/gemm_util.hpp"
#include "test/gemm/gemm_util.hpp"
int
main
()
using
ADataType
=
double
;
{
using
BDataType
=
double
;
using
ADataType
=
double
;
using
CDataType
=
double
;
using
BDataType
=
double
;
using
AccDataType
=
double
;
using
CDataType
=
double
;
using
AccDataType
=
double
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
#include "run_gemm_test.inc"
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
int
main
()
{
return
run_gemm_test
();
}
auto
test
=
[
&
](
auto
a_layout
,
auto
b_layout
,
auto
c_layout
)
{
bool
pass
=
true
;
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceGemm
<
decltype
(
a_layout
),
decltype
(
b_layout
),
decltype
(
c_layout
),
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>
;
const
auto
gemmPtrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
for
(
auto
&
gemmPtr
:
gemmPtrs
)
{
pass
&=
ck
::
gemm_util
::
TestGemm
<
std
::
unique_ptr
<
DeviceOp
>
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
decltype
(
a_layout
),
decltype
(
b_layout
),
decltype
(
c_layout
),
PassThrough
,
PassThrough
,
PassThrough
>
{}(
gemmPtr
);
}
return
pass
;
};
bool
pass
=
test
(
Row
{},
Row
{},
Row
{})
&&
test
(
Row
{},
Col
{},
Row
{})
&&
test
(
Col
{},
Row
{},
Row
{})
&&
test
(
Col
{},
Col
{},
Row
{});
std
::
cout
<<
"TestGemm ..... "
<<
(
pass
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
return
pass
?
0
:
1
;
}
test/gemm/gemm_int8.cpp
View file @
ebe8b7d1
...
@@ -24,56 +24,11 @@
...
@@ -24,56 +24,11 @@
#include "test/gemm/gemm_util.hpp"
#include "test/gemm/gemm_util.hpp"
int
main
()
using
ADataType
=
int8_t
;
{
using
BDataType
=
int8_t
;
using
ADataType
=
int8_t
;
using
CDataType
=
int8_t
;
using
BDataType
=
int8_t
;
using
AccDataType
=
int32_t
;
using
CDataType
=
int8_t
;
using
AccDataType
=
int32_t
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
#include "run_gemm_test.inc"
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
int
main
()
{
return
run_gemm_test
();
}
auto
test
=
[
&
](
auto
a_layout
,
auto
b_layout
,
auto
c_layout
)
{
bool
pass
=
true
;
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceGemm
<
decltype
(
a_layout
),
decltype
(
b_layout
),
decltype
(
c_layout
),
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>
;
const
auto
gemmPtrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
for
(
auto
&
gemmPtr
:
gemmPtrs
)
{
pass
&=
ck
::
gemm_util
::
TestGemm
<
std
::
unique_ptr
<
DeviceOp
>
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
decltype
(
a_layout
),
decltype
(
b_layout
),
decltype
(
c_layout
),
PassThrough
,
PassThrough
,
PassThrough
>
{}(
gemmPtr
);
}
return
pass
;
};
bool
pass
=
test
(
Row
{},
Row
{},
Row
{})
&&
test
(
Row
{},
Col
{},
Row
{})
&&
test
(
Col
{},
Row
{},
Row
{})
&&
test
(
Col
{},
Col
{},
Row
{});
std
::
cout
<<
"TestGemm ..... "
<<
(
pass
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
return
pass
?
0
:
1
;
}
test/gemm/run_gemm_test.inc
0 → 100644
View file @
ebe8b7d1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
int
run_gemm_test
()
{
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
auto
test
=
[
&
](
auto
a_layout
,
auto
b_layout
,
auto
c_layout
)
{
bool
pass
=
true
;
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceGemm
<
decltype
(
a_layout
),
decltype
(
b_layout
),
decltype
(
c_layout
),
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>
;
const
auto
gemmPtrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
for
(
auto
&
gemmPtr
:
gemmPtrs
)
{
pass
&=
ck
::
gemm_util
::
TestGemm
<
AccDataType
>
{}(
gemmPtr
.
get
());
}
return
pass
;
};
bool
pass
=
test
(
Row
{},
Row
{},
Row
{})
&&
test
(
Row
{},
Col
{},
Row
{})
&&
test
(
Col
{},
Row
{},
Row
{})
&&
test
(
Col
{},
Col
{},
Row
{});
std
::
cout
<<
"TestGemm ..... "
<<
(
pass
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
return
pass
?
0
:
1
;
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment