Commit 882683dc authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

Added more feature in train data deduplication

parent 47e7eb80
import argparse
import json
import os
import time
if __name__ == '__main__':
print('parsing the arguments ...')
parser = argparse.ArgumentParser()
parser.add_argument('--input-file', type=str, default=None, help='Input'\
' json file where id needs to be added')
parser.add_argument('--output-file', type=str, default=None, help=\
'Output file name with id')
parser.add_argument('--id-prefix', type=str, default=None, help=\
'Id prefix')
parser.add_argument('--log-interval', type=int, default=100,
help='Log interval')
args = parser.parse_args()
print('Adding ids to dataset ...')
f_input = open(args.input_file, 'r', encoding='utf-8')
f_output = open(args.output_file, 'wb')
unique_ids = 1
start_time = time.time()
for row in f_input:
each_row = json.loads(row)
adlr_id_string = args.id_prefix + '-{:010d}'.format(int(unique_ids))
each_row['adlr_id'] = adlr_id_string
myjson = json.dumps(each_row, ensure_ascii=False)
f_output.write(myjson.encode('utf-8'))
f_output.write('\n'.encode('utf-8'))
if unique_ids % args.log_interval == 0:
print(' processed {:9d} documents in {:.2f} seconds ...'.format( \
unique_ids, time.time() - start_time), flush=True)
unique_ids += 1
# Close the file.
f_input.close()
f_output.close()
print('done :-)', flush=True)
......@@ -14,9 +14,11 @@
# limitations under the License.
import argparse
from functools import partial
import itertools
import json
from lsh import cache, minhash
import multiprocessing
import numpy as np
import time
import pickle
......@@ -31,11 +33,31 @@ def shingles(text, char_ngram=5):
# This function is adapted from:
# https://github.com/mattilyra/LSH/blob/master/examples/Introduction.ipynb
def jaccard(set_a, set_b):
def jaccard(set_a, set_b, args):
if len(set_a) < 1 or len(set_b) < 1:
return 0.0
intersection = set_a & set_b
union = set_a | set_b
return len(intersection) / len(union)
if args.jaccard == 'min':
return len(intersection) / min(len(set_a), len(set_b))
elif args.jaccard == 'max':
return len(intersection) / max(len(set_a), len(set_b))
else:
return len(intersection) / len(union)
def compute_fingerprint(line, key):
try:
myjson = json.loads(line)
url = myjson[key]
text = myjson['text']
fingerprint = hasher.fingerprint(text)
except Exception as e:
print('Error:', e)
return None, None, None, False
return url, text, fingerprint, True
if __name__ == '__main__':
......@@ -55,17 +77,29 @@ if __name__ == '__main__':
parser.add_argument('--output', type=str, default=None,
help='Output file name that consists of all ids'
' with matching similarities')
parser.add_argument('--jaccard', type=str, default='union',
choices=['union', 'min', 'max'], help='Jaccard'\
' similarity computation')
parser.add_argument('--heuristic-iter', type=int, default=1,
help='Number of iterations to run the heuristics'
': use -1 for exact')
parser.add_argument('--num-bands', type=int, default=10,
help='Number of bands to use in cache')
parser.add_argument('--num-seeds', type=int, default=100,
help='Number of seeds to use for minhash. Note that'
' this value should be divisible by num-bands')
args = parser.parse_args()
print('finding possible duplicate content ...')
# set seed and get an array of seeds of 100 integers
np.random.seed(args.seed)
seeds = np.random.randint(0, 1e6, size=100)
seeds = np.random.randint(0, 1e6, size=args.num_seeds)
# initialize minhash and lsh cache
hasher = minhash.MinHasher(seeds=seeds, char_ngram=5, hashbytes=4)
lshcache = cache.Cache(bands=10, hasher=hasher)
lshcache = cache.Cache(num_bands=args.num_bands, hasher=hasher)
url_doc = {}
......@@ -100,22 +134,28 @@ if __name__ == '__main__':
for input_file, key in zip(args.inputs[::2], args.inputs[1::2]):
print(' document processing {} with key {}'.format(input_file, key),
flush=True)
# compute fingerprints in parallel
num_workers = 20
pool = multiprocessing.Pool(num_workers)
fin = open(input_file, 'r', encoding='utf-8')
compute_fingerprint_partial = partial(compute_fingerprint, key=key)
compute_fingerprint_iter = pool.imap(compute_fingerprint_partial,
fin, 500)
# traverse all the texts and add fingerprints
with open(input_file, 'r') as f_input:
for line in f_input:
try:
myjson = json.loads(line)
url = myjson[key]
text = myjson['text']
counter += 1
url_doc[url] = text
lshcache.add_fingerprint(hasher.fingerprint(text), url)
except Exception as e:
print('Error:', e)
if counter % 10000 == 0:
print(' [read]> processed {} documents in {:.2f} '
'seconds ...'.format(counter, time.time() - \
start_time), flush=True)
for url, text, fingerprint, flag in compute_fingerprint_iter:
counter += 1
if flag:
url_doc[url] = text
lshcache.add_fingerprint(fingerprint, url)
if counter % 10000 == 0:
print(' [read]> processed {} documents in {:.2f} '
'seconds ...'.format(counter, time.time() - \
start_time), flush=True)
fin.close()
pool.close()
pool.join()
# Save the fingerprints if needed
if args.save_fingerprints is not None:
......@@ -133,32 +173,52 @@ if __name__ == '__main__':
f_out = open(args.output, 'wb')
for b in lshcache.bins:
for bucket_id in b:
if len(b[bucket_id]) > 1:
items = list(b[bucket_id])
main_url = items[0]
main_dhingles = shingles(url_doc[main_url])
if len(b[bucket_id]) <= 1:
continue
bucket_urls = b[bucket_id].copy()
iteration = 0
while len(bucket_urls) > 1:
if args.heuristic_iter != -1 and \
iteration == args.heuristic_iter:
break
items = list(bucket_urls)
remove_urls = []
for i in range(1, len(items)):
main_url = items[np.random.randint(0, len(items))]
main_dhingles = shingles(url_doc[main_url])
for i in range(0, len(items)):
counter += 1
other_url= items[i]
if other_url == main_url:
continue
other_shingles = shingles(url_doc[other_url])
try:
jaccard_sim = jaccard(main_dhingles, other_shingles)
jaccard_sim = jaccard(main_dhingles, other_shingles,
args)
except Exception as e:
print('Error:', e)
jaccard_sim = 0.0
if jaccard_sim > 0.5:
remove_urls.append({other_url: jaccard_sim})
deduped += 1
bucket_urls.remove(other_url)
if counter % 10000 == 0:
print(' [write]> processed {} documents in {:.2f} '
'seoncds and deduped {} documents ...'.
format(counter, time.time() - start_time,
deduped), flush=True)
'seoncds and deduped {} documents ...'.
format(counter, time.time() - start_time,
deduped), flush=True)
bucket_urls.remove(main_url)
if len(remove_urls) > 0:
myjson = json.dumps({main_url: remove_urls},
ensure_ascii=False)
ensure_ascii=False)
f_out.write(myjson.encode('utf-8'))
f_out.write('\n'.encode('utf-8'))
iteration += 1
f_out.close()
print('done :-)')
......@@ -18,10 +18,6 @@ import time
import sys
def is_similar(jaccard_similarity):
return (jaccard_similarity >= 0.7)
if __name__ == '__main__':
......@@ -29,6 +25,10 @@ if __name__ == '__main__':
input = sys.argv[1]
output = sys.argv[2]
if len(sys.argv) > 3:
jaccard_similarity_threshold = float(sys.argv[3])
else:
jaccard_similarity_threshold = 0.7
url_to_index = {}
index_to_urls = []
......@@ -43,7 +43,7 @@ if __name__ == '__main__':
urls.append(main_url)
for value in myjson[main_url]:
for other_url, js in value.items():
if is_similar(js):
if js >= jaccard_similarity_threshold:
urls.append(other_url)
current_index = -1
other_indices = set()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment