Commit f495bfb4 authored by researcher2's avatar researcher2
Browse files

Refactor for PR

parent 0f283a9c
......@@ -6,6 +6,7 @@ import io
import datetime
import mmap
import tqdm
from pathlib import Path
def json_serial(obj):
"""JSON serializer for objects not serializable by default json code"""
......@@ -61,16 +62,19 @@ class Reader:
else:
yield text
# Simple text reader and writer with same interface as above
class TextArchive:
def __init__(self, file_path, mode="ab"):
def __init__(self, file_path, mode="rb+"):
self.file_path = file_path
dir_name = os.path.dirname(file_path)
if dir_name:
os.makedirs(dir_name, exist_ok=True)
self.fh = open(self.file_path, mode)
if not os.path.exists(file_path):
Path(file_path).touch()
self.fh = open(self.file_path, mode)
def add_data(self, data, meta={}):
def add_data(self, data):
self.fh.write(data.encode('UTF-8') + b'\n')
def commit(self):
......
import mmap
import string
import time
import random
import pickle
import json
import glob
import os
import collections
import re
import tqdm
try:
import janitor_util
JANITOR_CPP = True
except Exception as e:
# print("WARNING: C++ module could not be loaded. Janitor running in python mode")
JANITOR_CPP = False
from .janitor import Janitor, word_ngrams
from .archiver import ZStdTextReader
# Was used for testing the evaluator decoupled from the full logic below
def get_train_overlap_stub(docs, ngrams_path, ngrams_n_size):
......@@ -25,11 +17,12 @@ def get_train_overlap_stub(docs, ngrams_path, ngrams_n_size):
# Returns a dictionary containing all overlapping documents in each
# task based on any 13gram being found in the training set.
# ngrams_path is the parent directory containing the "ngrams_{x}.bkt.txt.sorted.zst"
# files built by the other scripts "generate_13_grams.py" and "sort_13_gram_buckets.py.
# ngrams_n_size is expected to be 13 but we made it a parameter for generality.
# The task set is only included for caching purposes.
# task. In the standard use case, an overlap occurs when any of the 13-grams
# found in the task document exist in the training set documents.
#
# To generate 13-grams for the pile see scripts/clean_training_data. The final output of these
# scripts are an info.json file containing the n_gram_size (13) and a bunch of "ngrams_{x}.bkt.txt.sorted.zst"
# files. These should exist in the "ngrams_path" provided to this function.
# Algorithm:
# 1. Build lookups for each dataset {ngram: list(document_ids)}
......@@ -39,12 +32,12 @@ def get_train_overlap_stub(docs, ngrams_path, ngrams_n_size):
# 4. Strip the task_set from the dictionary keys and return
#
# We cache the task+set lookups as well as the overlaps.
#
# Currently calculating some per file ngram stats for interest, might remove before merging into main
def get_train_overlap(docs_by_task_set, ngrams_path, limit):
# return get_train_overlap_stub(docs, ngrams_path, ngrams_n_size)
# TODO: infer ngrams_n_size from ngrams_path
info_dict_path = os.path.join(ngrams_path, "info.json")
info_dict = json.load(open(info_dict_path, "r"))
ngrams_n_size = info_dict["ngram_size"]
janitor = Janitor()
......@@ -62,16 +55,17 @@ def get_train_overlap(docs_by_task_set, ngrams_path, limit):
for (task_name, task_set), docs in docs_by_task_set.items():
if not os.path.exists(f"data/{task_name}"):
os.mkdir(f"data/{task_name}")
# Check if we've decontaminated this set before
# Check if we've decontaminated this combination before
overlaps_dump_path = get_overlaps_dump_path(task_name, task_set, ngrams_n_size, limit)
if os.path.exists(overlaps_dump_path):
duplicates[(task_name, task_set)] = pickle.load(open(overlaps_dump_path, "rb"))
sets_to_decontaminate -= 1
continue
else:
duplicates[(task_name, task_set)] = set() # No defaultdict, we want to dump empty sets too later
duplicates[(task_name, task_set)] = set()
# Build/load the task lookup {ngram: documents}.
# Build/load the task lookup {ngram: set(documents)}.
task_set_lookup_path = f"data/{task_name}/{task_set}_{ngrams_n_size}grams_limit{limit}.lookup"
if os.path.exists(task_set_lookup_path):
print(f"{task_set_lookup_path} available, loading...")
......@@ -104,7 +98,7 @@ def get_train_overlap(docs_by_task_set, ngrams_path, limit):
elapsed = time.perf_counter() - start
print(f"Merging lookups took {elapsed:0.5f} seconds.")
print(f"13 grams files found in {ngrams_path}:")
print(f"{ngrams_n_size} grams files found in {ngrams_path}:")
files = glob.glob(os.path.join(ngrams_path, f"*.sorted.zst"))
print(files)
......@@ -157,281 +151,3 @@ def get_train_overlap(docs_by_task_set, ngrams_path, limit):
# Strip task set and return
return {task_name: doc_ids for (task_name, task_set), doc_ids in duplicates.items()}
# Implementation from nltk source
# https://www.nltk.org/_modules/nltk/util.html
def form_ngrams(sequence, n):
history = []
while n > 1:
# PEP 479, prevent RuntimeError from being raised when StopIteration bubbles out of generator
try:
next_item = next(sequence)
except StopIteration:
# no more data, terminate the generator
return
history.append(next_item)
n -= 1
for item in sequence:
history.append(item)
yield tuple(history)
del history[0]
def word_ngrams(s, n):
"""Splits a string into ngram words"""
tokens = s.split() # not a generator :(
ngram_seqs = form_ngrams(iter(tokens), n)
return (" ".join(ngram) for ngram in ngram_seqs)
# https://stackoverflow.com/questions/13734451/string-split-with-indices-in-python
def split_indices(s):
"""Splits a string on whitespaces and records the indices of each in the original string.
@:return generator((word, (start_idx, end_idx)), ...)
"""
return ((m.group(0), (m.start(), m.end() - 1)) for m in re.finditer(r'\S+', s))
def word_ngrams_indices(s, n):
"""Splits a string into pairs of (ngram words, their start/end indices)"""
tokens_with_indices = split_indices(s)
# Generator of ngrams of (word, idx_pairs)
# (
# [(word, (start,end)), (word, (start, end))...],
# [(word, (start, end)), ...],
# ...
# )
ngram_seqs_with_indices = form_ngrams(tokens_with_indices, n)
# Generator of pairs of word and index ngrams
# (
# ([word, word, ...], [(start,end), (start,end), ...]),
# ...
# )
ngram_indices_pairs = (zip(*ngram_with_indices) for ngram_with_indices in ngram_seqs_with_indices)
# Generator of ( (word_ngram, (start, end)), (word_ngram, start, end)), ...)
return ((" ".join(ngram_seq), (indices[0][0], indices[-1][1])) for ngram_seq, indices in ngram_indices_pairs)
class Janitor:
# FIXME delete_chars: Should anything else go here? Special chars?
def __init__(
self,
ngram_n=13,
window_to_remove=200,
too_dirty_cutoff=10,
minimum_slice_length=200,
delete_chars=string.punctuation
):
self.ngram_n = ngram_n
self.window_to_remove = window_to_remove
self.too_dirty_cutoff = too_dirty_cutoff
self.minimum_slice_length = minimum_slice_length
self.delete_chars = delete_chars
self.dirt_ngrams = set()
# If in python, we'll translate uppercase to lowercase and delete naughty characters.
# This is fast by python standards
# https://stackoverflow.com/questions/638893/what-is-the-most-efficient-way-in-python-to-convert-a-string-to-all-lowercase-st
self.translation_table = str.maketrans(
string.ascii_lowercase + string.ascii_uppercase, # These characters
string.ascii_lowercase * 2, # Become these characters
self.delete_chars # These are deleted
)
##############
# I/O for saving contamination ngrams
##############
def save_contamination_ngrams(self, filename):
with open(filename, 'wb') as fp:
pickle.dump(filename, fp)
def load_contamination_ngrams(self, filename):
with open(filename, 'rb') as fp:
self.dirt_ngrams = pickle.load(fp)
##############
# Call these :)
##############
def register_contaminant(self, dirt_string):
"""Register a string as contamination to be removed, e.g. a test set
This breaks the dirt_string into ngrams to store for future cleaning"""
if JANITOR_CPP:
return self.register_contaminant_cpp(dirt_string)
else:
print("WARNING: Janitor running in python mode")
return self.register_contaminant_python(dirt_string)
def clean(self, dirty_string):
"""Clean a string (e.g. a training set) by removing all ngrams previously
reigstered as contaminants. Returns a list of clean chunks, or empty if
the string was too dirty"""
if JANITOR_CPP:
return self.clean_cpp(dirty_string)
else:
print("WARNING: Janitor running in python mode")
return self.clean_python(dirty_string)
def _split_chunks(self, dirty_string, dirty_parts):
clean_chunks = []
splice_idx = 0
end = -1
for i, (ngram, start, end) in enumerate(dirty_parts):
if i >= self.too_dirty_cutoff:
return []
start = max(0, start - self.window_to_remove)
end = min(len(dirty_string), end + self.window_to_remove)
if start - splice_idx > self.minimum_slice_length:
clean_chunks.append(dirty_string[splice_idx: start])
splice_idx = end
if end < len(dirty_string) - self.minimum_slice_length:
clean_chunks.append(dirty_string[end+1:])
return clean_chunks
##############
# Fast C++
##############
def register_contaminant_cpp(self, dirt_string):
self.dirt_ngrams.update(janitor_util.clean_ngram(dirt_string, self.delete_chars, self.ngram_n))
def clean_cpp(self, dirty_string):
contamination_indices = janitor_util.clean_ngram_with_indices(dirty_string, self.delete_chars, self.ngram_n)
return self._split_chunks(dirty_string, contamination_indices)
##############
# Slow python
##############
def normalize_string(self, s):
return s.translate(self.translation_table)
def register_contaminant_python(self, dirt_string):
self.dirt_ngrams.update(word_ngrams(self.normalize_string(dirt_string), self.ngram_n))
def clean_python(self, dirty_string):
contamination_indices = (
(None, *idx_pair)
for dirty_ngram, idx_pair in word_ngrams_indices(dirty_string, self.ngram_n)
if self.normalize_string(dirty_ngram) in self.dirt_ngrams
)
return self._split_chunks(dirty_string, contamination_indices)
# Implementation from nltk source
# https://www.nltk.org/_modules/nltk/util.html
def form_ngrams(sequence, n):
history = []
while n > 1:
# PEP 479, prevent RuntimeError from being raised when StopIteration bubbles out of generator
try:
next_item = next(sequence)
except StopIteration:
# no more data, terminate the generator
return
history.append(next_item)
n -= 1
for item in sequence:
history.append(item)
yield tuple(history)
del history[0]
def word_ngrams(s, n):
"""Splits a string into ngram words"""
tokens = s.split() # not a generator :(
ngram_seqs = form_ngrams(iter(tokens), n)
return (" ".join(ngram) for ngram in ngram_seqs)
# Simple text reader and writer with same interface as above
class TextArchive:
def __init__(self, file_path, mode="ab"):
self.file_path = file_path
dir_name = os.path.dirname(file_path)
if dir_name:
os.makedirs(dir_name, exist_ok=True)
self.fh = open(self.file_path, mode)
def add_data(self, data, meta={}):
self.fh.write(data.encode('UTF-8') + b'\n')
def commit(self):
self.fh.flush()
self.fh.close()
class TextReader:
def __init__(self, file_path):
self.file_path = file_path
# Optimized mmap read with infrequent tqdm updates to maintain speed
# Tested up to 250MB/s.
def read_tqdm(self, update_frequency=10000):
current_file_position = 0
line_counter = 0
with open(self.file_path, 'r') as fh, \
tqdm.tqdm(total=os.path.getsize(self.file_path), dynamic_ncols=True,
unit="byte", unit_scale=1) as progress:
with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
for line in iter(mmap_obj.readline, b""):
line = line.decode("utf-8")
line_counter += 1
if line_counter == update_frequency:
new_file_pos = mmap_obj.tell()
bytes_read = new_file_pos - current_file_position
current_file_position = new_file_pos
progress.update(bytes_read)
line_counter = 0
yield line[:-1]
def read_and_tell(self):
current_file_position = 0
with open(self.file_path, 'r', encoding="utf8") as fh:
with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
for line in iter(mmap_obj.readline, b""):
line = line.decode("utf-8")
new_file_pos = mmap_obj.tell()
raw_bytes_read = new_file_pos - current_file_position
current_file_position = new_file_pos
yield line[:-1], raw_bytes_read
def read(self):
with open(self.file_path, 'r', encoding="utf8") as fh:
with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
for line in iter(mmap_obj.readline, b""):
line = line.decode("utf-8")
yield line[:-1]
def read_slow(self):
with open(self.file_path, 'r', encoding="utf8") as fh:
while True:
line = fh.readline()
if line == -1 or line == "":
break
else:
yield line[:-1]
# Optimized for speed. Decompresses the archive in shell before
# using the mmap'd TextReader.
class ZStdTextReader:
def __init__(self, file):
self.file = file
def read_tqdm(self):
decompressed_file = self.file[:-4]
print("Decompressing file, please wait...")
os.system(f"zstd -d {self.file}") # linux decompress is faster
reader = TextReader(decompressed_file)
yield from reader.read_tqdm()
os.remove(decompressed_file)
......@@ -4,7 +4,6 @@ import timeit
import pickle
import traceback
from pprint import pprint
from lm_eval.decontamination import word_ngrams
# This is a cpp module. Compile janitor_util.cpp with:
# c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) -undefined dynamic_lookup
......
......@@ -8,6 +8,7 @@ import lm_eval.base
import lm_eval.decontamination
import numpy as np
from lm_eval.utils import positional_deprecated
from lm_eval.decontamination.decontaminate import get_train_overlap
@positional_deprecated
def simple_evaluate(model, model_args=None, tasks=[],
......@@ -189,7 +190,7 @@ def evaluate(lm, task_dict, provide_description=None, num_fewshot=0, limit=None,
# Compare all tasks/sets at once to ensure a single training set scan
if decontaminate:
print("Finding train/test overlap, please wait...")
overlaps = lm_eval.decontamination.get_train_overlap(docs_for_decontamination, decontamination_ngrams_path, limit)
overlaps = get_train_overlap(docs_for_decontamination, decontamination_ngrams_path, limit)
# all responses for each (task, doc)
process_res_queue = collections.defaultdict(list)
......
......@@ -35,25 +35,11 @@ def parse_args():
parser.add_argument('--output_path', default=None)
parser.add_argument('--limit', type=int, default=None)
parser.add_argument('--no_cache', action="store_true")
parser.add_argument('--decontaminate', action="store_true")
parser.add_argument('--decontaminate_ngrams_path', default=None)
parser.add_argument('--decontaminate_ngrams_n_size', type=int, default=None)
parser.add_argument('--decontamination_ngrams_path', default=None)
parser.add_argument('--description_dict_path', default=None)
return parser.parse_args()
def ensure_correct_decontamination_params(args):
valid = True
if args.decontaminate:
if not args.decontaminate_ngrams_n_size:
print("Please specify n size of training set n-grams. (--ngrams_n_size)")
valid = False
if not args.decontaminate_ngrams_path:
print("Please specify path containing training set n-grams. (--ngrams_path)")
valid = False
return valid
# Returns a list containing all values of the source_list that
# match at least one of the patterns
def pattern_match(patterns, source_list):
......@@ -65,8 +51,6 @@ def pattern_match(patterns, source_list):
def main():
args = parse_args()
if not ensure_correct_decontamination_params(args):
return
assert not args.provide_description # not implemented
......@@ -95,9 +79,7 @@ def main():
no_cache=args.no_cache,
limit=args.limit,
description_dict=description_dict,
decontaminate=args.decontaminate,
decontaminate_ngrams_path=args.decontaminate_ngrams_path,
decontaminate_ngrams_n_size=args.decontaminate_ngrams_n_size
decontamination_ngrams_path=args.decontamination_ngrams_path
)
dumped = json.dumps(results, indent=2)
......
{
"Data": "Pile statistics",
"Document Count": 210607728,
"Total Pile Characters": 421215456,
"File Start Offsets": [
0,
7021438,
14042822,
21066113,
28086515,
35106072,
42123306,
49145091,
56165817,
63185587,
70211208,
77234322,
84249267,
91267634,
98285983,
105305110,
112322489,
119342491,
126367373,
133389153,
140412039,
147432373,
154452516,
161470190,
168492733,
175512521,
182526939,
189547478,
196565318,
203583306
]
}
\ No newline at end of file
......@@ -21,8 +21,10 @@ Arguments
"""
import argparse
import json
import pickle
import os
import sys
from pathlib import Path
import glob
import signal
......@@ -30,32 +32,89 @@ from signal import SIGINT
from tqdm import tqdm
from scripts.clean_training_data.janitor import Janitor, word_ngrams
from scripts.clean_training_data.archiver import TextArchive, Reader
from lm_eval.decontamination.janitor import Janitor, word_ngrams
from lm_eval.decontamination.archiver import TextArchive, Reader
import logging
from tqdm_multiprocess.logger import setup_logger_tqdm
logger = logging.getLogger(__name__)
pile_document_count = 210607728
terminate = False
def handler(signal_received, frame):
global terminate
terminate = True
def get_pile(directory):
reader = Reader()
for file in glob.glob(os.path.join(directory, f"*.jsonl.zst*")):
for document in reader.read(file):
yield document
def yield_pile(start_offsets=None, checkpoint_offset=None):
directory = "pile"
if not os.path.exists(directory):
print("We expect the pile archives to be in the 'pile' directory, but this was not found.")
raise Exception("Pile directory not found.")
files = list(sorted(glob.glob(os.path.join(directory, "*.jsonl.zst*"))))
pile_global_offset = 0
start_file = 0
if checkpoint_offset:
for file_i, start_offset in enumerate(start_offsets):
if start_offset > checkpoint_offset:
break
def close_buckets(buckets):
for bucket in buckets:
bucket.commit()
start_file = file_i
pile_global_offset = start_offset
for file_i, file in enumerate(files):
if file_i < start_file:
logger.info(f"Skipping file {file}")
continue
logger.info(f"Reading from pile file: {file}")
reader = Reader()
for document in reader.read(file):
yield (pile_global_offset, document)
pile_global_offset += 1
# Hash buckets > disk backed files. Supports file position checkpointing and resuming
# Allows you to write continuously and checkpoint intermittently. If a failure occurs
# the buckets are simply truncated at your last checkpoint.
class Buckets:
def __init__(self, directory, num_buckets):
self.bucket_files = [os.path.join(directory, f"ngrams_{i}.bkt.txt") for i in range(num_buckets)]
self.buckets = list(map(TextArchive, self.bucket_files))
self.checkpoint_file = os.path.join(directory, f"bucket_offsets.ckpt")
if os.path.exists(self.checkpoint_file):
self.bucket_offsets = pickle.load(open(self.checkpoint_file, "rb"))
else:
self.bucket_offsets = [0 for i in range(len(self.buckets))]
for i, offset in enumerate(self.bucket_offsets):
bucket = self.buckets[i]
bucket.fh.seek(offset)
bucket.fh.truncate()
def add_data(self, key, value):
i = hash(key) % len(self.buckets)
bucket = self.buckets[i]
bucket.add_data(value)
def save_checkpoint(self):
for bucket in self.buckets:
bucket.fh.flush()
bucket_offsets = [bucket.fh.tell() for bucket in self.buckets]
pickle.dump(bucket_offsets, open(self.checkpoint_file, "wb"))
def close_buckets(self):
for bucket in self.buckets:
bucket.commit()
def do_ngrams_in_buckets(n_value, working_directory, bucket_count):
pile_statistics = json.load(open("pile_statistics.json", "r"))
pile_document_count = pile_statistics["Document Count"]
start_offsets = pile_statistics["File Start Offsets"]
output_directory = os.path.join(working_directory, "output")
os.makedirs(output_directory, exist_ok=True)
......@@ -68,49 +127,56 @@ def do_ngrams_in_buckets(n_value, working_directory, bucket_count):
return
# Checkpoint
checkpoint_file = os.path.join(output_directory, f"ngram_buckets.ckpt")
checkpoint_file = os.path.join(working_directory, f"pile_offset.ckpt")
if os.path.exists(checkpoint_file):
start_id = pickle.load(open(checkpoint_file,"rb"))
checkpoint_offset = pickle.load(open(checkpoint_file,"rb"))
iterate = True
else:
start_id = 0
checkpoint_offset = 0
iterate = False
logger.info(f"Starting at pile document index {start_id}")
bucket_files = [os.path.join(output_directory, f"ngrams_{i}.bkt.txt") for i in range(bucket_count)]
buckets = list(map(TextArchive, bucket_files))
logger.info(f"Starting at pile document index {checkpoint_offset}")
buckets = Buckets(output_directory, bucket_count)
janitor = Janitor()
current_id = 0
batch_size = 1000
batch_counter = 0
with tqdm(total=pile_document_count, dynamic_ncols=True, unit="docs") as progress:
for document in get_pile(working_directory):
if current_id < start_id:
if terminate:
close_buckets(buckets)
return
current_id += 1
with tqdm(total=checkpoint_offset, dynamic_ncols=True, unit="docs") as progress:
for offset, document in yield_pile(start_offsets, checkpoint_offset):
if iterate:
logger.info(f"Iterating to offset {checkpoint_offset} from {offset}")
progress.update(offset)
iterate = False
if offset < checkpoint_offset:
progress.update()
if terminate:
return
continue
if offset == checkpoint_offset:
progress.reset(total=pile_document_count)
progress.update(checkpoint_offset)
# Save checkpoint every "batch_size", only allow terminate after checkpoint
if batch_counter == batch_size:
progress.update(batch_size)
batch_counter = 0
pickle.dump(current_id, open(checkpoint_file,"wb"))
buckets.save_checkpoint()
pickle.dump(offset, open(checkpoint_file,"wb"))
if terminate:
close_buckets(buckets)
buckets.close_buckets()
return
ngrams = word_ngrams(janitor.normalize_string(document), n_value)
for ngram in ngrams:
bucket = hash(ngram) % len(buckets)
buckets[bucket].add_data(f"{ngram} {current_id}")
buckets.add_data(ngram, f"{ngram} {offset}")
batch_counter += 1
current_id += 1
close_buckets(buckets)
buckets.close_buckets()
Path(done_file).touch()
......@@ -120,6 +186,12 @@ parser.add_argument("-n", "--n_value", type=int, default=13)
parser.add_argument("-buckets", "--bucket_count", type=int, default=500)
if __name__ == '__main__':
version = 1.00
print(f"Running version {version}")
if "PYTHONHASHSEED" not in os.environ or os.environ["PYTHONHASHSEED"] != "0":
print("Please run 'export PYTHONHASHSEED=0' before running generate.")
sys.exit()
# Handle sigint (ctrl-c) cleanly
previous_signal_int = signal.signal(SIGINT, handler)
......@@ -128,4 +200,8 @@ if __name__ == '__main__':
setup_logger_tqdm(logfile_path)
args = parser.parse_args()
do_ngrams_in_buckets(args.n_value, args.working_directory, args.bucket_count)
\ No newline at end of file
do_ngrams_in_buckets(args.n_value, args.working_directory, args.bucket_count)
info_dict = {"title": "dataset ngrams", "ngram_size": 13}
info_dict_path = os.path.join(args.working_directory, "info.json")
json.dump(info_dict, open(info_dict_path, "w"))
\ No newline at end of file
from lm_eval.decontamination.archiver import Reader
import os
import json
from functools import reduce
import glob
import tqdm
from tqdm_multiprocess import TqdmMultiProcessPool
def get_file_stats(file_path, tqdm_func, global_tqdm):
reader = Reader()
total_documents = 0
total_size = 0
update_frequency = 10000
current_file_position = 0
with tqdm_func(total=os.path.getsize(file_path), dynamic_ncols=True, unit="byte", unit_scale=1) as progress:
for document in reader.read(file_path, get_meta=True):
total_size += len(document)
total_documents += 1
if total_documents % update_frequency == 0:
new_file_pos = reader.fh.tell()
bytes_read = new_file_pos - current_file_position
current_file_position = new_file_pos
progress.update(bytes_read)
global_tqdm.update(bytes_read)
return (total_documents, total_size)
def get_files():
directory = "pile"
files = list(sorted(glob.glob(os.path.join(directory, "*.jsonl.zst*"))))
print(files)
return files
def get_stats():
files = get_files()
total_size_bytes = sum(map(lambda x: os.path.getsize(x), files))
pool = TqdmMultiProcessPool(4)
global_tqdm = tqdm.tqdm(total=total_size_bytes, dynamic_ncols=True, unit="byte", unit_scale=1)
# Generate minhashes with pool
tasks = [(get_file_stats, (file,)) for file in files]
on_done = lambda _ : None
on_error = lambda _ : None
results = pool.map(global_tqdm, tasks, on_error, on_done)
total_documents, total_size = reduce(lambda x, y: (x[0]+y[0],x[1]+y[1]), results)
start_offsets = []
current_offset = 0
for file_document_count, _ in results:
start_offsets.append(current_offset)
current_offset += file_document_count
return (total_documents, total_size, start_offsets)
if __name__ == '__main__':
version = 1.01
print(f"Running version {version}")
stats_file_path = "pile_statistics.json"
if os.path.exists(stats_file_path):
stats = json.load(open(stats_file_path, "r"))
else:
document_count, total_document_size_chars, start_offsets = get_stats()
stats = {"Data": "Pile statistics",
"Document Count": document_count,
"Total Pile Characters": total_document_size_chars,
"File Start Offsets": start_offsets
}
json.dump(stats, open(stats_file_path, "w"), indent=4)
print(f"document_count: {stats['Document Count']}")
print(f"total_chars: {stats['Total Pile Characters']}")
print(f"start_offsets: {stats['File Start Offsets']}")
......@@ -3,12 +3,14 @@ from collections import Counter
import shutil
import glob
from scripts.clean_training_data.janitor import *
from lm_eval.decontamination.janitor import *
from scripts.clean_training_data.generate_13_grams import do_ngrams_in_buckets
from scripts.clean_training_data.archiver import Archive, TextReader
from lm_eval.decontamination.archiver import Archive, TextReader
import logging
logger = logging.getLogger(__name__)
def test_generate_13_grams_1():
def test_generate_13_grams_1(caplog):
data = """A goose (plural geese) is a bird of any of several waterfowl species in the family Anatidae.
This group comprises the genera Anser (the grey geese and white geese) and Branta (the black geese).
Some other birds, mostly related to the shelducks, have "goose" as part of their names.
......@@ -22,6 +24,7 @@ def test_generate_13_grams_1():
data = data + data
# Simple Generation
print("simple generation")
n = 13
janitor = Janitor()
ngrams = word_ngrams(janitor.normalize_string(data), n)
......@@ -31,22 +34,26 @@ def test_generate_13_grams_1():
# print(comparison)
# Generating into buckets
print("bucket generation")
test_working_directory = "test_generate_13_grams"
output_directory = os.path.join(test_working_directory, "output")
try:
shutil.rmtree(output_directory)
shutil.rmtree(test_working_directory)
except FileNotFoundError:
pass
os.makedirs(test_working_directory, exist_ok=True)
archive = Archive(os.path.join(test_working_directory, "test.jsonl.zst"))
os.makedirs(test_working_directory)
assert(not os.path.exists("pile"))
os.makedirs("pile")
archive = Archive(os.path.join("pile", "test.jsonl.zst"))
archive.add_data(data)
archive.commit()
bucket_count = 4
do_ngrams_in_buckets(n, test_working_directory, bucket_count)
# Rebuild from buckets
print("rebuild")
rebuilt_ngrams = []
bucket_file_paths = glob.glob(os.path.join(test_working_directory, "output", f"*.bkt.txt"))
for bucket_file_path in bucket_file_paths:
reader = TextReader(bucket_file_path)
......@@ -56,6 +63,7 @@ def test_generate_13_grams_1():
rebuilt_ngrams.append(ngram)
# Compare
print("compare")
result_counter = Counter(rebuilt_ngrams)
# print(len(result_counter))
# print(len(comparison_counter))
......
import re
from collections import defaultdict
from scripts.clean_training_data.janitor import *
from lm_eval.decontamination.janitor import *
def simple_ngram(sequence, n):
ngrams = list()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment