find_duplicates.py 8.85 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

16
import argparse
17
from functools import partial
18
19
20
import itertools
import json
from lsh import cache, minhash
21
import multiprocessing
22
import numpy as np
23
import time
24
import pickle
25
26
27
28
29
30
31
32
33
34
35
import sys

# This function is adapted from:
#   https://github.com/mattilyra/LSH/blob/master/examples/Introduction.ipynb
def shingles(text, char_ngram=5):
    return set(text[head:head + char_ngram]
               for head in range(0, len(text) - char_ngram))


# This function is adapted from:
#  https://github.com/mattilyra/LSH/blob/master/examples/Introduction.ipynb
36
37
38
39
def jaccard(set_a, set_b, args):
    if len(set_a) < 1 or len(set_b) < 1:
        return 0.0

40
41
42
    intersection = set_a & set_b
    union = set_a | set_b

43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
    if args.jaccard == 'min':
        return len(intersection) / min(len(set_a), len(set_b))
    elif args.jaccard == 'max':
        return len(intersection) / max(len(set_a), len(set_b))
    else:
        return len(intersection) / len(union)

def compute_fingerprint(line, key):
    try:
        myjson = json.loads(line)
        url = myjson[key]
        text = myjson['text']
        fingerprint = hasher.fingerprint(text)
    except Exception as e:
        print('Error:', e)
        return None, None, None, False

    return url, text, fingerprint, True
61
62
63

if __name__ == '__main__':

Mostofa Patwary's avatar
Mostofa Patwary committed
64
    print('parsing the arguments ...')
65
66

    parser = argparse.ArgumentParser()
Mostofa Patwary's avatar
Mostofa Patwary committed
67
68
    parser.add_argument('--seed', type=int, default=1234,
                       help='Random seed used for python, numpy')
69
70
71
72
73
74
    parser.add_argument('--inputs', nargs = '*', default=None, help = \
                        'Pairwise list of the input files and keys, '
                        'e.g. --inputs cc.json cc_id news.json news_id')
    parser.add_argument('--load-fingerprints', nargs = '*', default=None,
                       help='Load fingerprints from a list of pickle files,'
                        ' e.g. cc.pkl news.pkl')
75
76
    parser.add_argument('--save-fingerprints', type=str, default=None,
                       help='Save the fingerprints of the inputs.')
77
78
79
    parser.add_argument('--output', type=str, default=None,
                       help='Output file name that consists of all ids'
                        ' with matching similarities')
80
81
82
83
84
85
86
87
88
89
90
91
    parser.add_argument('--jaccard', type=str, default='union',
                        choices=['union', 'min', 'max'], help='Jaccard'\
                        ' similarity computation')
    parser.add_argument('--heuristic-iter', type=int, default=1,
                       help='Number of iterations to run the heuristics'
                        ': use -1 for exact')
    parser.add_argument('--num-bands', type=int, default=10,
                       help='Number of bands to use in cache')
    parser.add_argument('--num-seeds', type=int, default=100,
                       help='Number of seeds to use for minhash. Note that'
                        ' this value should be divisible by num-bands')

92
    args = parser.parse_args()
93

94
    print('finding possible duplicate content ...')
95

96
    # set seed and get an array of seeds of 100 integers
Mostofa Patwary's avatar
Mostofa Patwary committed
97
    np.random.seed(args.seed)
98
    seeds = np.random.randint(0, 1e6, size=args.num_seeds)
99
100
101

    # initialize minhash and lsh cache
    hasher = minhash.MinHasher(seeds=seeds, char_ngram=5, hashbytes=4)
102
    lshcache = cache.Cache(num_bands=args.num_bands, hasher=hasher)
103
104

    url_doc = {}
105
106
107

    # load fingerprints from pickle file if needed
    if args.load_fingerprints is not None:
Mostofa Patwary's avatar
Mostofa Patwary committed
108
        for count_fp, fp_file_name in enumerate(args.load_fingerprints):
109
110
111
112
113
114
115
116
117
118
119
120
121
122
            print("Loading fingerprints from pickle file {}".format(
                fp_file_name), flush=True)
            fp = open(fp_file_name, "rb")
            if count_fp == 0:
                # assign directory for the first pkl
                lshcache = pickle.load(fp)
                url_doc = pickle.load(fp)
            else:
                # append these to lshcache and url_doc
                local_lshcache = pickle.load(fp)
                local_url_doc = pickle.load(fp)
                for url in local_lshcache.fingerprints.keys():
                    url_doc[url] = local_url_doc[url]
                    lshcache.add_fingerprint(local_lshcache.fingerprints[url], url)
Mostofa Patwary's avatar
Mostofa Patwary committed
123
            fp.close()
124
125

    counter = 0
126
    start_time = time.time()
127
128
129

    print("Computing fingerprints", flush=True)

130
    # compute finger prints of the inputs if any
Mostofa Patwary's avatar
Mostofa Patwary committed
131
132
133
134
135
136
    # input file and the key to use as id
    if args.inputs is not None:
        assert len(args.inputs) % 2 == 0
        for input_file, key in zip(args.inputs[::2], args.inputs[1::2]):
            print(' document processing {} with key {}'.format(input_file, key),
                flush=True)
137
138
139
140
141
142
143
144

            # compute fingerprints in parallel
            num_workers = 20
            pool = multiprocessing.Pool(num_workers)
            fin = open(input_file, 'r', encoding='utf-8')
            compute_fingerprint_partial = partial(compute_fingerprint, key=key)
            compute_fingerprint_iter = pool.imap(compute_fingerprint_partial,
                                                    fin, 500)
Mostofa Patwary's avatar
Mostofa Patwary committed
145
            # traverse all the texts and add fingerprints
146
147
148
149
150
151
152
153
154
155
156
157
158
            for url, text, fingerprint, flag in compute_fingerprint_iter:
                counter += 1
                if flag:
                    url_doc[url] = text
                    lshcache.add_fingerprint(fingerprint, url)
                if counter % 10000 == 0:
                    print(' [read]> processed {} documents in {:.2f} '
                        'seconds ...'.format(counter, time.time() - \
                        start_time), flush=True)

            fin.close()
            pool.close()
            pool.join()
159
160
161
162
163

    # Save the fingerprints if needed
    if args.save_fingerprints is not None:
        print("Saving fingerprints to pickle file {}".format(
            args.save_fingerprints), flush=True)
164
165
166
        with open(args.save_fingerprints, 'wb') as f_save:
            pickle.dump(lshcache, f_save)
            pickle.dump(url_doc, f_save)
167
168
169
170

    counter = 0
    start_time = time.time()
    deduped = 0
171
172
173
    # compute jaccard index of the input texts and write to file if needed
    if args.output is not None:
        f_out = open(args.output, 'wb')
174
175
        for b in lshcache.bins:
            for bucket_id in b:
176
177
178
179
180
181
182
183
184
185
186
                if len(b[bucket_id]) <= 1:
                    continue

                bucket_urls = b[bucket_id].copy()
                iteration = 0
                while len(bucket_urls) > 1:
                    if args.heuristic_iter != -1 and \
                        iteration == args.heuristic_iter:
                        break

                    items = list(bucket_urls)
187
                    remove_urls = []
188
189
190
191
                    main_url = items[np.random.randint(0, len(items))]
                    main_dhingles = shingles(url_doc[main_url])

                    for i in range(0, len(items)):
192
193
                        counter += 1
                        other_url= items[i]
194
195
                        if other_url == main_url:
                            continue
196
197
                        other_shingles = shingles(url_doc[other_url])
                        try:
198
199
                            jaccard_sim = jaccard(main_dhingles, other_shingles,
                                                    args)
200
201
                        except Exception as e:
                            print('Error:', e)
202
                            jaccard_sim = 0.0
203
204
205
                        if jaccard_sim > 0.5:
                            remove_urls.append({other_url: jaccard_sim})
                            deduped += 1
206
                            bucket_urls.remove(other_url)
207
208
                        if counter % 10000 == 0:
                            print(' [write]> processed {} documents in {:.2f} '
209
210
211
212
213
                                'seoncds and deduped {} documents ...'.
                                format(counter, time.time() - start_time,
                                deduped), flush=True)

                    bucket_urls.remove(main_url)
214
215
                    if len(remove_urls) > 0:
                        myjson = json.dumps({main_url: remove_urls},
216
                                        ensure_ascii=False)
217
218
                        f_out.write(myjson.encode('utf-8'))
                        f_out.write('\n'.encode('utf-8'))
219
220
                    iteration += 1

Mostofa Patwary's avatar
Mostofa Patwary committed
221
        f_out.close()
222
223

    print('done :-)')
224