contamination.py 6.25 KB
Newer Older
researcher2's avatar
researcher2 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import time
import random
import pickle
import glob
import os
import collections

from scripts.clean_training_data.janitor import Janitor, word_ngrams
from scripts.clean_training_data.archiver import ZStdTextReader

# Was used for testing the evaluator decoupled from the full logic below
def get_train_overlap_stub(docs, ngrams_path, ngrams_n_size):
    simulated_overlap = 0.1
    contaminated = int(len(docs) * simulated_overlap)
    return random.sample(range(len(docs)), contaminated)


# Returns a dictionary containing all overlapping documents in each
# task based on any 13gram being found in the training set.
# ngrams_path is the parent directory containing the "ngrams_{x}.bkt.txt.sorted.zst"
# files built by the other scripts "generate_13_grams.py" and "sort_13_gram_buckets.py.
# ngrams_n_size is expected to be 13 but we made it a parameter for generality.
# The task set is only included for caching purposes.

# Algorithm:
# 1. Build lookups for each dataset {ngram: list(document_ids)}
# 2. Merge into an overall lookup {ngram: [(task_name, task_set, doc_ids),]}
# 3. Full scan the 13-grams from the training set against the merged lookup,
#    saving matches in the "duplicates" dictionary {(task_name, task_set): set(doc_ids)}
# 4. Strip the task_set from the dictionary keys and return
#
# We cache the task+set lookups as well as the overlaps.
#
# Currently calculating some per file ngram stats for interest, might remove before merging into main
def get_train_overlap(docs_by_task_set, ngrams_path, ngrams_n_size, limit):
    # return get_train_overlap_stub(docs, ngrams_path, ngrams_n_size)

    janitor = Janitor()

    # Build lookup for each dataset first in case we use different task combinations later
    print("Building Lookups...")
    start = time.perf_counter()

    def get_overlaps_dump_path(task_name, task_set, ngrams_n_size, limit):
        return f"data/{task_name}/{task_set}_{ngrams_n_size}grams_limit{limit}.overlaps"

    lookups = {}
    duplicates = {} # (task_name, task_set): set(doc_ids)}
    sets_to_decontaminate = len(docs_by_task_set.keys())

    for (task_name, task_set), docs in docs_by_task_set.items():
52
53
        if not os.path.exists(f"data/{task_name}"):
            os.mkdir(f"data/{task_name}")
researcher2's avatar
researcher2 committed
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
        # Check if we've decontaminated this set before
        overlaps_dump_path = get_overlaps_dump_path(task_name, task_set, ngrams_n_size, limit)
        if os.path.exists(overlaps_dump_path):
            duplicates[(task_name, task_set)] = pickle.load(open(overlaps_dump_path, "rb"))
            sets_to_decontaminate -= 1
            continue
        else:
            duplicates[(task_name, task_set)] = set() # No defaultdict, we want to dump empty sets too later

        # Build/load the task lookup {ngram: documents}.
        task_set_lookup_path = f"data/{task_name}/{task_set}_{ngrams_n_size}grams_limit{limit}.lookup"
        if os.path.exists(task_set_lookup_path):
            print(f"{task_set_lookup_path} available, loading...")        
            lookups[(task_name, task_set)] = pickle.load(open(task_set_lookup_path, "rb"))
        else:
            print(f"{task_set_lookup_path} not available, building...")
            lookup = collections.defaultdict(set)

            for doc_id, document in enumerate(docs):
                ngrams = word_ngrams(janitor.normalize_string(document), ngrams_n_size)
                for ngram in ngrams:
                    lookup[ngram].add(doc_id)

            pickle.dump(lookup, open(task_set_lookup_path,"wb"))
            lookups[(task_name, task_set)] = lookup

    elapsed = time.perf_counter() - start
    print(f"Building lookups took {elapsed:0.5f} seconds.")

    matched_ngrams = []

    if sets_to_decontaminate > 0:
        print("Merging lookups...")
        start = time.perf_counter()    
        merged_lookup = collections.defaultdict(list)
        for (task_name, task_set), lookup in lookups.items():
            for ngram, doc_ids in lookup.items():
                merged_lookup[ngram].append((task_name, task_set, doc_ids))

        elapsed = time.perf_counter() - start
        print(f"Merging lookups took {elapsed:0.5f} seconds.")

        print(f"13 grams files found in {ngrams_path}:")
        files = glob.glob(os.path.join(ngrams_path, f"*.sorted.zst"))
        print(files)

        for file in files:
            start = time.perf_counter()
            print(f"Scanning {file}")
            reader = ZStdTextReader(file)
            total_ngrams = 0
            unique_ngrams = 0
            matching_unique = 0
            non_matching_unique = 0

            current_ngram = ""
            for line in reader.read_tqdm():
                total_ngrams += 1
                [ngram, document_id] = line.rsplit(" ", 1)
                if ngram != current_ngram:
                    unique_ngrams += 1
                    current_ngram = ngram
                    if ngram in merged_lookup:
                        matched_ngrams.append(ngram)
                        matching_unique += 1
                        for task_name, task_set, doc_ids in merged_lookup[ngram]:
                            task_doc_set = duplicates[(task_name, task_set)]
                            for doc_id in doc_ids:
                                task_doc_set.add(doc_id)
                        del merged_lookup[ngram] # No point matching again
                    else:
                        non_matching_unique += 1

            print(f"Total Ngrams: {total_ngrams}")
            print(f"Unique Ngrams: {unique_ngrams}")
            print(f"Unique Matching: {matching_unique}")
            print(f"Unique Non Matching: {non_matching_unique}")
            print("Matched ngrams:")
            for ngram in matched_ngrams:
                print(ngram)

            elapsed = time.perf_counter() - start
            print(f"Read took {elapsed:0.5f} seconds.")
            print(f"Speed: {(os.path.getsize(file)/1000000.0)/elapsed}MB/second")

        print(duplicates)

        # Dump overlaps separately    
        for (task_name, task_set), doc_ids in duplicates.items():
            overlaps_dump_path = get_overlaps_dump_path(task_name, task_set, ngrams_n_size, limit)
            pickle.dump(doc_ids, open(overlaps_dump_path,"wb"))

    # Strip task set and return
    return {task_name: doc_ids for (task_name, task_set), doc_ids in duplicates.items()}