generate_13_grams.py 4.67 KB
Newer Older
researcher2's avatar
researcher2 committed
1
2
3
"""
Outputs all 13-grams found in The Pile.

Fabrizio Milo's avatar
Fabrizio Milo committed
4
5
6
Loops through all documents and uses the logic found in janitor.py to extract 13-grams.
We bucket each 13-gram by hash into separate file buckets to allow easy parallel processing in the
next stage. We also include the current pile document_id with each ngram instance to allow the
researcher2's avatar
researcher2 committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
filtering to exclude 13-grams that match more then 10 unique documents (done further down the pipeline).

We didn't use lm_dataformat to output as it increases time 4x (slow jsonify) and makes
resuming hard (and we had the storage).

Arguments
---------
--working_directory (-dir)
    Directory containing the pile distribution. An "output" subdirectory will be created underneath
    to store the bucketed 13-grams, checkpoint and done files. Default: current directory
--n_value (-n)
    n value in n-gram, added for later use if ever needed. Default: 13
--bucket_count (-buckets)
    Number of file buckets to use when generating 13grams. Default: 500
"""

import argparse
researcher2's avatar
researcher2 committed
24
import json
researcher2's avatar
researcher2 committed
25
26
import pickle
import os
researcher2's avatar
researcher2 committed
27
import sys
researcher2's avatar
researcher2 committed
28
29
30
31
32
33
34
from pathlib import Path
import glob
import signal
from signal import SIGINT

from tqdm import tqdm

researcher2's avatar
researcher2 committed
35
36
from lm_eval.decontamination.janitor import Janitor, word_ngrams
from lm_eval.decontamination.archiver import TextArchive, Reader
researcher2's avatar
researcher2 committed
37
38
39
40
41
42
43
44
45
46
47
48

import logging
from tqdm_multiprocess.logger import setup_logger_tqdm

logger = logging.getLogger(__name__)

terminate = False
def handler(signal_received, frame):
    global terminate
    terminate = True

def get_pile(directory):
cardy20's avatar
cardy20 committed
49
    reader = Reader()
cardy20's avatar
cardy20 committed
50
    for file in glob.glob(os.path.join(directory, f"*.jsonl.zst*")):
researcher2's avatar
researcher2 committed
51
        for document in reader.read(file):
cardy20's avatar
cardy20 committed
52
            yield document
researcher2's avatar
researcher2 committed
53

researcher2's avatar
researcher2 committed
54
55
56
    def close_buckets(self):
        for bucket in self.buckets:
            bucket.commit()
researcher2's avatar
researcher2 committed
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71

def do_ngrams_in_buckets(n_value, working_directory, bucket_count):

    output_directory = os.path.join(working_directory, "output")
    os.makedirs(output_directory, exist_ok=True)

    logger.info(f"Generating {n_value}-grams and bucketing.")

    # Done file
    done_file = os.path.join(output_directory, f"ngram_buckets.done")
    if os.path.exists(done_file):
        logger.info("ngrams already generated and bucketed, skipping")
        return

    # Checkpoint
researcher2's avatar
researcher2 committed
72
    checkpoint_file = os.path.join(working_directory, f"pile_offset.ckpt")
researcher2's avatar
researcher2 committed
73
    if os.path.exists(checkpoint_file):
Fabrizio Milo's avatar
Fabrizio Milo committed
74
        checkpoint_offset = pickle.load(open(checkpoint_file, "rb"))
researcher2's avatar
researcher2 committed
75
        iterate = True
researcher2's avatar
researcher2 committed
76
    else:
researcher2's avatar
researcher2 committed
77
78
        checkpoint_offset = 0
        iterate = False
researcher2's avatar
researcher2 committed
79

researcher2's avatar
researcher2 committed
80
81
    logger.info(f"Starting at pile document index {checkpoint_offset}")
    buckets = Buckets(output_directory, bucket_count)
researcher2's avatar
researcher2 committed
82
83
84
85
86

    janitor = Janitor()
    batch_size = 1000
    batch_counter = 0

researcher2's avatar
researcher2 committed
87
88
89
    with tqdm(total=checkpoint_offset, dynamic_ncols=True, unit="docs") as progress:
        for offset, document in yield_pile(start_offsets, checkpoint_offset):
            if iterate:
Fabrizio Milo's avatar
Fabrizio Milo committed
90
                logger.info(f"Iterating to offset {checkpoint_offset} from {offset}")
researcher2's avatar
researcher2 committed
91
92
93
94
                progress.update(offset)
                iterate = False

            if offset < checkpoint_offset:
researcher2's avatar
researcher2 committed
95
                progress.update()
researcher2's avatar
researcher2 committed
96
97
98

                if terminate:
                    return
researcher2's avatar
researcher2 committed
99
100
101
102
103
104
                continue

            # Save checkpoint every "batch_size", only allow terminate after checkpoint
            if batch_counter == batch_size:
                progress.update(batch_size)
                batch_counter = 0
researcher2's avatar
researcher2 committed
105
                buckets.save_checkpoint()
Fabrizio Milo's avatar
Fabrizio Milo committed
106
                pickle.dump(offset, open(checkpoint_file, "wb"))
researcher2's avatar
researcher2 committed
107
                if terminate:
researcher2's avatar
researcher2 committed
108
                    buckets.close_buckets()
researcher2's avatar
researcher2 committed
109
110
111
112
                    return

            ngrams = word_ngrams(janitor.normalize_string(document), n_value)
            for ngram in ngrams:
researcher2's avatar
researcher2 committed
113
                buckets.add_data(ngram, f"{ngram} {offset}")
researcher2's avatar
researcher2 committed
114
115

            batch_counter += 1
Fabrizio Milo's avatar
Fabrizio Milo committed
116

researcher2's avatar
researcher2 committed
117
    buckets.close_buckets()
researcher2's avatar
researcher2 committed
118
119
120
    Path(done_file).touch()


Fabrizio Milo's avatar
Fabrizio Milo committed
121
parser = argparse.ArgumentParser(description="Generate 13 grams from Pile.")
researcher2's avatar
researcher2 committed
122
parser.add_argument("-dir", "--working_directory", default="")
cardy20's avatar
cardy20 committed
123
parser.add_argument("-sdir", "--save_directory", default="")
researcher2's avatar
researcher2 committed
124
125
126
parser.add_argument("-n", "--n_value", type=int, default=13)
parser.add_argument("-buckets", "--bucket_count", type=int, default=500)

Fabrizio Milo's avatar
Fabrizio Milo committed
127
if __name__ == "__main__":
researcher2's avatar
researcher2 committed
128
129
130
131
132
133
    version = 1.00
    print(f"Running version {version}")

    if "PYTHONHASHSEED" not in os.environ or os.environ["PYTHONHASHSEED"] != "0":
        print("Please run 'export PYTHONHASHSEED=0' before running generate.")
        sys.exit()
researcher2's avatar
researcher2 committed
134
135
136
137
138
139
140
141

    # Handle sigint (ctrl-c) cleanly
    previous_signal_int = signal.signal(SIGINT, handler)

    logfile_path = "ngrams.log"
    setup_logger_tqdm(logfile_path)

    args = parser.parse_args()
cardy20's avatar
cardy20 committed
142
    do_ngrams_in_buckets(args.n_value, args.working_directory, args.bucket_count)