decontamination.py 16.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
import mmap
import string
import time
import random
import pickle
import glob
import os
import collections
import re

import tqdm

try:
    import janitor_util
    JANITOR_CPP = True
except Exception as e:
    print("WARNING: C++ module could not be loaded. Janitor running in python mode")
    JANITOR_CPP = False

# Was used for testing the evaluator decoupled from the full logic below
def get_train_overlap_stub(docs, ngrams_path, ngrams_n_size):
    simulated_overlap = 0.1
    contaminated = int(len(docs) * simulated_overlap)
    return random.sample(range(len(docs)), contaminated)


# Returns a dictionary containing all overlapping documents in each
# task based on any 13gram being found in the training set.
# ngrams_path is the parent directory containing the "ngrams_{x}.bkt.txt.sorted.zst"
# files built by the other scripts "generate_13_grams.py" and "sort_13_gram_buckets.py.
# ngrams_n_size is expected to be 13 but we made it a parameter for generality.
# The task set is only included for caching purposes.

# Algorithm:
# 1. Build lookups for each dataset {ngram: list(document_ids)}
# 2. Merge into an overall lookup {ngram: [(task_name, task_set, doc_ids),]}
# 3. Full scan the 13-grams from the training set against the merged lookup,
#    saving matches in the "duplicates" dictionary {(task_name, task_set): set(doc_ids)}
# 4. Strip the task_set from the dictionary keys and return
#
# We cache the task+set lookups as well as the overlaps.
#
# Currently calculating some per file ngram stats for interest, might remove before merging into main
def get_train_overlap(docs_by_task_set, ngrams_path, ngrams_n_size, limit):
    # return get_train_overlap_stub(docs, ngrams_path, ngrams_n_size)

    janitor = Janitor()

    # Build lookup for each dataset first in case we use different task combinations later
    print("Building Lookups...")
    start = time.perf_counter()

    def get_overlaps_dump_path(task_name, task_set, ngrams_n_size, limit):
        return f"data/{task_name}/{task_set}_{ngrams_n_size}grams_limit{limit}.overlaps"

    lookups = {}
    duplicates = {} # (task_name, task_set): set(doc_ids)}
    sets_to_decontaminate = len(docs_by_task_set.keys())

    for (task_name, task_set), docs in docs_by_task_set.items():
        if not os.path.exists(f"data/{task_name}"):
            os.mkdir(f"data/{task_name}")
        # Check if we've decontaminated this set before
        overlaps_dump_path = get_overlaps_dump_path(task_name, task_set, ngrams_n_size, limit)
        if os.path.exists(overlaps_dump_path):
            duplicates[(task_name, task_set)] = pickle.load(open(overlaps_dump_path, "rb"))
            sets_to_decontaminate -= 1
            continue
        else:
            duplicates[(task_name, task_set)] = set() # No defaultdict, we want to dump empty sets too later

        # Build/load the task lookup {ngram: documents}.
        task_set_lookup_path = f"data/{task_name}/{task_set}_{ngrams_n_size}grams_limit{limit}.lookup"
        if os.path.exists(task_set_lookup_path):
            print(f"{task_set_lookup_path} available, loading...")        
            lookups[(task_name, task_set)] = pickle.load(open(task_set_lookup_path, "rb"))
        else:
            print(f"{task_set_lookup_path} not available, building...")
            lookup = collections.defaultdict(set)

            for doc_id, document in enumerate(docs):
                ngrams = word_ngrams(janitor.normalize_string(document), ngrams_n_size)
                for ngram in ngrams:
                    lookup[ngram].add(doc_id)

            pickle.dump(lookup, open(task_set_lookup_path,"wb"))
            lookups[(task_name, task_set)] = lookup

    elapsed = time.perf_counter() - start
    print(f"Building lookups took {elapsed:0.5f} seconds.")

    matched_ngrams = []

    if sets_to_decontaminate > 0:
        print("Merging lookups...")
        start = time.perf_counter()    
        merged_lookup = collections.defaultdict(list)
        for (task_name, task_set), lookup in lookups.items():
            for ngram, doc_ids in lookup.items():
                merged_lookup[ngram].append((task_name, task_set, doc_ids))

        elapsed = time.perf_counter() - start
        print(f"Merging lookups took {elapsed:0.5f} seconds.")

        print(f"13 grams files found in {ngrams_path}:")
        files = glob.glob(os.path.join(ngrams_path, f"*.sorted.zst"))
        print(files)

        for file in files:
            start = time.perf_counter()
            print(f"Scanning {file}")
            reader = ZStdTextReader(file)
            total_ngrams = 0
            unique_ngrams = 0
            matching_unique = 0
            non_matching_unique = 0

            current_ngram = ""
            for line in reader.read_tqdm(): # Scan training set ngrams file
                total_ngrams += 1
                [ngram, document_id] = line.rsplit(" ", 1)
                if ngram != current_ngram: # Only need to match the ngram once in training set
                    unique_ngrams += 1
                    current_ngram = ngram
                    if ngram in merged_lookup:
                        matched_ngrams.append(ngram) # For logging
                        matching_unique += 1
                        for task_name, task_set, doc_ids in merged_lookup[ngram]:
                            task_doc_set = duplicates[(task_name, task_set)]
                            for doc_id in doc_ids: # Record contamination across all relevant task/set combos
                                task_doc_set.add(doc_id)
                        del merged_lookup[ngram] # No point matching again
                    else:
                        non_matching_unique += 1

            print(f"Total Ngrams: {total_ngrams}")
            print(f"Unique Ngrams: {unique_ngrams}")
            print(f"Unique Matching: {matching_unique}")
            print(f"Unique Non Matching: {non_matching_unique}")
            print("Matched ngrams:")
            for ngram in matched_ngrams:
                print(ngram)

            elapsed = time.perf_counter() - start
            print(f"Read took {elapsed:0.5f} seconds.")
            print(f"Speed: {(os.path.getsize(file)/1000000.0)/elapsed}MB/second")

        print(duplicates)

        # Dump overlaps separately    
        for (task_name, task_set), doc_ids in duplicates.items():
            overlaps_dump_path = get_overlaps_dump_path(task_name, task_set, ngrams_n_size, limit)
            pickle.dump(doc_ids, open(overlaps_dump_path,"wb"))

    # Strip task set and return
    return {task_name: doc_ids for (task_name, task_set), doc_ids in duplicates.items()}


# Implementation from nltk source
# https://www.nltk.org/_modules/nltk/util.html
def form_ngrams(sequence, n):
    history = []
    while n > 1:
        # PEP 479, prevent RuntimeError from being raised when StopIteration bubbles out of generator
        try:
            next_item = next(sequence)
        except StopIteration:
            # no more data, terminate the generator
            return
        history.append(next_item)
        n -= 1
    for item in sequence:
        history.append(item)
        yield tuple(history)
        del history[0]


def word_ngrams(s, n):
    """Splits a string into ngram words"""
    tokens = s.split()  # not a generator :(
    ngram_seqs = form_ngrams(iter(tokens), n)
    return (" ".join(ngram) for ngram in ngram_seqs)


# https://stackoverflow.com/questions/13734451/string-split-with-indices-in-python
def split_indices(s):
    """Splits a string on whitespaces and records the indices of each in the original string.
    @:return generator((word, (start_idx, end_idx)), ...)
    """
    return ((m.group(0), (m.start(), m.end() - 1)) for m in re.finditer(r'\S+', s))


def word_ngrams_indices(s, n):
    """Splits a string into pairs of (ngram words, their start/end indices)"""
    tokens_with_indices = split_indices(s)

    # Generator of ngrams of (word, idx_pairs)
    # (
    #   [(word, (start,end)), (word, (start, end))...],
    #   [(word, (start, end)), ...],
    #   ...
    # )
    ngram_seqs_with_indices = form_ngrams(tokens_with_indices, n)

    # Generator of pairs of word and index ngrams
    # (
    #   ([word, word, ...], [(start,end), (start,end), ...]),
    #   ...
    # )
    ngram_indices_pairs = (zip(*ngram_with_indices) for ngram_with_indices in ngram_seqs_with_indices)

    # Generator of ( (word_ngram, (start, end)), (word_ngram, start, end)), ...)
    return ((" ".join(ngram_seq), (indices[0][0], indices[-1][1])) for ngram_seq, indices in ngram_indices_pairs)


class Janitor:

    # FIXME delete_chars: Should anything else go here? Special chars?
    def __init__(
            self,
            ngram_n=13,
            window_to_remove=200,
            too_dirty_cutoff=10,
            minimum_slice_length=200,
            delete_chars=string.punctuation
    ):
        self.ngram_n = ngram_n
        self.window_to_remove = window_to_remove
        self.too_dirty_cutoff = too_dirty_cutoff
        self.minimum_slice_length = minimum_slice_length
        self.delete_chars = delete_chars

        self.dirt_ngrams = set()

        # If in python, we'll translate uppercase to lowercase and delete naughty characters.
        # This is fast by python standards
        # https://stackoverflow.com/questions/638893/what-is-the-most-efficient-way-in-python-to-convert-a-string-to-all-lowercase-st
        self.translation_table = str.maketrans(
            string.ascii_lowercase + string.ascii_uppercase,  # These characters
            string.ascii_lowercase * 2,  # Become these characters
            self.delete_chars  # These are deleted
        )

    ##############
    # I/O for saving contamination ngrams
    ##############

    def save_contamination_ngrams(self, filename):
        with open(filename, 'wb') as fp:
            pickle.dump(filename, fp)

    def load_contamination_ngrams(self, filename):
        with open(filename, 'rb') as fp:
            self.dirt_ngrams = pickle.load(fp)


    ##############
    # Call these :)
    ##############

    def register_contaminant(self, dirt_string):
        """Register a string as contamination to be removed, e.g. a test set
        This breaks the dirt_string into ngrams to store for future cleaning"""
        if JANITOR_CPP:
            return self.register_contaminant_cpp(dirt_string)
        else:
            print("WARNING: Janitor running in python mode")
            return self.register_contaminant_python(dirt_string)

    def clean(self, dirty_string):
        """Clean a string (e.g. a training set) by removing all ngrams previously
        reigstered as contaminants. Returns a list of clean chunks, or empty if
        the string was too dirty"""
        if JANITOR_CPP:
            return self.clean_cpp(dirty_string)
        else:
            print("WARNING: Janitor running in python mode")
            return self.clean_python(dirty_string)

    def _split_chunks(self, dirty_string, dirty_parts):
        clean_chunks = []
        splice_idx = 0
        end = -1
        for i, (ngram, start, end) in enumerate(dirty_parts):
            if i >= self.too_dirty_cutoff:
                return []
            start = max(0, start - self.window_to_remove)
            end = min(len(dirty_string), end + self.window_to_remove)

            if start - splice_idx > self.minimum_slice_length:
                clean_chunks.append(dirty_string[splice_idx: start])
            splice_idx = end

        if end < len(dirty_string) - self.minimum_slice_length:
            clean_chunks.append(dirty_string[end+1:])

        return clean_chunks

    ##############
    # Fast C++
    ##############

    def register_contaminant_cpp(self, dirt_string):
        self.dirt_ngrams.update(janitor_util.clean_ngram(dirt_string, self.delete_chars, self.ngram_n))

    def clean_cpp(self, dirty_string):
        contamination_indices = janitor_util.clean_ngram_with_indices(dirty_string, self.delete_chars, self.ngram_n)
        return self._split_chunks(dirty_string, contamination_indices)

    ##############
    # Slow python
    ##############

    def normalize_string(self, s):
        return s.translate(self.translation_table)

    def register_contaminant_python(self, dirt_string):
        self.dirt_ngrams.update(word_ngrams(self.normalize_string(dirt_string), self.ngram_n))

    def clean_python(self, dirty_string):
        contamination_indices = (
            (None, *idx_pair)
            for dirty_ngram, idx_pair in word_ngrams_indices(dirty_string, self.ngram_n)
            if self.normalize_string(dirty_ngram) in self.dirt_ngrams
        )
        return self._split_chunks(dirty_string, contamination_indices)



# Implementation from nltk source
# https://www.nltk.org/_modules/nltk/util.html
def form_ngrams(sequence, n):
    history = []
    while n > 1:
        # PEP 479, prevent RuntimeError from being raised when StopIteration bubbles out of generator
        try:
            next_item = next(sequence)
        except StopIteration:
            # no more data, terminate the generator
            return
        history.append(next_item)
        n -= 1
    for item in sequence:
        history.append(item)
        yield tuple(history)
        del history[0]


def word_ngrams(s, n):
    """Splits a string into ngram words"""
    tokens = s.split()  # not a generator :(
    ngram_seqs = form_ngrams(iter(tokens), n)
    return (" ".join(ngram) for ngram in ngram_seqs)

# Simple text reader and writer with same interface as above
class TextArchive:
    def __init__(self, file_path, mode="ab"):
        self.file_path = file_path
        dir_name = os.path.dirname(file_path)
        if dir_name:
            os.makedirs(dir_name, exist_ok=True)    
        self.fh = open(self.file_path, mode)      
    
    def add_data(self, data, meta={}):
        self.fh.write(data.encode('UTF-8') + b'\n')
    
    def commit(self):
        self.fh.flush()
        self.fh.close()

class TextReader:
    def __init__(self, file_path):
        self.file_path = file_path

    # Optimized mmap read with infrequent tqdm updates to maintain speed
    # Tested up to 250MB/s. 
    def read_tqdm(self, update_frequency=10000):
        current_file_position = 0
        line_counter = 0
        with open(self.file_path, 'r') as fh, \
            tqdm.tqdm(total=os.path.getsize(self.file_path), dynamic_ncols=True, 
                unit="byte", unit_scale=1) as progress:
            with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
                for line in iter(mmap_obj.readline, b""):
                    line = line.decode("utf-8")
                    line_counter += 1
                    if line_counter == update_frequency:
                        new_file_pos = mmap_obj.tell() 
                        bytes_read = new_file_pos - current_file_position
                        current_file_position = new_file_pos
                        progress.update(bytes_read)
                        line_counter = 0
                    yield line[:-1]

    def read_and_tell(self):
        current_file_position = 0
        with open(self.file_path, 'r', encoding="utf8") as fh:
            with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
                for line in iter(mmap_obj.readline, b""):
                    line = line.decode("utf-8")                    
                    new_file_pos = mmap_obj.tell() 
                    raw_bytes_read = new_file_pos - current_file_position
                    current_file_position = new_file_pos
                    yield line[:-1], raw_bytes_read

    def read(self):
        with open(self.file_path, 'r', encoding="utf8") as fh:
            with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
                for line in iter(mmap_obj.readline, b""):
                    line = line.decode("utf-8")                    
                    yield line[:-1]

    def read_slow(self):
        with open(self.file_path, 'r', encoding="utf8") as fh:
            while True:
                line = fh.readline()
                if line == -1 or line == "":
                    break
                else:
                    yield line[:-1]

# Optimized for speed. Decompresses the archive in shell before
# using the mmap'd TextReader.
class ZStdTextReader:
    def __init__(self, file):
        self.file = file        

    def read_tqdm(self):       
        decompressed_file = self.file[:-4]
        print("Decompressing file, please wait...")
        os.system(f"zstd -d {self.file}") # linux decompress is faster        
        reader = TextReader(decompressed_file)
        yield from reader.read_tqdm()
        os.remove(decompressed_file)