decontaminate.py 6.72 KB
Newer Older
researcher2's avatar
researcher2 committed
1
2
3
4
5
6
7
8
9
10
11
import time
import random
import pickle
import json
import glob
import os
import collections

from .janitor import Janitor, word_ngrams
from .archiver import ZStdTextReader

12

researcher2's avatar
researcher2 committed
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
# Was used for testing the evaluator decoupled from the full logic below
def get_train_overlap_stub(docs, ngrams_path, ngrams_n_size):
    simulated_overlap = 0.1
    contaminated = int(len(docs) * simulated_overlap)
    return random.sample(range(len(docs)), contaminated)


# Returns a dictionary containing all overlapping documents in each
# task. In the standard use case, an overlap occurs when any of the 13-grams
# found in the task document exist in the training set documents.
#
# To generate 13-grams for the pile see scripts/clean_training_data. The final output of these
# scripts are an info.json file containing the n_gram_size (13) and a bunch of "ngrams_{x}.bkt.txt.sorted.zst"
# files. These should exist in the "ngrams_path" provided to this function.

# Algorithm:
# 1. Build lookups for each dataset {ngram: list(document_ids)}
# 2. Merge into an overall lookup {ngram: [(task_name, task_set, doc_ids),]}
# 3. Full scan the 13-grams from the training set against the merged lookup,
#    saving matches in the "duplicates" dictionary {(task_name, task_set): set(doc_ids)}
# 4. Strip the task_set from the dictionary keys and return
#
# We cache the task+set lookups as well as the overlaps.
def get_train_overlap(docs_by_task_set, ngrams_path, limit):
    # return get_train_overlap_stub(docs, ngrams_path, ngrams_n_size)

    info_dict_path = os.path.join(ngrams_path, "info.json")
    info_dict = json.load(open(info_dict_path, "r"))
    ngrams_n_size = info_dict["ngram_size"]

    janitor = Janitor()

    # Build lookup for each dataset first in case we use different task combinations later
    print("Building Lookups...")
    start = time.perf_counter()

    def get_overlaps_dump_path(task_name, task_set, ngrams_n_size, limit):
        return f"data/{task_name}/{task_set}_{ngrams_n_size}grams_limit{limit}.overlaps"

    lookups = {}
Fabrizio Milo's avatar
Fabrizio Milo committed
53
    duplicates = {}  # (task_name, task_set): set(doc_ids)}
researcher2's avatar
researcher2 committed
54
55
56
57
58
59
60
    sets_to_decontaminate = len(docs_by_task_set.keys())

    for (task_name, task_set), docs in docs_by_task_set.items():
        if not os.path.exists(f"data/{task_name}"):
            os.mkdir(f"data/{task_name}")

        # Check if we've decontaminated this combination before
Fabrizio Milo's avatar
Fabrizio Milo committed
61
62
63
        overlaps_dump_path = get_overlaps_dump_path(
            task_name, task_set, ngrams_n_size, limit
        )
researcher2's avatar
researcher2 committed
64
        if os.path.exists(overlaps_dump_path):
Fabrizio Milo's avatar
Fabrizio Milo committed
65
66
67
            duplicates[(task_name, task_set)] = pickle.load(
                open(overlaps_dump_path, "rb")
            )
researcher2's avatar
researcher2 committed
68
69
70
71
72
73
            sets_to_decontaminate -= 1
            continue
        else:
            duplicates[(task_name, task_set)] = set()

        # Build/load the task lookup {ngram: set(documents)}.
Fabrizio Milo's avatar
Fabrizio Milo committed
74
75
76
        task_set_lookup_path = (
            f"data/{task_name}/{task_set}_{ngrams_n_size}grams_limit{limit}.lookup"
        )
researcher2's avatar
researcher2 committed
77
        if os.path.exists(task_set_lookup_path):
Fabrizio Milo's avatar
Fabrizio Milo committed
78
79
80
81
            print(f"{task_set_lookup_path} available, loading...")
            lookups[(task_name, task_set)] = pickle.load(
                open(task_set_lookup_path, "rb")
            )
researcher2's avatar
researcher2 committed
82
83
84
85
86
87
88
89
90
        else:
            print(f"{task_set_lookup_path} not available, building...")
            lookup = collections.defaultdict(set)

            for doc_id, document in enumerate(docs):
                ngrams = word_ngrams(janitor.normalize_string(document), ngrams_n_size)
                for ngram in ngrams:
                    lookup[ngram].add(doc_id)

Fabrizio Milo's avatar
Fabrizio Milo committed
91
            pickle.dump(lookup, open(task_set_lookup_path, "wb"))
researcher2's avatar
researcher2 committed
92
93
94
95
96
97
98
99
100
            lookups[(task_name, task_set)] = lookup

    elapsed = time.perf_counter() - start
    print(f"Building lookups took {elapsed:0.5f} seconds.")

    matched_ngrams = []

    if sets_to_decontaminate > 0:
        print("Merging lookups...")
Fabrizio Milo's avatar
Fabrizio Milo committed
101
        start = time.perf_counter()
researcher2's avatar
researcher2 committed
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
        merged_lookup = collections.defaultdict(list)
        for (task_name, task_set), lookup in lookups.items():
            for ngram, doc_ids in lookup.items():
                merged_lookup[ngram].append((task_name, task_set, doc_ids))

        elapsed = time.perf_counter() - start
        print(f"Merging lookups took {elapsed:0.5f} seconds.")

        print(f"{ngrams_n_size} grams files found in {ngrams_path}:")
        files = glob.glob(os.path.join(ngrams_path, f"*.sorted.zst"))
        print(files)

        for file in files:
            start = time.perf_counter()
            print(f"Scanning {file}")
            reader = ZStdTextReader(file)
            total_ngrams = 0
            unique_ngrams = 0
            matching_unique = 0
            non_matching_unique = 0

            current_ngram = ""
Fabrizio Milo's avatar
Fabrizio Milo committed
124
            for line in reader.read_tqdm():  # Scan training set ngrams file
researcher2's avatar
researcher2 committed
125
126
                total_ngrams += 1
                [ngram, document_id] = line.rsplit(" ", 1)
Fabrizio Milo's avatar
Fabrizio Milo committed
127
128
129
                if (
                    ngram != current_ngram
                ):  # Only need to match the ngram once in training set
researcher2's avatar
researcher2 committed
130
131
132
                    unique_ngrams += 1
                    current_ngram = ngram
                    if ngram in merged_lookup:
Fabrizio Milo's avatar
Fabrizio Milo committed
133
                        matched_ngrams.append(ngram)  # For logging
researcher2's avatar
researcher2 committed
134
135
136
                        matching_unique += 1
                        for task_name, task_set, doc_ids in merged_lookup[ngram]:
                            task_doc_set = duplicates[(task_name, task_set)]
Fabrizio Milo's avatar
Fabrizio Milo committed
137
138
139
140
141
                            for (
                                doc_id
                            ) in (
                                doc_ids
                            ):  # Record contamination across all relevant task/set combos
researcher2's avatar
researcher2 committed
142
                                task_doc_set.add(doc_id)
Fabrizio Milo's avatar
Fabrizio Milo committed
143
                        del merged_lookup[ngram]  # No point matching again
researcher2's avatar
researcher2 committed
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
                    else:
                        non_matching_unique += 1

            print(f"Total Ngrams: {total_ngrams}")
            print(f"Unique Ngrams: {unique_ngrams}")
            print(f"Unique Matching: {matching_unique}")
            print(f"Unique Non Matching: {non_matching_unique}")
            print("Matched ngrams:")
            for ngram in matched_ngrams:
                print(ngram)

            elapsed = time.perf_counter() - start
            print(f"Read took {elapsed:0.5f} seconds.")
            print(f"Speed: {(os.path.getsize(file)/1000000.0)/elapsed}MB/second")

        print(duplicates)

Fabrizio Milo's avatar
Fabrizio Milo committed
161
        # Dump overlaps separately
researcher2's avatar
researcher2 committed
162
        for (task_name, task_set), doc_ids in duplicates.items():
Fabrizio Milo's avatar
Fabrizio Milo committed
163
164
165
166
            overlaps_dump_path = get_overlaps_dump_path(
                task_name, task_set, ngrams_n_size, limit
            )
            pickle.dump(doc_ids, open(overlaps_dump_path, "wb"))
researcher2's avatar
researcher2 committed
167
168
169

    # Strip task set and return
    return {task_name: doc_ids for (task_name, task_set), doc_ids in duplicates.items()}