Commit 688b6eac authored by SWHL's avatar SWHL
Browse files

Update files

parents
#include "arpa_io.hh"
#include "../../util/file_piece.hh"
#include "../../util/string_stream.hh"
#include <iostream>
#include <ostream>
#include <string>
#include <vector>
#include <cctype>
#include <cerrno>
#include <cstring>
namespace lm {
ARPAInputException::ARPAInputException(const StringPiece &message) throw() {
*this << message;
}
ARPAInputException::ARPAInputException(const StringPiece &message, const StringPiece &line) throw() {
*this << message << " in line " << line;
}
ARPAInputException::~ARPAInputException() throw() {}
// Seeking is the responsibility of the caller.
template <class Stream> void WriteCounts(Stream &out, const std::vector<uint64_t> &number) {
out << "\n\\data\\\n";
for (unsigned int i = 0; i < number.size(); ++i) {
out << "ngram " << i+1 << "=" << number[i] << '\n';
}
out << '\n';
}
size_t SizeNeededForCounts(const std::vector<uint64_t> &number) {
util::StringStream stream;
WriteCounts(stream, number);
return stream.str().size();
}
bool IsEntirelyWhiteSpace(const StringPiece &line) {
for (size_t i = 0; i < static_cast<size_t>(line.size()); ++i) {
if (!isspace(line.data()[i])) return false;
}
return true;
}
ARPAOutput::ARPAOutput(const char *name, size_t buffer_size)
: file_backing_(util::CreateOrThrow(name)), file_(file_backing_.get(), buffer_size) {}
void ARPAOutput::ReserveForCounts(std::streampos reserve) {
for (std::streampos i = 0; i < reserve; i += std::streampos(1)) {
file_ << '\n';
}
}
void ARPAOutput::BeginLength(unsigned int length) {
file_ << '\\' << length << "-grams:" << '\n';
fast_counter_ = 0;
}
void ARPAOutput::EndLength(unsigned int length) {
file_ << '\n';
if (length > counts_.size()) {
counts_.resize(length);
}
counts_[length - 1] = fast_counter_;
}
void ARPAOutput::Finish() {
file_ << "\\end\\\n";
file_.seekp(0);
WriteCounts(file_, counts_);
file_.flush();
}
} // namespace lm
#ifndef LM_FILTER_ARPA_IO_H
#define LM_FILTER_ARPA_IO_H
/* Input and output for ARPA format language model files.
*/
#include "../read_arpa.hh"
#include "../../util/exception.hh"
#include "../../util/file_stream.hh"
#include "../../util/string_piece.hh"
#include "../../util/tokenize_piece.hh"
#include <boost/noncopyable.hpp>
#include <boost/scoped_array.hpp>
#include <fstream>
#include <string>
#include <vector>
#include <cstring>
#include <stdint.h>
namespace util { class FilePiece; }
namespace lm {
class ARPAInputException : public util::Exception {
public:
explicit ARPAInputException(const StringPiece &message) throw();
explicit ARPAInputException(const StringPiece &message, const StringPiece &line) throw();
virtual ~ARPAInputException() throw();
};
// Handling for the counts of n-grams at the beginning of ARPA files.
size_t SizeNeededForCounts(const std::vector<uint64_t> &number);
/* Writes an ARPA file. This has to be seekable so the counts can be written
* at the end. Hence, I just have it own a std::fstream instead of accepting
* a separately held std::ostream. TODO: use the fast one from estimation.
*/
class ARPAOutput : boost::noncopyable {
public:
explicit ARPAOutput(const char *name, size_t buffer_size = 65536);
void ReserveForCounts(std::streampos reserve);
void BeginLength(unsigned int length);
void AddNGram(const StringPiece &line) {
file_ << line << '\n';
++fast_counter_;
}
void AddNGram(const StringPiece &ngram, const StringPiece &line) {
AddNGram(line);
}
template <class Iterator> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line) {
AddNGram(line);
}
void EndLength(unsigned int length);
void Finish();
private:
util::scoped_fd file_backing_;
util::FileStream file_;
uint64_t fast_counter_;
std::vector<uint64_t> counts_;
};
template <class Output> void ReadNGrams(util::FilePiece &in, unsigned int length, uint64_t number, Output &out) {
ReadNGramHeader(in, length);
out.BeginLength(length);
for (uint64_t i = 0; i < number; ++i) {
StringPiece line = in.ReadLine();
util::TokenIter<util::SingleCharacter> tabber(line, '\t');
if (!tabber) throw ARPAInputException("blank line", line);
if (!++tabber) throw ARPAInputException("no tab", line);
out.AddNGram(*tabber, line);
}
out.EndLength(length);
}
template <class Output> void ReadARPA(util::FilePiece &in_lm, Output &out) {
std::vector<uint64_t> number;
ReadARPACounts(in_lm, number);
out.ReserveForCounts(SizeNeededForCounts(number));
for (unsigned int i = 0; i < number.size(); ++i) {
ReadNGrams(in_lm, i + 1, number[i], out);
}
ReadEnd(in_lm);
out.Finish();
}
} // namespace lm
#endif // LM_FILTER_ARPA_IO_H
#ifndef LM_FILTER_COUNT_IO_H
#define LM_FILTER_COUNT_IO_H
#include <fstream>
#include <iostream>
#include <string>
#include "../../util/file_stream.hh"
#include "../../util/file.hh"
#include "../../util/file_piece.hh"
namespace lm {
class CountOutput : boost::noncopyable {
public:
explicit CountOutput(const char *name) : file_(util::CreateOrThrow(name)) {}
void AddNGram(const StringPiece &line) {
file_ << line << '\n';
}
template <class Iterator> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line) {
AddNGram(line);
}
void AddNGram(const StringPiece &ngram, const StringPiece &line) {
AddNGram(line);
}
private:
util::FileStream file_;
};
class CountBatch {
public:
explicit CountBatch(std::streamsize initial_read)
: initial_read_(initial_read) {
buffer_.reserve(initial_read);
}
void Read(std::istream &in) {
buffer_.resize(initial_read_);
in.read(&*buffer_.begin(), initial_read_);
buffer_.resize(in.gcount());
char got;
while (in.get(got) && got != '\n')
buffer_.push_back(got);
}
template <class Output> void Send(Output &out) {
for (util::TokenIter<util::SingleCharacter> line(StringPiece(&*buffer_.begin(), buffer_.size()), '\n'); line; ++line) {
util::TokenIter<util::SingleCharacter> tabber(*line, '\t');
if (!tabber) {
std::cerr << "Warning: empty n-gram count line being removed\n";
continue;
}
util::TokenIter<util::SingleCharacter, true> words(*tabber, ' ');
if (!words) {
std::cerr << "Line has a tab but no words.\n";
continue;
}
out.AddNGram(words, util::TokenIter<util::SingleCharacter, true>::end(), *line);
}
}
private:
std::streamsize initial_read_;
// This could have been a std::string but that's less happy with raw writes.
std::vector<char> buffer_;
};
template <class Output> void ReadCount(util::FilePiece &in_file, Output &out) {
try {
while (true) {
StringPiece line = in_file.ReadLine();
util::TokenIter<util::SingleCharacter> tabber(line, '\t');
if (!tabber) {
std::cerr << "Warning: empty n-gram count line being removed\n";
continue;
}
out.AddNGram(*tabber, line);
}
} catch (const util::EndOfFileException &) {}
}
} // namespace lm
#endif // LM_FILTER_COUNT_IO_H
#include "arpa_io.hh"
#include "format.hh"
#include "phrase.hh"
#ifndef NTHREAD
#include "thread.hh"
#endif
#include "vocab.hh"
#include "wrapper.hh"
#include "../../util/exception.hh"
#include "../../util/file_piece.hh"
#include <boost/ptr_container/ptr_vector.hpp>
#include <cstring>
#include <fstream>
#include <iostream>
#include <memory>
namespace lm {
namespace {
void DisplayHelp(const char *name) {
std::cerr
<< "Usage: " << name << " mode [context] [phrase] [raw|arpa] [threads:m] [batch_size:m] (vocab|model):input_file output_file\n\n"
"copy mode just copies, but makes the format nicer for e.g. irstlm's broken\n"
" parser.\n"
"single mode treats the entire input as a single sentence.\n"
"multiple mode filters to multiple sentences in parallel. Each sentence is on\n"
" a separate line. A separate file is created for each sentence by appending\n"
" the 0-indexed line number to the output file name.\n"
"union mode produces one filtered model that is the union of models created by\n"
" multiple mode.\n\n"
"context means only the context (all but last word) has to pass the filter, but\n"
" the entire n-gram is output.\n\n"
"phrase means that the vocabulary is actually tab-delimited phrases and that the\n"
" phrases can generate the n-gram when assembled in arbitrary order and\n"
" clipped. Currently works with multiple or union mode.\n\n"
"The file format is set by [raw|arpa] with default arpa:\n"
"raw means space-separated tokens, optionally followed by a tab and arbitrary\n"
" text. This is useful for ngram count files.\n"
"arpa means the ARPA file format for n-gram language models.\n\n"
#ifndef NTHREAD
"threads:m sets m threads (default: conccurrency detected by boost)\n"
"batch_size:m sets the batch size for threading. Expect memory usage from this\n"
" of 2*threads*batch_size n-grams.\n\n"
#else
"This binary was compiled with -DNTHREAD, disabling threading. If you wanted\n"
" threading, compile without this flag against Boost >=1.42.0.\n\n"
#endif
"There are two inputs: vocabulary and model. Either may be given as a file\n"
" while the other is on stdin. Specify the type given as a file using\n"
" vocab: or model: before the file name. \n\n"
"For ARPA format, the output must be seekable. For raw format, it can be a\n"
" stream i.e. /dev/stdout\n";
}
typedef enum {MODE_COPY, MODE_SINGLE, MODE_MULTIPLE, MODE_UNION, MODE_UNSET} FilterMode;
typedef enum {FORMAT_ARPA, FORMAT_COUNT} Format;
struct Config {
Config() :
#ifndef NTHREAD
batch_size(25000),
threads(boost::thread::hardware_concurrency()),
#endif
phrase(false),
context(false),
format(FORMAT_ARPA)
{
#ifndef NTHREAD
if (!threads) threads = 1;
#endif
}
#ifndef NTHREAD
size_t batch_size;
size_t threads;
#endif
bool phrase;
bool context;
FilterMode mode;
Format format;
};
template <class Format, class Filter, class OutputBuffer, class Output> void RunThreadedFilter(const Config &config, util::FilePiece &in_lm, Filter &filter, Output &output) {
#ifndef NTHREAD
if (config.threads == 1) {
#endif
Format::RunFilter(in_lm, filter, output);
#ifndef NTHREAD
} else {
typedef Controller<Filter, OutputBuffer, Output> Threaded;
Threaded threading(config.batch_size, config.threads * 2, config.threads, filter, output);
Format::RunFilter(in_lm, threading, output);
}
#endif
}
template <class Format, class Filter, class OutputBuffer, class Output> void RunContextFilter(const Config &config, util::FilePiece &in_lm, Filter filter, Output &output) {
if (config.context) {
ContextFilter<Filter> context_filter(filter);
RunThreadedFilter<Format, ContextFilter<Filter>, OutputBuffer, Output>(config, in_lm, context_filter, output);
} else {
RunThreadedFilter<Format, Filter, OutputBuffer, Output>(config, in_lm, filter, output);
}
}
template <class Format, class Binary> void DispatchBinaryFilter(const Config &config, util::FilePiece &in_lm, const Binary &binary, typename Format::Output &out) {
typedef BinaryFilter<Binary> Filter;
RunContextFilter<Format, Filter, BinaryOutputBuffer, typename Format::Output>(config, in_lm, Filter(binary), out);
}
template <class Format> void DispatchFilterModes(const Config &config, std::istream &in_vocab, util::FilePiece &in_lm, const char *out_name) {
if (config.mode == MODE_MULTIPLE) {
if (config.phrase) {
typedef phrase::Multiple Filter;
phrase::Substrings substrings;
typename Format::Multiple out(out_name, phrase::ReadMultiple(in_vocab, substrings));
RunContextFilter<Format, Filter, MultipleOutputBuffer, typename Format::Multiple>(config, in_lm, Filter(substrings), out);
} else {
typedef vocab::Multiple Filter;
boost::unordered_map<std::string, std::vector<unsigned int> > words;
typename Format::Multiple out(out_name, vocab::ReadMultiple(in_vocab, words));
RunContextFilter<Format, Filter, MultipleOutputBuffer, typename Format::Multiple>(config, in_lm, Filter(words), out);
}
return;
}
typename Format::Output out(out_name);
if (config.mode == MODE_COPY) {
Format::Copy(in_lm, out);
return;
}
if (config.mode == MODE_SINGLE) {
vocab::Single::Words words;
vocab::ReadSingle(in_vocab, words);
DispatchBinaryFilter<Format, vocab::Single>(config, in_lm, vocab::Single(words), out);
return;
}
if (config.mode == MODE_UNION) {
if (config.phrase) {
phrase::Substrings substrings;
phrase::ReadMultiple(in_vocab, substrings);
DispatchBinaryFilter<Format, phrase::Union>(config, in_lm, phrase::Union(substrings), out);
} else {
vocab::Union::Words words;
vocab::ReadMultiple(in_vocab, words);
DispatchBinaryFilter<Format, vocab::Union>(config, in_lm, vocab::Union(words), out);
}
return;
}
}
} // namespace
} // namespace lm
int main(int argc, char *argv[]) {
try {
if (argc < 4) {
lm::DisplayHelp(argv[0]);
return 1;
}
// I used to have boost::program_options, but some users didn't want to compile boost.
lm::Config config;
config.mode = lm::MODE_UNSET;
for (int i = 1; i < argc - 2; ++i) {
const char *str = argv[i];
if (!std::strcmp(str, "copy")) {
config.mode = lm::MODE_COPY;
} else if (!std::strcmp(str, "single")) {
config.mode = lm::MODE_SINGLE;
} else if (!std::strcmp(str, "multiple")) {
config.mode = lm::MODE_MULTIPLE;
} else if (!std::strcmp(str, "union")) {
config.mode = lm::MODE_UNION;
} else if (!std::strcmp(str, "phrase")) {
config.phrase = true;
} else if (!std::strcmp(str, "context")) {
config.context = true;
} else if (!std::strcmp(str, "arpa")) {
config.format = lm::FORMAT_ARPA;
} else if (!std::strcmp(str, "raw")) {
config.format = lm::FORMAT_COUNT;
#ifndef NTHREAD
} else if (!std::strncmp(str, "threads:", 8)) {
config.threads = boost::lexical_cast<size_t>(str + 8);
if (!config.threads) {
std::cerr << "Specify at least one thread." << std::endl;
return 1;
}
} else if (!std::strncmp(str, "batch_size:", 11)) {
config.batch_size = boost::lexical_cast<size_t>(str + 11);
if (config.batch_size < 5000) {
std::cerr << "Batch size must be at least one and should probably be >= 5000" << std::endl;
if (!config.batch_size) return 1;
}
#endif
} else {
lm::DisplayHelp(argv[0]);
return 1;
}
}
if (config.mode == lm::MODE_UNSET) {
lm::DisplayHelp(argv[0]);
return 1;
}
if (config.phrase && config.mode != lm::MODE_UNION && config.mode != lm::MODE_MULTIPLE) {
std::cerr << "Phrase constraint currently only works in multiple or union mode. If you really need it for single, put everything on one line and use union." << std::endl;
return 1;
}
bool cmd_is_model = true;
const char *cmd_input = argv[argc - 2];
if (!strncmp(cmd_input, "vocab:", 6)) {
cmd_is_model = false;
cmd_input += 6;
} else if (!strncmp(cmd_input, "model:", 6)) {
cmd_input += 6;
} else if (strchr(cmd_input, ':')) {
std::cerr << "Specify vocab: or model: before the input file name, not " << cmd_input << std::endl;
return 1;
} else {
std::cerr << "Assuming that " << cmd_input << " is a model file" << std::endl;
}
std::ifstream cmd_file;
std::istream *vocab;
if (cmd_is_model) {
vocab = &std::cin;
} else {
cmd_file.open(cmd_input, std::ios::in);
UTIL_THROW_IF(!cmd_file, util::ErrnoException, "Failed to open " << cmd_input);
vocab = &cmd_file;
}
util::FilePiece model(cmd_is_model ? util::OpenReadOrThrow(cmd_input) : 0, cmd_is_model ? cmd_input : NULL, &std::cerr);
if (config.format == lm::FORMAT_ARPA) {
lm::DispatchFilterModes<lm::ARPAFormat>(config, *vocab, model, argv[argc - 1]);
} else if (config.format == lm::FORMAT_COUNT) {
lm::DispatchFilterModes<lm::CountFormat>(config, *vocab, model, argv[argc - 1]);
}
return 0;
} catch (const std::exception &e) {
std::cerr << e.what() << std::endl;
return 1;
}
}
#ifndef LM_FILTER_FORMAT_H
#define LM_FILTER_FORMAT_H
#include "arpa_io.hh"
#include "count_io.hh"
#include <boost/lexical_cast.hpp>
#include <boost/ptr_container/ptr_vector.hpp>
#include <iosfwd>
namespace lm {
template <class Single> class MultipleOutput {
private:
typedef boost::ptr_vector<Single> Singles;
typedef typename Singles::iterator SinglesIterator;
public:
MultipleOutput(const char *prefix, size_t number) {
files_.reserve(number);
std::string tmp;
for (unsigned int i = 0; i < number; ++i) {
tmp = prefix;
tmp += boost::lexical_cast<std::string>(i);
files_.push_back(new Single(tmp.c_str()));
}
}
void AddNGram(const StringPiece &line) {
for (SinglesIterator i = files_.begin(); i != files_.end(); ++i)
i->AddNGram(line);
}
template <class Iterator> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line) {
for (SinglesIterator i = files_.begin(); i != files_.end(); ++i)
i->AddNGram(begin, end, line);
}
void SingleAddNGram(size_t offset, const StringPiece &line) {
files_[offset].AddNGram(line);
}
template <class Iterator> void SingleAddNGram(size_t offset, const Iterator &begin, const Iterator &end, const StringPiece &line) {
files_[offset].AddNGram(begin, end, line);
}
protected:
Singles files_;
};
class MultipleARPAOutput : public MultipleOutput<ARPAOutput> {
public:
MultipleARPAOutput(const char *prefix, size_t number) : MultipleOutput<ARPAOutput>(prefix, number) {}
void ReserveForCounts(std::streampos reserve) {
for (boost::ptr_vector<ARPAOutput>::iterator i = files_.begin(); i != files_.end(); ++i)
i->ReserveForCounts(reserve);
}
void BeginLength(unsigned int length) {
for (boost::ptr_vector<ARPAOutput>::iterator i = files_.begin(); i != files_.end(); ++i)
i->BeginLength(length);
}
void EndLength(unsigned int length) {
for (boost::ptr_vector<ARPAOutput>::iterator i = files_.begin(); i != files_.end(); ++i)
i->EndLength(length);
}
void Finish() {
for (boost::ptr_vector<ARPAOutput>::iterator i = files_.begin(); i != files_.end(); ++i)
i->Finish();
}
};
template <class Filter, class Output> class DispatchInput {
public:
DispatchInput(Filter &filter, Output &output) : filter_(filter), output_(output) {}
/* template <class Iterator> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line) {
filter_.AddNGram(begin, end, line, output_);
}*/
void AddNGram(const StringPiece &ngram, const StringPiece &line) {
filter_.AddNGram(ngram, line, output_);
}
protected:
Filter &filter_;
Output &output_;
};
template <class Filter, class Output> class DispatchARPAInput : public DispatchInput<Filter, Output> {
private:
typedef DispatchInput<Filter, Output> B;
public:
DispatchARPAInput(Filter &filter, Output &output) : B(filter, output) {}
void ReserveForCounts(std::streampos reserve) { B::output_.ReserveForCounts(reserve); }
void BeginLength(unsigned int length) { B::output_.BeginLength(length); }
void EndLength(unsigned int length) {
B::filter_.Flush();
B::output_.EndLength(length);
}
void Finish() { B::output_.Finish(); }
};
struct ARPAFormat {
typedef ARPAOutput Output;
typedef MultipleARPAOutput Multiple;
static void Copy(util::FilePiece &in, Output &out) {
ReadARPA(in, out);
}
template <class Filter, class Out> static void RunFilter(util::FilePiece &in, Filter &filter, Out &output) {
DispatchARPAInput<Filter, Out> dispatcher(filter, output);
ReadARPA(in, dispatcher);
}
};
struct CountFormat {
typedef CountOutput Output;
typedef MultipleOutput<Output> Multiple;
static void Copy(util::FilePiece &in, Output &out) {
ReadCount(in, out);
}
template <class Filter, class Out> static void RunFilter(util::FilePiece &in, Filter &filter, Out &output) {
DispatchInput<Filter, Out> dispatcher(filter, output);
ReadCount(in, dispatcher);
}
};
/* For multithreading, the buffer classes hold batches of filter inputs and
* outputs in memory. The strings get reused a lot, so keep them around
* instead of clearing each time.
*/
class InputBuffer {
public:
InputBuffer() : actual_(0) {}
void Reserve(size_t size) { lines_.reserve(size); }
template <class Output> void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) {
if (lines_.size() == actual_) lines_.resize(lines_.size() + 1);
// TODO avoid this copy.
std::string &copied = lines_[actual_].line;
copied.assign(line.data(), line.size());
lines_[actual_].ngram.set(copied.data() + (ngram.data() - line.data()), ngram.size());
++actual_;
}
template <class Filter, class Output> void CallFilter(Filter &filter, Output &output) const {
for (std::vector<Line>::const_iterator i = lines_.begin(); i != lines_.begin() + actual_; ++i) {
filter.AddNGram(i->ngram, i->line, output);
}
}
void Clear() { actual_ = 0; }
bool Empty() { return actual_ == 0; }
size_t Size() { return actual_; }
private:
struct Line {
std::string line;
StringPiece ngram;
};
size_t actual_;
std::vector<Line> lines_;
};
class BinaryOutputBuffer {
public:
BinaryOutputBuffer() {}
void Reserve(size_t size) {
lines_.reserve(size);
}
void AddNGram(const StringPiece &line) {
lines_.push_back(line);
}
template <class Output> void Flush(Output &output) {
for (std::vector<StringPiece>::const_iterator i = lines_.begin(); i != lines_.end(); ++i) {
output.AddNGram(*i);
}
lines_.clear();
}
private:
std::vector<StringPiece> lines_;
};
class MultipleOutputBuffer {
public:
MultipleOutputBuffer() : last_(NULL) {}
void Reserve(size_t size) {
annotated_.reserve(size);
}
void AddNGram(const StringPiece &line) {
annotated_.resize(annotated_.size() + 1);
annotated_.back().line = line;
}
void SingleAddNGram(size_t offset, const StringPiece &line) {
if ((line.data() == last_.data()) && (line.length() == last_.length())) {
annotated_.back().systems.push_back(offset);
} else {
annotated_.resize(annotated_.size() + 1);
annotated_.back().systems.push_back(offset);
annotated_.back().line = line;
last_ = line;
}
}
template <class Output> void Flush(Output &output) {
for (std::vector<Annotated>::const_iterator i = annotated_.begin(); i != annotated_.end(); ++i) {
if (i->systems.empty()) {
output.AddNGram(i->line);
} else {
for (std::vector<size_t>::const_iterator j = i->systems.begin(); j != i->systems.end(); ++j) {
output.SingleAddNGram(*j, i->line);
}
}
}
annotated_.clear();
}
private:
struct Annotated {
// If this is empty, send to all systems.
// A filter should never send to all systems and send to a single one.
std::vector<size_t> systems;
StringPiece line;
};
StringPiece last_;
std::vector<Annotated> annotated_;
};
} // namespace lm
#endif // LM_FILTER_FORMAT_H
#include "phrase.hh"
#include "format.hh"
#include <algorithm>
#include <functional>
#include <iostream>
#include <queue>
#include <string>
#include <vector>
#include <cctype>
namespace lm {
namespace phrase {
unsigned int ReadMultiple(std::istream &in, Substrings &out) {
bool sentence_content = false;
unsigned int sentence_id = 0;
std::vector<Hash> phrase;
std::string word;
while (in) {
char c;
// Gather a word.
while (!isspace(c = in.get()) && in) word += c;
// Treat EOF like a newline.
if (!in) c = '\n';
// Add the word to the phrase.
if (!word.empty()) {
phrase.push_back(util::MurmurHashNative(word.data(), word.size()));
word.clear();
}
if (c == ' ') continue;
// It's more than just a space. Close out the phrase.
if (!phrase.empty()) {
sentence_content = true;
out.AddPhrase(sentence_id, phrase.begin(), phrase.end());
phrase.clear();
}
if (c == '\t' || c == '\v') continue;
// It's more than a space or tab: a newline.
if (sentence_content) {
++sentence_id;
sentence_content = false;
}
}
if (!in.eof()) in.exceptions(std::istream::failbit | std::istream::badbit);
return sentence_id + sentence_content;
}
namespace {
typedef unsigned int Sentence;
typedef std::vector<Sentence> Sentences;
} // namespace
namespace detail {
const StringPiece kEndSentence("</s>");
class Arc {
public:
Arc() {}
// For arcs from one vertex to another.
void SetPhrase(detail::Vertex &from, detail::Vertex &to, const Sentences &intersect) {
Set(to, intersect);
from_ = &from;
}
/* For arcs from before the n-gram begins to somewhere in the n-gram (right
* aligned). These have no from_ vertex; it implictly matches every
* sentence. This also handles when the n-gram is a substring of a phrase.
*/
void SetRight(detail::Vertex &to, const Sentences &complete) {
Set(to, complete);
from_ = NULL;
}
Sentence Current() const {
return *current_;
}
bool Empty() const {
return current_ == last_;
}
/* When this function returns:
* If Empty() then there's nothing left from this intersection.
*
* If Current() == to then to is part of the intersection.
*
* Otherwise, Current() > to. In this case, to is not part of the
* intersection and neither is anything < Current(). To determine if
* any value >= Current() is in the intersection, call LowerBound again
* with the value.
*/
void LowerBound(const Sentence to);
private:
void Set(detail::Vertex &to, const Sentences &sentences);
const Sentence *current_;
const Sentence *last_;
detail::Vertex *from_;
};
struct ArcGreater : public std::binary_function<const Arc *, const Arc *, bool> {
bool operator()(const Arc *first, const Arc *second) const {
return first->Current() > second->Current();
}
};
class Vertex {
public:
Vertex() : current_(0) {}
Sentence Current() const {
return current_;
}
bool Empty() const {
return incoming_.empty();
}
void LowerBound(const Sentence to);
private:
friend class Arc;
void AddIncoming(Arc *arc) {
if (!arc->Empty()) incoming_.push(arc);
}
unsigned int current_;
std::priority_queue<Arc*, std::vector<Arc*>, ArcGreater> incoming_;
};
void Arc::LowerBound(const Sentence to) {
current_ = std::lower_bound(current_, last_, to);
// If *current_ > to, don't advance from_. The intervening values of
// from_ may be useful for another one of its outgoing arcs.
if (!from_ || Empty() || (Current() > to)) return;
assert(Current() == to);
from_->LowerBound(to);
if (from_->Empty()) {
current_ = last_;
return;
}
assert(from_->Current() >= to);
if (from_->Current() > to) {
current_ = std::lower_bound(current_ + 1, last_, from_->Current());
}
}
void Arc::Set(Vertex &to, const Sentences &sentences) {
current_ = &*sentences.begin();
last_ = &*sentences.end();
to.AddIncoming(this);
}
void Vertex::LowerBound(const Sentence to) {
if (Empty()) return;
// Union lower bound.
while (true) {
Arc *top = incoming_.top();
if (top->Current() > to) {
current_ = top->Current();
return;
}
// If top->Current() == to, we still need to verify that's an actual
// element and not just a bound.
incoming_.pop();
top->LowerBound(to);
if (!top->Empty()) {
incoming_.push(top);
if (top->Current() == to) {
current_ = to;
return;
}
} else if (Empty()) {
return;
}
}
}
} // namespace detail
namespace {
void BuildGraph(const Substrings &phrase, const std::vector<Hash> &hashes, detail::Vertex *const vertices, detail::Arc *free_arc) {
using detail::Vertex;
using detail::Arc;
assert(!hashes.empty());
const Hash *const first_word = &*hashes.begin();
const Hash *const last_word = &*hashes.end() - 1;
Hash hash = 0;
const Sentences *found;
// Phrases starting at or before the first word in the n-gram.
{
Vertex *vertex = vertices;
for (const Hash *word = first_word; ; ++word, ++vertex) {
hash = util::MurmurHashNative(&hash, sizeof(uint64_t), *word);
// Now hash is [hashes.begin(), word].
if (word == last_word) {
if (phrase.FindSubstring(hash, found))
(free_arc++)->SetRight(*vertex, *found);
break;
}
if (!phrase.FindRight(hash, found)) break;
(free_arc++)->SetRight(*vertex, *found);
}
}
// Phrases starting at the second or later word in the n-gram.
Vertex *vertex_from = vertices;
for (const Hash *word_from = first_word + 1; word_from != &*hashes.end(); ++word_from, ++vertex_from) {
hash = 0;
Vertex *vertex_to = vertex_from + 1;
for (const Hash *word_to = word_from; ; ++word_to, ++vertex_to) {
// Notice that word_to and vertex_to have the same index.
hash = util::MurmurHashNative(&hash, sizeof(uint64_t), *word_to);
// Now hash covers [word_from, word_to].
if (word_to == last_word) {
if (phrase.FindLeft(hash, found))
(free_arc++)->SetPhrase(*vertex_from, *vertex_to, *found);
break;
}
if (!phrase.FindPhrase(hash, found)) break;
(free_arc++)->SetPhrase(*vertex_from, *vertex_to, *found);
}
}
}
} // namespace
namespace detail {
// Here instead of header due to forward declaration.
ConditionCommon::ConditionCommon(const Substrings &substrings) : substrings_(substrings) {}
// Rest of the variables are temporaries anyway
ConditionCommon::ConditionCommon(const ConditionCommon &from) : substrings_(from.substrings_) {}
ConditionCommon::~ConditionCommon() {}
detail::Vertex &ConditionCommon::MakeGraph() {
assert(!hashes_.empty());
vertices_.clear();
vertices_.resize(hashes_.size());
arcs_.clear();
// One for every substring.
arcs_.resize(((hashes_.size() + 1) * hashes_.size()) / 2);
BuildGraph(substrings_, hashes_, &*vertices_.begin(), &*arcs_.begin());
return vertices_[hashes_.size() - 1];
}
} // namespace detail
bool Union::Evaluate() {
detail::Vertex &last_vertex = MakeGraph();
unsigned int lower = 0;
while (true) {
last_vertex.LowerBound(lower);
if (last_vertex.Empty()) return false;
if (last_vertex.Current() == lower) return true;
lower = last_vertex.Current();
}
}
template <class Output> void Multiple::Evaluate(const StringPiece &line, Output &output) {
detail::Vertex &last_vertex = MakeGraph();
unsigned int lower = 0;
while (true) {
last_vertex.LowerBound(lower);
if (last_vertex.Empty()) return;
if (last_vertex.Current() == lower) {
output.SingleAddNGram(lower, line);
++lower;
} else {
lower = last_vertex.Current();
}
}
}
template void Multiple::Evaluate<CountFormat::Multiple>(const StringPiece &line, CountFormat::Multiple &output);
template void Multiple::Evaluate<ARPAFormat::Multiple>(const StringPiece &line, ARPAFormat::Multiple &output);
template void Multiple::Evaluate<MultipleOutputBuffer>(const StringPiece &line, MultipleOutputBuffer &output);
} // namespace phrase
} // namespace lm
#ifndef LM_FILTER_PHRASE_H
#define LM_FILTER_PHRASE_H
#include "../../util/murmur_hash.hh"
#include "../../util/string_piece.hh"
#include "../../util/tokenize_piece.hh"
#include <boost/unordered_map.hpp>
#include <iosfwd>
#include <vector>
#define LM_FILTER_PHRASE_METHOD(caps, lower) \
bool Find##caps(Hash key, const std::vector<unsigned int> *&out) const {\
Table::const_iterator i(table_.find(key));\
if (i==table_.end()) return false; \
out = &i->second.lower; \
return true; \
}
namespace lm {
namespace phrase {
typedef uint64_t Hash;
class Substrings {
private:
/* This is the value in a hash table where the key is a string. It indicates
* four sets of sentences:
* substring is sentences with a phrase containing the key as a substring.
* left is sentencess with a phrase that begins with the key (left aligned).
* right is sentences with a phrase that ends with the key (right aligned).
* phrase is sentences where the key is a phrase.
* Each set is encoded as a vector of sentence ids in increasing order.
*/
struct SentenceRelation {
std::vector<unsigned int> substring, left, right, phrase;
};
/* Most of the CPU is hash table lookups, so let's not complicate it with
* vector equality comparisons. If a collision happens, the SentenceRelation
* structure will contain the union of sentence ids over the colliding strings.
* In that case, the filter will be slightly more permissive.
* The key here is the same as boost's hash of std::vector<std::string>.
*/
typedef boost::unordered_map<Hash, SentenceRelation> Table;
public:
Substrings() {}
/* If the string isn't a substring of any phrase, return NULL. Otherwise,
* return a pointer to std::vector<unsigned int> listing sentences with
* matching phrases. This set may be empty for Left, Right, or Phrase.
* Example: const std::vector<unsigned int> *FindSubstring(Hash key)
*/
LM_FILTER_PHRASE_METHOD(Substring, substring)
LM_FILTER_PHRASE_METHOD(Left, left)
LM_FILTER_PHRASE_METHOD(Right, right)
LM_FILTER_PHRASE_METHOD(Phrase, phrase)
#pragma GCC diagnostic ignored "-Wuninitialized" // end != finish so there's always an initialization
// sentence_id must be non-decreasing. Iterators are over words in the phrase.
template <class Iterator> void AddPhrase(unsigned int sentence_id, const Iterator &begin, const Iterator &end) {
// Iterate over all substrings.
for (Iterator start = begin; start != end; ++start) {
Hash hash = 0;
SentenceRelation *relation;
for (Iterator finish = start; finish != end; ++finish) {
hash = util::MurmurHashNative(&hash, sizeof(uint64_t), *finish);
// Now hash is of [start, finish].
relation = &table_[hash];
AppendSentence(relation->substring, sentence_id);
if (start == begin) AppendSentence(relation->left, sentence_id);
}
AppendSentence(relation->right, sentence_id);
if (start == begin) AppendSentence(relation->phrase, sentence_id);
}
}
private:
void AppendSentence(std::vector<unsigned int> &vec, unsigned int sentence_id) {
if (vec.empty() || vec.back() != sentence_id) vec.push_back(sentence_id);
}
Table table_;
};
// Read a file with one sentence per line containing tab-delimited phrases of
// space-separated words.
unsigned int ReadMultiple(std::istream &in, Substrings &out);
namespace detail {
extern const StringPiece kEndSentence;
template <class Iterator> void MakeHashes(Iterator i, const Iterator &end, std::vector<Hash> &hashes) {
hashes.clear();
if (i == end) return;
// TODO: check strict phrase boundaries after <s> and before </s>. For now, just skip tags.
if ((i->data()[0] == '<') && (i->data()[i->size() - 1] == '>')) {
++i;
}
for (; i != end && (*i != kEndSentence); ++i) {
hashes.push_back(util::MurmurHashNative(i->data(), i->size()));
}
}
class Vertex;
class Arc;
class ConditionCommon {
protected:
ConditionCommon(const Substrings &substrings);
ConditionCommon(const ConditionCommon &from);
~ConditionCommon();
detail::Vertex &MakeGraph();
// Temporaries in PassNGram and Evaluate to avoid reallocation.
std::vector<Hash> hashes_;
private:
std::vector<detail::Vertex> vertices_;
std::vector<detail::Arc> arcs_;
const Substrings &substrings_;
};
} // namespace detail
class Union : public detail::ConditionCommon {
public:
explicit Union(const Substrings &substrings) : detail::ConditionCommon(substrings) {}
template <class Iterator> bool PassNGram(const Iterator &begin, const Iterator &end) {
detail::MakeHashes(begin, end, hashes_);
return hashes_.empty() || Evaluate();
}
private:
bool Evaluate();
};
class Multiple : public detail::ConditionCommon {
public:
explicit Multiple(const Substrings &substrings) : detail::ConditionCommon(substrings) {}
template <class Iterator, class Output> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line, Output &output) {
detail::MakeHashes(begin, end, hashes_);
if (hashes_.empty()) {
output.AddNGram(line);
} else {
Evaluate(line, output);
}
}
template <class Output> void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) {
AddNGram(util::TokenIter<util::SingleCharacter, true>(ngram, ' '), util::TokenIter<util::SingleCharacter, true>::end(), line, output);
}
void Flush() const {}
private:
template <class Output> void Evaluate(const StringPiece &line, Output &output);
};
} // namespace phrase
} // namespace lm
#endif // LM_FILTER_PHRASE_H
#include "../../util/file_stream.hh"
#include "../../util/file_piece.hh"
#include "../../util/murmur_hash.hh"
#include "../../util/pool.hh"
#include "../../util/string_piece.hh"
#include "../../util/string_piece_hash.hh"
#include "../../util/tokenize_piece.hh"
#include <boost/unordered_map.hpp>
#include <boost/unordered_set.hpp>
#include <cstddef>
#include <vector>
namespace {
struct MutablePiece {
mutable StringPiece behind;
bool operator==(const MutablePiece &other) const {
return behind == other.behind;
}
};
std::size_t hash_value(const MutablePiece &m) {
return hash_value(m.behind);
}
class InternString {
public:
const char *Add(StringPiece str) {
MutablePiece mut;
mut.behind = str;
std::pair<boost::unordered_set<MutablePiece>::iterator, bool> res(strs_.insert(mut));
if (res.second) {
void *mem = backing_.Allocate(str.size() + 1);
memcpy(mem, str.data(), str.size());
static_cast<char*>(mem)[str.size()] = 0;
res.first->behind = StringPiece(static_cast<char*>(mem), str.size());
}
return res.first->behind.data();
}
private:
util::Pool backing_;
boost::unordered_set<MutablePiece> strs_;
};
class TargetWords {
public:
void Introduce(StringPiece source) {
vocab_.resize(vocab_.size() + 1);
std::vector<unsigned int> temp(1, vocab_.size() - 1);
Add(temp, source);
}
void Add(const std::vector<unsigned int> &sentences, StringPiece target) {
if (sentences.empty()) return;
interns_.clear();
for (util::TokenIter<util::SingleCharacter, true> i(target, ' '); i; ++i) {
interns_.push_back(intern_.Add(*i));
}
for (std::vector<unsigned int>::const_iterator i(sentences.begin()); i != sentences.end(); ++i) {
boost::unordered_set<const char *> &vocab = vocab_[*i];
for (std::vector<const char *>::const_iterator j = interns_.begin(); j != interns_.end(); ++j) {
vocab.insert(*j);
}
}
}
void Print() const {
util::FileStream out(1);
for (std::vector<boost::unordered_set<const char *> >::const_iterator i = vocab_.begin(); i != vocab_.end(); ++i) {
for (boost::unordered_set<const char *>::const_iterator j = i->begin(); j != i->end(); ++j) {
out << *j << ' ';
}
out << '\n';
}
}
private:
InternString intern_;
std::vector<boost::unordered_set<const char *> > vocab_;
// Temporary in Add.
std::vector<const char *> interns_;
};
class Input {
public:
explicit Input(std::size_t max_length)
: max_length_(max_length), sentence_id_(0), empty_() {}
void AddSentence(StringPiece sentence, TargetWords &targets) {
canonical_.clear();
starts_.clear();
starts_.push_back(0);
for (util::TokenIter<util::AnyCharacter, true> i(sentence, StringPiece("\0 \t", 3)); i; ++i) {
canonical_.append(i->data(), i->size());
canonical_ += ' ';
starts_.push_back(canonical_.size());
}
targets.Introduce(canonical_);
for (std::size_t i = 0; i < starts_.size() - 1; ++i) {
std::size_t subtract = starts_[i];
const char *start = &canonical_[subtract];
for (std::size_t j = i + 1; j < std::min(starts_.size(), i + max_length_ + 1); ++j) {
map_[util::MurmurHash64A(start, &canonical_[starts_[j]] - start - 1)].push_back(sentence_id_);
}
}
++sentence_id_;
}
// Assumes single space-delimited phrase with no space at the beginning or end.
const std::vector<unsigned int> &Matches(StringPiece phrase) const {
Map::const_iterator i = map_.find(util::MurmurHash64A(phrase.data(), phrase.size()));
return i == map_.end() ? empty_ : i->second;
}
private:
const std::size_t max_length_;
// hash of phrase is the key, array of sentences is the value.
typedef boost::unordered_map<uint64_t, std::vector<unsigned int> > Map;
Map map_;
std::size_t sentence_id_;
// Temporaries in AddSentence.
std::string canonical_;
std::vector<std::size_t> starts_;
const std::vector<unsigned int> empty_;
};
} // namespace
int main(int argc, char *argv[]) {
if (argc != 2) {
std::cerr << "Expected source text on the command line" << std::endl;
return 1;
}
Input input(7);
TargetWords targets;
try {
util::FilePiece inputs(argv[1], &std::cerr);
while (true)
input.AddSentence(inputs.ReadLine(), targets);
} catch (const util::EndOfFileException &e) {}
util::FilePiece table(0, NULL, &std::cerr);
StringPiece line;
const StringPiece pipes("|||");
while (true) {
try {
line = table.ReadLine();
} catch (const util::EndOfFileException &e) { break; }
util::TokenIter<util::MultiCharacter> it(line, pipes);
StringPiece source(*it);
if (!source.empty() && source[source.size() - 1] == ' ')
source.remove_suffix(1);
targets.Add(input.Matches(source), *++it);
}
targets.Print();
}
#ifndef LM_FILTER_THREAD_H
#define LM_FILTER_THREAD_H
#include "../../util/thread_pool.hh"
#include <boost/utility/in_place_factory.hpp>
#include <deque>
#include <stack>
namespace lm {
template <class OutputBuffer> class ThreadBatch {
public:
ThreadBatch() {}
void Reserve(size_t size) {
input_.Reserve(size);
output_.Reserve(size);
}
// File reading thread.
InputBuffer &Fill(uint64_t sequence) {
sequence_ = sequence;
// Why wait until now to clear instead of after output? free in the same
// thread as allocated.
input_.Clear();
return input_;
}
// Filter worker thread.
template <class Filter> void CallFilter(Filter &filter) {
input_.CallFilter(filter, output_);
}
uint64_t Sequence() const { return sequence_; }
// File writing thread.
template <class RealOutput> void Flush(RealOutput &output) {
output_.Flush(output);
}
private:
InputBuffer input_;
OutputBuffer output_;
uint64_t sequence_;
};
template <class Batch, class Filter> class FilterWorker {
public:
typedef Batch *Request;
FilterWorker(const Filter &filter, util::PCQueue<Request> &done) : filter_(filter), done_(done) {}
void operator()(Request request) {
request->CallFilter(filter_);
done_.Produce(request);
}
private:
Filter filter_;
util::PCQueue<Request> &done_;
};
// There should only be one OutputWorker.
template <class Batch, class Output> class OutputWorker {
public:
typedef Batch *Request;
OutputWorker(Output &output, util::PCQueue<Request> &done) : output_(output), done_(done), base_sequence_(0) {}
void operator()(Request request) {
assert(request->Sequence() >= base_sequence_);
// Assemble the output in order.
uint64_t pos = request->Sequence() - base_sequence_;
if (pos >= ordering_.size()) {
ordering_.resize(pos + 1, NULL);
}
ordering_[pos] = request;
while (!ordering_.empty() && ordering_.front()) {
ordering_.front()->Flush(output_);
done_.Produce(ordering_.front());
ordering_.pop_front();
++base_sequence_;
}
}
private:
Output &output_;
util::PCQueue<Request> &done_;
std::deque<Request> ordering_;
uint64_t base_sequence_;
};
template <class Filter, class OutputBuffer, class RealOutput> class Controller : boost::noncopyable {
private:
typedef ThreadBatch<OutputBuffer> Batch;
public:
Controller(size_t batch_size, size_t queue, size_t workers, const Filter &filter, RealOutput &output)
: batch_size_(batch_size), queue_size_(queue),
batches_(queue),
to_read_(queue),
output_(queue, 1, boost::in_place(boost::ref(output), boost::ref(to_read_)), NULL),
filter_(queue, workers, boost::in_place(boost::ref(filter), boost::ref(output_.In())), NULL),
sequence_(0) {
for (size_t i = 0; i < queue; ++i) {
batches_[i].Reserve(batch_size);
local_read_.push(&batches_[i]);
}
NewInput();
}
void AddNGram(const StringPiece &ngram, const StringPiece &line, RealOutput &output) {
input_->AddNGram(ngram, line, output);
if (input_->Size() == batch_size_) {
FlushInput();
NewInput();
}
}
void Flush() {
FlushInput();
while (local_read_.size() < queue_size_) {
MoveRead();
}
NewInput();
}
private:
void FlushInput() {
if (input_->Empty()) return;
filter_.Produce(local_read_.top());
local_read_.pop();
if (local_read_.empty()) MoveRead();
}
void NewInput() {
input_ = &local_read_.top()->Fill(sequence_++);
}
void MoveRead() {
local_read_.push(to_read_.Consume());
}
const size_t batch_size_;
const size_t queue_size_;
std::vector<Batch> batches_;
util::PCQueue<Batch*> to_read_;
std::stack<Batch*> local_read_;
util::ThreadPool<OutputWorker<Batch, RealOutput> > output_;
util::ThreadPool<FilterWorker<Batch, Filter> > filter_;
uint64_t sequence_;
InputBuffer *input_;
};
} // namespace lm
#endif // LM_FILTER_THREAD_H
#include "vocab.hh"
#include <istream>
#include <iostream>
#include <cctype>
namespace lm {
namespace vocab {
void ReadSingle(std::istream &in, boost::unordered_set<std::string> &out) {
in.exceptions(std::istream::badbit);
std::string word;
while (in >> word) {
out.insert(word);
}
}
namespace {
bool IsLineEnd(std::istream &in) {
int got;
do {
got = in.get();
if (!in) return true;
if (got == '\n') return true;
} while (isspace(got));
in.unget();
return false;
}
}// namespace
// Read space separated words in enter separated lines. These lines can be
// very long, so don't read an entire line at a time.
unsigned int ReadMultiple(std::istream &in, boost::unordered_map<std::string, std::vector<unsigned int> > &out) {
in.exceptions(std::istream::badbit);
unsigned int sentence = 0;
bool used_id = false;
std::string word;
while (in >> word) {
used_id = true;
std::vector<unsigned int> &posting = out[word];
if (posting.empty() || (posting.back() != sentence))
posting.push_back(sentence);
if (IsLineEnd(in)) {
++sentence;
used_id = false;
}
}
return sentence + used_id;
}
} // namespace vocab
} // namespace lm
#ifndef LM_FILTER_VOCAB_H
#define LM_FILTER_VOCAB_H
// Vocabulary-based filters for language models.
#include "../../util/multi_intersection.hh"
#include "../../util/string_piece.hh"
#include "../../util/string_piece_hash.hh"
#include "../../util/tokenize_piece.hh"
#include <boost/noncopyable.hpp>
#include <boost/range/iterator_range.hpp>
#include <boost/unordered/unordered_map.hpp>
#include <boost/unordered/unordered_set.hpp>
#include <string>
#include <vector>
namespace lm {
namespace vocab {
void ReadSingle(std::istream &in, boost::unordered_set<std::string> &out);
// Read one sentence vocabulary per line. Return the number of sentences.
unsigned int ReadMultiple(std::istream &in, boost::unordered_map<std::string, std::vector<unsigned int> > &out);
/* Is this a special tag like <s> or <UNK>? This actually includes anything
* surrounded with < and >, which most tokenizers separate for real words, so
* this should not catch real words as it looks at a single token.
*/
inline bool IsTag(const StringPiece &value) {
// The parser should never give an empty string.
assert(!value.empty());
return (value.data()[0] == '<' && value.data()[value.size() - 1] == '>');
}
class Single {
public:
typedef boost::unordered_set<std::string> Words;
explicit Single(const Words &vocab) : vocab_(vocab) {}
template <class Iterator> bool PassNGram(const Iterator &begin, const Iterator &end) {
for (Iterator i = begin; i != end; ++i) {
if (IsTag(*i)) continue;
if (FindStringPiece(vocab_, *i) == vocab_.end()) return false;
}
return true;
}
private:
const Words &vocab_;
};
class Union {
public:
typedef boost::unordered_map<std::string, std::vector<unsigned int> > Words;
explicit Union(const Words &vocabs) : vocabs_(vocabs) {}
template <class Iterator> bool PassNGram(const Iterator &begin, const Iterator &end) {
sets_.clear();
for (Iterator i(begin); i != end; ++i) {
if (IsTag(*i)) continue;
Words::const_iterator found(FindStringPiece(vocabs_, *i));
if (vocabs_.end() == found) return false;
sets_.push_back(boost::iterator_range<const unsigned int*>(&*found->second.begin(), &*found->second.end()));
}
return (sets_.empty() || util::FirstIntersection(sets_));
}
private:
const Words &vocabs_;
std::vector<boost::iterator_range<const unsigned int*> > sets_;
};
class Multiple {
public:
typedef boost::unordered_map<std::string, std::vector<unsigned int> > Words;
Multiple(const Words &vocabs) : vocabs_(vocabs) {}
private:
// Callback from AllIntersection that does AddNGram.
template <class Output> class Callback {
public:
Callback(Output &out, const StringPiece &line) : out_(out), line_(line) {}
void operator()(unsigned int index) {
out_.SingleAddNGram(index, line_);
}
private:
Output &out_;
const StringPiece &line_;
};
public:
template <class Iterator, class Output> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line, Output &output) {
sets_.clear();
for (Iterator i(begin); i != end; ++i) {
if (IsTag(*i)) continue;
Words::const_iterator found(FindStringPiece(vocabs_, *i));
if (vocabs_.end() == found) return;
sets_.push_back(boost::iterator_range<const unsigned int*>(&*found->second.begin(), &*found->second.end()));
}
if (sets_.empty()) {
output.AddNGram(line);
return;
}
Callback<Output> cb(output, line);
util::AllIntersection(sets_, cb);
}
template <class Output> void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) {
AddNGram(util::TokenIter<util::SingleCharacter, true>(ngram, ' '), util::TokenIter<util::SingleCharacter, true>::end(), line, output);
}
void Flush() const {}
private:
const Words &vocabs_;
std::vector<boost::iterator_range<const unsigned int*> > sets_;
};
} // namespace vocab
} // namespace lm
#endif // LM_FILTER_VOCAB_H
#ifndef LM_FILTER_WRAPPER_H
#define LM_FILTER_WRAPPER_H
#include "../../util/string_piece.hh"
#include <algorithm>
#include <string>
#include <vector>
namespace lm {
// Provide a single-output filter with the same interface as a
// multiple-output filter so clients code against one interface.
template <class Binary> class BinaryFilter {
public:
// Binary modes are just references (and a set) and it makes the API cleaner to copy them.
explicit BinaryFilter(Binary binary) : binary_(binary) {}
template <class Iterator, class Output> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line, Output &output) {
if (binary_.PassNGram(begin, end))
output.AddNGram(line);
}
template <class Output> void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) {
AddNGram(util::TokenIter<util::SingleCharacter, true>(ngram, ' '), util::TokenIter<util::SingleCharacter, true>::end(), line, output);
}
void Flush() const {}
private:
Binary binary_;
};
// Wrap another filter to pay attention only to context words
template <class FilterT> class ContextFilter {
public:
typedef FilterT Filter;
explicit ContextFilter(Filter &backend) : backend_(backend) {}
template <class Output> void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) {
// Find beginning of string or last space.
const char *last_space;
for (last_space = ngram.data() + ngram.size() - 1; last_space > ngram.data() && *last_space != ' '; --last_space) {}
backend_.AddNGram(StringPiece(ngram.data(), last_space - ngram.data()), line, output);
}
void Flush() const {}
private:
Filter backend_;
};
} // namespace lm
#endif // LM_FILTER_WRAPPER_H
#include "binary_format.hh"
#include "model.hh"
#include "left.hh"
#include "../util/tokenize_piece.hh"
template <class Model> void Query(const char *name) {
Model model(name);
std::string line;
lm::ngram::ChartState ignored;
while (getline(std::cin, line)) {
lm::ngram::RuleScore<Model> scorer(model, ignored);
for (util::TokenIter<util::SingleCharacter, true> i(line, ' '); i; ++i) {
scorer.Terminal(model.GetVocabulary().Index(*i));
}
std::cout << scorer.Finish() << '\n';
}
}
int main(int argc, char *argv[]) {
if (argc != 2) {
std::cerr << "Expected model file name." << std::endl;
return 1;
}
const char *name = argv[1];
lm::ngram::ModelType model_type = lm::ngram::PROBING;
lm::ngram::RecognizeBinary(name, model_type);
switch (model_type) {
case lm::ngram::PROBING:
Query<lm::ngram::ProbingModel>(name);
break;
case lm::ngram::REST_PROBING:
Query<lm::ngram::RestProbingModel>(name);
break;
default:
std::cerr << "Model type not supported yet." << std::endl;
}
}
# Eigen3 less than 3.1.0 has a race condition: http://eigen.tuxfamily.org/bz/show_bug.cgi?id=466
if(ENABLE_INTERPOLATE)
find_package(Eigen3 3.1.0 CONFIG REQUIRED)
include_directories(${EIGEN3_INCLUDE_DIR})
set(KENLM_INTERPOLATE_SOURCE
backoff_reunification.cc
bounded_sequence_encoding.cc
merge_probabilities.cc
merge_vocab.cc
normalize.cc
pipeline.cc
split_worker.cc
tune_derivatives.cc
tune_instances.cc
tune_weights.cc
universal_vocab.cc)
add_library(kenlm_interpolate ${KENLM_INTERPOLATE_SOURCE})
target_link_libraries(kenlm_interpolate PUBLIC kenlm Eigen3::Eigen)
# Since headers are relative to `include/kenlm` at install time, not just `include`
target_include_directories(kenlm_interpolate PUBLIC $<INSTALL_INTERFACE:include/kenlm>)
find_package(OpenMP)
if (OPENMP_CXX_FOUND)
target_link_libraries(kenlm_interpolate PUBLIC OpenMP::OpenMP_CXX)
endif()
set(KENLM_INTERPOLATE_EXES
interpolate
streaming_example)
set(KENLM_INTERPOLATE_LIBS
kenlm_interpolate)
AddExes(EXES ${KENLM_INTERPOLATE_EXES}
LIBRARIES ${KENLM_INTERPOLATE_LIBS})
install(
TARGETS kenlm_interpolate
EXPORT kenlmTargets
RUNTIME DESTINATION bin
LIBRARY DESTINATION lib
ARCHIVE DESTINATION lib
INCLUDES DESTINATION include
)
if(BUILD_TESTING)
AddTests(TESTS backoff_reunification_test bounded_sequence_encoding_test merge_vocab_test normalize_test tune_derivatives_test
LIBRARIES ${KENLM_INTERPOLATE_LIBS} Threads::Threads)
# tune_instances_test needs an extra command line parameter
KenLMAddTest(TEST tune_instances_test
LIBRARIES ${KENLM_INTERPOLATE_LIBS}
TEST_ARGS -- ${CMAKE_CURRENT_SOURCE_DIR}/../common/test_data)
endif()
endif()
#ifndef LM_INTERPOLATE_BACKOFF_MATRIX_H
#define LM_INTERPOLATE_BACKOFF_MATRIX_H
#include <cstddef>
#include <vector>
namespace lm { namespace interpolate {
class BackoffMatrix {
public:
BackoffMatrix(std::size_t num_models, std::size_t max_order)
: max_order_(max_order), backing_(num_models * max_order) {}
float &Backoff(std::size_t model, std::size_t order_minus_1) {
return backing_[model * max_order_ + order_minus_1];
}
float Backoff(std::size_t model, std::size_t order_minus_1) const {
return backing_[model * max_order_ + order_minus_1];
}
private:
const std::size_t max_order_;
std::vector<float> backing_;
};
}} // namespaces
#endif // LM_INTERPOLATE_BACKOFF_MATRIX_H
#include "backoff_reunification.hh"
#include "../common/model_buffer.hh"
#include "../common/ngram_stream.hh"
#include "../common/ngram.hh"
#include "../common/compare.hh"
#include <algorithm>
#include <cassert>
namespace lm {
namespace interpolate {
namespace {
class MergeWorker {
public:
MergeWorker(std::size_t order, const util::stream::ChainPosition &prob_pos,
const util::stream::ChainPosition &boff_pos)
: order_(order), prob_pos_(prob_pos), boff_pos_(boff_pos) {
// nothing
}
void Run(const util::stream::ChainPosition &position) {
lm::NGramStream<ProbBackoff> stream(position);
lm::NGramStream<float> prob_input(prob_pos_);
util::stream::Stream boff_input(boff_pos_);
for (; prob_input && boff_input; ++prob_input, ++boff_input, ++stream) {
std::copy(prob_input->begin(), prob_input->end(), stream->begin());
stream->Value().prob = std::min(0.0f, prob_input->Value());
stream->Value().backoff = *reinterpret_cast<float *>(boff_input.Get());
}
UTIL_THROW_IF2(prob_input || boff_input,
"Streams were not the same size during merging");
stream.Poison();
}
private:
std::size_t order_;
util::stream::ChainPosition prob_pos_;
util::stream::ChainPosition boff_pos_;
};
}
// Since we are *adding* something to the output chain here, we pass in the
// chain itself so that we can safely add a new step to the chain without
// creating a deadlock situation (since creating a new ChainPosition will
// make a new input/output pair---we want that position to be created
// *here*, not before).
void ReunifyBackoff(util::stream::ChainPositions &prob_pos,
util::stream::ChainPositions &boff_pos,
util::stream::Chains &output_chains) {
assert(prob_pos.size() == boff_pos.size());
for (size_t i = 0; i < prob_pos.size(); ++i)
output_chains[i] >> MergeWorker(i + 1, prob_pos[i], boff_pos[i]);
}
}
}
#ifndef KENLM_INTERPOLATE_BACKOFF_REUNIFICATION_
#define KENLM_INTERPOLATE_BACKOFF_REUNIFICATION_
#include "../../util/stream/stream.hh"
#include "../../util/stream/multi_stream.hh"
namespace lm {
namespace interpolate {
/**
* The third pass for the offline log-linear interpolation algorithm. This
* reads **suffix-ordered** probability values (ngram-id, float) and
* **suffix-ordered** backoff values (float) and writes the merged contents
* to the output.
*
* @param prob_pos The chain position for each order from which to read
* the probability values
* @param boff_pos The chain position for each order from which to read
* the backoff values
* @param output_chains The output chains for each order
*/
void ReunifyBackoff(util::stream::ChainPositions &prob_pos,
util::stream::ChainPositions &boff_pos,
util::stream::Chains &output_chains);
}
}
#endif
#include "backoff_reunification.hh"
#include "../common/ngram_stream.hh"
#define BOOST_TEST_MODULE InterpolateBackoffReunificationTest
#include <boost/test/unit_test.hpp>
namespace lm {
namespace interpolate {
namespace {
// none of this input actually makes sense, all we care about is making
// sure the merging works
template <uint8_t N>
struct Gram {
WordIndex ids[N];
float prob;
float boff;
};
template <uint8_t N>
struct Grams {
const static Gram<N> grams[];
};
template <>
const Gram<1> Grams<1>::grams[]
= {{{0}, -0.1f, -0.1f}, {{1}, -0.4f, -0.2f}, {{2}, -0.5f, -0.1f}};
template <>
const Gram<2> Grams<2>::grams[] = {{{0, 0}, -0.05f, -0.05f},
{{1, 0}, -0.05f, -0.02f},
{{1, 1}, -0.2f, -0.04f},
{{2, 2}, -0.2f, -0.01f}};
template <>
const Gram<3> Grams<3>::grams[] = {{{0, 0, 0}, -0.001f, -0.005f},
{{1, 0, 0}, -0.001f, -0.002f},
{{2, 0, 0}, -0.001f, -0.003f},
{{0, 1, 0}, -0.1f, -0.008f},
{{1, 1, 0}, -0.1f, -0.09f},
{{1, 1, 1}, -0.2f, -0.08f}};
template <uint8_t N>
class WriteInput {
public:
void Run(const util::stream::ChainPosition &position) {
lm::NGramStream<float> output(position);
for (std::size_t i = 0; i < sizeof(Grams<N>::grams) / sizeof(Gram<N>);
++i, ++output) {
std::copy(Grams<N>::grams[i].ids, Grams<N>::grams[i].ids + N,
output->begin());
output->Value() = Grams<N>::grams[i].prob;
}
output.Poison();
}
};
template <uint8_t N>
class WriteBackoffs {
public:
void Run(const util::stream::ChainPosition &position) {
util::stream::Stream output(position);
for (std::size_t i = 0; i < sizeof(Grams<N>::grams) / sizeof(Gram<N>);
++i, ++output) {
*reinterpret_cast<float *>(output.Get()) = Grams<N>::grams[i].boff;
}
output.Poison();
}
};
template <uint8_t N>
class CheckOutput {
public:
void Run(const util::stream::ChainPosition &position) {
lm::NGramStream<ProbBackoff> stream(position);
std::size_t i = 0;
for (; stream; ++stream, ++i) {
std::stringstream ss;
for (WordIndex *idx = stream->begin(); idx != stream->end(); ++idx)
ss << "(" << *idx << ")";
BOOST_CHECK(std::equal(stream->begin(), stream->end(), Grams<N>::grams[i].ids));
//"Mismatched id in CheckOutput<" << (int)N << ">: " << ss.str();
BOOST_CHECK_EQUAL(stream->Value().prob, Grams<N>::grams[i].prob);
/* "Mismatched probability in CheckOutput<"
<< (int)N << ">, got " << stream->Value().prob
<< ", expected " << Grams<N>::grams[i].prob;*/
BOOST_CHECK_EQUAL(stream->Value().backoff, Grams<N>::grams[i].boff);
/* "Mismatched backoff in CheckOutput<"
<< (int)N << ">, got " << stream->Value().backoff
<< ", expected " << Grams<N>::grams[i].boff);*/
}
BOOST_CHECK_EQUAL(i , sizeof(Grams<N>::grams) / sizeof(Gram<N>));
/* "Did not get correct number of "
<< (int)N << "-grams: expected "
<< sizeof(Grams<N>::grams) / sizeof(Gram<N>)
<< ", got " << i;*/
}
};
}
BOOST_AUTO_TEST_CASE(BackoffReunificationTest) {
util::stream::ChainConfig config;
config.total_memory = 100;
config.block_count = 1;
util::stream::Chains prob_chains(3);
config.entry_size = NGram<float>::TotalSize(1);
prob_chains.push_back(config);
prob_chains.back() >> WriteInput<1>();
config.entry_size = NGram<float>::TotalSize(2);
prob_chains.push_back(config);
prob_chains.back() >> WriteInput<2>();
config.entry_size = NGram<float>::TotalSize(3);
prob_chains.push_back(config);
prob_chains.back() >> WriteInput<3>();
util::stream::Chains boff_chains(3);
config.entry_size = sizeof(float);
boff_chains.push_back(config);
boff_chains.back() >> WriteBackoffs<1>();
boff_chains.push_back(config);
boff_chains.back() >> WriteBackoffs<2>();
boff_chains.push_back(config);
boff_chains.back() >> WriteBackoffs<3>();
util::stream::ChainPositions prob_pos(prob_chains);
util::stream::ChainPositions boff_pos(boff_chains);
util::stream::Chains output_chains(3);
for (std::size_t i = 0; i < 3; ++i) {
config.entry_size = NGram<ProbBackoff>::TotalSize(i + 1);
output_chains.push_back(config);
}
ReunifyBackoff(prob_pos, boff_pos, output_chains);
output_chains[0] >> CheckOutput<1>();
output_chains[1] >> CheckOutput<2>();
output_chains[2] >> CheckOutput<3>();
prob_chains >> util::stream::kRecycle;
boff_chains >> util::stream::kRecycle;
output_chains.Wait();
}
}
}
#include "bounded_sequence_encoding.hh"
#include <algorithm>
namespace lm { namespace interpolate {
BoundedSequenceEncoding::BoundedSequenceEncoding(const unsigned char *bound_begin, const unsigned char *bound_end)
: entries_(bound_end - bound_begin) {
std::size_t full = 0;
Entry entry;
entry.shift = 0;
for (const unsigned char *i = bound_begin; i != bound_end; ++i) {
uint8_t length;
if (*i <= 1) {
length = 0;
} else {
length = sizeof(unsigned int) * 8 - __builtin_clz((unsigned int)*i);
}
entry.mask = (1ULL << length) - 1ULL;
if (entry.shift + length > 64) {
entry.shift = 0;
entry.next = true;
++full;
} else {
entry.next = false;
}
entries_.push_back(entry);
entry.shift += length;
}
byte_length_ = full * sizeof(uint64_t) + (entry.shift + 7) / 8;
first_copy_ = std::min<std::size_t>(byte_length_, sizeof(uint64_t));
// Size of last uint64_t. Zero if empty, otherwise [1,8] depending on mod.
overhang_ = byte_length_ == 0 ? 0 : ((byte_length_ - 1) % 8 + 1);
}
}} // namespaces
#ifndef LM_INTERPOLATE_BOUNDED_SEQUENCE_ENCODING_H
#define LM_INTERPOLATE_BOUNDED_SEQUENCE_ENCODING_H
/* Encodes fixed-length sequences of integers with known bounds on each entry.
* This is used to encode how far each model has backed off.
* TODO: make this class efficient. Bit-level packing or multiply by bound and
* add.
*/
#include "../../util/exception.hh"
#include "../../util/fixed_array.hh"
#include <algorithm>
#include <cstring>
namespace lm {
namespace interpolate {
class BoundedSequenceEncoding {
public:
// Encode [0, bound_begin[0]) x [0, bound_begin[1]) x [0, bound_begin[2]) x ... x [0, *(bound_end - 1)) for entries in the sequence
BoundedSequenceEncoding(const unsigned char *bound_begin, const unsigned char *bound_end);
std::size_t Entries() const { return entries_.size(); }
std::size_t EncodedLength() const { return byte_length_; }
void Encode(const unsigned char *from, void *to_void) const {
uint8_t *to = static_cast<uint8_t*>(to_void);
uint64_t cur = 0;
for (const Entry *i = entries_.begin(); i != entries_.end(); ++i, ++from) {
if (UTIL_UNLIKELY(i->next)) {
std::memcpy(to, &cur, sizeof(uint64_t));
to += sizeof(uint64_t);
cur = 0;
}
cur |= static_cast<uint64_t>(*from) << i->shift;
}
#if BYTE_ORDER == BIG_ENDIAN
cur <<= (8 - overhang_) * 8;
#endif
memcpy(to, &cur, overhang_);
}
void Decode(const void *from_void, unsigned char *to) const {
const uint8_t *from = static_cast<const uint8_t*>(from_void);
uint64_t cur = 0;
memcpy(&cur, from, first_copy_);
#if BYTE_ORDER == BIG_ENDIAN
cur >>= (8 - first_copy_) * 8;
#endif
for (const Entry *i = entries_.begin(); i != entries_.end(); ++i, ++to) {
if (UTIL_UNLIKELY(i->next)) {
from += sizeof(uint64_t);
cur = 0;
std::memcpy(&cur, from,
std::min<std::size_t>(sizeof(uint64_t), static_cast<const uint8_t*>(from_void) + byte_length_ - from));
#if BYTE_ORDER == BIG_ENDIAN
cur >>= (8 - (static_cast<const uint8_t*>(from_void) + byte_length_ - from)) * 8;
#endif
}
*to = (cur >> i->shift) & i->mask;
}
}
private:
struct Entry {
bool next;
uint8_t shift;
uint64_t mask;
};
util::FixedArray<Entry> entries_;
std::size_t byte_length_;
std::size_t first_copy_;
std::size_t overhang_;
};
}} // namespaces
#endif // LM_INTERPOLATE_BOUNDED_SEQUENCE_ENCODING_H
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment