Commit 39b6343d authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Migrate CTC decoder code (#2580)

Summary:
This commit gets rid of our copy of CTC decoder code and
replace it with upstream Flashlight-Text repo.

Pull Request resolved: https://github.com/pytorch/audio/pull/2580

Reviewed By: carolineechen

Differential Revision: D38244906

Pulled By: mthrok

fbshipit-source-id: d274240fc67675552d19ff35e9a363b9b9048721
parent 919fd0c4
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/
#pragma once
#include <cstring>
#include <memory>
#include <stdexcept>
#include <unordered_map>
#include <utility>
#include <vector>
namespace torchaudio {
namespace lib {
namespace text {
struct LMState {
std::unordered_map<int, std::shared_ptr<LMState>> children;
template <typename T>
std::shared_ptr<T> child(int usrIdx) {
auto s = children.find(usrIdx);
if (s == children.end()) {
auto state = std::make_shared<T>();
children[usrIdx] = state;
return state;
} else {
return std::static_pointer_cast<T>(s->second);
}
}
/* Compare two language model states. */
int compare(const std::shared_ptr<LMState>& state) const {
LMState* inState = state.get();
if (!state) {
throw std::runtime_error("a state is null");
}
if (this == inState) {
return 0;
} else if (this < inState) {
return -1;
} else {
return 1;
}
};
};
/**
* LMStatePtr is a shared LMState* tracking LM states generated during decoding.
*/
using LMStatePtr = std::shared_ptr<LMState>;
/**
* LM is a thin wrapper for laguage models. We abstrct several common methods
* here which can be shared for KenLM, ConvLM, RNNLM, etc.
*/
class LM {
public:
/* Initialize or reset language model */
virtual LMStatePtr start(bool startWithNothing) = 0;
/**
* Query the language model given input language model state and a specific
* token, return a new language model state and score.
*/
virtual std::pair<LMStatePtr, float> score(
const LMStatePtr& state,
const int usrTokenIdx) = 0;
/* Query the language model and finish decoding. */
virtual std::pair<LMStatePtr, float> finish(const LMStatePtr& state) = 0;
/* Update LM caches (optional) given a bunch of new states generated */
virtual void updateCache(std::vector<LMStatePtr> stateIdices) {}
virtual ~LM() = default;
protected:
/* Map indices from acoustic model to LM for each valid token. */
std::vector<int> usrToLmIdxMap_;
};
using LMPtr = std::shared_ptr<LM>;
} // namespace text
} // namespace lib
} // namespace torchaudio
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in the
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/
#include "torchaudio/csrc/decoder/src/decoder/lm/ZeroLM.h"
#include <stdexcept>
namespace torchaudio {
namespace lib {
namespace text {
LMStatePtr ZeroLM::start(bool /* unused */) {
return std::make_shared<LMState>();
}
std::pair<LMStatePtr, float> ZeroLM::score(
const LMStatePtr& state /* unused */,
const int usrTokenIdx) {
return std::make_pair(state->child<LMState>(usrTokenIdx), 0.0);
}
std::pair<LMStatePtr, float> ZeroLM::finish(const LMStatePtr& state) {
return std::make_pair(state, 0.0);
}
} // namespace text
} // namespace lib
} // namespace torchaudio
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in the
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/
#pragma once
#include "torchaudio/csrc/decoder/src/decoder/lm/LM.h"
namespace torchaudio {
namespace lib {
namespace text {
/**
* ZeroLM is a dummy language model class, which mimics the behavior of a
* uni-gram language model but always returns 0 as score.
*/
class ZeroLM : public LM {
public:
LMStatePtr start(bool startWithNothing) override;
std::pair<LMStatePtr, float> score(
const LMStatePtr& state,
const int usrTokenIdx) override;
std::pair<LMStatePtr, float> finish(const LMStatePtr& state) override;
};
using ZeroLMPtr = std::shared_ptr<ZeroLM>;
} // namespace text
} // namespace lib
} // namespace torchaudio
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/
#pragma once
namespace torchaudio {
namespace lib {
namespace text {
constexpr const char* kUnkToken = "<unk>";
constexpr const char* kEosToken = "</s>";
constexpr const char* kPadToken = "<pad>";
constexpr const char* kMaskToken = "<mask>";
} // namespace text
} // namespace lib
} // namespace torchaudio
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/
#include <iostream>
#include <stdexcept>
#include "torchaudio/csrc/decoder/src/dictionary/Dictionary.h"
#include "torchaudio/csrc/decoder/src/dictionary/String.h"
#include "torchaudio/csrc/decoder/src/dictionary/System.h"
#include "torchaudio/csrc/decoder/src/dictionary/Utils.h"
namespace torchaudio {
namespace lib {
namespace text {
Dictionary::Dictionary(std::istream& stream) {
createFromStream(stream);
}
Dictionary::Dictionary(const std::string& filename) {
std::ifstream stream = createInputStream(filename);
createFromStream(stream);
}
Dictionary::Dictionary(const std::vector<std::string>& tkns) {
for (const auto& tkn : tkns) {
addEntry(tkn);
}
if (!isContiguous()) {
throw std::runtime_error("Invalid dictionary format - not contiguous");
}
}
void Dictionary::createFromStream(std::istream& stream) {
if (!stream) {
throw std::runtime_error("Unable to open dictionary input stream.");
}
std::string line;
while (std::getline(stream, line)) {
if (line.empty()) {
continue;
}
auto tkns = splitOnWhitespace(line, true);
auto idx = idx2entry_.size();
// All entries on the same line map to the same index
for (const auto& tkn : tkns) {
addEntry(tkn, idx);
}
}
if (!isContiguous()) {
throw std::runtime_error("Invalid dictionary format - not contiguous");
}
}
void Dictionary::addEntry(const std::string& entry, int idx) {
if (entry2idx_.find(entry) != entry2idx_.end()) {
throw std::invalid_argument(
"Duplicate entry name in dictionary '" + entry + "'");
}
entry2idx_[entry] = idx;
if (idx2entry_.find(idx) == idx2entry_.end()) {
idx2entry_[idx] = entry;
}
}
void Dictionary::addEntry(const std::string& entry) {
// Check if the entry already exists in the dictionary
if (entry2idx_.find(entry) != entry2idx_.end()) {
throw std::invalid_argument(
"Duplicate entry in dictionary '" + entry + "'");
}
int idx = idx2entry_.size();
// Find first available index.
while (idx2entry_.find(idx) != idx2entry_.end()) {
++idx;
}
addEntry(entry, idx);
}
std::string Dictionary::getEntry(int idx) const {
auto iter = idx2entry_.find(idx);
if (iter == idx2entry_.end()) {
throw std::invalid_argument(
"Unknown index in dictionary '" + std::to_string(idx) + "'");
}
return iter->second;
}
void Dictionary::setDefaultIndex(int idx) {
defaultIndex_ = idx;
}
int Dictionary::getIndex(const std::string& entry) const {
auto iter = entry2idx_.find(entry);
if (iter == entry2idx_.end()) {
if (defaultIndex_ < 0) {
throw std::invalid_argument(
"Unknown entry in dictionary: '" + entry + "'");
} else {
return defaultIndex_;
}
}
return iter->second;
}
bool Dictionary::contains(const std::string& entry) const {
auto iter = entry2idx_.find(entry);
if (iter == entry2idx_.end()) {
return false;
}
return true;
}
size_t Dictionary::entrySize() const {
return entry2idx_.size();
}
bool Dictionary::isContiguous() const {
for (size_t i = 0; i < indexSize(); ++i) {
if (idx2entry_.find(i) == idx2entry_.end()) {
return false;
}
}
for (const auto& tknidx : entry2idx_) {
if (idx2entry_.find(tknidx.second) == idx2entry_.end()) {
return false;
}
}
return true;
}
std::vector<int> Dictionary::mapEntriesToIndices(
const std::vector<std::string>& entries) const {
std::vector<int> indices;
indices.reserve(entries.size());
for (const auto& tkn : entries) {
indices.emplace_back(getIndex(tkn));
}
return indices;
}
std::vector<std::string> Dictionary::mapIndicesToEntries(
const std::vector<int>& indices) const {
std::vector<std::string> entries;
entries.reserve(indices.size());
for (const auto& idx : indices) {
entries.emplace_back(getEntry(idx));
}
return entries;
}
size_t Dictionary::indexSize() const {
return idx2entry_.size();
}
} // namespace text
} // namespace lib
} // namespace torchaudio
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/
#pragma once
#include <istream>
#include <string>
#include <unordered_map>
#include <vector>
namespace torchaudio {
namespace lib {
namespace text {
// A simple dictionary class which holds a bidirectional map
// entry (strings) <--> integer indices. Not thread-safe !
class Dictionary {
public:
// Creates an empty dictionary
Dictionary() {}
explicit Dictionary(std::istream& stream);
explicit Dictionary(const std::string& filename);
explicit Dictionary(const std::vector<std::string>& tkns);
size_t entrySize() const;
size_t indexSize() const;
void addEntry(const std::string& entry, int idx);
void addEntry(const std::string& entry);
std::string getEntry(int idx) const;
void setDefaultIndex(int idx);
int getIndex(const std::string& entry) const;
bool contains(const std::string& entry) const;
// checks if all the indices are contiguous
bool isContiguous() const;
std::vector<int> mapEntriesToIndices(
const std::vector<std::string>& entries) const;
std::vector<std::string> mapIndicesToEntries(
const std::vector<int>& indices) const;
private:
// Creates a dictionary from an input stream
void createFromStream(std::istream& stream);
std::unordered_map<std::string, int> entry2idx_;
std::unordered_map<int, std::string> idx2entry_;
int defaultIndex_ = -1;
};
typedef std::unordered_map<int, Dictionary> DictionaryMap;
} // namespace text
} // namespace lib
} // namespace torchaudio
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/
#include "torchaudio/csrc/decoder/src/dictionary/String.h"
#include <sys/types.h>
#include <array>
#include <cstdlib>
#include <ctime>
#include <functional>
static constexpr const char* kSpaceChars = "\t\n\v\f\r ";
namespace torchaudio {
namespace lib {
std::string trim(const std::string& str) {
auto i = str.find_first_not_of(kSpaceChars);
if (i == std::string::npos) {
return "";
}
auto j = str.find_last_not_of(kSpaceChars);
if (j == std::string::npos || i > j) {
return "";
}
return str.substr(i, j - i + 1);
}
void replaceAll(
std::string& str,
const std::string& from,
const std::string& repl) {
if (from.empty()) {
return;
}
size_t pos = 0;
while ((pos = str.find(from, pos)) != std::string::npos) {
str.replace(pos, from.length(), repl);
pos += repl.length();
}
}
bool startsWith(const std::string& input, const std::string& pattern) {
return (input.find(pattern) == 0);
}
bool endsWith(const std::string& input, const std::string& pattern) {
if (pattern.size() > input.size()) {
return false;
}
return std::equal(pattern.rbegin(), pattern.rend(), input.rbegin());
}
template <bool Any, typename Delim>
static std::vector<std::string> splitImpl(
const Delim& delim,
std::string::size_type delimSize,
const std::string& input,
bool ignoreEmpty = false) {
std::vector<std::string> result;
std::string::size_type i = 0;
while (true) {
auto j = Any ? input.find_first_of(delim, i) : input.find(delim, i);
if (j == std::string::npos) {
break;
}
if (!(ignoreEmpty && i == j)) {
result.emplace_back(input.begin() + i, input.begin() + j);
}
i = j + delimSize;
}
if (!(ignoreEmpty && i == input.size())) {
result.emplace_back(input.begin() + i, input.end());
}
return result;
}
std::vector<std::string> split(
char delim,
const std::string& input,
bool ignoreEmpty) {
return splitImpl<false>(delim, 1, input, ignoreEmpty);
}
std::vector<std::string> split(
const std::string& delim,
const std::string& input,
bool ignoreEmpty) {
if (delim.empty()) {
throw std::invalid_argument("delimiter is empty string");
}
return splitImpl<false>(delim, delim.size(), input, ignoreEmpty);
}
std::vector<std::string> splitOnAnyOf(
const std::string& delim,
const std::string& input,
bool ignoreEmpty) {
return splitImpl<true>(delim, 1, input, ignoreEmpty);
}
std::vector<std::string> splitOnWhitespace(
const std::string& input,
bool ignoreEmpty) {
return splitOnAnyOf(kSpaceChars, input, ignoreEmpty);
}
std::string join(
const std::string& delim,
const std::vector<std::string>& vec) {
return join(delim, vec.begin(), vec.end());
}
} // namespace lib
} // namespace torchaudio
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/
#pragma once
#include <errno.h>
#include <algorithm>
#include <chrono>
#include <cstring>
#include <stdexcept>
#include <string>
#include <type_traits>
#include <unordered_map>
#include <vector>
namespace torchaudio {
namespace lib {
// ============================ Types and Templates ============================
template <typename It>
using DecayDereference =
typename std::decay<decltype(*std::declval<It>())>::type;
template <typename S, typename T>
using EnableIfSame = typename std::enable_if<std::is_same<S, T>::value>::type;
// ================================== Functions
// ==================================
std::string trim(const std::string& str);
void replaceAll(
std::string& str,
const std::string& from,
const std::string& repl);
bool startsWith(const std::string& input, const std::string& pattern);
bool endsWith(const std::string& input, const std::string& pattern);
std::vector<std::string> split(
char delim,
const std::string& input,
bool ignoreEmpty = false);
std::vector<std::string> split(
const std::string& delim,
const std::string& input,
bool ignoreEmpty = false);
std::vector<std::string> splitOnAnyOf(
const std::string& delim,
const std::string& input,
bool ignoreEmpty = false);
std::vector<std::string> splitOnWhitespace(
const std::string& input,
bool ignoreEmpty = false);
/**
* Join a vector of `std::string` inserting `delim` in between.
*/
std::string join(const std::string& delim, const std::vector<std::string>& vec);
/**
* Join a range of `std::string` specified by iterators.
*/
template <
typename FwdIt,
typename = EnableIfSame<DecayDereference<FwdIt>, std::string>>
std::string join(const std::string& delim, FwdIt begin, FwdIt end) {
if (begin == end) {
return "";
}
size_t totalSize = begin->size();
for (auto it = std::next(begin); it != end; ++it) {
totalSize += delim.size() + it->size();
}
std::string result;
result.reserve(totalSize);
result.append(*begin);
for (auto it = std::next(begin); it != end; ++it) {
result.append(delim);
result.append(*it);
}
return result;
}
/**
* Create an output string using a `printf`-style format string and arguments.
* Safer than `sprintf` which is vulnerable to buffer overflow.
*/
template <class... Args>
std::string format(const char* fmt, Args&&... args) {
auto res = std::snprintf(nullptr, 0, fmt, std::forward<Args>(args)...);
if (res < 0) {
throw std::runtime_error(std::strerror(errno));
}
std::string buf(res, '\0');
// the size here is fine -- it's legal to write '\0' to buf[res]
auto res2 = std::snprintf(&buf[0], res + 1, fmt, std::forward<Args>(args)...);
if (res2 < 0) {
throw std::runtime_error(std::strerror(errno));
}
if (res2 != res) {
throw std::runtime_error(
"The size of the formated string is not equal to what it is expected.");
}
return buf;
}
/**
* Dedup the elements in a vector.
*/
template <class T>
void dedup(std::vector<T>& in) {
if (in.empty()) {
return;
}
auto it = std::unique(in.begin(), in.end());
in.resize(std::distance(in.begin(), it));
}
} // namespace lib
} // namespace torchaudio
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/
#include "torchaudio/csrc/decoder/src/dictionary/System.h"
#include <glob.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <array>
#include <cstdlib>
#include <ctime>
#include <functional>
#ifdef _WIN32
#include <windows.h>
#else
#include <unistd.h>
#endif
#include "torchaudio/csrc/decoder/src/dictionary/String.h"
namespace torchaudio {
namespace lib {
size_t getProcessId() {
#ifdef _WIN32
return GetCurrentProcessId();
#else
return ::getpid();
#endif
}
size_t getThreadId() {
#ifdef _WIN32
return GetCurrentThreadId();
#else
return std::hash<std::thread::id>()(std::this_thread::get_id());
#endif
}
std::string pathSeperator() {
#ifdef _WIN32
return "\\";
#else
return "/";
#endif
}
std::string pathsConcat(const std::string& p1, const std::string& p2) {
if (!p1.empty() && p1[p1.length() - 1] != pathSeperator()[0]) {
return (
trim(p1) + pathSeperator() + trim(p2)); // Need to add a path separator
} else {
return (trim(p1) + trim(p2));
}
}
namespace {
/**
* @path contains directories separated by path separator.
* Returns a vector with the directores in the original order. Vector with a
* Special cases: a vector with a single entry containing the input is returned
* when path is one of the following special cases: empty, “.”, “..” and “/”
*/
std::vector<std::string> getDirsOnPath(const std::string& path) {
const std::string trimPath = trim(path);
if (trimPath.empty() || trimPath == pathSeperator() || trimPath == "." ||
trimPath == "..") {
return {trimPath};
}
const std::vector<std::string> tokens = split(pathSeperator(), trimPath);
std::vector<std::string> dirs;
for (const std::string& token : tokens) {
const std::string dir = trim(token);
if (!dir.empty()) {
dirs.push_back(dir);
}
}
return dirs;
}
} // namespace
std::string dirname(const std::string& path) {
std::vector<std::string> dirsOnPath = getDirsOnPath(path);
if (dirsOnPath.size() < 2) {
return ".";
} else {
dirsOnPath.pop_back();
const std::string root =
((trim(path))[0] == pathSeperator()[0]) ? pathSeperator() : "";
return root + join(pathSeperator(), dirsOnPath);
}
}
std::string basename(const std::string& path) {
std::vector<std::string> dirsOnPath = getDirsOnPath(path);
if (dirsOnPath.empty()) {
return "";
} else {
return dirsOnPath.back();
}
}
bool dirExists(const std::string& path) {
struct stat info;
if (stat(path.c_str(), &info) != 0) {
return false;
} else if (info.st_mode & S_IFDIR) {
return true;
} else {
return false;
}
}
void dirCreate(const std::string& path) {
if (dirExists(path)) {
return;
}
mode_t nMode = 0755;
int nError = 0;
#ifdef _WIN32
nError = _mkdir(path.c_str());
#else
nError = mkdir(path.c_str(), nMode);
#endif
if (nError != 0) {
throw std::runtime_error(
std::string() + "Unable to create directory - " + path);
}
}
void dirCreateRecursive(const std::string& path) {
if (dirExists(path)) {
return;
}
std::vector<std::string> dirsOnPath = getDirsOnPath(path);
std::string pathFromStart;
if (path[0] == pathSeperator()[0]) {
pathFromStart = pathSeperator();
}
for (std::string& dir : dirsOnPath) {
if (pathFromStart.empty()) {
pathFromStart = dir;
} else {
pathFromStart = pathsConcat(pathFromStart, dir);
}
if (!dirExists(pathFromStart)) {
dirCreate(pathFromStart);
}
}
}
bool fileExists(const std::string& path) {
std::ifstream fs(path, std::ifstream::in);
return fs.good();
}
std::string getEnvVar(
const std::string& key,
const std::string& dflt /*= "" */) {
char* val = getenv(key.c_str());
return val ? std::string(val) : dflt;
}
std::string getCurrentDate() {
time_t now = time(nullptr);
struct tm tmbuf;
struct tm* tstruct;
tstruct = localtime_r(&now, &tmbuf);
std::array<char, 80> buf;
strftime(buf.data(), buf.size(), "%Y-%m-%d", tstruct);
return std::string(buf.data());
}
std::string getCurrentTime() {
time_t now = time(nullptr);
struct tm tmbuf;
struct tm* tstruct;
tstruct = localtime_r(&now, &tmbuf);
std::array<char, 80> buf;
strftime(buf.data(), buf.size(), "%X", tstruct);
return std::string(buf.data());
}
std::string getTmpPath(const std::string& filename) {
std::string tmpDir = "/tmp";
auto getTmpDir = [&tmpDir](const std::string& env) {
char* dir = std::getenv(env.c_str());
if (dir != nullptr) {
tmpDir = std::string(dir);
}
};
getTmpDir("TMPDIR");
getTmpDir("TEMP");
getTmpDir("TMP");
return tmpDir + "/fl_tmp_" + getEnvVar("USER", "unknown") + "_" + filename;
}
std::vector<std::string> getFileContent(const std::string& file) {
std::vector<std::string> data;
std::ifstream in = createInputStream(file);
std::string str;
while (std::getline(in, str)) {
data.emplace_back(str);
}
in.close();
return data;
}
std::vector<std::string> fileGlob(const std::string& pat) {
glob_t result;
glob(pat.c_str(), GLOB_TILDE, nullptr, &result);
std::vector<std::string> ret;
for (unsigned int i = 0; i < result.gl_pathc; ++i) {
ret.push_back(std::string(result.gl_pathv[i]));
}
globfree(&result);
return ret;
}
std::ifstream createInputStream(const std::string& filename) {
std::ifstream file(filename);
if (!file.is_open()) {
throw std::runtime_error("Failed to open file for reading: " + filename);
}
return file;
}
std::ofstream createOutputStream(
const std::string& filename,
std::ios_base::openmode mode) {
std::ofstream file(filename, mode);
if (!file.is_open()) {
throw std::runtime_error("Failed to open file for writing: " + filename);
}
return file;
}
} // namespace lib
} // namespace torchaudio
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/
#pragma once
#include <chrono>
#include <fstream>
#include <string>
#include <thread>
#include <type_traits>
#include <vector>
namespace torchaudio {
namespace lib {
size_t getProcessId();
size_t getThreadId();
std::string pathsConcat(const std::string& p1, const std::string& p2);
std::string pathSeperator();
std::string dirname(const std::string& path);
std::string basename(const std::string& path);
bool dirExists(const std::string& path);
void dirCreate(const std::string& path);
void dirCreateRecursive(const std::string& path);
bool fileExists(const std::string& path);
std::string getEnvVar(const std::string& key, const std::string& dflt = "");
std::string getCurrentDate();
std::string getCurrentTime();
std::string getTmpPath(const std::string& filename);
std::vector<std::string> getFileContent(const std::string& file);
std::vector<std::string> fileGlob(const std::string& pat);
std::ifstream createInputStream(const std::string& filename);
std::ofstream createOutputStream(
const std::string& filename,
std::ios_base::openmode mode = std::ios_base::out);
/**
* Calls `f(args...)` repeatedly, retrying if an exception is thrown.
* Supports sleeps between retries, with duration starting at `initial` and
* multiplying by `factor` each retry. At most `maxIters` calls are made.
*/
template <class Fn, class... Args>
typename std::result_of<Fn(Args...)>::type retryWithBackoff(
std::chrono::duration<double> initial,
double factor,
int64_t maxIters,
Fn&& f,
Args&&... args) {
if (!(initial.count() >= 0.0)) {
throw std::invalid_argument("retryWithBackoff: bad initial");
} else if (!(factor >= 0.0)) {
throw std::invalid_argument("retryWithBackoff: bad factor");
} else if (maxIters <= 0) {
throw std::invalid_argument("retryWithBackoff: bad maxIters");
}
auto sleepSecs = initial.count();
for (int64_t i = 0; i < maxIters; ++i) {
try {
return f(std::forward<Args>(args)...);
} catch (...) {
if (i >= maxIters - 1) {
throw;
}
}
if (sleepSecs > 0.0) {
/* sleep override */
std::this_thread::sleep_for(
std::chrono::duration<double>(std::min(1e7, sleepSecs)));
}
sleepSecs *= factor;
}
throw std::logic_error("retryWithBackoff: hit unreachable");
}
} // namespace lib
} // namespace torchaudio
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/
#include "torchaudio/csrc/decoder/src/dictionary/Utils.h"
#include "torchaudio/csrc/decoder/src/dictionary/Defines.h"
#include "torchaudio/csrc/decoder/src/dictionary/String.h"
#include "torchaudio/csrc/decoder/src/dictionary/System.h"
namespace torchaudio {
namespace lib {
namespace text {
Dictionary createWordDict(const LexiconMap& lexicon) {
Dictionary dict;
for (const auto& it : lexicon) {
dict.addEntry(it.first);
}
dict.setDefaultIndex(dict.getIndex(kUnkToken));
return dict;
}
LexiconMap loadWords(const std::string& filename, int maxWords) {
LexiconMap lexicon;
std::string line;
std::ifstream infile = createInputStream(filename);
// Add at most `maxWords` words into the lexicon.
// If `maxWords` is negative then no limit is applied.
while (maxWords != lexicon.size() && std::getline(infile, line)) {
// Parse the line into two strings: word and spelling.
auto fields = splitOnWhitespace(line, true);
if (fields.size() < 2) {
throw std::runtime_error("[loadWords] Invalid line: " + line);
}
const std::string& word = fields[0];
std::vector<std::string> spelling(fields.size() - 1);
std::copy(fields.begin() + 1, fields.end(), spelling.begin());
// Add the word into the dictionary.
if (lexicon.find(word) == lexicon.end()) {
lexicon[word] = {};
}
// Add the current spelling of the words to the list of spellings.
lexicon[word].push_back(spelling);
}
// Insert unknown word.
lexicon[kUnkToken] = {};
return lexicon;
}
std::vector<std::string> splitWrd(const std::string& word) {
std::vector<std::string> tokens;
tokens.reserve(word.size());
int len = word.length();
for (int i = 0; i < len;) {
auto c = static_cast<unsigned char>(word[i]);
int curTknBytes = -1;
// UTF-8 checks, works for ASCII automatically
if ((c & 0x80) == 0) {
curTknBytes = 1;
} else if ((c & 0xE0) == 0xC0) {
curTknBytes = 2;
} else if ((c & 0xF0) == 0xE0) {
curTknBytes = 3;
} else if ((c & 0xF8) == 0xF0) {
curTknBytes = 4;
}
if (curTknBytes == -1 || i + curTknBytes > len) {
throw std::runtime_error("splitWrd: invalid UTF-8 : " + word);
}
tokens.emplace_back(word.begin() + i, word.begin() + i + curTknBytes);
i += curTknBytes;
}
return tokens;
}
std::vector<int> packReplabels(
const std::vector<int>& tokens,
const Dictionary& dict,
int maxReps) {
if (tokens.empty() || maxReps <= 0) {
return tokens;
}
std::vector<int> replabelValueToIdx(maxReps + 1);
for (int i = 1; i <= maxReps; ++i) {
replabelValueToIdx[i] = dict.getIndex("<" + std::to_string(i) + ">");
}
std::vector<int> result;
int prevToken = -1;
int numReps = 0;
for (int token : tokens) {
if (token == prevToken && numReps < maxReps) {
numReps++;
} else {
if (numReps > 0) {
result.push_back(replabelValueToIdx[numReps]);
numReps = 0;
}
result.push_back(token);
prevToken = token;
}
}
if (numReps > 0) {
result.push_back(replabelValueToIdx[numReps]);
}
return result;
}
std::vector<int> unpackReplabels(
const std::vector<int>& tokens,
const Dictionary& dict,
int maxReps) {
if (tokens.empty() || maxReps <= 0) {
return tokens;
}
std::unordered_map<int, int> replabelIdxToValue;
for (int i = 1; i <= maxReps; ++i) {
replabelIdxToValue.emplace(dict.getIndex("<" + std::to_string(i) + ">"), i);
}
std::vector<int> result;
int prevToken = -1;
for (int token : tokens) {
auto it = replabelIdxToValue.find(token);
if (it == replabelIdxToValue.end()) {
result.push_back(token);
prevToken = token;
} else if (prevToken != -1) {
result.insert(result.end(), it->second, prevToken);
prevToken = -1;
}
}
return result;
}
} // namespace text
} // namespace lib
} // namespace torchaudio
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/
#pragma once
#include <string>
#include <unordered_map>
#include <vector>
#include "torchaudio/csrc/decoder/src/dictionary/Dictionary.h"
namespace torchaudio {
namespace lib {
namespace text {
using LexiconMap =
std::unordered_map<std::string, std::vector<std::vector<std::string>>>;
Dictionary createWordDict(const LexiconMap& lexicon);
LexiconMap loadWords(const std::string& filename, int maxWords = -1);
// split word into tokens abc -> {"a", "b", "c"}
// Works with ASCII, UTF-8 encodings
std::vector<std::string> splitWrd(const std::string& word);
/**
* Pack a token sequence by replacing consecutive repeats with replabels,
* e.g. "abbccc" -> "ab1c2". The tokens "1", "2", ..., `to_string(maxReps)`
* must already be in `dict`.
*/
std::vector<int> packReplabels(
const std::vector<int>& tokens,
const Dictionary& dict,
int maxReps);
/**
* Unpack a token sequence by replacing replabels with repeated tokens,
* e.g. "ab1c2" -> "abbccc". The tokens "1", "2", ..., `to_string(maxReps)`
* must already be in `dict`.
*/
std::vector<int> unpackReplabels(
const std::vector<int>& tokens,
const Dictionary& dict,
int maxReps);
} // namespace text
} // namespace lib
} // namespace torchaudio
...@@ -28,21 +28,23 @@ try: ...@@ -28,21 +28,23 @@ try:
load_words as _load_words, load_words as _load_words,
) )
except Exception: except Exception:
torchaudio._extension._load_lib("libtorchaudio_decoder") torchaudio._extension._load_lib("libflashlight-text")
from torchaudio._torchaudio_decoder import ( from torchaudio.flashlight_lib_text_decoder import (
_create_word_dict, CriterionType as _CriterionType,
_CriterionType, KenLM as _KenLM,
_Dictionary, LexiconDecoder as _LexiconDecoder,
_KenLM, LexiconDecoderOptions as _LexiconDecoderOptions,
_LexiconDecoder, LexiconFreeDecoder as _LexiconFreeDecoder,
_LexiconDecoderOptions, LexiconFreeDecoderOptions as _LexiconFreeDecoderOptions,
_LexiconFreeDecoder, LM as _LM,
_LexiconFreeDecoderOptions, SmearingMode as _SmearingMode,
_LM, Trie as _Trie,
_load_words, ZeroLM as _ZeroLM,
_SmearingMode, )
_Trie, from torchaudio.flashlight_lib_text_dictionary import (
_ZeroLM, create_word_dict as _create_word_dict,
Dictionary as _Dictionary,
load_words as _load_words,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment