Commit 8ec5d678 authored by hepj987's avatar hepj987
Browse files

GPT2 base on megatron-deepspeed

parents
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# This script rescales scalar values in TensorBoard log files.
# It does the modification in-place (so make back ups!).
#
# Example:
#
# find . -name "*.tfevents*" -exec tb-rescale-scalars.py {} "iteration-time/samples per second" 1000 \;
#
# More than one old tag can be rescaled – use ";" as a separator:
#
# tb-rescale-scalars.py events.out.tfevents.1 "training loss;validation loss" 1e-2
#
# By default, BigScience GPT throughput values will be fixed up according to
# https://github.com/bigscience-workshop/Megatron-DeepSpeed/issues/236,
# i.e. the rescaling fixes values wrongly logged as "seconds" when they are
# actually milliseconds.
#
# This script is derived from https://stackoverflow.com/a/60080531/9201239
# and https://gist.github.com/stas00/4cd1651d1c8f01196ea322c733bde46c.
import os
import sys
# Use this if you want to avoid using the GPU
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
import tensorflow as tf
from tensorflow.core.util.event_pb2 import Event
def rescale_scalars(input_file, tags, rescale_factor):
new_file = input_file + '.new'
# Make a record writer
with tf.io.TFRecordWriter(new_file) as writer:
# Iterate event records
for rec in tf.data.TFRecordDataset([input_file]):
# Read event
ev = Event()
ev.MergeFromString(rec.numpy())
# Check if it is a summary
if ev.summary:
# Iterate summary values
for v in ev.summary.value:
# Check if the tag should be rescaled
if v.tag in tags:
v.simple_value *= rescale_factor
writer.write(ev.SerializeToString())
os.rename(new_file, input_file)
if __name__ == '__main__':
if len(sys.argv) < 2:
print(f'{sys.argv[0]} <input file> [<tags> [<rescale factor>]]',
file=sys.stderr)
sys.exit(1)
if len(sys.argv) < 3:
sys.argv.append(';'.join([
'iteration-time/samples per second',
'iteration-time/samples per second per replica',
'iteration-time/tokens per second',
'iteration-time/tokens per second per replica',
]))
if len(sys.argv) < 4:
sys.argv.append('1000')
input_file, tags, rescale_factor = sys.argv[1:]
tags = tags.split(';')
rescale_factor = float(rescale_factor)
rescale_scalars(input_file, tags, rescale_factor)
print('Done')
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Merge model parallel partitions."""
import os
import re
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir)))
import torch
from megatron import mpu
from megatron.checkpointing import load_checkpoint, save_checkpoint
from megatron.checkpointing import ensure_directory_exists
from megatron.checkpointing import get_checkpoint_name
from megatron.checkpointing import get_checkpoint_version
from megatron.checkpointing import get_checkpoint_tracker_filename
from megatron.global_vars import set_global_variables, get_args
from megatron.global_vars import rebuild_tokenizer
def split_into_partitions(tensor, num_partitions, partition_dim, stride):
per_partition_size = mpu.utils.divide(tensor.size(partition_dim),
num_partitions)
per_partition_per_stride_size = mpu.utils.divide(per_partition_size, stride)
partitions_list = torch.split(tensor,
per_partition_per_stride_size,
dim=partition_dim)
partitions = []
for i in range(num_partitions):
partition = torch.cat(partitions_list[i::num_partitions],
dim=partition_dim)
partitions.append(partition)
return partitions
def merge_partitions(merged, partitions, partition_dim, stride):
# Number and size of each partition.
num_partitions = len(partitions)
per_partition_size = None
for partition in partitions:
if per_partition_size is None:
per_partition_size = partition.size(partition_dim)
else:
assert per_partition_size == partition.size(partition_dim)
def concat_partitions(partitions_):
with torch.no_grad():
if (per_partition_size * num_partitions) == merged.size(
partition_dim):
torch.cat(partitions_, dim=partition_dim, out=merged)
else:
print(' ***WARNING*** sizes do not match. Will cut '
'the merged partitions by {} along dimension {} '
'to reduce the size from {} to {} ...'.format(
(per_partition_size * num_partitions) - \
merged.size(partition_dim), partition_dim,
per_partition_size * num_partitions,
merged.size(partition_dim)))
merged_ = torch.cat(partitions_, dim=partition_dim)
merged_split = torch.split(merged_, merged.size(partition_dim),
dim=partition_dim)
merged_ = merged_split[0]
assert merged_.size(partition_dim) == merged.size(partition_dim)
merged.data.copy_(merged_.data)
# If stride is 1, then do simple concatination.
if stride == 1:
concat_partitions(partitions)
return
# For none unity strides, first split based on stride and then group.
per_partition_per_stride_size = mpu.utils.divide(per_partition_size, stride)
# Chunk and build a list.
chunks = None
for i, partition in enumerate(partitions):
chunk = torch.split(partition,
per_partition_per_stride_size,
dim=partition_dim)
if chunks is None:
chunks = [0]*(num_partitions*len(chunk))
chunks[i::num_partitions] = chunk
# Concatinate.
concat_partitions(chunks)
return
def get_model(model_type):
if model_type == 'BERT':
from pretrain_bert import model_provider
elif model_type == 'GPT':
from pretrain_gpt import model_provider
elif model_type == 'RACE':
from tasks.race.finetune import model_provider
elif model_type == ['MNLI', 'QQP']:
num_classes = 2
if model_type == 'MNLI':
num_classes = 3
from megatron.model.classification import Classification
def model_provider():
return Classification(num_classes=num_classes, num_tokentypes=2)
else:
raise Exception('unrecognized model type: {}'.format(model_type))
model = model_provider()
model = model.half()
return model
def get_parallel_checkpoint_name(path):
tracker_filename = get_checkpoint_tracker_filename(path)
iteration = 0
with open(tracker_filename, 'r') as f:
metastring = f.read().strip()
iteration = int(metastring)
assert iteration > 0
checkpoint_name = get_checkpoint_name(path, iteration)
return checkpoint_name, iteration
def test_split_merge():
print('testing split and merge ...')
#[QKV.ROW-COL]
tensor = torch.FloatTensor([[1.11, 1.12, 1.13, 1.14, 1.15],
[1.21, 1.22, 1.23, 1.24, 1.25],
[1.31, 1.32, 1.33, 1.34, 1.35],
[1.41, 1.42, 1.43, 1.44, 1.45],
[2.11, 2.12, 2.13, 2.14, 2.15],
[2.21, 2.22, 2.23, 2.24, 2.25],
[2.31, 2.32, 2.33, 2.34, 2.35],
[2.41, 2.42, 2.43, 2.44, 2.45],
[3.11, 3.12, 3.13, 3.14, 3.15],
[3.21, 3.22, 3.23, 3.24, 3.25],
[3.31, 3.32, 3.33, 3.34, 3.35],
[3.41, 3.42, 3.43, 3.44, 3.45]])
num_partitions = 2
partition_dim = 0
stride = 3
partitions = split_into_partitions(tensor, num_partitions,
partition_dim, stride)
merged = torch.zeros_like(tensor)
merge_partitions(merged, partitions, partition_dim, stride)
max_error = (merged - tensor).abs().max()
print(' > max error (should be zero): {}'.format(max_error))
def get_mp_merge_args(parser):
"""Provide extra arguments required for merging."""
group = parser.add_argument_group(title='mp merge')
group.add_argument('--model-type', type=str, required=True,
choices=['BERT', 'GPT', 'RACE', 'MNLI', 'QQP'],
help='Type of the mdoel.')
group.add_argument('--target-pipeline-model-parallel-size', type=int, default=1,
help='Degree of pipeline model parallelism in output model.')
return parser
def main():
# Arguments do sanity checks on the world size, but we don't care,
# so trick it into thinking we are plenty of processes
os.environ["WORLD_SIZE"] = f'{2**31}'
# Args
set_global_variables(extra_args_provider=get_mp_merge_args,
args_defaults = {'use_cpu_initialization': True,
'micro_batch_size': 1,
'no_load_optim': True,
'no_load_rng': True,
'no_save_optim': True,
'no_save_rng': True,
'save_interval': 1})
args = get_args()
if args.pipeline_model_parallel_size > 1:
print("Checkpoints with pipeline model parallelism are not currently supported.")
exit()
model_type = args.model_type
orig_tensor_model_parallel_size = args.tensor_model_parallel_size
args.tensor_model_parallel_size = 1
tokenizer = rebuild_tokenizer(args)
print('\n merging model parallel partitions ...')
print(' > number of partitions: {}'.format(orig_tensor_model_parallel_size))
print(' > checkpoint path: {}'.format(args.load))
print(' > model parameters:')
print(' number of tokens ................ {} '.format(
tokenizer.vocab_size))
print(' number of layers ................ {}'.format(args.num_layers))
print(' hidden size ..................... {}'.format(args.hidden_size))
print(' number of attention heads ....... {}'.format(
args.num_attention_heads))
print(' maximum position embeddings ..... {}'.format(
args.max_position_embeddings))
# Full model.
print('> building the full model ...')
mpu.initialize.set_tensor_model_parallel_world_size(1)
mpu.initialize.set_tensor_model_parallel_rank(0)
mpu.initialize.set_pipeline_model_parallel_world_size(1)
mpu.initialize.set_pipeline_model_parallel_rank(0)
merged_model = get_model(model_type)
# Build and load partitions.
partitions = []
iteration = 0
args.tensor_model_parallel_size = orig_tensor_model_parallel_size
tokenizer = rebuild_tokenizer(args)
mpu.initialize.set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)
for rank in range(args.tensor_model_parallel_size):
# Reset these since load_checkpoint asserts they are 0, but we are loading
# multiple checkpoints in the same process and they get set each time
args.consumed_train_samples = 0
args.consumed_valid_samples = 0
mpu.initialize.set_tensor_model_parallel_rank(rank)
checkpoint_name, iteration = get_parallel_checkpoint_name(args.load)
model_ = get_model(model_type)
print(f'> loading {checkpoint_name} ...')
load_checkpoint(model_, None, None)
print(f'> checkpoint version {get_checkpoint_version()}')
partitions.append(model_)
# Parameter generators so we can loop through them semiltaneouly.
merged_params_gen = merged_model.named_parameters()
partitions_params_gen = [partition.named_parameters()
for partition in partitions]
while True:
try:
# Get the params and check names.
name, merged_param = next(merged_params_gen)
print(' > working on {} ...'.format(name))
print(' merged type: {}, size: {}'.format(
merged_param.dtype, list(merged_param.size())))
partitions_param = []
for rank, partition_params_gen in enumerate(partitions_params_gen):
partition_name, partition_param = next(partition_params_gen)
assert partition_name == name
partitions_param.append(partition_param)
print(' partition {} type: {}, size: {}'.format(
rank, partition_param.dtype, list(partition_param.size())))
# For the non-parallel parameters, simply copy the rank 0 values.
if not hasattr(merged_param, 'tensor_model_parallel'):
print(' none-parallel parameter, simple copy from rank 0')
with torch.no_grad():
merged_param.data.copy_(partitions_param[0].data)
# For parallel parameters, merge the values
else:
dim = merged_param.partition_dim
stride = merged_param.partition_stride
print(f' parallel parameter merge with stride {stride} along '
f'dimention {dim}')
merge_partitions(merged_param,
partitions_param,
dim,
stride)
except StopIteration:
break
partitions = []
args.tensor_model_parallel_size = 1
args.pipeline_model_parallel_size = args.target_pipeline_model_parallel_size
assert args.num_layers % args.pipeline_model_parallel_size == 0, \
'num_layers must be divisible by target pipeline model parallel size'
layers_per_part = args.num_layers // args.pipeline_model_parallel_size
tokenizer = rebuild_tokenizer(args)
mpu.initialize.set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)
mpu.initialize.set_tensor_model_parallel_rank(0)
mpu.initialize.set_pipeline_model_parallel_world_size(args.pipeline_model_parallel_size)
# regex to parse out layer number from param name
layer_re = re.compile('layers\.([0-9]+)')
if args.pipeline_model_parallel_size > 1:
merged_params = {}
for name, merged_param in merged_model.named_parameters():
merged_params[name] = merged_param
for rank in range(args.pipeline_model_parallel_size):
mpu.initialize.set_pipeline_model_parallel_rank(rank)
model = get_model(model_type)
def update_layer_num(m):
# TODO! This assumes no interleaved pipeline execution
layer = int(m.group(1))
layer += rank * layers_per_part
return f'layers.{layer}'
for dst_name, partition_param in model.named_parameters():
if dst_name == "word_embeddings.weight":
# See comment in MegatronModule.initialize_word_embeddings()
src_name = "language_model.embedding.word_embeddings.weight"
else:
# Translate destination layer number (0-N for each partition)
# to source layer number (single-model layer number)
src_name = re.sub(layer_re, update_layer_num, dst_name)
print(f" > copying {src_name} to {dst_name} in rank {rank}'s model")
partition_param.data.copy_(merged_params[src_name].data)
partitions.append(model)
else:
partitions = [merged_model]
for rank, model in enumerate(partitions):
mpu.initialize.set_pipeline_model_parallel_rank(rank)
print(f"> saving rank {rank}'s model")
save_checkpoint(iteration, model, None, None)
print('done :-)')
if __name__ == '__main__':
main()
"""Merge a list of indexed datasets into a single indexed dataset.
This script can run in two modes: a serial mode in which a single
process merges all datasets, and a distributed parallel mode in
which a set of processes in a torch.distributed environment
collectively merge datasets into a single file.
The serial mode is simpler to use.
Provides that the file system permits it, the parallel mode
can improve performance when merging many dataset files.
The distributed mode requires one to write the output dataset to
a POSIX-complaint file system that supports shared parallel
access to the file as different processes write to different
regions of the output file simultaneously.
To run in serial mode:
python tools/merge_preprocessed_data.py \
--datasets \
meg-gpt2-oscar-en-500-p1_text_document \
meg-gpt2-oscar-en-500-p2_text_document \
meg-gpt2-oscar-en-500-p3_text_document \
--output-prefix meg-gpt2_oscar_text_document
To run in distributed mode:
MASTER_ADDR="localhost"
MASTER_PORT=12345
python -m torch.distributed.launch \
--nproc_per_node 40 \
--master_addr $MASTER_ADDR \
--master_port $MASTER_PORT \
tools/merge_preprocessed_data.py \
--merge distributed \
--datasets \
meg-gpt2-oscar-en-500-p1_text_document \
meg-gpt2-oscar-en-500-p2_text_document \
meg-gpt2-oscar-en-500-p3_text_document \
--output-prefix meg-gpt2_oscar_text_document
"""
import argparse
import time
from megatron import print_rank_0
from megatron.data import indexed_dataset
from megatron.data.indexed_dataset import infer_dataset_impl, MMapIndexedDataset, data_file_path, index_file_path, merge_files_dist
from megatron.data.distdata import DistData
def get_args():
parser = argparse.ArgumentParser()
group = parser.add_argument_group(title='input data')
group.add_argument('--datasets', nargs='+', default=None,
help='Paths to one or more input datasets to merge')
group = parser.add_argument_group(title='output data')
group.add_argument('--output-prefix', type=str, required=True,
help='Path to binary output file without suffix')
group = parser.add_argument_group(title='runtime')
group.add_argument('--merge', type=str, default='serial', choices=['serial', 'distributed'],
help='Whether to use a serial merge with a single process or a distributed parallel merge.')
group.add_argument('--torch-backend', type=str, default=None, choices=['gloo', 'mpi'],
help='Select torch.distributed backend.')
group.add_argument('--local_rank', type=int, default=None,
help='Local rank of calling process on its node (from torch.distributed.launch).')
args = parser.parse_args()
# initialize distributed environment if distributed merge requested
if args.merge == 'distributed':
if args.torch_backend is None:
print_rank_0("Distributed merge using --torch-backend=gloo as default")
args.torch_backend = 'gloo'
args.distctx = DistData(backend=args.torch_backend)
if args.merge == 'serial' and args.torch_backend is not None:
print_rank_0("Ignoring setting for --torch-backend since using a serial merge")
return args
def main():
"""
Allows merging multiple types of datasets generated through preprocess_data script
"""
args = get_args()
startup_start = time.time()
print_rank_0(f"Merging {args.datasets}")
print_rank_0(f"Output prefix: {args.output_prefix}")
if args.merge == 'distributed':
if args.distctx.numranks > len(args.datasets):
print_rank_0(f"Using more ranks {args.distctx.numranks} than datasets {len(args.datasets)}")
merge_files_dist(args.output_prefix, args.datasets, args.distctx)
else:
# We use the first dataset to infer the dataset implementation common to all datasets.
dataset_impl = infer_dataset_impl(args.datasets[0])
assert dataset_impl is not None
# Ensure that all datasets use the same implementaton.
for ds in args.datasets:
ds_impl = infer_dataset_impl(ds)
assert ds_impl == dataset_impl, f"Dataset type '{ds_impl}' in file '{ds}' does not match type '{dataset_impl}' from file '{args.datasets[0]}'"
# We use the first dataset to infer the dtype common to all datasets.
first_dataset = indexed_dataset.make_dataset(args.datasets[0], dataset_impl)
dtype = first_dataset.dtype if isinstance(first_dataset, MMapIndexedDataset) else None
output_filename = args.output_prefix
output_bin_file = data_file_path(output_filename)
output_idx_file = index_file_path(output_filename)
builder = indexed_dataset.make_builder(output_bin_file,
impl=dataset_impl,
dtype=dtype)
for dataset in args.datasets:
builder.merge_file_(dataset)
builder.finalize(output_idx_file)
startup_end = time.time()
print_rank_0(f"Time to merge: {startup_end - startup_start}")
print_rank_0(f"Merged {len(args.datasets)} datasets to {args.output_prefix}")
if __name__ == "__main__":
main()
The following steps show how to prepare training dataset to train the mode.
# Libraries to install
```
pip install ftfy langdetect numpy torch pandas nltk sentencepiece boto3 tqdm regex bs4 newspaper3k htmlmin tldextract
git clone https://github.com/mattilyra/LSH
cd LSH
python setup.py install
```
# Download the dataset
1. Download the deduplicated URLs from [jcpeterson](https://mega.nz/#F!EZZD0YwJ!9_PlEQzdMVLaNdKv_ICNVQ!cc4RgQQZ)
2. Remove blacklisted URLs.
```
python blacklist_urls.py <path to the dowloaded deduplicated URLs> <filename for clean urls. e.g. clean_urls.txt>
```
3. Download the content from the clean urls with [openwebtext's utilities](https://github.com/eukaryote31/openwebtext/blob/master/download.py).
4. Merge the contents into one loose json file with 1 json per newline of the format `{'text': text, 'url': unique_url}`. It is important for the url to be unique.
# Prepare the data for GPT-2 training:
1. Perform ftfy, english detection and remove documents with less than 128 tokens. This step can be sharded and run on shards.
```
python cleanup_dataset.py <input data file> <output cleaned data filename>
```
Additional cleanup (e.g. remove documents less than 512 characters or dataset specific cleaning like stories, realnews datasets) can be done using `cleanup_fix_dataset.py`. More details can be found by running `python cleanup_fix_dataset.py --help`.
2. Using LSH, find possible duplicates and store then in a file for later processing. The code supports saving and loading fingerprints for recurrent deduplications, and is also multithreaded for faster processing. More details are can be found by `python find_duplicate.py --help`.
```
python find_duplicates.py --inputs <pairlist list of input cleaned data files and keys, e.g. cc.json cc_id news.json news_id> --output <output possible duplicate urls filename>
```
3. Based on similarity measure defind inside function `is_similar` (default: 0.9), group urls that are similar. Basically, for each group, only one url we should keep and remove the rest.
```
python group_duplicate_urls.py <possible duplicate urls file> <output file containing similar urls>
```
4. Remove similar documents that were detected in the last step.
```
python remove_group_duplicates.py <file containing simialr documents> <cleaned data file> <outputfile containing deduplicate data>
```
5. Shuffle the dataset.
```
shuf <cleaned deduped data file> -o train_data.json
```
# Deduplicating ngrams
To deduplicate the downstream tasks (e.g. lambada, squad) from the training dataset, we run the following command.
```
python filter_ngrams.py --tasks <name of he task, e.g. lambada, squad> --dedup-dataset <training dataset to deduplicate> <json key> --output <output training dataset>
```
We use 13-grams by default for the deduplication. When we find a 13-gram match in a training document, we split the document into two pieces and remove the 13-gram along with 200 characters from the both side of the 13-gram. We also remove any splitted document with less than 200 characters or if a document got splitted more than 10 times. These parameters can be changed using corresponding arguments.
Only for the lambada task, we need to provide the path, `--lambada-path <path of the lambada test data>`.
Several other features (e.g. save and load dictionary) have been added, look at `python filter_ngrams.py --help` for details.
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import os
import time
"""
This code adds id to each json object in a json file. User can add prefix
to the ids.
"""
if __name__ == '__main__':
print('parsing the arguments ...')
parser = argparse.ArgumentParser()
parser.add_argument('--input-file', type=str, default=None, help='Input'\
' json file where id needs to be added')
parser.add_argument('--output-file', type=str, default=None, help=\
'Output file name with id')
parser.add_argument('--id-prefix', type=str, default=None, help=\
'Id prefix')
parser.add_argument('--log-interval', type=int, default=100,
help='Log interval')
args = parser.parse_args()
print('Adding ids to dataset ...')
f_input = open(args.input_file, 'r', encoding='utf-8')
f_output = open(args.output_file, 'wb')
unique_ids = 1
start_time = time.time()
for row in f_input:
each_row = json.loads(row)
adlr_id_string = args.id_prefix + '-{:010d}'.format(int(unique_ids))
each_row['adlr_id'] = adlr_id_string
myjson = json.dumps(each_row, ensure_ascii=False)
f_output.write(myjson.encode('utf-8'))
f_output.write('\n'.encode('utf-8'))
if unique_ids % args.log_interval == 0:
print(' processed {:9d} documents in {:.2f} seconds ...'.format( \
unique_ids, time.time() - start_time), flush=True)
unique_ids += 1
# Close the file.
f_input.close()
f_output.close()
print('done :-)', flush=True)
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import re
import time
import tldextract
import sys
# List of the domains to blacklist.
domain_blacklist = set([
'500px',
'aapks',
'akamaihd',
'amazon',
'apple',
'artifactfire',
'artstation',
'awwni',
'bandcamp',
'battleforthenet',
'coinscalendar',
'dailymotion',
'deviantart',
'discord',
'discordapp',
'dlapkandroid',
'dropbox',
'e621',
'ebay',
'edealinfo',
'erome',
'eroshare',
'explosm',
'facebook',
'fbcdn',
'flickr',
'furaffinity',
'futhead',
'gatopardo',
'gfycat',
'gifsound',
'gifsoup',
'giphy',
'github',
'google',
'gunprime',
'gyazo',
'hotdealstar',
'imagefap',
'imageshack',
'imgflip',
'imgur',
'instagram',
'karmadecay',
'kryptocal',
'kym-cdn',
'liveleak',
'livememe',
'lmgtfy',
'magaimg',
'memegenerator',
'minorplanetcenter',
'minus',
'mobafire',
'morejpeg',
'nocookie',
'pcpartpicker',
'photobucket',
'pinimg',
'pinterest',
'pixiv',
'pornhub',
'prntscr',
'puu',
'qkme',
'quickmeme',
'radd',
'redd',
'reddit',
'reddit-stream',
'redditlog',
'redditmedia',
'reddituploads',
'redtube',
'reupp',
'reverb',
'roanoke',
'rollingstone',
'sli',
'soundcloud',
'soundgasm',
'spankbang',
'spotify',
'strawpoll',
'streamable',
'timeanddate',
'tinypic',
'touhouradio',
'tumblr',
'twimg',
'twitch',
'twitter',
'vid',
'vimeo',
'vine',
'vkaao',
'vocaroo',
'voyagefusion',
'walmart',
'wciu',
'wikimedia',
'wikipedia',
'xhamster',
'xkcd',
'xvideos',
'youtu',
'youtube',
'youtubedoubler',
'ytimg',
'zillexplorer',
])
def domain_is_in_blacklist(url):
domain = tldextract.extract(url).domain
return domain in domain_blacklist
# List of extentions to blacklist.
extentions_blacklist = (
'.3gp',
'.7z'
'.ai',
'.aif',
'.apk',
'.app',
'.avi',
'.bin',
'.bmp',
'.bz2',
'.css',
'.csv',
'.dat',
'.deb',
'.dmg',
'.doc',
'.docx',
'.exe',
'.gif',
'.gifv',
'.gz',
'.iso',
'.jar',
'.jpeg',
'.jpg',
'.js',
'.log',
'.mid',
'.midi',
'.mkv',
'.mov',
'.mp3',
'.mp4',
'.mpeg',
'.mpg',
'.ogg',
'.ogv',
'.otf',
'.pdf',
'.pkg',
'.png',
'.pps',
'.ppt',
'.pptx',
'.psd',
'.py',
'.qt',
'.ram',
'.rar',
'.sql',
'.svg',
'.swf',
'.tar.gz',
'.tar',
'.tgz',
'.tiff',
'.ttf',
'.txt',
'.wav',
'.webm',
'.wma',
'.wmv',
'.xls',
'.xlsx',
'.xml',
'.xz',
'.zip',
)
def extention_is_in_blacklist(url):
if url.split('?')[0].lower().endswith(extentions_blacklist):
return True
return False
# Malformed urls.
# This function is adapted from:
# https://stackoverflow.com/questions/7160737/python-how-to-validate-a-url-in-python-malformed-or-not
url_regex = re.compile(
r'^(?:http)s?://' # http:// or https://
r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|' #domain...
r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # ...or ip
r'(?::\d+)?' # optional port
r'(?:/?|[/?]\S+)$', re.IGNORECASE)
def url_is_malformed(url):
return re.match(url_regex, url) is None
def print_progress(prefix, start_time, urls_counter,
domain_blacklist_counter,
extention_blacklist_counter,
short_url_counter, malformed_url_counter,
duplicate_url_counter):
string = prefix + ' | '
string += 'time elapsed (s): {:.2f} | '.format(time.time() - start_time)
string += 'number of urls: {} | '.format(urls_counter)
string += 'domain blacklisted: {} | '.format(domain_blacklist_counter)
string += 'extention blacklisted: {} | '.format(extention_blacklist_counter)
string += 'short urls (<=8): {} | '.format(short_url_counter)
string += 'malformed urls: {} | '.format(malformed_url_counter)
string += 'duplicate urls: {}'.format(duplicate_url_counter)
print(string, flush=True)
if __name__ == '__main__':
print('remove blacklisted urls ..')
# Path to the url files.
path = sys.argv[1]
# Output url file.
output = sys.argv[2]
# Get the list of url files.
files = glob.glob(path + '/*.txt')
print('> found {} files'.format(len(files)))
urls = set()
urls_counter = 0
domain_blacklist_counter = 0
extention_blacklist_counter = 0
short_url_counter = 0
malformed_url_counter = 0
duplicate_url_counter = 0
start_time = time.time()
for filename in files:
with open(filename, 'r') as f:
for line in f:
url = line.strip()
urls_counter += 1
if domain_is_in_blacklist(url):
print('[DOMAIN BLACKLIST]: {}'.format(url), flush=True)
domain_blacklist_counter += 1
elif extention_is_in_blacklist(url):
print('[EXTENTION BLACKLIST]: {}'.format(url), flush=True)
extention_blacklist_counter += 1
elif len(url) <= 8:
print('[SHORT URL]: {}'.format(url), flush=True)
short_url_counter += 1
elif url_is_malformed(url):
print('[MALFORMED URL]: {}'.format(url), flush=True)
malformed_url_counter += 1
elif url in urls:
print('[DUPLICATE URL]: {}'.format(url), flush=True)
duplicate_url_counter += 1
else:
urls.add(url)
if urls_counter % 100000 == 0:
print_progress('PROGRESS', start_time, urls_counter,
domain_blacklist_counter,
extention_blacklist_counter,
short_url_counter, malformed_url_counter,
duplicate_url_counter)
print_progress('FINAL', start_time, urls_counter,
domain_blacklist_counter,
extention_blacklist_counter,
short_url_counter, malformed_url_counter,
duplicate_url_counter)
# Write the final set of urls.
print('> writing cleaned up url list to {}'.format(output))
with open(output, 'w') as f:
for url in urls:
f.write(url + '\n')
print('done :-)')
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ftfy
import json
from langdetect import detect
import numpy as np
import time
import os
import sys
from tokenizer import Tokenizer
MIN_DOCUMENT_LENGHT = 128
def print_progress(prefix, start_time, num_docs, num_fixed_text,
num_non_english_docs, chars_non_english_docs,
num_small_docs, chars_small_docs):
string = prefix + ' | '
string += 'elapsed time: {:.2f} | '.format(time.time() - start_time)
string += 'documents: {} | '.format(num_docs)
string += 'fixed text: {} | '.format(num_fixed_text)
string += 'non-english: {} | '.format(num_non_english_docs)
string += 'non-english chars: {} | '.format(chars_non_english_docs)
string += 'small docs: {} | '.format(num_small_docs)
string += 'small docs chars: {}'.format(chars_small_docs)
print(string, flush=True)
def filter_corpus(filename, out_filename, print_interval=10000):
print(' > filtering {}'.format(filename))
tokenizer = Tokenizer(cache_dir='./cache')
num_docs = 0
num_written_docs = 0
num_small_docs = 0
num_fixed_text = 0
num_non_english_docs = 0
chars_non_english_docs = 0
chars_small_docs = 0
start_time = time.time()
with open(out_filename, 'wb') as f:
with open(filename, 'r') as fin:
for line in fin:
try:
num_docs += 1
myjson = json.loads(line)
# Fix text
text = ftfy.fix_text(myjson['text'])
if text != myjson['text']:
num_fixed_text += 1
myjson['text'] = text
# Detect language.
if detect(text) != 'en':
print('[non-english text]', myjson)
num_non_english_docs += 1
chars_non_english_docs += len(text)
continue
# On average each token is 5 characters so 8 is an
# upper bound.
if len(text) < (8 * MIN_DOCUMENT_LENGHT):
tokens = tokenizer.tokenize_document(text)
if len(tokens) < MIN_DOCUMENT_LENGHT:
print('[small document, skipping]:', myjson)
num_small_docs += 1
chars_small_docs += len(text)
continue
myjson = json.dumps(myjson, ensure_ascii=False)
f.write(myjson.encode('utf-8'))
f.write('\n'.encode('utf-8'))
num_written_docs += 1
if num_docs % print_interval == 0:
print_progress('[PROGRESS]', start_time, num_docs,
num_fixed_text, num_non_english_docs,
chars_non_english_docs,
num_small_docs, chars_small_docs)
except Exception as e:
print(' skipping ', line, e)
print_progress('[FINAL]', start_time, num_docs,
num_fixed_text, num_non_english_docs,
chars_non_english_docs,
num_small_docs, chars_small_docs)
if __name__ == '__main__':
print('building gpt2 dataset ...')
input_filename = sys.argv[1]
output_filename = sys.argv[2]
print('will be reading {}'.format(input_filename))
print('and will write the results to {}'.format(output_filename))
filter_corpus(input_filename, output_filename)
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Filter and clean documents:
Capable to clean docs with less than 512 characters, less than
256 characters and contains javascript, fix text and dataset specific
cleaning like stories and realnews datasets.
Program arguments have the details.
"""
import argparse
from functools import partial
import glob
import ftfy
import json
from langdetect import detect
import multiprocessing
import os
from pathlib import Path
import re
import time
def process_doc(json_line, args):
# Read the line.
document = json.loads(json_line)
text = document['text']
output = {'remove_512': False, 'remove_256_javascript': False, \
'remove_512_non_english': False, 'ftfy_fix_text': False, \
'general_cleaning': False}
try:
# Reomove all docs with less than 512 characters
if "remove_512" in args.tasks:
if len(text) < 512:
output['remove_512'] = True
return output, text, document, True
# Remove docs if less than 256 character length and contains Javascript
if "remove_256_javascript" in args.tasks:
if len(text) < 256 and 'javascript' in text.lower():
output['remove_256_javascript'] = True
return output, text, document, True
# Remove docs < 512 and nonenglish
if "remove_512_non_english" in args.tasks:
if len(text) < 512 and detect(text) != 'en':
output['remove_512_non_english'] = True
return output, text, document, True
# Fix the text using ftfy, don't remove the text, hence return False
if "ftfy_fix_text" in args.tasks:
fixed_text = ftfy.fix_text(text)
output['ftfy_fix_text'] = True
return output, fixed_text, document, False
# Cleaning extra spaces and newlines
if "general_cleaning" in args.tasks:
cleaned_text = re.sub(r" +|\b\n+ |\b\n+", " ", text)
#cleaned_text = re.sub(r"\n\n+", "\n\n", text) # used this for Gutenberg dataset
#cleaned_text = re.sub(r"\n", "\n\n", text) # Used this for realnews
# stories datasets
#cleaned_text = re.sub(r" \'", "'", text)
#cleaned_text = re.sub(r" \!", "!", cleaned_text)
#cleaned_text = re.sub(r" \.", ".", cleaned_text)
#cleaned_text = re.sub(r" \?", "?", cleaned_text)
#cleaned_text = re.sub(r" - ", "-", cleaned_text)
##cleaned_text = re.sub(r"\" ", "\"", cleaned_text)
#cleaned_text = re.sub(r" @ ", "@", cleaned_text)
output['general_cleaning'] = True
return output, cleaned_text, document, False
except Exception as e:
print('Error: *************************\n{}\ntext: {}'.format(e, \
text), flush=True)
return output, text, document, True
# don't remove
return output, text, document, False
def process_set(args, input_file, output_f_cleaned, output_f_filtered):
print(' > working on {} ...'.format(input_file), flush=True)
num_docs = num_remove_512 = num_remove_java = num_remove_512_non_english \
= num_ftfy_fix_text = num_general_cleaning = 0
# Output file and counters.
output_cleaned = open(output_f_cleaned, 'wb')
output_filtered = open(output_f_filtered, 'wb')
start_time = time.time()
# Setup multi-processing.
num_workers = 40
fin = open(input_file, 'r', encoding='utf-8')
pool = multiprocessing.Pool(num_workers)
process_doc_partial = partial(process_doc, args=args)
processed_docs = pool.imap(process_doc_partial, fin, 500)
# Process documents.
for output, text, document, to_filter in processed_docs:
num_docs += 1
num_remove_512 += 1 if output['remove_512'] else 0
num_remove_java += 1 if output['remove_256_javascript'] else 0
num_remove_512_non_english += 1 if output['remove_512_non_english'] \
else 0
num_ftfy_fix_text += 1 if output['ftfy_fix_text'] else 0
num_general_cleaning += 1 if output['general_cleaning'] else 0
document['text'] = text
myjson = json.dumps(document, ensure_ascii=False)
if to_filter:
output_filtered.write(myjson.encode('utf-8'))
output_filtered.write('\n'.encode('utf-8'))
else:
output_cleaned.write(myjson.encode('utf-8'))
output_cleaned.write('\n'.encode('utf-8'))
if num_docs % args.log_interval == 0:
print(' processed {:9d} documents in {:.2f} seconds ...'.format(
num_docs, time.time() - start_time), flush=True)
# Close the file.
output_cleaned.close()
output_filtered.close()
fin.close()
# Print stats.
print(' >> total docs: {} remove_512 {} remove_256_javascript {} '\
'remove_512_non_english {} ftfy_fix_text {} general_cleaning {}'.\
format(num_docs, num_remove_512, num_remove_java,\
num_remove_512_non_english, num_ftfy_fix_text, \
num_general_cleaning), flush=True)
if __name__ == '__main__':
print('parsing the arguments ...')
parser = argparse.ArgumentParser()
parser.add_argument('--input-files', nargs = '*', required=True, default=\
None, help = 'Input json files that needs to be'\
' cleaned')
parser.add_argument('--tasks', nargs = '*', required=True, default=None,\
help = 'Tasks to perform on the input files, ' \
'such as remove_512, remove_256_javascript, ' \
'remove_512_non_english, ftfy_fix_text, and ' \
'general_cleaning. 256 or 512 means the number' \
' of characters.')
parser.add_argument('--output-path', type=str, default=None,
help='Directory where the output should go')
parser.add_argument('--log-interval', type=int, default=100,
help='Log interval')
args = parser.parse_args()
print('cleanup dataset ...')
for input_file in args.input_files:
input_filename, input_filename_ext = os.path.splitext(Path(input_file)\
.name)
output_f_cleaned = os.path.join(args.output_path, input_filename + \
"_cleaned" + input_filename_ext)
output_f_filtered = os.path.join(args.output_path, input_filename + \
"_filtered" + input_filename_ext)
process_set(args, input_file, output_f_cleaned, output_f_filtered)
print('done :-)', flush=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment