Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
0ef27d53
Commit
0ef27d53
authored
Jan 21, 2025
by
Andriy Roshchenko
Browse files
Merge remote-tracking branch 'origin/gfx950' into andriy/lwpck-2682
parents
6778c318
74a743e2
Changes
348
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
158 additions
and
56 deletions
+158
-56
script/process_qa_data.sh
script/process_qa_data.sh
+16
-0
test/CMakeLists.txt
test/CMakeLists.txt
+52
-0
test/ck_tile/batched_gemm/test_batched_gemm_util.hpp
test/ck_tile/batched_gemm/test_batched_gemm_util.hpp
+21
-22
test/ck_tile/gemm/test_gemm_pipeline_util.hpp
test/ck_tile/gemm/test_gemm_pipeline_util.hpp
+12
-32
test/data_type/CMakeLists.txt
test/data_type/CMakeLists.txt
+1
-0
test/data_type/test_bhalf.cpp
test/data_type/test_bhalf.cpp
+48
-0
test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp
...uped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp
+7
-2
test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp
test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp
+1
-0
No files found.
script/process_qa_data.sh
View file @
0ef27d53
...
@@ -52,3 +52,19 @@ file=./perf_fmha_bwd_gfx90a.log
...
@@ -52,3 +52,19 @@ file=./perf_fmha_bwd_gfx90a.log
if
[
-e
"
$file
"
]
;
then
if
[
-e
"
$file
"
]
;
then
python3 process_perf_data.py perf_fmha_bwd_gfx90a.log
python3 process_perf_data.py perf_fmha_bwd_gfx90a.log
fi
fi
file
=
./perf_gemm_basic_gfx942.log
if
[
-e
"
$file
"
]
;
then
python3 process_perf_data.py perf_gemm_basic_gfx942.log
fi
file
=
./perf_gemm_basic_gfx90a.log
if
[
-e
"
$file
"
]
;
then
python3 process_perf_data.py perf_gemm_basic_gfx90a.log
fi
file
=
./perf_gemm_mem_pipeline_gfx942.log
if
[
-e
"
$file
"
]
;
then
python3 process_perf_data.py perf_gemm_mem_pipeline_gfx942.log
fi
file
=
./perf_gemm_mem_pipeline_gfx90a.log
if
[
-e
"
$file
"
]
;
then
python3 process_perf_data.py perf_gemm_mem_pipeline_gfx90a.log
fi
test/CMakeLists.txt
View file @
0ef27d53
...
@@ -7,6 +7,34 @@ include(gtest)
...
@@ -7,6 +7,34 @@ include(gtest)
add_custom_target
(
tests
)
add_custom_target
(
tests
)
# list of tests that are labelled as REGRESSION_TEST for make regression (runtime more than 30 seconds)
# all other tests are labelled as SMOKE_TEST
set
(
REGRESSION_TESTS
test_gemm_standalone_xdl_fp16
test_gemm_fp16
test_gemm_splitk
test_batched_gemm
test_gemm_universal
test_batched_gemm_softmax_gemm_fp16
test_batched_gemm_softmax_gemm_permute_fp16
test_batched_gemm_bias_softmax_gemm_permute_fp16
test_batched_gemm_softmax_gemm_permute_bf16
test_batched_gemm_bias_softmax_gemm_permute_bf16
test_grouped_gemm_splitk
test_reduce_no_index
test_reduce_with_index
test_convnd_fwd
test_convnd_bwd_data
test_grouped_convnd_fwd
test_grouped_convnd_bwd_weight
test_softmax_rank3
test_softmax_rank4
test_batchnorm_fwd_rank_4
test_batchnorm_bwd_rank_4
test_grouped_convnd_bwd_data_xdl
test_conv_tensor_rearrange
)
function
(
add_test_executable TEST_NAME
)
function
(
add_test_executable TEST_NAME
)
message
(
"adding test
${
TEST_NAME
}
"
)
message
(
"adding test
${
TEST_NAME
}
"
)
set
(
result 1
)
set
(
result 1
)
...
@@ -43,6 +71,12 @@ function(add_test_executable TEST_NAME)
...
@@ -43,6 +71,12 @@ function(add_test_executable TEST_NAME)
set
(
TEST_TARGETS
${
SUPPORTED_GPU_TARGETS
}
)
set
(
TEST_TARGETS
${
SUPPORTED_GPU_TARGETS
}
)
foreach
(
source IN LISTS ARGN
)
if
(
NOT DEFINED DPP_KERNELS AND source MATCHES
"_dpp"
)
message
(
"removing dpp test
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endforeach
()
foreach
(
source IN LISTS ARGN
)
foreach
(
source IN LISTS ARGN
)
if
(
NOT DEFINED DL_KERNELS AND source MATCHES
"_dl"
)
if
(
NOT DEFINED DL_KERNELS AND source MATCHES
"_dl"
)
message
(
"removing dl test
${
source
}
"
)
message
(
"removing dl test
${
source
}
"
)
...
@@ -82,6 +116,15 @@ function(add_test_executable TEST_NAME)
...
@@ -82,6 +116,15 @@ function(add_test_executable TEST_NAME)
endif
()
endif
()
#message("add_test returns ${result}")
#message("add_test returns ${result}")
set
(
result
${
result
}
PARENT_SCOPE
)
set
(
result
${
result
}
PARENT_SCOPE
)
if
(
result EQUAL 0 AND NOT
"
${
TEST_NAME
}
"
IN_LIST REGRESSION_TESTS
)
message
(
"adding to SMOKE TEST FILTER
${
TEST_NAME
}
"
)
set_tests_properties
(
${
TEST_NAME
}
PROPERTIES LABELS
"SMOKE_TEST"
)
add_dependencies
(
smoke
${
TEST_NAME
}
)
elseif
(
result EQUAL 0 AND
"
${
TEST_NAME
}
"
IN_LIST REGRESSION_TESTS
)
message
(
"Adding to REGRESSION TEST FILTER
${
TEST_NAME
}
"
)
set_tests_properties
(
${
TEST_NAME
}
PROPERTIES LABELS
"REGRESSION_TEST"
)
add_dependencies
(
regression
${
TEST_NAME
}
)
endif
()
endfunction
()
endfunction
()
function
(
add_gtest_executable TEST_NAME
)
function
(
add_gtest_executable TEST_NAME
)
...
@@ -172,6 +215,15 @@ function(add_gtest_executable TEST_NAME)
...
@@ -172,6 +215,15 @@ function(add_gtest_executable TEST_NAME)
endif
()
endif
()
#message("add_gtest returns ${result}")
#message("add_gtest returns ${result}")
set
(
result
${
result
}
PARENT_SCOPE
)
set
(
result
${
result
}
PARENT_SCOPE
)
if
(
result EQUAL 0 AND NOT
"
${
TEST_NAME
}
"
IN_LIST REGRESSION_TESTS
)
#message("adding to smoke test FILTER ${TEST_NAME}")
set_tests_properties
(
${
TEST_NAME
}
PROPERTIES LABELS
"SMOKE_TEST"
)
add_dependencies
(
smoke
${
TEST_NAME
}
)
elseif
(
result EQUAL 0 AND
"
${
TEST_NAME
}
"
IN_LIST REGRESSION_TESTS
)
#message("Adding to REGRESSION TEST FILTER ${TEST_NAME}")
set_tests_properties
(
${
TEST_NAME
}
PROPERTIES LABELS
"REGRESSION_TEST"
)
add_dependencies
(
regression
${
TEST_NAME
}
)
endif
()
endfunction
()
endfunction
()
add_compile_options
(
-Wno-c++20-extensions
)
add_compile_options
(
-Wno-c++20-extensions
)
...
...
test/ck_tile/batched_gemm/test_batched_gemm_util.hpp
View file @
0ef27d53
...
@@ -24,12 +24,9 @@ class TestCkTileBatchedGemm : public ::testing::Test
...
@@ -24,12 +24,9 @@ class TestCkTileBatchedGemm : public ::testing::Test
using
AccDataType
=
std
::
tuple_element_t
<
5
,
Tuple
>
;
using
AccDataType
=
std
::
tuple_element_t
<
5
,
Tuple
>
;
using
CDataType
=
std
::
tuple_element_t
<
6
,
Tuple
>
;
using
CDataType
=
std
::
tuple_element_t
<
6
,
Tuple
>
;
struct
batched_gemm_kargs
:
public
ck_tile
::
BatchedGemmHostArgs
{
};
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
void
invoke_batched_gemm
(
const
batched_gemm_kargs
&
args
,
const
ck_tile
::
stream_config
&
s
)
void
invoke_batched_gemm
(
const
ck_tile
::
BatchedGemmHostArgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
{
{
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
constexpr
bool
kPadM
=
false
;
constexpr
bool
kPadM
=
false
;
...
@@ -94,9 +91,9 @@ class TestCkTileBatchedGemm : public ::testing::Test
...
@@ -94,9 +91,9 @@ class TestCkTileBatchedGemm : public ::testing::Test
using
Kernel
=
using
Kernel
=
ck_tile
::
BatchedGemmKernel
<
TilePartitioner
,
CodegenGemmPipeline
,
GemmEpilogue
>
;
ck_tile
::
BatchedGemmKernel
<
TilePartitioner
,
CodegenGemmPipeline
,
GemmEpilogue
>
;
auto
kargs
=
Kernel
::
MakeK
a
rgs
(
args
);
auto
kargs
=
Kernel
::
MakeK
ernelA
rgs
(
args
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
k_batch
,
args
.
batch_count
);
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
if
(
s
.
log_level_
>
0
)
if
(
s
.
log_level_
>
0
)
...
@@ -185,21 +182,23 @@ class TestCkTileBatchedGemm : public ::testing::Test
...
@@ -185,21 +182,23 @@ class TestCkTileBatchedGemm : public ::testing::Test
c_m_n_dev_buf
.
SetZero
();
c_m_n_dev_buf
.
SetZero
();
c_m_n_dev_result
.
SetZero
();
c_m_n_dev_result
.
SetZero
();
batched_gemm_kargs
kargs
{
a_m_k_dev_buf
.
GetDeviceBuffer
(),
ck_tile
::
BatchedGemmHostArgs
args
;
b_k_n_dev_buf
.
GetDeviceBuffer
(),
args
.
a_ptr
=
a_m_k_dev_buf
.
GetDeviceBuffer
();
c_m_n_dev_buf
.
GetDeviceBuffer
(),
args
.
b_ptr
=
b_k_n_dev_buf
.
GetDeviceBuffer
();
M
,
args
.
c_ptr
=
c_m_n_dev_buf
.
GetDeviceBuffer
();
N
,
args
.
k_batch
=
1
;
K
,
args
.
M
=
M
;
StrideA
,
args
.
N
=
N
;
StrideB
,
args
.
K
=
K
;
StrideC
,
args
.
stride_A
=
StrideA
;
BatchStrideA
,
args
.
stride_B
=
StrideB
;
BatchStrideB
,
args
.
stride_C
=
StrideC
;
BatchStrideC
,
args
.
batch_stride_A
=
BatchStrideA
;
BatchCount
};
args
.
batch_stride_B
=
BatchStrideB
;
args
.
batch_stride_C
=
BatchStrideC
;
invoke_batched_gemm
<
ALayout
,
BLayout
,
CLayout
>
(
kargs
,
args
.
batch_count
=
BatchCount
;
invoke_batched_gemm
<
ALayout
,
BLayout
,
CLayout
>
(
args
,
ck_tile
::
stream_config
{
nullptr
,
false
});
ck_tile
::
stream_config
{
nullptr
,
false
});
std
::
cout
<<
"Run kernel with M ="
<<
M
<<
" N ="
<<
N
<<
" K ="
<<
K
std
::
cout
<<
"Run kernel with M ="
<<
M
<<
" N ="
<<
N
<<
" K ="
<<
K
...
...
test/ck_tile/gemm/test_gemm_pipeline_util.hpp
View file @
0ef27d53
...
@@ -31,22 +31,8 @@ class TestCkTileGemmPipeline : public ::testing::Test
...
@@ -31,22 +31,8 @@ class TestCkTileGemmPipeline : public ::testing::Test
static
constexpr
auto
PipelineType
=
std
::
tuple_element_t
<
8
,
Tuple
>::
value
;
static
constexpr
auto
PipelineType
=
std
::
tuple_element_t
<
8
,
Tuple
>::
value
;
// TODO: expose tile size through test t-param ?
// TODO: expose tile size through test t-param ?
struct
gemm_args
{
const
void
*
p_a
;
const
void
*
p_b
;
void
*
p_c
;
ck_tile
::
index_t
kbatch
;
ck_tile
::
index_t
M
;
ck_tile
::
index_t
N
;
ck_tile
::
index_t
K
;
ck_tile
::
index_t
stride_A
;
ck_tile
::
index_t
stride_B
;
ck_tile
::
index_t
stride_C
;
};
template
<
bool
PadM
,
bool
PadN
,
bool
PadK
>
template
<
bool
PadM
,
bool
PadN
,
bool
PadK
>
void
invoke_gemm
(
const
gemm_a
rgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
void
invoke_gemm
(
const
ck_tile
::
GemmHostA
rgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
{
{
// TODO: This should be parameterized in tests
// TODO: This should be parameterized in tests
constexpr
ck_tile
::
index_t
M_Tile
=
128
;
constexpr
ck_tile
::
index_t
M_Tile
=
128
;
...
@@ -88,7 +74,9 @@ class TestCkTileGemmPipeline : public ::testing::Test
...
@@ -88,7 +74,9 @@ class TestCkTileGemmPipeline : public ::testing::Test
ck_tile
::
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
>>>
;
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
GemmShape
,
Traits
>>>
;
const
ck_tile
::
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
args
.
K
);
const
ck_tile
::
index_t
k_grain
=
args
.
k_batch
*
K_Tile
;
const
ck_tile
::
index_t
K_split
=
(
args
.
K
+
k_grain
-
1
)
/
k_grain
*
K_Tile
;
const
ck_tile
::
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
K_split
);
const
bool
has_hot_loop
=
BaseGemmPipeline
::
BlockHasHotloop
(
num_loop
);
const
bool
has_hot_loop
=
BaseGemmPipeline
::
BlockHasHotloop
(
num_loop
);
const
ck_tile
::
TailNumber
tail_num
=
BaseGemmPipeline
::
GetBlockLoopTailNum
(
num_loop
);
const
ck_tile
::
TailNumber
tail_num
=
BaseGemmPipeline
::
GetBlockLoopTailNum
(
num_loop
);
...
@@ -117,17 +105,9 @@ class TestCkTileGemmPipeline : public ::testing::Test
...
@@ -117,17 +105,9 @@ class TestCkTileGemmPipeline : public ::testing::Test
has_hot_loop_v
,
has_hot_loop_v
,
tail_number_v
>>>
;
tail_number_v
>>>
;
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
GemmPipeline
,
GemmEpilogue
>
;
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
GemmPipeline
,
GemmEpilogue
>
;
auto
kargs
=
Kernel
::
MakeKargs
(
args
.
p_a
,
auto
kargs
=
Kernel
::
MakeKernelArgs
(
args
);
args
.
p_b
,
args
.
p_c
,
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
k_batch
);
args
.
M
,
args
.
N
,
args
.
K
,
args
.
stride_A
,
args
.
stride_B
,
args
.
stride_C
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
kbatch
);
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
if
(
!
Kernel
::
IsSupportedArgument
(
kargs
))
if
(
!
Kernel
::
IsSupportedArgument
(
kargs
))
...
@@ -319,11 +299,11 @@ class TestCkTileGemmPipeline : public ::testing::Test
...
@@ -319,11 +299,11 @@ class TestCkTileGemmPipeline : public ::testing::Test
c_m_n_dev_buf
.
SetZero
();
c_m_n_dev_buf
.
SetZero
();
c_m_n_dev_result
.
SetZero
();
c_m_n_dev_result
.
SetZero
();
gemm_a
rgs
args
;
ck_tile
::
GemmHostA
rgs
args
;
args
.
p_a
=
a_m_k_dev_buf
.
GetDeviceBuffer
();
args
.
a_ptr
=
a_m_k_dev_buf
.
GetDeviceBuffer
();
args
.
p_b
=
b_k_n_dev_buf
.
GetDeviceBuffer
();
args
.
b_ptr
=
b_k_n_dev_buf
.
GetDeviceBuffer
();
args
.
p_c
=
c_m_n_dev_buf
.
GetDeviceBuffer
();
args
.
c_ptr
=
c_m_n_dev_buf
.
GetDeviceBuffer
();
args
.
kbatch
=
kbatch
;
args
.
k
_
batch
=
kbatch
;
args
.
M
=
M
;
args
.
M
=
M
;
args
.
N
=
N
;
args
.
N
=
N
;
args
.
K
=
K
;
args
.
K
=
K
;
...
...
test/data_type/CMakeLists.txt
View file @
0ef27d53
...
@@ -87,3 +87,4 @@ if(result EQUAL 0)
...
@@ -87,3 +87,4 @@ if(result EQUAL 0)
endif
()
endif
()
add_gtest_executable
(
test_type_convert_const type_convert_const.cpp
)
add_gtest_executable
(
test_type_convert_const type_convert_const.cpp
)
add_gtest_executable
(
test_bhalf test_bhalf.cpp
)
test/data_type/test_bhalf.cpp
0 → 100644
View file @
0ef27d53
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp"
using
ck
::
bhalf_t
;
using
ck
::
type_convert
;
TEST
(
BHALF_T
,
Nan
)
{
const
uint16_t
binary_bhalf_nan
=
0x7FC0
;
const
bhalf_t
bhalf_nan
=
ck
::
bit_cast
<
bhalf_t
>
(
binary_bhalf_nan
);
EXPECT_EQ
(
bhalf_nan
,
type_convert
<
bhalf_t
>
(
ck
::
NumericLimits
<
float
>::
QuietNaN
()));
}
TEST
(
BHALF_T
,
Inf
)
{
const
uint16_t
binary_bhalf_inf
=
0x7F80
;
const
bhalf_t
bhalf_inf
=
ck
::
bit_cast
<
bhalf_t
>
(
binary_bhalf_inf
);
EXPECT_EQ
(
bhalf_inf
,
type_convert
<
bhalf_t
>
(
ck
::
NumericLimits
<
float
>::
Infinity
()));
}
TEST
(
BHALF_T
,
MantisaOverflow
)
{
const
float
abs_tol
=
std
::
pow
(
2
,
-
7
);
const
uint32_t
val
=
0x81FFFFFF
;
const
float
float_val
=
ck
::
bit_cast
<
float
>
(
val
);
ASSERT_NEAR
(
float_val
,
type_convert
<
float
>
(
type_convert
<
bhalf_t
>
(
float_val
)),
abs_tol
);
}
TEST
(
BHALF_T
,
ExpOverflow
)
{
const
uint32_t
val
=
0xFF800000
;
const
float
float_val
=
ck
::
bit_cast
<
float
>
(
val
);
ASSERT_EQ
(
type_convert
<
float
>
(
type_convert
<
bhalf_t
>
(
float_val
)),
float_val
);
}
TEST
(
BHALF_T
,
MantisaExpOverflow
)
{
const
uint32_t
val
=
0xFFFFFFFF
;
const
float
float_val
=
ck
::
bit_cast
<
float
>
(
val
);
ASSERT_TRUE
(
std
::
isnan
(
float_val
));
ASSERT_TRUE
(
std
::
isnan
(
type_convert
<
float
>
(
type_convert
<
bhalf_t
>
(
float_val
))));
}
test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp
View file @
0ef27d53
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <cstdlib>
#include <iostream>
#include <iostream>
...
@@ -43,7 +43,6 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
...
@@ -43,7 +43,6 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
return
true
;
return
true
;
}
}
}
}
if
(
ck
::
is_gfx11_supported
()
||
ck
::
is_gfx12_supported
())
if
(
ck
::
is_gfx11_supported
()
||
ck
::
is_gfx12_supported
())
{
{
// on gfx11 only support for 3d is implemented
// on gfx11 only support for 3d is implemented
...
@@ -143,19 +142,23 @@ using KernelTypes2d = ::testing::Types<
...
@@ -143,19 +142,23 @@ using KernelTypes2d = ::testing::Types<
std
::
tuple
<
float
,
float
,
float
,
GNHWC
,
GKYXC
,
GNHWK
,
ck
::
Number
<
2
>>
,
std
::
tuple
<
float
,
float
,
float
,
GNHWC
,
GKYXC
,
GNHWK
,
ck
::
Number
<
2
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
GNHWC
,
GKYXC
,
GNHWK
,
ck
::
Number
<
2
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
GNHWC
,
GKYXC
,
GNHWK
,
ck
::
Number
<
2
>>
,
std
::
tuple
<
ck
::
bhalf_t
,
float
,
ck
::
bhalf_t
,
GNHWC
,
GKYXC
,
GNHWK
,
ck
::
Number
<
2
>>
,
std
::
tuple
<
ck
::
bhalf_t
,
float
,
ck
::
bhalf_t
,
GNHWC
,
GKYXC
,
GNHWK
,
ck
::
Number
<
2
>>
,
std
::
tuple
<
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
NHWGC
,
GKYXC
,
NHWGK
,
ck
::
Number
<
2
>>
,
std
::
tuple
<
float
,
float
,
float
,
NHWGC
,
GKYXC
,
NHWGK
,
ck
::
Number
<
2
>>
,
std
::
tuple
<
float
,
float
,
float
,
NHWGC
,
GKYXC
,
NHWGK
,
ck
::
Number
<
2
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
NHWGC
,
GKYXC
,
NHWGK
,
ck
::
Number
<
2
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
NHWGC
,
GKYXC
,
NHWGK
,
ck
::
Number
<
2
>>
,
std
::
tuple
<
ck
::
bhalf_t
,
float
,
ck
::
bhalf_t
,
NHWGC
,
GKYXC
,
NHWGK
,
ck
::
Number
<
2
>>
,
std
::
tuple
<
ck
::
bhalf_t
,
float
,
ck
::
bhalf_t
,
NHWGC
,
GKYXC
,
NHWGK
,
ck
::
Number
<
2
>>
,
std
::
tuple
<
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
NGCHW
,
GKYXC
,
NGKHW
,
ck
::
Number
<
2
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
NGCHW
,
GKYXC
,
NGKHW
,
ck
::
Number
<
2
>>>
;
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
NGCHW
,
GKYXC
,
NGKHW
,
ck
::
Number
<
2
>>>
;
using
KernelTypes3d
=
::
testing
::
Types
<
using
KernelTypes3d
=
::
testing
::
Types
<
std
::
tuple
<
float
,
float
,
float
,
GNDHWC
,
GKZYXC
,
GNDHWK
,
ck
::
Number
<
3
>>
,
std
::
tuple
<
float
,
float
,
float
,
GNDHWC
,
GKZYXC
,
GNDHWK
,
ck
::
Number
<
3
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
GNDHWC
,
GKZYXC
,
GNDHWK
,
ck
::
Number
<
3
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
GNDHWC
,
GKZYXC
,
GNDHWK
,
ck
::
Number
<
3
>>
,
std
::
tuple
<
ck
::
bhalf_t
,
float
,
ck
::
bhalf_t
,
GNDHWC
,
GKZYXC
,
GNDHWK
,
ck
::
Number
<
3
>>
,
std
::
tuple
<
ck
::
bhalf_t
,
float
,
ck
::
bhalf_t
,
GNDHWC
,
GKZYXC
,
GNDHWK
,
ck
::
Number
<
3
>>
,
std
::
tuple
<
int8_t
,
int8_t
,
int8_t
,
GNDHWC
,
GKZYXC
,
GNDHWK
,
ck
::
Number
<
3
>>
,
std
::
tuple
<
int8_t
,
int8_t
,
int8_t
,
GNDHWC
,
GKZYXC
,
GNDHWK
,
ck
::
Number
<
3
>>
,
std
::
tuple
<
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
ck
::
Number
<
3
>>
,
std
::
tuple
<
float
,
float
,
float
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
ck
::
Number
<
3
>>
,
std
::
tuple
<
float
,
float
,
float
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
ck
::
Number
<
3
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
ck
::
Number
<
3
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
ck
::
Number
<
3
>>
,
std
::
tuple
<
ck
::
bhalf_t
,
float
,
ck
::
bhalf_t
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
ck
::
Number
<
3
>>
,
std
::
tuple
<
ck
::
bhalf_t
,
float
,
ck
::
bhalf_t
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
ck
::
Number
<
3
>>
,
std
::
tuple
<
int8_t
,
int8_t
,
int8_t
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
ck
::
Number
<
3
>>
,
std
::
tuple
<
int8_t
,
int8_t
,
int8_t
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
ck
::
Number
<
3
>>
,
std
::
tuple
<
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
NGCDHW
,
GKZYXC
,
NGKDHW
,
ck
::
Number
<
3
>>
,
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
NGCDHW
,
GKZYXC
,
NGKDHW
,
ck
::
Number
<
3
>>>
;
std
::
tuple
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
NGCDHW
,
GKZYXC
,
NGKDHW
,
ck
::
Number
<
3
>>>
;
TYPED_TEST_SUITE
(
TestGroupedConvndBwdWeight1d
,
KernelTypes1d
);
TYPED_TEST_SUITE
(
TestGroupedConvndBwdWeight1d
,
KernelTypes1d
);
...
@@ -179,6 +182,8 @@ TYPED_TEST(TestGroupedConvndBwdWeight2d, Test2D)
...
@@ -179,6 +182,8 @@ TYPED_TEST(TestGroupedConvndBwdWeight2d, Test2D)
this
->
conv_params
.
clear
();
this
->
conv_params
.
clear
();
this
->
conv_params
.
push_back
(
this
->
conv_params
.
push_back
(
{
2
,
2
,
64
,
128
,
256
,
{
1
,
1
},
{
7
,
7
},
{
2
,
2
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}});
{
2
,
2
,
64
,
128
,
256
,
{
1
,
1
},
{
7
,
7
},
{
2
,
2
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}});
this
->
conv_params
.
push_back
({
2
,
2
,
64
,
3
,
3
,
{
1
,
1
},
{
7
,
7
},
{
1
,
1
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}});
this
->
conv_params
.
push_back
({
2
,
2
,
64
,
5
,
5
,
{
1
,
1
},
{
7
,
7
},
{
1
,
1
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}});
this
->
conv_params
.
push_back
(
this
->
conv_params
.
push_back
(
{
2
,
2
,
4
,
128
,
256
,
{
3
,
3
},
{
14
,
14
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
{
2
,
2
,
4
,
128
,
256
,
{
3
,
3
},
{
14
,
14
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
this
->
conv_params
.
push_back
(
this
->
conv_params
.
push_back
(
...
...
test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp
View file @
0ef27d53
...
@@ -64,6 +64,7 @@ using KernelTypes2d = ::testing::Types<std::tuple<float, GNHWC, GKYXC, GNHWK>,
...
@@ -64,6 +64,7 @@ using KernelTypes2d = ::testing::Types<std::tuple<float, GNHWC, GKYXC, GNHWK>,
std
::
tuple
<
int8_t
,
NHWGC
,
GKYXC
,
NHWGK
>
,
std
::
tuple
<
int8_t
,
NHWGC
,
GKYXC
,
NHWGK
>
,
std
::
tuple
<
float
,
NGCHW
,
GKYXC
,
NGKHW
>
,
std
::
tuple
<
float
,
NGCHW
,
GKYXC
,
NGKHW
>
,
std
::
tuple
<
ck
::
half_t
,
NGCHW
,
GKYXC
,
NGKHW
>
,
std
::
tuple
<
ck
::
half_t
,
NGCHW
,
GKYXC
,
NGKHW
>
,
std
::
tuple
<
ck
::
bhalf_t
,
NGCHW
,
GKYXC
,
NGKHW
>
,
std
::
tuple
<
int8_t
,
NGCHW
,
GKYXC
,
NGKHW
>>
;
std
::
tuple
<
int8_t
,
NGCHW
,
GKYXC
,
NGKHW
>>
;
using
KernelTypes3d
=
::
testing
::
Types
<
std
::
tuple
<
float
,
GNDHWC
,
GKZYXC
,
GNDHWK
>
,
using
KernelTypes3d
=
::
testing
::
Types
<
std
::
tuple
<
float
,
GNDHWC
,
GKZYXC
,
GNDHWK
>
,
...
...
Prev
1
…
14
15
16
17
18
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment