CMakeLists.txt 1.13 KB
Newer Older
1
2
3
4
5
6
7
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

cmake_minimum_required(VERSION 3.18)

if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
8
  set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90)
9
10
11
12
13
14
15
16
17
18
19
20
21
endif()

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CUDA_STANDARD 17)
set(CMAKE_CUDA_STANDARD_REQUIRED ON)

project(transformer_engine LANGUAGES CUDA CXX)

list(APPEND CMAKE_CUDA_FLAGS "--threads 4")
if (CMAKE_BUILD_TYPE STREQUAL "Debug")
  set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -G")
endif()

cyanguwa's avatar
cyanguwa committed
22
list(APPEND CMAKE_MODULE_PATH "${CMAKE_SOURCE_DIR}/cmake/")
23
find_package(CUDAToolkit REQUIRED cublas nvToolsExt)
cyanguwa's avatar
cyanguwa committed
24
find_package(CUDNN REQUIRED cudnn)
25
26
27
28
29
find_package(Python COMPONENTS Interpreter Development REQUIRED)

include_directories(${PROJECT_SOURCE_DIR})

add_subdirectory(common)
30
if(NVTE_WITH_USERBUFFERS)
Tim Moon's avatar
Tim Moon committed
31
    message(STATUS "userbuffers support enabled")
32
33
    add_subdirectory(pytorch/csrc/userbuffers)
endif()
34

Tim Moon's avatar
Tim Moon committed
35

36
option(ENABLE_JAX "Enable JAX in the building workflow." OFF)
Tim Moon's avatar
Tim Moon committed
37
message(STATUS "JAX support: ${ENABLE_JAX}")
38
39
40
41
if(ENABLE_JAX)
  find_package(pybind11 CONFIG REQUIRED)
  add_subdirectory(jax)
endif()