generate_13_grams.py 7.23 KB
Newer Older
researcher2's avatar
researcher2 committed
1
2
3
"""
Outputs all 13-grams found in The Pile.

Fabrizio Milo's avatar
Fabrizio Milo committed
4
5
6
Loops through all documents and uses the logic found in janitor.py to extract 13-grams.
We bucket each 13-gram by hash into separate file buckets to allow easy parallel processing in the
next stage. We also include the current pile document_id with each ngram instance to allow the
researcher2's avatar
researcher2 committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
filtering to exclude 13-grams that match more then 10 unique documents (done further down the pipeline).

We didn't use lm_dataformat to output as it increases time 4x (slow jsonify) and makes
resuming hard (and we had the storage).

Arguments
---------
--working_directory (-dir)
    Directory containing the pile distribution. An "output" subdirectory will be created underneath
    to store the bucketed 13-grams, checkpoint and done files. Default: current directory
--n_value (-n)
    n value in n-gram, added for later use if ever needed. Default: 13
--bucket_count (-buckets)
    Number of file buckets to use when generating 13grams. Default: 500
"""

import argparse
24
import glob
researcher2's avatar
researcher2 committed
25
import json
26
import logging
researcher2's avatar
researcher2 committed
27
import os
28
29
import pickle
import signal
researcher2's avatar
researcher2 committed
30
import sys
researcher2's avatar
researcher2 committed
31
32
33
34
from pathlib import Path
from signal import SIGINT

from tqdm import tqdm
35
from tqdm_multiprocess.logger import setup_logger_tqdm
researcher2's avatar
researcher2 committed
36

37
from lm_eval.decontamination.archiver import Reader, TextArchive
researcher2's avatar
researcher2 committed
38
from lm_eval.decontamination.janitor import Janitor, word_ngrams
researcher2's avatar
researcher2 committed
39

Fabrizio Milo's avatar
Fabrizio Milo committed
40

researcher2's avatar
researcher2 committed
41
42
43
logger = logging.getLogger(__name__)

terminate = False
Fabrizio Milo's avatar
Fabrizio Milo committed
44
45


researcher2's avatar
researcher2 committed
46
47
48
49
def handler(signal_received, frame):
    global terminate
    terminate = True

Fabrizio Milo's avatar
Fabrizio Milo committed
50

researcher2's avatar
researcher2 committed
51
52
53
54
def yield_pile(start_offsets=None, checkpoint_offset=None):
    directory = "pile"

    if not os.path.exists(directory):
Fabrizio Milo's avatar
Fabrizio Milo committed
55
56
57
        print(
            "We expect the pile archives to be in the 'pile' directory, but this was not found."
        )
researcher2's avatar
researcher2 committed
58
59
60
61
62
63
64
65
66
67
        raise Exception("Pile directory not found.")

    files = list(sorted(glob.glob(os.path.join(directory, "*.jsonl.zst*"))))

    pile_global_offset = 0
    start_file = 0
    if checkpoint_offset:
        for file_i, start_offset in enumerate(start_offsets):
            if start_offset > checkpoint_offset:
                break
researcher2's avatar
researcher2 committed
68

researcher2's avatar
researcher2 committed
69
70
71
72
73
            start_file = file_i
            pile_global_offset = start_offset

    for file_i, file in enumerate(files):
        if file_i < start_file:
Fabrizio Milo's avatar
Fabrizio Milo committed
74
            logger.info(f"Skipping file {file}")
researcher2's avatar
researcher2 committed
75
76
77
78
79
80
81
            continue
        logger.info(f"Reading from pile file: {file}")
        reader = Reader()
        for document in reader.read(file):
            yield (pile_global_offset, document)
            pile_global_offset += 1

Fabrizio Milo's avatar
Fabrizio Milo committed
82

researcher2's avatar
researcher2 committed
83
84
85
86
87
# Hash buckets > disk backed files. Supports file position checkpointing and resuming
# Allows you to write continuously and checkpoint intermittently. If a failure occurs
# the buckets are simply truncated at your last checkpoint.
class Buckets:
    def __init__(self, directory, num_buckets):
Fabrizio Milo's avatar
Fabrizio Milo committed
88
89
90
        self.bucket_files = [
            os.path.join(directory, f"ngrams_{i}.bkt.txt") for i in range(num_buckets)
        ]
researcher2's avatar
researcher2 committed
91
        self.buckets = list(map(TextArchive, self.bucket_files))
92
        self.checkpoint_file = os.path.join(directory, "bucket_offsets.ckpt")
researcher2's avatar
researcher2 committed
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118

        if os.path.exists(self.checkpoint_file):
            self.bucket_offsets = pickle.load(open(self.checkpoint_file, "rb"))
        else:
            self.bucket_offsets = [0 for i in range(len(self.buckets))]

        for i, offset in enumerate(self.bucket_offsets):
            bucket = self.buckets[i]
            bucket.fh.seek(offset)
            bucket.fh.truncate()

    def add_data(self, key, value):
        i = hash(key) % len(self.buckets)
        bucket = self.buckets[i]
        bucket.add_data(value)

    def save_checkpoint(self):
        for bucket in self.buckets:
            bucket.fh.flush()

        bucket_offsets = [bucket.fh.tell() for bucket in self.buckets]
        pickle.dump(bucket_offsets, open(self.checkpoint_file, "wb"))

    def close_buckets(self):
        for bucket in self.buckets:
            bucket.commit()
researcher2's avatar
researcher2 committed
119

Fabrizio Milo's avatar
Fabrizio Milo committed
120

researcher2's avatar
researcher2 committed
121
def do_ngrams_in_buckets(n_value, working_directory, bucket_count):
researcher2's avatar
researcher2 committed
122
123
124
125
    pile_statistics = json.load(open("pile_statistics.json", "r"))
    pile_document_count = pile_statistics["Document Count"]
    start_offsets = pile_statistics["File Start Offsets"]

researcher2's avatar
researcher2 committed
126
127
128
129
130
131
    output_directory = os.path.join(working_directory, "output")
    os.makedirs(output_directory, exist_ok=True)

    logger.info(f"Generating {n_value}-grams and bucketing.")

    # Done file
132
    done_file = os.path.join(output_directory, "ngram_buckets.done")
researcher2's avatar
researcher2 committed
133
134
135
136
137
    if os.path.exists(done_file):
        logger.info("ngrams already generated and bucketed, skipping")
        return

    # Checkpoint
138
    checkpoint_file = os.path.join(working_directory, "pile_offset.ckpt")
researcher2's avatar
researcher2 committed
139
    if os.path.exists(checkpoint_file):
Fabrizio Milo's avatar
Fabrizio Milo committed
140
        checkpoint_offset = pickle.load(open(checkpoint_file, "rb"))
researcher2's avatar
researcher2 committed
141
        iterate = True
researcher2's avatar
researcher2 committed
142
    else:
researcher2's avatar
researcher2 committed
143
144
        checkpoint_offset = 0
        iterate = False
researcher2's avatar
researcher2 committed
145

researcher2's avatar
researcher2 committed
146
147
    logger.info(f"Starting at pile document index {checkpoint_offset}")
    buckets = Buckets(output_directory, bucket_count)
researcher2's avatar
researcher2 committed
148
149
150
151
152

    janitor = Janitor()
    batch_size = 1000
    batch_counter = 0

researcher2's avatar
researcher2 committed
153
154
155
    with tqdm(total=checkpoint_offset, dynamic_ncols=True, unit="docs") as progress:
        for offset, document in yield_pile(start_offsets, checkpoint_offset):
            if iterate:
Fabrizio Milo's avatar
Fabrizio Milo committed
156
                logger.info(f"Iterating to offset {checkpoint_offset} from {offset}")
researcher2's avatar
researcher2 committed
157
158
159
160
                progress.update(offset)
                iterate = False

            if offset < checkpoint_offset:
researcher2's avatar
researcher2 committed
161
                progress.update()
researcher2's avatar
researcher2 committed
162
163
164

                if terminate:
                    return
researcher2's avatar
researcher2 committed
165
166
                continue

researcher2's avatar
researcher2 committed
167
168
169
170
            if offset == checkpoint_offset:
                progress.reset(total=pile_document_count)
                progress.update(checkpoint_offset)

researcher2's avatar
researcher2 committed
171
172
173
174
            # Save checkpoint every "batch_size", only allow terminate after checkpoint
            if batch_counter == batch_size:
                progress.update(batch_size)
                batch_counter = 0
researcher2's avatar
researcher2 committed
175
                buckets.save_checkpoint()
Fabrizio Milo's avatar
Fabrizio Milo committed
176
                pickle.dump(offset, open(checkpoint_file, "wb"))
researcher2's avatar
researcher2 committed
177
                if terminate:
researcher2's avatar
researcher2 committed
178
                    buckets.close_buckets()
researcher2's avatar
researcher2 committed
179
180
181
182
                    return

            ngrams = word_ngrams(janitor.normalize_string(document), n_value)
            for ngram in ngrams:
researcher2's avatar
researcher2 committed
183
                buckets.add_data(ngram, f"{ngram} {offset}")
researcher2's avatar
researcher2 committed
184
185

            batch_counter += 1
Fabrizio Milo's avatar
Fabrizio Milo committed
186

researcher2's avatar
researcher2 committed
187
    buckets.close_buckets()
researcher2's avatar
researcher2 committed
188
189
190
    Path(done_file).touch()


Fabrizio Milo's avatar
Fabrizio Milo committed
191
parser = argparse.ArgumentParser(description="Generate 13 grams from Pile.")
researcher2's avatar
researcher2 committed
192
193
194
195
parser.add_argument("-dir", "--working_directory", default="")
parser.add_argument("-n", "--n_value", type=int, default=13)
parser.add_argument("-buckets", "--bucket_count", type=int, default=500)

Fabrizio Milo's avatar
Fabrizio Milo committed
196
if __name__ == "__main__":
researcher2's avatar
researcher2 committed
197
198
199
200
201
202
    version = 1.00
    print(f"Running version {version}")

    if "PYTHONHASHSEED" not in os.environ or os.environ["PYTHONHASHSEED"] != "0":
        print("Please run 'export PYTHONHASHSEED=0' before running generate.")
        sys.exit()
researcher2's avatar
researcher2 committed
203
204
205
206
207
208
209
210

    # Handle sigint (ctrl-c) cleanly
    previous_signal_int = signal.signal(SIGINT, handler)

    logfile_path = "ngrams.log"
    setup_logger_tqdm(logfile_path)

    args = parser.parse_args()
researcher2's avatar
researcher2 committed
211
212
213
214
    do_ngrams_in_buckets(args.n_value, args.working_directory, args.bucket_count)

    info_dict = {"title": "dataset ngrams", "ngram_size": 13}
    info_dict_path = os.path.join(args.working_directory, "info.json")
Fabrizio Milo's avatar
Fabrizio Milo committed
215
    json.dump(info_dict, open(info_dict_path, "w"))