generate_13_grams.py 4.54 KB
Newer Older
researcher2's avatar
researcher2 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
"""
Outputs all 13-grams found in The Pile.

Loops through all documents and uses the logic found in janitor.py to extract 13-grams. 
We bucket each 13-gram by hash into separate file buckets to allow easy parallel processing in the 
next stage. We also include the current pile document_id with each ngram instance to allow the 
filtering to exclude 13-grams that match more then 10 unique documents (done further down the pipeline).

We didn't use lm_dataformat to output as it increases time 4x (slow jsonify) and makes
resuming hard (and we had the storage).

Arguments
---------
--working_directory (-dir)
    Directory containing the pile distribution. An "output" subdirectory will be created underneath
    to store the bucketed 13-grams, checkpoint and done files. Default: current directory
--n_value (-n)
    n value in n-gram, added for later use if ever needed. Default: 13
--bucket_count (-buckets)
    Number of file buckets to use when generating 13grams. Default: 500
"""

import argparse
import pickle
import os
from pathlib import Path
import glob
import signal
from signal import SIGINT

from tqdm import tqdm

from scripts.clean_training_data.janitor import Janitor, word_ngrams
from scripts.clean_training_data.archiver import TextArchive, Reader

import logging
from tqdm_multiprocess.logger import setup_logger_tqdm
logger = logging.getLogger(__name__)

pile_document_count = 210607728

terminate = False
cardy20's avatar
cardy20 committed
43

researcher2's avatar
researcher2 committed
44
45
46
47
48
def handler(signal_received, frame):
    global terminate
    terminate = True

def get_pile(directory):
cardy20's avatar
cardy20 committed
49
    reader = Reader()
cardy20's avatar
cardy20 committed
50
51
    # for file in glob.glob(os.path.join(directory, f"*.jsonl.zst*")):
    for dir in os.listdir(directory):
cardy20's avatar
cardy20 committed
52
        for file in glob.glob(os.path.join(directory + dir, "*.jsonl")):
cardy20's avatar
cardy20 committed
53
            for document in reader.read(file):
cardy20's avatar
cardy20 committed
54
                yield document
researcher2's avatar
researcher2 committed
55
56
57
58
59

def close_buckets(buckets):
    for bucket in buckets:
        bucket.commit()

cardy20's avatar
cardy20 committed
60
def do_ngrams_in_buckets(n_value, working_directory, sdir, bucket_count):
researcher2's avatar
researcher2 committed
61

cardy20's avatar
cardy20 committed
62
    output_directory = os.path.join(sdir, "output")
researcher2's avatar
researcher2 committed
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
    os.makedirs(output_directory, exist_ok=True)

    logger.info(f"Generating {n_value}-grams and bucketing.")

    # Done file
    done_file = os.path.join(output_directory, f"ngram_buckets.done")
    if os.path.exists(done_file):
        logger.info("ngrams already generated and bucketed, skipping")
        return

    # Checkpoint
    checkpoint_file = os.path.join(output_directory, f"ngram_buckets.ckpt")
    if os.path.exists(checkpoint_file):
        start_id = pickle.load(open(checkpoint_file,"rb"))
    else:
        start_id = 0

    logger.info(f"Starting at pile document index {start_id}")
    bucket_files = [os.path.join(output_directory, f"ngrams_{i}.bkt.txt") for i in range(bucket_count)]
    buckets = list(map(TextArchive, bucket_files))

    janitor = Janitor()
    current_id = 0
    batch_size = 1000
    batch_counter = 0
    with tqdm(total=pile_document_count, dynamic_ncols=True, unit="docs") as progress:
        for document in get_pile(working_directory):
            if current_id < start_id:
                if terminate:
                    close_buckets(buckets)
                    return

                current_id += 1
                progress.update()
                continue
cardy20's avatar
cardy20 committed
98
            
researcher2's avatar
researcher2 committed
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
            # Save checkpoint every "batch_size", only allow terminate after checkpoint
            if batch_counter == batch_size:
                progress.update(batch_size)
                batch_counter = 0
                pickle.dump(current_id, open(checkpoint_file,"wb"))
                if terminate:
                    close_buckets(buckets)
                    return

            ngrams = word_ngrams(janitor.normalize_string(document), n_value)
            for ngram in ngrams:
                bucket = hash(ngram) % len(buckets)
                buckets[bucket].add_data(f"{ngram} {current_id}")

            batch_counter += 1
            current_id += 1
    
    close_buckets(buckets)
    Path(done_file).touch()


parser = argparse.ArgumentParser(description='Generate 13 grams from Pile.')
parser.add_argument("-dir", "--working_directory", default="")
cardy20's avatar
cardy20 committed
122
parser.add_argument("-sdir", "--save_directory", default="")
researcher2's avatar
researcher2 committed
123
124
125
126
127
128
129
130
131
132
133
134
parser.add_argument("-n", "--n_value", type=int, default=13)
parser.add_argument("-buckets", "--bucket_count", type=int, default=500)

if __name__ == '__main__':

    # Handle sigint (ctrl-c) cleanly
    previous_signal_int = signal.signal(SIGINT, handler)

    logfile_path = "ngrams.log"
    setup_logger_tqdm(logfile_path)

    args = parser.parse_args()
cardy20's avatar
cardy20 committed
135
    do_ngrams_in_buckets(args.n_value, args.working_directory, args.save_directory , args.bucket_count)