"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "1b4455a55a81175ef0eff0bf6a6db12aa235cc4f"
Unverified Commit e3734fef authored by moto's avatar moto Committed by GitHub
Browse files

Add OpenMP support (#1761)

parent c9e4c75d
...@@ -62,6 +62,7 @@ option(BUILD_RNNT "Enable RNN transducer" ON) ...@@ -62,6 +62,7 @@ option(BUILD_RNNT "Enable RNN transducer" ON)
option(BUILD_TORCHAUDIO_PYTHON_EXTENSION "Build Python extension" OFF) option(BUILD_TORCHAUDIO_PYTHON_EXTENSION "Build Python extension" OFF)
option(USE_CUDA "Enable CUDA support" OFF) option(USE_CUDA "Enable CUDA support" OFF)
option(USE_ROCM "Enable ROCM support" OFF) option(USE_ROCM "Enable ROCM support" OFF)
option(USE_OPENMP "Enable OpenMP support" OFF)
# check that USE_CUDA and USE_ROCM are not set at the same time # check that USE_CUDA and USE_ROCM are not set at the same time
...@@ -122,6 +123,10 @@ if(MSVC) ...@@ -122,6 +123,10 @@ if(MSVC)
endif() endif()
endif() endif()
if(USE_OPENMP)
find_package(OpenMP REQUIRED)
endif()
# TORCH_CXX_FLAGS contains the same -D_GLIBCXX_USE_CXX11_ABI value as PyTorch # TORCH_CXX_FLAGS contains the same -D_GLIBCXX_USE_CXX11_ABI value as PyTorch
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall ${TORCH_CXX_FLAGS}") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall ${TORCH_CXX_FLAGS}")
......
...@@ -39,6 +39,8 @@ _BUILD_KALDI = False if platform.system() == 'Windows' else _get_build("BUILD_KA ...@@ -39,6 +39,8 @@ _BUILD_KALDI = False if platform.system() == 'Windows' else _get_build("BUILD_KA
_BUILD_RNNT = _get_build("BUILD_RNNT", True) _BUILD_RNNT = _get_build("BUILD_RNNT", True)
_USE_ROCM = _get_build("USE_ROCM", torch.cuda.is_available() and torch.version.hip is not None) _USE_ROCM = _get_build("USE_ROCM", torch.cuda.is_available() and torch.version.hip is not None)
_USE_CUDA = _get_build("USE_CUDA", torch.cuda.is_available() and torch.version.hip is None) _USE_CUDA = _get_build("USE_CUDA", torch.cuda.is_available() and torch.version.hip is None)
_USE_OPENMP = _get_build("USE_OPENMP", True) and \
'ATen parallel backend: OpenMP' in torch.__config__.parallel_info()
_TORCH_CUDA_ARCH_LIST = os.environ.get('TORCH_CUDA_ARCH_LIST', None) _TORCH_CUDA_ARCH_LIST = os.environ.get('TORCH_CUDA_ARCH_LIST', None)
...@@ -90,6 +92,7 @@ class CMakeBuild(build_ext): ...@@ -90,6 +92,7 @@ class CMakeBuild(build_ext):
"-DBUILD_TORCHAUDIO_PYTHON_EXTENSION:BOOL=ON", "-DBUILD_TORCHAUDIO_PYTHON_EXTENSION:BOOL=ON",
f"-DUSE_ROCM:BOOL={'ON' if _USE_ROCM else 'OFF'}", f"-DUSE_ROCM:BOOL={'ON' if _USE_ROCM else 'OFF'}",
f"-DUSE_CUDA:BOOL={'ON' if _USE_CUDA else 'OFF'}", f"-DUSE_CUDA:BOOL={'ON' if _USE_CUDA else 'OFF'}",
f"-DUSE_OPENMP:BOOL={'ON' if _USE_OPENMP else 'OFF'}",
] ]
build_args = [ build_args = [
'--target', 'install' '--target', 'install'
......
...@@ -94,6 +94,10 @@ if (MSVC) ...@@ -94,6 +94,10 @@ if (MSVC)
set_target_properties(libtorchaudio PROPERTIES SUFFIX ".pyd") set_target_properties(libtorchaudio PROPERTIES SUFFIX ".pyd")
endif(MSVC) endif(MSVC)
if(OpenMP_CXX_FOUND)
target_link_libraries(libtorchaudio OpenMP::OpenMP_CXX)
endif()
install( install(
TARGETS libtorchaudio TARGETS libtorchaudio
LIBRARY DESTINATION lib LIBRARY DESTINATION lib
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment