Commit 764b3a75 authored by Sugon_ldc's avatar Sugon_ldc
Browse files

add new model

parents
add_library(http STATIC
http_client.cc
http_server.cc
)
target_link_libraries(http PUBLIC decoder)
// Copyright (c) 2023 Ximalaya Speech Team (Xiang Lyu)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "http/http_client.h"
#include "boost/json/src.hpp"
#include "utils/log.h"
namespace wenet {
namespace beast = boost::beast; // from <boost/beast.hpp>
namespace http = beast::http; // from <boost/beast/http.hpp>
namespace net = boost::asio; // from <boost/asio.hpp>
using tcp = net::ip::tcp; // from <boost/asio/ip/tcp.hpp>
namespace json = boost::json;
HttpClient::HttpClient(const std::string& hostname, int port)
: hostname_(hostname), port_(port) {
Connect();
}
void HttpClient::Connect() {
tcp::resolver resolver{ioc_};
// Look up the domain name
auto const results = resolver.resolve(hostname_, std::to_string(port_));
stream_.connect(results);
}
void HttpClient::SendBinaryData(const void* data, size_t size) {
try {
json::value start_tag = {{"nbest", nbest_},
{"continuous_decoding", continuous_decoding_}};
std::string config = json::serialize(start_tag);
req_.set("config", config);
std::size_t encode_size = beast::detail::base64::encoded_size(size);
char encode_data[encode_size]; // NOLINT
beast::detail::base64::encode(encode_data, data, size);
req_.body() = encode_data;
req_.prepare_payload();
http::write(stream_, req_, ec_);
http::read(stream_, buffer_, res_);
std::string message = res_.body();
json::object obj = json::parse(message).as_object();
LOG(INFO) << message;
} catch (std::exception const& e) {
LOG(ERROR) << e.what();
}
stream_.socket().shutdown(tcp::socket::shutdown_both, ec_);
}
} // namespace wenet
// Copyright (c) 2023 Ximalaya Speech Team (Xiang Lyu)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef HTTP_HTTP_CLIENT_H_
#define HTTP_HTTP_CLIENT_H_
#include <iostream>
#include <memory>
#include <string>
#include <thread>
#include <boost/asio/connect.hpp>
#include <boost/asio/ip/tcp.hpp>
#include <boost/beast/core.hpp>
#include <boost/beast/core/detail/base64.hpp>
#include <boost/beast/http.hpp>
#include <boost/beast/version.hpp>
#include "utils/utils.h"
namespace wenet {
namespace beast = boost::beast; // from <boost/beast.hpp>
namespace http = beast::http; // from <boost/beast/http.hpp>
namespace net = boost::asio; // from <boost/asio.hpp>
using tcp = net::ip::tcp; // from <boost/asio/ip/tcp.hpp>
class HttpClient {
public:
HttpClient(const std::string& host, int port);
void SendBinaryData(const void* data, size_t size);
void set_nbest(int nbest) { nbest_ = nbest; }
private:
void Connect();
std::string hostname_;
int port_;
std::string target_ = "/";
int version_ = 11;
int nbest_ = 1;
const bool continuous_decoding_ = false;
net::io_context ioc_;
beast::tcp_stream stream_{ioc_};
beast::flat_buffer buffer_;
http::request<http::string_body> req_{http::verb::get, target_, version_};
http::response<http::string_body> res_{http::status::ok, version_};
beast::error_code ec_;
WENET_DISALLOW_COPY_AND_ASSIGN(HttpClient);
};
} // namespace wenet
#endif // HTTP_HTTP_CLIENT_H_
// Copyright (c) 2023 Ximalaya Speech Team (Xiang Lyu)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "http/http_server.h"
#include <thread>
#include <utility>
#include <vector>
#include "boost/json/src.hpp"
#include "utils/log.h"
namespace wenet {
namespace beast = boost::beast; // from <boost/beast.hpp>
namespace http = beast::http; // from <boost/beast/http.hpp>
namespace net = boost::asio; // from <boost/asio.hpp>
using tcp = boost::asio::ip::tcp; // from <boost/asio/ip/tcp.hpp>
namespace json = boost::json;
ConnectionHandler::ConnectionHandler(
tcp::socket&& socket, std::shared_ptr<FeaturePipelineConfig> feature_config,
std::shared_ptr<DecodeOptions> decode_config,
std::shared_ptr<DecodeResource> decode_resource)
: socket_(std::move(socket)),
feature_config_(std::move(feature_config)),
decode_config_(std::move(decode_config)),
decode_resource_(std::move(decode_resource)),
req_(std::make_shared<http::request<http::string_body>>(
http::verb::post, target_, version_)),
res_(std::make_shared<http::response<http::string_body>>(http::status::ok,
version_)) {}
void ConnectionHandler::OnSpeechStart() {
feature_pipeline_ = std::make_shared<FeaturePipeline>(*feature_config_);
decoder_ = std::make_shared<AsrDecoder>(feature_pipeline_, decode_resource_,
*decode_config_);
// Start decoder thread
decode_thread_ =
std::make_shared<std::thread>(&ConnectionHandler::DecodeThreadFunc, this);
}
void ConnectionHandler::OnSpeechEnd() {
if (feature_pipeline_ != nullptr) {
feature_pipeline_->set_input_finished();
}
}
void ConnectionHandler::OnFinalResult(const std::string& result) {
LOG(INFO) << "Final result: " << result;
json::value rv = {
{"status", "ok"}, {"type", "final_result"}, {"nbest", result}};
std::string message = json::serialize(rv);
res_.get()->body() = message;
http::write(socket_, *res_.get(), ec_);
}
void ConnectionHandler::OnSpeechData(const std::string& message) {
std::size_t decode_size =
beast::detail::base64::decoded_size(message.length());
int num_samples = decode_size / sizeof(int16_t);
int16_t decode_data[num_samples]; // NOLINT
beast::detail::base64::decode(decode_data, message.c_str(), message.length());
// Read binary PCM data
VLOG(2) << "Received " << num_samples << " samples";
CHECK(feature_pipeline_ != nullptr);
CHECK(decoder_ != nullptr);
feature_pipeline_->AcceptWaveform(decode_data, num_samples);
}
std::string ConnectionHandler::SerializeResult(bool finish) {
json::array nbest;
for (const DecodeResult& path : decoder_->result()) {
json::object jpath({{"sentence", path.sentence}});
if (finish) {
json::array word_pieces;
for (const WordPiece& word_piece : path.word_pieces) {
json::object jword_piece({{"word", word_piece.word},
{"start", word_piece.start},
{"end", word_piece.end}});
word_pieces.emplace_back(jword_piece);
}
jpath.emplace("word_pieces", word_pieces);
}
nbest.emplace_back(jpath);
if (nbest.size() == nbest_) {
break;
}
}
return json::serialize(nbest);
}
void ConnectionHandler::DecodeThreadFunc() {
try {
while (true) {
DecodeState state = decoder_->Decode();
if (state == DecodeState::kEndFeats || state == DecodeState::kEndpoint) {
decoder_->Rescoring();
std::string result = SerializeResult(true);
OnFinalResult(result);
break;
}
}
} catch (std::exception const& e) {
LOG(ERROR) << e.what();
}
}
void ConnectionHandler::OnError(const std::string& message) {
json::value rv = {{"status", "failed"}, {"message", message}};
res_.get()->body() = json::serialize(rv);
http::write(socket_, *res_.get(), ec_);
// Send a TCP shutdown
socket_.shutdown(tcp::socket::shutdown_send, ec_);
}
void ConnectionHandler::OnText(const std::string& message) {
LOG(INFO) << message;
json::value v = json::parse(message);
if (v.is_object()) {
json::object obj = v.get_object();
if (obj.find("nbest") != obj.end()) {
if (obj["nbest"].is_int64()) {
nbest_ = obj["nbest"].as_int64();
} else {
OnError("integer is expected for nbest option");
}
}
} else {
OnError("Wrong protocol");
}
}
void ConnectionHandler::operator()() {
try {
http::read(socket_, buffer_, *req_.get(), ec_);
if (ec_) {
LOG(ERROR) << ec_;
} else {
OnText(req_.get()->base()["config"].to_string());
OnSpeechStart();
OnSpeechData(req_.get()->body());
OnSpeechEnd();
}
LOG(INFO) << "Read all pcm data, wait for decoding thread";
if (decode_thread_ != nullptr) {
decode_thread_->join();
}
} catch (beast::system_error const& se) {
LOG(INFO) << se.code().message();
if (decode_thread_ != nullptr) {
decode_thread_->join();
}
} catch (std::exception const& e) {
LOG(ERROR) << e.what();
}
socket_.shutdown(tcp::socket::shutdown_send, ec_);
}
void HttpServer::Start() {
try {
auto const address = net::ip::make_address("0.0.0.0");
tcp::acceptor acceptor{ioc_, {address, static_cast<uint16_t>(port_)}};
for (;;) {
// This will receive the new connection
tcp::socket socket{ioc_};
// Block until we get a connection
acceptor.accept(socket);
// Launch the session, transferring ownership of the socket
ConnectionHandler handler(std::move(socket), feature_config_,
decode_config_, decode_resource_);
std::thread t(std::move(handler));
t.detach();
}
} catch (const std::exception& e) {
LOG(FATAL) << e.what();
}
}
} // namespace wenet
// Copyright (c) 2023 Ximalaya Speech Team (Xiang Lyu)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef HTTP_HTTP_SERVER_H_
#define HTTP_HTTP_SERVER_H_
#include <iostream>
#include <memory>
#include <string>
#include <thread>
#include <utility>
#include <boost/asio/ip/tcp.hpp>
#include <boost/beast/core.hpp>
#include <boost/beast/core/detail/base64.hpp>
#include <boost/beast/http.hpp>
#include <boost/beast/version.hpp>
#include <boost/config.hpp>
#include "decoder/asr_decoder.h"
#include "frontend/feature_pipeline.h"
#include "utils/log.h"
namespace wenet {
namespace beast = boost::beast; // from <boost/beast.hpp>
namespace http = beast::http; // from <boost/beast/http.hpp>
namespace net = boost::asio; // from <boost/asio.hpp>
using tcp = boost::asio::ip::tcp; // from <boost/asio/ip/tcp.hpp>
class ConnectionHandler {
public:
ConnectionHandler(tcp::socket&& socket,
std::shared_ptr<FeaturePipelineConfig> feature_config,
std::shared_ptr<DecodeOptions> decode_config,
std::shared_ptr<DecodeResource> decode_resource_);
void operator()();
private:
void OnSpeechStart();
void OnSpeechEnd();
void OnText(const std::string& message);
void OnSpeechData(const std::string& message);
void OnError(const std::string& message);
void OnFinalResult(const std::string& result);
void DecodeThreadFunc();
std::string SerializeResult(bool finish);
std::string target_ = "/";
int version_ = 11;
const bool continuous_decoding_ = false;
int nbest_ = 1;
tcp::socket socket_;
beast::flat_buffer buffer_;
beast::error_code ec_;
std::shared_ptr<http::request<http::string_body>> req_;
std::shared_ptr<http::response<http::string_body>> res_;
std::shared_ptr<FeaturePipelineConfig> feature_config_;
std::shared_ptr<DecodeOptions> decode_config_;
std::shared_ptr<DecodeResource> decode_resource_;
std::shared_ptr<FeaturePipeline> feature_pipeline_ = nullptr;
std::shared_ptr<AsrDecoder> decoder_ = nullptr;
std::shared_ptr<std::thread> decode_thread_ = nullptr;
};
class HttpServer {
public:
HttpServer(int port, std::shared_ptr<FeaturePipelineConfig> feature_config,
std::shared_ptr<DecodeOptions> decode_config,
std::shared_ptr<DecodeResource> decode_resource)
: port_(port),
feature_config_(std::move(feature_config)),
decode_config_(std::move(decode_config)),
decode_resource_(std::move(decode_resource)) {}
void Start();
private:
int port_;
// The io_context is required for all I/O
net::io_context ioc_{1};
std::shared_ptr<FeaturePipelineConfig> feature_config_;
std::shared_ptr<DecodeOptions> decode_config_;
std::shared_ptr<DecodeResource> decode_resource_;
WENET_DISALLOW_COPY_AND_ASSIGN(HttpServer);
};
} // namespace wenet
#endif // HTTP_HTTP_SERVER_H_
cmake_minimum_required(VERSION 3.10 FATAL_ERROR)
project(kaldi)
# include_directories() is called in the root CMakeLists.txt
add_library(kaldi-util
base/kaldi-error.cc
base/kaldi-math.cc
util/kaldi-io.cc
util/parse-options.cc
util/simple-io-funcs.cc
util/text-utils.cc
)
target_link_libraries(kaldi-util PUBLIC utils)
add_library(kaldi-decoder
lat/determinize-lattice-pruned.cc
lat/lattice-functions.cc
decoder/lattice-faster-decoder.cc
decoder/lattice-faster-online-decoder.cc
)
target_link_libraries(kaldi-decoder PUBLIC kaldi-util)
if(GRAPH_TOOLS)
# Arpa binary
add_executable(arpa2fst
lm/arpa-file-parser.cc
lm/arpa-lm-compiler.cc
lmbin/arpa2fst.cc
)
target_link_libraries(arpa2fst PUBLIC kaldi-util)
# FST tools binary
set(FST_BINS
fstaddselfloops
fstdeterminizestar
fstisstochastic
fstminimizeencoded
fsttablecompose
)
if(NOT MSVC)
# dl is for dynamic linking, otherwise there is a linking error on linux
link_libraries(dl)
endif()
foreach(name IN LISTS FST_BINS)
add_executable(${name}
fstbin/${name}.cc
fstext/kaldi-fst-io.cc
)
target_link_libraries(${name} PUBLIC kaldi-util)
endforeach()
endif()
We use Kaldi decoder to implement TLG based language model integration,
so we copied related files to this directory.
The main changes are:
1. To minimize the change, we use the same directories tree as Kaldi.
2. We replace Kaldi log system with glog in the following way.
``` c++
#define KALDI_WARN \
google::LogMessage(__FILE__, __LINE__, google::GLOG_WARNING).stream()
#define KALDI_ERR \
google::LogMessage(__FILE__, __LINE__, google::GLOG_ERROR).stream()
#define KALDI_INFO \
google::LogMessage(__FILE__, __LINE__, google::GLOG_INFO).stream()
#define KALDI_VLOG(v) VLOG(v)
#define KALDI_ASSERT(condition) CHECK(condition)
```
3. We lint all the files to satisfy the lint in WeNet.
// base/io-funcs-inl.h
// Copyright 2009-2011 Microsoft Corporation; Saarland University;
// Jan Silovsky; Yanmin Qian;
// Johns Hopkins University (Author: Daniel Povey)
// 2016 Xiaohui Zhang
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_BASE_IO_FUNCS_INL_H_
#define KALDI_BASE_IO_FUNCS_INL_H_ 1
// Do not include this file directly. It is included by base/io-funcs.h
#include <limits>
#include <vector>
#include <utility>
namespace kaldi {
// Template that covers integers.
template <class T>
void WriteBasicType(std::ostream &os, bool binary, T t) {
// Compile time assertion that this is not called with a wrong type.
KALDI_ASSERT_IS_INTEGER_TYPE(T);
if (binary) {
char len_c = (std::numeric_limits<T>::is_signed ? 1 : -1) *
static_cast<char>(sizeof(t));
os.put(len_c);
os.write(reinterpret_cast<const char *>(&t), sizeof(t));
} else {
if (sizeof(t) == 1)
os << static_cast<int16>(t) << " ";
else
os << t << " ";
}
if (os.fail()) {
KALDI_ERR << "Write failure in WriteBasicType.";
}
}
// Template that covers integers.
template <class T>
inline void ReadBasicType(std::istream &is, bool binary, T *t) {
KALDI_PARANOID_ASSERT(t != NULL);
// Compile time assertion that this is not called with a wrong type.
KALDI_ASSERT_IS_INTEGER_TYPE(T);
if (binary) {
int len_c_in = is.get();
if (len_c_in == -1)
KALDI_ERR << "ReadBasicType: encountered end of stream.";
char len_c = static_cast<char>(len_c_in),
len_c_expected = (std::numeric_limits<T>::is_signed ? 1 : -1) *
static_cast<char>(sizeof(*t));
if (len_c != len_c_expected) {
KALDI_ERR << "ReadBasicType: did not get expected integer type, "
<< static_cast<int>(len_c) << " vs. "
<< static_cast<int>(len_c_expected)
<< ". You can change this code to successfully"
<< " read it later, if needed.";
// insert code here to read "wrong" type. Might have a switch statement.
}
is.read(reinterpret_cast<char *>(t), sizeof(*t));
} else {
if (sizeof(*t) == 1) {
int16 i;
is >> i;
*t = i;
} else {
is >> *t;
}
}
if (is.fail()) {
KALDI_ERR << "Read failure in ReadBasicType, file position is "
<< is.tellg() << ", next char is " << is.peek();
}
}
// Template that covers integers.
template <class T>
inline void WriteIntegerPairVector(std::ostream &os, bool binary,
const std::vector<std::pair<T, T> > &v) {
// Compile time assertion that this is not called with a wrong type.
KALDI_ASSERT_IS_INTEGER_TYPE(T);
if (binary) {
char sz = sizeof(T); // this is currently just a check.
os.write(&sz, 1);
int32 vecsz = static_cast<int32>(v.size());
KALDI_ASSERT((size_t)vecsz == v.size());
os.write(reinterpret_cast<const char *>(&vecsz), sizeof(vecsz));
if (vecsz != 0) {
os.write(reinterpret_cast<const char *>(&(v[0])), sizeof(T) * vecsz * 2);
}
} else {
// focus here is on prettiness of text form rather than
// efficiency of reading-in.
// reading-in is dominated by low-level operations anyway:
// for efficiency use binary.
os << "[ ";
typename std::vector<std::pair<T, T> >::const_iterator iter = v.begin(),
end = v.end();
for (; iter != end; ++iter) {
if (sizeof(T) == 1)
os << static_cast<int16>(iter->first) << ','
<< static_cast<int16>(iter->second) << ' ';
else
os << iter->first << ',' << iter->second << ' ';
}
os << "]\n";
}
if (os.fail()) {
KALDI_ERR << "Write failure in WriteIntegerPairVector.";
}
}
// Template that covers integers.
template <class T>
inline void ReadIntegerPairVector(std::istream &is, bool binary,
std::vector<std::pair<T, T> > *v) {
KALDI_ASSERT_IS_INTEGER_TYPE(T);
KALDI_ASSERT(v != NULL);
if (binary) {
int sz = is.peek();
if (sz == sizeof(T)) {
is.get();
} else { // this is currently just a check.
KALDI_ERR << "ReadIntegerPairVector: expected to see type of size "
<< sizeof(T) << ", saw instead " << sz << ", at file position "
<< is.tellg();
}
int32 vecsz;
is.read(reinterpret_cast<char *>(&vecsz), sizeof(vecsz));
if (is.fail() || vecsz < 0) goto bad;
v->resize(vecsz);
if (vecsz > 0) {
is.read(reinterpret_cast<char *>(&((*v)[0])), sizeof(T) * vecsz * 2);
}
} else {
std::vector<std::pair<T, T> > tmp_v; // use temporary so v doesn't use
// extra memory due to resizing.
is >> std::ws;
if (is.peek() != static_cast<int>('[')) {
KALDI_ERR << "ReadIntegerPairVector: expected to see [, saw " << is.peek()
<< ", at file position " << is.tellg();
}
is.get(); // consume the '['.
is >> std::ws; // consume whitespace.
while (is.peek() != static_cast<int>(']')) {
if (sizeof(T) == 1) { // read/write chars as numbers.
int16 next_t1, next_t2;
is >> next_t1;
if (is.fail()) goto bad;
if (is.peek() != static_cast<int>(','))
KALDI_ERR << "ReadIntegerPairVector: expected to see ',', saw "
<< is.peek() << ", at file position " << is.tellg();
is.get(); // consume the ','.
is >> next_t2 >> std::ws;
if (is.fail())
goto bad;
else
tmp_v.push_back(std::make_pair((T)next_t1, (T)next_t2));
} else {
T next_t1, next_t2;
is >> next_t1;
if (is.fail()) goto bad;
if (is.peek() != static_cast<int>(','))
KALDI_ERR << "ReadIntegerPairVector: expected to see ',', saw "
<< is.peek() << ", at file position " << is.tellg();
is.get(); // consume the ','.
is >> next_t2 >> std::ws;
if (is.fail())
goto bad;
else
tmp_v.push_back(std::pair<T, T>(next_t1, next_t2));
}
}
is.get(); // get the final ']'.
*v = tmp_v; // could use std::swap to use less temporary memory, but this
// uses less permanent memory.
}
if (!is.fail()) return;
bad:
KALDI_ERR << "ReadIntegerPairVector: read failure at file position "
<< is.tellg();
}
template <class T>
inline void WriteIntegerVector(std::ostream &os, bool binary,
const std::vector<T> &v) {
// Compile time assertion that this is not called with a wrong type.
KALDI_ASSERT_IS_INTEGER_TYPE(T);
if (binary) {
char sz = sizeof(T); // this is currently just a check.
os.write(&sz, 1);
int32 vecsz = static_cast<int32>(v.size());
KALDI_ASSERT((size_t)vecsz == v.size());
os.write(reinterpret_cast<const char *>(&vecsz), sizeof(vecsz));
if (vecsz != 0) {
os.write(reinterpret_cast<const char *>(&(v[0])), sizeof(T) * vecsz);
}
} else {
// focus here is on prettiness of text form rather than
// efficiency of reading-in.
// reading-in is dominated by low-level operations anyway:
// for efficiency use binary.
os << "[ ";
typename std::vector<T>::const_iterator iter = v.begin(), end = v.end();
for (; iter != end; ++iter) {
if (sizeof(T) == 1)
os << static_cast<int16>(*iter) << " ";
else
os << *iter << " ";
}
os << "]\n";
}
if (os.fail()) {
KALDI_ERR << "Write failure in WriteIntegerVector.";
}
}
template <class T>
inline void ReadIntegerVector(std::istream &is, bool binary,
std::vector<T> *v) {
KALDI_ASSERT_IS_INTEGER_TYPE(T);
KALDI_ASSERT(v != NULL);
if (binary) {
int sz = is.peek();
if (sz == sizeof(T)) {
is.get();
} else { // this is currently just a check.
KALDI_ERR << "ReadIntegerVector: expected to see type of size "
<< sizeof(T) << ", saw instead " << sz << ", at file position "
<< is.tellg();
}
int32 vecsz;
is.read(reinterpret_cast<char *>(&vecsz), sizeof(vecsz));
if (is.fail() || vecsz < 0) goto bad;
v->resize(vecsz);
if (vecsz > 0) {
is.read(reinterpret_cast<char *>(&((*v)[0])), sizeof(T) * vecsz);
}
} else {
std::vector<T> tmp_v; // use temporary so v doesn't use extra memory
// due to resizing.
is >> std::ws;
if (is.peek() != static_cast<int>('[')) {
KALDI_ERR << "ReadIntegerVector: expected to see [, saw " << is.peek()
<< ", at file position " << is.tellg();
}
is.get(); // consume the '['.
is >> std::ws; // consume whitespace.
while (is.peek() != static_cast<int>(']')) {
if (sizeof(T) == 1) { // read/write chars as numbers.
int16 next_t;
is >> next_t >> std::ws;
if (is.fail())
goto bad;
else
tmp_v.push_back((T)next_t);
} else {
T next_t;
is >> next_t >> std::ws;
if (is.fail())
goto bad;
else
tmp_v.push_back(next_t);
}
}
is.get(); // get the final ']'.
*v = tmp_v; // could use std::swap to use less temporary memory, but this
// uses less permanent memory.
}
if (!is.fail()) return;
bad:
KALDI_ERR << "ReadIntegerVector: read failure at file position "
<< is.tellg();
}
// Initialize an opened stream for writing by writing an optional binary
// header and modifying the floating-point precision.
inline void InitKaldiOutputStream(std::ostream &os, bool binary) {
// This does not throw exceptions (does not check for errors).
if (binary) {
os.put('\0');
os.put('B');
}
// Note, in non-binary mode we may at some point want to mess with
// the precision a bit.
// 7 is a bit more than the precision of float..
if (os.precision() < 7) os.precision(7);
}
/// Initialize an opened stream for reading by detecting the binary header and
// setting the "binary" value appropriately.
inline bool InitKaldiInputStream(std::istream &is, bool *binary) {
// Sets the 'binary' variable.
// Throws exception in the very unusual situation that stream
// starts with '\0' but not then 'B'.
if (is.peek() == '\0') { // seems to be binary
is.get();
if (is.peek() != 'B') {
return false;
}
is.get();
*binary = true;
return true;
} else {
*binary = false;
return true;
}
}
} // end namespace kaldi.
#endif // KALDI_BASE_IO_FUNCS_INL_H_
// base/io-funcs.cc
// Copyright 2009-2011 Microsoft Corporation; Saarland University
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include "base/io-funcs.h"
#include "base/kaldi-math.h"
namespace kaldi {
template <>
void WriteBasicType<bool>(std::ostream &os, bool binary, bool b) {
os << (b ? "T" : "F");
if (!binary) os << " ";
if (os.fail()) KALDI_ERR << "Write failure in WriteBasicType<bool>";
}
template <>
void ReadBasicType<bool>(std::istream &is, bool binary, bool *b) {
KALDI_PARANOID_ASSERT(b != NULL);
if (!binary) is >> std::ws; // eat up whitespace.
char c = is.peek();
if (c == 'T') {
*b = true;
is.get();
} else if (c == 'F') {
*b = false;
is.get();
} else {
KALDI_ERR << "Read failure in ReadBasicType<bool>, file position is "
<< is.tellg() << ", next char is " << CharToString(c);
}
}
template <>
void WriteBasicType<float>(std::ostream &os, bool binary, float f) {
if (binary) {
char c = sizeof(f);
os.put(c);
os.write(reinterpret_cast<const char *>(&f), sizeof(f));
} else {
os << f << " ";
}
}
template <>
void WriteBasicType<double>(std::ostream &os, bool binary, double f) {
if (binary) {
char c = sizeof(f);
os.put(c);
os.write(reinterpret_cast<const char *>(&f), sizeof(f));
} else {
os << f << " ";
}
}
template <>
void ReadBasicType<float>(std::istream &is, bool binary, float *f) {
KALDI_PARANOID_ASSERT(f != NULL);
if (binary) {
double d;
int c = is.peek();
if (c == sizeof(*f)) {
is.get();
is.read(reinterpret_cast<char *>(f), sizeof(*f));
} else if (c == sizeof(d)) {
ReadBasicType(is, binary, &d);
*f = d;
} else {
KALDI_ERR << "ReadBasicType: expected float, saw " << is.peek()
<< ", at file position " << is.tellg();
}
} else {
is >> *f;
}
if (is.fail()) {
KALDI_ERR << "ReadBasicType: failed to read, at file position "
<< is.tellg();
}
}
template <>
void ReadBasicType<double>(std::istream &is, bool binary, double *d) {
KALDI_PARANOID_ASSERT(d != NULL);
if (binary) {
float f;
int c = is.peek();
if (c == sizeof(*d)) {
is.get();
is.read(reinterpret_cast<char *>(d), sizeof(*d));
} else if (c == sizeof(f)) {
ReadBasicType(is, binary, &f);
*d = f;
} else {
KALDI_ERR << "ReadBasicType: expected float, saw " << is.peek()
<< ", at file position " << is.tellg();
}
} else {
is >> *d;
}
if (is.fail()) {
KALDI_ERR << "ReadBasicType: failed to read, at file position "
<< is.tellg();
}
}
void CheckToken(const char *token) {
if (*token == '\0') KALDI_ERR << "Token is empty (not a valid token)";
const char *orig_token = token;
while (*token != '\0') {
if (::isspace(*token))
KALDI_ERR << "Token is not a valid token (contains space): '"
<< orig_token << "'";
token++;
}
}
void WriteToken(std::ostream &os, bool binary, const char *token) {
// binary mode is ignored;
// we use space as termination character in either case.
KALDI_ASSERT(token != NULL);
CheckToken(token); // make sure it's valid (can be read back)
os << token << " ";
if (os.fail()) {
KALDI_ERR << "Write failure in WriteToken.";
}
}
int Peek(std::istream &is, bool binary) {
if (!binary) is >> std::ws; // eat up whitespace.
return is.peek();
}
void WriteToken(std::ostream &os, bool binary, const std::string &token) {
WriteToken(os, binary, token.c_str());
}
void ReadToken(std::istream &is, bool binary, std::string *str) {
KALDI_ASSERT(str != NULL);
if (!binary) is >> std::ws; // consume whitespace.
is >> *str;
if (is.fail()) {
KALDI_ERR << "ReadToken, failed to read token at file position "
<< is.tellg();
}
if (!isspace(is.peek())) {
KALDI_ERR << "ReadToken, expected space after token, saw instead "
<< CharToString(static_cast<char>(is.peek()))
<< ", at file position " << is.tellg();
}
is.get(); // consume the space.
}
int PeekToken(std::istream &is, bool binary) {
if (!binary) is >> std::ws; // consume whitespace.
bool read_bracket;
if (static_cast<char>(is.peek()) == '<') {
read_bracket = true;
is.get();
} else {
read_bracket = false;
}
int ans = is.peek();
if (read_bracket) {
if (!is.unget()) {
// Clear the bad bit. This code can be (and is in fact) reached, since the
// C++ standard does not guarantee that a call to unget() must succeed.
is.clear();
}
}
return ans;
}
void ExpectToken(std::istream &is, bool binary, const char *token) {
int pos_at_start = is.tellg();
KALDI_ASSERT(token != NULL);
CheckToken(token); // make sure it's valid (can be read back)
if (!binary) is >> std::ws; // consume whitespace.
std::string str;
is >> str;
is.get(); // consume the space.
if (is.fail()) {
KALDI_ERR << "Failed to read token [started at file position "
<< pos_at_start << "], expected " << token;
}
// The second half of the '&&' expression below is so that if we're expecting
// "<Foo>", we will accept "Foo>" instead. This is so that the model-reading
// code will tolerate errors in PeekToken where is.unget() failed; search for
// is.clear() in PeekToken() for an explanation.
if (strcmp(str.c_str(), token) != 0 &&
!(token[0] == '<' && strcmp(str.c_str(), token + 1) == 0)) {
KALDI_ERR << "Expected token \"" << token << "\", got instead \"" << str
<< "\".";
}
}
void ExpectToken(std::istream &is, bool binary, const std::string &token) {
ExpectToken(is, binary, token.c_str());
}
} // end namespace kaldi
// base/io-funcs.h
// Copyright 2009-2011 Microsoft Corporation; Saarland University;
// Jan Silovsky; Yanmin Qian
// 2016 Xiaohui Zhang
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_BASE_IO_FUNCS_H_
#define KALDI_BASE_IO_FUNCS_H_
// This header only contains some relatively low-level I/O functions.
// The full Kaldi I/O declarations are in ../util/kaldi-io.h
// and ../util/kaldi-table.h
// They were put in util/ in order to avoid making the Matrix library
// dependent on them.
#include <cctype>
#include <string>
#include <utility>
#include <vector>
#include "base/io-funcs-inl.h"
#include "base/kaldi-common.h"
namespace kaldi {
/*
This comment describes the Kaldi approach to I/O. All objects can be written
and read in two modes: binary and text. In addition we want to make the I/O
work if we redefine the typedef "BaseFloat" between floats and doubles.
We also want to have control over whitespace in text mode without affecting
the meaning of the file, for pretty-printing purposes.
Errors are handled by throwing a KaldiFatalError exception.
For integer and floating-point types (and boolean values):
WriteBasicType(std::ostream &, bool binary, const T&);
ReadBasicType(std::istream &, bool binary, T*);
and we expect these functions to be defined in such a way that they work when
the type T changes between float and double, so you can read float into double
and vice versa]. Note that for efficiency and space-saving reasons, the
Vector and Matrix classes do not use these functions [but they preserve the
type interchangeability in their own way]
For a class (or struct) C:
class C {
..
Write(std::ostream &, bool binary, [possibly extra optional args for
specific classes]) const; Read(std::istream &, bool binary, [possibly extra
optional args for specific classes]);
..
}
NOTE: The only actual optional args we used are the "add" arguments in
Vector/Matrix classes, which specify whether we should sum the data already
in the class with the data being read.
For types which are typedef's involving stl classes, I/O is as follows:
typedef std::vector<std::pair<A, B> > MyTypedefName;
The user should define something like:
WriteMyTypedefName(std::ostream &, bool binary, const MyTypedefName &t);
ReadMyTypedefName(std::ostream &, bool binary, MyTypedefName *t);
The user would have to write these functions.
For a type std::vector<T>:
void WriteIntegerVector(std::ostream &os, bool binary, const std::vector<T>
&v); void ReadIntegerVector(std::istream &is, bool binary, std::vector<T> *v);
For other types, e.g. vectors of pairs, the user should create a routine of
the type WriteMyTypedefName. This is to avoid introducing confusing templated
functions; we could easily create templated functions to handle most of these
cases but they would have to share the same name.
It also often happens that the user needs to write/read special tokens as part
of a file. These might be class headers, or separators/identifiers in the
class. We provide special functions for manipulating these. These special
tokens must be nonempty and must not contain any whitespace.
void WriteToken(std::ostream &os, bool binary, const char*);
void WriteToken(std::ostream &os, bool binary, const std::string & token);
int Peek(std::istream &is, bool binary);
void ReadToken(std::istream &is, bool binary, std::string *str);
void PeekToken(std::istream &is, bool binary, std::string *str);
WriteToken writes the token and one space (whether in binary or text mode).
Peek returns the first character of the next token, by consuming whitespace
(in text mode) and then returning the peek() character. It returns -1 at EOF;
it doesn't throw. It's useful if a class can have various forms based on
typedefs and virtual classes, and wants to know which version to read.
ReadToken allows the caller to obtain the next token. PeekToken works just
like ReadToken, but seeks back to the beginning of the token. A subsequent
call to ReadToken will read the same token again. This is useful when
different object types are written to the same file; using PeekToken one can
decide which of the objects to read.
There is currently no special functionality for writing/reading strings (where
the strings contain data rather than "special tokens" that are whitespace-free
and nonempty). This is because Kaldi is structured in such a way that strings
don't appear, except as OpenFst symbol table entries (and these have their own
format).
NOTE: you should not call ReadIntegerType and WriteIntegerType with types,
such as int and size_t, that are machine-independent -- at least not
if you want your file formats to port between machines. Use int32 and
int64 where necessary. There is no way to detect this using compile-time
assertions because C++ only keeps track of the internal representation of
the type.
*/
/// \addtogroup io_funcs_basic
/// @{
/// WriteBasicType is the name of the write function for bool, integer types,
/// and floating-point types. They all throw on error.
template <class T>
void WriteBasicType(std::ostream &os, bool binary, T t);
/// ReadBasicType is the name of the read function for bool, integer types,
/// and floating-point types. They all throw on error.
template <class T>
void ReadBasicType(std::istream &is, bool binary, T *t);
// Declare specialization for bool.
template <>
void WriteBasicType<bool>(std::ostream &os, bool binary, bool b);
template <>
void ReadBasicType<bool>(std::istream &is, bool binary, bool *b);
// Declare specializations for float and double.
template <>
void WriteBasicType<float>(std::ostream &os, bool binary, float f);
template <>
void WriteBasicType<double>(std::ostream &os, bool binary, double f);
template <>
void ReadBasicType<float>(std::istream &is, bool binary, float *f);
template <>
void ReadBasicType<double>(std::istream &is, bool binary, double *f);
// Define ReadBasicType that accepts an "add" parameter to add to
// the destination. Caution: if used in Read functions, be careful
// to initialize the parameters concerned to zero in the default
// constructor.
template <class T>
inline void ReadBasicType(std::istream &is, bool binary, T *t, bool add) {
if (!add) {
ReadBasicType(is, binary, t);
} else {
T tmp = T(0);
ReadBasicType(is, binary, &tmp);
*t += tmp;
}
}
/// Function for writing STL vectors of integer types.
template <class T>
inline void WriteIntegerVector(std::ostream &os, bool binary,
const std::vector<T> &v);
/// Function for reading STL vector of integer types.
template <class T>
inline void ReadIntegerVector(std::istream &is, bool binary, std::vector<T> *v);
/// Function for writing STL vectors of pairs of integer types.
template <class T>
inline void WriteIntegerPairVector(std::ostream &os, bool binary,
const std::vector<std::pair<T, T> > &v);
/// Function for reading STL vector of pairs of integer types.
template <class T>
inline void ReadIntegerPairVector(std::istream &is, bool binary,
std::vector<std::pair<T, T> > *v);
/// The WriteToken functions are for writing nonempty sequences of non-space
/// characters. They are not for general strings.
void WriteToken(std::ostream &os, bool binary, const char *token);
void WriteToken(std::ostream &os, bool binary, const std::string &token);
/// Peek consumes whitespace (if binary == false) and then returns the peek()
/// value of the stream.
int Peek(std::istream &is, bool binary);
/// ReadToken gets the next token and puts it in str (exception on failure). If
/// PeekToken() had been previously called, it is possible that the stream had
/// failed to unget the starting '<' character. In this case ReadToken() returns
/// the token string without the leading '<'. You must be prepared to handle
/// this case. ExpectToken() handles this internally, and is not affected.
void ReadToken(std::istream &is, bool binary, std::string *token);
/// PeekToken will return the first character of the next token, or -1 if end of
/// file. It's the same as Peek(), except if the first character is '<' it will
/// skip over it and will return the next character. It will attempt to unget
/// the '<' so the stream is where it was before you did PeekToken(), however,
/// this is not guaranteed (see ReadToken()).
int PeekToken(std::istream &is, bool binary);
/// ExpectToken tries to read in the given token, and throws an exception
/// on failure.
void ExpectToken(std::istream &is, bool binary, const char *token);
void ExpectToken(std::istream &is, bool binary, const std::string &token);
/// ExpectPretty attempts to read the text in "token", but only in non-binary
/// mode. Throws exception on failure. It expects an exact match except that
/// arbitrary whitespace matches arbitrary whitespace.
void ExpectPretty(std::istream &is, bool binary, const char *token);
void ExpectPretty(std::istream &is, bool binary, const std::string &token);
/// @} end "addtogroup io_funcs_basic"
/// InitKaldiOutputStream initializes an opened stream for writing by writing an
/// optional binary header and modifying the floating-point precision; it will
/// typically not be called by users directly.
inline void InitKaldiOutputStream(std::ostream &os, bool binary);
/// InitKaldiInputStream initializes an opened stream for reading by detecting
/// the binary header and setting the "binary" value appropriately;
/// It will typically not be called by users directly.
inline bool InitKaldiInputStream(std::istream &is, bool *binary);
} // end namespace kaldi.
#endif // KALDI_BASE_IO_FUNCS_H_
// base/kaldi-common.h
// Copyright 2009-2011 Microsoft Corporation
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_BASE_KALDI_COMMON_H_
#define KALDI_BASE_KALDI_COMMON_H_ 1
#include <cstddef>
#include <cstdlib>
#include <cstring> // C string stuff like strcpy
#include <string>
#include <sstream>
#include <stdexcept>
#include <cassert>
#include <vector>
#include <iostream>
#include <fstream>
#include "base/kaldi-utils.h"
#include "base/kaldi-error.h"
#include "base/kaldi-types.h"
// #include "base/io-funcs.h"
#include "base/kaldi-math.h"
// #include "base/timer.h"
#endif // KALDI_BASE_KALDI_COMMON_H_
// base/kaldi-error.cc
// Copyright 2019 LAIX (Yi Sun)
// Copyright 2019 SmartAction LLC (kkm)
// Copyright 2016 Brno University of Technology (author: Karel Vesely)
// Copyright 2009-2011 Microsoft Corporation; Lukas Burget; Ondrej Glembek
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include "base/kaldi-error.h"
#include <string>
namespace kaldi {
/***** GLOBAL VARIABLES FOR LOGGING *****/
int32 g_kaldi_verbose_level = 0;
static std::string program_name; // NOLINT
void SetProgramName(const char *basename) {
// Using the 'static std::string' for the program name is mostly harmless,
// because (a) Kaldi logging is undefined before main(), and (b) no stdc++
// string implementation has been found in the wild that would not be just
// an empty string when zero-initialized but not yet constructed.
program_name = basename;
}
} // namespace kaldi
// base/kaldi-error.h
// Copyright (c) 2021 Mobvoi Inc (Binbin Zhang)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_BASE_KALDI_ERROR_H_
#define KALDI_BASE_KALDI_ERROR_H_ 1
#include "utils/log.h"
namespace kaldi {
#define KALDI_WARN \
google::LogMessage(__FILE__, __LINE__, google::GLOG_WARNING).stream()
#define KALDI_ERR \
google::LogMessage(__FILE__, __LINE__, google::GLOG_ERROR).stream()
#define KALDI_LOG \
google::LogMessage(__FILE__, __LINE__, google::GLOG_INFO).stream()
#define KALDI_VLOG(v) VLOG(v)
#define KALDI_ASSERT(condition) CHECK(condition)
/***** PROGRAM NAME AND VERBOSITY LEVEL *****/
/// Called by ParseOptions to set base name (no directory) of the executing
/// program. The name is printed in logging code along with every message,
/// because in our scripts, we often mix together the stderr of many programs.
/// This function is very thread-unsafe.
void SetProgramName(const char *basename);
/// This is set by util/parse-options.{h,cc} if you set --verbose=? option.
/// Do not use directly, prefer {Get,Set}VerboseLevel().
extern int32 g_kaldi_verbose_level;
/// Get verbosity level, usually set via command line '--verbose=' switch.
inline int32 GetVerboseLevel() { return g_kaldi_verbose_level; }
/// This should be rarely used, except by programs using Kaldi as library;
/// command-line programs set the verbose level automatically from ParseOptions.
inline void SetVerboseLevel(int32 i) { g_kaldi_verbose_level = i; }
} // namespace kaldi
#endif // KALDI_BASE_KALDI_ERROR_H_
// base/kaldi-math.cc
// Copyright 2009-2011 Microsoft Corporation; Yanmin Qian;
// Saarland University; Jan Silovsky
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include "base/kaldi-math.h"
#ifndef _MSC_VER
#include <stdlib.h>
#include <unistd.h>
#endif
#include <string>
#include <mutex>
namespace kaldi {
// These routines are tested in matrix/matrix-test.cc
int32 RoundUpToNearestPowerOfTwo(int32 n) {
KALDI_ASSERT(n > 0);
n--;
n |= n >> 1;
n |= n >> 2;
n |= n >> 4;
n |= n >> 8;
n |= n >> 16;
return n+1;
}
static std::mutex _RandMutex;
int Rand(struct RandomState* state) {
#if !defined(_POSIX_THREAD_SAFE_FUNCTIONS)
// On Windows and Cygwin, just call Rand()
return rand();
#else
if (state) {
return rand_r(&(state->seed));
} else {
std::lock_guard<std::mutex> lock(_RandMutex);
return rand();
}
#endif
}
RandomState::RandomState() {
// we initialize it as Rand() + 27437 instead of just Rand(), because on some
// systems, e.g. at the very least Mac OSX Yosemite and later, it seems to be
// the case that rand_r when initialized with rand() will give you the exact
// same sequence of numbers that rand() will give if you keep calling rand()
// after that initial call. This can cause problems with repeated sequences.
// For example if you initialize two RandomState structs one after the other
// without calling rand() in between, they would give you the same sequence
// offset by one (if we didn't have the "+ 27437" in the code). 27437 is just
// a randomly chosen prime number.
seed = unsigned(Rand()) + 27437;
}
bool WithProb(BaseFloat prob, struct RandomState* state) {
KALDI_ASSERT(prob >= 0 && prob <= 1.1); // prob should be <= 1.0,
// but we allow slightly larger values that could arise from roundoff in
// previous calculations.
KALDI_COMPILE_TIME_ASSERT(RAND_MAX > 128 * 128);
if (prob == 0) {
return false;
} else if (prob == 1.0) {
return true;
} else if (prob * RAND_MAX < 128.0) {
// prob is very small but nonzero, and the "main algorithm"
// wouldn't work that well. So: with probability 1/128, we
// return WithProb (prob * 128), else return false.
if (Rand(state) < RAND_MAX / 128) { // with probability 128...
// Note: we know that prob * 128.0 < 1.0, because
// we asserted RAND_MAX > 128 * 128.
return WithProb(prob * 128.0);
} else {
return false;
}
} else {
return (Rand(state) < ((RAND_MAX + static_cast<BaseFloat>(1.0)) * prob));
}
}
int32 RandInt(int32 min_val, int32 max_val, struct RandomState* state) {
// This is not exact.
KALDI_ASSERT(max_val >= min_val);
if (max_val == min_val) return min_val;
#ifdef _MSC_VER
// RAND_MAX is quite small on Windows -> may need to handle larger numbers.
if (RAND_MAX > (max_val-min_val)*8) {
// *8 to avoid large inaccuracies in probability, from the modulus...
return min_val +
((unsigned int)Rand(state) % (unsigned int)(max_val+1-min_val));
} else {
if ((unsigned int)(RAND_MAX*RAND_MAX) >
(unsigned int)((max_val+1-min_val)*8)) {
// *8 to avoid inaccuracies in probability, from the modulus...
return min_val + ( (unsigned int)( (Rand(state)+RAND_MAX*Rand(state)))
% (unsigned int)(max_val+1-min_val));
} else {
KALDI_ERR << "rand_int failed because we do not support such large "
"random numbers. (Extend this function).";
}
}
#else
return min_val +
(static_cast<int32>(Rand(state)) % static_cast<int32>(max_val+1-min_val));
#endif
}
// Returns poisson-distributed random number.
// Take care: this takes time proportional
// to lambda. Faster algorithms exist but are more complex.
int32 RandPoisson(float lambda, struct RandomState* state) {
// Knuth's algorithm.
KALDI_ASSERT(lambda >= 0);
float L = expf(-lambda), p = 1.0;
int32 k = 0;
do {
k++;
float u = RandUniform(state);
p *= u;
} while (p > L);
return k-1;
}
void RandGauss2(float *a, float *b, RandomState *state) {
KALDI_ASSERT(a);
KALDI_ASSERT(b);
float u1 = RandUniform(state);
float u2 = RandUniform(state);
u1 = sqrtf(-2.0f * logf(u1));
u2 = 2.0f * M_PI * u2;
*a = u1 * cosf(u2);
*b = u1 * sinf(u2);
}
void RandGauss2(double *a, double *b, RandomState *state) {
KALDI_ASSERT(a);
KALDI_ASSERT(b);
float a_float, b_float;
// Just because we're using doubles doesn't mean we need super-high-quality
// random numbers, so we just use the floating-point version internally.
RandGauss2(&a_float, &b_float, state);
*a = a_float;
*b = b_float;
}
} // end namespace kaldi
// base/kaldi-math.h
// Copyright 2009-2011 Ondrej Glembek; Microsoft Corporation; Yanmin Qian;
// Jan Silovsky; Saarland University
//
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_BASE_KALDI_MATH_H_
#define KALDI_BASE_KALDI_MATH_H_ 1
#ifdef _MSC_VER
#include <float.h>
#endif
#include <cmath>
#include <limits>
#include <vector>
#include "base/kaldi-types.h"
#include "base/kaldi-common.h"
#ifndef DBL_EPSILON
#define DBL_EPSILON 2.2204460492503131e-16
#endif
#ifndef FLT_EPSILON
#define FLT_EPSILON 1.19209290e-7f
#endif
#ifndef M_PI
#define M_PI 3.1415926535897932384626433832795
#endif
#ifndef M_SQRT2
#define M_SQRT2 1.4142135623730950488016887
#endif
#ifndef M_2PI
#define M_2PI 6.283185307179586476925286766559005
#endif
#ifndef M_SQRT1_2
#define M_SQRT1_2 0.7071067811865475244008443621048490
#endif
#ifndef M_LOG_2PI
#define M_LOG_2PI 1.8378770664093454835606594728112
#endif
#ifndef M_LN2
#define M_LN2 0.693147180559945309417232121458
#endif
#ifndef M_LN10
#define M_LN10 2.302585092994045684017991454684
#endif
#define KALDI_ISNAN std::isnan
#define KALDI_ISINF std::isinf
#define KALDI_ISFINITE(x) std::isfinite(x)
#if !defined(KALDI_SQR)
# define KALDI_SQR(x) ((x) * (x))
#endif
namespace kaldi {
#if !defined(_MSC_VER) || (_MSC_VER >= 1900)
inline double Exp(double x) { return exp(x); }
#ifndef KALDI_NO_EXPF
inline float Exp(float x) { return expf(x); }
#else
inline float Exp(float x) { return exp(static_cast<double>(x)); }
#endif // KALDI_NO_EXPF
#else
inline double Exp(double x) { return exp(x); }
#if !defined(__INTEL_COMPILER) && _MSC_VER == 1800 && defined(_M_X64)
// Microsoft CL v18.0 buggy 64-bit implementation of
// expf() incorrectly returns -inf for exp(-inf).
inline float Exp(float x) { return exp(static_cast<double>(x)); }
#else
inline float Exp(float x) { return expf(x); }
#endif // !defined(__INTEL_COMPILER) && _MSC_VER == 1800 && defined(_M_X64)
#endif // !defined(_MSC_VER) || (_MSC_VER >= 1900)
inline double Log(double x) { return log(x); }
inline float Log(float x) { return logf(x); }
#if !defined(_MSC_VER) || (_MSC_VER >= 1700)
inline double Log1p(double x) { return log1p(x); }
inline float Log1p(float x) { return log1pf(x); }
#else
inline double Log1p(double x) {
const double cutoff = 1.0e-08;
if (x < cutoff)
return x - 0.5 * x * x;
else
return Log(1.0 + x);
}
inline float Log1p(float x) {
const float cutoff = 1.0e-07;
if (x < cutoff)
return x - 0.5 * x * x;
else
return Log(1.0 + x);
}
#endif
static const double kMinLogDiffDouble = Log(DBL_EPSILON); // negative!
static const float kMinLogDiffFloat = Log(FLT_EPSILON); // negative!
// -infinity
const float kLogZeroFloat = -std::numeric_limits<float>::infinity();
const double kLogZeroDouble = -std::numeric_limits<double>::infinity();
const BaseFloat kLogZeroBaseFloat = -std::numeric_limits<BaseFloat>::infinity();
// Returns a random integer between 0 and RAND_MAX, inclusive
int Rand(struct RandomState* state = NULL);
// State for thread-safe random number generator
struct RandomState {
RandomState();
unsigned seed;
};
// Returns a random integer between first and last inclusive.
int32 RandInt(int32 first, int32 last, struct RandomState* state = NULL);
// Returns true with probability "prob",
bool WithProb(BaseFloat prob, struct RandomState* state = NULL);
// with 0 <= prob <= 1 [we check this].
// Internally calls Rand(). This function is carefully implemented so
// that it should work even if prob is very small.
/// Returns a random number strictly between 0 and 1.
inline float RandUniform(struct RandomState* state = NULL) {
return static_cast<float>((Rand(state) + 1.0) / (RAND_MAX+2.0));
}
inline float RandGauss(struct RandomState* state = NULL) {
return static_cast<float>(sqrtf (-2 * Log(RandUniform(state)))
* cosf(2*M_PI*RandUniform(state)));
}
// Returns poisson-distributed random number. Uses Knuth's algorithm.
// Take care: this takes time proportional
// to lambda. Faster algorithms exist but are more complex.
int32 RandPoisson(float lambda, struct RandomState* state = NULL);
// Returns a pair of gaussian random numbers. Uses Box-Muller transform
void RandGauss2(float *a, float *b, RandomState *state = NULL);
void RandGauss2(double *a, double *b, RandomState *state = NULL);
// Also see Vector<float,double>::RandCategorical().
// This is a randomized pruning mechanism that preserves expectations,
// that we typically use to prune posteriors.
template<class Float>
inline Float RandPrune(Float post, BaseFloat prune_thresh,
struct RandomState* state = NULL) {
KALDI_ASSERT(prune_thresh >= 0.0);
if (post == 0.0 || std::abs(post) >= prune_thresh)
return post;
return (post >= 0 ? 1.0 : -1.0) *
(RandUniform(state) <= fabs(post)/prune_thresh ? prune_thresh : 0.0);
}
// returns log(exp(x) + exp(y)).
inline double LogAdd(double x, double y) {
double diff;
if (x < y) {
diff = x - y;
x = y;
} else {
diff = y - x;
}
// diff is negative. x is now the larger one.
if (diff >= kMinLogDiffDouble) {
double res;
res = x + Log1p(Exp(diff));
return res;
} else {
return x; // return the larger one.
}
}
// returns log(exp(x) + exp(y)).
inline float LogAdd(float x, float y) {
float diff;
if (x < y) {
diff = x - y;
x = y;
} else {
diff = y - x;
}
// diff is negative. x is now the larger one.
if (diff >= kMinLogDiffFloat) {
float res;
res = x + Log1p(Exp(diff));
return res;
} else {
return x; // return the larger one.
}
}
// returns log(exp(x) - exp(y)).
inline double LogSub(double x, double y) {
if (y >= x) { // Throws exception if y>=x.
if (y == x)
return kLogZeroDouble;
else
KALDI_ERR << "Cannot subtract a larger from a smaller number.";
}
double diff = y - x; // Will be negative.
double res = x + Log(1.0 - Exp(diff));
// res might be NAN if diff ~0.0, and 1.0-exp(diff) == 0 to machine precision
if (KALDI_ISNAN(res))
return kLogZeroDouble;
return res;
}
// returns log(exp(x) - exp(y)).
inline float LogSub(float x, float y) {
if (y >= x) { // Throws exception if y>=x.
if (y == x)
return kLogZeroDouble;
else
KALDI_ERR << "Cannot subtract a larger from a smaller number.";
}
float diff = y - x; // Will be negative.
float res = x + Log(1.0f - Exp(diff));
// res might be NAN if diff ~0.0, and 1.0-exp(diff) == 0 to machine precision
if (KALDI_ISNAN(res))
return kLogZeroFloat;
return res;
}
/// return abs(a - b) <= relative_tolerance * (abs(a)+abs(b)).
static inline bool ApproxEqual(float a, float b,
float relative_tolerance = 0.001) {
// a==b handles infinities.
if (a == b) return true;
float diff = std::abs(a-b);
if (diff == std::numeric_limits<float>::infinity()
|| diff != diff) return false; // diff is +inf or nan.
return (diff <= relative_tolerance*(std::abs(a)+std::abs(b)));
}
/// assert abs(a - b) <= relative_tolerance * (abs(a)+abs(b))
static inline void AssertEqual(float a, float b,
float relative_tolerance = 0.001) {
// a==b handles infinities.
KALDI_ASSERT(ApproxEqual(a, b, relative_tolerance));
}
// RoundUpToNearestPowerOfTwo does the obvious thing. It crashes if n <= 0.
int32 RoundUpToNearestPowerOfTwo(int32 n);
/// Returns a / b, rounding towards negative infinity in all cases.
static inline int32 DivideRoundingDown(int32 a, int32 b) {
KALDI_ASSERT(b != 0);
if (a * b >= 0)
return a / b;
else if (a < 0)
return (a - b + 1) / b;
else
return (a - b - 1) / b;
}
template<class I> I Gcd(I m, I n) {
if (m == 0 || n == 0) {
if (m == 0 && n == 0) { // gcd not defined, as all integers are divisors.
KALDI_ERR << "Undefined GCD since m = 0, n = 0.";
}
return (m == 0 ? (n > 0 ? n : -n) : ( m > 0 ? m : -m));
// return absolute value of whichever is nonzero
}
// could use compile-time assertion
// but involves messing with complex template stuff.
KALDI_ASSERT(std::numeric_limits<I>::is_integer);
while (1) {
m %= n;
if (m == 0) return (n > 0 ? n : -n);
n %= m;
if (n == 0) return (m > 0 ? m : -m);
}
}
/// Returns the least common multiple of two integers. Will
/// crash unless the inputs are positive.
template<class I> I Lcm(I m, I n) {
KALDI_ASSERT(m > 0 && n > 0);
I gcd = Gcd(m, n);
return gcd * (m/gcd) * (n/gcd);
}
template<class I> void Factorize(I m, std::vector<I> *factors) {
// Splits a number into its prime factors, in sorted order from
// least to greatest, with duplication. A very inefficient
// algorithm, which is mainly intended for use in the
// mixed-radix FFT computation (where we assume most factors
// are small).
KALDI_ASSERT(factors != NULL);
KALDI_ASSERT(m >= 1); // Doesn't work for zero or negative numbers.
factors->clear();
I small_factors[10] = { 2, 3, 5, 7, 11, 13, 17, 19, 23, 29 };
// First try small factors.
for (I i = 0; i < 10; i++) {
if (m == 1) return; // We're done.
while (m % small_factors[i] == 0) {
m /= small_factors[i];
factors->push_back(small_factors[i]);
}
}
// Next try all odd numbers starting from 31.
for (I j = 31;; j += 2) {
if (m == 1) return;
while (m % j == 0) {
m /= j;
factors->push_back(j);
}
}
}
inline double Hypot(double x, double y) { return hypot(x, y); }
inline float Hypot(float x, float y) { return hypotf(x, y); }
} // namespace kaldi
#endif // KALDI_BASE_KALDI_MATH_H_
// base/kaldi-types.h
// Copyright 2009-2011 Microsoft Corporation; Saarland University;
// Jan Silovsky; Yanmin Qian
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_BASE_KALDI_TYPES_H_
#define KALDI_BASE_KALDI_TYPES_H_ 1
namespace kaldi {
// TYPEDEFS ..................................................................
#if (KALDI_DOUBLEPRECISION != 0)
typedef double BaseFloat;
#else
typedef float BaseFloat;
#endif
}
#ifdef _MSC_VER
#include <basetsd.h>
#define ssize_t SSIZE_T
#endif
// we can do this a different way if some platform
// we find in the future lacks stdint.h
#include <stdint.h>
// for discussion on what to do if you need compile kaldi
// without OpenFST, see the bottom of this this file
#include <fst/types.h>
namespace kaldi {
using ::int16;
using ::int32;
using ::int64;
using ::uint16;
using ::uint32;
using ::uint64;
typedef float float32;
typedef double double64;
} // end namespace kaldi
// In a theoretical case you decide compile Kaldi without the OpenFST
// comment the previous namespace statement and uncomment the following
/*
namespace kaldi {
typedef int8_t int8;
typedef int16_t int16;
typedef int32_t int32;
typedef int64_t int64;
typedef uint8_t uint8;
typedef uint16_t uint16;
typedef uint32_t uint32;
typedef uint64_t uint64;
typedef float float32;
typedef double double64;
} // end namespace kaldi
*/
#endif // KALDI_BASE_KALDI_TYPES_H_
// base/kaldi-utils.h
// Copyright 2009-2011 Ondrej Glembek; Microsoft Corporation;
// Saarland University; Karel Vesely; Yanmin Qian
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_BASE_KALDI_UTILS_H_
#define KALDI_BASE_KALDI_UTILS_H_ 1
#if defined(_MSC_VER)
# define WIN32_LEAN_AND_MEAN
# define NOMINMAX
# include <windows.h>
#endif
#ifdef _MSC_VER
#include <stdio.h>
#define unlink _unlink
#else
#include <unistd.h>
#endif
#include <limits>
#include <string>
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4056 4305 4800 4267 4996 4756 4661)
#if _MSC_VER < 1400
#define __restrict__
#else
#define __restrict__ __restrict
#endif
#endif
#if defined(_MSC_VER)
# define KALDI_MEMALIGN(align, size, pp_orig) \
(*(pp_orig) = _aligned_malloc(size, align))
# define KALDI_MEMALIGN_FREE(x) _aligned_free(x)
#elif defined(__CYGWIN__)
# define KALDI_MEMALIGN(align, size, pp_orig) \
(*(pp_orig) = aligned_alloc(align, size))
# define KALDI_MEMALIGN_FREE(x) free(x)
#else
# define KALDI_MEMALIGN(align, size, pp_orig) \
(!posix_memalign(pp_orig, align, size) ? *(pp_orig) : NULL)
# define KALDI_MEMALIGN_FREE(x) free(x)
#endif
#ifdef __ICC
#pragma warning(disable: 383) // ICPC remark we don't want.
#pragma warning(disable: 810) // ICPC remark we don't want.
#pragma warning(disable: 981) // ICPC remark we don't want.
#pragma warning(disable: 1418) // ICPC remark we don't want.
#pragma warning(disable: 444) // ICPC remark we don't want.
#pragma warning(disable: 869) // ICPC remark we don't want.
#pragma warning(disable: 1287) // ICPC remark we don't want.
#pragma warning(disable: 279) // ICPC remark we don't want.
#pragma warning(disable: 981) // ICPC remark we don't want.
#endif
namespace kaldi {
// CharToString prints the character in a human-readable form, for debugging.
std::string CharToString(const char &c);
inline int MachineIsLittleEndian() {
int check = 1;
return (*reinterpret_cast<char*>(&check) != 0);
}
// This function kaldi::Sleep() provides a portable way
// to sleep for a possibly fractional
// number of seconds. On Windows it's only accurate to microseconds.
void Sleep(float seconds);
} // namespace kaldi
#define KALDI_SWAP8(a) do { \
int t = (reinterpret_cast<char*>(&a))[0];\
(reinterpret_cast<char*>(&a))[0]=(reinterpret_cast<char*>(&a))[7];\
(reinterpret_cast<char*>(&a))[7] = t;\
t = (reinterpret_cast<char*>(&a))[1];\
(reinterpret_cast<char*>(&a))[1]=(reinterpret_cast<char*>(&a))[6];\
(reinterpret_cast<char*>(&a))[6] = t;\
t = (reinterpret_cast<char*>(&a))[2];\
(reinterpret_cast<char*>(&a))[2]=(reinterpret_cast<char*>(&a))[5];\
(reinterpret_cast<char*>(&a))[5] = t;\
t = (reinterpret_cast<char*>(&a))[3];\
(reinterpret_cast<char*>(&a))[3]=(reinterpret_cast<char*>(&a))[4];\
(reinterpret_cast<char*>(&a))[4] = t;} while (0)
#define KALDI_SWAP4(a) do { \
int t = (reinterpret_cast<char*>(&a))[0];\
(reinterpret_cast<char*>(&a))[0]=(reinterpret_cast<char*>(&a))[3];\
(reinterpret_cast<char*>(&a))[3] = t;\
t = (reinterpret_cast<char*>(&a))[1];\
(reinterpret_cast<char*>(&a))[1]=(reinterpret_cast<char*>(&a))[2];\
(reinterpret_cast<char*>(&a))[2]=t;} while (0)
#define KALDI_SWAP2(a) do { \
int t = (reinterpret_cast<char*>(&a))[0];\
(reinterpret_cast<char*>(&a))[0]=(reinterpret_cast<char*>(&a))[1];\
(reinterpret_cast<char*>(&a))[1] = t;} while (0)
// Makes copy constructor and operator= private.
#define KALDI_DISALLOW_COPY_AND_ASSIGN(type) \
type(const type&); \
void operator = (const type&)
template<bool B> class KaldiCompileTimeAssert { };
template<> class KaldiCompileTimeAssert<true> {
public:
static inline void Check() { }
};
#define KALDI_COMPILE_TIME_ASSERT(b) KaldiCompileTimeAssert<(b)>::Check()
#define KALDI_ASSERT_IS_INTEGER_TYPE(I) \
KaldiCompileTimeAssert<std::numeric_limits<I>::is_specialized \
&& std::numeric_limits<I>::is_integer>::Check()
#define KALDI_ASSERT_IS_FLOATING_TYPE(F) \
KaldiCompileTimeAssert<std::numeric_limits<F>::is_specialized \
&& !std::numeric_limits<F>::is_integer>::Check()
#if defined(_MSC_VER)
#define KALDI_STRCASECMP _stricmp
#elif defined(__CYGWIN__)
#include <strings.h>
#define KALDI_STRCASECMP strcasecmp
#else
#define KALDI_STRCASECMP strcasecmp
#endif
#ifdef _MSC_VER
# define KALDI_STRTOLL(cur_cstr, end_cstr) _strtoi64(cur_cstr, end_cstr, 10);
#else
# define KALDI_STRTOLL(cur_cstr, end_cstr) strtoll(cur_cstr, end_cstr, 10);
#endif
#endif // KALDI_BASE_KALDI_UTILS_H_
// decoder/lattice-faster-decoder.cc
// Copyright 2009-2012 Microsoft Corporation Mirko Hannemann
// 2013-2018 Johns Hopkins University (Author: Daniel Povey)
// 2014 Guoguo Chen
// 2018 Zhehuai Chen
// 2021 Binbin Zhang, Zhendong Peng
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <unordered_set>
#include "decoder/lattice-faster-decoder.h"
// #include "lat/lattice-functions.h"
namespace kaldi {
// instantiate this class once for each thing you have to decode.
template <typename FST, typename Token>
LatticeFasterDecoderTpl<FST, Token>::LatticeFasterDecoderTpl(
const FST &fst, const LatticeFasterDecoderConfig &config,
const std::shared_ptr<wenet::ContextGraph> &context_graph)
: fst_(&fst),
delete_fst_(false),
config_(config),
num_toks_(0),
context_graph_(context_graph) {
config.Check();
toks_.SetSize(
1000); // just so on the first frame we do something reasonable.
}
template <typename FST, typename Token>
LatticeFasterDecoderTpl<FST, Token>::LatticeFasterDecoderTpl(
const LatticeFasterDecoderConfig &config, FST *fst)
: fst_(fst), delete_fst_(true), config_(config), num_toks_(0) {
config.Check();
toks_.SetSize(
1000); // just so on the first frame we do something reasonable.
}
template <typename FST, typename Token>
LatticeFasterDecoderTpl<FST, Token>::~LatticeFasterDecoderTpl() {
DeleteElems(toks_.Clear());
ClearActiveTokens();
if (delete_fst_) delete fst_;
}
template <typename FST, typename Token>
void LatticeFasterDecoderTpl<FST, Token>::InitDecoding() {
// clean up from last time:
DeleteElems(toks_.Clear());
cost_offsets_.clear();
ClearActiveTokens();
warned_ = false;
num_toks_ = 0;
decoding_finalized_ = false;
final_costs_.clear();
StateId start_state = fst_->Start();
KALDI_ASSERT(start_state != fst::kNoStateId);
active_toks_.resize(1);
Token *start_tok = new Token(0.0, 0.0, NULL, NULL, NULL);
active_toks_[0].toks = start_tok;
toks_.Insert(start_state, start_tok);
num_toks_++;
ProcessNonemitting(config_.beam);
}
// Returns true if any kind of traceback is available (not necessarily from
// a final state). It should only very rarely return false; this indicates
// an unusual search error.
template <typename FST, typename Token>
bool LatticeFasterDecoderTpl<FST, Token>::Decode(
DecodableInterface *decodable) {
InitDecoding();
// We use 1-based indexing for frames in this decoder (if you view it in
// terms of features), but note that the decodable object uses zero-based
// numbering, which we have to correct for when we call it.
AdvanceDecoding(decodable);
FinalizeDecoding();
// Returns true if we have any kind of traceback available (not necessarily
// to the end state; query ReachedFinal() for that).
return !active_toks_.empty() && active_toks_.back().toks != NULL;
}
// Outputs an FST corresponding to the single best path through the lattice.
template <typename FST, typename Token>
bool LatticeFasterDecoderTpl<FST, Token>::GetBestPath(
Lattice *olat, bool use_final_probs) const {
Lattice raw_lat;
GetRawLattice(&raw_lat, use_final_probs);
ShortestPath(raw_lat, olat);
return (olat->NumStates() != 0);
}
// Outputs an FST corresponding to the raw, state-level lattice
template <typename FST, typename Token>
bool LatticeFasterDecoderTpl<FST, Token>::GetRawLattice(
Lattice *ofst, bool use_final_probs) const {
typedef LatticeArc Arc;
typedef Arc::StateId StateId;
typedef Arc::Weight Weight;
typedef Arc::Label Label;
// Note: you can't use the old interface (Decode()) if you want to
// get the lattice with use_final_probs = false. You'd have to do
// InitDecoding() and then AdvanceDecoding().
if (decoding_finalized_ && !use_final_probs)
KALDI_ERR << "You cannot call FinalizeDecoding() and then call "
<< "GetRawLattice() with use_final_probs == false";
unordered_map<Token *, BaseFloat> final_costs_local;
const unordered_map<Token *, BaseFloat> &final_costs =
(decoding_finalized_ ? final_costs_ : final_costs_local);
if (!decoding_finalized_ && use_final_probs)
ComputeFinalCosts(&final_costs_local, NULL, NULL);
ofst->DeleteStates();
// num-frames plus one (since frames are one-based, and we have
// an extra frame for the start-state).
int32 num_frames = active_toks_.size() - 1;
KALDI_ASSERT(num_frames > 0);
const int32 bucket_count = num_toks_ / 2 + 3;
unordered_map<Token *, StateId> tok_map(bucket_count);
// First create all states.
std::vector<Token *> token_list;
for (int32 f = 0; f <= num_frames; f++) {
if (active_toks_[f].toks == NULL) {
KALDI_WARN << "GetRawLattice: no tokens active on frame " << f
<< ": not producing lattice.\n";
return false;
}
TopSortTokens(active_toks_[f].toks, &token_list);
for (size_t i = 0; i < token_list.size(); i++)
if (token_list[i] != NULL) tok_map[token_list[i]] = ofst->AddState();
}
// The next statement sets the start state of the output FST. Because we
// topologically sorted the tokens, state zero must be the start-state.
ofst->SetStart(0);
KALDI_VLOG(4) << "init:" << num_toks_ / 2 + 3
<< " buckets:" << tok_map.bucket_count()
<< " load:" << tok_map.load_factor()
<< " max:" << tok_map.max_load_factor();
// Now create all arcs.
for (int32 f = 0; f <= num_frames; f++) {
for (Token *tok = active_toks_[f].toks; tok != NULL; tok = tok->next) {
StateId cur_state = tok_map[tok];
for (ForwardLinkT *l = tok->links; l != NULL; l = l->next) {
typename unordered_map<Token *, StateId>::const_iterator iter =
tok_map.find(l->next_tok);
StateId nextstate = iter->second;
KALDI_ASSERT(iter != tok_map.end());
BaseFloat cost_offset = 0.0;
if (l->ilabel != 0) { // emitting..
KALDI_ASSERT(f >= 0 && f < cost_offsets_.size());
cost_offset = cost_offsets_[f];
}
StateId state = cur_state;
if (l->is_start_boundary) {
StateId tmp = ofst->AddState();
Arc arc(0, context_graph_->start_tag_id(), Weight(0, 0), tmp);
ofst->AddArc(state, arc);
state = tmp;
}
if (l->is_end_boundary) {
StateId tmp = ofst->AddState();
Arc arc(0, context_graph_->end_tag_id(), Weight(0, 0), nextstate);
ofst->AddArc(tmp, arc);
nextstate = tmp;
}
Arc arc(l->ilabel, l->olabel,
Weight(l->graph_cost, l->acoustic_cost - cost_offset),
nextstate);
ofst->AddArc(state, arc);
}
if (f == num_frames) {
if (use_final_probs && !final_costs.empty()) {
typename unordered_map<Token *, BaseFloat>::const_iterator iter =
final_costs.find(tok);
if (iter != final_costs.end())
ofst->SetFinal(cur_state, LatticeWeight(iter->second, 0));
} else {
ofst->SetFinal(cur_state, LatticeWeight::One());
}
}
}
}
fst::TopSort(ofst);
return (ofst->NumStates() > 0);
}
// This function is now deprecated, since now we do determinization from outside
// the LatticeFasterDecoder class. Outputs an FST corresponding to the
// lattice-determinized lattice (one path per word sequence).
template <typename FST, typename Token>
bool LatticeFasterDecoderTpl<FST, Token>::GetLattice(
CompactLattice *ofst, bool use_final_probs) const {
Lattice raw_fst;
GetRawLattice(&raw_fst, use_final_probs);
Invert(&raw_fst); // make it so word labels are on the input.
// (in phase where we get backward-costs).
fst::ILabelCompare<LatticeArc> ilabel_comp;
ArcSort(&raw_fst, ilabel_comp); // sort on ilabel; makes
// lattice-determinization more efficient.
fst::DeterminizeLatticePrunedOptions lat_opts;
lat_opts.max_mem = config_.det_opts.max_mem;
DeterminizeLatticePruned(raw_fst, config_.lattice_beam, ofst, lat_opts);
raw_fst.DeleteStates(); // Free memory-- raw_fst no longer needed.
Connect(ofst); // Remove unreachable states... there might be
// a small number of these, in some cases.
// Note: if something went wrong and the raw lattice was empty,
// we should still get to this point in the code without warnings or failures.
return (ofst->NumStates() != 0);
}
template <typename FST, typename Token>
void LatticeFasterDecoderTpl<FST, Token>::PossiblyResizeHash(size_t num_toks) {
size_t new_sz = static_cast<size_t>(static_cast<BaseFloat>(num_toks) *
config_.hash_ratio);
if (new_sz > toks_.Size()) {
toks_.SetSize(new_sz);
}
}
/*
A note on the definition of extra_cost.
extra_cost is used in pruning tokens, to save memory.
extra_cost can be thought of as a beta (backward) cost assuming
we had set the betas on currently-active tokens to all be the negative
of the alphas for those tokens. (So all currently active tokens would
be on (tied) best paths).
We can use the extra_cost to accurately prune away tokens that we know will
never appear in the lattice. If the extra_cost is greater than the desired
lattice beam, the token would provably never appear in the lattice, so we can
prune away the token.
(Note: we don't update all the extra_costs every time we update a frame; we
only do it every 'config_.prune_interval' frames).
*/
// FindOrAddToken either locates a token in hash of toks_,
// or if necessary inserts a new, empty token (i.e. with no forward links)
// for the current frame. [note: it's inserted if necessary into hash toks_
// and also into the singly linked list of tokens active on this frame
// (whose head is at active_toks_[frame]).
template <typename FST, typename Token>
inline typename LatticeFasterDecoderTpl<FST, Token>::Elem *
LatticeFasterDecoderTpl<FST, Token>::FindOrAddToken(StateId state,
int32 frame_plus_one,
BaseFloat tot_cost,
Token *backpointer,
bool *changed) {
// Returns the Token pointer. Sets "changed" (if non-NULL) to true
// if the token was newly created or the cost changed.
KALDI_ASSERT(frame_plus_one < active_toks_.size());
Token *&toks = active_toks_[frame_plus_one].toks;
Elem *e_found = toks_.Insert(state, NULL);
if (e_found->val == NULL) { // no such token presently.
const BaseFloat extra_cost = 0.0;
// tokens on the currently final frame have zero extra_cost
// as any of them could end up
// on the winning path.
Token *new_tok = new Token(tot_cost, extra_cost, NULL, toks, backpointer);
// NULL: no forward links yet
toks = new_tok;
num_toks_++;
e_found->val = new_tok;
if (changed) *changed = true;
return e_found;
} else {
Token *tok = e_found->val; // There is an existing Token for this state.
if (tok->tot_cost > tot_cost) { // replace old token
tok->tot_cost = tot_cost;
// SetBackpointer() just does tok->backpointer = backpointer in
// the case where Token == BackpointerToken, else nothing.
tok->SetBackpointer(backpointer);
// we don't allocate a new token, the old stays linked in active_toks_
// we only replace the tot_cost
// in the current frame, there are no forward links (and no extra_cost)
// only in ProcessNonemitting we have to delete forward links
// in case we visit a state for the second time
// those forward links, that lead to this replaced token before:
// they remain and will hopefully be pruned later (PruneForwardLinks...)
if (changed) *changed = true;
} else {
if (changed) *changed = false;
}
return e_found;
}
}
// prunes outgoing links for all tokens in active_toks_[frame]
// it's called by PruneActiveTokens
// all links, that have link_extra_cost > lattice_beam are pruned
template <typename FST, typename Token>
void LatticeFasterDecoderTpl<FST, Token>::PruneForwardLinks(
int32 frame_plus_one, bool *extra_costs_changed, bool *links_pruned,
BaseFloat delta) {
// delta is the amount by which the extra_costs must change
// If delta is larger, we'll tend to go back less far
// toward the beginning of the file.
// extra_costs_changed is set to true if extra_cost was changed for any token
// links_pruned is set to true if any link in any token was pruned
*extra_costs_changed = false;
*links_pruned = false;
KALDI_ASSERT(frame_plus_one >= 0 && frame_plus_one < active_toks_.size());
if (active_toks_[frame_plus_one].toks ==
NULL) { // empty list; should not happen.
if (!warned_) {
KALDI_WARN << "No tokens alive [doing pruning].. warning first "
"time only for each utterance\n";
warned_ = true;
}
}
// We have to iterate until there is no more change, because the links
// are not guaranteed to be in topological order.
bool changed = true; // difference new minus old extra cost >= delta ?
while (changed) {
changed = false;
for (Token *tok = active_toks_[frame_plus_one].toks; tok != NULL;
tok = tok->next) {
ForwardLinkT *link, *prev_link = NULL;
// will recompute tok_extra_cost for tok.
BaseFloat tok_extra_cost = std::numeric_limits<BaseFloat>::infinity();
// tok_extra_cost is the best (min) of link_extra_cost of outgoing links
for (link = tok->links; link != NULL;) {
// See if we need to excise this link...
Token *next_tok = link->next_tok;
BaseFloat link_extra_cost =
next_tok->extra_cost +
((tok->tot_cost + link->acoustic_cost + link->graph_cost) -
next_tok->tot_cost); // difference in brackets is >= 0
// link_exta_cost is the difference in score between the best paths
// through link source state and through link destination state
KALDI_ASSERT(link_extra_cost == link_extra_cost); // check for NaN
// the graph_cost contatins the context score
// if it's the score of the backoff arc, it should be removed.
if (link->context_score < 0) {
link_extra_cost += link->context_score;
}
if (link_extra_cost > config_.lattice_beam) { // excise link
ForwardLinkT *next_link = link->next;
if (prev_link != NULL)
prev_link->next = next_link;
else
tok->links = next_link;
delete link;
link = next_link; // advance link but leave prev_link the same.
*links_pruned = true;
} else { // keep the link and update the tok_extra_cost if needed.
if (link_extra_cost < 0.0) { // this is just a precaution.
// if (link_extra_cost < -0.01)
// KALDI_WARN << "Negative extra_cost: " << link_extra_cost;
link_extra_cost = 0.0;
}
if (link_extra_cost < tok_extra_cost)
tok_extra_cost = link_extra_cost;
prev_link = link; // move to next link
link = link->next;
}
} // for all outgoing links
if (fabs(tok_extra_cost - tok->extra_cost) > delta)
changed = true; // difference new minus old is bigger than delta
tok->extra_cost = tok_extra_cost;
// will be +infinity or <= lattice_beam_.
// infinity indicates, that no forward link survived pruning
} // for all Token on active_toks_[frame]
if (changed) *extra_costs_changed = true;
// Note: it's theoretically possible that aggressive compiler
// optimizations could cause an infinite loop here for small delta and
// high-dynamic-range scores.
} // while changed
}
// PruneForwardLinksFinal is a version of PruneForwardLinks that we call
// on the final frame. If there are final tokens active, it uses
// the final-probs for pruning, otherwise it treats all tokens as final.
template <typename FST, typename Token>
void LatticeFasterDecoderTpl<FST, Token>::PruneForwardLinksFinal() {
KALDI_ASSERT(!active_toks_.empty());
int32 frame_plus_one = active_toks_.size() - 1;
if (active_toks_[frame_plus_one].toks ==
NULL) // empty list; should not happen.
KALDI_WARN << "No tokens alive at end of file";
typedef typename unordered_map<Token *, BaseFloat>::const_iterator IterType;
ComputeFinalCosts(&final_costs_, &final_relative_cost_, &final_best_cost_);
decoding_finalized_ = true;
// We call DeleteElems() as a nicety, not because it's really necessary;
// otherwise there would be a time, after calling PruneTokensForFrame() on the
// final frame, when toks_.GetList() or toks_.Clear() would contain pointers
// to nonexistent tokens.
DeleteElems(toks_.Clear());
// Now go through tokens on this frame, pruning forward links... may have to
// iterate a few times until there is no more change, because the list is not
// in topological order. This is a modified version of the code in
// PruneForwardLinks, but here we also take account of the final-probs.
bool changed = true;
BaseFloat delta = 1.0e-05;
while (changed) {
changed = false;
for (Token *tok = active_toks_[frame_plus_one].toks; tok != NULL;
tok = tok->next) {
ForwardLinkT *link, *prev_link = NULL;
// will recompute tok_extra_cost. It has a term in it that corresponds
// to the "final-prob", so instead of initializing tok_extra_cost to
// infinity below we set it to the difference between the
// (score+final_prob) of this token, and the best such (score+final_prob).
BaseFloat final_cost;
if (final_costs_.empty()) {
final_cost = 0.0;
} else {
IterType iter = final_costs_.find(tok);
if (iter != final_costs_.end())
final_cost = iter->second;
else
final_cost = std::numeric_limits<BaseFloat>::infinity();
}
BaseFloat tok_extra_cost = tok->tot_cost + final_cost - final_best_cost_;
// tok_extra_cost will be a "min" over either directly being final, or
// being indirectly final through other links, and the loop below may
// decrease its value:
for (link = tok->links; link != NULL;) {
// See if we need to excise this link...
Token *next_tok = link->next_tok;
BaseFloat link_extra_cost =
next_tok->extra_cost +
((tok->tot_cost + link->acoustic_cost + link->graph_cost) -
next_tok->tot_cost);
if (link_extra_cost > config_.lattice_beam) { // excise link
ForwardLinkT *next_link = link->next;
if (prev_link != NULL)
prev_link->next = next_link;
else
tok->links = next_link;
delete link;
link = next_link; // advance link but leave prev_link the same.
} else { // keep the link and update the tok_extra_cost if needed.
if (link_extra_cost < 0.0) { // this is just a precaution.
// if (link_extra_cost < -0.01)
// KALDI_WARN << "Negative extra_cost: " << link_extra_cost;
link_extra_cost = 0.0;
}
if (link_extra_cost < tok_extra_cost)
tok_extra_cost = link_extra_cost;
prev_link = link;
link = link->next;
}
}
// prune away tokens worse than lattice_beam above best path. This step
// was not necessary in the non-final case because then, this case
// showed up as having no forward links. Here, the tok_extra_cost has
// an extra component relating to the final-prob.
if (tok_extra_cost > config_.lattice_beam)
tok_extra_cost = std::numeric_limits<BaseFloat>::infinity();
// to be pruned in PruneTokensForFrame
if (!ApproxEqual(tok->extra_cost, tok_extra_cost, delta)) changed = true;
tok->extra_cost =
tok_extra_cost; // will be +infinity or <= lattice_beam_.
}
} // while changed
}
template <typename FST, typename Token>
BaseFloat LatticeFasterDecoderTpl<FST, Token>::FinalRelativeCost() const {
if (!decoding_finalized_) {
BaseFloat relative_cost;
ComputeFinalCosts(NULL, &relative_cost, NULL);
return relative_cost;
} else {
// we're not allowed to call that function if FinalizeDecoding() has
// been called; return a cached value.
return final_relative_cost_;
}
}
// Prune away any tokens on this frame that have no forward links.
// [we don't do this in PruneForwardLinks because it would give us
// a problem with dangling pointers].
// It's called by PruneActiveTokens if any forward links have been pruned
template <typename FST, typename Token>
void LatticeFasterDecoderTpl<FST, Token>::PruneTokensForFrame(
int32 frame_plus_one) {
KALDI_ASSERT(frame_plus_one >= 0 && frame_plus_one < active_toks_.size());
Token *&toks = active_toks_[frame_plus_one].toks;
if (toks == NULL) KALDI_WARN << "No tokens alive [doing pruning]";
Token *tok, *next_tok, *prev_tok = NULL;
for (tok = toks; tok != NULL; tok = next_tok) {
next_tok = tok->next;
if (tok->extra_cost == std::numeric_limits<BaseFloat>::infinity()) {
// token is unreachable from end of graph; (no forward links survived)
// excise tok from list and delete tok.
if (prev_tok != NULL)
prev_tok->next = tok->next;
else
toks = tok->next;
delete tok;
num_toks_--;
} else { // fetch next Token
prev_tok = tok;
}
}
}
// Go backwards through still-alive tokens, pruning them, starting not from
// the current frame (where we want to keep all tokens) but from the frame
// before that. We go backwards through the frames and stop when we reach a
// point where the delta-costs are not changing (and the delta controls when we
// consider a cost to have "not changed").
template <typename FST, typename Token>
void LatticeFasterDecoderTpl<FST, Token>::PruneActiveTokens(BaseFloat delta) {
int32 cur_frame_plus_one = NumFramesDecoded();
int32 num_toks_begin = num_toks_;
// The index "f" below represents a "frame plus one", i.e. you'd have to
// subtract one to get the corresponding index for the decodable object.
for (int32 f = cur_frame_plus_one - 1; f >= 0; f--) {
// Reason why we need to prune forward links in this situation:
// (1) we have never pruned them (new TokenList)
// (2) we have not yet pruned the forward links to the next f,
// after any of those tokens have changed their extra_cost.
if (active_toks_[f].must_prune_forward_links) {
bool extra_costs_changed = false, links_pruned = false;
PruneForwardLinks(f, &extra_costs_changed, &links_pruned, delta);
if (extra_costs_changed && f > 0) // any token has changed extra_cost
active_toks_[f - 1].must_prune_forward_links = true;
if (links_pruned) // any link was pruned
active_toks_[f].must_prune_tokens = true;
active_toks_[f].must_prune_forward_links = false; // job done
}
if (f + 1 < cur_frame_plus_one && // except for last f (no forward links)
active_toks_[f + 1].must_prune_tokens) {
PruneTokensForFrame(f + 1);
active_toks_[f + 1].must_prune_tokens = false;
}
}
KALDI_VLOG(4) << "PruneActiveTokens: pruned tokens from " << num_toks_begin
<< " to " << num_toks_;
}
template <typename FST, typename Token>
void LatticeFasterDecoderTpl<FST, Token>::ComputeFinalCosts(
unordered_map<Token *, BaseFloat> *final_costs,
BaseFloat *final_relative_cost, BaseFloat *final_best_cost) const {
KALDI_ASSERT(!decoding_finalized_);
if (final_costs != NULL) final_costs->clear();
const Elem *final_toks = toks_.GetList();
BaseFloat infinity = std::numeric_limits<BaseFloat>::infinity();
BaseFloat best_cost = infinity, best_cost_with_final = infinity;
while (final_toks != NULL) {
StateId state = final_toks->key;
Token *tok = final_toks->val;
const Elem *next = final_toks->tail;
BaseFloat final_cost = fst_->Final(state).Value();
BaseFloat cost = tok->tot_cost, cost_with_final = cost + final_cost;
best_cost = std::min(cost, best_cost);
best_cost_with_final = std::min(cost_with_final, best_cost_with_final);
if (final_costs != NULL && final_cost != infinity)
(*final_costs)[tok] = final_cost;
final_toks = next;
}
if (final_relative_cost != NULL) {
if (best_cost == infinity && best_cost_with_final == infinity) {
// Likely this will only happen if there are no tokens surviving.
// This seems the least bad way to handle it.
*final_relative_cost = infinity;
} else {
*final_relative_cost = best_cost_with_final - best_cost;
}
}
if (final_best_cost != NULL) {
if (best_cost_with_final != infinity) { // final-state exists.
*final_best_cost = best_cost_with_final;
} else { // no final-state exists.
*final_best_cost = best_cost;
}
}
}
template <typename FST, typename Token>
void LatticeFasterDecoderTpl<FST, Token>::AdvanceDecoding(
DecodableInterface *decodable, int32 max_num_frames) {
if (std::is_same<FST, fst::Fst<fst::StdArc> >::value) {
// if the type 'FST' is the FST base-class, then see if the FST type of fst_
// is actually VectorFst or ConstFst. If so, call the AdvanceDecoding()
// function after casting *this to the more specific type.
if (fst_->Type() == "const") {
LatticeFasterDecoderTpl<fst::ConstFst<fst::StdArc>, Token> *this_cast =
reinterpret_cast<
LatticeFasterDecoderTpl<fst::ConstFst<fst::StdArc>, Token> *>(
this);
this_cast->AdvanceDecoding(decodable, max_num_frames);
return;
} else if (fst_->Type() == "vector") {
LatticeFasterDecoderTpl<fst::VectorFst<fst::StdArc>, Token> *this_cast =
reinterpret_cast<
LatticeFasterDecoderTpl<fst::VectorFst<fst::StdArc>, Token> *>(
this);
this_cast->AdvanceDecoding(decodable, max_num_frames);
return;
}
}
KALDI_ASSERT(!active_toks_.empty() && !decoding_finalized_ &&
"You must call InitDecoding() before AdvanceDecoding");
int32 num_frames_ready = decodable->NumFramesReady();
// num_frames_ready must be >= num_frames_decoded, or else
// the number of frames ready must have decreased (which doesn't
// make sense) or the decodable object changed between calls
// (which isn't allowed).
KALDI_ASSERT(num_frames_ready >= NumFramesDecoded());
int32 target_frames_decoded = num_frames_ready;
if (max_num_frames >= 0)
target_frames_decoded =
std::min(target_frames_decoded, NumFramesDecoded() + max_num_frames);
while (NumFramesDecoded() < target_frames_decoded) {
if (NumFramesDecoded() % config_.prune_interval == 0) {
PruneActiveTokens(config_.lattice_beam * config_.prune_scale);
}
BaseFloat cost_cutoff = ProcessEmitting(decodable);
ProcessNonemitting(cost_cutoff);
}
}
// FinalizeDecoding() is a version of PruneActiveTokens that we call
// (optionally) on the final frame. Takes into account the final-prob of
// tokens. This function used to be called PruneActiveTokensFinal().
template <typename FST, typename Token>
void LatticeFasterDecoderTpl<FST, Token>::FinalizeDecoding() {
int32 final_frame_plus_one = NumFramesDecoded();
int32 num_toks_begin = num_toks_;
// PruneForwardLinksFinal() prunes final frame (with final-probs), and
// sets decoding_finalized_.
PruneForwardLinksFinal();
for (int32 f = final_frame_plus_one - 1; f >= 0; f--) {
bool b1, b2; // values not used.
BaseFloat dontcare = 0.0; // delta of zero means we must always update
PruneForwardLinks(f, &b1, &b2, dontcare);
PruneTokensForFrame(f + 1);
}
PruneTokensForFrame(0);
KALDI_VLOG(4) << "pruned tokens from " << num_toks_begin << " to "
<< num_toks_;
}
/// Gets the weight cutoff. Also counts the active tokens.
template <typename FST, typename Token>
BaseFloat LatticeFasterDecoderTpl<FST, Token>::GetCutoff(
Elem *list_head, size_t *tok_count, BaseFloat *adaptive_beam,
Elem **best_elem) {
BaseFloat best_weight = std::numeric_limits<BaseFloat>::infinity();
// positive == high cost == bad.
size_t count = 0;
if (config_.max_active == std::numeric_limits<int32>::max() &&
config_.min_active == 0) {
for (Elem *e = list_head; e != NULL; e = e->tail, count++) {
BaseFloat w = static_cast<BaseFloat>(e->val->tot_cost);
if (w < best_weight) {
best_weight = w;
if (best_elem) *best_elem = e;
}
}
if (tok_count != NULL) *tok_count = count;
if (adaptive_beam != NULL) *adaptive_beam = config_.beam;
return best_weight + config_.beam;
} else {
tmp_array_.clear();
for (Elem *e = list_head; e != NULL; e = e->tail, count++) {
BaseFloat w = e->val->tot_cost;
tmp_array_.push_back(w);
if (w < best_weight) {
best_weight = w;
if (best_elem) *best_elem = e;
}
}
if (tok_count != NULL) *tok_count = count;
BaseFloat beam_cutoff = best_weight + config_.beam,
min_active_cutoff = std::numeric_limits<BaseFloat>::infinity(),
max_active_cutoff = std::numeric_limits<BaseFloat>::infinity();
KALDI_VLOG(6) << "Number of tokens active on frame " << NumFramesDecoded()
<< " is " << tmp_array_.size();
if (tmp_array_.size() > static_cast<size_t>(config_.max_active)) {
std::nth_element(tmp_array_.begin(),
tmp_array_.begin() + config_.max_active,
tmp_array_.end());
max_active_cutoff = tmp_array_[config_.max_active];
}
if (max_active_cutoff < beam_cutoff) { // max_active is tighter than beam.
if (adaptive_beam)
*adaptive_beam = max_active_cutoff - best_weight + config_.beam_delta;
return max_active_cutoff;
}
if (tmp_array_.size() > static_cast<size_t>(config_.min_active)) {
if (config_.min_active == 0) {
min_active_cutoff = best_weight;
} else {
std::nth_element(
tmp_array_.begin(), tmp_array_.begin() + config_.min_active,
tmp_array_.size() > static_cast<size_t>(config_.max_active)
? tmp_array_.begin() + config_.max_active
: tmp_array_.end());
min_active_cutoff = tmp_array_[config_.min_active];
}
}
if (min_active_cutoff > beam_cutoff) { // min_active is looser than beam.
if (adaptive_beam)
*adaptive_beam = min_active_cutoff - best_weight + config_.beam_delta;
return min_active_cutoff;
} else {
*adaptive_beam = config_.beam;
return beam_cutoff;
}
}
}
template <typename FST, typename Token>
BaseFloat LatticeFasterDecoderTpl<FST, Token>::ProcessEmitting(
DecodableInterface *decodable) {
KALDI_ASSERT(active_toks_.size() > 0);
int32 frame =
active_toks_.size() - 1; // frame is the frame-index
// (zero-based) used to get likelihoods
// from the decodable object.
active_toks_.resize(active_toks_.size() + 1);
Elem *final_toks =
toks_.Clear(); // analogous to swapping prev_toks_ / cur_toks_
// in simple-decoder.h. Removes the Elems from
// being indexed in the hash in toks_.
Elem *best_elem = NULL;
BaseFloat adaptive_beam;
size_t tok_cnt;
BaseFloat cur_cutoff =
GetCutoff(final_toks, &tok_cnt, &adaptive_beam, &best_elem);
KALDI_VLOG(6) << "Adaptive beam on frame " << NumFramesDecoded() << " is "
<< adaptive_beam;
PossiblyResizeHash(
tok_cnt); // This makes sure the hash is always big enough.
BaseFloat next_cutoff = std::numeric_limits<BaseFloat>::infinity();
// pruning "online" before having seen all tokens
BaseFloat cost_offset = 0.0; // Used to keep probabilities in a good
// dynamic range.
// First process the best token to get a hopefully
// reasonably tight bound on the next cutoff. The only
// products of the next block are "next_cutoff" and "cost_offset".
if (best_elem) {
StateId state = best_elem->key;
Token *tok = best_elem->val;
cost_offset = -tok->tot_cost;
for (fst::ArcIterator<FST> aiter(*fst_, state); !aiter.Done();
aiter.Next()) {
const Arc &arc = aiter.Value();
if (arc.ilabel != 0) { // propagate..
BaseFloat new_weight = arc.weight.Value() + cost_offset -
decodable->LogLikelihood(frame, arc.ilabel) +
tok->tot_cost;
if (state != arc.nextstate) {
new_weight += config_.length_penalty;
}
if (new_weight + adaptive_beam < next_cutoff)
next_cutoff = new_weight + adaptive_beam;
}
}
}
// Store the offset on the acoustic likelihoods that we're applying.
// Could just do cost_offsets_.push_back(cost_offset), but we
// do it this way as it's more robust to future code changes.
cost_offsets_.resize(frame + 1, 0.0);
cost_offsets_[frame] = cost_offset;
// the tokens are now owned here, in final_toks, and the hash is empty.
// 'owned' is a complex thing here; the point is we need to call DeleteElem
// on each elem 'e' to let toks_ know we're done with them.
for (Elem *e = final_toks, *e_tail; e != NULL; e = e_tail) {
// loop this way because we delete "e" as we go.
StateId state = e->key;
Token *tok = e->val;
if (tok->tot_cost <= cur_cutoff) {
for (fst::ArcIterator<FST> aiter(*fst_, state); !aiter.Done();
aiter.Next()) {
const Arc &arc = aiter.Value();
if (arc.ilabel != 0) { // propagate..
BaseFloat ac_cost = cost_offset -
decodable->LogLikelihood(frame, arc.ilabel),
graph_cost = arc.weight.Value();
if (state != arc.nextstate) {
graph_cost += config_.length_penalty;
}
BaseFloat cur_cost = tok->tot_cost,
tot_cost = cur_cost + ac_cost + graph_cost;
if (tot_cost >= next_cutoff)
continue;
else if (tot_cost + adaptive_beam < next_cutoff)
next_cutoff =
tot_cost + adaptive_beam; // prune by best current token
// Note: the frame indexes into active_toks_ are one-based,
// hence the + 1.
Elem *e_next =
FindOrAddToken(arc.nextstate, frame + 1, tot_cost, tok, NULL);
// NULL: no change indicator needed
bool is_start_boundary = false;
bool is_end_boundary = false;
float context_score = 0;
if (context_graph_) {
if (arc.olabel == 0) {
e_next->val->context_state = tok->context_state;
} else {
e_next->val->context_state = context_graph_->GetNextState(
tok->context_state, arc.olabel, &context_score,
&is_start_boundary, &is_end_boundary);
graph_cost -= context_score;
}
}
// Add ForwardLink from tok to next_tok (put on head of list
// tok->links)
tok->links = new ForwardLinkT(e_next->val, arc.ilabel, arc.olabel,
graph_cost, ac_cost, is_start_boundary,
is_end_boundary, tok->links);
tok->links->context_score = context_score;
}
} // for all arcs
}
e_tail = e->tail;
toks_.Delete(e); // delete Elem
}
return next_cutoff;
}
// static inline
template <typename FST, typename Token>
void LatticeFasterDecoderTpl<FST, Token>::DeleteForwardLinks(Token *tok) {
ForwardLinkT *l = tok->links, *m;
while (l != NULL) {
m = l->next;
delete l;
l = m;
}
tok->links = NULL;
}
template <typename FST, typename Token>
void LatticeFasterDecoderTpl<FST, Token>::ProcessNonemitting(BaseFloat cutoff) {
KALDI_ASSERT(!active_toks_.empty());
int32 frame = static_cast<int32>(active_toks_.size()) - 2;
// Note: "frame" is the time-index we just processed, or -1 if
// we are processing the nonemitting transitions before the
// first frame (called from InitDecoding()).
// Processes nonemitting arcs for one frame. Propagates within toks_.
// Note-- this queue structure is not very optimal as
// it may cause us to process states unnecessarily (e.g. more than once),
// but in the baseline code, turning this vector into a set to fix this
// problem did not improve overall speed.
KALDI_ASSERT(queue_.empty());
if (toks_.GetList() == NULL) {
if (!warned_) {
KALDI_WARN << "Error, no surviving tokens: frame is " << frame;
warned_ = true;
}
}
int before = 0, after = 0;
for (const Elem *e = toks_.GetList(); e != NULL; e = e->tail) {
StateId state = e->key;
if (fst_->NumInputEpsilons(state) != 0) queue_.push_back(e);
++before;
}
while (!queue_.empty()) {
++after;
const Elem *e = queue_.back();
queue_.pop_back();
StateId state = e->key;
Token *tok =
e->val; // would segfault if e is a NULL pointer but this can't happen.
BaseFloat cur_cost = tok->tot_cost;
if (cur_cost >= cutoff) // Don't bother processing successors.
continue;
// If "tok" has any existing forward links, delete them,
// because we're about to regenerate them. This is a kind
// of non-optimality (remember, this is the simple decoder),
// but since most states are emitting it's not a huge issue.
DeleteForwardLinks(tok); // necessary when re-visiting
tok->links = NULL;
for (fst::ArcIterator<FST> aiter(*fst_, state); !aiter.Done();
aiter.Next()) {
const Arc &arc = aiter.Value();
if (arc.ilabel == 0) { // propagate nonemitting only...
BaseFloat graph_cost = arc.weight.Value(),
tot_cost = cur_cost + graph_cost;
if (tot_cost < cutoff) {
bool changed;
Elem *e_new =
FindOrAddToken(arc.nextstate, frame + 1, tot_cost, tok, &changed);
bool is_start_boundary = false;
bool is_end_boundary = false;
float context_score = 0;
if (context_graph_) {
if (arc.olabel == 0) {
e_new->val->context_state = tok->context_state;
} else {
e_new->val->context_state = context_graph_->GetNextState(
tok->context_state, arc.olabel, &context_score,
&is_start_boundary, &is_end_boundary);
graph_cost -= context_score;
}
}
tok->links =
new ForwardLinkT(e_new->val, 0, arc.olabel, graph_cost, 0,
is_start_boundary, is_end_boundary, tok->links);
tok->links->context_score = context_score;
// "changed" tells us whether the new token has a different
// cost from before, or is new [if so, add into queue].
if (changed && fst_->NumInputEpsilons(arc.nextstate) != 0)
queue_.push_back(e_new);
}
}
} // for all arcs
} // while queue not empty
KALDI_VLOG(3) << "ProcessNonemitting " << before << " " << after;
}
template <typename FST, typename Token>
void LatticeFasterDecoderTpl<FST, Token>::DeleteElems(Elem *list) {
for (Elem *e = list, *e_tail; e != NULL; e = e_tail) {
e_tail = e->tail;
toks_.Delete(e);
}
}
template <typename FST, typename Token>
void LatticeFasterDecoderTpl<
FST, Token>::ClearActiveTokens() { // a cleanup routine, at utt end/begin
for (size_t i = 0; i < active_toks_.size(); i++) {
// Delete all tokens alive on this frame, and any forward
// links they may have.
for (Token *tok = active_toks_[i].toks; tok != NULL;) {
DeleteForwardLinks(tok);
Token *next_tok = tok->next;
delete tok;
num_toks_--;
tok = next_tok;
}
}
active_toks_.clear();
KALDI_ASSERT(num_toks_ == 0);
}
// static
template <typename FST, typename Token>
void LatticeFasterDecoderTpl<FST, Token>::TopSortTokens(
Token *tok_list, std::vector<Token *> *topsorted_list) {
unordered_map<Token *, int32> token2pos;
using std::unordered_set;
typedef typename unordered_map<Token *, int32>::iterator IterType;
int32 num_toks = 0;
for (Token *tok = tok_list; tok != NULL; tok = tok->next) num_toks++;
int32 cur_pos = 0;
// We assign the tokens numbers num_toks - 1, ... , 2, 1, 0.
// This is likely to be in closer to topological order than
// if we had given them ascending order, because of the way
// new tokens are put at the front of the list.
for (Token *tok = tok_list; tok != NULL; tok = tok->next)
token2pos[tok] = num_toks - ++cur_pos;
unordered_set<Token *> reprocess;
for (IterType iter = token2pos.begin(); iter != token2pos.end(); ++iter) {
Token *tok = iter->first;
int32 pos = iter->second;
for (ForwardLinkT *link = tok->links; link != NULL; link = link->next) {
if (link->ilabel == 0) {
// We only need to consider epsilon links, since non-epsilon links
// transition between frames and this function only needs to sort a list
// of tokens from a single frame.
IterType following_iter = token2pos.find(link->next_tok);
if (following_iter != token2pos.end()) { // another token on this
// frame, so must consider it.
int32 next_pos = following_iter->second;
if (next_pos < pos) { // reassign the position of the next Token.
following_iter->second = cur_pos++;
reprocess.insert(link->next_tok);
}
}
}
}
// In case we had previously assigned this token to be reprocessed, we can
// erase it from that set because it's "happy now" (we just processed it).
reprocess.erase(tok);
}
size_t max_loop = 1000000,
loop_count; // max_loop is to detect epsilon cycles.
for (loop_count = 0; !reprocess.empty() && loop_count < max_loop;
++loop_count) {
std::vector<Token *> reprocess_vec;
for (typename unordered_set<Token *>::iterator iter = reprocess.begin();
iter != reprocess.end(); ++iter)
reprocess_vec.push_back(*iter);
reprocess.clear();
for (typename std::vector<Token *>::iterator iter = reprocess_vec.begin();
iter != reprocess_vec.end(); ++iter) {
Token *tok = *iter;
int32 pos = token2pos[tok];
// Repeat the processing we did above (for comments, see above).
for (ForwardLinkT *link = tok->links; link != NULL; link = link->next) {
if (link->ilabel == 0) {
IterType following_iter = token2pos.find(link->next_tok);
if (following_iter != token2pos.end()) {
int32 next_pos = following_iter->second;
if (next_pos < pos) {
following_iter->second = cur_pos++;
reprocess.insert(link->next_tok);
}
}
}
}
}
}
KALDI_ASSERT(loop_count < max_loop &&
"Epsilon loops exist in your decoding "
"graph (this is not allowed!)");
topsorted_list->clear();
topsorted_list->resize(cur_pos,
NULL); // create a list with NULLs in between.
for (IterType iter = token2pos.begin(); iter != token2pos.end(); ++iter)
(*topsorted_list)[iter->second] = iter->first;
}
// Instantiate the template for the combination of token types and FST types
// that we'll need.
template class LatticeFasterDecoderTpl<fst::Fst<fst::StdArc>,
decoder::StdToken>;
template class LatticeFasterDecoderTpl<fst::VectorFst<fst::StdArc>,
decoder::StdToken>;
template class LatticeFasterDecoderTpl<fst::ConstFst<fst::StdArc>,
decoder::StdToken>;
// template class LatticeFasterDecoderTpl<fst::ConstGrammarFst,
// decoder::StdToken>; template class
// LatticeFasterDecoderTpl<fst::VectorGrammarFst, decoder::StdToken>;
template class LatticeFasterDecoderTpl<fst::Fst<fst::StdArc>,
decoder::BackpointerToken>;
template class LatticeFasterDecoderTpl<fst::VectorFst<fst::StdArc>,
decoder::BackpointerToken>;
template class LatticeFasterDecoderTpl<fst::ConstFst<fst::StdArc>,
decoder::BackpointerToken>;
// template class LatticeFasterDecoderTpl<fst::ConstGrammarFst,
// decoder::BackpointerToken>; template class
// LatticeFasterDecoderTpl<fst::VectorGrammarFst, decoder::BackpointerToken>;
} // end namespace kaldi.
// decoder/lattice-faster-decoder.h
// Copyright 2009-2013 Microsoft Corporation; Mirko Hannemann;
// 2013-2014 Johns Hopkins University (Author: Daniel Povey)
// 2014 Guoguo Chen
// 2018 Zhehuai Chen
// 2021 Binbin Zhang, Zhendong Peng
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_DECODER_LATTICE_FASTER_DECODER_H_
#define KALDI_DECODER_LATTICE_FASTER_DECODER_H_
#include <limits>
#include <memory>
#include <unordered_map>
#include <vector>
#include "base/kaldi-common.h"
#include "decoder/context_graph.h"
#include "fst/fstlib.h"
#include "fstext/fstext-lib.h"
#include "itf/decodable-itf.h"
#include "lat/determinize-lattice-pruned.h"
#include "lat/kaldi-lattice.h"
#include "util/hash-list.h"
namespace kaldi {
struct LatticeFasterDecoderConfig {
BaseFloat beam;
int32 max_active;
int32 min_active;
BaseFloat lattice_beam;
int32 prune_interval;
bool determinize_lattice; // not inspected by this class... used in
// command-line program.
BaseFloat beam_delta;
BaseFloat hash_ratio;
// Note: we don't make prune_scale configurable on the command line, it's not
// a very important parameter. It affects the algorithm that prunes the
// tokens as we go.
BaseFloat prune_scale;
BaseFloat length_penalty; // for balancing the del/ins ratio, suggested -3.0
// Most of the options inside det_opts are not actually queried by the
// LatticeFasterDecoder class itself, but by the code that calls it, for
// example in the function DecodeUtteranceLatticeFaster.
fst::DeterminizeLatticePhonePrunedOptions det_opts;
LatticeFasterDecoderConfig()
: beam(16.0),
max_active(std::numeric_limits<int32>::max()),
min_active(200),
lattice_beam(10.0),
prune_interval(25),
determinize_lattice(true),
beam_delta(0.5),
hash_ratio(2.0),
prune_scale(0.1),
length_penalty(0.0) {}
void Register(OptionsItf *opts) {
det_opts.Register(opts);
opts->Register("beam", &beam,
"Decoding beam. Larger->slower, more accurate.");
opts->Register("max-active", &max_active,
"Decoder max active states. Larger->slower; "
"more accurate");
opts->Register("min-active", &min_active,
"Decoder minimum #active states.");
opts->Register("lattice-beam", &lattice_beam,
"Lattice generation beam. Larger->slower, "
"and deeper lattices");
opts->Register("prune-interval", &prune_interval,
"Interval (in frames) at "
"which to prune tokens");
opts->Register(
"determinize-lattice", &determinize_lattice,
"If true, "
"determinize the lattice (lattice-determinization, keeping only "
"best pdf-sequence for each word-sequence).");
opts->Register(
"beam-delta", &beam_delta,
"Increment used in decoding-- this "
"parameter is obscure and relates to a speedup in the way the "
"max-active constraint is applied. Larger is more accurate.");
opts->Register("hash-ratio", &hash_ratio,
"Setting used in decoder to "
"control hash behavior");
}
void Check() const {
KALDI_ASSERT(beam > 0.0 && max_active > 1 && lattice_beam > 0.0 &&
min_active <= max_active && prune_interval > 0 &&
beam_delta > 0.0 && hash_ratio >= 1.0 && prune_scale > 0.0 &&
prune_scale < 1.0);
}
};
namespace decoder {
// We will template the decoder on the token type as well as the FST type; this
// is a mechanism so that we can use the same underlying decoder code for
// versions of the decoder that support quickly getting the best path
// (LatticeFasterOnlineDecoder, see lattice-faster-online-decoder.h) and also
// those that do not (LatticeFasterDecoder).
// ForwardLinks are the links from a token to a token on the next frame.
// or sometimes on the current frame (for input-epsilon links).
template <typename Token>
struct ForwardLink {
using Label = fst::StdArc::Label;
Token *next_tok; // the next token [or NULL if represents final-state]
Label ilabel; // ilabel on arc
Label olabel; // olabel on arc
BaseFloat graph_cost; // graph cost of traversing arc (contains LM, etc.)
BaseFloat acoustic_cost; // acoustic cost (pre-scaled) of traversing arc
bool is_start_boundary;
bool is_end_boundary;
float context_score;
ForwardLink *next; // next in singly-linked list of forward arcs (arcs
// in the state-level lattice) from a token.
inline ForwardLink(Token *next_tok, Label ilabel, Label olabel,
BaseFloat graph_cost, BaseFloat acoustic_cost,
bool is_start_boundary, bool is_end_boundary,
ForwardLink *next)
: next_tok(next_tok),
ilabel(ilabel),
olabel(olabel),
graph_cost(graph_cost),
acoustic_cost(acoustic_cost),
is_start_boundary(is_start_boundary),
is_end_boundary(is_end_boundary),
context_score(0),
next(next) {}
};
struct StdToken {
using ForwardLinkT = ForwardLink<StdToken>;
using Token = StdToken;
// Standard token type for LatticeFasterDecoder. Each active HCLG
// (decoding-graph) state on each frame has one token.
// tot_cost is the total (LM + acoustic) cost from the beginning of the
// utterance up to this point. (but see cost_offset_, which is subtracted
// to keep it in a good numerical range).
BaseFloat tot_cost;
// exta_cost is >= 0. After calling PruneForwardLinks, this equals the
// minimum difference between the cost of the best path that this link is a
// part of, and the cost of the absolute best path, under the assumption that
// any of the currently active states at the decoding front may eventually
// succeed (e.g. if you were to take the currently active states one by one
// and compute this difference, and then take the minimum).
BaseFloat extra_cost;
int context_state = 0;
// 'links' is the head of singly-linked list of ForwardLinks, which is what we
// use for lattice generation.
ForwardLinkT *links;
// 'next' is the next in the singly-linked list of tokens for this frame.
Token *next;
// This function does nothing and should be optimized out; it's needed
// so we can share the regular LatticeFasterDecoderTpl code and the code
// for LatticeFasterOnlineDecoder that supports fast traceback.
inline void SetBackpointer(Token *backpointer) {}
// This constructor just ignores the 'backpointer' argument. That argument is
// needed so that we can use the same decoder code for LatticeFasterDecoderTpl
// and LatticeFasterOnlineDecoderTpl (which needs backpointers to support a
// fast way to obtain the best path).
inline StdToken(BaseFloat tot_cost, BaseFloat extra_cost, ForwardLinkT *links,
Token *next, Token *backpointer)
: tot_cost(tot_cost),
extra_cost(extra_cost),
links(links),
context_state(0),
next(next) {}
};
struct BackpointerToken {
using ForwardLinkT = ForwardLink<BackpointerToken>;
using Token = BackpointerToken;
// BackpointerToken is like Token but also
// Standard token type for LatticeFasterDecoder. Each active HCLG
// (decoding-graph) state on each frame has one token.
// tot_cost is the total (LM + acoustic) cost from the beginning of the
// utterance up to this point. (but see cost_offset_, which is subtracted
// to keep it in a good numerical range).
BaseFloat tot_cost;
// exta_cost is >= 0. After calling PruneForwardLinks, this equals
// the minimum difference between the cost of the best path, and the cost of
// this is on, and the cost of the absolute best path, under the assumption
// that any of the currently active states at the decoding front may
// eventually succeed (e.g. if you were to take the currently active states
// one by one and compute this difference, and then take the minimum).
BaseFloat extra_cost;
int context_state = 0;
// 'links' is the head of singly-linked list of ForwardLinks, which is what we
// use for lattice generation.
ForwardLinkT *links;
// 'next' is the next in the singly-linked list of tokens for this frame.
BackpointerToken *next;
// Best preceding BackpointerToken (could be a on this frame, connected to
// this via an epsilon transition, or on a previous frame). This is only
// required for an efficient GetBestPath function in
// LatticeFasterOnlineDecoderTpl; it plays no part in the lattice generation
// (the "links" list is what stores the forward links, for that).
Token *backpointer;
inline void SetBackpointer(Token *backpointer) {
this->backpointer = backpointer;
}
inline BackpointerToken(BaseFloat tot_cost, BaseFloat extra_cost,
ForwardLinkT *links, Token *next, Token *backpointer)
: tot_cost(tot_cost),
extra_cost(extra_cost),
links(links),
next(next),
backpointer(backpointer),
context_state(0) {}
};
} // namespace decoder
/** This is the "normal" lattice-generating decoder.
See \ref lattices_generation \ref decoders_faster and \ref decoders_simple
for more information.
The decoder is templated on the FST type and the token type. The token type
will normally be StdToken, but also may be BackpointerToken which is to
support quick lookup of the current best path (see
lattice-faster-online-decoder.h)
The FST you invoke this decoder which is expected to equal
Fst::Fst<fst::StdArc>, a.k.a. StdFst, or GrammarFst. If you invoke it with
FST == StdFst and it notices that the actual FST type is
fst::VectorFst<fst::StdArc> or fst::ConstFst<fst::StdArc>, the decoder object
will internally cast itself to one that is templated on those more specific
types; this is an optimization for speed.
*/
template <typename FST, typename Token = decoder::StdToken>
class LatticeFasterDecoderTpl {
public:
using Arc = typename FST::Arc;
using Label = typename Arc::Label;
using StateId = typename Arc::StateId;
using Weight = typename Arc::Weight;
using ForwardLinkT = decoder::ForwardLink<Token>;
// Instantiate this class once for each thing you have to decode.
// This version of the constructor does not take ownership of
// 'fst'.
LatticeFasterDecoderTpl(
const FST &fst, const LatticeFasterDecoderConfig &config,
const std::shared_ptr<wenet::ContextGraph> &context_graph);
// This version of the constructor takes ownership of the fst, and will delete
// it when this object is destroyed.
LatticeFasterDecoderTpl(const LatticeFasterDecoderConfig &config, FST *fst);
void SetOptions(const LatticeFasterDecoderConfig &config) {
config_ = config;
}
const LatticeFasterDecoderConfig &GetOptions() const { return config_; }
~LatticeFasterDecoderTpl();
/// Decodes until there are no more frames left in the "decodable" object..
/// note, this may block waiting for input if the "decodable" object blocks.
/// Returns true if any kind of traceback is available (not necessarily from a
/// final state).
bool Decode(DecodableInterface *decodable);
/// says whether a final-state was active on the last frame. If it was not,
/// the lattice (or traceback) will end with states that are not final-states.
bool ReachedFinal() const {
return FinalRelativeCost() != std::numeric_limits<BaseFloat>::infinity();
}
/// Outputs an FST corresponding to the single best path through the lattice.
/// Returns true if result is nonempty (using the return status is deprecated,
/// it will become void). If "use_final_probs" is true AND we reached the
/// final-state of the graph then it will include those as final-probs, else
/// it will treat all final-probs as one. Note: this just calls
/// GetRawLattice() and figures out the shortest path.
bool GetBestPath(Lattice *ofst, bool use_final_probs = true) const;
/// Outputs an FST corresponding to the raw, state-level
/// tracebacks. Returns true if result is nonempty.
/// If "use_final_probs" is true AND we reached the final-state
/// of the graph then it will include those as final-probs, else
/// it will treat all final-probs as one.
/// The raw lattice will be topologically sorted.
///
/// See also GetRawLatticePruned in lattice-faster-online-decoder.h,
/// which also supports a pruning beam, in case for some reason
/// you want it pruned tighter than the regular lattice beam.
/// We could put that here in future needed.
bool GetRawLattice(Lattice *ofst, bool use_final_probs = true) const;
/// [Deprecated, users should now use GetRawLattice and determinize it
/// themselves, e.g. using DeterminizeLatticePhonePrunedWrapper].
/// Outputs an FST corresponding to the lattice-determinized
/// lattice (one path per word sequence). Returns true if result is
/// nonempty. If "use_final_probs" is true AND we reached the final-state of
/// the graph then it will include those as final-probs, else it will treat
/// all final-probs as one.
bool GetLattice(CompactLattice *ofst, bool use_final_probs = true) const;
/// InitDecoding initializes the decoding, and should only be used if you
/// intend to call AdvanceDecoding(). If you call Decode(), you don't need to
/// call this. You can also call InitDecoding if you have already decoded an
/// utterance and want to start with a new utterance.
void InitDecoding();
/// This will decode until there are no more frames ready in the decodable
/// object. You can keep calling it each time more frames become available.
/// If max_num_frames is specified, it specifies the maximum number of frames
/// the function will decode before returning.
void AdvanceDecoding(DecodableInterface *decodable,
int32 max_num_frames = -1);
/// This function may be optionally called after AdvanceDecoding(), when you
/// do not plan to decode any further. It does an extra pruning step that
/// will help to prune the lattices output by GetLattice and (particularly)
/// GetRawLattice more completely, particularly toward the end of the
/// utterance. If you call this, you cannot call AdvanceDecoding again (it
/// will fail), and you cannot call GetLattice() and related functions with
/// use_final_probs = false. Used to be called PruneActiveTokensFinal().
void FinalizeDecoding();
/// FinalRelativeCost() serves the same purpose as ReachedFinal(), but gives
/// more information. It returns the difference between the best (final-cost
/// plus cost) of any token on the final frame, and the best cost of any token
/// on the final frame. If it is infinity it means no final-states were
/// present on the final frame. It will usually be nonnegative. If it not
/// too positive (e.g. < 5 is my first guess, but this is not tested) you can
/// take it as a good indication that we reached the final-state with
/// reasonable likelihood.
BaseFloat FinalRelativeCost() const;
// Returns the number of frames decoded so far. The value returned changes
// whenever we call ProcessEmitting().
inline int32 NumFramesDecoded() const { return active_toks_.size() - 1; }
protected:
// we make things protected instead of private, as code in
// LatticeFasterOnlineDecoderTpl, which inherits from this, also uses the
// internals.
// Deletes the elements of the singly linked list tok->links.
inline static void DeleteForwardLinks(Token *tok);
// head of per-frame list of Tokens (list is in topological order),
// and something saying whether we ever pruned it using PruneForwardLinks.
struct TokenList {
Token *toks;
bool must_prune_forward_links;
bool must_prune_tokens;
TokenList()
: toks(NULL), must_prune_forward_links(true), must_prune_tokens(true) {}
};
using Elem = typename HashList<StateId, Token *>::Elem;
// Equivalent to:
// struct Elem {
// StateId key;
// Token *val;
// Elem *tail;
// };
void PossiblyResizeHash(size_t num_toks);
// FindOrAddToken either locates a token in hash of toks_, or if necessary
// inserts a new, empty token (i.e. with no forward links) for the current
// frame. [note: it's inserted if necessary into hash toks_ and also into the
// singly linked list of tokens active on this frame (whose head is at
// active_toks_[frame]). The frame_plus_one argument is the acoustic frame
// index plus one, which is used to index into the active_toks_ array.
// Returns the Token pointer. Sets "changed" (if non-NULL) to true if the
// token was newly created or the cost changed.
// If Token == StdToken, the 'backpointer' argument has no purpose (and will
// hopefully be optimized out).
inline Elem *FindOrAddToken(StateId state, int32 frame_plus_one,
BaseFloat tot_cost, Token *backpointer,
bool *changed);
// prunes outgoing links for all tokens in active_toks_[frame]
// it's called by PruneActiveTokens
// all links, that have link_extra_cost > lattice_beam are pruned
// delta is the amount by which the extra_costs must change
// before we set *extra_costs_changed = true.
// If delta is larger, we'll tend to go back less far
// toward the beginning of the file.
// extra_costs_changed is set to true if extra_cost was changed for any token
// links_pruned is set to true if any link in any token was pruned
void PruneForwardLinks(int32 frame_plus_one, bool *extra_costs_changed,
bool *links_pruned, BaseFloat delta);
// This function computes the final-costs for tokens active on the final
// frame. It outputs to final-costs, if non-NULL, a map from the Token*
// pointer to the final-prob of the corresponding state, for all Tokens
// that correspond to states that have final-probs. This map will be
// empty if there were no final-probs. It outputs to
// final_relative_cost, if non-NULL, the difference between the best
// forward-cost including the final-prob cost, and the best forward-cost
// without including the final-prob cost (this will usually be positive), or
// infinity if there were no final-probs. [c.f. FinalRelativeCost(), which
// outputs this quanitity]. It outputs to final_best_cost, if
// non-NULL, the lowest for any token t active on the final frame, of
// forward-cost[t] + final-cost[t], where final-cost[t] is the final-cost in
// the graph of the state corresponding to token t, or the best of
// forward-cost[t] if there were no final-probs active on the final frame.
// You cannot call this after FinalizeDecoding() has been called; in that
// case you should get the answer from class-member variables.
void ComputeFinalCosts(unordered_map<Token *, BaseFloat> *final_costs,
BaseFloat *final_relative_cost,
BaseFloat *final_best_cost) const;
// PruneForwardLinksFinal is a version of PruneForwardLinks that we call
// on the final frame. If there are final tokens active, it uses
// the final-probs for pruning, otherwise it treats all tokens as final.
void PruneForwardLinksFinal();
// Prune away any tokens on this frame that have no forward links.
// [we don't do this in PruneForwardLinks because it would give us
// a problem with dangling pointers].
// It's called by PruneActiveTokens if any forward links have been pruned
void PruneTokensForFrame(int32 frame_plus_one);
// Go backwards through still-alive tokens, pruning them if the
// forward+backward cost is more than lat_beam away from the best path. It's
// possible to prove that this is "correct" in the sense that we won't lose
// anything outside of lat_beam, regardless of what happens in the future.
// delta controls when it considers a cost to have changed enough to continue
// going backward and propagating the change. larger delta -> will recurse
// less far.
void PruneActiveTokens(BaseFloat delta);
/// Gets the weight cutoff. Also counts the active tokens.
BaseFloat GetCutoff(Elem *list_head, size_t *tok_count,
BaseFloat *adaptive_beam, Elem **best_elem);
/// Processes emitting arcs for one frame. Propagates from prev_toks_ to
/// cur_toks_. Returns the cost cutoff for subsequent ProcessNonemitting() to
/// use.
BaseFloat ProcessEmitting(DecodableInterface *decodable);
/// Processes nonemitting (epsilon) arcs for one frame. Called after
/// ProcessEmitting() on each frame. The cost cutoff is computed by the
/// preceding ProcessEmitting().
void ProcessNonemitting(BaseFloat cost_cutoff);
// HashList defined in ../util/hash-list.h. It actually allows us to maintain
// more than one list (e.g. for current and previous frames), but only one of
// them at a time can be indexed by StateId. It is indexed by frame-index
// plus one, where the frame-index is zero-based, as used in decodable object.
// That is, the emitting probs of frame t are accounted for in tokens at
// toks_[t+1]. The zeroth frame is for nonemitting transition at the start of
// the graph.
HashList<StateId, Token *> toks_;
std::vector<TokenList> active_toks_; // Lists of tokens, indexed by
// frame (members of TokenList are toks, must_prune_forward_links,
// must_prune_tokens).
std::vector<const Elem *>
queue_; // temp variable used in ProcessNonemitting,
std::vector<BaseFloat> tmp_array_; // used in GetCutoff.
// fst_ is a pointer to the FST we are decoding from.
const FST *fst_;
// delete_fst_ is true if the pointer fst_ needs to be deleted when this
// object is destroyed.
bool delete_fst_;
std::vector<BaseFloat> cost_offsets_; // This contains, for each
// frame, an offset that was added to the acoustic log-likelihoods on that
// frame in order to keep everything in a nice dynamic range i.e. close to
// zero, to reduce roundoff errors.
LatticeFasterDecoderConfig config_;
int32 num_toks_; // current total #toks allocated...
bool warned_;
/// decoding_finalized_ is true if someone called FinalizeDecoding(). [note,
/// calling this is optional]. If true, it's forbidden to decode more. Also,
/// if this is set, then the output of ComputeFinalCosts() is in the next
/// three variables. The reason we need to do this is that after
/// FinalizeDecoding() calls PruneTokensForFrame() for the final frame, some
/// of the tokens on the last frame are freed, so we free the list from toks_
/// to avoid having dangling pointers hanging around.
bool decoding_finalized_;
/// For the meaning of the next 3 variables, see the comment for
/// decoding_finalized_ above., and ComputeFinalCosts().
unordered_map<Token *, BaseFloat> final_costs_;
BaseFloat final_relative_cost_;
BaseFloat final_best_cost_;
std::shared_ptr<wenet::ContextGraph> context_graph_ = nullptr;
// There are various cleanup tasks... the toks_ structure contains
// singly linked lists of Token pointers, where Elem is the list type.
// It also indexes them in a hash, indexed by state (this hash is only
// maintained for the most recent frame). toks_.Clear()
// deletes them from the hash and returns the list of Elems. The
// function DeleteElems calls toks_.Delete(elem) for each elem in
// the list, which returns ownership of the Elem to the toks_ structure
// for reuse, but does not delete the Token pointer. The Token pointers
// are reference-counted and are ultimately deleted in PruneTokensForFrame,
// but are also linked together on each frame by their own linked-list,
// using the "next" pointer. We delete them manually.
void DeleteElems(Elem *list);
// This function takes a singly linked list of tokens for a single frame, and
// outputs a list of them in topological order (it will crash if no such order
// can be found, which will typically be due to decoding graphs with epsilon
// cycles, which are not allowed). Note: the output list may contain NULLs,
// which the caller should pass over; it just happens to be more efficient for
// the algorithm to output a list that contains NULLs.
static void TopSortTokens(Token *tok_list,
std::vector<Token *> *topsorted_list);
void ClearActiveTokens();
KALDI_DISALLOW_COPY_AND_ASSIGN(LatticeFasterDecoderTpl);
};
typedef LatticeFasterDecoderTpl<fst::StdFst, decoder::StdToken>
LatticeFasterDecoder;
} // end namespace kaldi.
#endif // KALDI_DECODER_LATTICE_FASTER_DECODER_H_
// decoder/lattice-faster-online-decoder.cc
// Copyright 2009-2012 Microsoft Corporation Mirko Hannemann
// 2013-2014 Johns Hopkins University (Author: Daniel Povey)
// 2014 Guoguo Chen
// 2014 IMSL, PKU-HKUST (author: Wei Shi)
// 2018 Zhehuai Chen
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
// see note at the top of lattice-faster-decoder.cc, about how to maintain this
// file in sync with lattice-faster-decoder.cc
#include <limits>
#include <queue>
#include <unordered_map>
#include <utility>
#include "decoder/lattice-faster-online-decoder.h"
namespace kaldi {
template <typename FST>
bool LatticeFasterOnlineDecoderTpl<FST>::TestGetBestPath(
bool use_final_probs) const {
Lattice lat1;
{
Lattice raw_lat;
this->GetRawLattice(&raw_lat, use_final_probs);
ShortestPath(raw_lat, &lat1);
}
Lattice lat2;
GetBestPath(&lat2, use_final_probs);
BaseFloat delta = 0.1;
int32 num_paths = 1;
if (!fst::RandEquivalent(lat1, lat2, num_paths, delta, rand())) {
KALDI_WARN << "Best-path test failed";
return false;
} else {
return true;
}
}
// Outputs an FST corresponding to the single best path through the lattice.
template <typename FST>
bool LatticeFasterOnlineDecoderTpl<FST>::GetBestPath(
Lattice *olat, bool use_final_probs) const {
olat->DeleteStates();
BaseFloat final_graph_cost;
BestPathIterator iter = BestPathEnd(use_final_probs, &final_graph_cost);
if (iter.Done()) return false; // would have printed warning.
StateId state = olat->AddState();
olat->SetFinal(state, LatticeWeight(final_graph_cost, 0.0));
while (!iter.Done()) {
LatticeArc arc;
iter = TraceBackBestPath(iter, &arc);
arc.nextstate = state;
StateId new_state = olat->AddState();
olat->AddArc(new_state, arc);
state = new_state;
}
olat->SetStart(state);
return true;
}
template <typename FST>
typename LatticeFasterOnlineDecoderTpl<FST>::BestPathIterator
LatticeFasterOnlineDecoderTpl<FST>::BestPathEnd(
bool use_final_probs, BaseFloat *final_cost_out) const {
if (this->decoding_finalized_ && !use_final_probs)
KALDI_ERR << "You cannot call FinalizeDecoding() and then call "
<< "BestPathEnd() with use_final_probs == false";
KALDI_ASSERT(this->NumFramesDecoded() > 0 &&
"You cannot call BestPathEnd if no frames were decoded.");
unordered_map<Token *, BaseFloat> final_costs_local;
const unordered_map<Token *, BaseFloat> &final_costs =
(this->decoding_finalized_ ? this->final_costs_ : final_costs_local);
if (!this->decoding_finalized_ && use_final_probs)
this->ComputeFinalCosts(&final_costs_local, NULL, NULL);
// Singly linked list of tokens on last frame (access list through "next"
// pointer).
BaseFloat best_cost = std::numeric_limits<BaseFloat>::infinity();
BaseFloat best_final_cost = 0;
Token *best_tok = NULL;
for (Token *tok = this->active_toks_.back().toks; tok != NULL;
tok = tok->next) {
BaseFloat cost = tok->tot_cost, final_cost = 0.0;
if (use_final_probs && !final_costs.empty()) {
// if we are instructed to use final-probs, and any final tokens were
// active on final frame, include the final-prob in the cost of the token.
typename unordered_map<Token *, BaseFloat>::const_iterator iter =
final_costs.find(tok);
if (iter != final_costs.end()) {
final_cost = iter->second;
cost += final_cost;
} else {
cost = std::numeric_limits<BaseFloat>::infinity();
}
}
if (cost < best_cost) {
best_cost = cost;
best_tok = tok;
best_final_cost = final_cost;
}
}
if (best_tok ==
NULL) { // this should not happen, and is likely a code error or
// caused by infinities in likelihoods, but I'm not making
// it a fatal error for now.
KALDI_WARN << "No final token found.";
}
if (final_cost_out) *final_cost_out = best_final_cost;
return BestPathIterator(best_tok, this->NumFramesDecoded() - 1);
}
template <typename FST>
typename LatticeFasterOnlineDecoderTpl<FST>::BestPathIterator
LatticeFasterOnlineDecoderTpl<FST>::TraceBackBestPath(BestPathIterator iter,
LatticeArc *oarc) const {
KALDI_ASSERT(!iter.Done() && oarc != NULL);
Token *tok = static_cast<Token *>(iter.tok);
int32 cur_t = iter.frame, step_t = 0;
if (tok->backpointer != NULL) {
// retrieve the correct forward link(with the best link cost)
BaseFloat best_cost = std::numeric_limits<BaseFloat>::infinity();
ForwardLinkT *link;
for (link = tok->backpointer->links; link != NULL; link = link->next) {
if (link->next_tok == tok) { // this is a link to "tok"
BaseFloat graph_cost = link->graph_cost,
acoustic_cost = link->acoustic_cost;
BaseFloat cost = graph_cost + acoustic_cost;
if (cost < best_cost) {
oarc->ilabel = link->ilabel;
oarc->olabel = link->olabel;
if (link->ilabel != 0) {
KALDI_ASSERT(static_cast<size_t>(cur_t) <
this->cost_offsets_.size());
acoustic_cost -= this->cost_offsets_[cur_t];
step_t = -1;
} else {
step_t = 0;
}
oarc->weight = LatticeWeight(graph_cost, acoustic_cost);
best_cost = cost;
}
}
}
if (link == NULL &&
best_cost ==
std::numeric_limits<BaseFloat>::infinity()) { // Did not find
// correct link.
KALDI_ERR << "Error tracing best-path back (likely "
<< "bug in token-pruning algorithm)";
}
} else {
oarc->ilabel = 0;
oarc->olabel = 0;
oarc->weight = LatticeWeight::One(); // zero costs.
}
return BestPathIterator(tok->backpointer, cur_t + step_t);
}
template <typename FST>
bool LatticeFasterOnlineDecoderTpl<FST>::GetRawLatticePruned(
Lattice *ofst, bool use_final_probs, BaseFloat beam) const {
typedef LatticeArc Arc;
typedef Arc::StateId StateId;
typedef Arc::Weight Weight;
typedef Arc::Label Label;
// Note: you can't use the old interface (Decode()) if you want to
// get the lattice with use_final_probs = false. You'd have to do
// InitDecoding() and then AdvanceDecoding().
if (this->decoding_finalized_ && !use_final_probs)
KALDI_ERR << "You cannot call FinalizeDecoding() and then call "
<< "GetRawLattice() with use_final_probs == false";
unordered_map<Token *, BaseFloat> final_costs_local;
const unordered_map<Token *, BaseFloat> &final_costs =
(this->decoding_finalized_ ? this->final_costs_ : final_costs_local);
if (!this->decoding_finalized_ && use_final_probs)
this->ComputeFinalCosts(&final_costs_local, NULL, NULL);
ofst->DeleteStates();
// num-frames plus one (since frames are one-based, and we have
// an extra frame for the start-state).
int32 num_frames = this->active_toks_.size() - 1;
KALDI_ASSERT(num_frames > 0);
for (int32 f = 0; f <= num_frames; f++) {
if (this->active_toks_[f].toks == NULL) {
KALDI_WARN << "No tokens active on frame " << f
<< ": not producing lattice.\n";
return false;
}
}
unordered_map<Token *, StateId> tok_map;
std::queue<std::pair<Token *, int32> > tok_queue;
// First initialize the queue and states. Put the initial state on the queue;
// this is the last token in the list active_toks_[0].toks.
for (Token *tok = this->active_toks_[0].toks; tok != NULL; tok = tok->next) {
if (tok->next == NULL) {
tok_map[tok] = ofst->AddState();
ofst->SetStart(tok_map[tok]);
std::pair<Token *, int32> tok_pair(tok, 0); // #frame = 0
tok_queue.push(tok_pair);
}
}
// Next create states for "good" tokens
while (!tok_queue.empty()) {
std::pair<Token *, int32> cur_tok_pair = tok_queue.front();
tok_queue.pop();
Token *cur_tok = cur_tok_pair.first;
int32 cur_frame = cur_tok_pair.second;
KALDI_ASSERT(cur_frame >= 0 && cur_frame <= this->cost_offsets_.size());
typename unordered_map<Token *, StateId>::const_iterator iter =
tok_map.find(cur_tok);
KALDI_ASSERT(iter != tok_map.end());
StateId cur_state = iter->second;
for (ForwardLinkT *l = cur_tok->links; l != NULL; l = l->next) {
Token *next_tok = l->next_tok;
if (next_tok->extra_cost < beam) {
// so both the current and the next token are good; create the arc
int32 next_frame = l->ilabel == 0 ? cur_frame : cur_frame + 1;
StateId nextstate;
if (tok_map.find(next_tok) == tok_map.end()) {
nextstate = tok_map[next_tok] = ofst->AddState();
tok_queue.push(std::pair<Token *, int32>(next_tok, next_frame));
} else {
nextstate = tok_map[next_tok];
}
BaseFloat cost_offset =
(l->ilabel != 0 ? this->cost_offsets_[cur_frame] : 0);
Arc arc(l->ilabel, l->olabel,
Weight(l->graph_cost, l->acoustic_cost - cost_offset),
nextstate);
ofst->AddArc(cur_state, arc);
}
}
if (cur_frame == num_frames) {
if (use_final_probs && !final_costs.empty()) {
typename unordered_map<Token *, BaseFloat>::const_iterator iter =
final_costs.find(cur_tok);
if (iter != final_costs.end())
ofst->SetFinal(cur_state, LatticeWeight(iter->second, 0));
} else {
ofst->SetFinal(cur_state, LatticeWeight::One());
}
}
}
return (ofst->NumStates() != 0);
}
// Instantiate the template for the FST types that we'll need.
template class LatticeFasterOnlineDecoderTpl<fst::Fst<fst::StdArc> >;
template class LatticeFasterOnlineDecoderTpl<fst::VectorFst<fst::StdArc> >;
template class LatticeFasterOnlineDecoderTpl<fst::ConstFst<fst::StdArc> >;
} // end namespace kaldi.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment