Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
b5ada11b
Commit
b5ada11b
authored
Jun 01, 2022
by
Jing Zhang
Browse files
merge develop
parents
cee92951
b6eaf3eb
Changes
95
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
502 additions
and
48 deletions
+502
-48
profiler/include/profile_gemm_reduce_impl.hpp
profiler/include/profile_gemm_reduce_impl.hpp
+31
-19
profiler/include/profile_grouped_gemm_impl.hpp
profiler/include/profile_grouped_gemm_impl.hpp
+2
-0
profiler/src/profile_gemm.cpp
profiler/src/profile_gemm.cpp
+16
-0
profiler/src/profile_grouped_gemm.cpp
profiler/src/profile_grouped_gemm.cpp
+4
-0
profiler/src/profiler.cpp
profiler/src/profiler.cpp
+2
-3
test/block_to_ctile_map/test_block_to_ctile_map.cpp
test/block_to_ctile_map/test_block_to_ctile_map.cpp
+224
-6
test/gemm/gemm_dl_fp16.cpp
test/gemm/gemm_dl_fp16.cpp
+8
-3
test/gemm/gemm_dl_fp32.cpp
test/gemm/gemm_dl_fp32.cpp
+8
-3
test/gemm/gemm_dl_int8.cpp
test/gemm/gemm_dl_int8.cpp
+8
-3
test/gemm/gemm_util.hpp
test/gemm/gemm_util.hpp
+8
-0
test/gemm/gemm_xdl_fp16.cpp
test/gemm/gemm_xdl_fp16.cpp
+8
-3
test/gemm/gemm_xdl_fp32.cpp
test/gemm/gemm_xdl_fp32.cpp
+8
-3
test/gemm/gemm_xdl_fp64.cpp
test/gemm/gemm_xdl_fp64.cpp
+156
-0
test/gemm/gemm_xdl_int8.cpp
test/gemm/gemm_xdl_int8.cpp
+8
-3
test/grouped_gemm/grouped_gemm_fp16.cpp
test/grouped_gemm/grouped_gemm_fp16.cpp
+11
-2
No files found.
profiler/include/profile_gemm_reduce_impl.hpp
View file @
b5ada11b
...
...
@@ -19,10 +19,11 @@ namespace device_gemm_instance {
using
F32
=
float
;
using
F16
=
ck
::
half_t
;
using
DPtrsGlobal
=
ck
::
Tuple
<
F32
*
,
F32
*>
;
using
Div
=
ck
::
tensor_operation
::
element_wise
::
UnaryIdentic
<
F32
,
F32
,
true
>
;
using
Identity
=
ck
::
tensor_operation
::
element_wise
::
UnaryIdentic
<
F32
,
F32
,
false
>
;
using
Square
=
ck
::
tensor_operation
::
element_wise
::
UnarySquare
<
F32
,
F32
,
false
>
;
using
DInElementOps
=
ck
::
Tuple
<
Identity
,
Square
>
;
using
DOutElementOps
=
ck
::
Tuple
<
Identity
,
Identity
>
;
using
DOutElementOps
=
ck
::
Tuple
<
Div
,
Div
>
;
using
DeviceGemmReduceNoOpPtr
=
ck
::
tensor_operation
::
device
::
DeviceGemmReducePtr
<
DPtrsGlobal
,
...
...
@@ -122,30 +123,37 @@ bool profile_gemm_reduce_impl(int do_verification,
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
},
num_thread
);
}
using
AElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
BElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
CElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
D0ReduceOp
=
ck
::
reduce
::
Add
<
float
>
;
using
D1ReduceOp
=
ck
::
reduce
::
Add
<
float
>
;
using
AElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
BElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
CElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
D0ReduceOp
=
ck
::
reduce
::
Add
<
float
>
;
using
D1ReduceOp
=
ck
::
reduce
::
Add
<
float
>
;
using
UnaryDivElementOp
=
ck
::
tensor_operation
::
element_wise
::
UnaryIdentic
<
float
,
float
,
true
>
;
using
UnaryIdenticElementOp
=
ck
::
tensor_operation
::
element_wise
::
UnaryIdentic
<
float
,
float
,
false
>
;
using
UnarySquareElementOp
=
ck
::
tensor_operation
::
element_wise
::
UnarySquare
<
float
,
float
,
false
>
;
using
DxsInElementOps
=
ck
::
Tuple
<
UnaryIdenticElementOp
,
UnarySquareElementOp
>
;
using
DxsOutElementOps
=
ck
::
Tuple
<
Unary
Identic
ElementOp
,
Unary
Identic
ElementOp
>
;
using
DxsOutElementOps
=
ck
::
Tuple
<
Unary
Div
ElementOp
,
Unary
Div
ElementOp
>
;
const
auto
a_element_op
=
AElementOp
{};
const
auto
b_element_op
=
BElementOp
{};
const
auto
c_element_op
=
CElementOp
{};
const
auto
dxs_in_element_op
=
DxsInElementOps
{};
const
auto
dxs_out_element_op
=
DxsOutElementOps
{};
const
auto
d0_reduce_op
=
D0ReduceOp
{};
const
auto
d1_reduce_op
=
D1ReduceOp
{};
const
auto
a_element_op
=
AElementOp
{};
const
auto
b_element_op
=
BElementOp
{};
const
auto
c_element_op
=
CElementOp
{};
const
auto
d0_reduce_op
=
D0ReduceOp
{};
const
auto
d1_reduce_op
=
D1ReduceOp
{};
auto
dxs_in_element_op
=
DxsInElementOps
{};
auto
dxs_out_element_op
=
DxsOutElementOps
{
M
,
M
};
if
(
do_verification
)
{
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
DDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
...
...
@@ -162,14 +170,18 @@ bool profile_gemm_reduce_impl(int do_verification,
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
float
d0_val
=
ck
::
type_convert
<
float
>
(
c_m_n_host_result
(
m
,
n
));
float
d1_val
;
float
c_val
=
ck
::
type_convert
<
float
>
(
c_m_n_host_result
(
m
,
n
));
float
d0_val
=
0
;
float
d1_val
=
0
;
UnarySquareElementOp
{}(
d1_val
,
d0_val
);
dxs_in_element_op
(
ck
::
Number
<
0
>
{})(
d0_val
,
c_val
);
dxs_in_element_op
(
ck
::
Number
<
1
>
{})(
d1_val
,
c_val
);
d0_reduce_op
(
d0_acc
,
d0_val
);
d1_reduce_op
(
d1_acc
,
d1_val
);
}
dxs_out_element_op
(
ck
::
Number
<
0
>
{})(
d0_acc
,
d0_acc
);
dxs_out_element_op
(
ck
::
Number
<
1
>
{})(
d1_acc
,
d1_acc
);
d0_m_host_result
(
m
)
=
ck
::
type_convert
<
DDataType
>
(
d0_acc
);
d1_m_host_result
(
m
)
=
ck
::
type_convert
<
DDataType
>
(
d1_acc
);
}
...
...
profiler/include/profile_grouped_gemm_impl.hpp
View file @
b5ada11b
...
...
@@ -43,6 +43,7 @@ namespace profiler {
template
<
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
AccDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
...
...
@@ -271,6 +272,7 @@ void profile_grouped_gemm_impl(int do_verification,
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
...
...
profiler/src/profile_gemm.cpp
View file @
b5ada11b
...
...
@@ -68,6 +68,7 @@ int profile_gemm(int argc, char* argv[])
ck
::
profiler
::
profile_gemm_impl
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
...
...
@@ -88,6 +89,7 @@ int profile_gemm(int argc, char* argv[])
ck
::
profiler
::
profile_gemm_impl
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
...
...
@@ -108,6 +110,7 @@ int profile_gemm(int argc, char* argv[])
ck
::
profiler
::
profile_gemm_impl
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
...
...
@@ -128,6 +131,7 @@ int profile_gemm(int argc, char* argv[])
ck
::
profiler
::
profile_gemm_impl
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
...
...
@@ -146,6 +150,7 @@ int profile_gemm(int argc, char* argv[])
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
ck
::
profiler
::
profile_gemm_impl
<
float
,
float
,
float
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
...
...
@@ -166,6 +171,7 @@ int profile_gemm(int argc, char* argv[])
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
ck
::
profiler
::
profile_gemm_impl
<
float
,
float
,
float
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
...
...
@@ -186,6 +192,7 @@ int profile_gemm(int argc, char* argv[])
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
{
ck
::
profiler
::
profile_gemm_impl
<
float
,
float
,
float
,
float
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
...
...
@@ -206,6 +213,7 @@ int profile_gemm(int argc, char* argv[])
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
{
ck
::
profiler
::
profile_gemm_impl
<
float
,
float
,
float
,
float
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
...
...
@@ -228,6 +236,7 @@ int profile_gemm(int argc, char* argv[])
ck
::
profiler
::
profile_gemm_impl
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
...
...
@@ -248,6 +257,7 @@ int profile_gemm(int argc, char* argv[])
ck
::
profiler
::
profile_gemm_impl
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
...
...
@@ -268,6 +278,7 @@ int profile_gemm(int argc, char* argv[])
ck
::
profiler
::
profile_gemm_impl
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
...
...
@@ -288,6 +299,7 @@ int profile_gemm(int argc, char* argv[])
ck
::
profiler
::
profile_gemm_impl
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
...
...
@@ -308,6 +320,7 @@ int profile_gemm(int argc, char* argv[])
ck
::
profiler
::
profile_gemm_impl
<
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
...
...
@@ -328,6 +341,7 @@ int profile_gemm(int argc, char* argv[])
ck
::
profiler
::
profile_gemm_impl
<
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
...
...
@@ -348,6 +362,7 @@ int profile_gemm(int argc, char* argv[])
ck
::
profiler
::
profile_gemm_impl
<
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
float
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
...
...
@@ -368,6 +383,7 @@ int profile_gemm(int argc, char* argv[])
ck
::
profiler
::
profile_gemm_impl
<
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
float
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
...
...
profiler/src/profile_grouped_gemm.cpp
View file @
b5ada11b
...
...
@@ -79,6 +79,7 @@ int profile_grouped_gemm(int argc, char* argv[])
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
ck
::
profiler
::
profile_grouped_gemm_impl
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
...
...
@@ -97,6 +98,7 @@ int profile_grouped_gemm(int argc, char* argv[])
else
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
ck
::
profiler
::
profile_grouped_gemm_impl
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
...
...
@@ -115,6 +117,7 @@ int profile_grouped_gemm(int argc, char* argv[])
else
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
KM_KN_MN
)
{
ck
::
profiler
::
profile_grouped_gemm_impl
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
...
...
@@ -133,6 +136,7 @@ int profile_grouped_gemm(int argc, char* argv[])
else
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
KM_NK_MN
)
{
ck
::
profiler
::
profile_grouped_gemm_impl
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
...
...
profiler/src/profiler.cpp
View file @
b5ada11b
...
...
@@ -26,8 +26,7 @@ int main(int argc, char* argv[])
{
if
(
strcmp
(
argv
[
1
],
"gemm"
)
==
0
)
{
int
stat
=
profile_gemm
(
argc
,
argv
);
return
stat
;
return
profile_gemm
(
argc
,
argv
);
}
else
if
(
strcmp
(
argv
[
1
],
"gemm_bias_2d"
)
==
0
)
{
...
...
@@ -55,7 +54,7 @@ int main(int argc, char* argv[])
}
else
if
(
strcmp
(
argv
[
1
],
"grouped_gemm"
)
==
0
)
{
profile_grouped_gemm
(
argc
,
argv
);
return
profile_grouped_gemm
(
argc
,
argv
);
}
else
if
(
strcmp
(
argv
[
1
],
"conv_fwd"
)
==
0
)
{
...
...
test/block_to_ctile_map/test_block_to_ctile_map.cpp
View file @
b5ada11b
...
...
@@ -8,6 +8,7 @@ using namespace ck;
static
auto
I0
=
Number
<
0
>
{};
static
auto
I1
=
Number
<
1
>
{};
static
auto
I2
=
Number
<
2
>
{};
TEST
(
BlockToCTileMap
,
TestBlockToCTileMap_M00_N00_M01_N01_DeviceCTileIndexCheck1
)
{
...
...
@@ -20,7 +21,7 @@ TEST(BlockToCTileMap, TestBlockToCTileMap_M00_N00_M01_N01_DeviceCTileIndexCheck1
const
index_t
M01
=
4
;
const
index_t
N01
=
4
;
auto
c_grid_desc_m_n
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
I1
,
I1
));
auto
c_grid_desc_m_n
=
make_naive_tensor_descriptor
_packed
(
make_tuple
(
M
,
N
));
printf
(
"(M, N, MPerBlock, NPerBlock, M01, N01) = (%d, %d, %d, %d, %d, %d)
\n
"
,
M
,
...
...
@@ -37,7 +38,7 @@ TEST(BlockToCTileMap, TestBlockToCTileMap_M00_N00_M01_N01_DeviceCTileIndexCheck1
EXPECT_TRUE
(
tile_map
.
CalculateGridSize
(
c_grid_desc_m_n
)
==
16
);
// clang-format off
std
::
vector
<
std
::
vector
<
int
>>
expected
=
{
std
::
vector
<
std
::
vector
<
int
>>
expected
_m0idx_n0idx_valid
=
{
{
0
,
0
,
1
},
{
0
,
1
,
1
},
{
0
,
2
,
1
},
...
...
@@ -64,7 +65,7 @@ TEST(BlockToCTileMap, TestBlockToCTileMap_M00_N00_M01_N01_DeviceCTileIndexCheck1
std
::
cout
<<
", valid = "
<<
tile_map
.
ValidCTileIndex
(
m0n0_idx
,
make_tuple
(
MBlock
,
NBlock
))
<<
std
::
endl
;
bool
equal
=
expected
[
i
]
==
expected
_m0idx_n0idx_valid
[
i
]
==
std
::
vector
<
int
>
{
m0n0_idx
[
I0
],
m0n0_idx
[
I1
],
tile_map
.
ValidCTileIndex
(
m0n0_idx
,
make_tuple
(
MBlock
,
NBlock
))};
...
...
@@ -78,12 +79,11 @@ TEST(BlockToCTileMap, TestBlockToCTileMap_M00_N00_M01_N01_DeviceCTileIndexCheck0
const
index_t
N
=
384
;
const
index_t
MPerBlock
=
128
;
const
index_t
NPerBlock
=
128
;
// const index_t MBlock = M / MPerBlock;
// const index_t NBlock = N / NPerBlock;
const
index_t
M01
=
4
;
const
index_t
N01
=
4
;
auto
c_grid_desc_m_n
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
I1
,
I1
));
auto
c_grid_desc_m_n
=
make_naive_tensor_descriptor
_packed
(
make_tuple
(
M
,
N
));
printf
(
"(M, N, MPerBlock, NPerBlock, M01, N01) = (%d, %d, %d, %d, %d, %d)
\n
"
,
M
,
...
...
@@ -98,3 +98,221 @@ TEST(BlockToCTileMap, TestBlockToCTileMap_M00_N00_M01_N01_DeviceCTileIndexCheck0
EXPECT_TRUE
(
tile_map
.
CheckValidity
(
c_grid_desc_m_n
)
==
false
);
}
TEST
(
BlockToCTileMap
,
TestBlockToCTileMap_M00_N0_M01_DeviceCTileIndexCheck1
)
{
const
index_t
M
=
384
;
const
index_t
N
=
512
;
const
index_t
MPerBlock
=
128
;
const
index_t
NPerBlock
=
128
;
const
index_t
MBlock
=
M
/
MPerBlock
;
const
index_t
NBlock
=
N
/
NPerBlock
;
const
index_t
M01
=
4
;
auto
c_grid_desc_m_n
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
M
,
N
));
printf
(
"(M, N, MPerBlock, NPerBlock, M01) = (%d, %d, %d, %d, %d)
\n
"
,
M
,
N
,
MPerBlock
,
NPerBlock
,
M01
);
BlockToCTileMap_M00_N0_M01
<
MPerBlock
,
NPerBlock
,
decltype
(
c_grid_desc_m_n
),
true
>
tile_map
(
c_grid_desc_m_n
,
M01
);
EXPECT_TRUE
(
tile_map
.
CheckValidity
(
c_grid_desc_m_n
)
==
true
);
EXPECT_TRUE
(
tile_map
.
CalculateGridSize
(
c_grid_desc_m_n
)
==
16
);
// clang-format off
std
::
vector
<
std
::
vector
<
int
>>
expected_m0idx_n0idx_valid
=
{
{
0
,
0
,
1
},
{
1
,
0
,
1
},
{
2
,
0
,
1
},
{
3
,
0
,
0
},
{
0
,
1
,
1
},
{
1
,
1
,
1
},
{
2
,
1
,
1
},
{
3
,
1
,
0
},
{
0
,
2
,
1
},
{
1
,
2
,
1
},
{
2
,
2
,
1
},
{
3
,
2
,
0
},
{
0
,
3
,
1
},
{
1
,
3
,
1
},
{
2
,
3
,
1
},
{
3
,
3
,
0
}
};
// clang-format on
for
(
index_t
i
=
0
;
i
<
tile_map
.
CalculateGridSize
(
c_grid_desc_m_n
);
i
++
)
{
auto
m0n0_idx
=
tile_map
.
CalculateBottomIndex
(
make_multi_index
(
i
));
std
::
cout
<<
"block_1d_id = "
<<
i
<<
", m0, n0 = "
<<
m0n0_idx
[
I0
]
<<
", "
<<
m0n0_idx
[
I1
];
std
::
cout
<<
", valid = "
<<
tile_map
.
ValidCTileIndex
(
m0n0_idx
,
make_tuple
(
MBlock
,
NBlock
))
<<
std
::
endl
;
bool
equal
=
expected_m0idx_n0idx_valid
[
i
]
==
std
::
vector
<
int
>
{
m0n0_idx
[
I0
],
m0n0_idx
[
I1
],
tile_map
.
ValidCTileIndex
(
m0n0_idx
,
make_tuple
(
MBlock
,
NBlock
))};
EXPECT_TRUE
(
equal
);
}
}
TEST
(
BlockToCTileMap
,
TestBlockToCTileMap_M00_N0_M01_DeviceCTileIndexCheck0
)
{
const
index_t
M
=
512
;
const
index_t
N
=
384
;
const
index_t
MPerBlock
=
128
;
const
index_t
NPerBlock
=
128
;
auto
c_grid_desc_m_n
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
M
,
N
));
// clang-format off
std
::
vector
<
std
::
tuple
<
int
,
int
,
bool
>>
expected_m0_gridsize_validity
=
{
{
5
,
15
,
false
},
{
4
,
12
,
true
},
{
3
,
18
,
false
},
{
2
,
12
,
true
},
{
1
,
12
,
true
}
};
// clang-format on
for
(
auto
e
:
expected_m0_gridsize_validity
)
{
const
index_t
M01
=
std
::
get
<
0
>
(
e
);
printf
(
"(M, N, MPerBlock, NPerBlock, M01) = (%d, %d, %d, %d, %d)
\n
"
,
M
,
N
,
MPerBlock
,
NPerBlock
,
M01
);
BlockToCTileMap_M00_N0_M01
<
MPerBlock
,
NPerBlock
,
decltype
(
c_grid_desc_m_n
),
false
>
tile_map
(
c_grid_desc_m_n
,
M01
);
EXPECT_EQ
(
tile_map
.
CalculateGridSize
(
c_grid_desc_m_n
),
std
::
get
<
1
>
(
e
));
EXPECT_EQ
(
tile_map
.
CheckValidity
(
c_grid_desc_m_n
),
std
::
get
<
2
>
(
e
));
}
}
TEST
(
BlockToCTileMap
,
TestBlockToCTileMap_M00_N0_M01Adapt
)
{
const
index_t
M
=
768
;
const
index_t
N
=
384
;
const
index_t
MPerBlock
=
128
;
const
index_t
NPerBlock
=
128
;
const
index_t
MBlock
=
M
/
MPerBlock
;
const
index_t
NBlock
=
N
/
NPerBlock
;
constexpr
index_t
M01
=
4
;
auto
c_grid_desc_m_n
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
M
,
N
));
printf
(
"(M, N, MPerBlock, NPerBlock, M01) = (%d, %d, %d, %d, %d)
\n
"
,
M
,
N
,
MPerBlock
,
NPerBlock
,
M01
);
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
decltype
(
c_grid_desc_m_n
)
>
tile_map
(
c_grid_desc_m_n
,
M01
);
EXPECT_TRUE
(
tile_map
.
CheckValidity
(
c_grid_desc_m_n
)
==
true
);
EXPECT_TRUE
(
tile_map
.
CalculateGridSize
(
c_grid_desc_m_n
)
==
18
);
// clang-format off
std
::
vector
<
std
::
vector
<
int
>>
expected_m0idx_n0idx_valid
=
{
{
0
,
0
,
1
},
{
1
,
0
,
1
},
{
2
,
0
,
1
},
{
3
,
0
,
1
},
{
0
,
1
,
1
},
{
1
,
1
,
1
},
{
2
,
1
,
1
},
{
3
,
1
,
1
},
{
0
,
2
,
1
},
{
1
,
2
,
1
},
{
2
,
2
,
1
},
{
3
,
2
,
1
},
{
4
,
0
,
1
},
{
5
,
0
,
1
},
{
4
,
1
,
1
},
{
5
,
1
,
1
},
{
4
,
2
,
1
},
{
5
,
2
,
1
},
};
// clang-format on
for
(
index_t
i
=
0
;
i
<
tile_map
.
CalculateGridSize
(
c_grid_desc_m_n
);
i
++
)
{
auto
m0n0_idx
=
tile_map
.
CalculateBottomIndex
(
make_multi_index
(
i
));
std
::
cout
<<
"block_1d_id = "
<<
i
<<
", m0, n0 = "
<<
m0n0_idx
[
I0
]
<<
", "
<<
m0n0_idx
[
I1
];
std
::
cout
<<
", valid = "
<<
tile_map
.
ValidCTileIndex
(
m0n0_idx
,
make_tuple
(
MBlock
,
NBlock
))
<<
std
::
endl
;
bool
equal
=
expected_m0idx_n0idx_valid
[
i
]
==
std
::
vector
<
int
>
{
m0n0_idx
[
I0
],
m0n0_idx
[
I1
],
tile_map
.
ValidCTileIndex
(
m0n0_idx
,
make_tuple
(
MBlock
,
NBlock
))};
EXPECT_TRUE
(
equal
);
}
}
TEST
(
BlockToCTileMap
,
TestBlockToCTileMap_KSplit_M00_N0_M01Adapt
)
{
const
index_t
M
=
768
;
const
index_t
N
=
384
;
const
index_t
MPerBlock
=
128
;
const
index_t
NPerBlock
=
128
;
const
index_t
MBlock
=
M
/
MPerBlock
;
const
index_t
NBlock
=
N
/
NPerBlock
;
constexpr
index_t
M01
=
4
;
const
index_t
KSplit
=
3
;
auto
c_grid_desc_m_n
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
M
,
N
));
printf
(
"(M, N, MPerBlock, NPerBlock, M01) = (%d, %d, %d, %d, %d)
\n
"
,
M
,
N
,
MPerBlock
,
NPerBlock
,
M01
);
BlockToCTileMap_KSplit_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
decltype
(
c_grid_desc_m_n
)
>
tile_map
(
c_grid_desc_m_n
,
M01
,
KSplit
);
EXPECT_TRUE
(
tile_map
.
CheckValidity
(
c_grid_desc_m_n
)
==
true
);
EXPECT_TRUE
(
tile_map
.
CalculateGridSize
(
c_grid_desc_m_n
)
==
18
*
KSplit
);
std
::
vector
<
std
::
vector
<
int
>>
expected_ksplitidx_m0idx_n0idx_valid
=
{
{
0
,
0
,
0
,
1
},
{
0
,
1
,
0
,
1
},
{
0
,
2
,
0
,
1
},
{
0
,
3
,
0
,
1
},
{
0
,
0
,
1
,
1
},
{
0
,
1
,
1
,
1
},
{
0
,
2
,
1
,
1
},
{
0
,
3
,
1
,
1
},
{
0
,
0
,
2
,
1
},
{
0
,
1
,
2
,
1
},
{
0
,
2
,
2
,
1
},
{
0
,
3
,
2
,
1
},
{
0
,
4
,
0
,
1
},
{
0
,
5
,
0
,
1
},
{
0
,
4
,
1
,
1
},
{
0
,
5
,
1
,
1
},
{
0
,
4
,
2
,
1
},
{
0
,
5
,
2
,
1
},
{
1
,
0
,
0
,
1
},
{
1
,
1
,
0
,
1
},
{
1
,
2
,
0
,
1
},
{
1
,
3
,
0
,
1
},
{
1
,
0
,
1
,
1
},
{
1
,
1
,
1
,
1
},
{
1
,
2
,
1
,
1
},
{
1
,
3
,
1
,
1
},
{
1
,
0
,
2
,
1
},
{
1
,
1
,
2
,
1
},
{
1
,
2
,
2
,
1
},
{
1
,
3
,
2
,
1
},
{
1
,
4
,
0
,
1
},
{
1
,
5
,
0
,
1
},
{
1
,
4
,
1
,
1
},
{
1
,
5
,
1
,
1
},
{
1
,
4
,
2
,
1
},
{
1
,
5
,
2
,
1
},
{
2
,
0
,
0
,
1
},
{
2
,
1
,
0
,
1
},
{
2
,
2
,
0
,
1
},
{
2
,
3
,
0
,
1
},
{
2
,
0
,
1
,
1
},
{
2
,
1
,
1
,
1
},
{
2
,
2
,
1
,
1
},
{
2
,
3
,
1
,
1
},
{
2
,
0
,
2
,
1
},
{
2
,
1
,
2
,
1
},
{
2
,
2
,
2
,
1
},
{
2
,
3
,
2
,
1
},
{
2
,
4
,
0
,
1
},
{
2
,
5
,
0
,
1
},
{
2
,
4
,
1
,
1
},
{
2
,
5
,
1
,
1
},
{
2
,
4
,
2
,
1
},
{
2
,
5
,
2
,
1
},
};
for
(
index_t
i
=
0
;
i
<
tile_map
.
CalculateGridSize
(
c_grid_desc_m_n
);
i
++
)
{
auto
ksplitm0n0_idx
=
tile_map
.
CalculateBottomIndex
(
make_multi_index
(
i
));
std
::
cout
<<
"block_1d_id = "
<<
i
<<
", ksplit, m0, n0 = "
<<
ksplitm0n0_idx
[
I0
]
<<
", "
<<
ksplitm0n0_idx
[
I1
]
<<
", "
<<
ksplitm0n0_idx
[
I2
];
std
::
cout
<<
", valid = "
<<
tile_map
.
ValidCTileIndex
(
ksplitm0n0_idx
,
make_tuple
(
MBlock
,
NBlock
))
<<
std
::
endl
;
bool
equal
=
expected_ksplitidx_m0idx_n0idx_valid
[
i
]
==
std
::
vector
<
int
>
{
ksplitm0n0_idx
[
I0
],
ksplitm0n0_idx
[
I1
],
ksplitm0n0_idx
[
I2
],
tile_map
.
ValidCTileIndex
(
ksplitm0n0_idx
,
make_tuple
(
MBlock
,
NBlock
))};
EXPECT_TRUE
(
equal
);
}
}
test/gemm/gemm_dl_fp16.cpp
View file @
b5ada11b
...
...
@@ -43,9 +43,10 @@ void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(std::vector<DeviceGemmNoO
int
main
()
{
using
ADataType
=
ck
::
half_t
;
using
BDataType
=
ck
::
half_t
;
using
CDataType
=
ck
::
half_t
;
using
ADataType
=
ck
::
half_t
;
using
BDataType
=
ck
::
half_t
;
using
CDataType
=
ck
::
half_t
;
using
AccDataType
=
float
;
using
RowMajor
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
ColumnMajor
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
...
...
@@ -63,6 +64,7 @@ int main()
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ColumnMajor
,
RowMajor
,
RowMajor
,
...
...
@@ -81,6 +83,7 @@ int main()
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ColumnMajor
,
ColumnMajor
,
RowMajor
,
...
...
@@ -99,6 +102,7 @@ int main()
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
RowMajor
,
RowMajor
,
RowMajor
,
...
...
@@ -117,6 +121,7 @@ int main()
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
RowMajor
,
ColumnMajor
,
RowMajor
,
...
...
test/gemm/gemm_dl_fp32.cpp
View file @
b5ada11b
...
...
@@ -43,9 +43,10 @@ void add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances(std::vector<DeviceGemmNoO
int
main
()
{
using
ADataType
=
float
;
using
BDataType
=
float
;
using
CDataType
=
float
;
using
ADataType
=
float
;
using
BDataType
=
float
;
using
CDataType
=
float
;
using
AccDataType
=
float
;
using
RowMajor
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
ColumnMajor
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
...
...
@@ -61,6 +62,7 @@ int main()
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ColumnMajor
,
RowMajor
,
RowMajor
,
...
...
@@ -79,6 +81,7 @@ int main()
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ColumnMajor
,
ColumnMajor
,
RowMajor
,
...
...
@@ -97,6 +100,7 @@ int main()
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
RowMajor
,
RowMajor
,
RowMajor
,
...
...
@@ -115,6 +119,7 @@ int main()
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
RowMajor
,
ColumnMajor
,
RowMajor
,
...
...
test/gemm/gemm_dl_int8.cpp
View file @
b5ada11b
...
...
@@ -43,9 +43,10 @@ void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPt
int
main
()
{
using
ADataType
=
int8_t
;
using
BDataType
=
int8_t
;
using
CDataType
=
int8_t
;
using
ADataType
=
int8_t
;
using
BDataType
=
int8_t
;
using
CDataType
=
int8_t
;
using
AccDataType
=
int
;
using
RowMajor
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
ColumnMajor
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
...
...
@@ -61,6 +62,7 @@ int main()
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ColumnMajor
,
RowMajor
,
RowMajor
,
...
...
@@ -79,6 +81,7 @@ int main()
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ColumnMajor
,
ColumnMajor
,
RowMajor
,
...
...
@@ -97,6 +100,7 @@ int main()
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
RowMajor
,
RowMajor
,
RowMajor
,
...
...
@@ -115,6 +119,7 @@ int main()
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
RowMajor
,
ColumnMajor
,
RowMajor
,
...
...
test/gemm/gemm_util.hpp
View file @
b5ada11b
...
...
@@ -111,6 +111,7 @@ template <typename DeviceGemmPtr_,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
AccDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
...
...
@@ -186,6 +187,7 @@ struct TestGemm
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
;
...
...
@@ -215,6 +217,11 @@ struct TestGemm
res
=
ck
::
utils
::
check_err
(
c_device
.
mData
,
c_host
.
mData
);
std
::
cout
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
}
else
if
(
std
::
is_same
<
CDataType
,
double
>::
value
)
{
res
=
ck
::
utils
::
check_err
(
c_device
.
mData
,
c_host
.
mData
);
std
::
cout
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
}
return
res
;
}
...
...
@@ -311,6 +318,7 @@ struct TestGemmBF16
// use fp32 host kernel to verify bf16 device kernel
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
float
,
float
,
float
,
float
,
AElementwiseOperation
,
...
...
test/gemm/gemm_xdl_fp16.cpp
View file @
b5ada11b
...
...
@@ -52,9 +52,10 @@ void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(
int
main
()
{
using
ADataType
=
ck
::
half_t
;
using
BDataType
=
ck
::
half_t
;
using
CDataType
=
ck
::
half_t
;
using
ADataType
=
ck
::
half_t
;
using
BDataType
=
ck
::
half_t
;
using
CDataType
=
ck
::
half_t
;
using
AccDataType
=
float
;
using
RowMajor
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
ColumnMajor
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
...
...
@@ -74,6 +75,7 @@ int main()
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ColumnMajor
,
RowMajor
,
RowMajor
,
...
...
@@ -96,6 +98,7 @@ int main()
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ColumnMajor
,
ColumnMajor
,
RowMajor
,
...
...
@@ -118,6 +121,7 @@ int main()
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
RowMajor
,
RowMajor
,
RowMajor
,
...
...
@@ -142,6 +146,7 @@ int main()
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
RowMajor
,
ColumnMajor
,
RowMajor
,
...
...
test/gemm/gemm_xdl_fp32.cpp
View file @
b5ada11b
...
...
@@ -53,9 +53,10 @@ void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(std::vector<De
int
main
()
{
using
ADataType
=
float
;
using
BDataType
=
float
;
using
CDataType
=
float
;
using
ADataType
=
float
;
using
BDataType
=
float
;
using
CDataType
=
float
;
using
AccDataType
=
float
;
using
RowMajor
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
ColumnMajor
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
...
...
@@ -75,6 +76,7 @@ int main()
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ColumnMajor
,
RowMajor
,
RowMajor
,
...
...
@@ -97,6 +99,7 @@ int main()
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ColumnMajor
,
ColumnMajor
,
RowMajor
,
...
...
@@ -119,6 +122,7 @@ int main()
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
RowMajor
,
RowMajor
,
RowMajor
,
...
...
@@ -141,6 +145,7 @@ int main()
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
RowMajor
,
ColumnMajor
,
RowMajor
,
...
...
test/gemm/gemm_xdl_fp64.cpp
0 → 100644
View file @
b5ada11b
#include <algorithm>
#include <cstdlib>
#include <half.hpp>
#include <iostream>
#include <numeric>
#include <tuple>
#include <vector>
#include "gemm_util.hpp"
#include "config.hpp"
#include "print.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "host_gemm.hpp"
#include "device_tensor.hpp"
#include "device_gemm_xdl.hpp"
#include "element_wise_operation.hpp"
#include "reference_gemm.hpp"
#include "gemm_specialization.hpp"
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
DeviceGemmNoOpPtr
=
ck
::
tensor_operation
::
device
::
DeviceGemmPtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_gemm_instance
{
void
add_device_gemm_xdl_f64_f64_f64_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_f64_f64_f64_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_f64_f64_f64_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_f64_f64_f64_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
}
// namespace device_gemm_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
inline
std
::
string
get_device_name
()
{
hipDeviceProp_t
props
{};
int
device
;
auto
status
=
hipGetDevice
(
&
device
);
if
(
status
!=
hipSuccess
)
{
return
std
::
string
();
}
status
=
hipGetDeviceProperties
(
&
props
,
device
);
if
(
status
!=
hipSuccess
)
{
return
std
::
string
();
}
const
std
::
string
name
(
props
.
gcnArchName
);
return
name
;
}
int
main
()
{
if
(
get_device_name
().
find
(
"gfx90a"
)
==
std
::
string
::
npos
)
{
std
::
cout
<<
"TestGemm ..... SUCCESS"
<<
std
::
endl
;
return
0
;
}
using
ADataType
=
double
;
using
BDataType
=
double
;
using
CDataType
=
double
;
using
AccDataType
=
double
;
using
RowMajor
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
ColumnMajor
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
bool
res
=
true
;
std
::
vector
<
DeviceGemmNoOpPtr
>
gemmPtrs
;
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_f64_f64_f64_km_kn_mn_instances
(
gemmPtrs
);
for
(
auto
&
gemmPtr
:
gemmPtrs
)
{
res
&=
ck
::
gemm_util
::
TestGemm
<
DeviceGemmNoOpPtr
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ColumnMajor
,
RowMajor
,
RowMajor
,
PassThrough
,
PassThrough
,
PassThrough
>
{}(
gemmPtr
);
}
gemmPtrs
.
clear
();
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_f64_f64_f64_km_nk_mn_instances
(
gemmPtrs
);
for
(
auto
&
gemmPtr
:
gemmPtrs
)
{
res
&=
ck
::
gemm_util
::
TestGemm
<
DeviceGemmNoOpPtr
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ColumnMajor
,
ColumnMajor
,
RowMajor
,
PassThrough
,
PassThrough
,
PassThrough
>
{}(
gemmPtr
);
}
gemmPtrs
.
clear
();
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_f64_f64_f64_mk_kn_mn_instances
(
gemmPtrs
);
for
(
auto
&
gemmPtr
:
gemmPtrs
)
{
res
&=
ck
::
gemm_util
::
TestGemm
<
DeviceGemmNoOpPtr
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
RowMajor
,
RowMajor
,
RowMajor
,
PassThrough
,
PassThrough
,
PassThrough
>
{}(
gemmPtr
);
}
gemmPtrs
.
clear
();
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_xdl_f64_f64_f64_mk_nk_mn_instances
(
gemmPtrs
);
for
(
auto
&
gemmPtr
:
gemmPtrs
)
{
res
&=
ck
::
gemm_util
::
TestGemm
<
DeviceGemmNoOpPtr
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
RowMajor
,
ColumnMajor
,
RowMajor
,
PassThrough
,
PassThrough
,
PassThrough
>
{}(
gemmPtr
);
}
std
::
cout
<<
"TestGemm ..... "
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
return
res
?
0
:
1
;
}
test/gemm/gemm_xdl_int8.cpp
View file @
b5ada11b
...
...
@@ -42,9 +42,10 @@ void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(std::vector<Devic
int
main
()
{
using
ADataType
=
int8_t
;
using
BDataType
=
int8_t
;
using
CDataType
=
int8_t
;
using
ADataType
=
int8_t
;
using
BDataType
=
int8_t
;
using
CDataType
=
int8_t
;
using
AccDataType
=
int32_t
;
using
RowMajor
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
ColumnMajor
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
...
...
@@ -61,6 +62,7 @@ int main()
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ColumnMajor
,
RowMajor
,
RowMajor
,
...
...
@@ -79,6 +81,7 @@ int main()
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ColumnMajor
,
ColumnMajor
,
RowMajor
,
...
...
@@ -97,6 +100,7 @@ int main()
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
RowMajor
,
RowMajor
,
RowMajor
,
...
...
@@ -115,6 +119,7 @@ int main()
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
RowMajor
,
ColumnMajor
,
RowMajor
,
...
...
test/grouped_gemm/grouped_gemm_fp16.cpp
View file @
b5ada11b
...
...
@@ -150,14 +150,23 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr)
groupedGemmPtr
->
SetWorkSpacePointer
(
argument_ptr
.
get
(),
gemm_desc_workspace
.
GetDeviceBuffer
());
DeviceMem
gemm_desc_workspace
(
groupedGemmPtr
->
GetWorkSpaceSize
(
argument_ptr
.
get
()));
groupedGemmPtr
->
SetWorkSpacePointer
(
argument_ptr
.
get
(),
gemm_desc_workspace
.
GetDeviceBuffer
());
invoker_ptr
->
Run
(
argument_ptr
.
get
());
for
(
std
::
size_t
i
=
0
;
i
<
gemm_descs
.
size
();
i
++
)
{
c_tensors_device
[
i
]
->
FromDevice
(
c_device_tensors
[
i
].
mData
.
data
());
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
>
;
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
PassThrough
,
PassThrough
,
PassThrough
>
;
auto
ref_gemm
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment