generate_13_grams.py 7.22 KB
Newer Older
researcher2's avatar
researcher2 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
"""
Outputs all 13-grams found in The Pile.

Loops through all documents and uses the logic found in janitor.py to extract 13-grams. 
We bucket each 13-gram by hash into separate file buckets to allow easy parallel processing in the 
next stage. We also include the current pile document_id with each ngram instance to allow the 
filtering to exclude 13-grams that match more then 10 unique documents (done further down the pipeline).

We didn't use lm_dataformat to output as it increases time 4x (slow jsonify) and makes
resuming hard (and we had the storage).

Arguments
---------
--working_directory (-dir)
    Directory containing the pile distribution. An "output" subdirectory will be created underneath
    to store the bucketed 13-grams, checkpoint and done files. Default: current directory
--n_value (-n)
    n value in n-gram, added for later use if ever needed. Default: 13
--bucket_count (-buckets)
    Number of file buckets to use when generating 13grams. Default: 500
"""

import argparse
researcher2's avatar
researcher2 committed
24
import json
researcher2's avatar
researcher2 committed
25
26
import pickle
import os
researcher2's avatar
researcher2 committed
27
import sys
researcher2's avatar
researcher2 committed
28
29
30
31
32
33
34
from pathlib import Path
import glob
import signal
from signal import SIGINT

from tqdm import tqdm

researcher2's avatar
researcher2 committed
35
36
from lm_eval.decontamination.janitor import Janitor, word_ngrams
from lm_eval.decontamination.archiver import TextArchive, Reader
researcher2's avatar
researcher2 committed
37
38
39
40
41
42
43
44
45
46

import logging
from tqdm_multiprocess.logger import setup_logger_tqdm
logger = logging.getLogger(__name__)

terminate = False
def handler(signal_received, frame):
    global terminate
    terminate = True

researcher2's avatar
researcher2 committed
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def yield_pile(start_offsets=None, checkpoint_offset=None):
    directory = "pile"

    if not os.path.exists(directory):
        print("We expect the pile archives to be in the 'pile' directory, but this was not found.")
        raise Exception("Pile directory not found.")

    files = list(sorted(glob.glob(os.path.join(directory, "*.jsonl.zst*"))))

    pile_global_offset = 0
    start_file = 0
    if checkpoint_offset:
        for file_i, start_offset in enumerate(start_offsets):
            if start_offset > checkpoint_offset:
                break
researcher2's avatar
researcher2 committed
62

researcher2's avatar
researcher2 committed
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
            start_file = file_i
            pile_global_offset = start_offset

    
    for file_i, file in enumerate(files):
        if file_i < start_file:
            logger.info(f"Skipping file {file}")            
            continue
        logger.info(f"Reading from pile file: {file}")
        reader = Reader()
        for document in reader.read(file):
            yield (pile_global_offset, document)
            pile_global_offset += 1

# Hash buckets > disk backed files. Supports file position checkpointing and resuming
# Allows you to write continuously and checkpoint intermittently. If a failure occurs
# the buckets are simply truncated at your last checkpoint.
class Buckets:
    def __init__(self, directory, num_buckets):
        self.bucket_files = [os.path.join(directory, f"ngrams_{i}.bkt.txt") for i in range(num_buckets)]
        self.buckets = list(map(TextArchive, self.bucket_files))
        self.checkpoint_file = os.path.join(directory, f"bucket_offsets.ckpt")

        if os.path.exists(self.checkpoint_file):
            self.bucket_offsets = pickle.load(open(self.checkpoint_file, "rb"))
        else:
            self.bucket_offsets = [0 for i in range(len(self.buckets))]

        for i, offset in enumerate(self.bucket_offsets):
            bucket = self.buckets[i]
            bucket.fh.seek(offset)
            bucket.fh.truncate()

    def add_data(self, key, value):
        i = hash(key) % len(self.buckets)
        bucket = self.buckets[i]
        bucket.add_data(value)

    def save_checkpoint(self):
        for bucket in self.buckets:
            bucket.fh.flush()

        bucket_offsets = [bucket.fh.tell() for bucket in self.buckets]
        pickle.dump(bucket_offsets, open(self.checkpoint_file, "wb"))

    def close_buckets(self):
        for bucket in self.buckets:
            bucket.commit()
researcher2's avatar
researcher2 committed
111
112
113

def do_ngrams_in_buckets(n_value, working_directory, bucket_count):

researcher2's avatar
researcher2 committed
114
115
116
117
    pile_statistics = json.load(open("pile_statistics.json", "r"))
    pile_document_count = pile_statistics["Document Count"]
    start_offsets = pile_statistics["File Start Offsets"]

researcher2's avatar
researcher2 committed
118
119
120
121
122
123
124
125
126
127
128
129
    output_directory = os.path.join(working_directory, "output")
    os.makedirs(output_directory, exist_ok=True)

    logger.info(f"Generating {n_value}-grams and bucketing.")

    # Done file
    done_file = os.path.join(output_directory, f"ngram_buckets.done")
    if os.path.exists(done_file):
        logger.info("ngrams already generated and bucketed, skipping")
        return

    # Checkpoint
researcher2's avatar
researcher2 committed
130
    checkpoint_file = os.path.join(working_directory, f"pile_offset.ckpt")
researcher2's avatar
researcher2 committed
131
    if os.path.exists(checkpoint_file):
researcher2's avatar
researcher2 committed
132
133
        checkpoint_offset = pickle.load(open(checkpoint_file,"rb"))
        iterate = True
researcher2's avatar
researcher2 committed
134
    else:
researcher2's avatar
researcher2 committed
135
136
        checkpoint_offset = 0
        iterate = False
researcher2's avatar
researcher2 committed
137

researcher2's avatar
researcher2 committed
138
139
    logger.info(f"Starting at pile document index {checkpoint_offset}")
    buckets = Buckets(output_directory, bucket_count)
researcher2's avatar
researcher2 committed
140
141
142
143
144

    janitor = Janitor()
    batch_size = 1000
    batch_counter = 0

researcher2's avatar
researcher2 committed
145
146
147
148
149
150
151
152
    with tqdm(total=checkpoint_offset, dynamic_ncols=True, unit="docs") as progress:
        for offset, document in yield_pile(start_offsets, checkpoint_offset):
            if iterate:
                logger.info(f"Iterating to offset {checkpoint_offset} from {offset}")                
                progress.update(offset)
                iterate = False

            if offset < checkpoint_offset:
researcher2's avatar
researcher2 committed
153
                progress.update()
researcher2's avatar
researcher2 committed
154
155
156

                if terminate:
                    return
researcher2's avatar
researcher2 committed
157
158
                continue

researcher2's avatar
researcher2 committed
159
160
161
162
            if offset == checkpoint_offset:
                progress.reset(total=pile_document_count)
                progress.update(checkpoint_offset)

researcher2's avatar
researcher2 committed
163
164
165
166
            # Save checkpoint every "batch_size", only allow terminate after checkpoint
            if batch_counter == batch_size:
                progress.update(batch_size)
                batch_counter = 0
researcher2's avatar
researcher2 committed
167
168
                buckets.save_checkpoint()
                pickle.dump(offset, open(checkpoint_file,"wb"))
researcher2's avatar
researcher2 committed
169
                if terminate:
researcher2's avatar
researcher2 committed
170
                    buckets.close_buckets()
researcher2's avatar
researcher2 committed
171
172
173
174
                    return

            ngrams = word_ngrams(janitor.normalize_string(document), n_value)
            for ngram in ngrams:
researcher2's avatar
researcher2 committed
175
                buckets.add_data(ngram, f"{ngram} {offset}")
researcher2's avatar
researcher2 committed
176
177
178

            batch_counter += 1
    
researcher2's avatar
researcher2 committed
179
    buckets.close_buckets()
researcher2's avatar
researcher2 committed
180
181
182
183
184
185
186
187
188
    Path(done_file).touch()


parser = argparse.ArgumentParser(description='Generate 13 grams from Pile.')
parser.add_argument("-dir", "--working_directory", default="")
parser.add_argument("-n", "--n_value", type=int, default=13)
parser.add_argument("-buckets", "--bucket_count", type=int, default=500)

if __name__ == '__main__':
researcher2's avatar
researcher2 committed
189
190
191
192
193
194
    version = 1.00
    print(f"Running version {version}")

    if "PYTHONHASHSEED" not in os.environ or os.environ["PYTHONHASHSEED"] != "0":
        print("Please run 'export PYTHONHASHSEED=0' before running generate.")
        sys.exit()
researcher2's avatar
researcher2 committed
195
196
197
198
199
200
201
202

    # Handle sigint (ctrl-c) cleanly
    previous_signal_int = signal.signal(SIGINT, handler)

    logfile_path = "ngrams.log"
    setup_logger_tqdm(logfile_path)

    args = parser.parse_args()
researcher2's avatar
researcher2 committed
203
204
205
206
207
    do_ngrams_in_buckets(args.n_value, args.working_directory, args.bucket_count)

    info_dict = {"title": "dataset ngrams", "ngram_size": 13}
    info_dict_path = os.path.join(args.working_directory, "info.json")
    json.dump(info_dict, open(info_dict_path, "w"))