"test/vscode:/vscode.git/clone" did not exist on "75964177327c85ee1de05a28ebf5eae4d88974d5"
Commit 688b6eac authored by SWHL's avatar SWHL
Browse files

Update files

parents
# Helper functions used across the CMake build system
include(CMakeParseArguments)
# Adds a bunch of executables to the build, each depending on the specified
# dependent object files and linking against the specified libraries
function(AddExes)
set(multiValueArgs EXES DEPENDS LIBRARIES)
cmake_parse_arguments(AddExes "" "" "${multiValueArgs}" ${ARGN})
# Iterate through the executable list
foreach(exe ${AddExes_EXES})
# Compile the executable, linking against the requisite dependent object files
add_executable(${exe} ${exe}_main.cc ${AddExes_DEPENDS})
# Link the executable against the supplied libraries
target_link_libraries(${exe} ${AddExes_LIBRARIES})
# Group executables together
set_target_properties(${exe} PROPERTIES FOLDER executables)
# End for loop
endforeach(exe)
# Install the executable files
install(TARGETS ${AddExes_EXES} DESTINATION bin)
endfunction()
# Adds a single test to the build, depending on the specified dependent
# object files, linking against the specified libraries, and with the
# specified command line arguments
function(KenLMAddTest)
cmake_parse_arguments(KenLMAddTest "" "TEST"
"DEPENDS;LIBRARIES;TEST_ARGS" ${ARGN})
# Compile the executable, linking against the requisite dependent object files
add_executable(${KenLMAddTest_TEST}
${KenLMAddTest_TEST}.cc
${KenLMAddTest_DEPENDS})
if (Boost_USE_STATIC_LIBS)
set(DYNLINK_FLAGS)
else()
set(DYNLINK_FLAGS COMPILE_FLAGS -DBOOST_TEST_DYN_LINK)
endif()
# Require the following compile flag
set_target_properties(${KenLMAddTest_TEST} PROPERTIES
${DYNLINK_FLAGS}
RUNTIME_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}/tests)
target_link_libraries(${KenLMAddTest_TEST} ${KenLMAddTest_LIBRARIES} ${TIMER_LINK})
set(test_params "")
if(KenLMAddTest_TEST_ARGS)
set(test_params ${KenLMAddTest_TEST_ARGS})
endif()
# Specify command arguments for how to run each unit test
add_test(NAME ${KenLMAddTest_TEST}
COMMAND ${KenLMAddTest_TEST} ${test_params})
# Group unit tests together
set_target_properties(${KenLMAddTest_TEST} PROPERTIES FOLDER "unit_tests")
endfunction()
# Adds a bunch of tests to the build, each depending on the specified
# dependent object files and linking against the specified libraries
function(AddTests)
set(multiValueArgs TESTS DEPENDS LIBRARIES TEST_ARGS)
cmake_parse_arguments(AddTests "" "" "${multiValueArgs}" ${ARGN})
# Iterate through the Boost tests list
foreach(test ${AddTests_TESTS})
KenLMAddTest(TEST ${test}
DEPENDS ${AddTests_DEPENDS}
LIBRARIES ${AddTests_LIBRARIES}
TEST_ARGS ${AddTests_TEST_ARGS})
endforeach(test)
endfunction()
@PACKAGE_INIT@
include(CMakeFindDependencyMacro)
find_dependency(Boost)
find_dependency(Threads)
# Compression libs
if (@ZLIB_FOUND@)
find_dependency(ZLIB)
endif()
if (@BZIP2_FOUND@)
find_dependency(BZip2)
endif()
if (@LIBLZMA_FOUND@)
find_dependency(LibLZMA)
endif()
include("${CMAKE_CURRENT_LIST_DIR}/kenlmTargets.cmake")
# - Try to find Eigen3 lib
#
# This module supports requiring a minimum version, e.g. you can do
# find_package(Eigen3 3.1.2)
# to require version 3.1.2 or newer of Eigen3.
#
# Once done this will define
#
# EIGEN3_FOUND - system has eigen lib with correct version
# EIGEN3_INCLUDE_DIR - the eigen include directory
# EIGEN3_VERSION - eigen version
#
# This module reads hints about search locations from
# the following enviroment variables:
#
# EIGEN3_ROOT
# EIGEN3_ROOT_DIR
# Copyright (c) 2006, 2007 Montel Laurent, <montel@kde.org>
# Copyright (c) 2008, 2009 Gael Guennebaud, <g.gael@free.fr>
# Copyright (c) 2009 Benoit Jacob <jacob.benoit.1@gmail.com>
# Redistribution and use is allowed according to the terms of the 2-clause BSD license.
if(NOT Eigen3_FIND_VERSION)
if(NOT Eigen3_FIND_VERSION_MAJOR)
set(Eigen3_FIND_VERSION_MAJOR 2)
endif(NOT Eigen3_FIND_VERSION_MAJOR)
if(NOT Eigen3_FIND_VERSION_MINOR)
set(Eigen3_FIND_VERSION_MINOR 91)
endif(NOT Eigen3_FIND_VERSION_MINOR)
if(NOT Eigen3_FIND_VERSION_PATCH)
set(Eigen3_FIND_VERSION_PATCH 0)
endif(NOT Eigen3_FIND_VERSION_PATCH)
set(Eigen3_FIND_VERSION "${Eigen3_FIND_VERSION_MAJOR}.${Eigen3_FIND_VERSION_MINOR}.${Eigen3_FIND_VERSION_PATCH}")
endif(NOT Eigen3_FIND_VERSION)
macro(_eigen3_check_version)
file(READ "${EIGEN3_INCLUDE_DIR}/Eigen/src/Core/util/Macros.h" _eigen3_version_header)
string(REGEX MATCH "define[ \t]+EIGEN_WORLD_VERSION[ \t]+([0-9]+)" _eigen3_world_version_match "${_eigen3_version_header}")
set(EIGEN3_WORLD_VERSION "${CMAKE_MATCH_1}")
string(REGEX MATCH "define[ \t]+EIGEN_MAJOR_VERSION[ \t]+([0-9]+)" _eigen3_major_version_match "${_eigen3_version_header}")
set(EIGEN3_MAJOR_VERSION "${CMAKE_MATCH_1}")
string(REGEX MATCH "define[ \t]+EIGEN_MINOR_VERSION[ \t]+([0-9]+)" _eigen3_minor_version_match "${_eigen3_version_header}")
set(EIGEN3_MINOR_VERSION "${CMAKE_MATCH_1}")
set(EIGEN3_VERSION ${EIGEN3_WORLD_VERSION}.${EIGEN3_MAJOR_VERSION}.${EIGEN3_MINOR_VERSION})
if(${EIGEN3_VERSION} VERSION_LESS ${Eigen3_FIND_VERSION})
set(EIGEN3_VERSION_OK FALSE)
else(${EIGEN3_VERSION} VERSION_LESS ${Eigen3_FIND_VERSION})
set(EIGEN3_VERSION_OK TRUE)
endif(${EIGEN3_VERSION} VERSION_LESS ${Eigen3_FIND_VERSION})
if(NOT EIGEN3_VERSION_OK)
message(STATUS "Eigen3 version ${EIGEN3_VERSION} found in ${EIGEN3_INCLUDE_DIR}, "
"but at least version ${Eigen3_FIND_VERSION} is required")
endif(NOT EIGEN3_VERSION_OK)
endmacro(_eigen3_check_version)
if (EIGEN3_INCLUDE_DIR)
# in cache already
_eigen3_check_version()
set(EIGEN3_FOUND ${EIGEN3_VERSION_OK})
else (EIGEN3_INCLUDE_DIR)
find_path(EIGEN3_INCLUDE_DIR NAMES signature_of_eigen3_matrix_library
HINTS
ENV EIGEN3_ROOT
ENV EIGEN3_ROOT_DIR
PATHS
${CMAKE_INSTALL_PREFIX}/include
${KDE4_INCLUDE_DIR}
PATH_SUFFIXES eigen3 eigen
)
if(EIGEN3_INCLUDE_DIR)
_eigen3_check_version()
endif(EIGEN3_INCLUDE_DIR)
include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(Eigen3 DEFAULT_MSG EIGEN3_INCLUDE_DIR EIGEN3_VERSION_OK)
mark_as_advanced(EIGEN3_INCLUDE_DIR)
endif(EIGEN3_INCLUDE_DIR)
#!/bin/bash
#This is just an example compilation. You should integrate these files into your build system. Boost jam is provided and preferred.
echo You must use ./bjam if you want language model estimation, filtering, or support for compressed files \(.gz, .bz2, .xz\) 1>&2
rm {lm,util}/*.o 2>/dev/null
set -e
CXX=${CXX:-g++}
CXXFLAGS+=" -I. -O3 -DNDEBUG -DKENLM_MAX_ORDER=6"
#If this fails for you, consider using bjam.
if [ ${#NPLM} != 0 ]; then
CXXFLAGS+=" -DHAVE_NPLM -lneuralLM -L$NPLM/src -I$NPLM/src -lboost_thread-mt -fopenmp"
ADDED_PATHS="lm/wrappers/*.cc"
fi
echo 'Compiling with '$CXX $CXXFLAGS
#Grab all cc files in these directories except those ending in test.cc or main.cc
objects=""
for i in util/double-conversion/*.cc util/*.cc lm/*.cc $ADDED_PATHS; do
if [ "${i%test.cc}" == "$i" ] && [ "${i%main.cc}" == "$i" ]; then
$CXX $CXXFLAGS -c $i -o ${i%.cc}.o
objects="$objects ${i%.cc}.o"
fi
done
mkdir -p bin
if [ "$(uname)" != Darwin ]; then
CXXFLAGS="$CXXFLAGS -lrt"
fi
$CXX lm/build_binary_main.cc $objects -o bin/build_binary $CXXFLAGS $LDFLAGS
$CXX lm/query_main.cc $objects -o bin/query $CXXFLAGS $LDFLAGS
# Explicitly list the source files for this subdirectory
#
# If you add any source files to this subdirectory
# that should be included in the kenlm library,
# (this excludes any unit test files)
# you should add them to the following list:
set(KENLM_LM_SOURCE
bhiksha.cc
binary_format.cc
config.cc
lm_exception.cc
model.cc
quantize.cc
read_arpa.cc
search_hashed.cc
search_trie.cc
sizes.cc
trie.cc
trie_sort.cc
value_build.cc
virtual_interface.cc
vocab.cc
)
# Group these objects together for later use.
#
# Given add_library(foo OBJECT ${my_foo_sources}),
# refer to these objects as $<TARGET_OBJECTS:foo>
#
add_subdirectory(common)
add_library(kenlm ${KENLM_LM_SOURCE} ${KENLM_LM_COMMON_SOURCE})
set_target_properties(kenlm PROPERTIES POSITION_INDEPENDENT_CODE ON)
target_link_libraries(kenlm PUBLIC kenlm_util Threads::Threads)
# Since headers are relative to `include/kenlm` at install time, not just `include`
target_include_directories(kenlm PUBLIC $<INSTALL_INTERFACE:include/kenlm>)
set(KENLM_MAX_ORDER 6 CACHE STRING "Maximum supported ngram order")
target_compile_definitions(kenlm PUBLIC -DKENLM_MAX_ORDER=${KENLM_MAX_ORDER})
# This directory has children that need to be processed
add_subdirectory(builder)
add_subdirectory(filter)
add_subdirectory(interpolate)
# Explicitly list the executable files to be compiled
set(EXE_LIST
query
fragment
build_binary
kenlm_benchmark
)
set(LM_LIBS kenlm kenlm_util Threads::Threads)
install(
TARGETS kenlm
EXPORT kenlmTargets
RUNTIME DESTINATION bin
LIBRARY DESTINATION lib
ARCHIVE DESTINATION lib
INCLUDES DESTINATION include
)
AddExes(EXES ${EXE_LIST}
LIBRARIES ${LM_LIBS})
if(BUILD_TESTING)
set(KENLM_BOOST_TESTS_LIST left_test partial_test)
AddTests(TESTS ${KENLM_BOOST_TESTS_LIST}
LIBRARIES ${LM_LIBS}
TEST_ARGS ${CMAKE_CURRENT_SOURCE_DIR}/test.arpa)
# model_test requires an extra command line parameter
KenLMAddTest(TEST model_test
LIBRARIES ${LM_LIBS}
TEST_ARGS ${CMAKE_CURRENT_SOURCE_DIR}/test.arpa
${CMAKE_CURRENT_SOURCE_DIR}/test_nounk.arpa)
endif()
#include "bhiksha.hh"
#include "binary_format.hh"
#include "config.hh"
#include "../util/file.hh"
#include "../util/exception.hh"
#include <limits>
namespace lm {
namespace ngram {
namespace trie {
DontBhiksha::DontBhiksha(const void * /*base*/, uint64_t /*max_offset*/, uint64_t max_next, const Config &/*config*/) :
next_(util::BitsMask::ByMax(max_next)) {}
const uint8_t kArrayBhikshaVersion = 0;
// TODO: put this in binary file header instead when I change the binary file format again.
void ArrayBhiksha::UpdateConfigFromBinary(const BinaryFormat &file, uint64_t offset, Config &config) {
uint8_t buffer[2];
file.ReadForConfig(buffer, 2, offset);
uint8_t version = buffer[0];
uint8_t configured_bits = buffer[1];
if (version != kArrayBhikshaVersion) UTIL_THROW(FormatLoadException, "This file has sorted array compression version " << (unsigned) version << " but the code expects version " << (unsigned)kArrayBhikshaVersion);
config.pointer_bhiksha_bits = configured_bits;
}
namespace {
// Find argmin_{chopped \in [0, RequiredBits(max_next)]} ChoppedDelta(max_offset)
uint8_t ChopBits(uint64_t max_offset, uint64_t max_next, const Config &config) {
uint8_t required = util::RequiredBits(max_next);
uint8_t best_chop = 0;
int64_t lowest_change = std::numeric_limits<int64_t>::max();
// There are probably faster ways but I don't care because this is only done once per order at construction time.
for (uint8_t chop = 0; chop <= std::min(required, config.pointer_bhiksha_bits); ++chop) {
int64_t change = (max_next >> (required - chop)) * 64 /* table cost in bits */
- max_offset * static_cast<int64_t>(chop); /* savings in bits*/
if (change < lowest_change) {
lowest_change = change;
best_chop = chop;
}
}
return best_chop;
}
std::size_t ArrayCount(uint64_t max_offset, uint64_t max_next, const Config &config) {
uint8_t required = util::RequiredBits(max_next);
uint8_t chopping = ChopBits(max_offset, max_next, config);
return (max_next >> (required - chopping)) + 1 /* we store 0 too */;
}
} // namespace
uint64_t ArrayBhiksha::Size(uint64_t max_offset, uint64_t max_next, const Config &config) {
return sizeof(uint64_t) * (1 /* header */ + ArrayCount(max_offset, max_next, config)) + 7 /* 8-byte alignment */;
}
uint8_t ArrayBhiksha::InlineBits(uint64_t max_offset, uint64_t max_next, const Config &config) {
return util::RequiredBits(max_next) - ChopBits(max_offset, max_next, config);
}
namespace {
void *AlignTo8(void *from) {
uint8_t *val = reinterpret_cast<uint8_t*>(from);
std::size_t remainder = reinterpret_cast<std::size_t>(val) & 7;
if (!remainder) return val;
return val + 8 - remainder;
}
} // namespace
ArrayBhiksha::ArrayBhiksha(void *base, uint64_t max_offset, uint64_t max_next, const Config &config)
: next_inline_(util::BitsMask::ByBits(InlineBits(max_offset, max_next, config))),
offset_begin_(reinterpret_cast<const uint64_t*>(AlignTo8(base)) + 1 /* 8-byte header */),
offset_end_(offset_begin_ + ArrayCount(max_offset, max_next, config)),
write_to_(reinterpret_cast<uint64_t*>(AlignTo8(base)) + 1 /* 8-byte header */ + 1 /* first entry is 0 */),
original_base_(base) {}
void ArrayBhiksha::FinishedLoading(const Config &config) {
// *offset_begin_ = 0 but without a const_cast.
*(write_to_ - (write_to_ - offset_begin_)) = 0;
if (write_to_ != offset_end_) UTIL_THROW(util::Exception, "Did not get all the array entries that were expected.");
uint8_t *head_write = reinterpret_cast<uint8_t*>(original_base_);
*(head_write++) = kArrayBhikshaVersion;
*(head_write++) = config.pointer_bhiksha_bits;
}
} // namespace trie
} // namespace ngram
} // namespace lm
/* Simple implementation of
* @inproceedings{bhikshacompression,
* author={Bhiksha Raj and Ed Whittaker},
* year={2003},
* title={Lossless Compression of Language Model Structure and Word Identifiers},
* booktitle={Proceedings of IEEE International Conference on Acoustics, Speech and Signal Processing},
* pages={388--391},
* }
*
* Currently only used for next pointers.
*/
#ifndef LM_BHIKSHA_H
#define LM_BHIKSHA_H
#include "model_type.hh"
#include "trie.hh"
#include "../util/bit_packing.hh"
#include "../util/sorted_uniform.hh"
#include <algorithm>
#include <stdint.h>
#include <cassert>
namespace lm {
namespace ngram {
struct Config;
class BinaryFormat;
namespace trie {
class DontBhiksha {
public:
static const ModelType kModelTypeAdd = static_cast<ModelType>(0);
static void UpdateConfigFromBinary(const BinaryFormat &, uint64_t, Config &/*config*/) {}
static uint64_t Size(uint64_t /*max_offset*/, uint64_t /*max_next*/, const Config &/*config*/) { return 0; }
static uint8_t InlineBits(uint64_t /*max_offset*/, uint64_t max_next, const Config &/*config*/) {
return util::RequiredBits(max_next);
}
DontBhiksha(const void *base, uint64_t max_offset, uint64_t max_next, const Config &config);
void ReadNext(const void *base, uint64_t bit_offset, uint64_t /*index*/, uint8_t total_bits, NodeRange &out) const {
out.begin = util::ReadInt57(base, bit_offset, next_.bits, next_.mask);
out.end = util::ReadInt57(base, bit_offset + total_bits, next_.bits, next_.mask);
//assert(out.end >= out.begin);
}
void WriteNext(void *base, uint64_t bit_offset, uint64_t /*index*/, uint64_t value) {
util::WriteInt57(base, bit_offset, next_.bits, value);
}
void FinishedLoading(const Config &/*config*/) {}
uint8_t InlineBits() const { return next_.bits; }
private:
util::BitsMask next_;
};
class ArrayBhiksha {
public:
static const ModelType kModelTypeAdd = kArrayAdd;
static void UpdateConfigFromBinary(const BinaryFormat &file, uint64_t offset, Config &config);
static uint64_t Size(uint64_t max_offset, uint64_t max_next, const Config &config);
static uint8_t InlineBits(uint64_t max_offset, uint64_t max_next, const Config &config);
ArrayBhiksha(void *base, uint64_t max_offset, uint64_t max_value, const Config &config);
void ReadNext(const void *base, uint64_t bit_offset, uint64_t index, uint8_t total_bits, NodeRange &out) const {
// Some assertions are commented out because they are expensive.
// assert(*offset_begin_ == 0);
// std::upper_bound returns the first element that is greater. Want the
// last element that is <= to the index.
const uint64_t *begin_it = std::upper_bound(offset_begin_, offset_end_, index) - 1;
// Since *offset_begin_ == 0, the position should be in range.
// assert(begin_it >= offset_begin_);
const uint64_t *end_it;
for (end_it = begin_it + 1; (end_it < offset_end_) && (*end_it <= index + 1); ++end_it) {}
// assert(end_it == std::upper_bound(offset_begin_, offset_end_, index + 1));
--end_it;
// assert(end_it >= begin_it);
out.begin = ((begin_it - offset_begin_) << next_inline_.bits) |
util::ReadInt57(base, bit_offset, next_inline_.bits, next_inline_.mask);
out.end = ((end_it - offset_begin_) << next_inline_.bits) |
util::ReadInt57(base, bit_offset + total_bits, next_inline_.bits, next_inline_.mask);
// If this fails, consider rebuilding your model using KenLM after 1e333d786b748555e8f368d2bbba29a016c98052
assert(out.end >= out.begin);
}
void WriteNext(void *base, uint64_t bit_offset, uint64_t index, uint64_t value) {
uint64_t encode = value >> next_inline_.bits;
for (; write_to_ <= offset_begin_ + encode; ++write_to_) *write_to_ = index;
util::WriteInt57(base, bit_offset, next_inline_.bits, value & next_inline_.mask);
}
void FinishedLoading(const Config &config);
uint8_t InlineBits() const { return next_inline_.bits; }
private:
const util::BitsMask next_inline_;
const uint64_t *const offset_begin_;
const uint64_t *const offset_end_;
uint64_t *write_to_;
void *original_base_;
};
} // namespace trie
} // namespace ngram
} // namespace lm
#endif // LM_BHIKSHA_H
#include "binary_format.hh"
#include "lm_exception.hh"
#include "../util/file.hh"
#include "../util/file_piece.hh"
#include <cstddef>
#include <cstring>
#include <limits>
#include <string>
#include <cstdlib>
#include <stdint.h>
namespace lm {
namespace ngram {
const char *kModelNames[6] = {"probing hash tables", "probing hash tables with rest costs", "trie", "trie with quantization", "trie with array-compressed pointers", "trie with quantization and array-compressed pointers"};
namespace {
const char kMagicBeforeVersion[] = "mmap lm http://kheafield.com/code format version";
const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 5\n\0";
// This must be shorter than kMagicBytes and indicates an incomplete binary file (i.e. build failed).
const char kMagicIncomplete[] = "mmap lm http://kheafield.com/code incomplete\n";
const long int kMagicVersion = 5;
// Old binary files built on 32-bit machines have this header.
// TODO: eliminate with next binary release.
struct OldSanity {
char magic[sizeof(kMagicBytes)];
float zero_f, one_f, minus_half_f;
WordIndex one_word_index, max_word_index;
uint64_t one_uint64;
void SetToReference() {
std::memset(this, 0, sizeof(OldSanity));
std::memcpy(magic, kMagicBytes, sizeof(magic));
zero_f = 0.0; one_f = 1.0; minus_half_f = -0.5;
one_word_index = 1;
max_word_index = std::numeric_limits<WordIndex>::max();
one_uint64 = 1;
}
};
// Test values aligned to 8 bytes.
struct Sanity {
char magic[ALIGN8(sizeof(kMagicBytes))];
float zero_f, one_f, minus_half_f;
WordIndex one_word_index, max_word_index, padding_to_8;
uint64_t one_uint64;
void SetToReference() {
std::memset(this, 0, sizeof(Sanity));
std::memcpy(magic, kMagicBytes, sizeof(kMagicBytes));
zero_f = 0.0; one_f = 1.0; minus_half_f = -0.5;
one_word_index = 1;
max_word_index = std::numeric_limits<WordIndex>::max();
padding_to_8 = 0;
one_uint64 = 1;
}
};
std::size_t TotalHeaderSize(unsigned char order) {
return ALIGN8(sizeof(Sanity) + sizeof(FixedWidthParameters) + sizeof(uint64_t) * order);
}
void WriteHeader(void *to, const Parameters &params) {
Sanity header = Sanity();
header.SetToReference();
std::memcpy(to, &header, sizeof(Sanity));
char *out = reinterpret_cast<char*>(to) + sizeof(Sanity);
*reinterpret_cast<FixedWidthParameters*>(out) = params.fixed;
out += sizeof(FixedWidthParameters);
uint64_t *counts = reinterpret_cast<uint64_t*>(out);
for (std::size_t i = 0; i < params.counts.size(); ++i) {
counts[i] = params.counts[i];
}
}
} // namespace
bool IsBinaryFormat(int fd) {
const uint64_t size = util::SizeFile(fd);
if (size == util::kBadSize || (size <= static_cast<uint64_t>(sizeof(Sanity)))) return false;
// Try reading the header.
util::scoped_memory memory;
try {
util::MapRead(util::LAZY, fd, 0, sizeof(Sanity), memory);
} catch (const util::Exception &e) {
return false;
}
Sanity reference_header = Sanity();
reference_header.SetToReference();
if (!std::memcmp(memory.get(), &reference_header, sizeof(Sanity))) return true;
if (!std::memcmp(memory.get(), kMagicIncomplete, strlen(kMagicIncomplete))) {
UTIL_THROW(FormatLoadException, "This binary file did not finish building");
}
if (!std::memcmp(memory.get(), kMagicBeforeVersion, strlen(kMagicBeforeVersion))) {
char *end_ptr;
const char *begin_version = static_cast<const char*>(memory.get()) + strlen(kMagicBeforeVersion);
long int version = std::strtol(begin_version, &end_ptr, 10);
if ((end_ptr != begin_version) && version != kMagicVersion) {
UTIL_THROW(FormatLoadException, "Binary file has version " << version << " but this implementation expects version " << kMagicVersion << " so you'll have to use the ARPA to rebuild your binary");
}
OldSanity old_sanity = OldSanity();
old_sanity.SetToReference();
UTIL_THROW_IF(!std::memcmp(memory.get(), &old_sanity, sizeof(OldSanity)), FormatLoadException, "Looks like this is an old 32-bit format. The old 32-bit format has been removed so that 64-bit and 32-bit files are exchangeable.");
UTIL_THROW(FormatLoadException, "File looks like it should be loaded with mmap, but the test values don't match. Try rebuilding the binary format LM using the same code revision, compiler, and architecture");
}
return false;
}
void ReadHeader(int fd, Parameters &out) {
util::SeekOrThrow(fd, sizeof(Sanity));
util::ReadOrThrow(fd, &out.fixed, sizeof(out.fixed));
if (out.fixed.probing_multiplier < 1.0)
UTIL_THROW(FormatLoadException, "Binary format claims to have a probing multiplier of " << out.fixed.probing_multiplier << " which is < 1.0.");
out.counts.resize(static_cast<std::size_t>(out.fixed.order));
if (out.fixed.order) util::ReadOrThrow(fd, &*out.counts.begin(), sizeof(uint64_t) * out.fixed.order);
}
void MatchCheck(ModelType model_type, unsigned int search_version, const Parameters &params) {
if (params.fixed.model_type != model_type) {
if (static_cast<unsigned int>(params.fixed.model_type) >= (sizeof(kModelNames) / sizeof(const char *)))
UTIL_THROW(FormatLoadException, "The binary file claims to be model type " << static_cast<unsigned int>(params.fixed.model_type) << " but this is not implemented for in this inference code.");
UTIL_THROW(FormatLoadException, "The binary file was built for " << kModelNames[params.fixed.model_type] << " but the inference code is trying to load " << kModelNames[model_type]);
}
UTIL_THROW_IF(search_version != params.fixed.search_version, FormatLoadException, "The binary file has " << kModelNames[params.fixed.model_type] << " version " << params.fixed.search_version << " but this code expects " << kModelNames[params.fixed.model_type] << " version " << search_version);
}
const std::size_t kInvalidSize = static_cast<std::size_t>(-1);
BinaryFormat::BinaryFormat(const Config &config)
: write_method_(config.write_method), write_mmap_(config.write_mmap), load_method_(config.load_method),
header_size_(kInvalidSize), vocab_size_(kInvalidSize), vocab_string_offset_(kInvalidOffset) {}
void BinaryFormat::InitializeBinary(int fd, ModelType model_type, unsigned int search_version, Parameters &params) {
file_.reset(fd);
write_mmap_ = NULL; // Ignore write requests; this is already in binary format.
ReadHeader(fd, params);
MatchCheck(model_type, search_version, params);
header_size_ = TotalHeaderSize(params.counts.size());
}
void BinaryFormat::ReadForConfig(void *to, std::size_t amount, uint64_t offset_excluding_header) const {
assert(header_size_ != kInvalidSize);
util::ErsatzPRead(file_.get(), to, amount, offset_excluding_header + header_size_);
}
void *BinaryFormat::LoadBinary(std::size_t size) {
assert(header_size_ != kInvalidSize);
const uint64_t file_size = util::SizeFile(file_.get());
// The header is smaller than a page, so we have to map the whole header as well.
uint64_t total_map = static_cast<uint64_t>(header_size_) + static_cast<uint64_t>(size);
UTIL_THROW_IF(file_size != util::kBadSize && file_size < total_map, FormatLoadException, "Binary file has size " << file_size << " but the headers say it should be at least " << total_map);
util::MapRead(load_method_, file_.get(), 0, util::CheckOverflow(total_map), mapping_);
vocab_string_offset_ = total_map;
return reinterpret_cast<uint8_t*>(mapping_.get()) + header_size_;
}
void *BinaryFormat::SetupJustVocab(std::size_t memory_size, uint8_t order) {
vocab_size_ = memory_size;
if (!write_mmap_) {
header_size_ = 0;
util::HugeMalloc(memory_size, true, memory_vocab_);
return reinterpret_cast<uint8_t*>(memory_vocab_.get());
}
header_size_ = TotalHeaderSize(order);
std::size_t total = util::CheckOverflow(static_cast<uint64_t>(header_size_) + static_cast<uint64_t>(memory_size));
file_.reset(util::CreateOrThrow(write_mmap_));
// some gccs complain about uninitialized variables even though all enum values are covered.
void *vocab_base = NULL;
switch (write_method_) {
case Config::WRITE_MMAP:
mapping_.reset(util::MapZeroedWrite(file_.get(), total), total, util::scoped_memory::MMAP_ALLOCATED);
util::AdviseHugePages(vocab_base, total);
vocab_base = mapping_.get();
break;
case Config::WRITE_AFTER:
util::ResizeOrThrow(file_.get(), 0);
util::HugeMalloc(total, true, memory_vocab_);
vocab_base = memory_vocab_.get();
break;
}
strncpy(reinterpret_cast<char*>(vocab_base), kMagicIncomplete, header_size_);
return reinterpret_cast<uint8_t*>(vocab_base) + header_size_;
}
void *BinaryFormat::GrowForSearch(std::size_t memory_size, std::size_t vocab_pad, void *&vocab_base) {
assert(vocab_size_ != kInvalidSize);
vocab_pad_ = vocab_pad;
std::size_t new_size = header_size_ + vocab_size_ + vocab_pad_ + memory_size;
vocab_string_offset_ = new_size;
if (!write_mmap_ || write_method_ == Config::WRITE_AFTER) {
util::HugeMalloc(memory_size, true, memory_search_);
assert(header_size_ == 0 || write_mmap_);
vocab_base = reinterpret_cast<uint8_t*>(memory_vocab_.get()) + header_size_;
util::AdviseHugePages(memory_search_.get(), memory_size);
return reinterpret_cast<uint8_t*>(memory_search_.get());
}
assert(write_method_ == Config::WRITE_MMAP);
// Also known as total size without vocab words.
// Grow the file to accomodate the search, using zeros.
// According to man mmap, behavior is undefined when the file is resized
// underneath a mmap that is not a multiple of the page size. So to be
// safe, we'll unmap it and map it again.
mapping_.reset();
util::ResizeOrThrow(file_.get(), new_size);
void *ret;
MapFile(vocab_base, ret);
util::AdviseHugePages(ret, new_size);
return ret;
}
void BinaryFormat::WriteVocabWords(const std::string &buffer, void *&vocab_base, void *&search_base) {
// Checking Config's include_vocab is the responsibility of the caller.
assert(header_size_ != kInvalidSize && vocab_size_ != kInvalidSize);
if (!write_mmap_) {
// Unchanged base.
vocab_base = reinterpret_cast<uint8_t*>(memory_vocab_.get());
search_base = reinterpret_cast<uint8_t*>(memory_search_.get());
return;
}
if (write_method_ == Config::WRITE_MMAP) {
mapping_.reset();
}
util::SeekOrThrow(file_.get(), VocabStringReadingOffset());
util::WriteOrThrow(file_.get(), &buffer[0], buffer.size());
if (write_method_ == Config::WRITE_MMAP) {
MapFile(vocab_base, search_base);
} else {
vocab_base = reinterpret_cast<uint8_t*>(memory_vocab_.get()) + header_size_;
search_base = reinterpret_cast<uint8_t*>(memory_search_.get());
}
}
void BinaryFormat::FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts) {
if (!write_mmap_) return;
switch (write_method_) {
case Config::WRITE_MMAP:
util::SyncOrThrow(mapping_.get(), mapping_.size());
break;
case Config::WRITE_AFTER:
util::SeekOrThrow(file_.get(), 0);
util::WriteOrThrow(file_.get(), memory_vocab_.get(), memory_vocab_.size());
util::SeekOrThrow(file_.get(), header_size_ + vocab_size_ + vocab_pad_);
util::WriteOrThrow(file_.get(), memory_search_.get(), memory_search_.size());
util::FSyncOrThrow(file_.get());
break;
}
// header and vocab share the same mmap.
Parameters params = Parameters();
memset(&params, 0, sizeof(Parameters));
params.counts = counts;
params.fixed.order = counts.size();
params.fixed.probing_multiplier = config.probing_multiplier;
params.fixed.model_type = model_type;
params.fixed.has_vocabulary = config.include_vocab;
params.fixed.search_version = search_version;
switch (write_method_) {
case Config::WRITE_MMAP:
WriteHeader(mapping_.get(), params);
util::SyncOrThrow(mapping_.get(), mapping_.size());
break;
case Config::WRITE_AFTER:
{
std::vector<uint8_t> buffer(TotalHeaderSize(counts.size()));
WriteHeader(&buffer[0], params);
util::SeekOrThrow(file_.get(), 0);
util::WriteOrThrow(file_.get(), &buffer[0], buffer.size());
}
break;
}
}
void BinaryFormat::MapFile(void *&vocab_base, void *&search_base) {
mapping_.reset(util::MapOrThrow(vocab_string_offset_, true, util::kFileFlags, false, file_.get()), vocab_string_offset_, util::scoped_memory::MMAP_ALLOCATED);
vocab_base = reinterpret_cast<uint8_t*>(mapping_.get()) + header_size_;
search_base = reinterpret_cast<uint8_t*>(mapping_.get()) + header_size_ + vocab_size_ + vocab_pad_;
}
bool RecognizeBinary(const char *file, ModelType &recognized) {
util::scoped_fd fd(util::OpenReadOrThrow(file));
if (!IsBinaryFormat(fd.get())) {
return false;
}
Parameters params;
ReadHeader(fd.get(), params);
recognized = params.fixed.model_type;
return true;
}
} // namespace ngram
} // namespace lm
#ifndef LM_BINARY_FORMAT_H
#define LM_BINARY_FORMAT_H
#include "config.hh"
#include "model_type.hh"
#include "read_arpa.hh"
#include "../util/file_piece.hh"
#include "../util/mmap.hh"
#include "../util/scoped.hh"
#include <cstddef>
#include <vector>
#include <stdint.h>
namespace lm {
namespace ngram {
extern const char *kModelNames[6];
/*Inspect a file to determine if it is a binary lm. If not, return false.
* If so, return true and set recognized to the type. This is the only API in
* this header designed for use by decoder authors.
*/
bool RecognizeBinary(const char *file, ModelType &recognized);
struct FixedWidthParameters {
unsigned char order;
float probing_multiplier;
// What type of model is this?
ModelType model_type;
// Does the end of the file have the actual strings in the vocabulary?
bool has_vocabulary;
unsigned int search_version;
};
// This is a macro instead of an inline function so constants can be assigned using it.
#define ALIGN8(a) ((std::ptrdiff_t(((a)-1)/8)+1)*8)
// Parameters stored in the header of a binary file.
struct Parameters {
FixedWidthParameters fixed;
std::vector<uint64_t> counts;
};
class BinaryFormat {
public:
explicit BinaryFormat(const Config &config);
// Reading a binary file:
// Takes ownership of fd
void InitializeBinary(int fd, ModelType model_type, unsigned int search_version, Parameters &params);
// Used to read parts of the file to update the config object before figuring out full size.
void ReadForConfig(void *to, std::size_t amount, uint64_t offset_excluding_header) const;
// Actually load the binary file and return a pointer to the beginning of the search area.
void *LoadBinary(std::size_t size);
uint64_t VocabStringReadingOffset() const {
assert(vocab_string_offset_ != kInvalidOffset);
return vocab_string_offset_;
}
// Writing a binary file or initializing in RAM from ARPA:
// Size for vocabulary.
void *SetupJustVocab(std::size_t memory_size, uint8_t order);
// Warning: can change the vocaulary base pointer.
void *GrowForSearch(std::size_t memory_size, std::size_t vocab_pad, void *&vocab_base);
// Warning: can change vocabulary and search base addresses.
void WriteVocabWords(const std::string &buffer, void *&vocab_base, void *&search_base);
// Write the header at the beginning of the file.
void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector<uint64_t> &counts);
private:
void MapFile(void *&vocab_base, void *&search_base);
// Copied from configuration.
const Config::WriteMethod write_method_;
const char *write_mmap_;
util::LoadMethod load_method_;
// File behind memory, if any.
util::scoped_fd file_;
// If there is a file involved, a single mapping.
util::scoped_memory mapping_;
// If the data is only in memory, separately allocate each because the trie
// knows vocab's size before it knows search's size (because SRILM might
// have pruned).
util::scoped_memory memory_vocab_, memory_search_;
// Memory ranges. Note that these may not be contiguous and may not all
// exist.
std::size_t header_size_, vocab_size_, vocab_pad_;
// aka end of search.
uint64_t vocab_string_offset_;
static const uint64_t kInvalidOffset = (uint64_t)-1;
};
bool IsBinaryFormat(int fd);
} // namespace ngram
} // namespace lm
#endif // LM_BINARY_FORMAT_H
#ifndef LM_BLANK_H
#define LM_BLANK_H
#include <limits>
#include <stdint.h>
#include <cmath>
namespace lm {
namespace ngram {
/* Suppose "foo bar" appears with zero backoff but there is no trigram
* beginning with these words. Then, when scoring "foo bar", the model could
* return out_state containing "bar" or even null context if "bar" also has no
* backoff and is never followed by another word. Then the backoff is set to
* kNoExtensionBackoff. If the n-gram might be extended, then out_state must
* contain the full n-gram, in which case kExtensionBackoff is set. In any
* case, if an n-gram has non-zero backoff, the full state is returned so
* backoff can be properly charged.
* These differ only in sign bit because the backoff is in fact zero in either
* case.
*/
const float kNoExtensionBackoff = -0.0;
const float kExtensionBackoff = 0.0;
const uint64_t kNoExtensionQuant = 0;
const uint64_t kExtensionQuant = 1;
inline void SetExtension(float &backoff) {
if (backoff == kNoExtensionBackoff) backoff = kExtensionBackoff;
}
// This compiles down nicely.
inline bool HasExtension(const float &backoff) {
typedef union { float f; uint32_t i; } UnionValue;
UnionValue compare, interpret;
compare.f = kNoExtensionBackoff;
interpret.f = backoff;
return compare.i != interpret.i;
}
} // namespace ngram
} // namespace lm
#endif // LM_BLANK_H
#include "model.hh"
#include "sizes.hh"
#include "../util/file_piece.hh"
#include "../util/usage.hh"
#include <algorithm>
#include <cstdlib>
#include <exception>
#include <iostream>
#include <iomanip>
#include <limits>
#include <cmath>
#ifdef WIN32
#include "../util/getopt.hh"
#else
#include <unistd.h>
#endif
namespace lm {
namespace ngram {
namespace {
void Usage(const char *name, const char *default_mem) {
std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-i] [-v] [-w mmap|after] [-p probing_multiplier] [-T trie_temporary] [-S trie_building_mem] [-q bits] [-b bits] [-a bits] [type] input.arpa [output.mmap]\n\n"
"-u sets the log10 probability for <unk> if the ARPA file does not have one.\n"
" Default is -100. The ARPA file will always take precedence.\n"
"-s allows models to be built even if they do not have <s> and </s>.\n"
"-i allows buggy models from IRSTLM by mapping positive log probability to 0.\n"
"-v disables inclusion of the vocabulary in the binary file.\n"
"-w mmap|after determines how writing is done.\n"
" mmap maps the binary file and writes to it. Default for trie.\n"
" after allocates anonymous memory, builds, and writes. Default for probing.\n"
"-r \"order1.arpa order2 order3 order4\" adds lower-order rest costs from these\n"
" model files. order1.arpa must be an ARPA file. All others may be ARPA or\n"
" the same data structure as being built. All files must have the same\n"
" vocabulary. For probing, the unigrams must be in the same order.\n\n"
"type is either probing or trie. Default is probing.\n\n"
"probing uses a probing hash table. It is the fastest but uses the most memory.\n"
"-p sets the space multiplier and must be >1.0. The default is 1.5.\n\n"
"trie is a straightforward trie with bit-level packing. It uses the least\n"
"memory and is still faster than SRI or IRST. Building the trie format uses an\n"
"on-disk sort to save memory.\n"
"-T is the temporary directory prefix. Default is the output file name.\n"
"-S determines memory use for sorting. Default is " << default_mem << ". This is compatible\n"
" with GNU sort. The number is followed by a unit: \% for percent of physical\n"
" memory, b for bytes, K for Kilobytes, M for megabytes, then G,T,P,E,Z,Y. \n"
" Default unit is K for Kilobytes.\n"
"-q turns quantization on and sets the number of bits (e.g. -q 8).\n"
"-b sets backoff quantization bits. Requires -q and defaults to that value.\n"
"-a compresses pointers using an array of offsets. The parameter is the\n"
" maximum number of bits encoded by the array. Memory is minimized subject\n"
" to the maximum, so pick 255 to minimize memory.\n\n"
"-h print this help message.\n\n"
"Get a memory estimate by passing an ARPA file without an output file name.\n";
exit(1);
}
// I could really use boost::lexical_cast right about now.
float ParseFloat(const char *from) {
char *end;
float ret = strtod(from, &end);
if (*end) throw util::ParseNumberException(from);
return ret;
}
unsigned long int ParseUInt(const char *from) {
char *end;
unsigned long int ret = strtoul(from, &end, 10);
if (*end) throw util::ParseNumberException(from);
return ret;
}
uint8_t ParseBitCount(const char *from) {
unsigned long val = ParseUInt(from);
if (val > 25) {
util::ParseNumberException e(from);
e << " bit counts are limited to 25.";
}
return val;
}
void ParseFileList(const char *from, std::vector<std::string> &to) {
to.clear();
while (true) {
const char *i;
for (i = from; *i && *i != ' '; ++i) {}
to.push_back(std::string(from, i - from));
if (!*i) break;
from = i + 1;
}
}
void ProbingQuantizationUnsupported() {
std::cerr << "Quantization is only implemented in the trie data structure." << std::endl;
exit(1);
}
} // namespace ngram
} // namespace lm
} // namespace
int main(int argc, char *argv[]) {
using namespace lm::ngram;
const char *default_mem = util::GuessPhysicalMemory() ? "80%" : "1G";
if (argc == 2 && !strcmp(argv[1], "--help"))
Usage(argv[0], default_mem);
try {
bool quantize = false, set_backoff_bits = false, bhiksha = false, set_write_method = false, rest = false;
lm::ngram::Config config;
config.building_memory = util::ParseSize(default_mem);
int opt;
while ((opt = getopt(argc, argv, "q:b:a:u:p:t:T:m:S:w:sir:vh")) != -1) {
switch(opt) {
case 'q':
config.prob_bits = ParseBitCount(optarg);
if (!set_backoff_bits) config.backoff_bits = config.prob_bits;
quantize = true;
break;
case 'b':
config.backoff_bits = ParseBitCount(optarg);
set_backoff_bits = true;
break;
case 'a':
config.pointer_bhiksha_bits = ParseBitCount(optarg);
bhiksha = true;
break;
case 'u':
config.unknown_missing_logprob = ParseFloat(optarg);
break;
case 'p':
config.probing_multiplier = ParseFloat(optarg);
break;
case 't': // legacy
case 'T':
config.temporary_directory_prefix = optarg;
util::NormalizeTempPrefix(config.temporary_directory_prefix);
break;
case 'm': // legacy
config.building_memory = ParseUInt(optarg) * 1048576;
break;
case 'S':
config.building_memory = std::min(static_cast<uint64_t>(std::numeric_limits<std::size_t>::max()), util::ParseSize(optarg));
break;
case 'w':
set_write_method = true;
if (!strcmp(optarg, "mmap")) {
config.write_method = Config::WRITE_MMAP;
} else if (!strcmp(optarg, "after")) {
config.write_method = Config::WRITE_AFTER;
} else {
Usage(argv[0], default_mem);
}
break;
case 's':
config.sentence_marker_missing = lm::SILENT;
break;
case 'i':
config.positive_log_probability = lm::SILENT;
break;
case 'r':
rest = true;
ParseFileList(optarg, config.rest_lower_files);
config.rest_function = Config::REST_LOWER;
break;
case 'v':
config.include_vocab = false;
break;
case 'h': // help
default:
Usage(argv[0], default_mem);
}
}
if (!quantize && set_backoff_bits) {
std::cerr << "You specified backoff quantization (-b) but not probability quantization (-q)" << std::endl;
abort();
}
if (optind + 1 == argc) {
ShowSizes(argv[optind], config);
return 0;
}
const char *model_type;
const char *from_file;
if (optind + 2 == argc) {
model_type = "probing";
from_file = argv[optind];
config.write_mmap = argv[optind + 1];
} else if (optind + 3 == argc) {
model_type = argv[optind];
from_file = argv[optind + 1];
config.write_mmap = argv[optind + 2];
} else {
Usage(argv[0], default_mem);
return 1;
}
if (!strcmp(model_type, "probing")) {
if (!set_write_method) config.write_method = Config::WRITE_AFTER;
if (quantize || set_backoff_bits) ProbingQuantizationUnsupported();
if (rest) {
RestProbingModel(from_file, config);
} else {
ProbingModel(from_file, config);
}
} else if (!strcmp(model_type, "trie")) {
if (rest) {
std::cerr << "Rest + trie is not supported yet." << std::endl;
return 1;
}
if (!set_write_method) config.write_method = Config::WRITE_MMAP;
if (quantize) {
if (bhiksha) {
QuantArrayTrieModel(from_file, config);
} else {
QuantTrieModel(from_file, config);
}
} else {
if (bhiksha) {
ArrayTrieModel(from_file, config);
} else {
TrieModel(from_file, config);
}
}
} else {
Usage(argv[0], default_mem);
}
}
catch (const std::exception &e) {
std::cerr << e.what() << std::endl;
std::cerr << "ERROR" << std::endl;
return 1;
}
std::cerr << "SUCCESS" << std::endl;
return 0;
}
# This CMake file was created by Lane Schwartz <dowobeha@gmail.com>
# Explicitly list the source files for this subdirectory
#
# If you add any source files to this subdirectory
# that should be included in the kenlm library,
# (this excludes any unit test files)
# you should add them to the following list:
#
# In order to set correct paths to these files
# in case this variable is referenced by CMake files in the parent directory,
# we prefix all files with ${CMAKE_CURRENT_SOURCE_DIR}.
#
set(KENLM_BUILDER_SOURCE
${CMAKE_CURRENT_SOURCE_DIR}/adjust_counts.cc
${CMAKE_CURRENT_SOURCE_DIR}/corpus_count.cc
${CMAKE_CURRENT_SOURCE_DIR}/initial_probabilities.cc
${CMAKE_CURRENT_SOURCE_DIR}/interpolate.cc
${CMAKE_CURRENT_SOURCE_DIR}/output.cc
${CMAKE_CURRENT_SOURCE_DIR}/pipeline.cc
)
# Group these objects together for later use.
#
# Given add_library(foo OBJECT ${my_foo_sources}),
# refer to these objects as $<TARGET_OBJECTS:foo>
#
add_library(kenlm_builder ${KENLM_BUILDER_SOURCE})
target_link_libraries(kenlm_builder PUBLIC kenlm kenlm_util Threads::Threads)
# Since headers are relative to `include/kenlm` at install time, not just `include`
target_include_directories(kenlm_builder PUBLIC $<INSTALL_INTERFACE:include/kenlm>)
AddExes(EXES lmplz
LIBRARIES kenlm_builder kenlm kenlm_util Threads::Threads)
AddExes(EXES count_ngrams
LIBRARIES kenlm_builder kenlm kenlm_util Threads::Threads)
install(
TARGETS kenlm_builder
EXPORT kenlmTargets
RUNTIME DESTINATION bin
LIBRARY DESTINATION lib
ARCHIVE DESTINATION lib
INCLUDES DESTINATION include
)
if(BUILD_TESTING)
# Explicitly list the Boost test files to be compiled
set(KENLM_BOOST_TESTS_LIST
adjust_counts_test
corpus_count_test
)
AddTests(TESTS ${KENLM_BOOST_TESTS_LIST}
LIBRARIES kenlm_builder kenlm kenlm_util Threads::Threads)
endif()
Dependencies
============
Boost >= 1.42.0 is required.
For Ubuntu,
```bash
sudo apt-get install libboost1.48-all-dev
```
Alternatively, you can download, compile, and install it yourself:
```bash
wget http://sourceforge.net/projects/boost/files/boost/1.52.0/boost_1_52_0.tar.gz/download -O boost_1_52_0.tar.gz
tar -xvzf boost_1_52_0.tar.gz
cd boost_1_52_0
./bootstrap.sh
./b2
sudo ./b2 install
```
Local install options (in a user-space prefix directory) are also possible. See http://www.boost.org/doc/libs/1_52_0/doc/html/bbv2/installation.html.
Building
========
```bash
bjam
```
Your distribution might package bjam and boost-build separately from Boost. Both are required.
Usage
=====
Run
```bash
$ bin/lmplz
```
to see command line arguments
Running
=======
```bash
bin/lmplz -o 5 <text >text.arpa
```
More tests!
Sharding.
Some way to manage all the crazy config options.
Option to build the binary file directly.
Interpolation of different orders.
#include "adjust_counts.hh"
#include "../common/ngram_stream.hh"
#include "payload.hh"
#include <algorithm>
#include <iostream>
#include <limits>
namespace lm { namespace builder {
BadDiscountException::BadDiscountException() throw() {}
BadDiscountException::~BadDiscountException() throw() {}
namespace {
// Return last word in full that is different.
const WordIndex* FindDifference(const NGram<BuildingPayload> &full, const NGram<BuildingPayload> &lower_last) {
const WordIndex *cur_word = full.end() - 1;
const WordIndex *pre_word = lower_last.end() - 1;
// Find last difference.
for (; pre_word >= lower_last.begin() && *pre_word == *cur_word; --cur_word, --pre_word) {}
return cur_word;
}
class StatCollector {
public:
StatCollector(std::size_t order, std::vector<uint64_t> &counts, std::vector<uint64_t> &counts_pruned, std::vector<Discount> &discounts)
: orders_(order), full_(orders_.back()), counts_(counts), counts_pruned_(counts_pruned), discounts_(discounts) {
memset(&orders_[0], 0, sizeof(OrderStat) * order);
}
~StatCollector() {}
void CalculateDiscounts(const DiscountConfig &config) {
counts_.resize(orders_.size());
counts_pruned_.resize(orders_.size());
for (std::size_t i = 0; i < orders_.size(); ++i) {
const OrderStat &s = orders_[i];
counts_[i] = s.count;
counts_pruned_[i] = s.count_pruned;
}
discounts_ = config.overwrite;
discounts_.resize(orders_.size());
for (std::size_t i = config.overwrite.size(); i < orders_.size(); ++i) {
const OrderStat &s = orders_[i];
try {
for (unsigned j = 1; j < 4; ++j) {
// TODO: Specialize error message for j == 3, meaning 3+
UTIL_THROW_IF(s.n[j] == 0, BadDiscountException, "Could not calculate Kneser-Ney discounts for "
<< (i+1) << "-grams with adjusted count " << (j+1) << " because we didn't observe any "
<< (i+1) << "-grams with adjusted count " << j << "; Is this small or artificial data?\n"
<< "Try deduplicating the input. To override this error for e.g. a class-based model, rerun with --discount_fallback\n");
}
// See equation (26) in Chen and Goodman.
discounts_[i].amount[0] = 0.0;
float y = static_cast<float>(s.n[1]) / static_cast<float>(s.n[1] + 2.0 * s.n[2]);
for (unsigned j = 1; j < 4; ++j) {
discounts_[i].amount[j] = static_cast<float>(j) - static_cast<float>(j + 1) * y * static_cast<float>(s.n[j+1]) / static_cast<float>(s.n[j]);
UTIL_THROW_IF(discounts_[i].amount[j] < 0.0 || discounts_[i].amount[j] > j, BadDiscountException, "ERROR: " << (i+1) << "-gram discount out of range for adjusted count " << j << ": " << discounts_[i].amount[j] << ". This means modified Kneser-Ney smoothing thinks something is weird about your data. To override this error for e.g. a class-based model, rerun with --discount_fallback\n");
}
} catch (const BadDiscountException &) {
switch (config.bad_action) {
case THROW_UP:
throw;
case COMPLAIN:
std::cerr << "Substituting fallback discounts for order " << i << ": D1=" << config.fallback.amount[1] << " D2=" << config.fallback.amount[2] << " D3+=" << config.fallback.amount[3] << std::endl;
case SILENT:
break;
}
discounts_[i] = config.fallback;
}
}
}
void Add(std::size_t order_minus_1, uint64_t count, bool pruned = false) {
OrderStat &stat = orders_[order_minus_1];
++stat.count;
if (!pruned)
++stat.count_pruned;
if (count < 5) ++stat.n[count];
}
void AddFull(uint64_t count, bool pruned = false) {
++full_.count;
if (!pruned)
++full_.count_pruned;
if (count < 5) ++full_.n[count];
}
private:
struct OrderStat {
// n_1 in equation 26 of Chen and Goodman etc
uint64_t n[5];
uint64_t count;
uint64_t count_pruned;
};
std::vector<OrderStat> orders_;
OrderStat &full_;
std::vector<uint64_t> &counts_;
std::vector<uint64_t> &counts_pruned_;
std::vector<Discount> &discounts_;
};
// Reads all entries in order like NGramStream does.
// But deletes any entries that have <s> in the 1st (not 0th) position on the
// way out by putting other entries in their place. This disrupts the sort
// order but we don't care because the data is going to be sorted again.
class CollapseStream {
public:
CollapseStream(const util::stream::ChainPosition &position, uint64_t prune_threshold, const std::vector<bool>& prune_words) :
current_(NULL, NGram<BuildingPayload>::OrderFromSize(position.GetChain().EntrySize())),
prune_threshold_(prune_threshold),
prune_words_(prune_words),
block_(position) {
StartBlock();
}
const NGram<BuildingPayload> &operator*() const { return current_; }
const NGram<BuildingPayload> *operator->() const { return &current_; }
operator bool() const { return block_; }
CollapseStream &operator++() {
assert(block_);
if (current_.begin()[1] == kBOS && current_.Base() < copy_from_) {
memcpy(current_.Base(), copy_from_, current_.TotalSize());
UpdateCopyFrom();
// Mark highest order n-grams for later pruning
if(current_.Value().count <= prune_threshold_) {
current_.Value().Mark();
}
if(!prune_words_.empty()) {
for(WordIndex* i = current_.begin(); i != current_.end(); i++) {
if(prune_words_[*i]) {
current_.Value().Mark();
break;
}
}
}
}
current_.NextInMemory();
uint8_t *block_base = static_cast<uint8_t*>(block_->Get());
if (current_.Base() == block_base + block_->ValidSize()) {
block_->SetValidSize(copy_from_ + current_.TotalSize() - block_base);
++block_;
StartBlock();
}
// Mark highest order n-grams for later pruning
if(current_.Value().count <= prune_threshold_) {
current_.Value().Mark();
}
if(!prune_words_.empty()) {
for(WordIndex* i = current_.begin(); i != current_.end(); i++) {
if(prune_words_[*i]) {
current_.Value().Mark();
break;
}
}
}
return *this;
}
private:
void StartBlock() {
for (; ; ++block_) {
if (!block_) return;
if (block_->ValidSize()) break;
}
current_.ReBase(block_->Get());
copy_from_ = static_cast<uint8_t*>(block_->Get()) + block_->ValidSize();
UpdateCopyFrom();
// Mark highest order n-grams for later pruning
if(current_.Value().count <= prune_threshold_) {
current_.Value().Mark();
}
if(!prune_words_.empty()) {
for(WordIndex* i = current_.begin(); i != current_.end(); i++) {
if(prune_words_[*i]) {
current_.Value().Mark();
break;
}
}
}
}
// Find last without bos.
void UpdateCopyFrom() {
for (copy_from_ -= current_.TotalSize(); copy_from_ >= current_.Base(); copy_from_ -= current_.TotalSize()) {
if (NGram<BuildingPayload>(copy_from_, current_.Order()).begin()[1] != kBOS) break;
}
}
NGram<BuildingPayload> current_;
// Goes backwards in the block
uint8_t *copy_from_;
uint64_t prune_threshold_;
const std::vector<bool>& prune_words_;
util::stream::Link block_;
};
} // namespace
void AdjustCounts::Run(const util::stream::ChainPositions &positions) {
const std::size_t order = positions.size();
StatCollector stats(order, counts_, counts_pruned_, discounts_);
if (order == 1) {
// Only unigrams. Just collect stats.
for (NGramStream<BuildingPayload> full(positions[0]); full; ++full) {
// Do not prune <s> </s> <unk>
if(*full->begin() > 2) {
if(full->Value().count <= prune_thresholds_[0])
full->Value().Mark();
if(!prune_words_.empty() && prune_words_[*full->begin()])
full->Value().Mark();
}
stats.AddFull(full->Value().UnmarkedCount(), full->Value().IsMarked());
}
stats.CalculateDiscounts(discount_config_);
return;
}
NGramStreams<BuildingPayload> streams;
streams.Init(positions, positions.size() - 1);
CollapseStream full(positions[positions.size() - 1], prune_thresholds_.back(), prune_words_);
// Initialization: <unk> has count 0 and so does <s>.
NGramStream<BuildingPayload> *lower_valid = streams.begin();
const NGramStream<BuildingPayload> *const streams_begin = streams.begin();
streams[0]->Value().count = 0;
*streams[0]->begin() = kUNK;
stats.Add(0, 0);
(++streams[0])->Value().count = 0;
*streams[0]->begin() = kBOS;
// <s> is not in stats yet because it will get put in later.
// This keeps track of actual counts for lower orders. It is not output
// (only adjusted counts are), but used to determine pruning.
std::vector<uint64_t> actual_counts(positions.size(), 0);
// Something of a hack: don't prune <s>.
actual_counts[0] = std::numeric_limits<uint64_t>::max();
// Iterate over full (the stream of the highest order ngrams)
for (; full; ++full) {
const WordIndex *different = FindDifference(*full, **lower_valid);
std::size_t same = full->end() - 1 - different;
// STEP 1: Output all the n-grams that changed.
for (; lower_valid >= streams.begin() + same; --lower_valid) {
uint64_t order_minus_1 = lower_valid - streams_begin;
if(actual_counts[order_minus_1] <= prune_thresholds_[order_minus_1])
(*lower_valid)->Value().Mark();
if(!prune_words_.empty()) {
for(WordIndex* i = (*lower_valid)->begin(); i != (*lower_valid)->end(); i++) {
if(prune_words_[*i]) {
(*lower_valid)->Value().Mark();
break;
}
}
}
stats.Add(order_minus_1, (*lower_valid)->Value().UnmarkedCount(), (*lower_valid)->Value().IsMarked());
++*lower_valid;
}
// STEP 2: Update n-grams that still match.
// n-grams that match get count from the full entry.
for (std::size_t i = 0; i < same; ++i) {
actual_counts[i] += full->Value().UnmarkedCount();
}
// Increment the number of unique extensions for the longest match.
if (same) ++streams[same - 1]->Value().count;
// STEP 3: Initialize new n-grams.
// This is here because bos is also const WordIndex *, so copy gets
// consistent argument types.
const WordIndex *full_end = full->end();
// Initialize and mark as valid up to bos.
const WordIndex *bos;
for (bos = different; (bos > full->begin()) && (*bos != kBOS); --bos) {
NGramStream<BuildingPayload> &to = *++lower_valid;
std::copy(bos, full_end, to->begin());
to->Value().count = 1;
actual_counts[lower_valid - streams_begin] = full->Value().UnmarkedCount();
}
// Now bos indicates where <s> is or is the 0th word of full.
if (bos != full->begin()) {
// There is an <s> beyond the 0th word.
NGramStream<BuildingPayload> &to = *++lower_valid;
std::copy(bos, full_end, to->begin());
// Anything that begins with <s> has full non adjusted count.
to->Value().count = full->Value().UnmarkedCount();
actual_counts[lower_valid - streams_begin] = full->Value().UnmarkedCount();
} else {
stats.AddFull(full->Value().UnmarkedCount(), full->Value().IsMarked());
}
assert(lower_valid >= &streams[0]);
}
// The above loop outputs n-grams when it observes changes. This outputs
// the last n-grams.
for (NGramStream<BuildingPayload> *s = streams.begin(); s <= lower_valid; ++s) {
uint64_t lower_count = actual_counts[(*s)->Order() - 1];
if(lower_count <= prune_thresholds_[(*s)->Order() - 1])
(*s)->Value().Mark();
if(!prune_words_.empty()) {
for(WordIndex* i = (*s)->begin(); i != (*s)->end(); i++) {
if(prune_words_[*i]) {
(*s)->Value().Mark();
break;
}
}
}
stats.Add(s - streams.begin(), lower_count, (*s)->Value().IsMarked());
++*s;
}
// Poison everyone! Except the N-grams which were already poisoned by the input.
for (NGramStream<BuildingPayload> *s = streams.begin(); s != streams.end(); ++s)
s->Poison();
stats.CalculateDiscounts(discount_config_);
// NOTE: See special early-return case for unigrams near the top of this function
}
}} // namespaces
#ifndef LM_BUILDER_ADJUST_COUNTS_H
#define LM_BUILDER_ADJUST_COUNTS_H
#include "discount.hh"
#include "../lm_exception.hh"
#include "../../util/exception.hh"
#include <vector>
#include <stdint.h>
namespace util { namespace stream { class ChainPositions; } }
namespace lm {
namespace builder {
class BadDiscountException : public util::Exception {
public:
BadDiscountException() throw();
~BadDiscountException() throw();
};
struct DiscountConfig {
// Overrides discounts for orders [1,discount_override.size()].
std::vector<Discount> overwrite;
// If discounting fails for an order, copy them from here.
Discount fallback;
// What to do when discounts are out of range or would trigger divison by
// zero. It it does something other than THROW_UP, use fallback_discount.
WarningAction bad_action;
};
/* Compute adjusted counts.
* Input: unique suffix sorted N-grams (and just the N-grams) with raw counts.
* Output: [1,N]-grams with adjusted counts.
* [1,N)-grams are in suffix order
* N-grams are in undefined order (they're going to be sorted anyway).
*/
class AdjustCounts {
public:
// counts: output
// counts_pruned: output
// discounts: mostly output. If the input already has entries, they will be kept.
// prune_thresholds: input. n-grams with normal (not adjusted) count below this will be pruned.
AdjustCounts(
const std::vector<uint64_t> &prune_thresholds,
std::vector<uint64_t> &counts,
std::vector<uint64_t> &counts_pruned,
const std::vector<bool> &prune_words,
const DiscountConfig &discount_config,
std::vector<Discount> &discounts)
: prune_thresholds_(prune_thresholds), counts_(counts), counts_pruned_(counts_pruned),
prune_words_(prune_words), discount_config_(discount_config), discounts_(discounts)
{}
void Run(const util::stream::ChainPositions &positions);
private:
const std::vector<uint64_t> &prune_thresholds_;
std::vector<uint64_t> &counts_;
std::vector<uint64_t> &counts_pruned_;
const std::vector<bool> &prune_words_;
DiscountConfig discount_config_;
std::vector<Discount> &discounts_;
};
} // namespace builder
} // namespace lm
#endif // LM_BUILDER_ADJUST_COUNTS_H
#include "adjust_counts.hh"
#include "../common/ngram_stream.hh"
#include "payload.hh"
#include "../../util/scoped.hh"
#include <boost/thread/thread.hpp>
#define BOOST_TEST_MODULE AdjustCounts
#include <boost/test/unit_test.hpp>
namespace lm { namespace builder { namespace {
class KeepCopy {
public:
KeepCopy() : size_(0) {}
void Run(const util::stream::ChainPosition &position) {
for (util::stream::Link link(position); link; ++link) {
mem_.call_realloc(size_ + link->ValidSize());
memcpy(static_cast<uint8_t*>(mem_.get()) + size_, link->Get(), link->ValidSize());
size_ += link->ValidSize();
}
}
uint8_t *Get() { return static_cast<uint8_t*>(mem_.get()); }
std::size_t Size() const { return size_; }
private:
util::scoped_malloc mem_;
std::size_t size_;
};
struct Gram4 {
WordIndex ids[4];
uint64_t count;
};
class WriteInput {
public:
void Run(const util::stream::ChainPosition &position) {
NGramStream<BuildingPayload> input(position);
Gram4 grams[] = {
{{0,0,0,0},10},
{{0,0,3,0},3},
// bos
{{1,1,1,2},5},
{{0,0,3,2},5},
};
for (size_t i = 0; i < sizeof(grams) / sizeof(Gram4); ++i, ++input) {
memcpy(input->begin(), grams[i].ids, sizeof(WordIndex) * 4);
input->Value().count = grams[i].count;
}
input.Poison();
}
};
BOOST_AUTO_TEST_CASE(Simple) {
KeepCopy outputs[4];
std::vector<uint64_t> counts;
std::vector<Discount> discount;
{
util::stream::ChainConfig config;
config.total_memory = 100;
config.block_count = 1;
util::stream::Chains chains(4);
for (unsigned i = 0; i < 4; ++i) {
config.entry_size = NGram<BuildingPayload>::TotalSize(i + 1);
chains.push_back(config);
}
chains[3] >> WriteInput();
util::stream::ChainPositions for_adjust(chains);
for (unsigned i = 0; i < 4; ++i) {
chains[i] >> boost::ref(outputs[i]);
}
chains >> util::stream::kRecycle;
std::vector<uint64_t> counts_pruned(4);
std::vector<uint64_t> prune_thresholds(4);
DiscountConfig discount_config;
discount_config.fallback = Discount();
discount_config.bad_action = THROW_UP;
BOOST_CHECK_THROW(AdjustCounts(prune_thresholds, counts, counts_pruned, std::vector<bool>(), discount_config, discount).Run(for_adjust), BadDiscountException);
}
BOOST_REQUIRE_EQUAL(4UL, counts.size());
BOOST_CHECK_EQUAL(4UL, counts[0]);
// These are no longer set because the discounts are bad.
/* BOOST_CHECK_EQUAL(4UL, counts[1]);
BOOST_CHECK_EQUAL(3UL, counts[2]);
BOOST_CHECK_EQUAL(3UL, counts[3]);*/
BOOST_REQUIRE_EQUAL(NGram<BuildingPayload>::TotalSize(1) * 4, outputs[0].Size());
NGram<BuildingPayload> uni(outputs[0].Get(), 1);
BOOST_CHECK_EQUAL(kUNK, *uni.begin());
BOOST_CHECK_EQUAL(0ULL, uni.Value().count);
uni.NextInMemory();
BOOST_CHECK_EQUAL(kBOS, *uni.begin());
BOOST_CHECK_EQUAL(0ULL, uni.Value().count);
uni.NextInMemory();
BOOST_CHECK_EQUAL(0UL, *uni.begin());
BOOST_CHECK_EQUAL(2ULL, uni.Value().count);
uni.NextInMemory();
BOOST_CHECK_EQUAL(2ULL, uni.Value().count);
BOOST_CHECK_EQUAL(2UL, *uni.begin());
BOOST_REQUIRE_EQUAL(NGram<BuildingPayload>::TotalSize(2) * 4, outputs[1].Size());
NGram<BuildingPayload> bi(outputs[1].Get(), 2);
BOOST_CHECK_EQUAL(0UL, *bi.begin());
BOOST_CHECK_EQUAL(0UL, *(bi.begin() + 1));
BOOST_CHECK_EQUAL(1ULL, bi.Value().count);
bi.NextInMemory();
}
}}} // namespaces
#ifndef LM_BUILDER_COMBINE_COUNTS_H
#define LM_BUILDER_COMBINE_COUNTS_H
#include "payload.hh"
#include "../common/ngram.hh"
#include "../common/compare.hh"
#include "../word_index.hh"
#include "../../util/stream/sort.hh"
#include <functional>
#include <string>
namespace lm {
namespace builder {
// Sum counts for the same n-gram.
struct CombineCounts {
bool operator()(void *first_void, const void *second_void, const SuffixOrder &compare) const {
NGram<BuildingPayload> first(first_void, compare.Order());
// There isn't a const version of NGram.
NGram<BuildingPayload> second(const_cast<void*>(second_void), compare.Order());
if (memcmp(first.begin(), second.begin(), sizeof(WordIndex) * compare.Order())) return false;
first.Value().count += second.Value().count;
return true;
}
};
} // namespace builder
} // namespace lm
#endif // LM_BUILDER_COMBINE_COUNTS_H
#include "corpus_count.hh"
#include "payload.hh"
#include "../common/ngram.hh"
#include "../lm_exception.hh"
#include "../vocab.hh"
#include "../word_index.hh"
#include "../../util/file_stream.hh"
#include "../../util/file.hh"
#include "../../util/file_piece.hh"
#include "../../util/murmur_hash.hh"
#include "../../util/probing_hash_table.hh"
#include "../../util/scoped.hh"
#include "../../util/stream/chain.hh"
#include "../../util/tokenize_piece.hh"
#include <functional>
#include <stdint.h>
namespace lm {
namespace builder {
namespace {
class DedupeHash : public std::unary_function<const WordIndex *, bool> {
public:
explicit DedupeHash(std::size_t order) : size_(order * sizeof(WordIndex)) {}
std::size_t operator()(const WordIndex *start) const {
return util::MurmurHashNative(start, size_);
}
private:
const std::size_t size_;
};
class DedupeEquals : public std::binary_function<const WordIndex *, const WordIndex *, bool> {
public:
explicit DedupeEquals(std::size_t order) : size_(order * sizeof(WordIndex)) {}
bool operator()(const WordIndex *first, const WordIndex *second) const {
return !memcmp(first, second, size_);
}
private:
const std::size_t size_;
};
struct DedupeEntry {
typedef WordIndex *Key;
Key GetKey() const { return key; }
void SetKey(WordIndex *to) { key = to; }
Key key;
static DedupeEntry Construct(WordIndex *at) {
DedupeEntry ret;
ret.key = at;
return ret;
}
};
// TODO: don't have this here, should be with probing hash table defaults?
const float kProbingMultiplier = 1.5;
typedef util::ProbingHashTable<DedupeEntry, DedupeHash, DedupeEquals> Dedupe;
class Writer {
public:
Writer(std::size_t order, const util::stream::ChainPosition &position, void *dedupe_mem, std::size_t dedupe_mem_size)
: block_(position), gram_(block_->Get(), order),
dedupe_invalid_(order, std::numeric_limits<WordIndex>::max()),
dedupe_(dedupe_mem, dedupe_mem_size, &dedupe_invalid_[0], DedupeHash(order), DedupeEquals(order)),
buffer_(new WordIndex[order - 1]),
block_size_(position.GetChain().BlockSize()) {
dedupe_.Clear();
assert(Dedupe::Size(position.GetChain().BlockSize() / position.GetChain().EntrySize(), kProbingMultiplier) == dedupe_mem_size);
if (order == 1) {
// Add special words. AdjustCounts is responsible if order != 1.
AddUnigramWord(kUNK);
AddUnigramWord(kBOS);
}
}
~Writer() {
block_->SetValidSize(reinterpret_cast<const uint8_t*>(gram_.begin()) - static_cast<const uint8_t*>(block_->Get()));
(++block_).Poison();
}
// Write context with a bunch of <s>
void StartSentence() {
for (WordIndex *i = gram_.begin(); i != gram_.end() - 1; ++i) {
*i = kBOS;
}
}
void Append(WordIndex word) {
*(gram_.end() - 1) = word;
Dedupe::MutableIterator at;
bool found = dedupe_.FindOrInsert(DedupeEntry::Construct(gram_.begin()), at);
if (found) {
// Already present.
NGram<BuildingPayload> already(at->key, gram_.Order());
++(already.Value().count);
// Shift left by one.
memmove(gram_.begin(), gram_.begin() + 1, sizeof(WordIndex) * (gram_.Order() - 1));
return;
}
// Complete the write.
gram_.Value().count = 1;
// Prepare the next n-gram.
if (reinterpret_cast<uint8_t*>(gram_.begin()) + gram_.TotalSize() != static_cast<uint8_t*>(block_->Get()) + block_size_) {
NGram<BuildingPayload> last(gram_);
gram_.NextInMemory();
std::copy(last.begin() + 1, last.end(), gram_.begin());
return;
}
// Block end. Need to store the context in a temporary buffer.
std::copy(gram_.begin() + 1, gram_.end(), buffer_.get());
dedupe_.Clear();
block_->SetValidSize(block_size_);
gram_.ReBase((++block_)->Get());
std::copy(buffer_.get(), buffer_.get() + gram_.Order() - 1, gram_.begin());
}
private:
void AddUnigramWord(WordIndex index) {
*gram_.begin() = index;
gram_.Value().count = 0;
gram_.NextInMemory();
if (gram_.Base() == static_cast<uint8_t*>(block_->Get()) + block_size_) {
block_->SetValidSize(block_size_);
gram_.ReBase((++block_)->Get());
}
}
util::stream::Link block_;
NGram<BuildingPayload> gram_;
// This is the memory behind the invalid value in dedupe_.
std::vector<WordIndex> dedupe_invalid_;
// Hash table combiner implementation.
Dedupe dedupe_;
// Small buffer to hold existing ngrams when shifting across a block boundary.
boost::scoped_array<WordIndex> buffer_;
const std::size_t block_size_;
};
} // namespace
float CorpusCount::DedupeMultiplier(std::size_t order) {
return kProbingMultiplier * static_cast<float>(sizeof(DedupeEntry)) / static_cast<float>(NGram<BuildingPayload>::TotalSize(order));
}
std::size_t CorpusCount::VocabUsage(std::size_t vocab_estimate) {
return ngram::GrowableVocab<ngram::WriteUniqueWords>::MemUsage(vocab_estimate);
}
CorpusCount::CorpusCount(util::FilePiece &from, int vocab_write, bool dynamic_vocab, uint64_t &token_count, WordIndex &type_count, std::vector<bool> &prune_words, const std::string& prune_vocab_filename, std::size_t entries_per_block, WarningAction disallowed_symbol)
: from_(from), vocab_write_(vocab_write), dynamic_vocab_(dynamic_vocab), token_count_(token_count), type_count_(type_count),
prune_words_(prune_words), prune_vocab_filename_(prune_vocab_filename),
dedupe_mem_size_(Dedupe::Size(entries_per_block, kProbingMultiplier)),
dedupe_mem_(util::MallocOrThrow(dedupe_mem_size_)),
disallowed_symbol_action_(disallowed_symbol) {
}
namespace {
void ComplainDisallowed(StringPiece word, WarningAction &action) {
switch (action) {
case SILENT:
return;
case COMPLAIN:
std::cerr << "Warning: " << word << " appears in the input. All instances of <s>, </s>, and <unk> will be interpreted as whitespace." << std::endl;
action = SILENT;
return;
case THROW_UP:
UTIL_THROW(FormatLoadException, "Special word " << word << " is not allowed in the corpus. I plan to support models containing <unk> in the future. Pass --skip_symbols to convert these symbols to whitespace.");
}
}
// Vocab ids are given in a precompiled hash table.
class VocabGiven {
public:
explicit VocabGiven(int fd) {
util::MapRead(util::POPULATE_OR_READ, fd, 0, util::CheckOverflow(util::SizeOrThrow(fd)), table_backing_);
// Leave space for header with size.
table_ = Table(static_cast<char*>(table_backing_.get()) + sizeof(uint64_t), table_backing_.size() - sizeof(uint64_t));
bos_ = FindOrInsert("<s>");
eos_ = FindOrInsert("</s>");
}
WordIndex FindOrInsert(const StringPiece &word) const {
Table::ConstIterator it;
if (table_.Find(util::MurmurHash64A(word.data(), word.size()), it)) {
return it->value;
} else {
return 0; // <unk>.
}
}
WordIndex Index(const StringPiece &word) const {
return FindOrInsert(word);
}
WordIndex Size() const {
return *static_cast<const uint64_t*>(table_backing_.get());
}
bool IsSpecial(WordIndex word) const {
return word == 0 || word == bos_ || word == eos_;
}
private:
util::scoped_memory table_backing_;
typedef util::ProbingHashTable<ngram::ProbingVocabularyEntry, util::IdentityHash> Table;
Table table_;
WordIndex bos_, eos_;
};
} // namespace
void CorpusCount::Run(const util::stream::ChainPosition &position) {
if (dynamic_vocab_) {
ngram::GrowableVocab<ngram::WriteUniqueWords> vocab(type_count_, vocab_write_);
RunWithVocab(position, vocab);
} else {
VocabGiven vocab(vocab_write_);
RunWithVocab(position, vocab);
}
}
template <class Vocab> void CorpusCount::RunWithVocab(const util::stream::ChainPosition &position, Vocab &vocab) {
token_count_ = 0;
type_count_ = 0;
const WordIndex end_sentence = vocab.FindOrInsert("</s>");
Writer writer(NGram<BuildingPayload>::OrderFromSize(position.GetChain().EntrySize()), position, dedupe_mem_.get(), dedupe_mem_size_);
uint64_t count = 0;
bool delimiters[256];
util::BoolCharacter::Build("\0\t\n\r ", delimiters);
StringPiece w;
while(true) {
writer.StartSentence();
while (from_.ReadWordSameLine(w, delimiters)) {
WordIndex word = vocab.FindOrInsert(w);
if (UTIL_UNLIKELY(vocab.IsSpecial(word))) {
ComplainDisallowed(w, disallowed_symbol_action_);
continue;
}
writer.Append(word);
++count;
}
if (!from_.ReadLineOrEOF(w)) break;
writer.Append(end_sentence);
}
token_count_ = count;
type_count_ = vocab.Size();
// Create list of unigrams that are supposed to be pruned
if (!prune_vocab_filename_.empty()) {
try {
util::FilePiece prune_vocab_file(prune_vocab_filename_.c_str());
prune_words_.resize(vocab.Size(), true);
try {
while (true) {
StringPiece word(prune_vocab_file.ReadDelimited(delimiters));
prune_words_[vocab.Index(word)] = false;
}
} catch (const util::EndOfFileException &e) {}
// Never prune <unk>, <s>, </s>
prune_words_[kUNK] = false;
prune_words_[kBOS] = false;
prune_words_[kEOS] = false;
} catch (const util::Exception &e) {
std::cerr << e.what() << std::endl;
abort();
}
}
}
} // namespace builder
} // namespace lm
#ifndef LM_BUILDER_CORPUS_COUNT_H
#define LM_BUILDER_CORPUS_COUNT_H
#include "../lm_exception.hh"
#include "../word_index.hh"
#include "../../util/scoped.hh"
#include <cstddef>
#include <string>
#include <stdint.h>
#include <vector>
namespace util {
class FilePiece;
namespace stream {
class ChainPosition;
} // namespace stream
} // namespace util
namespace lm {
namespace builder {
class CorpusCount {
public:
// Memory usage will be DedupeMultipler(order) * block_size + total_chain_size + unknown vocab_hash_size
static float DedupeMultiplier(std::size_t order);
// How much memory vocabulary will use based on estimated size of the vocab.
static std::size_t VocabUsage(std::size_t vocab_estimate);
// token_count: out.
// type_count aka vocabulary size. Initialize to an estimate. It is set to the exact value.
CorpusCount(util::FilePiece &from, int vocab_write, bool dynamic_vocab, uint64_t &token_count, WordIndex &type_count, std::vector<bool> &prune_words, const std::string& prune_vocab_filename, std::size_t entries_per_block, WarningAction disallowed_symbol);
void Run(const util::stream::ChainPosition &position);
private:
template <class Vocab> void RunWithVocab(const util::stream::ChainPosition &position, Vocab &vocab);
util::FilePiece &from_;
int vocab_write_;
bool dynamic_vocab_;
uint64_t &token_count_;
WordIndex &type_count_;
std::vector<bool> &prune_words_;
const std::string prune_vocab_filename_;
std::size_t dedupe_mem_size_;
util::scoped_malloc dedupe_mem_;
WarningAction disallowed_symbol_action_;
};
} // namespace builder
} // namespace lm
#endif // LM_BUILDER_CORPUS_COUNT_H
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment