Commit 0211193c authored by zhuwenwen's avatar zhuwenwen
Browse files

initial llama

parents
Pipeline #509 failed with stages
in 0 seconds
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# The following code has been taken from
# https://github.com/facebookresearch/DPR, which is CC-BY-NC 4.0
# licensed as of now. More details on the license can be found
# at https://github.com/facebookresearch/DPR/blob/master/LICENSE
"""
Most of the tokenizers code here is copied from DrQA codebase to avoid adding extra dependency
"""
import copy
import logging
import regex
import spacy
logger = logging.getLogger(__name__)
class Tokens(object):
"""A class to represent a list of tokenized text."""
TEXT = 0
TEXT_WS = 1
SPAN = 2
POS = 3
LEMMA = 4
NER = 5
def __init__(self, data, annotators, opts=None):
self.data = data
self.annotators = annotators
self.opts = opts or {}
def __len__(self):
"""The number of tokens."""
return len(self.data)
def slice(self, i=None, j=None):
"""Return a view of the list of tokens from [i, j)."""
new_tokens = copy.copy(self)
new_tokens.data = self.data[i: j]
return new_tokens
def untokenize(self):
"""Returns the original text (with whitespace reinserted)."""
return ''.join([t[self.TEXT_WS] for t in self.data]).strip()
def words(self, uncased=False):
"""Returns a list of the text of each token
Args:
uncased: lower cases text
"""
if uncased:
return [t[self.TEXT].lower() for t in self.data]
else:
return [t[self.TEXT] for t in self.data]
def offsets(self):
"""Returns a list of [start, end) character offsets of each token."""
return [t[self.SPAN] for t in self.data]
def pos(self):
"""Returns a list of part-of-speech tags of each token.
Returns None if this annotation was not included.
"""
if 'pos' not in self.annotators:
return None
return [t[self.POS] for t in self.data]
def lemmas(self):
"""Returns a list of the lemmatized text of each token.
Returns None if this annotation was not included.
"""
if 'lemma' not in self.annotators:
return None
return [t[self.LEMMA] for t in self.data]
def entities(self):
"""Returns a list of named-entity-recognition tags of each token.
Returns None if this annotation was not included.
"""
if 'ner' not in self.annotators:
return None
return [t[self.NER] for t in self.data]
def ngrams(self, n=1, uncased=False, filter_fn=None, as_strings=True):
"""Returns a list of all ngrams from length 1 to n.
Args:
n: upper limit of ngram length
uncased: lower cases text
filter_fn: user function that takes in an ngram list and returns
True or False to keep or not keep the ngram
as_string: return the ngram as a string vs list
"""
def _skip(gram):
if not filter_fn:
return False
return filter_fn(gram)
words = self.words(uncased)
ngrams = [(s, e + 1)
for s in range(len(words))
for e in range(s, min(s + n, len(words)))
if not _skip(words[s:e + 1])]
# Concatenate into strings
if as_strings:
ngrams = ['{}'.format(' '.join(words[s:e])) for (s, e) in ngrams]
return ngrams
def entity_groups(self):
"""Group consecutive entity tokens with the same NER tag."""
entities = self.entities()
if not entities:
return None
non_ent = self.opts.get('non_ent', 'O')
groups = []
idx = 0
while idx < len(entities):
ner_tag = entities[idx]
# Check for entity tag
if ner_tag != non_ent:
# Chomp the sequence
start = idx
while (idx < len(entities) and entities[idx] == ner_tag):
idx += 1
groups.append((self.slice(start, idx).untokenize(), ner_tag))
else:
idx += 1
return groups
class Tokenizer(object):
"""Base tokenizer class.
Tokenizers implement tokenize, which should return a Tokens class.
"""
def tokenize(self, text):
raise NotImplementedError
def shutdown(self):
pass
def __del__(self):
self.shutdown()
class SimpleTokenizer(Tokenizer):
ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+'
NON_WS = r'[^\p{Z}\p{C}]'
def __init__(self, **kwargs):
"""
Args:
annotators: None or empty set (only tokenizes).
"""
self._regexp = regex.compile(
'(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS),
flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE
)
if len(kwargs.get('annotators', {})) > 0:
logger.warning('%s only tokenizes! Skipping annotators: %s' %
(type(self).__name__, kwargs.get('annotators')))
self.annotators = set()
def tokenize(self, text):
data = []
matches = [m for m in self._regexp.finditer(text)]
for i in range(len(matches)):
# Get text
token = matches[i].group()
# Get whitespace
span = matches[i].span()
start_ws = span[0]
if i + 1 < len(matches):
end_ws = matches[i + 1].span()[0]
else:
end_ws = span[1]
# Format data
data.append((
token,
text[start_ws: end_ws],
span,
))
return Tokens(data, self.annotators)
class SpacyTokenizer(Tokenizer):
def __init__(self, **kwargs):
"""
Args:
annotators: set that can include pos, lemma, and ner.
model: spaCy model to use (either path, or keyword like 'en').
"""
model = kwargs.get('model', 'en')
self.annotators = copy.deepcopy(kwargs.get('annotators', set()))
nlp_kwargs = {'parser': False}
if not any([p in self.annotators for p in ['lemma', 'pos', 'ner']]):
nlp_kwargs['tagger'] = False
if 'ner' not in self.annotators:
nlp_kwargs['entity'] = False
self.nlp = spacy.load(model, **nlp_kwargs)
def tokenize(self, text):
# We don't treat new lines as tokens.
clean_text = text.replace('\n', ' ')
tokens = self.nlp.tokenizer(clean_text)
if any([p in self.annotators for p in ['lemma', 'pos', 'ner']]):
self.nlp.tagger(tokens)
if 'ner' in self.annotators:
self.nlp.entity(tokens)
data = []
for i in range(len(tokens)):
# Get whitespace
start_ws = tokens[i].idx
if i + 1 < len(tokens):
end_ws = tokens[i + 1].idx
else:
end_ws = tokens[i].idx + len(tokens[i].text)
data.append((
tokens[i].text,
text[start_ws: end_ws],
(tokens[i].idx, tokens[i].idx + len(tokens[i].text)),
tokens[i].tag_,
tokens[i].lemma_,
tokens[i].ent_type_,
))
# Set special option for non-entity tag: '' vs 'O' in spaCy
return Tokens(data, self.annotators, opts={'non_ent': ''})
import glob
import json
import os
import time
from torch.utils.data import Dataset
from megatron import print_rank_0
from tasks.data_utils import build_sample
from tasks.data_utils import build_tokens_types_paddings_from_ids
from tasks.data_utils import clean_text
NUM_CHOICES = 4
MAX_QA_LENGTH = 128
class RaceDataset(Dataset):
def __init__(self, dataset_name, datapaths, tokenizer, max_seq_length,
max_qa_length=MAX_QA_LENGTH):
self.dataset_name = dataset_name
print_rank_0(' > building RACE dataset for {}:'.format(
self.dataset_name))
string = ' > paths:'
for path in datapaths:
string += ' ' + path
print_rank_0(string)
self.samples = []
for datapath in datapaths:
self.samples.extend(process_single_datapath(datapath, tokenizer,
max_qa_length,
max_seq_length))
print_rank_0(' >> total number of samples: {}'.format(
len(self.samples)))
# This indicates that each "sample" has multiple samples that
# will collapse into batch dimension
self.sample_multiplier = NUM_CHOICES
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
return self.samples[idx]
def process_single_datapath(datapath, tokenizer, max_qa_length, max_seq_length):
"""Read in RACE files, combine, clean-up, tokenize, and convert to
samples."""
print_rank_0(' > working on {}'.format(datapath))
start_time = time.time()
# Get list of files.
filenames = glob.glob(os.path.join(datapath, '*.txt'))
samples = []
num_docs = 0
num_questions = 0
num_samples = 0
# Load all the files
for filename in filenames:
with open(filename, 'r') as f:
for line in f:
data = json.loads(line)
num_docs += 1
context = data["article"]
questions = data["questions"]
choices = data["options"]
answers = data["answers"]
# Check the length.
assert len(questions) == len(answers)
assert len(questions) == len(choices)
# Context: clean up and convert to ids.
context = clean_text(context)
context_ids = tokenizer.tokenize(context)
# Loop over questions.
for qi, question in enumerate(questions):
num_questions += 1
# Label.
label = ord(answers[qi]) - ord("A")
assert label >= 0
assert label < NUM_CHOICES
assert len(choices[qi]) == NUM_CHOICES
# For each question, build num-choices samples.
ids_list = []
types_list = []
paddings_list = []
for ci in range(NUM_CHOICES):
choice = choices[qi][ci]
# Merge with choice.
if "_" in question:
qa = question.replace("_", choice)
else:
qa = " ".join([question, choice])
# Clean QA.
qa = clean_text(qa)
# Tokenize.
qa_ids = tokenizer.tokenize(qa)
# Trim if needed.
if len(qa_ids) > max_qa_length:
qa_ids = qa_ids[0:max_qa_length]
# Build the sample.
ids, types, paddings \
= build_tokens_types_paddings_from_ids(
qa_ids, context_ids, max_seq_length,
tokenizer.cls, tokenizer.sep, tokenizer.pad)
ids_list.append(ids)
types_list.append(types)
paddings_list.append(paddings)
# Convert to numpy and add to samples
samples.append(build_sample(ids_list, types_list,
paddings_list, label,
num_samples))
num_samples += 1
elapsed_time = time.time() - start_time
print_rank_0(' > processed {} document, {} questions, and {} samples'
' in {:.2f} seconds'.format(num_docs, num_questions,
num_samples, elapsed_time))
return samples
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Race."""
from megatron import get_args
from megatron import print_rank_0
from megatron import get_tokenizer
from megatron import mpu
from megatron.model.multiple_choice import MultipleChoice
from tasks.eval_utils import accuracy_func_provider
from tasks.finetune_utils import finetune
from tasks.race.data import RaceDataset
def train_valid_datasets_provider():
"""Provide train and validation datasets."""
args = get_args()
tokenizer = get_tokenizer()
train_dataset = RaceDataset('training', args.train_data,
tokenizer, args.seq_length)
valid_dataset = RaceDataset('validation', args.valid_data,
tokenizer, args.seq_length)
return train_dataset, valid_dataset
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
print_rank_0('building multichoice model for RACE ...')
model = MultipleChoice(num_tokentypes=2,
pre_process=pre_process,
post_process=post_process)
return model
def metrics_func_provider():
"""Privde metrics callback function."""
args = get_args()
tokenizer = get_tokenizer()
def single_dataset_provider(datapath):
name = datapath.split('RACE')[-1].strip('/').replace('/', '-')
return RaceDataset(name, [datapath], tokenizer, args.seq_length)
return accuracy_func_provider(single_dataset_provider)
def main():
finetune(train_valid_datasets_provider, model_provider,
end_of_epoch_callback_provider=metrics_func_provider)
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Vision-classification finetuning/evaluation."""
from megatron import get_args
from megatron import print_rank_0
from megatron.model.vit_model import VitModel
from megatron.data.vit_dataset import build_train_valid_datasets
from tasks.vision.eval_utils import accuracy_func_provider
from tasks.vision.finetune_utils import finetune
def classification():
def train_valid_datasets_provider():
"""Build train and validation dataset."""
args = get_args()
train_ds, valid_ds = build_train_valid_datasets(
data_path=args.data_path,
crop_size=args.img_dim,
)
return train_ds, valid_ds
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
args = get_args()
print_rank_0("building classification model for ImageNet ...")
return VitModel(num_classes=args.num_classes, finetune=True,
pre_process=pre_process, post_process=post_process)
"""Finetune/evaluate."""
finetune(
train_valid_datasets_provider,
model_provider,
end_of_epoch_callback_provider=accuracy_func_provider,
)
def main():
classification()
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluation utilities."""
import os
from functools import partial
import torch
from megatron import get_args
from megatron import print_rank_0, print_rank_last
from megatron import mpu
from megatron.schedules import get_forward_backward_func
from tasks.vision.finetune_utils import build_data_loader
from tasks.vision.finetune_utils import process_batch
from torchvision import datasets, transforms
def accuracy_func_provider():
"""Provide function that calculates accuracies."""
args = get_args()
data_path = args.data_path
crop_size = args.img_dim
# mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
# Build dataloaders.
val_data_path = os.path.join(data_path[0], "val")
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
transform_val = transforms.Compose(
[
transforms.Resize(crop_size),
transforms.CenterCrop(crop_size),
transforms.ToTensor(),
normalize,
]
)
dataset = datasets.ImageFolder(root=val_data_path, transform=transform_val)
dataloader = build_data_loader(
dataset,
args.micro_batch_size,
num_workers=args.num_workers,
drop_last=(mpu.get_data_parallel_world_size() > 1),
)
def metrics_func(model, epoch):
print_rank_0("calculating metrics ...")
correct, total = calculate_correct_answers(model, dataloader, epoch)
percent = float(correct) * 100.0 / float(total)
print_rank_last(
" >> |epoch: {}| overall: correct / total = {} / {} = "
"{:.4f} %".format(epoch, correct, total, percent)
)
return metrics_func
def calculate_correct_answers(model, dataloader, epoch):
"""Calculate correct over total answers"""
args = get_args()
forward_backward_func = get_forward_backward_func()
for m in model:
m.eval()
def loss_func(labels, output_tensor):
logits = output_tensor
loss_dict = {}
# Compute the correct answers.
predicted = torch.argmax(logits, dim=-1)
corrects = (predicted == labels).float()
# Add to the counters.
loss_dict['total'] = labels.size(0)
loss_dict['correct'] = corrects.sum().item()
return 0, loss_dict
#defined inside to capture output_predictions
def correct_answers_forward_step(batch, model):
try:
batch_ = next(batch)
except BaseException:
batch_ = batch
images, labels = process_batch(batch_)
# Forward model.
args = get_args()
output_tensor = model(images)
return output_tensor, partial(loss_func, labels)
with torch.no_grad():
# For all the batches in the dataset.
total = 0
correct = 0
for _, batch in enumerate(dataloader):
loss_dicts = forward_backward_func(correct_answers_forward_step, batch, model,
optimizer=None, timers=None, forward_only=True)
for loss_dict in loss_dicts:
total += loss_dict['total']
correct += loss_dict['correct']
for m in model:
m.train()
# Reduce.
if mpu.is_pipeline_last_stage():
unreduced = torch.cuda.LongTensor([correct, total])
torch.distributed.all_reduce(unreduced,
group=mpu.get_data_parallel_group())
# Print on screen.
correct_ans = unreduced[0].item()
total_count = unreduced[1].item()
return correct_ans, total_count
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Finetune utilities."""
import torch
import torch.nn.functional as F
from functools import partial
from megatron import get_args
from megatron import print_rank_0
from megatron import get_timers
from megatron import mpu
from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
from megatron.training import evaluate_and_print_results
from megatron.training import setup_model_and_optimizer
from megatron.training import train_step
from megatron.training import training_log
from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import average_losses_across_data_parallel_group
def process_batch(batch):
"""Process batch and produce inputs for the model."""
images = batch[0].cuda().contiguous()
labels = batch[1].cuda().contiguous()
return images, labels
def cross_entropy_loss_func(labels, output_tensor):
logits = output_tensor
# Cross-entropy loss.
loss = F.cross_entropy(logits.contiguous().float(), labels)
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'lm loss': averaged_loss[0]}
def _cross_entropy_forward_step(batch, model):
"""Simple forward step with cross-entropy loss."""
timers = get_timers()
# Get the batch.
timers("batch generator").start()
try:
batch_ = next(batch)
except BaseException:
batch_ = batch
images, labels = process_batch(batch_)
timers("batch generator").stop()
# Forward model.
output_tensor = model(images)
return output_tensor, partial(cross_entropy_loss_func, labels)
def build_data_loader(dataset, micro_batch_size, num_workers, drop_last):
"""Data loader. Note that batch-size is the local (per GPU) batch-size."""
# Sampler.
world_size = mpu.get_data_parallel_world_size()
rank = mpu.get_data_parallel_rank()
sampler = torch.utils.data.distributed.DistributedSampler(
dataset, num_replicas=world_size, rank=rank
)
# Data loader. Note that batch size is the per GPU batch size.
data_loader = torch.utils.data.DataLoader(
dataset,
batch_size=micro_batch_size,
sampler=sampler,
shuffle=False,
num_workers=num_workers,
drop_last=drop_last,
pin_memory=True,
)
return data_loader
def _build_infinite_size_dataloader(dataloader):
"""Build a looped dataloader with infinite size."""
iterator = dataloader.__iter__()
while True:
try:
yield iterator.__next__()
except StopIteration:
iterator = dataloader.__iter__()
def _build_train_valid_dataloaders(train_dataset, valid_dataset):
"""Traing and validation dataloaders."""
args = get_args()
print_rank_0('building train and validation dataloaders ...')
# Training dataset.
train_dataloader = build_data_loader(train_dataset, args.micro_batch_size,
args.num_workers, not args.keep_last)
# Set the training iterations.
args.train_iters_per_epoch = len(train_dataloader)
args.train_iters = args.epochs * args.train_iters_per_epoch
# Validation dataset. For this dataset, we do not need to set up
# shuffling so we can just use a simple infinite loop.
valid_dataloader_ = build_data_loader(valid_dataset, args.micro_batch_size,
args.num_workers, not args.keep_last)
valid_dataloader = _build_infinite_size_dataloader(valid_dataloader_)
# Now that we've built the data loaders, set batch_size arguments
# to the actual batch size the model will see for this dataset.
# This is necessary so pipeline transfers know what size they are
# and the LR schedule, which is based on samples seen, gets set
# correctly.
args.orig_micro_batch_size = args.micro_batch_size
args.orig_global_batch_size = args.global_batch_size
return train_dataloader, valid_dataloader
def _train(
model,
optimizer,
lr_scheduler,
forward_step,
train_dataloader,
valid_dataloader,
end_of_epoch_callback,
):
"""Train the model."""
args = get_args()
timers = get_timers()
# Turn on training mode which enables dropout.
for m in model:
m.train()
# Tracking loss.
losses_dict_sum = {}
# Starting epoch and iteration
start_epoch = args.iteration // args.train_iters_per_epoch
start_iteration = args.iteration % args.train_iters_per_epoch
iteration = args.iteration
# Memory reporting flag.
report_memory_flag = True
# For each remaining epoch
timers("interval-time").start()
for epoch in range(start_epoch, args.epochs):
print_rank_0("working on epoch {} ...".format(epoch + 1))
# Set the data loader epoch to shuffle the index iterator.
train_dataloader.sampler.set_epoch(args.seed + epoch)
# For all the batches in the dataset.
for iteration_, batch in enumerate(train_dataloader):
# Ignore the iterations before starting value
if iteration_ < start_iteration:
continue
# Set to zero so the next epoch does not skip any batches.
start_iteration = 0
# Train for one step.
losses_dict, skipped_iter, grad_norm, num_zeros_in_grad = train_step(
forward_step, batch, model, optimizer, lr_scheduler
)
iteration += 1
# Logging.
params_norm = None
if args.log_params_norm:
params_norm = calc_params_l2_norm(model)
report_memory_flag = training_log(
losses_dict,
losses_dict_sum,
optimizer.param_groups[0]["lr"],
iteration,
optimizer.get_loss_scale().item(),
report_memory_flag,
skipped_iter,
grad_norm,
params_norm,
num_zeros_in_grad
)
# Autoresume
if args.adlr_autoresume and (
iteration % args.adlr_autoresume_interval == 0
):
check_adlr_autoresume_termination(
iteration, model, optimizer, lr_scheduler
)
# Checkpointing
if (
args.save
and args.save_interval
and iteration % args.save_interval == 0
):
save_checkpoint(iteration, model, optimizer, lr_scheduler)
# Evaluation
if args.eval_interval and iteration % args.eval_interval == 0:
prefix = "iteration {}".format(iteration)
evaluate_and_print_results(
prefix,
forward_step,
valid_dataloader,
model,
iteration,
False,
)
# Checkpointing at the end of each epoch.
if args.save:
save_checkpoint(iteration, model, optimizer, lr_scheduler)
# Callback at the end of each epoch.
if end_of_epoch_callback is not None:
end_of_epoch_callback(model, epoch)
def finetune(
train_valid_datasets_provider,
model_provider,
forward_step=_cross_entropy_forward_step,
end_of_epoch_callback_provider=None,
):
"""Main finetune function used across all tasks."""
args = get_args()
timers = get_timers()
# Train and validation data loaders.
timers("train/valid/test dataset/dataloder").start()
if args.epochs > 0:
train_dataset, valid_dataset = train_valid_datasets_provider()
train_dataloader, valid_dataloader = _build_train_valid_dataloaders(
train_dataset, valid_dataset
)
timers("train/valid/test dataset/dataloder").stop()
# Build calback function.
timers("callback function").start()
end_of_epoch_callback = None
if end_of_epoch_callback_provider is not None:
end_of_epoch_callback = end_of_epoch_callback_provider()
timers("callback function").stop()
# Build model, optimizer and learning rate scheduler.
timers("model and optimizer").start()
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
timers("model and optimizer").stop()
# If pretrained checkpoint is provided and we have not trained for
# any iteration (i.e., iteration is zero), then load the pretrained
# checkpoint.
timers("pretrained checkpoint").start()
if args.iteration == 0 and args.pretrained_checkpoint is not None:
original_load = args.load
args.load = args.pretrained_checkpoint
_ = load_checkpoint(model, None, None, strict=False)
args.load = original_load
# This is critical when only model is loaded. We should make sure
# master parameters are also updated.
optimizer.reload_model_params()
timers("pretrained checkpoint").stop()
# Print setup timing.
print_rank_0("done with setups ...")
timers.log(
[
"train/valid/test dataset/dataloder",
"callback function",
"model and optimizer",
"pretrained checkpoint",
]
)
print_rank_0("training ...")
# Finetune the model.
if args.epochs > 0:
_train(
model,
optimizer,
lr_scheduler,
forward_step,
train_dataloader,
valid_dataloader,
end_of_epoch_callback,
)
# Or just evaluate.
else:
if end_of_epoch_callback is not None:
print_rank_0("evaluation only mode, setting epoch to -1")
end_of_epoch_callback(model, epoch=-1, output_predictions=True)
print_rank_0("done :-)")
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Main tasks functionality."""
import os
import sys
sys.path.append(
os.path.abspath(
os.path.join(
os.path.join(os.path.dirname(__file__), os.path.pardir),
os.path.pardir,
)
)
)
from megatron import get_args
from megatron.initialize import initialize_megatron
from classification import main
def get_tasks_args(parser):
"""Provide extra arguments required for tasks."""
group = parser.add_argument_group(title="tasks")
group.add_argument(
"--epochs",
type=int,
default=None,
help="Number of finetunning epochs. Zero results in "
"evaluation only.",
)
group.add_argument(
"--pretrained-checkpoint",
type=str,
default=None,
help="Pretrained checkpoint used for finetunning.",
)
group.add_argument(
"--keep-last",
action="store_true",
help="Keep the last batch (maybe incomplete) in" "the data loader",
)
return parser
if __name__ == "__main__":
initialize_megatron(extra_args_provider=get_tasks_args)
args = get_args()
main()
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Zero-shot datasets."""
import json
import math
import numpy as np
import torch
from megatron import get_args
from megatron import print_rank_0
from megatron import get_tokenizer
from .detokenizer import get_detokenizer
def build_dataset(task):
"""Helper function to select and build dataset."""
if task == 'LAMBADA':
return _build_lambada_dataset()
if task == 'WIKITEXT103':
return _build_wikitext103_dataset()
raise NotImplementedError('dataset for {} task is not '
'implemented.'.format(task))
class _LMDataset(torch.utils.data.Dataset):
def __init__(self, tokens, seq_len, pad_idx, num_original_tokens,
num_tokenized_tokens, overalapping_eval=None):
self.tokens = tokens
self.seq_len = seq_len
self.pad_idx = pad_idx
self.overalapping_eval = overalapping_eval
if self.overalapping_eval is None:
self.overalapping_eval = self.seq_len
self.overalapping_eval = max(1, self.overalapping_eval)
self.num_original_tokens = num_original_tokens
self.num_tokenized_tokens = num_tokenized_tokens
self.total_targets = len(self.tokens) - 1
# remove first sequence tokens
targets = max(self.total_targets - self.overalapping_eval, 0)
self.total_sequences = max(
math.ceil(targets / self.overalapping_eval) + 1, 1)
def __len__(self):
return self.total_sequences
def __getitem__(self, idx):
start_idx = idx * self.overalapping_eval
end_idx = start_idx + self.seq_len
tokens = self.tokens[start_idx:end_idx + 1]
num_tokens = len(tokens)
pad_mask = [1] * num_tokens
if num_tokens < self.seq_len + 1:
num_pad = (self.seq_len + 1 - num_tokens)
pad_mask += [0] * (num_pad)
tokens += [self.pad_idx] * num_pad
pad_mask = np.array(pad_mask[1:])
if self.overalapping_eval != self.seq_len and idx != 0:
pad_mask[:-self.overalapping_eval] *= 0
return {'text': np.array(tokens), 'pad_mask': pad_mask}
class _LambadaDataset(torch.utils.data.Dataset):
def __init__(self, path, pad_idx, tokenizer, seq_len, strict=False):
print_rank_0('> building lambada dataset from {} ...'.format(path))
self.seq_len = seq_len
self.pad_idx = pad_idx
self.tokenizer = tokenizer
self.strict = strict
self.tokens = []
self.labels = []
with open(path, 'r') as f:
for line in f.readlines():
text = json.loads(line)['text']
tokens, labels = self.get_tokens(text)
self.tokens.append(tokens)
self.labels.append(labels)
def get_tokens(self, text):
if not self.strict:
tokens = self.tokenizer.tokenize(text)
return tokens[:-1], [tokens[-1]]
last_token = text.split()[-1]
start_idx = text.rfind(last_token)
beginning_tokens = self.tokenizer.tokenize(text[:start_idx].strip())
last_token = self.tokenizer.tokenize(' ' + last_token)
return beginning_tokens, last_token
def __len__(self):
return len(self.tokens)
def __getitem__(self, idx):
tokens = self.tokens[idx]
num_tokens = len(tokens)
pad_mask = [0] * num_tokens
labels = self.labels[idx]
pad_mask += [1] * len(labels)
tokens = tokens + labels
num_tokens = len(tokens)
if num_tokens < self.seq_len + 1:
num_pad = (self.seq_len + 1 - num_tokens)
pad_mask += [0] * (num_pad)
tokens += [self.pad_idx] * num_pad
pad_mask = np.array(pad_mask[1:])
return {'text': np.array(tokens), 'pad_mask': pad_mask}
def _build_lambada_dataset():
"""Build lambada dataset."""
args = get_args()
tokenizer = get_tokenizer()
assert len(args.valid_data) == 1
val_dataset = _LambadaDataset(args.valid_data[0], tokenizer.eod, tokenizer,
args.seq_length, args.strict_lambada)
print_rank_0(' > found {} samples.'.format(len(val_dataset)))
return val_dataset
def _build_wikitext103_dataset():
""""""
args = get_args()
tokenizer = get_tokenizer()
assert len(args.valid_data) == 1
with open(args.valid_data[0], "rb") as reader:
entire_data = reader.read().decode('utf-8')
num_original_tokens = len(entire_data.strip().split(" "))
entire_data = get_detokenizer(args.valid_data[0])(entire_data)
tokenized_data = tokenizer.tokenize(entire_data)
num_tokenized_tokens = len(tokenized_data)
val_dataset = _LMDataset(tokenized_data, args.seq_length, tokenizer.eod,
num_original_tokens, num_tokenized_tokens,
args.overlapping_eval)
print_rank_0(' > number of original tokens: {}, number of detokenized '
'tokens: {}'.format(num_original_tokens, num_tokenized_tokens))
return val_dataset
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Detokenization."""
import re
def ptb_detokenizer(string):
string = string.replace(" '", "'")
string = string.replace(" \n", "\n")
string = string.replace("\n ", "\n")
string = string.replace(" n't", "n't")
string = string.replace(" N ", "1 ")
string = string.replace("$ 1", "$1")
string = string.replace("# 1", "#1")
return string
def wikitext_detokenizer(string):
# contractions
string = string.replace("s '", "s'")
string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string)
# number separators
string = string.replace(" @-@ ", "-")
string = string.replace(" @,@ ", ",")
string = string.replace(" @.@ ", ".")
# punctuation
string = string.replace(" : ", ": ")
string = string.replace(" ; ", "; ")
string = string.replace(" . ", ". ")
string = string.replace(" ! ", "! ")
string = string.replace(" ? ", "? ")
string = string.replace(" , ", ", ")
# double brackets
string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string)
string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string)
string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string)
string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string)
string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string)
# miscellaneous
string = string.replace("= = = =", "====")
string = string.replace("= = =", "===")
string = string.replace("= =", "==")
string = string.replace(" " + chr(176) + " ", chr(176))
string = string.replace(" \n", "\n")
string = string.replace("\n ", "\n")
string = string.replace(" N ", " 1 ")
string = string.replace(" 's", "'s")
return string
def lambada_detokenizer(string):
return string
_DETOKENIZERS = {
'ptb': ptb_detokenizer,
'wiki': wikitext_detokenizer,
'lambada': lambada_detokenizer,
}
def get_detokenizer(path):
for key in _DETOKENIZERS.keys():
if key in path:
return _DETOKENIZERS[key]
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""GPT zero-shot evaluation."""
import math
import torch
from megatron import get_args
from megatron import print_rank_0, is_last_rank
from megatron import get_tokenizer
from megatron import mpu
from megatron.checkpointing import load_checkpoint
from megatron.model import GPTModel
from megatron.training import get_model
from megatron.utils import get_ltor_masks_and_position_ids, unwrap_model
from megatron.p2p_communication import recv_forward, send_forward
from tasks.finetune_utils import build_data_loader
from .datasets import build_dataset
# These are needed to unwrap the model, would be nice to put these in megatron.utils if possible?
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
def get_model_provider(eval_metric):
"""Based on evaluation metric set the parallel-output flag and
return the model provider."""
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
if eval_metric == 'loss':
parallel_output = True
elif eval_metric == 'accuracy':
parallel_output = False
else:
raise NotImplementedError('output type for {} evaluation metric '
'is not supported.'.format(eval_metric))
print_rank_0('building GPT model ...')
model = GPTModel(num_tokentypes=0, parallel_output=parallel_output,
pre_process=pre_process, post_process=post_process)
return model
return model_provider
def process_batch(batch):
"""Process batch and produce inputs for the model."""
args = get_args()
tokenizer = get_tokenizer()
loss_mask = batch['pad_mask'].long().cuda().contiguous().byte()
tokens_ = batch['text'].long().cuda().contiguous()
labels = tokens_[:, 1:].contiguous()
tokens = tokens_[:, :-1].contiguous()
# Get the masks and postition ids.
attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
tokens,
tokenizer.eod,
args.reset_position_ids,
args.reset_attention_mask,
args.eod_mask_loss)
return tokens, labels, attention_mask, position_ids, loss_mask
def forward_step(batch, model, eval_metric):
"""Forward step."""
# Get the batch.
tokens, labels, attention_mask, position_ids, loss_mask = process_batch(
batch)
# Tell the model what our actual batch size will be
args = get_args()
args.micro_batch_size = len(labels)
input_tensor = recv_forward()
# Forward pass through the model.
unwrapped_model = unwrap_model(
model, (torchDDP, LocalDDP, Float16Module))
unwrapped_model.set_input_tensor(input_tensor)
output = model(tokens, position_ids, attention_mask)
send_forward(output)
if mpu.is_pipeline_last_stage():
# For loss, return the unreduced loss.
if eval_metric == 'loss':
losses = mpu.vocab_parallel_cross_entropy(
output.contiguous().float(), labels.contiguous())
loss = torch.sum(
losses.view(-1) * loss_mask.contiguous().view(-1).float())
return loss
# For accuracy, return the number of correctly predicted samples.
if eval_metric == 'accuracy':
outputs = torch.argmax(output, -1)
correct = (outputs == labels).float()
correct[(1 - loss_mask).bool()] = 1
correct = correct.prod(-1)
return correct.sum()
raise NotImplementedError('forward method for evaluation metric {} '
'is not implemented.'.format(eval_metric))
return None
def evaluate(data_loader, model, eval_metric):
"""Evaluation."""
args = get_args()
# Turn on evaluation mode which disables dropout.
model.eval()
total_output = 0.0
with torch.no_grad():
# For all the batches in the dataset.
for iteration, batch in enumerate(data_loader):
if iteration % args.log_interval == 0:
print_rank_0('> working on iteration: {}'.format(iteration))
# Forward evaluation.
output = forward_step(batch, model, eval_metric)
# Reduce across processes.
if mpu.is_pipeline_last_stage():
torch.distributed.all_reduce(output,
group=mpu.get_data_parallel_group())
total_output += output
return total_output
def evaluate_and_print_results(task, data_loader, model, eval_metric):
"""Evaluate and print results on screen."""
# Evaluate and get results.
output = evaluate(data_loader, model, eval_metric)
string = ' validation results on {} | '.format(task)
if is_last_rank():
if eval_metric == 'loss':
num_tokenized_tokens = data_loader.dataset.num_tokenized_tokens
num_original_tokens = data_loader.dataset.num_original_tokens
val_loss = output / (num_tokenized_tokens - 1)
ppl = math.exp(min(20, val_loss))
token_ratio = (num_tokenized_tokens - 1) / (num_original_tokens - 1)
adjusted_ppl = math.exp(min(20, val_loss * token_ratio))
string += 'avg loss: {:.4E} | '.format(val_loss)
string += 'ppl: {:.4E} | '.format(ppl)
string += 'adjusted ppl: {:.4E} | '.format(adjusted_ppl)
string += 'token ratio: {} |'.format(token_ratio)
elif eval_metric == 'accuracy':
num_examples = len(data_loader.dataset)
acc = output / num_examples
string += 'number correct: {:.4E} | '.format(output)
string += 'total examples: {:.4E} | '.format(num_examples)
string += 'avg accuracy: {:.4E}'.format(acc)
else:
raise NotImplementedError('evaluation method for {} metric is not '
'implemented yet.'.format(eval_metric))
length = len(string) + 1
print('-' * length)
print(string)
print('-' * length)
def main():
"""Main program."""
args = get_args()
if args.num_layers_per_virtual_pipeline_stage is not None:
print("Interleaved pipeline schedule is not yet supported for text generation.")
exit()
if args.task == 'LAMBADA':
eval_metric = 'accuracy'
elif args.task == 'WIKITEXT103':
eval_metric = 'loss'
else:
raise NotImplementedError('{} task is not implemented.'.format(
args.task))
# Set up model and load checkpoint.
model = get_model(get_model_provider(eval_metric), wrap_with_ddp=False)
if args.load is not None:
_ = load_checkpoint(model, None, None)
assert len(model) == 1, "Above condition should have caught this"
model = model[0]
# Data stuff.
dataset = build_dataset(args.task)
dataloader = build_data_loader(dataset, args.micro_batch_size,
args.num_workers, drop_last=False)
# Run evaluation.
evaluate_and_print_results(args.task, dataloader, model, eval_metric)
print_rank_0('done :-)')
def test_import():
import megatron
import os
import os.path as osp
import pathlib
import subprocess
def recursively_lint_files():
"""Recursively lint all python files in chosen subdirectories of megatron-lm"""
try:
import autopep8
except ModuleNotFoundError:
print("Please first install autopep8 via `pip install autopep8`")
return
# get all python file paths from top level directory
file_dir = str(pathlib.Path(__file__).parent.absolute())
working_dir = osp.join(file_dir, os.pardir)
all_py_paths = set(os.path.join(working_dir, fname)
for fname in os.listdir(working_dir) if ".py" in fname)
# get all python file paths from chosen subdirectories
check_dirs = ['docker', 'megatron', 'openwebtext', 'scripts', 'tasks']
for sub_dir in check_dirs:
for path, _, fnames in os.walk(osp.join(working_dir, sub_dir)):
all_py_paths.update(set(osp.join(path, fname) for fname in fnames if ".py" in fname))
print("Linting the following: ")
for py_path in all_py_paths:
print(py_path)
command = 'autopep8 --max-line-length 100 --aggressive --in-place {}'.format(py_path)
subprocess.check_call(command)
if __name__ == "__main__":
recursively_lint_files()
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Merge model parallel partitions."""
import os
import re
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir)))
import torch
from megatron import mpu
from megatron.checkpointing import load_checkpoint, save_checkpoint
from megatron.checkpointing import ensure_directory_exists
from megatron.checkpointing import get_checkpoint_name
from megatron.checkpointing import get_checkpoint_version
from megatron.checkpointing import get_checkpoint_tracker_filename
from megatron.global_vars import set_global_variables, get_args
from megatron.global_vars import rebuild_tokenizer
def split_into_partitions(tensor, num_partitions, partition_dim, stride):
per_partition_size = mpu.utils.divide(tensor.size(partition_dim),
num_partitions)
per_partition_per_stride_size = mpu.utils.divide(per_partition_size, stride)
partitions_list = torch.split(tensor,
per_partition_per_stride_size,
dim=partition_dim)
partitions = []
for i in range(num_partitions):
partition = torch.cat(partitions_list[i::num_partitions],
dim=partition_dim)
partitions.append(partition)
return partitions
def merge_partitions(merged, partitions, partition_dim, stride):
# Number and size of each partition.
num_partitions = len(partitions)
per_partition_size = None
for partition in partitions:
if per_partition_size is None:
per_partition_size = partition.size(partition_dim)
else:
assert per_partition_size == partition.size(partition_dim)
def concat_partitions(partitions_):
with torch.no_grad():
if (per_partition_size * num_partitions) == merged.size(
partition_dim):
torch.cat(partitions_, dim=partition_dim, out=merged)
else:
print(' ***WARNING*** sizes do not match. Will cut '
'the merged partitions by {} along dimension {} '
'to reduce the size from {} to {} ...'.format(
(per_partition_size * num_partitions) - \
merged.size(partition_dim), partition_dim,
per_partition_size * num_partitions,
merged.size(partition_dim)))
merged_ = torch.cat(partitions_, dim=partition_dim)
merged_split = torch.split(merged_, merged.size(partition_dim),
dim=partition_dim)
merged_ = merged_split[0]
assert merged_.size(partition_dim) == merged.size(partition_dim)
merged.data.copy_(merged_.data)
# If stride is 1, then do simple concatination.
if stride == 1:
concat_partitions(partitions)
return
# For none unity strides, first split based on stride and then group.
per_partition_per_stride_size = mpu.utils.divide(per_partition_size, stride)
# Chunk and build a list.
chunks = None
for i, partition in enumerate(partitions):
chunk = torch.split(partition,
per_partition_per_stride_size,
dim=partition_dim)
if chunks is None:
chunks = [0]*(num_partitions*len(chunk))
chunks[i::num_partitions] = chunk
# Concatinate.
concat_partitions(chunks)
return
def get_model(model_type):
if model_type == 'BERT':
from pretrain_bert import model_provider
elif model_type == 'GPT':
from pretrain_gpt import model_provider
elif model_type == 'RACE':
from tasks.race.finetune import model_provider
elif model_type == ['MNLI', 'QQP']:
num_classes = 2
if model_type == 'MNLI':
num_classes = 3
from megatron.model.classification import Classification
def model_provider():
return Classification(num_classes=num_classes, num_tokentypes=2)
else:
raise Exception('unrecognized model type: {}'.format(model_type))
model = model_provider()
model = model.half()
return model
def get_parallel_checkpoint_name(path):
tracker_filename = get_checkpoint_tracker_filename(path)
iteration = 0
with open(tracker_filename, 'r') as f:
metastring = f.read().strip()
iteration = int(metastring)
assert iteration > 0
checkpoint_name = get_checkpoint_name(path, iteration)
return checkpoint_name, iteration
def test_split_merge():
print('testing split and merge ...')
#[QKV.ROW-COL]
tensor = torch.FloatTensor([[1.11, 1.12, 1.13, 1.14, 1.15],
[1.21, 1.22, 1.23, 1.24, 1.25],
[1.31, 1.32, 1.33, 1.34, 1.35],
[1.41, 1.42, 1.43, 1.44, 1.45],
[2.11, 2.12, 2.13, 2.14, 2.15],
[2.21, 2.22, 2.23, 2.24, 2.25],
[2.31, 2.32, 2.33, 2.34, 2.35],
[2.41, 2.42, 2.43, 2.44, 2.45],
[3.11, 3.12, 3.13, 3.14, 3.15],
[3.21, 3.22, 3.23, 3.24, 3.25],
[3.31, 3.32, 3.33, 3.34, 3.35],
[3.41, 3.42, 3.43, 3.44, 3.45]])
num_partitions = 2
partition_dim = 0
stride = 3
partitions = split_into_partitions(tensor, num_partitions,
partition_dim, stride)
merged = torch.zeros_like(tensor)
merge_partitions(merged, partitions, partition_dim, stride)
max_error = (merged - tensor).abs().max()
print(' > max error (should be zero): {}'.format(max_error))
def get_mp_merge_args(parser):
"""Provide extra arguments required for merging."""
group = parser.add_argument_group(title='mp merge')
group.add_argument('--model-type', type=str, required=True,
choices=['BERT', 'GPT', 'RACE', 'MNLI', 'QQP'],
help='Type of the mdoel.')
group.add_argument('--target-pipeline-model-parallel-size', type=int, default=1,
help='Degree of pipeline model parallelism in output model.')
return parser
def main():
# Arguments do sanity checks on the world size, but we don't care,
# so trick it into thinking we are plenty of processes
os.environ["WORLD_SIZE"] = f'{2**31}'
# Args
set_global_variables(extra_args_provider=get_mp_merge_args,
args_defaults = {'use_cpu_initialization': True,
'micro_batch_size': 1,
'no_load_optim': True,
'no_load_rng': True,
'no_save_optim': True,
'no_save_rng': True,
'save_interval': 1})
args = get_args()
if args.pipeline_model_parallel_size > 1:
print("Checkpoints with pipeline model parallelism are not currently supported.")
exit()
model_type = args.model_type
orig_tensor_model_parallel_size = args.tensor_model_parallel_size
args.tensor_model_parallel_size = 1
tokenizer = rebuild_tokenizer(args)
print('\n merging model parallel partitions ...')
print(' > number of partitions: {}'.format(orig_tensor_model_parallel_size))
print(' > checkpoint path: {}'.format(args.load))
print(' > model parameters:')
print(' number of tokens ................ {} '.format(
tokenizer.vocab_size))
print(' number of layers ................ {}'.format(args.num_layers))
print(' hidden size ..................... {}'.format(args.hidden_size))
print(' number of attention heads ....... {}'.format(
args.num_attention_heads))
print(' maximum position embeddings ..... {}'.format(
args.max_position_embeddings))
# Full model.
print('> building the full model ...')
mpu.initialize.set_tensor_model_parallel_world_size(1)
mpu.initialize.set_tensor_model_parallel_rank(0)
mpu.initialize.set_pipeline_model_parallel_world_size(1)
mpu.initialize.set_pipeline_model_parallel_rank(0)
merged_model = get_model(model_type)
# Build and load partitions.
partitions = []
iteration = 0
args.tensor_model_parallel_size = orig_tensor_model_parallel_size
tokenizer = rebuild_tokenizer(args)
mpu.initialize.set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)
for rank in range(args.tensor_model_parallel_size):
# Reset these since load_checkpoint asserts they are 0, but we are loading
# multiple checkpoints in the same process and they get set each time
args.consumed_train_samples = 0
args.consumed_valid_samples = 0
mpu.initialize.set_tensor_model_parallel_rank(rank)
checkpoint_name, iteration = get_parallel_checkpoint_name(args.load)
model_ = get_model(model_type)
print(f'> loading {checkpoint_name} ...')
load_checkpoint(model_, None, None)
print(f'> checkpoint version {get_checkpoint_version()}')
partitions.append(model_)
# Parameter generators so we can loop through them semiltaneouly.
merged_params_gen = merged_model.named_parameters()
partitions_params_gen = [partition.named_parameters()
for partition in partitions]
while True:
try:
# Get the params and check names.
name, merged_param = next(merged_params_gen)
print(' > working on {} ...'.format(name))
print(' merged type: {}, size: {}'.format(
merged_param.dtype, list(merged_param.size())))
partitions_param = []
for rank, partition_params_gen in enumerate(partitions_params_gen):
partition_name, partition_param = next(partition_params_gen)
assert partition_name == name
partitions_param.append(partition_param)
print(' partition {} type: {}, size: {}'.format(
rank, partition_param.dtype, list(partition_param.size())))
# For the non-parallel parameters, simply copy the rank 0 values.
if not hasattr(merged_param, 'tensor_model_parallel'):
print(' none-parallel parameter, simple copy from rank 0')
with torch.no_grad():
merged_param.data.copy_(partitions_param[0].data)
# For parallel parameters, merge the values
else:
dim = merged_param.partition_dim
stride = merged_param.partition_stride
print(f' parallel parameter merge with stride {stride} along '
f'dimention {dim}')
merge_partitions(merged_param,
partitions_param,
dim,
stride)
except StopIteration:
break
partitions = []
args.tensor_model_parallel_size = 1
args.pipeline_model_parallel_size = args.target_pipeline_model_parallel_size
assert args.num_layers % args.pipeline_model_parallel_size == 0, \
'num_layers must be divisible by target pipeline model parallel size'
layers_per_part = args.num_layers // args.pipeline_model_parallel_size
tokenizer = rebuild_tokenizer(args)
mpu.initialize.set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)
mpu.initialize.set_tensor_model_parallel_rank(0)
mpu.initialize.set_pipeline_model_parallel_world_size(args.pipeline_model_parallel_size)
# regex to parse out layer number from param name
layer_re = re.compile('layers\.([0-9]+)')
if args.pipeline_model_parallel_size > 1:
merged_params = {}
for name, merged_param in merged_model.named_parameters():
merged_params[name] = merged_param
for rank in range(args.pipeline_model_parallel_size):
mpu.initialize.set_pipeline_model_parallel_rank(rank)
model = get_model(model_type)
def update_layer_num(m):
# TODO! This assumes no interleaved pipeline execution
layer = int(m.group(1))
layer += rank * layers_per_part
return f'layers.{layer}'
for dst_name, partition_param in model.named_parameters():
if dst_name == "word_embeddings.weight":
# See comment in MegatronModule.initialize_word_embeddings()
src_name = "language_model.embedding.word_embeddings.weight"
else:
# Translate destination layer number (0-N for each partition)
# to source layer number (single-model layer number)
src_name = re.sub(layer_re, update_layer_num, dst_name)
print(f" > copying {src_name} to {dst_name} in rank {rank}'s model")
partition_param.data.copy_(merged_params[src_name].data)
partitions.append(model)
else:
partitions = [merged_model]
for rank, model in enumerate(partitions):
mpu.initialize.set_pipeline_model_parallel_rank(rank)
print(f"> saving rank {rank}'s model")
save_checkpoint(iteration, model, None, None)
print('done :-)')
if __name__ == '__main__':
main()
The following steps show how to prepare training dataset to train the mode.
# Libraries to install
```
pip install ftfy langdetect numpy torch pandas nltk sentencepiece boto3 tqdm regex bs4 newspaper3k htmlmin tldextract
git clone https://github.com/mattilyra/LSH
cd LSH
python setup.py install
```
# Download the dataset
1. Download the deduplicated URLs from [jcpeterson](https://mega.nz/#F!EZZD0YwJ!9_PlEQzdMVLaNdKv_ICNVQ!cc4RgQQZ)
2. Remove blacklisted URLs.
```
python blacklist_urls.py <path to the dowloaded deduplicated URLs> <filename for clean urls. e.g. clean_urls.txt>
```
3. Download the content from the clean urls with [openwebtext's utilities](https://github.com/eukaryote31/openwebtext/blob/master/download.py).
4. Merge the contents into one loose json file with 1 json per newline of the format `{'text': text, 'url': unique_url}`. It is important for the url to be unique.
# Prepare the data for GPT training:
1. Perform ftfy, english detection and remove documents with less than 128 tokens. This step can be sharded and run on shards.
```
python cleanup_dataset.py <input data file> <output cleaned data filename>
```
Additional cleanup (e.g. remove documents less than 512 characters or dataset specific cleaning like stories, realnews datasets) can be done using `cleanup_fix_dataset.py`. More details can be found by running `python cleanup_fix_dataset.py --help`.
2. Using LSH, find possible duplicates and store then in a file for later processing. The code supports saving and loading fingerprints for recurrent deduplications, and is also multithreaded for faster processing. More details are can be found by `python find_duplicate.py --help`.
```
python find_duplicates.py --inputs <pairlist list of input cleaned data files and keys, e.g. cc.json cc_id news.json news_id> --output <output possible duplicate urls filename>
```
3. Based on similarity measure defind inside function `is_similar` (default: 0.9), group urls that are similar. Basically, for each group, only one url we should keep and remove the rest.
```
python group_duplicate_urls.py <possible duplicate urls file> <output file containing similar urls>
```
4. Remove similar documents that were detected in the last step.
```
python remove_group_duplicates.py <file containing simialr documents> <cleaned data file> <outputfile containing deduplicate data>
```
5. Shuffle the dataset.
```
shuf <cleaned deduped data file> -o train_data.json
```
# Deduplicating ngrams
To deduplicate the downstream tasks (e.g. lambada, squad) from the training dataset, we run the following command.
```
python filter_ngrams.py --tasks <name of the task, e.g. lambada, squad> --dedup-dataset <training dataset to deduplicate> <json key> --output <output training dataset>
```
We use 13-grams by default for the deduplication. When we find a 13-gram match in a training document, we split the document into two pieces and remove the 13-gram along with 200 characters from the both side of the 13-gram. We also remove any splitted document with less than 200 characters or if a document got splitted more than 10 times. These parameters can be changed using corresponding arguments.
Only for the lambada task, we need to provide the path, `--lambada-path <path of the lambada test data>`.
Several other features (e.g. save and load dictionary) have been added, look at `python filter_ngrams.py --help` for details.
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import os
import time
"""
This code adds id to each json object in a json file. User can add prefix
to the ids.
"""
if __name__ == '__main__':
print('parsing the arguments ...')
parser = argparse.ArgumentParser()
parser.add_argument('--input-file', type=str, default=None, help='Input'\
' json file where id needs to be added')
parser.add_argument('--output-file', type=str, default=None, help=\
'Output file name with id')
parser.add_argument('--id-prefix', type=str, default=None, help=\
'Id prefix')
parser.add_argument('--log-interval', type=int, default=100,
help='Log interval')
args = parser.parse_args()
print('Adding ids to dataset ...')
f_input = open(args.input_file, 'r', encoding='utf-8')
f_output = open(args.output_file, 'wb')
unique_ids = 1
start_time = time.time()
for row in f_input:
each_row = json.loads(row)
adlr_id_string = args.id_prefix + '-{:010d}'.format(int(unique_ids))
each_row['adlr_id'] = adlr_id_string
myjson = json.dumps(each_row, ensure_ascii=False)
f_output.write(myjson.encode('utf-8'))
f_output.write('\n'.encode('utf-8'))
if unique_ids % args.log_interval == 0:
print(' processed {:9d} documents in {:.2f} seconds ...'.format( \
unique_ids, time.time() - start_time), flush=True)
unique_ids += 1
# Close the file.
f_input.close()
f_output.close()
print('done :-)', flush=True)
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import re
import time
import tldextract
import sys
# List of the domains to blacklist.
domain_blacklist = set([
'500px',
'aapks',
'akamaihd',
'amazon',
'apple',
'artifactfire',
'artstation',
'awwni',
'bandcamp',
'battleforthenet',
'coinscalendar',
'dailymotion',
'deviantart',
'discord',
'discordapp',
'dlapkandroid',
'dropbox',
'e621',
'ebay',
'edealinfo',
'erome',
'eroshare',
'explosm',
'facebook',
'fbcdn',
'flickr',
'furaffinity',
'futhead',
'gatopardo',
'gfycat',
'gifsound',
'gifsoup',
'giphy',
'github',
'google',
'gunprime',
'gyazo',
'hotdealstar',
'imagefap',
'imageshack',
'imgflip',
'imgur',
'instagram',
'karmadecay',
'kryptocal',
'kym-cdn',
'liveleak',
'livememe',
'lmgtfy',
'magaimg',
'memegenerator',
'minorplanetcenter',
'minus',
'mobafire',
'morejpeg',
'nocookie',
'pcpartpicker',
'photobucket',
'pinimg',
'pinterest',
'pixiv',
'pornhub',
'prntscr',
'puu',
'qkme',
'quickmeme',
'radd',
'redd',
'reddit',
'reddit-stream',
'redditlog',
'redditmedia',
'reddituploads',
'redtube',
'reupp',
'reverb',
'roanoke',
'rollingstone',
'sli',
'soundcloud',
'soundgasm',
'spankbang',
'spotify',
'strawpoll',
'streamable',
'timeanddate',
'tinypic',
'touhouradio',
'tumblr',
'twimg',
'twitch',
'twitter',
'vid',
'vimeo',
'vine',
'vkaao',
'vocaroo',
'voyagefusion',
'walmart',
'wciu',
'wikimedia',
'wikipedia',
'xhamster',
'xkcd',
'xvideos',
'youtu',
'youtube',
'youtubedoubler',
'ytimg',
'zillexplorer',
])
def domain_is_in_blacklist(url):
domain = tldextract.extract(url).domain
return domain in domain_blacklist
# List of extentions to blacklist.
extentions_blacklist = (
'.3gp',
'.7z'
'.ai',
'.aif',
'.apk',
'.app',
'.avi',
'.bin',
'.bmp',
'.bz2',
'.css',
'.csv',
'.dat',
'.deb',
'.dmg',
'.doc',
'.docx',
'.exe',
'.gif',
'.gifv',
'.gz',
'.iso',
'.jar',
'.jpeg',
'.jpg',
'.js',
'.log',
'.mid',
'.midi',
'.mkv',
'.mov',
'.mp3',
'.mp4',
'.mpeg',
'.mpg',
'.ogg',
'.ogv',
'.otf',
'.pdf',
'.pkg',
'.png',
'.pps',
'.ppt',
'.pptx',
'.psd',
'.py',
'.qt',
'.ram',
'.rar',
'.sql',
'.svg',
'.swf',
'.tar.gz',
'.tar',
'.tgz',
'.tiff',
'.ttf',
'.txt',
'.wav',
'.webm',
'.wma',
'.wmv',
'.xls',
'.xlsx',
'.xml',
'.xz',
'.zip',
)
def extention_is_in_blacklist(url):
if url.split('?')[0].lower().endswith(extentions_blacklist):
return True
return False
# Malformed urls.
# This function is adapted from:
# https://stackoverflow.com/questions/7160737/python-how-to-validate-a-url-in-python-malformed-or-not
url_regex = re.compile(
r'^(?:http)s?://' # http:// or https://
r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|' #domain...
r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # ...or ip
r'(?::\d+)?' # optional port
r'(?:/?|[/?]\S+)$', re.IGNORECASE)
def url_is_malformed(url):
return re.match(url_regex, url) is None
def print_progress(prefix, start_time, urls_counter,
domain_blacklist_counter,
extention_blacklist_counter,
short_url_counter, malformed_url_counter,
duplicate_url_counter):
string = prefix + ' | '
string += 'time elapsed (s): {:.2f} | '.format(time.time() - start_time)
string += 'number of urls: {} | '.format(urls_counter)
string += 'domain blacklisted: {} | '.format(domain_blacklist_counter)
string += 'extention blacklisted: {} | '.format(extention_blacklist_counter)
string += 'short urls (<=8): {} | '.format(short_url_counter)
string += 'malformed urls: {} | '.format(malformed_url_counter)
string += 'duplicate urls: {}'.format(duplicate_url_counter)
print(string, flush=True)
if __name__ == '__main__':
print('remove blacklisted urls ..')
# Path to the url files.
path = sys.argv[1]
# Output url file.
output = sys.argv[2]
# Get the list of url files.
files = glob.glob(path + '/*.txt')
print('> found {} files'.format(len(files)))
urls = set()
urls_counter = 0
domain_blacklist_counter = 0
extention_blacklist_counter = 0
short_url_counter = 0
malformed_url_counter = 0
duplicate_url_counter = 0
start_time = time.time()
for filename in files:
with open(filename, 'r') as f:
for line in f:
url = line.strip()
urls_counter += 1
if domain_is_in_blacklist(url):
print('[DOMAIN BLACKLIST]: {}'.format(url), flush=True)
domain_blacklist_counter += 1
elif extention_is_in_blacklist(url):
print('[EXTENTION BLACKLIST]: {}'.format(url), flush=True)
extention_blacklist_counter += 1
elif len(url) <= 8:
print('[SHORT URL]: {}'.format(url), flush=True)
short_url_counter += 1
elif url_is_malformed(url):
print('[MALFORMED URL]: {}'.format(url), flush=True)
malformed_url_counter += 1
elif url in urls:
print('[DUPLICATE URL]: {}'.format(url), flush=True)
duplicate_url_counter += 1
else:
urls.add(url)
if urls_counter % 100000 == 0:
print_progress('PROGRESS', start_time, urls_counter,
domain_blacklist_counter,
extention_blacklist_counter,
short_url_counter, malformed_url_counter,
duplicate_url_counter)
print_progress('FINAL', start_time, urls_counter,
domain_blacklist_counter,
extention_blacklist_counter,
short_url_counter, malformed_url_counter,
duplicate_url_counter)
# Write the final set of urls.
print('> writing cleaned up url list to {}'.format(output))
with open(output, 'w') as f:
for url in urls:
f.write(url + '\n')
print('done :-)')
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ftfy
import json
from langdetect import detect
import numpy as np
import time
import os
import sys
from tokenizer import Tokenizer
MIN_DOCUMENT_LENGHT = 128
def print_progress(prefix, start_time, num_docs, num_fixed_text,
num_non_english_docs, chars_non_english_docs,
num_small_docs, chars_small_docs):
string = prefix + ' | '
string += 'elapsed time: {:.2f} | '.format(time.time() - start_time)
string += 'documents: {} | '.format(num_docs)
string += 'fixed text: {} | '.format(num_fixed_text)
string += 'non-english: {} | '.format(num_non_english_docs)
string += 'non-english chars: {} | '.format(chars_non_english_docs)
string += 'small docs: {} | '.format(num_small_docs)
string += 'small docs chars: {}'.format(chars_small_docs)
print(string, flush=True)
def filter_corpus(filename, out_filename, print_interval=10000):
print(' > filtering {}'.format(filename))
tokenizer = Tokenizer(cache_dir='./cache')
num_docs = 0
num_written_docs = 0
num_small_docs = 0
num_fixed_text = 0
num_non_english_docs = 0
chars_non_english_docs = 0
chars_small_docs = 0
start_time = time.time()
with open(out_filename, 'wb') as f:
with open(filename, 'r') as fin:
for line in fin:
try:
num_docs += 1
myjson = json.loads(line)
# Fix text
text = ftfy.fix_text(myjson['text'])
if text != myjson['text']:
num_fixed_text += 1
myjson['text'] = text
# Detect language.
if detect(text) != 'en':
print('[non-english text]', myjson)
num_non_english_docs += 1
chars_non_english_docs += len(text)
continue
# On average each token is 5 characters so 8 is an
# upper bound.
if len(text) < (8 * MIN_DOCUMENT_LENGHT):
tokens = tokenizer.tokenize_document(text)
if len(tokens) < MIN_DOCUMENT_LENGHT:
print('[small document, skipping]:', myjson)
num_small_docs += 1
chars_small_docs += len(text)
continue
myjson = json.dumps(myjson, ensure_ascii=False)
f.write(myjson.encode('utf-8'))
f.write('\n'.encode('utf-8'))
num_written_docs += 1
if num_docs % print_interval == 0:
print_progress('[PROGRESS]', start_time, num_docs,
num_fixed_text, num_non_english_docs,
chars_non_english_docs,
num_small_docs, chars_small_docs)
except Exception as e:
print(' skipping ', line, e)
print_progress('[FINAL]', start_time, num_docs,
num_fixed_text, num_non_english_docs,
chars_non_english_docs,
num_small_docs, chars_small_docs)
if __name__ == '__main__':
print('building gpt2 dataset ...')
input_filename = sys.argv[1]
output_filename = sys.argv[2]
print('will be reading {}'.format(input_filename))
print('and will write the results to {}'.format(output_filename))
filter_corpus(input_filename, output_filename)
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Filter and clean documents:
Capable to clean docs with less than 512 characters, less than
256 characters and contains javascript, fix text and dataset specific
cleaning like stories and realnews datasets.
Program arguments have the details.
"""
import argparse
from functools import partial
import glob
import ftfy
import json
from langdetect import detect
import multiprocessing
import os
from pathlib import Path
import re
import time
def process_doc(json_line, args):
# Read the line.
document = json.loads(json_line)
text = document['text']
output = {'remove_512': False, 'remove_256_javascript': False, \
'remove_512_non_english': False, 'ftfy_fix_text': False, \
'general_cleaning': False}
try:
# Reomove all docs with less than 512 characters
if "remove_512" in args.tasks:
if len(text) < 512:
output['remove_512'] = True
return output, text, document, True
# Remove docs if less than 256 character length and contains Javascript
if "remove_256_javascript" in args.tasks:
if len(text) < 256 and 'javascript' in text.lower():
output['remove_256_javascript'] = True
return output, text, document, True
# Remove docs < 512 and nonenglish
if "remove_512_non_english" in args.tasks:
if len(text) < 512 and detect(text) != 'en':
output['remove_512_non_english'] = True
return output, text, document, True
# Fix the text using ftfy, don't remove the text, hence return False
if "ftfy_fix_text" in args.tasks:
fixed_text = ftfy.fix_text(text)
output['ftfy_fix_text'] = True
return output, fixed_text, document, False
# Cleaning extra spaces and newlines
if "general_cleaning" in args.tasks:
cleaned_text = re.sub(r" +|\b\n+ |\b\n+", " ", text)
#cleaned_text = re.sub(r"\n\n+", "\n\n", text) # used this for Gutenberg dataset
#cleaned_text = re.sub(r"\n", "\n\n", text) # Used this for realnews
# stories datasets
#cleaned_text = re.sub(r" \'", "'", text)
#cleaned_text = re.sub(r" \!", "!", cleaned_text)
#cleaned_text = re.sub(r" \.", ".", cleaned_text)
#cleaned_text = re.sub(r" \?", "?", cleaned_text)
#cleaned_text = re.sub(r" - ", "-", cleaned_text)
##cleaned_text = re.sub(r"\" ", "\"", cleaned_text)
#cleaned_text = re.sub(r" @ ", "@", cleaned_text)
output['general_cleaning'] = True
return output, cleaned_text, document, False
except Exception as e:
print('Error: *************************\n{}\ntext: {}'.format(e, \
text), flush=True)
return output, text, document, True
# don't remove
return output, text, document, False
def process_set(args, input_file, output_f_cleaned, output_f_filtered):
print(' > working on {} ...'.format(input_file), flush=True)
num_docs = num_remove_512 = num_remove_java = num_remove_512_non_english \
= num_ftfy_fix_text = num_general_cleaning = 0
# Output file and counters.
output_cleaned = open(output_f_cleaned, 'wb')
output_filtered = open(output_f_filtered, 'wb')
start_time = time.time()
# Setup multi-processing.
num_workers = 40
fin = open(input_file, 'r', encoding='utf-8')
pool = multiprocessing.Pool(num_workers)
process_doc_partial = partial(process_doc, args=args)
processed_docs = pool.imap(process_doc_partial, fin, 500)
# Process documents.
for output, text, document, to_filter in processed_docs:
num_docs += 1
num_remove_512 += 1 if output['remove_512'] else 0
num_remove_java += 1 if output['remove_256_javascript'] else 0
num_remove_512_non_english += 1 if output['remove_512_non_english'] \
else 0
num_ftfy_fix_text += 1 if output['ftfy_fix_text'] else 0
num_general_cleaning += 1 if output['general_cleaning'] else 0
document['text'] = text
myjson = json.dumps(document, ensure_ascii=False)
if to_filter:
output_filtered.write(myjson.encode('utf-8'))
output_filtered.write('\n'.encode('utf-8'))
else:
output_cleaned.write(myjson.encode('utf-8'))
output_cleaned.write('\n'.encode('utf-8'))
if num_docs % args.log_interval == 0:
print(' processed {:9d} documents in {:.2f} seconds ...'.format(
num_docs, time.time() - start_time), flush=True)
# Close the file.
output_cleaned.close()
output_filtered.close()
fin.close()
# Print stats.
print(' >> total docs: {} remove_512 {} remove_256_javascript {} '\
'remove_512_non_english {} ftfy_fix_text {} general_cleaning {}'.\
format(num_docs, num_remove_512, num_remove_java,\
num_remove_512_non_english, num_ftfy_fix_text, \
num_general_cleaning), flush=True)
if __name__ == '__main__':
print('parsing the arguments ...')
parser = argparse.ArgumentParser()
parser.add_argument('--input-files', nargs = '*', required=True, default=\
None, help = 'Input json files that needs to be'\
' cleaned')
parser.add_argument('--tasks', nargs = '*', required=True, default=None,\
help = 'Tasks to perform on the input files, ' \
'such as remove_512, remove_256_javascript, ' \
'remove_512_non_english, ftfy_fix_text, and ' \
'general_cleaning. 256 or 512 means the number' \
' of characters.')
parser.add_argument('--output-path', type=str, default=None,
help='Directory where the output should go')
parser.add_argument('--log-interval', type=int, default=100,
help='Log interval')
args = parser.parse_args()
print('cleanup dataset ...')
for input_file in args.input_files:
input_filename, input_filename_ext = os.path.splitext(Path(input_file)\
.name)
output_f_cleaned = os.path.join(args.output_path, input_filename + \
"_cleaned" + input_filename_ext)
output_f_filtered = os.path.join(args.output_path, input_filename + \
"_filtered" + input_filename_ext)
process_set(args, input_file, output_f_cleaned, output_f_filtered)
print('done :-)', flush=True)
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Deduplicate downstream tasks from training dataset. 13-grams have been used.
All split documents with less than 200 characters got filtered. Any document
with more than 10 splits got filtered as well.
"""
import argparse
from functools import partial
import json
import multiprocessing
import nltk
import pickle
import re
import string
import sys
import time
def get_words(text):
# get all the lowercase words from text
words, positions = [], []
for match in re.finditer(r'\w+', text.lower()):
words.append(match.group(0))
positions.append(match.start())
return words, positions
# splits the text
def split_text(text, start_position, remove_char_each_side, seq):
# first part of the text
punctuations = ".!?"
pos = start_position - remove_char_each_side
text_first = ""
while pos > 0 and not text[pos] in punctuations:
pos -= 1
if pos > 0:
text_first = text[0:pos+1]
# add length of seq and remove_char_each_side
pos = start_position + len(seq) + remove_char_each_side
# last part of the text
text_second = ""
while pos < len(text) and not text[pos] in punctuations:
pos += 1
if pos + 1 < len(text):
text_second = text[pos+1:len(text)]
return text_first, text_second
def check_and_clean_text(args, words, ngrams, text, start_position, \
text_buf_ngram_free, text_buf, local_ngram):
seq = " ".join(words)
if seq in ngrams:
print(" [matched]: {}".format(seq), flush=True)
if args.get_ngram_freq_only:
# increase freq of this seq and then only consider the later part
# of the text for further processing
if seq in local_ngram:
local_ngram[seq] += 1
else:
local_ngram[seq] = 1
#print(" [increased]: {} {}".format(seq, ngrams[seq]), flush=True)
if (start_position + len(seq) + 1) < len(text):
text_buf.append(text[start_position + len(seq) + 1:len(text)])
return False
# split the text
text_first, text_second = split_text(text, start_position, \
args.remove_char_each_side, seq)
# first part of ngrams free
if len(text_first) > args.filter_text_char_len:
text_buf_ngram_free.append(text_first)
# add second part for further processing
if len(text_second) > args.filter_text_char_len:
text_buf.append(text_second)
return False # not ngram free
# ngram free
return True
def free_ngram(line, args, key, ngrams, ngrams_freq_sorted):
# remove all the ngrams
try:
myjson = json.loads(line)
text_buf = [myjson[key]]
except Exception as e:
print("Error: {}".format(e), flush=True)
text_buf = []
text_buf_ngram_free = []
local_ngram = {}
while len(text_buf) > 0:
# get the first one from the buffer
text = text_buf.pop(0)
words, positions = get_words(text)
ngram_free = True
# find each max n-grams and check dictionary
for i in range(len(words) - args.max_ngram_size + 1):
check_ngram_free = check_and_clean_text(args, words[i:\
i+args.max_ngram_size], ngrams, text, positions[i], \
text_buf_ngram_free, text_buf, local_ngram)
# the seq is ngram free? if yes, break
if not check_ngram_free:
ngram_free = False
break
# if max ngrams doesn't match, check if any other lower n-grams
# within max ngram macthes
for ngram_len, _ in ngrams_freq_sorted:
check_ngram_free = check_and_clean_text(args, words[i:\
i+ngram_len], ngrams, text, positions[i], \
text_buf_ngram_free, text_buf, local_ngram)
# same check as above
if not check_ngram_free:
ngram_free = False
break
# check break from lower than max ngram loop above
if not ngram_free:
break
# for the last max n-gram, check all the lower ngrams in it
if ngram_free and len(words) - args.max_ngram_size > 0:
# get the last words of the lax max ngram
last_seq_words = words[(len(words)-args.max_ngram_size):len(words)]
last_seq_start_position = len(words) - args.max_ngram_size
# check all n-grams lower than the max
for pos, (ngram_len, _) in enumerate(ngrams_freq_sorted):
# ignore the max ngram as has been considered already
if ngram_len == args.max_ngram_size:
continue
# find each ngram of ngram_len in max n-grams and check
for i in range(len(last_seq_words) - ngram_len + 1):
check_ngram_free = check_and_clean_text(args, \
last_seq_words[i:i+ngram_len], ngrams, text,\
positions[last_seq_start_position+i], \
text_buf_ngram_free, text_buf, local_ngram)
if not check_ngram_free:
ngram_free = False
break
if not ngram_free:
break
# texts are ngram free
if ngram_free and not args.get_ngram_freq_only:
text_buf_ngram_free.append(text)
# check if the text has only been trimmed
trimmed = 0
if not args.get_ngram_freq_only and len(text_buf_ngram_free) == 1 and \
len(text_buf_ngram_free[0]) < len(myjson[key]):
trimmed = 1
return text_buf_ngram_free, trimmed, myjson, local_ngram
# insert word sequence into dictionary
def insert_dict(words, ngrams, pos):
seq = " ".join(words)
if seq not in ngrams:
ngrams[seq] = 0
#ngrams[seq] = pos
# insert each ngram from text into the ngrams dictionary
def compute_ngrams_insert_dict(args, text, ngrams):
words, positions = get_words(text)
if len(words) < args.min_ngram_size:
return
if len(words) < args.max_ngram_size:
insert_dict(words, ngrams, positions[0])
for i in range(len(words) - args.max_ngram_size+1):
insert_dict(words[i:i+args.max_ngram_size], ngrams, positions[i])
# Build ngrams for the lambada dataset
def process_task_lambda(args, task_file, ngrams):
print(' reading from {} and computing ngrams'.format(task_file))
with open(task_file, 'r') as f:
for line in f:
try:
myjson = json.loads(line)
text = myjson['text']
compute_ngrams_insert_dict(args, text, ngrams)
except Exception as e:
print('Error:', e)
print(" Entities in ngrams {}".format(len(ngrams)), flush=True)
# Build ngrams for the dataset of the given task
def process_task(args, task_name, ngrams):
print(' reading from {} and computing ngrams'.format('import datasets'))
print(" Current entities in ngrams {}".format(len(ngrams)), flush=True)
# using validation/test data from datasets
from datasets import load_dataset
entities_in_ngrams = len(ngrams)
# load the dataset
if task_name == 'squad':
dataset = load_dataset('squad_v2', split='validation')
elif task_name == 'natural_questions':
dataset = load_dataset('natural_questions', split='validation')
elif task_name == 'triviaqa':
dataset = load_dataset('trivia_qa', 'unfiltered', split='test')
elif task_name == 'webqa':
dataset = load_dataset('web_questions', split='test')
elif task_name == 'race':
dataset = load_dataset('race', 'all', split='test')
elif task_name == 'drop':
dataset = load_dataset('drop', split='validation')
elif task_name == 'coqa':
dataset = load_dataset('coqa', split='validation')
elif task_name == 'piqa':
dataset = load_dataset('piqa', split='test')
else:
print("Invalid task name: {}".format(task_name), flush=True)
return
# read the dataset and add to ngrams
for line in dataset:
try:
if task_name in ['squad', 'triviaqa', 'webqa', 'race', 'drop']:
text = line['question']
compute_ngrams_insert_dict(args, text, ngrams)
elif task_name == 'natural_questions':
text = line['question']['text']
compute_ngrams_insert_dict(args, text, ngrams)
elif task_name == 'coqa':
all_questions = line['questions']
for question in all_questions:
compute_ngrams_insert_dict(args, question, ngrams)
elif task_name == 'piqa':
text = line['goal']
compute_ngrams_insert_dict(args, text, ngrams)
except Exception as e:
print('Error:', e)
print(" After task {} entities in ngrams {}, added {}".format(task_name, \
len(ngrams), len(ngrams) - entities_in_ngrams), flush=True)
def compute_tasks_ngrams(args, ngrams):
start_time = time.time()
for _, task_name in enumerate(args.tasks):
print('Task: {}'.format(task_name), flush=True)
if task_name == 'lambada':
assert args.lambada_path is not None
process_task_lambda(args, args.lambada_path, ngrams)
else:
process_task(args, task_name, ngrams)
print(" Taken time to compute ngrams {:.2f}".format(time.time() - \
start_time), flush=True)
def compute_ngram_freq_sorted(args, ngrams):
ngrams_freq = {}
for ngram_key in ngrams.keys():
length = len(ngram_key.split())
ngrams_freq[length] = ngrams_freq[length] + 1 if length in \
ngrams_freq else 1
ngrams_freq_sorted = sorted(ngrams_freq.items(), key=lambda item: item[0])
print(" Ngram frequencies: {}".format(ngrams_freq_sorted), flush=True)
print(" Entities in ngrams {} min_ngram_size {} max_ngram_size {}".format(\
len(ngrams), ngrams_freq_sorted[0][0], ngrams_freq_sorted[len(\
ngrams_freq_sorted) -1 ][0]), flush=True)
return ngrams_freq_sorted
def get_ngrams_below_threshold(args, ngrams, ngrams_below_threshold, \
dedup_file, dedup_key, ngrams_freq_sorted):
start_time = time.time()
# get the ngrams frequency
args.get_ngram_freq_only = True
# Open the large file to process in parallel
num_workers = args.num_threads
pool = multiprocessing.Pool(num_workers)
fin = open(dedup_file, 'r', encoding='utf-8')
free_ngram_abt_partial=partial(free_ngram, args=args, key=dedup_key, \
ngrams=ngrams, ngrams_freq_sorted=ngrams_freq_sorted)
free_ngrams_abt = pool.imap(free_ngram_abt_partial, fin, 500)
counter = 0
for _, _, _, local_ngram in free_ngrams_abt:
counter += 1
if counter % 1000 == 0:
print(' [compute_stat]> processed {} documents in {:.2f} seconds ...'.
format(counter, time.time() - start_time), flush=True)
for local_key in local_ngram:
if local_key in ngrams:
ngrams[local_key] += 1
local_ngram = {}
print(' Time taken to compute statistics {:.2f} seconds'.format(time.time() - \
start_time), flush=True)
pool.close()
pool.join()
start_time = time.time()
counter_threshold = 0
# Get ngram below theadhold
for local_key, local_val in ngrams.items():
if ngrams[local_key] < args.key_threshold:
print(" [threshold] {} {}".format(local_key, local_val), flush=True)
counter_threshold += 1
ngrams_below_threshold[local_key] = 1
print(' Ngrams below threshold {}'.format(counter_threshold), flush=True)
fin.close()
def clean_ngrams_below_threshold(args, ngrams_below_threshold, dedup_file, \
dedup_key):
start_time = time.time()
# Now actually filter the dataset
args.get_ngram_freq_only = False
#id_prefix = '-'.join(args.tasks[::2])
id_prefix = '-'.join(args.tasks[::1])
# get the range of the size of the ngrams
ngrams_freq_sorted = compute_ngram_freq_sorted(args, ngrams_below_threshold)
# Open the large file to process in parallel
counter = splitted = ignored = split_mt_thld = trimmed_count = 0
num_workers = args.num_threads
pool = multiprocessing.Pool(num_workers)
fin = open(dedup_file, 'r', encoding='utf-8')
free_ngram_clean_partial=partial(free_ngram, args=args, key=dedup_key, \
ngrams=ngrams_below_threshold, ngrams_freq_sorted=ngrams_freq_sorted)
free_ngrams_clean = pool.imap(free_ngram_clean_partial, fin, 500)
out_f = open(args.output, 'wb')
for text_buf_ngram_free, trimmed, myjson, _ in free_ngrams_clean:
counter += 1
try:
trimmed_count += trimmed
if len(text_buf_ngram_free) > 1:
splitted += 1
if len(text_buf_ngram_free) == 0:
ignored += 1
# more than 10 splits ignored
if len(text_buf_ngram_free) > args.splits_count:
text_buf_ngram_free = []
split_mt_thld += 1
if args.output is not None:
if "split_id" in myjson:
use_prefix = myjson["split_id"] + "-"
else:
use_prefix = ""
for i in range(len(text_buf_ngram_free)):
split_id_string = id_prefix + '-{:010d}'.format(int(\
counter)) + '-{:04d}'.format(int(i))
myjson[dedup_key] = text_buf_ngram_free[i]
myjson["split_id"] = use_prefix + split_id_string
outjson = json.dumps(myjson, ensure_ascii=False)
#outjson = json.dumps({"text":text_buf_ngram_free[i],
# id_prefix+"_split_id":split_id_string},
# ensure_ascii=False)
out_f.write(outjson.encode('utf-8'))
out_f.write('\n'.encode('utf-8'))
if counter % 1000 == 0:
print(' [final]> processed {} documents in {:.2f} seconds ...'.
format(counter, time.time() - start_time), flush=True)
except Exception as e:
print('Error:', e)
print(' [final]> processed {} documents in {:.2f} seconds ...'.
format(counter, time.time() - start_time), flush=True)
print(' Total docs {} splitted {} ignored {} splits > theshold {} trimmed'\
' {}'.format(counter, splitted, ignored, split_mt_thld, trimmed_count)\
, flush=True)
pool.close()
pool.join()
out_f.close()
fin.close()
if __name__ == '__main__':
# we use 13-grams, any text less than 200 characters got removed
# any text splitted more than 10 got removed as well
print('parsing the arguments ...')
parser = argparse.ArgumentParser()
parser.add_argument('--tasks', nargs = '*', required=True, default=None, \
help = 'Tasks to use for deduplication: currently '
' suuport [lambada, squad, natural_questions,'
' triviaqa, webqa, race, drop, coqa, and piqa]')
parser.add_argument('--lambada-path', type=str, default=None,
help='Only Lambada task needs the path')
parser.add_argument('--dedup-dataset', nargs = '*', default=None,
help='Dataset to deduplicate with the key to use'
' e.g. cc.json text')
parser.add_argument('--output', type=str, default=None,
help='Output file name to save dedup dataset')
parser.add_argument('--num-threads', type=int, default=40,
help='Number of threads to use')
# Default dedup values
parser.add_argument('--max-ngram-size', type=int, default=13,
help='Maximum size of ngram to use.')
parser.add_argument('--min-ngram-size', type=int, default=8,
help='Minimum size of ngram to use.')
parser.add_argument('--filter-text-char-len', type=int, default=200,
help='Remove any text below this length.')
parser.add_argument('--key-threshold', type=int, default=10,
help='Number of keys to consider as threshold')
parser.add_argument('--save-dictionary', type=str, default=None,
help='Save the dictionary')
parser.add_argument('--load-dictionary', type=str, default=None,
help='Load the dictionary')
parser.add_argument('--splits-count', type=int, default=10,
help='Remove any documents more than this many splits')
parser.add_argument('--remove-char-each-side', type=int, default=200,
help='Maximum size of ngram to use.')
args = parser.parse_args()
assert len(args.dedup_dataset) == 2
dedup_file = args.dedup_dataset[0]
dedup_key = args.dedup_dataset[1]
# Setup multi-processing
num_workers = args.num_threads
if args.load_dictionary is None:
# Build ngrams
ngrams = {}
compute_tasks_ngrams(args, ngrams)
# get the range of the size of the ngrams
ngrams_freq_sorted = compute_ngram_freq_sorted(args, ngrams)
# get ngram freq from large file in parallel
# get ngrams below threshold
ngrams_below_threshold = {}
get_ngrams_below_threshold(args, ngrams, ngrams_below_threshold, \
dedup_file, dedup_key, ngrams_freq_sorted)
# save the dictionary if needed
if args.save_dictionary is not None:
with open(args.save_dictionary, 'wb') as save_dict_handle:
pickle.dump(ngrams_below_threshold, save_dict_handle)
else:
with open(args.load_dictionary, 'rb') as load_dict_handle:
ngrams_below_threshold = pickle.load(load_dict_handle)
# filter the large file
if args.output is not None:
clean_ngrams_below_threshold(args, ngrams_below_threshold, \
dedup_file, dedup_key)
print('done :-)')
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from functools import partial
import itertools
import json
from lsh import cache, minhash
import multiprocessing
import numpy as np
import time
import pickle
import sys
import os
# This function is adapted from:
# https://github.com/mattilyra/LSH/blob/master/examples/Introduction.ipynb
def shingles(text, char_ngram=5):
return set(text[head:head + char_ngram]
for head in range(0, len(text) - char_ngram))
# This function is adapted from:
# https://github.com/mattilyra/LSH/blob/master/examples/Introduction.ipynb
def jaccard(set_a, set_b, args):
if len(set_a) < 1 or len(set_b) < 1:
return 0.0
intersection = set_a & set_b
union = set_a | set_b
if args.jaccard == 'min':
return len(intersection) / min(len(set_a), len(set_b))
elif args.jaccard == 'max':
return len(intersection) / max(len(set_a), len(set_b))
else:
return len(intersection) / len(union)
def compute_fingerprint(line, key):
try:
myjson = json.loads(line)
url = myjson[key]
text = myjson['text']
fingerprint = hasher.fingerprint(text)
except Exception as e:
print('Error:', e)
return None, None, None, False
return url, text, fingerprint, True
def url_pairs_to_remove(args, bucket_urls, url_doc):
remove_urls_list = []
deduped_local, counter_local = 0, 0
iteration = 0
while len(bucket_urls) > 1:
if args.heuristic_iter != -1 and \
iteration == args.heuristic_iter:
break
items = list(bucket_urls)
remove_urls = []
main_url = items[np.random.randint(0, len(items))]
main_dhingles = shingles(url_doc[main_url])
for i in range(0, len(items)):
counter_local += 1
other_url = items[i]
if other_url == main_url:
continue
other_shingles = shingles(url_doc[other_url])
try:
jaccard_sim = jaccard(main_dhingles, other_shingles, args)
except Exception as e:
print('Error:', e)
jaccard_sim = 0.0
if jaccard_sim > 0.5:
remove_urls.append({other_url: jaccard_sim})
deduped_local += 1
bucket_urls.remove(other_url)
bucket_urls.remove(main_url)
if len(remove_urls) > 0:
remove_urls_list.append({main_url: remove_urls})
iteration += 1
return remove_urls_list, deduped_local, counter_local
def write_remove_urls_list(remove_urls_list, f_out):
if len(remove_urls_list) > 0:
for each_url_remove in remove_urls_list:
myjson = json.dumps(each_url_remove, ensure_ascii=False)
f_out.write(myjson.encode('utf-8'))
f_out.write('\n'.encode('utf-8'))
def compute_jaccard(each_bin, num_bins, start_time_local):
remove_urls_list = []
deduped_local, counter_local, bucket_local = 0, 0, 0
for bucket_id in each_bin:
bucket_local += 1
if os.getpid() % num_bins == 0 and bucket_local % 100000 == 0:
print("Counter {}, progress {:.2f} time {:.2f}".\
format(bucket_local, float(bucket_local)/float(len(each_bin)),\
time.time() - start_time_local), flush=True)
if len(each_bin[bucket_id]) <= 1:
continue
bucket_urls = each_bin[bucket_id].copy()
remove_urls_list_sub, deduped_local_sub, counter_local_sub = \
url_pairs_to_remove(args, bucket_urls, url_doc)
deduped_local += deduped_local_sub
counter_local += counter_local_sub
if len(remove_urls_list_sub) > 0:
remove_urls_list.extend(remove_urls_list_sub)
return remove_urls_list, deduped_local, counter_local
def find_pair_urls_parallel(args, lshcache, url_doc):
start_time = time.time()
f_out = open(args.output, 'wb')
deduped, counter = 0, 0
# compute jaccards of buckets in bin in parallel (parallelism
# limited to # of bins)
num_bins = len(lshcache.bins)
pool = multiprocessing.Pool(num_bins)
compute_jaccard_partial = partial(compute_jaccard, num_bins=num_bins, \
start_time_local=start_time)
# don't need to pass args and url_doc as they are already shared
compute_jaccard_iter = pool.imap(compute_jaccard_partial, lshcache.bins)
print("multiprocessing init took {:.2f}".format(time.time() - start_time),\
flush=True)
for remove_urls_list, deduped_local, counter_local in compute_jaccard_iter:
deduped += deduped_local
counter += counter_local
write_remove_urls_list(remove_urls_list, f_out)
print(' [write]> processed {} documents in {:.2f} '
'seoncds and deduped {} documents ...'.format(counter, time.time()\
- start_time, deduped), flush=True)
pool.close()
pool.join()
f_out.close()
print(' Taken time for jaccard similariries {:.2f} seconds'.format(\
time.time() - start_time), flush=True)
def find_pair_urls_sequential(args, lshcache, url_doc):
start_time = time.time()
f_out = open(args.output, 'wb')
deduped, counter = 0, 0
for b in lshcache.bins:
for bucket_id in b:
if len(b[bucket_id]) <= 1:
continue
bucket_urls = b[bucket_id].copy()
remove_urls_list_sub, deduped_local_sub, counter_local_sub = \
url_pairs_to_remove(args, bucket_urls, url_doc)
deduped += deduped_local_sub
counter += counter_local_sub
write_remove_urls_list(remove_urls_list_sub, f_out)
if counter % 10000 == 0:
print(' [write]> processed {} documents in {:.2f} '
'seoncds and deduped {} documents ...'.
format(counter, time.time() - start_time,
deduped), flush=True)
f_out.close()
print(' [write]> processed {} documents in {:.2f} '
'seoncds and deduped {} documents ...'.
format(counter, time.time() - start_time,
deduped), flush=True)
if __name__ == '__main__':
print('parsing the arguments ...')
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=1234,
help='Random seed used for python, numpy')
parser.add_argument('--inputs', nargs = '*', default=None, help = \
'Pairwise list of the input files and keys, '
'e.g. --inputs cc.json cc_id news.json news_id')
parser.add_argument('--load-fingerprints', nargs = '*', default=None,
help='Load fingerprints from a list of pickle files,'
' e.g. cc.pkl news.pkl')
parser.add_argument('--save-fingerprints', type=str, default=None,
help='Save the fingerprints of the inputs.')
parser.add_argument('--output', type=str, default=None,
help='Output file name that consists of all ids'
' with matching similarities')
parser.add_argument('--jaccard', type=str, default='union',
choices=['union', 'min', 'max'], help='Jaccard'\
' similarity computation')
parser.add_argument('--heuristic-iter', type=int, default=1,
help='Number of iterations to run the heuristics'
': use -1 for exact')
parser.add_argument('--num-bands', type=int, default=10,
help='Number of bands to use in cache')
parser.add_argument('--num-seeds', type=int, default=100,
help='Number of seeds to use for minhash. Note that'
' this value should be divisible by num-bands')
parser.add_argument('--jaccard-parallel', action='store_true',
help='Use this to process large number of documents.')
args = parser.parse_args()
print('finding possible duplicate content ...')
# set seed and get an array of seeds of 100 integers
np.random.seed(args.seed)
seeds = np.random.randint(0, 1e6, size=args.num_seeds)
# initialize minhash and lsh cache
hasher = minhash.MinHasher(seeds=seeds, char_ngram=5, hashbytes=4)
lshcache = cache.Cache(num_bands=args.num_bands, hasher=hasher)
url_doc = {}
# load fingerprints from pickle file if needed
if args.load_fingerprints is not None:
for count_fp, fp_file_name in enumerate(args.load_fingerprints):
print("Loading fingerprints from pickle file {}".format(
fp_file_name), flush=True)
fp = open(fp_file_name, "rb")
if count_fp == 0:
# assign directory for the first pkl
lshcache = pickle.load(fp)
url_doc = pickle.load(fp)
else:
# append these to lshcache and url_doc
local_lshcache = pickle.load(fp)
local_url_doc = pickle.load(fp)
for url in local_lshcache.fingerprints.keys():
url_doc[url] = local_url_doc[url]
lshcache.add_fingerprint(local_lshcache.fingerprints[url], url)
fp.close()
counter = 0
start_time = time.time()
# compute finger prints of the inputs if any
# input file and the key to use as id
if args.inputs is not None:
print("Computing fingerprints", flush=True)
assert len(args.inputs) % 2 == 0
for input_file, key in zip(args.inputs[::2], args.inputs[1::2]):
print(' document processing {} with key {}'.format(input_file, key),
flush=True)
# compute fingerprints in parallel
num_workers = 40
pool = multiprocessing.Pool(num_workers)
fin = open(input_file, 'r', encoding='utf-8')
compute_fingerprint_partial = partial(compute_fingerprint, key=key)
compute_fingerprint_iter = pool.imap(compute_fingerprint_partial,
fin, 512)
# traverse all the texts and add fingerprints
for url, text, fingerprint, flag in compute_fingerprint_iter:
counter += 1
if flag:
url_doc[url] = text
lshcache.add_fingerprint(fingerprint, url)
if counter % 10000 == 0:
print(' [read]> processed {} documents in {:.2f} '
'seconds ...'.format(counter, time.time() - \
start_time), flush=True)
fin.close()
pool.close()
pool.join()
# Save the fingerprints if needed
if args.save_fingerprints is not None:
print("Saving fingerprints to pickle file {}".format(
args.save_fingerprints), flush=True)
with open(args.save_fingerprints, 'wb') as f_save:
pickle.dump(lshcache, f_save)
pickle.dump(url_doc, f_save)
# compute jaccard index of the input texts and write to file if needed
if args.output is not None:
print("Compute jaccard similarity", flush=True)
if args.jaccard_parallel:
find_pair_urls_parallel(args, lshcache, url_doc)
else:
find_pair_urls_sequential(args, lshcache, url_doc)
print('done :-)')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment