Unverified Commit 9690e8e1 authored by moto's avatar moto Committed by GitHub
Browse files

Clean up transducer build (#1159)

parent 84966ae1
...@@ -61,29 +61,31 @@ def _get_ela(debug): ...@@ -61,29 +61,31 @@ def _get_ela(debug):
def _get_srcs(): def _get_srcs():
return [str(p) for p in _CSRC_DIR.glob('**/*.cpp')] srcs = [_CSRC_DIR / 'pybind.cpp']
srcs += list(_CSRC_DIR.glob('sox/**/*.cpp'))
if _BUILD_TRANSDUCER:
srcs += [_CSRC_DIR / 'transducer.cpp']
return [str(path) for path in srcs]
def _get_include_dirs(): def _get_include_dirs():
dirs = [ dirs = [
str(_ROOT_DIR), str(_ROOT_DIR),
] ]
if _BUILD_SOX: if _BUILD_SOX or _BUILD_TRANSDUCER:
dirs.append(str(_TP_INSTALL_DIR / 'include')) dirs.append(str(_TP_INSTALL_DIR / 'include'))
if _BUILD_TRANSDUCER:
dirs.append(str(_TP_BASE_DIR / 'transducer' / 'submodule' / 'include'))
return dirs return dirs
def _get_extra_objects(): def _get_extra_objects():
objs = [] libs = []
if _BUILD_SOX: if _BUILD_SOX:
# NOTE: The order of the library listed bellow matters. # NOTE: The order of the library listed bellow matters.
# #
# (the most important thing is that dependencies come after a library # (the most important thing is that dependencies come after a library
# e.g., sox comes first, flac/vorbis comes before ogg, and # e.g., sox comes first, flac/vorbis comes before ogg, and
# vorbisenc/vorbisfile comes before vorbis # vorbisenc/vorbisfile comes before vorbis
libs = [ libs += [
'libsox.a', 'libsox.a',
'libmad.a', 'libmad.a',
'libFLAC.a', 'libFLAC.a',
...@@ -97,27 +99,34 @@ def _get_extra_objects(): ...@@ -97,27 +99,34 @@ def _get_extra_objects():
'libopencore-amrnb.a', 'libopencore-amrnb.a',
'libopencore-amrwb.a', 'libopencore-amrwb.a',
] ]
for lib in libs:
objs.append(str(_TP_INSTALL_DIR / 'lib' / lib))
if _BUILD_TRANSDUCER: if _BUILD_TRANSDUCER:
objs.append(str(_TP_BASE_DIR / 'build' / 'transducer' / 'libwarprnnt.a')) libs += ['libwarprnnt.a']
return objs
return [str(_TP_INSTALL_DIR / 'lib' / lib) for lib in libs]
def _get_libraries(): def _get_libraries():
return [] if _BUILD_SOX else ['sox'] return [] if _BUILD_SOX else ['sox']
def _build_third_party(): def _build_third_party(base_build_dir):
build_dir = str(_TP_BASE_DIR / 'build') build_dir = os.path.join(base_build_dir, 'third_party')
os.makedirs(build_dir, exist_ok=True) os.makedirs(build_dir, exist_ok=True)
subprocess.run( subprocess.run(
args=['cmake', '..'], args=[
'cmake',
f'-DCMAKE_INSTALL_PREFIX={_TP_INSTALL_DIR}',
f'-DBUILD_SOX={"ON" if _BUILD_SOX else "OFF"}',
f'-DBUILD_TRANSDUCER={"ON" if _BUILD_TRANSDUCER else "OFF"}',
f'{_TP_BASE_DIR}'],
cwd=build_dir, cwd=build_dir,
check=True, check=True,
) )
command = ['cmake', '--build', '.']
if _BUILD_TRANSDUCER:
command += ['--target', 'install']
subprocess.run( subprocess.run(
args=['cmake', '--build', '.'], args=command,
cwd=build_dir, cwd=build_dir,
check=True, check=True,
) )
...@@ -145,5 +154,5 @@ def get_ext_modules(debug=False): ...@@ -145,5 +154,5 @@ def get_ext_modules(debug=False):
class BuildExtension(TorchBuildExtension): class BuildExtension(TorchBuildExtension):
def build_extension(self, ext): def build_extension(self, ext):
if ext.name == _EXT_NAME and _BUILD_SOX: if ext.name == _EXT_NAME and _BUILD_SOX:
_build_third_party() _build_third_party(self.build_temp)
super().build_extension(ext) super().build_extension(ext)
...@@ -50,7 +50,6 @@ class clean(distutils.command.clean.clean): ...@@ -50,7 +50,6 @@ class clean(distutils.command.clean.clean):
# Remove build directory # Remove build directory
build_dirs = [ build_dirs = [
ROOT_DIR / 'build', ROOT_DIR / 'build',
ROOT_DIR / 'third_party' / 'build',
] ]
for path in build_dirs: for path in build_dirs:
if path.exists(): if path.exists():
......
cmake_minimum_required(VERSION 3.1) cmake_minimum_required(VERSION 3.5)
project(torchaudio_third_parties) project(torchaudio_third_parties)
include(ExternalProject)
set(INSTALL_DIR ${CMAKE_CURRENT_SOURCE_DIR}/install) option(BUILD_SOX "Build libsox statically")
option(BUILD_TRANSDUCER "Build transducer statically")
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
if (BUILD_SOX)
include(ExternalProject)
set(INSTALL_DIR ${CMAKE_INSTALL_PREFIX})
set(ARCHIVE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/archives) set(ARCHIVE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/archives)
set(COMMON_ARGS --quiet --disable-shared --enable-static --prefix=${INSTALL_DIR} --with-pic --disable-dependency-tracking --disable-debug --disable-examples --disable-doc) set(COMMON_ARGS --quiet --disable-shared --enable-static --prefix=${INSTALL_DIR} --with-pic --disable-dependency-tracking --disable-debug --disable-examples --disable-doc)
...@@ -88,5 +94,8 @@ ExternalProject_Add(libsox ...@@ -88,5 +94,8 @@ ExternalProject_Add(libsox
# See https://github.com/pytorch/audio/pull/1026 # See https://github.com/pytorch/audio/pull/1026
CONFIGURE_COMMAND ${CMAKE_CURRENT_SOURCE_DIR}/build_codec_helper.sh ${CMAKE_CURRENT_SOURCE_DIR}/src/libsox/configure ${COMMON_ARGS} --with-lame --with-flac --with-mad --with-oggvorbis --without-alsa --without-coreaudio --without-png --without-oss --without-sndfile --with-opus --with-amrwb --with-amrnb --disable-openmp --without-sndio --without-pulseaudio CONFIGURE_COMMAND ${CMAKE_CURRENT_SOURCE_DIR}/build_codec_helper.sh ${CMAKE_CURRENT_SOURCE_DIR}/src/libsox/configure ${COMMON_ARGS} --with-lame --with-flac --with-mad --with-oggvorbis --without-alsa --without-coreaudio --without-png --without-oss --without-sndfile --with-opus --with-amrwb --with-amrnb --disable-openmp --without-sndio --without-pulseaudio
) )
endif()
if(BUILD_TRANSDUCER)
add_subdirectory(transducer) add_subdirectory(transducer)
endif()
CMAKE_MINIMUM_REQUIRED(VERSION 3.5)
PROJECT(rnnt_release)
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O2")
IF(APPLE)
ADD_DEFINITIONS(-DAPPLE)
ENDIF()
INCLUDE_DIRECTORIES(submodule/include)
SET(CMAKE_POSITION_INDEPENDENT_CODE ON)
ADD_DEFINITIONS(-DRNNT_DISABLE_OMP) ADD_DEFINITIONS(-DRNNT_DISABLE_OMP)
IF(APPLE) IF(APPLE)
ADD_DEFINITIONS(-DAPPLE)
EXEC_PROGRAM(uname ARGS -v OUTPUT_VARIABLE DARWIN_VERSION) EXEC_PROGRAM(uname ARGS -v OUTPUT_VARIABLE DARWIN_VERSION)
STRING(REGEX MATCH "[0-9]+" DARWIN_VERSION ${DARWIN_VERSION}) STRING(REGEX MATCH "[0-9]+" DARWIN_VERSION ${DARWIN_VERSION})
MESSAGE(STATUS "DARWIN_VERSION=${DARWIN_VERSION}") MESSAGE(STATUS "DARWIN_VERSION=${DARWIN_VERSION}")
...@@ -30,9 +17,11 @@ ELSE() ...@@ -30,9 +17,11 @@ ELSE()
ENDIF() ENDIF()
ADD_LIBRARY(warprnnt STATIC submodule/src/rnnt_entrypoint.cpp) ADD_LIBRARY(warprnnt STATIC submodule/src/rnnt_entrypoint.cpp)
target_include_directories(warprnnt PUBLIC submodule/include)
set_target_properties(warprnnt PROPERTIES PUBLIC_HEADER submodule/include/rnnt.h)
INSTALL(TARGETS warprnnt
LIBRARY DESTINATION "lib"
ARCHIVE DESTINATION "lib")
INSTALL(FILES submodule/include/rnnt.h DESTINATION "submodule/include") INSTALL(
TARGETS warprnnt
ARCHIVE DESTINATION "lib"
PUBLIC_HEADER DESTINATION "include")
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include <torchaudio/csrc/sox/io.h> #include <torchaudio/csrc/sox/io.h>
#include <torchaudio/csrc/sox/utils.h> #include <torchaudio/csrc/sox/utils.h>
TORCH_LIBRARY(torchaudio, m) { TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
// sox_utils.h // sox_utils.h
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
...@@ -74,18 +74,4 @@ TORCH_LIBRARY(torchaudio, m) { ...@@ -74,18 +74,4 @@ TORCH_LIBRARY(torchaudio, m) {
m.def( m.def(
"torchaudio::sox_effects_apply_effects_file", "torchaudio::sox_effects_apply_effects_file",
&torchaudio::sox_effects::apply_effects_file); &torchaudio::sox_effects::apply_effects_file);
//////////////////////////////////////////////////////////////////////////////
// transducer.cpp
//////////////////////////////////////////////////////////////////////////////
#ifdef BUILD_TRANSDUCER
m.def("rnnt_loss(Tensor acts,"
"Tensor labels,"
"Tensor input_lengths,"
"Tensor label_lengths,"
"Tensor costs,"
"Tensor grads,"
"int blank_label,"
"int num_threads) -> int");
#endif
} }
#ifdef BUILD_TRANSDUCER
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
#include <string> #include <string>
...@@ -8,6 +6,8 @@ ...@@ -8,6 +6,8 @@
#include <torch/script.h> #include <torch/script.h>
#include "rnnt.h" #include "rnnt.h"
namespace {
int64_t cpu_rnnt_loss(torch::Tensor acts, int64_t cpu_rnnt_loss(torch::Tensor acts,
torch::Tensor labels, torch::Tensor labels,
torch::Tensor input_lengths, torch::Tensor input_lengths,
...@@ -75,8 +75,19 @@ int64_t cpu_rnnt_loss(torch::Tensor acts, ...@@ -75,8 +75,19 @@ int64_t cpu_rnnt_loss(torch::Tensor acts,
return -1; return -1;
} }
} // namespace
TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
m.impl("rnnt_loss", &cpu_rnnt_loss); m.impl("rnnt_loss", &cpu_rnnt_loss);
} }
#endif TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def("rnnt_loss(Tensor acts,"
"Tensor labels,"
"Tensor input_lengths,"
"Tensor label_lengths,"
"Tensor costs,"
"Tensor grads,"
"int blank_label,"
"int num_threads) -> int");
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment