Commit 688b6eac authored by SWHL's avatar SWHL
Browse files

Update files

parents
#include "value_build.hh"
#include "model.hh"
#include "read_arpa.hh"
namespace lm {
namespace ngram {
template <class Model> LowerRestBuild<Model>::LowerRestBuild(const Config &config, unsigned int order, const typename Model::Vocabulary &vocab) {
UTIL_THROW_IF(config.rest_lower_files.size() != order - 1, ConfigException, "This model has order " << order << " so there should be " << (order - 1) << " lower-order models for rest cost purposes.");
Config for_lower = config;
for_lower.write_mmap = NULL;
for_lower.rest_lower_files.clear();
// Unigram models aren't supported, so this is a custom loader.
// TODO: optimize the unigram loading?
{
util::FilePiece uni(config.rest_lower_files[0].c_str());
std::vector<uint64_t> number;
ReadARPACounts(uni, number);
UTIL_THROW_IF(number.size() != 1, FormatLoadException, "Expected the unigram model to have order 1, not " << number.size());
ReadNGramHeader(uni, 1);
unigrams_.resize(number[0]);
unigrams_[0] = config.unknown_missing_logprob;
PositiveProbWarn warn;
for (uint64_t i = 0; i < number[0]; ++i) {
WordIndex w;
Prob entry;
ReadNGram(uni, 1, vocab, &w, entry, warn);
unigrams_[w] = entry.prob;
}
}
try {
for (unsigned int i = 2; i < order; ++i) {
models_.push_back(new Model(config.rest_lower_files[i - 1].c_str(), for_lower));
UTIL_THROW_IF(models_.back()->Order() != i, FormatLoadException, "Lower order file " << config.rest_lower_files[i-1] << " should have order " << i);
}
} catch (...) {
for (typename std::vector<const Model*>::const_iterator i = models_.begin(); i != models_.end(); ++i) {
delete *i;
}
models_.clear();
throw;
}
// TODO: force/check same vocab.
}
template <class Model> LowerRestBuild<Model>::~LowerRestBuild() {
for (typename std::vector<const Model*>::const_iterator i = models_.begin(); i != models_.end(); ++i) {
delete *i;
}
}
template class LowerRestBuild<ProbingModel>;
} // namespace ngram
} // namespace lm
#ifndef LM_VALUE_BUILD_H
#define LM_VALUE_BUILD_H
#include "weights.hh"
#include "word_index.hh"
#include "../util/bit_packing.hh"
#include <vector>
namespace lm {
namespace ngram {
struct Config;
struct BackoffValue;
struct RestValue;
class NoRestBuild {
public:
typedef BackoffValue Value;
NoRestBuild() {}
void SetRest(const WordIndex *, unsigned int, const Prob &/*prob*/) const {}
void SetRest(const WordIndex *, unsigned int, const ProbBackoff &) const {}
template <class Second> bool MarkExtends(ProbBackoff &weights, const Second &) const {
util::UnsetSign(weights.prob);
return false;
}
// Probing doesn't need to go back to unigram.
const static bool kMarkEvenLower = false;
};
class MaxRestBuild {
public:
typedef RestValue Value;
MaxRestBuild() {}
void SetRest(const WordIndex *, unsigned int, const Prob &/*prob*/) const {}
void SetRest(const WordIndex *, unsigned int, RestWeights &weights) const {
weights.rest = weights.prob;
util::SetSign(weights.rest);
}
bool MarkExtends(RestWeights &weights, const RestWeights &to) const {
util::UnsetSign(weights.prob);
if (weights.rest >= to.rest) return false;
weights.rest = to.rest;
return true;
}
bool MarkExtends(RestWeights &weights, const Prob &to) const {
util::UnsetSign(weights.prob);
if (weights.rest >= to.prob) return false;
weights.rest = to.prob;
return true;
}
// Probing does need to go back to unigram.
const static bool kMarkEvenLower = true;
};
template <class Model> class LowerRestBuild {
public:
typedef RestValue Value;
LowerRestBuild(const Config &config, unsigned int order, const typename Model::Vocabulary &vocab);
~LowerRestBuild();
void SetRest(const WordIndex *, unsigned int, const Prob &/*prob*/) const {}
void SetRest(const WordIndex *vocab_ids, unsigned int n, RestWeights &weights) const {
typename Model::State ignored;
if (n == 1) {
weights.rest = unigrams_[*vocab_ids];
} else {
weights.rest = models_[n-2]->FullScoreForgotState(vocab_ids + 1, vocab_ids + n, *vocab_ids, ignored).prob;
}
}
template <class Second> bool MarkExtends(RestWeights &weights, const Second &) const {
util::UnsetSign(weights.prob);
return false;
}
const static bool kMarkEvenLower = false;
std::vector<float> unigrams_;
std::vector<const Model*> models_;
};
} // namespace ngram
} // namespace lm
#endif // LM_VALUE_BUILD_H
#include "virtual_interface.hh"
#include "lm_exception.hh"
namespace lm {
namespace base {
Vocabulary::~Vocabulary() {}
void Vocabulary::SetSpecial(WordIndex begin_sentence, WordIndex end_sentence, WordIndex not_found) {
begin_sentence_ = begin_sentence;
end_sentence_ = end_sentence;
not_found_ = not_found;
}
Model::~Model() {}
} // namespace base
} // namespace lm
#ifndef LM_VIRTUAL_INTERFACE_H
#define LM_VIRTUAL_INTERFACE_H
#include "return.hh"
#include "word_index.hh"
#include "../util/string_piece.hh"
#include <string>
#include <cstring>
namespace lm {
namespace base {
template <class T, class U, class V> class ModelFacade;
/* Vocabulary interface. Call Index(string) and get a word index for use in
* calling Model. It provides faster convenience functions for <s>, </s>, and
* <unk> although you can also find these using Index.
*
* Some models do not load the mapping from index to string. If you need this,
* check if the model Vocabulary class implements such a function and access it
* directly.
*
* The Vocabulary object is always owned by the Model and can be retrieved from
* the Model using BaseVocabulary() for this abstract interface or
* GetVocabulary() for the actual implementation (in which case you'll need the
* actual implementation of the Model too).
*/
class Vocabulary {
public:
virtual ~Vocabulary();
WordIndex BeginSentence() const { return begin_sentence_; }
WordIndex EndSentence() const { return end_sentence_; }
WordIndex NotFound() const { return not_found_; }
/* Most implementations allow StringPiece lookups and need only override
* Index(StringPiece). SRI requires null termination and overrides all
* three methods.
*/
virtual WordIndex Index(const StringPiece &str) const = 0;
virtual WordIndex Index(const std::string &str) const {
return Index(StringPiece(str));
}
virtual WordIndex Index(const char *str) const {
return Index(StringPiece(str));
}
protected:
// Call SetSpecial afterward.
Vocabulary() {}
Vocabulary(WordIndex begin_sentence, WordIndex end_sentence, WordIndex not_found) {
SetSpecial(begin_sentence, end_sentence, not_found);
}
void SetSpecial(WordIndex begin_sentence, WordIndex end_sentence, WordIndex not_found);
WordIndex begin_sentence_, end_sentence_, not_found_;
private:
// Disable copy constructors. They're private and undefined.
// Ersatz boost::noncopyable.
Vocabulary(const Vocabulary &);
Vocabulary &operator=(const Vocabulary &);
};
/* There are two ways to access a Model.
*
*
* OPTION 1: Access the Model directly (e.g. lm::ngram::Model in model.hh).
*
* Every Model implements the scoring function:
* float Score(
* const Model::State &in_state,
* const WordIndex new_word,
* Model::State &out_state) const;
*
* It can also return the length of n-gram matched by the model:
* FullScoreReturn FullScore(
* const Model::State &in_state,
* const WordIndex new_word,
* Model::State &out_state) const;
*
*
* There are also accessor functions:
* const State &BeginSentenceState() const;
* const State &NullContextState() const;
* const Vocabulary &GetVocabulary() const;
* unsigned int Order() const;
*
* NB: In case you're wondering why the model implementation looks like it's
* missing these methods, see facade.hh.
*
* This is the fastest way to use a model and presents a normal State class to
* be included in a hypothesis state structure.
*
*
* OPTION 2: Use the virtual interface below.
*
* The virtual interface allow you to decide which Model to use at runtime
* without templatizing everything on the Model type. However, each Model has
* its own State class, so a single State cannot be efficiently provided (it
* would require using the maximum memory of any Model's State or memory
* allocation with each lookup). This means you become responsible for
* allocating memory with size StateSize() and passing it to the Score or
* FullScore functions provided here.
*
* For example, cdec has a std::string containing the entire state of a
* hypothesis. It can reserve StateSize bytes in this string for the model
* state.
*
* All the State objects are POD, so it's ok to use raw memory for storing
* State.
* in_state and out_state must not have the same address.
*/
class Model {
public:
virtual ~Model();
size_t StateSize() const { return state_size_; }
const void *BeginSentenceMemory() const { return begin_sentence_memory_; }
void BeginSentenceWrite(void *to) const { memcpy(to, begin_sentence_memory_, StateSize()); }
const void *NullContextMemory() const { return null_context_memory_; }
void NullContextWrite(void *to) const { memcpy(to, null_context_memory_, StateSize()); }
// Requires in_state != out_state
virtual float BaseScore(const void *in_state, const WordIndex new_word, void *out_state) const = 0;
// Requires in_state != out_state
virtual FullScoreReturn BaseFullScore(const void *in_state, const WordIndex new_word, void *out_state) const = 0;
// Prefer to use FullScore. The context words should be provided in reverse order.
virtual FullScoreReturn BaseFullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, void *out_state) const = 0;
unsigned char Order() const { return order_; }
const Vocabulary &BaseVocabulary() const { return *base_vocab_; }
private:
template <class T, class U, class V> friend class ModelFacade;
explicit Model(size_t state_size) : state_size_(state_size) {}
const size_t state_size_;
const void *begin_sentence_memory_, *null_context_memory_;
const Vocabulary *base_vocab_;
unsigned char order_;
// Disable copy constructors. They're private and undefined.
// Ersatz boost::noncopyable.
Model(const Model &);
Model &operator=(const Model &);
};
} // mamespace base
} // namespace lm
#endif // LM_VIRTUAL_INTERFACE_H
#include "vocab.hh"
#include "binary_format.hh"
#include "enumerate_vocab.hh"
#include "lm_exception.hh"
#include "config.hh"
#include "weights.hh"
#include "../util/exception.hh"
#include "../util/file_stream.hh"
#include "../util/file.hh"
#include "../util/joint_sort.hh"
#include "../util/murmur_hash.hh"
#include "../util/probing_hash_table.hh"
#include <cstring>
#include <string>
namespace lm {
namespace ngram {
namespace detail {
uint64_t HashForVocab(const char *str, std::size_t len) {
// This proved faster than Boost's hash in speed trials: total load time Murmur 67090000, Boost 72210000
// Chose to use 64A instead of native so binary format will be portable across 64 and 32 bit.
return util::MurmurHash64A(str, len, 0);
}
} // namespace detail
namespace {
// Normally static initialization is a bad idea but MurmurHash is pure arithmetic, so this is ok.
const uint64_t kUnknownHash = detail::HashForVocab("<unk>", 5);
// Sadly some LMs have <UNK>.
const uint64_t kUnknownCapHash = detail::HashForVocab("<UNK>", 5);
void ReadWords(int fd, EnumerateVocab *enumerate, WordIndex expected_count, uint64_t offset) {
util::SeekOrThrow(fd, offset);
// Check that we're at the right place by reading <unk> which is always first.
char check_unk[6];
util::ReadOrThrow(fd, check_unk, 6);
UTIL_THROW_IF(
memcmp(check_unk, "<unk>", 6),
FormatLoadException,
"Vocabulary words are in the wrong place. This could be because the binary file was built with stale gcc and old kenlm. Stale gcc, including the gcc distributed with RedHat and OS X, has a bug that ignores pragma pack for template-dependent types. New kenlm works around this, so you'll save memory but have to rebuild any binary files using the probing data structure.");
if (!enumerate) return;
enumerate->Add(0, "<unk>");
WordIndex index = 1; // Read <unk> already.
util::FilePiece in(util::DupOrThrow(fd));
for (util::LineIterator w(in, '\0'); w; ++w, ++index) {
enumerate->Add(index, *w);
}
UTIL_THROW_IF(expected_count != index, FormatLoadException, "The binary file has the wrong number of words at the end. This could be caused by a truncated binary file.");
}
// Constructor ordering madness.
int SeekAndReturn(int fd, uint64_t start) {
util::SeekOrThrow(fd, start);
return fd;
}
} // namespace
ImmediateWriteWordsWrapper::ImmediateWriteWordsWrapper(EnumerateVocab *inner, int fd, uint64_t start)
: inner_(inner), stream_(SeekAndReturn(fd, start)) {}
WriteWordsWrapper::WriteWordsWrapper(EnumerateVocab *inner) : inner_(inner) {}
void WriteWordsWrapper::Add(WordIndex index, const StringPiece &str) {
if (inner_) inner_->Add(index, str);
buffer_.append(str.data(), str.size());
buffer_.push_back(0);
}
void WriteWordsWrapper::Write(int fd, uint64_t start) {
util::SeekOrThrow(fd, start);
util::WriteOrThrow(fd, buffer_.data(), buffer_.size());
// Free memory from the string.
std::string for_swap;
std::swap(buffer_, for_swap);
}
SortedVocabulary::SortedVocabulary() : begin_(NULL), end_(NULL), enumerate_(NULL) {}
uint64_t SortedVocabulary::Size(uint64_t entries, const Config &/*config*/) {
// Lead with the number of entries.
return sizeof(uint64_t) + sizeof(uint64_t) * entries;
}
void SortedVocabulary::SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config) {
assert(allocated >= Size(entries, config));
// Leave space for number of entries.
begin_ = reinterpret_cast<uint64_t*>(start) + 1;
end_ = begin_;
saw_unk_ = false;
}
void SortedVocabulary::Relocate(void *new_start) {
std::size_t delta = end_ - begin_;
begin_ = reinterpret_cast<uint64_t*>(new_start) + 1;
end_ = begin_ + delta;
}
void SortedVocabulary::ConfigureEnumerate(EnumerateVocab *to, std::size_t max_entries) {
enumerate_ = to;
if (enumerate_) {
enumerate_->Add(0, "<unk>");
strings_to_enumerate_.resize(max_entries);
}
}
WordIndex SortedVocabulary::Insert(const StringPiece &str) {
uint64_t hashed = detail::HashForVocab(str);
if (hashed == kUnknownHash || hashed == kUnknownCapHash) {
saw_unk_ = true;
return 0;
}
*end_ = hashed;
if (enumerate_) {
void *copied = string_backing_.Allocate(str.size());
memcpy(copied, str.data(), str.size());
strings_to_enumerate_[end_ - begin_] = StringPiece(static_cast<const char*>(copied), str.size());
}
++end_;
// This is 1 + the offset where it was inserted to make room for unk.
return end_ - begin_;
}
void SortedVocabulary::FinishedLoading(ProbBackoff *reorder) {
GenericFinished(reorder);
}
namespace {
#pragma pack(push)
#pragma pack(4)
struct RenumberEntry {
uint64_t hash;
const char *str;
WordIndex old;
bool operator<(const RenumberEntry &other) const {
return hash < other.hash;
}
};
#pragma pack(pop)
} // namespace
void SortedVocabulary::ComputeRenumbering(WordIndex types, int from_words, int to_words, std::vector<WordIndex> &mapping) {
mapping.clear();
uint64_t file_size = util::SizeOrThrow(from_words);
util::scoped_memory strings;
util::MapRead(util::POPULATE_OR_READ, from_words, 0, file_size, strings);
const char *const start = static_cast<const char*>(strings.get());
UTIL_THROW_IF(memcmp(start, "<unk>", 6), FormatLoadException, "Vocab file does not begin with <unk> followed by null");
std::vector<RenumberEntry> entries;
entries.reserve(types - 1);
RenumberEntry entry;
entry.old = 1;
for (entry.str = start + 6 /* skip <unk>\0 */; entry.str < start + file_size; ++entry.old) {
StringPiece str(entry.str, strlen(entry.str));
entry.hash = detail::HashForVocab(str);
entries.push_back(entry);
entry.str += str.size() + 1;
}
UTIL_THROW_IF2(entries.size() != types - 1, "Wrong number of vocab ids. Got " << (entries.size() + 1) << " expected " << types);
std::sort(entries.begin(), entries.end());
// Write out new vocab file.
{
util::FileStream out(to_words);
out << "<unk>" << '\0';
for (std::vector<RenumberEntry>::const_iterator i = entries.begin(); i != entries.end(); ++i) {
out << i->str << '\0';
}
}
strings.reset();
mapping.resize(types);
mapping[0] = 0; // <unk>
for (std::vector<RenumberEntry>::const_iterator i = entries.begin(); i != entries.end(); ++i) {
mapping[i->old] = i + 1 - entries.begin();
}
}
void SortedVocabulary::Populated() {
saw_unk_ = true;
SetSpecial(Index("<s>"), Index("</s>"), 0);
bound_ = end_ - begin_ + 1;
*(reinterpret_cast<uint64_t*>(begin_) - 1) = end_ - begin_;
}
void SortedVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset) {
end_ = begin_ + *(reinterpret_cast<const uint64_t*>(begin_) - 1);
SetSpecial(Index("<s>"), Index("</s>"), 0);
bound_ = end_ - begin_ + 1;
if (have_words) ReadWords(fd, to, bound_, offset);
}
template <class T> void SortedVocabulary::GenericFinished(T *reorder) {
if (enumerate_) {
if (!strings_to_enumerate_.empty()) {
util::PairedIterator<T*, StringPiece*> values(reorder + 1, &*strings_to_enumerate_.begin());
util::JointSort(begin_, end_, values);
}
for (WordIndex i = 0; i < static_cast<WordIndex>(end_ - begin_); ++i) {
// <unk> strikes again: +1 here.
enumerate_->Add(i + 1, strings_to_enumerate_[i]);
}
strings_to_enumerate_.clear();
string_backing_.FreeAll();
} else {
util::JointSort(begin_, end_, reorder + 1);
}
SetSpecial(Index("<s>"), Index("</s>"), 0);
// Save size. Excludes UNK.
*(reinterpret_cast<uint64_t*>(begin_) - 1) = end_ - begin_;
// Includes UNK.
bound_ = end_ - begin_ + 1;
}
namespace {
const unsigned int kProbingVocabularyVersion = 0;
} // namespace
namespace detail {
struct ProbingVocabularyHeader {
// Lowest unused vocab id. This is also the number of words, including <unk>.
unsigned int version;
WordIndex bound;
};
} // namespace detail
ProbingVocabulary::ProbingVocabulary() : enumerate_(NULL) {}
uint64_t ProbingVocabulary::Size(uint64_t entries, float probing_multiplier) {
return ALIGN8(sizeof(detail::ProbingVocabularyHeader)) + Lookup::Size(entries, probing_multiplier);
}
uint64_t ProbingVocabulary::Size(uint64_t entries, const Config &config) {
return Size(entries, config.probing_multiplier);
}
void ProbingVocabulary::SetupMemory(void *start, std::size_t allocated) {
header_ = static_cast<detail::ProbingVocabularyHeader*>(start);
lookup_ = Lookup(static_cast<uint8_t*>(start) + ALIGN8(sizeof(detail::ProbingVocabularyHeader)), allocated);
bound_ = 1;
saw_unk_ = false;
}
void ProbingVocabulary::Relocate(void *new_start) {
header_ = static_cast<detail::ProbingVocabularyHeader*>(new_start);
lookup_.Relocate(static_cast<uint8_t*>(new_start) + ALIGN8(sizeof(detail::ProbingVocabularyHeader)));
}
void ProbingVocabulary::ConfigureEnumerate(EnumerateVocab *to, std::size_t /*max_entries*/) {
enumerate_ = to;
if (enumerate_) {
enumerate_->Add(0, "<unk>");
}
}
WordIndex ProbingVocabulary::Insert(const StringPiece &str) {
uint64_t hashed = detail::HashForVocab(str);
// Prevent unknown from going into the table.
if (hashed == kUnknownHash || hashed == kUnknownCapHash) {
saw_unk_ = true;
return 0;
} else {
if (enumerate_) enumerate_->Add(bound_, str);
lookup_.Insert(ProbingVocabularyEntry::Make(hashed, bound_));
return bound_++;
}
}
void ProbingVocabulary::InternalFinishedLoading() {
lookup_.FinishedInserting();
header_->bound = bound_;
header_->version = kProbingVocabularyVersion;
SetSpecial(Index("<s>"), Index("</s>"), 0);
}
void ProbingVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset) {
UTIL_THROW_IF(header_->version != kProbingVocabularyVersion, FormatLoadException, "The binary file has probing version " << header_->version << " but the code expects version " << kProbingVocabularyVersion << ". Please rerun build_binary using the same version of the code.");
bound_ = header_->bound;
SetSpecial(Index("<s>"), Index("</s>"), 0);
if (have_words) ReadWords(fd, to, bound_, offset);
}
void MissingUnknown(const Config &config) {
switch(config.unknown_missing) {
case SILENT:
return;
case COMPLAIN:
if (config.messages) *config.messages << "The ARPA file is missing <unk>. Substituting log10 probability " << config.unknown_missing_logprob << "." << std::endl;
break;
case THROW_UP:
UTIL_THROW(SpecialWordMissingException, "The ARPA file is missing <unk> and the model is configured to throw an exception.");
}
}
void MissingSentenceMarker(const Config &config, const char *str) {
switch (config.sentence_marker_missing) {
case SILENT:
return;
case COMPLAIN:
if (config.messages) *config.messages << "Missing special word " << str << "; will treat it as <unk>.";
break;
case THROW_UP:
UTIL_THROW(SpecialWordMissingException, "The ARPA file is missing " << str << " and the model is configured to reject these models. Run build_binary -s to disable this check.");
}
}
} // namespace ngram
} // namespace lm
#ifndef LM_VOCAB_H
#define LM_VOCAB_H
#include "enumerate_vocab.hh"
#include "lm_exception.hh"
#include "virtual_interface.hh"
#include "../util/file_stream.hh"
#include "../util/murmur_hash.hh"
#include "../util/pool.hh"
#include "../util/probing_hash_table.hh"
#include "../util/sorted_uniform.hh"
#include "../util/string_piece.hh"
#include <limits>
#include <string>
#include <vector>
namespace lm {
struct ProbBackoff;
class EnumerateVocab;
namespace ngram {
struct Config;
namespace detail {
uint64_t HashForVocab(const char *str, std::size_t len);
inline uint64_t HashForVocab(const StringPiece &str) {
return HashForVocab(str.data(), str.length());
}
struct ProbingVocabularyHeader;
} // namespace detail
// Writes words immediately to a file instead of buffering, because we know
// where in the file to put them.
class ImmediateWriteWordsWrapper : public EnumerateVocab {
public:
ImmediateWriteWordsWrapper(EnumerateVocab *inner, int fd, uint64_t start);
void Add(WordIndex index, const StringPiece &str) {
stream_ << str << '\0';
if (inner_) inner_->Add(index, str);
}
private:
EnumerateVocab *inner_;
util::FileStream stream_;
};
// When the binary size isn't known yet.
class WriteWordsWrapper : public EnumerateVocab {
public:
WriteWordsWrapper(EnumerateVocab *inner);
void Add(WordIndex index, const StringPiece &str);
const std::string &Buffer() const { return buffer_; }
void Write(int fd, uint64_t start);
private:
EnumerateVocab *inner_;
std::string buffer_;
};
// Vocabulary based on sorted uniform find storing only uint64_t values and using their offsets as indices.
class SortedVocabulary : public base::Vocabulary {
public:
SortedVocabulary();
WordIndex Index(const StringPiece &str) const {
const uint64_t *found;
if (util::BoundedSortedUniformFind<const uint64_t*, util::IdentityAccessor<uint64_t>, util::Pivot64>(
util::IdentityAccessor<uint64_t>(),
begin_ - 1, 0,
end_, std::numeric_limits<uint64_t>::max(),
detail::HashForVocab(str), found)) {
return found - begin_ + 1; // +1 because <unk> is 0 and does not appear in the lookup table.
} else {
return 0;
}
}
// Size for purposes of file writing
static uint64_t Size(uint64_t entries, const Config &config);
/* Read null-delimited words from file from_words, renumber according to
* hash order, write null-delimited words to to_words, and create a mapping
* from old id to new id. The 0th vocab word must be <unk>.
*/
static void ComputeRenumbering(WordIndex types, int from_words, int to_words, std::vector<WordIndex> &mapping);
// Vocab words are [0, Bound()) Only valid after FinishedLoading/LoadedBinary.
WordIndex Bound() const { return bound_; }
// Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway.
void SetupMemory(void *start, std::size_t allocated, std::size_t entries, const Config &config);
void Relocate(void *new_start);
void ConfigureEnumerate(EnumerateVocab *to, std::size_t max_entries);
// Insert and FinishedLoading go together.
WordIndex Insert(const StringPiece &str);
// Reorders reorder_vocab so that the IDs are sorted.
void FinishedLoading(ProbBackoff *reorder_vocab);
// Trie stores the correct counts including <unk> in the header. If this was previously sized based on a count exluding <unk>, padding with 8 bytes will make it the correct size based on a count including <unk>.
std::size_t UnkCountChangePadding() const { return SawUnk() ? 0 : sizeof(uint64_t); }
bool SawUnk() const { return saw_unk_; }
void LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset);
uint64_t *&EndHack() { return end_; }
void Populated();
private:
template <class T> void GenericFinished(T *reorder);
uint64_t *begin_, *end_;
WordIndex bound_;
bool saw_unk_;
EnumerateVocab *enumerate_;
// Actual strings. Used only when loading from ARPA and enumerate_ != NULL
util::Pool string_backing_;
std::vector<StringPiece> strings_to_enumerate_;
};
#pragma pack(push)
#pragma pack(4)
struct ProbingVocabularyEntry {
uint64_t key;
WordIndex value;
typedef uint64_t Key;
uint64_t GetKey() const { return key; }
void SetKey(uint64_t to) { key = to; }
static ProbingVocabularyEntry Make(uint64_t key, WordIndex value) {
ProbingVocabularyEntry ret;
ret.key = key;
ret.value = value;
return ret;
}
};
#pragma pack(pop)
// Vocabulary storing a map from uint64_t to WordIndex.
class ProbingVocabulary : public base::Vocabulary {
public:
ProbingVocabulary();
WordIndex Index(const StringPiece &str) const {
Lookup::ConstIterator i;
return lookup_.Find(detail::HashForVocab(str), i) ? i->value : 0;
}
static uint64_t Size(uint64_t entries, float probing_multiplier);
// This just unwraps Config to get the probing_multiplier.
static uint64_t Size(uint64_t entries, const Config &config);
// Vocab words are [0, Bound()).
WordIndex Bound() const { return bound_; }
// Everything else is for populating. I'm too lazy to hide and friend these, but you'll only get a const reference anyway.
void SetupMemory(void *start, std::size_t allocated);
void SetupMemory(void *start, std::size_t allocated, std::size_t /*entries*/, const Config &/*config*/) {
SetupMemory(start, allocated);
}
void Relocate(void *new_start);
void ConfigureEnumerate(EnumerateVocab *to, std::size_t max_entries);
WordIndex Insert(const StringPiece &str);
template <class Weights> void FinishedLoading(Weights * /*reorder_vocab*/) {
InternalFinishedLoading();
}
std::size_t UnkCountChangePadding() const { return 0; }
bool SawUnk() const { return saw_unk_; }
void LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset);
private:
void InternalFinishedLoading();
typedef util::ProbingHashTable<ProbingVocabularyEntry, util::IdentityHash> Lookup;
Lookup lookup_;
WordIndex bound_;
bool saw_unk_;
EnumerateVocab *enumerate_;
detail::ProbingVocabularyHeader *header_;
};
void MissingUnknown(const Config &config);
void MissingSentenceMarker(const Config &config, const char *str);
template <class Vocab> void CheckSpecials(const Config &config, const Vocab &vocab) {
if (!vocab.SawUnk()) MissingUnknown(config);
if (vocab.BeginSentence() == vocab.NotFound()) MissingSentenceMarker(config, "<s>");
if (vocab.EndSentence() == vocab.NotFound()) MissingSentenceMarker(config, "</s>");
}
class WriteUniqueWords {
public:
explicit WriteUniqueWords(int fd) : word_list_(fd) {}
void operator()(const StringPiece &word) {
word_list_ << word << '\0';
}
private:
util::FileStream word_list_;
};
class NoOpUniqueWords {
public:
NoOpUniqueWords() {}
void operator()(const StringPiece &word) {}
};
template <class NewWordAction = NoOpUniqueWords> class GrowableVocab {
public:
static std::size_t MemUsage(WordIndex content) {
return Lookup::MemUsage(content > 2 ? content : 2);
}
// Does not take ownership of new_word_construct
template <class NewWordConstruct> GrowableVocab(WordIndex initial_size, const NewWordConstruct &new_word_construct = NewWordAction())
: lookup_(initial_size), new_word_(new_word_construct) {
FindOrInsert("<unk>"); // Force 0
FindOrInsert("<s>"); // Force 1
FindOrInsert("</s>"); // Force 2
}
WordIndex Index(const StringPiece &str) const {
Lookup::ConstIterator i;
return lookup_.Find(detail::HashForVocab(str), i) ? i->value : 0;
}
WordIndex FindOrInsert(const StringPiece &word) {
ProbingVocabularyEntry entry = ProbingVocabularyEntry::Make(util::MurmurHashNative(word.data(), word.size()), Size());
Lookup::MutableIterator it;
if (!lookup_.FindOrInsert(entry, it)) {
new_word_(word);
UTIL_THROW_IF(Size() >= std::numeric_limits<lm::WordIndex>::max(), VocabLoadException, "Too many vocabulary words. Change WordIndex to uint64_t in lm/word_index.hh");
}
return it->value;
}
WordIndex Size() const { return lookup_.Size(); }
bool IsSpecial(WordIndex word) const {
return word <= 2;
}
private:
typedef util::AutoProbing<ProbingVocabularyEntry, util::IdentityHash> Lookup;
Lookup lookup_;
NewWordAction new_word_;
};
} // namespace ngram
} // namespace lm
#endif // LM_VOCAB_H
#ifndef LM_WEIGHTS_H
#define LM_WEIGHTS_H
// Weights for n-grams. Probability and possibly a backoff.
namespace lm {
struct Prob {
float prob;
};
// No inheritance so this will be a POD.
struct ProbBackoff {
float prob;
float backoff;
};
struct RestWeights {
float prob;
float backoff;
float rest;
};
} // namespace lm
#endif // LM_WEIGHTS_H
// Separate header because this is used often.
#ifndef LM_WORD_INDEX_H
#define LM_WORD_INDEX_H
#include <climits>
namespace lm {
typedef unsigned int WordIndex;
const WordIndex kMaxWordIndex = UINT_MAX;
const WordIndex kUNK = 0;
} // namespace lm
typedef lm::WordIndex LMWordIndex;
#endif
This directory is for wrappers around other people's LMs, presenting an interface similar to KenLM's. You will need to have their LM installed.
NPLM is a work in progress.
#include "nplm.hh"
#include "../../util/exception.hh"
#include "../../util/file.hh"
#include <algorithm>
#include <cstring>
#include "neuralLM.h"
namespace lm {
namespace np {
Vocabulary::Vocabulary(const nplm::vocabulary &vocab)
: base::Vocabulary(vocab.lookup_word("<s>"), vocab.lookup_word("</s>"), vocab.lookup_word("<unk>")),
vocab_(vocab), null_word_(vocab.lookup_word("<null>")) {}
Vocabulary::~Vocabulary() {}
WordIndex Vocabulary::Index(const std::string &str) const {
return vocab_.lookup_word(str);
}
class Backend {
public:
Backend(const nplm::neuralLM &from, const std::size_t cache_size) : lm_(from), ngram_(from.get_order()) {
lm_.set_cache(cache_size);
}
nplm::neuralLM &LM() { return lm_; }
const nplm::neuralLM &LM() const { return lm_; }
Eigen::Matrix<int,Eigen::Dynamic,1> &staging_ngram() { return ngram_; }
double lookup_from_staging() { return lm_.lookup_ngram(ngram_); }
int order() const { return lm_.get_order(); }
private:
nplm::neuralLM lm_;
Eigen::Matrix<int,Eigen::Dynamic,1> ngram_;
};
bool Model::Recognize(const std::string &name) {
try {
util::scoped_fd file(util::OpenReadOrThrow(name.c_str()));
char magic_check[16];
util::ReadOrThrow(file.get(), magic_check, sizeof(magic_check));
const char nnlm_magic[] = "\\config\nversion ";
return !memcmp(magic_check, nnlm_magic, 16);
} catch (const util::Exception &) {
return false;
}
}
namespace {
nplm::neuralLM *LoadNPLM(const std::string &file) {
util::scoped_ptr<nplm::neuralLM> ret(new nplm::neuralLM());
ret->read(file);
return ret.release();
}
} // namespace
Model::Model(const std::string &file, std::size_t cache)
: base_instance_(LoadNPLM(file)), vocab_(base_instance_->get_vocabulary()), cache_size_(cache) {
UTIL_THROW_IF(base_instance_->get_order() > NPLM_MAX_ORDER, util::Exception, "This NPLM has order " << (unsigned int)base_instance_->get_order() << " but the KenLM wrapper was compiled with " << NPLM_MAX_ORDER << ". Change the defintion of NPLM_MAX_ORDER and recompile.");
// log10 compatible with backoff models.
base_instance_->set_log_base(10.0);
State begin_sentence, null_context;
std::fill(begin_sentence.words, begin_sentence.words + NPLM_MAX_ORDER - 1, base_instance_->lookup_word("<s>"));
null_word_ = base_instance_->lookup_word("<null>");
std::fill(null_context.words, null_context.words + NPLM_MAX_ORDER - 1, null_word_);
Init(begin_sentence, null_context, vocab_, base_instance_->get_order());
}
Model::~Model() {}
FullScoreReturn Model::FullScore(const State &from, const WordIndex new_word, State &out_state) const {
Backend *backend = backend_.get();
if (!backend) {
backend = new Backend(*base_instance_, cache_size_);
backend_.reset(backend);
}
// State is in natural word order.
FullScoreReturn ret;
for (int i = 0; i < backend->order() - 1; ++i) {
backend->staging_ngram()(i) = from.words[i];
}
backend->staging_ngram()(backend->order() - 1) = new_word;
ret.prob = backend->lookup_from_staging();
// Always say full order.
ret.ngram_length = backend->order();
// Shift everything down by one.
memcpy(out_state.words, from.words + 1, sizeof(WordIndex) * (backend->order() - 2));
out_state.words[backend->order() - 2] = new_word;
// Fill in trailing words with zeros so state comparison works.
memset(out_state.words + backend->order() - 1, 0, sizeof(WordIndex) * (NPLM_MAX_ORDER - backend->order()));
return ret;
}
// TODO: optimize with direct call?
FullScoreReturn Model::FullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, State &out_state) const {
// State is in natural word order. The API here specifies reverse order.
std::size_t state_length = std::min<std::size_t>(Order() - 1, context_rend - context_rbegin);
State state;
// Pad with null words.
for (lm::WordIndex *i = state.words; i < state.words + Order() - 1 - state_length; ++i) {
*i = null_word_;
}
// Put new words at the end.
std::reverse_copy(context_rbegin, context_rbegin + state_length, state.words + Order() - 1 - state_length);
return FullScore(state, new_word, out_state);
}
} // namespace np
} // namespace lm
#ifndef LM_WRAPPERS_NPLM_H
#define LM_WRAPPERS_NPLM_H
#include "../facade.hh"
#include "../max_order.hh"
#include "../../util/string_piece.hh"
#include <boost/thread/tss.hpp>
#include <boost/scoped_ptr.hpp>
/* Wrapper to NPLM "by Ashish Vaswani, with contributions from David Chiang
* and Victoria Fossum."
* http://nlg.isi.edu/software/nplm/
*/
namespace nplm {
class vocabulary;
class neuralLM;
} // namespace nplm
namespace lm {
namespace np {
class Vocabulary : public base::Vocabulary {
public:
Vocabulary(const nplm::vocabulary &vocab);
~Vocabulary();
WordIndex Index(const std::string &str) const;
// TODO: lobby them to support StringPiece
WordIndex Index(const StringPiece &str) const {
return Index(std::string(str.data(), str.size()));
}
lm::WordIndex NullWord() const { return null_word_; }
private:
const nplm::vocabulary &vocab_;
const lm::WordIndex null_word_;
};
// Sorry for imposing my limitations on your code.
#define NPLM_MAX_ORDER 7
struct State {
WordIndex words[NPLM_MAX_ORDER - 1];
};
class Backend;
class Model : public lm::base::ModelFacade<Model, State, Vocabulary> {
private:
typedef lm::base::ModelFacade<Model, State, Vocabulary> P;
public:
// Does this look like an NPLM?
static bool Recognize(const std::string &file);
explicit Model(const std::string &file, std::size_t cache_size = 1 << 20);
~Model();
FullScoreReturn FullScore(const State &from, const WordIndex new_word, State &out_state) const;
FullScoreReturn FullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, State &out_state) const;
private:
boost::scoped_ptr<nplm::neuralLM> base_instance_;
mutable boost::thread_specific_ptr<Backend> backend_;
Vocabulary vocab_;
lm::WordIndex null_word_;
const std::size_t cache_size_;
};
} // namespace np
} // namespace lm
#endif // LM_WRAPPERS_NPLM_H
find_package(PythonInterp REQUIRED)
find_package(PythonLibs ${PYTHON_VERSION_STRING} EXACT REQUIRED)
include_directories(${PYTHON_INCLUDE_DIRS})
add_library(kenlm_python MODULE kenlm.cpp score_sentence.cc)
set_target_properties(kenlm_python PROPERTIES OUTPUT_NAME kenlm)
set_target_properties(kenlm_python PROPERTIES PREFIX "")
if(APPLE)
set_target_properties(kenlm_python PROPERTIES SUFFIX ".so")
elseif(WIN32)
set_target_properties(kenlm_python PROPERTIES SUFFIX ".pyd")
endif()
target_link_libraries(kenlm_python PUBLIC kenlm)
if(WIN32)
target_link_libraries(kenlm_python PUBLIC ${PYTHON_LIBRARIES})
elseif(APPLE)
set_target_properties(kenlm_python PROPERTIES LINK_FLAGS "-undefined dynamic_lookup")
endif()
if (WIN32)
set (PYTHON_SITE_PACKAGES Lib/site-packages)
else ()
set (PYTHON_SITE_PACKAGES lib/python${PYTHON_VERSION_MAJOR}.${PYTHON_VERSION_MINOR}/site-packages)
endif ()
install(TARGETS kenlm_python DESTINATION ${PYTHON_SITE_PACKAGES})
from libcpp cimport bool
cdef extern from "lm/word_index.hh" namespace "lm":
ctypedef unsigned WordIndex
cdef extern from "lm/return.hh" namespace "lm":
cdef struct FullScoreReturn:
float prob
unsigned char ngram_length
cdef extern from "lm/state.hh" namespace "lm::ngram":
cdef cppclass State :
int Compare(const State &other) const
int hash_value(const State &state)
cdef extern from "lm/virtual_interface.hh" namespace "lm::base":
cdef cppclass Vocabulary:
WordIndex Index(char*)
WordIndex BeginSentence()
WordIndex EndSentence()
WordIndex NotFound()
ctypedef Vocabulary const_Vocabulary "const lm::base::Vocabulary"
cdef cppclass Model:
void BeginSentenceWrite(void *)
void NullContextWrite(void *)
unsigned int Order()
const_Vocabulary& BaseVocabulary()
float BaseScore(void *in_state, WordIndex new_word, void *out_state)
FullScoreReturn BaseFullScore(void *in_state, WordIndex new_word, void *out_state)
cdef extern from "util/mmap.hh" namespace "util":
cdef enum LoadMethod:
LAZY
POPULATE_OR_LAZY
POPULATE_OR_READ
READ
PARALLEL_READ
cdef extern from "lm/config.hh" namespace "lm::ngram::Config":
cdef enum ARPALoadComplain:
ALL
EXPENSIVE
NONE
cdef extern from "lm/config.hh" namespace "lm::ngram":
cdef cppclass Config:
Config()
float probing_multiplier
LoadMethod load_method
bool show_progress
ARPALoadComplain arpa_complain
float unknown_missing_logprob
cdef extern from "lm/model.hh" namespace "lm::ngram":
cdef Model *LoadVirtual(char *, Config &config) except +
#default constructor
cdef Model *LoadVirtual(char *) except +
cdef extern from "python/score_sentence.hh" namespace "lm::base":
cdef float ScoreSentence(const Model *model, const char *sentence)
#!/usr/bin/env python
import os
import kenlm
LM = os.path.join(os.path.dirname(__file__), '..', 'lm', 'test.arpa')
model = kenlm.LanguageModel(LM)
print('{0}-gram model'.format(model.order))
sentence = 'language modeling is fun .'
print(sentence)
print(model.score(sentence))
# Check that total full score = direct score
def score(s):
return sum(prob for prob, _, _ in model.full_scores(s))
assert (abs(score(sentence) - model.score(sentence)) < 1e-3)
# Show scores and n-gram matches
words = ['<s>'] + sentence.split() + ['</s>']
for i, (prob, length, oov) in enumerate(model.full_scores(sentence)):
print('{0} {1}: {2}'.format(prob, length, ' '.join(words[i+2-length:i+2])))
if oov:
print('\t"{0}" is an OOV'.format(words[i+1]))
# Find out-of-vocabulary words
for w in words:
if not w in model:
print('"{0}" is an OOV'.format(w))
#Stateful query
state = kenlm.State()
state2 = kenlm.State()
#Use <s> as context. If you don't want <s>, use model.NullContextWrite(state).
model.BeginSentenceWrite(state)
accum = 0.0
accum += model.BaseScore(state, "a", state2)
accum += model.BaseScore(state2, "sentence", state)
#score defaults to bos = True and eos = True. Here we'll check without the end
#of sentence marker.
assert (abs(accum - model.score("a sentence", eos = False)) < 1e-3)
accum += model.BaseScore(state, "</s>", state2)
assert (abs(accum - model.score("a sentence")) < 1e-3)
This source diff could not be displayed because it is too large. You can view the blob instead.
import os
cimport _kenlm
cdef bytes as_str(data):
if isinstance(data, bytes):
return data
elif isinstance(data, unicode):
return data.encode('utf8')
raise TypeError('Cannot convert %s to string' % type(data))
cdef class FullScoreReturn:
"""
Wrapper around FullScoreReturn.
Notes:
`prob` has been renamed to `log_prob`
`oov` has been added to flag whether the word is OOV
"""
cdef float log_prob
cdef int ngram_length
cdef bint oov
def __cinit__(self, log_prob, ngram_length, oov):
self.log_prob = log_prob
self.ngram_length = ngram_length
self.oov = oov
def __repr__(self):
return '{0}({1}, {2}, {3})'.format(self.__class__.__name__, repr(self.log_prob), repr(self.ngram_length), repr(self.oov))
property log_prob:
def __get__(self):
return self.log_prob
property ngram_length:
def __get__(self):
return self.ngram_length
property oov:
def __get__(self):
return self.oov
cdef class State:
"""
Wrapper around lm::ngram::State so that python code can make incremental queries.
Notes:
* rich comparisons
* hashable
"""
cdef _kenlm.State _c_state
def __richcmp__(State qa, State qb, int op):
r = qa._c_state.Compare(qb._c_state)
if op == 0: # <
return r < 0
elif op == 1: # <=
return r <= 0
elif op == 2: # ==
return r == 0
elif op == 3: # !=
return r != 0
elif op == 4: # >
return r > 0
else: # >=
return r >= 0
def __hash__(self):
return _kenlm.hash_value(self._c_state)
def __copy__(self):
ret = State()
ret._c_state = self._c_state
return ret
def __deepcopy__(self):
return self.__copy__()
class LoadMethod:
LAZY = _kenlm.LAZY
POPULATE_OR_LAZY = _kenlm.POPULATE_OR_LAZY
POPULATE_OR_READ = _kenlm.POPULATE_OR_READ
READ = _kenlm.READ
PARALLEL_READ = _kenlm.PARALLEL_READ
class ARPALoadComplain:
ALL = _kenlm.ALL
EXPENSIVE = _kenlm.EXPENSIVE
NONE = _kenlm.NONE
cdef class Config:
"""
Wrapper around lm::ngram::Config.
Pass this to Model's constructor to set configuration options.
"""
cdef _kenlm.Config _c_config
def __init__(self):
self._c_config = _kenlm.Config()
property load_method:
def __get__(self):
return self._c_config.load_method
def __set__(self, to):
self._c_config.load_method = to
property show_progress:
def __get__(self):
return self._c_config.show_progress
def __set__(self, to):
self._c_config.show_progress = to
property arpa_complain:
def __get__(self):
return self._c_config.arpa_complain
def __set__(self, to):
self._c_config.arpa_complain = to
cdef class Model:
"""
Wrapper around lm::ngram::Model.
"""
cdef _kenlm.Model* model
cdef public bytes path
cdef _kenlm.const_Vocabulary* vocab
def __init__(self, path, Config config = Config()):
"""
Load the language model.
:param path: path to an arpa file or a kenlm binary file.
:param config: configuration options (see lm/config.hh for documentation)
"""
self.path = os.path.abspath(as_str(path))
try:
self.model = _kenlm.LoadVirtual(self.path, config._c_config)
except RuntimeError as exception:
exception_message = str(exception).replace('\n', ' ')
raise IOError('Cannot read model \'{}\' ({})'.format(path, exception_message))\
from exception
self.vocab = &self.model.BaseVocabulary()
def __dealloc__(self):
del self.model
property order:
def __get__(self):
return self.model.Order()
def score(self, sentence, bos = True, eos = True):
"""
Return the log10 probability of a string. By default, the string is
treated as a sentence.
return log10 p(sentence </s> | <s>)
If you do not want to condition on the beginning of sentence, pass
bos = False
Never include <s> as part of the string. That would be predicting the
beginning of sentence. Language models are only supposed to condition
on it as context.
Similarly, the end of sentence token </s> can be omitted with
eos = False
Since language models explicitly predict </s>, it can be part of the
string.
Examples:
#Good: returns log10 p(this is a sentence . </s> | <s>)
model.score("this is a sentence .")
#Good: same as the above but more explicit
model.score("this is a sentence .", bos = True, eos = True)
#Bad: never include <s>
model.score("<s> this is a sentence")
#Bad: never include <s>, even if bos = False.
model.score("<s> this is a sentence", bos = False)
#Good: returns log10 p(a fragment)
model.score("a fragment", bos = False, eos = False)
#Good: returns log10 p(a fragment </s>)
model.score("a fragment", bos = False, eos = True)
#Ok, but bad practice: returns log10 p(a fragment </s>)
#Unlike <s>, the end of sentence token </s> can appear explicitly.
model.score("a fragment </s>", bos = False, eos = False)
"""
if bos and eos:
return _kenlm.ScoreSentence(self.model, as_str(sentence))
cdef list words = as_str(sentence).split()
cdef _kenlm.State state
if bos:
self.model.BeginSentenceWrite(&state)
else:
self.model.NullContextWrite(&state)
cdef _kenlm.State out_state
cdef float total = 0
for word in words:
total += self.model.BaseScore(&state, self.vocab.Index(word), &out_state)
state = out_state
if eos:
total += self.model.BaseScore(&state, self.vocab.EndSentence(), &out_state)
return total
def perplexity(self, sentence):
"""
Compute perplexity of a sentence.
@param sentence One full sentence to score. Do not include <s> or </s>.
"""
words = len(as_str(sentence).split()) + 1 # For </s>
return 10.0**(-self.score(sentence) / words)
def full_scores(self, sentence, bos = True, eos = True):
"""
full_scores(sentence, bos = True, eos = True) -> generate full scores (prob, ngram length, oov)
@param sentence is a string (do not use boundary symbols)
@param bos should kenlm add a bos state
@param eos should kenlm add an eos state
"""
cdef list words = as_str(sentence).split()
cdef _kenlm.State state
if bos:
self.model.BeginSentenceWrite(&state)
else:
self.model.NullContextWrite(&state)
cdef _kenlm.State out_state
cdef _kenlm.FullScoreReturn ret
cdef float total = 0
cdef _kenlm.WordIndex wid
for word in words:
wid = self.vocab.Index(word)
ret = self.model.BaseFullScore(&state, wid, &out_state)
yield (ret.prob, ret.ngram_length, wid == 0)
state = out_state
if eos:
ret = self.model.BaseFullScore(&state,
self.vocab.EndSentence(), &out_state)
yield (ret.prob, ret.ngram_length, False)
def BeginSentenceWrite(self, State state):
"""Change the given state to a BOS state."""
self.model.BeginSentenceWrite(&state._c_state)
def NullContextWrite(self, State state):
"""Change the given state to a NULL state."""
self.model.NullContextWrite(&state._c_state)
def BaseScore(self, State in_state, str word, State out_state):
"""
Return p(word|in_state) and update the output state.
Wrapper around model.BaseScore(in_state, Index(word), out_state)
:param word: the suffix
:param state: the context (defaults to NullContext)
:returns: p(word|state)
"""
cdef float total = self.model.BaseScore(&in_state._c_state, self.vocab.Index(as_str(word)), &out_state._c_state)
return total
def BaseFullScore(self, State in_state, str word, State out_state):
"""
Wrapper around model.BaseFullScore(in_state, Index(word), out_state)
:param word: the suffix
:param state: the context (defaults to NullContext)
:returns: FullScoreReturn(word|state)
"""
cdef _kenlm.WordIndex wid = self.vocab.Index(as_str(word))
cdef _kenlm.FullScoreReturn ret = self.model.BaseFullScore(&in_state._c_state, wid, &out_state._c_state)
return FullScoreReturn(ret.prob, ret.ngram_length, wid == 0)
def __contains__(self, word):
cdef bytes w = as_str(word)
return (self.vocab.Index(w) != 0)
def __repr__(self):
return '<Model from {0}>'.format(os.path.basename(self.path))
def __reduce__(self):
return (Model, (self.path,))
class LanguageModel(Model):
"""Backwards compatability stub. Use Model."""
#include "lm/state.hh"
#include "lm/virtual_interface.hh"
#include "util/tokenize_piece.hh"
#include <algorithm>
#include <utility>
namespace lm {
namespace base {
float ScoreSentence(const base::Model *model, const char *sentence) {
// TODO: reduce virtual dispatch to one per sentence?
const base::Vocabulary &vocab = model->BaseVocabulary();
// We know it's going to be a KenLM State.
lm::ngram::State state_vec[2];
lm::ngram::State *state = &state_vec[0];
lm::ngram::State *state2 = &state_vec[1];
model->BeginSentenceWrite(state);
float ret = 0.0;
for (util::TokenIter<util::BoolCharacter, true> i(sentence, util::kSpaces); i; ++i) {
lm::WordIndex index = vocab.Index(*i);
ret += model->BaseScore(state, index, state2);
std::swap(state, state2);
}
ret += model->BaseScore(state, vocab.EndSentence(), state2);
return ret;
}
} // namespace base
} // namespace lm
// Score an entire sentence splitting on whitespace. This should not be needed
// for C++ users (who should do it themselves), but it's faster for python users.
#pragma once
namespace lm {
namespace base {
class Model;
float ScoreSentence(const Model *model, const char *sentence);
} // namespace base
} // namespace lm
from setuptools import setup, Extension
import glob
import platform
import os
import sys
import re
#Does gcc compile with this header and library?
def compile_test(header, library):
dummy_path = os.path.join(os.path.dirname(__file__), "dummy")
command = "bash -c \"g++ -include " + header + " -l" + library + " -x c++ - <<<'int main() {}' -o " + dummy_path + " >/dev/null 2>/dev/null && rm " + dummy_path + " 2>/dev/null\""
return os.system(command) == 0
max_order = "6"
is_max_order = [s for s in sys.argv if "--max_order" in s]
for element in is_max_order:
max_order = re.split('[= ]',element)[1]
sys.argv.remove(element)
FILES = glob.glob('util/*.cc') + glob.glob('lm/*.cc') + glob.glob('util/double-conversion/*.cc') + glob.glob('python/*.cc')
FILES = [fn for fn in FILES if not (fn.endswith('main.cc') or fn.endswith('test.cc'))]
if platform.system() == 'Linux':
LIBS = ['stdc++', 'rt']
elif platform.system() == 'Darwin':
LIBS = ['c++']
else:
LIBS = []
#We don't need -std=c++11 but python seems to be compiled with it now. https://github.com/kpu/kenlm/issues/86
ARGS = ['-O3', '-DNDEBUG', '-DKENLM_MAX_ORDER='+max_order, '-std=c++11']
#Attempted fix to https://github.com/kpu/kenlm/issues/186 and https://github.com/kpu/kenlm/issues/197
if platform.system() == 'Darwin':
ARGS += ["-stdlib=libc++", "-mmacosx-version-min=10.7"]
if compile_test('zlib.h', 'z'):
ARGS.append('-DHAVE_ZLIB')
LIBS.append('z')
if compile_test('bzlib.h', 'bz2'):
ARGS.append('-DHAVE_BZLIB')
LIBS.append('bz2')
if compile_test('lzma.h', 'lzma'):
ARGS.append('-DHAVE_XZLIB')
LIBS.append('lzma')
ext_modules = [
Extension(name='kenlm',
sources=FILES + ['python/kenlm.cpp'],
language='C++',
include_dirs=['.'],
libraries=LIBS,
extra_compile_args=ARGS)
]
setup(
name='kenlm',
ext_modules=ext_modules,
include_package_data=True,
)
# Explicitly list the source files for this subdirectory
#
# If you add any source files to this subdirectory
# that should be included in the kenlm library,
# (this excludes any unit test files)
# you should add them to the following list:
#
# Because we do not set PARENT_SCOPE in the following definition,
# CMake files in the parent directory won't be able to access this variable.
#
set(KENLM_UTIL_SOURCE
bit_packing.cc
ersatz_progress.cc
exception.cc
file.cc
file_piece.cc
float_to_string.cc
integer_to_string.cc
mmap.cc
murmur_hash.cc
parallel_read.cc
pool.cc
read_compressed.cc
scoped.cc
spaces.cc
string_piece.cc
usage.cc
)
if (WIN32)
set(KENLM_UTIL_SOURCE ${KENLM_UTIL_SOURCE} getopt.c)
endif()
# This directory has children that need to be processed
add_subdirectory(double-conversion)
add_subdirectory(stream)
add_library(kenlm_util ${KENLM_UTIL_DOUBLECONVERSION_SOURCE} ${KENLM_UTIL_STREAM_SOURCE} ${KENLM_UTIL_SOURCE})
# Since headers are relative to `include/kenlm` at install time, not just `include`
target_include_directories(kenlm_util PUBLIC $<INSTALL_INTERFACE:include/kenlm>)
set(READ_COMPRESSED_FLAGS)
find_package(ZLIB)
if (ZLIB_FOUND)
set(READ_COMPRESSED_FLAGS "${READ_COMPRESSED_FLAGS} -DHAVE_ZLIB")
target_link_libraries(kenlm_util PRIVATE ${ZLIB_LIBRARIES})
include_directories(${ZLIB_INCLUDE_DIR})
endif()
find_package(BZip2)
if (BZIP2_FOUND)
set(READ_COMPRESSED_FLAGS "${READ_COMPRESSED_FLAGS} -DHAVE_BZLIB")
target_link_libraries(kenlm_util PRIVATE ${BZIP2_LIBRARIES})
include_directories(${BZIP2_INCLUDE_DIR})
endif()
find_package(LibLZMA)
if (LIBLZMA_FOUND)
set(READ_COMPRESSED_FLAGS "${READ_COMPRESSED_FLAGS} -DHAVE_XZLIB")
target_link_libraries(kenlm_util PRIVATE ${LIBLZMA_LIBRARIES})
include_directories(${LIBLZMA_INCLUDE_DIRS})
endif()
if (NOT "${READ_COMPRESSED_FLAGS}" STREQUAL "")
set_source_files_properties(read_compressed.cc PROPERTIES COMPILE_FLAGS ${READ_COMPRESSED_FLAGS})
set_source_files_properties(read_compressed_test.cc PROPERTIES COMPILE_FLAGS ${READ_COMPRESSED_FLAGS})
set_source_files_properties(file_piece_test.cc PROPERTIES COMPILE_FLAGS ${READ_COMPRESSED_FLAGS})
endif()
if(UNIX)
include(CheckLibraryExists)
check_library_exists(rt clock_gettime "clock_gettime from librt" HAVE_CLOCKGETTIME_RT)
if (HAVE_CLOCKGETTIME_RT)
set(RT rt)
else()
check_library_exists(c clock_gettime "clock_gettime from the libc" HAVE_CLOCKGETTIME)
endif()
if (HAVE_CLOCKGETTIME_RT OR HAVE_CLOCKGETTIME)
add_definitions(-DHAVE_CLOCKGETTIME)
endif()
endif()
# Group these objects together for later use.
set_target_properties(kenlm_util PROPERTIES POSITION_INDEPENDENT_CODE ON)
target_link_libraries(kenlm_util
PUBLIC
# Boost is required for building binaries and tests
"$<BUILD_INTERFACE:${Boost_LIBRARIES}>"
PRIVATE
Threads::Threads
${RT})
install(
TARGETS kenlm_util
EXPORT kenlmTargets
RUNTIME DESTINATION bin
LIBRARY DESTINATION lib
ARCHIVE DESTINATION lib
INCLUDES DESTINATION include
)
if (NOT WIN32)
AddExes(EXES probing_hash_table_benchmark
LIBRARIES kenlm_util Threads::Threads)
endif()
# Only compile and run unit tests if tests should be run
if(BUILD_TESTING)
set(KENLM_BOOST_TESTS_LIST
bit_packing_test
integer_to_string_test
joint_sort_test
multi_intersection_test
pcqueue_test
probing_hash_table_test
read_compressed_test
sized_iterator_test
sorted_uniform_test
string_stream_test
tokenize_piece_test
)
AddTests(TESTS ${KENLM_BOOST_TESTS_LIST}
LIBRARIES kenlm_util Threads::Threads)
# file_piece_test requires an extra command line parameter
KenLMAddTest(TEST file_piece_test
LIBRARIES kenlm_util Threads::Threads
TEST_ARGS ${CMAKE_CURRENT_SOURCE_DIR}/file_piece.cc)
endif()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment