decontaminate.py 6.66 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
import collections
Rayyyyy's avatar
Rayyyyy committed
2
import glob
Rayyyyy's avatar
Rayyyyy committed
3
import json
Rayyyyy's avatar
Rayyyyy committed
4
import os
Rayyyyy's avatar
Rayyyyy committed
5
6
7
import pickle
import random
import time
Rayyyyy's avatar
Rayyyyy committed
8
9

from .archiver import ZStdTextReader
Rayyyyy's avatar
Rayyyyy committed
10
from .janitor import Janitor, word_ngrams
Rayyyyy's avatar
Rayyyyy committed
11
12
13


# Was used for testing the evaluator decoupled from the full logic below
Rayyyyy's avatar
Rayyyyy committed
14
def get_train_overlap_stub(docs: dict, ngrams_path: str, ngrams_n_size: str):
Rayyyyy's avatar
Rayyyyy committed
15
16
17
18
19
20
21
22
23
24
25
26
27
    simulated_overlap = 0.1
    contaminated = int(len(docs) * simulated_overlap)
    return random.sample(range(len(docs)), contaminated)


# Returns a dictionary containing all overlapping documents in each
# task. In the standard use case, an overlap occurs when any of the 13-grams
# found in the task document exist in the training set documents.
#
# To generate 13-grams for the pile see scripts/clean_training_data. The final output of these
# scripts are an info.json file containing the n_gram_size (13) and a bunch of "ngrams_{x}.bkt.txt.sorted.zst"
# files. These should exist in the "ngrams_path" provided to this function.

Rayyyyy's avatar
Rayyyyy committed
28

Rayyyyy's avatar
Rayyyyy committed
29
30
31
32
33
34
35
36
# Algorithm:
# 1. Build lookups for each dataset {ngram: list(document_ids)}
# 2. Merge into an overall lookup {ngram: [(task_name, task_set, doc_ids),]}
# 3. Full scan the 13-grams from the training set against the merged lookup,
#    saving matches in the "duplicates" dictionary {(task_name, task_set): set(doc_ids)}
# 4. Strip the task_set from the dictionary keys and return
#
# We cache the task+set lookups as well as the overlaps.
Rayyyyy's avatar
Rayyyyy committed
37
def get_train_overlap(docs_by_task_set: dict, ngrams_path: str, limit: int) -> dict:
Rayyyyy's avatar
Rayyyyy committed
38
39
40
    # return get_train_overlap_stub(docs, ngrams_path, ngrams_n_size)

    info_dict_path = os.path.join(ngrams_path, "info.json")
Rayyyyy's avatar
Rayyyyy committed
41
    info_dict = json.load(open(info_dict_path, "r", encoding="utf-8"))
Rayyyyy's avatar
Rayyyyy committed
42
43
44
45
46
47
48
49
    ngrams_n_size = info_dict["ngram_size"]

    janitor = Janitor()

    # Build lookup for each dataset first in case we use different task combinations later
    print("Building Lookups...")
    start = time.perf_counter()

Rayyyyy's avatar
Rayyyyy committed
50
    def get_overlaps_dump_path(task_name, task_set, ngrams_n_size, limit) -> str:
Rayyyyy's avatar
Rayyyyy committed
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
        return f"data/{task_name}/{task_set}_{ngrams_n_size}grams_limit{limit}.overlaps"

    lookups = {}
    duplicates = {}  # (task_name, task_set): set(doc_ids)}
    sets_to_decontaminate = len(docs_by_task_set.keys())

    for (task_name, task_set), docs in docs_by_task_set.items():
        if not os.path.exists(f"data/{task_name}"):
            os.mkdir(f"data/{task_name}")

        # Check if we've decontaminated this combination before
        overlaps_dump_path = get_overlaps_dump_path(
            task_name, task_set, ngrams_n_size, limit
        )
        if os.path.exists(overlaps_dump_path):
            duplicates[(task_name, task_set)] = pickle.load(
                open(overlaps_dump_path, "rb")
            )
            sets_to_decontaminate -= 1
            continue
        else:
            duplicates[(task_name, task_set)] = set()

        # Build/load the task lookup {ngram: set(documents)}.
        task_set_lookup_path = (
            f"data/{task_name}/{task_set}_{ngrams_n_size}grams_limit{limit}.lookup"
        )
        if os.path.exists(task_set_lookup_path):
            print(f"{task_set_lookup_path} available, loading...")
            lookups[(task_name, task_set)] = pickle.load(
                open(task_set_lookup_path, "rb")
            )
        else:
            print(f"{task_set_lookup_path} not available, building...")
            lookup = collections.defaultdict(set)

            for doc_id, document in enumerate(docs):
                ngrams = word_ngrams(janitor.normalize_string(document), ngrams_n_size)
                for ngram in ngrams:
                    lookup[ngram].add(doc_id)

            pickle.dump(lookup, open(task_set_lookup_path, "wb"))
            lookups[(task_name, task_set)] = lookup

    elapsed = time.perf_counter() - start
    print(f"Building lookups took {elapsed:0.5f} seconds.")

    matched_ngrams = []

    if sets_to_decontaminate > 0:
        print("Merging lookups...")
        start = time.perf_counter()
        merged_lookup = collections.defaultdict(list)
        for (task_name, task_set), lookup in lookups.items():
            for ngram, doc_ids in lookup.items():
                merged_lookup[ngram].append((task_name, task_set, doc_ids))

        elapsed = time.perf_counter() - start
        print(f"Merging lookups took {elapsed:0.5f} seconds.")

        print(f"{ngrams_n_size} grams files found in {ngrams_path}:")
Rayyyyy's avatar
Rayyyyy committed
112
        files = glob.glob(os.path.join(ngrams_path, "*.sorted.zst"))
Rayyyyy's avatar
Rayyyyy committed
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
        print(files)

        for file in files:
            start = time.perf_counter()
            print(f"Scanning {file}")
            reader = ZStdTextReader(file)
            total_ngrams = 0
            unique_ngrams = 0
            matching_unique = 0
            non_matching_unique = 0

            current_ngram = ""
            for line in reader.read_tqdm():  # Scan training set ngrams file
                total_ngrams += 1
                [ngram, document_id] = line.rsplit(" ", 1)
                if (
                    ngram != current_ngram
                ):  # Only need to match the ngram once in training set
                    unique_ngrams += 1
                    current_ngram = ngram
                    if ngram in merged_lookup:
                        matched_ngrams.append(ngram)  # For logging
                        matching_unique += 1
                        for task_name, task_set, doc_ids in merged_lookup[ngram]:
                            task_doc_set = duplicates[(task_name, task_set)]
Rayyyyy's avatar
Rayyyyy committed
138
                            for doc_id in doc_ids:  # Record contamination across all relevant task/set combos
Rayyyyy's avatar
Rayyyyy committed
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
                                task_doc_set.add(doc_id)
                        del merged_lookup[ngram]  # No point matching again
                    else:
                        non_matching_unique += 1

            print(f"Total Ngrams: {total_ngrams}")
            print(f"Unique Ngrams: {unique_ngrams}")
            print(f"Unique Matching: {matching_unique}")
            print(f"Unique Non Matching: {non_matching_unique}")
            print("Matched ngrams:")
            for ngram in matched_ngrams:
                print(ngram)

            elapsed = time.perf_counter() - start
            print(f"Read took {elapsed:0.5f} seconds.")
            print(f"Speed: {(os.path.getsize(file)/1000000.0)/elapsed}MB/second")

        print(duplicates)

        # Dump overlaps separately
        for (task_name, task_set), doc_ids in duplicates.items():
            overlaps_dump_path = get_overlaps_dump_path(
                task_name, task_set, ngrams_n_size, limit
            )
            pickle.dump(doc_ids, open(overlaps_dump_path, "wb"))

    # Strip task set and return
    return {task_name: doc_ids for (task_name, task_set), doc_ids in duplicates.items()}