# SPDX-License-Identifier: MIT # Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. include_directories(BEFORE ${PROJECT_SOURCE_DIR}/ ${PROJECT_SOURCE_DIR}/profiler/include ) include(GTest) add_custom_target(tests) function(add_test_executable TEST_NAME) message("adding test ${TEST_NAME}") set(result 1) if(DEFINED DTYPES) foreach(source IN LISTS ARGN) set(test 0) if((source MATCHES "_fp16" OR source MATCHES "_f16") AND NOT "fp16" IN_LIST DTYPES) set(test 1) endif() if((source MATCHES "_fp32" OR source MATCHES "_f32") AND NOT "fp32" IN_LIST DTYPES) set(test 1) endif() if((source MATCHES "_fp64" OR source MATCHES "_f64") AND NOT "fp64" IN_LIST DTYPES) set(test 1) endif() if((source MATCHES "_fp8" OR source MATCHES "_f8") AND NOT "fp8" IN_LIST DTYPES) set(test 1) endif() if((source MATCHES "_bf8" OR source MATCHES "_bf8") AND NOT "bf8" IN_LIST DTYPES) set(test 1) endif() if((source MATCHES "_bf16" OR source MATCHES "_b16") AND NOT "bf16" IN_LIST DTYPES) set(test 1) endif() if((source MATCHES "_int8" OR source MATCHES "_i8") AND NOT "int8" IN_LIST DTYPES) set(test 1) endif() if(test EQUAL 1) message("removing test ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() endforeach() endif() foreach(source IN LISTS ARGN) if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") message("removing dl test ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() endforeach() #only continue if there are some source files left on the list if(ARGN) add_executable(${TEST_NAME} ${ARGN}) target_link_libraries(${TEST_NAME} PRIVATE getopt::getopt) add_test(NAME ${TEST_NAME} COMMAND $) add_dependencies(tests ${TEST_NAME}) add_dependencies(check ${TEST_NAME}) rocm_install(TARGETS ${TEST_NAME} COMPONENT tests) set(result 0) endif() #message("add_test returns ${result}") set(result ${result} PARENT_SCOPE) endfunction() function(add_gtest_executable TEST_NAME) message("adding gtest ${TEST_NAME}") set(result 1) if(DEFINED DTYPES) foreach(source IN LISTS ARGN) set(test 0) if((source MATCHES "_fp16" OR source MATCHES "_f16") AND NOT "fp16" IN_LIST DTYPES) set(test 1) endif() if((source MATCHES "_fp32" OR source MATCHES "_f32") AND NOT "fp32" IN_LIST DTYPES) set(test 1) endif() if((source MATCHES "_fp64" OR source MATCHES "_f64") AND NOT "fp64" IN_LIST DTYPES) set(test 1) endif() if((source MATCHES "_fp8" OR source MATCHES "_f8") AND NOT "fp8" IN_LIST DTYPES) set(test 1) endif() if((source MATCHES "_bf8" OR source MATCHES "_bf8") AND NOT "bf8" IN_LIST DTYPES) set(test 1) endif() if((source MATCHES "_bf16" OR source MATCHES "_b16") AND NOT "bf16" IN_LIST DTYPES) set(test 1) endif() if((source MATCHES "_int8" OR source MATCHES "_i8") AND NOT "int8" IN_LIST DTYPES) set(test 1) endif() if(test EQUAL 1) message("removing gtest ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() endforeach() endif() foreach(source IN LISTS ARGN) if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") message("removing dl test ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() endforeach() #only continue if there are some source files left on the list if(ARGN) add_executable(${TEST_NAME} ${ARGN}) add_dependencies(tests ${TEST_NAME}) add_dependencies(check ${TEST_NAME}) # suppress gtest warnings target_compile_options(${TEST_NAME} PRIVATE -Wno-global-constructors -Wno-undef) target_link_libraries(${TEST_NAME} PRIVATE gtest_main getopt::getopt) add_test(NAME ${TEST_NAME} COMMAND $) rocm_install(TARGETS ${TEST_NAME} COMPONENT tests) set(result 0) endif() #message("add_gtest returns ${result}") set(result ${result} PARENT_SCOPE) endfunction() add_subdirectory(magic_number_division) add_subdirectory(space_filling_curve) add_subdirectory(conv_util) add_subdirectory(reference_conv_fwd) add_subdirectory(gemm) add_subdirectory(gemm_layernorm) add_subdirectory(gemm_split_k) add_subdirectory(gemm_reduce) add_subdirectory(batched_gemm) add_subdirectory(batched_gemm_reduce) add_subdirectory(batched_gemm_gemm) add_subdirectory(batched_gemm_softmax_gemm) add_subdirectory(batched_gemm_softmax_gemm_permute) add_subdirectory(grouped_gemm) add_subdirectory(reduce) add_subdirectory(convnd_fwd) add_subdirectory(convnd_bwd_data) add_subdirectory(grouped_convnd_fwd) add_subdirectory(grouped_convnd_bwd_weight) add_subdirectory(block_to_ctile_map) add_subdirectory(softmax) add_subdirectory(normalization) add_subdirectory(data_type) add_subdirectory(elementwise_normalization) add_subdirectory(batchnorm) add_subdirectory(contraction) add_subdirectory(pool) add_subdirectory(batched_gemm_multi_d) add_subdirectory(grouped_convnd_bwd_data) add_subdirectory(conv_tensor_rearrange) if(GPU_TARGETS MATCHES "gfx11") add_subdirectory(wmma_op) endif()