CMakeLists.txt 3.63 KB
Newer Older
zhuwenwen's avatar
zhuwenwen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# Copyright (c) 2021-2023, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# GoogleTest Preparation - Code block copied from
#   https://google.github.io/googletest/quickstart-cmake.html
include(FetchContent)
FetchContent_Declare(
  googletest
  GIT_REPOSITORY https://github.com/google/googletest.git
  GIT_TAG release-1.12.1
zhuwenwen's avatar
zhuwenwen committed
22
23
24
25
  # URL /path/to/local/googletest-release-1.12.1.zip
  # URL_HASH SHA256=24564e3b712d3eb30ac9a85d92f7d720f60cc0173730ac166f27dda7fed76cb2
  # export C_INCLUDE_PATH=/path/to/local/googletest-release-1.12.1/googletest/include${C_INCLUDE_PATH:+:${C_INCLUDE_PATH}}
  # export CPLUS_INCLUDE_PATH=/path/to/local/googletest-release-1.12.1/include${CPLUS_INCLUDE_PATH:+:${CPLUS_INCLUDE_PATH}}
zhuwenwen's avatar
zhuwenwen committed
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
)
add_definitions(-DTORCH_CUDA=1)

# For Windows: Prevent overriding the parent project's compiler/linker settings
set(gtest_force_shared_crt ON CACHE BOOL "" FORCE)
FetchContent_MakeAvailable(googletest)

add_executable(unittest
    test_attention_kernels.cu
    test_logprob_kernels.cu
    test_penalty_kernels.cu
    test_sampling_kernels.cu
    test_sampling_layer.cu
    test_tensor.cu)

# automatic discovery of unit tests
target_link_libraries(unittest PUBLIC "${TORCH_LIBRARIES}" gtest_main)
target_compile_features(unittest PRIVATE cxx_std_14)

# Sorted by alphabetical order of test name.
target_link_libraries(  # Libs for test_attention_kernels
  unittest PUBLIC
    -lcudart -lcurand
    gen_relative_pos_bias gpt_kernels gtest memory_utils tensor unfused_attention_kernels cuda_utils logger)
target_link_libraries(  # Libs for test_logprob_kernels
  unittest PUBLIC
    -lcudart
    logprob_kernels memory_utils cuda_utils logger)
target_link_libraries(  # Libs for test_penalty_kernels
  unittest PUBLIC
    -lcublas -lcublasLt -lcudart
    sampling_penalty_kernels beam_search_penalty_kernels memory_utils cuda_utils logger)
target_link_libraries(  # Libs for test_sampling_kernel
  unittest PUBLIC
    -lcudart
    sampling_topk_kernels sampling_topp_kernels memory_utils tensor cuda_utils logger)
target_link_libraries(  # Libs for test_sampling_layer
  unittest PUBLIC
    -lcublas -lcublasLt -lcudart
    cublasMMWrapper memory_utils
    DynamicDecodeLayer TopKSamplingLayer TopPSamplingLayer tensor cuda_utils logger)
target_link_libraries(  # Libs for test_tensor
  unittest PUBLIC tensor cuda_utils logger)

remove_definitions(-DTORCH_CUDA=1)
add_executable(test_gemm test_gemm.cu)
target_link_libraries(test_gemm PUBLIC -lcublas -lcudart -lcurand gemm cublasMMWrapper tensor cuda_utils logger)

add_executable(test_gpt_kernels test_gpt_kernels.cu)
target_link_libraries(test_gpt_kernels PUBLIC
                      gpt_kernels memory_utils tensor cuda_utils logger)

add_executable(test_activation test_activation.cu)
target_link_libraries(test_activation PUBLIC
                    -lcublas -lcublasLt -lcudart
                    activation_kernels memory_utils cuda_utils logger)

add_executable(test_context_decoder_layer test_context_decoder_layer.cu)
target_link_libraries(test_context_decoder_layer PUBLIC
                      ParallelGpt -lcublas -lcublasLt -lcudart
                      memory_utils tensor cuda_utils logger)