generate_13_grams.py 7.23 KB
Newer Older
researcher2's avatar
researcher2 committed
1
2
3
"""
Outputs all 13-grams found in The Pile.

Fabrizio Milo's avatar
Fabrizio Milo committed
4
5
6
Loops through all documents and uses the logic found in janitor.py to extract 13-grams.
We bucket each 13-gram by hash into separate file buckets to allow easy parallel processing in the
next stage. We also include the current pile document_id with each ngram instance to allow the
researcher2's avatar
researcher2 committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
filtering to exclude 13-grams that match more then 10 unique documents (done further down the pipeline).

We didn't use lm_dataformat to output as it increases time 4x (slow jsonify) and makes
resuming hard (and we had the storage).

Arguments
---------
--working_directory (-dir)
    Directory containing the pile distribution. An "output" subdirectory will be created underneath
    to store the bucketed 13-grams, checkpoint and done files. Default: current directory
--n_value (-n)
    n value in n-gram, added for later use if ever needed. Default: 13
--bucket_count (-buckets)
    Number of file buckets to use when generating 13grams. Default: 500
"""

import argparse
researcher2's avatar
researcher2 committed
24
import json
researcher2's avatar
researcher2 committed
25
26
import pickle
import os
researcher2's avatar
researcher2 committed
27
import sys
researcher2's avatar
researcher2 committed
28
29
30
31
32
33
34
from pathlib import Path
import glob
import signal
from signal import SIGINT

from tqdm import tqdm

researcher2's avatar
researcher2 committed
35
36
from lm_eval.decontamination.janitor import Janitor, word_ngrams
from lm_eval.decontamination.archiver import TextArchive, Reader
researcher2's avatar
researcher2 committed
37
38
39

import logging
from tqdm_multiprocess.logger import setup_logger_tqdm
Fabrizio Milo's avatar
Fabrizio Milo committed
40

researcher2's avatar
researcher2 committed
41
42
43
logger = logging.getLogger(__name__)

terminate = False
Fabrizio Milo's avatar
Fabrizio Milo committed
44
45


researcher2's avatar
researcher2 committed
46
47
48
49
def handler(signal_received, frame):
    global terminate
    terminate = True

Fabrizio Milo's avatar
Fabrizio Milo committed
50

researcher2's avatar
researcher2 committed
51
52
53
54
def yield_pile(start_offsets=None, checkpoint_offset=None):
    directory = "pile"

    if not os.path.exists(directory):
Fabrizio Milo's avatar
Fabrizio Milo committed
55
56
57
        print(
            "We expect the pile archives to be in the 'pile' directory, but this was not found."
        )
researcher2's avatar
researcher2 committed
58
59
60
61
62
63
64
65
66
67
        raise Exception("Pile directory not found.")

    files = list(sorted(glob.glob(os.path.join(directory, "*.jsonl.zst*"))))

    pile_global_offset = 0
    start_file = 0
    if checkpoint_offset:
        for file_i, start_offset in enumerate(start_offsets):
            if start_offset > checkpoint_offset:
                break
researcher2's avatar
researcher2 committed
68

researcher2's avatar
researcher2 committed
69
70
71
72
73
            start_file = file_i
            pile_global_offset = start_offset

    for file_i, file in enumerate(files):
        if file_i < start_file:
Fabrizio Milo's avatar
Fabrizio Milo committed
74
            logger.info(f"Skipping file {file}")
researcher2's avatar
researcher2 committed
75
76
77
78
79
80
81
            continue
        logger.info(f"Reading from pile file: {file}")
        reader = Reader()
        for document in reader.read(file):
            yield (pile_global_offset, document)
            pile_global_offset += 1

Fabrizio Milo's avatar
Fabrizio Milo committed
82

researcher2's avatar
researcher2 committed
83
84
85
86
87
# Hash buckets > disk backed files. Supports file position checkpointing and resuming
# Allows you to write continuously and checkpoint intermittently. If a failure occurs
# the buckets are simply truncated at your last checkpoint.
class Buckets:
    def __init__(self, directory, num_buckets):
Fabrizio Milo's avatar
Fabrizio Milo committed
88
89
90
        self.bucket_files = [
            os.path.join(directory, f"ngrams_{i}.bkt.txt") for i in range(num_buckets)
        ]
researcher2's avatar
researcher2 committed
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
        self.buckets = list(map(TextArchive, self.bucket_files))
        self.checkpoint_file = os.path.join(directory, f"bucket_offsets.ckpt")

        if os.path.exists(self.checkpoint_file):
            self.bucket_offsets = pickle.load(open(self.checkpoint_file, "rb"))
        else:
            self.bucket_offsets = [0 for i in range(len(self.buckets))]

        for i, offset in enumerate(self.bucket_offsets):
            bucket = self.buckets[i]
            bucket.fh.seek(offset)
            bucket.fh.truncate()

    def add_data(self, key, value):
        i = hash(key) % len(self.buckets)
        bucket = self.buckets[i]
        bucket.add_data(value)

    def save_checkpoint(self):
        for bucket in self.buckets:
            bucket.fh.flush()

        bucket_offsets = [bucket.fh.tell() for bucket in self.buckets]
        pickle.dump(bucket_offsets, open(self.checkpoint_file, "wb"))

    def close_buckets(self):
        for bucket in self.buckets:
            bucket.commit()
researcher2's avatar
researcher2 committed
119

Fabrizio Milo's avatar
Fabrizio Milo committed
120

researcher2's avatar
researcher2 committed
121
122
def do_ngrams_in_buckets(n_value, working_directory, bucket_count):

researcher2's avatar
researcher2 committed
123
124
125
126
    pile_statistics = json.load(open("pile_statistics.json", "r"))
    pile_document_count = pile_statistics["Document Count"]
    start_offsets = pile_statistics["File Start Offsets"]

researcher2's avatar
researcher2 committed
127
128
129
130
131
132
133
134
135
136
137
138
    output_directory = os.path.join(working_directory, "output")
    os.makedirs(output_directory, exist_ok=True)

    logger.info(f"Generating {n_value}-grams and bucketing.")

    # Done file
    done_file = os.path.join(output_directory, f"ngram_buckets.done")
    if os.path.exists(done_file):
        logger.info("ngrams already generated and bucketed, skipping")
        return

    # Checkpoint
researcher2's avatar
researcher2 committed
139
    checkpoint_file = os.path.join(working_directory, f"pile_offset.ckpt")
researcher2's avatar
researcher2 committed
140
    if os.path.exists(checkpoint_file):
Fabrizio Milo's avatar
Fabrizio Milo committed
141
        checkpoint_offset = pickle.load(open(checkpoint_file, "rb"))
researcher2's avatar
researcher2 committed
142
        iterate = True
researcher2's avatar
researcher2 committed
143
    else:
researcher2's avatar
researcher2 committed
144
145
        checkpoint_offset = 0
        iterate = False
researcher2's avatar
researcher2 committed
146

researcher2's avatar
researcher2 committed
147
148
    logger.info(f"Starting at pile document index {checkpoint_offset}")
    buckets = Buckets(output_directory, bucket_count)
researcher2's avatar
researcher2 committed
149
150
151
152
153

    janitor = Janitor()
    batch_size = 1000
    batch_counter = 0

researcher2's avatar
researcher2 committed
154
155
156
    with tqdm(total=checkpoint_offset, dynamic_ncols=True, unit="docs") as progress:
        for offset, document in yield_pile(start_offsets, checkpoint_offset):
            if iterate:
Fabrizio Milo's avatar
Fabrizio Milo committed
157
                logger.info(f"Iterating to offset {checkpoint_offset} from {offset}")
researcher2's avatar
researcher2 committed
158
159
160
161
                progress.update(offset)
                iterate = False

            if offset < checkpoint_offset:
researcher2's avatar
researcher2 committed
162
                progress.update()
researcher2's avatar
researcher2 committed
163
164
165

                if terminate:
                    return
researcher2's avatar
researcher2 committed
166
167
                continue

researcher2's avatar
researcher2 committed
168
169
170
171
            if offset == checkpoint_offset:
                progress.reset(total=pile_document_count)
                progress.update(checkpoint_offset)

researcher2's avatar
researcher2 committed
172
173
174
175
            # Save checkpoint every "batch_size", only allow terminate after checkpoint
            if batch_counter == batch_size:
                progress.update(batch_size)
                batch_counter = 0
researcher2's avatar
researcher2 committed
176
                buckets.save_checkpoint()
Fabrizio Milo's avatar
Fabrizio Milo committed
177
                pickle.dump(offset, open(checkpoint_file, "wb"))
researcher2's avatar
researcher2 committed
178
                if terminate:
researcher2's avatar
researcher2 committed
179
                    buckets.close_buckets()
researcher2's avatar
researcher2 committed
180
181
182
183
                    return

            ngrams = word_ngrams(janitor.normalize_string(document), n_value)
            for ngram in ngrams:
researcher2's avatar
researcher2 committed
184
                buckets.add_data(ngram, f"{ngram} {offset}")
researcher2's avatar
researcher2 committed
185
186

            batch_counter += 1
Fabrizio Milo's avatar
Fabrizio Milo committed
187

researcher2's avatar
researcher2 committed
188
    buckets.close_buckets()
researcher2's avatar
researcher2 committed
189
190
191
    Path(done_file).touch()


Fabrizio Milo's avatar
Fabrizio Milo committed
192
parser = argparse.ArgumentParser(description="Generate 13 grams from Pile.")
researcher2's avatar
researcher2 committed
193
194
195
196
parser.add_argument("-dir", "--working_directory", default="")
parser.add_argument("-n", "--n_value", type=int, default=13)
parser.add_argument("-buckets", "--bucket_count", type=int, default=500)

Fabrizio Milo's avatar
Fabrizio Milo committed
197
if __name__ == "__main__":
researcher2's avatar
researcher2 committed
198
199
200
201
202
203
    version = 1.00
    print(f"Running version {version}")

    if "PYTHONHASHSEED" not in os.environ or os.environ["PYTHONHASHSEED"] != "0":
        print("Please run 'export PYTHONHASHSEED=0' before running generate.")
        sys.exit()
researcher2's avatar
researcher2 committed
204
205
206
207
208
209
210
211

    # Handle sigint (ctrl-c) cleanly
    previous_signal_int = signal.signal(SIGINT, handler)

    logfile_path = "ngrams.log"
    setup_logger_tqdm(logfile_path)

    args = parser.parse_args()
researcher2's avatar
researcher2 committed
212
213
214
215
    do_ngrams_in_buckets(args.n_value, args.working_directory, args.bucket_count)

    info_dict = {"title": "dataset ngrams", "ngram_size": 13}
    info_dict_path = os.path.join(args.working_directory, "info.json")
Fabrizio Milo's avatar
Fabrizio Milo committed
216
    json.dump(info_dict, open(info_dict_path, "w"))