Unverified Commit 9690e8e1 authored by moto's avatar moto Committed by GitHub
Browse files

Clean up transducer build (#1159)

parent 84966ae1
......@@ -61,29 +61,31 @@ def _get_ela(debug):
def _get_srcs():
return [str(p) for p in _CSRC_DIR.glob('**/*.cpp')]
srcs = [_CSRC_DIR / 'pybind.cpp']
srcs += list(_CSRC_DIR.glob('sox/**/*.cpp'))
if _BUILD_TRANSDUCER:
srcs += [_CSRC_DIR / 'transducer.cpp']
return [str(path) for path in srcs]
def _get_include_dirs():
dirs = [
str(_ROOT_DIR),
]
if _BUILD_SOX:
if _BUILD_SOX or _BUILD_TRANSDUCER:
dirs.append(str(_TP_INSTALL_DIR / 'include'))
if _BUILD_TRANSDUCER:
dirs.append(str(_TP_BASE_DIR / 'transducer' / 'submodule' / 'include'))
return dirs
def _get_extra_objects():
objs = []
libs = []
if _BUILD_SOX:
# NOTE: The order of the library listed bellow matters.
#
# (the most important thing is that dependencies come after a library
# e.g., sox comes first, flac/vorbis comes before ogg, and
# vorbisenc/vorbisfile comes before vorbis
libs = [
libs += [
'libsox.a',
'libmad.a',
'libFLAC.a',
......@@ -97,27 +99,34 @@ def _get_extra_objects():
'libopencore-amrnb.a',
'libopencore-amrwb.a',
]
for lib in libs:
objs.append(str(_TP_INSTALL_DIR / 'lib' / lib))
if _BUILD_TRANSDUCER:
objs.append(str(_TP_BASE_DIR / 'build' / 'transducer' / 'libwarprnnt.a'))
return objs
libs += ['libwarprnnt.a']
return [str(_TP_INSTALL_DIR / 'lib' / lib) for lib in libs]
def _get_libraries():
return [] if _BUILD_SOX else ['sox']
def _build_third_party():
build_dir = str(_TP_BASE_DIR / 'build')
def _build_third_party(base_build_dir):
build_dir = os.path.join(base_build_dir, 'third_party')
os.makedirs(build_dir, exist_ok=True)
subprocess.run(
args=['cmake', '..'],
args=[
'cmake',
f'-DCMAKE_INSTALL_PREFIX={_TP_INSTALL_DIR}',
f'-DBUILD_SOX={"ON" if _BUILD_SOX else "OFF"}',
f'-DBUILD_TRANSDUCER={"ON" if _BUILD_TRANSDUCER else "OFF"}',
f'{_TP_BASE_DIR}'],
cwd=build_dir,
check=True,
)
command = ['cmake', '--build', '.']
if _BUILD_TRANSDUCER:
command += ['--target', 'install']
subprocess.run(
args=['cmake', '--build', '.'],
args=command,
cwd=build_dir,
check=True,
)
......@@ -145,5 +154,5 @@ def get_ext_modules(debug=False):
class BuildExtension(TorchBuildExtension):
def build_extension(self, ext):
if ext.name == _EXT_NAME and _BUILD_SOX:
_build_third_party()
_build_third_party(self.build_temp)
super().build_extension(ext)
......@@ -50,7 +50,6 @@ class clean(distutils.command.clean.clean):
# Remove build directory
build_dirs = [
ROOT_DIR / 'build',
ROOT_DIR / 'third_party' / 'build',
]
for path in build_dirs:
if path.exists():
......
cmake_minimum_required(VERSION 3.1)
cmake_minimum_required(VERSION 3.5)
project(torchaudio_third_parties)
include(ExternalProject)
set(INSTALL_DIR ${CMAKE_CURRENT_SOURCE_DIR}/install)
option(BUILD_SOX "Build libsox statically")
option(BUILD_TRANSDUCER "Build transducer statically")
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
if (BUILD_SOX)
include(ExternalProject)
set(INSTALL_DIR ${CMAKE_INSTALL_PREFIX})
set(ARCHIVE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/archives)
set(COMMON_ARGS --quiet --disable-shared --enable-static --prefix=${INSTALL_DIR} --with-pic --disable-dependency-tracking --disable-debug --disable-examples --disable-doc)
......@@ -88,5 +94,8 @@ ExternalProject_Add(libsox
# See https://github.com/pytorch/audio/pull/1026
CONFIGURE_COMMAND ${CMAKE_CURRENT_SOURCE_DIR}/build_codec_helper.sh ${CMAKE_CURRENT_SOURCE_DIR}/src/libsox/configure ${COMMON_ARGS} --with-lame --with-flac --with-mad --with-oggvorbis --without-alsa --without-coreaudio --without-png --without-oss --without-sndfile --with-opus --with-amrwb --with-amrnb --disable-openmp --without-sndio --without-pulseaudio
)
endif()
if(BUILD_TRANSDUCER)
add_subdirectory(transducer)
endif()
CMAKE_MINIMUM_REQUIRED(VERSION 3.5)
PROJECT(rnnt_release)
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O2")
IF(APPLE)
ADD_DEFINITIONS(-DAPPLE)
ENDIF()
INCLUDE_DIRECTORIES(submodule/include)
SET(CMAKE_POSITION_INDEPENDENT_CODE ON)
ADD_DEFINITIONS(-DRNNT_DISABLE_OMP)
IF(APPLE)
ADD_DEFINITIONS(-DAPPLE)
EXEC_PROGRAM(uname ARGS -v OUTPUT_VARIABLE DARWIN_VERSION)
STRING(REGEX MATCH "[0-9]+" DARWIN_VERSION ${DARWIN_VERSION})
MESSAGE(STATUS "DARWIN_VERSION=${DARWIN_VERSION}")
......@@ -30,9 +17,11 @@ ELSE()
ENDIF()
ADD_LIBRARY(warprnnt STATIC submodule/src/rnnt_entrypoint.cpp)
target_include_directories(warprnnt PUBLIC submodule/include)
set_target_properties(warprnnt PROPERTIES PUBLIC_HEADER submodule/include/rnnt.h)
INSTALL(TARGETS warprnnt
LIBRARY DESTINATION "lib"
ARCHIVE DESTINATION "lib")
INSTALL(FILES submodule/include/rnnt.h DESTINATION "submodule/include")
INSTALL(
TARGETS warprnnt
ARCHIVE DESTINATION "lib"
PUBLIC_HEADER DESTINATION "include")
......@@ -2,7 +2,7 @@
#include <torchaudio/csrc/sox/io.h>
#include <torchaudio/csrc/sox/utils.h>
TORCH_LIBRARY(torchaudio, m) {
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
//////////////////////////////////////////////////////////////////////////////
// sox_utils.h
//////////////////////////////////////////////////////////////////////////////
......@@ -74,18 +74,4 @@ TORCH_LIBRARY(torchaudio, m) {
m.def(
"torchaudio::sox_effects_apply_effects_file",
&torchaudio::sox_effects::apply_effects_file);
//////////////////////////////////////////////////////////////////////////////
// transducer.cpp
//////////////////////////////////////////////////////////////////////////////
#ifdef BUILD_TRANSDUCER
m.def("rnnt_loss(Tensor acts,"
"Tensor labels,"
"Tensor input_lengths,"
"Tensor label_lengths,"
"Tensor costs,"
"Tensor grads,"
"int blank_label,"
"int num_threads) -> int");
#endif
}
#ifdef BUILD_TRANSDUCER
#include <iostream>
#include <numeric>
#include <string>
......@@ -8,6 +6,8 @@
#include <torch/script.h>
#include "rnnt.h"
namespace {
int64_t cpu_rnnt_loss(torch::Tensor acts,
torch::Tensor labels,
torch::Tensor input_lengths,
......@@ -75,8 +75,19 @@ int64_t cpu_rnnt_loss(torch::Tensor acts,
return -1;
}
} // namespace
TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
m.impl("rnnt_loss", &cpu_rnnt_loss);
m.impl("rnnt_loss", &cpu_rnnt_loss);
}
#endif
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def("rnnt_loss(Tensor acts,"
"Tensor labels,"
"Tensor input_lengths,"
"Tensor label_lengths,"
"Tensor costs,"
"Tensor grads,"
"int blank_label,"
"int num_threads) -> int");
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment