Unverified Commit e00d682f authored by Jonathan Tow's avatar Jonathan Tow Committed by GitHub
Browse files

Merge pull request #261 from EleutherAI/researcher2

Update CLI options and introduce decontamination
parents eb8163e9 ab6883b1
......@@ -69,6 +69,12 @@ class SQuAD2(Task):
def doc_to_text(self, doc):
return 'Title: ' + doc['title'] + '\n\n' + 'Background: ' + doc['context'] + '\n\n' + 'Question: ' + doc['question'] + '\n\n' + 'Answer:'
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc['context']
def doc_to_target(self, doc):
answer_list = doc['answers']['text']
if len(answer_list) > 0:
......
......@@ -72,6 +72,17 @@ class StoryCloze(Task):
doc["input_sentence_4"],
])
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return ' '.join([
doc["input_sentence_1"],
doc["input_sentence_2"],
doc["input_sentence_3"],
doc["input_sentence_4"],
])
def doc_to_target(self, doc):
clozes = [doc["sentence_quiz1"], doc["sentence_quiz2"]]
# `- 1` because the `answer_right_ending` index is 1-based.
......
......@@ -56,6 +56,12 @@ class BoolQ(Task):
def doc_to_text(self, doc):
return f"{doc['passage']}\nQuestion: {doc['question']}?\nAnswer:"
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc['passage']
def doc_to_target(self, doc):
return " " + yesno(doc['label'])
......
......@@ -128,6 +128,12 @@ class GeneralTranslationTask(Task):
tar_lang = code_to_language(language_codes[1])
return f"{src_lang} phrase: " + doc["src"] + f"\n{tar_lang} phrase:"
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["src"]
def doc_to_target(self, doc):
# This shows a single target, though there may be multiple targets in a lang test
return " " + doc["ref"] if isinstance(doc["ref"], str) else doc["ref"][0]
......
......@@ -54,6 +54,12 @@ class TriviaQA(Task):
def doc_to_text(self, doc):
return f"Question: {doc['question']}\nAnswer:"
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc['question']
def doc_to_target(self, doc):
return " " + doc['answer']['value']
......
......@@ -84,6 +84,12 @@ class TruthfulQAMultipleChoice(Task):
def doc_to_text(self, doc):
return QA_PROMPT + "\n\nQ: " + doc['question'] + "\nA:"
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc['question']
def doc_to_target(self, doc):
return " "
......
......@@ -49,6 +49,12 @@ class WordUnscrambleTask(Task):
def doc_to_text(self, doc):
return doc["context"]
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["context"]
def doc_to_target(self, doc):
return doc["completion"]
......
......@@ -56,6 +56,12 @@ class WebQs(Task):
def doc_to_text(self, doc):
return "Question: " + doc['question'] + '\nAnswer:'
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc['question']
def doc_to_target(self, doc):
# this picks one answer to be the "correct" one, despite sometimes
# multiple correct answers being possible.
......
......@@ -90,6 +90,9 @@ class WikiText(PerplexityTask):
def doc_to_target(self, doc):
return wikitext_detokenizer(doc)
def should_decontaminate(self):
return True
def count_words(self, doc):
# count number of words in *original doc before detokenization*
return len(re.split(r"\s+", doc))
......@@ -56,6 +56,12 @@ class Winogrande(Task):
def doc_to_text(self, doc):
return self.partial_context(doc, doc["option" + doc["answer"]])
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["sentence"]
@classmethod
def partial_context(cls, doc, option):
# Substitute the pronoun in the sentence with the specified option
......
......@@ -85,6 +85,12 @@ class WinogradSchemaChallenge273(Task):
def doc_to_text(self, doc):
return self.partial_context(doc, doc["options"][doc["label"]])
def should_decontaminate(self):
return True
def doc_to_decontamination_query(self, doc):
return doc["text"]
@classmethod
def partial_context(cls, doc, option):
# Substitute the pronoun in the original text with the specified
......
import argparse
import json
import logging
import fnmatch
from lm_eval import tasks, evaluator
logging.getLogger("openai").setLevel(logging.WARNING)
class MultiChoice:
def __init__(self, choices):
self.choices = choices
# Simple wildcard support (linux filename patterns)
def __contains__(self, values):
for value in values.split(","):
if len(fnmatch.filter(self.choices, value)) == 0:
return False
return True
def __iter__(self):
for choice in self.choices:
yield choice
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--model', required=True)
parser.add_argument('--model_args', default="")
parser.add_argument('--tasks', default="all_tasks")
parser.add_argument('--tasks', default=None, choices=MultiChoice(tasks.ALL_TASKS))
parser.add_argument('--provide_description', action="store_true")
parser.add_argument('--num_fewshot', type=int, default=0)
parser.add_argument('--batch_size', type=int, default=None)
......@@ -19,22 +35,35 @@ def parse_args():
parser.add_argument('--output_path', default=None)
parser.add_argument('--limit', type=int, default=None)
parser.add_argument('--no_cache', action="store_true")
parser.add_argument('--description_dict_path', default=None)
parser.add_argument('--check_integrity', action="store_true")
parser.add_argument('--decontamination_ngrams_path', default=None)
parser.add_argument('--description_dict_path', default=None)
parser.add_argument('--check_integrity', action="store_true")
return parser.parse_args()
# Returns a list containing all values of the source_list that
# match at least one of the patterns
def pattern_match(patterns, source_list):
task_names = set()
for pattern in patterns:
for matching in fnmatch.filter(source_list, pattern):
task_names.add(matching)
return list(task_names)
def main():
args = parse_args()
assert not args.provide_description # not implemented
if args.limit:
print("WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.")
if args.tasks == "all_tasks":
if args.tasks is None:
task_names = tasks.ALL_TASKS
else:
task_names = args.tasks.split(",")
task_names = pattern_match(args.tasks.split(","), tasks.ALL_TASKS)
print(f"Selected Tasks: {task_names}")
description_dict = {}
if args.description_dict_path:
......@@ -51,11 +80,11 @@ def main():
no_cache=args.no_cache,
limit=args.limit,
description_dict=description_dict,
decontamination_ngrams_path=args.decontamination_ngrams_path,
check_integrity=args.check_integrity
)
dumped = json.dumps(results, indent=2)
dumped = json.dumps(results, indent=2)
print(dumped)
if args.output_path:
......
{
"Data": "Pile statistics",
"Document Count": 210607728,
"Total Pile Characters": 421215456,
"File Start Offsets": [
0,
7021438,
14042822,
21066113,
28086515,
35106072,
42123306,
49145091,
56165817,
63185587,
70211208,
77234322,
84249267,
91267634,
98285983,
105305110,
112322489,
119342491,
126367373,
133389153,
140412039,
147432373,
154452516,
161470190,
168492733,
175512521,
182526939,
189547478,
196565318,
203583306
]
}
\ No newline at end of file
import glob
import argparse
import os
import subprocess
import shutil
from tqdm import tqdm
from tqdm_multiprocess import TqdmMultiProcessPool
import logging
from tqdm_multiprocess.logger import setup_logger_tqdm
logger = logging.getLogger(__name__)
def process_task(working_directory, output_directory, bucket_file_path, tqdm_func, global_tqdm):
command = f"zstd {bucket_file_path}"
logger.info(command)
subprocess.call(command, shell=True)
compressed_file = bucket_file_path + ".zst"
if output_directory:
shutil.move(compressed_file, output_directory)
os.remove(bucket_file_path)
global_tqdm.update()
def compress_and_move(working_directory, output_directory, process_count):
os.makedirs(output_directory, exist_ok=True)
original_info_file_path = os.path.join(working_directory, "info.json")
assert(os.path.exists(original_info_file_path))
tasks = []
bucket_file_paths = glob.glob(os.path.join(working_directory, "output", f"*.bkt.txt.sorted"))
for bucket_file_path in bucket_file_paths:
task = (process_task, (working_directory, output_directory, bucket_file_path))
tasks.append(task)
pool = TqdmMultiProcessPool(process_count)
on_done = lambda _ : None
on_error = lambda _ : None
global_progress = tqdm(total=len(bucket_file_paths), dynamic_ncols=True, unit="file")
_ = pool.map(global_progress, tasks, on_error, on_done)
shutil.copy(original_info_file_path, os.path.join(output_directory, "info.json"))
parser = argparse.ArgumentParser(description='sort 13gram buckets')
parser.add_argument("-dir", "--working_directory", required=True)
parser.add_argument("-output", "--output_directory", required=True)
parser.add_argument("-procs", "--process_count", type=int, default=8)
if __name__ == '__main__':
version = 1.00
print(f"Running version {version}")
logfile_path = "compress_and_package.log"
setup_logger_tqdm(logfile_path)
args = parser.parse_args()
compress_and_move(args.working_directory, args.output_directory, args.process_count)
\ No newline at end of file
......@@ -21,8 +21,10 @@ Arguments
"""
import argparse
import json
import pickle
import os
import sys
from pathlib import Path
import glob
import signal
......@@ -30,32 +32,89 @@ from signal import SIGINT
from tqdm import tqdm
from scripts.clean_training_data.janitor import Janitor, word_ngrams
from scripts.clean_training_data.archiver import TextArchive, Reader
from lm_eval.decontamination.janitor import Janitor, word_ngrams
from lm_eval.decontamination.archiver import TextArchive, Reader
import logging
from tqdm_multiprocess.logger import setup_logger_tqdm
logger = logging.getLogger(__name__)
pile_document_count = 210607728
terminate = False
def handler(signal_received, frame):
global terminate
terminate = True
def get_pile(directory):
reader = Reader()
for file in glob.glob(os.path.join(directory, f"*.jsonl.zst*")):
for document in reader.read(file):
yield document
def yield_pile(start_offsets=None, checkpoint_offset=None):
directory = "pile"
if not os.path.exists(directory):
print("We expect the pile archives to be in the 'pile' directory, but this was not found.")
raise Exception("Pile directory not found.")
files = list(sorted(glob.glob(os.path.join(directory, "*.jsonl.zst*"))))
pile_global_offset = 0
start_file = 0
if checkpoint_offset:
for file_i, start_offset in enumerate(start_offsets):
if start_offset > checkpoint_offset:
break
def close_buckets(buckets):
for bucket in buckets:
bucket.commit()
start_file = file_i
pile_global_offset = start_offset
for file_i, file in enumerate(files):
if file_i < start_file:
logger.info(f"Skipping file {file}")
continue
logger.info(f"Reading from pile file: {file}")
reader = Reader()
for document in reader.read(file):
yield (pile_global_offset, document)
pile_global_offset += 1
# Hash buckets > disk backed files. Supports file position checkpointing and resuming
# Allows you to write continuously and checkpoint intermittently. If a failure occurs
# the buckets are simply truncated at your last checkpoint.
class Buckets:
def __init__(self, directory, num_buckets):
self.bucket_files = [os.path.join(directory, f"ngrams_{i}.bkt.txt") for i in range(num_buckets)]
self.buckets = list(map(TextArchive, self.bucket_files))
self.checkpoint_file = os.path.join(directory, f"bucket_offsets.ckpt")
if os.path.exists(self.checkpoint_file):
self.bucket_offsets = pickle.load(open(self.checkpoint_file, "rb"))
else:
self.bucket_offsets = [0 for i in range(len(self.buckets))]
for i, offset in enumerate(self.bucket_offsets):
bucket = self.buckets[i]
bucket.fh.seek(offset)
bucket.fh.truncate()
def add_data(self, key, value):
i = hash(key) % len(self.buckets)
bucket = self.buckets[i]
bucket.add_data(value)
def save_checkpoint(self):
for bucket in self.buckets:
bucket.fh.flush()
bucket_offsets = [bucket.fh.tell() for bucket in self.buckets]
pickle.dump(bucket_offsets, open(self.checkpoint_file, "wb"))
def close_buckets(self):
for bucket in self.buckets:
bucket.commit()
def do_ngrams_in_buckets(n_value, working_directory, bucket_count):
pile_statistics = json.load(open("pile_statistics.json", "r"))
pile_document_count = pile_statistics["Document Count"]
start_offsets = pile_statistics["File Start Offsets"]
output_directory = os.path.join(working_directory, "output")
os.makedirs(output_directory, exist_ok=True)
......@@ -68,49 +127,56 @@ def do_ngrams_in_buckets(n_value, working_directory, bucket_count):
return
# Checkpoint
checkpoint_file = os.path.join(output_directory, f"ngram_buckets.ckpt")
checkpoint_file = os.path.join(working_directory, f"pile_offset.ckpt")
if os.path.exists(checkpoint_file):
start_id = pickle.load(open(checkpoint_file,"rb"))
checkpoint_offset = pickle.load(open(checkpoint_file,"rb"))
iterate = True
else:
start_id = 0
checkpoint_offset = 0
iterate = False
logger.info(f"Starting at pile document index {start_id}")
bucket_files = [os.path.join(output_directory, f"ngrams_{i}.bkt.txt") for i in range(bucket_count)]
buckets = list(map(TextArchive, bucket_files))
logger.info(f"Starting at pile document index {checkpoint_offset}")
buckets = Buckets(output_directory, bucket_count)
janitor = Janitor()
current_id = 0
batch_size = 1000
batch_counter = 0
with tqdm(total=pile_document_count, dynamic_ncols=True, unit="docs") as progress:
for document in get_pile(working_directory):
if current_id < start_id:
if terminate:
close_buckets(buckets)
return
current_id += 1
with tqdm(total=checkpoint_offset, dynamic_ncols=True, unit="docs") as progress:
for offset, document in yield_pile(start_offsets, checkpoint_offset):
if iterate:
logger.info(f"Iterating to offset {checkpoint_offset} from {offset}")
progress.update(offset)
iterate = False
if offset < checkpoint_offset:
progress.update()
if terminate:
return
continue
if offset == checkpoint_offset:
progress.reset(total=pile_document_count)
progress.update(checkpoint_offset)
# Save checkpoint every "batch_size", only allow terminate after checkpoint
if batch_counter == batch_size:
progress.update(batch_size)
batch_counter = 0
pickle.dump(current_id, open(checkpoint_file,"wb"))
buckets.save_checkpoint()
pickle.dump(offset, open(checkpoint_file,"wb"))
if terminate:
close_buckets(buckets)
buckets.close_buckets()
return
ngrams = word_ngrams(janitor.normalize_string(document), n_value)
for ngram in ngrams:
bucket = hash(ngram) % len(buckets)
buckets[bucket].add_data(f"{ngram} {current_id}")
buckets.add_data(ngram, f"{ngram} {offset}")
batch_counter += 1
current_id += 1
close_buckets(buckets)
buckets.close_buckets()
Path(done_file).touch()
......@@ -120,6 +186,12 @@ parser.add_argument("-n", "--n_value", type=int, default=13)
parser.add_argument("-buckets", "--bucket_count", type=int, default=500)
if __name__ == '__main__':
version = 1.00
print(f"Running version {version}")
if "PYTHONHASHSEED" not in os.environ or os.environ["PYTHONHASHSEED"] != "0":
print("Please run 'export PYTHONHASHSEED=0' before running generate.")
sys.exit()
# Handle sigint (ctrl-c) cleanly
previous_signal_int = signal.signal(SIGINT, handler)
......@@ -128,4 +200,8 @@ if __name__ == '__main__':
setup_logger_tqdm(logfile_path)
args = parser.parse_args()
do_ngrams_in_buckets(args.n_value, args.working_directory, args.bucket_count)
\ No newline at end of file
do_ngrams_in_buckets(args.n_value, args.working_directory, args.bucket_count)
info_dict = {"title": "dataset ngrams", "ngram_size": 13}
info_dict_path = os.path.join(args.working_directory, "info.json")
json.dump(info_dict, open(info_dict_path, "w"))
\ No newline at end of file
from lm_eval.decontamination.archiver import Reader
import os
import json
from functools import reduce
import glob
import tqdm
from tqdm_multiprocess import TqdmMultiProcessPool
def get_file_stats(file_path, tqdm_func, global_tqdm):
reader = Reader()
total_documents = 0
total_size = 0
update_frequency = 10000
current_file_position = 0
with tqdm_func(total=os.path.getsize(file_path), dynamic_ncols=True, unit="byte", unit_scale=1) as progress:
for document in reader.read(file_path, get_meta=True):
total_size += len(document)
total_documents += 1
if total_documents % update_frequency == 0:
new_file_pos = reader.fh.tell()
bytes_read = new_file_pos - current_file_position
current_file_position = new_file_pos
progress.update(bytes_read)
global_tqdm.update(bytes_read)
return (total_documents, total_size)
def get_files():
directory = "pile"
files = list(sorted(glob.glob(os.path.join(directory, "*.jsonl.zst*"))))
print(files)
return files
def get_stats():
files = get_files()
total_size_bytes = sum(map(lambda x: os.path.getsize(x), files))
pool = TqdmMultiProcessPool(4)
global_tqdm = tqdm.tqdm(total=total_size_bytes, dynamic_ncols=True, unit="byte", unit_scale=1)
# Generate minhashes with pool
tasks = [(get_file_stats, (file,)) for file in files]
on_done = lambda _ : None
on_error = lambda _ : None
results = pool.map(global_tqdm, tasks, on_error, on_done)
total_documents, total_size = reduce(lambda x, y: (x[0]+y[0],x[1]+y[1]), results)
start_offsets = []
current_offset = 0
for file_document_count, _ in results:
start_offsets.append(current_offset)
current_offset += file_document_count
return (total_documents, total_size, start_offsets)
if __name__ == '__main__':
version = 1.01
print(f"Running version {version}")
stats_file_path = "pile_statistics.json"
if os.path.exists(stats_file_path):
stats = json.load(open(stats_file_path, "r"))
else:
document_count, total_document_size_chars, start_offsets = get_stats()
stats = {"Data": "Pile statistics",
"Document Count": document_count,
"Total Pile Characters": total_document_size_chars,
"File Start Offsets": start_offsets
}
json.dump(stats, open(stats_file_path, "w"), indent=4)
print(f"document_count: {stats['Document Count']}")
print(f"total_chars: {stats['Total Pile Characters']}")
print(f"start_offsets: {stats['File Start Offsets']}")
"""
Iteratively runs gnu sort on each bucket, gnu handles the multiprocessing.
Iteratively runs gnu sort on each bucket, uses up to 8 cores.
Arguments
---------
......@@ -11,10 +11,8 @@ Arguments
import glob
import argparse
import os
from pathlib import Path
import signal
from signal import SIGINT
import re
import subprocess
from tqdm import tqdm
......@@ -32,12 +30,6 @@ def sort_13_gram_buckets(working_directory):
bucket_file_paths = glob.glob(os.path.join(working_directory, f"*.bkt.txt"))
for bucket_file_path in tqdm(bucket_file_paths, dynamic_ncols=True):
bucket_id = re.sub("\D", "", os.path.basename(bucket_file_path))
done_file = os.path.join(working_directory, f"ngram_bucket_sorting_{bucket_id}.done")
if os.path.exists(done_file):
logger.info(f"bucket {bucket_id} already processed, skipping")
return
sorted_file_path = bucket_file_path + ".sorted"
command = f"sort {bucket_file_path} > {sorted_file_path}"
logger.info(command)
......@@ -46,7 +38,6 @@ def sort_13_gram_buckets(working_directory):
if terminate:
return
Path(done_file).touch()
os.remove(bucket_file_path)
parser = argparse.ArgumentParser(description='sort 13gram buckets')
......@@ -54,6 +45,9 @@ parser.add_argument("-dir", "--working_directory", default="")
if __name__ == '__main__':
version = 1.00
print(f"Running version {version}")
# Handle sigint (ctrl-c) cleanly
previous_signal_int = signal.signal(SIGINT, handler)
......@@ -61,4 +55,4 @@ if __name__ == '__main__':
setup_logger_tqdm(logfile_path)
args = parser.parse_args()
sort_13_gram_buckets(args.working_directory)
\ No newline at end of file
sort_13_gram_buckets(args.working_directory)
......@@ -3,12 +3,14 @@ from collections import Counter
import shutil
import glob
from scripts.clean_training_data.janitor import *
from lm_eval.decontamination.janitor import *
from scripts.clean_training_data.generate_13_grams import do_ngrams_in_buckets
from scripts.clean_training_data.archiver import Archive, TextReader
from lm_eval.decontamination.archiver import Archive, TextReader
import logging
logger = logging.getLogger(__name__)
def test_generate_13_grams_1():
def test_generate_13_grams_1(caplog):
data = """A goose (plural geese) is a bird of any of several waterfowl species in the family Anatidae.
This group comprises the genera Anser (the grey geese and white geese) and Branta (the black geese).
Some other birds, mostly related to the shelducks, have "goose" as part of their names.
......@@ -22,6 +24,7 @@ def test_generate_13_grams_1():
data = data + data
# Simple Generation
print("simple generation")
n = 13
janitor = Janitor()
ngrams = word_ngrams(janitor.normalize_string(data), n)
......@@ -31,22 +34,26 @@ def test_generate_13_grams_1():
# print(comparison)
# Generating into buckets
print("bucket generation")
test_working_directory = "test_generate_13_grams"
output_directory = os.path.join(test_working_directory, "output")
try:
shutil.rmtree(output_directory)
shutil.rmtree(test_working_directory)
except FileNotFoundError:
pass
os.makedirs(test_working_directory, exist_ok=True)
archive = Archive(os.path.join(test_working_directory, "test.jsonl.zst"))
os.makedirs(test_working_directory)
assert(not os.path.exists("pile"))
os.makedirs("pile")
archive = Archive(os.path.join("pile", "test.jsonl.zst"))
archive.add_data(data)
archive.commit()
bucket_count = 4
do_ngrams_in_buckets(n, test_working_directory, bucket_count)
# Rebuild from buckets
print("rebuild")
rebuilt_ngrams = []
bucket_file_paths = glob.glob(os.path.join(test_working_directory, "output", f"*.bkt.txt"))
for bucket_file_path in bucket_file_paths:
reader = TextReader(bucket_file_path)
......@@ -56,6 +63,7 @@ def test_generate_13_grams_1():
rebuilt_ngrams.append(ngram)
# Compare
print("compare")
result_counter = Counter(rebuilt_ngrams)
# print(len(result_counter))
# print(len(comparison_counter))
......
import re
from collections import defaultdict
from scripts.clean_training_data.janitor import *
from lm_eval.decontamination.janitor import *
def simple_ngram(sequence, n):
ngrams = list()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment