find_duplicates.py 11.7 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

16
import argparse
17
from functools import partial
18
19
20
import itertools
import json
from lsh import cache, minhash
21
import multiprocessing
22
import numpy as np
23
import time
24
import pickle
25
import sys
26
import os
27
28
29
30
31
32
33
34
35
36

# This function is adapted from:
#   https://github.com/mattilyra/LSH/blob/master/examples/Introduction.ipynb
def shingles(text, char_ngram=5):
    return set(text[head:head + char_ngram]
               for head in range(0, len(text) - char_ngram))


# This function is adapted from:
#  https://github.com/mattilyra/LSH/blob/master/examples/Introduction.ipynb
37
38
39
40
def jaccard(set_a, set_b, args):
    if len(set_a) < 1 or len(set_b) < 1:
        return 0.0

41
42
43
    intersection = set_a & set_b
    union = set_a | set_b

44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
    if args.jaccard == 'min':
        return len(intersection) / min(len(set_a), len(set_b))
    elif args.jaccard == 'max':
        return len(intersection) / max(len(set_a), len(set_b))
    else:
        return len(intersection) / len(union)

def compute_fingerprint(line, key):
    try:
        myjson = json.loads(line)
        url = myjson[key]
        text = myjson['text']
        fingerprint = hasher.fingerprint(text)
    except Exception as e:
        print('Error:', e)
        return None, None, None, False

    return url, text, fingerprint, True
62

63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
def url_pairs_to_remove(args, bucket_urls, url_doc):
    remove_urls_list = []
    deduped_local, counter_local = 0, 0
    iteration = 0
    while len(bucket_urls) > 1:
        if args.heuristic_iter != -1 and \
            iteration == args.heuristic_iter:
            break

        items = list(bucket_urls)
        remove_urls = []
        main_url = items[np.random.randint(0, len(items))]
        main_dhingles = shingles(url_doc[main_url])

        for i in range(0, len(items)):
            counter_local += 1
            other_url = items[i]
            if other_url == main_url:
                continue
            other_shingles = shingles(url_doc[other_url])
            try:
                jaccard_sim = jaccard(main_dhingles, other_shingles, args)
            except Exception as e:
                print('Error:', e)
                jaccard_sim = 0.0
            if jaccard_sim > 0.5:
                remove_urls.append({other_url: jaccard_sim})
                deduped_local += 1
                bucket_urls.remove(other_url)

        bucket_urls.remove(main_url)
        if len(remove_urls) > 0:
            remove_urls_list.append({main_url: remove_urls})
        iteration += 1
    return remove_urls_list, deduped_local, counter_local

def write_remove_urls_list(remove_urls_list, f_out):
    if len(remove_urls_list) > 0:
        for each_url_remove in remove_urls_list:
            myjson = json.dumps(each_url_remove, ensure_ascii=False)
            f_out.write(myjson.encode('utf-8'))
            f_out.write('\n'.encode('utf-8'))

def compute_jaccard(each_bin, num_bins, start_time_local):

    remove_urls_list = []
    deduped_local, counter_local, bucket_local = 0, 0, 0

    for bucket_id in each_bin:
        bucket_local += 1
        if os.getpid() % num_bins == 0 and bucket_local % 100000 == 0:
            print("Counter {}, progress {:.2f} time {:.2f}".\
                format(bucket_local, float(bucket_local)/float(len(each_bin)),\
                time.time() - start_time_local), flush=True)

        if len(each_bin[bucket_id]) <= 1:
            continue

        bucket_urls = each_bin[bucket_id].copy()
        remove_urls_list_sub, deduped_local_sub, counter_local_sub = \
            url_pairs_to_remove(args, bucket_urls, url_doc)

        deduped_local += deduped_local_sub
        counter_local += counter_local_sub
        if len(remove_urls_list_sub) > 0:
            remove_urls_list.extend(remove_urls_list_sub)

    return remove_urls_list, deduped_local, counter_local

def find_pair_urls_parallel(args, lshcache, url_doc):
    start_time = time.time()
    f_out = open(args.output, 'wb')
    deduped, counter = 0, 0

    # compute jaccards of buckets in bin in parallel (parallelism
    # limited to # of bins)
    num_bins = len(lshcache.bins)
    pool = multiprocessing.Pool(num_bins)
    compute_jaccard_partial = partial(compute_jaccard, num_bins=num_bins, \
        start_time_local=start_time)
    # don't need to pass args and url_doc as they are already shared
    compute_jaccard_iter = pool.imap(compute_jaccard_partial, lshcache.bins)

    print("multiprocessing init took {:.2f}".format(time.time() - start_time),\
        flush=True)
    for remove_urls_list, deduped_local, counter_local in compute_jaccard_iter:
        deduped += deduped_local
        counter += counter_local
        write_remove_urls_list(remove_urls_list, f_out)
        print(' [write]> processed {} documents in {:.2f} '
            'seoncds and deduped {} documents ...'.format(counter, time.time()\
            - start_time, deduped), flush=True)

    pool.close()
    pool.join()
    f_out.close()

    print(' Taken time for jaccard similariries {:.2f} seconds'.format(\
        time.time() - start_time), flush=True)

def find_pair_urls_sequential(args, lshcache, url_doc):
    start_time = time.time()
    f_out = open(args.output, 'wb')
    deduped, counter = 0, 0
    for b in lshcache.bins:
        for bucket_id in b:
            if len(b[bucket_id]) <= 1:
                continue

            bucket_urls = b[bucket_id].copy()
            remove_urls_list_sub, deduped_local_sub, counter_local_sub = \
                url_pairs_to_remove(args, bucket_urls, url_doc)

            deduped += deduped_local_sub
            counter += counter_local_sub
            write_remove_urls_list(remove_urls_list_sub, f_out)
            if counter % 10000 == 0:
                print(' [write]> processed {} documents in {:.2f} '
                    'seoncds and deduped {} documents ...'.
                    format(counter, time.time() - start_time,
                    deduped), flush=True)
    f_out.close()
    print(' [write]> processed {} documents in {:.2f} '
        'seoncds and deduped {} documents ...'.
        format(counter, time.time() - start_time,
        deduped), flush=True)

190
191
if __name__ == '__main__':

Mostofa Patwary's avatar
Mostofa Patwary committed
192
    print('parsing the arguments ...')
193
194

    parser = argparse.ArgumentParser()
Mostofa Patwary's avatar
Mostofa Patwary committed
195
196
    parser.add_argument('--seed', type=int, default=1234,
                       help='Random seed used for python, numpy')
197
198
199
200
201
202
    parser.add_argument('--inputs', nargs = '*', default=None, help = \
                        'Pairwise list of the input files and keys, '
                        'e.g. --inputs cc.json cc_id news.json news_id')
    parser.add_argument('--load-fingerprints', nargs = '*', default=None,
                       help='Load fingerprints from a list of pickle files,'
                        ' e.g. cc.pkl news.pkl')
203
204
    parser.add_argument('--save-fingerprints', type=str, default=None,
                       help='Save the fingerprints of the inputs.')
205
206
207
    parser.add_argument('--output', type=str, default=None,
                       help='Output file name that consists of all ids'
                        ' with matching similarities')
208
209
210
211
212
213
214
215
216
217
218
    parser.add_argument('--jaccard', type=str, default='union',
                        choices=['union', 'min', 'max'], help='Jaccard'\
                        ' similarity computation')
    parser.add_argument('--heuristic-iter', type=int, default=1,
                       help='Number of iterations to run the heuristics'
                        ': use -1 for exact')
    parser.add_argument('--num-bands', type=int, default=10,
                       help='Number of bands to use in cache')
    parser.add_argument('--num-seeds', type=int, default=100,
                       help='Number of seeds to use for minhash. Note that'
                        ' this value should be divisible by num-bands')
219
220
    parser.add_argument('--jaccard-parallel', action='store_true',
                       help='Use this to process large number of documents.')
221
    args = parser.parse_args()
222

223
    print('finding possible duplicate content ...')
224

225
    # set seed and get an array of seeds of 100 integers
Mostofa Patwary's avatar
Mostofa Patwary committed
226
    np.random.seed(args.seed)
227
    seeds = np.random.randint(0, 1e6, size=args.num_seeds)
228
229
230

    # initialize minhash and lsh cache
    hasher = minhash.MinHasher(seeds=seeds, char_ngram=5, hashbytes=4)
231
    lshcache = cache.Cache(num_bands=args.num_bands, hasher=hasher)
232
233

    url_doc = {}
234
235
236

    # load fingerprints from pickle file if needed
    if args.load_fingerprints is not None:
Mostofa Patwary's avatar
Mostofa Patwary committed
237
        for count_fp, fp_file_name in enumerate(args.load_fingerprints):
238
239
240
241
242
243
244
245
246
247
248
249
250
251
            print("Loading fingerprints from pickle file {}".format(
                fp_file_name), flush=True)
            fp = open(fp_file_name, "rb")
            if count_fp == 0:
                # assign directory for the first pkl
                lshcache = pickle.load(fp)
                url_doc = pickle.load(fp)
            else:
                # append these to lshcache and url_doc
                local_lshcache = pickle.load(fp)
                local_url_doc = pickle.load(fp)
                for url in local_lshcache.fingerprints.keys():
                    url_doc[url] = local_url_doc[url]
                    lshcache.add_fingerprint(local_lshcache.fingerprints[url], url)
Mostofa Patwary's avatar
Mostofa Patwary committed
252
            fp.close()
253
254

    counter = 0
255
    start_time = time.time()
256

257
    # compute finger prints of the inputs if any
Mostofa Patwary's avatar
Mostofa Patwary committed
258
259
    # input file and the key to use as id
    if args.inputs is not None:
260
        print("Computing fingerprints", flush=True)
Mostofa Patwary's avatar
Mostofa Patwary committed
261
262
263
264
        assert len(args.inputs) % 2 == 0
        for input_file, key in zip(args.inputs[::2], args.inputs[1::2]):
            print(' document processing {} with key {}'.format(input_file, key),
                flush=True)
265
266

            # compute fingerprints in parallel
267
            num_workers = 40
268
269
270
271
            pool = multiprocessing.Pool(num_workers)
            fin = open(input_file, 'r', encoding='utf-8')
            compute_fingerprint_partial = partial(compute_fingerprint, key=key)
            compute_fingerprint_iter = pool.imap(compute_fingerprint_partial,
272
                                                    fin, 512)
Mostofa Patwary's avatar
Mostofa Patwary committed
273
            # traverse all the texts and add fingerprints
274
275
276
277
278
279
280
281
282
283
284
285
286
            for url, text, fingerprint, flag in compute_fingerprint_iter:
                counter += 1
                if flag:
                    url_doc[url] = text
                    lshcache.add_fingerprint(fingerprint, url)
                if counter % 10000 == 0:
                    print(' [read]> processed {} documents in {:.2f} '
                        'seconds ...'.format(counter, time.time() - \
                        start_time), flush=True)

            fin.close()
            pool.close()
            pool.join()
287
288
289
290
291

    # Save the fingerprints if needed
    if args.save_fingerprints is not None:
        print("Saving fingerprints to pickle file {}".format(
            args.save_fingerprints), flush=True)
292
293
294
        with open(args.save_fingerprints, 'wb') as f_save:
            pickle.dump(lshcache, f_save)
            pickle.dump(url_doc, f_save)
295

296
297
    # compute jaccard index of the input texts and write to file if needed
    if args.output is not None:
298
299
300
301
302
        print("Compute jaccard similarity", flush=True)
        if args.jaccard_parallel:
            find_pair_urls_parallel(args, lshcache, url_doc)
        else:
            find_pair_urls_sequential(args, lshcache, url_doc)
303
304

    print('done :-)')
305