find_duplicates.py 11.2 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2

3
import argparse
4
from functools import partial
5
6
7
import itertools
import json
from lsh import cache, minhash
8
import multiprocessing
9
import numpy as np
10
import time
11
import pickle
12
import sys
13
import os
14
15
16
17
18
19
20
21
22
23

# This function is adapted from:
#   https://github.com/mattilyra/LSH/blob/master/examples/Introduction.ipynb
def shingles(text, char_ngram=5):
    return set(text[head:head + char_ngram]
               for head in range(0, len(text) - char_ngram))


# This function is adapted from:
#  https://github.com/mattilyra/LSH/blob/master/examples/Introduction.ipynb
24
25
26
27
def jaccard(set_a, set_b, args):
    if len(set_a) < 1 or len(set_b) < 1:
        return 0.0

28
29
30
    intersection = set_a & set_b
    union = set_a | set_b

31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
    if args.jaccard == 'min':
        return len(intersection) / min(len(set_a), len(set_b))
    elif args.jaccard == 'max':
        return len(intersection) / max(len(set_a), len(set_b))
    else:
        return len(intersection) / len(union)

def compute_fingerprint(line, key):
    try:
        myjson = json.loads(line)
        url = myjson[key]
        text = myjson['text']
        fingerprint = hasher.fingerprint(text)
    except Exception as e:
        print('Error:', e)
        return None, None, None, False

    return url, text, fingerprint, True
49

50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
def url_pairs_to_remove(args, bucket_urls, url_doc):
    remove_urls_list = []
    deduped_local, counter_local = 0, 0
    iteration = 0
    while len(bucket_urls) > 1:
        if args.heuristic_iter != -1 and \
            iteration == args.heuristic_iter:
            break

        items = list(bucket_urls)
        remove_urls = []
        main_url = items[np.random.randint(0, len(items))]
        main_dhingles = shingles(url_doc[main_url])

        for i in range(0, len(items)):
            counter_local += 1
            other_url = items[i]
            if other_url == main_url:
                continue
            other_shingles = shingles(url_doc[other_url])
            try:
                jaccard_sim = jaccard(main_dhingles, other_shingles, args)
            except Exception as e:
                print('Error:', e)
                jaccard_sim = 0.0
            if jaccard_sim > 0.5:
                remove_urls.append({other_url: jaccard_sim})
                deduped_local += 1
                bucket_urls.remove(other_url)

        bucket_urls.remove(main_url)
        if len(remove_urls) > 0:
            remove_urls_list.append({main_url: remove_urls})
        iteration += 1
    return remove_urls_list, deduped_local, counter_local

def write_remove_urls_list(remove_urls_list, f_out):
    if len(remove_urls_list) > 0:
        for each_url_remove in remove_urls_list:
            myjson = json.dumps(each_url_remove, ensure_ascii=False)
            f_out.write(myjson.encode('utf-8'))
            f_out.write('\n'.encode('utf-8'))

def compute_jaccard(each_bin, num_bins, start_time_local):

    remove_urls_list = []
    deduped_local, counter_local, bucket_local = 0, 0, 0

    for bucket_id in each_bin:
        bucket_local += 1
        if os.getpid() % num_bins == 0 and bucket_local % 100000 == 0:
            print("Counter {}, progress {:.2f} time {:.2f}".\
                format(bucket_local, float(bucket_local)/float(len(each_bin)),\
                time.time() - start_time_local), flush=True)

        if len(each_bin[bucket_id]) <= 1:
            continue

        bucket_urls = each_bin[bucket_id].copy()
        remove_urls_list_sub, deduped_local_sub, counter_local_sub = \
            url_pairs_to_remove(args, bucket_urls, url_doc)

        deduped_local += deduped_local_sub
        counter_local += counter_local_sub
        if len(remove_urls_list_sub) > 0:
            remove_urls_list.extend(remove_urls_list_sub)

    return remove_urls_list, deduped_local, counter_local

def find_pair_urls_parallel(args, lshcache, url_doc):
    start_time = time.time()
    f_out = open(args.output, 'wb')
    deduped, counter = 0, 0

    # compute jaccards of buckets in bin in parallel (parallelism
    # limited to # of bins)
    num_bins = len(lshcache.bins)
    pool = multiprocessing.Pool(num_bins)
    compute_jaccard_partial = partial(compute_jaccard, num_bins=num_bins, \
        start_time_local=start_time)
    # don't need to pass args and url_doc as they are already shared
    compute_jaccard_iter = pool.imap(compute_jaccard_partial, lshcache.bins)

    print("multiprocessing init took {:.2f}".format(time.time() - start_time),\
        flush=True)
    for remove_urls_list, deduped_local, counter_local in compute_jaccard_iter:
        deduped += deduped_local
        counter += counter_local
        write_remove_urls_list(remove_urls_list, f_out)
        print(' [write]> processed {} documents in {:.2f} '
            'seoncds and deduped {} documents ...'.format(counter, time.time()\
            - start_time, deduped), flush=True)

    pool.close()
    pool.join()
    f_out.close()

    print(' Taken time for jaccard similariries {:.2f} seconds'.format(\
        time.time() - start_time), flush=True)

def find_pair_urls_sequential(args, lshcache, url_doc):
    start_time = time.time()
    f_out = open(args.output, 'wb')
    deduped, counter = 0, 0
    for b in lshcache.bins:
        for bucket_id in b:
            if len(b[bucket_id]) <= 1:
                continue

            bucket_urls = b[bucket_id].copy()
            remove_urls_list_sub, deduped_local_sub, counter_local_sub = \
                url_pairs_to_remove(args, bucket_urls, url_doc)

            deduped += deduped_local_sub
            counter += counter_local_sub
            write_remove_urls_list(remove_urls_list_sub, f_out)
            if counter % 10000 == 0:
                print(' [write]> processed {} documents in {:.2f} '
                    'seoncds and deduped {} documents ...'.
                    format(counter, time.time() - start_time,
                    deduped), flush=True)
    f_out.close()
    print(' [write]> processed {} documents in {:.2f} '
        'seoncds and deduped {} documents ...'.
        format(counter, time.time() - start_time,
        deduped), flush=True)

177
178
if __name__ == '__main__':

Mostofa Patwary's avatar
Mostofa Patwary committed
179
    print('parsing the arguments ...')
180
181

    parser = argparse.ArgumentParser()
Mostofa Patwary's avatar
Mostofa Patwary committed
182
183
    parser.add_argument('--seed', type=int, default=1234,
                       help='Random seed used for python, numpy')
184
185
186
187
188
189
    parser.add_argument('--inputs', nargs = '*', default=None, help = \
                        'Pairwise list of the input files and keys, '
                        'e.g. --inputs cc.json cc_id news.json news_id')
    parser.add_argument('--load-fingerprints', nargs = '*', default=None,
                       help='Load fingerprints from a list of pickle files,'
                        ' e.g. cc.pkl news.pkl')
190
191
    parser.add_argument('--save-fingerprints', type=str, default=None,
                       help='Save the fingerprints of the inputs.')
192
193
194
    parser.add_argument('--output', type=str, default=None,
                       help='Output file name that consists of all ids'
                        ' with matching similarities')
195
196
197
198
199
200
201
202
203
204
205
    parser.add_argument('--jaccard', type=str, default='union',
                        choices=['union', 'min', 'max'], help='Jaccard'\
                        ' similarity computation')
    parser.add_argument('--heuristic-iter', type=int, default=1,
                       help='Number of iterations to run the heuristics'
                        ': use -1 for exact')
    parser.add_argument('--num-bands', type=int, default=10,
                       help='Number of bands to use in cache')
    parser.add_argument('--num-seeds', type=int, default=100,
                       help='Number of seeds to use for minhash. Note that'
                        ' this value should be divisible by num-bands')
206
207
    parser.add_argument('--jaccard-parallel', action='store_true',
                       help='Use this to process large number of documents.')
208
    args = parser.parse_args()
209

210
    print('finding possible duplicate content ...')
211

212
    # set seed and get an array of seeds of 100 integers
Mostofa Patwary's avatar
Mostofa Patwary committed
213
    np.random.seed(args.seed)
214
    seeds = np.random.randint(0, 1e6, size=args.num_seeds)
215
216
217

    # initialize minhash and lsh cache
    hasher = minhash.MinHasher(seeds=seeds, char_ngram=5, hashbytes=4)
218
    lshcache = cache.Cache(num_bands=args.num_bands, hasher=hasher)
219
220

    url_doc = {}
221
222
223

    # load fingerprints from pickle file if needed
    if args.load_fingerprints is not None:
Mostofa Patwary's avatar
Mostofa Patwary committed
224
        for count_fp, fp_file_name in enumerate(args.load_fingerprints):
225
226
227
228
229
230
231
232
233
234
235
236
237
238
            print("Loading fingerprints from pickle file {}".format(
                fp_file_name), flush=True)
            fp = open(fp_file_name, "rb")
            if count_fp == 0:
                # assign directory for the first pkl
                lshcache = pickle.load(fp)
                url_doc = pickle.load(fp)
            else:
                # append these to lshcache and url_doc
                local_lshcache = pickle.load(fp)
                local_url_doc = pickle.load(fp)
                for url in local_lshcache.fingerprints.keys():
                    url_doc[url] = local_url_doc[url]
                    lshcache.add_fingerprint(local_lshcache.fingerprints[url], url)
Mostofa Patwary's avatar
Mostofa Patwary committed
239
            fp.close()
240
241

    counter = 0
242
    start_time = time.time()
243

244
    # compute finger prints of the inputs if any
Mostofa Patwary's avatar
Mostofa Patwary committed
245
246
    # input file and the key to use as id
    if args.inputs is not None:
247
        print("Computing fingerprints", flush=True)
Mostofa Patwary's avatar
Mostofa Patwary committed
248
249
250
251
        assert len(args.inputs) % 2 == 0
        for input_file, key in zip(args.inputs[::2], args.inputs[1::2]):
            print(' document processing {} with key {}'.format(input_file, key),
                flush=True)
252
253

            # compute fingerprints in parallel
254
            num_workers = 40
255
256
257
258
            pool = multiprocessing.Pool(num_workers)
            fin = open(input_file, 'r', encoding='utf-8')
            compute_fingerprint_partial = partial(compute_fingerprint, key=key)
            compute_fingerprint_iter = pool.imap(compute_fingerprint_partial,
259
                                                    fin, 512)
Mostofa Patwary's avatar
Mostofa Patwary committed
260
            # traverse all the texts and add fingerprints
261
262
263
264
265
266
267
268
269
270
271
272
273
            for url, text, fingerprint, flag in compute_fingerprint_iter:
                counter += 1
                if flag:
                    url_doc[url] = text
                    lshcache.add_fingerprint(fingerprint, url)
                if counter % 10000 == 0:
                    print(' [read]> processed {} documents in {:.2f} '
                        'seconds ...'.format(counter, time.time() - \
                        start_time), flush=True)

            fin.close()
            pool.close()
            pool.join()
274
275
276
277
278

    # Save the fingerprints if needed
    if args.save_fingerprints is not None:
        print("Saving fingerprints to pickle file {}".format(
            args.save_fingerprints), flush=True)
279
280
281
        with open(args.save_fingerprints, 'wb') as f_save:
            pickle.dump(lshcache, f_save)
            pickle.dump(url_doc, f_save)
282

283
284
    # compute jaccard index of the input texts and write to file if needed
    if args.output is not None:
285
286
287
288
289
        print("Compute jaccard similarity", flush=True)
        if args.jaccard_parallel:
            find_pair_urls_parallel(args, lshcache, url_doc)
        else:
            find_pair_urls_sequential(args, lshcache, url_doc)
290
291

    print('done :-)')
292