Commit 4de29dca authored by &'s avatar &
Browse files

Moving over training data decontamination code

parent f7992789
env
*.pyc
data/
.idea
\ No newline at end of file
janitor.py contains a script to remove benchmark data contamination from training data sets.
It uses the approach described in the [GPT-3 paper](https://arxiv.org/abs/2005.14165).
## Algorithm
1) Collects all contamination text files that are to be removed from training data
2) Filters training data by finding `N`gram matches between the training data
and any contamination
1) `N`grams ignore case and punctation and are split on whitespace.
2) Matching `N`gram substrings are removed, as is a `window_to_remove` character window around
the match, splitting the training data into chunks
3) Any chunks less than `minimum_slice_length` are removed
4) Training data sets split into more than `too_dirty_cutoff` are considered
completey contaminated and removed
OpenAI used:
```
ngram_n = 13
window_to_remove = 200
minimum_slice_length = 200
too_dirty_cutoff = 10
```
## Compling
Janitor can be used as a pure python program, but it is much faster if the ngram
code is run in C++. To compile the C++ code, run
```
pip install pybind11
c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix)
```
If your your compiler isn't linked to python, you may need to add to the above `-undefined dynamic_lookup`
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <utility>
#include <string>
#include <vector>
#include <tuple>
#include <queue>
bool is_whitespace(char ch) noexcept {
// " \t\n\r\x0b\x0c" (python string.whitespace)
return ch == 32 or (9 <= ch and ch <= 13);
// return ch <= 32; // arguably too general, but slightly faster
}
bool is_punctuation(char c) noexcept {
// '!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~' ascii values: 33-47, 58-64, 91-96, 123-126
return (33 <= c and c <= 47) or (58 <= c and c <= 64) or (91 <= c and c <= 96) or (123 <= c and c <= 126);
}
// Takes a string and makes ngrams of length N, splitting grams on whitespace and ignoring ignored characters
// Returns a LARGE array of ngrams
std::vector<std::string> clean_ngram(
std::string const & input, std::string const & ignore, size_t ngram_n
) noexcept {
size_t num_grams = 0;
std::vector<std::string> ngram_list;
std::vector<uint8_t> gram_lengths;
std::string current_ngram;
// Max gram length is set to 10 below.
current_ngram.reserve(11*ngram_n);
gram_lengths.reserve(ngram_n);
bool started_gram = false;
gram_lengths.push_back(0);
//for (size_t i=0; i<input.length(); i++) {
// this is slightly faster, and we don't need the index in this one
for (auto iter = input.begin(); iter != input.end(); iter++) {
// If whitespace, end the current ngram and start the next
// alternatively, (perhaps marginally) faster: if (is_whitespace(ch)) { ... }
if (is_whitespace(*iter) || gram_lengths.back() > 10) {
// Skip all whitespace
while (++iter != input.end() && is_whitespace(*iter));
iter--;
if (started_gram){
num_grams += 1;
// Building 1grams is a special case
if (ngram_n == 1){
ngram_list.push_back(current_ngram);
current_ngram = current_ngram.substr(gram_lengths.front());
gram_lengths.back() = 0;
// If there are enough grams to form an ngram, save
} else if (num_grams >= ngram_n){
// Save the current ngram
ngram_list.push_back(current_ngram);
// Start the next ngram by dropping the first gram and its space from the ngram
current_ngram = current_ngram.substr(gram_lengths.front() + 1);
current_ngram += ' ';
// Drop the length of the first gram and prepare to record the length of the new gram
gram_lengths.erase(gram_lengths.begin());
gram_lengths.push_back(0);
// Otherwise, continute building
} else {
current_ngram += ' ';
gram_lengths.push_back(0);
}
started_gram = false;
}
// Skip ignored characters
// alternatively, (perhaps marginally) faster: if (is_punctuation(ch)) continue;
} else if (ignore.find(*iter) != std::string::npos) {
continue;
}
// If it is a non-ignored character, add it to the ngram and update the last gram's length
else {
current_ngram += tolower(*iter);
gram_lengths.back() += 1;
started_gram = true;
}
}
return ngram_list;
}
// Takes a string and makes ngrams of length N, splitting grams on whitespace and ignoring ignored characters
// Returns a LARGE array of tuples of (ngram, start_idx, end_idx)
std::vector<std::tuple<std::string, size_t, size_t> > clean_ngram_with_indices(
std::string const & input, std::string const & ignore, size_t ngram_n
) noexcept {
size_t num_grams = 0;
std::vector<std::tuple<std::string, size_t, size_t> > ngram_list;
std::vector<uint8_t> gram_lengths;
std::vector<size_t> gram_start_indices;
std::string current_ngram;
// Max gram length is set to 10 below.
current_ngram.reserve(11*ngram_n);
bool started_gram = false;
gram_lengths.push_back(0);
gram_start_indices.push_back(0);
for (size_t i=0; i<input.length(); i++) {
char ch = input[i];
// If whitespace, end the current ngram and start the next
if (is_whitespace(ch) || gram_lengths.back() > 10) {
// Skip all whitespace
while (++i < input.length() && is_whitespace(input[i]));
i--;
if (started_gram){
num_grams += 1;
// Building 1grams is a special case
if (ngram_n == 1){
ngram_list.push_back(std::make_tuple(current_ngram, gram_start_indices.front(), i));
current_ngram = current_ngram.substr(gram_lengths.front());
gram_lengths.back() = 0;
gram_start_indices.back() = i+1;
// If there are enough grams to form an ngram, save
} else if (num_grams >= ngram_n){
// Save the current ngram
ngram_list.push_back(
std::make_tuple(current_ngram, gram_start_indices.front(), i)
);
// Start the next ngram by dropping the first gram and its space from the ngram
current_ngram = current_ngram.substr(gram_lengths.front() + 1);
current_ngram += ' ';
// Drop the length of the first gram and prepare to record the length of the new gram
gram_lengths.erase(gram_lengths.begin());
gram_lengths.push_back(0);
gram_start_indices.erase(gram_start_indices.begin());
gram_start_indices.push_back(i+1);
// Otherwise, continute building
} else {
current_ngram += ' ';
gram_lengths.push_back(0);
gram_start_indices.push_back(i+1);
}
started_gram = false;
}
// Skip ignored characters
} else if (ignore.find(*iter) != std::string::npos) {
continue;
// If it is a non-ignored character, add it to the ngram and update the last gram's length
} else {
current_ngram += tolower(ch);
gram_lengths.back() += 1;
started_gram = true;
}
}
return ngram_list;
}
PYBIND11_MODULE(janitor_util, m) {
m.doc() = "pybind11 example plugin"; // optional module docstring
// m.def("add", &add, "A function which adds two numbers"); // example function
m.def("clean_ngram", &clean_ngram, "Create ngrams of words, ignoring some characters");
m.def("clean_ngram_with_indices", &clean_ngram_with_indices, "Create ngrams of words with indices, ignoring some characters");
}
// Example compile
// c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix)
// If python and gcc aren't linked, append to the above: -undefined dynamic_lookup
\ No newline at end of file
import re
import string
import timeit
import pickle
import traceback
from pprint import pprint
# This is a cpp module. Compile janitor_util.cpp with:
# c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) -undefined dynamic_lookup
try:
import janitor_util
JANITOR_CPP = True
except Exception as e:
print("WARNING: C++ module could not be loaded. Janitor running in python mode")
traceback.print_exc()
JANITOR_CPP = False
# Implementation from nltk source
# https://www.nltk.org/_modules/nltk/util.html
def form_ngrams(sequence, n):
history = []
while n > 1:
# PEP 479, prevent RuntimeError from being raised when StopIteration bubbles out of generator
try:
next_item = next(sequence)
except StopIteration:
# no more data, terminate the generator
return
history.append(next_item)
n -= 1
for item in sequence:
history.append(item)
yield tuple(history)
del history[0]
def word_ngrams(s, n):
"""Splits a string into ngram words"""
tokens = s.split() # not a generator :(
ngram_seqs = form_ngrams(iter(tokens), n)
return (" ".join(ngram) for ngram in ngram_seqs)
# https://stackoverflow.com/questions/13734451/string-split-with-indices-in-python
def split_indices(s):
"""Splits a string on whitespaces and records the indices of each in the original string.
@:return generator((word, (start_idx, end_idx)), ...)
"""
return ((m.group(0), (m.start(), m.end() - 1)) for m in re.finditer(r'\S+', s))
def word_ngrams_indices(s, n):
"""Splits a string into pairs of (ngram words, their start/end indices)"""
tokens_with_indices = split_indices(s)
# Generator of ngrams of (word, idx_pairs)
# (
# [(word, (start,end)), (word, (start, end))...],
# [(word, (start, end)), ...],
# ...
# )
ngram_seqs_with_indices = form_ngrams(tokens_with_indices, n)
# Generator of pairs of word and index ngrams
# (
# ([word, word, ...], [(start,end), (start,end), ...]),
# ...
# )
ngram_indices_pairs = (zip(*ngram_with_indices) for ngram_with_indices in ngram_seqs_with_indices)
# Generator of ( (word_ngram, (start, end)), (word_ngram, start, end)), ...)
return ((" ".join(ngram_seq), (indices[0][0], indices[-1][1])) for ngram_seq, indices in ngram_indices_pairs)
class Janitor:
# FIXME delete_chars: Should anything else go here? Special chars?
def __init__(
self,
ngram_n=13,
window_to_remove=200,
too_dirty_cutoff=10,
minimum_slice_length=200,
delete_chars=string.punctuation
):
self.ngram_n = ngram_n
self.window_to_remove = window_to_remove
self.too_dirty_cutoff = too_dirty_cutoff
self.minimum_slice_length = minimum_slice_length
self.delete_chars = delete_chars
self.dirt_ngrams = set()
# If in python, we'll translate uppercase to lowercase and delete naughty characters.
# This is fast by python standards
# https://stackoverflow.com/questions/638893/what-is-the-most-efficient-way-in-python-to-convert-a-string-to-all-lowercase-st
self.translation_table = str.maketrans(
string.ascii_lowercase + string.ascii_uppercase, # These characters
string.ascii_lowercase * 2, # Become these characters
self.delete_chars # These are deleted
)
##############
# I/O for saving contamination ngrams
##############
def save_contamination_ngrams(self, filename):
with open(filename, 'wb') as fp:
pickle.dump(filename, fp)
def load_contamination_ngrams(self, filename):
with open(filename, 'rb') as fp:
self.dirt_ngrams = pickle.load(fp)
##############
# Call these :)
##############
def register_contaminant(self, dirt_string):
"""Register a string as contamination to be removed, e.g. a test set
This breaks the dirt_string into ngrams to store for future cleaning"""
if JANITOR_CPP:
return self.register_contaminant_cpp(dirt_string)
else:
print("WARNING: Janitor running in python mode")
return self.register_contaminant_python(dirt_string)
def clean(self, dirty_string):
"""Clean a string (e.g. a training set) by removing all ngrams previously
reigstered as contaminants. Returns a list of clean chunks, or empty if
the string was too dirty"""
if JANITOR_CPP:
return self.clean_cpp(dirty_string)
else:
print("WARNING: Janitor running in python mode")
return self.clean_python(dirty_string)
def _split_chunks(self, dirty_string, dirty_parts):
clean_chunks = []
splice_idx = 0
for i, (ngram, start, end) in enumerate(dirty_parts):
if i > self.too_dirty_cutoff:
return []
start = max(0, start - self.window_to_remove)
end = min(len(dirty_string), end + self.window_to_remove)
if start - splice_idx > self.minimum_slice_length:
clean_chunks.append(dirty_string[splice_idx: start])
splice_idx = end
return clean_chunks
##############
# Fast C++
##############
def register_contaminant_cpp(self, dirt_string):
self.dirt_ngrams.update(janitor_util.clean_ngram(dirt_string, self.delete_chars, self.ngram_n))
def clean_cpp(self, dirty_string):
contamination_indices = janitor_util.clean_ngram_with_indices(dirty_string, self.delete_chars, self.ngram_n)
return self._split_chunks(dirty_string, contamination_indices)
##############
# Slow python
##############
def normalize_string(self, s):
return s.translate(self.translation_table)
def register_contaminant_python(self, dirt_string):
self.dirt_ngrams.update(word_ngrams(self.normalize_string(dirt_string), self.ngram_n))
def clean_python(self, dirty_string):
contamination_indices = (
(None, *idx_pair)
for dirty_ngram, idx_pair in word_ngrams_indices(dirty_string, self.ngram_n)
if self.normalize_string(dirty_ngram) in self.dirt_ngrams
)
return self._split_chunks(dirty_string, contamination_indices)
##################################################################
# Tests
#################################################################
def print_cpp():
source = """ ,, I'm a very !dirty,, ,, dirty boy. Clean me daddy. \n\nhe he he hehe heh. lastword """ * 2
for i in range(1, 10, 2):
pprint(janitor_util.clean_ngram(source, string.punctuation, i))
for ngram, start, end in \
janitor_util.clean_ngram_with_indices(source, string.punctuation, i):
print(ngram, "\t", start, end, source[start:end].replace("\n", "\\n"))
def test_cpp():
source = """ ,, I'm a very !dirty,, ,, dirty boy. Clean me daddy. \n\nhe he he hehe heh. lastword """ * 2
contaminant = "dirty boy. Clean he he"
jan_python = Janitor()
jan_cpp = Janitor()
jan_python.register_contaminant_python(contaminant)
jan_cpp.register_contaminant(contaminant)
assert jan_python.dirt_ngrams == jan_cpp.dirt_ngrams, (jan_python.dirt_ngrams, jan_cpp.dirt_ngrams)
assert jan_python.clean_python(source) == jan_cpp.clean(source), \
(jan_python.clean_python(source), jan_cpp.clean(source))
print("Passed test, python==cpp")
def benchmark():
# Download and put in data folder: enwik8 (100 MB) from https://cs.fit.edu/~mmahoney/compression/textdata.html
setup = \
"""
with open("data/enwik8", "r") as f:
data = f.read()
jan = Janitor(too_dirty_cutoff=1000)
jan.register_contaminant('''
theories is that there is a connection between &quot;geekdom&quot; and autism.
This is hinted, for instance, by a ''Wired Magazine'' article in 2001 entitled &quot;
The [[Geek]] Syndrome&quot;, which is a point argued by many in the autism rights
movement{{ref|Wired}}. This article, many professionals assert, is just one example of
the media's application of mental disease labels to what is actually variant normal behavior
&amp;mdash;they argue that shyness, lack of athletic ability or social skills, and intellectual
interests, even when they seem unusual to others, are not in themselves signs of autism or
Asperger's syndrome. Others assert that it is actually the medical profession which is applying
mental disease labels to children who in the past would have simply been accepted as a little
different or even labeled 'gifted'. See [[clinomorphism]] for further discussion of this issue.
Due to the recent publicity surrounding autism and autis
ultan Al Nahyan]] granted [[Petroleum]] concessions, and oil was first found in 1958. At first,
oil money had a marginal impact. A few lowrise concete buildings were erected, and the first
paved road was completed in 1961, but Sheikh Shakbut, uncertain whether the new oil royalties
would last, took a cautious approach, prefering to save the revenue rather than investing it in
development. His brother, [[Zayed bin Sultan Al Nahayan]], saw that oil wealth had the potential
to transform Abu Dhabi. The ruling Al Nahayan family decided that Sheikh Zayed should replace his
brother as Ruler and carry out his vision of developing the country. On [[August 6]], [[1966]],
with the assistance of the British, Sheikh Zayed became the new ruler. See generally, Al-Fahim, M,
''From Rags to Riches: A Story of Abu Dhabi'', Chapter Six (London Centre of Arab Studies, 1995),
ISBN 1 900404 00 1. With the announcement by Britain in 1968 that it would withdraw from the
Gulf area by 1971, Sheikh Zayed became the main driving force behind the formation of the
[[United Arab Emirates]]. After the Emirates gained independence in 1971,
''')
"""
n = 1
print(f"Timing {n} run on 100 MB")
print("Register contaminant")
# print("\tPython", timeit.timeit("jan.register_contaminant_python(data)", setup=setup, globals=globals(), number=n))
print("\tCpp", timeit.timeit("jan.register_contaminant(data)", setup=setup, globals=globals(), number=n))
print("Clean")
# print("\tPython", timeit.timeit("jan.clean_python(data)", setup=setup, globals=globals(), number=n))
print("\tCpp", timeit.timeit("jan.clean(data)", setup=setup, globals=globals(), number=n))
def test():
source = """ ,, I'm a very !dirty,, ,, dirty boy. Clean me daddy. \n\nhe he he hehe heh. lastword """ * 2
contaminant = "dirty boy. Clean he he"
jan = Janitor(ngram_n=3)
jan.register_contaminant(contaminant)
cleaned = " ".join(jan.clean(source))
for contam in jan.dirt_ngrams:
assert contam not in cleaned, contam
filename = "data/saved_contam"
jan.save_contamination_ngrams(filename)
jan = Janitor(ngram_n=3)
jan.load_contamination_ngrams(filename)
cleaned = " ".join(jan.clean(source))
for contam in jan.dirt_ngrams:
assert contam not in cleaned, contam
if __name__ == "__main__":
test()
# print_cpp()
# test_cpp()
# benchmark()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment