Commit 688b6eac authored by SWHL's avatar SWHL
Browse files

Update files

parents
#ifndef LM_INTERPOLATE_TUNE_INSTANCE_H
#define LM_INTERPOLATE_TUNE_INSTANCE_H
#include "tune_matrix.hh"
#include "../word_index.hh"
#include "../../util/scoped.hh"
#include "../../util/stream/config.hh"
#include "../../util/string_piece.hh"
#include <boost/optional.hpp>
#include <vector>
namespace util { namespace stream {
class Chain;
class FileBuffer;
}} // namespaces
namespace lm { namespace interpolate {
typedef uint32_t InstanceIndex;
typedef uint32_t ModelIndex;
struct Extension {
// Which tuning instance does this belong to?
InstanceIndex instance;
WordIndex word;
ModelIndex model;
// ln p_{model} (word | context(instance))
float ln_prob;
bool operator<(const Extension &other) const;
};
class ExtensionsFirstIteration;
struct InstancesConfig {
// For batching the model reads. This is per order.
std::size_t model_read_chain_mem;
// This is being sorted, make it larger.
std::size_t extension_write_chain_mem;
std::size_t lazy_memory;
util::stream::SortConfig sort;
};
class Instances {
private:
typedef Eigen::Matrix<Accum, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> BackoffMatrix;
public:
Instances(int tune_file, const std::vector<StringPiece> &model_names, const InstancesConfig &config);
// For destruction of forward-declared classes.
~Instances();
// Full backoff from unigram for each model.
typedef BackoffMatrix::ConstRowXpr FullBackoffs;
FullBackoffs LNBackoffs(InstanceIndex instance) const {
return ln_backoffs_.row(instance);
}
InstanceIndex NumInstances() const { return ln_backoffs_.rows(); }
const Vector &CorrectGradientTerm() const { return neg_ln_correct_sum_; }
const Matrix &LNUnigrams() const { return ln_unigrams_; }
// Entry size to use to configure the chain (since in practice order is needed).
std::size_t ReadExtensionsEntrySize() const;
void ReadExtensions(util::stream::Chain &chain);
// Vocab id of the beginning of sentence. Used to ignore it for normalization.
WordIndex BOS() const { return bos_; }
private:
// Allow the derivatives test to get access.
friend class MockInstances;
Instances();
// backoffs_(instance, model) is the backoff all the way to unigrams.
BackoffMatrix ln_backoffs_;
// neg_correct_sum_(model) = -\sum_{instances} ln p_{model}(correct(instance) | context(instance)).
// This appears as a term in the gradient.
Vector neg_ln_correct_sum_;
// ln_unigrams_(word, model) = ln p_{model}(word).
Matrix ln_unigrams_;
// This is the source of data for the first iteration.
util::scoped_ptr<ExtensionsFirstIteration> extensions_first_;
// Source of data for subsequent iterations. This contains already-sorted data.
util::scoped_ptr<util::stream::FileBuffer> extensions_subsequent_;
WordIndex bos_;
std::string temp_prefix_;
};
}} // namespaces
#endif // LM_INTERPOLATE_TUNE_INSTANCE_H
#include "tune_instances.hh"
#include "../../util/file.hh"
#include "../../util/file_stream.hh"
#include "../../util/stream/chain.hh"
#include "../../util/stream/config.hh"
#include "../../util/stream/typed_stream.hh"
#include "../../util/string_piece.hh"
#define BOOST_TEST_MODULE InstanceTest
#include <boost/test/unit_test.hpp>
#include <vector>
#include <math.h>
namespace lm { namespace interpolate { namespace {
BOOST_AUTO_TEST_CASE(Toy) {
util::scoped_fd test_input(util::MakeTemp("temporary"));
util::FileStream(test_input.get()) << "c\n";
std::string dir("../common/test_data");
if (boost::unit_test::framework::master_test_suite().argc == 2) {
dir = boost::unit_test::framework::master_test_suite().argv[1];
}
#if BYTE_ORDER == LITTLE_ENDIAN
std::string endian = "little";
#elif BYTE_ORDER == BIG_ENDIAN
std::string endian = "big";
#else
#error "Unsupported byte order."
#endif
dir += "/" + endian + "endian/";
std::vector<StringPiece> model_names;
std::string full0 = dir + "toy0";
std::string full1 = dir + "toy1";
model_names.push_back(full0);
model_names.push_back(full1);
// Tiny buffer sizes.
InstancesConfig config;
config.model_read_chain_mem = 100;
config.extension_write_chain_mem = 100;
config.lazy_memory = 100;
config.sort.temp_prefix = "temporary";
config.sort.buffer_size = 100;
config.sort.total_memory = 1024;
util::SeekOrThrow(test_input.get(), 0);
Instances inst(test_input.release(), model_names, config);
BOOST_CHECK_EQUAL(1, inst.BOS());
const Matrix &ln_unigrams = inst.LNUnigrams();
// <unk>=0
BOOST_CHECK_CLOSE(-0.90309 * M_LN10, ln_unigrams(0, 0), 0.001);
BOOST_CHECK_CLOSE(-1 * M_LN10, ln_unigrams(0, 1), 0.001);
// <s>=1 doesn't matter as long as it doesn't cause NaNs.
BOOST_CHECK(!isnan(ln_unigrams(1, 0)));
BOOST_CHECK(!isnan(ln_unigrams(1, 1)));
// a = 2
BOOST_CHECK_CLOSE(-0.46943438 * M_LN10, ln_unigrams(2, 0), 0.001);
BOOST_CHECK_CLOSE(-0.6146491 * M_LN10, ln_unigrams(2, 1), 0.001);
// </s> = 3
BOOST_CHECK_CLOSE(-0.5720968 * M_LN10, ln_unigrams(3, 0), 0.001);
BOOST_CHECK_CLOSE(-0.6146491 * M_LN10, ln_unigrams(3, 1), 0.001);
// c = 4
BOOST_CHECK_CLOSE(-0.90309 * M_LN10, ln_unigrams(4, 0), 0.001); // <unk>
BOOST_CHECK_CLOSE(-0.7659168 * M_LN10, ln_unigrams(4, 1), 0.001);
// too lazy to do b = 5.
// Two instances:
// <s> predicts c
// <s> c predicts </s>
BOOST_REQUIRE_EQUAL(2, inst.NumInstances());
BOOST_CHECK_CLOSE(-0.30103 * M_LN10, inst.LNBackoffs(0)(0), 0.001);
BOOST_CHECK_CLOSE(-0.30103 * M_LN10, inst.LNBackoffs(0)(1), 0.001);
// Backoffs of <s> c
BOOST_CHECK_CLOSE(0.0, inst.LNBackoffs(1)(0), 0.001);
BOOST_CHECK_CLOSE((-0.30103 - 0.30103) * M_LN10, inst.LNBackoffs(1)(1), 0.001);
util::stream::Chain extensions(util::stream::ChainConfig(inst.ReadExtensionsEntrySize(), 2, 300));
inst.ReadExtensions(extensions);
util::stream::TypedStream<Extension> stream(extensions.Add());
extensions >> util::stream::kRecycle;
// The extensions are (in order of instance, vocab id, and model as they should be sorted):
// <s> a from both models 0 and 1 (so two instances)
// <s> c from model 1
// <s> b from model 0
// c </s> from model 1
// Magic probabilities come from querying the models directly.
// <s> a from model 0
BOOST_REQUIRE(stream);
BOOST_CHECK_EQUAL(0, stream->instance);
BOOST_CHECK_EQUAL(2 /* a */, stream->word);
BOOST_CHECK_EQUAL(0, stream->model);
BOOST_CHECK_CLOSE(-0.37712017 * M_LN10, stream->ln_prob, 0.001);
// <s> a from model 1
BOOST_REQUIRE(++stream);
BOOST_CHECK_EQUAL(0, stream->instance);
BOOST_CHECK_EQUAL(2 /* a */, stream->word);
BOOST_CHECK_EQUAL(1, stream->model);
BOOST_CHECK_CLOSE(-0.4301247 * M_LN10, stream->ln_prob, 0.001);
// <s> c from model 1
BOOST_REQUIRE(++stream);
BOOST_CHECK_EQUAL(0, stream->instance);
BOOST_CHECK_EQUAL(4 /* c */, stream->word);
BOOST_CHECK_EQUAL(1, stream->model);
BOOST_CHECK_CLOSE(-0.4740302 * M_LN10, stream->ln_prob, 0.001);
// <s> b from model 0
BOOST_REQUIRE(++stream);
BOOST_CHECK_EQUAL(0, stream->instance);
BOOST_CHECK_EQUAL(5 /* b */, stream->word);
BOOST_CHECK_EQUAL(0, stream->model);
BOOST_CHECK_CLOSE(-0.41574955 * M_LN10, stream->ln_prob, 0.001);
// c </s> from model 1
BOOST_REQUIRE(++stream);
BOOST_CHECK_EQUAL(1, stream->instance);
BOOST_CHECK_EQUAL(3 /* </s> */, stream->word);
BOOST_CHECK_EQUAL(1, stream->model);
BOOST_CHECK_CLOSE(-0.09113217 * M_LN10, stream->ln_prob, 0.001);
BOOST_CHECK(!++stream);
}
}}} // namespaces
#ifndef LM_INTERPOLATE_TUNE_MATRIX_H
#define LM_INTERPOLATE_TUNE_MATRIX_H
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wpragmas" // Older gcc doesn't have "-Wunused-local-typedefs" and complains.
#pragma GCC diagnostic ignored "-Wunused-local-typedefs"
#include <Eigen/Core>
#pragma GCC diagnostic pop
namespace lm { namespace interpolate {
typedef Eigen::MatrixXf Matrix;
typedef Eigen::VectorXf Vector;
typedef Matrix::Scalar Accum;
}} // namespaces
#endif // LM_INTERPOLATE_TUNE_MATRIX_H
#include "tune_weights.hh"
#include "tune_derivatives.hh"
#include "tune_instances.hh"
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wpragmas" // Older gcc doesn't have "-Wunused-local-typedefs" and complains.
#pragma GCC diagnostic ignored "-Wunused-local-typedefs"
#include <Eigen/Dense>
#pragma GCC diagnostic pop
#include <boost/program_options.hpp>
#include <iostream>
namespace lm { namespace interpolate {
void TuneWeights(int tune_file, const std::vector<StringPiece> &model_names, const InstancesConfig &config, std::vector<float> &weights_out) {
Instances instances(tune_file, model_names, config);
Vector weights = Vector::Constant(model_names.size(), 1.0 / model_names.size());
Vector gradient;
Matrix hessian;
for (std::size_t iteration = 0; iteration < 10 /*TODO fancy stopping criteria */; ++iteration) {
std::cerr << "Iteration " << iteration << ": weights =";
for (Vector::Index i = 0; i < weights.rows(); ++i) {
std::cerr << ' ' << weights(i);
}
std::cerr << std::endl;
std::cerr << "Perplexity = " << Derivatives(instances, weights, gradient, hessian) << std::endl;
// TODO: 1.0 step size was too big and it kept getting unstable. More math.
weights -= 0.7 * hessian.inverse() * gradient;
}
weights_out.assign(weights.data(), weights.data() + weights.size());
}
}} // namespaces
#ifndef LM_INTERPOLATE_TUNE_WEIGHTS_H
#define LM_INTERPOLATE_TUNE_WEIGHTS_H
#include "../../util/string_piece.hh"
#include <vector>
namespace lm { namespace interpolate {
struct InstancesConfig;
// Run a tuning loop, producing weights as output.
void TuneWeights(int tune_file, const std::vector<StringPiece> &model_names, const InstancesConfig &config, std::vector<float> &weights);
}} // namespaces
#endif // LM_INTERPOLATE_TUNE_WEIGHTS_H
#include "universal_vocab.hh"
namespace lm {
namespace interpolate {
UniversalVocab::UniversalVocab(const std::vector<WordIndex>& model_vocab_sizes) {
model_index_map_.resize(model_vocab_sizes.size());
for (size_t i = 0; i < model_vocab_sizes.size(); ++i) {
model_index_map_[i].resize(model_vocab_sizes[i]);
}
}
}} // namespaces
#ifndef LM_INTERPOLATE_UNIVERSAL_VOCAB_H
#define LM_INTERPOLATE_UNIVERSAL_VOCAB_H
#include "../word_index.hh"
#include <vector>
#include <cstddef>
namespace lm {
namespace interpolate {
class UniversalVocab {
public:
explicit UniversalVocab(const std::vector<WordIndex>& model_vocab_sizes);
// GetUniversalIndex takes the model number and index for the specific
// model and returns the universal model number
WordIndex GetUniversalIdx(std::size_t model_num, WordIndex model_word_index) const {
return model_index_map_[model_num][model_word_index];
}
const WordIndex *Mapping(std::size_t model) const {
return &*model_index_map_[model].begin();
}
WordIndex SlowConvertToModel(std::size_t model, WordIndex index) const {
std::vector<WordIndex>::const_iterator i = lower_bound(model_index_map_[model].begin(), model_index_map_[model].end(), index);
if (i == model_index_map_[model].end() || *i != index) return 0;
return i - model_index_map_[model].begin();
}
void InsertUniversalIdx(std::size_t model_num, WordIndex word_index,
WordIndex universal_word_index) {
model_index_map_[model_num][word_index] = universal_word_index;
}
private:
std::vector<std::vector<WordIndex> > model_index_map_;
};
} // namespace interpolate
} // namespace lm
#endif // LM_INTERPOLATE_UNIVERSAL_VOCAB_H
#include "model.hh"
#include "../util/file_stream.hh"
#include "../util/file.hh"
#include "../util/file_piece.hh"
#include "../util/usage.hh"
#include "../util/thread_pool.hh"
#include <boost/range/iterator_range.hpp>
#include <boost/program_options.hpp>
#include <iostream>
#include <stdint.h>
namespace {
template <class Model, class Width> void ConvertToBytes(const Model &model, int fd_in) {
util::FilePiece in(fd_in);
util::FileStream out(1);
Width width;
StringPiece word;
const Width end_sentence = (Width)model.GetVocabulary().EndSentence();
while (true) {
while (in.ReadWordSameLine(word)) {
width = (Width)model.GetVocabulary().Index(word);
out.write(&width, sizeof(Width));
}
if (!in.ReadLineOrEOF(word)) break;
out.write(&end_sentence, sizeof(Width));
}
}
template <class Model, class Width> class Worker {
public:
explicit Worker(const Model &model, double &add_total) : model_(model), total_(0.0), add_total_(add_total) {}
// Destructors happen in the main thread, so there's no race for add_total_.
~Worker() { add_total_ += total_; }
typedef boost::iterator_range<Width *> Request;
void operator()(Request request) {
const lm::ngram::State *const begin_state = &model_.BeginSentenceState();
const lm::ngram::State *next_state = begin_state;
const Width kEOS = model_.GetVocabulary().EndSentence();
float sum = 0.0;
// Do even stuff first.
const Width *even_end = request.begin() + (request.size() & ~1);
// Alternating states
const Width *i;
for (i = request.begin(); i != even_end;) {
sum += model_.FullScore(*next_state, *i, state_[1]).prob;
next_state = (*i++ == kEOS) ? begin_state : &state_[1];
sum += model_.FullScore(*next_state, *i, state_[0]).prob;
next_state = (*i++ == kEOS) ? begin_state : &state_[0];
}
// Odd corner case.
if (request.size() & 1) {
sum += model_.FullScore(*next_state, *i, state_[2]).prob;
next_state = (*i++ == kEOS) ? begin_state : &state_[2];
}
total_ += sum;
}
private:
const Model &model_;
double total_;
double &add_total_;
lm::ngram::State state_[3];
};
struct Config {
int fd_in;
std::size_t threads;
std::size_t buf_per_thread;
bool query;
};
template <class Model, class Width> void QueryFromBytes(const Model &model, const Config &config) {
util::FileStream out(1);
out << "Threads: " << config.threads << '\n';
const Width kEOS = model.GetVocabulary().EndSentence();
double total = 0.0;
// Number of items to have in queue in addition to everything in flight.
const std::size_t kInQueue = 3;
std::size_t total_queue = config.threads + kInQueue;
std::vector<Width> backing(config.buf_per_thread * total_queue);
double loaded_cpu;
double loaded_wall;
uint64_t queries = 0;
{
util::RecyclingThreadPool<Worker<Model, Width> > pool(total_queue, config.threads, Worker<Model, Width>(model, total), boost::iterator_range<Width *>((Width*)0, (Width*)0));
for (std::size_t i = 0; i < total_queue; ++i) {
pool.PopulateRecycling(boost::iterator_range<Width *>(&backing[i * config.buf_per_thread], &backing[i * config.buf_per_thread]));
}
loaded_cpu = util::CPUTime();
loaded_wall = util::WallTime();
out << "To Load, CPU: " << loaded_cpu << " Wall: " << loaded_wall << '\n';
boost::iterator_range<Width *> overhang((Width*)0, (Width*)0);
while (true) {
boost::iterator_range<Width *> buf = pool.Consume();
std::memmove(buf.begin(), overhang.begin(), overhang.size() * sizeof(Width));
std::size_t got = util::ReadOrEOF(config.fd_in, buf.begin() + overhang.size(), (config.buf_per_thread - overhang.size()) * sizeof(Width));
if (!got && overhang.empty()) break;
UTIL_THROW_IF2(got % sizeof(Width), "File size not a multiple of vocab id size " << sizeof(Width));
Width *read_end = buf.begin() + overhang.size() + got / sizeof(Width);
Width *last_eos;
for (last_eos = read_end - 1; ; --last_eos) {
UTIL_THROW_IF2(last_eos <= buf.begin(), "Encountered a sentence longer than the buffer size of " << config.buf_per_thread << " words. Rerun with increased buffer size. TODO: adaptable buffer");
if (*last_eos == kEOS) break;
}
buf = boost::iterator_range<Width*>(buf.begin(), last_eos + 1);
overhang = boost::iterator_range<Width*>(last_eos + 1, read_end);
queries += buf.size();
pool.Produce(buf);
}
} // Drain pool.
double after_cpu = util::CPUTime();
double after_wall = util::WallTime();
util::FileStream(2, 70) << "Probability sum: " << total << '\n';
out << "Queries: " << queries << '\n';
out << "Excluding load, CPU: " << (after_cpu - loaded_cpu) << " Wall: " << (after_wall - loaded_wall) << '\n';
double cpu_per_entry = ((after_cpu - loaded_cpu) / static_cast<double>(queries));
double wall_per_entry = ((after_wall - loaded_wall) / static_cast<double>(queries));
out << "Seconds per query excluding load, CPU: " << cpu_per_entry << " Wall: " << wall_per_entry << '\n';
out << "Queries per second excluding load, CPU: " << (1.0/cpu_per_entry) << " Wall: " << (1.0/wall_per_entry) << '\n';
out << "RSSMax: " << util::RSSMax() << '\n';
}
template <class Model, class Width> void DispatchFunction(const Model &model, const Config &config) {
if (config.query) {
QueryFromBytes<Model, Width>(model, config);
} else {
ConvertToBytes<Model, Width>(model, config.fd_in);
}
}
template <class Model> void DispatchWidth(const char *file, const Config &config) {
lm::ngram::Config model_config;
model_config.load_method = util::READ;
Model model(file, model_config);
uint64_t bound = model.GetVocabulary().Bound();
if (bound <= 256) {
DispatchFunction<Model, uint8_t>(model, config);
} else if (bound <= 65536) {
DispatchFunction<Model, uint16_t>(model, config);
} else if (bound <= (1ULL << 32)) {
DispatchFunction<Model, uint32_t>(model, config);
} else {
DispatchFunction<Model, uint64_t>(model, config);
}
}
void Dispatch(const char *file, const Config &config) {
using namespace lm::ngram;
lm::ngram::ModelType model_type;
if (lm::ngram::RecognizeBinary(file, model_type)) {
switch(model_type) {
case PROBING:
DispatchWidth<lm::ngram::ProbingModel>(file, config);
break;
case REST_PROBING:
DispatchWidth<lm::ngram::RestProbingModel>(file, config);
break;
case TRIE:
DispatchWidth<lm::ngram::TrieModel>(file, config);
break;
case QUANT_TRIE:
DispatchWidth<lm::ngram::QuantTrieModel>(file, config);
break;
case ARRAY_TRIE:
DispatchWidth<lm::ngram::ArrayTrieModel>(file, config);
break;
case QUANT_ARRAY_TRIE:
DispatchWidth<lm::ngram::QuantArrayTrieModel>(file, config);
break;
default:
UTIL_THROW(util::Exception, "Unrecognized kenlm model type " << model_type);
}
} else {
UTIL_THROW(util::Exception, "Binarize before running benchmarks.");
}
}
} // namespace
int main(int argc, char *argv[]) {
try {
Config config;
config.fd_in = 0;
std::string model;
namespace po = boost::program_options;
po::options_description options("Benchmark options");
options.add_options()
("help,h", po::bool_switch(), "Show help message")
("model,m", po::value<std::string>(&model)->required(), "Model to query or convert vocab ids")
("threads,t", po::value<std::size_t>(&config.threads)->default_value(boost::thread::hardware_concurrency()), "Threads to use (querying only; TODO vocab conversion)")
("buffer,b", po::value<std::size_t>(&config.buf_per_thread)->default_value(4096), "Number of words to buffer per task.")
("vocab,v", po::bool_switch(), "Convert strings to vocab ids")
("query,q", po::bool_switch(), "Query from vocab ids");
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, options), vm);
if (argc == 1 || vm["help"].as<bool>()) {
std::cerr << "Benchmark program for KenLM. Intended usage:\n"
<< "#Convert text to vocabulary ids offline. These ids are tied to a model.\n"
<< argv[0] << " -v -m $model <$text >$text.vocab\n"
<< "#Ensure files are in RAM.\n"
<< "cat $text.vocab $model >/dev/null\n"
<< "#Timed query against the model.\n"
<< argv[0] << " -q -m $model <$text.vocab\n";
return 0;
}
po::notify(vm);
if (!(vm["vocab"].as<bool>() ^ vm["query"].as<bool>())) {
std::cerr << "Specify exactly one of -v (vocab conversion) or -q (query)." << std::endl;
return 0;
}
config.query = vm["query"].as<bool>();
if (!config.threads) {
std::cerr << "Specify a non-zero number of threads with -t." << std::endl;
}
Dispatch(model.c_str(), config);
} catch (const std::exception &e) {
std::cerr << e.what() << std::endl;
return 1;
}
return 0;
}
/* Efficient left and right language model state for sentence fragments.
* Intended usage:
* Store ChartState with every chart entry.
* To do a rule application:
* 1. Make a ChartState object for your new entry.
* 2. Construct RuleScore.
* 3. Going from left to right, call Terminal or NonTerminal.
* For terminals, just pass the vocab id.
* For non-terminals, pass that non-terminal's ChartState.
* If your decoder expects scores inclusive of subtree scores (i.e. you
* label entries with the highest-scoring path), pass the non-terminal's
* score as prob.
* If your decoder expects relative scores and will walk the chart later,
* pass prob = 0.0.
* In other words, the only effect of prob is that it gets added to the
* returned log probability.
* 4. Call Finish. It returns the log probability.
*
* There's a couple more details:
* Do not pass <s> to Terminal as it is formally not a word in the sentence,
* only context. Instead, call BeginSentence. If called, it should be the
* first call after RuleScore is constructed (since <s> is always the
* leftmost).
*
* If the leftmost RHS is a non-terminal, it's faster to call BeginNonTerminal.
*
* Hashing and sorting comparison operators are provided. All state objects
* are POD. If you intend to use memcmp on raw state objects, you must call
* ZeroRemaining first, as the value of array entries beyond length is
* otherwise undefined.
*
* Usage is of course not limited to chart decoding. Anything that generates
* sentence fragments missing left context could benefit. For example, a
* phrase-based decoder could pre-score phrases, storing ChartState with each
* phrase, even if hypotheses are generated left-to-right.
*/
#ifndef LM_LEFT_H
#define LM_LEFT_H
#include "max_order.hh"
#include "state.hh"
#include "return.hh"
#include "../util/murmur_hash.hh"
#include <algorithm>
namespace lm {
namespace ngram {
template <class M> class RuleScore {
public:
explicit RuleScore(const M &model, ChartState &out) : model_(model), out_(&out), left_done_(false), prob_(0.0) {
out.left.length = 0;
out.right.length = 0;
}
void BeginSentence() {
out_->right = model_.BeginSentenceState();
// out_->left is empty.
left_done_ = true;
}
void Terminal(WordIndex word) {
State copy(out_->right);
FullScoreReturn ret(model_.FullScore(copy, word, out_->right));
if (left_done_) { prob_ += ret.prob; return; }
if (ret.independent_left) {
prob_ += ret.prob;
left_done_ = true;
return;
}
out_->left.pointers[out_->left.length++] = ret.extend_left;
prob_ += ret.rest;
if (out_->right.length != copy.length + 1)
left_done_ = true;
}
// Faster version of NonTerminal for the case where the rule begins with a non-terminal.
void BeginNonTerminal(const ChartState &in, float prob = 0.0) {
prob_ = prob;
*out_ = in;
left_done_ = in.left.full;
}
void NonTerminal(const ChartState &in, float prob = 0.0) {
prob_ += prob;
if (!in.left.length) {
if (in.left.full) {
for (const float *i = out_->right.backoff; i < out_->right.backoff + out_->right.length; ++i) prob_ += *i;
left_done_ = true;
out_->right = in.right;
}
return;
}
if (!out_->right.length) {
out_->right = in.right;
if (left_done_) {
prob_ += model_.UnRest(in.left.pointers, in.left.pointers + in.left.length, 1);
return;
}
if (out_->left.length) {
left_done_ = true;
} else {
out_->left = in.left;
left_done_ = in.left.full;
}
return;
}
float backoffs[KENLM_MAX_ORDER - 1], backoffs2[KENLM_MAX_ORDER - 1];
float *back = backoffs, *back2 = backoffs2;
unsigned char next_use = out_->right.length;
// First word
if (ExtendLeft(in, next_use, 1, out_->right.backoff, back)) return;
// Words after the first, so extending a bigram to begin with
for (unsigned char extend_length = 2; extend_length <= in.left.length; ++extend_length) {
if (ExtendLeft(in, next_use, extend_length, back, back2)) return;
std::swap(back, back2);
}
if (in.left.full) {
for (const float *i = back; i != back + next_use; ++i) prob_ += *i;
left_done_ = true;
out_->right = in.right;
return;
}
// Right state was minimized, so it's already independent of the new words to the left.
if (in.right.length < in.left.length) {
out_->right = in.right;
return;
}
// Shift exisiting words down.
for (WordIndex *i = out_->right.words + next_use - 1; i >= out_->right.words; --i) {
*(i + in.right.length) = *i;
}
// Add words from in.right.
std::copy(in.right.words, in.right.words + in.right.length, out_->right.words);
// Assemble backoff composed on the existing state's backoff followed by the new state's backoff.
std::copy(in.right.backoff, in.right.backoff + in.right.length, out_->right.backoff);
std::copy(back, back + next_use, out_->right.backoff + in.right.length);
out_->right.length = in.right.length + next_use;
}
float Finish() {
// A N-1-gram might extend left and right but we should still set full to true because it's an N-1-gram.
out_->left.full = left_done_ || (out_->left.length == model_.Order() - 1);
return prob_;
}
void Reset() {
prob_ = 0.0;
left_done_ = false;
out_->left.length = 0;
out_->right.length = 0;
}
void Reset(ChartState &replacement) {
out_ = &replacement;
Reset();
}
private:
bool ExtendLeft(const ChartState &in, unsigned char &next_use, unsigned char extend_length, const float *back_in, float *back_out) {
ProcessRet(model_.ExtendLeft(
out_->right.words, out_->right.words + next_use, // Words to extend into
back_in, // Backoffs to use
in.left.pointers[extend_length - 1], extend_length, // Words to be extended
back_out, // Backoffs for the next score
next_use)); // Length of n-gram to use in next scoring.
if (next_use != out_->right.length) {
left_done_ = true;
if (!next_use) {
// Early exit.
out_->right = in.right;
prob_ += model_.UnRest(in.left.pointers + extend_length, in.left.pointers + in.left.length, extend_length + 1);
return true;
}
}
// Continue scoring.
return false;
}
void ProcessRet(const FullScoreReturn &ret) {
if (left_done_) {
prob_ += ret.prob;
return;
}
if (ret.independent_left) {
prob_ += ret.prob;
left_done_ = true;
return;
}
out_->left.pointers[out_->left.length++] = ret.extend_left;
prob_ += ret.rest;
}
const M &model_;
ChartState *out_;
bool left_done_;
float prob_;
};
} // namespace ngram
} // namespace lm
#endif // LM_LEFT_H
#include "left.hh"
#include "model.hh"
#include "../util/tokenize_piece.hh"
#include <vector>
#define BOOST_TEST_MODULE LeftTest
#include <boost/test/unit_test.hpp>
#include <boost/test/floating_point_comparison.hpp>
namespace lm {
namespace ngram {
namespace {
#define Term(word) score.Terminal(m.GetVocabulary().Index(word));
#define VCheck(word, value) BOOST_CHECK_EQUAL(m.GetVocabulary().Index(word), value);
// Apparently some Boost versions use templates and are pretty strict about types matching.
#define SLOPPY_CHECK_CLOSE(ref, value, tol) BOOST_CHECK_CLOSE(static_cast<double>(ref), static_cast<double>(value), static_cast<double>(tol));
template <class M> void Short(const M &m) {
ChartState base;
{
RuleScore<M> score(m, base);
Term("more");
Term("loin");
SLOPPY_CHECK_CLOSE(-1.206319 - 0.3561665, score.Finish(), 0.001);
}
BOOST_CHECK(base.left.full);
BOOST_CHECK_EQUAL(2, base.left.length);
BOOST_CHECK_EQUAL(1, base.right.length);
VCheck("loin", base.right.words[0]);
ChartState more_left;
{
RuleScore<M> score(m, more_left);
Term("little");
score.NonTerminal(base, -1.206319 - 0.3561665);
// p(little more loin | null context)
SLOPPY_CHECK_CLOSE(-1.56538, score.Finish(), 0.001);
}
BOOST_CHECK_EQUAL(3, more_left.left.length);
BOOST_CHECK_EQUAL(1, more_left.right.length);
VCheck("loin", more_left.right.words[0]);
BOOST_CHECK(more_left.left.full);
ChartState shorter;
{
RuleScore<M> score(m, shorter);
Term("to");
score.NonTerminal(base, -1.206319 - 0.3561665);
SLOPPY_CHECK_CLOSE(-0.30103 - 1.687872 - 1.206319 - 0.3561665, score.Finish(), 0.01);
}
BOOST_CHECK_EQUAL(1, shorter.left.length);
BOOST_CHECK_EQUAL(1, shorter.right.length);
VCheck("loin", shorter.right.words[0]);
BOOST_CHECK(shorter.left.full);
}
template <class M> void Charge(const M &m) {
ChartState base;
{
RuleScore<M> score(m, base);
Term("on");
Term("more");
SLOPPY_CHECK_CLOSE(-1.509559 -0.4771212 -1.206319, score.Finish(), 0.001);
}
BOOST_CHECK_EQUAL(1, base.left.length);
BOOST_CHECK_EQUAL(1, base.right.length);
VCheck("more", base.right.words[0]);
BOOST_CHECK(base.left.full);
ChartState extend;
{
RuleScore<M> score(m, extend);
Term("looking");
score.NonTerminal(base, -1.509559 -0.4771212 -1.206319);
SLOPPY_CHECK_CLOSE(-3.91039, score.Finish(), 0.001);
}
BOOST_CHECK_EQUAL(2, extend.left.length);
BOOST_CHECK_EQUAL(1, extend.right.length);
VCheck("more", extend.right.words[0]);
BOOST_CHECK(extend.left.full);
ChartState tobos;
{
RuleScore<M> score(m, tobos);
score.BeginSentence();
score.NonTerminal(extend, -3.91039);
SLOPPY_CHECK_CLOSE(-3.471169, score.Finish(), 0.001);
}
BOOST_CHECK_EQUAL(0, tobos.left.length);
BOOST_CHECK_EQUAL(1, tobos.right.length);
}
template <class M> float LeftToRight(const M &m, const std::vector<WordIndex> &words, bool begin_sentence = false) {
float ret = 0.0;
State right = begin_sentence ? m.BeginSentenceState() : m.NullContextState();
for (std::vector<WordIndex>::const_iterator i = words.begin(); i != words.end(); ++i) {
State copy(right);
ret += m.Score(copy, *i, right);
}
return ret;
}
template <class M> float RightToLeft(const M &m, const std::vector<WordIndex> &words, bool begin_sentence = false) {
float ret = 0.0;
ChartState state;
state.left.length = 0;
state.right.length = 0;
state.left.full = false;
for (std::vector<WordIndex>::const_reverse_iterator i = words.rbegin(); i != words.rend(); ++i) {
ChartState copy(state);
RuleScore<M> score(m, state);
score.Terminal(*i);
score.NonTerminal(copy, ret);
ret = score.Finish();
}
if (begin_sentence) {
ChartState copy(state);
RuleScore<M> score(m, state);
score.BeginSentence();
score.NonTerminal(copy, ret);
ret = score.Finish();
}
return ret;
}
template <class M> float TreeMiddle(const M &m, const std::vector<WordIndex> &words, bool begin_sentence = false) {
std::vector<std::pair<ChartState, float> > states(words.size());
for (unsigned int i = 0; i < words.size(); ++i) {
RuleScore<M> score(m, states[i].first);
score.Terminal(words[i]);
states[i].second = score.Finish();
}
while (states.size() > 1) {
std::vector<std::pair<ChartState, float> > upper((states.size() + 1) / 2);
for (unsigned int i = 0; i < states.size() / 2; ++i) {
RuleScore<M> score(m, upper[i].first);
score.NonTerminal(states[i*2].first, states[i*2].second);
score.NonTerminal(states[i*2+1].first, states[i*2+1].second);
upper[i].second = score.Finish();
}
if (states.size() % 2) {
upper.back() = states.back();
}
std::swap(states, upper);
}
if (states.empty()) return 0.0;
if (begin_sentence) {
ChartState ignored;
RuleScore<M> score(m, ignored);
score.BeginSentence();
score.NonTerminal(states.front().first, states.front().second);
return score.Finish();
} else {
return states.front().second;
}
}
template <class M> void LookupVocab(const M &m, const StringPiece &str, std::vector<WordIndex> &out) {
out.clear();
for (util::TokenIter<util::SingleCharacter, true> i(str, ' '); i; ++i) {
out.push_back(m.GetVocabulary().Index(*i));
}
}
#define TEXT_TEST(str) \
LookupVocab(m, str, words); \
expect = LeftToRight(m, words, rest); \
SLOPPY_CHECK_CLOSE(expect, RightToLeft(m, words, rest), 0.001); \
SLOPPY_CHECK_CLOSE(expect, TreeMiddle(m, words, rest), 0.001); \
// Build sentences, or parts thereof, from right to left.
template <class M> void GrowBig(const M &m, bool rest = false) {
std::vector<WordIndex> words;
float expect;
TEXT_TEST("in biarritz watching considering looking . on a little more loin also would consider higher to look good unknown the screening foo bar , unknown however unknown </s>");
TEXT_TEST("on a little more loin also would consider higher to look good unknown the screening foo bar , unknown however unknown </s>");
TEXT_TEST("on a little more loin also would consider higher to look good");
TEXT_TEST("more loin also would consider higher to look good");
TEXT_TEST("more loin also would consider higher to look");
TEXT_TEST("also would consider higher to look");
TEXT_TEST("also would consider higher");
TEXT_TEST("would consider higher to look");
TEXT_TEST("consider higher to look");
TEXT_TEST("consider higher to");
TEXT_TEST("consider higher");
}
template <class M> void GrowSmall(const M &m, bool rest = false) {
std::vector<WordIndex> words;
float expect;
TEXT_TEST("in biarritz watching considering looking . </s>");
TEXT_TEST("in biarritz watching considering looking .");
TEXT_TEST("in biarritz");
}
template <class M> void AlsoWouldConsiderHigher(const M &m) {
ChartState also;
{
RuleScore<M> score(m, also);
score.Terminal(m.GetVocabulary().Index("also"));
SLOPPY_CHECK_CLOSE(-1.687872, score.Finish(), 0.001);
}
ChartState would;
{
RuleScore<M> score(m, would);
score.Terminal(m.GetVocabulary().Index("would"));
SLOPPY_CHECK_CLOSE(-1.687872, score.Finish(), 0.001);
}
ChartState combine_also_would;
{
RuleScore<M> score(m, combine_also_would);
score.NonTerminal(also, -1.687872);
score.NonTerminal(would, -1.687872);
SLOPPY_CHECK_CLOSE(-1.687872 - 2.0, score.Finish(), 0.001);
}
BOOST_CHECK_EQUAL(2, combine_also_would.right.length);
ChartState also_would;
{
RuleScore<M> score(m, also_would);
score.Terminal(m.GetVocabulary().Index("also"));
score.Terminal(m.GetVocabulary().Index("would"));
SLOPPY_CHECK_CLOSE(-1.687872 - 2.0, score.Finish(), 0.001);
}
BOOST_CHECK_EQUAL(2, also_would.right.length);
ChartState consider;
{
RuleScore<M> score(m, consider);
score.Terminal(m.GetVocabulary().Index("consider"));
SLOPPY_CHECK_CLOSE(-1.687872, score.Finish(), 0.001);
}
BOOST_CHECK_EQUAL(1, consider.left.length);
BOOST_CHECK_EQUAL(1, consider.right.length);
BOOST_CHECK(!consider.left.full);
ChartState higher;
float higher_score;
{
RuleScore<M> score(m, higher);
score.Terminal(m.GetVocabulary().Index("higher"));
higher_score = score.Finish();
}
SLOPPY_CHECK_CLOSE(-1.509559, higher_score, 0.001);
BOOST_CHECK_EQUAL(1, higher.left.length);
BOOST_CHECK_EQUAL(1, higher.right.length);
BOOST_CHECK(!higher.left.full);
VCheck("higher", higher.right.words[0]);
SLOPPY_CHECK_CLOSE(-0.30103, higher.right.backoff[0], 0.001);
ChartState consider_higher;
{
RuleScore<M> score(m, consider_higher);
score.NonTerminal(consider, -1.687872);
score.NonTerminal(higher, higher_score);
SLOPPY_CHECK_CLOSE(-1.509559 - 1.687872 - 0.30103, score.Finish(), 0.001);
}
BOOST_CHECK_EQUAL(2, consider_higher.left.length);
BOOST_CHECK(!consider_higher.left.full);
ChartState full;
{
RuleScore<M> score(m, full);
score.NonTerminal(combine_also_would, -1.687872 - 2.0);
score.NonTerminal(consider_higher, -1.509559 - 1.687872 - 0.30103);
SLOPPY_CHECK_CLOSE(-10.6879, score.Finish(), 0.001);
}
BOOST_CHECK_EQUAL(4, full.right.length);
}
#define CHECK_SCORE(str, val) \
{ \
float got = val; \
std::vector<WordIndex> indices; \
LookupVocab(m, str, indices); \
SLOPPY_CHECK_CLOSE(LeftToRight(m, indices), got, 0.001); \
}
template <class M> void FullGrow(const M &m) {
std::vector<WordIndex> words;
LookupVocab(m, "in biarritz watching considering looking . </s>", words);
ChartState lexical[7];
float lexical_scores[7];
for (unsigned int i = 0; i < 7; ++i) {
RuleScore<M> score(m, lexical[i]);
score.Terminal(words[i]);
lexical_scores[i] = score.Finish();
}
CHECK_SCORE("in", lexical_scores[0]);
CHECK_SCORE("biarritz", lexical_scores[1]);
CHECK_SCORE("watching", lexical_scores[2]);
CHECK_SCORE("</s>", lexical_scores[6]);
ChartState l1[4];
float l1_scores[4];
{
RuleScore<M> score(m, l1[0]);
score.NonTerminal(lexical[0], lexical_scores[0]);
score.NonTerminal(lexical[1], lexical_scores[1]);
CHECK_SCORE("in biarritz", l1_scores[0] = score.Finish());
}
{
RuleScore<M> score(m, l1[1]);
score.NonTerminal(lexical[2], lexical_scores[2]);
score.NonTerminal(lexical[3], lexical_scores[3]);
CHECK_SCORE("watching considering", l1_scores[1] = score.Finish());
}
{
RuleScore<M> score(m, l1[2]);
score.NonTerminal(lexical[4], lexical_scores[4]);
score.NonTerminal(lexical[5], lexical_scores[5]);
CHECK_SCORE("looking .", l1_scores[2] = score.Finish());
}
BOOST_CHECK_EQUAL(l1[2].left.length, 1);
l1[3] = lexical[6];
l1_scores[3] = lexical_scores[6];
ChartState l2[2];
float l2_scores[2];
{
RuleScore<M> score(m, l2[0]);
score.NonTerminal(l1[0], l1_scores[0]);
score.NonTerminal(l1[1], l1_scores[1]);
CHECK_SCORE("in biarritz watching considering", l2_scores[0] = score.Finish());
}
{
RuleScore<M> score(m, l2[1]);
score.NonTerminal(l1[2], l1_scores[2]);
score.NonTerminal(l1[3], l1_scores[3]);
CHECK_SCORE("looking . </s>", l2_scores[1] = score.Finish());
}
BOOST_CHECK_EQUAL(l2[1].left.length, 1);
BOOST_CHECK(l2[1].left.full);
ChartState top;
{
RuleScore<M> score(m, top);
score.NonTerminal(l2[0], l2_scores[0]);
score.NonTerminal(l2[1], l2_scores[1]);
CHECK_SCORE("in biarritz watching considering looking . </s>", score.Finish());
}
}
const char *FileLocation() {
if (boost::unit_test::framework::master_test_suite().argc < 2) {
return "test.arpa";
}
return boost::unit_test::framework::master_test_suite().argv[1];
}
template <class M> void Everything() {
Config config;
config.messages = NULL;
M m(FileLocation(), config);
Short(m);
Charge(m);
GrowBig(m);
AlsoWouldConsiderHigher(m);
GrowSmall(m);
FullGrow(m);
}
BOOST_AUTO_TEST_CASE(ProbingAll) {
Everything<Model>();
}
BOOST_AUTO_TEST_CASE(TrieAll) {
Everything<TrieModel>();
}
BOOST_AUTO_TEST_CASE(QuantTrieAll) {
Everything<QuantTrieModel>();
}
BOOST_AUTO_TEST_CASE(ArrayQuantTrieAll) {
Everything<QuantArrayTrieModel>();
}
BOOST_AUTO_TEST_CASE(ArrayTrieAll) {
Everything<ArrayTrieModel>();
}
BOOST_AUTO_TEST_CASE(RestProbing) {
Config config;
config.messages = NULL;
RestProbingModel m(FileLocation(), config);
GrowBig(m, true);
}
} // namespace
} // namespace ngram
} // namespace lm
#include "lm_exception.hh"
#include <cerrno>
#include <cstdio>
namespace lm {
ConfigException::ConfigException() throw() {}
ConfigException::~ConfigException() throw() {}
LoadException::LoadException() throw() {}
LoadException::~LoadException() throw() {}
FormatLoadException::FormatLoadException() throw() {}
FormatLoadException::~FormatLoadException() throw() {}
VocabLoadException::VocabLoadException() throw() {}
VocabLoadException::~VocabLoadException() throw() {}
SpecialWordMissingException::SpecialWordMissingException() throw() {}
SpecialWordMissingException::~SpecialWordMissingException() throw() {}
} // namespace lm
#ifndef LM_LM_EXCEPTION_H
#define LM_LM_EXCEPTION_H
// Named to avoid conflict with util/exception.hh.
#include "../util/exception.hh"
#include "../util/string_piece.hh"
#include <exception>
#include <string>
namespace lm {
typedef enum {THROW_UP, COMPLAIN, SILENT} WarningAction;
class ConfigException : public util::Exception {
public:
ConfigException() throw();
~ConfigException() throw();
};
class LoadException : public util::Exception {
public:
virtual ~LoadException() throw();
protected:
LoadException() throw();
};
class FormatLoadException : public LoadException {
public:
FormatLoadException() throw();
~FormatLoadException() throw();
};
class VocabLoadException : public LoadException {
public:
virtual ~VocabLoadException() throw();
VocabLoadException() throw();
};
class SpecialWordMissingException : public VocabLoadException {
public:
explicit SpecialWordMissingException() throw();
~SpecialWordMissingException() throw();
};
} // namespace lm
#endif // LM_LM_EXCEPTION
#ifndef LM_MAX_ORDER_H
#define LM_MAX_ORDER_H
/* IF YOUR BUILD SYSTEM PASSES -DKENLM_MAX_ORDER, THEN CHANGE THE BUILD SYSTEM.
* If not, this is the default maximum order.
* Having this limit means that State can be
* (kMaxOrder - 1) * sizeof(float) bytes instead of
* sizeof(float*) + (kMaxOrder - 1) * sizeof(float) + malloc overhead
*/
#define KENLM_MAX_ORDER 10
#ifndef KENLM_ORDER_MESSAGE
#define KENLM_ORDER_MESSAGE "If your build system supports changing KENLM_MAX_ORDER, change it there and recompile. With cmake:\n cmake -DKENLM_MAX_ORDER=10 ..\nWith Moses:\n bjam --max-kenlm-order=10 -a\nOtherwise, edit lm/max_order.hh."
#endif
#endif // LM_MAX_ORDER_H
#include "model.hh"
#include "blank.hh"
#include "lm_exception.hh"
#include "search_hashed.hh"
#include "search_trie.hh"
#include "read_arpa.hh"
#include "../util/have.hh"
#include "../util/murmur_hash.hh"
#include <algorithm>
#include <functional>
#include <numeric>
#include <cmath>
#include <limits>
namespace lm {
namespace ngram {
namespace detail {
template <class Search, class VocabularyT> const ModelType GenericModel<Search, VocabularyT>::kModelType = Search::kModelType;
template <class Search, class VocabularyT> uint64_t GenericModel<Search, VocabularyT>::Size(const std::vector<uint64_t> &counts, const Config &config) {
return VocabularyT::Size(counts[0], config) + Search::Size(counts, config);
}
template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::SetupMemory(void *base, const std::vector<uint64_t> &counts, const Config &config) {
size_t goal_size = util::CheckOverflow(Size(counts, config));
uint8_t *start = static_cast<uint8_t*>(base);
size_t allocated = VocabularyT::Size(counts[0], config);
vocab_.SetupMemory(start, allocated, counts[0], config);
start += allocated;
start = search_.SetupMemory(start, counts, config);
if (static_cast<std::size_t>(start - static_cast<uint8_t*>(base)) != goal_size) UTIL_THROW(FormatLoadException, "The data structures took " << (start - static_cast<uint8_t*>(base)) << " but Size says they should take " << goal_size);
}
namespace {
void ComplainAboutARPA(const Config &config, ModelType model_type) {
if (config.write_mmap || !config.messages) return;
if (config.arpa_complain == Config::ALL) {
*config.messages << "Loading the LM will be faster if you build a binary file." << std::endl;
} else if (config.arpa_complain == Config::EXPENSIVE &&
(model_type == TRIE || model_type == QUANT_TRIE || model_type == ARRAY_TRIE || model_type == QUANT_ARRAY_TRIE)) {
*config.messages << "Building " << kModelNames[model_type] << " from ARPA is expensive. Save time by building a binary format." << std::endl;
}
}
void CheckCounts(const std::vector<uint64_t> &counts) {
UTIL_THROW_IF(counts.size() > KENLM_MAX_ORDER, FormatLoadException, "This model has order " << counts.size() << " but KenLM was compiled to support up to " << KENLM_MAX_ORDER << ". " << KENLM_ORDER_MESSAGE);
if (sizeof(uint64_t) > sizeof(std::size_t)) {
for (std::vector<uint64_t>::const_iterator i = counts.begin(); i != counts.end(); ++i) {
UTIL_THROW_IF(*i > static_cast<uint64_t>(std::numeric_limits<size_t>::max()), util::OverflowException, "This model has " << *i << " " << (i - counts.begin() + 1) << "-grams which is too many for 32-bit machines.");
}
}
}
} // namespace
template <class Search, class VocabularyT> GenericModel<Search, VocabularyT>::GenericModel(const char *file, const Config &init_config) : backing_(init_config) {
util::scoped_fd fd(util::OpenReadOrThrow(file));
if (IsBinaryFormat(fd.get())) {
Parameters parameters;
int fd_shallow = fd.release();
backing_.InitializeBinary(fd_shallow, kModelType, kVersion, parameters);
CheckCounts(parameters.counts);
Config new_config(init_config);
new_config.probing_multiplier = parameters.fixed.probing_multiplier;
Search::UpdateConfigFromBinary(backing_, parameters.counts, VocabularyT::Size(parameters.counts[0], new_config), new_config);
UTIL_THROW_IF(new_config.enumerate_vocab && !parameters.fixed.has_vocabulary, FormatLoadException, "The decoder requested all the vocabulary strings, but this binary file does not have them. You may need to rebuild the binary file with an updated version of build_binary.");
SetupMemory(backing_.LoadBinary(Size(parameters.counts, new_config)), parameters.counts, new_config);
vocab_.LoadedBinary(parameters.fixed.has_vocabulary, fd_shallow, new_config.enumerate_vocab, backing_.VocabStringReadingOffset());
} else {
ComplainAboutARPA(init_config, kModelType);
InitializeFromARPA(fd.release(), file, init_config);
}
// g++ prints warnings unless these are fully initialized.
State begin_sentence = State();
begin_sentence.length = 1;
begin_sentence.words[0] = vocab_.BeginSentence();
typename Search::Node ignored_node;
bool ignored_independent_left;
uint64_t ignored_extend_left;
begin_sentence.backoff[0] = search_.LookupUnigram(begin_sentence.words[0], ignored_node, ignored_independent_left, ignored_extend_left).Backoff();
State null_context = State();
null_context.length = 0;
P::Init(begin_sentence, null_context, vocab_, search_.Order());
}
template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::InitializeFromARPA(int fd, const char *file, const Config &config) {
// Backing file is the ARPA.
util::FilePiece f(fd, file, config.ProgressMessages());
try {
std::vector<uint64_t> counts;
// File counts do not include pruned trigrams that extend to quadgrams etc. These will be fixed by search_.
ReadARPACounts(f, counts);
CheckCounts(counts);
if (counts.size() < 2) UTIL_THROW(FormatLoadException, "This ngram implementation assumes at least a bigram model.");
if (config.probing_multiplier <= 1.0) UTIL_THROW(ConfigException, "probing multiplier must be > 1.0");
std::size_t vocab_size = util::CheckOverflow(VocabularyT::Size(counts[0], config));
// Setup the binary file for writing the vocab lookup table. The search_ is responsible for growing the binary file to its needs.
vocab_.SetupMemory(backing_.SetupJustVocab(vocab_size, counts.size()), vocab_size, counts[0], config);
if (config.write_mmap && config.include_vocab) {
WriteWordsWrapper wrap(config.enumerate_vocab);
vocab_.ConfigureEnumerate(&wrap, counts[0]);
search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_);
void *vocab_rebase, *search_rebase;
backing_.WriteVocabWords(wrap.Buffer(), vocab_rebase, search_rebase);
// Due to writing at the end of file, mmap may have relocated data. So remap.
vocab_.Relocate(vocab_rebase);
search_.SetupMemory(reinterpret_cast<uint8_t*>(search_rebase), counts, config);
} else {
vocab_.ConfigureEnumerate(config.enumerate_vocab, counts[0]);
search_.InitializeFromARPA(file, f, counts, config, vocab_, backing_);
}
if (!vocab_.SawUnk()) {
assert(config.unknown_missing != THROW_UP);
// Default probabilities for unknown.
search_.UnknownUnigram().backoff = 0.0;
search_.UnknownUnigram().prob = config.unknown_missing_logprob;
}
backing_.FinishFile(config, kModelType, kVersion, counts);
} catch (util::Exception &e) {
e << " Byte: " << f.Offset();
throw;
}
}
template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::FullScore(const State &in_state, const WordIndex new_word, State &out_state) const {
FullScoreReturn ret = ScoreExceptBackoff(in_state.words, in_state.words + in_state.length, new_word, out_state);
for (const float *i = in_state.backoff + ret.ngram_length - 1; i < in_state.backoff + in_state.length; ++i) {
ret.prob += *i;
}
return ret;
}
template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::FullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, State &out_state) const {
context_rend = std::min(context_rend, context_rbegin + P::Order() - 1);
FullScoreReturn ret = ScoreExceptBackoff(context_rbegin, context_rend, new_word, out_state);
// Add the backoff weights for n-grams of order start to (context_rend - context_rbegin).
unsigned char start = ret.ngram_length;
if (context_rend - context_rbegin < static_cast<std::ptrdiff_t>(start)) return ret;
bool independent_left;
uint64_t extend_left;
typename Search::Node node;
if (start <= 1) {
ret.prob += search_.LookupUnigram(*context_rbegin, node, independent_left, extend_left).Backoff();
start = 2;
} else if (!search_.FastMakeNode(context_rbegin, context_rbegin + start - 1, node)) {
return ret;
}
// i is the order of the backoff we're looking for.
unsigned char order_minus_2 = start - 2;
for (const WordIndex *i = context_rbegin + start - 1; i < context_rend; ++i, ++order_minus_2) {
typename Search::MiddlePointer p(search_.LookupMiddle(order_minus_2, *i, node, independent_left, extend_left));
if (!p.Found()) break;
ret.prob += p.Backoff();
}
return ret;
}
template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::GetState(const WordIndex *context_rbegin, const WordIndex *context_rend, State &out_state) const {
// Generate a state from context.
context_rend = std::min(context_rend, context_rbegin + P::Order() - 1);
if (context_rend == context_rbegin) {
out_state.length = 0;
return;
}
typename Search::Node node;
bool independent_left;
uint64_t extend_left;
out_state.backoff[0] = search_.LookupUnigram(*context_rbegin, node, independent_left, extend_left).Backoff();
out_state.length = HasExtension(out_state.backoff[0]) ? 1 : 0;
float *backoff_out = out_state.backoff + 1;
unsigned char order_minus_2 = 0;
for (const WordIndex *i = context_rbegin + 1; i < context_rend; ++i, ++backoff_out, ++order_minus_2) {
typename Search::MiddlePointer p(search_.LookupMiddle(order_minus_2, *i, node, independent_left, extend_left));
if (!p.Found()) {
std::copy(context_rbegin, context_rbegin + out_state.length, out_state.words);
return;
}
*backoff_out = p.Backoff();
if (HasExtension(*backoff_out)) out_state.length = i - context_rbegin + 1;
}
std::copy(context_rbegin, context_rbegin + out_state.length, out_state.words);
}
template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::ExtendLeft(
const WordIndex *add_rbegin, const WordIndex *add_rend,
const float *backoff_in,
uint64_t extend_pointer,
unsigned char extend_length,
float *backoff_out,
unsigned char &next_use) const {
FullScoreReturn ret;
typename Search::Node node;
if (extend_length == 1) {
typename Search::UnigramPointer ptr(search_.LookupUnigram(static_cast<WordIndex>(extend_pointer), node, ret.independent_left, ret.extend_left));
ret.rest = ptr.Rest();
ret.prob = ptr.Prob();
assert(!ret.independent_left);
} else {
typename Search::MiddlePointer ptr(search_.Unpack(extend_pointer, extend_length, node));
ret.rest = ptr.Rest();
ret.prob = ptr.Prob();
ret.extend_left = extend_pointer;
// If this function is called, then it does depend on left words.
ret.independent_left = false;
}
float subtract_me = ret.rest;
ret.ngram_length = extend_length;
next_use = extend_length;
ResumeScore(add_rbegin, add_rend, extend_length - 1, node, backoff_out, next_use, ret);
next_use -= extend_length;
// Charge backoffs.
for (const float *b = backoff_in + ret.ngram_length - extend_length; b < backoff_in + (add_rend - add_rbegin); ++b) ret.prob += *b;
ret.prob -= subtract_me;
ret.rest -= subtract_me;
return ret;
}
namespace {
// Do a paraonoid copy of history, assuming new_word has already been copied
// (hence the -1). out_state.length could be zero so I avoided using
// std::copy.
void CopyRemainingHistory(const WordIndex *from, State &out_state) {
WordIndex *out = out_state.words + 1;
const WordIndex *in_end = from + static_cast<ptrdiff_t>(out_state.length) - 1;
for (const WordIndex *in = from; in < in_end; ++in, ++out) *out = *in;
}
} // namespace
/* Ugly optimized function. Produce a score excluding backoff.
* The search goes in increasing order of ngram length.
* Context goes backward, so context_begin is the word immediately preceeding
* new_word.
*/
template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search, VocabularyT>::ScoreExceptBackoff(
const WordIndex *const context_rbegin,
const WordIndex *const context_rend,
const WordIndex new_word,
State &out_state) const {
assert(new_word < vocab_.Bound());
FullScoreReturn ret;
// ret.ngram_length contains the last known non-blank ngram length.
ret.ngram_length = 1;
typename Search::Node node;
typename Search::UnigramPointer uni(search_.LookupUnigram(new_word, node, ret.independent_left, ret.extend_left));
out_state.backoff[0] = uni.Backoff();
ret.prob = uni.Prob();
ret.rest = uni.Rest();
// This is the length of the context that should be used for continuation to the right.
out_state.length = HasExtension(out_state.backoff[0]) ? 1 : 0;
// We'll write the word anyway since it will probably be used and does no harm being there.
out_state.words[0] = new_word;
if (context_rbegin == context_rend) return ret;
ResumeScore(context_rbegin, context_rend, 0, node, out_state.backoff + 1, out_state.length, ret);
CopyRemainingHistory(context_rbegin, out_state);
return ret;
}
template <class Search, class VocabularyT> void GenericModel<Search, VocabularyT>::ResumeScore(const WordIndex *hist_iter, const WordIndex *const context_rend, unsigned char order_minus_2, typename Search::Node &node, float *backoff_out, unsigned char &next_use, FullScoreReturn &ret) const {
for (; ; ++order_minus_2, ++hist_iter, ++backoff_out) {
if (hist_iter == context_rend) return;
if (ret.independent_left) return;
if (order_minus_2 == P::Order() - 2) break;
typename Search::MiddlePointer pointer(search_.LookupMiddle(order_minus_2, *hist_iter, node, ret.independent_left, ret.extend_left));
if (!pointer.Found()) return;
*backoff_out = pointer.Backoff();
ret.prob = pointer.Prob();
ret.rest = pointer.Rest();
ret.ngram_length = order_minus_2 + 2;
if (HasExtension(*backoff_out)) {
next_use = ret.ngram_length;
}
}
ret.independent_left = true;
typename Search::LongestPointer longest(search_.LookupLongest(*hist_iter, node));
if (longest.Found()) {
ret.prob = longest.Prob();
ret.rest = ret.prob;
// There is no blank in longest_.
ret.ngram_length = P::Order();
}
}
template <class Search, class VocabularyT> float GenericModel<Search, VocabularyT>::InternalUnRest(const uint64_t *pointers_begin, const uint64_t *pointers_end, unsigned char first_length) const {
float ret;
typename Search::Node node;
if (first_length == 1) {
if (pointers_begin >= pointers_end) return 0.0;
bool independent_left;
uint64_t extend_left;
typename Search::UnigramPointer ptr(search_.LookupUnigram(static_cast<WordIndex>(*pointers_begin), node, independent_left, extend_left));
ret = ptr.Prob() - ptr.Rest();
++first_length;
++pointers_begin;
} else {
ret = 0.0;
}
for (const uint64_t *i = pointers_begin; i < pointers_end; ++i, ++first_length) {
typename Search::MiddlePointer ptr(search_.Unpack(*i, first_length, node));
ret += ptr.Prob() - ptr.Rest();
}
return ret;
}
template class GenericModel<HashedSearch<BackoffValue>, ProbingVocabulary>;
template class GenericModel<HashedSearch<RestValue>, ProbingVocabulary>;
template class GenericModel<trie::TrieSearch<DontQuantize, trie::DontBhiksha>, SortedVocabulary>;
template class GenericModel<trie::TrieSearch<DontQuantize, trie::ArrayBhiksha>, SortedVocabulary>;
template class GenericModel<trie::TrieSearch<SeparatelyQuantize, trie::DontBhiksha>, SortedVocabulary>;
template class GenericModel<trie::TrieSearch<SeparatelyQuantize, trie::ArrayBhiksha>, SortedVocabulary>;
} // namespace detail
base::Model *LoadVirtual(const char *file_name, const Config &config, ModelType model_type) {
RecognizeBinary(file_name, model_type);
switch (model_type) {
case PROBING:
return new ProbingModel(file_name, config);
case REST_PROBING:
return new RestProbingModel(file_name, config);
case TRIE:
return new TrieModel(file_name, config);
case QUANT_TRIE:
return new QuantTrieModel(file_name, config);
case ARRAY_TRIE:
return new ArrayTrieModel(file_name, config);
case QUANT_ARRAY_TRIE:
return new QuantArrayTrieModel(file_name, config);
default:
UTIL_THROW(FormatLoadException, "Confused by model type " << model_type);
}
}
} // namespace ngram
} // namespace lm
#ifndef LM_MODEL_H
#define LM_MODEL_H
#include "bhiksha.hh"
#include "binary_format.hh"
#include "config.hh"
#include "facade.hh"
#include "quantize.hh"
#include "search_hashed.hh"
#include "search_trie.hh"
#include "state.hh"
#include "value.hh"
#include "vocab.hh"
#include "weights.hh"
#include "../util/murmur_hash.hh"
#include <algorithm>
#include <vector>
#include <cstring>
namespace util { class FilePiece; }
namespace lm {
namespace ngram {
namespace detail {
// Should return the same results as SRI.
// ModelFacade typedefs Vocabulary so we use VocabularyT to avoid naming conflicts.
template <class Search, class VocabularyT> class GenericModel : public base::ModelFacade<GenericModel<Search, VocabularyT>, State, VocabularyT> {
private:
typedef base::ModelFacade<GenericModel<Search, VocabularyT>, State, VocabularyT> P;
public:
// This is the model type returned by RecognizeBinary.
static const ModelType kModelType;
static const unsigned int kVersion = Search::kVersion;
/* Get the size of memory that will be mapped given ngram counts. This
* does not include small non-mapped control structures, such as this class
* itself.
*/
static uint64_t Size(const std::vector<uint64_t> &counts, const Config &config = Config());
/* Load the model from a file. It may be an ARPA or binary file. Binary
* files must have the format expected by this class or you'll get an
* exception. So TrieModel can only load ARPA or binary created by
* TrieModel. To classify binary files, call RecognizeBinary in
* lm/binary_format.hh.
*/
explicit GenericModel(const char *file, const Config &config = Config());
/* Score p(new_word | in_state) and incorporate new_word into out_state.
* Note that in_state and out_state must be different references:
* &in_state != &out_state.
*/
FullScoreReturn FullScore(const State &in_state, const WordIndex new_word, State &out_state) const;
/* Slower call without in_state. Try to remember state, but sometimes it
* would cost too much memory or your decoder isn't setup properly.
* To use this function, make an array of WordIndex containing the context
* vocabulary ids in reverse order. Then, pass the bounds of the array:
* [context_rbegin, context_rend). The new_word is not part of the context
* array unless you intend to repeat words.
*/
FullScoreReturn FullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, State &out_state) const;
/* Get the state for a context. Don't use this if you can avoid it. Use
* BeginSentenceState or NullContextState and extend from those. If
* you're only going to use this state to call FullScore once, use
* FullScoreForgotState.
* To use this function, make an array of WordIndex containing the context
* vocabulary ids in reverse order. Then, pass the bounds of the array:
* [context_rbegin, context_rend).
*/
void GetState(const WordIndex *context_rbegin, const WordIndex *context_rend, State &out_state) const;
/* More efficient version of FullScore where a partial n-gram has already
* been scored.
* NOTE: THE RETURNED .rest AND .prob ARE RELATIVE TO THE .rest RETURNED BEFORE.
*/
FullScoreReturn ExtendLeft(
// Additional context in reverse order. This will update add_rend to
const WordIndex *add_rbegin, const WordIndex *add_rend,
// Backoff weights to use.
const float *backoff_in,
// extend_left returned by a previous query.
uint64_t extend_pointer,
// Length of n-gram that the pointer corresponds to.
unsigned char extend_length,
// Where to write additional backoffs for [extend_length + 1, min(Order() - 1, return.ngram_length)]
float *backoff_out,
// Amount of additional content that should be considered by the next call.
unsigned char &next_use) const;
/* Return probabilities minus rest costs for an array of pointers. The
* first length should be the length of the n-gram to which pointers_begin
* points.
*/
float UnRest(const uint64_t *pointers_begin, const uint64_t *pointers_end, unsigned char first_length) const {
// Compiler should optimize this if away.
return Search::kDifferentRest ? InternalUnRest(pointers_begin, pointers_end, first_length) : 0.0;
}
private:
FullScoreReturn ScoreExceptBackoff(const WordIndex *const context_rbegin, const WordIndex *const context_rend, const WordIndex new_word, State &out_state) const;
// Score bigrams and above. Do not include backoff.
void ResumeScore(const WordIndex *context_rbegin, const WordIndex *const context_rend, unsigned char starting_order_minus_2, typename Search::Node &node, float *backoff_out, unsigned char &next_use, FullScoreReturn &ret) const;
// Appears after Size in the cc file.
void SetupMemory(void *start, const std::vector<uint64_t> &counts, const Config &config);
void InitializeFromARPA(int fd, const char *file, const Config &config);
float InternalUnRest(const uint64_t *pointers_begin, const uint64_t *pointers_end, unsigned char first_length) const;
BinaryFormat backing_;
VocabularyT vocab_;
Search search_;
};
} // namespace detail
// Instead of typedef, inherit. This allows the Model etc to be forward declared.
// Oh the joys of C and C++.
#define LM_COMMA() ,
#define LM_NAME_MODEL(name, from)\
class name : public from {\
public:\
name(const char *file, const Config &config = Config()) : from(file, config) {}\
};
LM_NAME_MODEL(ProbingModel, detail::GenericModel<detail::HashedSearch<BackoffValue> LM_COMMA() ProbingVocabulary>);
LM_NAME_MODEL(RestProbingModel, detail::GenericModel<detail::HashedSearch<RestValue> LM_COMMA() ProbingVocabulary>);
LM_NAME_MODEL(TrieModel, detail::GenericModel<trie::TrieSearch<DontQuantize LM_COMMA() trie::DontBhiksha> LM_COMMA() SortedVocabulary>);
LM_NAME_MODEL(ArrayTrieModel, detail::GenericModel<trie::TrieSearch<DontQuantize LM_COMMA() trie::ArrayBhiksha> LM_COMMA() SortedVocabulary>);
LM_NAME_MODEL(QuantTrieModel, detail::GenericModel<trie::TrieSearch<SeparatelyQuantize LM_COMMA() trie::DontBhiksha> LM_COMMA() SortedVocabulary>);
LM_NAME_MODEL(QuantArrayTrieModel, detail::GenericModel<trie::TrieSearch<SeparatelyQuantize LM_COMMA() trie::ArrayBhiksha> LM_COMMA() SortedVocabulary>);
// Default implementation. No real reason for it to be the default.
typedef ::lm::ngram::ProbingVocabulary Vocabulary;
typedef ProbingModel Model;
/* Autorecognize the file type, load, and return the virtual base class. Don't
* use the virtual base class if you can avoid it. Instead, use the above
* classes as template arguments to your own virtual feature function.*/
base::Model *LoadVirtual(const char *file_name, const Config &config = Config(), ModelType if_arpa = PROBING);
} // namespace ngram
} // namespace lm
#endif // LM_MODEL_H
#include "model.hh"
#include <cstdlib>
#include <cstring>
#define BOOST_TEST_MODULE ModelTest
#include <boost/test/unit_test.hpp>
#include <boost/test/floating_point_comparison.hpp>
// Apparently some Boost versions use templates and are pretty strict about types matching.
#define SLOPPY_CHECK_CLOSE(ref, value, tol) BOOST_CHECK_CLOSE(static_cast<double>(ref), static_cast<double>(value), static_cast<double>(tol));
namespace lm {
namespace ngram {
std::ostream &operator<<(std::ostream &o, const State &state) {
o << "State length " << static_cast<unsigned int>(state.length) << ':';
for (const WordIndex *i = state.words; i < state.words + state.length; ++i) {
o << ' ' << *i;
}
return o;
}
namespace {
// Stupid bjam reverses the command line arguments randomly.
const char *TestLocation() {
if (boost::unit_test::framework::master_test_suite().argc < 3) {
return "test.arpa";
}
char **argv = boost::unit_test::framework::master_test_suite().argv;
return argv[strstr(argv[1], "nounk") ? 2 : 1];
}
const char *TestNoUnkLocation() {
if (boost::unit_test::framework::master_test_suite().argc < 3) {
return "test_nounk.arpa";
}
char **argv = boost::unit_test::framework::master_test_suite().argv;
return argv[strstr(argv[1], "nounk") ? 1 : 2];
}
template <class Model> State GetState(const Model &model, const char *word, const State &in) {
WordIndex context[in.length + 1];
context[0] = model.GetVocabulary().Index(word);
std::copy(in.words, in.words + in.length, context + 1);
State ret;
model.GetState(context, context + in.length + 1, ret);
return ret;
}
#define StartTest(word, ngram, score, indep_left) \
ret = model.FullScore( \
state, \
model.GetVocabulary().Index(word), \
out);\
SLOPPY_CHECK_CLOSE(score, ret.prob, 0.001); \
BOOST_CHECK_EQUAL(static_cast<unsigned int>(ngram), ret.ngram_length); \
BOOST_CHECK_GE(std::min<unsigned char>(ngram, 5 - 1), out.length); \
BOOST_CHECK_EQUAL(indep_left, ret.independent_left); \
BOOST_CHECK_EQUAL(out, GetState(model, word, state));
#define AppendTest(word, ngram, score, indep_left) \
StartTest(word, ngram, score, indep_left) \
state = out;
template <class M> void Starters(const M &model) {
FullScoreReturn ret;
Model::State state(model.BeginSentenceState());
Model::State out;
StartTest("looking", 2, -0.4846522, true);
// , probability plus <s> backoff
StartTest(",", 1, -1.383514 + -0.4149733, true);
// <unk> probability plus <s> backoff
StartTest("this_is_not_found", 1, -1.995635 + -0.4149733, true);
}
template <class M> void Continuation(const M &model) {
FullScoreReturn ret;
Model::State state(model.BeginSentenceState());
Model::State out;
AppendTest("looking", 2, -0.484652, true);
AppendTest("on", 3, -0.348837, true);
AppendTest("a", 4, -0.0155266, true);
AppendTest("little", 5, -0.00306122, true);
State preserve = state;
AppendTest("the", 1, -4.04005, true);
AppendTest("biarritz", 1, -1.9889, true);
AppendTest("not_found", 1, -2.29666, true);
AppendTest("more", 1, -1.20632 - 20.0, true);
AppendTest(".", 2, -0.51363, true);
AppendTest("</s>", 3, -0.0191651, true);
BOOST_CHECK_EQUAL(0, state.length);
state = preserve;
AppendTest("more", 5, -0.00181395, true);
BOOST_CHECK_EQUAL(4, state.length);
AppendTest("loin", 5, -0.0432557, true);
BOOST_CHECK_EQUAL(1, state.length);
}
template <class M> void Blanks(const M &model) {
FullScoreReturn ret;
State state(model.NullContextState());
State out;
AppendTest("also", 1, -1.687872, false);
AppendTest("would", 2, -2, true);
AppendTest("consider", 3, -3, true);
State preserve = state;
AppendTest("higher", 4, -4, true);
AppendTest("looking", 5, -5, true);
BOOST_CHECK_EQUAL(1, state.length);
state = preserve;
// also would consider not_found
AppendTest("not_found", 1, -1.995635 - 7.0 - 0.30103, true);
state = model.NullContextState();
// higher looking is a blank.
AppendTest("higher", 1, -1.509559, false);
AppendTest("looking", 2, -1.285941 - 0.30103, false);
State higher_looking = state;
BOOST_CHECK_EQUAL(1, state.length);
AppendTest("not_found", 1, -1.995635 - 0.4771212, true);
state = higher_looking;
// higher looking consider
AppendTest("consider", 1, -1.687872 - 0.4771212, true);
state = model.NullContextState();
AppendTest("would", 1, -1.687872, false);
BOOST_CHECK_EQUAL(1, state.length);
AppendTest("consider", 2, -1.687872 -0.30103, false);
BOOST_CHECK_EQUAL(2, state.length);
AppendTest("higher", 3, -1.509559 - 0.30103, false);
BOOST_CHECK_EQUAL(3, state.length);
AppendTest("looking", 4, -1.285941 - 0.30103, false);
}
template <class M> void Unknowns(const M &model) {
FullScoreReturn ret;
State state(model.NullContextState());
State out;
AppendTest("not_found", 1, -1.995635, false);
State preserve = state;
AppendTest("not_found2", 2, -15.0, true);
AppendTest("not_found3", 2, -15.0 - 2.0, true);
state = preserve;
AppendTest("however", 2, -4, true);
AppendTest("not_found3", 3, -6, true);
}
template <class M> void MinimalState(const M &model) {
FullScoreReturn ret;
State state(model.NullContextState());
State out;
AppendTest("baz", 1, -6.535897, true);
BOOST_CHECK_EQUAL(0, state.length);
state = model.NullContextState();
AppendTest("foo", 1, -3.141592, true);
BOOST_CHECK_EQUAL(1, state.length);
AppendTest("bar", 2, -6.0, true);
// Has to include the backoff weight.
BOOST_CHECK_EQUAL(1, state.length);
AppendTest("bar", 1, -2.718281 + 3.0, true);
BOOST_CHECK_EQUAL(1, state.length);
state = model.NullContextState();
AppendTest("to", 1, -1.687872, false);
AppendTest("look", 2, -0.2922095, true);
BOOST_CHECK_EQUAL(2, state.length);
AppendTest("a", 3, -7, true);
}
template <class M> void ExtendLeftTest(const M &model) {
State right;
FullScoreReturn little(model.FullScore(model.NullContextState(), model.GetVocabulary().Index("little"), right));
const float kLittleProb = -1.285941;
SLOPPY_CHECK_CLOSE(kLittleProb, little.prob, 0.001);
unsigned char next_use;
float backoff_out[4];
FullScoreReturn extend_none(model.ExtendLeft(NULL, NULL, NULL, little.extend_left, 1, NULL, next_use));
BOOST_CHECK_EQUAL(0, next_use);
BOOST_CHECK_EQUAL(little.extend_left, extend_none.extend_left);
SLOPPY_CHECK_CLOSE(little.prob - little.rest, extend_none.prob, 0.001);
BOOST_CHECK_EQUAL(1, extend_none.ngram_length);
const WordIndex a = model.GetVocabulary().Index("a");
float backoff_in = 3.14;
// a little
FullScoreReturn extend_a(model.ExtendLeft(&a, &a + 1, &backoff_in, little.extend_left, 1, backoff_out, next_use));
BOOST_CHECK_EQUAL(1, next_use);
SLOPPY_CHECK_CLOSE(-0.69897, backoff_out[0], 0.001);
SLOPPY_CHECK_CLOSE(-0.09132547 - little.rest, extend_a.prob, 0.001);
BOOST_CHECK_EQUAL(2, extend_a.ngram_length);
BOOST_CHECK(!extend_a.independent_left);
const WordIndex on = model.GetVocabulary().Index("on");
FullScoreReturn extend_on(model.ExtendLeft(&on, &on + 1, &backoff_in, extend_a.extend_left, 2, backoff_out, next_use));
BOOST_CHECK_EQUAL(1, next_use);
SLOPPY_CHECK_CLOSE(-0.4771212, backoff_out[0], 0.001);
SLOPPY_CHECK_CLOSE(-0.0283603 - (extend_a.rest + little.rest), extend_on.prob, 0.001);
BOOST_CHECK_EQUAL(3, extend_on.ngram_length);
BOOST_CHECK(!extend_on.independent_left);
const WordIndex both[2] = {a, on};
float backoff_in_arr[4];
FullScoreReturn extend_both(model.ExtendLeft(both, both + 2, backoff_in_arr, little.extend_left, 1, backoff_out, next_use));
BOOST_CHECK_EQUAL(2, next_use);
SLOPPY_CHECK_CLOSE(-0.69897, backoff_out[0], 0.001);
SLOPPY_CHECK_CLOSE(-0.4771212, backoff_out[1], 0.001);
SLOPPY_CHECK_CLOSE(-0.0283603 - little.rest, extend_both.prob, 0.001);
BOOST_CHECK_EQUAL(3, extend_both.ngram_length);
BOOST_CHECK(!extend_both.independent_left);
BOOST_CHECK_EQUAL(extend_on.extend_left, extend_both.extend_left);
}
#define StatelessTest(word, provide, ngram, score) \
ret = model.FullScoreForgotState(indices + num_words - word, indices + num_words - word + provide, indices[num_words - word - 1], state); \
SLOPPY_CHECK_CLOSE(score, ret.prob, 0.001); \
BOOST_CHECK_EQUAL(static_cast<unsigned int>(ngram), ret.ngram_length); \
model.GetState(indices + num_words - word, indices + num_words - word + provide, before); \
ret = model.FullScore(before, indices[num_words - word - 1], out); \
BOOST_CHECK(state == out); \
SLOPPY_CHECK_CLOSE(score, ret.prob, 0.001); \
BOOST_CHECK_EQUAL(static_cast<unsigned int>(ngram), ret.ngram_length);
template <class M> void Stateless(const M &model) {
const char *words[] = {"<s>", "looking", "on", "a", "little", "the", "biarritz", "not_found", "more", ".", "</s>"};
const size_t num_words = sizeof(words) / sizeof(const char*);
// Silience "array subscript is above array bounds" when extracting end pointer.
WordIndex indices[num_words + 1];
for (unsigned int i = 0; i < num_words; ++i) {
indices[num_words - 1 - i] = model.GetVocabulary().Index(words[i]);
}
FullScoreReturn ret;
State state, out, before;
ret = model.FullScoreForgotState(indices + num_words - 1, indices + num_words, indices[num_words - 2], state);
SLOPPY_CHECK_CLOSE(-0.484652, ret.prob, 0.001);
StatelessTest(1, 1, 2, -0.484652);
// looking
StatelessTest(1, 2, 2, -0.484652);
// on
AppendTest("on", 3, -0.348837, true);
StatelessTest(2, 3, 3, -0.348837);
StatelessTest(2, 2, 3, -0.348837);
StatelessTest(2, 1, 2, -0.4638903);
// a
StatelessTest(3, 4, 4, -0.0155266);
// little
AppendTest("little", 5, -0.00306122, true);
StatelessTest(4, 5, 5, -0.00306122);
// the
AppendTest("the", 1, -4.04005, true);
StatelessTest(5, 5, 1, -4.04005);
// No context of the.
StatelessTest(5, 0, 1, -1.687872);
// biarritz
StatelessTest(6, 1, 1, -1.9889);
// not found
StatelessTest(7, 1, 1, -2.29666);
StatelessTest(7, 0, 1, -1.995635);
WordIndex unk[1];
unk[0] = 0;
model.GetState(unk, unk + 1, state);
BOOST_CHECK_EQUAL(1, state.length);
BOOST_CHECK_EQUAL(static_cast<WordIndex>(0), state.words[0]);
}
template <class M> void NoUnkCheck(const M &model) {
WordIndex unk_index = 0;
State state;
FullScoreReturn ret = model.FullScoreForgotState(&unk_index, &unk_index + 1, unk_index, state);
SLOPPY_CHECK_CLOSE(-100.0, ret.prob, 0.001);
}
template <class M> void Everything(const M &m) {
Starters(m);
Continuation(m);
Blanks(m);
Unknowns(m);
MinimalState(m);
ExtendLeftTest(m);
Stateless(m);
}
class ExpectEnumerateVocab : public EnumerateVocab {
public:
ExpectEnumerateVocab() {}
void Add(WordIndex index, const StringPiece &str) {
BOOST_CHECK_EQUAL(seen.size(), index);
seen.push_back(std::string(str.data(), str.length()));
}
void Check(const base::Vocabulary &vocab) {
BOOST_CHECK_EQUAL(37ULL, seen.size());
BOOST_REQUIRE(!seen.empty());
BOOST_CHECK_EQUAL("<unk>", seen[0]);
for (WordIndex i = 0; i < seen.size(); ++i) {
BOOST_CHECK_EQUAL(i, vocab.Index(seen[i]));
}
}
void Clear() {
seen.clear();
}
std::vector<std::string> seen;
};
template <class ModelT> void LoadingTest() {
Config config;
config.arpa_complain = Config::NONE;
config.messages = NULL;
config.probing_multiplier = 2.0;
{
ExpectEnumerateVocab enumerate;
config.enumerate_vocab = &enumerate;
ModelT m(TestLocation(), config);
enumerate.Check(m.GetVocabulary());
BOOST_CHECK_EQUAL((WordIndex)37, m.GetVocabulary().Bound());
Everything(m);
}
{
ExpectEnumerateVocab enumerate;
config.enumerate_vocab = &enumerate;
ModelT m(TestNoUnkLocation(), config);
enumerate.Check(m.GetVocabulary());
BOOST_CHECK_EQUAL((WordIndex)37, m.GetVocabulary().Bound());
NoUnkCheck(m);
}
}
BOOST_AUTO_TEST_CASE(probing) {
LoadingTest<Model>();
}
BOOST_AUTO_TEST_CASE(trie) {
LoadingTest<TrieModel>();
}
BOOST_AUTO_TEST_CASE(quant_trie) {
LoadingTest<QuantTrieModel>();
}
BOOST_AUTO_TEST_CASE(bhiksha_trie) {
LoadingTest<ArrayTrieModel>();
}
BOOST_AUTO_TEST_CASE(quant_bhiksha_trie) {
LoadingTest<QuantArrayTrieModel>();
}
template <class ModelT> void BinaryTest(Config::WriteMethod write_method) {
Config config;
config.write_mmap = "test.binary";
config.messages = NULL;
config.write_method = write_method;
ExpectEnumerateVocab enumerate;
config.enumerate_vocab = &enumerate;
{
ModelT copy_model(TestLocation(), config);
enumerate.Check(copy_model.GetVocabulary());
enumerate.Clear();
Everything(copy_model);
}
config.write_mmap = NULL;
ModelType type;
BOOST_REQUIRE(RecognizeBinary("test.binary", type));
BOOST_CHECK_EQUAL(ModelT::kModelType, type);
{
ModelT binary("test.binary", config);
enumerate.Check(binary.GetVocabulary());
Everything(binary);
}
unlink("test.binary");
// Now test without <unk>.
config.write_mmap = "test_nounk.binary";
config.messages = NULL;
enumerate.Clear();
{
ModelT copy_model(TestNoUnkLocation(), config);
enumerate.Check(copy_model.GetVocabulary());
enumerate.Clear();
NoUnkCheck(copy_model);
}
config.write_mmap = NULL;
{
ModelT binary(TestNoUnkLocation(), config);
enumerate.Check(binary.GetVocabulary());
NoUnkCheck(binary);
}
unlink("test_nounk.binary");
}
template <class ModelT> void BinaryTest() {
BinaryTest<ModelT>(Config::WRITE_MMAP);
BinaryTest<ModelT>(Config::WRITE_AFTER);
}
BOOST_AUTO_TEST_CASE(write_and_read_probing) {
BinaryTest<ProbingModel>();
}
BOOST_AUTO_TEST_CASE(write_and_read_rest_probing) {
BinaryTest<RestProbingModel>();
}
BOOST_AUTO_TEST_CASE(write_and_read_trie) {
BinaryTest<TrieModel>();
}
BOOST_AUTO_TEST_CASE(write_and_read_quant_trie) {
BinaryTest<QuantTrieModel>();
}
BOOST_AUTO_TEST_CASE(write_and_read_array_trie) {
BinaryTest<ArrayTrieModel>();
}
BOOST_AUTO_TEST_CASE(write_and_read_quant_array_trie) {
BinaryTest<QuantArrayTrieModel>();
}
BOOST_AUTO_TEST_CASE(rest_max) {
Config config;
config.arpa_complain = Config::NONE;
config.messages = NULL;
RestProbingModel model(TestLocation(), config);
State state, out;
FullScoreReturn ret(model.FullScore(model.NullContextState(), model.GetVocabulary().Index("."), state));
SLOPPY_CHECK_CLOSE(-0.2705918, ret.rest, 0.001);
SLOPPY_CHECK_CLOSE(-0.01916512, model.FullScore(state, model.GetVocabulary().EndSentence(), out).rest, 0.001);
}
} // namespace
} // namespace ngram
} // namespace lm
#ifndef LM_MODEL_TYPE_H
#define LM_MODEL_TYPE_H
namespace lm {
namespace ngram {
/* Not the best numbering system, but it grew this way for historical reasons
* and I want to preserve existing binary files. */
typedef enum {PROBING=0, REST_PROBING=1, TRIE=2, QUANT_TRIE=3, ARRAY_TRIE=4, QUANT_ARRAY_TRIE=5} ModelType;
// Historical names.
const ModelType HASH_PROBING = PROBING;
const ModelType TRIE_SORTED = TRIE;
const ModelType QUANT_TRIE_SORTED = QUANT_TRIE;
const ModelType ARRAY_TRIE_SORTED = ARRAY_TRIE;
const ModelType QUANT_ARRAY_TRIE_SORTED = QUANT_ARRAY_TRIE;
const static ModelType kQuantAdd = static_cast<ModelType>(QUANT_TRIE - TRIE);
const static ModelType kArrayAdd = static_cast<ModelType>(ARRAY_TRIE - TRIE);
} // namespace ngram
} // namespace lm
#endif // LM_MODEL_TYPE_H
#ifndef LM_NGRAM_QUERY_H
#define LM_NGRAM_QUERY_H
#include "enumerate_vocab.hh"
#include "model.hh"
#include "../util/file_stream.hh"
#include "../util/file_piece.hh"
#include "../util/usage.hh"
#include <cstdlib>
#include <string>
#include <cmath>
namespace lm {
namespace ngram {
class QueryPrinter {
public:
QueryPrinter(int fd, bool print_word, bool print_line, bool print_summary, bool flush)
: out_(fd), print_word_(print_word), print_line_(print_line), print_summary_(print_summary), flush_(flush) {}
void Word(StringPiece surface, WordIndex vocab, const FullScoreReturn &ret) {
if (!print_word_) return;
out_ << surface << '=' << vocab << ' ' << static_cast<unsigned int>(ret.ngram_length) << ' ' << ret.prob << '\t';
if (flush_) out_.flush();
}
void Line(uint64_t oov, float total) {
if (!print_line_) return;
out_ << "Total: " << total << " OOV: " << oov << '\n';
if (flush_) out_.flush();
}
void Summary(double ppl_including_oov, double ppl_excluding_oov, uint64_t corpus_oov, uint64_t corpus_tokens) {
if (!print_summary_) return;
out_ <<
"Perplexity including OOVs:\t" << ppl_including_oov << "\n"
"Perplexity excluding OOVs:\t" << ppl_excluding_oov << "\n"
"OOVs:\t" << corpus_oov << "\n"
"Tokens:\t" << corpus_tokens << '\n';
out_.flush();
}
private:
util::FileStream out_;
bool print_word_;
bool print_line_;
bool print_summary_;
bool flush_;
};
template <class Model, class Printer> void Query(const Model &model, bool sentence_context, Printer &printer) {
typename Model::State state, out;
lm::FullScoreReturn ret;
StringPiece word;
util::FilePiece in(0);
double corpus_total = 0.0;
double corpus_total_oov_only = 0.0;
uint64_t corpus_oov = 0;
uint64_t corpus_tokens = 0;
while (true) {
state = sentence_context ? model.BeginSentenceState() : model.NullContextState();
float total = 0.0;
uint64_t oov = 0;
while (in.ReadWordSameLine(word)) {
lm::WordIndex vocab = model.GetVocabulary().Index(word);
ret = model.FullScore(state, vocab, out);
if (vocab == model.GetVocabulary().NotFound()) {
++oov;
corpus_total_oov_only += ret.prob;
}
total += ret.prob;
printer.Word(word, vocab, ret);
++corpus_tokens;
state = out;
}
// If people don't have a newline after their last query, this won't add a </s>.
// Sue me.
try {
UTIL_THROW_IF('\n' != in.get(), util::Exception, "FilePiece is confused.");
} catch (const util::EndOfFileException &e) { break; }
if (sentence_context) {
ret = model.FullScore(state, model.GetVocabulary().EndSentence(), out);
total += ret.prob;
++corpus_tokens;
printer.Word("</s>", model.GetVocabulary().EndSentence(), ret);
}
printer.Line(oov, total);
corpus_total += total;
corpus_oov += oov;
}
printer.Summary(
pow(10.0, -(corpus_total / static_cast<double>(corpus_tokens))), // PPL including OOVs
pow(10.0, -((corpus_total - corpus_total_oov_only) / static_cast<double>(corpus_tokens - corpus_oov))), // PPL excluding OOVs
corpus_oov,
corpus_tokens);
}
template <class Model> void Query(const char *file, const Config &config, bool sentence_context, QueryPrinter &printer) {
Model model(file, config);
Query<Model, QueryPrinter>(model, sentence_context, printer);
}
} // namespace ngram
} // namespace lm
#endif // LM_NGRAM_QUERY_H
#ifndef LM_PARTIAL_H
#define LM_PARTIAL_H
#include "return.hh"
#include "state.hh"
#include <algorithm>
#include <cassert>
namespace lm {
namespace ngram {
struct ExtendReturn {
float adjust;
bool make_full;
unsigned char next_use;
};
template <class Model> ExtendReturn ExtendLoop(
const Model &model,
unsigned char seen, const WordIndex *add_rbegin, const WordIndex *add_rend, const float *backoff_start,
const uint64_t *pointers, const uint64_t *pointers_end,
uint64_t *&pointers_write,
float *backoff_write) {
unsigned char add_length = add_rend - add_rbegin;
float backoff_buf[2][KENLM_MAX_ORDER - 1];
float *backoff_in = backoff_buf[0], *backoff_out = backoff_buf[1];
std::copy(backoff_start, backoff_start + add_length, backoff_in);
ExtendReturn value;
value.make_full = false;
value.adjust = 0.0;
value.next_use = add_length;
unsigned char i = 0;
unsigned char length = pointers_end - pointers;
// pointers_write is NULL means that the existing left state is full, so we should use completed probabilities.
if (pointers_write) {
// Using full context, writing to new left state.
for (; i < length; ++i) {
FullScoreReturn ret(model.ExtendLeft(
add_rbegin, add_rbegin + value.next_use,
backoff_in,
pointers[i], i + seen + 1,
backoff_out,
value.next_use));
std::swap(backoff_in, backoff_out);
if (ret.independent_left) {
value.adjust += ret.prob;
value.make_full = true;
++i;
break;
}
value.adjust += ret.rest;
*pointers_write++ = ret.extend_left;
if (value.next_use != add_length) {
value.make_full = true;
++i;
break;
}
}
}
// Using some of the new context.
for (; i < length && value.next_use; ++i) {
FullScoreReturn ret(model.ExtendLeft(
add_rbegin, add_rbegin + value.next_use,
backoff_in,
pointers[i], i + seen + 1,
backoff_out,
value.next_use));
std::swap(backoff_in, backoff_out);
value.adjust += ret.prob;
}
float unrest = model.UnRest(pointers + i, pointers_end, i + seen + 1);
// Using none of the new context.
value.adjust += unrest;
std::copy(backoff_in, backoff_in + value.next_use, backoff_write);
return value;
}
template <class Model> float RevealBefore(const Model &model, const Right &reveal, const unsigned char seen, bool reveal_full, Left &left, Right &right) {
assert(seen < reveal.length || reveal_full);
uint64_t *pointers_write = reveal_full ? NULL : left.pointers;
float backoff_buffer[KENLM_MAX_ORDER - 1];
ExtendReturn value(ExtendLoop(
model,
seen, reveal.words + seen, reveal.words + reveal.length, reveal.backoff + seen,
left.pointers, left.pointers + left.length,
pointers_write,
left.full ? backoff_buffer : (right.backoff + right.length)));
if (reveal_full) {
left.length = 0;
value.make_full = true;
} else {
left.length = pointers_write - left.pointers;
value.make_full |= (left.length == model.Order() - 1);
}
if (left.full) {
for (unsigned char i = 0; i < value.next_use; ++i) value.adjust += backoff_buffer[i];
} else {
// If left wasn't full when it came in, put words into right state.
std::copy(reveal.words + seen, reveal.words + seen + value.next_use, right.words + right.length);
right.length += value.next_use;
left.full = value.make_full || (right.length == model.Order() - 1);
}
return value.adjust;
}
template <class Model> float RevealAfter(const Model &model, Left &left, Right &right, const Left &reveal, unsigned char seen) {
assert(seen < reveal.length || reveal.full);
uint64_t *pointers_write = left.full ? NULL : (left.pointers + left.length);
ExtendReturn value(ExtendLoop(
model,
seen, right.words, right.words + right.length, right.backoff,
reveal.pointers + seen, reveal.pointers + reveal.length,
pointers_write,
right.backoff));
if (reveal.full) {
for (unsigned char i = 0; i < value.next_use; ++i) value.adjust += right.backoff[i];
right.length = 0;
value.make_full = true;
} else {
right.length = value.next_use;
value.make_full |= (right.length == model.Order() - 1);
}
if (!left.full) {
left.length = pointers_write - left.pointers;
left.full = value.make_full || (left.length == model.Order() - 1);
}
return value.adjust;
}
template <class Model> float Subsume(const Model &model, Left &first_left, const Right &first_right, const Left &second_left, Right &second_right, const unsigned int between_length) {
assert(first_right.length < KENLM_MAX_ORDER);
assert(second_left.length < KENLM_MAX_ORDER);
assert(between_length < KENLM_MAX_ORDER - 1);
uint64_t *pointers_write = first_left.full ? NULL : (first_left.pointers + first_left.length);
float backoff_buffer[KENLM_MAX_ORDER - 1];
ExtendReturn value(ExtendLoop(
model,
between_length, first_right.words, first_right.words + first_right.length, first_right.backoff,
second_left.pointers, second_left.pointers + second_left.length,
pointers_write,
second_left.full ? backoff_buffer : (second_right.backoff + second_right.length)));
if (second_left.full) {
for (unsigned char i = 0; i < value.next_use; ++i) value.adjust += backoff_buffer[i];
} else {
std::copy(first_right.words, first_right.words + value.next_use, second_right.words + second_right.length);
second_right.length += value.next_use;
value.make_full |= (second_right.length == model.Order() - 1);
}
if (!first_left.full) {
first_left.length = pointers_write - first_left.pointers;
first_left.full = value.make_full || second_left.full || (first_left.length == model.Order() - 1);
}
assert(first_left.length < KENLM_MAX_ORDER);
assert(second_right.length < KENLM_MAX_ORDER);
return value.adjust;
}
} // namespace ngram
} // namespace lm
#endif // LM_PARTIAL_H
#include "partial.hh"
#include "left.hh"
#include "model.hh"
#include "../util/tokenize_piece.hh"
#define BOOST_TEST_MODULE PartialTest
#include <boost/test/unit_test.hpp>
#include <boost/test/floating_point_comparison.hpp>
namespace lm {
namespace ngram {
namespace {
const char *TestLocation() {
if (boost::unit_test::framework::master_test_suite().argc < 2) {
return "test.arpa";
}
return boost::unit_test::framework::master_test_suite().argv[1];
}
Config SilentConfig() {
Config config;
config.arpa_complain = Config::NONE;
config.messages = NULL;
return config;
}
struct ModelFixture {
ModelFixture() : m(TestLocation(), SilentConfig()) {}
RestProbingModel m;
};
BOOST_FIXTURE_TEST_SUITE(suite, ModelFixture)
BOOST_AUTO_TEST_CASE(SimpleBefore) {
Left left;
left.full = false;
left.length = 0;
Right right;
right.length = 0;
Right reveal;
reveal.length = 1;
WordIndex period = m.GetVocabulary().Index(".");
reveal.words[0] = period;
reveal.backoff[0] = -0.845098;
BOOST_CHECK_CLOSE(0.0, RevealBefore(m, reveal, 0, false, left, right), 0.001);
BOOST_CHECK_EQUAL(0, left.length);
BOOST_CHECK(!left.full);
BOOST_CHECK_EQUAL(1, right.length);
BOOST_CHECK_EQUAL(period, right.words[0]);
BOOST_CHECK_CLOSE(-0.845098, right.backoff[0], 0.001);
WordIndex more = m.GetVocabulary().Index("more");
reveal.words[1] = more;
reveal.backoff[1] = -0.4771212;
reveal.length = 2;
BOOST_CHECK_CLOSE(0.0, RevealBefore(m, reveal, 1, false, left, right), 0.001);
BOOST_CHECK_EQUAL(0, left.length);
BOOST_CHECK(!left.full);
BOOST_CHECK_EQUAL(2, right.length);
BOOST_CHECK_EQUAL(period, right.words[0]);
BOOST_CHECK_EQUAL(more, right.words[1]);
BOOST_CHECK_CLOSE(-0.845098, right.backoff[0], 0.001);
BOOST_CHECK_CLOSE(-0.4771212, right.backoff[1], 0.001);
}
BOOST_AUTO_TEST_CASE(AlsoWouldConsider) {
WordIndex would = m.GetVocabulary().Index("would");
WordIndex consider = m.GetVocabulary().Index("consider");
ChartState current;
current.left.length = 1;
current.left.pointers[0] = would;
current.left.full = false;
current.right.length = 1;
current.right.words[0] = would;
current.right.backoff[0] = -0.30103;
Left after;
after.full = false;
after.length = 1;
after.pointers[0] = consider;
// adjustment for would consider
BOOST_CHECK_CLOSE(-1.687872 - -0.2922095 - 0.30103, RevealAfter(m, current.left, current.right, after, 0), 0.001);
BOOST_CHECK_EQUAL(2, current.left.length);
BOOST_CHECK_EQUAL(would, current.left.pointers[0]);
BOOST_CHECK_EQUAL(false, current.left.full);
WordIndex also = m.GetVocabulary().Index("also");
Right before;
before.length = 1;
before.words[0] = also;
before.backoff[0] = -0.30103;
// r(would) = -0.2922095 [i would], r(would -> consider) = -1.988902 [b(would) + p(consider)]
// p(also -> would) = -2, p(also would -> consider) = -3
BOOST_CHECK_CLOSE(-2 + 0.2922095 -3 + 1.988902, RevealBefore(m, before, 0, false, current.left, current.right), 0.001);
BOOST_CHECK_EQUAL(0, current.left.length);
BOOST_CHECK(current.left.full);
BOOST_CHECK_EQUAL(2, current.right.length);
BOOST_CHECK_EQUAL(would, current.right.words[0]);
BOOST_CHECK_EQUAL(also, current.right.words[1]);
}
BOOST_AUTO_TEST_CASE(EndSentence) {
WordIndex loin = m.GetVocabulary().Index("loin");
WordIndex period = m.GetVocabulary().Index(".");
WordIndex eos = m.GetVocabulary().EndSentence();
ChartState between;
between.left.length = 1;
between.left.pointers[0] = eos;
between.left.full = true;
between.right.length = 0;
Right before;
before.words[0] = period;
before.words[1] = loin;
before.backoff[0] = -0.845098;
before.backoff[1] = 0.0;
before.length = 1;
BOOST_CHECK_CLOSE(-0.0410707, RevealBefore(m, before, 0, true, between.left, between.right), 0.001);
BOOST_CHECK_EQUAL(0, between.left.length);
}
float ScoreFragment(const RestProbingModel &model, unsigned int *begin, unsigned int *end, ChartState &out) {
RuleScore<RestProbingModel> scorer(model, out);
for (unsigned int *i = begin; i < end; ++i) {
scorer.Terminal(*i);
}
return scorer.Finish();
}
void CheckAdjustment(const RestProbingModel &model, float expect, const Right &before_in, bool before_full, ChartState between, const Left &after_in) {
Right before(before_in);
Left after(after_in);
after.full = false;
float got = 0.0;
for (unsigned int i = 1; i < 5; ++i) {
if (before_in.length >= i) {
before.length = i;
got += RevealBefore(model, before, i - 1, false, between.left, between.right);
}
if (after_in.length >= i) {
after.length = i;
got += RevealAfter(model, between.left, between.right, after, i - 1);
}
}
if (after_in.full) {
after.full = true;
got += RevealAfter(model, between.left, between.right, after, after.length);
}
if (before_full) {
got += RevealBefore(model, before, before.length, true, between.left, between.right);
}
// Sometimes they're zero and BOOST_CHECK_CLOSE fails for this.
BOOST_CHECK(fabs(expect - got) < 0.001);
}
void FullDivide(const RestProbingModel &model, StringPiece str) {
std::vector<WordIndex> indices;
for (util::TokenIter<util::SingleCharacter, true> i(str, ' '); i; ++i) {
indices.push_back(model.GetVocabulary().Index(*i));
}
ChartState full_state;
float full = ScoreFragment(model, &indices.front(), &indices.back() + 1, full_state);
ChartState before_state;
before_state.left.full = false;
RuleScore<RestProbingModel> before_scorer(model, before_state);
float before_score = 0.0;
for (unsigned int before = 0; before < indices.size(); ++before) {
for (unsigned int after = before; after <= indices.size(); ++after) {
ChartState after_state, between_state;
float after_score = ScoreFragment(model, &indices.front() + after, &indices.front() + indices.size(), after_state);
float between_score = ScoreFragment(model, &indices.front() + before, &indices.front() + after, between_state);
CheckAdjustment(model, full - before_score - after_score - between_score, before_state.right, before_state.left.full, between_state, after_state.left);
}
before_scorer.Terminal(indices[before]);
before_score = before_scorer.Finish();
}
}
BOOST_AUTO_TEST_CASE(Strings) {
FullDivide(m, "also would consider");
FullDivide(m, "looking on a little more loin . </s>");
FullDivide(m, "in biarritz watching considering looking . on a little more loin also would consider higher to look good unknown the screening foo bar , unknown however unknown </s>");
}
BOOST_AUTO_TEST_SUITE_END()
} // namespace
} // namespace ngram
} // namespace lm
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment