Commit b71ea424 authored by yuguo's avatar yuguo
Browse files
parents dfd264c3 12fc1b14
...@@ -58,9 +58,34 @@ def setup_pytorch_extension( ...@@ -58,9 +58,34 @@ def setup_pytorch_extension(
"-U__HIP_NO_BFLOAT16_CONVERSIONS__", "-U__HIP_NO_BFLOAT16_CONVERSIONS__",
"-U__HIP_NO_BFLOAT162_OPERATORS__", "-U__HIP_NO_BFLOAT162_OPERATORS__",
"-U__HIP_NO_BFLOAT162_CONVERSIONS__", "-U__HIP_NO_BFLOAT162_CONVERSIONS__",
"-w",
"-DUSE_ROCM", "-DUSE_ROCM",
] ]
if bool(int(os.getenv("NVTE_BUILD_SUPPRESS_UNUSED_WARNING", "1"))):
nvcc_flags.extend(
[
"-Wno-unused-result",
"-Wno-unused-function",
"-Wno-unused-private-field",
"-Wno-unused-variable",
]
)
cxx_flags.extend(
[
"-Wno-unused-result",
"-Wno-unused-function",
"-Wno-unused-private-field",
"-Wno-unused-variable",
]
)
if bool(int(os.getenv("NVTE_BUILD_SUPPRESS_RETURN_TYPE_WARNING", "0"))):
nvcc_flags.append("-Wno-return-type")
cxx_flags.append("-Wno-return-type")
if bool(int(os.getenv("NVTE_BUILD_SUPPRESS_SIGN_COMPARE", "0"))):
nvcc_flags.append("-Wno-sign-compare")
cxx_flags.append("-Wno-sign-compare")
else: else:
nvcc_flags = [ nvcc_flags = [
"-O3", "-O3",
......
...@@ -64,6 +64,13 @@ def setup_common_extension() -> CMakeExtension: ...@@ -64,6 +64,13 @@ def setup_common_extension() -> CMakeExtension:
"""Setup CMake extension for common library""" """Setup CMake extension for common library"""
if rocm_build(): if rocm_build():
cmake_flags = [] cmake_flags = []
if bool(int(os.getenv("NVTE_BUILD_SUPPRESS_UNUSED_WARNING", "1"))):
cmake_flags.append("-DNVTE_BUILD_SUPPRESS_UNUSED_WARNING=ON")
if bool(int(os.getenv("NVTE_BUILD_SUPPRESS_RETURN_TYPE_WARNING", "0"))):
cmake_flags.append("-DNVTE_BUILD_SUPPRESS_RETURN_TYPE_WARNING=ON")
if bool(int(os.getenv("NVTE_BUILD_SUPPRESS_SIGN_COMPARE", "0"))):
cmake_flags.append("-DNVTE_BUILD_SUPPRESS_SIGN_COMPARE_WARNING=ON")
else: else:
cmake_flags = ["-DCMAKE_CUDA_ARCHITECTURES={}".format(archs)] cmake_flags = ["-DCMAKE_CUDA_ARCHITECTURES={}".format(archs)]
if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))): if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))):
......
...@@ -340,6 +340,18 @@ if(USE_CUDA) ...@@ -340,6 +340,18 @@ if(USE_CUDA)
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3")
else() else()
option(NVTE_BUILD_SUPPRESS_UNUSED_WARNING "Suppress unused* wanings while build" ON)
option(NVTE_BUILD_SUPPRESS_RETURN_TYPE_WARNING "Suppress return type waning while build" OFF)
option(NVTE_BUILD_SUPPRESS_SIGN_COMPARE_WARNING "Suppress sign compare waning while build" OFF)
if(NVTE_BUILD_SUPPRESS_UNUSED_WARNING)
set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} -Wno-unused-result -Wno-unused-function -Wno-unused-private-field -Wno-unused-variable")
endif()
if(NVTE_BUILD_SUPPRESS_RETURN_TYPE_WARNING)
set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} -Wno-return-type")
endif()
if(NVTE_BUILD_SUPPRESS_SIGN_COMPARE_WARNING)
set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} -Wno-sign-compare")
endif()
set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} -O3") set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} -O3")
set(HIP_HCC_FLAGS "${CMAKE_HIP_FLAGS} -mavx2 -mf16c -mfma -std=c++17") set(HIP_HCC_FLAGS "${CMAKE_HIP_FLAGS} -mavx2 -mf16c -mfma -std=c++17")
# Ask hcc to generate device code during compilation so we can use # Ask hcc to generate device code during compilation so we can use
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment