"vscode:/vscode.git/clone" did not exist on "910404764ef380738f5cc81a00f170ab43615559"
generate_13_grams.py 7.23 KB
Newer Older
researcher2's avatar
researcher2 committed
1
2
3
"""
Outputs all 13-grams found in The Pile.

bzantium's avatar
bzantium committed
4
5
6
Loops through all documents and uses the logic found in janitor.py to extract 13-grams.
We bucket each 13-gram by hash into separate file buckets to allow easy parallel processing in the
next stage. We also include the current pile document_id with each ngram instance to allow the
researcher2's avatar
researcher2 committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
filtering to exclude 13-grams that match more then 10 unique documents (done further down the pipeline).

We didn't use lm_dataformat to output as it increases time 4x (slow jsonify) and makes
resuming hard (and we had the storage).

Arguments
---------
--working_directory (-dir)
    Directory containing the pile distribution. An "output" subdirectory will be created underneath
    to store the bucketed 13-grams, checkpoint and done files. Default: current directory
--n_value (-n)
    n value in n-gram, added for later use if ever needed. Default: 13
--bucket_count (-buckets)
    Number of file buckets to use when generating 13grams. Default: 500
"""

import argparse
bzantium's avatar
bzantium committed
24
import json
researcher2's avatar
researcher2 committed
25
26
import pickle
import os
bzantium's avatar
bzantium committed
27
import sys
researcher2's avatar
researcher2 committed
28
29
30
31
32
33
34
from pathlib import Path
import glob
import signal
from signal import SIGINT

from tqdm import tqdm

bzantium's avatar
bzantium committed
35
36
from lm_eval.decontamination.janitor import Janitor, word_ngrams
from lm_eval.decontamination.archiver import TextArchive, Reader
researcher2's avatar
researcher2 committed
37
38
39
40

import logging
from tqdm_multiprocess.logger import setup_logger_tqdm

bzantium's avatar
bzantium committed
41
logger = logging.getLogger(__name__)
researcher2's avatar
researcher2 committed
42
43

terminate = False
bzantium's avatar
bzantium committed
44
45


researcher2's avatar
researcher2 committed
46
47
48
49
def handler(signal_received, frame):
    global terminate
    terminate = True

bzantium's avatar
bzantium committed
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77

def yield_pile(start_offsets=None, checkpoint_offset=None):
    directory = "pile"

    if not os.path.exists(directory):
        print(
            "We expect the pile archives to be in the 'pile' directory, but this was not found."
        )
        raise Exception("Pile directory not found.")

    files = list(sorted(glob.glob(os.path.join(directory, "*.jsonl.zst*"))))

    pile_global_offset = 0
    start_file = 0
    if checkpoint_offset:
        for file_i, start_offset in enumerate(start_offsets):
            if start_offset > checkpoint_offset:
                break

            start_file = file_i
            pile_global_offset = start_offset

    for file_i, file in enumerate(files):
        if file_i < start_file:
            logger.info(f"Skipping file {file}")
            continue
        logger.info(f"Reading from pile file: {file}")
        reader = Reader()
researcher2's avatar
researcher2 committed
78
        for document in reader.read(file):
bzantium's avatar
bzantium committed
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
            yield (pile_global_offset, document)
            pile_global_offset += 1


# Hash buckets > disk backed files. Supports file position checkpointing and resuming
# Allows you to write continuously and checkpoint intermittently. If a failure occurs
# the buckets are simply truncated at your last checkpoint.
class Buckets:
    def __init__(self, directory, num_buckets):
        self.bucket_files = [
            os.path.join(directory, f"ngrams_{i}.bkt.txt") for i in range(num_buckets)
        ]
        self.buckets = list(map(TextArchive, self.bucket_files))
        self.checkpoint_file = os.path.join(directory, f"bucket_offsets.ckpt")

        if os.path.exists(self.checkpoint_file):
            self.bucket_offsets = pickle.load(open(self.checkpoint_file, "rb"))
        else:
            self.bucket_offsets = [0 for i in range(len(self.buckets))]

        for i, offset in enumerate(self.bucket_offsets):
            bucket = self.buckets[i]
            bucket.fh.seek(offset)
            bucket.fh.truncate()

    def add_data(self, key, value):
        i = hash(key) % len(self.buckets)
        bucket = self.buckets[i]
        bucket.add_data(value)

    def save_checkpoint(self):
        for bucket in self.buckets:
            bucket.fh.flush()

        bucket_offsets = [bucket.fh.tell() for bucket in self.buckets]
        pickle.dump(bucket_offsets, open(self.checkpoint_file, "wb"))

    def close_buckets(self):
        for bucket in self.buckets:
            bucket.commit()
researcher2's avatar
researcher2 committed
119
120
121
122


def do_ngrams_in_buckets(n_value, working_directory, bucket_count):

bzantium's avatar
bzantium committed
123
124
125
126
    pile_statistics = json.load(open("pile_statistics.json", "r"))
    pile_document_count = pile_statistics["Document Count"]
    start_offsets = pile_statistics["File Start Offsets"]

researcher2's avatar
researcher2 committed
127
128
129
130
131
132
133
134
135
136
137
138
    output_directory = os.path.join(working_directory, "output")
    os.makedirs(output_directory, exist_ok=True)

    logger.info(f"Generating {n_value}-grams and bucketing.")

    # Done file
    done_file = os.path.join(output_directory, f"ngram_buckets.done")
    if os.path.exists(done_file):
        logger.info("ngrams already generated and bucketed, skipping")
        return

    # Checkpoint
bzantium's avatar
bzantium committed
139
    checkpoint_file = os.path.join(working_directory, f"pile_offset.ckpt")
researcher2's avatar
researcher2 committed
140
    if os.path.exists(checkpoint_file):
bzantium's avatar
bzantium committed
141
142
        checkpoint_offset = pickle.load(open(checkpoint_file, "rb"))
        iterate = True
researcher2's avatar
researcher2 committed
143
    else:
bzantium's avatar
bzantium committed
144
145
        checkpoint_offset = 0
        iterate = False
researcher2's avatar
researcher2 committed
146

bzantium's avatar
bzantium committed
147
148
    logger.info(f"Starting at pile document index {checkpoint_offset}")
    buckets = Buckets(output_directory, bucket_count)
researcher2's avatar
researcher2 committed
149
150
151
152
153

    janitor = Janitor()
    batch_size = 1000
    batch_counter = 0

bzantium's avatar
bzantium committed
154
155
156
157
158
159
160
161
    with tqdm(total=checkpoint_offset, dynamic_ncols=True, unit="docs") as progress:
        for offset, document in yield_pile(start_offsets, checkpoint_offset):
            if iterate:
                logger.info(f"Iterating to offset {checkpoint_offset} from {offset}")
                progress.update(offset)
                iterate = False

            if offset < checkpoint_offset:
researcher2's avatar
researcher2 committed
162
                progress.update()
bzantium's avatar
bzantium committed
163
164
165

                if terminate:
                    return
researcher2's avatar
researcher2 committed
166
167
                continue

bzantium's avatar
bzantium committed
168
169
170
171
            if offset == checkpoint_offset:
                progress.reset(total=pile_document_count)
                progress.update(checkpoint_offset)

researcher2's avatar
researcher2 committed
172
173
174
175
            # Save checkpoint every "batch_size", only allow terminate after checkpoint
            if batch_counter == batch_size:
                progress.update(batch_size)
                batch_counter = 0
bzantium's avatar
bzantium committed
176
177
                buckets.save_checkpoint()
                pickle.dump(offset, open(checkpoint_file, "wb"))
researcher2's avatar
researcher2 committed
178
                if terminate:
bzantium's avatar
bzantium committed
179
                    buckets.close_buckets()
researcher2's avatar
researcher2 committed
180
181
182
183
                    return

            ngrams = word_ngrams(janitor.normalize_string(document), n_value)
            for ngram in ngrams:
bzantium's avatar
bzantium committed
184
                buckets.add_data(ngram, f"{ngram} {offset}")
researcher2's avatar
researcher2 committed
185
186

            batch_counter += 1
bzantium's avatar
bzantium committed
187
188

    buckets.close_buckets()
researcher2's avatar
researcher2 committed
189
190
191
    Path(done_file).touch()


bzantium's avatar
bzantium committed
192
parser = argparse.ArgumentParser(description="Generate 13 grams from Pile.")
researcher2's avatar
researcher2 committed
193
194
195
196
parser.add_argument("-dir", "--working_directory", default="")
parser.add_argument("-n", "--n_value", type=int, default=13)
parser.add_argument("-buckets", "--bucket_count", type=int, default=500)

bzantium's avatar
bzantium committed
197
198
199
200
201
202
203
if __name__ == "__main__":
    version = 1.00
    print(f"Running version {version}")

    if "PYTHONHASHSEED" not in os.environ or os.environ["PYTHONHASHSEED"] != "0":
        print("Please run 'export PYTHONHASHSEED=0' before running generate.")
        sys.exit()
researcher2's avatar
researcher2 committed
204
205
206
207
208
209
210
211

    # Handle sigint (ctrl-c) cleanly
    previous_signal_int = signal.signal(SIGINT, handler)

    logfile_path = "ngrams.log"
    setup_logger_tqdm(logfile_path)

    args = parser.parse_args()
bzantium's avatar
bzantium committed
212
213
214
215
216
    do_ngrams_in_buckets(args.n_value, args.working_directory, args.bucket_count)

    info_dict = {"title": "dataset ngrams", "ngram_size": 13}
    info_dict_path = os.path.join(args.working_directory, "info.json")
    json.dump(info_dict, open(info_dict_path, "w"))