"vscode:/vscode.git/clone" did not exist on "fa7e2c3049fd0ec38502f5a847dc624a96363b34"
Commit 764b3a75 authored by Sugon_ldc's avatar Sugon_ldc
Browse files

add new model

parents
// fstext/pre-determinize.h
// Copyright 2009-2011 Microsoft Corporation
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_FSTEXT_PRE_DETERMINIZE_H_
#define KALDI_FSTEXT_PRE_DETERMINIZE_H_
#include <fst/fst-decl.h>
#include <fst/fstlib.h>
#include <algorithm>
#include <map>
#include <set>
#include <string>
#include <vector>
#include "base/kaldi-common.h"
namespace fst {
/* PreDeterminize inserts extra symbols on the input side of an FST as necessary
to ensure that, after epsilon removal, it will be compactly determinizable by
the determinize* algorithm. By compactly determinizable we mean that no
original FST state is represented in more than one determinized state).
Caution: this code is now only used in testing.
The new symbols start from the value "first_new_symbol", which should be
higher than the largest-numbered symbol currently in the FST. The new
symbols added are put in the array syms_out, which should be empty at start.
*/
template <class Arc, class Int>
void PreDeterminize(MutableFst<Arc> *fst, typename Arc::Label first_new_symbol,
std::vector<Int> *syms_out);
/* CreateNewSymbols is a helper function used inside PreDeterminize, and is also
useful when you need to add a number of extra symbols to a different
vocabulary from the one modified by PreDeterminize. */
template <class Label>
void CreateNewSymbols(SymbolTable *inputSymTable, int nSym, std::string prefix,
std::vector<Label> *syms_out);
/** AddSelfLoops is a function you will probably want to use alongside
PreDeterminize, to add self-loops to any FSTs that you compose on the left
hand side of the one modified by PreDeterminize.
This function inserts loops with "special symbols" [e.g. \#0, \#1] into an
FST. This is done at each final state and each state with non-epsilon output
symbols on at least one arc out of it. This is to ensure that these symbols,
when inserted into the input side of an FST we will compose with on the
right, can "pass through" this FST.
At input, isyms and osyms must be vectors of the same size n, corresponding
to symbols that currently do not exist in 'fst'. For each state in n that
has non-epsilon symbols on the output side of arcs leaving it, or which is a
final state, this function inserts n self-loops with unit weight and one of
the n pairs of symbols on its input and output.
*/
template <class Arc>
void AddSelfLoops(MutableFst<Arc> *fst,
const std::vector<typename Arc::Label> &isyms,
const std::vector<typename Arc::Label> &osyms);
/* DeleteSymbols replaces any instances of symbols in the vector symsIn,
appearing on the input side, with epsilon. */
/* It returns the number of instances of symbols deleted. */
template <class Arc>
int64 DeleteISymbols(MutableFst<Arc> *fst,
std::vector<typename Arc::Label> symsIn);
/* CreateSuperFinal takes an FST, and creates an equivalent FST with a single
final state with no transitions out and unit final weight, by inserting
epsilon transitions as necessary. */
template <class Arc>
typename Arc::StateId CreateSuperFinal(MutableFst<Arc> *fst);
} // end namespace fst
#include "fstext/pre-determinize-inl.h"
#endif // KALDI_FSTEXT_PRE_DETERMINIZE_H_
// fstext/remove-eps-local-inl.h
// Copyright 2009-2011 Microsoft Corporation
// 2014 Johns Hopkins University (author: Daniel Povey
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_FSTEXT_REMOVE_EPS_LOCAL_INL_H_
#define KALDI_FSTEXT_REMOVE_EPS_LOCAL_INL_H_
#include <vector>
namespace fst {
template <class Weight>
struct ReweightPlusDefault {
inline Weight operator()(const Weight &a, const Weight &b) {
return Plus(a, b);
}
};
struct ReweightPlusLogArc {
inline TropicalWeight operator()(const TropicalWeight &a,
const TropicalWeight &b) {
LogWeight a_log(a.Value()), b_log(b.Value());
return TropicalWeight(Plus(a_log, b_log).Value());
}
};
template <class Arc,
class ReweightPlus = ReweightPlusDefault<typename Arc::Weight> >
class RemoveEpsLocalClass {
typedef typename Arc::StateId StateId;
typedef typename Arc::Label Label;
typedef typename Arc::Weight Weight;
public:
explicit RemoveEpsLocalClass(MutableFst<Arc> *fst) : fst_(fst) {
if (fst_->Start() == kNoStateId) return; // empty.
non_coacc_state_ = fst_->AddState();
InitNumArcs();
StateId num_states = fst_->NumStates();
for (StateId s = 0; s < num_states; s++)
for (size_t pos = 0; pos < fst_->NumArcs(s); pos++) RemoveEps(s, pos);
assert(CheckNumArcs());
Connect(fst); // remove inaccessible states.
}
private:
MutableFst<Arc> *fst_;
StateId non_coacc_state_; // use this to delete arcs: make it nextstate
std::vector<StateId> num_arcs_in_; // The number of arcs into the state, plus
// one if it's the start state.
std::vector<StateId> num_arcs_out_; // The number of arcs out of the state,
// plus one if it's a final state.
ReweightPlus reweight_plus_;
bool CanCombineArcs(const Arc &a, const Arc &b, Arc *c) {
if (a.ilabel != 0 && b.ilabel != 0) return false;
if (a.olabel != 0 && b.olabel != 0) return false;
c->weight = Times(a.weight, b.weight);
c->ilabel = (a.ilabel != 0 ? a.ilabel : b.ilabel);
c->olabel = (a.olabel != 0 ? a.olabel : b.olabel);
c->nextstate = b.nextstate;
return true;
}
static bool CanCombineFinal(const Arc &a, Weight final_prob,
Weight *final_prob_out) {
if (a.ilabel != 0 || a.olabel != 0) {
return false;
} else {
*final_prob_out = Times(a.weight, final_prob);
return true;
}
}
void InitNumArcs() { // init num transitions in/out of each state.
StateId num_states = fst_->NumStates();
num_arcs_in_.resize(num_states);
num_arcs_out_.resize(num_states);
num_arcs_in_[fst_->Start()]++; // count start as trans in.
for (StateId s = 0; s < num_states; s++) {
if (fst_->Final(s) != Weight::Zero())
num_arcs_out_[s]++; // count final as transition.
for (ArcIterator<MutableFst<Arc> > aiter(*fst_, s); !aiter.Done();
aiter.Next()) {
num_arcs_in_[aiter.Value().nextstate]++;
num_arcs_out_[s]++;
}
}
}
bool CheckNumArcs() { // check num arcs in/out of each state, at end. Debug.
num_arcs_in_[fst_->Start()]--; // count start as trans in.
StateId num_states = fst_->NumStates();
for (StateId s = 0; s < num_states; s++) {
if (s == non_coacc_state_) continue;
if (fst_->Final(s) != Weight::Zero())
num_arcs_out_[s]--; // count final as transition.
for (ArcIterator<MutableFst<Arc> > aiter(*fst_, s); !aiter.Done();
aiter.Next()) {
if (aiter.Value().nextstate == non_coacc_state_) continue;
num_arcs_in_[aiter.Value().nextstate]--;
num_arcs_out_[s]--;
}
}
for (StateId s = 0; s < num_states; s++) {
assert(num_arcs_in_[s] == 0);
assert(num_arcs_out_[s] == 0);
}
return true; // always does this. so we can assert it w/o warnings.
}
inline void GetArc(StateId s, size_t pos, Arc *arc) const {
ArcIterator<MutableFst<Arc> > aiter(*fst_, s);
aiter.Seek(pos);
*arc = aiter.Value();
}
inline void SetArc(StateId s, size_t pos, const Arc &arc) {
MutableArcIterator<MutableFst<Arc> > aiter(fst_, s);
aiter.Seek(pos);
aiter.SetValue(arc);
}
void Reweight(StateId s, size_t pos, Weight reweight) {
// Reweight is called from RemoveEpsPattern1; it is a step we
// do to preserve stochasticity. This function multiplies the
// arc at (s, pos) by reweight and divides all the arcs [+final-prob]
// out of the next state by the same. This is only valid if
// the next state has only one arc in and is not the start state.
assert(reweight != Weight::Zero());
MutableArcIterator<MutableFst<Arc> > aiter(fst_, s);
aiter.Seek(pos);
Arc arc = aiter.Value();
assert(num_arcs_in_[arc.nextstate] == 1);
arc.weight = Times(arc.weight, reweight);
aiter.SetValue(arc);
for (MutableArcIterator<MutableFst<Arc> > aiter_next(fst_, arc.nextstate);
!aiter_next.Done(); aiter_next.Next()) {
Arc nextarc = aiter_next.Value();
if (nextarc.nextstate != non_coacc_state_) {
nextarc.weight = Divide(nextarc.weight, reweight, DIVIDE_LEFT);
aiter_next.SetValue(nextarc);
}
}
Weight final = fst_->Final(arc.nextstate);
if (final != Weight::Zero()) {
fst_->SetFinal(arc.nextstate, Divide(final, reweight, DIVIDE_LEFT));
}
}
// RemoveEpsPattern1 applies where this arc, which is not a
// self-loop, enters a state which has only one input transition
// [and is not the start state], and has multiple output
// transitions [counting being the final-state as a final-transition].
void RemoveEpsPattern1(StateId s, size_t pos, Arc arc) {
const StateId nextstate = arc.nextstate;
Weight total_removed = Weight::Zero(),
total_kept = Weight::Zero(); // totals out of nextstate.
std::vector<Arc> arcs_to_add; // to add to state s.
for (MutableArcIterator<MutableFst<Arc> > aiter_next(fst_, nextstate);
!aiter_next.Done(); aiter_next.Next()) {
Arc nextarc = aiter_next.Value();
if (nextarc.nextstate == non_coacc_state_) continue; // deleted.
Arc combined;
if (CanCombineArcs(arc, nextarc, &combined)) {
total_removed = reweight_plus_(total_removed, nextarc.weight);
num_arcs_out_[nextstate]--;
num_arcs_in_[nextarc.nextstate]--;
nextarc.nextstate = non_coacc_state_;
aiter_next.SetValue(nextarc);
arcs_to_add.push_back(combined);
} else {
total_kept = reweight_plus_(total_kept, nextarc.weight);
}
}
{ // now final-state.
Weight next_final = fst_->Final(nextstate);
if (next_final != Weight::Zero()) {
Weight new_final;
if (CanCombineFinal(arc, next_final, &new_final)) {
total_removed = reweight_plus_(total_removed, next_final);
if (fst_->Final(s) == Weight::Zero())
num_arcs_out_[s]++; // final is counted as arc.
fst_->SetFinal(s, Plus(fst_->Final(s), new_final));
num_arcs_out_[nextstate]--;
fst_->SetFinal(nextstate, Weight::Zero());
} else {
total_kept = reweight_plus_(total_kept, next_final);
}
}
}
if (total_removed != Weight::Zero()) { // did something...
if (total_kept == Weight::Zero()) { // removed everything: remove arc.
num_arcs_out_[s]--;
num_arcs_in_[arc.nextstate]--;
arc.nextstate = non_coacc_state_;
SetArc(s, pos, arc);
} else {
// Have to reweight.
Weight total = reweight_plus_(total_removed, total_kept);
Weight reweight = Divide(total_kept, total, DIVIDE_LEFT); // <=1
Reweight(s, pos, reweight);
}
}
// Now add the arcs we were going to add.
for (size_t i = 0; i < arcs_to_add.size(); i++) {
num_arcs_out_[s]++;
num_arcs_in_[arcs_to_add[i].nextstate]++;
fst_->AddArc(s, arcs_to_add[i]);
}
}
void RemoveEpsPattern2(StateId s, size_t pos, Arc arc) {
// Pattern 2 is where "nextstate" has only one arc out, counting
// being-the-final-state as an arc, but possibly multiple arcs in.
// Also, nextstate != s.
const StateId nextstate = arc.nextstate;
bool can_delete_next = (num_arcs_in_[nextstate] == 1); // if
// we combine, can delete the corresponding out-arc/final-prob
// of nextstate.
bool delete_arc = false; // set to true if this arc to be deleted.
Weight next_final = fst_->Final(arc.nextstate);
if (next_final !=
Weight::Zero()) { // nextstate has no actual arcs out, only final-prob.
Weight new_final;
if (CanCombineFinal(arc, next_final, &new_final)) {
if (fst_->Final(s) == Weight::Zero())
num_arcs_out_[s]++; // final is counted as arc.
fst_->SetFinal(s, Plus(fst_->Final(s), new_final));
delete_arc = true; // will delete "arc".
if (can_delete_next) {
num_arcs_out_[nextstate]--;
fst_->SetFinal(nextstate, Weight::Zero());
}
}
} else { // has an arc but no final prob.
MutableArcIterator<MutableFst<Arc> > aiter_next(fst_, nextstate);
assert(!aiter_next.Done());
while (aiter_next.Value().nextstate == non_coacc_state_) {
aiter_next.Next();
assert(!aiter_next.Done());
}
// now aiter_next points to a real arc out of nextstate.
Arc nextarc = aiter_next.Value();
Arc combined;
if (CanCombineArcs(arc, nextarc, &combined)) {
delete_arc = true;
if (can_delete_next) { // do it before we invalidate iterators
num_arcs_out_[nextstate]--;
num_arcs_in_[nextarc.nextstate]--;
nextarc.nextstate = non_coacc_state_;
aiter_next.SetValue(nextarc);
}
num_arcs_out_[s]++;
num_arcs_in_[combined.nextstate]++;
fst_->AddArc(s, combined);
}
}
if (delete_arc) {
num_arcs_out_[s]--;
num_arcs_in_[nextstate]--;
arc.nextstate = non_coacc_state_;
SetArc(s, pos, arc);
}
}
void RemoveEps(StateId s, size_t pos) {
// Tries to do local epsilon-removal for arc sequences starting with this
// arc
Arc arc;
GetArc(s, pos, &arc);
StateId nextstate = arc.nextstate;
if (nextstate == non_coacc_state_) return; // deleted arc.
if (nextstate == s) return; // don't handle self-loops: too complex.
if (num_arcs_in_[nextstate] == 1 && num_arcs_out_[nextstate] > 1) {
RemoveEpsPattern1(s, pos, arc);
} else if (num_arcs_out_[nextstate] == 1) {
RemoveEpsPattern2(s, pos, arc);
}
}
};
template <class Arc>
void RemoveEpsLocal(MutableFst<Arc> *fst) {
RemoveEpsLocalClass<Arc> c(fst); // work gets done in initializer.
}
void RemoveEpsLocalSpecial(MutableFst<StdArc> *fst) {
// work gets done in initializer.
RemoveEpsLocalClass<StdArc, ReweightPlusLogArc> c(fst);
}
} // end namespace fst.
#endif // KALDI_FSTEXT_REMOVE_EPS_LOCAL_INL_H_
// fstext/remove-eps-local.h
// Copyright 2009-2011 Microsoft Corporation
// 2014 Johns Hopkins University (author: Daniel Povey)
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_FSTEXT_REMOVE_EPS_LOCAL_H_
#define KALDI_FSTEXT_REMOVE_EPS_LOCAL_H_
#include <fst/fst-decl.h>
#include <fst/fstlib.h>
namespace fst {
/// RemoveEpsLocal remove some (but not necessarily all) epsilons in an FST,
/// using an algorithm that is guaranteed to never increase the number of arcs
/// in the FST (and will also never increase the number of states). The
/// algorithm is not optimal but is reasonably clever. It does not just remove
/// epsilon arcs;it also combines pairs of input-epsilon and output-epsilon arcs
/// into one.
/// The algorithm preserves equivalence and stochasticity in the given semiring.
/// If you want to preserve stochasticity in a different semiring (e.g. log),
/// then use RemoveEpsLocalSpecial, which only works for StdArc but which
/// preserves stochasticity, where possible (*) in the LogArc sense. The reason
/// that we can't just cast to a different semiring is that in that case we
/// would no longer be able to guarantee equivalence in the original semiring
/// (this arises from what happens when we combine identical arcs).
/// (*) by "where possible".. there are situations where we wouldn't be able to
/// preserve stochasticity in the LogArc sense while maintaining equivalence in
/// the StdArc sense, so in these situations we maintain equivalence.
template <class Arc>
void RemoveEpsLocal(MutableFst<Arc> *fst);
/// As RemoveEpsLocal but takes care to preserve stochasticity
/// when cast to LogArc.
inline void RemoveEpsLocalSpecial(MutableFst<StdArc> *fst);
} // namespace fst
#include "fstext/remove-eps-local-inl.h"
#endif // KALDI_FSTEXT_REMOVE_EPS_LOCAL_H_
// fstext/table-matcher.h
// Copyright 2009-2011 Microsoft Corporation
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_FSTEXT_TABLE_MATCHER_H_
#define KALDI_FSTEXT_TABLE_MATCHER_H_
#include <fst/fst-decl.h>
#include <fst/fstlib.h>
#include <memory>
#include <vector>
namespace fst {
/// TableMatcher is a matcher specialized for the case where the output
/// side of the left FST always has either all-epsilons coming out of
/// a state, or a majority of the symbol table. Therefore we can
/// either store nothing (for the all-epsilon case) or store a lookup
/// table from Labels to arc offsets. Since the TableMatcher has to
/// iterate over all arcs in each left-hand state the first time it sees
/// it, this matcher type is not efficient if you compose with
/// something very small on the right-- unless you do it multiple
/// times and keep the matcher around. To do this requires using the
/// most advanced form of ComposeFst in Compose.h, that initializes
/// with ComposeFstImplOptions.
struct TableMatcherOptions {
float
table_ratio; // we construct the table if it would be at least this full.
int min_table_size;
TableMatcherOptions() : table_ratio(0.25), min_table_size(4) {}
};
// Introducing an "impl" class for TableMatcher because
// we need to do a shallow copy of the Matcher for when
// we want to cache tables for multiple compositions.
template <class F, class BackoffMatcher = SortedMatcher<F> >
class TableMatcherImpl : public MatcherBase<typename F::Arc> {
public:
typedef F FST;
typedef typename F::Arc Arc;
typedef typename Arc::Label Label;
typedef typename Arc::StateId StateId;
typedef StateId
ArcId; // Use this type to store arc offsets [it's actually size_t
// in the Seek function of ArcIterator, but StateId should be big enough].
typedef typename Arc::Weight Weight;
public:
TableMatcherImpl(const FST &fst, MatchType match_type,
const TableMatcherOptions &opts = TableMatcherOptions())
: match_type_(match_type),
fst_(fst.Copy()),
loop_(match_type == MATCH_INPUT
? Arc(kNoLabel, 0, Weight::One(), kNoStateId)
: Arc(0, kNoLabel, Weight::One(), kNoStateId)),
aiter_(NULL),
s_(kNoStateId),
opts_(opts),
backoff_matcher_(fst, match_type) {
assert(opts_.min_table_size > 0);
if (match_type == MATCH_INPUT)
assert(fst_->Properties(kILabelSorted, true) == kILabelSorted);
else if (match_type == MATCH_OUTPUT)
assert(fst_->Properties(kOLabelSorted, true) == kOLabelSorted);
else
assert(0 && "Invalid FST properties");
}
virtual const FST &GetFst() const { return *fst_; }
virtual ~TableMatcherImpl() {
std::vector<ArcId> *const empty =
((std::vector<ArcId> *)(NULL)) + 1; // special marker.
for (size_t i = 0; i < tables_.size(); i++) {
if (tables_[i] != NULL && tables_[i] != empty) delete tables_[i];
}
delete aiter_;
delete fst_;
}
virtual MatchType Type(bool test) const { return match_type_; }
void SetState(StateId s) {
if (aiter_) {
delete aiter_;
aiter_ = NULL;
}
if (match_type_ == MATCH_NONE) LOG(FATAL) << "TableMatcher: bad match type";
s_ = s;
std::vector<ArcId> *const empty =
((std::vector<ArcId> *)(NULL)) + 1; // special marker.
if (static_cast<size_t>(s) >= tables_.size()) {
assert(s >= 0);
tables_.resize(s + 1, NULL);
}
std::vector<ArcId> *&this_table_ = tables_[s]; // note: ref to ptr.
if (this_table_ == empty) {
backoff_matcher_.SetState(s);
return;
} else if (this_table_ == NULL) { // NULL means has not been set.
ArcId num_arcs = fst_->NumArcs(s);
if (num_arcs == 0 || num_arcs < opts_.min_table_size) {
this_table_ = empty;
backoff_matcher_.SetState(s);
return;
}
ArcIterator<FST> aiter(*fst_, s);
aiter.SetFlags(
kArcNoCache |
(match_type_ == MATCH_OUTPUT ? kArcOLabelValue : kArcILabelValue),
kArcNoCache | kArcValueFlags);
// the statement above, says: "Don't cache stuff; and I only need the
// ilabel/olabel to be computed.
aiter.Seek(num_arcs - 1);
Label highest_label =
(match_type_ == MATCH_OUTPUT ? aiter.Value().olabel
: aiter.Value().ilabel);
if ((highest_label + 1) * opts_.table_ratio > num_arcs) {
this_table_ = empty;
backoff_matcher_.SetState(s);
return; // table would be too sparse.
}
// OK, now we are creating the table.
this_table_ = new std::vector<ArcId>(highest_label + 1, kNoStateId);
ArcId pos = 0;
for (aiter.Seek(0); !aiter.Done(); aiter.Next(), pos++) {
Label label = (match_type_ == MATCH_OUTPUT ? aiter.Value().olabel
: aiter.Value().ilabel);
assert(static_cast<size_t>(label) <=
static_cast<size_t>(highest_label)); // also checks >= 0.
if ((*this_table_)[label] == kNoStateId) (*this_table_)[label] = pos;
// set this_table_[label] to first position where arc has this
// label.
}
}
// At this point in the code, this_table_ != NULL and != empty.
aiter_ = new ArcIterator<FST>(*fst_, s);
aiter_->SetFlags(kArcNoCache,
kArcNoCache); // don't need to cache arcs as may only
// need a small subset.
loop_.nextstate = s;
// aiter_ = NULL;
// backoff_matcher_.SetState(s);
}
bool Find(Label match_label) {
if (!aiter_) {
return backoff_matcher_.Find(match_label);
} else {
match_label_ = match_label;
current_loop_ = (match_label == 0);
// kNoLabel means the implicit loop on the other FST --
// matches real epsilons but not the self-loop.
match_label_ = (match_label_ == kNoLabel ? 0 : match_label_);
if (static_cast<size_t>(match_label_) < tables_[s_]->size() &&
(*(tables_[s_]))[match_label_] != kNoStateId) {
aiter_->Seek((*(tables_[s_]))[match_label_]); // label exists.
return true;
}
return current_loop_;
}
}
const Arc &Value() const {
if (aiter_)
return current_loop_ ? loop_ : aiter_->Value();
else
return backoff_matcher_.Value();
}
void Next() {
if (aiter_) {
if (current_loop_)
current_loop_ = false;
else
aiter_->Next();
} else {
backoff_matcher_.Next();
}
}
bool Done() const {
if (aiter_ != NULL) {
if (current_loop_) return false;
if (aiter_->Done()) return true;
Label label = (match_type_ == MATCH_OUTPUT ? aiter_->Value().olabel
: aiter_->Value().ilabel);
return (label != match_label_);
} else {
return backoff_matcher_.Done();
}
}
const Arc &Value() {
if (aiter_ != NULL) {
return (current_loop_ ? loop_ : aiter_->Value());
} else {
return backoff_matcher_.Value();
}
}
virtual TableMatcherImpl<FST> *Copy(bool safe = false) const {
assert(0); // shouldn't be called. This is not a "real" matcher,
// although we derive from MatcherBase for convenience.
return NULL;
}
virtual uint64 Properties(uint64 props) const {
return props;
} // simple matcher that does
// not change its FST, so properties are properties of FST it is applied to
private:
virtual void SetState_(StateId s) { SetState(s); }
virtual bool Find_(Label label) { return Find(label); }
virtual bool Done_() const { return Done(); }
virtual const Arc &Value_() const { return Value(); }
virtual void Next_() { Next(); }
MatchType match_type_;
FST *fst_;
bool current_loop_;
Label match_label_;
Arc loop_;
ArcIterator<FST> *aiter_;
StateId s_;
std::vector<std::vector<ArcId> *> tables_;
TableMatcherOptions opts_;
BackoffMatcher backoff_matcher_;
};
template <class F, class BackoffMatcher = SortedMatcher<F> >
class TableMatcher : public MatcherBase<typename F::Arc> {
public:
typedef F FST;
typedef typename F::Arc Arc;
typedef typename Arc::Label Label;
typedef typename Arc::StateId StateId;
typedef StateId
ArcId; // Use this type to store arc offsets [it's actually size_t
// in the Seek function of ArcIterator, but StateId should be big enough].
typedef typename Arc::Weight Weight;
typedef TableMatcherImpl<F, BackoffMatcher> Impl;
TableMatcher(const FST &fst, MatchType match_type,
const TableMatcherOptions &opts = TableMatcherOptions())
: impl_(std::make_shared<Impl>(fst, match_type, opts)) {}
TableMatcher(const TableMatcher<FST, BackoffMatcher> &matcher,
bool safe = false)
: impl_(matcher.impl_) {
if (safe == true) {
LOG(FATAL) << "TableMatcher: Safe copy not supported";
}
}
virtual const FST &GetFst() const { return impl_->GetFst(); }
virtual MatchType Type(bool test) const { return impl_->Type(test); }
void SetState(StateId s) { return impl_->SetState(s); }
bool Find(Label match_label) { return impl_->Find(match_label); }
const Arc &Value() const { return impl_->Value(); }
void Next() { return impl_->Next(); }
bool Done() const { return impl_->Done(); }
const Arc &Value() { return impl_->Value(); }
virtual TableMatcher<FST, BackoffMatcher> *Copy(bool safe = false) const {
return new TableMatcher<FST, BackoffMatcher>(*this, safe);
}
virtual uint64 Properties(uint64 props) const {
return impl_->Properties(props);
} // simple matcher that does
// not change its FST, so properties are properties of FST it is applied to
private:
std::shared_ptr<Impl> impl_;
virtual void SetState_(StateId s) { impl_->SetState(s); }
virtual bool Find_(Label label) { return impl_->Find(label); }
virtual bool Done_() const { return impl_->Done(); }
virtual const Arc &Value_() const { return impl_->Value(); }
virtual void Next_() { impl_->Next(); }
TableMatcher &operator=(const TableMatcher &) = delete;
};
struct TableComposeOptions : public TableMatcherOptions {
bool connect; // Connect output
ComposeFilter filter_type; // Which pre-defined filter to use
MatchType table_match_type;
explicit TableComposeOptions(const TableMatcherOptions &mo, bool c = true,
ComposeFilter ft = SEQUENCE_FILTER,
MatchType tms = MATCH_OUTPUT)
: TableMatcherOptions(mo),
connect(c),
filter_type(ft),
table_match_type(tms) {}
TableComposeOptions()
: connect(true),
filter_type(SEQUENCE_FILTER),
table_match_type(MATCH_OUTPUT) {}
};
template <class Arc>
void TableCompose(const Fst<Arc> &ifst1, const Fst<Arc> &ifst2,
MutableFst<Arc> *ofst,
const TableComposeOptions &opts = TableComposeOptions()) {
typedef Fst<Arc> F;
CacheOptions nopts;
nopts.gc_limit = 0; // Cache only the last state for fastest copy.
if (opts.table_match_type == MATCH_OUTPUT) {
// ComposeFstImplOptions templated on matcher for fst1, matcher for fst2.
ComposeFstImplOptions<TableMatcher<F>, SortedMatcher<F> > impl_opts(nopts);
impl_opts.matcher1 = new TableMatcher<F>(ifst1, MATCH_OUTPUT, opts);
*ofst = ComposeFst<Arc>(ifst1, ifst2, impl_opts);
} else {
assert(opts.table_match_type == MATCH_INPUT);
// ComposeFstImplOptions templated on matcher for fst1, matcher for fst2.
ComposeFstImplOptions<SortedMatcher<F>, TableMatcher<F> > impl_opts(nopts);
impl_opts.matcher2 = new TableMatcher<F>(ifst2, MATCH_INPUT, opts);
*ofst = ComposeFst<Arc>(ifst1, ifst2, impl_opts);
}
if (opts.connect) Connect(ofst);
}
/// TableComposeCache lets us do multiple compositions while caching the same
/// matcher.
template <class F>
struct TableComposeCache {
TableMatcher<F> *matcher;
TableComposeOptions opts;
explicit TableComposeCache(
const TableComposeOptions &opts = TableComposeOptions())
: matcher(NULL), opts(opts) {}
~TableComposeCache() { delete (matcher); }
};
template <class Arc>
void TableCompose(const Fst<Arc> &ifst1, const Fst<Arc> &ifst2,
MutableFst<Arc> *ofst, TableComposeCache<Fst<Arc> > *cache) {
typedef Fst<Arc> F;
assert(cache != NULL);
CacheOptions nopts;
nopts.gc_limit = 0; // Cache only the last state for fastest copy.
if (cache->opts.table_match_type == MATCH_OUTPUT) {
ComposeFstImplOptions<TableMatcher<F>, SortedMatcher<F> > impl_opts(nopts);
if (cache->matcher == NULL)
cache->matcher = new TableMatcher<F>(ifst1, MATCH_OUTPUT, cache->opts);
impl_opts.matcher1 = cache->matcher->Copy(); // not passing "safe": may not
// be thread-safe-- anway I don't understand this part.
*ofst = ComposeFst<Arc>(ifst1, ifst2, impl_opts);
} else {
assert(cache->opts.table_match_type == MATCH_INPUT);
ComposeFstImplOptions<SortedMatcher<F>, TableMatcher<F> > impl_opts(nopts);
if (cache->matcher == NULL)
cache->matcher = new TableMatcher<F>(ifst2, MATCH_INPUT, cache->opts);
impl_opts.matcher2 = cache->matcher->Copy();
*ofst = ComposeFst<Arc>(ifst1, ifst2, impl_opts);
}
if (cache->opts.connect) Connect(ofst);
}
} // namespace fst
#endif // KALDI_FSTEXT_TABLE_MATCHER_H_
// itf/decodable-itf.h
// Copyright 2009-2011 Microsoft Corporation; Saarland University;
// Mirko Hannemann; Go Vivace Inc.;
// 2013 Johns Hopkins University (author: Daniel Povey)
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_ITF_DECODABLE_ITF_H_
#define KALDI_ITF_DECODABLE_ITF_H_ 1
#include "base/kaldi-common.h"
namespace kaldi {
/// @ingroup Interfaces
/// @{
/**
DecodableInterface provides a link between the (acoustic-modeling and
feature-processing) code and the decoder. The idea is to make this
interface as small as possible, and to make it as agnostic as possible about
the form of the acoustic model (e.g. don't assume the probabilities are a
function of just a vector of floats), and about the decoder (e.g. don't
assume it accesses frames in strict left-to-right order). For normal
models, without on-line operation, the "decodable" sub-class will just be a
wrapper around a matrix of features and an acoustic model, and it will
answer the question 'what is the acoustic likelihood for this index and this
frame?'.
For online decoding, where the features are coming in in real time, it is
important to understand the IsLastFrame() and NumFramesReady() functions.
There are two ways these are used: the old online-decoding code, in ../online/,
and the new online-decoding code, in ../online2/. In the old online-decoding
code, the decoder would do:
\code{.cc}
for (int frame = 0; !decodable.IsLastFrame(frame); frame++) {
// Process this frame
}
\endcode
and the call to IsLastFrame would block if the features had not arrived yet.
The decodable object would have to know when to terminate the decoding. This
online-decoding mode is still supported, it is what happens when you call, for
example, LatticeFasterDecoder::Decode().
We realized that this "blocking" mode of decoding is not very convenient
because it forces the program to be multi-threaded and makes it complex to
control endpointing. In the "new" decoding code, you don't call (for example)
LatticeFasterDecoder::Decode(), you call LatticeFasterDecoder::InitDecoding(),
and then each time you get more features, you provide them to the decodable
object, and you call LatticeFasterDecoder::AdvanceDecoding(), which does
something like this:
\code{.cc}
while (num_frames_decoded_ < decodable.NumFramesReady()) {
// Decode one more frame [increments num_frames_decoded_]
}
\endcode
So the decodable object never has IsLastFrame() called. For decoding where
you are starting with a matrix of features, the NumFramesReady() function will
always just return the number of frames in the file, and IsLastFrame() will
return true for the last frame.
For truly online decoding, the "old" online decodable objects in ../online/
have a "blocking" IsLastFrame() and will crash if you call NumFramesReady().
The "new" online decodable objects in ../online2/ return the number of frames
currently accessible if you call NumFramesReady(). You will likely not need
to call IsLastFrame(), but we implement it to only return true for the last
frame of the file once we've decided to terminate decoding.
*/
class DecodableInterface {
public:
/// Returns the log likelihood, which will be negated in the decoder.
/// The "frame" starts from zero. You should verify that
/// NumFramesReady() > frame before calling this.
virtual BaseFloat LogLikelihood(int32 frame, int32 index) = 0;
/// Returns true if this is the last frame. Frames are zero-based, so the
/// first frame is zero. IsLastFrame(-1) will return false, unless the file
/// is empty (which is a case that I'm not sure all the code will handle, so
/// be careful). Caution: the behavior of this function in an online setting
/// is being changed somewhat. In future it may return false in cases where
/// we haven't yet decided to terminate decoding, but later true if we decide
/// to terminate decoding. The plan in future is to rely more on
/// NumFramesReady(), and in future, IsLastFrame() would always return false
/// in an online-decoding setting, and would only return true in a
/// decoding-from-matrix setting where we want to allow the last delta or LDA
/// features to be flushed out for compatibility with the baseline setup.
virtual bool IsLastFrame(int32 frame) const = 0;
/// The call NumFramesReady() will return the number of frames currently
/// available for this decodable object. This is for use in setups where
/// you don't want the decoder to block while waiting for input. This is
/// newly added as of Jan 2014, and I hope, going forward, to rely on this
/// mechanism more than IsLastFrame to know when to stop decoding.
virtual int32 NumFramesReady() const {
KALDI_ERR << "NumFramesReady() not implemented for this decodable type.";
return -1;
}
/// Returns the number of states in the acoustic model
/// (they will be indexed one-based, i.e. from 1 to NumIndices();
/// this is for compatibility with OpenFst).
virtual int32 NumIndices() const = 0;
virtual ~DecodableInterface() {}
};
/// @}
} // namespace kaldi
#endif // KALDI_ITF_DECODABLE_ITF_H_
// itf/options-itf.h
// Copyright 2013 Tanel Alumae, Tallinn University of Technology
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_ITF_OPTIONS_ITF_H_
#define KALDI_ITF_OPTIONS_ITF_H_ 1
#include <string>
#include "base/kaldi-common.h"
namespace kaldi {
class OptionsItf {
public:
virtual void Register(const std::string &name,
bool *ptr, const std::string &doc) = 0;
virtual void Register(const std::string &name,
int32 *ptr, const std::string &doc) = 0;
virtual void Register(const std::string &name,
uint32 *ptr, const std::string &doc) = 0;
virtual void Register(const std::string &name,
float *ptr, const std::string &doc) = 0;
virtual void Register(const std::string &name,
double *ptr, const std::string &doc) = 0;
virtual void Register(const std::string &name,
std::string *ptr, const std::string &doc) = 0;
virtual ~OptionsItf() {}
};
} // namespace kaldi
#endif // KALDI_ITF_OPTIONS_ITF_H_
# So many lint errors now, we just ignore it now.
# We will try to fix it in the future.
exclude_files=.*
// lat/determinize-lattice-pruned.cc
// Copyright 2009-2012 Microsoft Corporation
// 2012-2013 Johns Hopkins University (Author: Daniel Povey)
// 2014 Guoguo Chen
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include <vector>
#include <climits>
#include "fstext/determinize-lattice.h" // for LatticeStringRepository
#include "fstext/fstext-utils.h"
#include "lat/lattice-functions.h" // for PruneLattice
// #include "lat/minimize-lattice.h" // for minimization
// #include "lat/push-lattice.h" // for minimization
#include "lat/determinize-lattice-pruned.h"
namespace fst {
using std::vector;
using std::pair;
using std::greater;
// class LatticeDeterminizerPruned is templated on the same types that
// CompactLatticeWeight is templated on: the base weight (Weight), typically
// LatticeWeightTpl<float> etc. but could also be e.g. TropicalWeight, and the
// IntType, typically int32, used for the output symbols in the compact
// representation of strings [note: the output symbols would usually be
// p.d.f. id's in the anticipated use of this code] It has a special requirement
// on the Weight type: that there should be a Compare function on the weights
// such that Compare(w1, w2) returns -1 if w1 < w2, 0 if w1 == w2, and +1 if w1 >
// w2. This requires that there be a total order on the weights.
template<class Weight, class IntType> class LatticeDeterminizerPruned {
public:
// Output to Gallic acceptor (so the strings go on weights, and there is a 1-1 correspondence
// between our states and the states in ofst. If destroy == true, release memory as we go
// (but we cannot output again).
typedef CompactLatticeWeightTpl<Weight, IntType> CompactWeight;
typedef ArcTpl<CompactWeight> CompactArc; // arc in compact, acceptor form of lattice
typedef ArcTpl<Weight> Arc; // arc in non-compact version of lattice
// Output to standard FST with CompactWeightTpl<Weight> as its weight type (the
// weight stores the original output-symbol strings). If destroy == true,
// release memory as we go (but we cannot output again).
void Output(MutableFst<CompactArc> *ofst, bool destroy = true) {
KALDI_ASSERT(determinized_);
typedef typename Arc::StateId StateId;
StateId nStates = static_cast<StateId>(output_states_.size());
if (destroy)
FreeMostMemory();
ofst->DeleteStates();
ofst->SetStart(kNoStateId);
if (nStates == 0) {
return;
}
for (StateId s = 0;s < nStates;s++) {
OutputStateId news = ofst->AddState();
KALDI_ASSERT(news == s);
}
ofst->SetStart(0);
// now process transitions.
for (StateId this_state_id = 0; this_state_id < nStates; this_state_id++) {
OutputState &this_state = *(output_states_[this_state_id]);
vector<TempArc> &this_vec(this_state.arcs);
typename vector<TempArc>::const_iterator iter = this_vec.begin(), end = this_vec.end();
for (;iter != end; ++iter) {
const TempArc &temp_arc(*iter);
CompactArc new_arc;
vector<Label> olabel_seq;
repository_.ConvertToVector(temp_arc.string, &olabel_seq);
CompactWeight weight(temp_arc.weight, olabel_seq);
if (temp_arc.nextstate == kNoStateId) { // is really final weight.
ofst->SetFinal(this_state_id, weight);
} else { // is really an arc.
new_arc.nextstate = temp_arc.nextstate;
new_arc.ilabel = temp_arc.ilabel;
new_arc.olabel = temp_arc.ilabel; // acceptor. input == output.
new_arc.weight = weight; // includes string and weight.
ofst->AddArc(this_state_id, new_arc);
}
}
// Free up memory. Do this inside the loop as ofst is also allocating memory,
// and we want to reduce the maximum amount ever allocated.
if (destroy) { vector<TempArc> temp; temp.swap(this_vec); }
}
if (destroy) {
FreeOutputStates();
repository_.Destroy();
}
}
// Output to standard FST with Weight as its weight type. We will create extra
// states to handle sequences of symbols on the output. If destroy == true,
// release memory as we go (but we cannot output again).
void Output(MutableFst<Arc> *ofst, bool destroy = true) {
// Outputs to standard fst.
OutputStateId nStates = static_cast<OutputStateId>(output_states_.size());
ofst->DeleteStates();
if (nStates == 0) {
ofst->SetStart(kNoStateId);
return;
}
if (destroy)
FreeMostMemory();
// Add basic states-- but we will add extra ones to account for strings on output.
for (OutputStateId s = 0; s< nStates;s++) {
OutputStateId news = ofst->AddState();
KALDI_ASSERT(news == s);
}
ofst->SetStart(0);
for (OutputStateId this_state_id = 0; this_state_id < nStates; this_state_id++) {
OutputState &this_state = *(output_states_[this_state_id]);
vector<TempArc> &this_vec(this_state.arcs);
typename vector<TempArc>::const_iterator iter = this_vec.begin(), end = this_vec.end();
for (; iter != end; ++iter) {
const TempArc &temp_arc(*iter);
vector<Label> seq;
repository_.ConvertToVector(temp_arc.string, &seq);
if (temp_arc.nextstate == kNoStateId) { // Really a final weight.
// Make a sequence of states going to a final state, with the strings
// as labels. Put the weight on the first arc.
OutputStateId cur_state = this_state_id;
for (size_t i = 0; i < seq.size(); i++) {
OutputStateId next_state = ofst->AddState();
Arc arc;
arc.nextstate = next_state;
arc.weight = (i == 0 ? temp_arc.weight : Weight::One());
arc.ilabel = 0; // epsilon.
arc.olabel = seq[i];
ofst->AddArc(cur_state, arc);
cur_state = next_state;
}
ofst->SetFinal(cur_state, (seq.size() == 0 ? temp_arc.weight : Weight::One()));
} else { // Really an arc.
OutputStateId cur_state = this_state_id;
// Have to be careful with this integer comparison (i+1 < seq.size()) because unsigned.
// i < seq.size()-1 could fail for zero-length sequences.
for (size_t i = 0; i+1 < seq.size();i++) {
// for all but the last element of seq, create new state.
OutputStateId next_state = ofst->AddState();
Arc arc;
arc.nextstate = next_state;
arc.weight = (i == 0 ? temp_arc.weight : Weight::One());
arc.ilabel = (i == 0 ? temp_arc.ilabel : 0); // put ilabel on first element of seq.
arc.olabel = seq[i];
ofst->AddArc(cur_state, arc);
cur_state = next_state;
}
// Add the final arc in the sequence.
Arc arc;
arc.nextstate = temp_arc.nextstate;
arc.weight = (seq.size() <= 1 ? temp_arc.weight : Weight::One());
arc.ilabel = (seq.size() <= 1 ? temp_arc.ilabel : 0);
arc.olabel = (seq.size() > 0 ? seq.back() : 0);
ofst->AddArc(cur_state, arc);
}
}
// Free up memory. Do this inside the loop as ofst is also allocating memory
if (destroy) { vector<TempArc> temp; temp.swap(this_vec); }
}
if (destroy) {
FreeOutputStates();
repository_.Destroy();
}
}
// Initializer. After initializing the object you will typically
// call Determinize() and then call one of the Output functions.
// Note: ifst.Copy() will generally do a
// shallow copy. We do it like this for memory safety, rather than
// keeping a reference or pointer to ifst_.
LatticeDeterminizerPruned(const ExpandedFst<Arc> &ifst,
double beam,
DeterminizeLatticePrunedOptions opts):
num_arcs_(0), num_elems_(0), ifst_(ifst.Copy()), beam_(beam), opts_(opts),
equal_(opts_.delta), determinized_(false),
minimal_hash_(3, hasher_, equal_), initial_hash_(3, hasher_, equal_) {
KALDI_ASSERT(Weight::Properties() & kIdempotent); // this algorithm won't
// work correctly otherwise.
}
void FreeOutputStates() {
for (size_t i = 0; i < output_states_.size(); i++)
delete output_states_[i];
vector<OutputState*> temp;
temp.swap(output_states_);
}
// frees all memory except the info (in output_states_[ ]->arcs)
// that we need to output the FST.
void FreeMostMemory() {
if (ifst_) {
delete ifst_;
ifst_ = NULL;
}
{ MinimalSubsetHash tmp; tmp.swap(minimal_hash_); }
for (size_t i = 0; i < output_states_.size(); i++) {
vector<Element> empty_subset;
empty_subset.swap(output_states_[i]->minimal_subset);
}
for (typename InitialSubsetHash::iterator iter = initial_hash_.begin();
iter != initial_hash_.end(); ++iter)
delete iter->first;
{ InitialSubsetHash tmp; tmp.swap(initial_hash_); }
for (size_t i = 0; i < output_states_.size(); i++) {
vector<Element> tmp;
tmp.swap(output_states_[i]->minimal_subset);
}
{ vector<char> tmp; tmp.swap(isymbol_or_final_); }
{ // Free up the queue. I'm not sure how to make sure all
// the memory is really freed (no swap() function)... doesn't really
// matter much though.
while (!queue_.empty()) {
Task *t = queue_.top();
delete t;
queue_.pop();
}
}
{ vector<pair<Label, Element> > tmp; tmp.swap(all_elems_tmp_); }
}
~LatticeDeterminizerPruned() {
FreeMostMemory();
FreeOutputStates();
// rest is deleted by destructors.
}
void RebuildRepository() { // rebuild the string repository,
// freeing stuff we don't need.. we call this when memory usage
// passes a supplied threshold. We need to accumulate all the
// strings we need the repository to "remember", then tell it
// to clean the repository.
std::vector<StringId> needed_strings;
for (size_t i = 0; i < output_states_.size(); i++) {
AddStrings(output_states_[i]->minimal_subset, &needed_strings);
for (size_t j = 0; j < output_states_[i]->arcs.size(); j++)
needed_strings.push_back(output_states_[i]->arcs[j].string);
}
{ // the queue doesn't allow us access to the underlying vector,
// so we have to resort to a temporary collection.
std::vector<Task*> tasks;
while (!queue_.empty()) {
Task *task = queue_.top();
queue_.pop();
tasks.push_back(task);
AddStrings(task->subset, &needed_strings);
}
for (size_t i = 0; i < tasks.size(); i++)
queue_.push(tasks[i]);
}
// the following loop covers strings present in initial_hash_.
for (typename InitialSubsetHash::const_iterator
iter = initial_hash_.begin();
iter != initial_hash_.end(); ++iter) {
const vector<Element> &vec = *(iter->first);
Element elem = iter->second;
AddStrings(vec, &needed_strings);
needed_strings.push_back(elem.string);
}
std::sort(needed_strings.begin(), needed_strings.end());
needed_strings.erase(std::unique(needed_strings.begin(),
needed_strings.end()),
needed_strings.end()); // uniq the strings.
KALDI_LOG << "Rebuilding repository.";
repository_.Rebuild(needed_strings);
}
bool CheckMemoryUsage() {
int32 repo_size = repository_.MemSize(),
arcs_size = num_arcs_ * sizeof(TempArc),
elems_size = num_elems_ * sizeof(Element),
total_size = repo_size + arcs_size + elems_size;
if (opts_.max_mem > 0 && total_size > opts_.max_mem) { // We passed the memory threshold.
// This is usually due to the repository getting large, so we
// clean this out.
RebuildRepository();
int32 new_repo_size = repository_.MemSize(),
new_total_size = new_repo_size + arcs_size + elems_size;
KALDI_VLOG(2) << "Rebuilt repository in determinize-lattice: repository shrank from "
<< repo_size << " to " << new_repo_size << " bytes (approximately)";
if (new_total_size > static_cast<int32>(opts_.max_mem * 0.8)) {
// Rebuilding didn't help enough-- we need a margin to stop
// having to rebuild too often. We'll just return to the user at
// this point, with a partial lattice that's pruned tighter than
// the specified beam. Here we figure out what the effective
// beam was.
double effective_beam = beam_;
if (!queue_.empty()) { // Note: queue should probably not be empty; we're
// just being paranoid here.
Task *task = queue_.top();
double total_weight = backward_costs_[ifst_->Start()]; // best weight of FST.
effective_beam = task->priority_cost - total_weight;
}
KALDI_WARN << "Did not reach requested beam in determinize-lattice: "
<< "size exceeds maximum " << opts_.max_mem
<< " bytes; (repo,arcs,elems) = (" << repo_size << ","
<< arcs_size << "," << elems_size
<< "), after rebuilding, repo size was " << new_repo_size
<< ", effective beam was " << effective_beam
<< " vs. requested beam " << beam_;
return false;
}
}
return true;
}
bool Determinize(double *effective_beam) {
KALDI_ASSERT(!determinized_);
// This determinizes the input fst but leaves it in the "special format"
// in "output_arcs_". Must be called after Initialize(). To get the
// output, call one of the Output routines.
InitializeDeterminization(); // some start-up tasks.
while (!queue_.empty()) {
Task *task = queue_.top();
// Note: the queue contains only tasks that are "within the beam".
// We also have to check whether we have reached one of the user-specified
// maximums, of estimated memory, arcs, or states. The condition for
// ending is:
// num-states is more than user specified, OR
// num-arcs is more than user specified, OR
// memory passed a user-specified threshold and cleanup failed
// to get it below that threshold.
size_t num_states = output_states_.size();
if ((opts_.max_states > 0 && num_states > opts_.max_states) ||
(opts_.max_arcs > 0 && num_arcs_ > opts_.max_arcs) ||
(num_states % 10 == 0 && !CheckMemoryUsage())) { // note: at some point
// it was num_states % 100, not num_states % 10, but I encountered an example
// where memory was exhausted before we reached state #100.
KALDI_VLOG(1) << "Lattice determinization terminated but not "
<< " because of lattice-beam. (#states, #arcs) is ( "
<< output_states_.size() << ", " << num_arcs_
<< " ), versus limits ( " << opts_.max_states << ", "
<< opts_.max_arcs << " ) (else, may be memory limit).";
break;
// we terminate the determinization here-- whatever we already expanded is
// what we'll return... because we expanded stuff in order of total
// (forward-backward) weight, the stuff we returned first is the most
// important.
}
queue_.pop();
ProcessTransition(task->state, task->label, &(task->subset));
delete task;
}
determinized_ = true;
if (effective_beam != NULL) {
if (queue_.empty()) *effective_beam = beam_;
else
*effective_beam = queue_.top()->priority_cost -
backward_costs_[ifst_->Start()];
}
return (queue_.empty()); // return success if queue was empty, i.e. we processed
// all tasks and did not break out of the loop early due to reaching a memory,
// arc or state limit.
}
private:
typedef typename Arc::Label Label;
typedef typename Arc::StateId StateId; // use this when we don't know if it's input or output.
typedef typename Arc::StateId InputStateId; // state in the input FST.
typedef typename Arc::StateId OutputStateId; // same as above but distinguish
// states in output Fst.
typedef LatticeStringRepository<IntType> StringRepositoryType;
typedef const typename StringRepositoryType::Entry* StringId;
// Element of a subset [of original states]
struct Element {
StateId state; // use StateId as this is usually InputStateId but in one case
// OutputStateId.
StringId string;
Weight weight;
bool operator != (const Element &other) const {
return (state != other.state || string != other.string ||
weight != other.weight);
}
// This operator is only intended for the priority_queue in the function
// EpsilonClosure().
bool operator > (const Element &other) const {
return state > other.state;
}
// This operator is only intended to support sorting in EpsilonClosure()
bool operator < (const Element &other) const {
return state < other.state;
}
};
// Arcs in the format we temporarily create in this class (a representation, essentially of
// a Gallic Fst).
struct TempArc {
Label ilabel;
StringId string; // Look it up in the StringRepository, it's a sequence of Labels.
OutputStateId nextstate; // or kNoState for final weights.
Weight weight;
};
// Hashing function used in hash of subsets.
// A subset is a pointer to vector<Element>.
// The Elements are in sorted order on state id, and without repeated states.
// Because the order of Elements is fixed, we can use a hashing function that is
// order-dependent. However the weights are not included in the hashing function--
// we hash subsets that differ only in weight to the same key. This is not optimal
// in terms of the O(N) performance but typically if we have a lot of determinized
// states that differ only in weight then the input probably was pathological in some way,
// or even non-determinizable.
// We don't quantize the weights, in order to avoid inexactness in simple cases.
// Instead we apply the delta when comparing subsets for equality, and allow a small
// difference.
class SubsetKey {
public:
size_t operator ()(const vector<Element> * subset) const { // hashes only the state and string.
size_t hash = 0, factor = 1;
for (typename vector<Element>::const_iterator iter= subset->begin(); iter != subset->end(); ++iter) {
hash *= factor;
hash += iter->state + reinterpret_cast<size_t>(iter->string);
factor *= 23531; // these numbers are primes.
}
return hash;
}
};
// This is the equality operator on subsets. It checks for exact match on state-id
// and string, and approximate match on weights.
class SubsetEqual {
public:
bool operator ()(const vector<Element> * s1, const vector<Element> * s2) const {
size_t sz = s1->size();
KALDI_ASSERT(sz>=0);
if (sz != s2->size()) return false;
typename vector<Element>::const_iterator iter1 = s1->begin(),
iter1_end = s1->end(), iter2=s2->begin();
for (; iter1 < iter1_end; ++iter1, ++iter2) {
if (iter1->state != iter2->state ||
iter1->string != iter2->string ||
! ApproxEqual(iter1->weight, iter2->weight, delta_)) return false;
}
return true;
}
float delta_;
SubsetEqual(float delta): delta_(delta) {}
SubsetEqual(): delta_(kDelta) {}
};
// Operator that says whether two Elements have the same states.
// Used only for debug.
class SubsetEqualStates {
public:
bool operator ()(const vector<Element> * s1, const vector<Element> * s2) const {
size_t sz = s1->size();
KALDI_ASSERT(sz>=0);
if (sz != s2->size()) return false;
typename vector<Element>::const_iterator iter1 = s1->begin(),
iter1_end = s1->end(), iter2=s2->begin();
for (; iter1 < iter1_end; ++iter1, ++iter2) {
if (iter1->state != iter2->state) return false;
}
return true;
}
};
// Define the hash type we use to map subsets (in minimal
// representation) to OutputStateId.
typedef unordered_map<const vector<Element>*, OutputStateId,
SubsetKey, SubsetEqual> MinimalSubsetHash;
// Define the hash type we use to map subsets (in initial
// representation) to OutputStateId, together with an
// extra weight. [note: we interpret the Element.state in here
// as an OutputStateId even though it's declared as InputStateId;
// these types are the same anyway].
typedef unordered_map<const vector<Element>*, Element,
SubsetKey, SubsetEqual> InitialSubsetHash;
// converts the representation of the subset from canonical (all states) to
// minimal (only states with output symbols on arcs leaving them, and final
// states). Output is not necessarily normalized, even if input_subset was.
void ConvertToMinimal(vector<Element> *subset) {
KALDI_ASSERT(!subset->empty());
typename vector<Element>::iterator cur_in = subset->begin(),
cur_out = subset->begin(), end = subset->end();
while (cur_in != end) {
if(IsIsymbolOrFinal(cur_in->state)) { // keep it...
*cur_out = *cur_in;
cur_out++;
}
cur_in++;
}
subset->resize(cur_out - subset->begin());
}
// Takes a minimal, normalized subset, and converts it to an OutputStateId.
// Involves a hash lookup, and possibly adding a new OutputStateId.
// If it creates a new OutputStateId, it creates a new record for it, works
// out its final-weight, and puts stuff on the queue relating to its
// transitions.
OutputStateId MinimalToStateId(const vector<Element> &subset,
const double forward_cost) {
typename MinimalSubsetHash::const_iterator iter
= minimal_hash_.find(&subset);
if (iter != minimal_hash_.end()) { // Found a matching subset.
OutputStateId state_id = iter->second;
const OutputState &state = *(output_states_[state_id]);
// Below is just a check that the algorithm is working...
if (forward_cost < state.forward_cost - 0.1) {
// for large weights, this check could fail due to roundoff.
KALDI_WARN << "New cost is less (check the difference is small) "
<< forward_cost << ", "
<< state.forward_cost;
}
return state_id;
}
OutputStateId state_id = static_cast<OutputStateId>(output_states_.size());
OutputState *new_state = new OutputState(subset, forward_cost);
minimal_hash_[&(new_state->minimal_subset)] = state_id;
output_states_.push_back(new_state);
num_elems_ += subset.size();
// Note: in the previous algorithm, we pushed the new state-id onto the queue
// at this point. Here, the queue happens elsewhere, and we directly process
// the state (which result in stuff getting added to the queue).
ProcessFinal(state_id); // will work out the final-prob.
ProcessTransitions(state_id); // will process transitions and add stuff to the queue.
return state_id;
}
// Given a normalized initial subset of elements (i.e. before epsilon closure),
// compute the corresponding output-state.
OutputStateId InitialToStateId(const vector<Element> &subset_in,
double forward_cost,
Weight *remaining_weight,
StringId *common_prefix) {
typename InitialSubsetHash::const_iterator iter
= initial_hash_.find(&subset_in);
if (iter != initial_hash_.end()) { // Found a matching subset.
const Element &elem = iter->second;
*remaining_weight = elem.weight;
*common_prefix = elem.string;
if (elem.weight == Weight::Zero())
KALDI_WARN << "Zero weight!";
return elem.state;
}
// else no matching subset-- have to work it out.
vector<Element> subset(subset_in);
// Follow through epsilons. Will add no duplicate states. note: after
// EpsilonClosure, it is the same as "canonical" subset, except not
// normalized (actually we never compute the normalized canonical subset,
// only the normalized minimal one).
EpsilonClosure(&subset); // follow epsilons.
ConvertToMinimal(&subset); // remove all but emitting and final states.
Element elem; // will be used to store remaining weight and string, and
// OutputStateId, in initial_hash_;
NormalizeSubset(&subset, &elem.weight, &elem.string); // normalize subset; put
// common string and weight in "elem". The subset is now a minimal,
// normalized subset.
forward_cost += ConvertToCost(elem.weight);
OutputStateId ans = MinimalToStateId(subset, forward_cost);
*remaining_weight = elem.weight;
*common_prefix = elem.string;
if (elem.weight == Weight::Zero())
KALDI_WARN << "Zero weight!";
// Before returning "ans", add the initial subset to the hash,
// so that we can bypass the epsilon-closure etc., next time
// we process the same initial subset.
vector<Element> *initial_subset_ptr = new vector<Element>(subset_in);
elem.state = ans;
initial_hash_[initial_subset_ptr] = elem;
num_elems_ += initial_subset_ptr->size(); // keep track of memory usage.
return ans;
}
// returns the Compare value (-1 if a < b, 0 if a == b, 1 if a > b) according
// to the ordering we defined on strings for the CompactLatticeWeightTpl.
// see function
// inline int Compare (const CompactLatticeWeightTpl<WeightType,IntType> &w1,
// const CompactLatticeWeightTpl<WeightType,IntType> &w2)
// in lattice-weight.h.
// this is the same as that, but optimized for our data structures.
inline int Compare(const Weight &a_w, StringId a_str,
const Weight &b_w, StringId b_str) const {
int weight_comp = fst::Compare(a_w, b_w);
if (weight_comp != 0) return weight_comp;
// now comparing strings.
if (a_str == b_str) return 0;
vector<IntType> a_vec, b_vec;
repository_.ConvertToVector(a_str, &a_vec);
repository_.ConvertToVector(b_str, &b_vec);
// First compare their lengths.
int a_len = a_vec.size(), b_len = b_vec.size();
// use opposite order on the string lengths (c.f. Compare in
// lattice-weight.h)
if (a_len > b_len) return -1;
else if (a_len < b_len) return 1;
for(int i = 0; i < a_len; i++) {
if (a_vec[i] < b_vec[i]) return -1;
else if (a_vec[i] > b_vec[i]) return 1;
}
KALDI_ASSERT(0); // because we checked if a_str == b_str above, shouldn't reach here
return 0;
}
// This function computes epsilon closure of subset of states by following epsilon links.
// Called by InitialToStateId and Initialize.
// Has no side effects except on the string repository. The "output_subset" is not
// necessarily normalized (in the sense of there being no common substring), unless
// input_subset was.
void EpsilonClosure(vector<Element> *subset) {
// at input, subset must have only one example of each StateId. [will still
// be so at output]. This function follows input-epsilons, and augments the
// subset accordingly.
std::priority_queue<Element, vector<Element>, greater<Element> > queue;
unordered_map<InputStateId, Element> cur_subset;
typedef typename unordered_map<InputStateId, Element>::iterator MapIter;
typedef typename vector<Element>::const_iterator VecIter;
for (VecIter iter = subset->begin(); iter != subset->end(); ++iter) {
queue.push(*iter);
cur_subset[iter->state] = *iter;
}
// find whether input fst is known to be sorted on input label.
bool sorted = ((ifst_->Properties(kILabelSorted, false) & kILabelSorted) != 0);
bool replaced_elems = false; // relates to an optimization, see below.
int counter = 0; // stops infinite loops here for non-lattice-determinizable input
// (e.g. input with negative-cost epsilon loops); useful in testing.
while (queue.size() != 0) {
Element elem = queue.top();
queue.pop();
// The next if-statement is a kind of optimization. It's to prevent us
// unnecessarily repeating the processing of a state. "cur_subset" always
// contains only one Element with a particular state. The issue is that
// whenever we modify the Element corresponding to that state in "cur_subset",
// both the new (optimal) and old (less-optimal) Element will still be in
// "queue". The next if-statement stops us from wasting compute by
// processing the old Element.
if (replaced_elems && cur_subset[elem.state] != elem)
continue;
if (opts_.max_loop > 0 && counter++ > opts_.max_loop) {
KALDI_ERR << "Lattice determinization aborted since looped more than "
<< opts_.max_loop << " times during epsilon closure.";
}
for (ArcIterator<ExpandedFst<Arc> > aiter(*ifst_, elem.state); !aiter.Done(); aiter.Next()) {
const Arc &arc = aiter.Value();
if (sorted && arc.ilabel != 0) break; // Break from the loop: due to sorting there will be no
// more transitions with epsilons as input labels.
if (arc.ilabel == 0
&& arc.weight != Weight::Zero()) { // Epsilon transition.
Element next_elem;
next_elem.state = arc.nextstate;
next_elem.weight = Times(elem.weight, arc.weight);
// next_elem.string is not set up yet... create it only
// when we know we need it (this is an optimization)
MapIter iter = cur_subset.find(next_elem.state);
if (iter == cur_subset.end()) {
// was no such StateId: insert and add to queue.
next_elem.string = (arc.olabel == 0 ? elem.string :
repository_.Successor(elem.string, arc.olabel));
cur_subset[next_elem.state] = next_elem;
queue.push(next_elem);
} else {
// was not inserted because one already there. In normal
// determinization we'd add the weights. Here, we find which one
// has the better weight, and keep its corresponding string.
int comp = fst::Compare(next_elem.weight, iter->second.weight);
if (comp == 0) { // A tie on weights. This should be a rare case;
// we don't optimize for it.
next_elem.string = (arc.olabel == 0 ? elem.string :
repository_.Successor(elem.string,
arc.olabel));
comp = Compare(next_elem.weight, next_elem.string,
iter->second.weight, iter->second.string);
}
if(comp == 1) { // next_elem is better, so use its (weight, string)
next_elem.string = (arc.olabel == 0 ? elem.string :
repository_.Successor(elem.string, arc.olabel));
iter->second.string = next_elem.string;
iter->second.weight = next_elem.weight;
queue.push(next_elem);
replaced_elems = true;
}
// else it is the same or worse, so use original one.
}
}
}
}
{ // copy cur_subset to subset.
subset->clear();
subset->reserve(cur_subset.size());
MapIter iter = cur_subset.begin(), end = cur_subset.end();
for (; iter != end; ++iter) subset->push_back(iter->second);
// sort by state ID, because the subset hash function is order-dependent(see SubsetKey)
std::sort(subset->begin(), subset->end());
}
}
// This function works out the final-weight of the determinized state.
// called by ProcessSubset.
// Has no side effects except on the variable repository_, and
// output_states_[output_state_id].arcs
void ProcessFinal(OutputStateId output_state_id) {
OutputState &state = *(output_states_[output_state_id]);
const vector<Element> &minimal_subset = state.minimal_subset;
// processes final-weights for this subset. state.minimal_subset_ may be
// empty if the graphs is not connected/trimmed, I think, do don't check
// that it's nonempty.
StringId final_string = repository_.EmptyString(); // set it to keep the
// compiler happy; if it doesn't get set in the loop, we won't use the value anyway.
Weight final_weight = Weight::Zero();
bool is_final = false;
typename vector<Element>::const_iterator iter = minimal_subset.begin(), end = minimal_subset.end();
for (; iter != end; ++iter) {
const Element &elem = *iter;
Weight this_final_weight = Times(elem.weight, ifst_->Final(elem.state));
StringId this_final_string = elem.string;
if (this_final_weight != Weight::Zero() &&
(!is_final || Compare(this_final_weight, this_final_string,
final_weight, final_string) == 1)) { // the new
// (weight, string) pair is more in semiring than our current
// one.
is_final = true;
final_weight = this_final_weight;
final_string = this_final_string;
}
}
if (is_final &&
ConvertToCost(final_weight) + state.forward_cost <= cutoff_) {
// store final weights in TempArc structure, just like a transition.
// Note: we only store the final-weight if it's inside the pruning beam, hence
// the stuff with Compare.
TempArc temp_arc;
temp_arc.ilabel = 0;
temp_arc.nextstate = kNoStateId; // special marker meaning "final weight".
temp_arc.string = final_string;
temp_arc.weight = final_weight;
state.arcs.push_back(temp_arc);
num_arcs_++;
}
}
// NormalizeSubset normalizes the subset "elems" by
// removing any common string prefix (putting it in common_str),
// and dividing by the total weight (putting it in tot_weight).
void NormalizeSubset(vector<Element> *elems,
Weight *tot_weight,
StringId *common_str) {
if(elems->empty()) { // just set common_str, tot_weight
// to defaults and return...
KALDI_WARN << "empty subset";
*common_str = repository_.EmptyString();
*tot_weight = Weight::Zero();
return;
}
size_t size = elems->size();
vector<IntType> common_prefix;
repository_.ConvertToVector((*elems)[0].string, &common_prefix);
Weight weight = (*elems)[0].weight;
for(size_t i = 1; i < size; i++) {
weight = Plus(weight, (*elems)[i].weight);
repository_.ReduceToCommonPrefix((*elems)[i].string, &common_prefix);
}
KALDI_ASSERT(weight != Weight::Zero()); // we made sure to ignore arcs with zero
// weights on them, so we shouldn't have zero here.
size_t prefix_len = common_prefix.size();
for(size_t i = 0; i < size; i++) {
(*elems)[i].weight = Divide((*elems)[i].weight, weight, DIVIDE_LEFT);
(*elems)[i].string =
repository_.RemovePrefix((*elems)[i].string, prefix_len);
}
*common_str = repository_.ConvertFromVector(common_prefix);
*tot_weight = weight;
}
// Take a subset of Elements that is sorted on state, and
// merge any Elements that have the same state (taking the best
// (weight, string) pair in the semiring).
void MakeSubsetUnique(vector<Element> *subset) {
typedef typename vector<Element>::iterator IterType;
// This KALDI_ASSERT is designed to fail (usually) if the subset is not sorted on
// state.
KALDI_ASSERT(subset->size() < 2 || (*subset)[0].state <= (*subset)[1].state);
IterType cur_in = subset->begin(), cur_out = cur_in, end = subset->end();
size_t num_out = 0;
// Merge elements with same state-id
while (cur_in != end) { // while we have more elements to process.
// At this point, cur_out points to location of next place we want to put an element,
// cur_in points to location of next element we want to process.
if (cur_in != cur_out) *cur_out = *cur_in;
cur_in++;
while (cur_in != end && cur_in->state == cur_out->state) {
if (Compare(cur_in->weight, cur_in->string,
cur_out->weight, cur_out->string) == 1) {
// if *cur_in > *cur_out in semiring, then take *cur_in.
cur_out->string = cur_in->string;
cur_out->weight = cur_in->weight;
}
cur_in++;
}
cur_out++;
num_out++;
}
subset->resize(num_out);
}
// ProcessTransition was called from "ProcessTransitions" in the non-pruned
// code, but now we in effect put the calls to ProcessTransition on a priority
// queue, and it now gets called directly from Determinize(). This function
// processes a transition from state "ostate_id". The set "subset" of Elements
// represents a set of next-states with associated weights and strings, each
// one arising from an arc from some state in a determinized-state; the
// next-states are unique (there is only one Entry assocated with each)
void ProcessTransition(OutputStateId ostate_id, Label ilabel, vector<Element> *subset) {
double forward_cost = output_states_[ostate_id]->forward_cost;
StringId common_str;
Weight tot_weight;
NormalizeSubset(subset, &tot_weight, &common_str);
forward_cost += ConvertToCost(tot_weight);
OutputStateId nextstate;
{
Weight next_tot_weight;
StringId next_common_str;
nextstate = InitialToStateId(*subset,
forward_cost,
&next_tot_weight,
&next_common_str);
common_str = repository_.Concatenate(common_str, next_common_str);
tot_weight = Times(tot_weight, next_tot_weight);
}
// Now add an arc to the next state (would have been created if necessary by
// InitialToStateId).
TempArc temp_arc;
temp_arc.ilabel = ilabel;
temp_arc.nextstate = nextstate;
temp_arc.string = common_str;
temp_arc.weight = tot_weight;
output_states_[ostate_id]->arcs.push_back(temp_arc); // record the arc.
num_arcs_++;
}
// "less than" operator for pair<Label, Element>. Used in ProcessTransitions.
// Lexicographical order, which only compares the state when ordering the
// "Element" member of the pair.
class PairComparator {
public:
inline bool operator () (const pair<Label, Element> &p1, const pair<Label, Element> &p2) {
if (p1.first < p2.first) return true;
else if (p1.first > p2.first) return false;
else {
return p1.second.state < p2.second.state;
}
}
};
// ProcessTransitions processes emitting transitions (transitions with
// ilabels) out of this subset of states. It actualy only creates records
// ("Task") that get added to the queue. The transitions will be processed in
// priority order from Determinize(). This function soes not consider final
// states. Partitions the emitting transitions up by ilabel (by sorting on
// ilabel), and for each unique ilabel, it creates a Task record that contains
// the information we need to process the transition.
void ProcessTransitions(OutputStateId output_state_id) {
const vector<Element> &minimal_subset = output_states_[output_state_id]->minimal_subset;
// it's possible that minimal_subset could be empty if there are
// unreachable parts of the graph, so don't check that it's nonempty.
vector<pair<Label, Element> > &all_elems(all_elems_tmp_); // use class member
// to avoid memory allocation/deallocation.
{
// Push back into "all_elems", elements corresponding to all
// non-epsilon-input transitions out of all states in "minimal_subset".
typename vector<Element>::const_iterator iter = minimal_subset.begin(), end = minimal_subset.end();
for (;iter != end; ++iter) {
const Element &elem = *iter;
for (ArcIterator<ExpandedFst<Arc> > aiter(*ifst_, elem.state); ! aiter.Done(); aiter.Next()) {
const Arc &arc = aiter.Value();
if (arc.ilabel != 0
&& arc.weight != Weight::Zero()) { // Non-epsilon transition -- ignore epsilons here.
pair<Label, Element> this_pr;
this_pr.first = arc.ilabel;
Element &next_elem(this_pr.second);
next_elem.state = arc.nextstate;
next_elem.weight = Times(elem.weight, arc.weight);
if (arc.olabel == 0) // output epsilon
next_elem.string = elem.string;
else
next_elem.string = repository_.Successor(elem.string, arc.olabel);
all_elems.push_back(this_pr);
}
}
}
}
PairComparator pc;
std::sort(all_elems.begin(), all_elems.end(), pc);
// now sorted first on input label, then on state.
typedef typename vector<pair<Label, Element> >::const_iterator PairIter;
PairIter cur = all_elems.begin(), end = all_elems.end();
while (cur != end) {
// The old code (non-pruned) called ProcessTransition; here, instead,
// we'll put the calls into a priority queue.
Task *task = new Task;
// Process ranges that share the same input symbol.
Label ilabel = cur->first;
task->state = output_state_id;
task->priority_cost = std::numeric_limits<double>::infinity();
task->label = ilabel;
while (cur != end && cur->first == ilabel) {
task->subset.push_back(cur->second);
const Element &element = cur->second;
// Note: we'll later include the term "forward_cost" in the
// priority_cost.
task->priority_cost = std::min(task->priority_cost,
ConvertToCost(element.weight) +
backward_costs_[element.state]);
cur++;
}
// After the command below, the "priority_cost" is a value comparable to
// the total-weight of the input FST, like a total-path weight... of
// course, it will typically be less (in the semiring) than that.
// note: we represent it just as a double.
task->priority_cost += output_states_[output_state_id]->forward_cost;
if (task->priority_cost > cutoff_) {
// This task would never get done as it's past the pruning cutoff.
delete task;
} else {
MakeSubsetUnique(&(task->subset)); // remove duplicate Elements with the same state.
queue_.push(task); // Push the task onto the queue. The queue keeps it
// in prioritized order, so we always process the one with the "best"
// weight (highest in the semiring).
{ // this is a check.
double best_cost = backward_costs_[ifst_->Start()],
tolerance = 0.01 + 1.0e-04 * std::abs(best_cost);
if (task->priority_cost < best_cost - tolerance) {
KALDI_WARN << "Cost below best cost was encountered:"
<< task->priority_cost << " < " << best_cost;
}
}
}
}
all_elems.clear(); // as it's a reference to a class variable; we want it to stay
// empty.
}
bool IsIsymbolOrFinal(InputStateId state) { // returns true if this state
// of the input FST either is final or has an osymbol on an arc out of it.
// Uses the vector isymbol_or_final_ as a cache for this info.
KALDI_ASSERT(state >= 0);
if (isymbol_or_final_.size() <= state)
isymbol_or_final_.resize(state+1, static_cast<char>(OSF_UNKNOWN));
if (isymbol_or_final_[state] == static_cast<char>(OSF_NO))
return false;
else if (isymbol_or_final_[state] == static_cast<char>(OSF_YES))
return true;
// else work it out...
isymbol_or_final_[state] = static_cast<char>(OSF_NO);
if (ifst_->Final(state) != Weight::Zero())
isymbol_or_final_[state] = static_cast<char>(OSF_YES);
for (ArcIterator<ExpandedFst<Arc> > aiter(*ifst_, state);
!aiter.Done();
aiter.Next()) {
const Arc &arc = aiter.Value();
if (arc.ilabel != 0 && arc.weight != Weight::Zero()) {
isymbol_or_final_[state] = static_cast<char>(OSF_YES);
return true;
}
}
return IsIsymbolOrFinal(state); // will only recurse once.
}
void ComputeBackwardWeight() {
// Sets up the backward_costs_ array, and the cutoff_ variable.
KALDI_ASSERT(beam_ > 0);
// Only handle the toplogically sorted case.
backward_costs_.resize(ifst_->NumStates());
for (StateId s = ifst_->NumStates() - 1; s >= 0; s--) {
double &cost = backward_costs_[s];
cost = ConvertToCost(ifst_->Final(s));
for (ArcIterator<ExpandedFst<Arc> > aiter(*ifst_, s);
!aiter.Done(); aiter.Next()) {
const Arc &arc = aiter.Value();
cost = std::min(cost,
ConvertToCost(arc.weight) + backward_costs_[arc.nextstate]);
}
}
if (ifst_->Start() == kNoStateId) return; // we'll be returning
// an empty FST.
double best_cost = backward_costs_[ifst_->Start()];
if (best_cost == std::numeric_limits<double>::infinity())
KALDI_WARN << "Total weight of input lattice is zero.";
cutoff_ = best_cost + beam_;
}
void InitializeDeterminization() {
// We insist that the input lattice be topologically sorted. This is not a
// fundamental limitation of the algorithm (which in principle should be
// applicable to even cyclic FSTs), but it helps us more efficiently
// compute the backward_costs_ array. There may be some other reason we
// require this, that escapes me at the moment.
KALDI_ASSERT(ifst_->Properties(kTopSorted, true) != 0);
ComputeBackwardWeight();
#if !(__GNUC__ == 4 && __GNUC_MINOR__ == 0)
if(ifst_->Properties(kExpanded, false) != 0) { // if we know the number of
// states in ifst_, it might be a bit more efficient
// to pre-size the hashes so we're not constantly rebuilding them.
StateId num_states =
down_cast<const ExpandedFst<Arc>*, const Fst<Arc> >(ifst_)->NumStates();
minimal_hash_.rehash(num_states/2 + 3);
initial_hash_.rehash(num_states/2 + 3);
}
#endif
InputStateId start_id = ifst_->Start();
if (start_id != kNoStateId) {
/* Create determinized-state corresponding to the start state....
Unlike all the other states, we don't "normalize" the representation
of this determinized-state before we put it into minimal_hash_. This is actually
what we want, as otherwise we'd have problems dealing with any extra weight
and string and might have to create a "super-initial" state which would make
the output nondeterministic. Normalization is only needed to make the
determinized output more minimal anyway, it's not needed for correctness.
Note, we don't put anything in the initial_hash_. The initial_hash_ is only
a lookaside buffer anyway, so this isn't a problem-- it will get populated
later if it needs to be.
*/
vector<Element> subset(1);
subset[0].state = start_id;
subset[0].weight = Weight::One();
subset[0].string = repository_.EmptyString(); // Id of empty sequence.
EpsilonClosure(&subset); // follow through epsilon-input links
ConvertToMinimal(&subset); // remove all but final states and
// states with input-labels on arcs out of them.
// Weight::One() is the "forward-weight" of this determinized state...
// i.e. the minimal cost from the start of the determinized FST to this
// state [One() because it's the start state].
OutputState *initial_state = new OutputState(subset, 0);
KALDI_ASSERT(output_states_.empty());
output_states_.push_back(initial_state);
num_elems_ += subset.size();
OutputStateId initial_state_id = 0;
minimal_hash_[&(initial_state->minimal_subset)] = initial_state_id;
ProcessFinal(initial_state_id);
ProcessTransitions(initial_state_id); // this will add tasks to
// the queue, which we'll start processing in Determinize().
}
}
KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeDeterminizerPruned);
struct OutputState {
vector<Element> minimal_subset;
vector<TempArc> arcs; // arcs out of the state-- those that have been processed.
// Note: the final-weight is included here with kNoStateId as the state id. We
// always process the final-weight regardless of the beam; when producing the
// output we may have to ignore some of these.
double forward_cost; // Represents minimal cost from start-state
// to this state. Used in prioritization of tasks, and pruning.
// Note: we know this minimal cost from when we first create the OutputState;
// this is because of the priority-queue we use, that ensures that the
// "best" path into the state will be expanded first.
OutputState(const vector<Element> &minimal_subset,
double forward_cost): minimal_subset(minimal_subset),
forward_cost(forward_cost) { }
};
vector<OutputState*> output_states_; // All the info about the output states.
int num_arcs_; // keep track of memory usage: number of arcs in output_states_[ ]->arcs
int num_elems_; // keep track of memory usage: number of elems in output_states_ and
// the keys of initial_hash_
const ExpandedFst<Arc> *ifst_;
std::vector<double> backward_costs_; // This vector stores, for every state in ifst_,
// the minimal cost to the end-state (i.e. the sum of weights; they are guaranteed to
// have "take-the-minimum" semantics). We get the double from the ConvertToCost()
// function on the lattice weights.
double beam_;
double cutoff_; // beam plus total-weight of input (and note, the weight is
// guaranteed to be "tropical-like" so the sum does represent a min-cost.
DeterminizeLatticePrunedOptions opts_;
SubsetKey hasher_; // object that computes keys-- has no data members.
SubsetEqual equal_; // object that compares subsets-- only data member is delta_.
bool determinized_; // set to true when user called Determinize(); used to make
// sure this object is used correctly.
MinimalSubsetHash minimal_hash_; // hash from Subset to OutputStateId. Subset is "minimal
// representation" (only include final and states and states with
// nonzero ilabel on arc out of them. Owns the pointers
// in its keys.
InitialSubsetHash initial_hash_; // hash from Subset to Element, which
// represents the OutputStateId together
// with an extra weight and string. Subset
// is "initial representation". The extra
// weight and string is needed because after
// we convert to minimal representation and
// normalize, there may be an extra weight
// and string. Owns the pointers
// in its keys.
struct Task {
OutputStateId state; // State from which we're processing the transition.
Label label; // Label on the transition we're processing out of this state.
vector<Element> subset; // Weighted subset of states (with strings)-- not normalized.
double priority_cost; // Cost used in deciding priority of tasks. Note:
// we assume there is a ConvertToCost() function that converts the semiring to double.
};
struct TaskCompare {
inline int operator() (const Task *t1, const Task *t2) {
// view this like operator <, which is the default template parameter
// to std::priority_queue.
// returns true if t1 is worse than t2.
return (t1->priority_cost > t2->priority_cost);
}
};
// This priority queue contains "Task"s to be processed; these correspond
// to transitions out of determinized states. We process these in priority
// order according to the best weight of any path passing through these
// determinized states... it's possible to work this out.
std::priority_queue<Task*, vector<Task*>, TaskCompare> queue_;
vector<pair<Label, Element> > all_elems_tmp_; // temporary vector used in ProcessTransitions.
enum IsymbolOrFinal { OSF_UNKNOWN = 0, OSF_NO = 1, OSF_YES = 2 };
vector<char> isymbol_or_final_; // A kind of cache; it says whether
// each state is (emitting or final) where emitting means it has at least one
// non-epsilon output arc. Only accessed by IsIsymbolOrFinal()
LatticeStringRepository<IntType> repository_; // defines a compact and fast way of
// storing sequences of labels.
void AddStrings(const vector<Element> &vec,
vector<StringId> *needed_strings) {
for (typename std::vector<Element>::const_iterator iter = vec.begin();
iter != vec.end(); ++iter)
needed_strings->push_back(iter->string);
}
};
// normally Weight would be LatticeWeight<float> (which has two floats),
// or possibly TropicalWeightTpl<float>, and IntType would be int32.
// Caution: there are two versions of the function DeterminizeLatticePruned,
// with identical code but different output FST types.
template<class Weight, class IntType>
bool DeterminizeLatticePruned(
const ExpandedFst<ArcTpl<Weight> >&ifst,
double beam,
MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, IntType> > >*ofst,
DeterminizeLatticePrunedOptions opts) {
ofst->SetInputSymbols(ifst.InputSymbols());
ofst->SetOutputSymbols(ifst.OutputSymbols());
if (ifst.NumStates() == 0) {
ofst->DeleteStates();
return true;
}
KALDI_ASSERT(opts.retry_cutoff >= 0.0 && opts.retry_cutoff < 1.0);
int32 max_num_iters = 10; // avoid the potential for infinite loops if
// retrying.
VectorFst<ArcTpl<Weight> > temp_fst;
for (int32 iter = 0; iter < max_num_iters; iter++) {
LatticeDeterminizerPruned<Weight, IntType> det(iter == 0 ? ifst : temp_fst,
beam, opts);
double effective_beam;
bool ans = det.Determinize(&effective_beam);
// if it returns false it will typically still produce reasonable output,
// just with a narrower beam than "beam". If the user specifies an infinite
// beam we don't do this beam-narrowing.
if (effective_beam >= beam * opts.retry_cutoff ||
beam == std::numeric_limits<double>::infinity() ||
iter + 1 == max_num_iters) {
det.Output(ofst);
return ans;
} else {
// The code below to set "beam" is a heuristic.
// If effective_beam is very small, we want to reduce by a lot.
// But never change the beam by more than a factor of two.
if (effective_beam < 0.0) effective_beam = 0.0;
double new_beam = beam * sqrt(effective_beam / beam);
if (new_beam < 0.5 * beam) new_beam = 0.5 * beam;
beam = new_beam;
if (iter == 0) temp_fst = ifst;
kaldi::PruneLattice(beam, &temp_fst);
KALDI_LOG << "Pruned state-level lattice with beam " << beam
<< " and retrying determinization with that beam.";
}
}
return false; // Suppress compiler warning; this code is unreachable.
}
// normally Weight would be LatticeWeight<float> (which has two floats),
// or possibly TropicalWeightTpl<float>, and IntType would be int32.
// Caution: there are two versions of the function DeterminizeLatticePruned,
// with identical code but different output FST types.
template<class Weight>
bool DeterminizeLatticePruned(const ExpandedFst<ArcTpl<Weight> > &ifst,
double beam,
MutableFst<ArcTpl<Weight> > *ofst,
DeterminizeLatticePrunedOptions opts) {
typedef int32 IntType;
ofst->SetInputSymbols(ifst.InputSymbols());
ofst->SetOutputSymbols(ifst.OutputSymbols());
KALDI_ASSERT(opts.retry_cutoff >= 0.0 && opts.retry_cutoff < 1.0);
if (ifst.NumStates() == 0) {
ofst->DeleteStates();
return true;
}
int32 max_num_iters = 10; // avoid the potential for infinite loops if
// retrying.
VectorFst<ArcTpl<Weight> > temp_fst;
for (int32 iter = 0; iter < max_num_iters; iter++) {
LatticeDeterminizerPruned<Weight, IntType> det(iter == 0 ? ifst : temp_fst,
beam, opts);
double effective_beam;
bool ans = det.Determinize(&effective_beam);
// if it returns false it will typically still
// produce reasonable output, just with a
// narrower beam than "beam".
if (effective_beam >= beam * opts.retry_cutoff ||
iter + 1 == max_num_iters) {
det.Output(ofst);
return ans;
} else {
// The code below to set "beam" is a heuristic.
// If effective_beam is very small, we want to reduce by a lot.
// But never change the beam by more than a factor of two.
if (effective_beam < 0)
effective_beam = 0;
double new_beam = beam * sqrt(effective_beam / beam);
if (new_beam < 0.5 * beam) new_beam = 0.5 * beam;
KALDI_WARN << "Effective beam " << effective_beam << " was less than beam "
<< beam << " * cutoff " << opts.retry_cutoff << ", pruning raw "
<< "lattice with new beam " << new_beam << " and retrying.";
beam = new_beam;
if (iter == 0) temp_fst = ifst;
kaldi::PruneLattice(beam, &temp_fst);
}
}
return false; // Suppress compiler warning; this code is unreachable.
}
// template<class Weight>
// typename ArcTpl<Weight>::Label DeterminizeLatticeInsertPhones(
// const kaldi::TransitionModel &trans_model,
// MutableFst<ArcTpl<Weight> > *fst) {
// // Define some types.
// typedef ArcTpl<Weight> Arc;
// typedef typename Arc::StateId StateId;
// typedef typename Arc::Label Label;
//
// // Work out the first phone symbol. This is more related to the phone
// // insertion function, so we put it here and make it the returning value of
// // DeterminizeLatticeInsertPhones().
// Label first_phone_label = HighestNumberedInputSymbol(*fst) + 1;
//
// // Insert phones here.
// for (StateIterator<MutableFst<Arc> > siter(*fst);
// !siter.Done(); siter.Next()) {
// StateId state = siter.Value();
// if (state == fst->Start())
// continue;
// for (MutableArcIterator<MutableFst<Arc> > aiter(fst, state);
// !aiter.Done(); aiter.Next()) {
// Arc arc = aiter.Value();
//
// // Note: the words are on the input symbol side and transition-id's are on
// // the output symbol side.
// if ((arc.olabel != 0)
// && (trans_model.TransitionIdToHmmState(arc.olabel) == 0)
// && (!trans_model.IsSelfLoop(arc.olabel))) {
// Label phone =
// static_cast<Label>(trans_model.TransitionIdToPhone(arc.olabel));
//
// // Skips <eps>.
// KALDI_ASSERT(phone != 0);
//
// if (arc.ilabel == 0) {
// // If there is no word on the arc, insert the phone directly.
// arc.ilabel = first_phone_label + phone;
// } else {
// // Otherwise, add an additional arc.
// StateId additional_state = fst->AddState();
// StateId next_state = arc.nextstate;
// arc.nextstate = additional_state;
// fst->AddArc(additional_state,
// Arc(first_phone_label + phone, 0,
// Weight::One(), next_state));
// }
// }
//
// aiter.SetValue(arc);
// }
// }
//
// return first_phone_label;
// }
//
// template<class Weight>
// void DeterminizeLatticeDeletePhones(
// typename ArcTpl<Weight>::Label first_phone_label,
// MutableFst<ArcTpl<Weight> > *fst) {
// // Define some types.
// typedef ArcTpl<Weight> Arc;
// typedef typename Arc::StateId StateId;
// typedef typename Arc::Label Label;
//
// // Delete phones here.
// for (StateIterator<MutableFst<Arc> > siter(*fst);
// !siter.Done(); siter.Next()) {
// StateId state = siter.Value();
// for (MutableArcIterator<MutableFst<Arc> > aiter(fst, state);
// !aiter.Done(); aiter.Next()) {
// Arc arc = aiter.Value();
//
// if (arc.ilabel >= first_phone_label)
// arc.ilabel = 0;
//
// aiter.SetValue(arc);
// }
// }
// }
// instantiate for type LatticeWeight
// template
// void DeterminizeLatticeDeletePhones(
// ArcTpl<kaldi::LatticeWeight>::Label first_phone_label,
// MutableFst<ArcTpl<kaldi::LatticeWeight> > *fst);
//
// /** This function does a first pass determinization with phone symbols inserted
// at phone boundary. It uses a transition model to work out the transition-id
// to phone map. First, phones will be inserted into the word level lattice.
// Second, determinization will be applied on top of the phone + word lattice.
// Finally, the inserted phones will be removed, converting the lattice back to
// a word level lattice. The output lattice of this pass is not deterministic,
// since we remove the phone symbols as a last step. It is supposed to be
// followed by another pass of determinization at the word level. It could also
// be useful for some other applications such as fMLLR estimation, confidence
// estimation, discriminative training, etc.
// */
// template<class Weight, class IntType>
// bool DeterminizeLatticePhonePrunedFirstPass(
// const kaldi::TransitionModel &trans_model,
// double beam,
// MutableFst<ArcTpl<Weight> > *fst,
// const DeterminizeLatticePrunedOptions &opts) {
// // First, insert the phones.
// typename ArcTpl<Weight>::Label first_phone_label =
// DeterminizeLatticeInsertPhones(trans_model, fst);
// TopSort(fst);
//
// // Second, do determinization with phone inserted.
// bool ans = DeterminizeLatticePruned<Weight>(*fst, beam, fst, opts);
//
// // Finally, remove the inserted phones.
// DeterminizeLatticeDeletePhones(first_phone_label, fst);
// TopSort(fst);
//
// return ans;
// }
//
// // "Destructive" version of DeterminizeLatticePhonePruned() where the input
// // lattice might be modified.
// template<class Weight, class IntType>
// bool DeterminizeLatticePhonePruned(
// const kaldi::TransitionModel &trans_model,
// MutableFst<ArcTpl<Weight> > *ifst,
// double beam,
// MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, IntType> > > *ofst,
// DeterminizeLatticePhonePrunedOptions opts) {
// // Returning status.
// bool ans = true;
//
// // Make sure at least one of opts.phone_determinize and opts.word_determinize
// // is not false, otherwise calling this function doesn't make any sense.
// if ((opts.phone_determinize || opts.word_determinize) == false) {
// KALDI_WARN << "Both --phone-determinize and --word-determinize are set to "
// << "false, copying lattice without determinization.";
// // We are expecting the words on the input side.
// ConvertLattice<Weight, IntType>(*ifst, ofst, false);
// return ans;
// }
//
// // Determinization options.
// DeterminizeLatticePrunedOptions det_opts;
// det_opts.delta = opts.delta;
// det_opts.max_mem = opts.max_mem;
//
// // If --phone-determinize is true, do the determinization on phone + word
// // lattices.
// if (opts.phone_determinize) {
// KALDI_VLOG(3) << "Doing first pass of determinization on phone + word "
// << "lattices.";
// ans = DeterminizeLatticePhonePrunedFirstPass<Weight, IntType>(
// trans_model, beam, ifst, det_opts) && ans;
//
// // If --word-determinize is false, we've finished the job and return here.
// if (!opts.word_determinize) {
// // We are expecting the words on the input side.
// ConvertLattice<Weight, IntType>(*ifst, ofst, false);
// return ans;
// }
// }
//
// // If --word-determinize is true, do the determinization on word lattices.
// if (opts.word_determinize) {
// KALDI_VLOG(3) << "Doing second pass of determinization on word lattices.";
// ans = DeterminizeLatticePruned<Weight, IntType>(
// *ifst, beam, ofst, det_opts) && ans;
// }
//
// // If --minimize is true, push and minimize after determinization.
// if (opts.minimize) {
// KALDI_VLOG(3) << "Pushing and minimizing on word lattices.";
// ans = PushCompactLatticeStrings<Weight, IntType>(ofst) && ans;
// ans = PushCompactLatticeWeights<Weight, IntType>(ofst) && ans;
// ans = MinimizeCompactLattice<Weight, IntType>(ofst) && ans;
// }
//
// return ans;
// }
//
// // Normal verson of DeterminizeLatticePhonePruned(), where the input lattice
// // will be kept as unchanged.
// template<class Weight, class IntType>
// bool DeterminizeLatticePhonePruned(
// const kaldi::TransitionModel &trans_model,
// const ExpandedFst<ArcTpl<Weight> > &ifst,
// double beam,
// MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, IntType> > > *ofst,
// DeterminizeLatticePhonePrunedOptions opts) {
// VectorFst<ArcTpl<Weight> > temp_fst(ifst);
// return DeterminizeLatticePhonePruned(trans_model, &temp_fst,
// beam, ofst, opts);
// }
//
// bool DeterminizeLatticePhonePrunedWrapper(
// const kaldi::TransitionModel &trans_model,
// MutableFst<kaldi::LatticeArc> *ifst,
// double beam,
// MutableFst<kaldi::CompactLatticeArc> *ofst,
// DeterminizeLatticePhonePrunedOptions opts) {
// bool ans = true;
// Invert(ifst);
// if (ifst->Properties(fst::kTopSorted, true) == 0) {
// if (!TopSort(ifst)) {
// // Cannot topologically sort the lattice -- determinization will fail.
// KALDI_ERR << "Topological sorting of state-level lattice failed (probably"
// << " your lexicon has empty words or your LM has epsilon cycles"
// << ").";
// }
// }
// ILabelCompare<kaldi::LatticeArc> ilabel_comp;
// ArcSort(ifst, ilabel_comp);
// ans = DeterminizeLatticePhonePruned<kaldi::LatticeWeight, kaldi::int32>(
// trans_model, ifst, beam, ofst, opts);
// Connect(ofst);
// return ans;
// }
// Instantiate the templates for the types we might need.
// Note: there are actually four templates, each of which
// we instantiate for a single type.
template
bool DeterminizeLatticePruned<kaldi::LatticeWeight>(
const ExpandedFst<kaldi::LatticeArc> &ifst,
double prune,
MutableFst<kaldi::CompactLatticeArc> *ofst,
DeterminizeLatticePrunedOptions opts);
template
bool DeterminizeLatticePruned<kaldi::LatticeWeight>(
const ExpandedFst<kaldi::LatticeArc> &ifst,
double prune,
MutableFst<kaldi::LatticeArc> *ofst,
DeterminizeLatticePrunedOptions opts);
// template
// bool DeterminizeLatticePhonePruned<kaldi::LatticeWeight, kaldi::int32>(
// const kaldi::TransitionModel &trans_model,
// const ExpandedFst<kaldi::LatticeArc> &ifst,
// double prune,
// MutableFst<kaldi::CompactLatticeArc> *ofst,
// DeterminizeLatticePhonePrunedOptions opts);
//
// template
// bool DeterminizeLatticePhonePruned<kaldi::LatticeWeight, kaldi::int32>(
// const kaldi::TransitionModel &trans_model,
// MutableFst<kaldi::LatticeArc> *ifst,
// double prune,
// MutableFst<kaldi::CompactLatticeArc> *ofst,
// DeterminizeLatticePhonePrunedOptions opts);
}
// lat/determinize-lattice-pruned.h
// Copyright 2009-2012 Microsoft Corporation
// 2012-2013 Johns Hopkins University (Author: Daniel Povey)
// 2014 Guoguo Chen
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_LAT_DETERMINIZE_LATTICE_PRUNED_H_
#define KALDI_LAT_DETERMINIZE_LATTICE_PRUNED_H_
#include <fst/fstlib.h>
#include <fst/fst-decl.h>
#include <algorithm>
#include <map>
#include <set>
#include <vector>
#include "fstext/lattice-weight.h"
// #include "hmm/transition-model.h"
#include "itf/options-itf.h"
#include "lat/kaldi-lattice.h"
namespace fst {
/// \addtogroup fst_extensions
/// @{
// For example of usage, see test-determinize-lattice-pruned.cc
/*
DeterminizeLatticePruned implements a special form of determinization with
epsilon removal, optimized for a phase of lattice generation. This algorithm
also does pruning at the same time-- the combination is more efficient as it
somtimes prevents us from creating a lot of states that would later be pruned
away. This allows us to increase the lattice-beam and not have the algorithm
blow up. Also, because our algorithm processes states in order from those
that appear on high-scoring paths down to those that appear on low-scoring
paths, we can easily terminate the algorithm after a certain specified number
of states or arcs.
The input is an FST with weight-type BaseWeightType (usually a pair of floats,
with a lexicographical type of order, such as LatticeWeightTpl<float>).
Typically this would be a state-level lattice, with input symbols equal to
words, and output-symbols equal to p.d.f's (so like the inverse of HCLG). Imagine representing this as an
acceptor of type CompactLatticeWeightTpl<float>, in which the input/output
symbols are words, and the weights contain the original weights together with
strings (with zero or one symbol in them) containing the original output labels
(the p.d.f.'s). We determinize this using acceptor determinization with
epsilon removal. Remember (from lattice-weight.h) that
CompactLatticeWeightTpl has a special kind of semiring where we always take
the string corresponding to the best cost (of type BaseWeightType), and
discard the other. This corresponds to taking the best output-label sequence
(of p.d.f.'s) for each input-label sequence (of words). We couldn't use the
Gallic weight for this, or it would die as soon as it detected that the input
FST was non-functional. In our case, any acyclic FST (and many cyclic ones)
can be determinized.
We assume that there is a function
Compare(const BaseWeightType &a, const BaseWeightType &b)
that returns (-1, 0, 1) according to whether (a < b, a == b, a > b) in the
total order on the BaseWeightType... this information should be the
same as NaturalLess would give, but it's more efficient to do it this way.
You can define this for things like TropicalWeight if you need to instantiate
this class for that weight type.
We implement this determinization in a special way to make it efficient for
the types of FSTs that we will apply it to. One issue is that if we
explicitly represent the strings (in CompactLatticeWeightTpl) as vectors of
type vector<IntType>, the algorithm takes time quadratic in the length of
words (in states), because propagating each arc involves copying a whole
vector (of integers representing p.d.f.'s). Instead we use a hash structure
where each string is a pointer (Entry*), and uses a hash from (Entry*,
IntType), to the successor string (and a way to get the latest IntType and the
ancestor Entry*). [this is the class LatticeStringRepository].
Another issue is that rather than representing a determinized-state as a
collection of (state, weight), we represent it in a couple of reduced forms.
Suppose a determinized-state is a collection of (state, weight) pairs; call
this the "canonical representation". Note: these collections are always
normalized to remove any common weight and string part. Define end-states as
the subset of states that have an arc out of them with a label on, or are
final. If we represent a determinized-state a the set of just its (end-state,
weight) pairs, this will be a valid and more compact representation, and will
lead to a smaller set of determinized states (like early minimization). Call
this collection of (end-state, weight) pairs the "minimal representation". As
a mechanism to reduce compute, we can also consider another representation.
In the determinization algorithm, we start off with a set of (begin-state,
weight) pairs (where the "begin-states" are initial or have a label on the
transition into them), and the "canonical representation" consists of the
epsilon-closure of this set (i.e. follow epsilons). Call this set of
(begin-state, weight) pairs, appropriately normalized, the "initial
representation". If two initial representations are the same, the "canonical
representation" and hence the "minimal representation" will be the same. We
can use this to reduce compute. Note that if two initial representations are
different, this does not preclude the other representations from being the same.
*/
struct DeterminizeLatticePrunedOptions {
float delta; // A small offset used to measure equality of weights.
int max_mem; // If >0, determinization will fail and return false
// when the algorithm's (approximate) memory consumption crosses this threshold.
int max_loop; // If >0, can be used to detect non-determinizable input
// (a case that wouldn't be caught by max_mem).
int max_states;
int max_arcs;
float retry_cutoff;
DeterminizeLatticePrunedOptions(): delta(kDelta),
max_mem(-1),
max_loop(-1),
max_states(-1),
max_arcs(-1),
retry_cutoff(0.5) { }
void Register (kaldi::OptionsItf *opts) {
opts->Register("delta", &delta, "Tolerance used in determinization");
opts->Register("max-mem", &max_mem, "Maximum approximate memory usage in "
"determinization (real usage might be many times this)");
opts->Register("max-arcs", &max_arcs, "Maximum number of arcs in "
"output FST (total, not per state");
opts->Register("max-states", &max_states, "Maximum number of arcs in output "
"FST (total, not per state");
opts->Register("max-loop", &max_loop, "Option used to detect a particular "
"type of determinization failure, typically due to invalid input "
"(e.g., negative-cost loops)");
opts->Register("retry-cutoff", &retry_cutoff, "Controls pruning un-determinized "
"lattice and retrying determinization: if effective-beam < "
"retry-cutoff * beam, we prune the raw lattice and retry. Avoids "
"ever getting empty output for long segments.");
}
};
struct DeterminizeLatticePhonePrunedOptions {
// delta: a small offset used to measure equality of weights.
float delta;
// max_mem: if > 0, determinization will fail and return false when the
// algorithm's (approximate) memory consumption crosses this threshold.
int max_mem;
// phone_determinize: if true, do a first pass determinization on both phones
// and words.
bool phone_determinize;
// word_determinize: if true, do a second pass determinization on words only.
bool word_determinize;
// minimize: if true, push and minimize after determinization.
bool minimize;
DeterminizeLatticePhonePrunedOptions(): delta(kDelta),
max_mem(50000000),
phone_determinize(true),
word_determinize(true),
minimize(false) {}
void Register (kaldi::OptionsItf *opts) {
opts->Register("delta", &delta, "Tolerance used in determinization");
opts->Register("max-mem", &max_mem, "Maximum approximate memory usage in "
"determinization (real usage might be many times this).");
opts->Register("phone-determinize", &phone_determinize, "If true, do an "
"initial pass of determinization on both phones and words (see"
" also --word-determinize)");
opts->Register("word-determinize", &word_determinize, "If true, do a second "
"pass of determinization on words only (see also "
"--phone-determinize)");
opts->Register("minimize", &minimize, "If true, push and minimize after "
"determinization.");
}
};
/**
This function implements the normal version of DeterminizeLattice, in which the
output strings are represented using sequences of arcs, where all but the
first one has an epsilon on the input side. It also prunes using the beam
in the "prune" parameter. The input FST must be topologically sorted in order
for the algorithm to work. For efficiency it is recommended to sort ilabel as well.
Returns true on success, and false if it had to terminate the determinization
earlier than specified by the "prune" beam-- that is, if it terminated because
of the max_mem, max_loop or max_arcs constraints in the options.
CAUTION: you may want to use the version below which outputs to CompactLattice.
*/
template<class Weight>
bool DeterminizeLatticePruned(
const ExpandedFst<ArcTpl<Weight> > &ifst,
double prune,
MutableFst<ArcTpl<Weight> > *ofst,
DeterminizeLatticePrunedOptions opts = DeterminizeLatticePrunedOptions());
/* This is a version of DeterminizeLattice with a slightly more "natural" output format,
where the output sequences are encoded using the CompactLatticeArcTpl template
(i.e. the sequences of output symbols are represented directly as strings The input
FST must be topologically sorted in order for the algorithm to work. For efficiency
it is recommended to sort the ilabel for the input FST as well.
Returns true on normal success, and false if it had to terminate the determinization
earlier than specified by the "prune" beam-- that is, if it terminated because
of the max_mem, max_loop or max_arcs constraints in the options.
CAUTION: if Lattice is the input, you need to Invert() before calling this,
so words are on the input side.
*/
template<class Weight, class IntType>
bool DeterminizeLatticePruned(
const ExpandedFst<ArcTpl<Weight> >&ifst,
double prune,
MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, IntType> > > *ofst,
DeterminizeLatticePrunedOptions opts = DeterminizeLatticePrunedOptions());
// /** This function takes in lattices and inserts phones at phone boundaries. It
// uses the transition model to work out the transition_id to phone map. The
// returning value is the starting index of the phone label. Typically we pick
// (maximum_output_label_index + 1) as this value. The inserted phones are then
// mapped to (returning_value + original_phone_label) in the new lattice. The
// returning value will be used by DeterminizeLatticeDeletePhones() where it
// works out the phones according to this value.
// */
// template<class Weight>
// typename ArcTpl<Weight>::Label DeterminizeLatticeInsertPhones(
// const kaldi::TransitionModel &trans_model,
// MutableFst<ArcTpl<Weight> > *fst);
//
// /** This function takes in lattices and deletes "phones" from them. The "phones"
// here are actually any label that is larger than first_phone_label because
// when we insert phones into the lattice, we map the original phone label to
// (first_phone_label + original_phone_label). It is supposed to be used
// together with DeterminizeLatticeInsertPhones()
// */
// template<class Weight>
// void DeterminizeLatticeDeletePhones(
// typename ArcTpl<Weight>::Label first_phone_label,
// MutableFst<ArcTpl<Weight> > *fst);
//
// /** This function is a wrapper of DeterminizeLatticePhonePrunedFirstPass() and
// DeterminizeLatticePruned(). If --phone-determinize is set to true, it first
// calls DeterminizeLatticePhonePrunedFirstPass() to do the initial pass of
// determinization on the phone + word lattices. If --word-determinize is set
// true, it then does a second pass of determinization on the word lattices by
// calling DeterminizeLatticePruned(). If both are set to false, then it gives
// a warning and copying the lattices without determinization.
//
// Note: the point of doing first a phone-level determinization pass and then
// a word-level determinization pass is that it allows us to determinize
// deeper lattices without "failing early" and returning a too-small lattice
// due to the max-mem constraint. The result should be the same as word-level
// determinization in general, but for deeper lattices it is a bit faster,
// despite the fact that we now have two passes of determinization by default.
// */
// template<class Weight, class IntType>
// bool DeterminizeLatticePhonePruned(
// const kaldi::TransitionModel &trans_model,
// const ExpandedFst<ArcTpl<Weight> > &ifst,
// double prune,
// MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, IntType> > > *ofst,
// DeterminizeLatticePhonePrunedOptions opts
// = DeterminizeLatticePhonePrunedOptions());
//
// /** "Destructive" version of DeterminizeLatticePhonePruned() where the input
// lattice might be changed.
// */
// template<class Weight, class IntType>
// bool DeterminizeLatticePhonePruned(
// const kaldi::TransitionModel &trans_model,
// MutableFst<ArcTpl<Weight> > *ifst,
// double prune,
// MutableFst<ArcTpl<CompactLatticeWeightTpl<Weight, IntType> > > *ofst,
// DeterminizeLatticePhonePrunedOptions opts
// = DeterminizeLatticePhonePrunedOptions());
//
// /** This function is a wrapper of DeterminizeLatticePhonePruned() that works for
// Lattice type FSTs. It simplifies the calling process by calling
// TopSort() Invert() and ArcSort() for you.
// Unlike other determinization routines, the function
// requires "ifst" to have transition-id's on the input side and words on the
// output side.
// This function can be used as the top-level interface to all the determinization
// code.
// */
// bool DeterminizeLatticePhonePrunedWrapper(
// const kaldi::TransitionModel &trans_model,
// MutableFst<kaldi::LatticeArc> *ifst,
// double prune,
// MutableFst<kaldi::CompactLatticeArc> *ofst,
// DeterminizeLatticePhonePrunedOptions opts
// = DeterminizeLatticePhonePrunedOptions());
/// @} end "addtogroup fst_extensions"
} // end namespace fst
#endif
// lat/kaldi-lattice.cc
// Copyright 2009-2011 Microsoft Corporation
// 2013 Johns Hopkins University (author: Daniel Povey)
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include "lat/kaldi-lattice.h"
#include "fst/script/print-impl.h"
namespace kaldi {
/// Converts lattice types if necessary, deleting its input.
template<class OrigWeightType>
CompactLattice* ConvertToCompactLattice(fst::VectorFst<OrigWeightType> *ifst) {
if (!ifst) return NULL;
CompactLattice *ofst = new CompactLattice();
ConvertLattice(*ifst, ofst);
delete ifst;
return ofst;
}
// This overrides the template if there is no type conversion going on
// (for efficiency).
template<>
CompactLattice* ConvertToCompactLattice(CompactLattice *ifst) {
return ifst;
}
/// Converts lattice types if necessary, deleting its input.
template<class OrigWeightType>
Lattice* ConvertToLattice(fst::VectorFst<OrigWeightType> *ifst) {
if (!ifst) return NULL;
Lattice *ofst = new Lattice();
ConvertLattice(*ifst, ofst);
delete ifst;
return ofst;
}
// This overrides the template if there is no type conversion going on
// (for efficiency).
template<>
Lattice* ConvertToLattice(Lattice *ifst) {
return ifst;
}
bool WriteCompactLattice(std::ostream &os, bool binary,
const CompactLattice &t) {
if (binary) {
fst::FstWriteOptions opts;
// Leave all the options default. Normally these lattices wouldn't have any
// osymbols/isymbols so no point directing it not to write them (who knows what
// we'd want to if we had them).
return t.Write(os, opts);
} else {
// Text-mode output. Note: we expect that t.InputSymbols() and
// t.OutputSymbols() would always return NULL. The corresponding input
// routine would not work if the FST actually had symbols attached.
// Write a newline after the key, so the first line of the FST appears
// on its own line.
os << '\n';
bool acceptor = true, write_one = false;
fst::FstPrinter<CompactLatticeArc> printer(t, t.InputSymbols(),
t.OutputSymbols(),
NULL, acceptor, write_one, "\t");
printer.Print(&os, "<unknown>");
if (os.fail())
KALDI_WARN << "Stream failure detected.";
// Write another newline as a terminating character. The read routine will
// detect this [this is a Kaldi mechanism, not somethig in the original
// OpenFst code].
os << '\n';
return os.good();
}
}
/// LatticeReader provides (static) functions for reading both Lattice
/// and CompactLattice, in text form.
class LatticeReader {
typedef LatticeArc Arc;
typedef LatticeWeight Weight;
typedef CompactLatticeArc CArc;
typedef CompactLatticeWeight CWeight;
typedef Arc::Label Label;
typedef Arc::StateId StateId;
public:
// everything is static in this class.
/** This function reads from the FST text format; it does not know in advance
whether it's a Lattice or CompactLattice in the stream so it tries to
read both formats until it becomes clear which is the correct one.
*/
static std::pair<Lattice*, CompactLattice*> ReadText(
std::istream &is) {
typedef std::pair<Lattice*, CompactLattice*> PairT;
using std::string;
using std::vector;
Lattice *fst = new Lattice();
CompactLattice *cfst = new CompactLattice();
string line;
size_t nline = 0;
string separator = FLAGS_fst_field_separator + "\r\n";
while (std::getline(is, line)) {
nline++;
vector<string> col;
// on Windows we'll write in text and read in binary mode.
SplitStringToVector(line, separator.c_str(), true, &col);
if (col.size() == 0) break; // Empty line is a signal to stop, in our
// archive format.
if (col.size() > 5) {
KALDI_WARN << "Reading lattice: bad line in FST: " << line;
delete fst;
delete cfst;
return PairT(static_cast<Lattice*>(NULL),
static_cast<CompactLattice*>(NULL));
}
StateId s;
if (!ConvertStringToInteger(col[0], &s)) {
KALDI_WARN << "FstCompiler: bad line in FST: " << line;
delete fst;
delete cfst;
return PairT(static_cast<Lattice*>(NULL),
static_cast<CompactLattice*>(NULL));
}
if (fst)
while (s >= fst->NumStates())
fst->AddState();
if (cfst)
while (s >= cfst->NumStates())
cfst->AddState();
if (nline == 1) {
if (fst) fst->SetStart(s);
if (cfst) cfst->SetStart(s);
}
if (fst) { // we still have fst; try to read that arc.
bool ok = true;
Arc arc;
Weight w;
StateId d = s;
switch (col.size()) {
case 1 :
fst->SetFinal(s, Weight::One());
break;
case 2:
if (!StrToWeight(col[1], true, &w)) ok = false;
else fst->SetFinal(s, w);
break;
case 3: // 3 columns not ok for Lattice format; it's not an acceptor.
ok = false;
break;
case 4:
ok = ConvertStringToInteger(col[1], &arc.nextstate) &&
ConvertStringToInteger(col[2], &arc.ilabel) &&
ConvertStringToInteger(col[3], &arc.olabel);
if (ok) {
d = arc.nextstate;
arc.weight = Weight::One();
fst->AddArc(s, arc);
}
break;
case 5:
ok = ConvertStringToInteger(col[1], &arc.nextstate) &&
ConvertStringToInteger(col[2], &arc.ilabel) &&
ConvertStringToInteger(col[3], &arc.olabel) &&
StrToWeight(col[4], false, &arc.weight);
if (ok) {
d = arc.nextstate;
fst->AddArc(s, arc);
}
break;
default:
ok = false;
}
while (d >= fst->NumStates())
fst->AddState();
if (!ok) {
delete fst;
fst = NULL;
}
}
if (cfst) {
bool ok = true;
CArc arc;
CWeight w;
StateId d = s;
switch (col.size()) {
case 1 :
cfst->SetFinal(s, CWeight::One());
break;
case 2:
if (!StrToCWeight(col[1], true, &w)) ok = false;
else cfst->SetFinal(s, w);
break;
case 3: // compact-lattice is acceptor format: state, next-state, label.
ok = ConvertStringToInteger(col[1], &arc.nextstate) &&
ConvertStringToInteger(col[2], &arc.ilabel);
if (ok) {
d = arc.nextstate;
arc.olabel = arc.ilabel;
arc.weight = CWeight::One();
cfst->AddArc(s, arc);
}
break;
case 4:
ok = ConvertStringToInteger(col[1], &arc.nextstate) &&
ConvertStringToInteger(col[2], &arc.ilabel) &&
StrToCWeight(col[3], false, &arc.weight);
if (ok) {
d = arc.nextstate;
arc.olabel = arc.ilabel;
cfst->AddArc(s, arc);
}
break;
case 5: default:
ok = false;
}
while (d >= cfst->NumStates())
cfst->AddState();
if (!ok) {
delete cfst;
cfst = NULL;
}
}
if (!fst && !cfst) {
KALDI_WARN << "Bad line in lattice text format: " << line;
// read until we get an empty line, so at least we
// have a chance to read the next one (although this might
// be a bit futile since the calling code will get unhappy
// about failing to read this one.
while (std::getline(is, line)) {
SplitStringToVector(line, separator.c_str(), true, &col);
if (col.empty()) break;
}
return PairT(static_cast<Lattice*>(NULL),
static_cast<CompactLattice*>(NULL));
}
}
return PairT(fst, cfst);
}
static bool StrToWeight(const std::string &s, bool allow_zero, Weight *w) {
std::istringstream strm(s);
strm >> *w;
if (!strm || (!allow_zero && *w == Weight::Zero())) {
return false;
}
return true;
}
static bool StrToCWeight(const std::string &s, bool allow_zero, CWeight *w) {
std::istringstream strm(s);
strm >> *w;
if (!strm || (!allow_zero && *w == CWeight::Zero())) {
return false;
}
return true;
}
};
CompactLattice *ReadCompactLatticeText(std::istream &is) {
std::pair<Lattice*, CompactLattice*> lat_pair = LatticeReader::ReadText(is);
if (lat_pair.second != NULL) {
delete lat_pair.first;
return lat_pair.second;
} else if (lat_pair.first != NULL) {
// note: ConvertToCompactLattice frees its input.
return ConvertToCompactLattice(lat_pair.first);
} else {
return NULL;
}
}
Lattice *ReadLatticeText(std::istream &is) {
std::pair<Lattice*, CompactLattice*> lat_pair = LatticeReader::ReadText(is);
if (lat_pair.first != NULL) {
delete lat_pair.second;
return lat_pair.first;
} else if (lat_pair.second != NULL) {
// note: ConvertToLattice frees its input.
return ConvertToLattice(lat_pair.second);
} else {
return NULL;
}
}
bool ReadCompactLattice(std::istream &is, bool binary,
CompactLattice **clat) {
KALDI_ASSERT(*clat == NULL);
if (binary) {
fst::FstHeader hdr;
if (!hdr.Read(is, "<unknown>")) {
KALDI_WARN << "Reading compact lattice: error reading FST header.";
return false;
}
if (hdr.FstType() != "vector") {
KALDI_WARN << "Reading compact lattice: unsupported FST type: "
<< hdr.FstType();
return false;
}
fst::FstReadOptions ropts("<unspecified>",
&hdr);
typedef fst::CompactLatticeWeightTpl<fst::LatticeWeightTpl<float>, int32> T1;
typedef fst::CompactLatticeWeightTpl<fst::LatticeWeightTpl<double>, int32> T2;
typedef fst::LatticeWeightTpl<float> T3;
typedef fst::LatticeWeightTpl<double> T4;
typedef fst::VectorFst<fst::ArcTpl<T1> > F1;
typedef fst::VectorFst<fst::ArcTpl<T2> > F2;
typedef fst::VectorFst<fst::ArcTpl<T3> > F3;
typedef fst::VectorFst<fst::ArcTpl<T4> > F4;
CompactLattice *ans = NULL;
if (hdr.ArcType() == T1::Type()) {
ans = ConvertToCompactLattice(F1::Read(is, ropts));
} else if (hdr.ArcType() == T2::Type()) {
ans = ConvertToCompactLattice(F2::Read(is, ropts));
} else if (hdr.ArcType() == T3::Type()) {
ans = ConvertToCompactLattice(F3::Read(is, ropts));
} else if (hdr.ArcType() == T4::Type()) {
ans = ConvertToCompactLattice(F4::Read(is, ropts));
} else {
KALDI_WARN << "FST with arc type " << hdr.ArcType()
<< " cannot be converted to CompactLattice.\n";
return false;
}
if (ans == NULL) {
KALDI_WARN << "Error reading compact lattice (after reading header).";
return false;
}
*clat = ans;
return true;
} else {
// The next line would normally consume the \r on Windows, plus any
// extra spaces that might have got in there somehow.
while (std::isspace(is.peek()) && is.peek() != '\n') is.get();
if (is.peek() == '\n') is.get(); // consume the newline.
else { // saw spaces but no newline.. this is not expected.
KALDI_WARN << "Reading compact lattice: unexpected sequence of spaces "
<< " at file position " << is.tellg();
return false;
}
*clat = ReadCompactLatticeText(is); // that routine will warn on error.
return (*clat != NULL);
}
}
bool CompactLatticeHolder::Read(std::istream &is) {
Clear(); // in case anything currently stored.
int c = is.peek();
if (c == -1) {
KALDI_WARN << "End of stream detected reading CompactLattice.";
return false;
} else if (isspace(c)) { // The text form of the lattice begins
// with space (normally, '\n'), so this means it's text (the binary form
// cannot begin with space because it starts with the FST Type() which is not
// space).
return ReadCompactLattice(is, false, &t_);
} else if (c != 214) { // 214 is first char of FST magic number,
// on little-endian machines which is all we support (\326 octal)
KALDI_WARN << "Reading compact lattice: does not appear to be an FST "
<< " [non-space but no magic number detected], file pos is "
<< is.tellg();
return false;
} else {
return ReadCompactLattice(is, true, &t_);
}
}
bool WriteLattice(std::ostream &os, bool binary, const Lattice &t) {
if (binary) {
fst::FstWriteOptions opts;
// Leave all the options default. Normally these lattices wouldn't have any
// osymbols/isymbols so no point directing it not to write them (who knows what
// we'd want to do if we had them).
return t.Write(os, opts);
} else {
// Text-mode output. Note: we expect that t.InputSymbols() and
// t.OutputSymbols() would always return NULL. The corresponding input
// routine would not work if the FST actually had symbols attached.
// Write a newline after the key, so the first line of the FST appears
// on its own line.
os << '\n';
bool acceptor = false, write_one = false;
fst::FstPrinter<LatticeArc> printer(t, t.InputSymbols(),
t.OutputSymbols(),
NULL, acceptor, write_one, "\t");
printer.Print(&os, "<unknown>");
if (os.fail())
KALDI_WARN << "Stream failure detected.";
// Write another newline as a terminating character. The read routine will
// detect this [this is a Kaldi mechanism, not somethig in the original
// OpenFst code].
os << '\n';
return os.good();
}
}
bool ReadLattice(std::istream &is, bool binary,
Lattice **lat) {
KALDI_ASSERT(*lat == NULL);
if (binary) {
fst::FstHeader hdr;
if (!hdr.Read(is, "<unknown>")) {
KALDI_WARN << "Reading lattice: error reading FST header.";
return false;
}
if (hdr.FstType() != "vector") {
KALDI_WARN << "Reading lattice: unsupported FST type: "
<< hdr.FstType();
return false;
}
fst::FstReadOptions ropts("<unspecified>",
&hdr);
typedef fst::CompactLatticeWeightTpl<fst::LatticeWeightTpl<float>, int32> T1;
typedef fst::CompactLatticeWeightTpl<fst::LatticeWeightTpl<double>, int32> T2;
typedef fst::LatticeWeightTpl<float> T3;
typedef fst::LatticeWeightTpl<double> T4;
typedef fst::VectorFst<fst::ArcTpl<T1> > F1;
typedef fst::VectorFst<fst::ArcTpl<T2> > F2;
typedef fst::VectorFst<fst::ArcTpl<T3> > F3;
typedef fst::VectorFst<fst::ArcTpl<T4> > F4;
Lattice *ans = NULL;
if (hdr.ArcType() == T1::Type()) {
ans = ConvertToLattice(F1::Read(is, ropts));
} else if (hdr.ArcType() == T2::Type()) {
ans = ConvertToLattice(F2::Read(is, ropts));
} else if (hdr.ArcType() == T3::Type()) {
ans = ConvertToLattice(F3::Read(is, ropts));
} else if (hdr.ArcType() == T4::Type()) {
ans = ConvertToLattice(F4::Read(is, ropts));
} else {
KALDI_WARN << "FST with arc type " << hdr.ArcType()
<< " cannot be converted to Lattice.\n";
return false;
}
if (ans == NULL) {
KALDI_WARN << "Error reading lattice (after reading header).";
return false;
}
*lat = ans;
return true;
} else {
// The next line would normally consume the \r on Windows, plus any
// extra spaces that might have got in there somehow.
while (std::isspace(is.peek()) && is.peek() != '\n') is.get();
if (is.peek() == '\n') is.get(); // consume the newline.
else { // saw spaces but no newline.. this is not expected.
KALDI_WARN << "Reading compact lattice: unexpected sequence of spaces "
<< " at file position " << is.tellg();
return false;
}
*lat = ReadLatticeText(is); // that routine will warn on error.
return (*lat != NULL);
}
}
/* Since we don't write the binary headers for this type of holder,
we use a different method to work out whether we're in binary mode.
*/
bool LatticeHolder::Read(std::istream &is) {
Clear(); // in case anything currently stored.
int c = is.peek();
if (c == -1) {
KALDI_WARN << "End of stream detected reading Lattice.";
return false;
} else if (isspace(c)) { // The text form of the lattice begins
// with space (normally, '\n'), so this means it's text (the binary form
// cannot begin with space because it starts with the FST Type() which is not
// space).
return ReadLattice(is, false, &t_);
} else if (c != 214) { // 214 is first char of FST magic number,
// on little-endian machines which is all we support (\326 octal)
KALDI_WARN << "Reading compact lattice: does not appear to be an FST "
<< " [non-space but no magic number detected], file pos is "
<< is.tellg();
return false;
} else {
return ReadLattice(is, true, &t_);
}
}
} // end namespace kaldi
// lat/kaldi-lattice.h
// Copyright 2009-2011 Microsoft Corporation
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_LAT_KALDI_LATTICE_H_
#define KALDI_LAT_KALDI_LATTICE_H_
#include "fstext/fstext-lib.h"
#include "base/kaldi-common.h"
// #include "util/common-utils.h"
namespace kaldi {
// will import some things above...
typedef fst::LatticeWeightTpl<BaseFloat> LatticeWeight;
// careful: kaldi::int32 is not always the same C type as fst::int32
typedef fst::CompactLatticeWeightTpl<LatticeWeight, int32> CompactLatticeWeight;
typedef fst::CompactLatticeWeightCommonDivisorTpl<LatticeWeight, int32>
CompactLatticeWeightCommonDivisor;
typedef fst::ArcTpl<LatticeWeight> LatticeArc;
typedef fst::ArcTpl<CompactLatticeWeight> CompactLatticeArc;
typedef fst::VectorFst<LatticeArc> Lattice;
typedef fst::VectorFst<CompactLatticeArc> CompactLattice;
// The following functions for writing and reading lattices in binary or text
// form are provided here in case you need to include lattices in larger,
// Kaldi-type objects with their own Read and Write functions. Caution: these
// functions return false on stream failure rather than throwing an exception as
// most similar Kaldi functions would do.
bool WriteCompactLattice(std::ostream &os, bool binary,
const CompactLattice &clat);
bool WriteLattice(std::ostream &os, bool binary,
const Lattice &lat);
// the following function requires that *clat be
// NULL when called.
bool ReadCompactLattice(std::istream &is, bool binary,
CompactLattice **clat);
// the following function requires that *lat be
// NULL when called.
bool ReadLattice(std::istream &is, bool binary,
Lattice **lat);
class CompactLatticeHolder {
public:
typedef CompactLattice T;
CompactLatticeHolder() { t_ = NULL; }
static bool Write(std::ostream &os, bool binary, const T &t) {
// Note: we don't include the binary-mode header when writing
// this object to disk; this ensures that if we write to single
// files, the result can be read by OpenFst.
return WriteCompactLattice(os, binary, t);
}
bool Read(std::istream &is);
static bool IsReadInBinary() { return true; }
T &Value() {
KALDI_ASSERT(t_ != NULL && "Called Value() on empty CompactLatticeHolder");
return *t_;
}
void Clear() { delete t_; t_ = NULL; }
void Swap(CompactLatticeHolder *other) {
std::swap(t_, other->t_);
}
bool ExtractRange(const CompactLatticeHolder &other, const std::string &range) {
KALDI_ERR << "ExtractRange is not defined for this type of holder.";
return false;
}
~CompactLatticeHolder() { Clear(); }
private:
T *t_;
};
class LatticeHolder {
public:
typedef Lattice T;
LatticeHolder() { t_ = NULL; }
static bool Write(std::ostream &os, bool binary, const T &t) {
// Note: we don't include the binary-mode header when writing
// this object to disk; this ensures that if we write to single
// files, the result can be read by OpenFst.
return WriteLattice(os, binary, t);
}
bool Read(std::istream &is);
static bool IsReadInBinary() { return true; }
T &Value() {
KALDI_ASSERT(t_ != NULL && "Called Value() on empty LatticeHolder");
return *t_;
}
void Clear() { delete t_; t_ = NULL; }
void Swap(LatticeHolder *other) {
std::swap(t_, other->t_);
}
bool ExtractRange(const LatticeHolder &other, const std::string &range) {
KALDI_ERR << "ExtractRange is not defined for this type of holder.";
return false;
}
~LatticeHolder() { Clear(); }
private:
T *t_;
};
// typedef TableWriter<LatticeHolder> LatticeWriter;
// typedef SequentialTableReader<LatticeHolder> SequentialLatticeReader;
// typedef RandomAccessTableReader<LatticeHolder> RandomAccessLatticeReader;
//
// typedef TableWriter<CompactLatticeHolder> CompactLatticeWriter;
// typedef SequentialTableReader<CompactLatticeHolder> SequentialCompactLatticeReader;
// typedef RandomAccessTableReader<CompactLatticeHolder> RandomAccessCompactLatticeReader;
} // namespace kaldi
#endif // KALDI_LAT_KALDI_LATTICE_H_
// lat/lattice-functions.cc
// Copyright 2009-2011 Saarland University (Author: Arnab Ghoshal)
// 2012-2013 Johns Hopkins University (Author: Daniel Povey); Chao Weng;
// Bagher BabaAli
// 2013 Cisco Systems (author: Neha Agrawal) [code modified
// from original code in ../gmmbin/gmm-rescore-lattice.cc]
// 2014 Guoguo Chen
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include "lat/lattice-functions.h"
// #include "hmm/transition-model.h"
// #include "util/stl-utils.h"
#include "base/kaldi-math.h"
// #include "hmm/hmm-utils.h"
namespace kaldi {
using std::map;
using std::vector;
// void GetPerFrameAcousticCosts(const Lattice &nbest,
// Vector<BaseFloat> *per_frame_loglikes) {
// using namespace fst;
// typedef Lattice::Arc::Weight Weight;
// vector<BaseFloat> loglikes;
//
// int32 cur_state = nbest.Start();
// int32 prev_frame = -1;
// BaseFloat eps_acwt = 0.0;
// while(1) {
// Weight w = nbest.Final(cur_state);
// if (w != Weight::Zero()) {
// KALDI_ASSERT(nbest.NumArcs(cur_state) == 0);
// if (per_frame_loglikes != NULL) {
// SubVector<BaseFloat> subvec(&(loglikes[0]), loglikes.size());
// Vector<BaseFloat> vec(subvec);
// *per_frame_loglikes = vec;
// }
// break;
// } else {
// KALDI_ASSERT(nbest.NumArcs(cur_state) == 1);
// fst::ArcIterator<Lattice> iter(nbest, cur_state);
// const Lattice::Arc &arc = iter.Value();
// BaseFloat acwt = arc.weight.Value2();
// if (arc.ilabel != 0) {
// if (eps_acwt > 0) {
// acwt += eps_acwt;
// eps_acwt = 0.0;
// }
// loglikes.push_back(acwt);
// prev_frame++;
// } else if (acwt == acwt){
// if (prev_frame > -1) {
// loglikes[prev_frame] += acwt;
// } else {
// eps_acwt += acwt;
// }
// }
// cur_state = arc.nextstate;
// }
// }
// }
//
// int32 LatticeStateTimes(const Lattice &lat, vector<int32> *times) {
// if (!lat.Properties(fst::kTopSorted, true))
// KALDI_ERR << "Input lattice must be topologically sorted.";
// KALDI_ASSERT(lat.Start() == 0);
// int32 num_states = lat.NumStates();
// times->clear();
// times->resize(num_states, -1);
// (*times)[0] = 0;
// for (int32 state = 0; state < num_states; state++) {
// int32 cur_time = (*times)[state];
// for (fst::ArcIterator<Lattice> aiter(lat, state); !aiter.Done();
// aiter.Next()) {
// const LatticeArc &arc = aiter.Value();
//
// if (arc.ilabel != 0) { // Non-epsilon input label on arc
// // next time instance
// if ((*times)[arc.nextstate] == -1) {
// (*times)[arc.nextstate] = cur_time + 1;
// } else {
// KALDI_ASSERT((*times)[arc.nextstate] == cur_time + 1);
// }
// } else { // epsilon input label on arc
// // Same time instance
// if ((*times)[arc.nextstate] == -1)
// (*times)[arc.nextstate] = cur_time;
// else
// KALDI_ASSERT((*times)[arc.nextstate] == cur_time);
// }
// }
// }
// return (*std::max_element(times->begin(), times->end()));
// }
//
// int32 CompactLatticeStateTimes(const CompactLattice &lat,
// vector<int32> *times) {
// if (!lat.Properties(fst::kTopSorted, true))
// KALDI_ERR << "Input lattice must be topologically sorted.";
// KALDI_ASSERT(lat.Start() == 0);
// int32 num_states = lat.NumStates();
// times->clear();
// times->resize(num_states, -1);
// (*times)[0] = 0;
// int32 utt_len = -1;
// for (int32 state = 0; state < num_states; state++) {
// int32 cur_time = (*times)[state];
// for (fst::ArcIterator<CompactLattice> aiter(lat, state); !aiter.Done();
// aiter.Next()) {
// const CompactLatticeArc &arc = aiter.Value();
// int32 arc_len = static_cast<int32>(arc.weight.String().size());
// if ((*times)[arc.nextstate] == -1)
// (*times)[arc.nextstate] = cur_time + arc_len;
// else
// KALDI_ASSERT((*times)[arc.nextstate] == cur_time + arc_len);
// }
// if (lat.Final(state) != CompactLatticeWeight::Zero()) {
// int32 this_utt_len = (*times)[state] + lat.Final(state).String().size();
// if (utt_len == -1) utt_len = this_utt_len;
// else {
// if (this_utt_len != utt_len) {
// KALDI_WARN << "Utterance does not "
// "seem to have a consistent length.";
// utt_len = std::max(utt_len, this_utt_len);
// }
// }
// }
// }
// if (utt_len == -1) {
// KALDI_WARN << "Utterance does not have a final-state.";
// return 0;
// }
// return utt_len;
// }
//
// bool ComputeCompactLatticeAlphas(const CompactLattice &clat,
// vector<double> *alpha) {
// using namespace fst;
//
// // typedef the arc, weight types
// typedef CompactLattice::Arc Arc;
// typedef Arc::Weight Weight;
// typedef Arc::StateId StateId;
//
// //Make sure the lattice is topologically sorted.
// if (clat.Properties(fst::kTopSorted, true) == 0) {
// KALDI_WARN << "Input lattice must be topologically sorted.";
// return false;
// }
// if (clat.Start() != 0) {
// KALDI_WARN << "Input lattice must start from state 0.";
// return false;
// }
//
// int32 num_states = clat.NumStates();
// (*alpha).resize(0);
// (*alpha).resize(num_states, kLogZeroDouble);
//
// // Now propagate alphas forward. Note that we don't acount the weight of the
// // final state to alpha[final_state] -- we acount it to beta[final_state];
// (*alpha)[0] = 0.0;
// for (StateId s = 0; s < num_states; s++) {
// double this_alpha = (*alpha)[s];
// for (ArcIterator<CompactLattice> aiter(clat, s);
// !aiter.Done(); aiter.Next()) {
// const Arc &arc = aiter.Value();
// double arc_like = -(arc.weight.Weight().Value1() +
// arc.weight.Weight().Value2());
// (*alpha)[arc.nextstate] = LogAdd((*alpha)[arc.nextstate],
// this_alpha + arc_like);
// }
// }
//
// return true;
// }
//
// bool ComputeCompactLatticeBetas(const CompactLattice &clat,
// vector<double> *beta) {
// using namespace fst;
//
// // typedef the arc, weight types
// typedef CompactLattice::Arc Arc;
// typedef Arc::Weight Weight;
// typedef Arc::StateId StateId;
//
// // Make sure the lattice is topologically sorted.
// if (clat.Properties(fst::kTopSorted, true) == 0) {
// KALDI_WARN << "Input lattice must be topologically sorted.";
// return false;
// }
// if (clat.Start() != 0) {
// KALDI_WARN << "Input lattice must start from state 0.";
// return false;
// }
//
// int32 num_states = clat.NumStates();
// (*beta).resize(0);
// (*beta).resize(num_states, kLogZeroDouble);
//
// // Now propagate betas backward. Note that beta[final_state] contains the
// // weight of the final state in the lattice -- compare that with alpha.
// for (StateId s = num_states-1; s >= 0; s--) {
// Weight f = clat.Final(s);
// double this_beta = -(f.Weight().Value1()+f.Weight().Value2());
// for (ArcIterator<CompactLattice> aiter(clat, s);
// !aiter.Done(); aiter.Next()) {
// const Arc &arc = aiter.Value();
// double arc_like = -(arc.weight.Weight().Value1() +
// arc.weight.Weight().Value2());
// double arc_beta = (*beta)[arc.nextstate] + arc_like;
// this_beta = LogAdd(this_beta, arc_beta);
// }
// (*beta)[s] = this_beta;
// }
//
// return true;
// }
template<class LatType> // could be Lattice or CompactLattice
bool PruneLattice(BaseFloat beam, LatType *lat) {
typedef typename LatType::Arc Arc;
typedef typename Arc::Weight Weight;
typedef typename Arc::StateId StateId;
KALDI_ASSERT(beam > 0.0);
if (!lat->Properties(fst::kTopSorted, true)) {
if (fst::TopSort(lat) == false) {
KALDI_WARN << "Cycles detected in lattice";
return false;
}
}
// We assume states before "start" are not reachable, since
// the lattice is topologically sorted.
int32 start = lat->Start();
int32 num_states = lat->NumStates();
if (num_states == 0) return false;
std::vector<double> forward_cost(num_states,
std::numeric_limits<double>::infinity()); // viterbi forward.
forward_cost[start] = 0.0; // lattice can't have cycles so couldn't be
// less than this.
double best_final_cost = std::numeric_limits<double>::infinity();
// Update the forward probs.
// Thanks to Jing Zheng for finding a bug here.
for (int32 state = 0; state < num_states; state++) {
double this_forward_cost = forward_cost[state];
for (fst::ArcIterator<LatType> aiter(*lat, state);
!aiter.Done();
aiter.Next()) {
const Arc &arc(aiter.Value());
StateId nextstate = arc.nextstate;
KALDI_ASSERT(nextstate > state && nextstate < num_states);
double next_forward_cost = this_forward_cost +
ConvertToCost(arc.weight);
if (forward_cost[nextstate] > next_forward_cost)
forward_cost[nextstate] = next_forward_cost;
}
Weight final_weight = lat->Final(state);
double this_final_cost = this_forward_cost +
ConvertToCost(final_weight);
if (this_final_cost < best_final_cost)
best_final_cost = this_final_cost;
}
int32 bad_state = lat->AddState(); // this state is not final.
double cutoff = best_final_cost + beam;
// Go backwards updating the backward probs (which share memory with the
// forward probs), and pruning arcs and deleting final-probs. We prune arcs
// by making them point to the non-final state "bad_state". We'll then use
// Trim() to remove unnecessary arcs and states. [this is just easier than
// doing it ourselves.]
std::vector<double> &backward_cost(forward_cost);
for (int32 state = num_states - 1; state >= 0; state--) {
double this_forward_cost = forward_cost[state];
double this_backward_cost = ConvertToCost(lat->Final(state));
if (this_backward_cost + this_forward_cost > cutoff
&& this_backward_cost != std::numeric_limits<double>::infinity())
lat->SetFinal(state, Weight::Zero());
for (fst::MutableArcIterator<LatType> aiter(lat, state);
!aiter.Done();
aiter.Next()) {
Arc arc(aiter.Value());
StateId nextstate = arc.nextstate;
KALDI_ASSERT(nextstate > state && nextstate < num_states);
double arc_cost = ConvertToCost(arc.weight),
arc_backward_cost = arc_cost + backward_cost[nextstate],
this_fb_cost = this_forward_cost + arc_backward_cost;
if (arc_backward_cost < this_backward_cost)
this_backward_cost = arc_backward_cost;
if (this_fb_cost > cutoff) { // Prune the arc.
arc.nextstate = bad_state;
aiter.SetValue(arc);
}
}
backward_cost[state] = this_backward_cost;
}
fst::Connect(lat);
return (lat->NumStates() > 0);
}
// instantiate the template for lattice and CompactLattice.
template bool PruneLattice(BaseFloat beam, Lattice *lat);
template bool PruneLattice(BaseFloat beam, CompactLattice *lat);
// BaseFloat LatticeForwardBackward(const Lattice &lat, Posterior *post,
// double *acoustic_like_sum) {
// // Note, Posterior is defined as follows: Indexed [frame], then a list
// // of (transition-id, posterior-probability) pairs.
// // typedef std::vector<std::vector<std::pair<int32, BaseFloat> > > Posterior;
// using namespace fst;
// typedef Lattice::Arc Arc;
// typedef Arc::Weight Weight;
// typedef Arc::StateId StateId;
//
// if (acoustic_like_sum) *acoustic_like_sum = 0.0;
//
// // Make sure the lattice is topologically sorted.
// if (lat.Properties(fst::kTopSorted, true) == 0)
// KALDI_ERR << "Input lattice must be topologically sorted.";
// KALDI_ASSERT(lat.Start() == 0);
//
// int32 num_states = lat.NumStates();
// vector<int32> state_times;
// int32 max_time = LatticeStateTimes(lat, &state_times);
// std::vector<double> alpha(num_states, kLogZeroDouble);
// std::vector<double> &beta(alpha); // we re-use the same memory for
// // this, but it's semantically distinct so we name it differently.
// double tot_forward_prob = kLogZeroDouble;
//
// post->clear();
// post->resize(max_time);
//
// alpha[0] = 0.0;
// // Propagate alphas forward.
// for (StateId s = 0; s < num_states; s++) {
// double this_alpha = alpha[s];
// for (ArcIterator<Lattice> aiter(lat, s); !aiter.Done(); aiter.Next()) {
// const Arc &arc = aiter.Value();
// double arc_like = -ConvertToCost(arc.weight);
// alpha[arc.nextstate] = LogAdd(alpha[arc.nextstate], this_alpha + arc_like);
// }
// Weight f = lat.Final(s);
// if (f != Weight::Zero()) {
// double final_like = this_alpha - (f.Value1() + f.Value2());
// tot_forward_prob = LogAdd(tot_forward_prob, final_like);
// KALDI_ASSERT(state_times[s] == max_time &&
// "Lattice is inconsistent (final-prob not at max_time)");
// }
// }
// for (StateId s = num_states-1; s >= 0; s--) {
// Weight f = lat.Final(s);
// double this_beta = -(f.Value1() + f.Value2());
// for (ArcIterator<Lattice> aiter(lat, s); !aiter.Done(); aiter.Next()) {
// const Arc &arc = aiter.Value();
// double arc_like = -ConvertToCost(arc.weight),
// arc_beta = beta[arc.nextstate] + arc_like;
// this_beta = LogAdd(this_beta, arc_beta);
// int32 transition_id = arc.ilabel;
//
// // The following "if" is an optimization to avoid un-needed exp().
// if (transition_id != 0 || acoustic_like_sum != NULL) {
// double posterior = Exp(alpha[s] + arc_beta - tot_forward_prob);
//
// if (transition_id != 0) // Arc has a transition-id on it [not epsilon]
// (*post)[state_times[s]].push_back(std::make_pair(transition_id,
// static_cast<kaldi::BaseFloat>(posterior)));
// if (acoustic_like_sum != NULL)
// *acoustic_like_sum -= posterior * arc.weight.Value2();
// }
// }
// if (acoustic_like_sum != NULL && f != Weight::Zero()) {
// double final_logprob = - ConvertToCost(f),
// posterior = Exp(alpha[s] + final_logprob - tot_forward_prob);
// *acoustic_like_sum -= posterior * f.Value2();
// }
// beta[s] = this_beta;
// }
// double tot_backward_prob = beta[0];
// if (!ApproxEqual(tot_forward_prob, tot_backward_prob, 1e-8)) {
// KALDI_WARN << "Total forward probability over lattice = " << tot_forward_prob
// << ", while total backward probability = " << tot_backward_prob;
// }
// // Now combine any posteriors with the same transition-id.
// for (int32 t = 0; t < max_time; t++)
// MergePairVectorSumming(&((*post)[t]));
// return tot_backward_prob;
// }
//
//
// void LatticeActivePhones(const Lattice &lat, const TransitionModel &trans,
// const vector<int32> &silence_phones,
// vector< std::set<int32> > *active_phones) {
// KALDI_ASSERT(IsSortedAndUniq(silence_phones));
// vector<int32> state_times;
// int32 num_states = lat.NumStates();
// int32 max_time = LatticeStateTimes(lat, &state_times);
// active_phones->clear();
// active_phones->resize(max_time);
// for (int32 state = 0; state < num_states; state++) {
// int32 cur_time = state_times[state];
// for (fst::ArcIterator<Lattice> aiter(lat, state); !aiter.Done();
// aiter.Next()) {
// const LatticeArc &arc = aiter.Value();
// if (arc.ilabel != 0) { // Non-epsilon arc
// int32 phone = trans.TransitionIdToPhone(arc.ilabel);
// if (!std::binary_search(silence_phones.begin(),
// silence_phones.end(), phone))
// (*active_phones)[cur_time].insert(phone);
// }
// } // end looping over arcs
// } // end looping over states
// }
//
// void ConvertLatticeToPhones(const TransitionModel &trans,
// Lattice *lat) {
// typedef LatticeArc Arc;
// int32 num_states = lat->NumStates();
// for (int32 state = 0; state < num_states; state++) {
// for (fst::MutableArcIterator<Lattice> aiter(lat, state); !aiter.Done();
// aiter.Next()) {
// Arc arc(aiter.Value());
// arc.olabel = 0; // remove any word.
// if ((arc.ilabel != 0) // has a transition-id on input..
// && (trans.TransitionIdToHmmState(arc.ilabel) == 0)
// && (!trans.IsSelfLoop(arc.ilabel))) {
// // && trans.IsFinal(arc.ilabel)) // there is one of these per phone...
// arc.olabel = trans.TransitionIdToPhone(arc.ilabel);
// }
// aiter.SetValue(arc);
// } // end looping over arcs
// } // end looping over states
// }
//
//
// static inline double LogAddOrMax(bool viterbi, double a, double b) {
// if (viterbi)
// return std::max(a, b);
// else
// return LogAdd(a, b);
// }
//
// template<typename LatticeType>
// double ComputeLatticeAlphasAndBetas(const LatticeType &lat,
// bool viterbi,
// vector<double> *alpha,
// vector<double> *beta) {
// typedef typename LatticeType::Arc Arc;
// typedef typename Arc::Weight Weight;
// typedef typename Arc::StateId StateId;
//
// StateId num_states = lat.NumStates();
// KALDI_ASSERT(lat.Properties(fst::kTopSorted, true) == fst::kTopSorted);
// KALDI_ASSERT(lat.Start() == 0);
// alpha->clear();
// beta->clear();
// alpha->resize(num_states, kLogZeroDouble);
// beta->resize(num_states, kLogZeroDouble);
//
// double tot_forward_prob = kLogZeroDouble;
// (*alpha)[0] = 0.0;
// // Propagate alphas forward.
// for (StateId s = 0; s < num_states; s++) {
// double this_alpha = (*alpha)[s];
// for (fst::ArcIterator<LatticeType> aiter(lat, s); !aiter.Done();
// aiter.Next()) {
// const Arc &arc = aiter.Value();
// double arc_like = -ConvertToCost(arc.weight);
// (*alpha)[arc.nextstate] = LogAddOrMax(viterbi, (*alpha)[arc.nextstate],
// this_alpha + arc_like);
// }
// Weight f = lat.Final(s);
// if (f != Weight::Zero()) {
// double final_like = this_alpha - ConvertToCost(f);
// tot_forward_prob = LogAddOrMax(viterbi, tot_forward_prob, final_like);
// }
// }
// for (StateId s = num_states-1; s >= 0; s--) { // it's guaranteed signed.
// double this_beta = -ConvertToCost(lat.Final(s));
// for (fst::ArcIterator<LatticeType> aiter(lat, s); !aiter.Done();
// aiter.Next()) {
// const Arc &arc = aiter.Value();
// double arc_like = -ConvertToCost(arc.weight),
// arc_beta = (*beta)[arc.nextstate] + arc_like;
// this_beta = LogAddOrMax(viterbi, this_beta, arc_beta);
// }
// (*beta)[s] = this_beta;
// }
// double tot_backward_prob = (*beta)[lat.Start()];
// if (!ApproxEqual(tot_forward_prob, tot_backward_prob, 1e-8)) {
// KALDI_WARN << "Total forward probability over lattice = " << tot_forward_prob
// << ", while total backward probability = " << tot_backward_prob;
// }
// // Split the difference when returning... they should be the same.
// return 0.5 * (tot_backward_prob + tot_forward_prob);
// }
//
// // instantiate the template for Lattice and CompactLattice
// template
// double ComputeLatticeAlphasAndBetas(const Lattice &lat,
// bool viterbi,
// vector<double> *alpha,
// vector<double> *beta);
//
// template
// double ComputeLatticeAlphasAndBetas(const CompactLattice &lat,
// bool viterbi,
// vector<double> *alpha,
// vector<double> *beta);
//
//
//
// /// This is used in CompactLatticeLimitDepth.
// struct LatticeArcRecord {
// BaseFloat logprob; // logprob <= 0 is the best Viterbi logprob of this arc,
// // minus the overall best-cost of the lattice.
// CompactLatticeArc::StateId state; // state in the lattice.
// size_t arc; // arc index within the state.
// bool operator < (const LatticeArcRecord &other) const {
// return logprob < other.logprob;
// }
// };
//
// void CompactLatticeLimitDepth(int32 max_depth_per_frame,
// CompactLattice *clat) {
// typedef CompactLatticeArc Arc;
// typedef Arc::Weight Weight;
// typedef Arc::StateId StateId;
//
// if (clat->Start() == fst::kNoStateId) {
// KALDI_WARN << "Limiting depth of empty lattice.";
// return;
// }
// if (clat->Properties(fst::kTopSorted, true) == 0) {
// if (!TopSort(clat))
// KALDI_ERR << "Topological sorting of lattice failed.";
// }
//
// vector<int32> state_times;
// int32 T = CompactLatticeStateTimes(*clat, &state_times);
//
// // The alpha and beta quantities here are "viterbi" alphas and beta.
// std::vector<double> alpha;
// std::vector<double> beta;
// bool viterbi = true;
// double best_prob = ComputeLatticeAlphasAndBetas(*clat, viterbi,
// &alpha, &beta);
//
// std::vector<std::vector<LatticeArcRecord> > arc_records(T);
//
// StateId num_states = clat->NumStates();
// for (StateId s = 0; s < num_states; s++) {
// for (fst::ArcIterator<CompactLattice> aiter(*clat, s); !aiter.Done();
// aiter.Next()) {
// const Arc &arc = aiter.Value();
// LatticeArcRecord arc_record;
// arc_record.state = s;
// arc_record.arc = aiter.Position();
// arc_record.logprob =
// (alpha[s] + beta[arc.nextstate] - ConvertToCost(arc.weight))
// - best_prob;
// KALDI_ASSERT(arc_record.logprob < 0.1); // Should be zero or negative.
// int32 num_frames = arc.weight.String().size(), start_t = state_times[s];
// for (int32 t = start_t; t < start_t + num_frames; t++) {
// KALDI_ASSERT(t < T);
// arc_records[t].push_back(arc_record);
// }
// }
// }
// StateId dead_state = clat->AddState(); // A non-coaccesible state which we use
// // to remove arcs (make them end
// // there).
// size_t max_depth = max_depth_per_frame;
// for (int32 t = 0; t < T; t++) {
// size_t size = arc_records[t].size();
// if (size > max_depth) {
// // we sort from worst to best, so we keep the later-numbered ones,
// // and delete the lower-numbered ones.
// size_t cutoff = size - max_depth;
// std::nth_element(arc_records[t].begin(),
// arc_records[t].begin() + cutoff,
// arc_records[t].end());
// for (size_t index = 0; index < cutoff; index++) {
// LatticeArcRecord record(arc_records[t][index]);
// fst::MutableArcIterator<CompactLattice> aiter(clat, record.state);
// aiter.Seek(record.arc);
// Arc arc = aiter.Value();
// if (arc.nextstate != dead_state) { // not already killed.
// arc.nextstate = dead_state;
// aiter.SetValue(arc);
// }
// }
// }
// }
// Connect(clat);
// TopSortCompactLatticeIfNeeded(clat);
// }
//
//
// void TopSortCompactLatticeIfNeeded(CompactLattice *clat) {
// if (clat->Properties(fst::kTopSorted, true) == 0) {
// if (fst::TopSort(clat) == false) {
// KALDI_ERR << "Topological sorting failed";
// }
// }
// }
//
// void TopSortLatticeIfNeeded(Lattice *lat) {
// if (lat->Properties(fst::kTopSorted, true) == 0) {
// if (fst::TopSort(lat) == false) {
// KALDI_ERR << "Topological sorting failed";
// }
// }
// }
//
//
// /// Returns the depth of the lattice, defined as the average number of
// /// arcs crossing any given frame. Returns 1 for empty lattices.
// /// Requires that input is topologically sorted.
// BaseFloat CompactLatticeDepth(const CompactLattice &clat,
// int32 *num_frames) {
// typedef CompactLattice::Arc::StateId StateId;
// if (clat.Properties(fst::kTopSorted, true) == 0) {
// KALDI_ERR << "Lattice input to CompactLatticeDepth was not topologically "
// << "sorted.";
// }
// if (clat.Start() == fst::kNoStateId) {
// *num_frames = 0;
// return 1.0;
// }
// size_t num_arc_frames = 0;
// int32 t;
// {
// vector<int32> state_times;
// t = CompactLatticeStateTimes(clat, &state_times);
// }
// if (num_frames != NULL)
// *num_frames = t;
// for (StateId s = 0; s < clat.NumStates(); s++) {
// for (fst::ArcIterator<CompactLattice> aiter(clat, s); !aiter.Done();
// aiter.Next()) {
// const CompactLatticeArc &arc = aiter.Value();
// num_arc_frames += arc.weight.String().size();
// }
// num_arc_frames += clat.Final(s).String().size();
// }
// return num_arc_frames / static_cast<BaseFloat>(t);
// }
//
//
// void CompactLatticeDepthPerFrame(const CompactLattice &clat,
// std::vector<int32> *depth_per_frame) {
// typedef CompactLattice::Arc::StateId StateId;
// if (clat.Properties(fst::kTopSorted, true) == 0) {
// KALDI_ERR << "Lattice input to CompactLatticeDepthPerFrame was not "
// << "topologically sorted.";
// }
// if (clat.Start() == fst::kNoStateId) {
// depth_per_frame->clear();
// return;
// }
// vector<int32> state_times;
// int32 T = CompactLatticeStateTimes(clat, &state_times);
//
// depth_per_frame->clear();
// if (T <= 0) {
// return;
// } else {
// depth_per_frame->resize(T, 0);
// for (StateId s = 0; s < clat.NumStates(); s++) {
// int32 start_time = state_times[s];
// for (fst::ArcIterator<CompactLattice> aiter(clat, s); !aiter.Done();
// aiter.Next()) {
// const CompactLatticeArc &arc = aiter.Value();
// int32 len = arc.weight.String().size();
// for (int32 t = start_time; t < start_time + len; t++) {
// KALDI_ASSERT(t < T);
// (*depth_per_frame)[t]++;
// }
// }
// int32 final_len = clat.Final(s).String().size();
// for (int32 t = start_time; t < start_time + final_len; t++) {
// KALDI_ASSERT(t < T);
// (*depth_per_frame)[t]++;
// }
// }
// }
// }
//
//
//
// void ConvertCompactLatticeToPhones(const TransitionModel &trans,
// CompactLattice *clat) {
// typedef CompactLatticeArc Arc;
// typedef Arc::Weight Weight;
// int32 num_states = clat->NumStates();
// for (int32 state = 0; state < num_states; state++) {
// for (fst::MutableArcIterator<CompactLattice> aiter(clat, state);
// !aiter.Done();
// aiter.Next()) {
// Arc arc(aiter.Value());
// std::vector<int32> phone_seq;
// const std::vector<int32> &tid_seq = arc.weight.String();
// for (std::vector<int32>::const_iterator iter = tid_seq.begin();
// iter != tid_seq.end(); ++iter) {
// if (trans.IsFinal(*iter))// note: there is one of these per phone...
// phone_seq.push_back(trans.TransitionIdToPhone(*iter));
// }
// arc.weight.SetString(phone_seq);
// aiter.SetValue(arc);
// } // end looping over arcs
// Weight f = clat->Final(state);
// if (f != Weight::Zero()) {
// std::vector<int32> phone_seq;
// const std::vector<int32> &tid_seq = f.String();
// for (std::vector<int32>::const_iterator iter = tid_seq.begin();
// iter != tid_seq.end(); ++iter) {
// if (trans.IsFinal(*iter))// note: there is one of these per phone...
// phone_seq.push_back(trans.TransitionIdToPhone(*iter));
// }
// f.SetString(phone_seq);
// clat->SetFinal(state, f);
// }
// } // end looping over states
// }
//
// bool LatticeBoost(const TransitionModel &trans,
// const std::vector<int32> &alignment,
// const std::vector<int32> &silence_phones,
// BaseFloat b,
// BaseFloat max_silence_error,
// Lattice *lat) {
// TopSortLatticeIfNeeded(lat);
//
// // get all stored properties (test==false means don't test if not known).
// uint64 props = lat->Properties(fst::kFstProperties,
// false);
//
// KALDI_ASSERT(IsSortedAndUniq(silence_phones));
// KALDI_ASSERT(max_silence_error >= 0.0 && max_silence_error <= 1.0);
// vector<int32> state_times;
// int32 num_states = lat->NumStates();
// int32 num_frames = LatticeStateTimes(*lat, &state_times);
// KALDI_ASSERT(num_frames == static_cast<int32>(alignment.size()));
// for (int32 state = 0; state < num_states; state++) {
// int32 cur_time = state_times[state];
// for (fst::MutableArcIterator<Lattice> aiter(lat, state); !aiter.Done();
// aiter.Next()) {
// LatticeArc arc = aiter.Value();
// if (arc.ilabel != 0) { // Non-epsilon arc
// if (arc.ilabel < 0 || arc.ilabel > trans.NumTransitionIds()) {
// KALDI_WARN << "Lattice has out-of-range transition-ids: "
// << "lattice/model mismatch?";
// return false;
// }
// int32 phone = trans.TransitionIdToPhone(arc.ilabel),
// ref_phone = trans.TransitionIdToPhone(alignment[cur_time]);
// BaseFloat frame_error;
// if (phone == ref_phone) {
// frame_error = 0.0;
// } else { // an error...
// if (std::binary_search(silence_phones.begin(), silence_phones.end(), phone))
// frame_error = max_silence_error;
// else
// frame_error = 1.0;
// }
// BaseFloat delta_cost = -b * frame_error; // negative cost if
// // frame is wrong, to boost likelihood of arcs with errors on them.
// // Add this cost to the graph part.
// arc.weight.SetValue1(arc.weight.Value1() + delta_cost);
// aiter.SetValue(arc);
// }
// }
// }
// // All we changed is the weights, so any properties that were
// // known before, are still known, except for whether or not the
// // lattice was weighted.
// lat->SetProperties(props,
// ~(fst::kWeighted|fst::kUnweighted));
//
// return true;
// }
//
//
//
// BaseFloat LatticeForwardBackwardMpeVariants(
// const TransitionModel &trans,
// const std::vector<int32> &silence_phones,
// const Lattice &lat,
// const std::vector<int32> &num_ali,
// std::string criterion,
// bool one_silence_class,
// Posterior *post) {
// using namespace fst;
// typedef Lattice::Arc Arc;
// typedef Arc::Weight Weight;
// typedef Arc::StateId StateId;
//
// KALDI_ASSERT(criterion == "mpfe" || criterion == "smbr");
// bool is_mpfe = (criterion == "mpfe");
//
// if (lat.Properties(fst::kTopSorted, true) == 0)
// KALDI_ERR << "Input lattice must be topologically sorted.";
// KALDI_ASSERT(lat.Start() == 0);
//
// int32 num_states = lat.NumStates();
// vector<int32> state_times;
// int32 max_time = LatticeStateTimes(lat, &state_times);
// KALDI_ASSERT(max_time == static_cast<int32>(num_ali.size()));
// std::vector<double> alpha(num_states, kLogZeroDouble),
// alpha_smbr(num_states, 0), //forward variable for sMBR
// beta(num_states, kLogZeroDouble),
// beta_smbr(num_states, 0); //backward variable for sMBR
//
// double tot_forward_prob = kLogZeroDouble;
// double tot_forward_score = 0;
//
// post->clear();
// post->resize(max_time);
//
// alpha[0] = 0.0;
// // First Pass Forward,
// for (StateId s = 0; s < num_states; s++) {
// double this_alpha = alpha[s];
// for (ArcIterator<Lattice> aiter(lat, s); !aiter.Done(); aiter.Next()) {
// const Arc &arc = aiter.Value();
// double arc_like = -ConvertToCost(arc.weight);
// alpha[arc.nextstate] = LogAdd(alpha[arc.nextstate], this_alpha + arc_like);
// }
// Weight f = lat.Final(s);
// if (f != Weight::Zero()) {
// double final_like = this_alpha - (f.Value1() + f.Value2());
// tot_forward_prob = LogAdd(tot_forward_prob, final_like);
// KALDI_ASSERT(state_times[s] == max_time &&
// "Lattice is inconsistent (final-prob not at max_time)");
// }
// }
// // First Pass Backward,
// for (StateId s = num_states-1; s >= 0; s--) {
// Weight f = lat.Final(s);
// double this_beta = -(f.Value1() + f.Value2());
// for (ArcIterator<Lattice> aiter(lat, s); !aiter.Done(); aiter.Next()) {
// const Arc &arc = aiter.Value();
// double arc_like = -ConvertToCost(arc.weight),
// arc_beta = beta[arc.nextstate] + arc_like;
// this_beta = LogAdd(this_beta, arc_beta);
// }
// beta[s] = this_beta;
// }
// // First Pass Forward-Backward Check
// double tot_backward_prob = beta[0];
// // may loose the condition somehow here 1e-6 (was 1e-8)
// if (!ApproxEqual(tot_forward_prob, tot_backward_prob, 1e-6)) {
// KALDI_ERR << "Total forward probability over lattice = " << tot_forward_prob
// << ", while total backward probability = " << tot_backward_prob;
// }
//
// alpha_smbr[0] = 0.0;
// // Second Pass Forward, calculate forward for MPFE/SMBR
// for (StateId s = 0; s < num_states; s++) {
// double this_alpha = alpha[s];
// for (ArcIterator<Lattice> aiter(lat, s); !aiter.Done(); aiter.Next()) {
// const Arc &arc = aiter.Value();
// double arc_like = -ConvertToCost(arc.weight);
// double frame_acc = 0.0;
// if (arc.ilabel != 0) {
// int32 cur_time = state_times[s];
// int32 phone = trans.TransitionIdToPhone(arc.ilabel),
// ref_phone = trans.TransitionIdToPhone(num_ali[cur_time]);
// bool phone_is_sil = std::binary_search(silence_phones.begin(),
// silence_phones.end(),
// phone),
// ref_phone_is_sil = std::binary_search(silence_phones.begin(),
// silence_phones.end(),
// ref_phone),
// both_sil = phone_is_sil && ref_phone_is_sil;
// if (!is_mpfe) { // smbr.
// int32 pdf = trans.TransitionIdToPdf(arc.ilabel),
// ref_pdf = trans.TransitionIdToPdf(num_ali[cur_time]);
// if (!one_silence_class) // old behavior
// frame_acc = (pdf == ref_pdf && !phone_is_sil) ? 1.0 : 0.0;
// else
// frame_acc = (pdf == ref_pdf || both_sil) ? 1.0 : 0.0;
// } else {
// if (!one_silence_class) // old behavior
// frame_acc = (phone == ref_phone && !phone_is_sil) ? 1.0 : 0.0;
// else
// frame_acc = (phone == ref_phone || both_sil) ? 1.0 : 0.0;
// }
// }
// double arc_scale = Exp(alpha[s] + arc_like - alpha[arc.nextstate]);
// alpha_smbr[arc.nextstate] += arc_scale * (alpha_smbr[s] + frame_acc);
// }
// Weight f = lat.Final(s);
// if (f != Weight::Zero()) {
// double final_like = this_alpha - (f.Value1() + f.Value2());
// double arc_scale = Exp(final_like - tot_forward_prob);
// tot_forward_score += arc_scale * alpha_smbr[s];
// KALDI_ASSERT(state_times[s] == max_time &&
// "Lattice is inconsistent (final-prob not at max_time)");
// }
// }
// // Second Pass Backward, collect Mpe style posteriors
// for (StateId s = num_states-1; s >= 0; s--) {
// for (ArcIterator<Lattice> aiter(lat, s); !aiter.Done(); aiter.Next()) {
// const Arc &arc = aiter.Value();
// double arc_like = -ConvertToCost(arc.weight),
// arc_beta = beta[arc.nextstate] + arc_like;
// double frame_acc = 0.0;
// int32 transition_id = arc.ilabel;
// if (arc.ilabel != 0) {
// int32 cur_time = state_times[s];
// int32 phone = trans.TransitionIdToPhone(arc.ilabel),
// ref_phone = trans.TransitionIdToPhone(num_ali[cur_time]);
// bool phone_is_sil = std::binary_search(silence_phones.begin(),
// silence_phones.end(), phone),
// ref_phone_is_sil = std::binary_search(silence_phones.begin(),
// silence_phones.end(),
// ref_phone),
// both_sil = phone_is_sil && ref_phone_is_sil;
// if (!is_mpfe) { // smbr.
// int32 pdf = trans.TransitionIdToPdf(arc.ilabel),
// ref_pdf = trans.TransitionIdToPdf(num_ali[cur_time]);
// if (!one_silence_class) // old behavior
// frame_acc = (pdf == ref_pdf && !phone_is_sil) ? 1.0 : 0.0;
// else
// frame_acc = (pdf == ref_pdf || both_sil) ? 1.0 : 0.0;
// } else {
// if (!one_silence_class) // old behavior
// frame_acc = (phone == ref_phone && !phone_is_sil) ? 1.0 : 0.0;
// else
// frame_acc = (phone == ref_phone || both_sil) ? 1.0 : 0.0;
// }
// }
// double arc_scale = Exp(beta[arc.nextstate] + arc_like - beta[s]);
// // check arc_scale NAN,
// // this is to prevent partial paths in Lattices
// // i.e., paths don't survive to the final state
// if (KALDI_ISNAN(arc_scale)) arc_scale = 0;
// beta_smbr[s] += arc_scale * (beta_smbr[arc.nextstate] + frame_acc);
//
// if (transition_id != 0) { // Arc has a transition-id on it [not epsilon]
// double posterior = Exp(alpha[s] + arc_beta - tot_forward_prob);
// double acc_diff = alpha_smbr[s] + frame_acc + beta_smbr[arc.nextstate]
// - tot_forward_score;
// double posterior_smbr = posterior * acc_diff;
// (*post)[state_times[s]].push_back(std::make_pair(transition_id,
// static_cast<BaseFloat>(posterior_smbr)));
// }
// }
// }
//
// //Second Pass Forward Backward check
// double tot_backward_score = beta_smbr[0]; // Initial state id == 0
// // may loose the condition somehow here 1e-5/1e-4
// if (!ApproxEqual(tot_forward_score, tot_backward_score, 1e-4)) {
// KALDI_ERR << "Total forward score over lattice = " << tot_forward_score
// << ", while total backward score = " << tot_backward_score;
// }
//
// // Output the computed posteriors
// for (int32 t = 0; t < max_time; t++)
// MergePairVectorSumming(&((*post)[t]));
// return tot_forward_score;
// }
//
// bool CompactLatticeToWordAlignment(const CompactLattice &clat,
// std::vector<int32> *words,
// std::vector<int32> *begin_times,
// std::vector<int32> *lengths) {
// words->clear();
// begin_times->clear();
// lengths->clear();
// typedef CompactLattice::Arc Arc;
// typedef Arc::Label Label;
// typedef CompactLattice::StateId StateId;
// typedef CompactLattice::Weight Weight;
// using namespace fst;
// StateId state = clat.Start();
// int32 cur_time = 0;
// if (state == kNoStateId) {
// KALDI_WARN << "Empty lattice.";
// return false;
// }
// while (1) {
// Weight final = clat.Final(state);
// size_t num_arcs = clat.NumArcs(state);
// if (final != Weight::Zero()) {
// if (num_arcs != 0) {
// KALDI_WARN << "Lattice is not linear.";
// return false;
// }
// if (! final.String().empty()) {
// KALDI_WARN << "Lattice has alignments on final-weight: probably "
// "was not word-aligned (alignments will be approximate)";
// }
// return true;
// } else {
// if (num_arcs != 1) {
// KALDI_WARN << "Lattice is not linear: num-arcs = " << num_arcs;
// return false;
// }
// fst::ArcIterator<CompactLattice> aiter(clat, state);
// const Arc &arc = aiter.Value();
// Label word_id = arc.ilabel; // Note: ilabel==olabel, since acceptor.
// // Also note: word_id may be zero; we output it anyway.
// int32 length = arc.weight.String().size();
// words->push_back(word_id);
// begin_times->push_back(cur_time);
// lengths->push_back(length);
// cur_time += length;
// state = arc.nextstate;
// }
// }
// }
//
//
// bool CompactLatticeToWordProns(
// const TransitionModel &tmodel,
// const CompactLattice &clat,
// std::vector<int32> *words,
// std::vector<int32> *begin_times,
// std::vector<int32> *lengths,
// std::vector<std::vector<int32> > *prons,
// std::vector<std::vector<int32> > *phone_lengths) {
// words->clear();
// begin_times->clear();
// lengths->clear();
// prons->clear();
// phone_lengths->clear();
// typedef CompactLattice::Arc Arc;
// typedef Arc::Label Label;
// typedef CompactLattice::StateId StateId;
// typedef CompactLattice::Weight Weight;
// using namespace fst;
// StateId state = clat.Start();
// int32 cur_time = 0;
// if (state == kNoStateId) {
// KALDI_WARN << "Empty lattice.";
// return false;
// }
// while (1) {
// Weight final = clat.Final(state);
// size_t num_arcs = clat.NumArcs(state);
// if (final != Weight::Zero()) {
// if (num_arcs != 0) {
// KALDI_WARN << "Lattice is not linear.";
// return false;
// }
// if (! final.String().empty()) {
// KALDI_WARN << "Lattice has alignments on final-weight: probably "
// "was not word-aligned (alignments will be approximate)";
// }
// return true;
// } else {
// if (num_arcs != 1) {
// KALDI_WARN << "Lattice is not linear: num-arcs = " << num_arcs;
// return false;
// }
// fst::ArcIterator<CompactLattice> aiter(clat, state);
// const Arc &arc = aiter.Value();
// Label word_id = arc.ilabel; // Note: ilabel==olabel, since acceptor.
// // Also note: word_id may be zero; we output it anyway.
// int32 length = arc.weight.String().size();
// words->push_back(word_id);
// begin_times->push_back(cur_time);
// lengths->push_back(length);
// const std::vector<int32> &arc_alignment = arc.weight.String();
// std::vector<std::vector<int32> > split_alignment;
// SplitToPhones(tmodel, arc_alignment, &split_alignment);
// std::vector<int32> phones(split_alignment.size());
// std::vector<int32> plengths(split_alignment.size());
// for (size_t i = 0; i < split_alignment.size(); i++) {
// KALDI_ASSERT(!split_alignment[i].empty());
// phones[i] = tmodel.TransitionIdToPhone(split_alignment[i][0]);
// plengths[i] = split_alignment[i].size();
// }
// prons->push_back(phones);
// phone_lengths->push_back(plengths);
//
// cur_time += length;
// state = arc.nextstate;
// }
// }
// }
//
//
//
// void CompactLatticeShortestPath(const CompactLattice &clat,
// CompactLattice *shortest_path) {
// using namespace fst;
// if (clat.Properties(fst::kTopSorted, true) == 0) {
// CompactLattice clat_copy(clat);
// if (!TopSort(&clat_copy))
// KALDI_ERR << "Was not able to topologically sort lattice (cycles found?)";
// CompactLatticeShortestPath(clat_copy, shortest_path);
// return;
// }
// // Now we can assume it's topologically sorted.
// shortest_path->DeleteStates();
// if (clat.Start() == kNoStateId) return;
// typedef CompactLatticeArc Arc;
// typedef Arc::StateId StateId;
// typedef CompactLatticeWeight Weight;
// vector<std::pair<double, StateId> > best_cost_and_pred(clat.NumStates() + 1);
// StateId superfinal = clat.NumStates();
// for (StateId s = 0; s <= clat.NumStates(); s++) {
// best_cost_and_pred[s].first = std::numeric_limits<double>::infinity();
// best_cost_and_pred[s].second = fst::kNoStateId;
// }
// best_cost_and_pred[clat.Start()].first = 0;
// for (StateId s = 0; s < clat.NumStates(); s++) {
// double my_cost = best_cost_and_pred[s].first;
// for (ArcIterator<CompactLattice> aiter(clat, s);
// !aiter.Done();
// aiter.Next()) {
// const Arc &arc = aiter.Value();
// double arc_cost = ConvertToCost(arc.weight),
// next_cost = my_cost + arc_cost;
// if (next_cost < best_cost_and_pred[arc.nextstate].first) {
// best_cost_and_pred[arc.nextstate].first = next_cost;
// best_cost_and_pred[arc.nextstate].second = s;
// }
// }
// double final_cost = ConvertToCost(clat.Final(s)),
// tot_final = my_cost + final_cost;
// if (tot_final < best_cost_and_pred[superfinal].first) {
// best_cost_and_pred[superfinal].first = tot_final;
// best_cost_and_pred[superfinal].second = s;
// }
// }
// std::vector<StateId> states; // states on best path.
// StateId cur_state = superfinal, start_state = clat.Start();
// while (cur_state != start_state) {
// StateId prev_state = best_cost_and_pred[cur_state].second;
// if (prev_state == kNoStateId) {
// KALDI_WARN << "Failure in best-path algorithm for lattice (infinite costs?)";
// return; // return empty best-path.
// }
// states.push_back(prev_state);
// KALDI_ASSERT(cur_state != prev_state && "Lattice with cycles");
// cur_state = prev_state;
// }
// std::reverse(states.begin(), states.end());
// for (size_t i = 0; i < states.size(); i++)
// shortest_path->AddState();
// for (StateId s = 0; static_cast<size_t>(s) < states.size(); s++) {
// if (s == 0) shortest_path->SetStart(s);
// if (static_cast<size_t>(s + 1) < states.size()) { // transition to next state.
// bool have_arc = false;
// Arc cur_arc;
// for (ArcIterator<CompactLattice> aiter(clat, states[s]);
// !aiter.Done();
// aiter.Next()) {
// const Arc &arc = aiter.Value();
// if (arc.nextstate == states[s+1]) {
// if (!have_arc ||
// ConvertToCost(arc.weight) < ConvertToCost(cur_arc.weight)) {
// cur_arc = arc;
// have_arc = true;
// }
// }
// }
// KALDI_ASSERT(have_arc && "Code error.");
// shortest_path->AddArc(s, Arc(cur_arc.ilabel, cur_arc.olabel,
// cur_arc.weight, s+1));
// } else { // final-prob.
// shortest_path->SetFinal(s, clat.Final(states[s]));
// }
// }
// }
//
//
// void ExpandCompactLattice(const CompactLattice &clat,
// double epsilon,
// CompactLattice *expand_clat) {
// using namespace fst;
// typedef CompactLattice::Arc Arc;
// typedef Arc::Weight Weight;
// typedef Arc::StateId StateId;
// typedef std::pair<StateId, StateId> StatePair;
// typedef unordered_map<StatePair, StateId, PairHasher<StateId> > MapType;
// typedef MapType::iterator IterType;
//
// if (clat.Start() == kNoStateId) return;
// // Make sure the input lattice is topologically sorted.
// if (clat.Properties(kTopSorted, true) == 0) {
// CompactLattice clat_copy(clat);
// KALDI_LOG << "Topsort this lattice.";
// if (!TopSort(&clat_copy))
// KALDI_ERR << "Was not able to topologically sort lattice (cycles found?)";
// ExpandCompactLattice(clat_copy, epsilon, expand_clat);
// return;
// }
//
// // Compute backward logprobs betas for the expanded lattice.
// // Note: the backward logprobs in the original lattice <clat> and the
// // expanded lattice <expand_clat> are the same.
// int32 num_states = clat.NumStates();
// std::vector<double> beta(num_states, kLogZeroDouble);
// ComputeCompactLatticeBetas(clat, &beta);
// double tot_backward_logprob = beta[0];
// std::vector<double> alpha;
// alpha.push_back(0.0);
// expand_clat->DeleteStates();
// MapType state_map; // Map from state pair (orig_state, copy_state) to
// // copy_state, where orig_state is a state in the original lattice, and
// // copy_state is its corresponding one in the expanded lattice.
// unordered_map<StateId, StateId> states; // Map from orig_state to its
// // copy_state for states with incoming arcs' posteriors <= epsilon.
// std::queue<StatePair> state_queue;
//
// // Set start state in the expanded lattice.
// StateId start_state = expand_clat->AddState();
// expand_clat->SetStart(start_state);
// StatePair start_pair(clat.Start(), start_state);
// state_queue.push(start_pair);
// std::pair<IterType, bool> result =
// state_map.insert(std::make_pair(start_pair, start_state));
// KALDI_ASSERT(result.second == true);
//
// // Expand <clat> and update forward logprobs alphas in <expand_clat>.
// while (!state_queue.empty()) {
// StatePair s = state_queue.front();
// StateId s1 = s.first,
// s2 = s.second;
// state_queue.pop();
//
// Weight f = clat.Final(s1);
// if (f != Weight::Zero()) {
// KALDI_ASSERT(state_map.find(s) != state_map.end());
// expand_clat->SetFinal(state_map[s], f);
// }
//
// for (ArcIterator<CompactLattice> aiter(clat, s1);
// !aiter.Done(); aiter.Next()) {
// const Arc &arc = aiter.Value();
// StateId orig_state = arc.nextstate;
// double arc_like = -ConvertToCost(arc.weight),
// this_alpha = alpha[s2] + arc_like,
// arc_post = Exp(this_alpha + beta[orig_state] -
// tot_backward_logprob);
// // Generate the expanded lattice.
// StateId copy_state;
// if (arc_post > epsilon) {
// copy_state = expand_clat->AddState();
// StatePair next_pair(orig_state, copy_state);
// std::pair<IterType, bool> result =
// state_map.insert(std::make_pair(next_pair, copy_state));
// KALDI_ASSERT(result.second == true);
// state_queue.push(next_pair);
// } else {
// unordered_map<StateId, StateId>::iterator iter = states.find(orig_state);
// if (iter == states.end() ) { // The counterpart state of orig_state
// // has not been created in <expand_clat> yet.
// copy_state = expand_clat->AddState();
// StatePair next_pair(orig_state, copy_state);
// std::pair<IterType, bool> result =
// state_map.insert(std::make_pair(next_pair, copy_state));
// KALDI_ASSERT(result.second == true);
// state_queue.push(next_pair);
// states[orig_state] = copy_state;
// } else {
// copy_state = iter->second;
// }
// }
// // Create an arc from state_map[s] to copy_state in the expanded lattice.
// expand_clat->AddArc(state_map[s], Arc(arc.ilabel, arc.olabel, arc.weight,
// copy_state));
// // Compute forward logprobs alpha for the expanded lattice.
// if ((alpha.size() - 1) < copy_state) { // The first time to compute alpha
// // for copy_state in <expand_clat>.
// alpha.push_back(this_alpha);
// } else { // Accumulate alpha.
// alpha[copy_state] = LogAdd(alpha[copy_state], this_alpha);
// }
// }
// } // end while
// }
//
//
// void CompactLatticeBestCostsAndTracebacks(
// const CompactLattice &clat,
// CostTraceType *forward_best_cost_and_pred,
// CostTraceType *backward_best_cost_and_pred) {
//
// // typedef the arc, weight types
// typedef CompactLatticeArc Arc;
// typedef Arc::Weight Weight;
// typedef Arc::StateId StateId;
//
// forward_best_cost_and_pred->clear();
// backward_best_cost_and_pred->clear();
// forward_best_cost_and_pred->resize(clat.NumStates());
// backward_best_cost_and_pred->resize(clat.NumStates());
// // Initialize the cost and predecessor state for each state.
// for (StateId s = 0; s < clat.NumStates(); s++) {
// (*forward_best_cost_and_pred)[s].first =
// std::numeric_limits<double>::infinity();
// (*backward_best_cost_and_pred)[s].first =
// std::numeric_limits<double>::infinity();
// (*forward_best_cost_and_pred)[s].second = fst::kNoStateId;
// (*backward_best_cost_and_pred)[s].second = fst::kNoStateId;
// }
//
// StateId start_state = clat.Start();
// (*forward_best_cost_and_pred)[start_state].first = 0;
// // Transverse the lattice forwardly to compute the best cost from the start
// // state to each state and the best predecessor state of each state.
// for (StateId s = 0; s < clat.NumStates(); s++) {
// double cur_cost = (*forward_best_cost_and_pred)[s].first;
// for (fst::ArcIterator<CompactLattice> aiter(clat, s);
// !aiter.Done(); aiter.Next()) {
// const Arc &arc = aiter.Value();
// double next_cost = cur_cost + ConvertToCost(arc.weight);
// if (next_cost < (*forward_best_cost_and_pred)[arc.nextstate].first) {
// (*forward_best_cost_and_pred)[arc.nextstate].first = next_cost;
// (*forward_best_cost_and_pred)[arc.nextstate].second = s;
// }
// }
// }
// // Transverse the lattice backwardly to compute the best cost from a final
// // state to each state and the best predecessor state of each state.
// for (StateId s = clat.NumStates() - 1; s >= 0; s--) {
// double this_cost = ConvertToCost(clat.Final(s));
// for (fst::ArcIterator<CompactLattice> aiter(clat, s);
// !aiter.Done(); aiter.Next()) {
// const Arc &arc = aiter.Value();
// double next_cost = (*backward_best_cost_and_pred)[arc.nextstate].first +
// ConvertToCost(arc.weight);
// if (next_cost < this_cost) {
// this_cost = next_cost;
// (*backward_best_cost_and_pred)[s].second = arc.nextstate;
// }
// }
// (*backward_best_cost_and_pred)[s].first = this_cost;
// }
// }
//
//
// void AddNnlmScoreToCompactLattice(const MapT &nnlm_scores,
// CompactLattice *clat) {
// if (clat->Start() == fst::kNoStateId) return;
// // Make sure the input lattice is topologically sorted.
// if (clat->Properties(fst::kTopSorted, true) == 0) {
// KALDI_LOG << "Topsort this lattice.";
// if (!TopSort(clat))
// KALDI_ERR << "Was not able to topologically sort lattice (cycles found?)";
// AddNnlmScoreToCompactLattice(nnlm_scores, clat);
// return;
// }
//
// // typedef the arc, weight types
// typedef CompactLatticeArc Arc;
// typedef Arc::Weight Weight;
// typedef Arc::StateId StateId;
// typedef std::pair<int32, int32> StatePair;
//
// int32 num_states = clat->NumStates();
// unordered_map<StatePair, bool, PairHasher<int32> > final_state_check;
// for (StateId s = 0; s < num_states; s++) {
// for (fst::MutableArcIterator<CompactLattice> aiter(clat, s);
// !aiter.Done(); aiter.Next()) {
// Arc arc(aiter.Value());
// StatePair arc_index = std::make_pair(static_cast<int32>(s),
// static_cast<int32>(arc.nextstate));
// MapT::const_iterator it = nnlm_scores.find(arc_index);
// double nnlm_score;
// if (it != nnlm_scores.end())
// nnlm_score = it->second;
// else
// KALDI_ERR << "Some arc does not have neural language model score.";
// if (arc.ilabel != 0) { // if there is a word on this arc
// LatticeWeight weight = arc.weight.Weight();
// // Add associated neural LM score to each arc.
// weight.SetValue1(weight.Value1() + nnlm_score);
// arc.weight.SetWeight(weight);
// aiter.SetValue(arc);
// }
// Weight clat_final = clat->Final(arc.nextstate);
// StatePair final_pair = std::make_pair(arc.nextstate, arc.nextstate);
// // Add neural LM scores to each final state only once.
// if (clat_final != CompactLatticeWeight::Zero() &&
// final_state_check.find(final_pair) == final_state_check.end()) {
// MapT::const_iterator final_it = nnlm_scores.find(final_pair);
// double final_nnlm_score = 0.0;
// if (final_it != nnlm_scores.end())
// final_nnlm_score = final_it->second;
// // Add neural LM scores to the final weight.
// Weight final_weight(LatticeWeight(clat_final.Weight().Value1() +
// final_nnlm_score,
// clat_final.Weight().Value2()),
// clat_final.String());
// clat->SetFinal(arc.nextstate, final_weight);
// final_state_check[final_pair] = true;
// }
// } // end looping over arcs
// } // end looping over states
// }
//
// void AddWordInsPenToCompactLattice(BaseFloat word_ins_penalty,
// CompactLattice *clat) {
// typedef CompactLatticeArc Arc;
// int32 num_states = clat->NumStates();
//
// //scan the lattice
// for (int32 state = 0; state < num_states; state++) {
// for (fst::MutableArcIterator<CompactLattice> aiter(clat, state);
// !aiter.Done(); aiter.Next()) {
//
// Arc arc(aiter.Value());
//
// if (arc.ilabel != 0) { // if there is a word on this arc
// LatticeWeight weight = arc.weight.Weight();
// // add word insertion penalty to lattice
// weight.SetValue1( weight.Value1() + word_ins_penalty);
// arc.weight.SetWeight(weight);
// aiter.SetValue(arc);
// }
// } // end looping over arcs
// } // end looping over states
// }
//
// struct ClatRescoreTuple {
// ClatRescoreTuple(int32 state, int32 arc, int32 tid):
// state_id(state), arc_id(arc), tid(tid) { }
// int32 state_id;
// int32 arc_id;
// int32 tid;
// };
//
// /** RescoreCompactLatticeInternal is the internal code for both
// RescoreCompactLattice and RescoreCompatLatticeSpeedup. For
// RescoreCompactLattice, "tmodel" will be NULL and speedup_factor will be 1.0.
// */
// bool RescoreCompactLatticeInternal(
// const TransitionModel *tmodel,
// BaseFloat speedup_factor,
// DecodableInterface *decodable,
// CompactLattice *clat) {
// KALDI_ASSERT(speedup_factor >= 1.0);
// if (clat->NumStates() == 0) {
// KALDI_WARN << "Rescoring empty lattice";
// return false;
// }
// if (!clat->Properties(fst::kTopSorted, true)) {
// if (fst::TopSort(clat) == false) {
// KALDI_WARN << "Cycles detected in lattice.";
// return false;
// }
// }
// std::vector<int32> state_times;
// int32 utt_len = kaldi::CompactLatticeStateTimes(*clat, &state_times);
//
// std::vector<std::vector<ClatRescoreTuple> > time_to_state(utt_len);
//
// int32 num_states = clat->NumStates();
// KALDI_ASSERT(num_states == state_times.size());
// for (size_t state = 0; state < num_states; state++) {
// KALDI_ASSERT(state_times[state] >= 0);
// int32 t = state_times[state];
// int32 arc_id = 0;
// for (fst::MutableArcIterator<CompactLattice> aiter(clat, state);
// !aiter.Done(); aiter.Next(), arc_id++) {
// CompactLatticeArc arc = aiter.Value();
// std::vector<int32> arc_string = arc.weight.String();
//
// for (size_t offset = 0; offset < arc_string.size(); offset++) {
// if (t < utt_len) { // end state may be past this..
// int32 tid = arc_string[offset];
// time_to_state[t+offset].push_back(ClatRescoreTuple(state, arc_id, tid));
// } else {
// if (t != utt_len) {
// KALDI_WARN << "There appears to be lattice/feature mismatch, "
// << "aborting.";
// return false;
// }
// }
// }
// }
// if (clat->Final(state) != CompactLatticeWeight::Zero()) {
// arc_id = -1;
// std::vector<int32> arc_string = clat->Final(state).String();
// for (size_t offset = 0; offset < arc_string.size(); offset++) {
// KALDI_ASSERT(t + offset < utt_len); // already checked in
// // CompactLatticeStateTimes, so would be code error.
// time_to_state[t+offset].push_back(
// ClatRescoreTuple(state, arc_id, arc_string[offset]));
// }
// }
// }
//
// for (int32 t = 0; t < utt_len; t++) {
// if ((t < utt_len - 1) && decodable->IsLastFrame(t)) {
// KALDI_WARN << "Features are too short for lattice: utt-len is "
// << utt_len << ", " << t << " is last frame";
// return false;
// }
// // frame_scale is the scale we put on the computed acoustic probs for this
// // frame. It will always be 1.0 if tmodel == NULL (i.e. if we are not doing
// // the "speedup" code). For frames with multiple pdf-ids it will be one.
// // For frames with only one pdf-id, it will equal speedup_factor (>=1.0)
// // with probability 1.0 / speedup_factor, and zero otherwise. If it is zero,
// // we can avoid computing the probabilities.
// BaseFloat frame_scale = 1.0;
// KALDI_ASSERT(!time_to_state[t].empty());
// if (tmodel != NULL) {
// int32 pdf_id = tmodel->TransitionIdToPdf(time_to_state[t][0].tid);
// bool frame_has_multiple_pdfs = false;
// for (size_t i = 1; i < time_to_state[t].size(); i++) {
// if (tmodel->TransitionIdToPdf(time_to_state[t][i].tid) != pdf_id) {
// frame_has_multiple_pdfs = true;
// break;
// }
// }
// if (frame_has_multiple_pdfs) {
// frame_scale = 1.0;
// } else {
// if (WithProb(1.0 / speedup_factor)) {
// frame_scale = speedup_factor;
// } else {
// frame_scale = 0.0;
// }
// }
// if (frame_scale == 0.0)
// continue; // the code below would be pointless.
// }
//
// for (size_t i = 0; i < time_to_state[t].size(); i++) {
// int32 state = time_to_state[t][i].state_id;
// int32 arc_id = time_to_state[t][i].arc_id;
// int32 tid = time_to_state[t][i].tid;
//
// if (arc_id == -1) { // Final state
// // Access the trans_id
// CompactLatticeWeight curr_clat_weight = clat->Final(state);
//
// // Calculate likelihood
// BaseFloat log_like = decodable->LogLikelihood(t, tid) * frame_scale;
// // update weight
// CompactLatticeWeight new_clat_weight = curr_clat_weight;
// LatticeWeight new_lat_weight = new_clat_weight.Weight();
// new_lat_weight.SetValue2(-log_like + curr_clat_weight.Weight().Value2());
// new_clat_weight.SetWeight(new_lat_weight);
// clat->SetFinal(state, new_clat_weight);
// } else {
// fst::MutableArcIterator<CompactLattice> aiter(clat, state);
//
// aiter.Seek(arc_id);
// CompactLatticeArc arc = aiter.Value();
//
// // Calculate likelihood
// BaseFloat log_like = decodable->LogLikelihood(t, tid) * frame_scale;
// // update weight
// LatticeWeight new_weight = arc.weight.Weight();
// new_weight.SetValue2(-log_like + arc.weight.Weight().Value2());
// arc.weight.SetWeight(new_weight);
// aiter.SetValue(arc);
// }
// }
// }
// return true;
// }
//
//
// bool RescoreCompactLatticeSpeedup(
// const TransitionModel &tmodel,
// BaseFloat speedup_factor,
// DecodableInterface *decodable,
// CompactLattice *clat) {
// return RescoreCompactLatticeInternal(&tmodel, speedup_factor, decodable, clat);
// }
//
// bool RescoreCompactLattice(DecodableInterface *decodable,
// CompactLattice *clat) {
// return RescoreCompactLatticeInternal(NULL, 1.0, decodable, clat);
// }
//
//
// bool RescoreLattice(DecodableInterface *decodable,
// Lattice *lat) {
// if (lat->NumStates() == 0) {
// KALDI_WARN << "Rescoring empty lattice";
// return false;
// }
// if (!lat->Properties(fst::kTopSorted, true)) {
// if (fst::TopSort(lat) == false) {
// KALDI_WARN << "Cycles detected in lattice.";
// return false;
// }
// }
// std::vector<int32> state_times;
// int32 utt_len = kaldi::LatticeStateTimes(*lat, &state_times);
//
// std::vector<std::vector<int32> > time_to_state(utt_len );
//
// int32 num_states = lat->NumStates();
// KALDI_ASSERT(num_states == state_times.size());
// for (size_t state = 0; state < num_states; state++) {
// int32 t = state_times[state];
// // Don't check t >= 0 because non-accessible states could have t = -1.
// KALDI_ASSERT(t <= utt_len);
// if (t >= 0 && t < utt_len)
// time_to_state[t].push_back(state);
// }
//
// for (int32 t = 0; t < utt_len; t++) {
// if ((t < utt_len - 1) && decodable->IsLastFrame(t)) {
// KALDI_WARN << "Features are too short for lattice: utt-len is "
// << utt_len << ", " << t << " is last frame";
// return false;
// }
// for (size_t i = 0; i < time_to_state[t].size(); i++) {
// int32 state = time_to_state[t][i];
// for (fst::MutableArcIterator<Lattice> aiter(lat, state);
// !aiter.Done(); aiter.Next()) {
// LatticeArc arc = aiter.Value();
// if (arc.ilabel != 0) {
// int32 trans_id = arc.ilabel; // Note: it doesn't necessarily
// // have to be a transition-id, just whatever the Decodable
// // object is expecting, but it's normally a transition-id.
//
// BaseFloat log_like = decodable->LogLikelihood(t, trans_id);
// arc.weight.SetValue2(-log_like + arc.weight.Value2());
// aiter.SetValue(arc);
// }
// }
// }
// }
// return true;
// }
//
//
// BaseFloat LatticeForwardBackwardMmi(
// const TransitionModel &tmodel,
// const Lattice &lat,
// const std::vector<int32> &num_ali,
// bool drop_frames,
// bool convert_to_pdf_ids,
// bool cancel,
// Posterior *post) {
// // First compute the MMI posteriors.
//
// Posterior den_post;
// BaseFloat ans = LatticeForwardBackward(lat,
// &den_post,
// NULL);
//
// Posterior num_post;
// AlignmentToPosterior(num_ali, &num_post);
//
// // Now negate the MMI posteriors and add the numerator
// // posteriors.
// ScalePosterior(-1.0, &den_post);
//
// if (convert_to_pdf_ids) {
// Posterior num_tmp;
// ConvertPosteriorToPdfs(tmodel, num_post, &num_tmp);
// num_tmp.swap(num_post);
// Posterior den_tmp;
// ConvertPosteriorToPdfs(tmodel, den_post, &den_tmp);
// den_tmp.swap(den_post);
// }
//
// MergePosteriors(num_post, den_post,
// cancel, drop_frames, post);
//
// return ans;
// }
//
//
// int32 LongestSentenceLength(const Lattice &lat) {
// typedef Lattice::Arc Arc;
// typedef Arc::Label Label;
// typedef Arc::StateId StateId;
//
// if (lat.Properties(fst::kTopSorted, true) == 0) {
// Lattice lat_copy(lat);
// if (!TopSort(&lat_copy))
// KALDI_ERR << "Was not able to topologically sort lattice (cycles found?)";
// return LongestSentenceLength(lat_copy);
// }
// std::vector<int32> max_length(lat.NumStates(), 0);
// int32 lattice_max_length = 0;
// for (StateId s = 0; s < lat.NumStates(); s++) {
// int32 this_max_length = max_length[s];
// for (fst::ArcIterator<Lattice> aiter(lat, s); !aiter.Done(); aiter.Next()) {
// const Arc &arc = aiter.Value();
// bool arc_has_word = (arc.olabel != 0);
// StateId nextstate = arc.nextstate;
// KALDI_ASSERT(static_cast<size_t>(nextstate) < max_length.size());
// if (arc_has_word) {
// // A lattice should ideally not have cycles anyway; a cycle with a word
// // on is something very bad.
// KALDI_ASSERT(nextstate > s && "Lattice has cycles with words on.");
// max_length[nextstate] = std::max(max_length[nextstate],
// this_max_length + 1);
// } else {
// max_length[nextstate] = std::max(max_length[nextstate],
// this_max_length);
// }
// }
// if (lat.Final(s) != LatticeWeight::Zero())
// lattice_max_length = std::max(lattice_max_length, max_length[s]);
// }
// return lattice_max_length;
// }
//
// int32 LongestSentenceLength(const CompactLattice &clat) {
// typedef CompactLattice::Arc Arc;
// typedef Arc::Label Label;
// typedef Arc::StateId StateId;
//
// if (clat.Properties(fst::kTopSorted, true) == 0) {
// CompactLattice clat_copy(clat);
// if (!TopSort(&clat_copy))
// KALDI_ERR << "Was not able to topologically sort lattice (cycles found?)";
// return LongestSentenceLength(clat_copy);
// }
// std::vector<int32> max_length(clat.NumStates(), 0);
// int32 lattice_max_length = 0;
// for (StateId s = 0; s < clat.NumStates(); s++) {
// int32 this_max_length = max_length[s];
// for (fst::ArcIterator<CompactLattice> aiter(clat, s);
// !aiter.Done(); aiter.Next()) {
// const Arc &arc = aiter.Value();
// bool arc_has_word = (arc.ilabel != 0); // note: olabel == ilabel.
// // also note: for normal CompactLattice, e.g. as produced by
// // determinization, all arcs will have nonzero labels, but the user might
// // decide to remplace some of the labels with zero for some reason, and we
// // want to support this.
// StateId nextstate = arc.nextstate;
// KALDI_ASSERT(static_cast<size_t>(nextstate) < max_length.size());
// KALDI_ASSERT(nextstate > s && "CompactLattice has cycles");
// if (arc_has_word)
// max_length[nextstate] = std::max(max_length[nextstate],
// this_max_length + 1);
// else
// max_length[nextstate] = std::max(max_length[nextstate],
// this_max_length);
// }
// if (clat.Final(s) != CompactLatticeWeight::Zero())
// lattice_max_length = std::max(lattice_max_length, max_length[s]);
// }
// return lattice_max_length;
// }
//
// void ComposeCompactLatticeDeterministic(
// const CompactLattice& clat,
// fst::DeterministicOnDemandFst<fst::StdArc>* det_fst,
// CompactLattice* composed_clat) {
// // StdFst::Arc and CompactLatticeArc has the same StateId type.
// typedef fst::StdArc::StateId StateId;
// typedef fst::StdArc::Weight Weight1;
// typedef CompactLatticeArc::Weight Weight2;
// typedef std::pair<StateId, StateId> StatePair;
// typedef unordered_map<StatePair, StateId, PairHasher<StateId> > MapType;
// typedef MapType::iterator IterType;
//
// // Empties the output FST.
// KALDI_ASSERT(composed_clat != NULL);
// composed_clat->DeleteStates();
//
// MapType state_map;
// std::queue<StatePair> state_queue;
//
// // Sets start state in <composed_clat>.
// StateId start_state = composed_clat->AddState();
// StatePair start_pair(clat.Start(), det_fst->Start());
// composed_clat->SetStart(start_state);
// state_queue.push(start_pair);
// std::pair<IterType, bool> result =
// state_map.insert(std::make_pair(start_pair, start_state));
// KALDI_ASSERT(result.second == true);
//
// // Starts composition here.
// while (!state_queue.empty()) {
// // Gets the first state in the queue.
// StatePair s = state_queue.front();
// StateId s1 = s.first;
// StateId s2 = s.second;
// state_queue.pop();
//
//
// Weight2 clat_final = clat.Final(s1);
// if (clat_final.Weight().Value1() !=
// std::numeric_limits<BaseFloat>::infinity()) {
// // Test for whether the final-prob of state s1 was zero.
// Weight1 det_fst_final = det_fst->Final(s2);
// if (det_fst_final.Value() !=
// std::numeric_limits<BaseFloat>::infinity()) {
// // Test for whether the final-prob of state s2 was zero. If neither
// // source-state final prob was zero, then we should create final state
// // in fst_composed. We compute the product manually since this is more
// // efficient.
// Weight2 final_weight(LatticeWeight(clat_final.Weight().Value1() +
// det_fst_final.Value(),
// clat_final.Weight().Value2()),
// clat_final.String());
// // we can assume final_weight is not Zero(), since neither of
// // the sources was zero.
// KALDI_ASSERT(state_map.find(s) != state_map.end());
// composed_clat->SetFinal(state_map[s], final_weight);
// }
// }
//
// // Loops over pair of edges at s1 and s2.
// for (fst::ArcIterator<CompactLattice> aiter(clat, s1);
// !aiter.Done(); aiter.Next()) {
// const CompactLatticeArc& arc1 = aiter.Value();
// fst::StdArc arc2;
// StateId next_state1 = arc1.nextstate, next_state2;
// bool matched = false;
//
// if (arc1.olabel == 0) {
// // If the symbol on <arc1> is <epsilon>, we transit to the next state
// // for <clat>, but keep <det_fst> at the current state.
// matched = true;
// next_state2 = s2;
// } else {
// // Otherwise try to find the matched arc in <det_fst>.
// matched = det_fst->GetArc(s2, arc1.olabel, &arc2);
// if (matched) {
// next_state2 = arc2.nextstate;
// }
// }
//
// // If matched arc is found in <det_fst>, then we have to add new arcs to
// // <composed_clat>.
// if (matched) {
// StatePair next_state_pair(next_state1, next_state2);
// IterType siter = state_map.find(next_state_pair);
// StateId next_state;
//
// // Adds composed state to <state_map>.
// if (siter == state_map.end()) {
// // If the composed state has not been created yet, create it.
// next_state = composed_clat->AddState();
// std::pair<const StatePair, StateId> next_state_map(next_state_pair,
// next_state);
// std::pair<IterType, bool> result = state_map.insert(next_state_map);
// KALDI_ASSERT(result.second);
// state_queue.push(next_state_pair);
// } else {
// // If the composed state is already in <state_map>, we can directly
// // use that.
// next_state = siter->second;
// }
//
// // Adds arc to <composed_clat>.
// if (arc1.olabel == 0) {
// composed_clat->AddArc(state_map[s],
// CompactLatticeArc(arc1.ilabel, 0,
// arc1.weight, next_state));
// } else {
// Weight2 composed_weight(
// LatticeWeight(arc1.weight.Weight().Value1() +
// arc2.weight.Value(),
// arc1.weight.Weight().Value2()),
// arc1.weight.String());
// composed_clat->AddArc(state_map[s],
// CompactLatticeArc(arc1.ilabel, arc2.olabel,
// composed_weight, next_state));
// }
// }
// }
// }
// fst::Connect(composed_clat);
// }
//
//
// void ComputeAcousticScoresMap(
// const Lattice &lat,
// unordered_map<std::pair<int32, int32>, std::pair<BaseFloat, int32>,
// PairHasher<int32> > *acoustic_scores) {
// // typedef the arc, weight types
// typedef Lattice::Arc Arc;
// typedef Arc::Weight LatticeWeight;
// typedef Arc::StateId StateId;
//
// acoustic_scores->clear();
//
// std::vector<int32> state_times;
// LatticeStateTimes(lat, &state_times); // Assumes the input is top sorted
//
// KALDI_ASSERT(lat.Start() == 0);
//
// for (StateId s = 0; s < lat.NumStates(); s++) {
// int32 t = state_times[s];
// for (fst::ArcIterator<Lattice> aiter(lat, s); !aiter.Done();
// aiter.Next()) {
// const Arc &arc = aiter.Value();
// const LatticeWeight &weight = arc.weight;
//
// int32 tid = arc.ilabel;
//
// if (tid != 0) {
// unordered_map<std::pair<int32, int32>, std::pair<BaseFloat, int32>,
// PairHasher<int32> >::iterator it = acoustic_scores->find(std::make_pair(t, tid));
// if (it == acoustic_scores->end()) {
// acoustic_scores->insert(std::make_pair(std::make_pair(t, tid),
// std::make_pair(weight.Value2(), 1)));
// } else {
// if (it->second.second == 2
// && it->second.first / it->second.second != weight.Value2()) {
// KALDI_VLOG(2) << "Transitions on the same frame have different "
// << "acoustic costs for tid " << tid << "; "
// << it->second.first / it->second.second
// << " vs " << weight.Value2();
// }
// it->second.first += weight.Value2();
// it->second.second++;
// }
// } else {
// // Arcs with epsilon input label (tid) must have 0 acoustic cost
// KALDI_ASSERT(weight.Value2() == 0);
// }
// }
//
// LatticeWeight f = lat.Final(s);
// if (f != LatticeWeight::Zero()) {
// // Final acoustic cost must be 0 as we are reading from
// // non-determinized, non-compact lattice
// KALDI_ASSERT(f.Value2() == 0.0);
// }
// }
// }
//
// void ReplaceAcousticScoresFromMap(
// const unordered_map<std::pair<int32, int32>, std::pair<BaseFloat, int32>,
// PairHasher<int32> > &acoustic_scores,
// Lattice *lat) {
// // typedef the arc, weight types
// typedef Lattice::Arc Arc;
// typedef Arc::Weight LatticeWeight;
// typedef Arc::StateId StateId;
//
// TopSortLatticeIfNeeded(lat);
//
// std::vector<int32> state_times;
// LatticeStateTimes(*lat, &state_times);
//
// KALDI_ASSERT(lat->Start() == 0);
//
// for (StateId s = 0; s < lat->NumStates(); s++) {
// int32 t = state_times[s];
// for (fst::MutableArcIterator<Lattice> aiter(lat, s);
// !aiter.Done(); aiter.Next()) {
// Arc arc(aiter.Value());
//
// int32 tid = arc.ilabel;
// if (tid != 0) {
// unordered_map<std::pair<int32, int32>, std::pair<BaseFloat, int32>,
// PairHasher<int32> >::const_iterator it = acoustic_scores.find(std::make_pair(t, tid));
// if (it == acoustic_scores.end()) {
// KALDI_ERR << "Could not find tid " << tid << " at time " << t
// << " in the acoustic scores map.";
// } else {
// arc.weight.SetValue2(it->second.first / it->second.second);
// }
// } else {
// // For epsilon arcs, set acoustic cost to 0.0
// arc.weight.SetValue2(0.0);
// }
// aiter.SetValue(arc);
// }
//
// LatticeWeight f = lat->Final(s);
// if (f != LatticeWeight::Zero()) {
// // Set final acoustic cost to 0.0
// f.SetValue2(0.0);
// lat->SetFinal(s, f);
// }
// }
// }
} // namespace kaldi
// lat/lattice-functions.h
// Copyright 2009-2012 Saarland University (author: Arnab Ghoshal)
// 2012-2013 Johns Hopkins University (Author: Daniel Povey);
// Bagher BabaAli
// 2014 Guoguo Chen
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_LAT_LATTICE_FUNCTIONS_H_
#define KALDI_LAT_LATTICE_FUNCTIONS_H_
#include <vector>
#include <map>
#include "base/kaldi-common.h"
// #include "hmm/posterior.h"
#include "fstext/fstext-lib.h"
// #include "hmm/transition-model.h"
#include "lat/kaldi-lattice.h"
// #include "itf/decodable-itf.h"
namespace kaldi {
// /**
// This function extracts the per-frame log likelihoods from a linear
// lattice (which we refer to as an 'nbest' lattice elsewhere in Kaldi code).
// The dimension of *per_frame_loglikes will be set to the
// number of input symbols in 'nbest'. The elements of
// '*per_frame_loglikes' will be set to the .Value2() elements of the lattice
// weights, which represent the acoustic costs; you may want to scale this
// vector afterward by -1/acoustic_scale to get the original loglikes.
// If there are acoustic costs on input-epsilon arcs or the final-prob in 'nbest'
// (and this should not normally be the case in situations where it makes
// sense to call this function), they will be included to the cost of the
// preceding input symbol, or the following input symbol for input-epsilons
// encountered prior to any input symbol. If 'nbest' has no input symbols,
// 'per_frame_loglikes' will be set to the empty vector.
// **/
// void GetPerFrameAcousticCosts(const Lattice &nbest,
// Vector<BaseFloat> *per_frame_loglikes);
//
// /// This function iterates over the states of a topologically sorted lattice and
// /// counts the time instance corresponding to each state. The times are returned
// /// in a vector of integers 'times' which is resized to have a size equal to the
// /// number of states in the lattice. The function also returns the maximum time
// /// in the lattice (this will equal the number of frames in the file).
// int32 LatticeStateTimes(const Lattice &lat, std::vector<int32> *times);
//
// /// As LatticeStateTimes, but in the CompactLattice format. Note: must
// /// be topologically sorted. Returns length of the utterance in frames, which
// /// might not be the same as the maximum time in the lattice, due to frames
// /// in the final-prob.
// int32 CompactLatticeStateTimes(const CompactLattice &clat,
// std::vector<int32> *times);
//
// /// This function does the forward-backward over lattices and computes the
// /// posterior probabilities of the arcs. It returns the total log-probability
// /// of the lattice. The Posterior quantities contain pairs of (transition-id, weight)
// /// on each frame.
// /// If the pointer "acoustic_like_sum" is provided, this value is set to
// /// the sum over the arcs, of the posterior of the arc times the
// /// acoustic likelihood [i.e. negated acoustic score] on that link.
// /// This is used in combination with other quantities to work out
// /// the objective function in MMI discriminative training.
// BaseFloat LatticeForwardBackward(const Lattice &lat,
// Posterior *arc_post,
// double *acoustic_like_sum = NULL);
//
// // This function is something similar to LatticeForwardBackward(), but it is on
// // the CompactLattice lattice format. Also we only need the alpha in the forward
// // path, not the posteriors.
// bool ComputeCompactLatticeAlphas(const CompactLattice &lat,
// std::vector<double> *alpha);
//
// // A sibling of the function CompactLatticeAlphas()... We compute the beta from
// // the backward path here.
// bool ComputeCompactLatticeBetas(const CompactLattice &lat,
// std::vector<double> *beta);
//
//
// // Computes (normal or Viterbi) alphas and betas; returns (total-prob, or
// // best-path negated cost) Note: in either case, the alphas and betas are
// // negated costs. Requires that lat be topologically sorted. This code
// // will work for either CompactLattice or Latice.
// template<typename LatticeType>
// double ComputeLatticeAlphasAndBetas(const LatticeType &lat,
// bool viterbi,
// std::vector<double> *alpha,
// std::vector<double> *beta);
//
//
// /// Topologically sort the compact lattice if not already topologically sorted.
// /// Will crash if the lattice cannot be topologically sorted.
// void TopSortCompactLatticeIfNeeded(CompactLattice *clat);
//
//
// /// Topologically sort the lattice if not already topologically sorted.
// /// Will crash if lattice cannot be topologically sorted.
// void TopSortLatticeIfNeeded(Lattice *clat);
//
// /// Returns the depth of the lattice, defined as the average number of arcs (or
// /// final-prob strings) crossing any given frame. Returns 1 for empty lattices.
// /// Requires that clat is topologically sorted!
// BaseFloat CompactLatticeDepth(const CompactLattice &clat,
// int32 *num_frames = NULL);
//
// /// This function returns, for each frame, the number of arcs crossing that
// /// frame.
// void CompactLatticeDepthPerFrame(const CompactLattice &clat,
// std::vector<int32> *depth_per_frame);
//
//
// /// This function limits the depth of the lattice, per frame: that means, it
// /// does not allow more than a specified number of arcs active on any given
// /// frame. This can be used to reduce the size of the "very deep" portions of
// /// the lattice.
// void CompactLatticeLimitDepth(int32 max_arcs_per_frame,
// CompactLattice *clat);
//
//
// /// Given a lattice, and a transition model to map pdf-ids to phones,
// /// outputs for each frame the set of phones active on that frame. If
// /// sil_phones (which must be sorted and uniq) is nonempty, it excludes
// /// phones in this list.
// void LatticeActivePhones(const Lattice &lat, const TransitionModel &trans,
// const std::vector<int32> &sil_phones,
// std::vector<std::set<int32> > *active_phones);
//
// /// Given a lattice, and a transition model to map pdf-ids to phones,
// /// replace the output symbols (presumably words), with phones; we
// /// use the TransitionModel to work out the phone sequence. Note
// /// that the phone labels are not exactly aligned with the phone
// /// boundaries. We put a phone label to coincide with any transition
// /// to the final, nonemitting state of a phone (this state always exists,
// /// we ensure this in HmmTopology::Check()). This would be the last
// /// transition-id in the phone if reordering is not done (but typically
// /// we do reorder).
// /// Also see PhoneAlignLattice, in phone-align-lattice.h.
// void ConvertLatticeToPhones(const TransitionModel &trans_model,
// Lattice *lat);
/// Prunes a lattice or compact lattice. Returns true on success, false if
/// there was some kind of failure.
template<class LatticeType>
bool PruneLattice(BaseFloat beam, LatticeType *lat);
//
// /// Given a lattice, and a transition model to map pdf-ids to phones,
// /// replace the sequences of transition-ids with sequences of phones.
// /// Note that this is different from ConvertLatticeToPhones, in that
// /// we replace the transition-ids not the words.
// void ConvertCompactLatticeToPhones(const TransitionModel &trans_model,
// CompactLattice *clat);
//
// /// Boosts LM probabilities by b * [number of frame errors]; equivalently, adds
// /// -b*[number of frame errors] to the graph-component of the cost of each arc/path.
// /// There is a frame error if a particular transition-id on a particular frame
// /// corresponds to a phone not matching transcription's alignment for that frame.
// /// This is used in "margin-inspired" discriminative training, esp. Boosted MMI.
// /// The TransitionModel is used to map transition-ids in the lattice
// /// input-side to phones; the phones appearing in
// /// "silence_phones" are treated specially in that we replace the frame error f
// /// (either zero or 1) for a frame, with the minimum of f or max_silence_error.
// /// For the normal recipe, max_silence_error would be zero.
// /// Returns true on success, false if there was some kind of mismatch.
// /// At input, silence_phones must be sorted and unique.
// bool LatticeBoost(const TransitionModel &trans,
// const std::vector<int32> &alignment,
// const std::vector<int32> &silence_phones,
// BaseFloat b,
// BaseFloat max_silence_error,
// Lattice *lat);
//
//
// /**
// This function implements either the MPFE (minimum phone frame error) or SMBR
// (state-level minimum bayes risk) forward-backward, depending on whether
// "criterion" is "mpfe" or "smbr". It returns the MPFE
// criterion of SMBR criterion for this utterance, and outputs the posteriors (which
// may be positive or negative) into "post".
//
// @param [in] trans The transition model. Used to map the
// transition-ids to phones or pdfs.
// @param [in] silence_phones A list of integer ids of silence phones. The
// silence frames i.e. the frames where num_ali
// corresponds to a silence phones are treated specially.
// The behavior is determined by 'one_silence_class'
// being false (traditional behavior) or true.
// Usually in our setup, several phones including
// the silence, vocalized noise, non-spoken noise
// and unk are treated as "silence phones"
// @param [in] lat The denominator lattice
// @param [in] num_ali The numerator alignment
// @param [in] criterion The objective function. Must be "mpfe" or "smbr"
// for MPFE (minimum phone frame error) or sMBR
// (state minimum bayes risk) training.
// @param [in] one_silence_class Determines how the silence frames are treated.
// Setting this to false gives the old traditional behavior,
// where the silence frames (according to num_ali) are
// treated as incorrect. However, this means that the
// insertions are not penalized by the objective.
// Setting this to true gives the new behaviour, where we
// treat silence as any other phone, except that all pdfs
// of silence phones are collapsed into a single class for
// the frame-error computation. This can possible reduce
// the insertions in the trained model. This is closer to
// the WER metric that we actually care about, since WER is
// generally computed after filtering out noises, but
// does penalize insertions.
// @param [out] post The "MBR posteriors" i.e. derivatives w.r.t to the
// pseudo log-likelihoods of states at each frame.
// */
// BaseFloat LatticeForwardBackwardMpeVariants(
// const TransitionModel &trans,
// const std::vector<int32> &silence_phones,
// const Lattice &lat,
// const std::vector<int32> &num_ali,
// std::string criterion,
// bool one_silence_class,
// Posterior *post);
//
// /**
// This function can be used to compute posteriors for MMI, with a positive contribution
// for the numerator and a negative one for the denominator. This function is not actually
// used in our normal MMI training recipes, where it's instead done using various command
// line programs that each do a part of the job. This function was written for use in
// neural-net MMI training.
//
// @param [in] trans The transition model. Used to map the
// transition-ids to phones or pdfs.
// @param [in] lat The denominator lattice
// @param [in] num_ali The numerator alignment
// @param [in] drop_frames If "drop_frames" is true, it will not compute any
// posteriors on frames where the num and den have disjoint
// pdf-ids.
// @param [in] convert_to_pdf_ids If "convert_to_pdfs_ids" is true, it will
// convert the output to be at the level of pdf-ids, not
// transition-ids.
// @param [in] cancel If "cancel" is true, it will cancel out any positive and
// negative parts from the same transition-id (or pdf-id,
// if convert_to_pdf_ids == true).
// @param [out] arc_post The output MMI posteriors of transition-ids (or
// pdf-ids if convert_to_pdf_ids == true) at each frame
// i.e. the difference between the numerator
// and denominator posteriors.
//
// It returns the forward-backward likelihood of the lattice. */
// BaseFloat LatticeForwardBackwardMmi(
// const TransitionModel &trans,
// const Lattice &lat,
// const std::vector<int32> &num_ali,
// bool drop_frames,
// bool convert_to_pdf_ids,
// bool cancel,
// Posterior *arc_post);
//
//
// /// This function takes a CompactLattice that should only contain a single
// /// linear sequence (e.g. derived from lattice-1best), and that should have been
// /// processed so that the arcs in the CompactLattice align correctly with the
// /// word boundaries (e.g. by lattice-align-words). It outputs 3 vectors of the
// /// same size, which give, for each word in the lattice (in sequence), the word
// /// label and the begin time and length in frames. This is done even for zero
// /// (epsilon) words, generally corresponding to optional silence-- if you don't
// /// want them, just ignore them in the output.
// /// This function will print a warning and return false, if the lattice
// /// did not have the correct format (e.g. if it is empty or it is not
// /// linear).
// bool CompactLatticeToWordAlignment(const CompactLattice &clat,
// std::vector<int32> *words,
// std::vector<int32> *begin_times,
// std::vector<int32> *lengths);
//
// /// This function takes a CompactLattice that should only contain a single
// /// linear sequence (e.g. derived from lattice-1best), and that should have been
// /// processed so that the arcs in the CompactLattice align correctly with the
// /// word boundaries (e.g. by lattice-align-words). It outputs 4 vectors of the
// /// same size, which give, for each word in the lattice (in sequence), the word
// /// label, the begin time and length in frames, and the pronunciation (sequence
// /// of phones). This is done even for zero words, corresponding to optional
// /// silences -- if you don't want them, just ignore them in the output.
// /// This function will print a warning and return false, if the lattice
// /// did not have the correct format (e.g. if it is empty or it is not
// /// linear).
// bool CompactLatticeToWordProns(
// const TransitionModel &tmodel,
// const CompactLattice &clat,
// std::vector<int32> *words,
// std::vector<int32> *begin_times,
// std::vector<int32> *lengths,
// std::vector<std::vector<int32> > *prons,
// std::vector<std::vector<int32> > *phone_lengths);
//
//
// /// A form of the shortest-path/best-path algorithm that's specially coded for
// /// CompactLattice. Requires that clat be acyclic.
// void CompactLatticeShortestPath(const CompactLattice &clat,
// CompactLattice *shortest_path);
//
// /// This function expands a CompactLattice to ensure high-probability paths
// /// have unique histories. Arcs with posteriors larger than epsilon get splitted.
// void ExpandCompactLattice(const CompactLattice &clat,
// double epsilon,
// CompactLattice *expand_clat);
//
// /// For each state, compute forward and backward best (viterbi) costs and its
// /// traceback states (for generating best paths later). The forward best cost
// /// for a state is the cost of the best path from the start state to the state.
// /// The traceback state of this state is its predecessor state in the best path.
// /// The backward best cost for a state is the cost of the best path from the
// /// state to a final one. Its traceback state is the successor state in the best
// /// path in the forward direction.
// /// Note: final weights of states are in backward_best_cost_and_pred.
// /// Requires the input CompactLattice clat be acyclic.
// typedef std::vector<std::pair<double,
// CompactLatticeArc::StateId> > CostTraceType;
// void CompactLatticeBestCostsAndTracebacks(
// const CompactLattice &clat,
// CostTraceType *forward_best_cost_and_pred,
// CostTraceType *backward_best_cost_and_pred);
//
// /// This function adds estimated neural language model scores of words in a
// /// minimal list of hypotheses that covers a lattice, to the graph scores on the
// /// arcs. The list of hypotheses are generated by latbin/lattice-path-cover.
// typedef unordered_map<std::pair<int32, int32>, double, PairHasher<int32> > MapT;
// void AddNnlmScoreToCompactLattice(const MapT &nnlm_scores,
// CompactLattice *clat);
//
// /// This function add the word insertion penalty to graph score of each word
// /// in the compact lattice
// void AddWordInsPenToCompactLattice(BaseFloat word_ins_penalty,
// CompactLattice *clat);
//
// /// This function *adds* the negated scores obtained from the Decodable object,
// /// to the acoustic scores on the arcs. If you want to replace them, you should
// /// use ScaleCompactLattice to first set the acoustic scores to zero. Returns
// /// true on success, false on error (typically some kind of mismatched inputs).
// bool RescoreCompactLattice(DecodableInterface *decodable,
// CompactLattice *clat);
//
//
// /// This function returns the number of words in the longest sentence in a
// /// CompactLattice (i.e. the the maximum of any path, of the count of
// /// olabels on that path).
// int32 LongestSentenceLength(const Lattice &lat);
//
// /// This function returns the number of words in the longest sentence in a
// /// CompactLattice, i.e. the the maximum of any path, of the count of
// /// labels on that path... note, in CompactLattice, the ilabels and olabels
// /// are identical because it is an acceptor.
// int32 LongestSentenceLength(const CompactLattice &lat);
//
//
// /// This function is like RescoreCompactLattice, but it is modified to avoid
// /// computing probabilities on most frames where all the pdf-ids are the same.
// /// (it needs the transition-model to work out whether two transition-ids map to
// /// the same pdf-id, and it assumes that the lattice has transition-ids on it).
// /// The naive thing would be to just set all probabilities to zero on frames
// /// where all the pdf-ids are the same (because this value won't affect the
// /// lattice posterior). But this would become confusing when we compute
// /// corpus-level diagnostics such as the MMI objective function. Instead,
// /// imagine speedup_factor = 100 (it must be >= 1.0)... with probability (1.0 /
// /// speedup_factor) we compute those likelihoods and multiply them by
// /// speedup_factor; otherwise we set them to zero. This gives the right
// /// expected probability so our corpus-level diagnostics will be about right.
// bool RescoreCompactLatticeSpeedup(
// const TransitionModel &tmodel,
// BaseFloat speedup_factor,
// DecodableInterface *decodable,
// CompactLattice *clat);
//
//
// /// This function *adds* the negated scores obtained from the Decodable object,
// /// to the acoustic scores on the arcs. If you want to replace them, you should
// /// use ScaleCompactLattice to first set the acoustic scores to zero. Returns
// /// true on success, false on error (e.g. some kind of mismatched inputs).
// /// The input labels, if nonzero, are interpreted as transition-ids or whatever
// /// other index the Decodable object expects.
// bool RescoreLattice(DecodableInterface *decodable,
// Lattice *lat);
//
// /// This function Composes a CompactLattice format lattice with a
// /// DeterministicOnDemandFst<fst::StdFst> format fst, and outputs another
// /// CompactLattice format lattice. The first element (the one that corresponds
// /// to LM weight) in CompactLatticeWeight is used for composition.
// ///
// /// Note that the DeterministicOnDemandFst interface is not "const", therefore
// /// we cannot use "const" for <det_fst>.
// void ComposeCompactLatticeDeterministic(
// const CompactLattice& clat,
// fst::DeterministicOnDemandFst<fst::StdArc>* det_fst,
// CompactLattice* composed_clat);
//
// /// This function computes the mapping from the pair
// /// (frame-index, transition-id) to the pair
// /// (sum-of-acoustic-scores, num-of-occurences) over all occurences of the
// /// transition-id in that frame.
// /// frame-index in the lattice.
// /// This function is useful for retaining the acoustic scores in a
// /// non-compact lattice after a process like determinization where the
// /// frame-level acoustic scores are typically lost.
// /// The function ReplaceAcousticScoresFromMap is used to restore the
// /// acoustic scores computed by this function.
// ///
// /// @param [in] lat Input lattice. Expected to be top-sorted. Otherwise the
// /// function will crash.
// /// @param [out] acoustic_scores
// /// Pointer to a map from the pair (frame-index,
// /// transition-id) to a pair (sum-of-acoustic-scores,
// /// num-of-occurences).
// /// Usually the acoustic scores for a pdf-id (and hence
// /// transition-id) on a frame will be the same for all the
// /// occurences of the pdf-id in that frame.
// /// But if not, we will take the average of the acoustic
// /// scores. Hence, we store both the sum-of-acoustic-scores
// /// and the num-of-occurences of the transition-id in that
// /// frame.
// void ComputeAcousticScoresMap(
// const Lattice &lat,
// unordered_map<std::pair<int32, int32>, std::pair<BaseFloat, int32>,
// PairHasher<int32> > *acoustic_scores);
//
// /// This function restores acoustic scores computed using the function
// /// ComputeAcousticScoresMap into the lattice.
// ///
// /// @param [in] acoustic_scores
// /// A map from the pair (frame-index, transition-id) to a
// /// pair (sum-of-acoustic-scores, num-of-occurences) of
// /// the occurences of the transition-id in that frame.
// /// See the comments for ComputeAcousticScoresMap for
// /// details.
// /// @param [out] lat Pointer to the output lattice.
// void ReplaceAcousticScoresFromMap(
// const unordered_map<std::pair<int32, int32>, std::pair<BaseFloat, int32>,
// PairHasher<int32> > &acoustic_scores,
// Lattice *lat);
} // namespace kaldi
#endif // KALDI_LAT_LATTICE_FUNCTIONS_H_
// lm/arpa-file-parser.cc
// Copyright 2014 Guoguo Chen
// Copyright 2016 Smart Action Company LLC (kkm)
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include <fst/fstlib.h>
#include <sstream>
#include "base/kaldi-error.h"
#include "base/kaldi-math.h"
#include "lm/arpa-file-parser.h"
#include "util/text-utils.h"
namespace kaldi {
ArpaFileParser::ArpaFileParser(const ArpaParseOptions& options,
fst::SymbolTable* symbols)
: options_(options),
symbols_(symbols),
line_number_(0),
warning_count_(0) {}
ArpaFileParser::~ArpaFileParser() {}
void TrimTrailingWhitespace(std::string* str) {
str->erase(str->find_last_not_of(" \n\r\t") + 1);
}
void ArpaFileParser::Read(std::istream& is) {
// Argument sanity checks.
if (options_.bos_symbol <= 0 || options_.eos_symbol <= 0 ||
options_.bos_symbol == options_.eos_symbol)
KALDI_ERR << "BOS and EOS symbols are required, must not be epsilons, and "
<< "differ from each other. Given:"
<< " BOS=" << options_.bos_symbol
<< " EOS=" << options_.eos_symbol;
if (symbols_ != NULL &&
options_.oov_handling == ArpaParseOptions::kReplaceWithUnk &&
(options_.unk_symbol <= 0 || options_.unk_symbol == options_.bos_symbol ||
options_.unk_symbol == options_.eos_symbol))
KALDI_ERR << "When symbol table is given and OOV mode is kReplaceWithUnk, "
<< "UNK symbol is required, must not be epsilon, and "
<< "differ from both BOS and EOS symbols. Given:"
<< " UNK=" << options_.unk_symbol
<< " BOS=" << options_.bos_symbol
<< " EOS=" << options_.eos_symbol;
if (symbols_ != NULL && symbols_->Find(options_.bos_symbol).empty())
KALDI_ERR << "BOS symbol must exist in symbol table";
if (symbols_ != NULL && symbols_->Find(options_.eos_symbol).empty())
KALDI_ERR << "EOS symbol must exist in symbol table";
if (symbols_ != NULL && options_.unk_symbol > 0 &&
symbols_->Find(options_.unk_symbol).empty())
KALDI_ERR << "UNK symbol must exist in symbol table";
ngram_counts_.clear();
line_number_ = 0;
warning_count_ = 0;
current_line_.clear();
#define PARSE_ERR KALDI_ERR << LineReference() << ": "
// Give derived class an opportunity to prepare its state.
ReadStarted();
// Processes "\data\" section.
bool keyword_found = false;
while (++line_number_, getline(is, current_line_) && !is.eof()) {
if (current_line_.find_first_not_of(" \t\n\r") == std::string::npos) {
continue;
}
TrimTrailingWhitespace(&current_line_);
// Continue skipping lines until the \data\ marker alone on a line is found.
if (!keyword_found) {
if (current_line_ == "\\data\\") {
KALDI_LOG << "Reading \\data\\ section.";
keyword_found = true;
}
continue;
}
if (current_line_[0] == '\\') break;
// Enters "\data\" section, and looks for patterns like "ngram 1=1000",
// which means there are 1000 unigrams.
std::size_t equal_symbol_pos = current_line_.find("=");
if (equal_symbol_pos != std::string::npos)
// Guaranteed spaces around the "=".
current_line_.replace(equal_symbol_pos, 1, " = ");
std::vector<std::string> col;
SplitStringToVector(current_line_, " \t", true, &col);
if (col.size() == 4 && col[0] == "ngram" && col[2] == "=") {
int32 order, ngram_count = 0;
if (!ConvertStringToInteger(col[1], &order) ||
!ConvertStringToInteger(col[3], &ngram_count)) {
PARSE_ERR << "cannot parse ngram count";
}
if (ngram_counts_.size() <= order) {
ngram_counts_.resize(order);
}
ngram_counts_[order - 1] = ngram_count;
} else {
KALDI_WARN << LineReference()
<< ": uninterpretable line in \\data\\ section";
}
}
if (ngram_counts_.size() == 0)
PARSE_ERR << "\\data\\ section missing or empty.";
// Signal that grammar order and n-gram counts are known.
HeaderAvailable();
NGram ngram;
ngram.words.reserve(ngram_counts_.size());
// Processes "\N-grams:" section.
for (int32 cur_order = 1; cur_order <= ngram_counts_.size(); ++cur_order) {
// Skips n-grams with zero count.
if (ngram_counts_[cur_order - 1] == 0)
KALDI_WARN << "Zero ngram count in ngram order " << cur_order
<< "(look for 'ngram " << cur_order << "=0' in the \\data\\ "
<< " section). There is possibly a problem with the file.";
// Must be looking at a \k-grams: directive at this point.
std::ostringstream keyword;
keyword << "\\" << cur_order << "-grams:";
if (current_line_ != keyword.str()) {
PARSE_ERR << "invalid directive, expecting '" << keyword.str() << "'";
}
KALDI_LOG << "Reading " << current_line_ << " section.";
int32 ngram_count = 0;
while (++line_number_, getline(is, current_line_) && !is.eof()) {
if (current_line_.find_first_not_of(" \n\t\r") == std::string::npos) {
continue;
}
if (current_line_[0] == '\\') {
TrimTrailingWhitespace(&current_line_);
std::ostringstream next_keyword;
next_keyword << "\\" << cur_order + 1 << "-grams:";
if ((current_line_ != next_keyword.str()) &&
(current_line_ != "\\end\\")) {
if (ShouldWarn()) {
KALDI_WARN << "ignoring possible directive '" << current_line_
<< "' expecting '" << next_keyword.str() << "'";
if (warning_count_ > 0 &&
warning_count_ > static_cast<uint32>(options_.max_warnings)) {
KALDI_WARN << "Of " << warning_count_ << " parse warnings, "
<< options_.max_warnings << " were reported. "
<< "Run program with --max-arpa-warnings=-1 "
<< "to see all warnings";
}
}
} else {
break;
}
}
std::vector<std::string> col;
SplitStringToVector(current_line_, " \t", true, &col);
if (col.size() < 1 + cur_order || col.size() > 2 + cur_order ||
(cur_order == ngram_counts_.size() && col.size() != 1 + cur_order)) {
PARSE_ERR << "Invalid n-gram data line";
}
++ngram_count;
// Parse out n-gram logprob and, if present, backoff weight.
if (!ConvertStringToReal(col[0], &ngram.logprob)) {
PARSE_ERR << "invalid n-gram logprob '" << col[0] << "'";
}
ngram.backoff = 0.0;
if (col.size() > cur_order + 1) {
if (!ConvertStringToReal(col[cur_order + 1], &ngram.backoff))
PARSE_ERR << "invalid backoff weight '" << col[cur_order + 1] << "'";
}
// Convert to natural log.
ngram.logprob *= M_LN10;
ngram.backoff *= M_LN10;
ngram.words.resize(cur_order);
bool skip_ngram = false;
for (int32 index = 0; !skip_ngram && index < cur_order; ++index) {
int32 word;
if (symbols_) {
// Symbol table provided, so symbol labels are expected.
if (options_.oov_handling == ArpaParseOptions::kAddToSymbols) {
word = symbols_->AddSymbol(col[1 + index]);
} else {
word = symbols_->Find(col[1 + index]);
if (word == -1) { // fst::kNoSymbol
switch (options_.oov_handling) {
case ArpaParseOptions::kReplaceWithUnk:
word = options_.unk_symbol;
break;
case ArpaParseOptions::kSkipNGram:
if (ShouldWarn())
KALDI_WARN << LineReference() << " skipped: word '"
<< col[1 + index] << "' not in symbol table";
skip_ngram = true;
break;
default:
PARSE_ERR << "word '" << col[1 + index]
<< "' not in symbol table";
}
}
}
} else {
// Symbols not provided, LM file should contain integers.
if (!ConvertStringToInteger(col[1 + index], &word) || word < 0) {
PARSE_ERR << "invalid symbol '" << col[1 + index] << "'";
}
}
// Whichever way we got it, an epsilon is invalid.
if (word == 0) {
PARSE_ERR << "epsilon symbol '" << col[1 + index]
<< "' is illegal in ARPA LM";
}
ngram.words[index] = word;
}
if (!skip_ngram) {
ConsumeNGram(ngram);
}
}
if (ngram_count > ngram_counts_[cur_order - 1]) {
PARSE_ERR << "header said there would be " << ngram_counts_[cur_order - 1]
<< " n-grams of order " << cur_order
<< ", but we saw more already.";
}
}
if (current_line_ != "\\end\\") {
PARSE_ERR << "invalid or unexpected directive line, expecting \\end\\";
}
if (warning_count_ > 0 &&
warning_count_ > static_cast<uint32>(options_.max_warnings)) {
KALDI_WARN << "Of " << warning_count_ << " parse warnings, "
<< options_.max_warnings << " were reported. Run program with "
<< "--max_warnings=-1 to see all warnings";
}
current_line_.clear();
ReadComplete();
#undef PARSE_ERR
}
std::string ArpaFileParser::LineReference() const {
std::ostringstream ss;
ss << "line " << line_number_ << " [" << current_line_ << "]";
return ss.str();
}
bool ArpaFileParser::ShouldWarn() {
return (warning_count_ != -1) &&
(++warning_count_ <= static_cast<uint32>(options_.max_warnings));
}
} // namespace kaldi
// lm/arpa-file-parser.h
// Copyright 2014 Guoguo Chen
// Copyright 2016 Smart Action Company LLC (kkm)
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_LM_ARPA_FILE_PARSER_H_
#define KALDI_LM_ARPA_FILE_PARSER_H_
#include <fst/fst-decl.h>
#include <string>
#include <vector>
#include "base/kaldi-types.h"
#include "itf/options-itf.h"
namespace kaldi {
/**
Options that control ArpaFileParser
*/
struct ArpaParseOptions {
enum OovHandling {
kRaiseError, ///< Abort on OOV words
kAddToSymbols, ///< Add novel words to the symbol table.
kReplaceWithUnk, ///< Replace OOV words with <unk>.
kSkipNGram ///< Skip n-gram with OOV word and continue.
};
ArpaParseOptions()
: bos_symbol(-1),
eos_symbol(-1),
unk_symbol(-1),
oov_handling(kRaiseError),
max_warnings(30) {}
void Register(OptionsItf* opts) {
// Registering only the max_warnings count, since other options are
// treated differently by client programs: some want integer symbols,
// while other are passed words in their command line.
opts->Register("max-arpa-warnings", &max_warnings,
"Maximum warnings to report on ARPA parsing, "
"0 to disable, -1 to show all");
}
int32 bos_symbol; ///< Symbol for <s>, Required non-epsilon.
int32 eos_symbol; ///< Symbol for </s>, Required non-epsilon.
int32 unk_symbol; ///< Symbol for <unk>, Required for kReplaceWithUnk.
OovHandling oov_handling; ///< How to handle OOV words in the file.
int32 max_warnings; ///< Maximum warnings to report, <0 unlimited.
};
/**
A parsed n-gram from ARPA LM file.
*/
struct NGram {
NGram() : logprob(0.0), backoff(0.0) {}
std::vector<int32> words; ///< Symbols in left to right order.
float logprob; ///< Log-prob of the n-gram.
float backoff; ///< log-backoff weight of the n-gram.
///< Defaults to zero if not specified.
};
/**
ArpaFileParser is an abstract base class for ARPA LM file conversion.
See ConstArpaLmBuilder and ArpaLmCompiler for usage examples.
*/
class ArpaFileParser {
public:
/// Constructs the parser with the given options and optional symbol table.
/// If symbol table is provided, then the file should contain text n-grams,
/// and the words are mapped to symbols through it. bos_symbol and
/// eos_symbol in the options structure must be valid symbols in the table,
/// and so must be unk_symbol if provided. The table is not owned by the
/// parser, but may be augmented, if oov_handling is set to kAddToSymbols.
/// If symbol table is a null pointer, the file should contain integer
/// symbol values, and oov_handling has no effect. bos_symbol and eos_symbol
/// must be valid symbols still.
ArpaFileParser(const ArpaParseOptions& options, fst::SymbolTable* symbols);
virtual ~ArpaFileParser();
/// Read ARPA LM file from a stream.
void Read(std::istream& is);
/// Parser options.
const ArpaParseOptions& Options() const { return options_; }
protected:
/// Override called before reading starts. This is the point to prepare
/// any state in the derived class.
virtual void ReadStarted() {}
/// Override function called to signal that ARPA header with the expected
/// number of n-grams has been read, and ngram_counts() is now valid.
virtual void HeaderAvailable() {}
/// Pure override that must be implemented to process current n-gram. The
/// n-grams are sent in the file order, which guarantees that all
/// (k-1)-grams are processed before the first k-gram is.
virtual void ConsumeNGram(const NGram&) = 0;
/// Override function called after the last n-gram has been consumed.
virtual void ReadComplete() {}
/// Read-only access to symbol table. Not owned, do not make public.
const fst::SymbolTable* Symbols() const { return symbols_; }
/// Inside ConsumeNGram(), provides the current line number.
int32 LineNumber() const { return line_number_; }
/// Inside ConsumeNGram(), returns a formatted reference to the line being
/// compiled, to print out as part of diagnostics.
std::string LineReference() const;
/// Increments warning count, and returns true if a warning should be
/// printed or false if the count has exceeded the set maximum.
bool ShouldWarn();
/// N-gram counts. Valid from the point when HeaderAvailable() is called.
const std::vector<int32>& NgramCounts() const { return ngram_counts_; }
private:
ArpaParseOptions options_;
fst::SymbolTable* symbols_; // the pointer is not owned here.
int32 line_number_;
uint32 warning_count_;
std::string current_line_;
std::vector<int32> ngram_counts_;
};
} // namespace kaldi
#endif // KALDI_LM_ARPA_FILE_PARSER_H_
// lm/arpa-lm-compiler.cc
// Copyright 2009-2011 Gilles Boulianne
// Copyright 2016 Smart Action LLC (kkm)
// Copyright 2017 Xiaohui Zhang
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <functional>
#include <limits>
#include <sstream>
#include <unordered_map>
#include <utility>
#include <vector>
#include "base/kaldi-math.h"
#include "fstext/remove-eps-local.h"
#include "lm/arpa-lm-compiler.h"
#include "util/stl-utils.h"
#include "util/text-utils.h"
namespace kaldi {
class ArpaLmCompilerImplInterface {
public:
virtual ~ArpaLmCompilerImplInterface() {}
virtual void ConsumeNGram(const NGram& ngram, bool is_highest) = 0;
};
namespace {
typedef int32 StateId;
typedef int32 Symbol;
// GeneralHistKey can represent state history in an arbitrarily large n
// n-gram model with symbol ids fitting int32.
class GeneralHistKey {
public:
// Construct key from being and end iterators.
template <class InputIt>
GeneralHistKey(InputIt begin, InputIt end) : vector_(begin, end) {}
// Construct empty history key.
GeneralHistKey() : vector_() {}
// Return tails of the key as a GeneralHistKey. The tails of an n-gram
// w[1..n] is the sequence w[2..n] (and the heads is w[1..n-1], but the
// key class does not need this operartion).
GeneralHistKey Tails() const {
return GeneralHistKey(vector_.begin() + 1, vector_.end());
}
// Keys are equal if represent same state.
friend bool operator==(const GeneralHistKey& a, const GeneralHistKey& b) {
return a.vector_ == b.vector_;
}
// Public typename HashType for hashing.
struct HashType : public std::unary_function<GeneralHistKey, size_t> {
size_t operator()(const GeneralHistKey& key) const {
return VectorHasher<Symbol>().operator()(key.vector_);
}
};
private:
std::vector<Symbol> vector_;
};
// OptimizedHistKey combines 3 21-bit symbol ID values into one 64-bit
// machine word. allowing significant memory reduction and some runtime
// benefit over GeneralHistKey. Since 3 symbols are enough to track history
// in a 4-gram model, this optimized key is used for smaller models with up
// to 4-gram and symbol values up to 2^21-1.
//
// See GeneralHistKey for interface requirements of a key class.
class OptimizedHistKey {
public:
enum {
kShift = 21, // 21 * 3 = 63 bits for data.
kMaxData = (1 << kShift) - 1
};
template <class InputIt>
OptimizedHistKey(InputIt begin, InputIt end) : data_(0) {
for (uint32 shift = 0; begin != end; ++begin, shift += kShift) {
data_ |= static_cast<uint64>(*begin) << shift;
}
}
OptimizedHistKey() : data_(0) {}
OptimizedHistKey Tails() const { return OptimizedHistKey(data_ >> kShift); }
friend bool operator==(const OptimizedHistKey& a, const OptimizedHistKey& b) {
return a.data_ == b.data_;
}
struct HashType : public std::unary_function<OptimizedHistKey, size_t> {
size_t operator()(const OptimizedHistKey& key) const { return key.data_; }
};
private:
explicit OptimizedHistKey(uint64 data) : data_(data) {}
uint64 data_;
};
} // namespace
template <class HistKey>
class ArpaLmCompilerImpl : public ArpaLmCompilerImplInterface {
public:
ArpaLmCompilerImpl(ArpaLmCompiler* parent, fst::StdVectorFst* fst,
Symbol sub_eps);
virtual void ConsumeNGram(const NGram& ngram, bool is_highest);
private:
StateId AddStateWithBackoff(HistKey key, float backoff);
void CreateBackoff(HistKey key, StateId state, float weight);
ArpaLmCompiler* parent_; // Not owned.
fst::StdVectorFst* fst_; // Not owned.
Symbol bos_symbol_;
Symbol eos_symbol_;
Symbol sub_eps_;
StateId eos_state_;
typedef unordered_map<HistKey, StateId, typename HistKey::HashType>
HistoryMap;
HistoryMap history_;
};
template <class HistKey>
ArpaLmCompilerImpl<HistKey>::ArpaLmCompilerImpl(ArpaLmCompiler* parent,
fst::StdVectorFst* fst,
Symbol sub_eps)
: parent_(parent),
fst_(fst),
bos_symbol_(parent->Options().bos_symbol),
eos_symbol_(parent->Options().eos_symbol),
sub_eps_(sub_eps) {
// The algorithm maintains state per history. The 0-gram is a special state
// for empty history. All unigrams (including BOS) backoff into this state.
StateId zerogram = fst_->AddState();
history_[HistKey()] = zerogram;
// Also, if </s> is not treated as epsilon, create a common end state for
// all transitions accepting the </s>, since they do not back off. This small
// optimization saves about 2% states in an average grammar.
if (sub_eps_ == 0) {
eos_state_ = fst_->AddState();
fst_->SetFinal(eos_state_, 0);
}
}
template <class HistKey>
void ArpaLmCompilerImpl<HistKey>::ConsumeNGram(const NGram& ngram,
bool is_highest) {
// Generally, we do the following. Suppose we are adding an n-gram "A B
// C". Then find the node for "A B", add a new node for "A B C", and connect
// them with the arc accepting "C" with the specified weight. Also, add a
// backoff arc from the new "A B C" node to its backoff state "B C".
//
// Two notable exceptions are the highest order n-grams, and final n-grams.
//
// When adding a highest order n-gram (e. g., our "A B C" is in a 3-gram LM),
// the following optimization is performed. There is no point adding a node
// for "A B C" with a "C" arc from "A B", since there will be no other
// arcs ingoing to this node, and an epsilon backoff arc into the backoff
// model "B C", with the weight of \bar{1}. To save a node, create an arc
// accepting "C" directly from "A B" to "B C". This saves as many nodes
// as there are the highest order n-grams, which is typically about half
// the size of a large 3-gram model.
//
// Indeed, this does not apply to n-grams ending in EOS, since they do not
// back off. These are special, as they do not have a back-off state, and
// the node for "(..anything..) </s>" is always final. These are handled
// in one of the two possible ways, If symbols <s> and </s> are being
// replaced by epsilons, neither node nor arc is created, and the logprob
// of the n-gram is applied to its source node as final weight. If <s> and
// </s> are preserved, then a special final node for </s> is allocated and
// used as the destination of the "</s>" acceptor arc.
HistKey heads(ngram.words.begin(), ngram.words.end() - 1);
typename HistoryMap::iterator source_it = history_.find(heads);
if (source_it == history_.end()) {
// There was no "A B", therefore the probability of "A B C" is zero.
// Print a warning and discard current n-gram.
if (parent_->ShouldWarn())
KALDI_WARN << parent_->LineReference()
<< " skipped: no parent (n-1)-gram exists";
return;
}
StateId source = source_it->second;
StateId dest;
Symbol sym = ngram.words.back();
float weight = -ngram.logprob;
if (sym == sub_eps_ || sym == 0) {
KALDI_ERR << " <eps> or disambiguation symbol " << sym
<< "found in the ARPA file. ";
}
if (sym == eos_symbol_) {
if (sub_eps_ == 0) {
// Keep </s> as a real symbol when not substituting.
dest = eos_state_;
} else {
// Treat </s> as if it was epsilon: mark source final, with the weight
// of the n-gram.
fst_->SetFinal(source, weight);
return;
}
} else {
// For the highest order n-gram, this may find an existing state, for
// non-highest, will create one (unless there are duplicate n-grams
// in the grammar, which cannot be reliably detected if highest order,
// so we better do not do that at all).
dest = AddStateWithBackoff(
HistKey(ngram.words.begin() + (is_highest ? 1 : 0), ngram.words.end()),
-ngram.backoff);
}
if (sym == bos_symbol_) {
weight = 0; // Accepting <s> is always free.
if (sub_eps_ == 0) {
// <s> is as a real symbol, only accepted in the start state.
source = fst_->AddState();
fst_->SetStart(source);
} else {
// The new state for <s> unigram history *is* the start state.
fst_->SetStart(dest);
return;
}
}
// Add arc from source to dest, whichever way it was found.
fst_->AddArc(source, fst::StdArc(sym, sym, weight, dest));
return;
}
// Find or create a new state for n-gram defined by key, and ensure it has a
// backoff transition. The key is either the current n-gram for all but
// highest orders, or the tails of the n-gram for the highest order. The
// latter arises from the chain-collapsing optimization described above.
template <class HistKey>
StateId ArpaLmCompilerImpl<HistKey>::AddStateWithBackoff(HistKey key,
float backoff) {
typename HistoryMap::iterator dest_it = history_.find(key);
if (dest_it != history_.end()) {
// Found an existing state in the history map. Invariant: if the state in
// the map, then its backoff arc is in the FST. We are done.
return dest_it->second;
}
// Otherwise create a new state and its backoff arc, and register in the map.
StateId dest = fst_->AddState();
history_[key] = dest;
CreateBackoff(key.Tails(), dest, backoff);
return dest;
}
// Create a backoff arc for a state. Key is a backoff destination that may or
// may not exist. When the destination is not found, naturally fall back to
// the lower order model, and all the way down until one is found (since the
// 0-gram model is always present, the search is guaranteed to terminate).
template <class HistKey>
inline void ArpaLmCompilerImpl<HistKey>::CreateBackoff(HistKey key,
StateId state,
float weight) {
typename HistoryMap::iterator dest_it = history_.find(key);
while (dest_it == history_.end()) {
key = key.Tails();
dest_it = history_.find(key);
}
// The arc should transduce either <eos> or #0 to <eps>, depending on the
// epsilon substitution mode. This is the only case when input and output
// label may differ.
fst_->AddArc(state, fst::StdArc(sub_eps_, 0, weight, dest_it->second));
}
ArpaLmCompiler::~ArpaLmCompiler() {
if (impl_ != NULL) delete impl_;
}
void ArpaLmCompiler::HeaderAvailable() {
KALDI_ASSERT(impl_ == NULL);
// Use optimized implementation if the grammar is 4-gram or less, and the
// maximum attained symbol id will fit into the optimized range.
int64 max_symbol = 0;
if (Symbols() != NULL) max_symbol = Symbols()->AvailableKey() - 1;
// If augmenting the symbol table, assume the worst case when all words in
// the model being read are novel.
if (Options().oov_handling == ArpaParseOptions::kAddToSymbols)
max_symbol += NgramCounts()[0];
if (NgramCounts().size() <= 4 && max_symbol < OptimizedHistKey::kMaxData) {
impl_ = new ArpaLmCompilerImpl<OptimizedHistKey>(this, &fst_, sub_eps_);
} else {
impl_ = new ArpaLmCompilerImpl<GeneralHistKey>(this, &fst_, sub_eps_);
KALDI_LOG << "Reverting to slower state tracking because model is large: "
<< NgramCounts().size() << "-gram with symbols up to "
<< max_symbol;
}
}
void ArpaLmCompiler::ConsumeNGram(const NGram& ngram) {
// <s> is invalid in tails, </s> in heads of an n-gram.
for (int i = 0; i < ngram.words.size(); ++i) {
if ((i > 0 && ngram.words[i] == Options().bos_symbol) ||
(i + 1 < ngram.words.size() &&
ngram.words[i] == Options().eos_symbol)) {
if (ShouldWarn())
KALDI_WARN << LineReference()
<< " skipped: n-gram has invalid BOS/EOS placement";
return;
}
}
bool is_highest = ngram.words.size() == NgramCounts().size();
impl_->ConsumeNGram(ngram, is_highest);
}
void ArpaLmCompiler::RemoveRedundantStates() {
fst::StdArc::Label backoff_symbol = sub_eps_;
if (backoff_symbol == 0) {
// The method of removing redundant states implemented in this function
// leads to slow determinization of L o G when people use the older style of
// usage of arpa2fst where the --disambig-symbol option was not specified.
// The issue seems to be that it creates a non-deterministic FST, while G is
// supposed to be deterministic. By 'return'ing below, we just disable this
// method if people were using an older script. This method isn't really
// that consequential anyway, and people will move to the newer-style
// scripts (see current utils/format_lm.sh), so this isn't much of a
// problem.
return;
}
fst::StdArc::StateId num_states = fst_.NumStates();
// replace the #0 symbols on the input of arcs out of redundant states (states
// that are not final and have only a backoff arc leaving them), with <eps>.
for (fst::StdArc::StateId state = 0; state < num_states; state++) {
if (fst_.NumArcs(state) == 1 &&
fst_.Final(state) == fst::TropicalWeight::Zero()) {
fst::MutableArcIterator<fst::StdVectorFst> iter(&fst_, state);
fst::StdArc arc = iter.Value();
if (arc.ilabel == backoff_symbol) {
arc.ilabel = 0;
iter.SetValue(arc);
}
}
}
// we could call fst::RemoveEps, and it would have the same effect in normal
// cases, where backoff_symbol != 0 and there are no epsilons in unexpected
// places, but RemoveEpsLocal is a bit safer in case something weird is going
// on; it guarantees not to blow up the FST.
fst::RemoveEpsLocal(&fst_);
KALDI_LOG << "Reduced num-states from " << num_states << " to "
<< fst_.NumStates();
}
void ArpaLmCompiler::Check() const {
if (fst_.Start() == fst::kNoStateId) {
KALDI_ERR << "Arpa file did not contain the beginning-of-sentence symbol "
<< Symbols()->Find(Options().bos_symbol) << ".";
}
}
void ArpaLmCompiler::ReadComplete() {
fst_.SetInputSymbols(Symbols());
fst_.SetOutputSymbols(Symbols());
RemoveRedundantStates();
Check();
}
} // namespace kaldi
// lm/arpa-lm-compiler.h
// Copyright 2009-2011 Gilles Boulianne
// Copyright 2016 Smart Action LLC (kkm)
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_LM_ARPA_LM_COMPILER_H_
#define KALDI_LM_ARPA_LM_COMPILER_H_
#include <fst/fstlib.h>
#include "lm/arpa-file-parser.h"
namespace kaldi {
class ArpaLmCompilerImplInterface;
class ArpaLmCompiler : public ArpaFileParser {
public:
ArpaLmCompiler(const ArpaParseOptions& options, int sub_eps,
fst::SymbolTable* symbols)
: ArpaFileParser(options, symbols), sub_eps_(sub_eps), impl_(NULL) {}
~ArpaLmCompiler();
const fst::StdVectorFst& Fst() const { return fst_; }
fst::StdVectorFst* MutableFst() { return &fst_; }
protected:
// ArpaFileParser overrides.
virtual void HeaderAvailable();
virtual void ConsumeNGram(const NGram& ngram);
virtual void ReadComplete();
private:
// this function removes states that only have a backoff arc coming
// out of them.
void RemoveRedundantStates();
void Check() const;
int sub_eps_;
ArpaLmCompilerImplInterface* impl_; // Owned.
fst::StdVectorFst fst_;
template <class HistKey>
friend class ArpaLmCompilerImpl;
};
} // namespace kaldi
#endif // KALDI_LM_ARPA_LM_COMPILER_H_
// bin/arpa2fst.cc
//
// Copyright 2009-2011 Gilles Boulianne.
//
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABILITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "lm/arpa-lm-compiler.h"
#include "util/kaldi-io.h"
#include "util/parse-options.h"
int main(int argc, char *argv[]) {
using namespace kaldi; // NOLINT
try {
const char *usage =
"Convert an ARPA format language model into an FST\n"
"Usage: arpa2fst [opts] <input-arpa> <output-fst>\n"
" e.g.: arpa2fst --disambig-symbol=#0 --read-symbol-table="
"data/lang/words.txt lm/input.arpa G.fst\n\n"
"Note: When called without switches, the output G.fst will contain\n"
"an embedded symbol table. This is compatible with the way a previous\n"
"version of arpa2fst worked.\n";
ParseOptions po(usage);
ArpaParseOptions options;
options.Register(&po);
// Option flags.
std::string bos_symbol = "<s>";
std::string eos_symbol = "</s>";
std::string disambig_symbol;
std::string read_syms_filename;
std::string write_syms_filename;
bool keep_symbols = false;
bool ilabel_sort = true;
po.Register("bos-symbol", &bos_symbol, "Beginning of sentence symbol");
po.Register("eos-symbol", &eos_symbol, "End of sentence symbol");
po.Register("disambig-symbol", &disambig_symbol,
"Disambiguator. If provided (e. g. #0), used on input side of "
"backoff links, and <s> and </s> are replaced with epsilons");
po.Register("read-symbol-table", &read_syms_filename,
"Use existing symbol table");
po.Register("write-symbol-table", &write_syms_filename,
"Write generated symbol table to a file");
po.Register("keep-symbols", &keep_symbols,
"Store symbol table with FST. Symbols always saved to FST if "
"symbol tables are neither read or written (otherwise symbols "
"would be lost entirely)");
po.Register("ilabel-sort", &ilabel_sort, "Ilabel-sort the output FST");
po.Read(argc, argv);
if (po.NumArgs() != 1 && po.NumArgs() != 2) {
po.PrintUsage();
exit(1);
}
std::string arpa_rxfilename = po.GetArg(1),
fst_wxfilename = po.GetOptArg(2);
int64 disambig_symbol_id = 0;
fst::SymbolTable *symbols;
if (!read_syms_filename.empty()) {
// Use existing symbols. Required symbols must be in the table.
kaldi::Input kisym(read_syms_filename);
symbols = fst::SymbolTable::ReadText(
kisym.Stream(), PrintableWxfilename(read_syms_filename));
if (symbols == NULL)
KALDI_ERR << "Could not read symbol table from file "
<< read_syms_filename;
options.oov_handling = ArpaParseOptions::kSkipNGram;
if (!disambig_symbol.empty()) {
disambig_symbol_id = symbols->Find(disambig_symbol);
if (disambig_symbol_id == -1) // fst::kNoSymbol
KALDI_ERR << "Symbol table " << read_syms_filename
<< " has no symbol for " << disambig_symbol;
}
} else {
// Create a new symbol table and populate it from ARPA file.
symbols = new fst::SymbolTable(PrintableWxfilename(fst_wxfilename));
options.oov_handling = ArpaParseOptions::kAddToSymbols;
symbols->AddSymbol("<eps>", 0);
if (!disambig_symbol.empty()) {
disambig_symbol_id = symbols->AddSymbol(disambig_symbol);
}
}
// Add or use existing BOS and EOS.
options.bos_symbol = symbols->AddSymbol(bos_symbol);
options.eos_symbol = symbols->AddSymbol(eos_symbol);
// If producing new (not reading existing) symbols and not saving them,
// need to keep symbols with FST, otherwise they would be lost.
if (read_syms_filename.empty() && write_syms_filename.empty())
keep_symbols = true;
// Actually compile LM.
KALDI_ASSERT(symbols != NULL);
ArpaLmCompiler lm_compiler(options, disambig_symbol_id, symbols);
{
Input ki(arpa_rxfilename);
lm_compiler.Read(ki.Stream());
}
// Sort the FST in-place if requested by options.
if (ilabel_sort) {
fst::ArcSort(lm_compiler.MutableFst(), fst::StdILabelCompare());
}
// Write symbols if requested.
if (!write_syms_filename.empty()) {
kaldi::Output kosym(write_syms_filename, false);
symbols->WriteText(kosym.Stream());
}
// Write LM FST.
bool write_binary = true, write_header = false;
kaldi::Output kofst(fst_wxfilename, write_binary, write_header);
fst::FstWriteOptions wopts(PrintableWxfilename(fst_wxfilename));
wopts.write_isymbols = wopts.write_osymbols = keep_symbols;
lm_compiler.Fst().Write(kofst.Stream(), wopts);
delete symbols;
} catch (const std::exception &e) {
std::cerr << e.what();
return -1;
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment