find_duplicates.py 6.59 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

16
import argparse
17
18
19
import itertools
import json
from lsh import cache, minhash
20
import numpy as np
21
import time
22
import pickle
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import sys

# This function is adapted from:
#   https://github.com/mattilyra/LSH/blob/master/examples/Introduction.ipynb
def shingles(text, char_ngram=5):
    return set(text[head:head + char_ngram]
               for head in range(0, len(text) - char_ngram))


# This function is adapted from:
#  https://github.com/mattilyra/LSH/blob/master/examples/Introduction.ipynb
def jaccard(set_a, set_b):
    intersection = set_a & set_b
    union = set_a | set_b
    return len(intersection) / len(union)


if __name__ == '__main__':

42
43
44
    print('parsing the inputs ...')

    parser = argparse.ArgumentParser()
45
46
47
48
49
50
    parser.add_argument('--inputs', nargs = '*', default=None, help = \
                        'Pairwise list of the input files and keys, '
                        'e.g. --inputs cc.json cc_id news.json news_id')
    parser.add_argument('--load-fingerprints', nargs = '*', default=None,
                       help='Load fingerprints from a list of pickle files,'
                        ' e.g. cc.pkl news.pkl')
51
52
    parser.add_argument('--save-fingerprints', type=str, default=None,
                       help='Save the fingerprints of the inputs.')
53
54
55
    parser.add_argument('--output', type=str, default=None,
                       help='Output file name that consists of all ids'
                        ' with matching similarities')
56
    args = parser.parse_args()
57

58
    print('finding possible duplicate content ...')
59

60
61
62
63
64
65
    # set seed and get an array of seeds of 100 integers
    np.random.seed(1234)
    seeds = np.random.randint(0, 1e6, size=100)

    # initialize minhash and lsh cache
    hasher = minhash.MinHasher(seeds=seeds, char_ngram=5, hashbytes=4)
66
67
68
    lshcache = cache.Cache(bands=10, hasher=hasher)

    url_doc = {}
69
70
71

    # load fingerprints from pickle file if needed
    if args.load_fingerprints is not None:
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
        count_fingerprints = len(args.load_fingerprints)

        for count_fp in range(count_fingerprints):
            fp_file_name = args.load_fingerprints[count_fp]
            print("Loading fingerprints from pickle file {}".format(
                fp_file_name), flush=True)
            fp = open(fp_file_name, "rb")
            if count_fp == 0:
                # assign directory for the first pkl
                lshcache = pickle.load(fp)
                url_doc = pickle.load(fp)
            else:
                # append these to lshcache and url_doc
                local_lshcache = pickle.load(fp)
                local_url_doc = pickle.load(fp)
                for url in local_lshcache.fingerprints.keys():
                    url_doc[url] = local_url_doc[url]
                    lshcache.add_fingerprint(local_lshcache.fingerprints[url], url)
90
91

    counter = 0
92
    start_time = time.time()
93
94
95

    print("Computing fingerprints", flush=True)

96
    # compute finger prints of the inputs if any
97
    input_pairs = 0 if args.inputs is None else int(len(args.inputs)/2)
98
99
100
101
    for input_pair in range(input_pairs):
        # input file and the key to use as id
        input_file = args.inputs[2 * input_pair]
        key = args.inputs[2 * input_pair + 1]
102
103
        print(' document processing {} with key {}'.format(input_file, key),
            flush=True)
104
105
106
        # traverse all the texts and add fingerprints
        with open(input_file, 'r') as f_input:
            for line in f_input:
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
                try:
                    myjson = json.loads(line)
                    url = myjson[key]
                    text = myjson['text']
                    counter += 1
                    url_doc[url] = text
                    lshcache.add_fingerprint(hasher.fingerprint(text), url)
                except Exception as e:
                    print('Error:', e)
                if counter % 10000 == 0:
                    print(' [read]> processed {} documents in {:.2f} '
                        'seconds ...'.format(counter, time.time() - \
                        start_time), flush=True)

    # Save the fingerprints if needed
    if args.save_fingerprints is not None:
        print("Saving fingerprints to pickle file {}".format(
            args.save_fingerprints), flush=True)
125
126
127
        with open(args.save_fingerprints, 'wb') as f_save:
            pickle.dump(lshcache, f_save)
            pickle.dump(url_doc, f_save)
128
129
130
131

    counter = 0
    start_time = time.time()
    deduped = 0
132
133
134
    # compute jaccard index of the input texts and write to file if needed
    if args.output is not None:
        f_out = open(args.output, 'wb')
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
        for b in lshcache.bins:
            for bucket_id in b:
                if len(b[bucket_id]) > 1:
                    items = list(b[bucket_id])
                    main_url = items[0]
                    main_dhingles = shingles(url_doc[main_url])
                    remove_urls = []
                    for i in range(1, len(items)):
                        counter += 1
                        other_url= items[i]
                        other_shingles = shingles(url_doc[other_url])
                        try:
                            jaccard_sim = jaccard(main_dhingles, other_shingles)
                        except Exception as e:
                            print('Error:', e)
                        if jaccard_sim > 0.5:
                            remove_urls.append({other_url: jaccard_sim})
                            deduped += 1
                        if counter % 10000 == 0:
                            print(' [write]> processed {} documents in {:.2f} '
                                  'seoncds and deduped {} documents ...'.
                                  format(counter, time.time() - start_time,
                                         deduped), flush=True)
                    if len(remove_urls) > 0:
                        myjson = json.dumps({main_url: remove_urls},
                                            ensure_ascii=False)
161
162
                        f_out.write(myjson.encode('utf-8'))
                        f_out.write('\n'.encode('utf-8'))
163
164

    print('done :-)')