decontaminate.py 6.64 KB
Newer Older
1
import collections
researcher2's avatar
researcher2 committed
2
import glob
3
import json
researcher2's avatar
researcher2 committed
4
import os
5
6
7
import pickle
import random
import time
researcher2's avatar
researcher2 committed
8
9

from .archiver import ZStdTextReader
10
from .janitor import Janitor, word_ngrams
researcher2's avatar
researcher2 committed
11

12

researcher2's avatar
researcher2 committed
13
# Was used for testing the evaluator decoupled from the full logic below
Ethan Smith's avatar
Ethan Smith committed
14
def get_train_overlap_stub(docs: dict, ngrams_path: str, ngrams_n_size: str):
researcher2's avatar
researcher2 committed
15
16
17
18
19
20
21
22
23
24
25
26
27
    simulated_overlap = 0.1
    contaminated = int(len(docs) * simulated_overlap)
    return random.sample(range(len(docs)), contaminated)


# Returns a dictionary containing all overlapping documents in each
# task. In the standard use case, an overlap occurs when any of the 13-grams
# found in the task document exist in the training set documents.
#
# To generate 13-grams for the pile see scripts/clean_training_data. The final output of these
# scripts are an info.json file containing the n_gram_size (13) and a bunch of "ngrams_{x}.bkt.txt.sorted.zst"
# files. These should exist in the "ngrams_path" provided to this function.

Ethan Smith's avatar
Ethan Smith committed
28

researcher2's avatar
researcher2 committed
29
30
31
32
33
34
35
36
# Algorithm:
# 1. Build lookups for each dataset {ngram: list(document_ids)}
# 2. Merge into an overall lookup {ngram: [(task_name, task_set, doc_ids),]}
# 3. Full scan the 13-grams from the training set against the merged lookup,
#    saving matches in the "duplicates" dictionary {(task_name, task_set): set(doc_ids)}
# 4. Strip the task_set from the dictionary keys and return
#
# We cache the task+set lookups as well as the overlaps.
Ethan Smith's avatar
Ethan Smith committed
37
def get_train_overlap(docs_by_task_set: dict, ngrams_path: str, limit: int) -> dict:
researcher2's avatar
researcher2 committed
38
39
40
41
42
43
44
45
46
47
48
49
    # return get_train_overlap_stub(docs, ngrams_path, ngrams_n_size)

    info_dict_path = os.path.join(ngrams_path, "info.json")
    info_dict = json.load(open(info_dict_path, "r"))
    ngrams_n_size = info_dict["ngram_size"]

    janitor = Janitor()

    # Build lookup for each dataset first in case we use different task combinations later
    print("Building Lookups...")
    start = time.perf_counter()

Ethan Smith's avatar
Ethan Smith committed
50
    def get_overlaps_dump_path(task_name, task_set, ngrams_n_size, limit) -> str:
researcher2's avatar
researcher2 committed
51
52
53
        return f"data/{task_name}/{task_set}_{ngrams_n_size}grams_limit{limit}.overlaps"

    lookups = {}
Fabrizio Milo's avatar
Fabrizio Milo committed
54
    duplicates = {}  # (task_name, task_set): set(doc_ids)}
researcher2's avatar
researcher2 committed
55
56
57
58
59
60
61
    sets_to_decontaminate = len(docs_by_task_set.keys())

    for (task_name, task_set), docs in docs_by_task_set.items():
        if not os.path.exists(f"data/{task_name}"):
            os.mkdir(f"data/{task_name}")

        # Check if we've decontaminated this combination before
Fabrizio Milo's avatar
Fabrizio Milo committed
62
63
64
        overlaps_dump_path = get_overlaps_dump_path(
            task_name, task_set, ngrams_n_size, limit
        )
researcher2's avatar
researcher2 committed
65
        if os.path.exists(overlaps_dump_path):
Fabrizio Milo's avatar
Fabrizio Milo committed
66
67
68
            duplicates[(task_name, task_set)] = pickle.load(
                open(overlaps_dump_path, "rb")
            )
researcher2's avatar
researcher2 committed
69
70
71
72
73
74
            sets_to_decontaminate -= 1
            continue
        else:
            duplicates[(task_name, task_set)] = set()

        # Build/load the task lookup {ngram: set(documents)}.
Fabrizio Milo's avatar
Fabrizio Milo committed
75
76
77
        task_set_lookup_path = (
            f"data/{task_name}/{task_set}_{ngrams_n_size}grams_limit{limit}.lookup"
        )
researcher2's avatar
researcher2 committed
78
        if os.path.exists(task_set_lookup_path):
Fabrizio Milo's avatar
Fabrizio Milo committed
79
80
81
82
            print(f"{task_set_lookup_path} available, loading...")
            lookups[(task_name, task_set)] = pickle.load(
                open(task_set_lookup_path, "rb")
            )
researcher2's avatar
researcher2 committed
83
84
85
86
87
88
89
90
91
        else:
            print(f"{task_set_lookup_path} not available, building...")
            lookup = collections.defaultdict(set)

            for doc_id, document in enumerate(docs):
                ngrams = word_ngrams(janitor.normalize_string(document), ngrams_n_size)
                for ngram in ngrams:
                    lookup[ngram].add(doc_id)

Fabrizio Milo's avatar
Fabrizio Milo committed
92
            pickle.dump(lookup, open(task_set_lookup_path, "wb"))
researcher2's avatar
researcher2 committed
93
94
95
96
97
98
99
100
101
            lookups[(task_name, task_set)] = lookup

    elapsed = time.perf_counter() - start
    print(f"Building lookups took {elapsed:0.5f} seconds.")

    matched_ngrams = []

    if sets_to_decontaminate > 0:
        print("Merging lookups...")
Fabrizio Milo's avatar
Fabrizio Milo committed
102
        start = time.perf_counter()
researcher2's avatar
researcher2 committed
103
104
105
106
107
108
109
110
111
        merged_lookup = collections.defaultdict(list)
        for (task_name, task_set), lookup in lookups.items():
            for ngram, doc_ids in lookup.items():
                merged_lookup[ngram].append((task_name, task_set, doc_ids))

        elapsed = time.perf_counter() - start
        print(f"Merging lookups took {elapsed:0.5f} seconds.")

        print(f"{ngrams_n_size} grams files found in {ngrams_path}:")
112
        files = glob.glob(os.path.join(ngrams_path, "*.sorted.zst"))
researcher2's avatar
researcher2 committed
113
114
115
116
117
118
119
120
121
122
123
124
        print(files)

        for file in files:
            start = time.perf_counter()
            print(f"Scanning {file}")
            reader = ZStdTextReader(file)
            total_ngrams = 0
            unique_ngrams = 0
            matching_unique = 0
            non_matching_unique = 0

            current_ngram = ""
Fabrizio Milo's avatar
Fabrizio Milo committed
125
            for line in reader.read_tqdm():  # Scan training set ngrams file
researcher2's avatar
researcher2 committed
126
127
                total_ngrams += 1
                [ngram, document_id] = line.rsplit(" ", 1)
Fabrizio Milo's avatar
Fabrizio Milo committed
128
129
130
                if (
                    ngram != current_ngram
                ):  # Only need to match the ngram once in training set
researcher2's avatar
researcher2 committed
131
132
133
                    unique_ngrams += 1
                    current_ngram = ngram
                    if ngram in merged_lookup:
Fabrizio Milo's avatar
Fabrizio Milo committed
134
                        matched_ngrams.append(ngram)  # For logging
researcher2's avatar
researcher2 committed
135
136
137
                        matching_unique += 1
                        for task_name, task_set, doc_ids in merged_lookup[ngram]:
                            task_doc_set = duplicates[(task_name, task_set)]
138
                            for doc_id in doc_ids:  # Record contamination across all relevant task/set combos
researcher2's avatar
researcher2 committed
139
                                task_doc_set.add(doc_id)
Fabrizio Milo's avatar
Fabrizio Milo committed
140
                        del merged_lookup[ngram]  # No point matching again
researcher2's avatar
researcher2 committed
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
                    else:
                        non_matching_unique += 1

            print(f"Total Ngrams: {total_ngrams}")
            print(f"Unique Ngrams: {unique_ngrams}")
            print(f"Unique Matching: {matching_unique}")
            print(f"Unique Non Matching: {non_matching_unique}")
            print("Matched ngrams:")
            for ngram in matched_ngrams:
                print(ngram)

            elapsed = time.perf_counter() - start
            print(f"Read took {elapsed:0.5f} seconds.")
            print(f"Speed: {(os.path.getsize(file)/1000000.0)/elapsed}MB/second")

        print(duplicates)

Fabrizio Milo's avatar
Fabrizio Milo committed
158
        # Dump overlaps separately
researcher2's avatar
researcher2 committed
159
        for (task_name, task_set), doc_ids in duplicates.items():
Fabrizio Milo's avatar
Fabrizio Milo committed
160
161
162
163
            overlaps_dump_path = get_overlaps_dump_path(
                task_name, task_set, ngrams_n_size, limit
            )
            pickle.dump(doc_ids, open(overlaps_dump_path, "wb"))
researcher2's avatar
researcher2 committed
164
165
166

    # Strip task set and return
    return {task_name: doc_ids for (task_name, task_set), doc_ids in duplicates.items()}