Commit 60c95ab6 authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

Merge branch 'main' into main_retriver_merge_ict_eval

parents bcd605f8 a1f4032b
...@@ -123,9 +123,6 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -123,9 +123,6 @@ def parse_args(extra_args_provider=None, defaults={},
args.virtual_pipeline_model_parallel_size = \ args.virtual_pipeline_model_parallel_size = \
(args.num_layers // args.pipeline_model_parallel_size) // \ (args.num_layers // args.pipeline_model_parallel_size) // \
args.num_layers_per_virtual_pipeline_stage args.num_layers_per_virtual_pipeline_stage
assert args.global_batch_size % args.pipeline_model_parallel_size == 0, \
'global batch size is not divisible by pipeline parallel size when ' \
'using interleaved schedule'
else: else:
args.virtual_pipeline_model_parallel_size = None args.virtual_pipeline_model_parallel_size = None
......
...@@ -338,6 +338,9 @@ def train_step(forward_step_func, data_iterator, ...@@ -338,6 +338,9 @@ def train_step(forward_step_func, data_iterator,
if mpu.get_pipeline_model_parallel_world_size() > 1: if mpu.get_pipeline_model_parallel_world_size() > 1:
if args.virtual_pipeline_model_parallel_size is not None: if args.virtual_pipeline_model_parallel_size is not None:
forward_backward_func = forward_backward_pipelining_with_interleaving forward_backward_func = forward_backward_pipelining_with_interleaving
assert get_num_microbatches() % args.pipeline_model_parallel_size == 0, \
'number of microbatches is not divisible by pipeline-parallel ' \
'size when using interleaved schedule'
else: else:
forward_backward_func = forward_backward_pipelining_without_interleaving forward_backward_func = forward_backward_pipelining_without_interleaving
else: else:
......
...@@ -26,9 +26,9 @@ python blacklist_urls.py <path to the dowloaded deduplicated URLs> <filename for ...@@ -26,9 +26,9 @@ python blacklist_urls.py <path to the dowloaded deduplicated URLs> <filename for
``` ```
python cleanup_dataset.py <input data file> <output cleaned data filename> python cleanup_dataset.py <input data file> <output cleaned data filename>
``` ```
2. Using LSH, find possible duplicates and store then in a file for later processing. This step can NOT be sharded and usually takes 12 to 24 hours for OpenWebText dataset. 2. Using LSH, find possible duplicates and store then in a file for later processing. This step can NOT be sharded and usually takes 12 to 24 hours for OpenWebText dataset. The code supports saving and loading fingerprints for recurrent deduplications.
``` ```
python find_duplicates.py <input cleaned data file> <output possible duplicate urls filename> python find_duplicates.py --inputs <pairlist list of input cleaned data files and keys, e.g. cc.json cc_id news.json news_id> --output <output possible duplicate urls filename>
``` ```
3. Based on similarity measure defind inside function `is_similar` (default: 0.9), group urls that are similar. Basically, for each group, only one url we should keep and remove the rest. 3. Based on similarity measure defind inside function `is_similar` (default: 0.9), group urls that are similar. Basically, for each group, only one url we should keep and remove the rest.
``` ```
......
...@@ -13,14 +13,15 @@ ...@@ -13,14 +13,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse
import itertools import itertools
import json import json
from lsh import cache, minhash from lsh import cache, minhash
import numpy as np
import time import time
import pickle
import sys import sys
# This function is adapted from: # This function is adapted from:
# https://github.com/mattilyra/LSH/blob/master/examples/Introduction.ipynb # https://github.com/mattilyra/LSH/blob/master/examples/Introduction.ipynb
def shingles(text, char_ngram=5): def shingles(text, char_ngram=5):
...@@ -38,36 +39,98 @@ def jaccard(set_a, set_b): ...@@ -38,36 +39,98 @@ def jaccard(set_a, set_b):
if __name__ == '__main__': if __name__ == '__main__':
print('parsing the arguments ...')
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=1234,
help='Random seed used for python, numpy')
parser.add_argument('--inputs', nargs = '*', default=None, help = \
'Pairwise list of the input files and keys, '
'e.g. --inputs cc.json cc_id news.json news_id')
parser.add_argument('--load-fingerprints', nargs = '*', default=None,
help='Load fingerprints from a list of pickle files,'
' e.g. cc.pkl news.pkl')
parser.add_argument('--save-fingerprints', type=str, default=None,
help='Save the fingerprints of the inputs.')
parser.add_argument('--output', type=str, default=None,
help='Output file name that consists of all ids'
' with matching similarities')
args = parser.parse_args()
print('finding possible duplicate content ...') print('finding possible duplicate content ...')
input = sys.argv[1] # set seed and get an array of seeds of 100 integers
output = sys.argv[2] np.random.seed(args.seed)
seeds = np.random.randint(0, 1e6, size=100)
hasher = minhash.MinHasher(seeds=100, char_ngram=5, hashbytes=4) # initialize minhash and lsh cache
hasher = minhash.MinHasher(seeds=seeds, char_ngram=5, hashbytes=4)
lshcache = cache.Cache(bands=10, hasher=hasher) lshcache = cache.Cache(bands=10, hasher=hasher)
counter = 0
url_doc = {} url_doc = {}
# load fingerprints from pickle file if needed
if args.load_fingerprints is not None:
for count_fp, fp_file_name in enumerate(args.load_fingerprints):
print("Loading fingerprints from pickle file {}".format(
fp_file_name), flush=True)
fp = open(fp_file_name, "rb")
if count_fp == 0:
# assign directory for the first pkl
lshcache = pickle.load(fp)
url_doc = pickle.load(fp)
else:
# append these to lshcache and url_doc
local_lshcache = pickle.load(fp)
local_url_doc = pickle.load(fp)
for url in local_lshcache.fingerprints.keys():
url_doc[url] = local_url_doc[url]
lshcache.add_fingerprint(local_lshcache.fingerprints[url], url)
fp.close()
counter = 0
start_time = time.time() start_time = time.time()
with open(input, 'r') as f:
for line in f: print("Computing fingerprints", flush=True)
try:
myjson = json.loads(line) # compute finger prints of the inputs if any
url = myjson['url'] # input file and the key to use as id
text = myjson['text'] if args.inputs is not None:
counter += 1 assert len(args.inputs) % 2 == 0
url_doc[url] = text for input_file, key in zip(args.inputs[::2], args.inputs[1::2]):
lshcache.add_fingerprint(hasher.fingerprint(text), url) print(' document processing {} with key {}'.format(input_file, key),
except Exception as e: flush=True)
print('Error:', e) # traverse all the texts and add fingerprints
if counter % 10000 == 0: with open(input_file, 'r') as f_input:
print(' [read]> processed {} documents in {:.2f} seconds ...'. for line in f_input:
format(counter, time.time() - start_time), flush=True) try:
myjson = json.loads(line)
url = myjson[key]
text = myjson['text']
counter += 1
url_doc[url] = text
lshcache.add_fingerprint(hasher.fingerprint(text), url)
except Exception as e:
print('Error:', e)
if counter % 10000 == 0:
print(' [read]> processed {} documents in {:.2f} '
'seconds ...'.format(counter, time.time() - \
start_time), flush=True)
# Save the fingerprints if needed
if args.save_fingerprints is not None:
print("Saving fingerprints to pickle file {}".format(
args.save_fingerprints), flush=True)
with open(args.save_fingerprints, 'wb') as f_save:
pickle.dump(lshcache, f_save)
pickle.dump(url_doc, f_save)
counter = 0 counter = 0
start_time = time.time() start_time = time.time()
deduped = 0 deduped = 0
with open(output, 'wb') as f: # compute jaccard index of the input texts and write to file if needed
if args.output is not None:
f_out = open(args.output, 'wb')
for b in lshcache.bins: for b in lshcache.bins:
for bucket_id in b: for bucket_id in b:
if len(b[bucket_id]) > 1: if len(b[bucket_id]) > 1:
...@@ -94,7 +157,8 @@ if __name__ == '__main__': ...@@ -94,7 +157,8 @@ if __name__ == '__main__':
if len(remove_urls) > 0: if len(remove_urls) > 0:
myjson = json.dumps({main_url: remove_urls}, myjson = json.dumps({main_url: remove_urls},
ensure_ascii=False) ensure_ascii=False)
f.write(myjson.encode('utf-8')) f_out.write(myjson.encode('utf-8'))
f.write('\n'.encode('utf-8')) f_out.write('\n'.encode('utf-8'))
f_out.close()
print('done :-)') print('done :-)')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment