Commit 396d16cb authored by Caroline Chen's avatar Caroline Chen Committed by Facebook GitHub Bot
Browse files

Add C++ files for CTC decoder bindings (#2079)

Summary:
part of https://github.com/pytorch/audio/issues/2072 -- splitting up the PR for easier review

Add C++ files for binding CTC decoder functionality for Python

Note: the code here will not be compiled until the build process is changed

Pull Request resolved: https://github.com/pytorch/audio/pull/2079

Reviewed By: mthrok

Differential Revision: D33196286

Pulled By: carolineechen

fbshipit-source-id: 9fe4a8635b60ebfb594918bab00f5c3dccf96bd2
parent 24baa243
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "torchaudio/csrc/decoder/src/decoder/LexiconDecoder.h"
#include "torchaudio/csrc/decoder/src/decoder/lm/KenLM.h"
namespace py = pybind11;
using namespace torchaudio::lib::text;
using namespace py::literals;
/**
* Some hackery that lets pybind11 handle shared_ptr<void> (for old LMStatePtr).
* See: https://github.com/pybind/pybind11/issues/820
* PYBIND11_MAKE_OPAQUE(std::shared_ptr<void>);
* and inside PYBIND11_MODULE
* py::class_<std::shared_ptr<void>>(m, "encapsulated_data");
*/
namespace {
/**
* A pybind11 "alias type" for abstract class LM, allowing one to subclass LM
* with a custom LM defined purely in Python. For those who don't want to build
* with KenLM, or have their own custom LM implementation.
* See: https://pybind11.readthedocs.io/en/stable/advanced/classes.html
*
* TODO: ensure this works. Last time Jeff tried this there were slicing issues,
* see https://github.com/pybind/pybind11/issues/1546 for workarounds.
* This is low-pri since we assume most people can just build with KenLM.
*/
class PyLM : public LM {
using LM::LM;
// needed for pybind11 or else it won't compile
using LMOutput = std::pair<LMStatePtr, float>;
LMStatePtr start(bool startWithNothing) override {
PYBIND11_OVERLOAD_PURE(LMStatePtr, LM, start, startWithNothing);
}
LMOutput score(const LMStatePtr& state, const int usrTokenIdx) override {
PYBIND11_OVERLOAD_PURE(LMOutput, LM, score, state, usrTokenIdx);
}
LMOutput finish(const LMStatePtr& state) override {
PYBIND11_OVERLOAD_PURE(LMOutput, LM, finish, state);
}
};
/**
* Using custom python LMState derived from LMState is not working with
* custom python LM (derived from PyLM) because we need to to custing of LMState
* in score and finish functions to the derived class
* (for example vie obj.__class__ = CustomPyLMSTate) which cause the error
* "TypeError: __class__ assignment: 'CustomPyLMState' deallocator differs
* from 'flashlight.text.decoder._decoder.LMState'"
* details see in https://github.com/pybind/pybind11/issues/1640
* To define custom LM you can introduce map inside LM which maps LMstate to
* additional state info (shared pointers pointing to the same underlying object
* will have the same id in python in functions score and finish)
*
* ```python
* from flashlight.lib.text.decoder import LM
* class MyPyLM(LM):
* mapping_states = dict() # store simple additional int for each state
*
* def __init__(self):
* LM.__init__(self)
*
* def start(self, start_with_nothing):
* state = LMState()
* self.mapping_states[state] = 0
* return state
*
* def score(self, state, index):
* outstate = state.child(index)
* if outstate not in self.mapping_states:
* self.mapping_states[outstate] = self.mapping_states[state] + 1
* return (outstate, -numpy.random.random())
*
* def finish(self, state):
* outstate = state.child(-1)
* if outstate not in self.mapping_states:
* self.mapping_states[outstate] = self.mapping_states[state] + 1
* return (outstate, -1)
*```
*/
void LexiconDecoder_decodeStep(
LexiconDecoder& decoder,
uintptr_t emissions,
int T,
int N) {
decoder.decodeStep(reinterpret_cast<const float*>(emissions), T, N);
}
std::vector<DecodeResult> LexiconDecoder_decode(
LexiconDecoder& decoder,
uintptr_t emissions,
int T,
int N) {
return decoder.decode(reinterpret_cast<const float*>(emissions), T, N);
}
} // namespace
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "torchaudio/csrc/decoder/src/dictionary/Dictionary.h"
#include "torchaudio/csrc/decoder/src/dictionary/Utils.h"
namespace py = pybind11;
using namespace torchaudio::lib::text;
using namespace py::literals;
namespace {
void Dictionary_addEntry_0(
Dictionary& dict,
const std::string& entry,
int idx) {
dict.addEntry(entry, idx);
}
void Dictionary_addEntry_1(Dictionary& dict, const std::string& entry) {
dict.addEntry(entry);
}
} // namespace
include(ExternalProject)
set(pybind11_URL https://github.com/pybind/pybind11.git)
set(pybind11_TAG 9a19306fbf30642ca331d0ec88e7da54a96860f9) # release 2.2.4
# Download pybind11
ExternalProject_Add(
pybind11
PREFIX pybind11
GIT_REPOSITORY ${pybind11_URL}
GIT_TAG ${pybind11_TAG}
BUILD_IN_SOURCE 0
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ""
)
ExternalProject_Get_Property(pybind11 SOURCE_DIR)
set(pybind11_INCLUDE_DIR "${SOURCE_DIR}/include")
\ No newline at end of file
# - Find python libraries
# This module finds the libraries corresponding to the Python interpreter
# FindPythonInterp provides.
# This code sets the following variables:
#
# PYTHONLIBS_FOUND - have the Python libs been found
# PYTHON_PREFIX - path to the Python installation
# PYTHON_LIBRARIES - path to the python library
# PYTHON_INCLUDE_DIRS - path to where Python.h is found
# PYTHON_MODULE_EXTENSION - lib extension, e.g. '.so' or '.pyd'
# PYTHON_MODULE_PREFIX - lib name prefix: usually an empty string
# PYTHON_SITE_PACKAGES - path to installation site-packages
# PYTHON_IS_DEBUG - whether the Python interpreter is a debug build
#
# Thanks to talljimbo for the patch adding the 'LDVERSION' config
# variable usage.
#=============================================================================
# Copyright 2001-2009 Kitware, Inc.
# Copyright 2012 Continuum Analytics, Inc.
#
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
#
# * Neither the names of Kitware, Inc., the Insight Software Consortium,
# nor the names of their contributors may be used to endorse or promote
# products derived from this software without specific prior written
# permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#=============================================================================
# Checking for the extension makes sure that `LibsNew` was found and not just `Libs`.
if(PYTHONLIBS_FOUND AND PYTHON_MODULE_EXTENSION)
return()
endif()
# Use the Python interpreter to find the libs.
if(PythonLibsNew_FIND_REQUIRED)
find_package(PythonInterp ${PythonLibsNew_FIND_VERSION} REQUIRED)
else()
find_package(PythonInterp ${PythonLibsNew_FIND_VERSION})
endif()
if(NOT PYTHONINTERP_FOUND)
set(PYTHONLIBS_FOUND FALSE)
return()
endif()
# According to http://stackoverflow.com/questions/646518/python-how-to-detect-debug-interpreter
# testing whether sys has the gettotalrefcount function is a reliable, cross-platform
# way to detect a CPython debug interpreter.
#
# The library suffix is from the config var LDVERSION sometimes, otherwise
# VERSION. VERSION will typically be like "2.7" on unix, and "27" on windows.
execute_process(COMMAND "${PYTHON_EXECUTABLE}" "-c"
"from distutils import sysconfig as s;import sys;import struct;
print('.'.join(str(v) for v in sys.version_info));
print(sys.prefix);
print(s.get_python_inc(plat_specific=True));
print(s.get_python_lib(plat_specific=True));
print(s.get_config_var('SO'));
print(hasattr(sys, 'gettotalrefcount')+0);
print(struct.calcsize('@P'));
print(s.get_config_var('LDVERSION') or s.get_config_var('VERSION'));
print(s.get_config_var('LIBDIR') or '');
print(s.get_config_var('MULTIARCH') or '');
"
RESULT_VARIABLE _PYTHON_SUCCESS
OUTPUT_VARIABLE _PYTHON_VALUES
ERROR_VARIABLE _PYTHON_ERROR_VALUE)
if(NOT _PYTHON_SUCCESS MATCHES 0)
if(PythonLibsNew_FIND_REQUIRED)
message(FATAL_ERROR
"Python config failure:\n${_PYTHON_ERROR_VALUE}")
endif()
set(PYTHONLIBS_FOUND FALSE)
return()
endif()
# Convert the process output into a list
if(WIN32)
string(REGEX REPLACE "\\\\" "/" _PYTHON_VALUES ${_PYTHON_VALUES})
endif()
string(REGEX REPLACE ";" "\\\\;" _PYTHON_VALUES ${_PYTHON_VALUES})
string(REGEX REPLACE "\n" ";" _PYTHON_VALUES ${_PYTHON_VALUES})
list(GET _PYTHON_VALUES 0 _PYTHON_VERSION_LIST)
list(GET _PYTHON_VALUES 1 PYTHON_PREFIX)
list(GET _PYTHON_VALUES 2 PYTHON_INCLUDE_DIR)
list(GET _PYTHON_VALUES 3 PYTHON_SITE_PACKAGES)
list(GET _PYTHON_VALUES 4 PYTHON_MODULE_EXTENSION)
list(GET _PYTHON_VALUES 5 PYTHON_IS_DEBUG)
list(GET _PYTHON_VALUES 6 PYTHON_SIZEOF_VOID_P)
list(GET _PYTHON_VALUES 7 PYTHON_LIBRARY_SUFFIX)
list(GET _PYTHON_VALUES 8 PYTHON_LIBDIR)
list(GET _PYTHON_VALUES 9 PYTHON_MULTIARCH)
# Make sure the Python has the same pointer-size as the chosen compiler
# Skip if CMAKE_SIZEOF_VOID_P is not defined
if(CMAKE_SIZEOF_VOID_P AND (NOT "${PYTHON_SIZEOF_VOID_P}" STREQUAL "${CMAKE_SIZEOF_VOID_P}"))
if(PythonLibsNew_FIND_REQUIRED)
math(EXPR _PYTHON_BITS "${PYTHON_SIZEOF_VOID_P} * 8")
math(EXPR _CMAKE_BITS "${CMAKE_SIZEOF_VOID_P} * 8")
message(FATAL_ERROR
"Python config failure: Python is ${_PYTHON_BITS}-bit, "
"chosen compiler is ${_CMAKE_BITS}-bit")
endif()
set(PYTHONLIBS_FOUND FALSE)
return()
endif()
# The built-in FindPython didn't always give the version numbers
string(REGEX REPLACE "\\." ";" _PYTHON_VERSION_LIST ${_PYTHON_VERSION_LIST})
list(GET _PYTHON_VERSION_LIST 0 PYTHON_VERSION_MAJOR)
list(GET _PYTHON_VERSION_LIST 1 PYTHON_VERSION_MINOR)
list(GET _PYTHON_VERSION_LIST 2 PYTHON_VERSION_PATCH)
# Make sure all directory separators are '/'
string(REGEX REPLACE "\\\\" "/" PYTHON_PREFIX ${PYTHON_PREFIX})
string(REGEX REPLACE "\\\\" "/" PYTHON_INCLUDE_DIR ${PYTHON_INCLUDE_DIR})
string(REGEX REPLACE "\\\\" "/" PYTHON_SITE_PACKAGES ${PYTHON_SITE_PACKAGES})
if(CMAKE_HOST_WIN32)
set(PYTHON_LIBRARY
"${PYTHON_PREFIX}/libs/Python${PYTHON_LIBRARY_SUFFIX}.lib")
# when run in a venv, PYTHON_PREFIX points to it. But the libraries remain in the
# original python installation. They may be found relative to PYTHON_INCLUDE_DIR.
if(NOT EXISTS "${PYTHON_LIBRARY}")
get_filename_component(_PYTHON_ROOT ${PYTHON_INCLUDE_DIR} DIRECTORY)
set(PYTHON_LIBRARY
"${_PYTHON_ROOT}/libs/Python${PYTHON_LIBRARY_SUFFIX}.lib")
endif()
# raise an error if the python libs are still not found.
if(NOT EXISTS "${PYTHON_LIBRARY}")
message(FATAL_ERROR "Python libraries not found")
endif()
else()
if(PYTHON_MULTIARCH)
set(_PYTHON_LIBS_SEARCH "${PYTHON_LIBDIR}/${PYTHON_MULTIARCH}" "${PYTHON_LIBDIR}")
else()
set(_PYTHON_LIBS_SEARCH "${PYTHON_LIBDIR}")
endif()
#message(STATUS "Searching for Python libs in ${_PYTHON_LIBS_SEARCH}")
# Probably this needs to be more involved. It would be nice if the config
# information the python interpreter itself gave us were more complete.
find_library(PYTHON_LIBRARY
NAMES "python${PYTHON_LIBRARY_SUFFIX}"
PATHS ${_PYTHON_LIBS_SEARCH}
NO_DEFAULT_PATH)
# If all else fails, just set the name/version and let the linker figure out the path.
if(NOT PYTHON_LIBRARY)
set(PYTHON_LIBRARY python${PYTHON_LIBRARY_SUFFIX})
endif()
endif()
MARK_AS_ADVANCED(
PYTHON_LIBRARY
PYTHON_INCLUDE_DIR
)
# We use PYTHON_INCLUDE_DIR, PYTHON_LIBRARY and PYTHON_DEBUG_LIBRARY for the
# cache entries because they are meant to specify the location of a single
# library. We now set the variables listed by the documentation for this
# module.
SET(PYTHON_INCLUDE_DIRS "${PYTHON_INCLUDE_DIR}")
SET(PYTHON_LIBRARIES "${PYTHON_LIBRARY}")
SET(PYTHON_DEBUG_LIBRARIES "${PYTHON_DEBUG_LIBRARY}")
find_package_message(PYTHON
"Found PythonLibs: ${PYTHON_LIBRARY}"
"${PYTHON_EXECUTABLE}${PYTHON_VERSION}")
set(PYTHONLIBS_FOUND TRUE)
# tools/pybind11Tools.cmake -- Build system for the pybind11 modules
#
# Copyright (c) 2015 Wenzel Jakob <wenzel@inf.ethz.ch>
#
# All rights reserved. Use of this source code is governed by a
# BSD-style license that can be found in the LICENSE file.
cmake_minimum_required(VERSION 2.8.12)
# Add a CMake parameter for choosing a desired Python version
if(NOT PYBIND11_PYTHON_VERSION)
set(PYBIND11_PYTHON_VERSION "" CACHE STRING "Python version to use for compiling modules")
endif()
set(Python_ADDITIONAL_VERSIONS 3.7 3.6 3.5 3.4)
find_package(PythonLibsNew ${PYBIND11_PYTHON_VERSION} REQUIRED)
include(CheckCXXCompilerFlag)
include(CMakeParseArguments)
if(NOT PYBIND11_CPP_STANDARD AND NOT CMAKE_CXX_STANDARD)
if(NOT MSVC)
check_cxx_compiler_flag("-std=c++14" HAS_CPP14_FLAG)
if (HAS_CPP14_FLAG)
set(PYBIND11_CPP_STANDARD -std=c++14)
else()
check_cxx_compiler_flag("-std=c++11" HAS_CPP11_FLAG)
if (HAS_CPP11_FLAG)
set(PYBIND11_CPP_STANDARD -std=c++11)
else()
message(FATAL_ERROR "Unsupported compiler -- pybind11 requires C++11 support!")
endif()
endif()
elseif(MSVC)
set(PYBIND11_CPP_STANDARD /std:c++14)
endif()
set(PYBIND11_CPP_STANDARD ${PYBIND11_CPP_STANDARD} CACHE STRING
"C++ standard flag, e.g. -std=c++11, -std=c++14, /std:c++14. Defaults to C++14 mode." FORCE)
endif()
# Checks whether the given CXX/linker flags can compile and link a cxx file. cxxflags and
# linkerflags are lists of flags to use. The result variable is a unique variable name for each set
# of flags: the compilation result will be cached base on the result variable. If the flags work,
# sets them in cxxflags_out/linkerflags_out internal cache variables (in addition to ${result}).
function(_pybind11_return_if_cxx_and_linker_flags_work result cxxflags linkerflags cxxflags_out linkerflags_out)
set(CMAKE_REQUIRED_LIBRARIES ${linkerflags})
check_cxx_compiler_flag("${cxxflags}" ${result})
if (${result})
set(${cxxflags_out} "${cxxflags}" CACHE INTERNAL "" FORCE)
set(${linkerflags_out} "${linkerflags}" CACHE INTERNAL "" FORCE)
endif()
endfunction()
# Internal: find the appropriate link time optimization flags for this compiler
function(_pybind11_add_lto_flags target_name prefer_thin_lto)
if (NOT DEFINED PYBIND11_LTO_CXX_FLAGS)
set(PYBIND11_LTO_CXX_FLAGS "" CACHE INTERNAL "")
set(PYBIND11_LTO_LINKER_FLAGS "" CACHE INTERNAL "")
if(CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang")
set(cxx_append "")
set(linker_append "")
if (CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND NOT APPLE)
# Clang Gold plugin does not support -Os; append -O3 to MinSizeRel builds to override it
set(linker_append ";$<$<CONFIG:MinSizeRel>:-O3>")
elseif(CMAKE_CXX_COMPILER_ID MATCHES "GNU")
set(cxx_append ";-fno-fat-lto-objects")
endif()
if (CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND prefer_thin_lto)
_pybind11_return_if_cxx_and_linker_flags_work(HAS_FLTO_THIN
"-flto=thin${cxx_append}" "-flto=thin${linker_append}"
PYBIND11_LTO_CXX_FLAGS PYBIND11_LTO_LINKER_FLAGS)
endif()
if (NOT HAS_FLTO_THIN)
_pybind11_return_if_cxx_and_linker_flags_work(HAS_FLTO
"-flto${cxx_append}" "-flto${linker_append}"
PYBIND11_LTO_CXX_FLAGS PYBIND11_LTO_LINKER_FLAGS)
endif()
elseif (CMAKE_CXX_COMPILER_ID MATCHES "Intel")
# Intel equivalent to LTO is called IPO
_pybind11_return_if_cxx_and_linker_flags_work(HAS_INTEL_IPO
"-ipo" "-ipo" PYBIND11_LTO_CXX_FLAGS PYBIND11_LTO_LINKER_FLAGS)
elseif(MSVC)
# cmake only interprets libraries as linker flags when they start with a - (otherwise it
# converts /LTCG to \LTCG as if it was a Windows path). Luckily MSVC supports passing flags
# with - instead of /, even if it is a bit non-standard:
_pybind11_return_if_cxx_and_linker_flags_work(HAS_MSVC_GL_LTCG
"/GL" "-LTCG" PYBIND11_LTO_CXX_FLAGS PYBIND11_LTO_LINKER_FLAGS)
endif()
if (PYBIND11_LTO_CXX_FLAGS)
message(STATUS "LTO enabled")
else()
message(STATUS "LTO disabled (not supported by the compiler and/or linker)")
endif()
endif()
# Enable LTO flags if found, except for Debug builds
if (PYBIND11_LTO_CXX_FLAGS)
target_compile_options(${target_name} PRIVATE "$<$<NOT:$<CONFIG:Debug>>:${PYBIND11_LTO_CXX_FLAGS}>")
endif()
if (PYBIND11_LTO_LINKER_FLAGS)
target_link_libraries(${target_name} PRIVATE "$<$<NOT:$<CONFIG:Debug>>:${PYBIND11_LTO_LINKER_FLAGS}>")
endif()
endfunction()
# Build a Python extension module:
# pybind11_add_module(<name> [MODULE | SHARED] [EXCLUDE_FROM_ALL]
# [NO_EXTRAS] [SYSTEM] [THIN_LTO] source1 [source2 ...])
#
function(pybind11_add_module target_name)
set(options MODULE SHARED EXCLUDE_FROM_ALL NO_EXTRAS SYSTEM THIN_LTO)
cmake_parse_arguments(ARG "${options}" "" "" ${ARGN})
if(ARG_MODULE AND ARG_SHARED)
message(FATAL_ERROR "Can't be both MODULE and SHARED")
elseif(ARG_SHARED)
set(lib_type SHARED)
else()
set(lib_type MODULE)
endif()
if(ARG_EXCLUDE_FROM_ALL)
set(exclude_from_all EXCLUDE_FROM_ALL)
endif()
add_library(${target_name} ${lib_type} ${exclude_from_all} ${ARG_UNPARSED_ARGUMENTS})
message("ADDING LIBRARY ${target_name}")
if(ARG_SYSTEM)
set(inc_isystem SYSTEM)
endif()
target_include_directories(${target_name} ${inc_isystem}
PRIVATE ${PYBIND11_INCLUDE_DIR} # from project CMakeLists.txt
PRIVATE ${pybind11_INCLUDE_DIR} # from pybind11Config
PRIVATE ${PYTHON_INCLUDE_DIRS})
# Python debug libraries expose slightly different objects
# https://docs.python.org/3.6/c-api/intro.html#debugging-builds
# https://stackoverflow.com/questions/39161202/how-to-work-around-missing-pymodule-create2-in-amd64-win-python35-d-lib
if(PYTHON_IS_DEBUG)
target_compile_definitions(${target_name} PRIVATE Py_DEBUG)
endif()
# The prefix and extension are provided by FindPythonLibsNew.cmake
set_target_properties(${target_name} PROPERTIES PREFIX "${PYTHON_MODULE_PREFIX}")
set_target_properties(${target_name} PROPERTIES SUFFIX "${PYTHON_MODULE_EXTENSION}")
# -fvisibility=hidden is required to allow multiple modules compiled against
# different pybind versions to work properly, and for some features (e.g.
# py::module_local). We force it on everything inside the `pybind11`
# namespace; also turning it on for a pybind module compilation here avoids
# potential warnings or issues from having mixed hidden/non-hidden types.
set_target_properties(${target_name} PROPERTIES CXX_VISIBILITY_PRESET "hidden")
set_target_properties(${target_name} PROPERTIES CUDA_VISIBILITY_PRESET "hidden")
if(WIN32 OR CYGWIN)
# Link against the Python shared library on Windows
target_link_libraries(${target_name} PRIVATE ${PYTHON_LIBRARIES})
elseif(APPLE)
# It's quite common to have multiple copies of the same Python version
# installed on one's system. E.g.: one copy from the OS and another copy
# that's statically linked into an application like Blender or Maya.
# If we link our plugin library against the OS Python here and import it
# into Blender or Maya later on, this will cause segfaults when multiple
# conflicting Python instances are active at the same time (even when they
# are of the same version).
# Windows is not affected by this issue since it handles DLL imports
# differently. The solution for Linux and Mac OS is simple: we just don't
# link against the Python library. The resulting shared library will have
# missing symbols, but that's perfectly fine -- they will be resolved at
# import time.
target_link_libraries(${target_name} PRIVATE "-undefined dynamic_lookup")
if(ARG_SHARED)
# Suppress CMake >= 3.0 warning for shared libraries
set_target_properties(${target_name} PROPERTIES MACOSX_RPATH ON)
endif()
endif()
# Make sure C++11/14 are enabled
target_compile_options(${target_name} PUBLIC ${PYBIND11_CPP_STANDARD})
if(ARG_NO_EXTRAS)
return()
endif()
_pybind11_add_lto_flags(${target_name} ${ARG_THIN_LTO})
if (NOT MSVC AND NOT ${CMAKE_BUILD_TYPE} MATCHES Debug)
# Strip unnecessary sections of the binary on Linux/Mac OS
if(CMAKE_STRIP)
if(APPLE)
add_custom_command(TARGET ${target_name} POST_BUILD
COMMAND ${CMAKE_STRIP} -x $<TARGET_FILE:${target_name}>)
else()
add_custom_command(TARGET ${target_name} POST_BUILD
COMMAND ${CMAKE_STRIP} $<TARGET_FILE:${target_name}>)
endif()
endif()
endif()
if(MSVC)
# /MP enables multithreaded builds (relevant when there are many files), /bigobj is
# needed for bigger binding projects due to the limit to 64k addressable sections
target_compile_options(${target_name} PRIVATE /bigobj)
if(CMAKE_VERSION VERSION_LESS 3.11)
target_compile_options(${target_name} PRIVATE $<$<NOT:$<CONFIG:Debug>>:/MP>)
else()
# Only set these options for C++ files. This is important so that, for
# instance, projects that include other types of source files like CUDA
# .cu files don't get these options propagated to nvcc since that would
# cause the build to fail.
target_compile_options(${target_name} PRIVATE $<$<NOT:$<CONFIG:Debug>>:$<$<COMPILE_LANGUAGE:CXX>:/MP>>)
endif()
endif()
endfunction()
\ No newline at end of file
#include <torch/extension.h>
#include <torchaudio/csrc/decoder/bindings/_decoder.cpp>
#include <torchaudio/csrc/decoder/bindings/_dictionary.cpp>
PYBIND11_MODULE(_torchaudio_decoder, m) {
#ifdef BUILD_CTC_DECODER
py::enum_<SmearingMode>(m, "_SmearingMode")
.value("NONE", SmearingMode::NONE)
.value("MAX", SmearingMode::MAX)
.value("LOGADD", SmearingMode::LOGADD);
py::class_<TrieNode, TrieNodePtr>(m, "_TrieNode")
.def(py::init<int>(), "idx"_a)
.def_readwrite("children", &TrieNode::children)
.def_readwrite("idx", &TrieNode::idx)
.def_readwrite("labels", &TrieNode::labels)
.def_readwrite("scores", &TrieNode::scores)
.def_readwrite("max_score", &TrieNode::maxScore);
py::class_<Trie, TriePtr>(m, "_Trie")
.def(py::init<int, int>(), "max_children"_a, "root_idx"_a)
.def("get_root", &Trie::getRoot)
.def("insert", &Trie::insert, "indices"_a, "label"_a, "score"_a)
.def("search", &Trie::search, "indices"_a)
.def("smear", &Trie::smear, "smear_mode"_a);
py::class_<LM, LMPtr, PyLM>(m, "_LM")
.def(py::init<>())
.def("start", &LM::start, "start_with_nothing"_a)
.def("score", &LM::score, "state"_a, "usr_token_idx"_a)
.def("finish", &LM::finish, "state"_a);
py::class_<LMState, LMStatePtr>(m, "_LMState")
.def(py::init<>())
.def_readwrite("children", &LMState::children)
.def("compare", &LMState::compare, "state"_a)
.def("child", &LMState::child<LMState>, "usr_index"_a);
py::class_<KenLM, KenLMPtr, LM>(m, "_KenLM")
.def(
py::init<const std::string&, const Dictionary&>(),
"path"_a,
"usr_token_dict"_a);
py::enum_<CriterionType>(m, "_CriterionType")
.value("ASG", CriterionType::ASG)
.value("CTC", CriterionType::CTC);
py::class_<LexiconDecoderOptions>(m, "_LexiconDecoderOptions")
.def(
py::init<
const int,
const int,
const double,
const double,
const double,
const double,
const double,
const bool,
const CriterionType>(),
"beam_size"_a,
"beam_size_token"_a,
"beam_threshold"_a,
"lm_weight"_a,
"word_score"_a,
"unk_score"_a,
"sil_score"_a,
"log_add"_a,
"criterion_type"_a)
.def_readwrite("beam_size", &LexiconDecoderOptions::beamSize)
.def_readwrite("beam_size_token", &LexiconDecoderOptions::beamSizeToken)
.def_readwrite("beam_threshold", &LexiconDecoderOptions::beamThreshold)
.def_readwrite("lm_weight", &LexiconDecoderOptions::lmWeight)
.def_readwrite("word_score", &LexiconDecoderOptions::wordScore)
.def_readwrite("unk_score", &LexiconDecoderOptions::unkScore)
.def_readwrite("sil_score", &LexiconDecoderOptions::silScore)
.def_readwrite("log_add", &LexiconDecoderOptions::logAdd)
.def_readwrite("criterion_type", &LexiconDecoderOptions::criterionType);
py::class_<DecodeResult>(m, "_DecodeResult")
.def(py::init<int>(), "length"_a)
.def_readwrite("score", &DecodeResult::score)
.def_readwrite("amScore", &DecodeResult::amScore)
.def_readwrite("lmScore", &DecodeResult::lmScore)
.def_readwrite("words", &DecodeResult::words)
.def_readwrite("tokens", &DecodeResult::tokens);
// NB: `decode` and `decodeStep` expect raw emissions pointers.
py::class_<LexiconDecoder>(m, "_LexiconDecoder")
.def(py::init<
LexiconDecoderOptions,
const TriePtr,
const LMPtr,
const int,
const int,
const int,
const std::vector<float>&,
const bool>())
.def("decode_begin", &LexiconDecoder::decodeBegin)
.def(
"decode_step",
&LexiconDecoder_decodeStep,
"emissions"_a,
"T"_a,
"N"_a)
.def("decode_end", &LexiconDecoder::decodeEnd)
.def("decode", &LexiconDecoder_decode, "emissions"_a, "T"_a, "N"_a)
.def("prune", &LexiconDecoder::prune, "look_back"_a = 0)
.def(
"get_best_hypothesis",
&LexiconDecoder::getBestHypothesis,
"look_back"_a = 0)
.def("get_all_final_hypothesis", &LexiconDecoder::getAllFinalHypothesis);
py::class_<Dictionary>(m, "_Dictionary")
.def(py::init<>())
.def(py::init<const std::string&>(), "filename"_a)
.def("entry_size", &Dictionary::entrySize)
.def("index_size", &Dictionary::indexSize)
.def("add_entry", &Dictionary_addEntry_0, "entry"_a, "idx"_a)
.def("add_entry", &Dictionary_addEntry_1, "entry"_a)
.def("get_entry", &Dictionary::getEntry, "idx"_a)
.def("set_default_index", &Dictionary::setDefaultIndex, "idx"_a)
.def("get_index", &Dictionary::getIndex, "entry"_a)
.def("contains", &Dictionary::contains, "entry"_a)
.def("is_contiguous", &Dictionary::isContiguous)
.def(
"map_entries_to_indices",
&Dictionary::mapEntriesToIndices,
"entries"_a)
.def(
"map_indices_to_entries",
&Dictionary::mapIndicesToEntries,
"indices"_a);
m.def("_create_word_dict", &createWordDict, "lexicon"_a);
m.def("_load_words", &loadWords, "filename"_a, "max_words"_a = -1);
#endif
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment