Commit f495bfb4 authored by researcher2's avatar researcher2
Browse files

Refactor for PR

parent 0f283a9c
...@@ -6,6 +6,7 @@ import io ...@@ -6,6 +6,7 @@ import io
import datetime import datetime
import mmap import mmap
import tqdm import tqdm
from pathlib import Path
def json_serial(obj): def json_serial(obj):
"""JSON serializer for objects not serializable by default json code""" """JSON serializer for objects not serializable by default json code"""
...@@ -61,16 +62,19 @@ class Reader: ...@@ -61,16 +62,19 @@ class Reader:
else: else:
yield text yield text
# Simple text reader and writer with same interface as above
class TextArchive: class TextArchive:
def __init__(self, file_path, mode="ab"): def __init__(self, file_path, mode="rb+"):
self.file_path = file_path self.file_path = file_path
dir_name = os.path.dirname(file_path) dir_name = os.path.dirname(file_path)
if dir_name: if dir_name:
os.makedirs(dir_name, exist_ok=True) os.makedirs(dir_name, exist_ok=True)
self.fh = open(self.file_path, mode)
if not os.path.exists(file_path):
Path(file_path).touch()
self.fh = open(self.file_path, mode)
def add_data(self, data, meta={}): def add_data(self, data):
self.fh.write(data.encode('UTF-8') + b'\n') self.fh.write(data.encode('UTF-8') + b'\n')
def commit(self): def commit(self):
......
import mmap
import string
import time import time
import random import random
import pickle import pickle
import json
import glob import glob
import os import os
import collections import collections
import re
import tqdm from .janitor import Janitor, word_ngrams
from .archiver import ZStdTextReader
try:
import janitor_util
JANITOR_CPP = True
except Exception as e:
# print("WARNING: C++ module could not be loaded. Janitor running in python mode")
JANITOR_CPP = False
# Was used for testing the evaluator decoupled from the full logic below # Was used for testing the evaluator decoupled from the full logic below
def get_train_overlap_stub(docs, ngrams_path, ngrams_n_size): def get_train_overlap_stub(docs, ngrams_path, ngrams_n_size):
...@@ -25,11 +17,12 @@ def get_train_overlap_stub(docs, ngrams_path, ngrams_n_size): ...@@ -25,11 +17,12 @@ def get_train_overlap_stub(docs, ngrams_path, ngrams_n_size):
# Returns a dictionary containing all overlapping documents in each # Returns a dictionary containing all overlapping documents in each
# task based on any 13gram being found in the training set. # task. In the standard use case, an overlap occurs when any of the 13-grams
# ngrams_path is the parent directory containing the "ngrams_{x}.bkt.txt.sorted.zst" # found in the task document exist in the training set documents.
# files built by the other scripts "generate_13_grams.py" and "sort_13_gram_buckets.py. #
# ngrams_n_size is expected to be 13 but we made it a parameter for generality. # To generate 13-grams for the pile see scripts/clean_training_data. The final output of these
# The task set is only included for caching purposes. # scripts are an info.json file containing the n_gram_size (13) and a bunch of "ngrams_{x}.bkt.txt.sorted.zst"
# files. These should exist in the "ngrams_path" provided to this function.
# Algorithm: # Algorithm:
# 1. Build lookups for each dataset {ngram: list(document_ids)} # 1. Build lookups for each dataset {ngram: list(document_ids)}
...@@ -39,12 +32,12 @@ def get_train_overlap_stub(docs, ngrams_path, ngrams_n_size): ...@@ -39,12 +32,12 @@ def get_train_overlap_stub(docs, ngrams_path, ngrams_n_size):
# 4. Strip the task_set from the dictionary keys and return # 4. Strip the task_set from the dictionary keys and return
# #
# We cache the task+set lookups as well as the overlaps. # We cache the task+set lookups as well as the overlaps.
#
# Currently calculating some per file ngram stats for interest, might remove before merging into main
def get_train_overlap(docs_by_task_set, ngrams_path, limit): def get_train_overlap(docs_by_task_set, ngrams_path, limit):
# return get_train_overlap_stub(docs, ngrams_path, ngrams_n_size) # return get_train_overlap_stub(docs, ngrams_path, ngrams_n_size)
# TODO: infer ngrams_n_size from ngrams_path info_dict_path = os.path.join(ngrams_path, "info.json")
info_dict = json.load(open(info_dict_path, "r"))
ngrams_n_size = info_dict["ngram_size"]
janitor = Janitor() janitor = Janitor()
...@@ -62,16 +55,17 @@ def get_train_overlap(docs_by_task_set, ngrams_path, limit): ...@@ -62,16 +55,17 @@ def get_train_overlap(docs_by_task_set, ngrams_path, limit):
for (task_name, task_set), docs in docs_by_task_set.items(): for (task_name, task_set), docs in docs_by_task_set.items():
if not os.path.exists(f"data/{task_name}"): if not os.path.exists(f"data/{task_name}"):
os.mkdir(f"data/{task_name}") os.mkdir(f"data/{task_name}")
# Check if we've decontaminated this set before
# Check if we've decontaminated this combination before
overlaps_dump_path = get_overlaps_dump_path(task_name, task_set, ngrams_n_size, limit) overlaps_dump_path = get_overlaps_dump_path(task_name, task_set, ngrams_n_size, limit)
if os.path.exists(overlaps_dump_path): if os.path.exists(overlaps_dump_path):
duplicates[(task_name, task_set)] = pickle.load(open(overlaps_dump_path, "rb")) duplicates[(task_name, task_set)] = pickle.load(open(overlaps_dump_path, "rb"))
sets_to_decontaminate -= 1 sets_to_decontaminate -= 1
continue continue
else: else:
duplicates[(task_name, task_set)] = set() # No defaultdict, we want to dump empty sets too later duplicates[(task_name, task_set)] = set()
# Build/load the task lookup {ngram: documents}. # Build/load the task lookup {ngram: set(documents)}.
task_set_lookup_path = f"data/{task_name}/{task_set}_{ngrams_n_size}grams_limit{limit}.lookup" task_set_lookup_path = f"data/{task_name}/{task_set}_{ngrams_n_size}grams_limit{limit}.lookup"
if os.path.exists(task_set_lookup_path): if os.path.exists(task_set_lookup_path):
print(f"{task_set_lookup_path} available, loading...") print(f"{task_set_lookup_path} available, loading...")
...@@ -104,7 +98,7 @@ def get_train_overlap(docs_by_task_set, ngrams_path, limit): ...@@ -104,7 +98,7 @@ def get_train_overlap(docs_by_task_set, ngrams_path, limit):
elapsed = time.perf_counter() - start elapsed = time.perf_counter() - start
print(f"Merging lookups took {elapsed:0.5f} seconds.") print(f"Merging lookups took {elapsed:0.5f} seconds.")
print(f"13 grams files found in {ngrams_path}:") print(f"{ngrams_n_size} grams files found in {ngrams_path}:")
files = glob.glob(os.path.join(ngrams_path, f"*.sorted.zst")) files = glob.glob(os.path.join(ngrams_path, f"*.sorted.zst"))
print(files) print(files)
...@@ -157,281 +151,3 @@ def get_train_overlap(docs_by_task_set, ngrams_path, limit): ...@@ -157,281 +151,3 @@ def get_train_overlap(docs_by_task_set, ngrams_path, limit):
# Strip task set and return # Strip task set and return
return {task_name: doc_ids for (task_name, task_set), doc_ids in duplicates.items()} return {task_name: doc_ids for (task_name, task_set), doc_ids in duplicates.items()}
# Implementation from nltk source
# https://www.nltk.org/_modules/nltk/util.html
def form_ngrams(sequence, n):
history = []
while n > 1:
# PEP 479, prevent RuntimeError from being raised when StopIteration bubbles out of generator
try:
next_item = next(sequence)
except StopIteration:
# no more data, terminate the generator
return
history.append(next_item)
n -= 1
for item in sequence:
history.append(item)
yield tuple(history)
del history[0]
def word_ngrams(s, n):
"""Splits a string into ngram words"""
tokens = s.split() # not a generator :(
ngram_seqs = form_ngrams(iter(tokens), n)
return (" ".join(ngram) for ngram in ngram_seqs)
# https://stackoverflow.com/questions/13734451/string-split-with-indices-in-python
def split_indices(s):
"""Splits a string on whitespaces and records the indices of each in the original string.
@:return generator((word, (start_idx, end_idx)), ...)
"""
return ((m.group(0), (m.start(), m.end() - 1)) for m in re.finditer(r'\S+', s))
def word_ngrams_indices(s, n):
"""Splits a string into pairs of (ngram words, their start/end indices)"""
tokens_with_indices = split_indices(s)
# Generator of ngrams of (word, idx_pairs)
# (
# [(word, (start,end)), (word, (start, end))...],
# [(word, (start, end)), ...],
# ...
# )
ngram_seqs_with_indices = form_ngrams(tokens_with_indices, n)
# Generator of pairs of word and index ngrams
# (
# ([word, word, ...], [(start,end), (start,end), ...]),
# ...
# )
ngram_indices_pairs = (zip(*ngram_with_indices) for ngram_with_indices in ngram_seqs_with_indices)
# Generator of ( (word_ngram, (start, end)), (word_ngram, start, end)), ...)
return ((" ".join(ngram_seq), (indices[0][0], indices[-1][1])) for ngram_seq, indices in ngram_indices_pairs)
class Janitor:
# FIXME delete_chars: Should anything else go here? Special chars?
def __init__(
self,
ngram_n=13,
window_to_remove=200,
too_dirty_cutoff=10,
minimum_slice_length=200,
delete_chars=string.punctuation
):
self.ngram_n = ngram_n
self.window_to_remove = window_to_remove
self.too_dirty_cutoff = too_dirty_cutoff
self.minimum_slice_length = minimum_slice_length
self.delete_chars = delete_chars
self.dirt_ngrams = set()
# If in python, we'll translate uppercase to lowercase and delete naughty characters.
# This is fast by python standards
# https://stackoverflow.com/questions/638893/what-is-the-most-efficient-way-in-python-to-convert-a-string-to-all-lowercase-st
self.translation_table = str.maketrans(
string.ascii_lowercase + string.ascii_uppercase, # These characters
string.ascii_lowercase * 2, # Become these characters
self.delete_chars # These are deleted
)
##############
# I/O for saving contamination ngrams
##############
def save_contamination_ngrams(self, filename):
with open(filename, 'wb') as fp:
pickle.dump(filename, fp)
def load_contamination_ngrams(self, filename):
with open(filename, 'rb') as fp:
self.dirt_ngrams = pickle.load(fp)
##############
# Call these :)
##############
def register_contaminant(self, dirt_string):
"""Register a string as contamination to be removed, e.g. a test set
This breaks the dirt_string into ngrams to store for future cleaning"""
if JANITOR_CPP:
return self.register_contaminant_cpp(dirt_string)
else:
print("WARNING: Janitor running in python mode")
return self.register_contaminant_python(dirt_string)
def clean(self, dirty_string):
"""Clean a string (e.g. a training set) by removing all ngrams previously
reigstered as contaminants. Returns a list of clean chunks, or empty if
the string was too dirty"""
if JANITOR_CPP:
return self.clean_cpp(dirty_string)
else:
print("WARNING: Janitor running in python mode")
return self.clean_python(dirty_string)
def _split_chunks(self, dirty_string, dirty_parts):
clean_chunks = []
splice_idx = 0
end = -1
for i, (ngram, start, end) in enumerate(dirty_parts):
if i >= self.too_dirty_cutoff:
return []
start = max(0, start - self.window_to_remove)
end = min(len(dirty_string), end + self.window_to_remove)
if start - splice_idx > self.minimum_slice_length:
clean_chunks.append(dirty_string[splice_idx: start])
splice_idx = end
if end < len(dirty_string) - self.minimum_slice_length:
clean_chunks.append(dirty_string[end+1:])
return clean_chunks
##############
# Fast C++
##############
def register_contaminant_cpp(self, dirt_string):
self.dirt_ngrams.update(janitor_util.clean_ngram(dirt_string, self.delete_chars, self.ngram_n))
def clean_cpp(self, dirty_string):
contamination_indices = janitor_util.clean_ngram_with_indices(dirty_string, self.delete_chars, self.ngram_n)
return self._split_chunks(dirty_string, contamination_indices)
##############
# Slow python
##############
def normalize_string(self, s):
return s.translate(self.translation_table)
def register_contaminant_python(self, dirt_string):
self.dirt_ngrams.update(word_ngrams(self.normalize_string(dirt_string), self.ngram_n))
def clean_python(self, dirty_string):
contamination_indices = (
(None, *idx_pair)
for dirty_ngram, idx_pair in word_ngrams_indices(dirty_string, self.ngram_n)
if self.normalize_string(dirty_ngram) in self.dirt_ngrams
)
return self._split_chunks(dirty_string, contamination_indices)
# Implementation from nltk source
# https://www.nltk.org/_modules/nltk/util.html
def form_ngrams(sequence, n):
history = []
while n > 1:
# PEP 479, prevent RuntimeError from being raised when StopIteration bubbles out of generator
try:
next_item = next(sequence)
except StopIteration:
# no more data, terminate the generator
return
history.append(next_item)
n -= 1
for item in sequence:
history.append(item)
yield tuple(history)
del history[0]
def word_ngrams(s, n):
"""Splits a string into ngram words"""
tokens = s.split() # not a generator :(
ngram_seqs = form_ngrams(iter(tokens), n)
return (" ".join(ngram) for ngram in ngram_seqs)
# Simple text reader and writer with same interface as above
class TextArchive:
def __init__(self, file_path, mode="ab"):
self.file_path = file_path
dir_name = os.path.dirname(file_path)
if dir_name:
os.makedirs(dir_name, exist_ok=True)
self.fh = open(self.file_path, mode)
def add_data(self, data, meta={}):
self.fh.write(data.encode('UTF-8') + b'\n')
def commit(self):
self.fh.flush()
self.fh.close()
class TextReader:
def __init__(self, file_path):
self.file_path = file_path
# Optimized mmap read with infrequent tqdm updates to maintain speed
# Tested up to 250MB/s.
def read_tqdm(self, update_frequency=10000):
current_file_position = 0
line_counter = 0
with open(self.file_path, 'r') as fh, \
tqdm.tqdm(total=os.path.getsize(self.file_path), dynamic_ncols=True,
unit="byte", unit_scale=1) as progress:
with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
for line in iter(mmap_obj.readline, b""):
line = line.decode("utf-8")
line_counter += 1
if line_counter == update_frequency:
new_file_pos = mmap_obj.tell()
bytes_read = new_file_pos - current_file_position
current_file_position = new_file_pos
progress.update(bytes_read)
line_counter = 0
yield line[:-1]
def read_and_tell(self):
current_file_position = 0
with open(self.file_path, 'r', encoding="utf8") as fh:
with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
for line in iter(mmap_obj.readline, b""):
line = line.decode("utf-8")
new_file_pos = mmap_obj.tell()
raw_bytes_read = new_file_pos - current_file_position
current_file_position = new_file_pos
yield line[:-1], raw_bytes_read
def read(self):
with open(self.file_path, 'r', encoding="utf8") as fh:
with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
for line in iter(mmap_obj.readline, b""):
line = line.decode("utf-8")
yield line[:-1]
def read_slow(self):
with open(self.file_path, 'r', encoding="utf8") as fh:
while True:
line = fh.readline()
if line == -1 or line == "":
break
else:
yield line[:-1]
# Optimized for speed. Decompresses the archive in shell before
# using the mmap'd TextReader.
class ZStdTextReader:
def __init__(self, file):
self.file = file
def read_tqdm(self):
decompressed_file = self.file[:-4]
print("Decompressing file, please wait...")
os.system(f"zstd -d {self.file}") # linux decompress is faster
reader = TextReader(decompressed_file)
yield from reader.read_tqdm()
os.remove(decompressed_file)
...@@ -4,7 +4,6 @@ import timeit ...@@ -4,7 +4,6 @@ import timeit
import pickle import pickle
import traceback import traceback
from pprint import pprint from pprint import pprint
from lm_eval.decontamination import word_ngrams
# This is a cpp module. Compile janitor_util.cpp with: # This is a cpp module. Compile janitor_util.cpp with:
# c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) -undefined dynamic_lookup # c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) -undefined dynamic_lookup
......
...@@ -8,6 +8,7 @@ import lm_eval.base ...@@ -8,6 +8,7 @@ import lm_eval.base
import lm_eval.decontamination import lm_eval.decontamination
import numpy as np import numpy as np
from lm_eval.utils import positional_deprecated from lm_eval.utils import positional_deprecated
from lm_eval.decontamination.decontaminate import get_train_overlap
@positional_deprecated @positional_deprecated
def simple_evaluate(model, model_args=None, tasks=[], def simple_evaluate(model, model_args=None, tasks=[],
...@@ -189,7 +190,7 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None, ...@@ -189,7 +190,7 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
# Compare all tasks/sets at once to ensure a single training set scan # Compare all tasks/sets at once to ensure a single training set scan
if decontaminate: if decontaminate:
print("Finding train/test overlap, please wait...") print("Finding train/test overlap, please wait...")
overlaps = lm_eval.decontamination.get_train_overlap(docs_for_decontamination, decontamination_ngrams_path, limit) overlaps = get_train_overlap(docs_for_decontamination, decontamination_ngrams_path, limit)
# all responses for each (task, doc) # all responses for each (task, doc)
process_res_queue = collections.defaultdict(list) process_res_queue = collections.defaultdict(list)
......
...@@ -35,25 +35,11 @@ def parse_args(): ...@@ -35,25 +35,11 @@ def parse_args():
parser.add_argument('--output_path', default=None) parser.add_argument('--output_path', default=None)
parser.add_argument('--limit', type=int, default=None) parser.add_argument('--limit', type=int, default=None)
parser.add_argument('--no_cache', action="store_true") parser.add_argument('--no_cache', action="store_true")
parser.add_argument('--decontaminate', action="store_true") parser.add_argument('--decontamination_ngrams_path', default=None)
parser.add_argument('--decontaminate_ngrams_path', default=None)
parser.add_argument('--decontaminate_ngrams_n_size', type=int, default=None)
parser.add_argument('--description_dict_path', default=None) parser.add_argument('--description_dict_path', default=None)
return parser.parse_args() return parser.parse_args()
def ensure_correct_decontamination_params(args):
valid = True
if args.decontaminate:
if not args.decontaminate_ngrams_n_size:
print("Please specify n size of training set n-grams. (--ngrams_n_size)")
valid = False
if not args.decontaminate_ngrams_path:
print("Please specify path containing training set n-grams. (--ngrams_path)")
valid = False
return valid
# Returns a list containing all values of the source_list that # Returns a list containing all values of the source_list that
# match at least one of the patterns # match at least one of the patterns
def pattern_match(patterns, source_list): def pattern_match(patterns, source_list):
...@@ -65,8 +51,6 @@ def pattern_match(patterns, source_list): ...@@ -65,8 +51,6 @@ def pattern_match(patterns, source_list):
def main(): def main():
args = parse_args() args = parse_args()
if not ensure_correct_decontamination_params(args):
return
assert not args.provide_description # not implemented assert not args.provide_description # not implemented
...@@ -95,9 +79,7 @@ def main(): ...@@ -95,9 +79,7 @@ def main():
no_cache=args.no_cache, no_cache=args.no_cache,
limit=args.limit, limit=args.limit,
description_dict=description_dict, description_dict=description_dict,
decontaminate=args.decontaminate, decontamination_ngrams_path=args.decontamination_ngrams_path
decontaminate_ngrams_path=args.decontaminate_ngrams_path,
decontaminate_ngrams_n_size=args.decontaminate_ngrams_n_size
) )
dumped = json.dumps(results, indent=2) dumped = json.dumps(results, indent=2)
......
{
"Data": "Pile statistics",
"Document Count": 210607728,
"Total Pile Characters": 421215456,
"File Start Offsets": [
0,
7021438,
14042822,
21066113,
28086515,
35106072,
42123306,
49145091,
56165817,
63185587,
70211208,
77234322,
84249267,
91267634,
98285983,
105305110,
112322489,
119342491,
126367373,
133389153,
140412039,
147432373,
154452516,
161470190,
168492733,
175512521,
182526939,
189547478,
196565318,
203583306
]
}
\ No newline at end of file
...@@ -21,8 +21,10 @@ Arguments ...@@ -21,8 +21,10 @@ Arguments
""" """
import argparse import argparse
import json
import pickle import pickle
import os import os
import sys
from pathlib import Path from pathlib import Path
import glob import glob
import signal import signal
...@@ -30,32 +32,89 @@ from signal import SIGINT ...@@ -30,32 +32,89 @@ from signal import SIGINT
from tqdm import tqdm from tqdm import tqdm
from scripts.clean_training_data.janitor import Janitor, word_ngrams from lm_eval.decontamination.janitor import Janitor, word_ngrams
from scripts.clean_training_data.archiver import TextArchive, Reader from lm_eval.decontamination.archiver import TextArchive, Reader
import logging import logging
from tqdm_multiprocess.logger import setup_logger_tqdm from tqdm_multiprocess.logger import setup_logger_tqdm
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
pile_document_count = 210607728
terminate = False terminate = False
def handler(signal_received, frame): def handler(signal_received, frame):
global terminate global terminate
terminate = True terminate = True
def get_pile(directory): def yield_pile(start_offsets=None, checkpoint_offset=None):
reader = Reader() directory = "pile"
for file in glob.glob(os.path.join(directory, f"*.jsonl.zst*")):
for document in reader.read(file): if not os.path.exists(directory):
yield document print("We expect the pile archives to be in the 'pile' directory, but this was not found.")
raise Exception("Pile directory not found.")
files = list(sorted(glob.glob(os.path.join(directory, "*.jsonl.zst*"))))
pile_global_offset = 0
start_file = 0
if checkpoint_offset:
for file_i, start_offset in enumerate(start_offsets):
if start_offset > checkpoint_offset:
break
def close_buckets(buckets): start_file = file_i
for bucket in buckets: pile_global_offset = start_offset
bucket.commit()
for file_i, file in enumerate(files):
if file_i < start_file:
logger.info(f"Skipping file {file}")
continue
logger.info(f"Reading from pile file: {file}")
reader = Reader()
for document in reader.read(file):
yield (pile_global_offset, document)
pile_global_offset += 1
# Hash buckets > disk backed files. Supports file position checkpointing and resuming
# Allows you to write continuously and checkpoint intermittently. If a failure occurs
# the buckets are simply truncated at your last checkpoint.
class Buckets:
def __init__(self, directory, num_buckets):
self.bucket_files = [os.path.join(directory, f"ngrams_{i}.bkt.txt") for i in range(num_buckets)]
self.buckets = list(map(TextArchive, self.bucket_files))
self.checkpoint_file = os.path.join(directory, f"bucket_offsets.ckpt")
if os.path.exists(self.checkpoint_file):
self.bucket_offsets = pickle.load(open(self.checkpoint_file, "rb"))
else:
self.bucket_offsets = [0 for i in range(len(self.buckets))]
for i, offset in enumerate(self.bucket_offsets):
bucket = self.buckets[i]
bucket.fh.seek(offset)
bucket.fh.truncate()
def add_data(self, key, value):
i = hash(key) % len(self.buckets)
bucket = self.buckets[i]
bucket.add_data(value)
def save_checkpoint(self):
for bucket in self.buckets:
bucket.fh.flush()
bucket_offsets = [bucket.fh.tell() for bucket in self.buckets]
pickle.dump(bucket_offsets, open(self.checkpoint_file, "wb"))
def close_buckets(self):
for bucket in self.buckets:
bucket.commit()
def do_ngrams_in_buckets(n_value, working_directory, bucket_count): def do_ngrams_in_buckets(n_value, working_directory, bucket_count):
pile_statistics = json.load(open("pile_statistics.json", "r"))
pile_document_count = pile_statistics["Document Count"]
start_offsets = pile_statistics["File Start Offsets"]
output_directory = os.path.join(working_directory, "output") output_directory = os.path.join(working_directory, "output")
os.makedirs(output_directory, exist_ok=True) os.makedirs(output_directory, exist_ok=True)
...@@ -68,49 +127,56 @@ def do_ngrams_in_buckets(n_value, working_directory, bucket_count): ...@@ -68,49 +127,56 @@ def do_ngrams_in_buckets(n_value, working_directory, bucket_count):
return return
# Checkpoint # Checkpoint
checkpoint_file = os.path.join(output_directory, f"ngram_buckets.ckpt") checkpoint_file = os.path.join(working_directory, f"pile_offset.ckpt")
if os.path.exists(checkpoint_file): if os.path.exists(checkpoint_file):
start_id = pickle.load(open(checkpoint_file,"rb")) checkpoint_offset = pickle.load(open(checkpoint_file,"rb"))
iterate = True
else: else:
start_id = 0 checkpoint_offset = 0
iterate = False
logger.info(f"Starting at pile document index {start_id}") logger.info(f"Starting at pile document index {checkpoint_offset}")
bucket_files = [os.path.join(output_directory, f"ngrams_{i}.bkt.txt") for i in range(bucket_count)] buckets = Buckets(output_directory, bucket_count)
buckets = list(map(TextArchive, bucket_files))
janitor = Janitor() janitor = Janitor()
current_id = 0
batch_size = 1000 batch_size = 1000
batch_counter = 0 batch_counter = 0
with tqdm(total=pile_document_count, dynamic_ncols=True, unit="docs") as progress:
for document in get_pile(working_directory):
if current_id < start_id:
if terminate:
close_buckets(buckets)
return
current_id += 1 with tqdm(total=checkpoint_offset, dynamic_ncols=True, unit="docs") as progress:
for offset, document in yield_pile(start_offsets, checkpoint_offset):
if iterate:
logger.info(f"Iterating to offset {checkpoint_offset} from {offset}")
progress.update(offset)
iterate = False
if offset < checkpoint_offset:
progress.update() progress.update()
if terminate:
return
continue continue
if offset == checkpoint_offset:
progress.reset(total=pile_document_count)
progress.update(checkpoint_offset)
# Save checkpoint every "batch_size", only allow terminate after checkpoint # Save checkpoint every "batch_size", only allow terminate after checkpoint
if batch_counter == batch_size: if batch_counter == batch_size:
progress.update(batch_size) progress.update(batch_size)
batch_counter = 0 batch_counter = 0
pickle.dump(current_id, open(checkpoint_file,"wb")) buckets.save_checkpoint()
pickle.dump(offset, open(checkpoint_file,"wb"))
if terminate: if terminate:
close_buckets(buckets) buckets.close_buckets()
return return
ngrams = word_ngrams(janitor.normalize_string(document), n_value) ngrams = word_ngrams(janitor.normalize_string(document), n_value)
for ngram in ngrams: for ngram in ngrams:
bucket = hash(ngram) % len(buckets) buckets.add_data(ngram, f"{ngram} {offset}")
buckets[bucket].add_data(f"{ngram} {current_id}")
batch_counter += 1 batch_counter += 1
current_id += 1
close_buckets(buckets) buckets.close_buckets()
Path(done_file).touch() Path(done_file).touch()
...@@ -120,6 +186,12 @@ parser.add_argument("-n", "--n_value", type=int, default=13) ...@@ -120,6 +186,12 @@ parser.add_argument("-n", "--n_value", type=int, default=13)
parser.add_argument("-buckets", "--bucket_count", type=int, default=500) parser.add_argument("-buckets", "--bucket_count", type=int, default=500)
if __name__ == '__main__': if __name__ == '__main__':
version = 1.00
print(f"Running version {version}")
if "PYTHONHASHSEED" not in os.environ or os.environ["PYTHONHASHSEED"] != "0":
print("Please run 'export PYTHONHASHSEED=0' before running generate.")
sys.exit()
# Handle sigint (ctrl-c) cleanly # Handle sigint (ctrl-c) cleanly
previous_signal_int = signal.signal(SIGINT, handler) previous_signal_int = signal.signal(SIGINT, handler)
...@@ -128,4 +200,8 @@ if __name__ == '__main__': ...@@ -128,4 +200,8 @@ if __name__ == '__main__':
setup_logger_tqdm(logfile_path) setup_logger_tqdm(logfile_path)
args = parser.parse_args() args = parser.parse_args()
do_ngrams_in_buckets(args.n_value, args.working_directory, args.bucket_count) do_ngrams_in_buckets(args.n_value, args.working_directory, args.bucket_count)
\ No newline at end of file
info_dict = {"title": "dataset ngrams", "ngram_size": 13}
info_dict_path = os.path.join(args.working_directory, "info.json")
json.dump(info_dict, open(info_dict_path, "w"))
\ No newline at end of file
from lm_eval.decontamination.archiver import Reader
import os
import json
from functools import reduce
import glob
import tqdm
from tqdm_multiprocess import TqdmMultiProcessPool
def get_file_stats(file_path, tqdm_func, global_tqdm):
reader = Reader()
total_documents = 0
total_size = 0
update_frequency = 10000
current_file_position = 0
with tqdm_func(total=os.path.getsize(file_path), dynamic_ncols=True, unit="byte", unit_scale=1) as progress:
for document in reader.read(file_path, get_meta=True):
total_size += len(document)
total_documents += 1
if total_documents % update_frequency == 0:
new_file_pos = reader.fh.tell()
bytes_read = new_file_pos - current_file_position
current_file_position = new_file_pos
progress.update(bytes_read)
global_tqdm.update(bytes_read)
return (total_documents, total_size)
def get_files():
directory = "pile"
files = list(sorted(glob.glob(os.path.join(directory, "*.jsonl.zst*"))))
print(files)
return files
def get_stats():
files = get_files()
total_size_bytes = sum(map(lambda x: os.path.getsize(x), files))
pool = TqdmMultiProcessPool(4)
global_tqdm = tqdm.tqdm(total=total_size_bytes, dynamic_ncols=True, unit="byte", unit_scale=1)
# Generate minhashes with pool
tasks = [(get_file_stats, (file,)) for file in files]
on_done = lambda _ : None
on_error = lambda _ : None
results = pool.map(global_tqdm, tasks, on_error, on_done)
total_documents, total_size = reduce(lambda x, y: (x[0]+y[0],x[1]+y[1]), results)
start_offsets = []
current_offset = 0
for file_document_count, _ in results:
start_offsets.append(current_offset)
current_offset += file_document_count
return (total_documents, total_size, start_offsets)
if __name__ == '__main__':
version = 1.01
print(f"Running version {version}")
stats_file_path = "pile_statistics.json"
if os.path.exists(stats_file_path):
stats = json.load(open(stats_file_path, "r"))
else:
document_count, total_document_size_chars, start_offsets = get_stats()
stats = {"Data": "Pile statistics",
"Document Count": document_count,
"Total Pile Characters": total_document_size_chars,
"File Start Offsets": start_offsets
}
json.dump(stats, open(stats_file_path, "w"), indent=4)
print(f"document_count: {stats['Document Count']}")
print(f"total_chars: {stats['Total Pile Characters']}")
print(f"start_offsets: {stats['File Start Offsets']}")
...@@ -3,12 +3,14 @@ from collections import Counter ...@@ -3,12 +3,14 @@ from collections import Counter
import shutil import shutil
import glob import glob
from scripts.clean_training_data.janitor import * from lm_eval.decontamination.janitor import *
from scripts.clean_training_data.generate_13_grams import do_ngrams_in_buckets from scripts.clean_training_data.generate_13_grams import do_ngrams_in_buckets
from scripts.clean_training_data.archiver import Archive, TextReader from lm_eval.decontamination.archiver import Archive, TextReader
import logging
logger = logging.getLogger(__name__)
def test_generate_13_grams_1(): def test_generate_13_grams_1(caplog):
data = """A goose (plural geese) is a bird of any of several waterfowl species in the family Anatidae. data = """A goose (plural geese) is a bird of any of several waterfowl species in the family Anatidae.
This group comprises the genera Anser (the grey geese and white geese) and Branta (the black geese). This group comprises the genera Anser (the grey geese and white geese) and Branta (the black geese).
Some other birds, mostly related to the shelducks, have "goose" as part of their names. Some other birds, mostly related to the shelducks, have "goose" as part of their names.
...@@ -22,6 +24,7 @@ def test_generate_13_grams_1(): ...@@ -22,6 +24,7 @@ def test_generate_13_grams_1():
data = data + data data = data + data
# Simple Generation # Simple Generation
print("simple generation")
n = 13 n = 13
janitor = Janitor() janitor = Janitor()
ngrams = word_ngrams(janitor.normalize_string(data), n) ngrams = word_ngrams(janitor.normalize_string(data), n)
...@@ -31,22 +34,26 @@ def test_generate_13_grams_1(): ...@@ -31,22 +34,26 @@ def test_generate_13_grams_1():
# print(comparison) # print(comparison)
# Generating into buckets # Generating into buckets
print("bucket generation")
test_working_directory = "test_generate_13_grams" test_working_directory = "test_generate_13_grams"
output_directory = os.path.join(test_working_directory, "output")
try: try:
shutil.rmtree(output_directory) shutil.rmtree(test_working_directory)
except FileNotFoundError: except FileNotFoundError:
pass pass
os.makedirs(test_working_directory, exist_ok=True) os.makedirs(test_working_directory)
archive = Archive(os.path.join(test_working_directory, "test.jsonl.zst"))
assert(not os.path.exists("pile"))
os.makedirs("pile")
archive = Archive(os.path.join("pile", "test.jsonl.zst"))
archive.add_data(data) archive.add_data(data)
archive.commit() archive.commit()
bucket_count = 4 bucket_count = 4
do_ngrams_in_buckets(n, test_working_directory, bucket_count) do_ngrams_in_buckets(n, test_working_directory, bucket_count)
# Rebuild from buckets # Rebuild from buckets
print("rebuild")
rebuilt_ngrams = [] rebuilt_ngrams = []
bucket_file_paths = glob.glob(os.path.join(test_working_directory, "output", f"*.bkt.txt")) bucket_file_paths = glob.glob(os.path.join(test_working_directory, "output", f"*.bkt.txt"))
for bucket_file_path in bucket_file_paths: for bucket_file_path in bucket_file_paths:
reader = TextReader(bucket_file_path) reader = TextReader(bucket_file_path)
...@@ -56,6 +63,7 @@ def test_generate_13_grams_1(): ...@@ -56,6 +63,7 @@ def test_generate_13_grams_1():
rebuilt_ngrams.append(ngram) rebuilt_ngrams.append(ngram)
# Compare # Compare
print("compare")
result_counter = Counter(rebuilt_ngrams) result_counter = Counter(rebuilt_ngrams)
# print(len(result_counter)) # print(len(result_counter))
# print(len(comparison_counter)) # print(len(comparison_counter))
......
import re import re
from collections import defaultdict from collections import defaultdict
from scripts.clean_training_data.janitor import * from lm_eval.decontamination.janitor import *
def simple_ngram(sequence, n): def simple_ngram(sequence, n):
ngrams = list() ngrams = list()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment