Commit a1f4032b authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'main_dedup' into 'main'

Storing and loading fingerprints of in deduplication

See merge request ADLR/megatron-lm!236
parents 90623142 ebc4a408
......@@ -26,9 +26,9 @@ python blacklist_urls.py <path to the dowloaded deduplicated URLs> <filename for
```
python cleanup_dataset.py <input data file> <output cleaned data filename>
```
2. Using LSH, find possible duplicates and store then in a file for later processing. This step can NOT be sharded and usually takes 12 to 24 hours for OpenWebText dataset.
2. Using LSH, find possible duplicates and store then in a file for later processing. This step can NOT be sharded and usually takes 12 to 24 hours for OpenWebText dataset. The code supports saving and loading fingerprints for recurrent deduplications.
```
python find_duplicates.py <input cleaned data file> <output possible duplicate urls filename>
python find_duplicates.py --inputs <pairlist list of input cleaned data files and keys, e.g. cc.json cc_id news.json news_id> --output <output possible duplicate urls filename>
```
3. Based on similarity measure defind inside function `is_similar` (default: 0.9), group urls that are similar. Basically, for each group, only one url we should keep and remove the rest.
```
......
......@@ -13,14 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import itertools
import json
from lsh import cache, minhash
import numpy as np
import time
import pickle
import sys
# This function is adapted from:
# https://github.com/mattilyra/LSH/blob/master/examples/Introduction.ipynb
def shingles(text, char_ngram=5):
......@@ -38,36 +39,98 @@ def jaccard(set_a, set_b):
if __name__ == '__main__':
print('parsing the arguments ...')
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=1234,
help='Random seed used for python, numpy')
parser.add_argument('--inputs', nargs = '*', default=None, help = \
'Pairwise list of the input files and keys, '
'e.g. --inputs cc.json cc_id news.json news_id')
parser.add_argument('--load-fingerprints', nargs = '*', default=None,
help='Load fingerprints from a list of pickle files,'
' e.g. cc.pkl news.pkl')
parser.add_argument('--save-fingerprints', type=str, default=None,
help='Save the fingerprints of the inputs.')
parser.add_argument('--output', type=str, default=None,
help='Output file name that consists of all ids'
' with matching similarities')
args = parser.parse_args()
print('finding possible duplicate content ...')
input = sys.argv[1]
output = sys.argv[2]
# set seed and get an array of seeds of 100 integers
np.random.seed(args.seed)
seeds = np.random.randint(0, 1e6, size=100)
hasher = minhash.MinHasher(seeds=100, char_ngram=5, hashbytes=4)
# initialize minhash and lsh cache
hasher = minhash.MinHasher(seeds=seeds, char_ngram=5, hashbytes=4)
lshcache = cache.Cache(bands=10, hasher=hasher)
counter = 0
url_doc = {}
# load fingerprints from pickle file if needed
if args.load_fingerprints is not None:
for count_fp, fp_file_name in enumerate(args.load_fingerprints):
print("Loading fingerprints from pickle file {}".format(
fp_file_name), flush=True)
fp = open(fp_file_name, "rb")
if count_fp == 0:
# assign directory for the first pkl
lshcache = pickle.load(fp)
url_doc = pickle.load(fp)
else:
# append these to lshcache and url_doc
local_lshcache = pickle.load(fp)
local_url_doc = pickle.load(fp)
for url in local_lshcache.fingerprints.keys():
url_doc[url] = local_url_doc[url]
lshcache.add_fingerprint(local_lshcache.fingerprints[url], url)
fp.close()
counter = 0
start_time = time.time()
with open(input, 'r') as f:
for line in f:
try:
myjson = json.loads(line)
url = myjson['url']
text = myjson['text']
counter += 1
url_doc[url] = text
lshcache.add_fingerprint(hasher.fingerprint(text), url)
except Exception as e:
print('Error:', e)
if counter % 10000 == 0:
print(' [read]> processed {} documents in {:.2f} seconds ...'.
format(counter, time.time() - start_time), flush=True)
print("Computing fingerprints", flush=True)
# compute finger prints of the inputs if any
# input file and the key to use as id
if args.inputs is not None:
assert len(args.inputs) % 2 == 0
for input_file, key in zip(args.inputs[::2], args.inputs[1::2]):
print(' document processing {} with key {}'.format(input_file, key),
flush=True)
# traverse all the texts and add fingerprints
with open(input_file, 'r') as f_input:
for line in f_input:
try:
myjson = json.loads(line)
url = myjson[key]
text = myjson['text']
counter += 1
url_doc[url] = text
lshcache.add_fingerprint(hasher.fingerprint(text), url)
except Exception as e:
print('Error:', e)
if counter % 10000 == 0:
print(' [read]> processed {} documents in {:.2f} '
'seconds ...'.format(counter, time.time() - \
start_time), flush=True)
# Save the fingerprints if needed
if args.save_fingerprints is not None:
print("Saving fingerprints to pickle file {}".format(
args.save_fingerprints), flush=True)
with open(args.save_fingerprints, 'wb') as f_save:
pickle.dump(lshcache, f_save)
pickle.dump(url_doc, f_save)
counter = 0
start_time = time.time()
deduped = 0
with open(output, 'wb') as f:
# compute jaccard index of the input texts and write to file if needed
if args.output is not None:
f_out = open(args.output, 'wb')
for b in lshcache.bins:
for bucket_id in b:
if len(b[bucket_id]) > 1:
......@@ -94,7 +157,8 @@ if __name__ == '__main__':
if len(remove_urls) > 0:
myjson = json.dumps({main_url: remove_urls},
ensure_ascii=False)
f.write(myjson.encode('utf-8'))
f.write('\n'.encode('utf-8'))
f_out.write(myjson.encode('utf-8'))
f_out.write('\n'.encode('utf-8'))
f_out.close()
print('done :-)')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment