Commit 2a11129d authored by researcher2's avatar researcher2
Browse files

Add pile 13gram processing.

parent ab8da1f0
import os
import zstandard
import json
import jsonlines
import io
import datetime
def json_serial(obj):
"""JSON serializer for objects not serializable by default json code"""
if isinstance(obj, (datetime.datetime,)):
return obj.isoformat()
raise TypeError ("Type %s not serializable" % type(obj))
# Modified version of lm_dataformat Archive for single file.
class Archive:
def __init__(self, file_path, compression_level=3):
self.file_path = file_path
dir_name = os.path.dirname(file_path)
if dir_name:
os.makedirs(dir_name, exist_ok=True)
self.fh = open(self.file_path, 'wb')
self.cctx = zstandard.ZstdCompressor(level=compression_level)
self.compressor = self.cctx.stream_writer(self.fh)
def add_data(self, data, meta={}):
self.compressor.write(json.dumps({'text': data, 'meta': meta}, default=json_serial).encode('UTF-8') + b'\n')
def commit(self):
self.compressor.flush(zstandard.FLUSH_FRAME)
self.fh.flush()
self.fh.close()
# Modified version of lm_dataformat Reader with self.fh set, allowing peeking for tqdm.
class Reader:
def __init__(self):
pass
def read(self, file, get_meta=False, autojoin_paragraphs=True, para_joiner='\n\n'):
with open(file, 'rb') as fh:
self.fh = fh
cctx = zstandard.ZstdDecompressor()
reader = io.BufferedReader(cctx.stream_reader(fh))
rdr = jsonlines.Reader(reader)
for ob in rdr:
# naive jsonl where each object is just the string itself, with no meta. For legacy compatibility.
if isinstance(ob, str):
assert not get_meta
yield ob
continue
text = ob['text']
if autojoin_paragraphs and isinstance(text, list):
text = para_joiner.join(text)
if get_meta:
yield text, (ob['meta'] if 'meta' in ob else {})
else:
yield text
# Simple text reader and writer with same interface as above
class TextArchive:
def __init__(self, file_path, mode="ab"):
self.file_path = file_path
dir_name = os.path.dirname(file_path)
if dir_name:
os.makedirs(dir_name, exist_ok=True)
self.fh = open(self.file_path, mode)
def add_data(self, data, meta={}):
self.fh.write(data.encode('UTF-8') + b'\n')
def commit(self):
self.fh.flush()
self.fh.close()
class TextReader:
def __init__(self, file_path):
self.file_path = file_path
def read(self):
with open(self.file_path, 'r', encoding="utf8") as fh:
self.fh = fh
while True:
line = self.fh.readline()
if line == -1 or line == "":
break
else :
yield line[:-1]
\ No newline at end of file
"""
Outputs all 13-grams found in The Pile.
Loops through all documents and uses the logic found in janitor.py to extract 13-grams.
We bucket each 13-gram by hash into separate file buckets to allow easy parallel processing in the
next stage. We also include the current pile document_id with each ngram instance to allow the
filtering to exclude 13-grams that match more then 10 unique documents (done further down the pipeline).
We didn't use lm_dataformat to output as it increases time 4x (slow jsonify) and makes
resuming hard (and we had the storage).
Arguments
---------
--working_directory (-dir)
Directory containing the pile distribution. An "output" subdirectory will be created underneath
to store the bucketed 13-grams, checkpoint and done files. Default: current directory
--n_value (-n)
n value in n-gram, added for later use if ever needed. Default: 13
--bucket_count (-buckets)
Number of file buckets to use when generating 13grams. Default: 500
"""
import argparse
import pickle
import os
from pathlib import Path
import glob
import signal
from signal import SIGINT
from tqdm import tqdm
from scripts.clean_training_data.janitor import Janitor, word_ngrams
from scripts.clean_training_data.archiver import TextArchive, Reader
import logging
from tqdm_multiprocess.logger import setup_logger_tqdm
logger = logging.getLogger(__name__)
pile_document_count = 210607728
terminate = False
def handler(signal_received, frame):
global terminate
terminate = True
def get_pile(directory):
reader = Reader()
for file in glob.glob(os.path.join(directory, f"*.jsonl.zst*")):
for document in reader.read(file):
yield document
def close_buckets(buckets):
for bucket in buckets:
bucket.commit()
def do_ngrams_in_buckets(n_value, working_directory, bucket_count):
output_directory = os.path.join(working_directory, "output")
os.makedirs(output_directory, exist_ok=True)
logger.info(f"Generating {n_value}-grams and bucketing.")
# Done file
done_file = os.path.join(output_directory, f"ngram_buckets.done")
if os.path.exists(done_file):
logger.info("ngrams already generated and bucketed, skipping")
return
# Checkpoint
checkpoint_file = os.path.join(output_directory, f"ngram_buckets.ckpt")
if os.path.exists(checkpoint_file):
start_id = pickle.load(open(checkpoint_file,"rb"))
else:
start_id = 0
logger.info(f"Starting at pile document index {start_id}")
bucket_files = [os.path.join(output_directory, f"ngrams_{i}.bkt.txt") for i in range(bucket_count)]
buckets = list(map(TextArchive, bucket_files))
janitor = Janitor()
current_id = 0
batch_size = 1000
batch_counter = 0
with tqdm(total=pile_document_count, dynamic_ncols=True, unit="docs") as progress:
for document in get_pile(working_directory):
if current_id < start_id:
if terminate:
close_buckets(buckets)
return
current_id += 1
progress.update()
continue
# Save checkpoint every "batch_size", only allow terminate after checkpoint
if batch_counter == batch_size:
progress.update(batch_size)
batch_counter = 0
pickle.dump(current_id, open(checkpoint_file,"wb"))
if terminate:
close_buckets(buckets)
return
ngrams = word_ngrams(janitor.normalize_string(document), n_value)
for ngram in ngrams:
bucket = hash(ngram) % len(buckets)
buckets[bucket].add_data(f"{ngram} {current_id}")
batch_counter += 1
current_id += 1
close_buckets(buckets)
Path(done_file).touch()
parser = argparse.ArgumentParser(description='Generate 13 grams from Pile.')
parser.add_argument("-dir", "--working_directory", default="")
parser.add_argument("-n", "--n_value", type=int, default=13)
parser.add_argument("-buckets", "--bucket_count", type=int, default=500)
if __name__ == '__main__':
# Handle sigint (ctrl-c) cleanly
previous_signal_int = signal.signal(SIGINT, handler)
logfile_path = "ngrams.log"
setup_logger_tqdm(logfile_path)
args = parser.parse_args()
do_ngrams_in_buckets(args.n_value, args.working_directory, args.bucket_count)
\ No newline at end of file
"""
Processes each sorted bucket, creating a new file listing all ngrams that matched more then 10
unique documents with their unique document counts. Uses multiprocessing and very little memory
as we stream from presorted buckets. Will use a lot of disk though.
Arguments
---------
--working_directory (-dir)
Directory containing the sorted buckets, processed files will be deposited here. Default: current directory
--move_dir (-move)
Directory to move processed 13grams too. Default: Do nothing
--process_count (-procs)
Number of processes to use. Default: 4
"""
import argparse
import glob
import os
from pathlib import Path
import re
import shutil
from tqdm import tqdm
from tqdm_multiprocess import TqdmMultiProcessPool
from scripts.clean_training_data.archiver import TextReader, TextArchive
import logging
from tqdm_multiprocess.logger import setup_logger_tqdm
logger = logging.getLogger(__name__)
# Multiprocessed
def process_bucket(bucket_file_path, processed_directory, move_dir, tqdm_func, global_tqdm):
bucket_id = re.sub("\D", "", os.path.basename(bucket_file_path))
done_file = os.path.join(processed_directory, f"ngram_bucket_processing_{bucket_id}.done")
if os.path.exists(done_file):
logger.info(f"bucket {bucket_id} already processed, skipping")
return
# For managing tqdm
file_size = os.path.getsize(bucket_file_path)
bucket_progress = tqdm_func(total=file_size, dynamic_ncols=True, unit="byte", unit_scale=1)
current_file_position = 0
update_frequency = 100 * 1000000 # 100mb
update_counter = 0
# Iterate through and output ngrams which occur in more then 10 documents
bucket = TextReader(bucket_file_path)
output_file_path = bucket_file_path + ".processed"
output_archive = TextArchive(output_file_path, mode="wb")
current_ngram = ""
current_ngram_document_ids = set()
for line in bucket.read():
[ngram, document_id] = line.rsplit(" ", 1)
# Write ngram if more then 10 unique document occurences
if ngram != current_ngram:
if len(current_ngram_document_ids) > 10:
output_archive.add_data(f"{current_ngram} {len(current_ngram_document_ids)}")
current_ngram = ngram
current_ngram_document_ids = set()
current_ngram_document_ids.add(document_id)
# Update tqdm
update_counter += bucket.fh.tell() - current_file_position
current_file_position = bucket.fh.tell()
if update_counter > update_frequency:
bucket_progress.update(update_counter)
update_counter = 0
# Remainder
if len(current_ngram_document_ids) > 10:
output_archive.add_data(f"{current_ngram} {len(current_ngram_document_ids)}")
output_archive.commit()
Path(done_file).touch()
if move_dir:
shutil.move(output_file_path, move_dir)
global_tqdm.update()
def process_sorted_buckets(working_directory, move_dir, process_count):
bucket_file_paths = glob.glob(os.path.join(working_directory, f"*.bkt.txt.sorted"))
processed_directory = os.path.join(working_directory, "processed")
os.makedirs(processed_directory, exist_ok=True)
pool = TqdmMultiProcessPool(process_count)
tasks = [(process_bucket, (bucket_file, processed_directory, move_dir)) for bucket_file in bucket_file_paths]
global_tqdm = tqdm(total=len(bucket_file_paths), dynamic_ncols=True, unit="bucket")
on_done = lambda _ : None
on_error = lambda _ : None
_ = pool.map(global_tqdm, tasks, on_error, on_done)
parser = argparse.ArgumentParser(description='Process 13 grams from sorted buckets.')
parser.add_argument("-dir", "--working_directory", default="")
parser.add_argument("-move", "--move_dir", default="")
parser.add_argument("-procs", "--process_count", type=int, default=4)
if __name__ == '__main__':
logfile_path = "process13grams.log"
setup_logger_tqdm(logfile_path)
args = parser.parse_args()
process_sorted_buckets(args.working_directory, args.move_dir, args.process_count)
\ No newline at end of file
"""
Iteratively runs gnu sort on each bucket, gnu handles the multiprocessing.
Arguments
---------
--working_directory (-dir)
Directory containing the bucketed 13-grams. Sorted buckets will be deposited in the same
directory and the unsorted buckets are removed after.
"""
import glob
import argparse
import os
from pathlib import Path
import signal
from signal import SIGINT
import re
import subprocess
from tqdm import tqdm
import logging
from tqdm_multiprocess.logger import setup_logger_tqdm
logger = logging.getLogger(__name__)
terminate = False
def handler(signal_received, frame):
global terminate
terminate = True
def sort_13_gram_buckets(working_directory):
bucket_file_paths = glob.glob(os.path.join(working_directory, f"*.bkt.txt"))
for bucket_file_path in tqdm(bucket_file_paths, dynamic_ncols=True):
bucket_id = re.sub("\D", "", os.path.basename(bucket_file_path))
done_file = os.path.join(working_directory, f"ngram_bucket_sorting_{bucket_id}.done")
if os.path.exists(done_file):
logger.info(f"bucket {bucket_id} already processed, skipping")
return
sorted_file_path = bucket_file_path + ".sorted"
command = f"sort {bucket_file_path} > {sorted_file_path}"
logger.info(command)
subprocess.call(command, shell=True)
if terminate:
return
Path(done_file).touch()
os.remove(bucket_file_path)
parser = argparse.ArgumentParser(description='sort 13gram buckets')
parser.add_argument("-dir", "--working_directory", default="")
if __name__ == '__main__':
# Handle sigint (ctrl-c) cleanly
previous_signal_int = signal.signal(SIGINT, handler)
logfile_path = "sort13grambuckets.log"
setup_logger_tqdm(logfile_path)
args = parser.parse_args()
sort_13_gram_buckets(args.working_directory)
\ No newline at end of file
......@@ -33,5 +33,10 @@ setuptools.setup(
"pycountry==20.7.3",
"numexpr==2.7.2",
"lm_dataformat==0.0.19",
"pytest==6.2.3",
"pybind11==2.6.2",
"tqdm-multiprocess==0.0.11",
"zstandard==0.15.2",
"jsonlines==2.0.0"
]
)
import os
from collections import Counter
import shutil
import glob
from scripts.clean_training_data.janitor import *
from scripts.clean_training_data.generate_13_grams import do_ngrams_in_buckets
from scripts.clean_training_data.archiver import Archive, TextReader
def test_generate_13_grams_1():
data = """A goose (plural geese) is a bird of any of several waterfowl species in the family Anatidae.
This group comprises the genera Anser (the grey geese and white geese) and Branta (the black geese).
Some other birds, mostly related to the shelducks, have "goose" as part of their names.
More distantly related members of the family Anatidae are swans, most of which are larger
than true geese, and ducks, which are smaller. The term "goose" may refer to either a male
or female bird, but when paired with "gander", refers specifically to a female one (the latter referring
to a male). Young birds before fledging are called goslings. The collective noun for a group of
geese on the ground is a gaggle; when in flight, they are called a skein, a team, or a wedge; when
flying close together, they are called a plump."""
data = data + data
# Simple Generation
n = 13
janitor = Janitor()
ngrams = word_ngrams(janitor.normalize_string(data), n)
comparison = list(ngrams)
comparison_counter = Counter(comparison)
print(len(comparison))
# print(comparison)
# Generating into buckets
test_working_directory = "test_generate_13_grams"
output_directory = os.path.join(test_working_directory, "output")
try:
shutil.rmtree(output_directory)
except FileNotFoundError:
pass
os.makedirs(test_working_directory, exist_ok=True)
archive = Archive(os.path.join(test_working_directory, "test.jsonl.zst"))
archive.add_data(data)
archive.commit()
bucket_count = 4
do_ngrams_in_buckets(n, test_working_directory, bucket_count)
# Rebuild from buckets
rebuilt_ngrams = []
bucket_file_paths = glob.glob(os.path.join(test_working_directory, "output", f"*.bkt.txt"))
for bucket_file_path in bucket_file_paths:
reader = TextReader(bucket_file_path)
for line in reader.read():
[ngram, document_id] = line.rsplit(" ", 1)
rebuilt_ngrams.append(ngram)
# Compare
result_counter = Counter(rebuilt_ngrams)
# print(len(result_counter))
# print(len(comparison_counter))
assert(len(result_counter) == len(comparison_counter))
# print(result_counter)
# print(comparison_counter)
assert(comparison_counter == result_counter)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment