Commit 882683dc authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

Added more feature in train data deduplication

parent 47e7eb80
import argparse
import json
import os
import time
if __name__ == '__main__':
print('parsing the arguments ...')
parser = argparse.ArgumentParser()
parser.add_argument('--input-file', type=str, default=None, help='Input'\
' json file where id needs to be added')
parser.add_argument('--output-file', type=str, default=None, help=\
'Output file name with id')
parser.add_argument('--id-prefix', type=str, default=None, help=\
'Id prefix')
parser.add_argument('--log-interval', type=int, default=100,
help='Log interval')
args = parser.parse_args()
print('Adding ids to dataset ...')
f_input = open(args.input_file, 'r', encoding='utf-8')
f_output = open(args.output_file, 'wb')
unique_ids = 1
start_time = time.time()
for row in f_input:
each_row = json.loads(row)
adlr_id_string = args.id_prefix + '-{:010d}'.format(int(unique_ids))
each_row['adlr_id'] = adlr_id_string
myjson = json.dumps(each_row, ensure_ascii=False)
f_output.write(myjson.encode('utf-8'))
f_output.write('\n'.encode('utf-8'))
if unique_ids % args.log_interval == 0:
print(' processed {:9d} documents in {:.2f} seconds ...'.format( \
unique_ids, time.time() - start_time), flush=True)
unique_ids += 1
# Close the file.
f_input.close()
f_output.close()
print('done :-)', flush=True)
...@@ -14,9 +14,11 @@ ...@@ -14,9 +14,11 @@
# limitations under the License. # limitations under the License.
import argparse import argparse
from functools import partial
import itertools import itertools
import json import json
from lsh import cache, minhash from lsh import cache, minhash
import multiprocessing
import numpy as np import numpy as np
import time import time
import pickle import pickle
...@@ -31,11 +33,31 @@ def shingles(text, char_ngram=5): ...@@ -31,11 +33,31 @@ def shingles(text, char_ngram=5):
# This function is adapted from: # This function is adapted from:
# https://github.com/mattilyra/LSH/blob/master/examples/Introduction.ipynb # https://github.com/mattilyra/LSH/blob/master/examples/Introduction.ipynb
def jaccard(set_a, set_b): def jaccard(set_a, set_b, args):
if len(set_a) < 1 or len(set_b) < 1:
return 0.0
intersection = set_a & set_b intersection = set_a & set_b
union = set_a | set_b union = set_a | set_b
return len(intersection) / len(union)
if args.jaccard == 'min':
return len(intersection) / min(len(set_a), len(set_b))
elif args.jaccard == 'max':
return len(intersection) / max(len(set_a), len(set_b))
else:
return len(intersection) / len(union)
def compute_fingerprint(line, key):
try:
myjson = json.loads(line)
url = myjson[key]
text = myjson['text']
fingerprint = hasher.fingerprint(text)
except Exception as e:
print('Error:', e)
return None, None, None, False
return url, text, fingerprint, True
if __name__ == '__main__': if __name__ == '__main__':
...@@ -55,17 +77,29 @@ if __name__ == '__main__': ...@@ -55,17 +77,29 @@ if __name__ == '__main__':
parser.add_argument('--output', type=str, default=None, parser.add_argument('--output', type=str, default=None,
help='Output file name that consists of all ids' help='Output file name that consists of all ids'
' with matching similarities') ' with matching similarities')
parser.add_argument('--jaccard', type=str, default='union',
choices=['union', 'min', 'max'], help='Jaccard'\
' similarity computation')
parser.add_argument('--heuristic-iter', type=int, default=1,
help='Number of iterations to run the heuristics'
': use -1 for exact')
parser.add_argument('--num-bands', type=int, default=10,
help='Number of bands to use in cache')
parser.add_argument('--num-seeds', type=int, default=100,
help='Number of seeds to use for minhash. Note that'
' this value should be divisible by num-bands')
args = parser.parse_args() args = parser.parse_args()
print('finding possible duplicate content ...') print('finding possible duplicate content ...')
# set seed and get an array of seeds of 100 integers # set seed and get an array of seeds of 100 integers
np.random.seed(args.seed) np.random.seed(args.seed)
seeds = np.random.randint(0, 1e6, size=100) seeds = np.random.randint(0, 1e6, size=args.num_seeds)
# initialize minhash and lsh cache # initialize minhash and lsh cache
hasher = minhash.MinHasher(seeds=seeds, char_ngram=5, hashbytes=4) hasher = minhash.MinHasher(seeds=seeds, char_ngram=5, hashbytes=4)
lshcache = cache.Cache(bands=10, hasher=hasher) lshcache = cache.Cache(num_bands=args.num_bands, hasher=hasher)
url_doc = {} url_doc = {}
...@@ -100,22 +134,28 @@ if __name__ == '__main__': ...@@ -100,22 +134,28 @@ if __name__ == '__main__':
for input_file, key in zip(args.inputs[::2], args.inputs[1::2]): for input_file, key in zip(args.inputs[::2], args.inputs[1::2]):
print(' document processing {} with key {}'.format(input_file, key), print(' document processing {} with key {}'.format(input_file, key),
flush=True) flush=True)
# compute fingerprints in parallel
num_workers = 20
pool = multiprocessing.Pool(num_workers)
fin = open(input_file, 'r', encoding='utf-8')
compute_fingerprint_partial = partial(compute_fingerprint, key=key)
compute_fingerprint_iter = pool.imap(compute_fingerprint_partial,
fin, 500)
# traverse all the texts and add fingerprints # traverse all the texts and add fingerprints
with open(input_file, 'r') as f_input: for url, text, fingerprint, flag in compute_fingerprint_iter:
for line in f_input: counter += 1
try: if flag:
myjson = json.loads(line) url_doc[url] = text
url = myjson[key] lshcache.add_fingerprint(fingerprint, url)
text = myjson['text'] if counter % 10000 == 0:
counter += 1 print(' [read]> processed {} documents in {:.2f} '
url_doc[url] = text 'seconds ...'.format(counter, time.time() - \
lshcache.add_fingerprint(hasher.fingerprint(text), url) start_time), flush=True)
except Exception as e:
print('Error:', e) fin.close()
if counter % 10000 == 0: pool.close()
print(' [read]> processed {} documents in {:.2f} ' pool.join()
'seconds ...'.format(counter, time.time() - \
start_time), flush=True)
# Save the fingerprints if needed # Save the fingerprints if needed
if args.save_fingerprints is not None: if args.save_fingerprints is not None:
...@@ -133,32 +173,52 @@ if __name__ == '__main__': ...@@ -133,32 +173,52 @@ if __name__ == '__main__':
f_out = open(args.output, 'wb') f_out = open(args.output, 'wb')
for b in lshcache.bins: for b in lshcache.bins:
for bucket_id in b: for bucket_id in b:
if len(b[bucket_id]) > 1: if len(b[bucket_id]) <= 1:
items = list(b[bucket_id]) continue
main_url = items[0]
main_dhingles = shingles(url_doc[main_url]) bucket_urls = b[bucket_id].copy()
iteration = 0
while len(bucket_urls) > 1:
if args.heuristic_iter != -1 and \
iteration == args.heuristic_iter:
break
items = list(bucket_urls)
remove_urls = [] remove_urls = []
for i in range(1, len(items)): main_url = items[np.random.randint(0, len(items))]
main_dhingles = shingles(url_doc[main_url])
for i in range(0, len(items)):
counter += 1 counter += 1
other_url= items[i] other_url= items[i]
if other_url == main_url:
continue
other_shingles = shingles(url_doc[other_url]) other_shingles = shingles(url_doc[other_url])
try: try:
jaccard_sim = jaccard(main_dhingles, other_shingles) jaccard_sim = jaccard(main_dhingles, other_shingles,
args)
except Exception as e: except Exception as e:
print('Error:', e) print('Error:', e)
jaccard_sim = 0.0
if jaccard_sim > 0.5: if jaccard_sim > 0.5:
remove_urls.append({other_url: jaccard_sim}) remove_urls.append({other_url: jaccard_sim})
deduped += 1 deduped += 1
bucket_urls.remove(other_url)
if counter % 10000 == 0: if counter % 10000 == 0:
print(' [write]> processed {} documents in {:.2f} ' print(' [write]> processed {} documents in {:.2f} '
'seoncds and deduped {} documents ...'. 'seoncds and deduped {} documents ...'.
format(counter, time.time() - start_time, format(counter, time.time() - start_time,
deduped), flush=True) deduped), flush=True)
bucket_urls.remove(main_url)
if len(remove_urls) > 0: if len(remove_urls) > 0:
myjson = json.dumps({main_url: remove_urls}, myjson = json.dumps({main_url: remove_urls},
ensure_ascii=False) ensure_ascii=False)
f_out.write(myjson.encode('utf-8')) f_out.write(myjson.encode('utf-8'))
f_out.write('\n'.encode('utf-8')) f_out.write('\n'.encode('utf-8'))
iteration += 1
f_out.close() f_out.close()
print('done :-)') print('done :-)')
...@@ -18,10 +18,6 @@ import time ...@@ -18,10 +18,6 @@ import time
import sys import sys
def is_similar(jaccard_similarity):
return (jaccard_similarity >= 0.7)
if __name__ == '__main__': if __name__ == '__main__':
...@@ -29,6 +25,10 @@ if __name__ == '__main__': ...@@ -29,6 +25,10 @@ if __name__ == '__main__':
input = sys.argv[1] input = sys.argv[1]
output = sys.argv[2] output = sys.argv[2]
if len(sys.argv) > 3:
jaccard_similarity_threshold = float(sys.argv[3])
else:
jaccard_similarity_threshold = 0.7
url_to_index = {} url_to_index = {}
index_to_urls = [] index_to_urls = []
...@@ -43,7 +43,7 @@ if __name__ == '__main__': ...@@ -43,7 +43,7 @@ if __name__ == '__main__':
urls.append(main_url) urls.append(main_url)
for value in myjson[main_url]: for value in myjson[main_url]:
for other_url, js in value.items(): for other_url, js in value.items():
if is_similar(js): if js >= jaccard_similarity_threshold:
urls.append(other_url) urls.append(other_url)
current_index = -1 current_index = -1
other_indices = set() other_indices = set()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment