Commit 1f8a8c1d authored by jon-tow's avatar jon-tow
Browse files

Merge branch 'master' of https://github.com/EleutherAI/lm-evaluation-harness into remove-dataset

parents b4c0275d b0acb337
""" """
Outputs all 13-grams found in The Pile. Outputs all 13-grams found in The Pile.
Loops through all documents and uses the logic found in janitor.py to extract 13-grams. Loops through all documents and uses the logic found in janitor.py to extract 13-grams.
We bucket each 13-gram by hash into separate file buckets to allow easy parallel processing in the We bucket each 13-gram by hash into separate file buckets to allow easy parallel processing in the
next stage. We also include the current pile document_id with each ngram instance to allow the next stage. We also include the current pile document_id with each ngram instance to allow the
filtering to exclude 13-grams that match more then 10 unique documents (done further down the pipeline). filtering to exclude 13-grams that match more then 10 unique documents (done further down the pipeline).
We didn't use lm_dataformat to output as it increases time 4x (slow jsonify) and makes We didn't use lm_dataformat to output as it increases time 4x (slow jsonify) and makes
...@@ -21,8 +21,10 @@ Arguments ...@@ -21,8 +21,10 @@ Arguments
""" """
import argparse import argparse
import json
import pickle import pickle
import os import os
import sys
from pathlib import Path from pathlib import Path
import glob import glob
import signal import signal
...@@ -30,32 +32,98 @@ from signal import SIGINT ...@@ -30,32 +32,98 @@ from signal import SIGINT
from tqdm import tqdm from tqdm import tqdm
from scripts.clean_training_data.janitor import Janitor, word_ngrams from lm_eval.decontamination.janitor import Janitor, word_ngrams
from scripts.clean_training_data.archiver import TextArchive, Reader from lm_eval.decontamination.archiver import TextArchive, Reader
import logging import logging
from tqdm_multiprocess.logger import setup_logger_tqdm from tqdm_multiprocess.logger import setup_logger_tqdm
logger = logging.getLogger(__name__)
pile_document_count = 210607728 logger = logging.getLogger(__name__)
terminate = False terminate = False
def handler(signal_received, frame): def handler(signal_received, frame):
global terminate global terminate
terminate = True terminate = True
def get_pile(directory):
reader = Reader() def yield_pile(start_offsets=None, checkpoint_offset=None):
for file in glob.glob(os.path.join(directory, f"*.jsonl.zst*")): directory = "pile"
if not os.path.exists(directory):
print(
"We expect the pile archives to be in the 'pile' directory, but this was not found."
)
raise Exception("Pile directory not found.")
files = list(sorted(glob.glob(os.path.join(directory, "*.jsonl.zst*"))))
pile_global_offset = 0
start_file = 0
if checkpoint_offset:
for file_i, start_offset in enumerate(start_offsets):
if start_offset > checkpoint_offset:
break
start_file = file_i
pile_global_offset = start_offset
for file_i, file in enumerate(files):
if file_i < start_file:
logger.info(f"Skipping file {file}")
continue
logger.info(f"Reading from pile file: {file}")
reader = Reader()
for document in reader.read(file): for document in reader.read(file):
yield document yield (pile_global_offset, document)
pile_global_offset += 1
# Hash buckets > disk backed files. Supports file position checkpointing and resuming
# Allows you to write continuously and checkpoint intermittently. If a failure occurs
# the buckets are simply truncated at your last checkpoint.
class Buckets:
def __init__(self, directory, num_buckets):
self.bucket_files = [
os.path.join(directory, f"ngrams_{i}.bkt.txt") for i in range(num_buckets)
]
self.buckets = list(map(TextArchive, self.bucket_files))
self.checkpoint_file = os.path.join(directory, f"bucket_offsets.ckpt")
if os.path.exists(self.checkpoint_file):
self.bucket_offsets = pickle.load(open(self.checkpoint_file, "rb"))
else:
self.bucket_offsets = [0 for i in range(len(self.buckets))]
for i, offset in enumerate(self.bucket_offsets):
bucket = self.buckets[i]
bucket.fh.seek(offset)
bucket.fh.truncate()
def add_data(self, key, value):
i = hash(key) % len(self.buckets)
bucket = self.buckets[i]
bucket.add_data(value)
def save_checkpoint(self):
for bucket in self.buckets:
bucket.fh.flush()
bucket_offsets = [bucket.fh.tell() for bucket in self.buckets]
pickle.dump(bucket_offsets, open(self.checkpoint_file, "wb"))
def close_buckets(self):
for bucket in self.buckets:
bucket.commit()
def close_buckets(buckets):
for bucket in buckets:
bucket.commit()
def do_ngrams_in_buckets(n_value, working_directory, bucket_count): def do_ngrams_in_buckets(n_value, working_directory, bucket_count):
pile_statistics = json.load(open("pile_statistics.json", "r"))
pile_document_count = pile_statistics["Document Count"]
start_offsets = pile_statistics["File Start Offsets"]
output_directory = os.path.join(working_directory, "output") output_directory = os.path.join(working_directory, "output")
os.makedirs(output_directory, exist_ok=True) os.makedirs(output_directory, exist_ok=True)
...@@ -68,58 +136,71 @@ def do_ngrams_in_buckets(n_value, working_directory, bucket_count): ...@@ -68,58 +136,71 @@ def do_ngrams_in_buckets(n_value, working_directory, bucket_count):
return return
# Checkpoint # Checkpoint
checkpoint_file = os.path.join(output_directory, f"ngram_buckets.ckpt") checkpoint_file = os.path.join(working_directory, f"pile_offset.ckpt")
if os.path.exists(checkpoint_file): if os.path.exists(checkpoint_file):
start_id = pickle.load(open(checkpoint_file,"rb")) checkpoint_offset = pickle.load(open(checkpoint_file, "rb"))
iterate = True
else: else:
start_id = 0 checkpoint_offset = 0
iterate = False
logger.info(f"Starting at pile document index {start_id}") logger.info(f"Starting at pile document index {checkpoint_offset}")
bucket_files = [os.path.join(output_directory, f"ngrams_{i}.bkt.txt") for i in range(bucket_count)] buckets = Buckets(output_directory, bucket_count)
buckets = list(map(TextArchive, bucket_files))
janitor = Janitor() janitor = Janitor()
current_id = 0
batch_size = 1000 batch_size = 1000
batch_counter = 0 batch_counter = 0
with tqdm(total=pile_document_count, dynamic_ncols=True, unit="docs") as progress:
for document in get_pile(working_directory):
if current_id < start_id:
if terminate:
close_buckets(buckets)
return
current_id += 1 with tqdm(total=checkpoint_offset, dynamic_ncols=True, unit="docs") as progress:
for offset, document in yield_pile(start_offsets, checkpoint_offset):
if iterate:
logger.info(f"Iterating to offset {checkpoint_offset} from {offset}")
progress.update(offset)
iterate = False
if offset < checkpoint_offset:
progress.update() progress.update()
if terminate:
return
continue continue
if offset == checkpoint_offset:
progress.reset(total=pile_document_count)
progress.update(checkpoint_offset)
# Save checkpoint every "batch_size", only allow terminate after checkpoint # Save checkpoint every "batch_size", only allow terminate after checkpoint
if batch_counter == batch_size: if batch_counter == batch_size:
progress.update(batch_size) progress.update(batch_size)
batch_counter = 0 batch_counter = 0
pickle.dump(current_id, open(checkpoint_file,"wb")) buckets.save_checkpoint()
pickle.dump(offset, open(checkpoint_file, "wb"))
if terminate: if terminate:
close_buckets(buckets) buckets.close_buckets()
return return
ngrams = word_ngrams(janitor.normalize_string(document), n_value) ngrams = word_ngrams(janitor.normalize_string(document), n_value)
for ngram in ngrams: for ngram in ngrams:
bucket = hash(ngram) % len(buckets) buckets.add_data(ngram, f"{ngram} {offset}")
buckets[bucket].add_data(f"{ngram} {current_id}")
batch_counter += 1 batch_counter += 1
current_id += 1
buckets.close_buckets()
close_buckets(buckets)
Path(done_file).touch() Path(done_file).touch()
parser = argparse.ArgumentParser(description='Generate 13 grams from Pile.') parser = argparse.ArgumentParser(description="Generate 13 grams from Pile.")
parser.add_argument("-dir", "--working_directory", default="") parser.add_argument("-dir", "--working_directory", default="")
parser.add_argument("-n", "--n_value", type=int, default=13) parser.add_argument("-n", "--n_value", type=int, default=13)
parser.add_argument("-buckets", "--bucket_count", type=int, default=500) parser.add_argument("-buckets", "--bucket_count", type=int, default=500)
if __name__ == '__main__': if __name__ == "__main__":
version = 1.00
print(f"Running version {version}")
if "PYTHONHASHSEED" not in os.environ or os.environ["PYTHONHASHSEED"] != "0":
print("Please run 'export PYTHONHASHSEED=0' before running generate.")
sys.exit()
# Handle sigint (ctrl-c) cleanly # Handle sigint (ctrl-c) cleanly
previous_signal_int = signal.signal(SIGINT, handler) previous_signal_int = signal.signal(SIGINT, handler)
...@@ -128,4 +209,8 @@ if __name__ == '__main__': ...@@ -128,4 +209,8 @@ if __name__ == '__main__':
setup_logger_tqdm(logfile_path) setup_logger_tqdm(logfile_path)
args = parser.parse_args() args = parser.parse_args()
do_ngrams_in_buckets(args.n_value, args.working_directory, args.bucket_count) do_ngrams_in_buckets(args.n_value, args.working_directory, args.bucket_count)
\ No newline at end of file
info_dict = {"title": "dataset ngrams", "ngram_size": 13}
info_dict_path = os.path.join(args.working_directory, "info.json")
json.dump(info_dict, open(info_dict_path, "w"))
from lm_eval.decontamination.archiver import Reader
import os
import json
from functools import reduce
import glob
import tqdm
from tqdm_multiprocess import TqdmMultiProcessPool
def get_file_stats(file_path, tqdm_func, global_tqdm):
reader = Reader()
total_documents = 0
total_size = 0
update_frequency = 10000
current_file_position = 0
with tqdm_func(
total=os.path.getsize(file_path), dynamic_ncols=True, unit="byte", unit_scale=1
) as progress:
for document in reader.read(file_path, get_meta=True):
total_size += len(document)
total_documents += 1
if total_documents % update_frequency == 0:
new_file_pos = reader.fh.tell()
bytes_read = new_file_pos - current_file_position
current_file_position = new_file_pos
progress.update(bytes_read)
global_tqdm.update(bytes_read)
return (total_documents, total_size)
def get_files():
directory = "pile"
files = list(sorted(glob.glob(os.path.join(directory, "*.jsonl.zst*"))))
print(files)
return files
def get_stats():
files = get_files()
total_size_bytes = sum(map(lambda x: os.path.getsize(x), files))
pool = TqdmMultiProcessPool(4)
global_tqdm = tqdm.tqdm(
total=total_size_bytes, dynamic_ncols=True, unit="byte", unit_scale=1
)
# Generate minhashes with pool
tasks = [(get_file_stats, (file,)) for file in files]
def on_done(_):
return None
def on_error(_):
return None
results = pool.map(global_tqdm, tasks, on_error, on_done)
total_documents, total_size = reduce(
lambda x, y: (x[0] + y[0], x[1] + y[1]), results
)
start_offsets = []
current_offset = 0
for file_document_count, _ in results:
start_offsets.append(current_offset)
current_offset += file_document_count
return (total_documents, total_size, start_offsets)
if __name__ == "__main__":
version = 1.01
print(f"Running version {version}")
stats_file_path = "pile_statistics.json"
if os.path.exists(stats_file_path):
stats = json.load(open(stats_file_path, "r"))
else:
document_count, total_document_size_chars, start_offsets = get_stats()
stats = {
"Data": "Pile statistics",
"Document Count": document_count,
"Total Pile Characters": total_document_size_chars,
"File Start Offsets": start_offsets,
}
json.dump(stats, open(stats_file_path, "w"), indent=4)
print(f"document_count: {stats['Document Count']}")
print(f"total_chars: {stats['Total Pile Characters']}")
print(f"start_offsets: {stats['File Start Offsets']}")
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <pybind11/stl.h> #include <pybind11/stl.h>
#include <utility> #include <queue>
#include <string> #include <string>
#include <vector>
#include <tuple> #include <tuple>
#include <queue> #include <utility>
#include <vector>
bool is_whitespace(char ch) noexcept { bool is_whitespace(char ch) noexcept {
// " \t\n\r\x0b\x0c" (python string.whitespace) // " \t\n\r\x0b\x0c" (python string.whitespace)
return ch == 32 or (9 <= ch and ch <= 13); return ch == 32 or (9 <= ch and ch <= 13);
// return ch <= 32; // arguably too general, but slightly faster // return ch <= 32; // arguably too general, but slightly faster
} }
bool is_punctuation(char c) noexcept { bool is_punctuation(char c) noexcept {
// '!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~' ascii values: 33-47, 58-64, 91-96, 123-126 // '!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~' ascii values: 33-47, 58-64,
return (33 <= c and c <= 47) or (58 <= c and c <= 64) or (91 <= c and c <= 96) or (123 <= c and c <= 126); // 91-96, 123-126
return (33 <= c and c <= 47) or (58 <= c and c <= 64) or
(91 <= c and c <= 96) or (123 <= c and c <= 126);
} }
// Takes a string and makes ngrams of length N, splitting grams on whitespace and ignoring ignored characters // Takes a string and makes ngrams of length N, splitting grams on whitespace
// Returns a LARGE array of ngrams // and ignoring ignored characters Returns a LARGE array of ngrams
std::vector<std::string> clean_ngram( std::vector<std::string> clean_ngram(std::string const &input,
std::string const & input, std::string const & ignore, size_t ngram_n std::string const &ignore,
) noexcept { size_t ngram_n) noexcept {
size_t num_grams = 0; size_t num_grams = 0;
std::vector<std::string> ngram_list; std::vector<std::string> ngram_list;
std::vector<uint8_t> gram_lengths; std::vector<uint8_t> gram_lengths;
std::string current_ngram; std::string current_ngram;
// Max gram length is set to 10 below. // Max gram length is set to 10 below.
current_ngram.reserve(11*ngram_n); current_ngram.reserve(11 * ngram_n);
gram_lengths.reserve(ngram_n); gram_lengths.reserve(ngram_n);
bool started_gram = false; bool started_gram = false;
gram_lengths.push_back(0); gram_lengths.push_back(0);
//for (size_t i=0; i<input.length(); i++) { // for (size_t i=0; i<input.length(); i++) {
// this is slightly faster, and we don't need the index in this one // this is slightly faster, and we don't need the index in this one
for (auto iter = input.begin(); iter != input.end(); iter++) { for (auto iter = input.begin(); iter != input.end(); iter++) {
// If whitespace, end the current ngram and start the next // If whitespace, end the current ngram and start the next
// alternatively, (perhaps marginally) faster: if (is_whitespace(ch)) { ... } // alternatively, (perhaps marginally) faster: if (is_whitespace(ch)) { ...
if (is_whitespace(*iter) || gram_lengths.back() > 10) { // }
if (is_whitespace(*iter) || gram_lengths.back() > 10) {
// Skip all whitespace
while (++iter != input.end() && is_whitespace(*iter)); // Skip all whitespace
iter--; while (++iter != input.end() && is_whitespace(*iter))
;
if (started_gram){ iter--;
num_grams += 1;
if (started_gram) {
// Building 1grams is a special case num_grams += 1;
if (ngram_n == 1){
ngram_list.push_back(current_ngram); // Building 1grams is a special case
current_ngram = current_ngram.substr(gram_lengths.front()); if (ngram_n == 1) {
gram_lengths.back() = 0; ngram_list.push_back(current_ngram);
current_ngram = current_ngram.substr(gram_lengths.front());
// If there are enough grams to form an ngram, save gram_lengths.back() = 0;
} else if (num_grams >= ngram_n){
// Save the current ngram // If there are enough grams to form an ngram, save
ngram_list.push_back(current_ngram); } else if (num_grams >= ngram_n) {
// Save the current ngram
// Start the next ngram by dropping the first gram and its space from the ngram ngram_list.push_back(current_ngram);
current_ngram = current_ngram.substr(gram_lengths.front() + 1);
current_ngram += ' '; // Start the next ngram by dropping the first gram and its space from
// the ngram
// Drop the length of the first gram and prepare to record the length of the new gram current_ngram = current_ngram.substr(gram_lengths.front() + 1);
gram_lengths.erase(gram_lengths.begin()); current_ngram += ' ';
gram_lengths.push_back(0);
// Drop the length of the first gram and prepare to record the length
// Otherwise, continute building // of the new gram
} else { gram_lengths.erase(gram_lengths.begin());
current_ngram += ' '; gram_lengths.push_back(0);
gram_lengths.push_back(0);
} // Otherwise, continute building
} else {
started_gram = false; current_ngram += ' ';
} gram_lengths.push_back(0);
}
started_gram = false;
}
// Skip ignored characters // Skip ignored characters
// alternatively, (perhaps marginally) faster: if (is_punctuation(ch)) continue; // alternatively, (perhaps marginally) faster: if (is_punctuation(ch))
} else if (ignore.find(*iter) != std::string::npos) { // continue;
continue; } else if (ignore.find(*iter) != std::string::npos) {
} continue;
}
// If it is a non-ignored character, add it to the ngram and update the last gram's length // If it is a non-ignored character, add it to the ngram and update the last
else { // gram's length
current_ngram += tolower(*iter); else {
gram_lengths.back() += 1; current_ngram += tolower(*iter);
started_gram = true; gram_lengths.back() += 1;
} started_gram = true;
} }
}
return ngram_list; return ngram_list;
} }
// Takes a string and makes ngrams of length N, splitting grams on whitespace
// and ignoring ignored characters Returns a LARGE array of tuples of (ngram,
// start_idx, end_idx)
std::vector<std::tuple<std::string, size_t, size_t>>
clean_ngram_with_indices(std::string const &input, std::string const &ignore,
size_t ngram_n) noexcept {
size_t num_grams = 0;
std::vector<std::tuple<std::string, size_t, size_t>> ngram_list;
std::vector<uint8_t> gram_lengths;
std::vector<size_t> gram_start_indices;
std::string current_ngram;
// Max gram length is set to 10 below.
current_ngram.reserve(11 * ngram_n);
bool started_gram = false;
gram_lengths.push_back(0);
gram_start_indices.push_back(0);
for (size_t i = 0; i < input.length(); i++) {
char ch = input[i];
// If whitespace, end the current ngram and start the next
if (is_whitespace(ch) || gram_lengths.back() > 10) {
// Skip all whitespace
while (++i < input.length() && is_whitespace(input[i]))
;
i--;
if (started_gram) {
num_grams += 1;
// Building 1grams is a special case
if (ngram_n == 1) {
ngram_list.push_back(
std::make_tuple(current_ngram, gram_start_indices.front(), i));
current_ngram = current_ngram.substr(gram_lengths.front());
gram_lengths.back() = 0;
gram_start_indices.back() = i + 1;
// If there are enough grams to form an ngram, save
} else if (num_grams >= ngram_n) {
// Save the current ngram
ngram_list.push_back(
std::make_tuple(current_ngram, gram_start_indices.front(), i));
// Start the next ngram by dropping the first gram and its space from
// the ngram
current_ngram = current_ngram.substr(gram_lengths.front() + 1);
current_ngram += ' ';
// Drop the length of the first gram and prepare to record the length
// of the new gram
gram_lengths.erase(gram_lengths.begin());
gram_lengths.push_back(0);
gram_start_indices.erase(gram_start_indices.begin());
gram_start_indices.push_back(i + 1);
// Otherwise, continute building
} else {
current_ngram += ' ';
gram_lengths.push_back(0);
gram_start_indices.push_back(i + 1);
}
// Takes a string and makes ngrams of length N, splitting grams on whitespace and ignoring ignored characters started_gram = false;
// Returns a LARGE array of tuples of (ngram, start_idx, end_idx) }
std::vector<std::tuple<std::string, size_t, size_t> > clean_ngram_with_indices(
std::string const & input, std::string const & ignore, size_t ngram_n
) noexcept {
size_t num_grams = 0;
std::vector<std::tuple<std::string, size_t, size_t> > ngram_list;
std::vector<uint8_t> gram_lengths;
std::vector<size_t> gram_start_indices;
std::string current_ngram;
// Max gram length is set to 10 below.
current_ngram.reserve(11*ngram_n);
bool started_gram = false;
gram_lengths.push_back(0);
gram_start_indices.push_back(0);
for (size_t i=0; i<input.length(); i++) {
char ch = input[i];
// If whitespace, end the current ngram and start the next
if (is_whitespace(ch) || gram_lengths.back() > 10) {
// Skip all whitespace
while (++i < input.length() && is_whitespace(input[i]));
i--;
if (started_gram){
num_grams += 1;
// Building 1grams is a special case
if (ngram_n == 1){
ngram_list.push_back(std::make_tuple(current_ngram, gram_start_indices.front(), i));
current_ngram = current_ngram.substr(gram_lengths.front());
gram_lengths.back() = 0;
gram_start_indices.back() = i+1;
// If there are enough grams to form an ngram, save
} else if (num_grams >= ngram_n){
// Save the current ngram
ngram_list.push_back(
std::make_tuple(current_ngram, gram_start_indices.front(), i)
);
// Start the next ngram by dropping the first gram and its space from the ngram
current_ngram = current_ngram.substr(gram_lengths.front() + 1);
current_ngram += ' ';
// Drop the length of the first gram and prepare to record the length of the new gram
gram_lengths.erase(gram_lengths.begin());
gram_lengths.push_back(0);
gram_start_indices.erase(gram_start_indices.begin());
gram_start_indices.push_back(i+1);
// Otherwise, continute building
} else {
current_ngram += ' ';
gram_lengths.push_back(0);
gram_start_indices.push_back(i+1);
}
started_gram = false;
}
// Skip ignored characters // Skip ignored characters
} else if (ignore.find(*iter) != std::string::npos) { } else if (ignore.find(ch) != std::string::npos) {
continue; continue;
// If it is a non-ignored character, add it to the ngram and update the last gram's length // If it is a non-ignored character, add it to the ngram and update the
} else { // last gram's length
current_ngram += tolower(ch); } else {
gram_lengths.back() += 1; current_ngram += tolower(ch);
started_gram = true; gram_lengths.back() += 1;
} started_gram = true;
} }
}
return ngram_list; return ngram_list;
} }
PYBIND11_MODULE(janitor_util, m) { PYBIND11_MODULE(janitor_util, m) {
m.doc() = "pybind11 example plugin"; // optional module docstring m.doc() = "pybind11 example plugin"; // optional module docstring
// m.def("add", &add, "A function which adds two numbers"); // example function // m.def("add", &add, "A function which adds two numbers"); // example
m.def("clean_ngram", &clean_ngram, "Create ngrams of words, ignoring some characters"); // function
m.def("clean_ngram_with_indices", &clean_ngram_with_indices, "Create ngrams of words with indices, ignoring some characters"); m.def("clean_ngram", &clean_ngram,
"Create ngrams of words, ignoring some characters");
m.def("clean_ngram_with_indices", &clean_ngram_with_indices,
"Create ngrams of words with indices, ignoring some characters");
} }
// Example compile // Example compile
// c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) // c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes)
// If python and gcc aren't linked, append to the above: -undefined dynamic_lookup // janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) If
\ No newline at end of file // python and gcc aren't linked, append to the above: -undefined
// dynamic_lookup
...@@ -27,25 +27,33 @@ from scripts.clean_training_data.archiver import TextReader, TextArchive ...@@ -27,25 +27,33 @@ from scripts.clean_training_data.archiver import TextReader, TextArchive
import logging import logging
from tqdm_multiprocess.logger import setup_logger_tqdm from tqdm_multiprocess.logger import setup_logger_tqdm
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Multiprocessed
def process_bucket(bucket_file_path, processed_directory, move_dir, tqdm_func, global_tqdm):
bucket_id = re.sub("\D", "", os.path.basename(bucket_file_path)) # Multiprocessed
done_file = os.path.join(processed_directory, f"ngram_bucket_processing_{bucket_id}.done") def process_bucket(
bucket_file_path, processed_directory, move_dir, tqdm_func, global_tqdm
):
bucket_id = re.sub("\D", "", os.path.basename(bucket_file_path)) # noqa: W605
done_file = os.path.join(
processed_directory, f"ngram_bucket_processing_{bucket_id}.done"
)
if os.path.exists(done_file): if os.path.exists(done_file):
logger.info(f"bucket {bucket_id} already processed, skipping") logger.info(f"bucket {bucket_id} already processed, skipping")
return return
# For managing tqdm # For managing tqdm
file_size = os.path.getsize(bucket_file_path) file_size = os.path.getsize(bucket_file_path)
bucket_progress = tqdm_func(total=file_size, dynamic_ncols=True, unit="byte", unit_scale=1) bucket_progress = tqdm_func(
total=file_size, dynamic_ncols=True, unit="byte", unit_scale=1
)
current_file_position = 0 current_file_position = 0
update_frequency = 100 * 1000000 # 100mb update_frequency = 100 * 1000000 # 100mb
update_counter = 0 update_counter = 0
# Iterate through and output ngrams which occur in more then 10 documents # Iterate through and output ngrams which occur in more then 10 documents
bucket = TextReader(bucket_file_path) bucket = TextReader(bucket_file_path)
output_file_path = bucket_file_path + ".processed" output_file_path = bucket_file_path + ".processed"
...@@ -56,10 +64,12 @@ def process_bucket(bucket_file_path, processed_directory, move_dir, tqdm_func, g ...@@ -56,10 +64,12 @@ def process_bucket(bucket_file_path, processed_directory, move_dir, tqdm_func, g
for line in bucket.read(): for line in bucket.read():
[ngram, document_id] = line.rsplit(" ", 1) [ngram, document_id] = line.rsplit(" ", 1)
# Write ngram if more then 10 unique document occurences # Write ngram if more then 10 unique document occurrences
if ngram != current_ngram: if ngram != current_ngram:
if len(current_ngram_document_ids) > 10: if len(current_ngram_document_ids) > 10:
output_archive.add_data(f"{current_ngram} {len(current_ngram_document_ids)}") output_archive.add_data(
f"{current_ngram} {len(current_ngram_document_ids)}"
)
current_ngram = ngram current_ngram = ngram
current_ngram_document_ids = set() current_ngram_document_ids = set()
...@@ -84,28 +94,38 @@ def process_bucket(bucket_file_path, processed_directory, move_dir, tqdm_func, g ...@@ -84,28 +94,38 @@ def process_bucket(bucket_file_path, processed_directory, move_dir, tqdm_func, g
global_tqdm.update() global_tqdm.update()
def process_sorted_buckets(working_directory, move_dir, process_count): def process_sorted_buckets(working_directory, move_dir, process_count):
bucket_file_paths = glob.glob(os.path.join(working_directory, f"*.bkt.txt.sorted")) bucket_file_paths = glob.glob(os.path.join(working_directory, f"*.bkt.txt.sorted"))
processed_directory = os.path.join(working_directory, "processed") processed_directory = os.path.join(working_directory, "processed")
os.makedirs(processed_directory, exist_ok=True) os.makedirs(processed_directory, exist_ok=True)
pool = TqdmMultiProcessPool(process_count) pool = TqdmMultiProcessPool(process_count)
tasks = [(process_bucket, (bucket_file, processed_directory, move_dir)) for bucket_file in bucket_file_paths] tasks = [
(process_bucket, (bucket_file, processed_directory, move_dir))
for bucket_file in bucket_file_paths
]
global_tqdm = tqdm(total=len(bucket_file_paths), dynamic_ncols=True, unit="bucket") global_tqdm = tqdm(total=len(bucket_file_paths), dynamic_ncols=True, unit="bucket")
on_done = lambda _ : None
on_error = lambda _ : None def on_done(_):
return None
def on_error(_):
return None
_ = pool.map(global_tqdm, tasks, on_error, on_done) _ = pool.map(global_tqdm, tasks, on_error, on_done)
parser = argparse.ArgumentParser(description='Process 13 grams from sorted buckets.')
parser = argparse.ArgumentParser(description="Process 13 grams from sorted buckets.")
parser.add_argument("-dir", "--working_directory", default="") parser.add_argument("-dir", "--working_directory", default="")
parser.add_argument("-move", "--move_dir", default="") parser.add_argument("-move", "--move_dir", default="")
parser.add_argument("-procs", "--process_count", type=int, default=4) parser.add_argument("-procs", "--process_count", type=int, default=4)
if __name__ == '__main__': if __name__ == "__main__":
logfile_path = "process13grams.log" logfile_path = "process13grams.log"
setup_logger_tqdm(logfile_path) setup_logger_tqdm(logfile_path)
args = parser.parse_args() args = parser.parse_args()
process_sorted_buckets(args.working_directory, args.move_dir, args.process_count) process_sorted_buckets(args.working_directory, args.move_dir, args.process_count)
\ No newline at end of file
""" """
Iteratively runs gnu sort on each bucket, gnu handles the multiprocessing. Iteratively runs gnu sort on each bucket, uses up to 8 cores.
Arguments Arguments
--------- ---------
...@@ -11,48 +11,47 @@ Arguments ...@@ -11,48 +11,47 @@ Arguments
import glob import glob
import argparse import argparse
import os import os
from pathlib import Path
import signal import signal
from signal import SIGINT from signal import SIGINT
import re
import subprocess import subprocess
from tqdm import tqdm from tqdm import tqdm
import logging import logging
from tqdm_multiprocess.logger import setup_logger_tqdm from tqdm_multiprocess.logger import setup_logger_tqdm
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
terminate = False terminate = False
def handler(signal_received, frame): def handler(signal_received, frame):
global terminate global terminate
terminate = True terminate = True
def sort_13_gram_buckets(working_directory): def sort_13_gram_buckets(working_directory):
bucket_file_paths = glob.glob(os.path.join(working_directory, f"*.bkt.txt")) bucket_file_paths = glob.glob(os.path.join(working_directory, f"*.bkt.txt"))
for bucket_file_path in tqdm(bucket_file_paths, dynamic_ncols=True): for bucket_file_path in tqdm(bucket_file_paths, dynamic_ncols=True):
bucket_id = re.sub("\D", "", os.path.basename(bucket_file_path))
done_file = os.path.join(working_directory, f"ngram_bucket_sorting_{bucket_id}.done")
if os.path.exists(done_file):
logger.info(f"bucket {bucket_id} already processed, skipping")
return
sorted_file_path = bucket_file_path + ".sorted" sorted_file_path = bucket_file_path + ".sorted"
command = f"sort {bucket_file_path} > {sorted_file_path}" command = f"sort {bucket_file_path} > {sorted_file_path}"
logger.info(command) logger.info(command)
subprocess.call(command, shell=True) subprocess.call(command, shell=True)
if terminate: if terminate:
return return
Path(done_file).touch()
os.remove(bucket_file_path) os.remove(bucket_file_path)
parser = argparse.ArgumentParser(description='sort 13gram buckets')
parser = argparse.ArgumentParser(description="sort 13gram buckets")
parser.add_argument("-dir", "--working_directory", default="") parser.add_argument("-dir", "--working_directory", default="")
if __name__ == '__main__': if __name__ == "__main__":
version = 1.00
print(f"Running version {version}")
# Handle sigint (ctrl-c) cleanly # Handle sigint (ctrl-c) cleanly
previous_signal_int = signal.signal(SIGINT, handler) previous_signal_int = signal.signal(SIGINT, handler)
...@@ -61,4 +60,4 @@ if __name__ == '__main__': ...@@ -61,4 +60,4 @@ if __name__ == '__main__':
setup_logger_tqdm(logfile_path) setup_logger_tqdm(logfile_path)
args = parser.parse_args() args = parser.parse_args()
sort_13_gram_buckets(args.working_directory) sort_13_gram_buckets(args.working_directory)
\ No newline at end of file
...@@ -7,7 +7,7 @@ from lm_eval.base import LM ...@@ -7,7 +7,7 @@ from lm_eval.base import LM
class DryrunLM(LM): class DryrunLM(LM):
def __init__(self): def __init__(self):
self.tokencost = 0 self.tokencost = 0
self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2') self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained("gpt2")
self.tokenizer.pad_token = "<|endoftext|>" self.tokenizer.pad_token = "<|endoftext|>"
@classmethod @classmethod
...@@ -16,16 +16,16 @@ class DryrunLM(LM): ...@@ -16,16 +16,16 @@ class DryrunLM(LM):
def loglikelihood(self, requests): def loglikelihood(self, requests):
res = [] res = []
for ctx, cont in requests: for ctx, cont in requests:
res.append((-random.random(), False)) res.append((-random.random(), False))
self.tokencost += len(self.tokenizer.tokenize(ctx + cont)) self.tokencost += len(self.tokenizer.tokenize(ctx + cont))
return res return res
def greedy_until(self, requests): def greedy_until(self, requests):
res = [] res = []
for ctx, until in requests: for ctx, until in requests:
res.append("lol") res.append("lol")
...@@ -33,11 +33,11 @@ class DryrunLM(LM): ...@@ -33,11 +33,11 @@ class DryrunLM(LM):
self.tokencost += len(self.tokenizer.tokenize(ctx)) + 256 self.tokencost += len(self.tokenizer.tokenize(ctx)) + 256
return res return res
def loglikelihood_rolling(self, requests): def loglikelihood_rolling(self, requests):
res = [] res = []
for s, in requests: for (s,) in requests:
# assume worst case: extra full context # assume worst case: extra full context
self.tokencost += len(self.tokenizer.tokenize(s)) + 2048 self.tokencost += len(self.tokenizer.tokenize(s)) + 2048
...@@ -46,7 +46,7 @@ class DryrunLM(LM): ...@@ -46,7 +46,7 @@ class DryrunLM(LM):
def main(): def main():
lm = DryrunLM() lm = DryrunLM()
task_list = "arc_challenge,arc_easy,boolq,cola,copa,headqa,hellaswag,lambada,logiqa,mathqa,mc_taco,mrpc,multirc,openbookqa,piqa,prost,pubmedqa,qnli,qqp,race,record,rte,sciq,sst,triviaqa,webqs,wic,wikitext,winogrande,wnli,wsc" task_list = "arc_challenge,arc_easy,boolq,cola,copa,headqa,hellaswag,lambada,logiqa,mathqa,mc_taco,mrpc,multirc,openbookqa,piqa,prost,pubmedqa,qnli,qqp,race,record,rte,sciq,sst,triviaqa,webqs,wic,wikitext,winogrande,wnli,wsc"
values = [] values = []
for taskname in task_list.split(","): for taskname in task_list.split(","):
...@@ -57,11 +57,20 @@ def main(): ...@@ -57,11 +57,20 @@ def main():
num_fewshot=0, num_fewshot=0,
limit=None, limit=None,
bootstrap_iters=10, bootstrap_iters=10,
description_dict=None description_dict=None,
) )
print(taskname, lm.tokencost) print(taskname, lm.tokencost)
values.append([taskname, lm.tokencost, lm.tokencost / 1000 * 0.0008, lm.tokencost / 1000 * 0.0012, lm.tokencost / 1000 * 0.006, lm.tokencost / 1000 * 0.06]) values.append(
[
taskname,
lm.tokencost,
lm.tokencost / 1000 * 0.0008,
lm.tokencost / 1000 * 0.0012,
lm.tokencost / 1000 * 0.006,
lm.tokencost / 1000 * 0.06,
]
)
from pytablewriter import MarkdownTableWriter from pytablewriter import MarkdownTableWriter
writer = MarkdownTableWriter() writer = MarkdownTableWriter()
...@@ -69,10 +78,21 @@ def main(): ...@@ -69,10 +78,21 @@ def main():
values.sort(key=lambda x: -x[1]) values.sort(key=lambda x: -x[1])
totcost = sum([x[1] for x in values]) totcost = sum([x[1] for x in values])
values.append(["**Total**", totcost, totcost / 1000 * 0.0008, totcost / 1000 * 0.0012, totcost / 1000 * 0.006, totcost / 1000 * 0.06]) values.append(
[
"**Total**",
totcost,
totcost / 1000 * 0.0008,
totcost / 1000 * 0.0012,
totcost / 1000 * 0.006,
totcost / 1000 * 0.06,
]
)
writer.value_matrix = values writer.value_matrix = values
print(writer.dumps()) print(writer.dumps())
if __name__ == "__main__": if __name__ == "__main__":
main() main()
...@@ -3,16 +3,21 @@ from itertools import islice ...@@ -3,16 +3,21 @@ from itertools import islice
ct = 3 ct = 3
for tname, Task in tasks.TASK_REGISTRY.items():#[('record', tasks.superglue.ReCoRD)]:# for (
tname,
Task,
) in tasks.TASK_REGISTRY.items(): # [('record', tasks.superglue.ReCoRD)]:#
task = Task() task = Task()
print('#', tname) print("#", tname)
docs = islice(task.validation_docs() if task.has_validation_docs() else task.test_docs(), ct) docs = islice(
task.validation_docs() if task.has_validation_docs() else task.test_docs(), ct
)
print() print()
for i in range(ct): for i in range(ct):
print() print()
doc = next(docs) doc = next(docs)
print("**Context**:", "\n```\n" + task.doc_to_text(doc) + "\n```\n") print("**Context**:", "\n```\n" + task.doc_to_text(doc) + "\n```\n")
print() print()
print('**Target**:', "\n```\n" + task.doc_to_target(doc) + "\n```\n") print("**Target**:", "\n```\n" + task.doc_to_target(doc) + "\n```\n")
print() print()
...@@ -10,7 +10,7 @@ random.seed(42) ...@@ -10,7 +10,7 @@ random.seed(42)
data = [ data = [
"A multilayer perceptron (MLP) is a class of feedforward artificial neural network (ANN)", "A multilayer perceptron (MLP) is a class of feedforward artificial neural network (ANN)",
"The term MLP is used ambiguously, sometimes loosely to any feedforward ANN, sometimes strictly to refer to networks composed of multiple layers of perceptrons (with threshold activation); see § Terminology", "The term MLP is used ambiguously, sometimes loosely to any feedforward ANN, sometimes strictly to refer to networks composed of multiple layers of perceptrons (with threshold activation); see § Terminology",
"Multilayer perceptrons are sometimes colloquially referred to as \"vanilla\" neural networks, especially when they have a single hidden layer.[1]", 'Multilayer perceptrons are sometimes colloquially referred to as "vanilla" neural networks, especially when they have a single hidden layer.[1]',
"An MLP consists of at least three layers of nodes: an input layer, a hidden layer and an output layer. Except for the input nodes, each node is a neuron that uses a nonlinear activation function.", "An MLP consists of at least three layers of nodes: an input layer, a hidden layer and an output layer. Except for the input nodes, each node is a neuron that uses a nonlinear activation function.",
"MLP utilizes a supervised learning technique called backpropagation for training.[2][3] Its multiple layers and non-linear activation distinguish MLP from a linear perceptron. It can distinguish data that is not linearly separable.[4]", "MLP utilizes a supervised learning technique called backpropagation for training.[2][3] Its multiple layers and non-linear activation distinguish MLP from a linear perceptron. It can distinguish data that is not linearly separable.[4]",
"Recent work has demonstrated substantial gains on many NLP tasks and benchmarks by pre-training on a large corpus of text followed by fine-tuning on a specific task. While typically task-agnostic in architecture, this method still requires task-specific fine-tuning datasets of thousands or tens of thousands of examples. By contrast, humans can generally perform a new language task from only a few examples or from simple instructions - something which current NLP systems still largely struggle to do. Here we show that scaling up language models greatly improves task-agnostic, few-shot performance, sometimes even reaching competitiveness with prior state-of-the-art fine-tuning approaches. ", "Recent work has demonstrated substantial gains on many NLP tasks and benchmarks by pre-training on a large corpus of text followed by fine-tuning on a specific task. While typically task-agnostic in architecture, this method still requires task-specific fine-tuning datasets of thousands or tens of thousands of examples. By contrast, humans can generally perform a new language task from only a few examples or from simple instructions - something which current NLP systems still largely struggle to do. Here we show that scaling up language models greatly improves task-agnostic, few-shot performance, sometimes even reaching competitiveness with prior state-of-the-art fine-tuning approaches. ",
...@@ -20,22 +20,28 @@ data = [ ...@@ -20,22 +20,28 @@ data = [
] ]
model = transformers.GPT2LMHeadModel.from_pretrained('gpt2') model = transformers.GPT2LMHeadModel.from_pretrained("gpt2")
tok = transformers.GPT2Tokenizer.from_pretrained('gpt2') tok = transformers.GPT2Tokenizer.from_pretrained("gpt2")
tgs = [] tgs = []
for dat in data: for dat in data:
random.seed(dat) random.seed(dat)
#print(model(tok.encode(dat, return_tensors="pt"))[0][0]) # print(model(tok.encode(dat, return_tensors="pt"))[0][0])
toks = tok.encode(dat, return_tensors="pt") toks = tok.encode(dat, return_tensors="pt")
ind = random.randrange(len(toks[0])-1) ind = random.randrange(len(toks[0]) - 1)
logits = F.log_softmax(model(toks)[0], dim=-1)[:, :-1] # [batch, seq, vocab] logits = F.log_softmax(model(toks)[0], dim=-1)[:, :-1] # [batch, seq, vocab]
res = torch.gather(logits, 2, toks[:, 1:].unsqueeze(-1)).squeeze(-1)[0] res = torch.gather(logits, 2, toks[:, 1:].unsqueeze(-1)).squeeze(-1)[0]
tgs.append( float(res[ind:].sum())) tgs.append(float(res[ind:].sum()))
print(r'("""' + tok.decode(toks[0, :ind+1]) + r'""", """' + tok.decode(toks[0, ind+1:]) + r'"""), ') print(
r'("""'
+ tok.decode(toks[0, : ind + 1])
+ r'""", """'
+ tok.decode(toks[0, ind + 1 :])
+ r'"""), '
)
print(tgs) print(tgs)
\ No newline at end of file
...@@ -2,23 +2,32 @@ from lm_eval import tasks ...@@ -2,23 +2,32 @@ from lm_eval import tasks
from pytablewriter import MarkdownTableWriter from pytablewriter import MarkdownTableWriter
writer = MarkdownTableWriter() writer = MarkdownTableWriter()
writer.headers = ["Task Name", "Train", "Val", "Test","Val/Test Docs", "Metrics"] writer.headers = ["Task Name", "Train", "Val", "Test", "Val/Test Docs", "Metrics"]
values = [] values = []
def chk(tf): def chk(tf):
if tf: if tf:
return '✓' return "✓"
else: else:
return ' ' return " "
for tname, Task in tasks.TASK_REGISTRY.items(): for tname, Task in tasks.TASK_REGISTRY.items():
task = Task() task = Task()
v = [tname,chk(task.has_training_docs()),chk(task.has_validation_docs()),chk(task.has_test_docs()), len(list(task.test_docs() if task.has_test_docs() else task.validation_docs())),', '.join(task.aggregation().keys())] v = [
tname,
chk(task.has_training_docs()),
chk(task.has_validation_docs()),
chk(task.has_test_docs()),
len(list(task.test_docs() if task.has_test_docs() else task.validation_docs())),
", ".join(task.aggregation().keys()),
]
print(v) print(v)
values.append(v) values.append(v)
writer.value_matrix = values writer.value_matrix = values
print(writer.dumps()) print(writer.dumps())
\ No newline at end of file
...@@ -11,14 +11,14 @@ EXAMPLE_DIVIDER = "!!@@##@@!! -- Example {i}\n" ...@@ -11,14 +11,14 @@ EXAMPLE_DIVIDER = "!!@@##@@!! -- Example {i}\n"
def parse_args(): def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--output_base_path', required=True) parser.add_argument("--output_base_path", required=True)
parser.add_argument('--tasks', default="all_tasks") parser.add_argument("--tasks", default="all_tasks")
parser.add_argument('--provide_description', action="store_true") parser.add_argument("--provide_description", action="store_true")
parser.add_argument('--sets', type=str, default="val") # example: val,test parser.add_argument("--sets", type=str, default="val") # example: val,test
parser.add_argument('--num_fewshot', type=int, default=1) parser.add_argument("--num_fewshot", type=int, default=1)
parser.add_argument('--seed', type=int, default=42) parser.add_argument("--seed", type=int, default=42)
parser.add_argument('--num_examples', type=int, default=1) parser.add_argument("--num_examples", type=int, default=1)
parser.add_argument('--description_dict_path', default=None) parser.add_argument("--description_dict_path", default=None)
return parser.parse_args() return parser.parse_args()
...@@ -34,7 +34,7 @@ def main(): ...@@ -34,7 +34,7 @@ def main():
description_dict = {} description_dict = {}
if args.description_dict_path: if args.description_dict_path:
with open(args.description_dict_path, 'r') as f: with open(args.description_dict_path, "r") as f:
description_dict = json.load(f) description_dict = json.load(f)
os.makedirs(args.output_base_path, exist_ok=True) os.makedirs(args.output_base_path, exist_ok=True)
...@@ -45,26 +45,34 @@ def main(): ...@@ -45,26 +45,34 @@ def main():
iters = [] iters = []
for set in args.sets.split(","): for set in args.sets.split(","):
if set == 'train' and task.has_training_docs(): if set == "train" and task.has_training_docs():
docs = task.training_docs() docs = task.training_docs()
if set == 'val' and task.has_validation_docs(): if set == "val" and task.has_validation_docs():
docs = task.validation_docs() docs = task.validation_docs()
if set == 'test' and task.has_test_docs(): if set == "test" and task.has_test_docs():
docs = task.test_docs() docs = task.test_docs()
iters.append(docs) iters.append(docs)
docs = join_iters(iters) docs = join_iters(iters)
description = description_dict[task_name] if description_dict and task_name in description_dict else "" description = (
description_dict[task_name]
if description_dict and task_name in description_dict
else ""
)
with open(os.path.join(args.output_base_path, task_name), "w") as f: with open(os.path.join(args.output_base_path, task_name), "w") as f:
for i, doc in zip(range(args.num_examples), docs) if args.num_examples > 0 else enumerate(docs): for i, doc in (
zip(range(args.num_examples), docs)
if args.num_examples > 0
else enumerate(docs)
):
f.write(EXAMPLE_DIVIDER.format(i=i)) f.write(EXAMPLE_DIVIDER.format(i=i))
ctx = task.fewshot_context( ctx = task.fewshot_context(
doc=doc, doc=doc,
num_fewshot=args.num_fewshot, num_fewshot=args.num_fewshot,
rnd=rnd, rnd=rnd,
description=description description=description,
) )
f.write(ctx + "\n") f.write(ctx + "\n")
......
...@@ -18,9 +18,9 @@ setuptools.setup( ...@@ -18,9 +18,9 @@ setuptools.setup(
"License :: OSI Approved :: MIT License", "License :: OSI Approved :: MIT License",
"Operating System :: OS Independent", "Operating System :: OS Independent",
], ],
python_requires='>=3.6', python_requires=">=3.6",
install_requires=[ install_requires=[
"datasets==2.0.0", "datasets>=2.0.0",
"click>=7.1", "click>=7.1",
"scikit-learn>=0.24.1", "scikit-learn>=0.24.1",
"torch>=1.7", "torch>=1.7",
...@@ -40,10 +40,10 @@ setuptools.setup( ...@@ -40,10 +40,10 @@ setuptools.setup(
"openai==0.6.4", "openai==0.6.4",
"jieba==0.42.1", "jieba==0.42.1",
"nagisa==0.2.7", "nagisa==0.2.7",
"bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt" "bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt",
], ],
dependency_links=[ dependency_links=[
"https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt", "https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt",
], ],
extras_require={'dev': [ 'pytest', 'black' ]} extras_require={"dev": ["pytest", "black", "pre-commit"]},
) )
# TODO: Remove all TODO comments once the implementation is complete. # TODO: Remove all TODO comments once the implementation is complete.
""" """
TODO: Add the Paper Title on this line. TODO: Add the Paper Title on this line.
TODO: Add the paper's PDF URL (preferrably from arXiv) on this line. TODO: Add the paper's PDF URL (preferably from arXiv) on this line.
TODO: Write a Short Description of the task. TODO: Write a Short Description of the task.
......
# TODO: Remove all TODO comments once the implementation is complete. # TODO: Remove all TODO comments once the implementation is complete.
""" """
TODO: Add the Paper Title on this line. TODO: Add the Paper Title on this line.
TODO: Add the paper's PDF URL (preferrably from arXiv) on this line. TODO: Add the paper's PDF URL (preferably from arXiv) on this line.
TODO: Write a Short Description of the task. TODO: Write a Short Description of the task.
...@@ -45,7 +45,7 @@ class NewTask(Task): ...@@ -45,7 +45,7 @@ class NewTask(Task):
if self._training_docs is None: if self._training_docs is None:
# TODO: Return the training document generator from `self.dataset`. # TODO: Return the training document generator from `self.dataset`.
# If you need to process the data, `map` over the documents with # If you need to process the data, `map` over the documents with
# the custom procesing function, `self._process_doc`. E.g. # the custom processing function, `self._process_doc`. E.g.
# `map(self._process_doc, self.dataset["validation"])` # `map(self._process_doc, self.dataset["validation"])`
# In most case you can leave this as is unless the dataset split is # In most case you can leave this as is unless the dataset split is
# named differently than the default `"train"`. # named differently than the default `"train"`.
...@@ -56,7 +56,7 @@ class NewTask(Task): ...@@ -56,7 +56,7 @@ class NewTask(Task):
if self.has_validation_docs(): if self.has_validation_docs():
# TODO: Return the validation document generator from `self.dataset`. # TODO: Return the validation document generator from `self.dataset`.
# If you need to process the data, `map` over the documents with the # If you need to process the data, `map` over the documents with the
# custom procesing function, `self._process_doc`. E.g. # custom processing function, `self._process_doc`. E.g.
# `map(self._process_doc, self.dataset["validation"])` # `map(self._process_doc, self.dataset["validation"])`
# In most case you can leave this as is unless the dataset split is # In most case you can leave this as is unless the dataset split is
# named differently than the default `"validation"`. # named differently than the default `"validation"`.
......
...@@ -10,23 +10,24 @@ import pytest ...@@ -10,23 +10,24 @@ import pytest
# TODO: more fine grained unit tests rather than this big honking integration # TODO: more fine grained unit tests rather than this big honking integration
# test once we break evaluator into smaller, more manageable pieces # test once we break evaluator into smaller, more manageable pieces
@pytest.mark.parametrize("taskname,task_class", tasks.TASK_REGISTRY.items()) @pytest.mark.parametrize("taskname,task_class", tasks.TASK_REGISTRY.items())
def test_evaluator(taskname, task_class): def test_evaluator(taskname, task_class):
task_dict = tasks.get_task_dict([taskname]) task_dict = tasks.get_task_dict([taskname])
os.system("rm test_cache.db") os.system("rm test_cache.db")
lm = base.CachingLM(models.get_model('dummy')(), "test_cache.db") lm = base.CachingLM(models.get_model("dummy")(), "test_cache.db")
def ll_fn(reqs): def ll_fn(reqs):
for ctx, cont in reqs: for ctx, cont in reqs:
if len(ctx) == 0: if len(ctx) == 0:
continue continue
# space convention # space convention
assert ctx[-1] != ' ' assert ctx[-1] != " "
assert cont[0] == ' ' or ctx[-1] == '\n' assert cont[0] == " " or ctx[-1] == "\n"
res = [] res = []
random.seed(42) random.seed(42)
for _ in reqs: for _ in reqs:
res.append((-random.random(), False)) res.append((-random.random(), False))
...@@ -34,7 +35,7 @@ def test_evaluator(taskname, task_class): ...@@ -34,7 +35,7 @@ def test_evaluator(taskname, task_class):
return res return res
def ll_perp_fn(reqs): def ll_perp_fn(reqs):
for string, in reqs: for (string,) in reqs:
assert isinstance(string, str) assert isinstance(string, str)
res = [] res = []
...@@ -49,20 +50,20 @@ def test_evaluator(taskname, task_class): ...@@ -49,20 +50,20 @@ def test_evaluator(taskname, task_class):
limit = 10 limit = 10
e1 = evaluator.evaluate( e1 = evaluator.evaluate(
lm=lm, lm=lm,
task_dict=task_dict, task_dict=task_dict,
num_fewshot=0, num_fewshot=0,
limit=limit, limit=limit,
bootstrap_iters=10, bootstrap_iters=10,
description_dict=None description_dict=None,
) )
e2 = evaluator.evaluate( e2 = evaluator.evaluate(
lm=lm, lm=lm,
task_dict=task_dict, task_dict=task_dict,
num_fewshot=0, num_fewshot=0,
limit=limit, limit=limit,
bootstrap_iters=10, bootstrap_iters=10,
description_dict=None description_dict=None,
) )
# check that caching is working # check that caching is working
......
...@@ -3,27 +3,32 @@ from collections import Counter ...@@ -3,27 +3,32 @@ from collections import Counter
import shutil import shutil
import glob import glob
from scripts.clean_training_data.janitor import * from lm_eval.decontamination.janitor import Janitor, word_ngrams
from scripts.clean_training_data.generate_13_grams import do_ngrams_in_buckets from scripts.clean_training_data.generate_13_grams import do_ngrams_in_buckets
from scripts.clean_training_data.archiver import Archive, TextReader from lm_eval.decontamination.archiver import Archive, TextReader
import logging
def test_generate_13_grams_1(): logger = logging.getLogger(__name__)
data = """A goose (plural geese) is a bird of any of several waterfowl species in the family Anatidae.
This group comprises the genera Anser (the grey geese and white geese) and Branta (the black geese).
Some other birds, mostly related to the shelducks, have "goose" as part of their names. def test_generate_13_grams_1(caplog):
More distantly related members of the family Anatidae are swans, most of which are larger data = """A goose (plural geese) is a bird of any of several waterfowl species in the family Anatidae.
than true geese, and ducks, which are smaller. The term "goose" may refer to either a male This group comprises the genera Anser (the grey geese and white geese) and Branta (the black geese).
or female bird, but when paired with "gander", refers specifically to a female one (the latter referring Some other birds, mostly related to the shelducks, have "goose" as part of their names.
to a male). Young birds before fledging are called goslings. The collective noun for a group of More distantly related members of the family Anatidae are swans, most of which are larger
geese on the ground is a gaggle; when in flight, they are called a skein, a team, or a wedge; when than true geese, and ducks, which are smaller. The term "goose" may refer to either a male
or female bird, but when paired with "gander", refers specifically to a female one (the latter referring
to a male). Young birds before fledging are called goslings. The collective noun for a group of
geese on the ground is a gaggle; when in flight, they are called a skein, a team, or a wedge; when
flying close together, they are called a plump.""" flying close together, they are called a plump."""
data = data + data data = data + data
# Simple Generation # Simple Generation
print("simple generation")
n = 13 n = 13
janitor = Janitor() janitor = Janitor()
ngrams = word_ngrams(janitor.normalize_string(data), n) ngrams = word_ngrams(janitor.normalize_string(data), n)
comparison = list(ngrams) comparison = list(ngrams)
comparison_counter = Counter(comparison) comparison_counter = Counter(comparison)
...@@ -31,35 +36,42 @@ def test_generate_13_grams_1(): ...@@ -31,35 +36,42 @@ def test_generate_13_grams_1():
# print(comparison) # print(comparison)
# Generating into buckets # Generating into buckets
print("bucket generation")
test_working_directory = "test_generate_13_grams" test_working_directory = "test_generate_13_grams"
output_directory = os.path.join(test_working_directory, "output")
try: try:
shutil.rmtree(output_directory) shutil.rmtree(test_working_directory)
except FileNotFoundError: except FileNotFoundError:
pass pass
os.makedirs(test_working_directory, exist_ok=True) os.makedirs(test_working_directory)
archive = Archive(os.path.join(test_working_directory, "test.jsonl.zst"))
assert not os.path.exists("pile")
os.makedirs("pile")
archive = Archive(os.path.join("pile", "test.jsonl.zst"))
archive.add_data(data) archive.add_data(data)
archive.commit() archive.commit()
bucket_count = 4 bucket_count = 4
do_ngrams_in_buckets(n, test_working_directory, bucket_count) do_ngrams_in_buckets(n, test_working_directory, bucket_count)
# Rebuild from buckets # Rebuild from buckets
print("rebuild")
rebuilt_ngrams = [] rebuilt_ngrams = []
bucket_file_paths = glob.glob(
bucket_file_paths = glob.glob(os.path.join(test_working_directory, "output", f"*.bkt.txt")) os.path.join(test_working_directory, "output", f"*.bkt.txt")
)
for bucket_file_path in bucket_file_paths: for bucket_file_path in bucket_file_paths:
reader = TextReader(bucket_file_path) reader = TextReader(bucket_file_path)
for line in reader.read(): for line in reader.read():
[ngram, document_id] = line.rsplit(" ", 1) [ngram, document_id] = line.rsplit(" ", 1)
rebuilt_ngrams.append(ngram) rebuilt_ngrams.append(ngram)
# Compare # Compare
print("compare")
result_counter = Counter(rebuilt_ngrams) result_counter = Counter(rebuilt_ngrams)
# print(len(result_counter)) # print(len(result_counter))
# print(len(comparison_counter)) # print(len(comparison_counter))
assert(len(result_counter) == len(comparison_counter)) assert len(result_counter) == len(comparison_counter)
# print(result_counter) # print(result_counter)
# print(comparison_counter) # print(comparison_counter)
assert(comparison_counter == result_counter) assert comparison_counter == result_counter
\ No newline at end of file
...@@ -12,40 +12,78 @@ def mock_completion(**kwargs): ...@@ -12,40 +12,78 @@ def mock_completion(**kwargs):
# Mock completion function # Mock completion function
# Loads from a cached+pickled response if it exists, otherwise it will actually try to ping # Loads from a cached+pickled response if it exists, otherwise it will actually try to ping
os.makedirs("tests/testdata", exist_ok=True) os.makedirs("tests/testdata", exist_ok=True)
hash = hashlib.sha256(json.dumps(kwargs, sort_keys=True).encode('utf-8')).hexdigest() hash = hashlib.sha256(
json.dumps(kwargs, sort_keys=True).encode("utf-8")
).hexdigest()
fname = f"tests/testdata/gpt3_test_{hash}.pkl" fname = f"tests/testdata/gpt3_test_{hash}.pkl"
if os.path.exists(fname): if os.path.exists(fname):
with open(fname, 'rb') as fh: with open(fname, "rb") as fh:
return pickle.load(fh) return pickle.load(fh)
ret = openai.Completion.create(**kwargs) ret = openai.Completion.create(**kwargs)
ret.api_key = "" ret.api_key = ""
with open(fname, 'wb') as fh: with open(fname, "wb") as fh:
pickle.dump(ret, fh) pickle.dump(ret, fh)
return ret return ret
@mock.patch("lm_eval.models.gpt3.oa_completion", new=mock_completion) @mock.patch("lm_eval.models.gpt3.oa_completion", new=mock_completion)
def test_gpt3(): def test_gpt3():
if "OPENAI_API_SECRET_KEY" not in os.environ: os.environ["OPENAI_API_SECRET_KEY"] = "" if "OPENAI_API_SECRET_KEY" not in os.environ:
gpt3 = models.get_model('gpt3').create_from_arg_string("engine=ada") os.environ["OPENAI_API_SECRET_KEY"] = ""
(ll_dog, ig_dog), (ll_cat, ig_cat), (_, ll_max_0), (_, ll_max_1), (_, ll_max_2), *vals = gpt3.loglikelihood([ gpt3 = models.get_model("gpt3").create_from_arg_string("engine=ada")
('The quick brown fox jumps over the lazy', ' dog'), (
('The quick brown fox jumps over the lazy', ' cat'), (ll_dog, ig_dog),
('The quick brown fox jumps over the lazy', ', lazy dog'), (ll_cat, ig_cat),
('The quick brown fox jumps over the lazy', ', lazy fox'), (_, ll_max_0),
('The quick brown fox jumps over the lazy', ', lazy fox and they both fall to the ground'), (_, ll_max_1),
(_, ll_max_2),
("""A mult""", """ilayer perceptron (MLP) is a class of feedforward artificial neural network (ANN)"""), *vals,
("""The term MLP is used ambiguously, sometimes loosely to any feedforward ANN, sometimes strictly to refer to networks composed of multiple layers of perceptrons""", """ (with threshold activation); see § Terminology"""), ) = gpt3.loglikelihood(
("""Multilayer perceptrons are sometimes coll""", """oquially referred to as "vanilla" neural networks, especially when they have a single hidden layer.[1]"""), [
("""An MLP consists of at least three layers of nodes: an input layer, a hidden layer and an output layer. Except for the input nodes, each node is a neuron that uses a nonlinear""", """ activation function."""), ("The quick brown fox jumps over the lazy", " dog"),
("""MLP utilizes a supervised""", """ learning technique called backpropagation for training.[2][3] Its multiple layers and non-linear activation distinguish MLP from a linear perceptron. It can distinguish data that is not linearly separable.[4]"""), ("The quick brown fox jumps over the lazy", " cat"),
("""Recent work has demonstrated substantial gains on many NLP tasks and benchmarks by pre-training on a large corpus of text followed by fine-tuning on a specific task. While typically task-agnostic""", """ in architecture, this method still requires task-specific fine-tuning datasets of thousands or tens of thousands of examples. By contrast, humans can generally perform a new language task from only a few examples or from simple instructions - something which current NLP systems still largely struggle to do. Here we show that scaling up language models greatly improves task-agnostic, few-shot performance, sometimes even reaching competitiveness with prior state-of-the-art fine-tuning approaches. """), ("The quick brown fox jumps over the lazy", ", lazy dog"),
("""Specifically, we train GPT-3, an autoregressive language model with 175""", """ billion parameters, 10x more than any previous non-sparse language model, and test its performance in the few-shot setting. For all tasks, GPT-3 is applied without any gradient updates or fine-tuning, with tasks and few-shot demonstrations specified purely via text interaction with the model. GPT-3 achieves strong performance on many NLP datasets, including translation, question-answering, and cloze tasks, as well as several tasks that require on-the-fly reasoning or domain adaptation, such as unscrambling words, using a novel word in a sentence, or performing 3-digit arithmetic. At the same time, we also identify some datasets where GPT-3's few-shot learning still struggles, as well as some datasets where GPT-3 faces methodological issues related to training on large web corpora. Finally, we find that GPT-3 can generate samples of news articles which human evaluators have difficulty distinguishing from articles written by humans. We discuss broader societal impacts of this finding and of GPT-3 in general."""), ("The quick brown fox jumps over the lazy", ", lazy fox"),
("""A mult""", """ilayer perceptron (MLP) is a class of feedforward artificial neural network (ANN)"""), (
("""Hello""", """ World"""), "The quick brown fox jumps over the lazy",
]) ", lazy fox and they both fall to the ground",
),
(
"""A mult""",
"""ilayer perceptron (MLP) is a class of feedforward artificial neural network (ANN)""",
),
(
"""The term MLP is used ambiguously, sometimes loosely to any feedforward ANN, sometimes strictly to refer to networks composed of multiple layers of perceptrons""",
""" (with threshold activation); see § Terminology""",
),
(
"""Multilayer perceptrons are sometimes coll""",
"""oquially referred to as "vanilla" neural networks, especially when they have a single hidden layer.[1]""",
),
(
"""An MLP consists of at least three layers of nodes: an input layer, a hidden layer and an output layer. Except for the input nodes, each node is a neuron that uses a nonlinear""",
""" activation function.""",
),
(
"""MLP utilizes a supervised""",
""" learning technique called backpropagation for training.[2][3] Its multiple layers and non-linear activation distinguish MLP from a linear perceptron. It can distinguish data that is not linearly separable.[4]""",
),
(
"""Recent work has demonstrated substantial gains on many NLP tasks and benchmarks by pre-training on a large corpus of text followed by fine-tuning on a specific task. While typically task-agnostic""",
""" in architecture, this method still requires task-specific fine-tuning datasets of thousands or tens of thousands of examples. By contrast, humans can generally perform a new language task from only a few examples or from simple instructions - something which current NLP systems still largely struggle to do. Here we show that scaling up language models greatly improves task-agnostic, few-shot performance, sometimes even reaching competitiveness with prior state-of-the-art fine-tuning approaches. """,
),
(
"""Specifically, we train GPT-3, an autoregressive language model with 175""",
""" billion parameters, 10x more than any previous non-sparse language model, and test its performance in the few-shot setting. For all tasks, GPT-3 is applied without any gradient updates or fine-tuning, with tasks and few-shot demonstrations specified purely via text interaction with the model. GPT-3 achieves strong performance on many NLP datasets, including translation, question-answering, and cloze tasks, as well as several tasks that require on-the-fly reasoning or domain adaptation, such as unscrambling words, using a novel word in a sentence, or performing 3-digit arithmetic. At the same time, we also identify some datasets where GPT-3's few-shot learning still struggles, as well as some datasets where GPT-3 faces methodological issues related to training on large web corpora. Finally, we find that GPT-3 can generate samples of news articles which human evaluators have difficulty distinguishing from articles written by humans. We discuss broader societal impacts of this finding and of GPT-3 in general.""",
),
(
"""A mult""",
"""ilayer perceptron (MLP) is a class of feedforward artificial neural network (ANN)""",
),
("""Hello""", """ World"""),
]
)
assert ll_dog > ll_cat assert ll_dog > ll_cat
assert not ig_cat assert not ig_cat
...@@ -56,19 +94,26 @@ def test_gpt3(): ...@@ -56,19 +94,26 @@ def test_gpt3():
assert not ll_max_2 assert not ll_max_2
# test empty context # test empty context
gpt3.loglikelihood([('', 'test')]) gpt3.loglikelihood([("", "test")])
gen, = gpt3.greedy_until([ (gen,) = gpt3.greedy_until(
('The quick brown fox jumps over the lazy', ['.', '\n']) [("The quick brown fox jumps over the lazy", [".", "\n"])]
]) )
assert gen == ' dog' assert gen == " dog"
print([x[0] for x in vals]) print([x[0] for x in vals])
targets = [ targets = [
-34.848301606999996, -47.148329679999996, -45.44380149599999, -5.285246016, -133.97821690686004, -34.848301606999996,
-321.2616693239001, -658.0299524401041, -34.848301606999996, -7.525115, -47.148329679999996,
-45.44380149599999,
-5.285246016,
-133.97821690686004,
-321.2616693239001,
-658.0299524401041,
-34.848301606999996,
-7.525115,
] ]
for (pred, _), tgt in zip(vals, targets): for (pred, _), tgt in zip(vals, targets):
...@@ -77,17 +122,20 @@ def test_gpt3(): ...@@ -77,17 +122,20 @@ def test_gpt3():
@mock.patch("lm_eval.models.gpt3.oa_completion", new=mock_completion) @mock.patch("lm_eval.models.gpt3.oa_completion", new=mock_completion)
def test_gpt3_perplexity(): def test_gpt3_perplexity():
if "OPENAI_API_SECRET_KEY" not in os.environ: os.environ["OPENAI_API_SECRET_KEY"] = "" if "OPENAI_API_SECRET_KEY" not in os.environ:
gpt3 = models.get_model('gpt3').create_from_arg_string("engine=ada") os.environ["OPENAI_API_SECRET_KEY"] = ""
gpt3 = models.get_model("gpt3").create_from_arg_string("engine=ada")
test_string = "We study empirical scaling laws for language model performance on the cross-entropy loss." test_string = "We study empirical scaling laws for language model performance on the cross-entropy loss."
perplexity = gpt3.loglikelihood_rolling([(test_string,)])[0] perplexity = gpt3.loglikelihood_rolling([(test_string,)])[0]
tgt = -84.38819608 tgt = -84.38819608
assert perplexity == pytest.approx(tgt, rel=1e-3) assert perplexity == pytest.approx(tgt, rel=1e-3)
# Hack: modify gpt3 to have shorter context length to induce rolling windows # Hack: modify gpt3 to have shorter context length to induce rolling windows
with mock.patch.object(models.gpt3.GPT3LM, 'max_length', new_callable=mock.PropertyMock) as mock_max_length: with mock.patch.object(
models.gpt3.GPT3LM, "max_length", new_callable=mock.PropertyMock
) as mock_max_length:
mock_max_length.return_value = 5 mock_max_length.return_value = 5
gpt3 = models.get_model('gpt3').create_from_arg_string("engine=ada") gpt3 = models.get_model("gpt3").create_from_arg_string("engine=ada")
perplexity = gpt3.loglikelihood_rolling([(test_string,)])[0] perplexity = gpt3.loglikelihood_rolling([(test_string,)])[0]
tgt = -101.81967209999999 tgt = -101.81967209999999
assert perplexity == pytest.approx(tgt, rel=1e-3) assert perplexity == pytest.approx(tgt, rel=1e-3)
import re import re
from collections import defaultdict from collections import defaultdict
from scripts.clean_training_data.janitor import * from lm_eval.decontamination.janitor import (
Janitor,
form_ngrams,
word_ngrams,
split_indices,
word_ngrams_indices,
)
def simple_ngram(sequence, n): def simple_ngram(sequence, n):
ngrams = list() ngrams = list()
...@@ -16,8 +23,10 @@ def simple_ngram(sequence, n): ...@@ -16,8 +23,10 @@ def simple_ngram(sequence, n):
def test_form_ngrams(): def test_form_ngrams():
sequence = "Hello my name is Bob, I like eating pizza, chicken, chips and ice cream. Maybe I should eat some" \ sequence = (
" more salad but it's so booooring. I just... like eating pizza, chicken, chips and ice cream so much." "Hello my name is Bob, I like eating pizza, chicken, chips and ice cream. Maybe I should eat some"
" more salad but it's so booooring. I just... like eating pizza, chicken, chips and ice cream so much."
)
n_values = [1, 2, 3, 5, 13] n_values = [1, 2, 3, 5, 13]
for n in n_values: for n in n_values:
...@@ -26,9 +35,12 @@ def test_form_ngrams(): ...@@ -26,9 +35,12 @@ def test_form_ngrams():
assert len(comparison) == len(result_to_test) assert len(comparison) == len(result_to_test)
assert comparison == result_to_test assert comparison == result_to_test
def test_word_ngrams(): def test_word_ngrams():
sequence = "Hello my name is Bob, I like eating pizza, chicken, chips and ice cream. Maybe I should eat some" \ sequence = (
" more salad but it's so booooring. I just... like eating pizza, chicken, chips and ice cream so much." "Hello my name is Bob, I like eating pizza, chicken, chips and ice cream. Maybe I should eat some"
" more salad but it's so booooring. I just... like eating pizza, chicken, chips and ice cream so much."
)
words = sequence.split() words = sequence.split()
...@@ -40,9 +52,12 @@ def test_word_ngrams(): ...@@ -40,9 +52,12 @@ def test_word_ngrams():
assert len(comparison) == len(result_to_test) assert len(comparison) == len(result_to_test)
assert result_to_test == comparison assert result_to_test == comparison
def test_split_indices(): def test_split_indices():
sequence = "Hello my name is Bob, I like eating pizza, chicken, chips and ice cream. Maybe I should eat some" \ sequence = (
" more salad but it's so booooring. I just... like eating pizza, chicken, chips and ice cream so much." "Hello my name is Bob, I like eating pizza, chicken, chips and ice cream. Maybe I should eat some"
" more salad but it's so booooring. I just... like eating pizza, chicken, chips and ice cream so much."
)
comparison = [] comparison = []
current_word = "" current_word = ""
...@@ -55,17 +70,22 @@ def test_split_indices(): ...@@ -55,17 +70,22 @@ def test_split_indices():
current_word = "" current_word = ""
if current_word: if current_word:
comparison.append((current_word, (len(sequence) - len(current_word), len(sequence) - 1))) comparison.append(
current_word = "" (current_word, (len(sequence) - len(current_word), len(sequence) - 1))
)
current_word = ""
result_to_test = list(split_indices(sequence)) result_to_test = list(split_indices(sequence))
assert len(comparison) == len(result_to_test) assert len(comparison) == len(result_to_test)
assert(comparison == result_to_test) assert comparison == result_to_test
def test_word_ngrams_indices(): def test_word_ngrams_indices():
sequence = "Hello my name is Bob, I like eating pizza, chicken, chips and ice cream. Maybe I should eat some" \ sequence = (
" more salad but it's so booooring. I just... like eating pizza, chicken, chips and ice cream so much." "Hello my name is Bob, I like eating pizza, chicken, chips and ice cream. Maybe I should eat some"
" more salad but it's so booooring. I just... like eating pizza, chicken, chips and ice cream so much."
)
n_values = [1, 2, 3, 5, 13] n_values = [1, 2, 3, 5, 13]
...@@ -76,55 +96,62 @@ def test_word_ngrams_indices(): ...@@ -76,55 +96,62 @@ def test_word_ngrams_indices():
for ngram in ngrams: for ngram in ngrams:
while True: while True:
start = sequence.find(ngram, tracker[ngram]) start = sequence.find(ngram, tracker[ngram])
assert start != -1 # testing the test assert start != -1 # testing the test
end = start + len(ngram) - 1 end = start + len(ngram) - 1
tracker[ngram] = end + 1 tracker[ngram] = end + 1
# ignore partial word matches # ignore partial word matches
if (start != 0 and sequence[start - 1] != " ") or \ if (start != 0 and sequence[start - 1] != " ") or (
(end != len(sequence) - 1 and sequence[end + 1] != " "): end != len(sequence) - 1 and sequence[end + 1] != " "
):
pass pass
else: else:
break break
comparison.append((ngram, (start, end))) comparison.append((ngram, (start, end)))
result_to_test = list(word_ngrams_indices(sequence, n)) result_to_test = list(word_ngrams_indices(sequence, n))
assert len(result_to_test) == len(comparison) assert len(result_to_test) == len(comparison)
assert result_to_test == comparison assert result_to_test == comparison
# Assumptions from GPT3 Paper: # Assumptions from GPT3 Paper:
# the 200 characters to remove include punctuation and is actually a half-window # the 200 characters to remove include punctuation and is actually a half-window
# All tests below initially test without any registered contaminants, expecting the same sequence back. # All tests below initially test without any registered contaminants, expecting the same sequence back.
def test_janitor1(): def test_janitor1():
# First test using a 1gram and expected the first block before the filth to have some remaining # First test using a 1gram and expected the first block before the filth to have some remaining
# characters, but the second block should be completely removed. # characters, but the second block should be completely removed.
sequence = "This is a @line #containing a certain number of characters, 76 to be exact. " \ sequence = (
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"FILTH. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "FILTH. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
)
filth = "filth" filth = "filth"
expected_result = "This is a @line #containing a certain number of characters, 76 to be exact. " \ expected_result = (
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing " "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing "
)
janitor = Janitor(ngram_n=1, window_to_remove=200, too_dirty_cutoff=10, minimum_slice_length=200) janitor = Janitor(
ngram_n=1, window_to_remove=200, too_dirty_cutoff=10, minimum_slice_length=200
)
result = janitor.clean_python(sequence) result = janitor.clean_python(sequence)
result = "".join(result) result = "".join(result)
assert result == sequence assert result == sequence
...@@ -133,42 +160,47 @@ def test_janitor1(): ...@@ -133,42 +160,47 @@ def test_janitor1():
assert janitor.dirt_ngrams == {filth} assert janitor.dirt_ngrams == {filth}
result = janitor.clean_python(sequence) result = janitor.clean_python(sequence)
result = "".join(result) result = "".join(result)
assert result == expected_result assert result == expected_result
def test_janitor2(): def test_janitor2():
# Second test using a 1gram and expected the first block before the filth to have some remaining # Second test using a 1gram and expected the first block before the filth to have some remaining
# characters, and the second block is longer then 200 characters so should also have some remaining. # characters, and the second block is longer then 200 characters so should also have some remaining.
sequence = "This is a @line #containing a certain number of characters, 76 to be exact. " \ sequence = (
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"FILTH. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "FILTH. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
)
filth = "filth" filth = "filth"
expected_result = "This is a @line #containing a certain number of characters, 76 to be exact. " \ expected_result = (
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
" characters, 76 to be exact. " \ "This is a @line #containing "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ " characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
janitor = Janitor(ngram_n=1, window_to_remove=200, too_dirty_cutoff=10, minimum_slice_length=200) )
janitor = Janitor(
ngram_n=1, window_to_remove=200, too_dirty_cutoff=10, minimum_slice_length=200
)
result = janitor.clean_python(sequence) result = janitor.clean_python(sequence)
result = "".join(result) result = "".join(result)
assert result == sequence assert result == sequence
...@@ -180,37 +212,43 @@ def test_janitor2(): ...@@ -180,37 +212,43 @@ def test_janitor2():
result = "".join(result) result = "".join(result)
assert result == expected_result assert result == expected_result
def test_janitor3(): def test_janitor3():
# Same test as above but with a 6gram. # Same test as above but with a 6gram.
sequence = "This is a @line #containing a certain number of characters, 76 to be exact. " \ sequence = (
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"FILTH. lots of dirty filtHy FIlTh " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "FILTH. lots of dirty filtHy FIlTh "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
)
filth = "filth lots of dirty filthy filth" filth = "filth lots of dirty filthy filth"
expected_result = "This is a @line #containing a certain number of characters, 76 to be exact. " \ expected_result = (
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
" characters, 76 to be exact. " \ "This is a @line #containing "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ " characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
janitor = Janitor(ngram_n=6, window_to_remove=200, too_dirty_cutoff=10, minimum_slice_length=200) )
janitor = Janitor(
ngram_n=6, window_to_remove=200, too_dirty_cutoff=10, minimum_slice_length=200
)
result = janitor.clean_python(sequence) result = janitor.clean_python(sequence)
result = "".join(result) result = "".join(result)
assert result == sequence assert result == sequence
...@@ -222,45 +260,51 @@ def test_janitor3(): ...@@ -222,45 +260,51 @@ def test_janitor3():
result = "".join(result) result = "".join(result)
assert result == expected_result assert result == expected_result
def test_janitor4(): def test_janitor4():
# This test adds another block to that from the previous. The middle block should be entirely # This test adds another block to that from the previous. The middle block should be entirely
# removed as the 200 characters are removed from each side. # removed as the 200 characters are removed from each side.
sequence = "This is a @line #containing a certain number of characters, 76 to be exact. " \ sequence = (
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"FILTH. lots of dirty filtHy FIlTh " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "FILTH. lots of dirty filtHy FIlTh "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"FILTH. lots of dirty filtHy FIlTh " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "FILTH. lots of dirty filtHy FIlTh "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
)
filth = "filth lots of dirty filthy filth" filth = "filth lots of dirty filthy filth"
expected_result = "This is a @line #containing a certain number of characters, 76 to be exact. " \ expected_result = (
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
" characters, 76 to be exact. " \ "This is a @line #containing "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ " characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
janitor = Janitor(ngram_n=6, window_to_remove=200, too_dirty_cutoff=10, minimum_slice_length=200) )
janitor = Janitor(
ngram_n=6, window_to_remove=200, too_dirty_cutoff=10, minimum_slice_length=200
)
result = janitor.clean_python(sequence) result = janitor.clean_python(sequence)
result = "".join(result) result = "".join(result)
assert result == sequence assert result == sequence
...@@ -272,49 +316,55 @@ def test_janitor4(): ...@@ -272,49 +316,55 @@ def test_janitor4():
result = "".join(result) result = "".join(result)
assert result == expected_result assert result == expected_result
def test_janitor5(): def test_janitor5():
# Same as above but using multiple different filth 6grams. # Same as above but using multiple different filth 6grams.
sequence = "This is a @line #containing a certain number of characters, 76 to be exact. " \ sequence = (
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"FILTH. lots of dirty filtHy FIlTh " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "FILTH. lots of dirty filtHy FIlTh "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"FILTH. lots of filtHy dirty FIlTh " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "FILTH. lots of filtHy dirty FIlTh "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
)
filths = ["filth lots of dirty filthy filth", "filth lots of filthy dirty filth"]
filths = ["filth lots of dirty filthy filth", "filth lots of filthy dirty filth"]
expected_result = "This is a @line #containing a certain number of characters, 76 to be exact. " \
"This is a @line #containing a certain number of characters, 76 to be exact. " \ expected_result = (
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
" characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ " characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
janitor = Janitor(ngram_n=6, window_to_remove=200, too_dirty_cutoff=10, minimum_slice_length=200) "This is a @line #containing a certain number of characters, 76 to be exact. "
)
janitor = Janitor(
ngram_n=6, window_to_remove=200, too_dirty_cutoff=10, minimum_slice_length=200
)
result = janitor.clean_python(sequence) result = janitor.clean_python(sequence)
result = "".join(result) result = "".join(result)
assert result == sequence assert result == sequence
for filth in filths: for filth in filths:
janitor.register_contaminant(filth) janitor.register_contaminant(filth)
assert janitor.dirt_ngrams == set(filths) assert janitor.dirt_ngrams == set(filths)
...@@ -322,57 +372,63 @@ def test_janitor5(): ...@@ -322,57 +372,63 @@ def test_janitor5():
result = "".join(result) result = "".join(result)
assert result == expected_result assert result == expected_result
def test_janitor6(): def test_janitor6():
# Same as above but now we add 10 filths and expect the same result, the following test does 11. # Same as above but now we add 10 filths and expect the same result, the following test does 11.
sequence = "This is a @line #containing a certain number of characters, 76 to be exact. " \ sequence = (
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"FILTH. lots of dirty filtHy FIlTh " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"FILTH. lots of dirty filtHy FIlTh " \ "FILTH. lots of dirty filtHy FIlTh "
"FILTH. lots of dirty filtHy FIlTh " \ "FILTH. lots of dirty filtHy FIlTh "
"FILTH. lots of dirty filtHy FIlTh " \ "FILTH. lots of dirty filtHy FIlTh "
"FILTH. lots of dirty filtHy FIlTh " \ "FILTH. lots of dirty filtHy FIlTh "
"FILTH. lots of dirty filtHy FIlTh " \ "FILTH. lots of dirty filtHy FIlTh "
"FILTH. lots of dirty filtHy FIlTh " \ "FILTH. lots of dirty filtHy FIlTh "
"FILTH. lots of dirty filtHy FIlTh " \ "FILTH. lots of dirty filtHy FIlTh "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "FILTH. lots of dirty filtHy FIlTh "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"FILTH. lots of filtHy dirty FIlTh " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"FILTH. lots of filtHy dirty FIlTh " \ "FILTH. lots of filtHy dirty FIlTh "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "FILTH. lots of filtHy dirty FIlTh "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
)
filths = ["filth lots of dirty filthy filth", "filth lots of filthy dirty filth"]
filths = ["filth lots of dirty filthy filth", "filth lots of filthy dirty filth"]
expected_result = "This is a @line #containing a certain number of characters, 76 to be exact. " \
"This is a @line #containing a certain number of characters, 76 to be exact. " \ expected_result = (
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
" characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ " characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
janitor = Janitor(ngram_n=6, window_to_remove=200, too_dirty_cutoff=10, minimum_slice_length=200) "This is a @line #containing a certain number of characters, 76 to be exact. "
)
janitor = Janitor(
ngram_n=6, window_to_remove=200, too_dirty_cutoff=10, minimum_slice_length=200
)
result = janitor.clean_python(sequence) result = janitor.clean_python(sequence)
result = "".join(result) result = "".join(result)
assert result == sequence assert result == sequence
for filth in filths: for filth in filths:
janitor.register_contaminant(filth) janitor.register_contaminant(filth)
assert janitor.dirt_ngrams == set(filths) assert janitor.dirt_ngrams == set(filths)
...@@ -380,51 +436,55 @@ def test_janitor6(): ...@@ -380,51 +436,55 @@ def test_janitor6():
result = "".join(result) result = "".join(result)
assert result == expected_result assert result == expected_result
def test_janitor7(): def test_janitor7():
# Same as above but now we add 9 filths and expect the same result, the following test does 10. # Same as above but now we add 9 filths and expect the same result, the following test does 10.
sequence = "This is a @line #containing a certain number of characters, 76 to be exact. " \ sequence = (
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"FILTH. lots of dirty filtHy FIlTh " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"FILTH. lots of dirty filtHy FIlTh " \ "FILTH. lots of dirty filtHy FIlTh "
"FILTH. lots of dirty filtHy FIlTh " \ "FILTH. lots of dirty filtHy FIlTh "
"FILTH. lots of dirty filtHy FIlTh " \ "FILTH. lots of dirty filtHy FIlTh "
"FILTH. lots of dirty filtHy FIlTh " \ "FILTH. lots of dirty filtHy FIlTh "
"FILTH. lots of dirty filtHy FIlTh " \ "FILTH. lots of dirty filtHy FIlTh "
"FILTH. lots of dirty filtHy FIlTh " \ "FILTH. lots of dirty filtHy FIlTh "
"FILTH. lots of dirty filtHy FIlTh " \ "FILTH. lots of dirty filtHy FIlTh "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "FILTH. lots of dirty filtHy FIlTh "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"FILTH. lots of filtHy dirty FIlTh " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"FILTH. lots of filtHy dirty FIlTh " \ "FILTH. lots of filtHy dirty FIlTh "
"FILTH. lots of filtHy dirty FIlTh " \ "FILTH. lots of filtHy dirty FIlTh "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "FILTH. lots of filtHy dirty FIlTh "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " \ "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. " "This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
)
filths = ["filth lots of dirty filthy filth", "filth lots of filthy dirty filth"]
filths = ["filth lots of dirty filthy filth", "filth lots of filthy dirty filth"]
expected_result = "" expected_result = ""
janitor = Janitor(ngram_n=6, window_to_remove=200, too_dirty_cutoff=10, minimum_slice_length=200) janitor = Janitor(
ngram_n=6, window_to_remove=200, too_dirty_cutoff=10, minimum_slice_length=200
)
result = janitor.clean_python(sequence) result = janitor.clean_python(sequence)
result = "".join(result) result = "".join(result)
assert result == sequence assert result == sequence
for filth in filths: for filth in filths:
janitor.register_contaminant(filth) janitor.register_contaminant(filth)
assert janitor.dirt_ngrams == set(filths) assert janitor.dirt_ngrams == set(filths)
...@@ -453,23 +513,3 @@ def test_janitor8(): ...@@ -453,23 +513,3 @@ def test_janitor8():
# cleaned = " ".join(jan.clean(source)) # cleaned = " ".join(jan.clean(source))
# for contam in jan.dirt_ngrams: # for contam in jan.dirt_ngrams:
# assert contam not in cleaned, contam # assert contam not in cleaned, contam
...@@ -4,24 +4,59 @@ import lm_eval.models as models ...@@ -4,24 +4,59 @@ import lm_eval.models as models
def test_gpt2(): def test_gpt2():
gpt2 = models.get_model('gpt2').create_from_arg_string("device=cpu") gpt2 = models.get_model("gpt2").create_from_arg_string("device=cpu")
(ll_dog, ig_dog), (ll_cat, ig_cat), (_, ll_max_0), (_, ll_max_1), (_, ll_max_2), *vals = gpt2.loglikelihood([ (
('The quick brown fox jumps over the lazy', ' dog'), (ll_dog, ig_dog),
('The quick brown fox jumps over the lazy', ' cat'), (ll_cat, ig_cat),
('The quick brown fox jumps over the lazy', ', lazy dog'), (_, ll_max_0),
('The quick brown fox jumps over the lazy', ', lazy fox'), (_, ll_max_1),
('The quick brown fox jumps over the lazy', ', lazy fox and they both fall to the ground'), (_, ll_max_2),
*vals,
("""A mult""", """ilayer perceptron (MLP) is a class of feedforward artificial neural network (ANN)"""), ) = gpt2.loglikelihood(
("""The term MLP is used ambiguously, sometimes loosely to any feedforward ANN, sometimes strictly to refer to networks composed of multiple layers of perceptrons""", """ (with threshold activation); see § Terminology"""), [
("""Multilayer perceptrons are sometimes coll""", """oquially referred to as "vanilla" neural networks, especially when they have a single hidden layer.[1]"""), ("The quick brown fox jumps over the lazy", " dog"),
("""An MLP consists of at least three layers of nodes: an input layer, a hidden layer and an output layer. Except for the input nodes, each node is a neuron that uses a nonlinear""", """ activation function."""), ("The quick brown fox jumps over the lazy", " cat"),
("""MLP utilizes a supervised""", """ learning technique called backpropagation for training.[2][3] Its multiple layers and non-linear activation distinguish MLP from a linear perceptron. It can distinguish data that is not linearly separable.[4]"""), ("The quick brown fox jumps over the lazy", ", lazy dog"),
("""Recent work has demonstrated substantial gains on many NLP tasks and benchmarks by pre-training on a large corpus of text followed by fine-tuning on a specific task. While typically task-agnostic""", """ in architecture, this method still requires task-specific fine-tuning datasets of thousands or tens of thousands of examples. By contrast, humans can generally perform a new language task from only a few examples or from simple instructions - something which current NLP systems still largely struggle to do. Here we show that scaling up language models greatly improves task-agnostic, few-shot performance, sometimes even reaching competitiveness with prior state-of-the-art fine-tuning approaches. """), ("The quick brown fox jumps over the lazy", ", lazy fox"),
("""Specifically, we train GPT-3, an autoregressive language model with 175""", """ billion parameters, 10x more than any previous non-sparse language model, and test its performance in the few-shot setting. For all tasks, GPT-3 is applied without any gradient updates or fine-tuning, with tasks and few-shot demonstrations specified purely via text interaction with the model. GPT-3 achieves strong performance on many NLP datasets, including translation, question-answering, and cloze tasks, as well as several tasks that require on-the-fly reasoning or domain adaptation, such as unscrambling words, using a novel word in a sentence, or performing 3-digit arithmetic. At the same time, we also identify some datasets where GPT-3's few-shot learning still struggles, as well as some datasets where GPT-3 faces methodological issues related to training on large web corpora. Finally, we find that GPT-3 can generate samples of news articles which human evaluators have difficulty distinguishing from articles written by humans. We discuss broader societal impacts of this finding and of GPT-3 in general."""), (
("""A mult""", """ilayer perceptron (MLP) is a class of feedforward artificial neural network (ANN)"""), "The quick brown fox jumps over the lazy",
("""Hello""", """ World"""), ", lazy fox and they both fall to the ground",
]) ),
(
"""A mult""",
"""ilayer perceptron (MLP) is a class of feedforward artificial neural network (ANN)""",
),
(
"""The term MLP is used ambiguously, sometimes loosely to any feedforward ANN, sometimes strictly to refer to networks composed of multiple layers of perceptrons""",
""" (with threshold activation); see § Terminology""",
),
(
"""Multilayer perceptrons are sometimes coll""",
"""oquially referred to as "vanilla" neural networks, especially when they have a single hidden layer.[1]""",
),
(
"""An MLP consists of at least three layers of nodes: an input layer, a hidden layer and an output layer. Except for the input nodes, each node is a neuron that uses a nonlinear""",
""" activation function.""",
),
(
"""MLP utilizes a supervised""",
""" learning technique called backpropagation for training.[2][3] Its multiple layers and non-linear activation distinguish MLP from a linear perceptron. It can distinguish data that is not linearly separable.[4]""",
),
(
"""Recent work has demonstrated substantial gains on many NLP tasks and benchmarks by pre-training on a large corpus of text followed by fine-tuning on a specific task. While typically task-agnostic""",
""" in architecture, this method still requires task-specific fine-tuning datasets of thousands or tens of thousands of examples. By contrast, humans can generally perform a new language task from only a few examples or from simple instructions - something which current NLP systems still largely struggle to do. Here we show that scaling up language models greatly improves task-agnostic, few-shot performance, sometimes even reaching competitiveness with prior state-of-the-art fine-tuning approaches. """,
),
(
"""Specifically, we train GPT-3, an autoregressive language model with 175""",
""" billion parameters, 10x more than any previous non-sparse language model, and test its performance in the few-shot setting. For all tasks, GPT-3 is applied without any gradient updates or fine-tuning, with tasks and few-shot demonstrations specified purely via text interaction with the model. GPT-3 achieves strong performance on many NLP datasets, including translation, question-answering, and cloze tasks, as well as several tasks that require on-the-fly reasoning or domain adaptation, such as unscrambling words, using a novel word in a sentence, or performing 3-digit arithmetic. At the same time, we also identify some datasets where GPT-3's few-shot learning still struggles, as well as some datasets where GPT-3 faces methodological issues related to training on large web corpora. Finally, we find that GPT-3 can generate samples of news articles which human evaluators have difficulty distinguishing from articles written by humans. We discuss broader societal impacts of this finding and of GPT-3 in general.""",
),
(
"""A mult""",
"""ilayer perceptron (MLP) is a class of feedforward artificial neural network (ANN)""",
),
("""Hello""", """ World"""),
]
)
assert ll_dog > ll_cat assert ll_dog > ll_cat
assert not ig_cat assert not ig_cat
...@@ -31,17 +66,24 @@ def test_gpt2(): ...@@ -31,17 +66,24 @@ def test_gpt2():
assert ll_max_2 assert ll_max_2
# test empty context # test empty context
gpt2.loglikelihood([('', 'test')]) gpt2.loglikelihood([("", "test")])
gen, = gpt2.greedy_until([ (gen,) = gpt2.greedy_until(
('The quick brown fox jumps over the lazy', ['.', '\n']) [("The quick brown fox jumps over the lazy", [".", "\n"])]
]) )
assert gen == ', lazy fox and they both fall to the ground' assert gen == ", lazy fox and they both fall to the ground"
targets = [ targets = [
-61.60536193847656, -56.57843780517578, -62.131004333496094, -9.799489974975586, -153.96334838867188, -61.60536193847656,
-341.222900390625, -731.1475830078125, -61.60536193847656, -8.682319641113281 -56.57843780517578,
-62.131004333496094,
-9.799489974975586,
-153.96334838867188,
-341.222900390625,
-731.1475830078125,
-61.60536193847656,
-8.682319641113281,
] ]
for (pred, _), tgt in zip(vals, targets): for (pred, _), tgt in zip(vals, targets):
...@@ -49,21 +91,57 @@ def test_gpt2(): ...@@ -49,21 +91,57 @@ def test_gpt2():
def test_gpt2_perplexity(): def test_gpt2_perplexity():
gpt2 = models.get_model('gpt2').create_from_arg_string("device=cpu") gpt2 = models.get_model("gpt2").create_from_arg_string("device=cpu")
test_string = "We study empirical scaling laws for language model performance on the cross-entropy loss." test_string = "We study empirical scaling laws for language model performance on the cross-entropy loss."
perplexity = gpt2.loglikelihood_rolling([(test_string,)])[0] perplexity = gpt2.loglikelihood_rolling([(test_string,)])[0]
tgt = sum([ tgt = sum(
-4.9599953, -8.069298, -8.308624, -10.178513, -8.906924, -1.9318912, -7.745445, -7.146077, -5.2072, [
-3.5882986, -1.9957212, -8.044922, -0.20841774, -5.1096807, -0.099879116, -8.888423, -4.6180487, -4.9599953,
]) -8.069298,
-8.308624,
-10.178513,
-8.906924,
-1.9318912,
-7.745445,
-7.146077,
-5.2072,
-3.5882986,
-1.9957212,
-8.044922,
-0.20841774,
-5.1096807,
-0.099879116,
-8.888423,
-4.6180487,
]
)
assert perplexity == pytest.approx(tgt, rel=1e-3) assert perplexity == pytest.approx(tgt, rel=1e-3)
with mock.patch.object(models.gpt2.HFLM, 'max_length', new_callable=mock.PropertyMock) as mock_max_length: with mock.patch.object(
models.gpt2.HFLM, "max_length", new_callable=mock.PropertyMock
) as mock_max_length:
mock_max_length.return_value = 5 mock_max_length.return_value = 5
gpt2 = models.get_model('gpt2').create_from_arg_string("device=cpu") gpt2 = models.get_model("gpt2").create_from_arg_string("device=cpu")
perplexity = gpt2.loglikelihood_rolling([(test_string,)])[0] perplexity = gpt2.loglikelihood_rolling([(test_string,)])[0]
tgt = sum([ tgt = sum(
-4.96001, -8.069275, -8.308612, -10.178482, -8.90691, -4.037338, -8.09261, -11.662385, -10.206891, [
-4.425003, -2.2563353, -7.909143, -1.9304147, -7.3610134, -2.3120654, -7.3229, -2.1643813, -4.96001,
]) -8.069275,
-8.308612,
-10.178482,
-8.90691,
-4.037338,
-8.09261,
-11.662385,
-10.206891,
-4.425003,
-2.2563353,
-7.909143,
-1.9304147,
-7.3610134,
-2.3120654,
-7.3229,
-2.1643813,
]
)
assert perplexity == pytest.approx(tgt, rel=1e-3) assert perplexity == pytest.approx(tgt, rel=1e-3)
...@@ -6,7 +6,7 @@ from itertools import islice ...@@ -6,7 +6,7 @@ from itertools import islice
@pytest.mark.parametrize("taskname,task_class", tasks.TASK_REGISTRY.items()) @pytest.mark.parametrize("taskname,task_class", tasks.TASK_REGISTRY.items())
def test_basic_interface(taskname, task_class): def test_basic_interface(taskname, task_class):
print('Evaluating task', taskname) print("Evaluating task", taskname)
# dl = task_class.download # dl = task_class.download
# task_class.download = MagicMock() # task_class.download = MagicMock()
task = task_class() task = task_class()
...@@ -42,7 +42,7 @@ def test_basic_interface(taskname, task_class): ...@@ -42,7 +42,7 @@ def test_basic_interface(taskname, task_class):
reqs = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr] reqs = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr]
reqs2 = [task2.construct_requests(doc, task2.doc_to_text(doc)) for doc in arr2] reqs2 = [task2.construct_requests(doc, task2.doc_to_text(doc)) for doc in arr2]
assert reqs == reqs2 assert reqs == reqs2
if task.has_test_docs(): if task.has_test_docs():
...@@ -53,7 +53,7 @@ def test_basic_interface(taskname, task_class): ...@@ -53,7 +53,7 @@ def test_basic_interface(taskname, task_class):
reqs = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr] reqs = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr]
reqs2 = [task2.construct_requests(doc, task2.doc_to_text(doc)) for doc in arr2] reqs2 = [task2.construct_requests(doc, task2.doc_to_text(doc)) for doc in arr2]
assert reqs == reqs2 assert reqs == reqs2
if task.has_training_docs(): if task.has_training_docs():
...@@ -64,13 +64,13 @@ def test_basic_interface(taskname, task_class): ...@@ -64,13 +64,13 @@ def test_basic_interface(taskname, task_class):
reqs = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr] reqs = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr]
reqs2 = [task2.construct_requests(doc, task2.doc_to_text(doc)) for doc in arr2] reqs2 = [task2.construct_requests(doc, task2.doc_to_text(doc)) for doc in arr2]
assert reqs == reqs2 assert reqs == reqs2
@pytest.mark.parametrize("taskname,task_class", tasks.TASK_REGISTRY.items()) @pytest.mark.parametrize("taskname,task_class", tasks.TASK_REGISTRY.items())
def test_documents_and_requests(taskname, task_class): def test_documents_and_requests(taskname, task_class):
print('Evaluating task', taskname) print("Evaluating task", taskname)
task = task_class() task = task_class()
fns = [] fns = []
if task.has_training_docs(): if task.has_training_docs():
...@@ -83,21 +83,21 @@ def test_documents_and_requests(taskname, task_class): ...@@ -83,21 +83,21 @@ def test_documents_and_requests(taskname, task_class):
for fn in fns: for fn in fns:
# print(list(islice(fn(), 10))) # print(list(islice(fn(), 10)))
for doc in islice(fn(), 10): for doc in islice(fn(), 10):
txt = task.doc_to_text(doc) txt = task.doc_to_text(doc)
tgt = task.doc_to_target(doc) tgt = task.doc_to_target(doc)
assert isinstance(txt, str) assert isinstance(txt, str)
assert isinstance(tgt, str) assert isinstance(tgt, str)
# space convention # space convention
# allow txt to have length 0 for perplexity-like tasks since the model tacks an <|endoftext|> on # allow txt to have length 0 for perplexity-like tasks since the model tacks an <|endoftext|> on
if len(txt) != 0: if len(txt) != 0:
assert txt[-1] != ' ' assert txt[-1] != " "
assert tgt[0] == ' ' or txt[-1] == '\n' assert tgt[0] == " " or txt[-1] == "\n"
reqs = task.construct_requests(doc, txt) reqs = task.construct_requests(doc, txt)
# construct_requests can return just one request # construct_requests can return just one request
if not isinstance(reqs, (list, tuple)): if not isinstance(reqs, (list, tuple)):
reqs = [reqs] reqs = [reqs]
......
...@@ -5,8 +5,14 @@ from lm_eval.utils import get_rolling_token_windows, make_disjoint_window ...@@ -5,8 +5,14 @@ from lm_eval.utils import get_rolling_token_windows, make_disjoint_window
def test_get_rolling_token_windows_v1(): def test_get_rolling_token_windows_v1():
gold = [ gold = [
([-100, 0, 1, 2, 3, 4, 5, 6, 7, 8], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), ([-100, 0, 1, 2, 3, 4, 5, 6, 7, 8], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
([9, 10, 11, 12, 13, 14, 15, 16, 17, 18], [10, 11, 12, 13, 14, 15, 16, 17, 18, 19]), (
([19, 20, 21, 22, 23, 24, 25, 26, 27, 28], [20, 21, 22, 23, 24, 25, 26, 27, 28, 29]), [9, 10, 11, 12, 13, 14, 15, 16, 17, 18],
[10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
),
(
[19, 20, 21, 22, 23, 24, 25, 26, 27, 28],
[20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
),
([23, 24, 25, 26, 27, 28, 29, 30, 31, 32], [30, 31, 32, 33]), ([23, 24, 25, 26, 27, 28, 29, 30, 31, 32], [30, 31, 32, 33]),
] ]
x = list(range(34)) x = list(range(34))
...@@ -123,7 +129,6 @@ def test_get_rolling_token_windows_v4(): ...@@ -123,7 +129,6 @@ def test_get_rolling_token_windows_v4():
([17, 18, 19, 20, 21, 22, 23, 24, 25, 26], [27]), ([17, 18, 19, 20, 21, 22, 23, 24, 25, 26], [27]),
([18, 19, 20, 21, 22, 23, 24, 25, 26, 27], [28]), ([18, 19, 20, 21, 22, 23, 24, 25, 26, 27], [28]),
([19, 20, 21, 22, 23, 24, 25, 26, 27, 28], [29]), ([19, 20, 21, 22, 23, 24, 25, 26, 27, 28], [29]),
] ]
x = list(range(30)) x = list(range(30))
generator = get_rolling_token_windows( generator = get_rolling_token_windows(
...@@ -145,8 +150,14 @@ def test_get_rolling_token_windows_v4(): ...@@ -145,8 +150,14 @@ def test_get_rolling_token_windows_v4():
def test_get_rolling_token_windows_v5(): def test_get_rolling_token_windows_v5():
gold = [ gold = [
([-100, 0, 1, 2, 3, 4, 5, 6, 7, 8], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), ([-100, 0, 1, 2, 3, 4, 5, 6, 7, 8], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
([9, 10, 11, 12, 13, 14, 15, 16, 17, 18], [10, 11, 12, 13, 14, 15, 16, 17, 18, 19]), (
([19, 20, 21, 22, 23, 24, 25, 26, 27, 28], [20, 21, 22, 23, 24, 25, 26, 27, 28, 29]), [9, 10, 11, 12, 13, 14, 15, 16, 17, 18],
[10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
),
(
[19, 20, 21, 22, 23, 24, 25, 26, 27, 28],
[20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
),
] ]
x = list(range(30)) x = list(range(30))
generator = get_rolling_token_windows( generator = get_rolling_token_windows(
...@@ -203,5 +214,8 @@ def test_get_rolling_token_windows_empty(): ...@@ -203,5 +214,8 @@ def test_get_rolling_token_windows_empty():
def test_make_disjoint_window(): def test_make_disjoint_window():
assert make_disjoint_window(([1,2,3,4,5], [2,3,4,5,6])) == ([1], [2,3,4,5,6]) assert make_disjoint_window(([1, 2, 3, 4, 5], [2, 3, 4, 5, 6])) == (
assert make_disjoint_window(([1,2,3,4,5], [4,5,6])) == ([1,2,3], [4,5,6]) [1],
\ No newline at end of file [2, 3, 4, 5, 6],
)
assert make_disjoint_window(([1, 2, 3, 4, 5], [4, 5, 6])) == ([1, 2, 3], [4, 5, 6])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment