CMakeLists.txt 6.76 KB
Newer Older
1
include_directories(BEFORE
Adam Osewski's avatar
Adam Osewski committed
2
    ${PROJECT_SOURCE_DIR}/
3
    ${PROJECT_SOURCE_DIR}/profiler/include
4
5
)

6
include(gtest)
JD's avatar
JD committed
7

8
9
add_custom_target(tests)

Chao Liu's avatar
Chao Liu committed
10
11
function(add_test_executable TEST_NAME)
    message("adding test ${TEST_NAME}")
12
13
    set(result 1)
    if(DEFINED DTYPES)
14
15
        foreach(source IN LISTS ARGN)
            set(test 0)
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
            if((source MATCHES "_fp16" OR source MATCHES "_f16") AND NOT "fp16" IN_LIST DTYPES)
                set(test 1)
            endif()
            if((source MATCHES "_fp32" OR source MATCHES "_f32") AND NOT "fp32" IN_LIST DTYPES)
                set(test 1)
            endif()
            if((source MATCHES "_fp64" OR source MATCHES "_f64") AND NOT "fp64" IN_LIST DTYPES)
                set(test 1)
            endif()
            if((source MATCHES "_fp8" OR source MATCHES "_f8") AND NOT "fp8" IN_LIST DTYPES)
                set(test 1)
            endif()
            if((source MATCHES "_bf8" OR source MATCHES "_bf8") AND NOT "bf8" IN_LIST DTYPES)
                set(test 1)
            endif()
            if((source MATCHES "_bf16" OR source MATCHES "_b16") AND NOT "bf16" IN_LIST DTYPES)
                set(test 1)
            endif()
            if((source MATCHES "_int8" OR source MATCHES "_i8") AND NOT "int8" IN_LIST DTYPES)
                set(test 1)
            endif()
37
38
39
            if(test EQUAL 1)
                message("removing test ${source} ")
                list(REMOVE_ITEM ARGN "${source}")
40
41
            endif()
        endforeach()
42
43
    endif()
    foreach(source IN LISTS ARGN)
44
45
46
47
48
        if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
            message("removing dl test ${source} ")
            list(REMOVE_ITEM ARGN "${source}")
        endif()
    endforeach()
49
50
51
52
53
54
55
56
57
58
59
60
    foreach(source IN LISTS ARGN)
        if(NOT GPU_TARGETS MATCHES "gfx9" AND source MATCHES "xdl")
            message("removing xdl test ${source} ")
            list(REMOVE_ITEM ARGN "${source}")
        endif()
    endforeach()
    foreach(source IN LISTS ARGN)
        if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "wmma")
            message("removing wmma test ${source} ")
            list(REMOVE_ITEM ARGN "${source}")
        endif()
    endforeach()
61
62
63
    #only continue if there are some source files left on the list
    if(ARGN)
        add_executable(${TEST_NAME} ${ARGN})
64
        target_link_libraries(${TEST_NAME} PRIVATE getopt::getopt)
65
66
67
68
69
70
71
        add_test(NAME ${TEST_NAME} COMMAND $<TARGET_FILE:${TEST_NAME}>)
        add_dependencies(tests ${TEST_NAME})
        add_dependencies(check ${TEST_NAME})
        rocm_install(TARGETS ${TEST_NAME} COMPONENT tests)
        set(result 0)
    endif()
    #message("add_test returns ${result}")
72
    set(result ${result} PARENT_SCOPE)
73
endfunction()
74
75
76

function(add_gtest_executable TEST_NAME)
    message("adding gtest ${TEST_NAME}")
77
78
    set(result 1)
    if(DEFINED DTYPES)
79
80
        foreach(source IN LISTS ARGN)
            set(test 0)
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
            if((source MATCHES "_fp16" OR source MATCHES "_f16") AND NOT "fp16" IN_LIST DTYPES)
                set(test 1)
            endif()
            if((source MATCHES "_fp32" OR source MATCHES "_f32") AND NOT "fp32" IN_LIST DTYPES)
                set(test 1)
            endif()
            if((source MATCHES "_fp64" OR source MATCHES "_f64") AND NOT "fp64" IN_LIST DTYPES)
                set(test 1)
            endif()
            if((source MATCHES "_fp8" OR source MATCHES "_f8") AND NOT "fp8" IN_LIST DTYPES)
                set(test 1)
            endif()
            if((source MATCHES "_bf8" OR source MATCHES "_bf8") AND NOT "bf8" IN_LIST DTYPES)
                set(test 1)
            endif()
            if((source MATCHES "_bf16" OR source MATCHES "_b16") AND NOT "bf16" IN_LIST DTYPES)
                set(test 1)
            endif()
            if((source MATCHES "_int8" OR source MATCHES "_i8") AND NOT "int8" IN_LIST DTYPES)
                set(test 1)
            endif()
102
103
104
            if(test EQUAL 1)
                message("removing gtest ${source} ")
                list(REMOVE_ITEM ARGN "${source}")
105
106
107
108
109
110
111
            endif()
        endforeach()
    endif()
    foreach(source IN LISTS ARGN)
        if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
            message("removing dl test ${source} ")
            list(REMOVE_ITEM ARGN "${source}")
112
113
114
115
116
117
118
119
120
121
122
123
        endif()
    endforeach()
    foreach(source IN LISTS ARGN)
        if(NOT GPU_TARGETS MATCHES "gfx9" AND source MATCHES "xdl")
            message("removing xdl test ${source} ")
            list(REMOVE_ITEM ARGN "${source}")
        endif()
    endforeach()
    foreach(source IN LISTS ARGN)
        if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "wmma")
            message("removing wmma test ${source} ")
            list(REMOVE_ITEM ARGN "${source}")
124
125
126
127
128
129
130
131
132
133
        endif()
    endforeach()
    #only continue if there are some source files left on the list
    if(ARGN)
        add_executable(${TEST_NAME} ${ARGN})
        add_dependencies(tests ${TEST_NAME})
        add_dependencies(check ${TEST_NAME})

        # suppress gtest warnings
        target_compile_options(${TEST_NAME} PRIVATE -Wno-global-constructors -Wno-undef)
134
        target_link_libraries(${TEST_NAME} PRIVATE gtest_main getopt::getopt)
135
136
137
138
139
        add_test(NAME ${TEST_NAME} COMMAND $<TARGET_FILE:${TEST_NAME}>)
        rocm_install(TARGETS ${TEST_NAME} COMPONENT tests)
        set(result 0)
    endif()
    #message("add_gtest returns ${result}")
140
    set(result ${result} PARENT_SCOPE)
141
endfunction()
142

143
add_compile_options(-Wno-c++20-extensions)
Chao Liu's avatar
Chao Liu committed
144
145
146
147
148
add_subdirectory(magic_number_division)
add_subdirectory(space_filling_curve)
add_subdirectory(conv_util)
add_subdirectory(reference_conv_fwd)
add_subdirectory(gemm)
149
add_subdirectory(gemm_add)
150
add_subdirectory(gemm_layernorm)
Chao Liu's avatar
Chao Liu committed
151
add_subdirectory(gemm_split_k)
152
add_subdirectory(gemm_universal)
Chao Liu's avatar
Chao Liu committed
153
154
add_subdirectory(gemm_reduce)
add_subdirectory(batched_gemm)
155
add_subdirectory(batched_gemm_reduce)
Anthony Chang's avatar
Anthony Chang committed
156
add_subdirectory(batched_gemm_gemm)
Anthony Chang's avatar
Anthony Chang committed
157
add_subdirectory(batched_gemm_softmax_gemm)
158
add_subdirectory(batched_gemm_softmax_gemm_permute)
Chao Liu's avatar
Chao Liu committed
159
add_subdirectory(grouped_gemm)
160
add_subdirectory(reduce)
161
add_subdirectory(convnd_fwd)
JD's avatar
JD committed
162
add_subdirectory(convnd_bwd_data)
163
add_subdirectory(grouped_convnd_fwd)
164
add_subdirectory(grouped_convnd_bwd_weight)
165
add_subdirectory(block_to_ctile_map)
166
add_subdirectory(softmax)
rocking's avatar
rocking committed
167
add_subdirectory(normalization_fwd)
168
add_subdirectory(normalization_bwd_data)
169
add_subdirectory(normalization_bwd_gamma_beta)
Adam Osewski's avatar
Adam Osewski committed
170
add_subdirectory(data_type)
171
add_subdirectory(elementwise_normalization)
172
add_subdirectory(batchnorm)
173
add_subdirectory(contraction)
174
add_subdirectory(pool)
175
add_subdirectory(batched_gemm_multi_d)
176
add_subdirectory(grouped_convnd_bwd_data)
177
add_subdirectory(conv_tensor_rearrange)
arai713's avatar
arai713 committed
178
add_subdirectory(transpose)
arai713's avatar
arai713 committed
179
add_subdirectory(permute_scale)
180
add_subdirectory(wrapper)
181
if(GPU_TARGETS MATCHES "gfx11")
182
183
    add_subdirectory(wmma_op)
endif()
carlushuang's avatar
carlushuang committed
184
add_subdirectory(position_embedding)