Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
c7c47fd7
Commit
c7c47fd7
authored
Aug 10, 2023
by
Bartlomiej Wroblewski
Browse files
Merge branch 'develop' into bwroblew/dpp8
parents
f8eb91d7
578142db
Changes
183
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
156 additions
and
95 deletions
+156
-95
profiler/src/profile_gemm.cpp
profiler/src/profile_gemm.cpp
+7
-1
profiler/src/profile_grouped_gemm.cpp
profiler/src/profile_grouped_gemm.cpp
+2
-2
test/CMakeLists.txt
test/CMakeLists.txt
+1
-1
test/batched_gemm/CMakeLists.txt
test/batched_gemm/CMakeLists.txt
+20
-15
test/batched_gemm_gemm/CMakeLists.txt
test/batched_gemm_gemm/CMakeLists.txt
+7
-5
test/batched_gemm_multi_d/CMakeLists.txt
test/batched_gemm_multi_d/CMakeLists.txt
+2
-3
test/batched_gemm_reduce/CMakeLists.txt
test/batched_gemm_reduce/CMakeLists.txt
+6
-4
test/batched_gemm_softmax_gemm/CMakeLists.txt
test/batched_gemm_softmax_gemm/CMakeLists.txt
+7
-5
test/batched_gemm_softmax_gemm_permute/CMakeLists.txt
test/batched_gemm_softmax_gemm_permute/CMakeLists.txt
+19
-15
test/elementwise_normalization/CMakeLists.txt
test/elementwise_normalization/CMakeLists.txt
+6
-7
test/gemm_layernorm/CMakeLists.txt
test/gemm_layernorm/CMakeLists.txt
+2
-0
test/gemm_reduce/CMakeLists.txt
test/gemm_reduce/CMakeLists.txt
+5
-3
test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp
...uped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp
+12
-0
test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface.cpp
...d_bwd_weight/test_grouped_convnd_bwd_weight_interface.cpp
+12
-13
test/grouped_convnd_fwd/grouped_convnd_fwd.cpp
test/grouped_convnd_fwd/grouped_convnd_fwd.cpp
+15
-1
test/grouped_gemm/CMakeLists.txt
test/grouped_gemm/CMakeLists.txt
+2
-0
test/normalization/CMakeLists.txt
test/normalization/CMakeLists.txt
+19
-16
test/pool_fwd/test_avg_pool2d_fwd.cpp
test/pool_fwd/test_avg_pool2d_fwd.cpp
+4
-0
test/pool_fwd/test_avg_pool3d_fwd.cpp
test/pool_fwd/test_avg_pool3d_fwd.cpp
+4
-2
test/pool_fwd/test_max_pool2d_fwd.cpp
test/pool_fwd/test_max_pool2d_fwd.cpp
+4
-2
No files found.
profiler/src/profile_gemm.cpp
View file @
c7c47fd7
...
@@ -121,7 +121,10 @@ int profile_gemm(int argc, char* argv[])
...
@@ -121,7 +121,10 @@ int profile_gemm(int argc, char* argv[])
return
pass
?
0
:
1
;
return
pass
?
0
:
1
;
};
};
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
if
(
false
)
;
#ifdef __fp32__
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
{
return
profile
(
Row
{},
Row
{},
Row
{},
F32
{},
F32
{},
F32
{},
F32
{});
return
profile
(
Row
{},
Row
{},
Row
{},
F32
{},
F32
{},
F32
{},
F32
{});
}
}
...
@@ -137,6 +140,8 @@ int profile_gemm(int argc, char* argv[])
...
@@ -137,6 +140,8 @@ int profile_gemm(int argc, char* argv[])
{
{
return
profile
(
Col
{},
Col
{},
Row
{},
F32
{},
F32
{},
F32
{},
F32
{});
return
profile
(
Col
{},
Col
{},
Row
{},
F32
{},
F32
{},
F32
{},
F32
{});
}
}
#endif
#ifdef __fp16__
else
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
else
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
{
return
profile
(
Row
{},
Row
{},
Row
{},
F16
{},
F16
{},
F32
{},
F16
{});
return
profile
(
Row
{},
Row
{},
Row
{},
F16
{},
F16
{},
F32
{},
F16
{});
...
@@ -153,6 +158,7 @@ int profile_gemm(int argc, char* argv[])
...
@@ -153,6 +158,7 @@ int profile_gemm(int argc, char* argv[])
{
{
return
profile
(
Col
{},
Col
{},
Row
{},
F16
{},
F16
{},
F32
{},
F16
{});
return
profile
(
Col
{},
Col
{},
Row
{},
F16
{},
F16
{},
F32
{},
F16
{});
}
}
#endif
#ifdef __bf16__
#ifdef __bf16__
else
if
(
data_type
==
GemmDataType
::
BF16_BF16_BF16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
else
if
(
data_type
==
GemmDataType
::
BF16_BF16_BF16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
{
...
...
profiler/src/profile_grouped_gemm.cpp
View file @
c7c47fd7
...
@@ -88,7 +88,7 @@ int profile_grouped_gemm(int argc, char* argv[])
...
@@ -88,7 +88,7 @@ int profile_grouped_gemm(int argc, char* argv[])
const
auto
StrideBs
=
argToIntArray
(
argv
[
12
]);
const
auto
StrideBs
=
argToIntArray
(
argv
[
12
]);
const
auto
StrideCs
=
argToIntArray
(
argv
[
13
]);
const
auto
StrideCs
=
argToIntArray
(
argv
[
13
]);
const
int
kbatch
=
argc
==
15
?
std
::
stoi
(
argv
[
14
])
:
1
;
const
int
kbatch
=
argc
==
15
?
std
::
stoi
(
argv
[
14
])
:
1
;
#ifdef __fp16__
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
{
ck
::
profiler
::
profile_grouped_gemm_impl
<
ck
::
half_t
,
ck
::
profiler
::
profile_grouped_gemm_impl
<
ck
::
half_t
,
...
@@ -173,7 +173,7 @@ int profile_grouped_gemm(int argc, char* argv[])
...
@@ -173,7 +173,7 @@ int profile_grouped_gemm(int argc, char* argv[])
{
{
throw
std
::
runtime_error
(
"wrong! this GEMM data_type & layout is not implemented"
);
throw
std
::
runtime_error
(
"wrong! this GEMM data_type & layout is not implemented"
);
}
}
#endif
return
0
;
return
0
;
}
}
...
...
test/CMakeLists.txt
View file @
c7c47fd7
...
@@ -60,6 +60,6 @@ add_subdirectory(contraction)
...
@@ -60,6 +60,6 @@ add_subdirectory(contraction)
add_subdirectory
(
pool_fwd
)
add_subdirectory
(
pool_fwd
)
add_subdirectory
(
batched_gemm_multi_d
)
add_subdirectory
(
batched_gemm_multi_d
)
add_subdirectory
(
grouped_convnd_bwd_data
)
add_subdirectory
(
grouped_convnd_bwd_data
)
if
(
GPU_TARGETS MATCHES
"gfx11
00
"
)
if
(
GPU_TARGETS MATCHES
"gfx11"
)
add_subdirectory
(
wmma_op
)
add_subdirectory
(
wmma_op
)
endif
()
endif
()
test/batched_gemm/CMakeLists.txt
View file @
c7c47fd7
...
@@ -2,21 +2,26 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
...
@@ -2,21 +2,26 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set
(
target 0
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
add_test_executable
(
test_batched_gemm_fp16 batched_gemm_fp16.cpp
)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
target_link_libraries
(
test_batched_gemm_fp16 PRIVATE utility
)
add_test_executable
(
test_batched_gemm_fp16 batched_gemm_fp16.cpp
)
target_link_libraries
(
test_batched_gemm_fp16 PRIVATE device_batched_gemm_instance
)
target_link_libraries
(
test_batched_gemm_fp16 PRIVATE utility
)
target_link_libraries
(
test_batched_gemm_fp16 PRIVATE device_batched_gemm_instance
)
add_test_executable
(
test_batched_gemm_fp32 batched_gemm_fp32.cpp
)
endif
()
target_link_libraries
(
test_batched_gemm_fp32 PRIVATE utility
)
if
(
DTYPES MATCHES
"fp32"
OR NOT DEFINED DTYPES
)
target_link_libraries
(
test_batched_gemm_fp32 PRIVATE device_batched_gemm_instance
)
add_test_executable
(
test_batched_gemm_fp32 batched_gemm_fp32.cpp
)
target_link_libraries
(
test_batched_gemm_fp32 PRIVATE utility
)
add_test_executable
(
test_batched_gemm_bf16 batched_gemm_bf16.cpp
)
target_link_libraries
(
test_batched_gemm_fp32 PRIVATE device_batched_gemm_instance
)
target_link_libraries
(
test_batched_gemm_bf16 PRIVATE utility
)
endif
()
target_link_libraries
(
test_batched_gemm_bf16 PRIVATE device_batched_gemm_instance
)
if
(
DTYPES MATCHES
"bf16"
OR NOT DEFINED DTYPES
)
add_test_executable
(
test_batched_gemm_bf16 batched_gemm_bf16.cpp
)
add_test_executable
(
test_batched_gemm_int8 batched_gemm_int8.cpp
)
target_link_libraries
(
test_batched_gemm_bf16 PRIVATE utility
)
target_link_libraries
(
test_batched_gemm_int8 PRIVATE utility
)
target_link_libraries
(
test_batched_gemm_bf16 PRIVATE device_batched_gemm_instance
)
target_link_libraries
(
test_batched_gemm_int8 PRIVATE device_batched_gemm_instance
)
endif
()
if
(
DTYPES MATCHES
"int8"
OR NOT DEFINED DTYPES
)
add_test_executable
(
test_batched_gemm_int8 batched_gemm_int8.cpp
)
target_link_libraries
(
test_batched_gemm_int8 PRIVATE utility
)
target_link_libraries
(
test_batched_gemm_int8 PRIVATE device_batched_gemm_instance
)
endif
()
set
(
target 1
)
set
(
target 1
)
endif
()
endif
()
endforeach
()
endforeach
()
\ No newline at end of file
test/batched_gemm_gemm/CMakeLists.txt
View file @
c7c47fd7
...
@@ -2,10 +2,12 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
...
@@ -2,10 +2,12 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set
(
target 0
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
add_custom_target
(
test_batched_gemm_gemm
)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_gtest_executable
(
test_batched_gemm_gemm_fp16 test_batched_gemm_gemm_fp16.cpp
)
add_custom_target
(
test_batched_gemm_gemm
)
target_link_libraries
(
test_batched_gemm_gemm_fp16 PRIVATE utility device_batched_gemm_gemm_instance
)
add_gtest_executable
(
test_batched_gemm_gemm_fp16 test_batched_gemm_gemm_fp16.cpp
)
add_dependencies
(
test_batched_gemm_gemm test_batched_gemm_gemm_fp16
)
target_link_libraries
(
test_batched_gemm_gemm_fp16 PRIVATE utility device_batched_gemm_gemm_instance
)
set
(
target 1
)
add_dependencies
(
test_batched_gemm_gemm test_batched_gemm_gemm_fp16
)
set
(
target 1
)
endif
()
endif
()
endif
()
endforeach
()
endforeach
()
\ No newline at end of file
test/batched_gemm_multi_d/CMakeLists.txt
View file @
c7c47fd7
# TODO: Enable for gfx90a after complier fix
if
(
DL_KERNELS
)
if
(
DL_KERNELS
)
add_gtest_executable
(
test_batched_gemm_multi_d test_batched_gemm_multi_d.cpp
)
add_gtest_executable
(
test_batched_gemm_multi_d test_batched_gemm_multi_d.cpp
)
target_link_libraries
(
test_batched_gemm_multi_d PRIVATE utility device_batched_gemm_multi_d_instance
)
target_link_libraries
(
test_batched_gemm_multi_d PRIVATE utility device_batched_gemm_multi_d_instance
)
endif
()
endif
()
test/batched_gemm_reduce/CMakeLists.txt
View file @
c7c47fd7
...
@@ -2,9 +2,11 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
...
@@ -2,9 +2,11 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set
(
target 0
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
add_test_executable
(
test_batched_gemm_reduce_fp16 batched_gemm_reduce_fp16.cpp
)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
target_link_libraries
(
test_batched_gemm_reduce_fp16 PRIVATE utility
)
add_test_executable
(
test_batched_gemm_reduce_fp16 batched_gemm_reduce_fp16.cpp
)
target_link_libraries
(
test_batched_gemm_reduce_fp16 PRIVATE device_batched_gemm_reduce_instance
)
target_link_libraries
(
test_batched_gemm_reduce_fp16 PRIVATE utility
)
set
(
target 1
)
target_link_libraries
(
test_batched_gemm_reduce_fp16 PRIVATE device_batched_gemm_reduce_instance
)
set
(
target 1
)
endif
()
endif
()
endif
()
endforeach
()
endforeach
()
test/batched_gemm_softmax_gemm/CMakeLists.txt
View file @
c7c47fd7
...
@@ -2,10 +2,12 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
...
@@ -2,10 +2,12 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set
(
target 0
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
add_custom_target
(
test_batched_gemm_softmax_gemm
)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_gtest_executable
(
test_batched_gemm_softmax_gemm_fp16 test_batched_gemm_softmax_gemm_fp16.cpp
)
add_custom_target
(
test_batched_gemm_softmax_gemm
)
target_link_libraries
(
test_batched_gemm_softmax_gemm_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_instance
)
add_gtest_executable
(
test_batched_gemm_softmax_gemm_fp16 test_batched_gemm_softmax_gemm_fp16.cpp
)
add_dependencies
(
test_batched_gemm_softmax_gemm test_batched_gemm_softmax_gemm_fp16
)
target_link_libraries
(
test_batched_gemm_softmax_gemm_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_instance
)
set
(
target 1
)
add_dependencies
(
test_batched_gemm_softmax_gemm test_batched_gemm_softmax_gemm_fp16
)
set
(
target 1
)
endif
()
endif
()
endif
()
endforeach
()
endforeach
()
\ No newline at end of file
test/batched_gemm_softmax_gemm_permute/CMakeLists.txt
View file @
c7c47fd7
...
@@ -2,21 +2,25 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
...
@@ -2,21 +2,25 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set
(
target 0
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
add_custom_target
(
test_batched_gemm_softmax_gemm_permute
)
if
(
DTYPES MATCHES
"fp16"
OR DTYPES MATCHES
"bf16"
OR NOT DEFINED DTYPES
)
add_custom_target
(
test_batched_gemm_softmax_gemm_permute
)
add_gtest_executable
(
test_batched_gemm_softmax_gemm_permute_fp16 test_batched_gemm_softmax_gemm_permute_fp16.cpp
)
endif
()
add_gtest_executable
(
test_batched_gemm_softmax_gemm_permute_bf16 test_batched_gemm_softmax_gemm_permute_bf16.cpp
)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
target_link_libraries
(
test_batched_gemm_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance
)
add_gtest_executable
(
test_batched_gemm_softmax_gemm_permute_fp16 test_batched_gemm_softmax_gemm_permute_fp16.cpp
)
target_link_libraries
(
test_batched_gemm_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance
)
add_gtest_executable
(
test_batched_gemm_bias_softmax_gemm_permute_fp16 test_batched_gemm_bias_softmax_gemm_permute_fp16.cpp
)
add_dependencies
(
test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_fp16
)
target_link_libraries
(
test_batched_gemm_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance
)
add_dependencies
(
test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_bf16
)
target_link_libraries
(
test_batched_gemm_bias_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance
)
add_dependencies
(
test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_fp16
)
add_gtest_executable
(
test_batched_gemm_bias_softmax_gemm_permute_fp16 test_batched_gemm_bias_softmax_gemm_permute_fp16.cpp
)
add_dependencies
(
test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_fp16
)
add_gtest_executable
(
test_batched_gemm_bias_softmax_gemm_permute_bf16 test_batched_gemm_bias_softmax_gemm_permute_bf16.cpp
)
endif
()
target_link_libraries
(
test_batched_gemm_bias_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance
)
if
(
DTYPES MATCHES
"bf16"
OR NOT DEFINED DTYPES
)
target_link_libraries
(
test_batched_gemm_bias_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance
)
add_gtest_executable
(
test_batched_gemm_softmax_gemm_permute_bf16 test_batched_gemm_softmax_gemm_permute_bf16.cpp
)
add_dependencies
(
test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_fp16
)
add_gtest_executable
(
test_batched_gemm_bias_softmax_gemm_permute_bf16 test_batched_gemm_bias_softmax_gemm_permute_bf16.cpp
)
add_dependencies
(
test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_bf16
)
target_link_libraries
(
test_batched_gemm_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance
)
target_link_libraries
(
test_batched_gemm_bias_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance
)
add_dependencies
(
test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_bf16
)
add_dependencies
(
test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_bf16
)
endif
()
set
(
target 1
)
set
(
target 1
)
endif
()
endif
()
endforeach
()
endforeach
()
\ No newline at end of file
test/elementwise_normalization/CMakeLists.txt
View file @
c7c47fd7
add_custom_target
(
test_elementwise_normalization
)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_custom_target
(
test_elementwise_normalization
)
add_gtest_executable
(
test_elementwise_layernorm_fp16 test_elementwise_layernorm_fp16.cpp
)
add_gtest_executable
(
test_elementwise_layernorm_fp16 test_elementwise_layernorm_fp16.cpp
)
target_link_libraries
(
test_elementwise_layernorm_fp16 PRIVATE utility device_elementwise_normalization_instance
)
target_link_libraries
(
test_elementwise_layernorm_fp16 PRIVATE utility device_elementwise_normalization_instance
)
add_dependencies
(
test_elementwise_normalization test_elementwise_layernorm_fp16
)
endif
()
add_dependencies
(
test_elementwise_normalization test_elementwise_layernorm_fp16
)
\ No newline at end of file
test/gemm_layernorm/CMakeLists.txt
View file @
c7c47fd7
...
@@ -2,10 +2,12 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
...
@@ -2,10 +2,12 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set
(
target 0
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_custom_target
(
test_gemm_layernorm
)
add_custom_target
(
test_gemm_layernorm
)
add_gtest_executable
(
test_gemm_add_relu_add_layernorm_fp16 test_gemm_add_relu_add_layernorm_fp16.cpp
)
add_gtest_executable
(
test_gemm_add_relu_add_layernorm_fp16 test_gemm_add_relu_add_layernorm_fp16.cpp
)
target_link_libraries
(
test_gemm_add_relu_add_layernorm_fp16 PRIVATE utility device_gemm_add_relu_add_layernorm_instance
)
target_link_libraries
(
test_gemm_add_relu_add_layernorm_fp16 PRIVATE utility device_gemm_add_relu_add_layernorm_instance
)
add_dependencies
(
test_gemm_layernorm test_gemm_add_relu_add_layernorm_fp16
)
add_dependencies
(
test_gemm_layernorm test_gemm_add_relu_add_layernorm_fp16
)
set
(
target 1
)
set
(
target 1
)
endif
()
endif
()
endif
()
endforeach
()
endforeach
()
test/gemm_reduce/CMakeLists.txt
View file @
c7c47fd7
add_test_executable
(
test_gemm_reduce_fp16 gemm_reduce_fp16.cpp
)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
target_link_libraries
(
test_gemm_reduce_fp16 PRIVATE utility
)
add_test_executable
(
test_gemm_reduce_fp16 gemm_reduce_fp16.cpp
)
target_link_libraries
(
test_gemm_reduce_fp16 PRIVATE device_gemm_reduce_instance
)
target_link_libraries
(
test_gemm_reduce_fp16 PRIVATE utility
)
target_link_libraries
(
test_gemm_reduce_fp16 PRIVATE device_gemm_reduce_instance
)
endif
()
\ No newline at end of file
test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp
View file @
c7c47fd7
...
@@ -100,6 +100,9 @@ TYPED_TEST(TestGroupedConvndBwdWeight1d, Test1D)
...
@@ -100,6 +100,9 @@ TYPED_TEST(TestGroupedConvndBwdWeight1d, Test1D)
this
->
conv_params
.
push_back
({
1
,
2
,
128
,
128
,
256
,
{
1
},
{
14
},
{
2
},
{
1
},
{
0
},
{
0
}});
this
->
conv_params
.
push_back
({
1
,
2
,
128
,
128
,
256
,
{
1
},
{
14
},
{
2
},
{
1
},
{
0
},
{
0
}});
this
->
conv_params
.
push_back
({
1
,
2
,
32
,
128
,
256
,
{
3
},
{
28
},
{
1
},
{
1
},
{
1
},
{
1
}});
this
->
conv_params
.
push_back
({
1
,
2
,
32
,
128
,
256
,
{
3
},
{
28
},
{
1
},
{
1
},
{
1
},
{
1
}});
this
->
conv_params
.
push_back
({
1
,
2
,
128
,
128
,
256
,
{
1
},
{
3
},
{
1
},
{
1
},
{
0
},
{
0
}});
this
->
conv_params
.
push_back
({
1
,
2
,
128
,
128
,
256
,
{
1
},
{
3
},
{
1
},
{
1
},
{
0
},
{
0
}});
this
->
conv_params
.
push_back
({
1
,
1
,
1
,
1
,
32
,
{
3
},
{
32
},
{
1
},
{
1
},
{
1
},
{
1
}});
this
->
conv_params
.
push_back
({
1
,
1
,
1
,
64
,
3
,
{
3
},
{
32
},
{
1
},
{
1
},
{
1
},
{
1
}});
this
->
conv_params
.
push_back
({
1
,
1
,
1
,
1
,
1
,
{
3
},
{
32
},
{
1
},
{
1
},
{
1
},
{
1
}});
this
->
Run
();
this
->
Run
();
}
}
...
@@ -112,6 +115,9 @@ TYPED_TEST(TestGroupedConvndBwdWeight2d, Test2D)
...
@@ -112,6 +115,9 @@ TYPED_TEST(TestGroupedConvndBwdWeight2d, Test2D)
{
2
,
2
,
4
,
128
,
256
,
{
3
,
3
},
{
14
,
14
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
{
2
,
2
,
4
,
128
,
256
,
{
3
,
3
},
{
14
,
14
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
this
->
conv_params
.
push_back
(
this
->
conv_params
.
push_back
(
{
2
,
2
,
128
,
128
,
256
,
{
1
,
1
},
{
3
,
3
},
{
1
,
1
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}});
{
2
,
2
,
128
,
128
,
256
,
{
1
,
1
},
{
3
,
3
},
{
1
,
1
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}});
this
->
conv_params
.
push_back
({
2
,
1
,
1
,
1
,
32
,
{
3
,
3
},
{
32
,
32
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
this
->
conv_params
.
push_back
({
2
,
1
,
1
,
64
,
3
,
{
3
,
3
},
{
32
,
32
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
this
->
conv_params
.
push_back
({
2
,
1
,
1
,
1
,
1
,
{
3
,
3
},
{
32
,
32
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
this
->
Run
();
this
->
Run
();
}
}
...
@@ -124,5 +130,11 @@ TYPED_TEST(TestGroupedConvndBwdWeight3d, Test3D)
...
@@ -124,5 +130,11 @@ TYPED_TEST(TestGroupedConvndBwdWeight3d, Test3D)
{
3
,
2
,
2
,
128
,
256
,
{
3
,
3
,
3
},
{
14
,
14
,
3
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
}});
{
3
,
2
,
2
,
128
,
256
,
{
3
,
3
,
3
},
{
14
,
14
,
3
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
}});
this
->
conv_params
.
push_back
(
this
->
conv_params
.
push_back
(
{
3
,
2
,
32
,
128
,
256
,
{
1
,
1
,
1
},
{
3
,
3
,
3
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
0
,
0
,
0
},
{
0
,
0
,
0
}});
{
3
,
2
,
32
,
128
,
256
,
{
1
,
1
,
1
},
{
3
,
3
,
3
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
0
,
0
,
0
},
{
0
,
0
,
0
}});
this
->
conv_params
.
push_back
(
{
3
,
1
,
1
,
1
,
32
,
{
3
,
3
,
3
},
{
32
,
32
,
32
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
}});
this
->
conv_params
.
push_back
(
{
3
,
1
,
1
,
64
,
3
,
{
3
,
3
,
3
},
{
32
,
32
,
32
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
}});
this
->
conv_params
.
push_back
(
{
3
,
1
,
1
,
1
,
1
,
{
3
,
3
,
3
},
{
32
,
32
,
32
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
}});
this
->
Run
();
this
->
Run
();
}
}
test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface.cpp
View file @
c7c47fd7
...
@@ -70,10 +70,11 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
...
@@ -70,10 +70,11 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
ck
::
utils
::
conv
::
make_output_host_tensor_descriptor_g_n_k_wos_packed
<
OutLayout
>
(
ck
::
utils
::
conv
::
make_output_host_tensor_descriptor_g_n_k_wos_packed
<
OutLayout
>
(
conv_param
);
conv_param
);
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_
spatial_
lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
input_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_
spatial_
lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
filter_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_
spatial_
lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
output_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
input_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
input_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
weights_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
output_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
output_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
{};
...
@@ -82,10 +83,11 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
...
@@ -82,10 +83,11 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
auto
range_copy
=
[](
const
auto
&
from
,
auto
to
)
{
std
::
copy
(
begin
(
from
),
end
(
from
),
to
);
};
auto
range_copy
=
[](
const
auto
&
from
,
auto
to
)
{
std
::
copy
(
begin
(
from
),
end
(
from
),
to
);
};
range_copy
(
conv_param
.
input_spatial_lengths_
,
begin
(
input_spatial_lengths
));
range_copy
(
in_g_n_c_wis_desc
.
GetLengths
(),
begin
(
input_lengths
));
range_copy
(
conv_param
.
filter_spatial_lengths_
,
begin
(
filter_spatial_lengths
));
range_copy
(
conv_param
.
output_spatial_lengths_
,
begin
(
output_spatial_lengths
));
range_copy
(
in_g_n_c_wis_desc
.
GetStrides
(),
begin
(
input_strides
));
range_copy
(
in_g_n_c_wis_desc
.
GetStrides
(),
begin
(
input_strides
));
range_copy
(
wei_g_k_c_xs_desc
.
GetLengths
(),
begin
(
filter_lengths
));
range_copy
(
wei_g_k_c_xs_desc
.
GetStrides
(),
begin
(
weights_strides
));
range_copy
(
out_g_n_k_wos_desc
.
GetLengths
(),
begin
(
output_lengths
));
range_copy
(
out_g_n_k_wos_desc
.
GetStrides
(),
begin
(
output_strides
));
range_copy
(
out_g_n_k_wos_desc
.
GetStrides
(),
begin
(
output_strides
));
range_copy
(
conv_param
.
conv_filter_strides_
,
begin
(
conv_filter_strides
));
range_copy
(
conv_param
.
conv_filter_strides_
,
begin
(
conv_filter_strides
));
range_copy
(
conv_param
.
conv_filter_dilations_
,
begin
(
conv_filter_dilations
));
range_copy
(
conv_param
.
conv_filter_dilations_
,
begin
(
conv_filter_dilations
));
...
@@ -97,14 +99,11 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
...
@@ -97,14 +99,11 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
auto
argument
=
conv
.
MakeArgument
(
nullptr
,
auto
argument
=
conv
.
MakeArgument
(
nullptr
,
nullptr
,
nullptr
,
nullptr
,
nullptr
,
conv_param
.
G_
,
input_lengths
,
conv_param
.
N_
,
conv_param
.
K_
,
conv_param
.
C_
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
input_strides
,
input_strides
,
filter_lengths
,
weights_strides
,
output_lengths
,
output_strides
,
output_strides
,
conv_filter_strides
,
conv_filter_strides
,
conv_filter_dilations
,
conv_filter_dilations
,
...
...
test/grouped_convnd_fwd/grouped_convnd_fwd.cpp
View file @
c7c47fd7
...
@@ -22,6 +22,8 @@ TEST_F(TestGroupedConvNdFwd, GroupedConv1dFwdGNWC)
...
@@ -22,6 +22,8 @@ TEST_F(TestGroupedConvNdFwd, GroupedConv1dFwdGNWC)
conv_params
.
push_back
({
1
,
2
,
128
,
128
,
256
,
{
1
},
{
14
},
{
2
},
{
1
},
{
0
},
{
0
}});
conv_params
.
push_back
({
1
,
2
,
128
,
128
,
256
,
{
1
},
{
14
},
{
2
},
{
1
},
{
0
},
{
0
}});
conv_params
.
push_back
({
1
,
2
,
128
,
128
,
256
,
{
3
},
{
28
},
{
1
},
{
1
},
{
1
},
{
1
}});
conv_params
.
push_back
({
1
,
2
,
128
,
128
,
256
,
{
3
},
{
28
},
{
1
},
{
1
},
{
1
},
{
1
}});
conv_params
.
push_back
({
1
,
2
,
128
,
128
,
256
,
{
1
},
{
3
},
{
1
},
{
1
},
{
0
},
{
0
}});
conv_params
.
push_back
({
1
,
2
,
128
,
128
,
256
,
{
1
},
{
3
},
{
1
},
{
1
},
{
0
},
{
0
}});
conv_params
.
push_back
({
1
,
1
,
1
,
1
,
32
,
{
3
},
{
32
},
{
1
},
{
1
},
{
1
},
{
1
}});
conv_params
.
push_back
({
1
,
1
,
1
,
64
,
3
,
{
3
},
{
32
},
{
1
},
{
1
},
{
1
},
{
1
}});
for
(
auto
&
param
:
conv_params
)
for
(
auto
&
param
:
conv_params
)
{
{
...
@@ -96,6 +98,9 @@ TEST_F(TestGroupedConvNdFwd, GroupedConv2dFwdGNHWC)
...
@@ -96,6 +98,9 @@ TEST_F(TestGroupedConvNdFwd, GroupedConv2dFwdGNHWC)
conv_params
.
push_back
({
2
,
2
,
128
,
128
,
256
,
{
1
,
1
},
{
7
,
7
},
{
2
,
2
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}});
conv_params
.
push_back
({
2
,
2
,
128
,
128
,
256
,
{
1
,
1
},
{
7
,
7
},
{
2
,
2
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}});
conv_params
.
push_back
({
2
,
2
,
128
,
128
,
256
,
{
3
,
3
},
{
14
,
14
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
conv_params
.
push_back
({
2
,
2
,
128
,
128
,
256
,
{
3
,
3
},
{
14
,
14
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
conv_params
.
push_back
({
2
,
2
,
128
,
128
,
256
,
{
1
,
1
},
{
3
,
3
},
{
1
,
1
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}});
conv_params
.
push_back
({
2
,
2
,
128
,
128
,
256
,
{
1
,
1
},
{
3
,
3
},
{
1
,
1
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}});
conv_params
.
push_back
({
2
,
1
,
1
,
1
,
32
,
{
3
,
3
},
{
32
,
32
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
conv_params
.
push_back
({
2
,
1
,
1
,
64
,
3
,
{
3
,
3
},
{
32
,
32
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
conv_params
.
push_back
({
2
,
1
,
1
,
1
,
1
,
{
3
,
3
},
{
32
,
32
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
for
(
auto
&
param
:
conv_params
)
for
(
auto
&
param
:
conv_params
)
{
{
...
@@ -173,6 +178,12 @@ TEST_F(TestGroupedConvNdFwd, GroupedConv3dFwdGNDHWC)
...
@@ -173,6 +178,12 @@ TEST_F(TestGroupedConvNdFwd, GroupedConv3dFwdGNDHWC)
{
3
,
2
,
128
,
128
,
256
,
{
3
,
3
,
3
},
{
14
,
14
,
3
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
}});
{
3
,
2
,
128
,
128
,
256
,
{
3
,
3
,
3
},
{
14
,
14
,
3
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
}});
conv_params
.
push_back
(
conv_params
.
push_back
(
{
3
,
2
,
128
,
128
,
256
,
{
1
,
1
,
1
},
{
3
,
3
,
3
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
0
,
0
,
0
},
{
0
,
0
,
0
}});
{
3
,
2
,
128
,
128
,
256
,
{
1
,
1
,
1
},
{
3
,
3
,
3
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
0
,
0
,
0
},
{
0
,
0
,
0
}});
conv_params
.
push_back
(
{
3
,
1
,
1
,
1
,
32
,
{
3
,
3
,
3
},
{
32
,
32
,
32
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
}});
this
->
conv_params
.
push_back
(
{
3
,
1
,
1
,
64
,
3
,
{
3
,
3
,
3
},
{
32
,
32
,
32
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
}});
conv_params
.
push_back
(
{
3
,
1
,
1
,
1
,
1
,
{
3
,
3
,
3
},
{
32
,
32
,
32
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
}});
for
(
auto
&
param
:
conv_params
)
for
(
auto
&
param
:
conv_params
)
{
{
...
@@ -247,6 +258,9 @@ TEST_F(TestGroupedConvNdFwd, GroupedConv2dFwdNHWGC)
...
@@ -247,6 +258,9 @@ TEST_F(TestGroupedConvNdFwd, GroupedConv2dFwdNHWGC)
conv_params
.
push_back
({
2
,
2
,
128
,
128
,
256
,
{
1
,
1
},
{
7
,
7
},
{
2
,
2
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}});
conv_params
.
push_back
({
2
,
2
,
128
,
128
,
256
,
{
1
,
1
},
{
7
,
7
},
{
2
,
2
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}});
conv_params
.
push_back
({
2
,
2
,
128
,
128
,
256
,
{
3
,
3
},
{
14
,
14
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
conv_params
.
push_back
({
2
,
2
,
128
,
128
,
256
,
{
3
,
3
},
{
14
,
14
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
conv_params
.
push_back
({
2
,
2
,
128
,
128
,
256
,
{
1
,
1
},
{
3
,
3
},
{
1
,
1
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}});
conv_params
.
push_back
({
2
,
2
,
128
,
128
,
256
,
{
1
,
1
},
{
3
,
3
},
{
1
,
1
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}});
conv_params
.
push_back
({
2
,
1
,
1
,
1
,
32
,
{
3
,
3
},
{
32
,
32
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
conv_params
.
push_back
({
2
,
1
,
1
,
64
,
3
,
{
3
,
3
},
{
32
,
32
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
conv_params
.
push_back
({
2
,
1
,
1
,
1
,
1
,
{
3
,
3
},
{
32
,
32
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
for
(
auto
&
param
:
conv_params
)
for
(
auto
&
param
:
conv_params
)
{
{
...
@@ -255,7 +269,7 @@ TEST_F(TestGroupedConvNdFwd, GroupedConv2dFwdNHWGC)
...
@@ -255,7 +269,7 @@ TEST_F(TestGroupedConvNdFwd, GroupedConv2dFwdNHWGC)
// fp16
// fp16
pass
=
ck
::
profiler
::
profile_grouped_conv_fwd_impl
<
2
,
pass
=
ck
::
profiler
::
profile_grouped_conv_fwd_impl
<
2
,
ck
::
tensor_layout
::
convolution
::
NHWGC
,
ck
::
tensor_layout
::
convolution
::
NHWGC
,
ck
::
tensor_layout
::
convolution
::
KYX
G
C
,
ck
::
tensor_layout
::
convolution
::
G
KYXC
,
ck
::
tensor_layout
::
convolution
::
NHWGK
,
ck
::
tensor_layout
::
convolution
::
NHWGK
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
...
...
test/grouped_gemm/CMakeLists.txt
View file @
c7c47fd7
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
set
(
target 0
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
...
@@ -12,3 +13,4 @@ foreach(gpu IN LISTS GPU_TARGETS)
...
@@ -12,3 +13,4 @@ foreach(gpu IN LISTS GPU_TARGETS)
set
(
target 1
)
set
(
target 1
)
endif
()
endif
()
endforeach
()
endforeach
()
endif
()
test/normalization/CMakeLists.txt
View file @
c7c47fd7
add_custom_target
(
test_normalization
)
if
(
DTYPES MATCHES
"fp16"
OR DTYPES MATCHES
"fp32"
OR NOT DEFINED DTYPES
)
add_custom_target
(
test_normalization
)
add_gtest_executable
(
test_layernorm2d_fp32 test_layernorm2d_fp32.cpp
)
endif
()
add_gtest_executable
(
test_layernorm2d_fp16 test_layernorm2d_fp16.cpp
)
if
(
DTYPES MATCHES
"fp32"
OR NOT DEFINED DTYPES
)
add_gtest_executable
(
test_groupnorm_fp16 test_groupnorm_fp16.cpp
)
add_gtest_executable
(
test_layernorm2d_fp32 test_layernorm2d_fp32.cpp
)
add_gtest_executable
(
test_groupnorm_fp32 test_groupnorm_fp32.cpp
)
add_gtest_executable
(
test_groupnorm_fp32 test_groupnorm_fp32.cpp
)
target_link_libraries
(
test_layernorm2d_fp32 PRIVATE utility device_normalization_instance
)
target_link_libraries
(
test_layernorm2d_fp32 PRIVATE utility device_normalization_instance
)
target_link_libraries
(
test_groupnorm_fp32 PRIVATE utility device_normalization_instance
)
target_link_libraries
(
test_layernorm2d_fp16 PRIVATE utility device_normalization_instance
)
add_dependencies
(
test_normalization test_layernorm2d_fp32
)
target_link_libraries
(
test_groupnorm_fp16 PRIVATE utility device_normalization_instance
)
add_dependencies
(
test_normalization test_groupnorm_fp32
)
target_link_libraries
(
test_groupnorm_fp32 PRIVATE utility device_normalization_instance
)
endif
()
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_dependencies
(
test_normalization test_layernorm2d_fp32
)
add_gtest_executable
(
test_layernorm2d_fp16 test_layernorm2d_fp16.cpp
)
add_dependencies
(
test_normalization test_layernorm2d_fp16
)
add_gtest_executable
(
test_groupnorm_fp16 test_groupnorm_fp16.cpp
)
add_dependencies
(
test_normalization test_groupnorm_fp16
)
target_link_libraries
(
test_layernorm2d_fp16 PRIVATE utility device_normalization_instance
)
add_dependencies
(
test_normalization test_groupnorm_fp32
)
target_link_libraries
(
test_groupnorm_fp16 PRIVATE utility device_normalization_instance
)
add_dependencies
(
test_normalization test_layernorm2d_fp16
)
add_dependencies
(
test_normalization test_groupnorm_fp16
)
endif
()
test/pool_fwd/test_avg_pool2d_fwd.cpp
View file @
c7c47fd7
...
@@ -41,8 +41,12 @@ class TestAvgPool2dFwd : public ::testing::Test
...
@@ -41,8 +41,12 @@ class TestAvgPool2dFwd : public ::testing::Test
}
}
};
};
#ifdef __fp16__
using
KernelTypes
=
using
KernelTypes
=
::
testing
::
Types
<
std
::
tuple
<
F16
,
F16
,
F32
,
I32
>
,
std
::
tuple
<
F32
,
F32
,
F32
,
I32
>>
;
::
testing
::
Types
<
std
::
tuple
<
F16
,
F16
,
F32
,
I32
>
,
std
::
tuple
<
F32
,
F32
,
F32
,
I32
>>
;
#else
using
KernelTypes
=
::
testing
::
Types
<
std
::
tuple
<
F32
,
F32
,
F32
,
I32
>>
;
#endif
TYPED_TEST_SUITE
(
TestAvgPool2dFwd
,
KernelTypes
);
TYPED_TEST_SUITE
(
TestAvgPool2dFwd
,
KernelTypes
);
TYPED_TEST
(
TestAvgPool2dFwd
,
Test_Pool
)
TYPED_TEST
(
TestAvgPool2dFwd
,
Test_Pool
)
...
...
test/pool_fwd/test_avg_pool3d_fwd.cpp
View file @
c7c47fd7
...
@@ -40,10 +40,12 @@ class TestAvgPool3dFwd : public ::testing::Test
...
@@ -40,10 +40,12 @@ class TestAvgPool3dFwd : public ::testing::Test
}
}
}
}
};
};
#ifdef __fp16__
using
KernelTypes
=
using
KernelTypes
=
::
testing
::
Types
<
std
::
tuple
<
F16
,
F16
,
F32
,
I32
>
,
std
::
tuple
<
F32
,
F32
,
F32
,
I32
>>
;
::
testing
::
Types
<
std
::
tuple
<
F16
,
F16
,
F32
,
I32
>
,
std
::
tuple
<
F32
,
F32
,
F32
,
I32
>>
;
#else
using
KernelTypes
=
::
testing
::
Types
<
std
::
tuple
<
F32
,
F32
,
F32
,
I32
>>
;
#endif
TYPED_TEST_SUITE
(
TestAvgPool3dFwd
,
KernelTypes
);
TYPED_TEST_SUITE
(
TestAvgPool3dFwd
,
KernelTypes
);
TYPED_TEST
(
TestAvgPool3dFwd
,
Test_Pool
)
TYPED_TEST
(
TestAvgPool3dFwd
,
Test_Pool
)
{
{
...
...
test/pool_fwd/test_max_pool2d_fwd.cpp
View file @
c7c47fd7
...
@@ -59,10 +59,12 @@ class TestMaxPool2dFwd : public ::testing::Test
...
@@ -59,10 +59,12 @@ class TestMaxPool2dFwd : public ::testing::Test
}
}
}
}
};
};
#ifdef __fp16__
using
KernelTypes
=
using
KernelTypes
=
::
testing
::
Types
<
std
::
tuple
<
F16
,
F16
,
F16
,
I32
>
,
std
::
tuple
<
F32
,
F32
,
F32
,
I32
>>
;
::
testing
::
Types
<
std
::
tuple
<
F16
,
F16
,
F16
,
I32
>
,
std
::
tuple
<
F32
,
F32
,
F32
,
I32
>>
;
#else
using
KernelTypes
=
::
testing
::
Types
<
std
::
tuple
<
F32
,
F32
,
F32
,
I32
>>
;
#endif
TYPED_TEST_SUITE
(
TestMaxPool2dFwd
,
KernelTypes
);
TYPED_TEST_SUITE
(
TestMaxPool2dFwd
,
KernelTypes
);
TYPED_TEST
(
TestMaxPool2dFwd
,
Test_Pool
)
TYPED_TEST
(
TestMaxPool2dFwd
,
Test_Pool
)
{
{
...
...
Prev
1
…
5
6
7
8
9
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment