Commit 24baa243 authored by Caroline Chen's avatar Caroline Chen Committed by Facebook GitHub Bot
Browse files

Add C++ files for CTC decoder (#2075)

Summary:
part of https://github.com/pytorch/audio/issues/2072 -- splitting up the PR for easier review

Add C++ files from [flashlight](https://github.com/flashlight/flashlight) that are needed for building CTC decoder w/ Lexicon and KenLM support

Note: the code here will not be compiled until the build process is changed (future PR)

Pull Request resolved: https://github.com/pytorch/audio/pull/2075

Reviewed By: mthrok

Differential Revision: D33186825

Pulled By: carolineechen

fbshipit-source-id: 5b69eea7634f3fae686471d988422942bb784cd9
parent adc559a8
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/
#pragma once
#include "torchaudio/csrc/decoder/src/decoder/Utils.h"
namespace torchaudio {
namespace lib {
namespace text {
enum class CriterionType { ASG = 0, CTC = 1, S2S = 2 };
/**
* Decoder support two typical use cases:
* Offline manner:
* decoder.decode(someData) [returns all hypothesis (transcription)]
*
* Online manner:
* decoder.decodeBegin() [called only at the beginning of the stream]
* while (stream)
* decoder.decodeStep(someData) [one or more calls]
* decoder.getBestHypothesis() [returns the best hypothesis (transcription)]
* decoder.prune() [prunes the hypothesis space]
* decoder.decodeEnd() [called only at the end of the stream]
*
* Note: function decoder.prune() deletes hypothesis up until time when called
* to supports online decoding. It will also add a offset to the scores in beam
* to avoid underflow/overflow.
*
*/
class Decoder {
public:
Decoder() = default;
virtual ~Decoder() = default;
/* Initialize decoder before starting consume emissions */
virtual void decodeBegin() {}
/* Consume emissions in T x N chunks and increase the hypothesis space */
virtual void decodeStep(const float* emissions, int T, int N) = 0;
/* Finish up decoding after consuming all emissions */
virtual void decodeEnd() {}
/* Offline decode function, which consume all emissions at once */
virtual std::vector<DecodeResult> decode(
const float* emissions,
int T,
int N) {
decodeBegin();
decodeStep(emissions, T, N);
decodeEnd();
return getAllFinalHypothesis();
}
/* Prune the hypothesis space */
virtual void prune(int lookBack = 0) = 0;
/* Get the number of decoded frame in buffer */
virtual int nDecodedFramesInBuffer() const = 0;
/*
* Get the best completed hypothesis which is `lookBack` frames ahead the last
* one in buffer. For lexicon requiredd LMs, completed hypothesis means no
* partial word appears at the end.
*/
virtual DecodeResult getBestHypothesis(int lookBack = 0) const = 0;
/* Get all the final hypothesis */
virtual std::vector<DecodeResult> getAllFinalHypothesis() const = 0;
};
} // namespace text
} // namespace lib
} // namespace torchaudio
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/
#include <stdlib.h>
#include <algorithm>
#include <cmath>
#include <functional>
#include <numeric>
#include <unordered_map>
#include "torchaudio/csrc/decoder/src/decoder/LexiconDecoder.h"
namespace torchaudio {
namespace lib {
namespace text {
void LexiconDecoder::decodeBegin() {
hyp_.clear();
hyp_.emplace(0, std::vector<LexiconDecoderState>());
/* note: the lm reset itself with :start() */
hyp_[0].emplace_back(
0.0, lm_->start(0), lexicon_->getRoot(), nullptr, sil_, -1);
nDecodedFrames_ = 0;
nPrunedFrames_ = 0;
}
void LexiconDecoder::decodeStep(const float* emissions, int T, int N) {
int startFrame = nDecodedFrames_ - nPrunedFrames_;
// Extend hyp_ buffer
if (hyp_.size() < startFrame + T + 2) {
for (int i = hyp_.size(); i < startFrame + T + 2; i++) {
hyp_.emplace(i, std::vector<LexiconDecoderState>());
}
}
std::vector<size_t> idx(N);
for (int t = 0; t < T; t++) {
std::iota(idx.begin(), idx.end(), 0);
if (N > opt_.beamSizeToken) {
std::partial_sort(
idx.begin(),
idx.begin() + opt_.beamSizeToken,
idx.end(),
[&t, &N, &emissions](const size_t& l, const size_t& r) {
return emissions[t * N + l] > emissions[t * N + r];
});
}
candidatesReset(candidatesBestScore_, candidates_, candidatePtrs_);
for (const LexiconDecoderState& prevHyp : hyp_[startFrame + t]) {
const TrieNode* prevLex = prevHyp.lex;
const int prevIdx = prevHyp.token;
const float lexMaxScore =
prevLex == lexicon_->getRoot() ? 0 : prevLex->maxScore;
/* (1) Try children */
for (int r = 0; r < std::min(opt_.beamSizeToken, N); ++r) {
int n = idx[r];
auto iter = prevLex->children.find(n);
if (iter == prevLex->children.end()) {
continue;
}
const TrieNodePtr& lex = iter->second;
double amScore = emissions[t * N + n];
if (nDecodedFrames_ + t > 0 &&
opt_.criterionType == CriterionType::ASG) {
amScore += transitions_[n * N + prevIdx];
}
double score = prevHyp.score + amScore;
if (n == sil_) {
score += opt_.silScore;
}
LMStatePtr lmState;
double lmScore = 0.;
if (isLmToken_) {
auto lmStateScorePair = lm_->score(prevHyp.lmState, n);
lmState = lmStateScorePair.first;
lmScore = lmStateScorePair.second;
}
// We eat-up a new token
if (opt_.criterionType != CriterionType::CTC || prevHyp.prevBlank ||
n != prevIdx) {
if (!lex->children.empty()) {
if (!isLmToken_) {
lmState = prevHyp.lmState;
lmScore = lex->maxScore - lexMaxScore;
}
candidatesAdd(
candidates_,
candidatesBestScore_,
opt_.beamThreshold,
score + opt_.lmWeight * lmScore,
lmState,
lex.get(),
&prevHyp,
n,
-1,
false, // prevBlank
prevHyp.amScore + amScore,
prevHyp.lmScore + lmScore);
}
}
// If we got a true word
for (auto label : lex->labels) {
if (prevLex == lexicon_->getRoot() && prevHyp.token == n) {
// This is to avoid an situation that, when there is word with
// single token spelling (e.g. X -> x) in the lexicon and token `x`
// is predicted in several consecutive frames, multiple word `X`
// will be emitted. This violates the property of CTC, where
// there must be an blank token in between to predict 2 identical
// tokens consecutively.
continue;
}
if (!isLmToken_) {
auto lmStateScorePair = lm_->score(prevHyp.lmState, label);
lmState = lmStateScorePair.first;
lmScore = lmStateScorePair.second - lexMaxScore;
}
candidatesAdd(
candidates_,
candidatesBestScore_,
opt_.beamThreshold,
score + opt_.lmWeight * lmScore + opt_.wordScore,
lmState,
lexicon_->getRoot(),
&prevHyp,
n,
label,
false, // prevBlank
prevHyp.amScore + amScore,
prevHyp.lmScore + lmScore);
}
// If we got an unknown word
if (lex->labels.empty() && (opt_.unkScore > kNegativeInfinity)) {
if (!isLmToken_) {
auto lmStateScorePair = lm_->score(prevHyp.lmState, unk_);
lmState = lmStateScorePair.first;
lmScore = lmStateScorePair.second - lexMaxScore;
}
candidatesAdd(
candidates_,
candidatesBestScore_,
opt_.beamThreshold,
score + opt_.lmWeight * lmScore + opt_.unkScore,
lmState,
lexicon_->getRoot(),
&prevHyp,
n,
unk_,
false, // prevBlank
prevHyp.amScore + amScore,
prevHyp.lmScore + lmScore);
}
}
/* (2) Try same lexicon node */
if (opt_.criterionType != CriterionType::CTC || !prevHyp.prevBlank ||
prevLex == lexicon_->getRoot()) {
int n = prevLex == lexicon_->getRoot() ? sil_ : prevIdx;
double amScore = emissions[t * N + n];
if (nDecodedFrames_ + t > 0 &&
opt_.criterionType == CriterionType::ASG) {
amScore += transitions_[n * N + prevIdx];
}
double score = prevHyp.score + amScore;
if (n == sil_) {
score += opt_.silScore;
}
candidatesAdd(
candidates_,
candidatesBestScore_,
opt_.beamThreshold,
score,
prevHyp.lmState,
prevLex,
&prevHyp,
n,
-1,
false, // prevBlank
prevHyp.amScore + amScore,
prevHyp.lmScore);
}
/* (3) CTC only, try blank */
if (opt_.criterionType == CriterionType::CTC) {
int n = blank_;
double amScore = emissions[t * N + n];
candidatesAdd(
candidates_,
candidatesBestScore_,
opt_.beamThreshold,
prevHyp.score + amScore,
prevHyp.lmState,
prevLex,
&prevHyp,
n,
-1,
true, // prevBlank
prevHyp.amScore + amScore,
prevHyp.lmScore);
}
// finish proposing
}
candidatesStore(
candidates_,
candidatePtrs_,
hyp_[startFrame + t + 1],
opt_.beamSize,
candidatesBestScore_ - opt_.beamThreshold,
opt_.logAdd,
false);
updateLMCache(lm_, hyp_[startFrame + t + 1]);
}
nDecodedFrames_ += T;
}
void LexiconDecoder::decodeEnd() {
candidatesReset(candidatesBestScore_, candidates_, candidatePtrs_);
bool hasNiceEnding = false;
for (const LexiconDecoderState& prevHyp :
hyp_[nDecodedFrames_ - nPrunedFrames_]) {
if (prevHyp.lex == lexicon_->getRoot()) {
hasNiceEnding = true;
break;
}
}
for (const LexiconDecoderState& prevHyp :
hyp_[nDecodedFrames_ - nPrunedFrames_]) {
const TrieNode* prevLex = prevHyp.lex;
const LMStatePtr& prevLmState = prevHyp.lmState;
if (!hasNiceEnding || prevHyp.lex == lexicon_->getRoot()) {
auto lmStateScorePair = lm_->finish(prevLmState);
auto lmScore = lmStateScorePair.second;
candidatesAdd(
candidates_,
candidatesBestScore_,
opt_.beamThreshold,
prevHyp.score + opt_.lmWeight * lmScore,
lmStateScorePair.first,
prevLex,
&prevHyp,
sil_,
-1,
false, // prevBlank
prevHyp.amScore,
prevHyp.lmScore + lmScore);
}
}
candidatesStore(
candidates_,
candidatePtrs_,
hyp_[nDecodedFrames_ - nPrunedFrames_ + 1],
opt_.beamSize,
candidatesBestScore_ - opt_.beamThreshold,
opt_.logAdd,
true);
++nDecodedFrames_;
}
std::vector<DecodeResult> LexiconDecoder::getAllFinalHypothesis() const {
int finalFrame = nDecodedFrames_ - nPrunedFrames_;
if (finalFrame < 1) {
return std::vector<DecodeResult>{};
}
return getAllHypothesis(hyp_.find(finalFrame)->second, finalFrame);
}
DecodeResult LexiconDecoder::getBestHypothesis(int lookBack) const {
if (nDecodedFrames_ - nPrunedFrames_ - lookBack < 1) {
return DecodeResult();
}
const LexiconDecoderState* bestNode = findBestAncestor(
hyp_.find(nDecodedFrames_ - nPrunedFrames_)->second, lookBack);
return getHypothesis(bestNode, nDecodedFrames_ - nPrunedFrames_ - lookBack);
}
int LexiconDecoder::nHypothesis() const {
int finalFrame = nDecodedFrames_ - nPrunedFrames_;
return hyp_.find(finalFrame)->second.size();
}
int LexiconDecoder::nDecodedFramesInBuffer() const {
return nDecodedFrames_ - nPrunedFrames_ + 1;
}
void LexiconDecoder::prune(int lookBack) {
if (nDecodedFrames_ - nPrunedFrames_ - lookBack < 1) {
return; // Not enough decoded frames to prune
}
/* (1) Find the last emitted word in the best path */
const LexiconDecoderState* bestNode = findBestAncestor(
hyp_.find(nDecodedFrames_ - nPrunedFrames_)->second, lookBack);
if (!bestNode) {
return; // Not enough decoded frames to prune
}
int startFrame = nDecodedFrames_ - nPrunedFrames_ - lookBack;
if (startFrame < 1) {
return; // Not enough decoded frames to prune
}
/* (2) Move things from back of hyp_ to front and normalize scores */
pruneAndNormalize(hyp_, startFrame, lookBack);
nPrunedFrames_ = nDecodedFrames_ - lookBack;
}
} // namespace text
} // namespace lib
} // namespace torchaudio
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/
#pragma once
#include <unordered_map>
#include "torchaudio/csrc/decoder/src/decoder/Decoder.h"
#include "torchaudio/csrc/decoder/src/decoder/Trie.h"
#include "torchaudio/csrc/decoder/src/decoder/lm/LM.h"
namespace torchaudio {
namespace lib {
namespace text {
struct LexiconDecoderOptions {
int beamSize; // Maximum number of hypothesis we hold after each step
int beamSizeToken; // Maximum number of tokens we consider at each step
double beamThreshold; // Threshold to prune hypothesis
double lmWeight; // Weight of lm
double wordScore; // Word insertion score
double unkScore; // Unknown word insertion score
double silScore; // Silence insertion score
bool logAdd; // If or not use logadd when merging hypothesis
CriterionType criterionType; // CTC or ASG
};
/**
* LexiconDecoderState stores information for each hypothesis in the beam.
*/
struct LexiconDecoderState {
double score; // Accumulated total score so far
LMStatePtr lmState; // Language model state
const TrieNode* lex; // Trie node in the lexicon
const LexiconDecoderState* parent; // Parent hypothesis
int token; // Label of token
int word; // Label of word (-1 if incomplete)
bool prevBlank; // If previous hypothesis is blank (for CTC only)
double amScore; // Accumulated AM score so far
double lmScore; // Accumulated LM score so far
LexiconDecoderState(
const double score,
const LMStatePtr& lmState,
const TrieNode* lex,
const LexiconDecoderState* parent,
const int token,
const int word,
const bool prevBlank = false,
const double amScore = 0,
const double lmScore = 0)
: score(score),
lmState(lmState),
lex(lex),
parent(parent),
token(token),
word(word),
prevBlank(prevBlank),
amScore(amScore),
lmScore(lmScore) {}
LexiconDecoderState()
: score(0.),
lmState(nullptr),
lex(nullptr),
parent(nullptr),
token(-1),
word(-1),
prevBlank(false),
amScore(0.),
lmScore(0.) {}
int compareNoScoreStates(const LexiconDecoderState* node) const {
int lmCmp = lmState->compare(node->lmState);
if (lmCmp != 0) {
return lmCmp > 0 ? 1 : -1;
} else if (lex != node->lex) {
return lex > node->lex ? 1 : -1;
} else if (token != node->token) {
return token > node->token ? 1 : -1;
} else if (prevBlank != node->prevBlank) {
return prevBlank > node->prevBlank ? 1 : -1;
}
return 0;
}
int getWord() const {
return word;
}
bool isComplete() const {
return !parent || parent->word >= 0;
}
};
/**
* Decoder implements a beam seach decoder that finds the word transcription
* W maximizing:
*
* AM(W) + lmWeight_ * log(P_{lm}(W)) + wordScore_ * |W_known| + unkScore_ *
* |W_unknown| + silScore_ * |{i| pi_i = <sil>}|
*
* where P_{lm}(W) is the language model score, pi_i is the value for the i-th
* frame in the path leading to W and AM(W) is the (unnormalized) acoustic model
* score of the transcription W. Note that the lexicon is used to limit the
* search space and all candidate words are generated from it if unkScore is
* -inf, otherwise <UNK> will be generated for OOVs.
*/
class LexiconDecoder : public Decoder {
public:
LexiconDecoder(
LexiconDecoderOptions opt,
const TriePtr& lexicon,
const LMPtr& lm,
const int sil,
const int blank,
const int unk,
const std::vector<float>& transitions,
const bool isLmToken)
: opt_(std::move(opt)),
lexicon_(lexicon),
lm_(lm),
sil_(sil),
blank_(blank),
unk_(unk),
transitions_(transitions),
isLmToken_(isLmToken) {}
void decodeBegin() override;
void decodeStep(const float* emissions, int T, int N) override;
void decodeEnd() override;
int nHypothesis() const;
void prune(int lookBack = 0) override;
int nDecodedFramesInBuffer() const override;
DecodeResult getBestHypothesis(int lookBack = 0) const override;
std::vector<DecodeResult> getAllFinalHypothesis() const override;
protected:
LexiconDecoderOptions opt_;
// Lexicon trie to restrict beam-search decoder
TriePtr lexicon_;
LMPtr lm_;
// Index of silence label
int sil_;
// Index of blank label (for CTC)
int blank_;
// Index of unknown word
int unk_;
// matrix of transitions (for ASG criterion)
std::vector<float> transitions_;
// if LM is token-level (operates on the same level as acoustic model)
// or it is word-level (in case of false)
bool isLmToken_;
// All the hypothesis new candidates (can be larger than beamsize) proposed
// based on the ones from previous frame
std::vector<LexiconDecoderState> candidates_;
// This vector is designed for efficient sorting and merging the candidates_,
// so instead of moving around objects, we only need to sort pointers
std::vector<LexiconDecoderState*> candidatePtrs_;
// Best candidate score of current frame
double candidatesBestScore_;
// Vector of hypothesis for all the frames so far
std::unordered_map<int, std::vector<LexiconDecoderState>> hyp_;
// These 2 variables are used for online decoding, for hypothesis pruning
int nDecodedFrames_; // Total number of decoded frames.
int nPrunedFrames_; // Total number of pruned frames from hyp_.
};
} // namespace text
} // namespace lib
} // namespace torchaudio
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/
#include <math.h>
#include <stdlib.h>
#include <iostream>
#include <limits>
#include "torchaudio/csrc/decoder/src/decoder/Trie.h"
namespace torchaudio {
namespace lib {
namespace text {
const double kMinusLogThreshold = -39.14;
const TrieNode* Trie::getRoot() const {
return root_.get();
}
TrieNodePtr Trie::insert(
const std::vector<int>& indices,
int label,
float score) {
TrieNodePtr node = root_;
for (int i = 0; i < indices.size(); i++) {
int idx = indices[i];
if (idx < 0 || idx >= maxChildren_) {
throw std::out_of_range(
"[Trie] Invalid letter index: " + std::to_string(idx));
}
if (node->children.find(idx) == node->children.end()) {
node->children[idx] = std::make_shared<TrieNode>(idx);
}
node = node->children[idx];
}
if (node->labels.size() < kTrieMaxLabel) {
node->labels.push_back(label);
node->scores.push_back(score);
} else {
std::cerr << "[Trie] Trie label number reached limit: " << kTrieMaxLabel
<< "\n";
}
return node;
}
TrieNodePtr Trie::search(const std::vector<int>& indices) {
TrieNodePtr node = root_;
for (auto idx : indices) {
if (idx < 0 || idx >= maxChildren_) {
throw std::out_of_range(
"[Trie] Invalid letter index: " + std::to_string(idx));
}
if (node->children.find(idx) == node->children.end()) {
return nullptr;
}
node = node->children[idx];
}
return node;
}
/* logadd */
double TrieLogAdd(double log_a, double log_b) {
double minusdif;
if (log_a < log_b) {
std::swap(log_a, log_b);
}
minusdif = log_b - log_a;
if (minusdif < kMinusLogThreshold) {
return log_a;
} else {
return log_a + log1p(exp(minusdif));
}
}
void smearNode(TrieNodePtr node, SmearingMode smearMode) {
node->maxScore = -std::numeric_limits<float>::infinity();
for (auto score : node->scores) {
node->maxScore = TrieLogAdd(node->maxScore, score);
}
for (auto child : node->children) {
auto childNode = child.second;
smearNode(childNode, smearMode);
if (smearMode == SmearingMode::LOGADD) {
node->maxScore = TrieLogAdd(node->maxScore, childNode->maxScore);
} else if (
smearMode == SmearingMode::MAX &&
childNode->maxScore > node->maxScore) {
node->maxScore = childNode->maxScore;
}
}
}
void Trie::smear(SmearingMode smearMode) {
if (smearMode != SmearingMode::NONE) {
smearNode(root_, smearMode);
}
}
} // namespace text
} // namespace lib
} // namespace torchaudio
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/
#pragma once
#include <memory>
#include <unordered_map>
#include <vector>
namespace torchaudio {
namespace lib {
namespace text {
constexpr int kTrieMaxLabel = 6;
enum class SmearingMode {
NONE = 0,
MAX = 1,
LOGADD = 2,
};
/**
* TrieNode is the trie node structure in Trie.
*/
struct TrieNode {
explicit TrieNode(int idx)
: children(std::unordered_map<int, std::shared_ptr<TrieNode>>()),
idx(idx),
maxScore(0) {
labels.reserve(kTrieMaxLabel);
scores.reserve(kTrieMaxLabel);
}
// Pointers to the children of a node
std::unordered_map<int, std::shared_ptr<TrieNode>> children;
// Node index
int idx;
// Labels of words that are constructed from the given path. Note that
// `labels` is nonempty only if the current node represents a completed token.
std::vector<int> labels;
// Scores (`scores` should have the same size as `labels`)
std::vector<float> scores;
// Maximum score of all the labels if this node is a leaf,
// otherwise it will be the value after trie smearing.
float maxScore;
};
using TrieNodePtr = std::shared_ptr<TrieNode>;
/**
* Trie is used to store the lexicon in langiage model. We use it to limit
* the search space in deocder and quickly look up scores for a given token
* (completed word) or make prediction for incompleted ones based on smearing.
*/
class Trie {
public:
Trie(int maxChildren, int rootIdx)
: root_(std::make_shared<TrieNode>(rootIdx)), maxChildren_(maxChildren) {}
/* Return the root node pointer */
const TrieNode* getRoot() const;
/* Insert a token into trie with label */
TrieNodePtr insert(const std::vector<int>& indices, int label, float score);
/* Get the labels for a given token */
TrieNodePtr search(const std::vector<int>& indices);
/**
* Smearing the trie using the valid labels inserted in the trie so as to get
* score on each node (incompleted token).
* For example, if smear_mode is MAX, then for node "a" in path "c"->"a", we
* will select the maximum score from all its children like "c"->"a"->"t",
* "c"->"a"->"n", "c"->"a"->"r"->"e" and so on.
* This process will be carry out recusively on all the nodes.
*/
void smear(const SmearingMode smear_mode);
private:
TrieNodePtr root_;
int maxChildren_; // The maximum number of childern for each node. It is
// usually the size of letters or phonmes.
};
using TriePtr = std::shared_ptr<Trie>;
} // namespace text
} // namespace lib
} // namespace torchaudio
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/
namespace torchaudio {
namespace lib {
namespace text {
// Place holder
} // namespace text
} // namespace lib
} // namespace torchaudio
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/
#pragma once
#include <algorithm>
#include <cmath>
#include <unordered_map>
#include <vector>
#include "torchaudio/csrc/decoder/src/decoder/lm/LM.h"
namespace torchaudio {
namespace lib {
namespace text {
/* ===================== Definitions ===================== */
const double kNegativeInfinity = -std::numeric_limits<double>::infinity();
const int kLookBackLimit = 100;
struct DecodeResult {
double score;
double amScore;
double lmScore;
std::vector<int> words;
std::vector<int> tokens;
explicit DecodeResult(int length = 0)
: score(0), words(length, -1), tokens(length, -1) {}
};
/* ===================== Candidate-related operations ===================== */
template <class DecoderState>
void candidatesReset(
double& candidatesBestScore,
std::vector<DecoderState>& candidates,
std::vector<DecoderState*>& candidatePtrs) {
candidatesBestScore = kNegativeInfinity;
candidates.clear();
candidatePtrs.clear();
}
template <class DecoderState, class... Args>
void candidatesAdd(
std::vector<DecoderState>& candidates,
double& candidatesBestScore,
const double beamThreshold,
const double score,
const Args&... args) {
if (score >= candidatesBestScore) {
candidatesBestScore = score;
}
if (score >= candidatesBestScore - beamThreshold) {
candidates.emplace_back(score, args...);
}
}
template <class DecoderState>
void candidatesStore(
std::vector<DecoderState>& candidates,
std::vector<DecoderState*>& candidatePtrs,
std::vector<DecoderState>& outputs,
const int beamSize,
const double threshold,
const bool logAdd,
const bool returnSorted) {
outputs.clear();
if (candidates.empty()) {
return;
}
/* 1. Select valid candidates */
for (auto& candidate : candidates) {
if (candidate.score >= threshold) {
candidatePtrs.emplace_back(&candidate);
}
}
/* 2. Merge candidates */
std::sort(
candidatePtrs.begin(),
candidatePtrs.end(),
[](const DecoderState* node1, const DecoderState* node2) {
int cmp = node1->compareNoScoreStates(node2);
return cmp == 0 ? node1->score > node2->score : cmp > 0;
});
int nHypAfterMerging = 1;
for (int i = 1; i < candidatePtrs.size(); i++) {
if (candidatePtrs[i]->compareNoScoreStates(
candidatePtrs[nHypAfterMerging - 1]) != 0) {
// Distinct candidate
candidatePtrs[nHypAfterMerging] = candidatePtrs[i];
nHypAfterMerging++;
} else {
// Same candidate
double maxScore = std::max(
candidatePtrs[nHypAfterMerging - 1]->score, candidatePtrs[i]->score);
if (logAdd) {
double minScore = std::min(
candidatePtrs[nHypAfterMerging - 1]->score,
candidatePtrs[i]->score);
candidatePtrs[nHypAfterMerging - 1]->score =
maxScore + std::log1p(std::exp(minScore - maxScore));
} else {
candidatePtrs[nHypAfterMerging - 1]->score = maxScore;
}
}
}
candidatePtrs.resize(nHypAfterMerging);
/* 3. Sort and prune */
auto compareNodeScore = [](const DecoderState* node1,
const DecoderState* node2) {
return node1->score > node2->score;
};
int nValidHyp = candidatePtrs.size();
int finalSize = std::min(nValidHyp, beamSize);
if (!returnSorted && nValidHyp > beamSize) {
std::nth_element(
candidatePtrs.begin(),
candidatePtrs.begin() + finalSize,
candidatePtrs.begin() + nValidHyp,
compareNodeScore);
} else if (returnSorted) {
std::partial_sort(
candidatePtrs.begin(),
candidatePtrs.begin() + finalSize,
candidatePtrs.begin() + nValidHyp,
compareNodeScore);
}
for (int i = 0; i < finalSize; i++) {
outputs.emplace_back(std::move(*candidatePtrs[i]));
}
}
/* ===================== Result-related operations ===================== */
template <class DecoderState>
DecodeResult getHypothesis(const DecoderState* node, const int finalFrame) {
const DecoderState* node_ = node;
if (!node_) {
return DecodeResult();
}
DecodeResult res(finalFrame + 1);
res.score = node_->score;
res.amScore = node_->amScore;
res.lmScore = node_->lmScore;
int i = 0;
while (node_) {
res.words[finalFrame - i] = node_->getWord();
res.tokens[finalFrame - i] = node_->token;
node_ = node_->parent;
i++;
}
return res;
}
template <class DecoderState>
std::vector<DecodeResult> getAllHypothesis(
const std::vector<DecoderState>& finalHyps,
const int finalFrame) {
int nHyp = finalHyps.size();
std::vector<DecodeResult> res(nHyp);
for (int r = 0; r < nHyp; r++) {
const DecoderState* node = &finalHyps[r];
res[r] = getHypothesis(node, finalFrame);
}
return res;
}
template <class DecoderState>
const DecoderState* findBestAncestor(
const std::vector<DecoderState>& finalHyps,
int& lookBack) {
int nHyp = finalHyps.size();
if (nHyp == 0) {
return nullptr;
}
double bestScore = finalHyps.front().score;
const DecoderState* bestNode = finalHyps.data();
for (int r = 1; r < nHyp; r++) {
const DecoderState* node = &finalHyps[r];
if (node->score > bestScore) {
bestScore = node->score;
bestNode = node;
}
}
int n = 0;
while (bestNode && n < lookBack) {
n++;
bestNode = bestNode->parent;
}
const int maxLookBack = lookBack + kLookBackLimit;
while (bestNode) {
// Check for first emitted word.
if (bestNode->isComplete()) {
break;
}
n++;
bestNode = bestNode->parent;
if (n == maxLookBack) {
break;
}
}
lookBack = n;
return bestNode;
}
template <class DecoderState>
void pruneAndNormalize(
std::unordered_map<int, std::vector<DecoderState>>& hypothesis,
const int startFrame,
const int lookBack) {
/* 1. Move things from back of hypothesis to front. */
for (int i = 0; i < hypothesis.size(); i++) {
if (i <= lookBack) {
hypothesis[i].swap(hypothesis[i + startFrame]);
} else {
hypothesis[i].clear();
}
}
/* 2. Avoid further back-tracking */
for (DecoderState& hyp : hypothesis[0]) {
hyp.parent = nullptr;
}
/* 3. Avoid score underflow/overflow. */
double largestScore = hypothesis[lookBack].front().score;
for (int i = 1; i < hypothesis[lookBack].size(); i++) {
if (largestScore < hypothesis[lookBack][i].score) {
largestScore = hypothesis[lookBack][i].score;
}
}
for (int i = 0; i < hypothesis[lookBack].size(); i++) {
hypothesis[lookBack][i].score -= largestScore;
}
}
/* ===================== LM-related operations ===================== */
template <class DecoderState>
void updateLMCache(const LMPtr& lm, std::vector<DecoderState>& hypothesis) {
// For ConvLM update cache
std::vector<LMStatePtr> states;
for (const auto& hyp : hypothesis) {
states.emplace_back(hyp.lmState);
}
lm->updateCache(states);
}
} // namespace text
} // namespace lib
} // namespace torchaudio
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/
#include "torchaudio/csrc/decoder/src/decoder/lm/KenLM.h"
#include <stdexcept>
#include "kenlm/lm/model.hh"
namespace torchaudio {
namespace lib {
namespace text {
KenLMState::KenLMState() : ken_(std::make_unique<lm::ngram::State>()) {}
KenLM::KenLM(const std::string& path, const Dictionary& usrTknDict) {
// Load LM
model_.reset(lm::ngram::LoadVirtual(path.c_str()));
if (!model_) {
throw std::runtime_error("[KenLM] LM loading failed.");
}
vocab_ = &model_->BaseVocabulary();
if (!vocab_) {
throw std::runtime_error("[KenLM] LM vocabulary loading failed.");
}
// Create index map
usrToLmIdxMap_.resize(usrTknDict.indexSize());
for (int i = 0; i < usrTknDict.indexSize(); i++) {
auto token = usrTknDict.getEntry(i);
int lmIdx = vocab_->Index(token.c_str());
usrToLmIdxMap_[i] = lmIdx;
}
}
LMStatePtr KenLM::start(bool startWithNothing) {
auto outState = std::make_shared<KenLMState>();
if (startWithNothing) {
model_->NullContextWrite(outState->ken());
} else {
model_->BeginSentenceWrite(outState->ken());
}
return outState;
}
std::pair<LMStatePtr, float> KenLM::score(
const LMStatePtr& state,
const int usrTokenIdx) {
if (usrTokenIdx < 0 || usrTokenIdx >= usrToLmIdxMap_.size()) {
throw std::runtime_error(
"[KenLM] Invalid user token index: " + std::to_string(usrTokenIdx));
}
auto inState = std::static_pointer_cast<KenLMState>(state);
auto outState = inState->child<KenLMState>(usrTokenIdx);
float score = model_->BaseScore(
inState->ken(), usrToLmIdxMap_[usrTokenIdx], outState->ken());
return std::make_pair(std::move(outState), score);
}
std::pair<LMStatePtr, float> KenLM::finish(const LMStatePtr& state) {
auto inState = std::static_pointer_cast<KenLMState>(state);
auto outState = inState->child<KenLMState>(-1);
float score =
model_->BaseScore(inState->ken(), vocab_->EndSentence(), outState->ken());
return std::make_pair(std::move(outState), score);
}
} // namespace text
} // namespace lib
} // namespace torchaudio
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/
#pragma once
#include <memory>
#include "torchaudio/csrc/decoder/src/decoder/lm/LM.h"
#include "torchaudio/csrc/decoder/src/dictionary/Dictionary.h"
// Forward declarations to avoid including KenLM headers
namespace lm {
namespace base {
struct Vocabulary;
struct Model;
} // namespace base
namespace ngram {
struct State;
} // namespace ngram
} // namespace lm
namespace torchaudio {
namespace lib {
namespace text {
/**
* KenLMState is a state object from KenLM, which contains context length,
* indicies and compare functions
* https://github.com/kpu/kenlm/blob/master/lm/state.hh.
*/
struct KenLMState : LMState {
KenLMState();
std::unique_ptr<lm::ngram::State> ken_;
lm::ngram::State* ken() {
return ken_.get();
}
};
/**
* KenLM extends LM by using the toolkit https://kheafield.com/code/kenlm/.
*/
class KenLM : public LM {
public:
KenLM(const std::string& path, const Dictionary& usrTknDict);
LMStatePtr start(bool startWithNothing) override;
std::pair<LMStatePtr, float> score(
const LMStatePtr& state,
const int usrTokenIdx) override;
std::pair<LMStatePtr, float> finish(const LMStatePtr& state) override;
private:
std::shared_ptr<lm::base::Model> model_;
const lm::base::Vocabulary* vocab_;
};
using KenLMPtr = std::shared_ptr<KenLM>;
} // namespace text
} // namespace lib
} // namespace torchaudio
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/
#pragma once
#include <cstring>
#include <memory>
#include <stdexcept>
#include <unordered_map>
#include <utility>
#include <vector>
namespace torchaudio {
namespace lib {
namespace text {
struct LMState {
std::unordered_map<int, std::shared_ptr<LMState>> children;
template <typename T>
std::shared_ptr<T> child(int usrIdx) {
auto s = children.find(usrIdx);
if (s == children.end()) {
auto state = std::make_shared<T>();
children[usrIdx] = state;
return state;
} else {
return std::static_pointer_cast<T>(s->second);
}
}
/* Compare two language model states. */
int compare(const std::shared_ptr<LMState>& state) const {
LMState* inState = state.get();
if (!state) {
throw std::runtime_error("a state is null");
}
if (this == inState) {
return 0;
} else if (this < inState) {
return -1;
} else {
return 1;
}
};
};
/**
* LMStatePtr is a shared LMState* tracking LM states generated during decoding.
*/
using LMStatePtr = std::shared_ptr<LMState>;
/**
* LM is a thin wrapper for laguage models. We abstrct several common methods
* here which can be shared for KenLM, ConvLM, RNNLM, etc.
*/
class LM {
public:
/* Initialize or reset language model */
virtual LMStatePtr start(bool startWithNothing) = 0;
/**
* Query the language model given input language model state and a specific
* token, return a new language model state and score.
*/
virtual std::pair<LMStatePtr, float> score(
const LMStatePtr& state,
const int usrTokenIdx) = 0;
/* Query the language model and finish decoding. */
virtual std::pair<LMStatePtr, float> finish(const LMStatePtr& state) = 0;
/* Update LM caches (optional) given a bunch of new states generated */
virtual void updateCache(std::vector<LMStatePtr> stateIdices) {}
virtual ~LM() = default;
protected:
/* Map indices from acoustic model to LM for each valid token. */
std::vector<int> usrToLmIdxMap_;
};
using LMPtr = std::shared_ptr<LM>;
} // namespace text
} // namespace lib
} // namespace torchaudio
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/
#pragma once
namespace torchaudio {
namespace lib {
namespace text {
constexpr const char* kUnkToken = "<unk>";
constexpr const char* kEosToken = "</s>";
constexpr const char* kPadToken = "<pad>";
constexpr const char* kMaskToken = "<mask>";
} // namespace text
} // namespace lib
} // namespace torchaudio
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/
#include <iostream>
#include <stdexcept>
#include "torchaudio/csrc/decoder/src/dictionary/Dictionary.h"
#include "torchaudio/csrc/decoder/src/dictionary/String.h"
#include "torchaudio/csrc/decoder/src/dictionary/System.h"
#include "torchaudio/csrc/decoder/src/dictionary/Utils.h"
namespace torchaudio {
namespace lib {
namespace text {
Dictionary::Dictionary(std::istream& stream) {
createFromStream(stream);
}
Dictionary::Dictionary(const std::string& filename) {
std::ifstream stream = createInputStream(filename);
createFromStream(stream);
}
void Dictionary::createFromStream(std::istream& stream) {
if (!stream) {
throw std::runtime_error("Unable to open dictionary input stream.");
}
std::string line;
while (std::getline(stream, line)) {
if (line.empty()) {
continue;
}
auto tkns = splitOnWhitespace(line, true);
auto idx = idx2entry_.size();
// All entries on the same line map to the same index
for (const auto& tkn : tkns) {
addEntry(tkn, idx);
}
}
if (!isContiguous()) {
throw std::runtime_error("Invalid dictionary format - not contiguous");
}
}
void Dictionary::addEntry(const std::string& entry, int idx) {
if (entry2idx_.find(entry) != entry2idx_.end()) {
throw std::invalid_argument(
"Duplicate entry name in dictionary '" + entry + "'");
}
entry2idx_[entry] = idx;
if (idx2entry_.find(idx) == idx2entry_.end()) {
idx2entry_[idx] = entry;
}
}
void Dictionary::addEntry(const std::string& entry) {
// Check if the entry already exists in the dictionary
if (entry2idx_.find(entry) != entry2idx_.end()) {
throw std::invalid_argument(
"Duplicate entry in dictionary '" + entry + "'");
}
int idx = idx2entry_.size();
// Find first available index.
while (idx2entry_.find(idx) != idx2entry_.end()) {
++idx;
}
addEntry(entry, idx);
}
std::string Dictionary::getEntry(int idx) const {
auto iter = idx2entry_.find(idx);
if (iter == idx2entry_.end()) {
throw std::invalid_argument(
"Unknown index in dictionary '" + std::to_string(idx) + "'");
}
return iter->second;
}
void Dictionary::setDefaultIndex(int idx) {
defaultIndex_ = idx;
}
int Dictionary::getIndex(const std::string& entry) const {
auto iter = entry2idx_.find(entry);
if (iter == entry2idx_.end()) {
if (defaultIndex_ < 0) {
throw std::invalid_argument(
"Unknown entry in dictionary: '" + entry + "'");
} else {
return defaultIndex_;
}
}
return iter->second;
}
bool Dictionary::contains(const std::string& entry) const {
auto iter = entry2idx_.find(entry);
if (iter == entry2idx_.end()) {
return false;
}
return true;
}
size_t Dictionary::entrySize() const {
return entry2idx_.size();
}
bool Dictionary::isContiguous() const {
for (size_t i = 0; i < indexSize(); ++i) {
if (idx2entry_.find(i) == idx2entry_.end()) {
return false;
}
}
for (const auto& tknidx : entry2idx_) {
if (idx2entry_.find(tknidx.second) == idx2entry_.end()) {
return false;
}
}
return true;
}
std::vector<int> Dictionary::mapEntriesToIndices(
const std::vector<std::string>& entries) const {
std::vector<int> indices;
indices.reserve(entries.size());
for (const auto& tkn : entries) {
indices.emplace_back(getIndex(tkn));
}
return indices;
}
std::vector<std::string> Dictionary::mapIndicesToEntries(
const std::vector<int>& indices) const {
std::vector<std::string> entries;
entries.reserve(indices.size());
for (const auto& idx : indices) {
entries.emplace_back(getEntry(idx));
}
return entries;
}
size_t Dictionary::indexSize() const {
return idx2entry_.size();
}
} // namespace text
} // namespace lib
} // namespace torchaudio
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/
#pragma once
#include <istream>
#include <string>
#include <unordered_map>
#include <vector>
namespace torchaudio {
namespace lib {
namespace text {
// A simple dictionary class which holds a bidirectional map
// entry (strings) <--> integer indices. Not thread-safe !
class Dictionary {
public:
// Creates an empty dictionary
Dictionary() {}
explicit Dictionary(std::istream& stream);
explicit Dictionary(const std::string& filename);
size_t entrySize() const;
size_t indexSize() const;
void addEntry(const std::string& entry, int idx);
void addEntry(const std::string& entry);
std::string getEntry(int idx) const;
void setDefaultIndex(int idx);
int getIndex(const std::string& entry) const;
bool contains(const std::string& entry) const;
// checks if all the indices are contiguous
bool isContiguous() const;
std::vector<int> mapEntriesToIndices(
const std::vector<std::string>& entries) const;
std::vector<std::string> mapIndicesToEntries(
const std::vector<int>& indices) const;
private:
// Creates a dictionary from an input stream
void createFromStream(std::istream& stream);
std::unordered_map<std::string, int> entry2idx_;
std::unordered_map<int, std::string> idx2entry_;
int defaultIndex_ = -1;
};
typedef std::unordered_map<int, Dictionary> DictionaryMap;
} // namespace text
} // namespace lib
} // namespace torchaudio
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/
#include "torchaudio/csrc/decoder/src/dictionary/String.h"
#include <sys/types.h>
#include <array>
#include <cstdlib>
#include <ctime>
#include <functional>
static constexpr const char* kSpaceChars = "\t\n\v\f\r ";
namespace torchaudio {
namespace lib {
std::string trim(const std::string& str) {
auto i = str.find_first_not_of(kSpaceChars);
if (i == std::string::npos) {
return "";
}
auto j = str.find_last_not_of(kSpaceChars);
if (j == std::string::npos || i > j) {
return "";
}
return str.substr(i, j - i + 1);
}
void replaceAll(
std::string& str,
const std::string& from,
const std::string& repl) {
if (from.empty()) {
return;
}
size_t pos = 0;
while ((pos = str.find(from, pos)) != std::string::npos) {
str.replace(pos, from.length(), repl);
pos += repl.length();
}
}
bool startsWith(const std::string& input, const std::string& pattern) {
return (input.find(pattern) == 0);
}
bool endsWith(const std::string& input, const std::string& pattern) {
if (pattern.size() > input.size()) {
return false;
}
return std::equal(pattern.rbegin(), pattern.rend(), input.rbegin());
}
template <bool Any, typename Delim>
static std::vector<std::string> splitImpl(
const Delim& delim,
std::string::size_type delimSize,
const std::string& input,
bool ignoreEmpty = false) {
std::vector<std::string> result;
std::string::size_type i = 0;
while (true) {
auto j = Any ? input.find_first_of(delim, i) : input.find(delim, i);
if (j == std::string::npos) {
break;
}
if (!(ignoreEmpty && i == j)) {
result.emplace_back(input.begin() + i, input.begin() + j);
}
i = j + delimSize;
}
if (!(ignoreEmpty && i == input.size())) {
result.emplace_back(input.begin() + i, input.end());
}
return result;
}
std::vector<std::string> split(
char delim,
const std::string& input,
bool ignoreEmpty) {
return splitImpl<false>(delim, 1, input, ignoreEmpty);
}
std::vector<std::string> split(
const std::string& delim,
const std::string& input,
bool ignoreEmpty) {
if (delim.empty()) {
throw std::invalid_argument("delimiter is empty string");
}
return splitImpl<false>(delim, delim.size(), input, ignoreEmpty);
}
std::vector<std::string> splitOnAnyOf(
const std::string& delim,
const std::string& input,
bool ignoreEmpty) {
return splitImpl<true>(delim, 1, input, ignoreEmpty);
}
std::vector<std::string> splitOnWhitespace(
const std::string& input,
bool ignoreEmpty) {
return splitOnAnyOf(kSpaceChars, input, ignoreEmpty);
}
std::string join(
const std::string& delim,
const std::vector<std::string>& vec) {
return join(delim, vec.begin(), vec.end());
}
} // namespace lib
} // namespace torchaudio
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/
#pragma once
#include <errno.h>
#include <algorithm>
#include <chrono>
#include <cstring>
#include <stdexcept>
#include <string>
#include <type_traits>
#include <unordered_map>
#include <vector>
namespace torchaudio {
namespace lib {
// ============================ Types and Templates ============================
template <typename It>
using DecayDereference =
typename std::decay<decltype(*std::declval<It>())>::type;
template <typename S, typename T>
using EnableIfSame = typename std::enable_if<std::is_same<S, T>::value>::type;
// ================================== Functions
// ==================================
std::string trim(const std::string& str);
void replaceAll(
std::string& str,
const std::string& from,
const std::string& repl);
bool startsWith(const std::string& input, const std::string& pattern);
bool endsWith(const std::string& input, const std::string& pattern);
std::vector<std::string> split(
char delim,
const std::string& input,
bool ignoreEmpty = false);
std::vector<std::string> split(
const std::string& delim,
const std::string& input,
bool ignoreEmpty = false);
std::vector<std::string> splitOnAnyOf(
const std::string& delim,
const std::string& input,
bool ignoreEmpty = false);
std::vector<std::string> splitOnWhitespace(
const std::string& input,
bool ignoreEmpty = false);
/**
* Join a vector of `std::string` inserting `delim` in between.
*/
std::string join(const std::string& delim, const std::vector<std::string>& vec);
/**
* Join a range of `std::string` specified by iterators.
*/
template <
typename FwdIt,
typename = EnableIfSame<DecayDereference<FwdIt>, std::string>>
std::string join(const std::string& delim, FwdIt begin, FwdIt end) {
if (begin == end) {
return "";
}
size_t totalSize = begin->size();
for (auto it = std::next(begin); it != end; ++it) {
totalSize += delim.size() + it->size();
}
std::string result;
result.reserve(totalSize);
result.append(*begin);
for (auto it = std::next(begin); it != end; ++it) {
result.append(delim);
result.append(*it);
}
return result;
}
/**
* Create an output string using a `printf`-style format string and arguments.
* Safer than `sprintf` which is vulnerable to buffer overflow.
*/
template <class... Args>
std::string format(const char* fmt, Args&&... args) {
auto res = std::snprintf(nullptr, 0, fmt, std::forward<Args>(args)...);
if (res < 0) {
throw std::runtime_error(std::strerror(errno));
}
std::string buf(res, '\0');
// the size here is fine -- it's legal to write '\0' to buf[res]
auto res2 = std::snprintf(&buf[0], res + 1, fmt, std::forward<Args>(args)...);
if (res2 < 0) {
throw std::runtime_error(std::strerror(errno));
}
if (res2 != res) {
throw std::runtime_error(
"The size of the formated string is not equal to what it is expected.");
}
return buf;
}
/**
* Dedup the elements in a vector.
*/
template <class T>
void dedup(std::vector<T>& in) {
if (in.empty()) {
return;
}
auto it = std::unique(in.begin(), in.end());
in.resize(std::distance(in.begin(), it));
}
} // namespace lib
} // namespace torchaudio
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/
#include "torchaudio/csrc/decoder/src/dictionary/System.h"
#include <glob.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <array>
#include <cstdlib>
#include <ctime>
#include <functional>
#ifdef _WIN32
#include <windows.h>
#else
#include <unistd.h>
#endif
#include "torchaudio/csrc/decoder/src/dictionary/String.h"
namespace torchaudio {
namespace lib {
size_t getProcessId() {
#ifdef _WIN32
return GetCurrentProcessId();
#else
return ::getpid();
#endif
}
size_t getThreadId() {
#ifdef _WIN32
return GetCurrentThreadId();
#else
return std::hash<std::thread::id>()(std::this_thread::get_id());
#endif
}
std::string pathSeperator() {
#ifdef _WIN32
return "\\";
#else
return "/";
#endif
}
std::string pathsConcat(const std::string& p1, const std::string& p2) {
if (!p1.empty() && p1[p1.length() - 1] != pathSeperator()[0]) {
return (
trim(p1) + pathSeperator() + trim(p2)); // Need to add a path separator
} else {
return (trim(p1) + trim(p2));
}
}
namespace {
/**
* @path contains directories separated by path separator.
* Returns a vector with the directores in the original order. Vector with a
* Special cases: a vector with a single entry containing the input is returned
* when path is one of the following special cases: empty, “.”, “..” and “/”
*/
std::vector<std::string> getDirsOnPath(const std::string& path) {
const std::string trimPath = trim(path);
if (trimPath.empty() || trimPath == pathSeperator() || trimPath == "." ||
trimPath == "..") {
return {trimPath};
}
const std::vector<std::string> tokens = split(pathSeperator(), trimPath);
std::vector<std::string> dirs;
for (const std::string& token : tokens) {
const std::string dir = trim(token);
if (!dir.empty()) {
dirs.push_back(dir);
}
}
return dirs;
}
} // namespace
std::string dirname(const std::string& path) {
std::vector<std::string> dirsOnPath = getDirsOnPath(path);
if (dirsOnPath.size() < 2) {
return ".";
} else {
dirsOnPath.pop_back();
const std::string root =
((trim(path))[0] == pathSeperator()[0]) ? pathSeperator() : "";
return root + join(pathSeperator(), dirsOnPath);
}
}
std::string basename(const std::string& path) {
std::vector<std::string> dirsOnPath = getDirsOnPath(path);
if (dirsOnPath.empty()) {
return "";
} else {
return dirsOnPath.back();
}
}
bool dirExists(const std::string& path) {
struct stat info;
if (stat(path.c_str(), &info) != 0) {
return false;
} else if (info.st_mode & S_IFDIR) {
return true;
} else {
return false;
}
}
void dirCreate(const std::string& path) {
if (dirExists(path)) {
return;
}
mode_t nMode = 0755;
int nError = 0;
#ifdef _WIN32
nError = _mkdir(path.c_str());
#else
nError = mkdir(path.c_str(), nMode);
#endif
if (nError != 0) {
throw std::runtime_error(
std::string() + "Unable to create directory - " + path);
}
}
void dirCreateRecursive(const std::string& path) {
if (dirExists(path)) {
return;
}
std::vector<std::string> dirsOnPath = getDirsOnPath(path);
std::string pathFromStart;
if (path[0] == pathSeperator()[0]) {
pathFromStart = pathSeperator();
}
for (std::string& dir : dirsOnPath) {
if (pathFromStart.empty()) {
pathFromStart = dir;
} else {
pathFromStart = pathsConcat(pathFromStart, dir);
}
if (!dirExists(pathFromStart)) {
dirCreate(pathFromStart);
}
}
}
bool fileExists(const std::string& path) {
std::ifstream fs(path, std::ifstream::in);
return fs.good();
}
std::string getEnvVar(
const std::string& key,
const std::string& dflt /*= "" */) {
char* val = getenv(key.c_str());
return val ? std::string(val) : dflt;
}
std::string getCurrentDate() {
time_t now = time(nullptr);
struct tm tmbuf;
struct tm* tstruct;
tstruct = localtime_r(&now, &tmbuf);
std::array<char, 80> buf;
strftime(buf.data(), buf.size(), "%Y-%m-%d", tstruct);
return std::string(buf.data());
}
std::string getCurrentTime() {
time_t now = time(nullptr);
struct tm tmbuf;
struct tm* tstruct;
tstruct = localtime_r(&now, &tmbuf);
std::array<char, 80> buf;
strftime(buf.data(), buf.size(), "%X", tstruct);
return std::string(buf.data());
}
std::string getTmpPath(const std::string& filename) {
std::string tmpDir = "/tmp";
auto getTmpDir = [&tmpDir](const std::string& env) {
char* dir = std::getenv(env.c_str());
if (dir != nullptr) {
tmpDir = std::string(dir);
}
};
getTmpDir("TMPDIR");
getTmpDir("TEMP");
getTmpDir("TMP");
return tmpDir + "/fl_tmp_" + getEnvVar("USER", "unknown") + "_" + filename;
}
std::vector<std::string> getFileContent(const std::string& file) {
std::vector<std::string> data;
std::ifstream in = createInputStream(file);
std::string str;
while (std::getline(in, str)) {
data.emplace_back(str);
}
in.close();
return data;
}
std::vector<std::string> fileGlob(const std::string& pat) {
glob_t result;
glob(pat.c_str(), GLOB_TILDE, nullptr, &result);
std::vector<std::string> ret;
for (unsigned int i = 0; i < result.gl_pathc; ++i) {
ret.push_back(std::string(result.gl_pathv[i]));
}
globfree(&result);
return ret;
}
std::ifstream createInputStream(const std::string& filename) {
std::ifstream file(filename);
if (!file.is_open()) {
throw std::runtime_error("Failed to open file for reading: " + filename);
}
return file;
}
std::ofstream createOutputStream(
const std::string& filename,
std::ios_base::openmode mode) {
std::ofstream file(filename, mode);
if (!file.is_open()) {
throw std::runtime_error("Failed to open file for writing: " + filename);
}
return file;
}
} // namespace lib
} // namespace torchaudio
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/
#pragma once
#include <chrono>
#include <fstream>
#include <string>
#include <thread>
#include <type_traits>
#include <vector>
namespace torchaudio {
namespace lib {
size_t getProcessId();
size_t getThreadId();
std::string pathsConcat(const std::string& p1, const std::string& p2);
std::string pathSeperator();
std::string dirname(const std::string& path);
std::string basename(const std::string& path);
bool dirExists(const std::string& path);
void dirCreate(const std::string& path);
void dirCreateRecursive(const std::string& path);
bool fileExists(const std::string& path);
std::string getEnvVar(const std::string& key, const std::string& dflt = "");
std::string getCurrentDate();
std::string getCurrentTime();
std::string getTmpPath(const std::string& filename);
std::vector<std::string> getFileContent(const std::string& file);
std::vector<std::string> fileGlob(const std::string& pat);
std::ifstream createInputStream(const std::string& filename);
std::ofstream createOutputStream(
const std::string& filename,
std::ios_base::openmode mode = std::ios_base::out);
/**
* Calls `f(args...)` repeatedly, retrying if an exception is thrown.
* Supports sleeps between retries, with duration starting at `initial` and
* multiplying by `factor` each retry. At most `maxIters` calls are made.
*/
template <class Fn, class... Args>
typename std::result_of<Fn(Args...)>::type retryWithBackoff(
std::chrono::duration<double> initial,
double factor,
int64_t maxIters,
Fn&& f,
Args&&... args) {
if (!(initial.count() >= 0.0)) {
throw std::invalid_argument("retryWithBackoff: bad initial");
} else if (!(factor >= 0.0)) {
throw std::invalid_argument("retryWithBackoff: bad factor");
} else if (maxIters <= 0) {
throw std::invalid_argument("retryWithBackoff: bad maxIters");
}
auto sleepSecs = initial.count();
for (int64_t i = 0; i < maxIters; ++i) {
try {
return f(std::forward<Args>(args)...);
} catch (...) {
if (i >= maxIters - 1) {
throw;
}
}
if (sleepSecs > 0.0) {
/* sleep override */
std::this_thread::sleep_for(
std::chrono::duration<double>(std::min(1e7, sleepSecs)));
}
sleepSecs *= factor;
}
throw std::logic_error("retryWithBackoff: hit unreachable");
}
} // namespace lib
} // namespace torchaudio
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/
#include "torchaudio/csrc/decoder/src/dictionary/Utils.h"
#include "torchaudio/csrc/decoder/src/dictionary/Defines.h"
#include "torchaudio/csrc/decoder/src/dictionary/String.h"
#include "torchaudio/csrc/decoder/src/dictionary/System.h"
namespace torchaudio {
namespace lib {
namespace text {
Dictionary createWordDict(const LexiconMap& lexicon) {
Dictionary dict;
for (const auto& it : lexicon) {
dict.addEntry(it.first);
}
dict.setDefaultIndex(dict.getIndex(kUnkToken));
return dict;
}
LexiconMap loadWords(const std::string& filename, int maxWords) {
LexiconMap lexicon;
std::string line;
std::ifstream infile = createInputStream(filename);
// Add at most `maxWords` words into the lexicon.
// If `maxWords` is negative then no limit is applied.
while (maxWords != lexicon.size() && std::getline(infile, line)) {
// Parse the line into two strings: word and spelling.
auto fields = splitOnWhitespace(line, true);
if (fields.size() < 2) {
throw std::runtime_error("[loadWords] Invalid line: " + line);
}
const std::string& word = fields[0];
std::vector<std::string> spelling(fields.size() - 1);
std::copy(fields.begin() + 1, fields.end(), spelling.begin());
// Add the word into the dictionary.
if (lexicon.find(word) == lexicon.end()) {
lexicon[word] = {};
}
// Add the current spelling of the words to the list of spellings.
lexicon[word].push_back(spelling);
}
// Insert unknown word.
lexicon[kUnkToken] = {};
return lexicon;
}
std::vector<std::string> splitWrd(const std::string& word) {
std::vector<std::string> tokens;
tokens.reserve(word.size());
int len = word.length();
for (int i = 0; i < len;) {
auto c = static_cast<unsigned char>(word[i]);
int curTknBytes = -1;
// UTF-8 checks, works for ASCII automatically
if ((c & 0x80) == 0) {
curTknBytes = 1;
} else if ((c & 0xE0) == 0xC0) {
curTknBytes = 2;
} else if ((c & 0xF0) == 0xE0) {
curTknBytes = 3;
} else if ((c & 0xF8) == 0xF0) {
curTknBytes = 4;
}
if (curTknBytes == -1 || i + curTknBytes > len) {
throw std::runtime_error("splitWrd: invalid UTF-8 : " + word);
}
tokens.emplace_back(word.begin() + i, word.begin() + i + curTknBytes);
i += curTknBytes;
}
return tokens;
}
std::vector<int> packReplabels(
const std::vector<int>& tokens,
const Dictionary& dict,
int maxReps) {
if (tokens.empty() || maxReps <= 0) {
return tokens;
}
std::vector<int> replabelValueToIdx(maxReps + 1);
for (int i = 1; i <= maxReps; ++i) {
replabelValueToIdx[i] = dict.getIndex("<" + std::to_string(i) + ">");
}
std::vector<int> result;
int prevToken = -1;
int numReps = 0;
for (int token : tokens) {
if (token == prevToken && numReps < maxReps) {
numReps++;
} else {
if (numReps > 0) {
result.push_back(replabelValueToIdx[numReps]);
numReps = 0;
}
result.push_back(token);
prevToken = token;
}
}
if (numReps > 0) {
result.push_back(replabelValueToIdx[numReps]);
}
return result;
}
std::vector<int> unpackReplabels(
const std::vector<int>& tokens,
const Dictionary& dict,
int maxReps) {
if (tokens.empty() || maxReps <= 0) {
return tokens;
}
std::unordered_map<int, int> replabelIdxToValue;
for (int i = 1; i <= maxReps; ++i) {
replabelIdxToValue.emplace(dict.getIndex("<" + std::to_string(i) + ">"), i);
}
std::vector<int> result;
int prevToken = -1;
for (int token : tokens) {
auto it = replabelIdxToValue.find(token);
if (it == replabelIdxToValue.end()) {
result.push_back(token);
prevToken = token;
} else if (prevToken != -1) {
result.insert(result.end(), it->second, prevToken);
prevToken = -1;
}
}
return result;
}
} // namespace text
} // namespace lib
} // namespace torchaudio
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/
#pragma once
#include <string>
#include <unordered_map>
#include <vector>
#include "torchaudio/csrc/decoder/src/dictionary/Dictionary.h"
namespace torchaudio {
namespace lib {
namespace text {
using LexiconMap =
std::unordered_map<std::string, std::vector<std::vector<std::string>>>;
Dictionary createWordDict(const LexiconMap& lexicon);
LexiconMap loadWords(const std::string& filename, int maxWords = -1);
// split word into tokens abc -> {"a", "b", "c"}
// Works with ASCII, UTF-8 encodings
std::vector<std::string> splitWrd(const std::string& word);
/**
* Pack a token sequence by replacing consecutive repeats with replabels,
* e.g. "abbccc" -> "ab1c2". The tokens "1", "2", ..., `to_string(maxReps)`
* must already be in `dict`.
*/
std::vector<int> packReplabels(
const std::vector<int>& tokens,
const Dictionary& dict,
int maxReps);
/**
* Unpack a token sequence by replacing replabels with repeated tokens,
* e.g. "ab1c2" -> "abbccc". The tokens "1", "2", ..., `to_string(maxReps)`
* must already be in `dict`.
*/
std::vector<int> unpackReplabels(
const std::vector<int>& tokens,
const Dictionary& dict,
int maxReps);
} // namespace text
} // namespace lib
} // namespace torchaudio
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment