Commit 688b6eac authored by SWHL's avatar SWHL
Browse files

Update files

parents
/* Quantize into bins of equal size as described in
* M. Federico and N. Bertoldi. 2006. How many bits are needed
* to store probabilities for phrase-based translation? In Proc.
* of the Workshop on Statistical Machine Translation, pages
* 94–101, New York City, June. Association for Computa-
* tional Linguistics.
*/
#include "quantize.hh"
#include "binary_format.hh"
#include "lm_exception.hh"
#include "../util/file.hh"
#include <algorithm>
#include <numeric>
namespace lm {
namespace ngram {
namespace {
void MakeBins(std::vector<float> &values, float *centers, uint32_t bins) {
std::sort(values.begin(), values.end());
std::vector<float>::const_iterator start = values.begin(), finish;
for (uint32_t i = 0; i < bins; ++i, ++centers, start = finish) {
finish = values.begin() + ((values.size() * static_cast<uint64_t>(i + 1)) / bins);
if (finish == start) {
// zero length bucket.
*centers = i ? *(centers - 1) : -std::numeric_limits<float>::infinity();
} else {
*centers = std::accumulate(start, finish, 0.0) / static_cast<float>(finish - start);
}
}
}
const char kSeparatelyQuantizeVersion = 2;
} // namespace
void SeparatelyQuantize::UpdateConfigFromBinary(const BinaryFormat &file, uint64_t offset, Config &config) {
unsigned char buffer[3];
file.ReadForConfig(buffer, 3, offset);
char version = buffer[0];
config.prob_bits = buffer[1];
config.backoff_bits = buffer[2];
if (version != kSeparatelyQuantizeVersion) UTIL_THROW(FormatLoadException, "This file has quantization version " << (unsigned)version << " but the code expects version " << (unsigned)kSeparatelyQuantizeVersion);
}
void SeparatelyQuantize::SetupMemory(void *base, unsigned char order, const Config &config) {
prob_bits_ = config.prob_bits;
backoff_bits_ = config.backoff_bits;
// We need the reserved values.
if (config.prob_bits == 0) UTIL_THROW(ConfigException, "You can't quantize probability to zero");
if (config.backoff_bits == 0) UTIL_THROW(ConfigException, "You can't quantize backoff to zero");
if (config.prob_bits > 25) UTIL_THROW(ConfigException, "For efficiency reasons, quantizing probability supports at most 25 bits. Currently you have requested " << static_cast<unsigned>(config.prob_bits) << " bits.");
if (config.backoff_bits > 25) UTIL_THROW(ConfigException, "For efficiency reasons, quantizing backoff supports at most 25 bits. Currently you have requested " << static_cast<unsigned>(config.backoff_bits) << " bits.");
// Reserve 8 byte header for bit counts.
actual_base_ = static_cast<uint8_t*>(base);
float *start = reinterpret_cast<float*>(actual_base_ + 8);
for (unsigned char i = 0; i < order - 2; ++i) {
tables_[i][0] = Bins(prob_bits_, start);
start += (1ULL << prob_bits_);
tables_[i][1] = Bins(backoff_bits_, start);
start += (1ULL << backoff_bits_);
}
longest_ = tables_[order - 2][0] = Bins(prob_bits_, start);
}
void SeparatelyQuantize::Train(uint8_t order, std::vector<float> &prob, std::vector<float> &backoff) {
TrainProb(order, prob);
// Backoff
float *centers = tables_[order - 2][1].Populate();
*(centers++) = kNoExtensionBackoff;
*(centers++) = kExtensionBackoff;
MakeBins(backoff, centers, (1ULL << backoff_bits_) - 2);
}
void SeparatelyQuantize::TrainProb(uint8_t order, std::vector<float> &prob) {
float *centers = tables_[order - 2][0].Populate();
MakeBins(prob, centers, (1ULL << prob_bits_));
}
void SeparatelyQuantize::FinishedLoading(const Config &config) {
uint8_t *actual_base = actual_base_;
*(actual_base++) = kSeparatelyQuantizeVersion; // version
*(actual_base++) = config.prob_bits;
*(actual_base++) = config.backoff_bits;
}
} // namespace ngram
} // namespace lm
#ifndef LM_QUANTIZE_H
#define LM_QUANTIZE_H
#include "blank.hh"
#include "config.hh"
#include "max_order.hh"
#include "model_type.hh"
#include "../util/bit_packing.hh"
#include <algorithm>
#include <vector>
#include <stdint.h>
#include <iostream>
namespace lm {
namespace ngram {
struct Config;
class BinaryFormat;
/* Store values directly and don't quantize. */
class DontQuantize {
public:
static const ModelType kModelTypeAdd = static_cast<ModelType>(0);
static void UpdateConfigFromBinary(const BinaryFormat &, uint64_t, Config &) {}
static uint64_t Size(uint8_t /*order*/, const Config &/*config*/) { return 0; }
static uint8_t MiddleBits(const Config &/*config*/) { return 63; }
static uint8_t LongestBits(const Config &/*config*/) { return 31; }
class MiddlePointer {
public:
MiddlePointer(const DontQuantize & /*quant*/, unsigned char /*order_minus_2*/, util::BitAddress address) : address_(address) {}
MiddlePointer() : address_(NULL, 0) {}
bool Found() const {
return address_.base != NULL;
}
float Prob() const {
return util::ReadNonPositiveFloat31(address_.base, address_.offset);
}
float Backoff() const {
return util::ReadFloat32(address_.base, address_.offset + 31);
}
float Rest() const { return Prob(); }
void Write(float prob, float backoff) {
util::WriteNonPositiveFloat31(address_.base, address_.offset, prob);
util::WriteFloat32(address_.base, address_.offset + 31, backoff);
}
private:
util::BitAddress address_;
};
class LongestPointer {
public:
explicit LongestPointer(const DontQuantize &/*quant*/, util::BitAddress address) : address_(address) {}
LongestPointer() : address_(NULL, 0) {}
bool Found() const {
return address_.base != NULL;
}
float Prob() const {
return util::ReadNonPositiveFloat31(address_.base, address_.offset);
}
void Write(float prob) {
util::WriteNonPositiveFloat31(address_.base, address_.offset, prob);
}
private:
util::BitAddress address_;
};
DontQuantize() {}
void SetupMemory(void * /*start*/, unsigned char /*order*/, const Config & /*config*/) {}
static const bool kTrain = false;
// These should never be called because kTrain is false.
void Train(uint8_t /*order*/, std::vector<float> &/*prob*/, std::vector<float> &/*backoff*/) {}
void TrainProb(uint8_t, std::vector<float> &/*prob*/) {}
void FinishedLoading(const Config &) {}
};
class SeparatelyQuantize {
private:
class Bins {
public:
// Sigh C++ default constructor
Bins() {}
Bins(uint8_t bits, float *begin) : begin_(begin), end_(begin_ + (1ULL << bits)), bits_(bits), mask_((1ULL << bits) - 1) {}
float *Populate() { return begin_; }
uint64_t EncodeProb(float value) const {
return Encode(value, 0);
}
uint64_t EncodeBackoff(float value) const {
if (value == 0.0) {
return HasExtension(value) ? kExtensionQuant : kNoExtensionQuant;
}
return Encode(value, 2);
}
float Decode(std::size_t off) const { return begin_[off]; }
uint8_t Bits() const { return bits_; }
uint64_t Mask() const { return mask_; }
private:
uint64_t Encode(float value, size_t reserved) const {
const float *above = std::lower_bound(static_cast<const float*>(begin_) + reserved, end_, value);
if (above == begin_ + reserved) return reserved;
if (above == end_) return end_ - begin_ - 1;
return above - begin_ - (value - *(above - 1) < *above - value);
}
float *begin_;
const float *end_;
uint8_t bits_;
uint64_t mask_;
};
public:
static const ModelType kModelTypeAdd = kQuantAdd;
static void UpdateConfigFromBinary(const BinaryFormat &file, uint64_t offset, Config &config);
static uint64_t Size(uint8_t order, const Config &config) {
uint64_t longest_table = (static_cast<uint64_t>(1) << static_cast<uint64_t>(config.prob_bits)) * sizeof(float);
uint64_t middle_table = (static_cast<uint64_t>(1) << static_cast<uint64_t>(config.backoff_bits)) * sizeof(float) + longest_table;
// unigrams are currently not quantized so no need for a table.
return (order - 2) * middle_table + longest_table + /* for the bit counts and alignment padding) */ 8;
}
static uint8_t MiddleBits(const Config &config) { return config.prob_bits + config.backoff_bits; }
static uint8_t LongestBits(const Config &config) { return config.prob_bits; }
class MiddlePointer {
public:
MiddlePointer(const SeparatelyQuantize &quant, unsigned char order_minus_2, const util::BitAddress &address) : bins_(quant.GetTables(order_minus_2)), address_(address) {}
MiddlePointer() : address_(NULL, 0) {}
bool Found() const { return address_.base != NULL; }
float Prob() const {
return ProbBins().Decode(util::ReadInt25(address_.base, address_.offset + BackoffBins().Bits(), ProbBins().Bits(), ProbBins().Mask()));
}
float Backoff() const {
return BackoffBins().Decode(util::ReadInt25(address_.base, address_.offset, BackoffBins().Bits(), BackoffBins().Mask()));
}
float Rest() const { return Prob(); }
void Write(float prob, float backoff) const {
uint64_t prob_encoded = ProbBins().EncodeProb(prob);
uint64_t backoff_encoded = BackoffBins().EncodeBackoff(backoff);
#if BYTE_ORDER == LITTLE_ENDIAN
prob_encoded <<= BackoffBins().Bits();
#elif BYTE_ORDER == BIG_ENDIAN
backoff_encoded <<= ProbBins().Bits();
#endif
util::WriteInt57(address_.base, address_.offset, ProbBins().Bits() + BackoffBins().Bits(),
prob_encoded | backoff_encoded);
}
private:
const Bins &ProbBins() const { return bins_[0]; }
const Bins &BackoffBins() const { return bins_[1]; }
const Bins *bins_;
util::BitAddress address_;
};
class LongestPointer {
public:
LongestPointer(const SeparatelyQuantize &quant, const util::BitAddress &address) : table_(&quant.LongestTable()), address_(address) {}
LongestPointer() : address_(NULL, 0) {}
bool Found() const { return address_.base != NULL; }
void Write(float prob) const {
util::WriteInt25(address_.base, address_.offset, table_->Bits(), table_->EncodeProb(prob));
}
float Prob() const {
return table_->Decode(util::ReadInt25(address_.base, address_.offset, table_->Bits(), table_->Mask()));
}
private:
const Bins *table_;
util::BitAddress address_;
};
SeparatelyQuantize() {}
void SetupMemory(void *start, unsigned char order, const Config &config);
static const bool kTrain = true;
// Assumes 0.0 is removed from backoff.
void Train(uint8_t order, std::vector<float> &prob, std::vector<float> &backoff);
// Train just probabilities (for longest order).
void TrainProb(uint8_t order, std::vector<float> &prob);
void FinishedLoading(const Config &config);
const Bins *GetTables(unsigned char order_minus_2) const { return tables_[order_minus_2]; }
const Bins &LongestTable() const { return longest_; }
private:
Bins tables_[KENLM_MAX_ORDER - 1][2];
Bins longest_;
uint8_t *actual_base_;
uint8_t prob_bits_, backoff_bits_;
};
} // namespace ngram
} // namespace lm
#endif // LM_QUANTIZE_H
#include "ngram_query.hh"
#include "../util/getopt.hh"
#ifdef WITH_NPLM
#include "wrappers/nplm.hh"
#endif
#include <stdlib.h>
void Usage(const char *name) {
std::cerr <<
"KenLM was compiled with maximum order " << KENLM_MAX_ORDER << ".\n"
"Usage: " << name << " [-b] [-n] [-w] [-s] lm_file\n"
"-b: Do not buffer output.\n"
"-n: Do not wrap the input in <s> and </s>.\n"
"-v summary|sentence|word: Print statistics at this level.\n"
" Can be used multiple times: -v summary -v sentence -v word\n"
"-l lazy|populate|read|parallel: Load lazily, with populate, or malloc+read\n"
"The default loading method is populate on Linux and read on others.\n\n"
"Each word in the output is formatted as:\n"
" word=vocab_id ngram_length log10(p(word|context))\n"
"where ngram_length is the length of n-gram matched. A vocab_id of 0 indicates\n"
"the unknown word. Sentence-level output includes log10 probability of the\n"
"sentence and OOV count.\n";
exit(1);
}
int main(int argc, char *argv[]) {
if (argc == 1 || (argc == 2 && !strcmp(argv[1], "--help")))
Usage(argv[0]);
lm::ngram::Config config;
bool sentence_context = true;
bool print_word = false;
bool print_line = false;
bool print_summary = false;
bool flush = false;
int opt;
while ((opt = getopt(argc, argv, "bnv:l:")) != -1) {
switch (opt) {
case 'b':
flush = true;
break;
case 'n':
sentence_context = false;
break;
case 'v':
if (!strcmp(optarg, "2")) {
print_word = true;
print_line = true;
print_summary = true;
} else if (!strcmp(optarg, "1")) {
print_word = false;
print_line = true;
print_summary = true;
} else if (!strcmp(optarg, "0")) {
print_word = false;
print_line = false;
print_summary = true;
} else if (!strcmp(optarg, "word")) {
print_word = true;
} else if (!strcmp(optarg, "sentence")) {
print_line = true;
} else if (!strcmp(optarg, "summary")) {
print_summary = true;
} else {
Usage(argv[0]);
}
break;
case 'l':
if (!strcmp(optarg, "lazy")) {
config.load_method = util::LAZY;
} else if (!strcmp(optarg, "populate")) {
config.load_method = util::POPULATE_OR_READ;
} else if (!strcmp(optarg, "read")) {
config.load_method = util::READ;
} else if (!strcmp(optarg, "parallel")) {
config.load_method = util::PARALLEL_READ;
} else {
Usage(argv[0]);
}
break;
case 'h':
default:
Usage(argv[0]);
}
}
if (optind + 1 != argc)
Usage(argv[0]);
// No verbosity argument specified.
if (!print_word && !print_line && !print_summary) {
print_word = true;
print_line = true;
print_summary = true;
}
lm::ngram::QueryPrinter printer(1, print_word, print_line, print_summary, flush);
const char *file = argv[optind];
try {
using namespace lm::ngram;
ModelType model_type;
if (RecognizeBinary(file, model_type)) {
std::cerr << "This binary file contains " << lm::ngram::kModelNames[model_type] << "." << std::endl;
switch(model_type) {
case PROBING:
Query<lm::ngram::ProbingModel>(file, config, sentence_context, printer);
break;
case REST_PROBING:
Query<lm::ngram::RestProbingModel>(file, config, sentence_context, printer);
break;
case TRIE:
Query<TrieModel>(file, config, sentence_context, printer);
break;
case QUANT_TRIE:
Query<QuantTrieModel>(file, config, sentence_context, printer);
break;
case ARRAY_TRIE:
Query<ArrayTrieModel>(file, config, sentence_context, printer);
break;
case QUANT_ARRAY_TRIE:
Query<QuantArrayTrieModel>(file, config, sentence_context, printer);
break;
default:
std::cerr << "Unrecognized kenlm model type " << model_type << std::endl;
abort();
}
#ifdef WITH_NPLM
} else if (lm::np::Model::Recognize(file)) {
lm::np::Model model(file);
Query<lm::np::Model, lm::ngram::QueryPrinter>(model, sentence_context, printer);
Query<lm::np::Model, lm::ngram::QueryPrinter>(model, sentence_context, printer);
#endif
} else {
Query<ProbingModel>(file, config, sentence_context, printer);
}
util::PrintUsage(std::cerr);
} catch (const std::exception &e) {
std::cerr << e.what() << std::endl;
return 1;
}
return 0;
}
#include "read_arpa.hh"
#include "blank.hh"
#include "../util/file.hh"
#include <cmath>
#include <cstdlib>
#include <iostream>
#include <sstream>
#include <vector>
#include <cctype>
#include <cstring>
#include <stdint.h>
#ifdef WIN32
#include <float.h>
#endif
namespace lm {
// 1 for '\t', '\n', '\r', and ' '. This is stricter than isspace. Apparently ARPA allows vertical tab inside a word.
const bool kARPASpaces[256] = {0,0,0,0,0,0,0,0,0,1,1,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0};
namespace {
bool IsEntirelyWhiteSpace(const StringPiece &line) {
for (size_t i = 0; i < static_cast<size_t>(line.size()); ++i) {
if (!isspace(line.data()[i])) return false;
}
return true;
}
const char kBinaryMagic[] = "mmap lm http://kheafield.com/code";
// strtoull isn't portable enough :-(
uint64_t ReadCount(const std::string &from) {
std::stringstream stream(from);
uint64_t ret;
stream >> ret;
UTIL_THROW_IF(!stream, FormatLoadException, "Bad count " << from);
return ret;
}
} // namespace
void ReadARPACounts(util::FilePiece &in, std::vector<uint64_t> &number) {
number.clear();
StringPiece line = in.ReadLine();
// In general, ARPA files can have arbitrary text before "\data\"
// But in KenLM, we require such lines to start with "#", so that
// we can do stricter error checking
while (IsEntirelyWhiteSpace(line) || starts_with(line, "#")) {
line = in.ReadLine();
}
if (line != "\\data\\") {
if ((line.size() >= 2) && (line.data()[0] == 0x1f) && (static_cast<unsigned char>(line.data()[1]) == 0x8b)) {
UTIL_THROW(FormatLoadException, "Looks like a gzip file. If this is an ARPA file, pipe " << in.FileName() << " through zcat. If this already in binary format, you need to decompress it because mmap doesn't work on top of gzip.");
}
if (static_cast<size_t>(line.size()) >= strlen(kBinaryMagic) && StringPiece(line.data(), strlen(kBinaryMagic)) == kBinaryMagic)
UTIL_THROW(FormatLoadException, "This looks like a binary file but got sent to the ARPA parser. Did you compress the binary file or pass a binary file where only ARPA files are accepted?");
UTIL_THROW_IF(line.size() >= 4 && StringPiece(line.data(), 4) == "blmt", FormatLoadException, "This looks like an IRSTLM binary file. Did you forget to pass --text yes to compile-lm?");
UTIL_THROW_IF(line == "iARPA", FormatLoadException, "This looks like an IRSTLM iARPA file. You need an ARPA file. Run\n compile-lm --text yes " << in.FileName() << " " << in.FileName() << ".arpa\nfirst.");
UTIL_THROW(FormatLoadException, "first non-empty line was \"" << line << "\" not \\data\\.");
}
while (!IsEntirelyWhiteSpace(line = in.ReadLine())) {
if (line.size() < 6 || strncmp(line.data(), "ngram ", 6)) UTIL_THROW(FormatLoadException, "count line \"" << line << "\"doesn't begin with \"ngram \"");
// So strtol doesn't go off the end of line.
std::string remaining(line.data() + 6, line.size() - 6);
char *end_ptr;
unsigned int length = std::strtol(remaining.c_str(), &end_ptr, 10);
if ((end_ptr == remaining.c_str()) || (length - 1 != number.size())) UTIL_THROW(FormatLoadException, "ngram count lengths should be consecutive starting with 1: " << line);
if (*end_ptr != '=') UTIL_THROW(FormatLoadException, "Expected = immediately following the first number in the count line " << line);
++end_ptr;
number.push_back(ReadCount(end_ptr));
}
}
void ReadNGramHeader(util::FilePiece &in, unsigned int length) {
StringPiece line;
while (IsEntirelyWhiteSpace(line = in.ReadLine())) {}
std::stringstream expected;
expected << '\\' << length << "-grams:";
if (line != expected.str()) UTIL_THROW(FormatLoadException, "Was expecting n-gram header " << expected.str() << " but got " << line << " instead");
}
void ConsumeNewline(util::FilePiece &in) {
char follow = in.get();
UTIL_THROW_IF('\n' != follow, FormatLoadException, "Expected newline got '" << follow << "'");
}
void ReadBackoff(util::FilePiece &in, Prob &/*weights*/) {
switch (in.get()) {
case '\t':
{
float got = in.ReadFloat();
if (got != 0.0)
UTIL_THROW(FormatLoadException, "Non-zero backoff " << got << " provided for an n-gram that should have no backoff");
}
break;
case '\r':
ConsumeNewline(in);
// Intentionally no break.
case '\n':
break;
default:
UTIL_THROW(FormatLoadException, "Expected tab or newline for backoff");
}
}
void ReadBackoff(util::FilePiece &in, float &backoff) {
// Always make zero negative.
// Negative zero means that no (n+1)-gram has this n-gram as context.
// Therefore the hypothesis state can be shorter. Of course, many n-grams
// are context for (n+1)-grams. An algorithm in the data structure will go
// back and set the backoff to positive zero in these cases.
switch (in.get()) {
case '\t':
backoff = in.ReadFloat();
if (backoff == ngram::kExtensionBackoff) backoff = ngram::kNoExtensionBackoff;
{
#if defined(WIN32) && !defined(__MINGW32__)
int float_class = _fpclass(backoff);
UTIL_THROW_IF(float_class == _FPCLASS_SNAN || float_class == _FPCLASS_QNAN || float_class == _FPCLASS_NINF || float_class == _FPCLASS_PINF, FormatLoadException, "Bad backoff " << backoff);
#else
int float_class = std::fpclassify(backoff);
UTIL_THROW_IF(float_class == FP_NAN || float_class == FP_INFINITE, FormatLoadException, "Bad backoff " << backoff);
#endif
}
switch (char got = in.get()) {
case '\r':
ConsumeNewline(in);
case '\n':
break;
default:
UTIL_THROW(FormatLoadException, "Expected newline after backoffs, got " << got);
}
break;
case '\r':
ConsumeNewline(in);
// Intentionally no break.
case '\n':
backoff = ngram::kNoExtensionBackoff;
break;
default:
UTIL_THROW(FormatLoadException, "Expected tab or newline for backoff");
}
}
void ReadEnd(util::FilePiece &in) {
StringPiece line;
do {
line = in.ReadLine();
} while (IsEntirelyWhiteSpace(line));
if (line != "\\end\\") UTIL_THROW(FormatLoadException, "Expected \\end\\ but the ARPA file has " << line);
try {
while (true) {
line = in.ReadLine();
if (!IsEntirelyWhiteSpace(line)) UTIL_THROW(FormatLoadException, "Trailing line " << line);
}
} catch (const util::EndOfFileException &) {}
}
void PositiveProbWarn::Warn(float prob) {
switch (action_) {
case THROW_UP:
UTIL_THROW(FormatLoadException, "Positive log probability " << prob << " in the model. This is a bug in IRSTLM; you can set config.positive_log_probability = SILENT or pass -i to build_binary to substitute 0.0 for the log probability. Error");
case COMPLAIN:
std::cerr << "There's a positive log probability " << prob << " in the APRA file, probably because of a bug in IRSTLM. This and subsequent entires will be mapped to 0 log probability." << std::endl;
action_ = SILENT;
break;
case SILENT:
break;
}
}
} // namespace lm
#ifndef LM_READ_ARPA_H
#define LM_READ_ARPA_H
#include "lm_exception.hh"
#include "word_index.hh"
#include "weights.hh"
#include "../util/file_piece.hh"
#include <cstddef>
#include <iosfwd>
#include <vector>
namespace lm {
void ReadARPACounts(util::FilePiece &in, std::vector<uint64_t> &number);
void ReadNGramHeader(util::FilePiece &in, unsigned int length);
void ReadBackoff(util::FilePiece &in, Prob &weights);
void ReadBackoff(util::FilePiece &in, float &backoff);
inline void ReadBackoff(util::FilePiece &in, ProbBackoff &weights) {
ReadBackoff(in, weights.backoff);
}
inline void ReadBackoff(util::FilePiece &in, RestWeights &weights) {
ReadBackoff(in, weights.backoff);
}
void ReadEnd(util::FilePiece &in);
extern const bool kARPASpaces[256];
// Positive log probability warning.
class PositiveProbWarn {
public:
PositiveProbWarn() : action_(THROW_UP) {}
explicit PositiveProbWarn(WarningAction action) : action_(action) {}
void Warn(float prob);
private:
WarningAction action_;
};
template <class Voc, class Weights> void Read1Gram(util::FilePiece &f, Voc &vocab, Weights *unigrams, PositiveProbWarn &warn) {
try {
float prob = f.ReadFloat();
if (prob > 0.0) {
warn.Warn(prob);
prob = 0.0;
}
UTIL_THROW_IF(f.get() != '\t', FormatLoadException, "Expected tab after probability");
WordIndex word = vocab.Insert(f.ReadDelimited(kARPASpaces));
Weights &w = unigrams[word];
w.prob = prob;
ReadBackoff(f, w);
} catch(util::Exception &e) {
e << " in the 1-gram at byte " << f.Offset();
throw;
}
}
template <class Voc, class Weights> void Read1Grams(util::FilePiece &f, std::size_t count, Voc &vocab, Weights *unigrams, PositiveProbWarn &warn) {
ReadNGramHeader(f, 1);
for (std::size_t i = 0; i < count; ++i) {
Read1Gram(f, vocab, unigrams, warn);
}
vocab.FinishedLoading(unigrams);
}
// Read ngram, write vocab ids to indices_out.
template <class Voc, class Weights, class Iterator> void ReadNGram(util::FilePiece &f, const unsigned char n, const Voc &vocab, Iterator indices_out, Weights &weights, PositiveProbWarn &warn) {
try {
weights.prob = f.ReadFloat();
if (weights.prob > 0.0) {
warn.Warn(weights.prob);
weights.prob = 0.0;
}
for (unsigned char i = 0; i < n; ++i, ++indices_out) {
StringPiece word(f.ReadDelimited(kARPASpaces));
WordIndex index = vocab.Index(word);
*indices_out = index;
// Check for words mapped to <unk> that are not the string <unk>.
UTIL_THROW_IF(index == 0 /* mapped to <unk> */ && (word != StringPiece("<unk>", 5)) && (word != StringPiece("<UNK>", 5)),
FormatLoadException, "Word " << word << " was not seen in the unigrams (which are supposed to list the entire vocabulary) but appears");
}
ReadBackoff(f, weights);
} catch(util::Exception &e) {
e << " in the " << static_cast<unsigned int>(n) << "-gram at byte " << f.Offset();
throw;
}
}
} // namespace lm
#endif // LM_READ_ARPA_H
#ifndef LM_RETURN_H
#define LM_RETURN_H
#include <stdint.h>
namespace lm {
/* Structure returned by scoring routines. */
struct FullScoreReturn {
// log10 probability
float prob;
/* The length of n-gram matched. Do not use this for recombination.
* Consider a model containing only the following n-grams:
* -1 foo
* -3.14 bar
* -2.718 baz -5
* -6 foo bar
*
* If you score ``bar'' then ngram_length is 1 and recombination state is the
* empty string because bar has zero backoff and does not extend to the
* right.
* If you score ``foo'' then ngram_length is 1 and recombination state is
* ``foo''.
*
* Ideally, keep output states around and compare them. Failing that,
* get out_state.ValidLength() and use that length for recombination.
*/
unsigned char ngram_length;
/* Left extension information. If independent_left is set, then prob is
* independent of words to the left (up to additional backoff). Otherwise,
* extend_left indicates how to efficiently extend further to the left.
*/
bool independent_left;
uint64_t extend_left; // Defined only if independent_left
// Rest cost for extension to the left.
float rest;
};
} // namespace lm
#endif // LM_RETURN_H
#include "search_hashed.hh"
#include "binary_format.hh"
#include "blank.hh"
#include "lm_exception.hh"
#include "model.hh"
#include "read_arpa.hh"
#include "value.hh"
#include "vocab.hh"
#include "../util/bit_packing.hh"
#include "../util/file_piece.hh"
#include <string>
namespace lm {
namespace ngram {
class ProbingModel;
namespace {
/* These are passed to ReadNGrams so that n-grams with zero backoff that appear as context will still be used in state. */
template <class Middle> class ActivateLowerMiddle {
public:
explicit ActivateLowerMiddle(Middle &middle) : modify_(middle) {}
void operator()(const WordIndex *vocab_ids, const unsigned int n) {
uint64_t hash = static_cast<WordIndex>(vocab_ids[1]);
for (const WordIndex *i = vocab_ids + 2; i < vocab_ids + n; ++i) {
hash = detail::CombineWordHash(hash, *i);
}
typename Middle::MutableIterator i;
// TODO: somehow get text of n-gram for this error message.
if (!modify_.UnsafeMutableFind(hash, i))
UTIL_THROW(FormatLoadException, "The context of every " << n << "-gram should appear as a " << (n-1) << "-gram");
SetExtension(i->value.backoff);
}
private:
Middle &modify_;
};
template <class Weights> class ActivateUnigram {
public:
explicit ActivateUnigram(Weights *unigram) : modify_(unigram) {}
void operator()(const WordIndex *vocab_ids, const unsigned int /*n*/) {
// assert(n == 2);
SetExtension(modify_[vocab_ids[1]].backoff);
}
private:
Weights *modify_;
};
// Find the lower order entry, inserting blanks along the way as necessary.
template <class Value> void FindLower(
const std::vector<uint64_t> &keys,
typename Value::Weights &unigram,
std::vector<util::ProbingHashTable<typename Value::ProbingEntry, util::IdentityHash> > &middle,
std::vector<typename Value::Weights *> &between) {
typename util::ProbingHashTable<typename Value::ProbingEntry, util::IdentityHash>::MutableIterator iter;
typename Value::ProbingEntry entry;
// Backoff will always be 0.0. We'll get the probability and rest in another pass.
entry.value.backoff = kNoExtensionBackoff;
// Go back and find the longest right-aligned entry, informing it that it extends left. Normally this will match immediately, but sometimes SRI is dumb.
for (int lower = keys.size() - 2; ; --lower) {
if (lower == -1) {
between.push_back(&unigram);
return;
}
entry.key = keys[lower];
bool found = middle[lower].FindOrInsert(entry, iter);
between.push_back(&iter->value);
if (found) return;
}
}
// Between usually has single entry, the value to adjust. But sometimes SRI stupidly pruned entries so it has unitialized blank values to be set here.
template <class Added, class Build> void AdjustLower(
const Added &added,
const Build &build,
std::vector<typename Build::Value::Weights *> &between,
const unsigned int n,
const std::vector<WordIndex> &vocab_ids,
typename Build::Value::Weights *unigrams,
std::vector<util::ProbingHashTable<typename Build::Value::ProbingEntry, util::IdentityHash> > &middle) {
typedef typename Build::Value Value;
if (between.size() == 1) {
build.MarkExtends(*between.front(), added);
return;
}
typedef util::ProbingHashTable<typename Value::ProbingEntry, util::IdentityHash> Middle;
float prob = -fabs(between.back()->prob);
// Order of the n-gram on which probabilities are based.
unsigned char basis = n - between.size();
assert(basis != 0);
typename Build::Value::Weights **change = &between.back();
// Skip the basis.
--change;
if (basis == 1) {
// Hallucinate a bigram based on a unigram's backoff and a unigram probability.
float &backoff = unigrams[vocab_ids[1]].backoff;
SetExtension(backoff);
prob += backoff;
(*change)->prob = prob;
build.SetRest(&*vocab_ids.begin(), 2, **change);
basis = 2;
--change;
}
uint64_t backoff_hash = static_cast<uint64_t>(vocab_ids[1]);
for (unsigned char i = 2; i <= basis; ++i) {
backoff_hash = detail::CombineWordHash(backoff_hash, vocab_ids[i]);
}
for (; basis < n - 1; ++basis, --change) {
typename Middle::MutableIterator gotit;
if (middle[basis - 2].UnsafeMutableFind(backoff_hash, gotit)) {
float &backoff = gotit->value.backoff;
SetExtension(backoff);
prob += backoff;
}
(*change)->prob = prob;
build.SetRest(&*vocab_ids.begin(), basis + 1, **change);
backoff_hash = detail::CombineWordHash(backoff_hash, vocab_ids[basis+1]);
}
typename std::vector<typename Value::Weights *>::const_iterator i(between.begin());
build.MarkExtends(**i, added);
const typename Value::Weights *longer = *i;
// Everything has probability but is not marked as extending.
for (++i; i != between.end(); ++i) {
build.MarkExtends(**i, *longer);
longer = *i;
}
}
// Continue marking lower entries even they know that they extend left. This is used for upper/lower bounds.
template <class Build> void MarkLower(
const std::vector<uint64_t> &keys,
const Build &build,
typename Build::Value::Weights &unigram,
std::vector<util::ProbingHashTable<typename Build::Value::ProbingEntry, util::IdentityHash> > &middle,
int start_order,
const typename Build::Value::Weights &longer) {
if (start_order == 0) return;
// Hopefully the compiler will realize that if MarkExtends always returns false, it can simplify this code.
for (int even_lower = start_order - 2 /* index in middle */; ; --even_lower) {
if (even_lower == -1) {
build.MarkExtends(unigram, longer);
return;
}
if (!build.MarkExtends(
middle[even_lower].UnsafeMutableMustFind(keys[even_lower])->value,
longer)) return;
}
}
template <class Build, class Activate, class Store> void ReadNGrams(
util::FilePiece &f,
const unsigned int n,
const size_t count,
const ProbingVocabulary &vocab,
const Build &build,
typename Build::Value::Weights *unigrams,
std::vector<util::ProbingHashTable<typename Build::Value::ProbingEntry, util::IdentityHash> > &middle,
Activate activate,
Store &store,
PositiveProbWarn &warn) {
typedef typename Build::Value Value;
assert(n >= 2);
ReadNGramHeader(f, n);
// Both vocab_ids and keys are non-empty because n >= 2.
// vocab ids of words in reverse order.
std::vector<WordIndex> vocab_ids(n);
std::vector<uint64_t> keys(n-1);
typename Store::Entry entry;
std::vector<typename Value::Weights *> between;
for (size_t i = 0; i < count; ++i) {
ReadNGram(f, n, vocab, vocab_ids.rbegin(), entry.value, warn);
build.SetRest(&*vocab_ids.begin(), n, entry.value);
keys[0] = detail::CombineWordHash(static_cast<uint64_t>(vocab_ids.front()), vocab_ids[1]);
for (unsigned int h = 1; h < n - 1; ++h) {
keys[h] = detail::CombineWordHash(keys[h-1], vocab_ids[h+1]);
}
// Initially the sign bit is on, indicating it does not extend left. Most already have this but there might +0.0.
util::SetSign(entry.value.prob);
entry.key = keys[n-2];
store.Insert(entry);
between.clear();
FindLower<Value>(keys, unigrams[vocab_ids.front()], middle, between);
AdjustLower<typename Store::Entry::Value, Build>(entry.value, build, between, n, vocab_ids, unigrams, middle);
if (Build::kMarkEvenLower) MarkLower<Build>(keys, build, unigrams[vocab_ids.front()], middle, n - between.size() - 1, *between.back());
activate(&*vocab_ids.begin(), n);
}
store.FinishedInserting();
}
} // namespace
namespace detail {
template <class Value> uint8_t *HashedSearch<Value>::SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config) {
unigram_ = Unigram(start, counts[0]);
start += Unigram::Size(counts[0]);
std::size_t allocated;
middle_.clear();
for (unsigned int n = 2; n < counts.size(); ++n) {
allocated = Middle::Size(counts[n - 1], config.probing_multiplier);
middle_.push_back(Middle(start, allocated));
start += allocated;
}
allocated = Longest::Size(counts.back(), config.probing_multiplier);
longest_ = Longest(start, allocated);
start += allocated;
return start;
}
/*template <class Value> void HashedSearch<Value>::Relocate(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config) {
unigram_ = Unigram(start, counts[0]);
start += Unigram::Size(counts[0]);
for (unsigned int n = 2; n < counts.size(); ++n) {
middle[n-2].Relocate(start);
start += Middle::Size(counts[n - 1], config.probing_multiplier)
}
longest_.Relocate(start);
}*/
template <class Value> void HashedSearch<Value>::InitializeFromARPA(const char * /*file*/, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, ProbingVocabulary &vocab, BinaryFormat &backing) {
void *vocab_rebase;
void *search_base = backing.GrowForSearch(Size(counts, config), vocab.UnkCountChangePadding(), vocab_rebase);
vocab.Relocate(vocab_rebase);
SetupMemory(reinterpret_cast<uint8_t*>(search_base), counts, config);
PositiveProbWarn warn(config.positive_log_probability);
Read1Grams(f, counts[0], vocab, unigram_.Raw(), warn);
CheckSpecials(config, vocab);
DispatchBuild(f, counts, config, vocab, warn);
}
template <> void HashedSearch<BackoffValue>::DispatchBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn) {
NoRestBuild build;
ApplyBuild(f, counts, vocab, warn, build);
}
template <> void HashedSearch<RestValue>::DispatchBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn) {
switch (config.rest_function) {
case Config::REST_MAX:
{
MaxRestBuild build;
ApplyBuild(f, counts, vocab, warn, build);
}
break;
case Config::REST_LOWER:
{
LowerRestBuild<ProbingModel> build(config, counts.size(), vocab);
ApplyBuild(f, counts, vocab, warn, build);
}
break;
}
}
template <class Value> template <class Build> void HashedSearch<Value>::ApplyBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const ProbingVocabulary &vocab, PositiveProbWarn &warn, const Build &build) {
for (WordIndex i = 0; i < counts[0]; ++i) {
build.SetRest(&i, (unsigned int)1, unigram_.Raw()[i]);
}
try {
if (counts.size() > 2) {
ReadNGrams<Build, ActivateUnigram<typename Value::Weights>, Middle>(
f, 2, counts[1], vocab, build, unigram_.Raw(), middle_, ActivateUnigram<typename Value::Weights>(unigram_.Raw()), middle_[0], warn);
}
for (unsigned int n = 3; n < counts.size(); ++n) {
ReadNGrams<Build, ActivateLowerMiddle<Middle>, Middle>(
f, n, counts[n-1], vocab, build, unigram_.Raw(), middle_, ActivateLowerMiddle<Middle>(middle_[n-3]), middle_[n-2], warn);
}
if (counts.size() > 2) {
ReadNGrams<Build, ActivateLowerMiddle<Middle>, Longest>(
f, counts.size(), counts[counts.size() - 1], vocab, build, unigram_.Raw(), middle_, ActivateLowerMiddle<Middle>(middle_.back()), longest_, warn);
} else {
ReadNGrams<Build, ActivateUnigram<typename Value::Weights>, Longest>(
f, counts.size(), counts[counts.size() - 1], vocab, build, unigram_.Raw(), middle_, ActivateUnigram<typename Value::Weights>(unigram_.Raw()), longest_, warn);
}
} catch (util::ProbingSizeException &e) {
UTIL_THROW(util::ProbingSizeException, "Avoid pruning n-grams like \"bar baz quux\" when \"foo bar baz quux\" is still in the model. KenLM will work when this pruning happens, but the probing model assumes these events are rare enough that using blank space in the probing hash table will cover all of them. Increase probing_multiplier (-p to build_binary) to add more blank spaces.\n");
}
ReadEnd(f);
}
template class HashedSearch<BackoffValue>;
template class HashedSearch<RestValue>;
} // namespace detail
} // namespace ngram
} // namespace lm
#ifndef LM_SEARCH_HASHED_H
#define LM_SEARCH_HASHED_H
#include "model_type.hh"
#include "config.hh"
#include "read_arpa.hh"
#include "return.hh"
#include "weights.hh"
#include "../util/bit_packing.hh"
#include "../util/probing_hash_table.hh"
#include <algorithm>
#include <iostream>
#include <vector>
namespace util { class FilePiece; }
namespace lm {
namespace ngram {
class BinaryFormat;
class ProbingVocabulary;
namespace detail {
inline uint64_t CombineWordHash(uint64_t current, const WordIndex next) {
uint64_t ret = (current * 8978948897894561157ULL) ^ (static_cast<uint64_t>(1 + next) * 17894857484156487943ULL);
return ret;
}
#pragma pack(push)
#pragma pack(4)
struct ProbEntry {
uint64_t key;
Prob value;
typedef uint64_t Key;
typedef Prob Value;
uint64_t GetKey() const {
return key;
}
};
#pragma pack(pop)
class LongestPointer {
public:
explicit LongestPointer(const float &to) : to_(&to) {}
LongestPointer() : to_(NULL) {}
bool Found() const {
return to_ != NULL;
}
float Prob() const {
return *to_;
}
private:
const float *to_;
};
template <class Value> class HashedSearch {
public:
typedef uint64_t Node;
typedef typename Value::ProbingProxy UnigramPointer;
typedef typename Value::ProbingProxy MiddlePointer;
typedef ::lm::ngram::detail::LongestPointer LongestPointer;
static const ModelType kModelType = Value::kProbingModelType;
static const bool kDifferentRest = Value::kDifferentRest;
static const unsigned int kVersion = 0;
// TODO: move probing_multiplier here with next binary file format update.
static void UpdateConfigFromBinary(const BinaryFormat &, const std::vector<uint64_t> &, uint64_t, Config &) {}
static uint64_t Size(const std::vector<uint64_t> &counts, const Config &config) {
uint64_t ret = Unigram::Size(counts[0]);
for (unsigned char n = 1; n < counts.size() - 1; ++n) {
ret += Middle::Size(counts[n], config.probing_multiplier);
}
return ret + Longest::Size(counts.back(), config.probing_multiplier);
}
uint8_t *SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config);
void InitializeFromARPA(const char *file, util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, ProbingVocabulary &vocab, BinaryFormat &backing);
unsigned char Order() const {
return middle_.size() + 2;
}
typename Value::Weights &UnknownUnigram() { return unigram_.Unknown(); }
UnigramPointer LookupUnigram(WordIndex word, Node &next, bool &independent_left, uint64_t &extend_left) const {
extend_left = static_cast<uint64_t>(word);
next = extend_left;
UnigramPointer ret(unigram_.Lookup(word));
independent_left = ret.IndependentLeft();
return ret;
}
MiddlePointer Unpack(uint64_t extend_pointer, unsigned char extend_length, Node &node) const {
node = extend_pointer;
return MiddlePointer(middle_[extend_length - 2].MustFind(extend_pointer)->value);
}
MiddlePointer LookupMiddle(unsigned char order_minus_2, WordIndex word, Node &node, bool &independent_left, uint64_t &extend_pointer) const {
node = CombineWordHash(node, word);
typename Middle::ConstIterator found;
if (!middle_[order_minus_2].Find(node, found)) {
independent_left = true;
return MiddlePointer();
}
extend_pointer = node;
MiddlePointer ret(found->value);
independent_left = ret.IndependentLeft();
return ret;
}
LongestPointer LookupLongest(WordIndex word, const Node &node) const {
// Sign bit is always on because longest n-grams do not extend left.
typename Longest::ConstIterator found;
if (!longest_.Find(CombineWordHash(node, word), found)) return LongestPointer();
return LongestPointer(found->value.prob);
}
// Generate a node without necessarily checking that it actually exists.
// Optionally return false if it's know to not exist.
bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const {
assert(begin != end);
node = static_cast<Node>(*begin);
for (const WordIndex *i = begin + 1; i < end; ++i) {
node = CombineWordHash(node, *i);
}
return true;
}
private:
// Interpret config's rest cost build policy and pass the right template argument to ApplyBuild.
void DispatchBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const Config &config, const ProbingVocabulary &vocab, PositiveProbWarn &warn);
template <class Build> void ApplyBuild(util::FilePiece &f, const std::vector<uint64_t> &counts, const ProbingVocabulary &vocab, PositiveProbWarn &warn, const Build &build);
class Unigram {
public:
Unigram() {}
Unigram(void *start, uint64_t count) :
unigram_(static_cast<typename Value::Weights*>(start))
#ifdef DEBUG
, count_(count)
#endif
{}
static uint64_t Size(uint64_t count) {
return (count + 1) * sizeof(typename Value::Weights); // +1 for hallucinate <unk>
}
const typename Value::Weights &Lookup(WordIndex index) const {
#ifdef DEBUG
assert(index < count_);
#endif
return unigram_[index];
}
typename Value::Weights &Unknown() { return unigram_[0]; }
// For building.
typename Value::Weights *Raw() { return unigram_; }
private:
typename Value::Weights *unigram_;
#ifdef DEBUG
uint64_t count_;
#endif
};
Unigram unigram_;
typedef util::ProbingHashTable<typename Value::ProbingEntry, util::IdentityHash> Middle;
std::vector<Middle> middle_;
typedef util::ProbingHashTable<ProbEntry, util::IdentityHash> Longest;
Longest longest_;
};
} // namespace detail
} // namespace ngram
} // namespace lm
#endif // LM_SEARCH_HASHED_H
/* This is where the trie is built. It's on-disk. */
#include "search_trie.hh"
#include "bhiksha.hh"
#include "binary_format.hh"
#include "blank.hh"
#include "lm_exception.hh"
#include "max_order.hh"
#include "quantize.hh"
#include "trie.hh"
#include "trie_sort.hh"
#include "vocab.hh"
#include "weights.hh"
#include "word_index.hh"
#include "../util/ersatz_progress.hh"
#include "../util/mmap.hh"
#include "../util/proxy_iterator.hh"
#include "../util/scoped.hh"
#include "../util/sized_iterator.hh"
#include <algorithm>
#include <cstring>
#include <cstdio>
#include <cstdlib>
#include <queue>
#include <limits>
#include <numeric>
#include <vector>
#if defined(_WIN32) || defined(_WIN64)
#include <windows.h>
#endif
namespace lm {
namespace ngram {
namespace trie {
namespace {
void ReadOrThrow(FILE *from, void *data, size_t size) {
UTIL_THROW_IF(1 != std::fread(data, size, 1, from), util::ErrnoException, "Short read");
}
int Compare(unsigned char order, const void *first_void, const void *second_void) {
const WordIndex *first = reinterpret_cast<const WordIndex*>(first_void), *second = reinterpret_cast<const WordIndex*>(second_void);
const WordIndex *end = first + order;
for (; first != end; ++first, ++second) {
if (*first < *second) return -1;
if (*first > *second) return 1;
}
return 0;
}
struct ProbPointer {
unsigned char array;
uint64_t index;
};
// Array of n-grams and float indices.
class BackoffMessages {
public:
void Init(std::size_t entry_size) {
current_ = NULL;
allocated_ = NULL;
entry_size_ = entry_size;
}
void Add(const WordIndex *to, ProbPointer index) {
while (current_ + entry_size_ > allocated_) {
std::size_t allocated_size = allocated_ - (uint8_t*)backing_.get();
Resize(std::max<std::size_t>(allocated_size * 2, entry_size_));
}
memcpy(current_, to, entry_size_ - sizeof(ProbPointer));
*reinterpret_cast<ProbPointer*>(current_ + entry_size_ - sizeof(ProbPointer)) = index;
current_ += entry_size_;
}
void Apply(float *const *const base, FILE *unigrams) {
FinishedAdding();
if (current_ == allocated_) return;
rewind(unigrams);
ProbBackoff weights;
WordIndex unigram = 0;
ReadOrThrow(unigrams, &weights, sizeof(weights));
for (; current_ != allocated_; current_ += entry_size_) {
const WordIndex &cur_word = *reinterpret_cast<const WordIndex*>(current_);
for (; unigram < cur_word; ++unigram) {
ReadOrThrow(unigrams, &weights, sizeof(weights));
}
if (!HasExtension(weights.backoff)) {
weights.backoff = kExtensionBackoff;
UTIL_THROW_IF(fseek(unigrams, -sizeof(weights), SEEK_CUR), util::ErrnoException, "Seeking backwards to denote unigram extension failed.");
util::WriteOrThrow(unigrams, &weights, sizeof(weights));
}
const ProbPointer &write_to = *reinterpret_cast<const ProbPointer*>(current_ + sizeof(WordIndex));
base[write_to.array][write_to.index] += weights.backoff;
}
backing_.reset();
}
void Apply(float *const *const base, RecordReader &reader) {
FinishedAdding();
if (current_ == allocated_) return;
// We'll also use the same buffer to record messages to blanks that they extend.
WordIndex *extend_out = reinterpret_cast<WordIndex*>(current_);
const unsigned char order = (entry_size_ - sizeof(ProbPointer)) / sizeof(WordIndex);
for (reader.Rewind(); reader && (current_ != allocated_); ) {
switch (Compare(order, reader.Data(), current_)) {
case -1:
++reader;
break;
case 1:
// Message but nobody to receive it. Write it down at the beginning of the buffer so we can inform this blank that it extends.
for (const WordIndex *w = reinterpret_cast<const WordIndex *>(current_); w != reinterpret_cast<const WordIndex *>(current_) + order; ++w, ++extend_out) *extend_out = *w;
current_ += entry_size_;
break;
case 0:
float &backoff = reinterpret_cast<ProbBackoff*>((uint8_t*)reader.Data() + order * sizeof(WordIndex))->backoff;
if (!HasExtension(backoff)) {
backoff = kExtensionBackoff;
reader.Overwrite(&backoff, sizeof(float));
} else {
const ProbPointer &write_to = *reinterpret_cast<const ProbPointer*>(current_ + entry_size_ - sizeof(ProbPointer));
base[write_to.array][write_to.index] += backoff;
}
current_ += entry_size_;
break;
}
}
// Now this is a list of blanks that extend right.
entry_size_ = sizeof(WordIndex) * order;
Resize(sizeof(WordIndex) * (extend_out - (const WordIndex*)backing_.get()));
current_ = (uint8_t*)backing_.get();
}
// Call after Apply
bool Extends(unsigned char order, const WordIndex *words) {
if (current_ == allocated_) return false;
assert(order * sizeof(WordIndex) == entry_size_);
while (true) {
switch(Compare(order, words, current_)) {
case 1:
current_ += entry_size_;
if (current_ == allocated_) return false;
break;
case -1:
return false;
case 0:
return true;
}
}
}
private:
void FinishedAdding() {
Resize(current_ - (uint8_t*)backing_.get());
// Sort requests in same order as files.
util::SizedSort(backing_.get(), current_, entry_size_, EntryCompare((entry_size_ - sizeof(ProbPointer)) / sizeof(WordIndex)));
current_ = (uint8_t*)backing_.get();
}
void Resize(std::size_t to) {
std::size_t current = current_ - (uint8_t*)backing_.get();
backing_.call_realloc(to);
current_ = (uint8_t*)backing_.get() + current;
allocated_ = (uint8_t*)backing_.get() + to;
}
util::scoped_malloc backing_;
uint8_t *current_, *allocated_;
std::size_t entry_size_;
};
const float kBadProb = std::numeric_limits<float>::infinity();
class SRISucks {
public:
SRISucks() {
for (BackoffMessages *i = messages_; i != messages_ + KENLM_MAX_ORDER - 1; ++i)
i->Init(sizeof(ProbPointer) + sizeof(WordIndex) * (i - messages_ + 1));
}
void Send(unsigned char begin, unsigned char order, const WordIndex *to, float prob_basis) {
assert(prob_basis != kBadProb);
ProbPointer pointer;
pointer.array = order - 1;
pointer.index = values_[order - 1].size();
for (unsigned char i = begin; i < order; ++i) {
messages_[i - 1].Add(to, pointer);
}
values_[order - 1].push_back(prob_basis);
}
void ObtainBackoffs(unsigned char total_order, FILE *unigram_file, RecordReader *reader) {
for (unsigned char i = 0; i < KENLM_MAX_ORDER - 1; ++i) {
it_[i] = values_[i].empty() ? NULL : &*values_[i].begin();
}
messages_[0].Apply(it_, unigram_file);
BackoffMessages *messages = messages_ + 1;
const RecordReader *end = reader + total_order - 2 /* exclude unigrams and longest order */;
for (; reader != end; ++messages, ++reader) {
messages->Apply(it_, *reader);
}
}
ProbBackoff GetBlank(unsigned char total_order, unsigned char order, const WordIndex *indices) {
assert(order > 1);
ProbBackoff ret;
ret.prob = *(it_[order - 1]++);
ret.backoff = ((order != total_order - 1) && messages_[order - 1].Extends(order, indices)) ? kExtensionBackoff : kNoExtensionBackoff;
return ret;
}
const std::vector<float> &Values(unsigned char order) const {
return values_[order - 1];
}
private:
// This used to be one array. Then I needed to separate it by order for quantization to work.
std::vector<float> values_[KENLM_MAX_ORDER - 1];
BackoffMessages messages_[KENLM_MAX_ORDER - 1];
float *it_[KENLM_MAX_ORDER - 1];
};
class FindBlanks {
public:
FindBlanks(unsigned char order, const ProbBackoff *unigrams, SRISucks &messages)
: counts_(order), unigrams_(unigrams), sri_(messages) {}
float UnigramProb(WordIndex index) const {
return unigrams_[index].prob;
}
void Unigram(WordIndex /*index*/) {
++counts_[0];
}
void MiddleBlank(const unsigned char order, const WordIndex *indices, unsigned char lower, float prob_basis) {
sri_.Send(lower, order, indices + 1, prob_basis);
++counts_[order - 1];
}
void Middle(const unsigned char order, const void * /*data*/) {
++counts_[order - 1];
}
void Longest(const void * /*data*/) {
++counts_.back();
}
const std::vector<uint64_t> &Counts() const {
return counts_;
}
private:
std::vector<uint64_t> counts_;
const ProbBackoff *unigrams_;
SRISucks &sri_;
};
// Phase to actually write n-grams to the trie.
template <class Quant, class Bhiksha> class WriteEntries {
public:
WriteEntries(RecordReader *contexts, const Quant &quant, UnigramValue *unigrams, BitPackedMiddle<Bhiksha> *middle, BitPackedLongest &longest, unsigned char order, SRISucks &sri) :
contexts_(contexts),
quant_(quant),
unigrams_(unigrams),
middle_(middle),
longest_(longest),
bigram_pack_((order == 2) ? static_cast<BitPacked&>(longest_) : static_cast<BitPacked&>(*middle_)),
order_(order),
sri_(sri) {}
float UnigramProb(WordIndex index) const { return unigrams_[index].weights.prob; }
void Unigram(WordIndex word) {
unigrams_[word].next = bigram_pack_.InsertIndex();
}
void MiddleBlank(const unsigned char order, const WordIndex *indices, unsigned char /*lower*/, float /*prob_base*/) {
ProbBackoff weights = sri_.GetBlank(order_, order, indices);
typename Quant::MiddlePointer(quant_, order - 2, middle_[order - 2].Insert(indices[order - 1])).Write(weights.prob, weights.backoff);
}
void Middle(const unsigned char order, const void *data) {
RecordReader &context = contexts_[order - 1];
const WordIndex *words = reinterpret_cast<const WordIndex*>(data);
ProbBackoff weights = *reinterpret_cast<const ProbBackoff*>(words + order);
if (context && !memcmp(data, context.Data(), sizeof(WordIndex) * order)) {
SetExtension(weights.backoff);
++context;
}
typename Quant::MiddlePointer(quant_, order - 2, middle_[order - 2].Insert(words[order - 1])).Write(weights.prob, weights.backoff);
}
void Longest(const void *data) {
const WordIndex *words = reinterpret_cast<const WordIndex*>(data);
typename Quant::LongestPointer(quant_, longest_.Insert(words[order_ - 1])).Write(reinterpret_cast<const Prob*>(words + order_)->prob);
}
private:
RecordReader *contexts_;
const Quant &quant_;
UnigramValue *const unigrams_;
BitPackedMiddle<Bhiksha> *const middle_;
BitPackedLongest &longest_;
BitPacked &bigram_pack_;
const unsigned char order_;
SRISucks &sri_;
};
struct Gram {
Gram(const WordIndex *in_begin, unsigned char order) : begin(in_begin), end(in_begin + order) {}
const WordIndex *begin, *end;
// For queue, this is the direction we want.
bool operator<(const Gram &other) const {
return std::lexicographical_compare(other.begin, other.end, begin, end);
}
};
template <class Doing> class BlankManager {
public:
BlankManager(unsigned char total_order, Doing &doing) : total_order_(total_order), been_length_(0), doing_(doing) {
for (float *i = basis_; i != basis_ + KENLM_MAX_ORDER - 1; ++i) *i = kBadProb;
}
void Visit(const WordIndex *to, unsigned char length, float prob) {
basis_[length - 1] = prob;
unsigned char overlap = std::min<unsigned char>(length - 1, been_length_);
const WordIndex *cur;
WordIndex *pre;
for (cur = to, pre = been_; cur != to + overlap; ++cur, ++pre) {
if (*pre != *cur) break;
}
if (cur == to + length - 1) {
*pre = *cur;
been_length_ = length;
return;
}
// There are blanks to insert starting with order blank.
unsigned char blank = cur - to + 1;
UTIL_THROW_IF(blank == 1, FormatLoadException, "Missing a unigram that appears as context.");
const float *lower_basis;
for (lower_basis = basis_ + blank - 2; *lower_basis == kBadProb; --lower_basis) {}
unsigned char based_on = lower_basis - basis_ + 1;
for (; cur != to + length - 1; ++blank, ++cur, ++pre) {
assert(*lower_basis != kBadProb);
doing_.MiddleBlank(blank, to, based_on, *lower_basis);
*pre = *cur;
// Mark that the probability is a blank so it shouldn't be used as the basis for a later n-gram.
basis_[blank - 1] = kBadProb;
}
*pre = *cur;
been_length_ = length;
}
private:
const unsigned char total_order_;
WordIndex been_[KENLM_MAX_ORDER];
unsigned char been_length_;
float basis_[KENLM_MAX_ORDER];
Doing &doing_;
};
template <class Doing> void RecursiveInsert(const unsigned char total_order, const WordIndex unigram_count, RecordReader *input, std::ostream *progress_out, const char *message, Doing &doing) {
util::ErsatzProgress progress(unigram_count + 1, progress_out, message);
WordIndex unigram = 0;
std::priority_queue<Gram> grams;
if (unigram_count) grams.push(Gram(&unigram, 1));
for (unsigned char i = 2; i <= total_order; ++i) {
if (input[i-2]) grams.push(Gram(reinterpret_cast<const WordIndex*>(input[i-2].Data()), i));
}
BlankManager<Doing> blank(total_order, doing);
while (!grams.empty()) {
Gram top = grams.top();
grams.pop();
unsigned char order = top.end - top.begin;
if (order == 1) {
blank.Visit(&unigram, 1, doing.UnigramProb(unigram));
doing.Unigram(unigram);
progress.Set(unigram);
if (++unigram < unigram_count) grams.push(top);
} else {
if (order == total_order) {
blank.Visit(top.begin, order, reinterpret_cast<const Prob*>(top.end)->prob);
doing.Longest(top.begin);
} else {
blank.Visit(top.begin, order, reinterpret_cast<const ProbBackoff*>(top.end)->prob);
doing.Middle(order, top.begin);
}
RecordReader &reader = input[order - 2];
if (++reader) grams.push(top);
}
}
}
void SanityCheckCounts(const std::vector<uint64_t> &initial, const std::vector<uint64_t> &fixed) {
if (fixed[0] != initial[0]) UTIL_THROW(util::Exception, "Unigram count should be constant but initial is " << initial[0] << " and recounted is " << fixed[0]);
if (fixed.back() != initial.back()) UTIL_THROW(util::Exception, "Longest count should be constant but it changed from " << initial.back() << " to " << fixed.back());
for (unsigned char i = 0; i < initial.size(); ++i) {
if (fixed[i] < initial[i]) UTIL_THROW(util::Exception, "Counts came out lower than expected. This shouldn't happen");
}
}
template <class Quant> void TrainQuantizer(uint8_t order, uint64_t count, const std::vector<float> &additional, RecordReader &reader, util::ErsatzProgress &progress, Quant &quant) {
std::vector<float> probs(additional), backoffs;
probs.reserve(count + additional.size());
backoffs.reserve(count);
for (reader.Rewind(); reader; ++reader) {
const ProbBackoff &weights = *reinterpret_cast<const ProbBackoff*>(reinterpret_cast<const uint8_t*>(reader.Data()) + sizeof(WordIndex) * order);
probs.push_back(weights.prob);
if (weights.backoff != 0.0) backoffs.push_back(weights.backoff);
++progress;
}
quant.Train(order, probs, backoffs);
}
template <class Quant> void TrainProbQuantizer(uint8_t order, uint64_t count, RecordReader &reader, util::ErsatzProgress &progress, Quant &quant) {
std::vector<float> probs, backoffs;
probs.reserve(count);
for (reader.Rewind(); reader; ++reader) {
const Prob &weights = *reinterpret_cast<const Prob*>(reinterpret_cast<const uint8_t*>(reader.Data()) + sizeof(WordIndex) * order);
probs.push_back(weights.prob);
++progress;
}
quant.TrainProb(order, probs);
}
void PopulateUnigramWeights(FILE *file, WordIndex unigram_count, RecordReader &contexts, UnigramValue *unigrams) {
// Fill unigram probabilities.
try {
rewind(file);
for (WordIndex i = 0; i < unigram_count; ++i) {
ReadOrThrow(file, &unigrams[i].weights, sizeof(ProbBackoff));
if (contexts && *reinterpret_cast<const WordIndex*>(contexts.Data()) == i) {
SetExtension(unigrams[i].weights.backoff);
++contexts;
}
}
} catch (util::Exception &e) {
e << " while re-reading unigram probabilities";
throw;
}
}
} // namespace
template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, SortedVocabulary &vocab, BinaryFormat &backing) {
RecordReader inputs[KENLM_MAX_ORDER - 1];
RecordReader contexts[KENLM_MAX_ORDER - 1];
for (unsigned char i = 2; i <= counts.size(); ++i) {
inputs[i-2].Init(files.Full(i), i * sizeof(WordIndex) + (i == counts.size() ? sizeof(Prob) : sizeof(ProbBackoff)));
contexts[i-2].Init(files.Context(i), (i-1) * sizeof(WordIndex));
}
SRISucks sri;
std::vector<uint64_t> fixed_counts;
util::scoped_FILE unigram_file;
util::scoped_fd unigram_fd(files.StealUnigram());
{
util::scoped_memory unigrams;
MapRead(util::POPULATE_OR_READ, unigram_fd.get(), 0, counts[0] * sizeof(ProbBackoff), unigrams);
FindBlanks finder(counts.size(), reinterpret_cast<const ProbBackoff*>(unigrams.get()), sri);
RecursiveInsert(counts.size(), counts[0], inputs, config.ProgressMessages(), "Identifying n-grams omitted by SRI", finder);
fixed_counts = finder.Counts();
}
unigram_file.reset(util::FDOpenOrThrow(unigram_fd));
for (const RecordReader *i = inputs; i != inputs + counts.size() - 2; ++i) {
if (*i) UTIL_THROW(FormatLoadException, "There's a bug in the trie implementation: the " << (i - inputs + 2) << "-gram table did not complete reading");
}
SanityCheckCounts(counts, fixed_counts);
counts = fixed_counts;
sri.ObtainBackoffs(counts.size(), unigram_file.get(), inputs);
void *vocab_relocate;
void *search_base = backing.GrowForSearch(TrieSearch<Quant, Bhiksha>::Size(fixed_counts, config), vocab.UnkCountChangePadding(), vocab_relocate);
vocab.Relocate(vocab_relocate);
out.SetupMemory(reinterpret_cast<uint8_t*>(search_base), fixed_counts, config);
for (unsigned char i = 2; i <= counts.size(); ++i) {
inputs[i-2].Rewind();
}
if (Quant::kTrain) {
util::ErsatzProgress progress(std::accumulate(counts.begin() + 1, counts.end(), 0),
config.ProgressMessages(), "Quantizing");
for (unsigned char i = 2; i < counts.size(); ++i) {
TrainQuantizer(i, counts[i-1], sri.Values(i), inputs[i-2], progress, quant);
}
TrainProbQuantizer(counts.size(), counts.back(), inputs[counts.size() - 2], progress, quant);
quant.FinishedLoading(config);
}
UnigramValue *unigrams = out.unigram_.Raw();
PopulateUnigramWeights(unigram_file.get(), counts[0], contexts[0], unigrams);
unigram_file.reset();
for (unsigned char i = 2; i <= counts.size(); ++i) {
inputs[i-2].Rewind();
}
// Fill entries except unigram probabilities.
{
WriteEntries<Quant, Bhiksha> writer(contexts, quant, unigrams, out.middle_begin_, out.longest_, counts.size(), sri);
RecursiveInsert(counts.size(), counts[0], inputs, config.ProgressMessages(), "Writing trie", writer);
// Write the last unigram entry, which is the end pointer for the bigrams.
writer.Unigram(counts[0]);
}
// Do not disable this error message or else too little state will be returned. Both WriteEntries::Middle and returning state based on found n-grams will need to be fixed to handle this situation.
for (unsigned char order = 2; order <= counts.size(); ++order) {
const RecordReader &context = contexts[order - 2];
if (context) {
FormatLoadException e;
e << "A " << static_cast<unsigned int>(order) << "-gram has context";
const WordIndex *ctx = reinterpret_cast<const WordIndex*>(context.Data());
for (const WordIndex *i = ctx; i != ctx + order - 1; ++i) {
e << ' ' << *i;
}
e << " so this context must appear in the model as a " << static_cast<unsigned int>(order - 1) << "-gram but it does not";
throw e;
}
}
/* Set ending offsets so the last entry will be sized properly */
// Last entry for unigrams was already set.
if (out.middle_begin_ != out.middle_end_) {
for (typename TrieSearch<Quant, Bhiksha>::Middle *i = out.middle_begin_; i != out.middle_end_ - 1; ++i) {
i->FinishedLoading((i+1)->InsertIndex(), config);
}
(out.middle_end_ - 1)->FinishedLoading(out.longest_.InsertIndex(), config);
}
}
template <class Quant, class Bhiksha> uint8_t *TrieSearch<Quant, Bhiksha>::SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config) {
quant_.SetupMemory(start, counts.size(), config);
start += Quant::Size(counts.size(), config);
unigram_.Init(start);
start += Unigram::Size(counts[0]);
FreeMiddles();
middle_begin_ = static_cast<Middle*>(malloc(sizeof(Middle) * (counts.size() - 2)));
middle_end_ = middle_begin_ + (counts.size() - 2);
std::vector<uint8_t*> middle_starts(counts.size() - 2);
for (unsigned char i = 2; i < counts.size(); ++i) {
middle_starts[i-2] = start;
start += Middle::Size(Quant::MiddleBits(config), counts[i-1], counts[0], counts[i], config);
}
// Crazy backwards thing so we initialize using pointers to ones that have already been initialized
for (unsigned char i = counts.size() - 1; i >= 2; --i) {
// use "placement new" syntax to initalize Middle in an already-allocated memory location
new (middle_begin_ + i - 2) Middle(
middle_starts[i-2],
quant_.MiddleBits(config),
counts[i-1],
counts[0],
counts[i],
(i == counts.size() - 1) ? static_cast<const BitPacked&>(longest_) : static_cast<const BitPacked &>(middle_begin_[i-1]),
config);
}
longest_.Init(start, quant_.LongestBits(config), counts[0]);
return start + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]);
}
template <class Quant, class Bhiksha> void TrieSearch<Quant, Bhiksha>::InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, BinaryFormat &backing) {
std::string temporary_prefix;
if (!config.temporary_directory_prefix.empty()) {
temporary_prefix = config.temporary_directory_prefix;
} else if (config.write_mmap) {
temporary_prefix = config.write_mmap;
} else {
temporary_prefix = file;
}
// At least 1MB sorting memory.
SortedFiles sorted(config, f, counts, std::max<size_t>(config.building_memory, 1048576), temporary_prefix, vocab);
BuildTrie(sorted, counts, config, *this, quant_, vocab, backing);
}
template class TrieSearch<DontQuantize, DontBhiksha>;
template class TrieSearch<DontQuantize, ArrayBhiksha>;
template class TrieSearch<SeparatelyQuantize, DontBhiksha>;
template class TrieSearch<SeparatelyQuantize, ArrayBhiksha>;
} // namespace trie
} // namespace ngram
} // namespace lm
#ifndef LM_SEARCH_TRIE_H
#define LM_SEARCH_TRIE_H
#include "config.hh"
#include "model_type.hh"
#include "return.hh"
#include "trie.hh"
#include "weights.hh"
#include "../util/file.hh"
#include "../util/file_piece.hh"
#include <vector>
#include <cstdlib>
#include <cassert>
namespace lm {
namespace ngram {
class BinaryFormat;
class SortedVocabulary;
namespace trie {
template <class Quant, class Bhiksha> class TrieSearch;
class SortedFiles;
template <class Quant, class Bhiksha> void BuildTrie(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, SortedVocabulary &vocab, BinaryFormat &backing);
template <class Quant, class Bhiksha> class TrieSearch {
public:
typedef NodeRange Node;
typedef ::lm::ngram::trie::UnigramPointer UnigramPointer;
typedef typename Quant::MiddlePointer MiddlePointer;
typedef typename Quant::LongestPointer LongestPointer;
static const bool kDifferentRest = false;
static const ModelType kModelType = static_cast<ModelType>(TRIE_SORTED + Quant::kModelTypeAdd + Bhiksha::kModelTypeAdd);
static const unsigned int kVersion = 1;
static void UpdateConfigFromBinary(const BinaryFormat &file, const std::vector<uint64_t> &counts, uint64_t offset, Config &config) {
Quant::UpdateConfigFromBinary(file, offset, config);
// Currently the unigram pointers are not compresssed, so there will only be a header for order > 2.
if (counts.size() > 2)
Bhiksha::UpdateConfigFromBinary(file, offset + Quant::Size(counts.size(), config) + Unigram::Size(counts[0]), config);
}
static uint64_t Size(const std::vector<uint64_t> &counts, const Config &config) {
uint64_t ret = Quant::Size(counts.size(), config) + Unigram::Size(counts[0]);
for (unsigned char i = 1; i < counts.size() - 1; ++i) {
ret += Middle::Size(Quant::MiddleBits(config), counts[i], counts[0], counts[i+1], config);
}
return ret + Longest::Size(Quant::LongestBits(config), counts.back(), counts[0]);
}
TrieSearch() : middle_begin_(NULL), middle_end_(NULL) {}
~TrieSearch() { FreeMiddles(); }
uint8_t *SetupMemory(uint8_t *start, const std::vector<uint64_t> &counts, const Config &config);
void InitializeFromARPA(const char *file, util::FilePiece &f, std::vector<uint64_t> &counts, const Config &config, SortedVocabulary &vocab, BinaryFormat &backing);
unsigned char Order() const {
return middle_end_ - middle_begin_ + 2;
}
ProbBackoff &UnknownUnigram() { return unigram_.Unknown(); }
UnigramPointer LookupUnigram(WordIndex word, Node &next, bool &independent_left, uint64_t &extend_left) const {
extend_left = static_cast<uint64_t>(word);
UnigramPointer ret(unigram_.Find(word, next));
independent_left = (next.begin == next.end);
return ret;
}
MiddlePointer Unpack(uint64_t extend_pointer, unsigned char extend_length, Node &node) const {
return MiddlePointer(quant_, extend_length - 2, middle_begin_[extend_length - 2].ReadEntry(extend_pointer, node));
}
MiddlePointer LookupMiddle(unsigned char order_minus_2, WordIndex word, Node &node, bool &independent_left, uint64_t &extend_left) const {
util::BitAddress address(middle_begin_[order_minus_2].Find(word, node, extend_left));
independent_left = (address.base == NULL) || (node.begin == node.end);
return MiddlePointer(quant_, order_minus_2, address);
}
LongestPointer LookupLongest(WordIndex word, const Node &node) const {
return LongestPointer(quant_, longest_.Find(word, node));
}
bool FastMakeNode(const WordIndex *begin, const WordIndex *end, Node &node) const {
assert(begin != end);
bool independent_left;
uint64_t ignored;
LookupUnigram(*begin, node, independent_left, ignored);
for (const WordIndex *i = begin + 1; i < end; ++i) {
if (independent_left || !LookupMiddle(i - begin - 1, *i, node, independent_left, ignored).Found()) return false;
}
return true;
}
private:
friend void BuildTrie<Quant, Bhiksha>(SortedFiles &files, std::vector<uint64_t> &counts, const Config &config, TrieSearch<Quant, Bhiksha> &out, Quant &quant, SortedVocabulary &vocab, BinaryFormat &backing);
// Middles are managed manually so we can delay construction and they don't have to be copyable.
void FreeMiddles() {
for (const Middle *i = middle_begin_; i != middle_end_; ++i) {
i->~Middle();
}
std::free(middle_begin_);
}
typedef trie::BitPackedMiddle<Bhiksha> Middle;
typedef trie::BitPackedLongest Longest;
Longest longest_;
Middle *middle_begin_, *middle_end_;
Quant quant_;
typedef ::lm::ngram::trie::Unigram Unigram;
Unigram unigram_;
};
} // namespace trie
} // namespace ngram
} // namespace lm
#endif // LM_SEARCH_TRIE_H
#include "sizes.hh"
#include "model.hh"
#include "../util/file_piece.hh"
#include <vector>
#include <iomanip>
namespace lm {
namespace ngram {
void ShowSizes(const std::vector<uint64_t> &counts, const lm::ngram::Config &config) {
uint64_t sizes[6];
sizes[0] = ProbingModel::Size(counts, config);
sizes[1] = RestProbingModel::Size(counts, config);
sizes[2] = TrieModel::Size(counts, config);
sizes[3] = QuantTrieModel::Size(counts, config);
sizes[4] = ArrayTrieModel::Size(counts, config);
sizes[5] = QuantArrayTrieModel::Size(counts, config);
uint64_t max_length = *std::max_element(sizes, sizes + sizeof(sizes) / sizeof(uint64_t));
uint64_t min_length = *std::min_element(sizes, sizes + sizeof(sizes) / sizeof(uint64_t));
uint64_t divide;
char prefix;
if (min_length < (1 << 10) * 10) {
prefix = ' ';
divide = 1;
} else if (min_length < (1 << 20) * 10) {
prefix = 'k';
divide = 1 << 10;
} else if (min_length < (1ULL << 30) * 10) {
prefix = 'M';
divide = 1 << 20;
} else {
prefix = 'G';
divide = 1 << 30;
}
long int length = std::max<long int>(2, static_cast<long int>(ceil(log10((double) max_length / divide))));
std::cerr << "Memory estimate for binary LM:\ntype ";
// right align bytes.
for (long int i = 0; i < length - 2; ++i) std::cerr << ' ';
std::cerr << prefix << "B\n"
"probing " << std::setw(length) << (sizes[0] / divide) << " assuming -p " << config.probing_multiplier << "\n"
"probing " << std::setw(length) << (sizes[1] / divide) << " assuming -r models -p " << config.probing_multiplier << "\n"
"trie " << std::setw(length) << (sizes[2] / divide) << " without quantization\n"
"trie " << std::setw(length) << (sizes[3] / divide) << " assuming -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits << " quantization \n"
"trie " << std::setw(length) << (sizes[4] / divide) << " assuming -a " << (unsigned)config.pointer_bhiksha_bits << " array pointer compression\n"
"trie " << std::setw(length) << (sizes[5] / divide) << " assuming -a " << (unsigned)config.pointer_bhiksha_bits << " -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits<< " array pointer compression and quantization\n";
}
void ShowSizes(const std::vector<uint64_t> &counts) {
lm::ngram::Config config;
ShowSizes(counts, config);
}
void ShowSizes(const char *file, const lm::ngram::Config &config) {
std::vector<uint64_t> counts;
util::FilePiece f(file);
lm::ReadARPACounts(f, counts);
ShowSizes(counts, config);
}
}} //namespaces
#ifndef LM_SIZES_H
#define LM_SIZES_H
#include <vector>
#include <stdint.h>
namespace lm { namespace ngram {
struct Config;
void ShowSizes(const std::vector<uint64_t> &counts, const lm::ngram::Config &config);
void ShowSizes(const std::vector<uint64_t> &counts);
void ShowSizes(const char *file, const lm::ngram::Config &config);
}} // namespaces
#endif // LM_SIZES_H
#ifndef LM_STATE_H
#define LM_STATE_H
#include "max_order.hh"
#include "word_index.hh"
#include "../util/murmur_hash.hh"
#include <cstring>
namespace lm {
namespace ngram {
// This is a POD but if you want memcmp to return the same as operator==, call
// ZeroRemaining first.
class State {
public:
bool operator==(const State &other) const {
if (length != other.length) return false;
return !memcmp(words, other.words, length * sizeof(WordIndex));
}
// Three way comparison function.
int Compare(const State &other) const {
if (length != other.length) return length < other.length ? -1 : 1;
return memcmp(words, other.words, length * sizeof(WordIndex));
}
bool operator<(const State &other) const {
if (length != other.length) return length < other.length;
return memcmp(words, other.words, length * sizeof(WordIndex)) < 0;
}
// Call this before using raw memcmp.
void ZeroRemaining() {
for (unsigned char i = length; i < KENLM_MAX_ORDER - 1; ++i) {
words[i] = 0;
backoff[i] = 0.0;
}
}
unsigned char Length() const { return length; }
// You shouldn't need to touch anything below this line, but the members are public so FullState will qualify as a POD.
// This order minimizes total size of the struct if WordIndex is 64 bit, float is 32 bit, and alignment of 64 bit integers is 64 bit.
WordIndex words[KENLM_MAX_ORDER - 1];
float backoff[KENLM_MAX_ORDER - 1];
unsigned char length;
};
typedef State Right;
inline uint64_t hash_value(const State &state, uint64_t seed = 0) {
return util::MurmurHashNative(state.words, sizeof(WordIndex) * state.length, seed);
}
struct Left {
bool operator==(const Left &other) const {
return
length == other.length &&
(!length || (pointers[length - 1] == other.pointers[length - 1] && full == other.full));
}
int Compare(const Left &other) const {
if (length < other.length) return -1;
if (length > other.length) return 1;
if (length == 0) return 0; // Must be full.
if (pointers[length - 1] > other.pointers[length - 1]) return 1;
if (pointers[length - 1] < other.pointers[length - 1]) return -1;
return (int)full - (int)other.full;
}
bool operator<(const Left &other) const {
return Compare(other) == -1;
}
void ZeroRemaining() {
for (uint64_t * i = pointers + length; i < pointers + KENLM_MAX_ORDER - 1; ++i)
*i = 0;
}
uint64_t pointers[KENLM_MAX_ORDER - 1];
unsigned char length;
bool full;
};
inline uint64_t hash_value(const Left &left) {
unsigned char add[2];
add[0] = left.length;
add[1] = left.full;
return util::MurmurHashNative(add, 2, left.length ? left.pointers[left.length - 1] : 0);
}
struct ChartState {
bool operator==(const ChartState &other) const {
return (right == other.right) && (left == other.left);
}
int Compare(const ChartState &other) const {
int lres = left.Compare(other.left);
if (lres) return lres;
return right.Compare(other.right);
}
bool operator<(const ChartState &other) const {
return Compare(other) < 0;
}
void ZeroRemaining() {
left.ZeroRemaining();
right.ZeroRemaining();
}
Left left;
State right;
};
inline uint64_t hash_value(const ChartState &state) {
return hash_value(state.right, hash_value(state.left));
}
} // namespace ngram
} // namespace lm
#endif // LM_STATE_H
\data\
ngram 1=37
ngram 2=47
ngram 3=11
ngram 4=6
ngram 5=4
\1-grams:
-1.383514 , -0.30103
-1.139057 . -0.845098
-1.029493 </s>
-99 <s> -0.4149733
-1.995635 <unk> -20
-1.285941 a -0.69897
-1.687872 also -0.30103
-1.687872 beyond -0.30103
-1.687872 biarritz -0.30103
-1.687872 call -0.30103
-1.687872 concerns -0.30103
-1.687872 consider -0.30103
-1.687872 considering -0.30103
-1.687872 for -0.30103
-1.509559 higher -0.30103
-1.687872 however -0.30103
-1.687872 i -0.30103
-1.687872 immediate -0.30103
-1.687872 in -0.30103
-1.687872 is -0.30103
-1.285941 little -0.69897
-1.383514 loin -0.30103
-1.687872 look -0.30103
-1.285941 looking -0.4771212
-1.206319 more -0.544068
-1.509559 on -0.4771212
-1.509559 screening -0.4771212
-1.687872 small -0.30103
-1.687872 the -0.30103
-1.687872 to -0.30103
-1.687872 watch -0.30103
-1.687872 watching -0.30103
-1.687872 what -0.30103
-1.687872 would -0.30103
-3.141592 foo
-2.718281 bar 3.0
-6.535897 baz -0.0
\2-grams:
-0.6925742 , .
-0.7522095 , however
-0.7522095 , is
-0.0602359 . </s>
-0.4846522 <s> looking -0.4771214
-1.051485 <s> screening
-1.07153 <s> the
-1.07153 <s> watching
-1.07153 <s> what
-0.09132547 a little -0.69897
-0.2922095 also call
-0.2922095 beyond immediate
-0.2705918 biarritz .
-0.2922095 call for
-0.2922095 concerns in
-0.2922095 consider watch
-0.2922095 considering consider
-0.2834328 for ,
-0.5511513 higher more
-0.5845945 higher small
-0.2834328 however ,
-0.2922095 i would
-0.2922095 immediate concerns
-0.2922095 in biarritz
-0.2922095 is to
-0.09021038 little more -0.1998621
-0.7273645 loin ,
-0.6925742 loin .
-0.6708385 loin </s>
-0.2922095 look beyond
-0.4638903 looking higher
-0.4638903 looking on -0.4771212
-0.5136299 more . -0.4771212
-0.3561665 more loin
-0.1649931 on a -0.4771213
-0.1649931 screening a -0.4771213
-0.2705918 small .
-0.287799 the screening
-0.2922095 to look
-0.2622373 watch </s>
-0.2922095 watching considering
-0.2922095 what i
-0.2922095 would also
-2 also would -6
-15 <unk> <unk> -2
-4 <unk> however -1
-6 foo bar
\3-grams:
-0.01916512 more . </s>
-0.0283603 on a little -0.4771212
-0.0283603 screening a little -0.4771212
-0.01660496 a little more -0.09409451
-0.3488368 <s> looking higher
-0.3488368 <s> looking on -0.4771212
-0.1892331 little more loin
-0.04835128 looking on a -0.4771212
-3 also would consider -7
-6 <unk> however <unk> -12
-7 to look a
\4-grams:
-0.009249173 looking on a little -0.4771212
-0.005464747 on a little more -0.4771212
-0.005464747 screening a little more
-0.1453306 a little more loin
-0.01552657 <s> looking on a -0.4771212
-4 also would consider higher -8
\5-grams:
-0.003061223 <s> looking on a little
-0.001813953 looking on a little more
-0.0432557 on a little more loin
-5 also would consider higher looking
\end\
\data\
ngram 1=36
ngram 2=45
ngram 3=10
ngram 4=6
ngram 5=4
\1-grams:
-1.383514 , -0.30103
-1.139057 . -0.845098
-1.029493 </s>
-99 <s> -0.4149733
-1.285941 a -0.69897
-1.687872 also -0.30103
-1.687872 beyond -0.30103
-1.687872 biarritz -0.30103
-1.687872 call -0.30103
-1.687872 concerns -0.30103
-1.687872 consider -0.30103
-1.687872 considering -0.30103
-1.687872 for -0.30103
-1.509559 higher -0.30103
-1.687872 however -0.30103
-1.687872 i -0.30103
-1.687872 immediate -0.30103
-1.687872 in -0.30103
-1.687872 is -0.30103
-1.285941 little -0.69897
-1.383514 loin -0.30103
-1.687872 look -0.30103
-1.285941 looking -0.4771212
-1.206319 more -0.544068
-1.509559 on -0.4771212
-1.509559 screening -0.4771212
-1.687872 small -0.30103
-1.687872 the -0.30103
-1.687872 to -0.30103
-1.687872 watch -0.30103
-1.687872 watching -0.30103
-1.687872 what -0.30103
-1.687872 would -0.30103
-3.141592 foo
-2.718281 bar 3.0
-6.535897 baz -0.0
\2-grams:
-0.6925742 , .
-0.7522095 , however
-0.7522095 , is
-0.0602359 . </s>
-0.4846522 <s> looking -0.4771214
-1.051485 <s> screening
-1.07153 <s> the
-1.07153 <s> watching
-1.07153 <s> what
-0.09132547 a little -0.69897
-0.2922095 also call
-0.2922095 beyond immediate
-0.2705918 biarritz .
-0.2922095 call for
-0.2922095 concerns in
-0.2922095 consider watch
-0.2922095 considering consider
-0.2834328 for ,
-0.5511513 higher more
-0.5845945 higher small
-0.2834328 however ,
-0.2922095 i would
-0.2922095 immediate concerns
-0.2922095 in biarritz
-0.2922095 is to
-0.09021038 little more -0.1998621
-0.7273645 loin ,
-0.6925742 loin .
-0.6708385 loin </s>
-0.2922095 look beyond
-0.4638903 looking higher
-0.4638903 looking on -0.4771212
-0.5136299 more . -0.4771212
-0.3561665 more loin
-0.1649931 on a -0.4771213
-0.1649931 screening a -0.4771213
-0.2705918 small .
-0.287799 the screening
-0.2922095 to look
-0.2622373 watch </s>
-0.2922095 watching considering
-0.2922095 what i
-0.2922095 would also
-2 also would -6
-6 foo bar
\3-grams:
-0.01916512 more . </s>
-0.0283603 on a little -0.4771212
-0.0283603 screening a little -0.4771212
-0.01660496 a little more -0.09409451
-0.3488368 <s> looking higher
-0.3488368 <s> looking on -0.4771212
-0.1892331 little more loin
-0.04835128 looking on a -0.4771212
-3 also would consider -7
-7 to look a
\4-grams:
-0.009249173 looking on a little -0.4771212
-0.005464747 on a little more -0.4771212
-0.005464747 screening a little more
-0.1453306 a little more loin
-0.01552657 <s> looking on a -0.4771212
-4 also would consider higher -8
\5-grams:
-0.003061223 <s> looking on a little
-0.001813953 looking on a little more
-0.0432557 on a little more loin
-5 also would consider higher looking
\end\
#include "trie.hh"
#include "bhiksha.hh"
#include "../util/bit_packing.hh"
#include "../util/exception.hh"
#include "../util/sorted_uniform.hh"
#include <cassert>
namespace lm {
namespace ngram {
namespace trie {
namespace {
class KeyAccessor {
public:
KeyAccessor(const void *base, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits)
: base_(reinterpret_cast<const uint8_t*>(base)), key_mask_(key_mask), key_bits_(key_bits), total_bits_(total_bits) {}
typedef uint64_t Key;
Key operator()(uint64_t index) const {
return util::ReadInt57(base_, index * static_cast<uint64_t>(total_bits_), key_bits_, key_mask_);
}
private:
const uint8_t *const base_;
const WordIndex key_mask_;
const uint8_t key_bits_, total_bits_;
};
bool FindBitPacked(const void *base, uint64_t key_mask, uint8_t key_bits, uint8_t total_bits, uint64_t begin_index, uint64_t end_index, const uint64_t max_vocab, const uint64_t key, uint64_t &at_index) {
KeyAccessor accessor(base, key_mask, key_bits, total_bits);
if (!util::BoundedSortedUniformFind<uint64_t, KeyAccessor, util::PivotSelect<sizeof(WordIndex)>::T>(accessor, begin_index - 1, (uint64_t)0, end_index, max_vocab, key, at_index)) return false;
return true;
}
} // namespace
uint64_t BitPacked::BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits) {
uint8_t total_bits = util::RequiredBits(max_vocab) + remaining_bits;
// Extra entry for next pointer at the end.
// +7 then / 8 to round up bits and convert to bytes
// +sizeof(uint64_t) so that ReadInt57 etc don't go segfault.
// Note that this waste is O(order), not O(number of ngrams).
return ((1 + entries) * total_bits + 7) / 8 + sizeof(uint64_t);
}
void BitPacked::BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits) {
util::BitPackingSanity();
word_bits_ = util::RequiredBits(max_vocab);
word_mask_ = (1ULL << word_bits_) - 1ULL;
if (word_bits_ > 57) UTIL_THROW(util::Exception, "Sorry, word indices more than " << (1ULL << 57) << " are not implemented. Edit util/bit_packing.hh and fix the bit packing functions.");
total_bits_ = word_bits_ + remaining_bits;
base_ = static_cast<uint8_t*>(base);
insert_index_ = 0;
max_vocab_ = max_vocab;
}
template <class Bhiksha> uint64_t BitPackedMiddle<Bhiksha>::Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_ptr, const Config &config) {
return Bhiksha::Size(entries + 1, max_ptr, config) + BaseSize(entries, max_vocab, quant_bits + Bhiksha::InlineBits(entries + 1, max_ptr, config));
}
template <class Bhiksha> BitPackedMiddle<Bhiksha>::BitPackedMiddle(void *base, uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source, const Config &config) :
BitPacked(),
quant_bits_(quant_bits),
// If the offset of the method changes, also change TrieSearch::UpdateConfigFromBinary.
bhiksha_(base, entries + 1, max_next, config),
next_source_(&next_source) {
if (entries + 1 >= (1ULL << 57) || (max_next >= (1ULL << 57))) UTIL_THROW(util::Exception, "Sorry, this does not support more than " << (1ULL << 57) << " n-grams of a particular order. Edit util/bit_packing.hh and fix the bit packing functions.");
BaseInit(reinterpret_cast<uint8_t*>(base) + Bhiksha::Size(entries + 1, max_next, config), max_vocab, quant_bits_ + bhiksha_.InlineBits());
}
template <class Bhiksha> util::BitAddress BitPackedMiddle<Bhiksha>::Insert(WordIndex word) {
assert(word <= word_mask_);
uint64_t at_pointer = insert_index_ * total_bits_;
util::WriteInt57(base_, at_pointer, word_bits_, word);
at_pointer += word_bits_;
util::BitAddress ret(base_, at_pointer);
at_pointer += quant_bits_;
uint64_t next = next_source_->InsertIndex();
bhiksha_.WriteNext(base_, at_pointer, insert_index_, next);
++insert_index_;
return ret;
}
template <class Bhiksha> util::BitAddress BitPackedMiddle<Bhiksha>::Find(WordIndex word, NodeRange &range, uint64_t &pointer) const {
uint64_t at_pointer;
if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) {
return util::BitAddress(NULL, 0);
}
pointer = at_pointer;
at_pointer *= total_bits_;
at_pointer += word_bits_;
bhiksha_.ReadNext(base_, at_pointer + quant_bits_, pointer, total_bits_, range);
return util::BitAddress(base_, at_pointer);
}
template <class Bhiksha> void BitPackedMiddle<Bhiksha>::FinishedLoading(uint64_t next_end, const Config &config) {
// Write at insert_index. . .
uint64_t last_next_write = insert_index_ * total_bits_ +
// at the offset where the next pointers are stored.
(total_bits_ - bhiksha_.InlineBits());
bhiksha_.WriteNext(base_, last_next_write, insert_index_, next_end);
bhiksha_.FinishedLoading(config);
}
util::BitAddress BitPackedLongest::Insert(WordIndex index) {
assert(index <= word_mask_);
uint64_t at_pointer = insert_index_ * total_bits_;
util::WriteInt57(base_, at_pointer, word_bits_, index);
at_pointer += word_bits_;
++insert_index_;
return util::BitAddress(base_, at_pointer);
}
util::BitAddress BitPackedLongest::Find(WordIndex word, const NodeRange &range) const {
uint64_t at_pointer;
if (!FindBitPacked(base_, word_mask_, word_bits_, total_bits_, range.begin, range.end, max_vocab_, word, at_pointer)) return util::BitAddress(NULL, 0);
at_pointer = at_pointer * total_bits_ + word_bits_;
return util::BitAddress(base_, at_pointer);
}
template class BitPackedMiddle<DontBhiksha>;
template class BitPackedMiddle<ArrayBhiksha>;
} // namespace trie
} // namespace ngram
} // namespace lm
#ifndef LM_TRIE_H
#define LM_TRIE_H
#include "weights.hh"
#include "word_index.hh"
#include "../util/bit_packing.hh"
#include <cstddef>
#include <stdint.h>
namespace lm {
namespace ngram {
struct Config;
namespace trie {
struct NodeRange {
uint64_t begin, end;
};
// TODO: if the number of unigrams is a concern, also bit pack these records.
struct UnigramValue {
ProbBackoff weights;
uint64_t next;
uint64_t Next() const { return next; }
};
class UnigramPointer {
public:
explicit UnigramPointer(const ProbBackoff &to) : to_(&to) {}
UnigramPointer() : to_(NULL) {}
bool Found() const { return to_ != NULL; }
float Prob() const { return to_->prob; }
float Backoff() const { return to_->backoff; }
float Rest() const { return Prob(); }
private:
const ProbBackoff *to_;
};
class Unigram {
public:
Unigram() {}
void Init(void *start) {
unigram_ = static_cast<UnigramValue*>(start);
}
static uint64_t Size(uint64_t count) {
// +1 in case unknown doesn't appear. +1 for the final next.
return (count + 2) * sizeof(UnigramValue);
}
const ProbBackoff &Lookup(WordIndex index) const { return unigram_[index].weights; }
ProbBackoff &Unknown() { return unigram_[0].weights; }
UnigramValue *Raw() {
return unigram_;
}
UnigramPointer Find(WordIndex word, NodeRange &next) const {
UnigramValue *val = unigram_ + word;
next.begin = val->next;
next.end = (val+1)->next;
return UnigramPointer(val->weights);
}
private:
UnigramValue *unigram_;
};
class BitPacked {
public:
BitPacked() {}
uint64_t InsertIndex() const {
return insert_index_;
}
protected:
static uint64_t BaseSize(uint64_t entries, uint64_t max_vocab, uint8_t remaining_bits);
void BaseInit(void *base, uint64_t max_vocab, uint8_t remaining_bits);
uint8_t word_bits_;
uint8_t total_bits_;
uint64_t word_mask_;
uint8_t *base_;
uint64_t insert_index_, max_vocab_;
};
template <class Bhiksha> class BitPackedMiddle : public BitPacked {
public:
static uint64_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const Config &config);
// next_source need not be initialized.
BitPackedMiddle(void *base, uint8_t quant_bits, uint64_t entries, uint64_t max_vocab, uint64_t max_next, const BitPacked &next_source, const Config &config);
util::BitAddress Insert(WordIndex word);
void FinishedLoading(uint64_t next_end, const Config &config);
util::BitAddress Find(WordIndex word, NodeRange &range, uint64_t &pointer) const;
util::BitAddress ReadEntry(uint64_t pointer, NodeRange &range) {
uint64_t addr = pointer * total_bits_;
addr += word_bits_;
bhiksha_.ReadNext(base_, addr + quant_bits_, pointer, total_bits_, range);
return util::BitAddress(base_, addr);
}
private:
uint8_t quant_bits_;
Bhiksha bhiksha_;
const BitPacked *next_source_;
};
class BitPackedLongest : public BitPacked {
public:
static uint64_t Size(uint8_t quant_bits, uint64_t entries, uint64_t max_vocab) {
return BaseSize(entries, max_vocab, quant_bits);
}
BitPackedLongest() {}
void Init(void *base, uint8_t quant_bits, uint64_t max_vocab) {
BaseInit(base, max_vocab, quant_bits);
}
util::BitAddress Insert(WordIndex word);
util::BitAddress Find(WordIndex word, const NodeRange &node) const;
};
} // namespace trie
} // namespace ngram
} // namespace lm
#endif // LM_TRIE_H
#include "trie_sort.hh"
#include "config.hh"
#include "lm_exception.hh"
#include "read_arpa.hh"
#include "vocab.hh"
#include "weights.hh"
#include "word_index.hh"
#include "../util/file_piece.hh"
#include "../util/mmap.hh"
#include "../util/pool.hh"
#include "../util/proxy_iterator.hh"
#include "../util/sized_iterator.hh"
#include <algorithm>
#include <cstring>
#include <cstdio>
#include <cstdlib>
#include <deque>
#include <iterator>
#include <limits>
#include <vector>
namespace lm {
namespace ngram {
namespace trie {
namespace {
typedef util::SizedIterator NGramIter;
// Proxy for an entry except there is some extra cruft between the entries. This is used to sort (n-1)-grams using the same memory as the sorted n-grams.
class PartialViewProxy {
public:
PartialViewProxy() : attention_size_(0), inner_() {}
PartialViewProxy(void *ptr, std::size_t block_size, util::FreePool &pool) : attention_size_(pool.ElementSize()), inner_(ptr, block_size), pool_(&pool) {}
operator util::ValueBlock() const {
return util::ValueBlock(inner_.Data(), *pool_);
}
PartialViewProxy &operator=(const PartialViewProxy &from) {
memcpy(inner_.Data(), from.inner_.Data(), attention_size_);
return *this;
}
PartialViewProxy &operator=(const util::ValueBlock &from) {
memcpy(inner_.Data(), from.Data(), attention_size_);
return *this;
}
const void *Data() const { return inner_.Data(); }
void *Data() { return inner_.Data(); }
friend void swap(PartialViewProxy first, PartialViewProxy second);
private:
friend class util::ProxyIterator<PartialViewProxy>;
typedef util::ValueBlock value_type;
const std::size_t attention_size_;
typedef util::SizedInnerIterator InnerIterator;
InnerIterator &Inner() { return inner_; }
const InnerIterator &Inner() const { return inner_; }
InnerIterator inner_;
util::FreePool *pool_;
};
#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wunused-function"
#endif
void swap(PartialViewProxy first, PartialViewProxy second) {
std::swap_ranges(reinterpret_cast<char*>(first.Data()), reinterpret_cast<char*>(first.Data()) + first.attention_size_, reinterpret_cast<char*>(second.Data()));
}
#ifdef __clang__
#pragma clang diagnostic pop
#endif
typedef util::ProxyIterator<PartialViewProxy> PartialIter;
FILE *DiskFlush(const void *mem_begin, const void *mem_end, const std::string &temp_prefix) {
util::scoped_fd file(util::MakeTemp(temp_prefix));
util::WriteOrThrow(file.get(), mem_begin, (uint8_t*)mem_end - (uint8_t*)mem_begin);
return util::FDOpenOrThrow(file);
}
FILE *WriteContextFile(uint8_t *begin, uint8_t *end, const std::string &temp_prefix, std::size_t entry_size, unsigned char order) {
const size_t context_size = sizeof(WordIndex) * (order - 1);
util::FreePool pool(context_size);
// Sort just the contexts using the same memory.
PartialIter context_begin(PartialViewProxy(begin + sizeof(WordIndex), entry_size, pool));
PartialIter context_end(PartialViewProxy(end + sizeof(WordIndex), entry_size, pool));
#if defined(_WIN32) || defined(_WIN64)
std::stable_sort
#else
std::sort
#endif
(context_begin, context_end, util::SizedCompare<EntryCompare, PartialViewProxy>(EntryCompare(order - 1)));
util::scoped_FILE out(util::FMakeTemp(temp_prefix));
// Write out to file and uniqueify at the same time. Could have used unique_copy if there was an appropriate OutputIterator.
if (context_begin == context_end) return out.release();
PartialIter i(context_begin);
util::WriteOrThrow(out.get(), i->Data(), context_size);
const void *previous = i->Data();
++i;
for (; i != context_end; ++i) {
if (memcmp(previous, i->Data(), context_size)) {
util::WriteOrThrow(out.get(), i->Data(), context_size);
previous = i->Data();
}
}
return out.release();
}
struct ThrowCombine {
void operator()(std::size_t entry_size, unsigned char order, const void *first, const void *second, FILE * /*out*/) const {
const WordIndex *base = reinterpret_cast<const WordIndex*>(first);
FormatLoadException e;
e << "Duplicate n-gram detected with vocab ids";
for (const WordIndex *i = base; i != base + order; ++i) {
e << ' ' << *i;
}
throw e;
}
};
// Useful for context files that just contain records with no value.
struct FirstCombine {
void operator()(std::size_t entry_size, unsigned char /*order*/, const void *first, const void * /*second*/, FILE *out) const {
util::WriteOrThrow(out, first, entry_size);
}
};
template <class Combine> FILE *MergeSortedFiles(FILE *first_file, FILE *second_file, const std::string &temp_prefix, std::size_t weights_size, unsigned char order, const Combine &combine) {
std::size_t entry_size = sizeof(WordIndex) * order + weights_size;
RecordReader first, second;
first.Init(first_file, entry_size);
second.Init(second_file, entry_size);
util::scoped_FILE out_file(util::FMakeTemp(temp_prefix));
EntryCompare less(order);
while (first && second) {
if (less(first.Data(), second.Data())) {
util::WriteOrThrow(out_file.get(), first.Data(), entry_size);
++first;
} else if (less(second.Data(), first.Data())) {
util::WriteOrThrow(out_file.get(), second.Data(), entry_size);
++second;
} else {
combine(entry_size, order, first.Data(), second.Data(), out_file.get());
++first; ++second;
}
}
for (RecordReader &remains = (first ? first : second); remains; ++remains) {
util::WriteOrThrow(out_file.get(), remains.Data(), entry_size);
}
return out_file.release();
}
} // namespace
void RecordReader::Init(FILE *file, std::size_t entry_size) {
entry_size_ = entry_size;
data_.reset(malloc(entry_size));
UTIL_THROW_IF(!data_.get(), util::ErrnoException, "Failed to malloc read buffer");
file_ = file;
if (file) {
rewind(file);
remains_ = true;
++*this;
} else {
remains_ = false;
}
}
void RecordReader::Overwrite(const void *start, std::size_t amount) {
long internal = (uint8_t*)start - (uint8_t*)data_.get();
UTIL_THROW_IF(fseek(file_, internal - entry_size_, SEEK_CUR), util::ErrnoException, "Couldn't seek backwards for revision");
util::WriteOrThrow(file_, start, amount);
long forward = entry_size_ - internal - amount;
#if !defined(_WIN32) && !defined(_WIN64)
if (forward)
#endif
UTIL_THROW_IF(fseek(file_, forward, SEEK_CUR), util::ErrnoException, "Couldn't seek forwards past revision");
}
void RecordReader::Rewind() {
if (file_) {
rewind(file_);
remains_ = true;
++*this;
} else {
remains_ = false;
}
}
SortedFiles::SortedFiles(const Config &config, util::FilePiece &f, std::vector<uint64_t> &counts, size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) {
PositiveProbWarn warn(config.positive_log_probability);
unigram_.reset(util::MakeTemp(file_prefix));
{
// In case <unk> appears.
size_t size_out = (counts[0] + 1) * sizeof(ProbBackoff);
util::scoped_mmap unigram_mmap(util::MapZeroedWrite(unigram_.get(), size_out), size_out);
Read1Grams(f, counts[0], vocab, reinterpret_cast<ProbBackoff*>(unigram_mmap.get()), warn);
CheckSpecials(config, vocab);
if (!vocab.SawUnk()) ++counts[0];
}
// Only use as much buffer as we need.
size_t buffer_use = 0;
for (unsigned int order = 2; order < counts.size(); ++order) {
buffer_use = std::max<size_t>(buffer_use, static_cast<size_t>((sizeof(WordIndex) * order + 2 * sizeof(float)) * counts[order - 1]));
}
buffer_use = std::max<size_t>(buffer_use, static_cast<size_t>((sizeof(WordIndex) * counts.size() + sizeof(float)) * counts.back()));
buffer = std::min<size_t>(buffer, buffer_use);
util::scoped_malloc mem;
mem.reset(malloc(buffer));
if (!mem.get()) UTIL_THROW(util::ErrnoException, "malloc failed for sort buffer size " << buffer);
for (unsigned char order = 2; order <= counts.size(); ++order) {
ConvertToSorted(f, vocab, counts, file_prefix, order, warn, mem.get(), buffer);
}
ReadEnd(f);
}
namespace {
class Closer {
public:
explicit Closer(std::deque<FILE*> &files) : files_(files) {}
~Closer() {
for (std::deque<FILE*>::iterator i = files_.begin(); i != files_.end(); ++i) {
util::scoped_FILE deleter(*i);
}
}
void PopFront() {
util::scoped_FILE deleter(files_.front());
files_.pop_front();
}
private:
std::deque<FILE*> &files_;
};
} // namespace
void SortedFiles::ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, const std::string &file_prefix, unsigned char order, PositiveProbWarn &warn, void *mem, std::size_t mem_size) {
ReadNGramHeader(f, order);
const size_t count = counts[order - 1];
// Size of weights. Does it include backoff?
const size_t words_size = sizeof(WordIndex) * order;
const size_t weights_size = sizeof(float) + ((order == counts.size()) ? 0 : sizeof(float));
const size_t entry_size = words_size + weights_size;
const size_t batch_size = std::min(count, mem_size / entry_size);
uint8_t *const begin = reinterpret_cast<uint8_t*>(mem);
std::deque<FILE*> files, contexts;
Closer files_closer(files), contexts_closer(contexts);
for (std::size_t batch = 0, done = 0; done < count; ++batch) {
uint8_t *out = begin;
uint8_t *out_end = out + std::min(count - done, batch_size) * entry_size;
if (order == counts.size()) {
for (; out != out_end; out += entry_size) {
std::reverse_iterator<WordIndex*> it(reinterpret_cast<WordIndex*>(out) + order);
ReadNGram(f, order, vocab, it, *reinterpret_cast<Prob*>(out + words_size), warn);
}
} else {
for (; out != out_end; out += entry_size) {
std::reverse_iterator<WordIndex*> it(reinterpret_cast<WordIndex*>(out) + order);
ReadNGram(f, order, vocab, it, *reinterpret_cast<ProbBackoff*>(out + words_size), warn);
}
}
// Sort full records by full n-gram.
util::SizedSort(begin, out_end, entry_size, EntryCompare(order));
files.push_back(DiskFlush(begin, out_end, file_prefix));
contexts.push_back(WriteContextFile(begin, out_end, file_prefix, entry_size, order));
done += (out_end - begin) / entry_size;
}
// All individual files created. Merge them.
while (files.size() > 1) {
files.push_back(MergeSortedFiles(files[0], files[1], file_prefix, weights_size, order, ThrowCombine()));
files_closer.PopFront();
files_closer.PopFront();
contexts.push_back(MergeSortedFiles(contexts[0], contexts[1], file_prefix, 0, order - 1, FirstCombine()));
contexts_closer.PopFront();
contexts_closer.PopFront();
}
if (!files.empty()) {
// Steal from closers.
full_[order - 2].reset(files.front());
files.pop_front();
context_[order - 2].reset(contexts.front());
contexts.pop_front();
}
}
} // namespace trie
} // namespace ngram
} // namespace lm
// Step of trie builder: create sorted files.
#ifndef LM_TRIE_SORT_H
#define LM_TRIE_SORT_H
#include "max_order.hh"
#include "word_index.hh"
#include "../util/file.hh"
#include "../util/scoped.hh"
#include <cstddef>
#include <functional>
#include <string>
#include <vector>
#include <stdint.h>
namespace util {
class FilePiece;
} // namespace util
namespace lm {
class PositiveProbWarn;
namespace ngram {
class SortedVocabulary;
struct Config;
namespace trie {
class EntryCompare : public std::binary_function<const void*, const void*, bool> {
public:
explicit EntryCompare(unsigned char order) : order_(order) {}
bool operator()(const void *first_void, const void *second_void) const {
const WordIndex *first = static_cast<const WordIndex*>(first_void);
const WordIndex *second = static_cast<const WordIndex*>(second_void);
const WordIndex *end = first + order_;
for (; first != end; ++first, ++second) {
if (*first < *second) return true;
if (*first > *second) return false;
}
return false;
}
private:
unsigned char order_;
};
class RecordReader {
public:
RecordReader() : remains_(true) {}
void Init(FILE *file, std::size_t entry_size);
void *Data() { return data_.get(); }
const void *Data() const { return data_.get(); }
RecordReader &operator++() {
std::size_t ret = fread(data_.get(), entry_size_, 1, file_);
if (!ret) {
UTIL_THROW_IF(!feof(file_), util::ErrnoException, "Error reading temporary file");
remains_ = false;
}
return *this;
}
operator bool() const { return remains_; }
void Rewind();
std::size_t EntrySize() const { return entry_size_; }
void Overwrite(const void *start, std::size_t amount);
private:
FILE *file_;
util::scoped_malloc data_;
bool remains_;
std::size_t entry_size_;
};
class SortedFiles {
public:
// Build from ARPA
SortedFiles(const Config &config, util::FilePiece &f, std::vector<uint64_t> &counts, std::size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab);
int StealUnigram() {
return unigram_.release();
}
FILE *Full(unsigned char order) {
return full_[order - 2].get();
}
FILE *Context(unsigned char of_order) {
return context_[of_order - 2].get();
}
private:
void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector<uint64_t> &counts, const std::string &prefix, unsigned char order, PositiveProbWarn &warn, void *mem, std::size_t mem_size);
util::scoped_fd unigram_;
util::scoped_FILE full_[KENLM_MAX_ORDER - 1], context_[KENLM_MAX_ORDER - 1];
};
} // namespace trie
} // namespace ngram
} // namespace lm
#endif // LM_TRIE_SORT_H
#ifndef LM_VALUE_H
#define LM_VALUE_H
#include "config.hh"
#include "model_type.hh"
#include "value_build.hh"
#include "weights.hh"
#include "../util/bit_packing.hh"
#include <stdint.h>
namespace lm {
namespace ngram {
// Template proxy for probing unigrams and middle.
template <class Weights> class GenericProbingProxy {
public:
explicit GenericProbingProxy(const Weights &to) : to_(&to) {}
GenericProbingProxy() : to_(0) {}
bool Found() const { return to_ != 0; }
float Prob() const {
util::FloatEnc enc;
enc.f = to_->prob;
enc.i |= util::kSignBit;
return enc.f;
}
float Backoff() const { return to_->backoff; }
bool IndependentLeft() const {
util::FloatEnc enc;
enc.f = to_->prob;
return enc.i & util::kSignBit;
}
protected:
const Weights *to_;
};
// Basic proxy for trie unigrams.
template <class Weights> class GenericTrieUnigramProxy {
public:
explicit GenericTrieUnigramProxy(const Weights &to) : to_(&to) {}
GenericTrieUnigramProxy() : to_(0) {}
bool Found() const { return to_ != 0; }
float Prob() const { return to_->prob; }
float Backoff() const { return to_->backoff; }
float Rest() const { return Prob(); }
protected:
const Weights *to_;
};
struct BackoffValue {
typedef ProbBackoff Weights;
static const ModelType kProbingModelType = PROBING;
class ProbingProxy : public GenericProbingProxy<Weights> {
public:
explicit ProbingProxy(const Weights &to) : GenericProbingProxy<Weights>(to) {}
ProbingProxy() {}
float Rest() const { return Prob(); }
};
class TrieUnigramProxy : public GenericTrieUnigramProxy<Weights> {
public:
explicit TrieUnigramProxy(const Weights &to) : GenericTrieUnigramProxy<Weights>(to) {}
TrieUnigramProxy() {}
float Rest() const { return Prob(); }
};
struct ProbingEntry {
typedef uint64_t Key;
typedef Weights Value;
uint64_t key;
ProbBackoff value;
uint64_t GetKey() const { return key; }
};
struct TrieUnigramValue {
Weights weights;
uint64_t next;
uint64_t Next() const { return next; }
};
const static bool kDifferentRest = false;
template <class Model, class C> void Callback(const Config &, unsigned int, typename Model::Vocabulary &, C &callback) {
NoRestBuild build;
callback(build);
}
};
struct RestValue {
typedef RestWeights Weights;
static const ModelType kProbingModelType = REST_PROBING;
class ProbingProxy : public GenericProbingProxy<RestWeights> {
public:
explicit ProbingProxy(const Weights &to) : GenericProbingProxy<RestWeights>(to) {}
ProbingProxy() {}
float Rest() const { return to_->rest; }
};
class TrieUnigramProxy : public GenericTrieUnigramProxy<Weights> {
public:
explicit TrieUnigramProxy(const Weights &to) : GenericTrieUnigramProxy<Weights>(to) {}
TrieUnigramProxy() {}
float Rest() const { return to_->rest; }
};
// gcc 4.1 doesn't properly back dependent types :-(.
#pragma pack(push)
#pragma pack(4)
struct ProbingEntry {
typedef uint64_t Key;
typedef Weights Value;
Key key;
Value value;
Key GetKey() const { return key; }
};
struct TrieUnigramValue {
Weights weights;
uint64_t next;
uint64_t Next() const { return next; }
};
#pragma pack(pop)
const static bool kDifferentRest = true;
template <class Model, class C> void Callback(const Config &config, unsigned int order, typename Model::Vocabulary &vocab, C &callback) {
switch (config.rest_function) {
case Config::REST_MAX:
{
MaxRestBuild build;
callback(build);
}
break;
case Config::REST_LOWER:
{
LowerRestBuild<Model> build(config, order, vocab);
callback(build);
}
break;
}
}
};
} // namespace ngram
} // namespace lm
#endif // LM_VALUE_H
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment