Commit 41a64613 authored by Rewon Child's avatar Rewon Child
Browse files

Merge main

parents 8676baca 83d26f03
This diff is collapsed.
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
import sys import sys
import torch import torch
from torch.nn.parallel import DistributedDataParallel as torchDDP
from apex.multi_tensor_apply import multi_tensor_applier from apex.multi_tensor_apply import multi_tensor_applier
import amp_C import amp_C
...@@ -26,11 +27,25 @@ from megatron import get_args ...@@ -26,11 +27,25 @@ from megatron import get_args
from megatron import print_rank_0 from megatron import print_rank_0
from megatron import get_adlr_autoresume from megatron import get_adlr_autoresume
from megatron import mpu from megatron import mpu
from megatron.checkpointing import save_checkpoint
from megatron.model.module import param_is_not_shared from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
def unwrap_model(model, module_instances=(torchDDP)):
return_list = True
if not isinstance(model, list):
model = [model]
return_list = False
unwrapped_model = []
for model_module in model:
while isinstance(model_module, module_instances):
model_module = model_module.module
unwrapped_model.append(model_module)
if not return_list:
return unwrapped_model[0]
return unwrapped_model
def calc_params_l2_norm(model): def calc_params_l2_norm(model):
"""Calculate l2 norm of parameters """ """Calculate l2 norm of parameters """
# Remove duplicate params. # Remove duplicate params.
...@@ -106,6 +121,8 @@ def print_params_min_max_norm(optimizer, iteration): ...@@ -106,6 +121,8 @@ def print_params_min_max_norm(optimizer, iteration):
def check_adlr_autoresume_termination(iteration, model, def check_adlr_autoresume_termination(iteration, model,
optimizer, lr_scheduler): optimizer, lr_scheduler):
"""Check for autoresume signal and exit if it is received.""" """Check for autoresume signal and exit if it is received."""
from megatron.checkpointing import save_checkpoint
args = get_args() args = get_args()
autoresume = get_adlr_autoresume() autoresume = get_adlr_autoresume()
# Add barrier to ensure consistnecy. # Add barrier to ensure consistnecy.
......
...@@ -38,7 +38,7 @@ def model_provider(): ...@@ -38,7 +38,7 @@ def model_provider():
args = get_args() args = get_args()
num_tokentypes = 2 if args.bert_binary_head else 0 num_tokentypes = 2 if args.bert_binary_head else 0
if mpu.get_pipeline_model_parallel_world_size() > 1: def model_provider_pipelined():
# Determine model based on position of stage in pipeline. # Determine model based on position of stage in pipeline.
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
model = BertModelFirstStage( model = BertModelFirstStage(
...@@ -51,6 +51,17 @@ def model_provider(): ...@@ -51,6 +51,17 @@ def model_provider():
else: else:
model = BertModelIntermediateStage( model = BertModelIntermediateStage(
num_tokentypes=num_tokentypes) num_tokentypes=num_tokentypes)
return model
args = get_args()
if mpu.get_pipeline_model_parallel_world_size() > 1:
if args.virtual_pipeline_model_parallel_size is not None:
model = []
for i in range(args.virtual_pipeline_model_parallel_size):
mpu.set_virtual_pipeline_model_parallel_rank(i)
model.append(model_provider_pipelined())
else:
model = model_provider_pipelined()
else: else:
model = BertModel( model = BertModel(
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
...@@ -92,8 +103,8 @@ def forward_step(data_iterator, model, input_tensor): ...@@ -92,8 +103,8 @@ def forward_step(data_iterator, model, input_tensor):
# Get the batch. # Get the batch.
timers('batch-generator').start() timers('batch-generator').start()
tokens, types, sentence_order, loss_mask, lm_labels, padding_mask \ tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = get_batch(
= get_batch(data_iterator) data_iterator)
timers('batch-generator').stop() timers('batch-generator').stop()
if not args.bert_binary_head: if not args.bert_binary_head:
......
...@@ -35,8 +35,8 @@ def model_provider(): ...@@ -35,8 +35,8 @@ def model_provider():
"""Build the model.""" """Build the model."""
print_rank_0('building GPT model ...') print_rank_0('building GPT model ...')
args = get_args()
if mpu.get_pipeline_model_parallel_world_size() > 1: def model_provider_pipelined():
# Determine model based on position of stage in pipeline. # Determine model based on position of stage in pipeline.
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
model = GPTModelFirstStage(num_tokentypes=0) model = GPTModelFirstStage(num_tokentypes=0)
...@@ -46,6 +46,17 @@ def model_provider(): ...@@ -46,6 +46,17 @@ def model_provider():
else: else:
model = GPTModelIntermediateStage( model = GPTModelIntermediateStage(
num_tokentypes=0) num_tokentypes=0)
return model
args = get_args()
if mpu.get_pipeline_model_parallel_world_size() > 1:
if args.virtual_pipeline_model_parallel_size is not None:
model = []
for i in range(args.virtual_pipeline_model_parallel_size):
mpu.set_virtual_pipeline_model_parallel_rank(i)
model.append(model_provider_pipelined())
else:
model = model_provider_pipelined()
else: else:
model = GPTModel(num_tokentypes=0, parallel_output=True) model = GPTModel(num_tokentypes=0, parallel_output=True)
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
"""Pretrain BERT for Inverse Cloze Task""" """Pretrain BERT for Inverse Cloze Task"""
import math
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -23,17 +24,21 @@ from megatron import get_args ...@@ -23,17 +24,21 @@ from megatron import get_args
from megatron import print_rank_0 from megatron import print_rank_0
from megatron import get_timers from megatron import get_timers
from megatron import mpu from megatron import mpu
from megatron.data.biencoder_dataset_utils import get_ict_batch
from megatron.data.dataset_utils import build_train_valid_test_datasets from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model.biencoder_model import biencoder_model_provider
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group from megatron.utils import average_losses_across_data_parallel_group
from megatron.model.realm_model import general_ict_model_provider
from megatron.data.realm_dataset_utils import get_ict_batch
def pretrain_ict_model_provider(): def pretrain_ict_model_provider():
args = get_args() args = get_args()
return general_ict_model_provider(False, False) model = biencoder_model_provider(
only_context_model=False,
only_query_model=False,
biencoder_shared_query_context_model=\
args.biencoder_shared_query_context_model)
return model
def get_group_world_size_rank(): def get_group_world_size_rank():
...@@ -72,7 +77,6 @@ class AllgatherFromDataParallelRegion(torch.autograd.Function): ...@@ -72,7 +77,6 @@ class AllgatherFromDataParallelRegion(torch.autograd.Function):
output = output_list[rank].contiguous() output = output_list[rank].contiguous()
return output return output
def forward_step(data_iterator, model, input_tensor): def forward_step(data_iterator, model, input_tensor):
"""Forward step.""" """Forward step."""
args = get_args() args = get_args()
...@@ -80,37 +84,57 @@ def forward_step(data_iterator, model, input_tensor): ...@@ -80,37 +84,57 @@ def forward_step(data_iterator, model, input_tensor):
# Get the batch. # Get the batch.
timers('batch-generator').start() timers('batch-generator').start()
query_tokens, query_pad_mask, \ query_tokens, query_mask, \
block_tokens, block_pad_mask, block_indices = get_ict_batch(data_iterator) context_tokens, context_mask, context_indices = get_ict_batch(data_iterator)
timers('batch-generator').stop() timers('batch-generator').stop()
# Query and Context Types
query_types = torch.cuda.LongTensor(*query_tokens.shape).fill_(0)
context_types = torch.cuda.LongTensor(*context_tokens.shape).fill_(0)
# Forward model. # Forward model.
query_logits, block_logits = model(query_tokens, query_pad_mask, block_tokens, block_pad_mask) query_logits, context_logits = model(query_tokens, query_mask,
query_types, context_tokens,
context_mask, context_types)
micro_batch_size = query_logits.shape[0] micro_batch_size = query_logits.shape[0]
global_batch_size = dist.get_world_size() * micro_batch_size # recall we assert that tensor_model_parallel_size == 1 # recall we assert that tensor_model_parallel_size == 1
assert mpu.get_tensor_model_parallel_world_size() == 1, \
"Model parallel size > 1 not supported for ICT"
global_batch_size = dist.get_world_size() * micro_batch_size
all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits) all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits)
all_block_logits = AllgatherFromDataParallelRegion.apply(block_logits) all_context_logits = AllgatherFromDataParallelRegion.apply(context_logits)
# scores are inner products between query and context embeddings
retrieval_scores = torch.matmul(all_query_logits,
torch.transpose(all_context_logits, 0, 1))
# scaling the retriever scores
if args.retriever_score_scaling:
retrieval_scores = retrieval_scores / math.sqrt(args.hidden_size)
# scores are inner products between query and block embeddings softmax_scores = F.log_softmax(retrieval_scores, dim=1)
retrieval_scores = all_query_logits.float().matmul(torch.transpose(all_block_logits, 0, 1).float()) sorted_vals, sorted_indices = torch.topk(softmax_scores,
softmaxed = F.softmax(retrieval_scores, dim=1) k=softmax_scores.shape[1], sorted=True)
sorted_vals, sorted_indices = torch.topk(softmaxed, k=softmaxed.shape[1], sorted=True)
def topk_accuracy(k): def topk_accuracy(k):
return torch.cuda.FloatTensor([sum([int(i in sorted_indices[i, :k]) for i in range(global_batch_size)]) / global_batch_size]) return torch.cuda.FloatTensor([sum([int(i in sorted_indices[i, :k]) \
for i in range(global_batch_size)]) / global_batch_size])
topk_accs = [topk_accuracy(int(k)) for k in args.report_topk_accuracies] topk_accs = [topk_accuracy(int(k)) for k in args.retriever_report_topk_accuracies]
retrieval_loss = torch.nn.CrossEntropyLoss()(retrieval_scores, torch.arange(global_batch_size).long().cuda())
retrieval_loss = retrieval_loss.float()
averaged_losses = average_losses_across_data_parallel_group([retrieval_loss, *topk_accs])
# create stats_dict with retrieval loss and all specified top-k accuracies labels = torch.arange(global_batch_size).long().cuda()
topk_acc_dict = {'top{}_acc'.format(k): v for k, v in zip(args.report_topk_accuracies, averaged_losses[1:])} loss = F.nll_loss(softmax_scores, labels, reduction='mean')
stats_dict = dict(retrieval_loss=averaged_losses[0], **topk_acc_dict) reduced_losses = average_losses_across_data_parallel_group([loss, *topk_accs])
# Scale the retrieval loss
loss = loss * mpu.get_data_parallel_world_size()
return retrieval_loss, stats_dict # create stats_dict with retrieval loss and all specified top-k accuracies
topk_acc_dict = {'top{}_acc'.format(k): v * 100 for k, v in \
zip(args.retriever_report_topk_accuracies, reduced_losses[1:])}
stats_dict = dict(loss=reduced_losses[0], **topk_acc_dict)
return loss, stats_dict
def train_valid_test_datasets_provider(train_val_test_num_samples): def train_valid_test_datasets_provider(train_val_test_num_samples):
...@@ -129,6 +153,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): ...@@ -129,6 +153,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
short_seq_prob=args.short_seq_prob, short_seq_prob=args.short_seq_prob,
seed=args.seed, seed=args.seed,
skip_warmup=(not args.mmap_warmup), skip_warmup=(not args.mmap_warmup),
binary_head=False,
dataset_type='ict') dataset_type='ict')
print_rank_0("> finished creating BERT ICT datasets ...") print_rank_0("> finished creating BERT ICT datasets ...")
...@@ -136,5 +161,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): ...@@ -136,5 +161,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
if __name__ == "__main__": if __name__ == "__main__":
pretrain(train_valid_test_datasets_provider, pretrain_ict_model_provider, forward_step, pretrain(train_valid_test_datasets_provider,
pretrain_ict_model_provider,
forward_step,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'}) args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
import os
import sys import sys
sys.path.append('../') sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir)))
from megatron import print_rank_0
from megatron.indexer import IndexBuilder from megatron.indexer import IndexBuilder
from megatron.initialize import initialize_megatron from megatron.initialize import initialize_megatron
...@@ -23,7 +26,7 @@ def main(): ...@@ -23,7 +26,7 @@ def main():
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'}) args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
index_builder = IndexBuilder() index_builder = IndexBuilder()
index_builder.build_and_save_index() index_builder.build_and_save_index()
print_rank_0("Build and save indices: done!")
if __name__ == "__main__": if __name__ == "__main__":
main() main()
......
...@@ -26,9 +26,9 @@ python blacklist_urls.py <path to the dowloaded deduplicated URLs> <filename for ...@@ -26,9 +26,9 @@ python blacklist_urls.py <path to the dowloaded deduplicated URLs> <filename for
``` ```
python cleanup_dataset.py <input data file> <output cleaned data filename> python cleanup_dataset.py <input data file> <output cleaned data filename>
``` ```
2. Using LSH, find possible duplicates and store then in a file for later processing. This step can NOT be sharded and usually takes 12 to 24 hours for OpenWebText dataset. 2. Using LSH, find possible duplicates and store then in a file for later processing. This step can NOT be sharded and usually takes 12 to 24 hours for OpenWebText dataset. The code supports saving and loading fingerprints for recurrent deduplications.
``` ```
python find_duplicates.py <input cleaned data file> <output possible duplicate urls filename> python find_duplicates.py --inputs <pairlist list of input cleaned data files and keys, e.g. cc.json cc_id news.json news_id> --output <output possible duplicate urls filename>
``` ```
3. Based on similarity measure defind inside function `is_similar` (default: 0.9), group urls that are similar. Basically, for each group, only one url we should keep and remove the rest. 3. Based on similarity measure defind inside function `is_similar` (default: 0.9), group urls that are similar. Basically, for each group, only one url we should keep and remove the rest.
``` ```
...@@ -44,3 +44,12 @@ python remove_group_duplicates.py <file containing simialr documents> <cleaned d ...@@ -44,3 +44,12 @@ python remove_group_duplicates.py <file containing simialr documents> <cleaned d
shuf <cleaned deduped data file> -o train_data.json shuf <cleaned deduped data file> -o train_data.json
``` ```
# Deduplicating ngrams
To deduplicate the downstream tasks from the training dataset, we run the following command.
```
python filter_ngrams.py <down stream task dataset> <training dataset to deduplicate> <output training dataset>
```
We use 13-grams for the deduplication. When we find a 13-gram match in a training document, we split the document into two pieces and remove the 13-gram along with 200 characters from the both side of the 13-gram. We also remove any splitted document with less than 200 characters or if a document got splitted more than 10 times.
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Deduplicate downstream tasks from training dataset. 13-grams have been used.
All split documents with less than 200 characters got filtered. Any document
with more than 10 splits got filtered as well.
"""
from functools import partial
import json
import multiprocessing
import nltk
import re
import string
import sys
import time
def get_words(text):
# get all the lowercase words from text
words, positions = [], []
for match in re.finditer(r'\w+', text.lower()):
words.append(match.group(0))
positions.append(match.start())
return words, positions
def free_ngram(line, ngrams, ngram_size, filter_text_len,
splits_count, split_window_each_size):
# remove all the ngrams
try:
myjson = json.loads(line)
text_buf = [myjson['text']]
except Exception as e:
print("Error: {}".format(e), flush=True)
text_buf = []
text_buf_ngram_free = []
while len(text_buf) > 0:
# get the first one from the buffer
text = text_buf.pop(0)
words, positions = get_words(text)
not_ngram_free = True
punctuations = ".!?"
# find n-grams
for i in range(len(words) - ngram_size + 1):
seq = " ".join(words[i:i+ngram_size])
if seq in ngrams:
# splits the text
# first part of the text
pos = positions[i] - split_window_each_size
text_first = ""
while pos > 0 and not text[pos] in punctuations:
pos -= 1
if pos > 0:
text_first = text[0:pos+1]
pos = positions[i] + split_window_each_size
# last part of the text
text_second = ""
while pos < len(text) and not text[pos] in punctuations:
pos += 1
if pos + 1 < len(text):
text_second = text[pos+1:len(text)]
# first part of ngrams free
if len(text_first) > filter_text_len:
text_buf_ngram_free.append(text_first)
# add second part for further processing
if len(text_second) > filter_text_len:
text_buf.append(text_second)
not_ngram_free = False
break
# text are ngram free
if not_ngram_free:
text_buf_ngram_free.append(text)
return text_buf_ngram_free
if __name__ == '__main__':
print('finding possible duplicate content ...')
main_file = sys.argv[1] # lambada file
dedup_file = sys.argv[2] # Book corpus
output_file = sys.argv[3] #Filtered book corpus
ngrams = {}
id_prefix = "lambada"
# we use 13-grams, any text less than 200 characters got removed
# any text splitted more than 10 got removed as well
ngram_size = 13
filter_text_len = 200
splits_count = 10
split_window_each_size = 200
print('Reading file {} and computing ngrams'.format(main_file))
with open(main_file, 'r') as f:
for line in f:
try:
myjson = json.loads(line)
words, positions = get_words(myjson['text'])
for i in range(len(words) - ngram_size+1):
seq = " ".join(words[i:i+ngram_size])
if seq not in ngrams:
ngrams[seq] = positions[i]
except Exception as e:
print('Error:', e)
print("ngrams size {}".format(len(ngrams)))
print('Reading file {} and deduping n-grams'.format(dedup_file))
counter = 0
start_time = time.time()
out_f = open(output_file, 'wb')
splitted, ignored, split_mt_thld = 0, 0, 0
# Setup multi-processing.
num_workers = 40
fin = open(dedup_file, 'r', encoding='utf-8')
pool = multiprocessing.Pool(num_workers)
free_ngram_x=partial(free_ngram, ngrams=ngrams, ngram_size=ngram_size,
filter_text_len=filter_text_len, splits_count=splits_count,
split_window_each_size=split_window_each_size)
free_ngrams = pool.imap(free_ngram_x, fin, 25)
for text_buf_ngram_free in free_ngrams:
counter += 1
try:
if len(text_buf_ngram_free) > 1:
splitted += (len(text_buf_ngram_free) - 1)
if len(text_buf_ngram_free) == 0:
ignored += 1
# more than 10 splits ignored
if len(text_buf_ngram_free) > splits_count:
text_buf_ngram_free = []
split_mt_thld += 1
for i in range(len(text_buf_ngram_free)):
split_id_string = id_prefix + '-{:010d}'.format(int(counter)) \
+ '-{:010d}'.format(int(i))
outjson = json.dumps({"text":text_buf_ngram_free[i],
id_prefix+"_split_id":split_id_string},
ensure_ascii=False)
out_f.write(outjson.encode('utf-8'))
out_f.write('\n'.encode('utf-8'))
if counter % 1000 == 0:
print(' [search]> processed {} documents in {:.2f} seconds ...'.
format(counter, time.time() - start_time), flush=True)
except Exception as e:
print('Error:', e)
print("Deduped file written to: {}".format(output_file), flush=True)
print("Total docs {} splitted {} ignored {} docs with many splits {}".\
format(counter, splitted, ignored, split_mt_thld), flush=True)
print('done :-)')
...@@ -13,14 +13,15 @@ ...@@ -13,14 +13,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse
import itertools import itertools
import json import json
from lsh import cache, minhash from lsh import cache, minhash
import numpy as np
import time import time
import pickle
import sys import sys
# This function is adapted from: # This function is adapted from:
# https://github.com/mattilyra/LSH/blob/master/examples/Introduction.ipynb # https://github.com/mattilyra/LSH/blob/master/examples/Introduction.ipynb
def shingles(text, char_ngram=5): def shingles(text, char_ngram=5):
...@@ -38,36 +39,98 @@ def jaccard(set_a, set_b): ...@@ -38,36 +39,98 @@ def jaccard(set_a, set_b):
if __name__ == '__main__': if __name__ == '__main__':
print('parsing the arguments ...')
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=1234,
help='Random seed used for python, numpy')
parser.add_argument('--inputs', nargs = '*', default=None, help = \
'Pairwise list of the input files and keys, '
'e.g. --inputs cc.json cc_id news.json news_id')
parser.add_argument('--load-fingerprints', nargs = '*', default=None,
help='Load fingerprints from a list of pickle files,'
' e.g. cc.pkl news.pkl')
parser.add_argument('--save-fingerprints', type=str, default=None,
help='Save the fingerprints of the inputs.')
parser.add_argument('--output', type=str, default=None,
help='Output file name that consists of all ids'
' with matching similarities')
args = parser.parse_args()
print('finding possible duplicate content ...') print('finding possible duplicate content ...')
input = sys.argv[1] # set seed and get an array of seeds of 100 integers
output = sys.argv[2] np.random.seed(args.seed)
seeds = np.random.randint(0, 1e6, size=100)
hasher = minhash.MinHasher(seeds=100, char_ngram=5, hashbytes=4) # initialize minhash and lsh cache
hasher = minhash.MinHasher(seeds=seeds, char_ngram=5, hashbytes=4)
lshcache = cache.Cache(bands=10, hasher=hasher) lshcache = cache.Cache(bands=10, hasher=hasher)
counter = 0
url_doc = {} url_doc = {}
# load fingerprints from pickle file if needed
if args.load_fingerprints is not None:
for count_fp, fp_file_name in enumerate(args.load_fingerprints):
print("Loading fingerprints from pickle file {}".format(
fp_file_name), flush=True)
fp = open(fp_file_name, "rb")
if count_fp == 0:
# assign directory for the first pkl
lshcache = pickle.load(fp)
url_doc = pickle.load(fp)
else:
# append these to lshcache and url_doc
local_lshcache = pickle.load(fp)
local_url_doc = pickle.load(fp)
for url in local_lshcache.fingerprints.keys():
url_doc[url] = local_url_doc[url]
lshcache.add_fingerprint(local_lshcache.fingerprints[url], url)
fp.close()
counter = 0
start_time = time.time() start_time = time.time()
with open(input, 'r') as f:
for line in f: print("Computing fingerprints", flush=True)
try:
myjson = json.loads(line) # compute finger prints of the inputs if any
url = myjson['url'] # input file and the key to use as id
text = myjson['text'] if args.inputs is not None:
counter += 1 assert len(args.inputs) % 2 == 0
url_doc[url] = text for input_file, key in zip(args.inputs[::2], args.inputs[1::2]):
lshcache.add_fingerprint(hasher.fingerprint(text), url) print(' document processing {} with key {}'.format(input_file, key),
except Exception as e: flush=True)
print('Error:', e) # traverse all the texts and add fingerprints
if counter % 10000 == 0: with open(input_file, 'r') as f_input:
print(' [read]> processed {} documents in {:.2f} seconds ...'. for line in f_input:
format(counter, time.time() - start_time), flush=True) try:
myjson = json.loads(line)
url = myjson[key]
text = myjson['text']
counter += 1
url_doc[url] = text
lshcache.add_fingerprint(hasher.fingerprint(text), url)
except Exception as e:
print('Error:', e)
if counter % 10000 == 0:
print(' [read]> processed {} documents in {:.2f} '
'seconds ...'.format(counter, time.time() - \
start_time), flush=True)
# Save the fingerprints if needed
if args.save_fingerprints is not None:
print("Saving fingerprints to pickle file {}".format(
args.save_fingerprints), flush=True)
with open(args.save_fingerprints, 'wb') as f_save:
pickle.dump(lshcache, f_save)
pickle.dump(url_doc, f_save)
counter = 0 counter = 0
start_time = time.time() start_time = time.time()
deduped = 0 deduped = 0
with open(output, 'wb') as f: # compute jaccard index of the input texts and write to file if needed
if args.output is not None:
f_out = open(args.output, 'wb')
for b in lshcache.bins: for b in lshcache.bins:
for bucket_id in b: for bucket_id in b:
if len(b[bucket_id]) > 1: if len(b[bucket_id]) > 1:
...@@ -94,7 +157,8 @@ if __name__ == '__main__': ...@@ -94,7 +157,8 @@ if __name__ == '__main__':
if len(remove_urls) > 0: if len(remove_urls) > 0:
myjson = json.dumps({main_url: remove_urls}, myjson = json.dumps({main_url: remove_urls},
ensure_ascii=False) ensure_ascii=False)
f.write(myjson.encode('utf-8')) f_out.write(myjson.encode('utf-8'))
f.write('\n'.encode('utf-8')) f_out.write('\n'.encode('utf-8'))
f_out.close()
print('done :-)') print('done :-)')
...@@ -19,7 +19,7 @@ import sys ...@@ -19,7 +19,7 @@ import sys
def is_similar(jaccard_similarity): def is_similar(jaccard_similarity):
return (js >= 0.9) return (jaccard_similarity >= 0.7)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment