Commit 688b6eac authored by SWHL's avatar SWHL
Browse files

Update files

parents
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ctc_beam_search_decoder.h"
#include <algorithm>
#include <cmath>
#include <future>
#include <iostream>
#include <limits>
#include <map>
#include <utility>
#include "ThreadPool/ThreadPool.h"
#include "decoder_utils.h"
#include "fst/fstlib.h"
#include "path_trie.h"
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
std::vector<std::pair<double, std::vector<int>>> ctc_beam_search_decoder(
const std::vector<std::vector<double>> &log_probs_seq,
const std::vector<std::vector<int>> &log_probs_idx, PathTrie &root,
const bool start, size_t beam_size, int blank_id, int space_id,
double cutoff_prob, Scorer *ext_scorer) {
if (start) {
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
auto fst_dict = static_cast<fst::StdVectorFst *>(ext_scorer->dictionary);
fst::StdVectorFst *dict_ptr = fst_dict->Copy(true);
root.set_dictionary(dict_ptr);
auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT);
root.set_matcher(matcher);
}
}
int timesteps = log_probs_seq.size();
std::vector<PathTrie *> prefixes;
// update log probs
if (root.log_prob_b_prev == -NUM_FLT_INF && start) {
root.score = root.log_prob_b_prev = 0.0;
}
root.iterate_to_vec_only(prefixes);
int prev_id = -1;
// prefix search over time
for (size_t time_step = 0; time_step < timesteps; ++time_step) {
float min_cutoff = -NUM_FLT_INF;
bool full_beam = false;
auto &log_prob = log_probs_seq[time_step];
auto &log_prob_idx = log_probs_idx[time_step];
double top_prob = exp(log_prob[0]);
auto top_id = log_prob_idx[0];
if (top_prob >= cutoff_prob && top_id == blank_id)
if (prev_id == blank_id) {
continue; // skip this round
} else
prev_id = top_id;
else
prev_id = -1;
// loop over chars
double cur_acc_prob = 0.0;
for (size_t index = 0; index < log_prob.size(); index++) {
auto c = log_prob_idx[index];
float log_prob_c = log_prob[index];
cur_acc_prob += exp(log_prob_c);
if (cur_acc_prob > cutoff_prob && index >= 1) break;
for (size_t i = 0; i < prefixes.size() && i < beam_size; ++i) {
auto prefix = prefixes[i];
if (full_beam && log_prob_c + prefix->score < min_cutoff) {
break;
}
// blank
if (c == blank_id) {
prefix->log_prob_b_cur =
log_sum_exp(prefix->log_prob_b_cur, log_prob_c + prefix->score);
continue;
}
// repeated character
if (c == prefix->character) {
prefix->log_prob_nb_cur = log_sum_exp(
prefix->log_prob_nb_cur, log_prob_c + prefix->log_prob_nb_prev);
}
// get new prefix
auto prefix_new = prefix->get_path_trie(c);
if (prefix_new != nullptr) {
float log_p = -NUM_FLT_INF;
if (c == prefix->character &&
prefix->log_prob_b_prev > -NUM_FLT_INF) {
log_p = log_prob_c + prefix->log_prob_b_prev;
} else if (c != prefix->character) {
log_p = log_prob_c + prefix->score;
}
// language model scoring
if (ext_scorer != nullptr &&
(c == space_id || ext_scorer->is_character_based())) {
PathTrie *prefix_to_score = nullptr;
// skip scoring the space
if (ext_scorer->is_character_based()) {
prefix_to_score = prefix_new;
} else {
prefix_to_score = prefix;
}
float score = 0.0;
std::vector<std::string> ngram;
ngram = ext_scorer->make_ngram(prefix_to_score);
score = ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha;
log_p += score;
log_p += ext_scorer->beta;
}
prefix_new->log_prob_nb_cur =
log_sum_exp(prefix_new->log_prob_nb_cur, log_p);
}
} // end of loop over prefix
} // end of loop over vocabulary
prefixes.clear();
// update log probs
root.iterate_to_vec(prefixes);
// only preserve top beam_size prefixes
if (prefixes.size() >= beam_size) {
std::nth_element(prefixes.begin(), prefixes.begin() + beam_size,
prefixes.end(), prefix_compare);
for (size_t i = beam_size; i < prefixes.size(); ++i) {
prefixes[i]->remove();
}
}
} // end of loop over time
size_t num_prefixes = std::min(prefixes.size(), beam_size);
std::sort(prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare);
return get_beam_search_result(prefixes, beam_size);
}
std::string map_sent(const std::vector<int> &sent,
const std::vector<std::string> &vocabulary, bool greedy,
int blank_id) {
std::string output_str;
if (!greedy) {
for (size_t j = 0; j < sent.size(); j++) {
output_str += vocabulary[sent[j]];
}
} else {
// greedy search
int prev = -1;
for (size_t i = 0; i < sent.size(); i++) {
int cur = sent[i];
if (cur != prev && cur != blank_id) output_str += vocabulary[cur];
prev = cur;
}
}
return output_str;
}
std::vector<std::string> map_batch(
const std::vector<std::vector<int>> &batch_sents,
const std::vector<std::string> &vocabulary, size_t num_processes,
bool greedy, int blank_id) {
ThreadPool pool(num_processes);
size_t batch_size = batch_sents.size();
std::vector<std::future<std::string>> res;
for (size_t i = 0; i < batch_size; ++i) {
res.emplace_back(pool.enqueue(map_sent, std::ref(batch_sents[i]),
std::ref(vocabulary), greedy, blank_id));
}
// get decoding results
std::vector<std::string> batch_results;
for (size_t i = 0; i < batch_size; ++i) {
batch_results.emplace_back(res[i].get());
}
return batch_results;
}
std::vector<std::vector<std::pair<double, std::vector<int>>>>
ctc_beam_search_decoder_batch(
const std::vector<std::vector<std::vector<double>>> &batch_log_probs_seq,
const std::vector<std::vector<std::vector<int>>> &batch_log_probs_idx,
std::vector<PathTrie *> &batch_root_trie,
const std::vector<bool> &batch_start, size_t beam_size,
size_t num_processes, int blank_id, int space_id, double cutoff_prob,
Scorer *ext_scorer) {
// thread pool
ThreadPool pool(num_processes);
// number of samples
size_t batch_size = batch_log_probs_seq.size();
// enqueue the tasks of decoding
std::vector<std::future<std::vector<std::pair<double, std::vector<int>>>>>
res;
for (size_t i = 0; i < batch_size; ++i) {
res.emplace_back(
pool.enqueue(ctc_beam_search_decoder, std::ref(batch_log_probs_seq[i]),
std::ref(batch_log_probs_idx[i]),
std::ref(*batch_root_trie[i]), batch_start[i], beam_size,
blank_id, space_id, cutoff_prob, ext_scorer));
}
// get decoding results
std::vector<std::vector<std::pair<double, std::vector<int>>>> batch_results;
for (size_t i = 0; i < batch_size; ++i) {
batch_results.emplace_back(res[i].get());
}
return batch_results;
}
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef CTC_BEAM_SEARCH_DECODER_H_
#define CTC_BEAM_SEARCH_DECODER_H_
#include <string>
#include <utility>
#include <vector>
#include "path_trie.h"
#include "scorer.h"
/* CTC Beam Search Decoder
* Parameters:
* log_probs_seq: 2-D vector that each element is a vector of log
probabilities
* for one time step, it is sorted (topk)
* log_probs_idx: 2-D vector that the index of every element in
log_probs_seq
* topk index
* root: A PathTrie root
* start: whether this the first chunk of this sequence
* beam_size: The width of beam search.
* blank_id: default is 0
* space_id: default is -1
* cutoff_prob: Cutoff probability for pruning.
* ext_scorer: External scorer to evaluate a prefix, which consists of
* n-gram language model scoring and word insertion term.
* Default null, decoding the input sample without scorer.
* Return:
* A vector that each element is a pair of score and decoding result,
* in desending order.
*/
std::vector<std::pair<double, std::vector<int>>> ctc_beam_search_decoder(
const std::vector<std::vector<double>> &log_probs_seq,
const std::vector<std::vector<int>> &log_probs_idx, PathTrie &root,
const bool start, size_t beam_size, int blank_id = 0, int space_id = -1,
double cutoff_prob = 0.999, Scorer *ext_scorer = nullptr);
/* CTC Beam Search Decoder for batch data
* Parameters:
* batch_log_probs_seq: 3-D vector that each element is a 2-D vector that
can be used
* by ctc_beam_search_decoder().
* batch_log_probs_idx: 3-D vector that each element is a 2-D vector that
can be used
* by ctc_beam_search_decoder().
* batch_root_trie: a batch of Path trie for each sequence
* batch_start: a batch of boolean value to indicate whether this is the
first
* chunk of each sequence
* beam_size: The width of beam search.
* num_processes: Number of threads for beam search.
* blank_id: default blank_id is 0
* space_id: default space_id is -1, this is for word based scorer
* cutoff_prob: Cutoff probability for pruning.
* ext_scorer: External scorer to evaluate a prefix, which consists of
* n-gram language model scoring and word insertion term.
* Default null, decoding the input sample without scorer.
* Return:
* A 2-D vector that each element is a vector of beam search decoding
* result for one audio sample.
*/
std::vector<std::vector<std::pair<double, std::vector<int>>>>
ctc_beam_search_decoder_batch(
const std::vector<std::vector<std::vector<double>>> &batch_log_probs_seq,
const std::vector<std::vector<std::vector<int>>> &batch_log_probs_idx,
std::vector<PathTrie *> &batch_root_trie,
const std::vector<bool> &batch_start, size_t beam_size,
size_t num_processes, int blank_id = 0, int space_id = -1,
double cutoff_prob = 0.999, Scorer *ext_scorer = nullptr);
/* Map vector of int to string
* Parameters:
* sent: a vector of int ids
* vocabulary: vocabulary
* Return:
* A decoded string
*/
std::string map_sent(const std::vector<int> &sent,
const std::vector<std::string> &vocabulary,
bool greedy = false, int blank_id = 0);
/* Map batch vector of int to string
* Parameters:
* batch_sents: a batch of vector of int ids
* vocabulary: vocabulary
* num_processes: number of processes to use
* Return:
* A vector decoded string
*/
std::vector<std::string> map_batch(
const std::vector<std::vector<int>> &batch_sents,
const std::vector<std::string> &vocabulary, size_t num_processes,
bool greedy = false, int blank_id = 0);
#endif // CTC_BEAM_SEARCH_DECODER_H_
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "decoder_utils.h"
#include <algorithm>
#include <cmath>
#include <limits>
std::vector<std::pair<size_t, float>> get_pruned_log_probs(
const std::vector<double> &prob_step, double cutoff_prob,
size_t cutoff_top_n) {
std::vector<std::pair<int, double>> prob_idx;
for (size_t i = 0; i < prob_step.size(); ++i) {
prob_idx.push_back(std::pair<int, double>(i, prob_step[i]));
}
// pruning of vacobulary
size_t cutoff_len = prob_step.size();
if (cutoff_prob < 1.0 || cutoff_top_n < cutoff_len) {
std::sort(prob_idx.begin(), prob_idx.end(),
pair_comp_second_rev<int, double>);
if (cutoff_prob < 1.0) {
double cum_prob = 0.0;
cutoff_len = 0;
for (size_t i = 0; i < prob_idx.size(); ++i) {
cum_prob += prob_idx[i].second;
cutoff_len += 1;
if (cum_prob >= cutoff_prob || cutoff_len >= cutoff_top_n) break;
}
}
prob_idx = std::vector<std::pair<int, double>>(
prob_idx.begin(), prob_idx.begin() + cutoff_len);
}
std::vector<std::pair<size_t, float>> log_prob_idx;
for (size_t i = 0; i < cutoff_len; ++i) {
log_prob_idx.push_back(std::pair<int, float>(
prob_idx[i].first, log(prob_idx[i].second + NUM_FLT_MIN)));
}
return log_prob_idx;
}
std::vector<std::pair<double, std::vector<int>>> get_beam_search_result(
const std::vector<PathTrie *> &prefixes, size_t beam_size) {
// allow for the post processing
std::vector<PathTrie *> space_prefixes;
if (space_prefixes.empty()) {
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
space_prefixes.push_back(prefixes[i]);
}
}
std::sort(space_prefixes.begin(), space_prefixes.end(), prefix_compare);
std::vector<std::pair<double, std::vector<int>>> output_vecs;
for (size_t i = 0; i < beam_size && i < space_prefixes.size(); ++i) {
std::vector<int> output;
space_prefixes[i]->get_path_vec(output);
// convert index to string
std::pair<double, std::vector<int>> output_pair(space_prefixes[i]->score,
output);
output_vecs.emplace_back(output_pair);
}
return output_vecs;
}
size_t get_utf8_str_len(const std::string &str) {
size_t str_len = 0;
for (char c : str) {
str_len += ((c & 0xc0) != 0x80);
}
return str_len;
}
std::vector<std::string> split_utf8_str(const std::string &str) {
std::vector<std::string> result;
std::string out_str;
for (char c : str) {
if ((c & 0xc0) != 0x80) // new UTF-8 character
{
if (!out_str.empty()) {
result.push_back(out_str);
out_str.clear();
}
}
out_str.append(1, c);
}
result.push_back(out_str);
return result;
}
std::vector<std::string> split_str(const std::string &s,
const std::string &delim) {
std::vector<std::string> result;
std::size_t start = 0, delim_len = delim.size();
while (true) {
std::size_t end = s.find(delim, start);
if (end == std::string::npos) {
if (start < s.size()) {
result.push_back(s.substr(start));
}
break;
}
if (end > start) {
result.push_back(s.substr(start, end - start));
}
start = end + delim_len;
}
return result;
}
bool prefix_compare(const PathTrie *x, const PathTrie *y) {
if (x->score == y->score) {
if (x->character == y->character) {
return false;
} else {
return (x->character < y->character);
}
} else {
return x->score > y->score;
}
}
void add_word_to_fst(const std::vector<int> &word,
fst::StdVectorFst *dictionary) {
if (dictionary->NumStates() == 0) {
fst::StdVectorFst::StateId start = dictionary->AddState();
assert(start == 0);
dictionary->SetStart(start);
}
fst::StdVectorFst::StateId src = dictionary->Start();
fst::StdVectorFst::StateId dst;
for (auto c : word) {
dst = dictionary->AddState();
dictionary->AddArc(src, fst::StdArc(c, c, 0, dst));
src = dst;
}
dictionary->SetFinal(dst, fst::StdArc::Weight::One());
}
bool add_word_to_dictionary(
const std::string &word,
const std::unordered_map<std::string, int> &char_map, bool add_space,
int SPACE_ID, fst::StdVectorFst *dictionary) {
auto characters = split_utf8_str(word);
std::vector<int> int_word;
for (auto &c : characters) {
if (c == " ") {
int_word.push_back(SPACE_ID);
} else {
auto int_c = char_map.find(c);
if (int_c != char_map.end()) {
int_word.push_back(int_c->second);
} else {
return false; // return without adding
}
}
}
if (add_space) {
int_word.push_back(SPACE_ID);
}
add_word_to_fst(int_word, dictionary);
return true; // return with successful adding
}
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef DECODER_UTILS_H_
#define DECODER_UTILS_H_
#include <utility>
#include "fst/log.h"
#include "path_trie.h"
const float NUM_FLT_INF = std::numeric_limits<float>::max();
const float NUM_FLT_MIN = std::numeric_limits<float>::min();
// inline function for validation check
inline void check(bool x, const char *expr, const char *file, int line,
const char *err) {
if (!x) {
std::cout << "[" << file << ":" << line << "] ";
LOG(FATAL) << "\"" << expr << "\" check failed. " << err;
}
}
#define VALID_CHECK(x, info) \
check(static_cast<bool>(x), #x, __FILE__, __LINE__, info)
#define VALID_CHECK_EQ(x, y, info) VALID_CHECK((x) == (y), info)
#define VALID_CHECK_GT(x, y, info) VALID_CHECK((x) > (y), info)
#define VALID_CHECK_LT(x, y, info) VALID_CHECK((x) < (y), info)
// Function template for comparing two pairs
template <typename T1, typename T2>
bool pair_comp_first_rev(const std::pair<T1, T2> &a,
const std::pair<T1, T2> &b) {
return a.first > b.first;
}
// Function template for comparing two pairs
template <typename T1, typename T2>
bool pair_comp_second_rev(const std::pair<T1, T2> &a,
const std::pair<T1, T2> &b) {
return a.second > b.second;
}
// Return the sum of two probabilities in log scale
template <typename T>
T log_sum_exp(const T &x, const T &y) {
static T num_min = -std::numeric_limits<T>::max();
if (x <= num_min) return y;
if (y <= num_min) return x;
T xmax = std::max(x, y);
return std::log(std::exp(x - xmax) + std::exp(y - xmax)) + xmax;
}
// Get pruned probability vector for each time step's beam search
std::vector<std::pair<size_t, float>> get_pruned_log_probs(
const std::vector<double> &prob_step, double cutoff_prob,
size_t cutoff_top_n);
// Get beam search result from prefixes in trie tree
std::vector<std::pair<double, std::string>> get_beam_search_result(
const std::vector<PathTrie *> &prefixes,
const std::vector<std::string> &vocabulary, size_t beam_size);
std::vector<std::pair<double, std::vector<int>>> get_beam_search_result(
const std::vector<PathTrie *> &prefixes, size_t beam_size);
// Functor for prefix comparsion
bool prefix_compare(const PathTrie *x, const PathTrie *y);
/* Get length of utf8 encoding string
* See: http://stackoverflow.com/a/4063229
*/
size_t get_utf8_str_len(const std::string &str);
/* Split a string into a list of strings on a given string
* delimiter. NB: delimiters on beginning / end of string are
* trimmed. Eg, "FooBarFoo" split on "Foo" returns ["Bar"].
*/
std::vector<std::string> split_str(const std::string &s,
const std::string &delim);
/* Splits string into vector of strings representing
* UTF-8 characters (not same as chars)
*/
std::vector<std::string> split_utf8_str(const std::string &str);
// Add a word in index to the dicionary of fst
void add_word_to_fst(const std::vector<int> &word,
fst::StdVectorFst *dictionary);
// Add a word in string to dictionary
bool add_word_to_dictionary(
const std::string &word,
const std::unordered_map<std::string, int> &char_map, bool add_space,
int SPACE_ID, fst::StdVectorFst *dictionary);
#endif // DECODER_UTILS_H
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "path_trie.h"
#include <algorithm>
#include <limits>
#include <memory>
#include <utility>
#include <vector>
#include "decoder_utils.h"
PathTrie::PathTrie() {
log_prob_b_prev = -NUM_FLT_INF;
log_prob_nb_prev = -NUM_FLT_INF;
log_prob_b_cur = -NUM_FLT_INF;
log_prob_nb_cur = -NUM_FLT_INF;
score = -NUM_FLT_INF;
ROOT_ = -1;
character = ROOT_;
exists_ = true;
parent = nullptr;
dictionary_ = nullptr;
dictionary_state_ = 0;
has_dictionary_ = false;
matcher_ = nullptr;
}
PathTrie::~PathTrie() {
for (auto child : children_) {
delete child.second;
}
}
PathTrie* PathTrie::get_path_trie(int new_char, bool reset) {
auto child = children_.begin();
for (child = children_.begin(); child != children_.end(); ++child) {
if (child->first == new_char) {
break;
}
}
if (child != children_.end()) {
if (!child->second->exists_) {
child->second->exists_ = true;
child->second->log_prob_b_prev = -NUM_FLT_INF;
child->second->log_prob_nb_prev = -NUM_FLT_INF;
child->second->log_prob_b_cur = -NUM_FLT_INF;
child->second->log_prob_nb_cur = -NUM_FLT_INF;
}
return (child->second);
} else {
if (has_dictionary_) {
matcher_->SetState(dictionary_state_);
bool found = matcher_->Find(new_char + 1);
if (!found) {
// Adding this character causes word outside dictionary
auto FSTZERO = fst::TropicalWeight::Zero();
auto final_weight = dictionary_->Final(dictionary_state_);
bool is_final = (final_weight != FSTZERO);
if (is_final && reset) {
dictionary_state_ = dictionary_->Start();
}
return nullptr;
} else {
PathTrie* new_path = new PathTrie;
new_path->character = new_char;
new_path->parent = this;
new_path->dictionary_ = dictionary_;
new_path->dictionary_state_ = matcher_->Value().nextstate;
new_path->has_dictionary_ = true;
new_path->matcher_ = matcher_;
children_.push_back(std::make_pair(new_char, new_path));
return new_path;
}
} else {
PathTrie* new_path = new PathTrie;
new_path->character = new_char;
new_path->parent = this;
children_.push_back(std::make_pair(new_char, new_path));
return new_path;
}
}
}
PathTrie* PathTrie::get_path_vec(std::vector<int>& output) {
return get_path_vec(output, ROOT_);
}
PathTrie* PathTrie::get_path_vec(std::vector<int>& output, int stop,
size_t max_steps) {
if (character == stop || character == ROOT_ || output.size() == max_steps) {
std::reverse(output.begin(), output.end());
return this;
} else {
output.push_back(character);
return parent->get_path_vec(output, stop, max_steps);
}
}
void PathTrie::iterate_to_vec_only(std::vector<PathTrie*>& output) {
if (exists_) {
output.push_back(this);
}
for (auto child : children_) {
child.second->iterate_to_vec_only(output);
}
}
void PathTrie::iterate_to_vec(std::vector<PathTrie*>& output) {
if (exists_) {
log_prob_b_prev = log_prob_b_cur;
log_prob_nb_prev = log_prob_nb_cur;
log_prob_b_cur = -NUM_FLT_INF;
log_prob_nb_cur = -NUM_FLT_INF;
score = log_sum_exp(log_prob_b_prev, log_prob_nb_prev);
output.push_back(this);
}
for (auto child : children_) {
child.second->iterate_to_vec(output);
}
}
void PathTrie::remove() {
exists_ = false;
if (children_.size() == 0) {
auto child = parent->children_.begin();
for (child = parent->children_.begin(); child != parent->children_.end();
++child) {
if (child->first == character) {
parent->children_.erase(child);
break;
}
}
if (parent->children_.size() == 0 && !parent->exists_) {
parent->remove();
}
delete this;
}
}
void PathTrie::set_dictionary(fst::StdVectorFst* dictionary) {
dictionary_ = dictionary;
dictionary_state_ = dictionary->Start();
has_dictionary_ = true;
}
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
void PathTrie::set_matcher(std::shared_ptr<FSTMATCH> matcher) {
matcher_ = matcher;
}
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef PATH_TRIE_H
#define PATH_TRIE_H
#include <algorithm>
#include <limits>
#include <memory>
#include <utility>
#include <vector>
#include "fst/fstlib.h"
/* Trie tree for prefix storing and manipulating, with a dictionary in
* finite-state transducer for spelling correction.
*/
class PathTrie {
public:
PathTrie();
~PathTrie();
// get new prefix after appending new char
PathTrie* get_path_trie(int new_char, bool reset = true);
// get the prefix in index from root to current node
PathTrie* get_path_vec(std::vector<int>& output);
// get the prefix in index from some stop node to current nodel
PathTrie* get_path_vec(std::vector<int>& output, int stop,
size_t max_steps = std::numeric_limits<size_t>::max());
// update log probs
void iterate_to_vec(std::vector<PathTrie*>& output);
void iterate_to_vec_only(std::vector<PathTrie*>& output);
// set dictionary for FST
void set_dictionary(fst::StdVectorFst* dictionary);
void set_matcher(std::shared_ptr<fst::SortedMatcher<fst::StdVectorFst>>);
bool is_empty() { return ROOT_ == character; }
// remove current path from root
void remove();
float log_prob_b_prev;
float log_prob_nb_prev;
float log_prob_b_cur;
float log_prob_nb_cur;
float score;
float approx_ctc;
int character;
PathTrie* parent;
private:
int ROOT_;
bool exists_;
bool has_dictionary_;
std::vector<std::pair<int, PathTrie*>> children_;
// pointer to dictionary of FST
fst::StdVectorFst* dictionary_;
fst::StdVectorFst::StateId dictionary_state_;
// true if finding ars in FST
std::shared_ptr<fst::SortedMatcher<fst::StdVectorFst>> matcher_;
};
#endif // PATH_TRIE_H
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "scorer.h"
#include <unistd.h>
#include <iostream>
#include "lm/config.hh"
#include "lm/model.hh"
#include "lm/state.hh"
#include "util/string_piece.hh"
#include "util/tokenize_piece.hh"
#include "decoder_utils.h"
using namespace lm::ngram;
Scorer::Scorer(double alpha, double beta, const std::string& lm_path,
const std::vector<std::string>& vocab_list) {
this->alpha = alpha;
this->beta = beta;
dictionary = nullptr;
is_character_based_ = true;
language_model_ = nullptr;
max_order_ = 0;
dict_size_ = 0;
SPACE_ID_ = -1;
setup(lm_path, vocab_list);
}
Scorer::~Scorer() {
if (language_model_ != nullptr) {
delete static_cast<lm::base::Model*>(language_model_);
}
if (dictionary != nullptr) {
delete static_cast<fst::StdVectorFst*>(dictionary);
}
}
void Scorer::setup(const std::string& lm_path,
const std::vector<std::string>& vocab_list) {
// load language model
load_lm(lm_path);
// set char map for scorer
set_char_map(vocab_list);
// fill the dictionary for FST
if (!is_character_based()) {
fill_dictionary(true);
}
}
void Scorer::load_lm(const std::string& lm_path) {
const char* filename = lm_path.c_str();
VALID_CHECK_EQ(access(filename, F_OK), 0, "Invalid language model path");
RetriveStrEnumerateVocab enumerate;
lm::ngram::Config config;
config.enumerate_vocab = &enumerate;
language_model_ = lm::ngram::LoadVirtual(filename, config);
max_order_ = static_cast<lm::base::Model*>(language_model_)->Order();
vocabulary_ = enumerate.vocabulary;
for (size_t i = 0; i < vocabulary_.size(); ++i) {
if (is_character_based_ && vocabulary_[i] != UNK_TOKEN &&
vocabulary_[i] != START_TOKEN && vocabulary_[i] != END_TOKEN &&
get_utf8_str_len(enumerate.vocabulary[i]) > 1) {
is_character_based_ = false;
}
}
}
double Scorer::get_log_cond_prob(const std::vector<std::string>& words) {
lm::base::Model* model = static_cast<lm::base::Model*>(language_model_);
double cond_prob;
lm::ngram::State state, tmp_state, out_state;
// avoid to inserting <s> in begin
model->NullContextWrite(&state);
for (size_t i = 0; i < words.size(); ++i) {
lm::WordIndex word_index = model->BaseVocabulary().Index(words[i]);
// encounter OOV
if (word_index == 0) {
return OOV_SCORE;
}
cond_prob = model->BaseScore(&state, word_index, &out_state);
tmp_state = state;
state = out_state;
out_state = tmp_state;
}
// return log10 prob
return cond_prob;
}
double Scorer::get_sent_log_prob(const std::vector<std::string>& words) {
std::vector<std::string> sentence;
if (words.size() == 0) {
for (size_t i = 0; i < max_order_; ++i) {
sentence.push_back(START_TOKEN);
}
} else {
for (size_t i = 0; i < max_order_ - 1; ++i) {
sentence.push_back(START_TOKEN);
}
sentence.insert(sentence.end(), words.begin(), words.end());
}
sentence.push_back(END_TOKEN);
return get_log_prob(sentence);
}
double Scorer::get_log_prob(const std::vector<std::string>& words) {
assert(words.size() > max_order_);
double score = 0.0;
for (size_t i = 0; i < words.size() - max_order_ + 1; ++i) {
std::vector<std::string> ngram(words.begin() + i,
words.begin() + i + max_order_);
score += get_log_cond_prob(ngram);
}
return score;
}
void Scorer::reset_params(float alpha, float beta) {
this->alpha = alpha;
this->beta = beta;
}
std::string Scorer::vec2str(const std::vector<int>& input) {
std::string word;
for (auto ind : input) {
word += char_list_[ind];
}
return word;
}
std::vector<std::string> Scorer::split_labels(const std::vector<int>& labels) {
if (labels.empty()) return {};
std::string s = vec2str(labels);
std::vector<std::string> words;
if (is_character_based_) {
words = split_utf8_str(s);
} else {
words = split_str(s, " ");
}
return words;
}
void Scorer::set_char_map(const std::vector<std::string>& char_list) {
char_list_ = char_list;
char_map_.clear();
// Set the char map for the FST for spelling correction
for (size_t i = 0; i < char_list_.size(); i++) {
if (char_list_[i] == " ") {
SPACE_ID_ = i;
}
// The initial state of FST is state 0, hence the index of chars in
// the FST should start from 1 to avoid the conflict with the initial
// state, otherwise wrong decoding results would be given.
char_map_[char_list_[i]] = i + 1;
}
}
std::vector<std::string> Scorer::make_ngram(PathTrie* prefix) {
std::vector<std::string> ngram;
PathTrie* current_node = prefix;
PathTrie* new_node = nullptr;
for (int order = 0; order < max_order_; order++) {
std::vector<int> prefix_vec;
if (is_character_based_) {
new_node = current_node->get_path_vec(prefix_vec, SPACE_ID_, 1);
current_node = new_node;
} else {
new_node = current_node->get_path_vec(prefix_vec, SPACE_ID_);
current_node = new_node->parent; // Skipping spaces
}
// reconstruct word
std::string word = vec2str(prefix_vec);
ngram.push_back(word);
if (new_node->character == -1) {
// No more spaces, but still need order
for (int i = 0; i < max_order_ - order - 1; i++) {
ngram.push_back(START_TOKEN);
}
break;
}
}
std::reverse(ngram.begin(), ngram.end());
return ngram;
}
void Scorer::fill_dictionary(bool add_space) {
fst::StdVectorFst dictionary;
// For each unigram convert to ints and put in trie
int dict_size = 0;
for (const auto& word : vocabulary_) {
bool added = add_word_to_dictionary(word, char_map_, add_space,
SPACE_ID_ + 1, &dictionary);
dict_size += added ? 1 : 0;
}
dict_size_ = dict_size;
/* Simplify FST
* This gets rid of "epsilon" transitions in the FST.
* These are transitions that don't require a string input to be taken.
* Getting rid of them is necessary to make the FST determinisitc, but
* can greatly increase the size of the FST
*/
fst::RmEpsilon(&dictionary);
fst::StdVectorFst* new_dict = new fst::StdVectorFst;
/* This makes the FST deterministic, meaning for any string input there's
* only one possible state the FST could be in. It is assumed our
* dictionary is deterministic when using it.
* (lest we'd have to check for multiple transitions at each state)
*/
fst::Determinize(dictionary, new_dict);
/* Finds the simplest equivalent fst. This is unnecessary but decreases
* memory usage of the dictionary
*/
fst::Minimize(new_dict);
this->dictionary = new_dict;
}
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SCORER_H_
#define SCORER_H_
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "lm/enumerate_vocab.hh"
#include "lm/virtual_interface.hh"
#include "lm/word_index.hh"
#include "util/string_piece.hh"
#include "path_trie.h"
const double OOV_SCORE = -1000.0;
const std::string START_TOKEN = "<s>";
const std::string UNK_TOKEN = "<unk>";
const std::string END_TOKEN = "</s>";
// Implement a callback to retrive the dictionary of language model.
class RetriveStrEnumerateVocab : public lm::EnumerateVocab {
public:
RetriveStrEnumerateVocab() {}
void Add(lm::WordIndex index, const StringPiece &str) {
vocabulary.push_back(std::string(str.data(), str.length()));
}
std::vector<std::string> vocabulary;
};
/* External scorer to query score for n-gram or sentence, including language
* model scoring and word insertion.
*
* Example:
* Scorer scorer(alpha, beta, "path_of_language_model");
* scorer.get_log_cond_prob({ "WORD1", "WORD2", "WORD3" });
* scorer.get_sent_log_prob({ "WORD1", "WORD2", "WORD3" });
*/
class Scorer {
public:
Scorer(double alpha, double beta, const std::string &lm_path,
const std::vector<std::string> &vocabulary);
~Scorer();
double get_log_cond_prob(const std::vector<std::string> &words);
double get_sent_log_prob(const std::vector<std::string> &words);
// return the max order
size_t get_max_order() const { return max_order_; }
// return the dictionary size of language model
size_t get_dict_size() const { return dict_size_; }
// retrun true if the language model is character based
bool is_character_based() const { return is_character_based_; }
// reset params alpha & beta
void reset_params(float alpha, float beta);
// make ngram for a given prefix
std::vector<std::string> make_ngram(PathTrie *prefix);
// trransform the labels in index to the vector of words (word based lm) or
// the vector of characters (character based lm)
std::vector<std::string> split_labels(const std::vector<int> &labels);
// language model weight
double alpha;
// word insertion weight
double beta;
// pointer to the dictionary of FST
void *dictionary;
protected:
// necessary setup: load language model, set char map, fill FST's dictionary
void setup(const std::string &lm_path,
const std::vector<std::string> &vocab_list);
// load language model from given path
void load_lm(const std::string &lm_path);
// fill dictionary for FST
void fill_dictionary(bool add_space);
// set char map
void set_char_map(const std::vector<std::string> &char_list);
double get_log_prob(const std::vector<std::string> &words);
// translate the vector in index to string
std::string vec2str(const std::vector<int> &input);
private:
void *language_model_;
bool is_character_based_;
size_t max_order_;
size_t dict_size_;
int SPACE_ID_;
std::vector<std::string> char_list_;
std::unordered_map<std::string, int> char_map_;
std::vector<std::string> vocabulary_;
};
#endif // SCORER_H_
#include "path_trie.h"
#include "scorer.h"
#include "decoder_utils.h"
#include "ctc_beam_search_decoder.h"
int main()
{
return 0;
}
\ No newline at end of file
util/file_piece.cc.gz
*.swp
*.o
doc/
build/
/bin
/lib
/tests
._*
windows/Win32
windows/x64
windows/*.user
windows/*.sdf
windows/*.opensdf
windows/*.suo
CMakeFiles
cmake_install.cmake
CMakeCache.txt
CTestTestfile.cmake
DartConfiguration.tcl
Makefile
KenLM has switched to cmake
cmake .
make -j 4
But they recommend building out of tree
mkdir -p build && cd build
cmake ..
make -j 4
If you only want the query code and do not care about compression (.gz, .bz2, and .xz):
./compile_query_only.sh
Windows:
The windows directory has visual studio files. Note that you need to compile
the kenlm project before build_binary and ngram_query projects.
OSX:
Missing dependencies can be remedied with brew.
brew install cmake boost eigen
Debian/Ubuntu:
sudo apt install build-essential cmake libboost-system-dev libboost-thread-dev libboost-program-options-dev libboost-test-dev libeigen3-dev zlib1g-dev libbz2-dev liblzma-dev
cmake_minimum_required(VERSION 3.1)
if (WIN32)
set(Boost_USE_STATIC_LIBS OFF)
# The auto-linking feature has problems with USE_STATIC_LIBS off, so we use
# BOOST_ALL_NO_LIB to turn it off.
# Several boost libraries headers aren't configured correctly if
# USE_STATIC_LIBS is off, so we explicitly say they are dynamic with the
# remaining definitions.
add_definitions(-DBOOST_ALL_NO_LIB -DBOOST_PROGRAM_OPTIONS_DYN_LINK -DBOOST_IOSTREAMS_DYN_LINK -DBOOST_THREAD_DYN_LINK)
endif( )
# Define a single cmake project
project(kenlm)
option(FORCE_STATIC "Build static executables" OFF)
option(COMPILE_TESTS "Compile tests" OFF)
option(ENABLE_PYTHON "Build Python bindings" OFF)
# Eigen3 less than 3.1.0 has a race condition: http://eigen.tuxfamily.org/bz/show_bug.cgi?id=466
find_package(Eigen3 3.1.0 CONFIG)
include(CMakeDependentOption)
cmake_dependent_option(ENABLE_INTERPOLATE "Build interpolation program (depends on Eigen3)" ON "EIGEN3_FOUND AND NOT WIN32" OFF)
if (FORCE_STATIC)
#presumably overkill, is there a better way?
#http://cmake.3232098.n2.nabble.com/Howto-compile-static-executable-td5580269.html
set(Boost_USE_STATIC_LIBS ON)
set_property(GLOBAL PROPERTY LINK_SEARCH_START_STATIC ON)
set_property(GLOBAL PROPERTY LINK_SEARCH_END_STATIC ON)
set(BUILD_SHARED_LIBRARIES OFF)
if (MSVC)
set(flag_vars
CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE
CMAKE_CXX_FLAGS_MINSIZEREL CMAKE_CXX_FLAGS_RELWITHDEBINFO
CMAKE_C_FLAGS CMAKE_C_FLAGS_DEBUG CMAKE_C_FLAGS_RELEASE
CMAKE_C_FLAGS_MINSIZEREL CMAKE_C_FLAGS_RELWITHDEBINFO)
foreach(flag_var ${flag_vars})
if(${flag_var} MATCHES "/MD")
string(REGEX REPLACE "/MD" "/MT" ${flag_var} "${${flag_var}}")
endif(${flag_var} MATCHES "/MD")
endforeach(flag_var)
else (MSVC)
if (NOT CMAKE_C_COMPILER_ID MATCHES ".*Clang")
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -static-libgcc -static-libstdc++ -static")
endif ()
set(CMAKE_FIND_LIBRARY_SUFFIXES ".a")
endif ()
#Annoyingly the exectuables say "File not found" unless these are set
set(CMAKE_EXE_LINK_DYNAMIC_C_FLAGS)
set(CMAKE_EXE_LINK_DYNAMIC_CXX_FLAGS)
set(CMAKE_SHARED_LIBRARY_C_FLAGS)
set(CMAKE_SHARED_LIBRARY_CXX_FLAGS)
set(CMAKE_SHARED_LIBRARY_LINK_C_FLAGS)
set(CMAKE_SHARED_LIBRARY_LINK_CXX_FLAGS)
endif ()
# Compile all executables into bin/
set(EXECUTABLE_OUTPUT_PATH ${PROJECT_BINARY_DIR}/bin)
# Compile all libraries into lib/
set(LIBRARY_OUTPUT_PATH ${PROJECT_BINARY_DIR}/lib)
if (NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE Release)
endif()
if (COMPILE_TESTS)
# Tell cmake that we want unit tests to be compiled
include(CTest)
enable_testing()
endif()
# Add our CMake helper functions
include(cmake/KenLMFunctions.cmake)
if(MSVC)
set(CMAKE_C_FLAGS "${CMAKE_CXX_FLAGS} /w34716")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /w34716")
endif()
# And our helper modules
list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/modules)
# We need boost
find_package(Boost 1.41.0 REQUIRED COMPONENTS
program_options
system
thread
unit_test_framework
)
# Define where include files live
include_directories(${Boost_INCLUDE_DIRS})
set(THREADS_PREFER_PTHREAD_FLAG ON)
find_package(Threads REQUIRED)
# Process subdirectories
add_subdirectory(util)
add_subdirectory(lm)
if(ENABLE_PYTHON)
add_subdirectory(python)
endif()
# Install targets
install(EXPORT kenlmTargets
FILE kenlmTargets.cmake
NAMESPACE kenlm::
DESTINATION share/kenlm/cmake
)
foreach(SUBDIR IN ITEMS util util/double-conversion util/stream lm lm/builder lm/common lm/filter lm/interpolate)
file(GLOB HEADERS ${CMAKE_CURRENT_LIST_DIR}/${SUBDIR}/*.h ${CMAKE_CURRENT_LIST_DIR}/${SUBDIR}/*.hh)
install(FILES ${HEADERS} DESTINATION include/kenlm/${SUBDIR} COMPONENT headers)
endforeach(SUBDIR)
# Config
include(CMakePackageConfigHelpers)
# generate the config file that is includes the exports
configure_package_config_file(${PROJECT_SOURCE_DIR}/cmake/kenlmConfig.cmake.in
"${CMAKE_CURRENT_BINARY_DIR}/kenlmConfig.cmake"
INSTALL_DESTINATION share/kenlm/cmake
NO_SET_AND_CHECK_MACRO
NO_CHECK_REQUIRED_COMPONENTS_MACRO
)
# install the configuration file
install(FILES
${CMAKE_CURRENT_BINARY_DIR}/kenlmConfig.cmake
DESTINATION share/kenlm/cmake
)
This diff is collapsed.
This diff is collapsed.
GNU LESSER GENERAL PUBLIC LICENSE
Version 3, 29 June 2007
Copyright (C) 2007 Free Software Foundation, Inc. <http://fsf.org/>
Everyone is permitted to copy and distribute verbatim copies
of this license document, but changing it is not allowed.
This version of the GNU Lesser General Public License incorporates
the terms and conditions of version 3 of the GNU General Public
License, supplemented by the additional permissions listed below.
0. Additional Definitions.
As used herein, "this License" refers to version 3 of the GNU Lesser
General Public License, and the "GNU GPL" refers to version 3 of the GNU
General Public License.
"The Library" refers to a covered work governed by this License,
other than an Application or a Combined Work as defined below.
An "Application" is any work that makes use of an interface provided
by the Library, but which is not otherwise based on the Library.
Defining a subclass of a class defined by the Library is deemed a mode
of using an interface provided by the Library.
A "Combined Work" is a work produced by combining or linking an
Application with the Library. The particular version of the Library
with which the Combined Work was made is also called the "Linked
Version".
The "Minimal Corresponding Source" for a Combined Work means the
Corresponding Source for the Combined Work, excluding any source code
for portions of the Combined Work that, considered in isolation, are
based on the Application, and not on the Linked Version.
The "Corresponding Application Code" for a Combined Work means the
object code and/or source code for the Application, including any data
and utility programs needed for reproducing the Combined Work from the
Application, but excluding the System Libraries of the Combined Work.
1. Exception to Section 3 of the GNU GPL.
You may convey a covered work under sections 3 and 4 of this License
without being bound by section 3 of the GNU GPL.
2. Conveying Modified Versions.
If you modify a copy of the Library, and, in your modifications, a
facility refers to a function or data to be supplied by an Application
that uses the facility (other than as an argument passed when the
facility is invoked), then you may convey a copy of the modified
version:
a) under this License, provided that you make a good faith effort to
ensure that, in the event an Application does not supply the
function or data, the facility still operates, and performs
whatever part of its purpose remains meaningful, or
b) under the GNU GPL, with none of the additional permissions of
this License applicable to that copy.
3. Object Code Incorporating Material from Library Header Files.
The object code form of an Application may incorporate material from
a header file that is part of the Library. You may convey such object
code under terms of your choice, provided that, if the incorporated
material is not limited to numerical parameters, data structure
layouts and accessors, or small macros, inline functions and templates
(ten or fewer lines in length), you do both of the following:
a) Give prominent notice with each copy of the object code that the
Library is used in it and that the Library and its use are
covered by this License.
b) Accompany the object code with a copy of the GNU GPL and this license
document.
4. Combined Works.
You may convey a Combined Work under terms of your choice that,
taken together, effectively do not restrict modification of the
portions of the Library contained in the Combined Work and reverse
engineering for debugging such modifications, if you also do each of
the following:
a) Give prominent notice with each copy of the Combined Work that
the Library is used in it and that the Library and its use are
covered by this License.
b) Accompany the Combined Work with a copy of the GNU GPL and this license
document.
c) For a Combined Work that displays copyright notices during
execution, include the copyright notice for the Library among
these notices, as well as a reference directing the user to the
copies of the GNU GPL and this license document.
d) Do one of the following:
0) Convey the Minimal Corresponding Source under the terms of this
License, and the Corresponding Application Code in a form
suitable for, and under terms that permit, the user to
recombine or relink the Application with a modified version of
the Linked Version to produce a modified Combined Work, in the
manner specified by section 6 of the GNU GPL for conveying
Corresponding Source.
1) Use a suitable shared library mechanism for linking with the
Library. A suitable mechanism is one that (a) uses at run time
a copy of the Library already present on the user's computer
system, and (b) will operate properly with a modified version
of the Library that is interface-compatible with the Linked
Version.
e) Provide Installation Information, but only if you would otherwise
be required to provide such information under section 6 of the
GNU GPL, and only to the extent that such information is
necessary to install and execute a modified version of the
Combined Work produced by recombining or relinking the
Application with a modified version of the Linked Version. (If
you use option 4d0, the Installation Information must accompany
the Minimal Corresponding Source and Corresponding Application
Code. If you use option 4d1, you must provide the Installation
Information in the manner specified by section 6 of the GNU GPL
for conveying Corresponding Source.)
5. Combined Libraries.
You may place library facilities that are a work based on the
Library side by side in a single library together with other library
facilities that are not Applications and are not covered by this
License, and convey such a combined library under terms of your
choice, if you do both of the following:
a) Accompany the combined library with a copy of the same work based
on the Library, uncombined with any other library facilities,
conveyed under the terms of this License.
b) Give prominent notice with the combined library that part of it
is a work based on the Library, and explaining where to find the
accompanying uncombined form of the same work.
6. Revised Versions of the GNU Lesser General Public License.
The Free Software Foundation may publish revised and/or new versions
of the GNU Lesser General Public License from time to time. Such new
versions will be similar in spirit to the present version, but may
differ in detail to address new problems or concerns.
Each version is given a distinguishing version number. If the
Library as you received it specifies that a certain numbered version
of the GNU Lesser General Public License "or any later version"
applies to it, you have the option of following the terms and
conditions either of that published version or of any later version
published by the Free Software Foundation. If the Library as you
received it does not specify a version number of the GNU Lesser
General Public License, you may choose any version of the GNU Lesser
General Public License ever published by the Free Software Foundation.
If the Library as you received it specifies that a proxy can decide
whether future versions of the GNU Lesser General Public License shall
apply, that proxy's public statement of acceptance of any version is
permanent authorization for you to choose that version for the
Library.
This diff is collapsed.
Most of the code here is licensed under the LGPL. There are exceptions that
have their own licenses, listed below. See comments in those files for more
details.
util/getopt.* is getopt for Windows
util/murmur_hash.cc
util/string_piece.hh and util/string_piece.cc
util/double-conversion/LICENSE covers util/double-conversion except the build files
util/file.cc contains a modified implementation of mkstemp under the LGPL
util/integer_to_string.* is BSD
For the rest:
KenLM is free software: you can redistribute it and/or modify
it under the terms of the GNU Lesser General Public License as published
by the Free Software Foundation, either version 2.1 of the License, or
(at your option) any later version.
KenLM is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public License 2.1
along with KenLM code. If not, see <http://www.gnu.org/licenses/lgpl-2.1.html>.
# file GENERATED by distutils, do NOT edit
include setup.py
include lm/*.cc
include lm/*.hh
include python/*.cpp
include util/*.cc
include util/*.hh
include util/double-conversion/*.cc
include util/double-conversion/*.h
# kenlm
Language model inference code by Kenneth Heafield (kenlm at kheafield.com)
The website https://kheafield.com/code/kenlm/ has more documentation. If you're a decoder developer, please download the latest version from there instead of copying from another decoder.
## Compiling
Use cmake, see [BUILDING](BUILDING) for build dependencies and more detail.
```bash
mkdir -p build
cd build
cmake ..
make -j 4
```
## Compiling with your own build system
If you want to compile with your own build system (Makefile etc) or to use as a library, there are a number of macros you can set on the g++ command line or in util/have.hh .
* `KENLM_MAX_ORDER` is the maximum order that can be loaded. This is done to make state an efficient POD rather than a vector.
* `HAVE_ICU` If your code links against ICU, define this to disable the internal StringPiece and replace it with ICU's copy of StringPiece, avoiding naming conflicts.
ARPA files can be read in compressed format with these options:
* `HAVE_ZLIB` Supports gzip. Link with -lz.
* `HAVE_BZLIB` Supports bzip2. Link with -lbz2.
* `HAVE_XZLIB` Supports xz. Link with -llzma.
Note that these macros impact only `read_compressed.cc` and `read_compressed_test.cc`. The bjam build system will auto-detect bzip2 and xz support.
## Estimation
lmplz estimates unpruned language models with modified Kneser-Ney smoothing. After compiling with bjam, run
```bash
bin/lmplz -o 5 <text >text.arpa
```
The algorithm is on-disk, using an amount of memory that you specify. See https://kheafield.com/code/kenlm/estimation/ for more.
MT Marathon 2012 team members Ivan Pouzyrevsky and Mohammed Mediani contributed to the computation design and early implementation. Jon Clark contributed to the design, clarified points about smoothing, and added logging.
## Filtering
filter takes an ARPA or count file and removes entries that will never be queried. The filter criterion can be corpus-level vocabulary, sentence-level vocabulary, or sentence-level phrases. Run
```bash
bin/filter
```
and see https://kheafield.com/code/kenlm/filter/ for more documentation.
## Querying
Two data structures are supported: probing and trie. Probing is a probing hash table with keys that are 64-bit hashes of n-grams and floats as values. Trie is a fairly standard trie but with bit-level packing so it uses the minimum number of bits to store word indices and pointers. The trie node entries are sorted by word index. Probing is the fastest and uses the most memory. Trie uses the least memory and is a bit slower.
As is the custom in language modeling, all probabilities are log base 10.
With trie, resident memory is 58% of IRST's smallest version and 21% of SRI's compact version. Simultaneously, trie CPU's use is 81% of IRST's fastest version and 84% of SRI's fast version. KenLM's probing hash table implementation goes even faster at the expense of using more memory. See https://kheafield.com/code/kenlm/benchmark/.
Binary format via mmap is supported. Run `./build_binary` to make one then pass the binary file name to the appropriate Model constructor.
## Platforms
`murmur_hash.cc` and `bit_packing.hh` perform unaligned reads and writes that make the code architecture-dependent.
It has been sucessfully tested on x86\_64, x86, and PPC64.
ARM support is reportedly working, at least on the iphone.
Runs on Linux, OS X, Cygwin, and MinGW.
Hideo Okuma and Tomoyuki Yoshimura from NICT contributed ports to ARM and MinGW.
## Decoder developers
- I recommend copying the code and distributing it with your decoder. However, please send improvements upstream.
- It's possible to compile the query-only code without Boost, but useful things like estimating models require Boost.
- Select the macros you want, listed in the previous section.
- There are two build systems: compile.sh and cmake. They're pretty simple and are intended to be reimplemented in your build system.
- Use either the interface in `lm/model.hh` or `lm/virtual_interface.hh`. Interface documentation is in comments of `lm/virtual_interface.hh` and `lm/model.hh`.
- There are several possible data structures in `model.hh`. Use `RecognizeBinary` in `binary_format.hh` to determine which one a user has provided. You probably already implement feature functions as an abstract virtual base class with several children. I suggest you co-opt this existing virtual dispatch by templatizing the language model feature implementation on the KenLM model identified by `RecognizeBinary`. This is the strategy used in Moses and cdec.
- See `lm/config.hh` for run-time tuning options.
## Contributors
Contributions to KenLM are welcome. Please base your contributions on https://github.com/kpu/kenlm and send pull requests (or I might give you commit access). Downstream copies in Moses and cdec are maintained by overwriting them so do not make changes there.
## Python module
Contributed by Victor Chahuneau.
### Installation
```bash
pip install https://github.com/kpu/kenlm/archive/master.zip
```
### Basic Usage
```python
import kenlm
model = kenlm.Model('lm/test.arpa')
print(model.score('this is a sentence .', bos = True, eos = True))
```
See [python/example.py](python/example.py) and [python/kenlm.pyx](python/kenlm.pyx) for more, including stateful APIs.
---
The name was Hieu Hoang's idea, not mine.
#!/bin/bash
rm -rf {lm,util,util/double-conversion}/*.o bin/{query,build_binary}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment