Commit 688b6eac authored by SWHL's avatar SWHL
Browse files

Update files

parents
#include "model_buffer.hh"
#include "compare.hh"
#include "../state.hh"
#include "../weights.hh"
#include "../../util/exception.hh"
#include "../../util/file_stream.hh"
#include "../../util/file.hh"
#include "../../util/file_piece.hh"
#include "../../util/stream/io.hh"
#include "../../util/stream/multi_stream.hh"
#include <boost/lexical_cast.hpp>
#include <numeric>
namespace lm {
namespace {
const char kMetadataHeader[] = "KenLM intermediate binary file";
} // namespace
ModelBuffer::ModelBuffer(StringPiece file_base, bool keep_buffer, bool output_q)
: file_base_(file_base.data(), file_base.size()), keep_buffer_(keep_buffer), output_q_(output_q),
vocab_file_(keep_buffer ? util::CreateOrThrow((file_base_ + ".vocab").c_str()) : util::MakeTemp(file_base_)) {}
ModelBuffer::ModelBuffer(StringPiece file_base)
: file_base_(file_base.data(), file_base.size()), keep_buffer_(false) {
const std::string full_name = file_base_ + ".kenlm_intermediate";
util::FilePiece in(full_name.c_str());
StringPiece token = in.ReadLine();
UTIL_THROW_IF2(token != kMetadataHeader, "File " << full_name << " begins with \"" << token << "\" not " << kMetadataHeader);
token = in.ReadDelimited();
UTIL_THROW_IF2(token != "Counts", "Expected Counts, got \"" << token << "\" in " << full_name);
char got;
while ((got = in.get()) == ' ') {
counts_.push_back(in.ReadULong());
}
UTIL_THROW_IF2(got != '\n', "Expected newline at end of counts.");
token = in.ReadDelimited();
UTIL_THROW_IF2(token != "Payload", "Expected Payload, got \"" << token << "\" in " << full_name);
token = in.ReadDelimited();
if (token == "q") {
output_q_ = true;
} else if (token == "pb") {
output_q_ = false;
} else {
UTIL_THROW(util::Exception, "Unknown payload " << token);
}
vocab_file_.reset(util::OpenReadOrThrow((file_base_ + ".vocab").c_str()));
files_.Init(counts_.size());
for (unsigned long i = 0; i < counts_.size(); ++i) {
files_.push_back(util::OpenReadOrThrow((file_base_ + '.' + boost::lexical_cast<std::string>(i + 1)).c_str()));
}
}
void ModelBuffer::Sink(util::stream::Chains &chains, const std::vector<uint64_t> &counts) {
counts_ = counts;
// Open files.
files_.Init(chains.size());
for (std::size_t i = 0; i < chains.size(); ++i) {
if (keep_buffer_) {
files_.push_back(util::CreateOrThrow(
(file_base_ + '.' + boost::lexical_cast<std::string>(i + 1)).c_str()
));
} else {
files_.push_back(util::MakeTemp(file_base_));
}
chains[i] >> util::stream::Write(files_.back().get());
}
if (keep_buffer_) {
util::scoped_fd metadata(util::CreateOrThrow((file_base_ + ".kenlm_intermediate").c_str()));
util::FileStream meta(metadata.get(), 200);
meta << kMetadataHeader << "\nCounts";
for (std::vector<uint64_t>::const_iterator i = counts_.begin(); i != counts_.end(); ++i) {
meta << ' ' << *i;
}
meta << "\nPayload " << (output_q_ ? "q" : "pb") << '\n';
}
}
void ModelBuffer::Source(util::stream::Chains &chains) {
assert(chains.size() <= files_.size());
for (unsigned int i = 0; i < chains.size(); ++i) {
chains[i].SetProgressTarget(util::SizeOrThrow(files_[i].get()));
chains[i] >> util::stream::PRead(files_[i].get());
}
}
void ModelBuffer::Source(std::size_t order_minus_1, util::stream::Chain &chain) {
chain >> util::stream::PRead(files_[order_minus_1].get());
}
float ModelBuffer::SlowQuery(const ngram::State &context, WordIndex word, ngram::State &out) const {
// Lookup unigram.
ProbBackoff value;
util::ErsatzPRead(RawFile(0), &value, sizeof(value), word * (sizeof(WordIndex) + sizeof(value)) + sizeof(WordIndex));
out.backoff[0] = value.backoff;
out.words[0] = word;
out.length = 1;
std::vector<WordIndex> buffer(context.length + 1), query(context.length + 1);
std::reverse_copy(context.words, context.words + context.length, query.begin());
query[context.length] = word;
for (std::size_t order = 2; order <= query.size() && order <= context.length + 1; ++order) {
SuffixOrder less(order);
const WordIndex *key = &*query.end() - order;
int file = RawFile(order - 1);
std::size_t length = order * sizeof(WordIndex) + sizeof(ProbBackoff);
// TODO: cache file size?
uint64_t begin = 0, end = util::SizeOrThrow(file) / length;
while (true) {
if (end <= begin) {
// Did not find for order.
return std::accumulate(context.backoff + out.length - 1, context.backoff + context.length, value.prob);
}
uint64_t test = begin + (end - begin) / 2;
util::ErsatzPRead(file, &*buffer.begin(), sizeof(WordIndex) * order, test * length);
if (less(&*buffer.begin(), key)) {
begin = test + 1;
} else if (less(key, &*buffer.begin())) {
end = test;
} else {
// Found it.
util::ErsatzPRead(file, &value, sizeof(value), test * length + sizeof(WordIndex) * order);
if (order != Order()) {
out.length = order;
out.backoff[order - 1] = value.backoff;
out.words[order - 1] = *key;
}
break;
}
}
}
return value.prob;
}
} // namespace
#ifndef LM_COMMON_MODEL_BUFFER_H
#define LM_COMMON_MODEL_BUFFER_H
/* Format with separate files in suffix order. Each file contains
* n-grams of the same order.
*/
#include "../word_index.hh"
#include "../../util/file.hh"
#include "../../util/fixed_array.hh"
#include "../../util/string_piece.hh"
#include <string>
#include <vector>
namespace util { namespace stream {
class Chains;
class Chain;
}} // namespaces
namespace lm {
namespace ngram { class State; }
class ModelBuffer {
public:
// Construct for writing. Must call VocabFile() and fill it with null-delimited vocab words.
ModelBuffer(StringPiece file_base, bool keep_buffer, bool output_q);
// Load from file.
explicit ModelBuffer(StringPiece file_base);
// Must call VocabFile and populate before calling this function.
void Sink(util::stream::Chains &chains, const std::vector<uint64_t> &counts);
// Read files and write to the given chains. If fewer chains are provided,
// only do the lower orders.
void Source(util::stream::Chains &chains);
void Source(std::size_t order_minus_1, util::stream::Chain &chain);
// The order of the n-gram model that is associated with the model buffer.
std::size_t Order() const { return counts_.size(); }
// Requires Sink or load from file.
const std::vector<uint64_t> &Counts() const {
assert(!counts_.empty());
return counts_;
}
int VocabFile() const { return vocab_file_.get(); }
int RawFile(std::size_t order_minus_1) const {
return files_[order_minus_1].get();
}
bool Keep() const { return keep_buffer_; }
// Slowly execute a language model query with binary search.
// This is used by interpolation to gather tuning probabilities rather than
// scanning the files.
float SlowQuery(const ngram::State &context, WordIndex word, ngram::State &out) const;
private:
const std::string file_base_;
const bool keep_buffer_;
bool output_q_;
std::vector<uint64_t> counts_;
util::scoped_fd vocab_file_;
util::FixedArray<util::scoped_fd> files_;
};
} // namespace lm
#endif // LM_COMMON_MODEL_BUFFER_H
#include "model_buffer.hh"
#include "../model.hh"
#include "../state.hh"
#define BOOST_TEST_MODULE ModelBufferTest
#include <boost/test/unit_test.hpp>
namespace lm { namespace {
BOOST_AUTO_TEST_CASE(Query) {
std::string dir("test_data");
if (boost::unit_test::framework::master_test_suite().argc == 2) {
dir = boost::unit_test::framework::master_test_suite().argv[1];
}
ngram::Model ref((dir + "/toy0.arpa").c_str());
#if BYTE_ORDER == LITTLE_ENDIAN
std::string endian = "little";
#elif BYTE_ORDER == BIG_ENDIAN
std::string endian = "big";
#else
#error "Unsupported byte order."
#endif
ModelBuffer test(dir + "/" + endian + "endian/toy0");
ngram::State ref_state, test_state;
WordIndex a = ref.GetVocabulary().Index("a");
BOOST_CHECK_CLOSE(
ref.FullScore(ref.BeginSentenceState(), a, ref_state).prob,
test.SlowQuery(ref.BeginSentenceState(), a, test_state),
0.001);
BOOST_CHECK_EQUAL((unsigned)ref_state.length, (unsigned)test_state.length);
BOOST_CHECK_EQUAL(ref_state.words[0], test_state.words[0]);
BOOST_CHECK_EQUAL(ref_state.backoff[0], test_state.backoff[0]);
BOOST_CHECK(ref_state == test_state);
ngram::State ref_state2, test_state2;
WordIndex b = ref.GetVocabulary().Index("b");
BOOST_CHECK_CLOSE(
ref.FullScore(ref_state, b, ref_state2).prob,
test.SlowQuery(test_state, b, test_state2),
0.001);
BOOST_CHECK(ref_state2 == test_state2);
BOOST_CHECK_EQUAL(ref_state2.backoff[0], test_state2.backoff[0]);
BOOST_CHECK_CLOSE(
ref.FullScore(ref_state2, 0, ref_state).prob,
test.SlowQuery(test_state2, 0, test_state),
0.001);
// The reference does state minimization but this doesn't.
}
}} // namespaces
#ifndef LM_COMMON_NGRAM_H
#define LM_COMMON_NGRAM_H
#include "../weights.hh"
#include "../word_index.hh"
#include <cstddef>
#include <cassert>
#include <stdint.h>
#include <cstring>
namespace lm {
class NGramHeader {
public:
NGramHeader(void *begin, std::size_t order)
: begin_(static_cast<WordIndex*>(begin)), end_(begin_ + order) {}
NGramHeader() : begin_(NULL), end_(NULL) {}
const uint8_t *Base() const { return reinterpret_cast<const uint8_t*>(begin_); }
uint8_t *Base() { return reinterpret_cast<uint8_t*>(begin_); }
void ReBase(void *to) {
std::size_t difference = end_ - begin_;
begin_ = reinterpret_cast<WordIndex*>(to);
end_ = begin_ + difference;
}
// These are for the vocab index.
// Lower-case in deference to STL.
const WordIndex *begin() const { return begin_; }
WordIndex *begin() { return begin_; }
const WordIndex *end() const { return end_; }
WordIndex *end() { return end_; }
std::size_t size() const { return end_ - begin_; }
std::size_t Order() const { return end_ - begin_; }
private:
WordIndex *begin_, *end_;
};
template <class PayloadT> class NGram : public NGramHeader {
public:
typedef PayloadT Payload;
NGram() : NGramHeader(NULL, 0) {}
NGram(void *begin, std::size_t order) : NGramHeader(begin, order) {}
// Would do operator++ but that can get confusing for a stream.
void NextInMemory() {
ReBase(&Value() + 1);
}
static std::size_t TotalSize(std::size_t order) {
return order * sizeof(WordIndex) + sizeof(Payload);
}
std::size_t TotalSize() const {
// Compiler should optimize this.
return TotalSize(Order());
}
static std::size_t OrderFromSize(std::size_t size) {
std::size_t ret = (size - sizeof(Payload)) / sizeof(WordIndex);
assert(size == TotalSize(ret));
return ret;
}
const Payload &Value() const { return *reinterpret_cast<const Payload *>(end()); }
Payload &Value() { return *reinterpret_cast<Payload *>(end()); }
};
} // namespace lm
#endif // LM_COMMON_NGRAM_H
#ifndef LM_BUILDER_NGRAM_STREAM_H
#define LM_BUILDER_NGRAM_STREAM_H
#include "ngram.hh"
#include "../../util/stream/chain.hh"
#include "../../util/stream/multi_stream.hh"
#include "../../util/stream/stream.hh"
#include <cstddef>
namespace lm {
template <class Proxy> class ProxyStream {
public:
// Make an invalid stream.
ProxyStream() {}
explicit ProxyStream(const util::stream::ChainPosition &position, const Proxy &proxy = Proxy())
: proxy_(proxy), stream_(position) {
proxy_.ReBase(stream_.Get());
}
Proxy &operator*() { return proxy_; }
const Proxy &operator*() const { return proxy_; }
Proxy *operator->() { return &proxy_; }
const Proxy *operator->() const { return &proxy_; }
void *Get() { return stream_.Get(); }
const void *Get() const { return stream_.Get(); }
operator bool() const { return stream_; }
bool operator!() const { return !stream_; }
void Poison() { stream_.Poison(); }
ProxyStream<Proxy> &operator++() {
++stream_;
proxy_.ReBase(stream_.Get());
return *this;
}
private:
Proxy proxy_;
util::stream::Stream stream_;
};
template <class Payload> class NGramStream : public ProxyStream<NGram<Payload> > {
public:
// Make an invalid stream.
NGramStream() {}
explicit NGramStream(const util::stream::ChainPosition &position) :
ProxyStream<NGram<Payload> >(position, NGram<Payload>(NULL, NGram<Payload>::OrderFromSize(position.GetChain().EntrySize()))) {}
};
template <class Payload> class NGramStreams : public util::stream::GenericStreams<NGramStream<Payload> > {
private:
typedef util::stream::GenericStreams<NGramStream<Payload> > P;
public:
NGramStreams() : P() {}
NGramStreams(const util::stream::ChainPositions &positions) : P(positions) {}
};
} // namespace
#endif // LM_BUILDER_NGRAM_STREAM_H
#include "print.hh"
#include "ngram_stream.hh"
#include "../../util/file_stream.hh"
#include "../../util/file.hh"
#include "../../util/mmap.hh"
#include "../../util/scoped.hh"
#include <sstream>
#include <cstring>
namespace lm {
VocabReconstitute::VocabReconstitute(int fd) {
uint64_t size = util::SizeOrThrow(fd);
util::MapRead(util::POPULATE_OR_READ, fd, 0, size, memory_);
const char *const start = static_cast<const char*>(memory_.get());
const char *i;
for (i = start; i != start + size; i += strlen(i) + 1) {
map_.push_back(i);
}
// Last one for LookupPiece.
map_.push_back(i);
}
namespace {
template <class Payload> void PrintLead(const VocabReconstitute &vocab, ProxyStream<Payload> &stream, util::FileStream &out) {
out << stream->Value().prob << '\t' << vocab.Lookup(*stream->begin());
for (const WordIndex *i = stream->begin() + 1; i != stream->end(); ++i) {
out << ' ' << vocab.Lookup(*i);
}
}
} // namespace
void PrintARPA::Run(const util::stream::ChainPositions &positions) {
VocabReconstitute vocab(vocab_fd_);
util::FileStream out(out_fd_);
out << "\\data\\\n";
for (size_t i = 0; i < positions.size(); ++i) {
out << "ngram " << (i+1) << '=' << counts_[i] << '\n';
}
out << '\n';
for (unsigned order = 1; order < positions.size(); ++order) {
out << "\\" << order << "-grams:" << '\n';
for (ProxyStream<NGram<ProbBackoff> > stream(positions[order - 1], NGram<ProbBackoff>(NULL, order)); stream; ++stream) {
PrintLead(vocab, stream, out);
out << '\t' << stream->Value().backoff << '\n';
}
out << '\n';
}
out << "\\" << positions.size() << "-grams:" << '\n';
for (ProxyStream<NGram<Prob> > stream(positions.back(), NGram<Prob>(NULL, positions.size())); stream; ++stream) {
PrintLead(vocab, stream, out);
out << '\n';
}
out << '\n';
out << "\\end\\\n";
}
} // namespace lm
#ifndef LM_COMMON_PRINT_H
#define LM_COMMON_PRINT_H
#include "../word_index.hh"
#include "../../util/mmap.hh"
#include "../../util/string_piece.hh"
#include <cassert>
#include <vector>
namespace util { namespace stream { class ChainPositions; }}
// Warning: PrintARPA routines read all unigrams before all bigrams before all
// trigrams etc. So if other parts of the chain move jointly, you'll have to
// buffer.
namespace lm {
class VocabReconstitute {
public:
// fd must be alive for life of this object; does not take ownership.
explicit VocabReconstitute(int fd);
const char *Lookup(WordIndex index) const {
assert(index < map_.size() - 1);
return map_[index];
}
StringPiece LookupPiece(WordIndex index) const {
return StringPiece(map_[index], map_[index + 1] - 1 - map_[index]);
}
std::size_t Size() const {
// There's an extra entry to support StringPiece lengths.
return map_.size() - 1;
}
private:
util::scoped_memory memory_;
std::vector<const char*> map_;
};
class PrintARPA {
public:
// Does not take ownership of vocab_fd or out_fd.
explicit PrintARPA(int vocab_fd, int out_fd, const std::vector<uint64_t> &counts)
: vocab_fd_(vocab_fd), out_fd_(out_fd), counts_(counts) {}
void Run(const util::stream::ChainPositions &positions);
private:
int vocab_fd_;
int out_fd_;
std::vector<uint64_t> counts_;
};
} // namespace lm
#endif // LM_COMMON_PRINT_H
#include "renumber.hh"
#include "ngram.hh"
#include "../../util/stream/stream.hh"
namespace lm {
void Renumber::Run(const util::stream::ChainPosition &position) {
for (util::stream::Stream stream(position); stream; ++stream) {
NGramHeader gram(stream.Get(), order_);
for (WordIndex *w = gram.begin(); w != gram.end(); ++w) {
*w = new_numbers_[*w];
}
}
}
} // namespace lm
/* Map vocab ids. This is useful to merge independently collected counts or
* change the vocab ids to the order used by the trie.
*/
#ifndef LM_COMMON_RENUMBER_H
#define LM_COMMON_RENUMBER_H
#include "../word_index.hh"
#include <cstddef>
namespace util { namespace stream { class ChainPosition; }}
namespace lm {
class Renumber {
public:
// Assumes the array is large enough to map all words and stays alive while
// the thread is active.
Renumber(const WordIndex *new_numbers, std::size_t order)
: new_numbers_(new_numbers), order_(order) {}
void Run(const util::stream::ChainPosition &position);
private:
const WordIndex *new_numbers_;
std::size_t order_;
};
} // namespace lm
#endif // LM_COMMON_RENUMBER_H
#include <boost/program_options.hpp>
#include "../../util/usage.hh"
namespace lm {
namespace {
class SizeNotify {
public:
explicit SizeNotify(std::size_t &out) : behind_(out) {}
void operator()(const std::string &from) {
behind_ = util::ParseSize(from);
}
private:
std::size_t &behind_;
};
}
boost::program_options::typed_value<std::string> *SizeOption(std::size_t &to, const char *default_value) {
return boost::program_options::value<std::string>()->notifier(SizeNotify(to))->default_value(default_value);
}
} // namespace lm
#include <boost/program_options.hpp>
#include <cstddef>
#include <string>
namespace lm {
// Create a boost program option for data sizes. This parses sizes like 1T and 10k.
boost::program_options::typed_value<std::string> *SizeOption(std::size_t &to, const char *default_value);
} // namespace lm
#ifndef LM_COMMON_SPECIAL_H
#define LM_COMMON_SPECIAL_H
#include "../word_index.hh"
namespace lm {
class SpecialVocab {
public:
SpecialVocab(WordIndex bos, WordIndex eos) : bos_(bos), eos_(eos) {}
bool IsSpecial(WordIndex word) const {
return word == kUNK || word == bos_ || word == eos_;
}
WordIndex UNK() const { return kUNK; }
WordIndex BOS() const { return bos_; }
WordIndex EOS() const { return eos_; }
private:
WordIndex bos_;
WordIndex eos_;
};
} // namespace lm
#endif // LM_COMMON_SPECIAL_H
KenLM intermediate binary file
Counts 5 7 7
Payload pb
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment