Commit 0024a5c6 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'main' of https://github.com/NVIDIA/Megatron-LM

parents b004456b 3db2063b
Pipeline #229 failed with stages
in 0 seconds
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from .embed import BertEmbedder, DiskDataParallelBertEmbedder
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import numpy as np
import torch
from megatron import get_args, get_tokenizer
from megatron.data.bert_dataset import build_training_sample
class BertEmbeddingDataset(torch.utils.data.Dataset):
'''Dataset to convert a text dataset to Bert tokens.'''
def __init__(self, text_dataset, max_seq_length):
super().__init__()
args = get_args()
# Dataset, tokenizer.
self.text_dataset = text_dataset
self.bert_tokenizer = get_tokenizer()
# Params to store.
self.max_seq_length = max_seq_length
self.seed = args.seed
self.masked_lm_prob = args.mask_prob
# Vocab stuff.
self.vocab_id_list = list(self.bert_tokenizer.inv_vocab.keys())
self.vocab_id_to_token_dict = self.bert_tokenizer.inv_vocab
self.cls_id = self.bert_tokenizer.cls
self.sep_id = self.bert_tokenizer.sep
self.mask_id = self.bert_tokenizer.mask
self.pad_id = self.bert_tokenizer.pad
def __len__(self):
return len(self.text_dataset)
def __getitem__(self, idx):
# Text.
text_sample = self.text_dataset[idx]
text = text_sample["text"]
text = text.replace("<|endoftext|>", "")
# Bert/Wordpiece tokens (+truncate).
bert_token_ids = self.bert_tokenizer.tokenize(text)
bert_token_ids = bert_token_ids[:self.max_seq_length - 2] # cls+sep.
if not bert_token_ids:
bert_token_ids = [ self.bert_tokenizer.pad_id ] # hack when empty seq
# Note that this rng state should be numpy and not python since
# python randint is inclusive whereas the numpy one is exclusive.
# We % 2**32 since numpy requres the seed to be between 0 and 2**32 - 1
np_rng = np.random.RandomState(seed=((self.seed + idx) % 2**32))
# Build sample.
sample = build_training_sample([bert_token_ids],
len(bert_token_ids),
len(bert_token_ids) + 2, # for cls+sep
self.vocab_id_list,
self.vocab_id_to_token_dict,
self.cls_id, self.sep_id,
self.mask_id, self.pad_id,
self.masked_lm_prob, np_rng,
binary_head=False)
sample["seq_length"] = len(sample["text"])
return sample
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from functools import partial
import numpy as np
import os
import time
import torch
from torch.utils.data import BatchSampler, DataLoader, SequentialSampler, Subset
from torch.utils.data._utils.collate import default_collate
from tqdm import tqdm
from megatron import get_args, get_tokenizer, print_rank_0
from megatron import core
from megatron.core.enums import ModelType
from megatron.model import BertModel
from megatron.schedules import get_forward_backward_func
from megatron.training import setup_model_and_optimizer
from .dataset import BertEmbeddingDataset
from .external_libs import h5py
from .huggingface import HuggingfaceEmbedder
from .utils import get_missing_blocks_by_rank
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
print_rank_0(" > build Bert model.")
args = get_args()
num_tokentypes = 2 if args.bert_binary_head else 0
model = BertModel(
num_tokentypes=num_tokentypes,
add_binary_head=args.bert_binary_head,
parallel_output=True,
pre_process=pre_process,
post_process=post_process)
return model
def get_batch(data_iterator):
"""Build the batch."""
# Items and their type.
keys = ['text', 'types', 'labels', 'is_random', 'loss_mask', 'padding_mask',
'seq_length']
datatype = torch.int64
# Broadcast data.
if data_iterator is not None:
data = next(data_iterator)
else:
data = None
data_b = core.tensor_parallel.broadcast_data(keys, data, datatype)
# Unpack.
tokens = data_b['text'].long()
types = data_b['types'].long()
sentence_order = data_b['is_random'].long()
loss_mask = data_b['loss_mask'].float()
lm_labels = data_b['labels'].long()
padding_mask = data_b['padding_mask'].long()
seq_lengths = data_b['seq_length'].long()
return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask, \
seq_lengths
def loss_func(loss_mask, sentence_order, seq_lengths,
output_tensor, non_loss_data):
"""Loss function. Sequence lengths returned here for progress print-outs."""
assert non_loss_data
return seq_lengths, output_tensor
def forward_step(data_iterator, model):
"""Forward step."""
args = get_args()
# Get the batch.
tokens, types, sentence_order, loss_mask, lm_labels, padding_mask, \
seq_lengths = get_batch(data_iterator)
if not args.bert_binary_head:
types = None
# Forward pass through the model.
output_tensor = model(tokens, padding_mask, tokentype_ids=types,
lm_labels=lm_labels)
return output_tensor, partial(loss_func, loss_mask, sentence_order,
seq_lengths)
def collate_batch(samples):
"""Collate samples of various lengths.
This collate function handles samples with various sequence lengths, by
padding 'text' arrays with pad_id, and other arrays with 0.
"""
n_samples = len(samples)
keys = list(samples[0].keys())
tokenizer = get_tokenizer()
# Max sample length across all samples.
max_length_map = { key:0 for key in keys }
for sample in samples:
for key in keys:
value_length = \
len(sample[key]) if isinstance(sample[key], np.ndarray) else None
max_length_map[key] = None \
if value_length is None else \
max(max_length_map[key], value_length)
# Pad samples.
padded_samples = []
for sample in samples:
padded_sample = {}
for key in keys:
padded_sample[key] = \
np.pad(
sample[key],
(0, max_length_map[key] - len(sample[key])),
mode="constant",
constant_values=tokenizer.pad_id if key == "text" else 0,
) \
if isinstance(sample[key], np.ndarray) else \
sample[key]
padded_samples.append(padded_sample)
# Build batch with padded samples.
batch = default_collate(padded_samples)
return batch
def get_data_loader(dataset, batch_size):
"""Build data loader over data subset.
Get a subset of the dataset (from start_idx -> end_idx), and wrap it in
a sequential sampler and data loader.
"""
args = get_args()
# Sequential & batch samplers.
batch_sampler = BatchSampler(
sampler=SequentialSampler(dataset),
batch_size=batch_size,
drop_last=False,
)
# Data loader.
data_loader = DataLoader(dataset,
batch_sampler=batch_sampler,
num_workers=args.num_workers,
pin_memory=True,
collate_fn=collate_batch)
return data_loader
def embed_data_loader(models, data_loader):
'''Iterate data loader and compute embeddings.'''
# Verify no model parallelism.
args = get_args()
assert args.tensor_model_parallel_size == 1 and \
args.pipeline_model_parallel_size == 1, \
"since we call forward_step directly, only tp == pp == 1 allowed."
# Data iterator.
data_iterator = iter(data_loader)
# Eval mode.
for m in models:
m.eval()
# Embed.
embeddings = []
for _ in tqdm(range(len(data_loader)), "mt embed"):
with torch.no_grad():
result = forward_step(data_iterator, models[0])
embeddings.append(result[0].detach().cpu().numpy())
# Concatenate embeddings.
embeddings = np.concatenate(embeddings, axis=0)
return embeddings
class BertEmbedder:
'''Compute Bert embeddings, from a text dataset.'''
def __init__(self, batch_size, max_bert_seq_length, embedder_type):
args = get_args()
assert args.output_bert_embeddings
self.models, optimizer, opt_param_scheduler = \
setup_model_and_optimizer(model_provider,
ModelType.encoder_or_decoder)
self.batch_size = batch_size
self.max_bert_seq_length = max_bert_seq_length
# Init Huggingface, if in use.
if embedder_type == "megatron":
self.huggingface_embedder = None
elif embedder_type == "huggingface":
self.huggingface_embedder = HuggingfaceEmbedder(batch_size,
max_bert_seq_length)
else:
raise Exception("specialize for embedder type '%s'." % embedder_type)
def embed_text_dataset(self, text_dataset):
'''Embed a text dataset.'''
# Huggingface.
if self.huggingface_embedder:
return self.huggingface_embedder.embed_text_dataset(text_dataset)
# Wrap in a BertEmbeddingDataset to tokenize samples.
bert_dataset = BertEmbeddingDataset(text_dataset,
self.max_bert_seq_length)
# Embed.
data_loader = get_data_loader(bert_dataset, self.batch_size)
embeddings = embed_data_loader(self.models, data_loader)
return embeddings
def embed_text(self, text):
'''Embed a single text string.
Primarily used for on-the-fly embeddings, particularly during
analysis or debugging. For large scale, use 'embed_text_dataset()'.
'''
class SingleTextDataset(torch.utils.data.Dataset):
'''Dataset that holds single string.'''
def __init__(self, text):
assert isinstance(text, str)
self.text = text
def __len__(self):
return 1
def __getitem__(self, i):
return {"text": self.text}
# Embed text.
text_ds = SingleTextDataset(text)
embed = self.embed_text_dataset(text_ds)[0]
return embed
class DiskDataParallelBertEmbedder:
'''Process embeddings in blocks & save to disk.'''
def __init__(self, batch_size, max_bert_seq_length, block_size,
embedder_type):
self.embedder = BertEmbedder(batch_size, max_bert_seq_length,
embedder_type)
self.block_size = block_size
def embed_text_blocks(self, name, workdir, text_dataset,
missing_embedding_blocks):
'''Process a text dataset in blocks.'''
# Iterate blocks.
for block_index, block_info in enumerate(missing_embedding_blocks):
# Missing block lists are extended with None to have equal-length
# lists. Skip the Nones.
if block_info is not None:
# Progress. (*note*: move world progress to here.)
print_rank_0("embed '%s' block %d / %d ... %s." % (
name,
block_index,
len(missing_embedding_blocks),
block_info["path"],
))
# Embed block.
sub_dataset = Subset(text_dataset, range(*block_info["range"]))
embeddings = self.embedder.embed_text_dataset(sub_dataset)
# Save embeddings.
f = h5py.File(block_info["path"], "w")
f.create_dataset("data", data=embeddings)
f.close()
# Synchronize progress across all ranks. (for easier observation)
print_rank_0(" > waiting for other ranks to finish block.")
torch.distributed.barrier()
def embed_text_dataset(self, name, workdir, text_dataset):
'''Embed a text dataset.'''
# Dataset workdir.
os.makedirs(workdir, exist_ok=True)
# Missing embedding blocks (stored on disk).
def validate(f):
assert f["data"].shape[1] == 1024
n_missing_world, missing_embedding_blocks = get_missing_blocks_by_rank(
workdir,
len(text_dataset),
self.block_size,
validate=validate)
# Prevent missing file race condition.
torch.distributed.barrier()
# Embed batches.
self.embed_text_blocks(name, workdir, text_dataset,
missing_embedding_blocks)
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import importlib
required_libs = [
"h5py",
"transformers", # for huggingface bert
]
for lib in required_libs:
try:
globals()[lib] = importlib.import_module(lib)
except ImportError as e:
raise Exception(f"Missing one or more packages required for Bert embedding: {required_libs}. Tried importing '{lib}'.")
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import numpy as np
import torch
from tqdm import tqdm
from .external_libs import transformers
class IterableTextDataset(torch.utils.data.IterableDataset):
'''Iterable over a text dataset.'''
def __init__(self, text_dataset):
self.text_dataset = text_dataset
def __iter__(self):
'''Remove 'endoftext' string.'''
for sample_idx in range(len(self.text_dataset)):
sample = self.text_dataset[sample_idx]
text = sample["text"].replace("<|endoftext|>", "")
yield text
class MyFeatureExtractionPipeline(transformers.FeatureExtractionPipeline):
def _forward(self, model_inputs):
# Embed inputs.
model_outputs = self.model(**model_inputs)
# Attention mask.
embeddings = model_outputs[0]
masks = torch.sum(model_inputs['attention_mask'], dim=1)
# Collect embeddings & check for nan.
outputs = []
for embedding, mask in zip(embeddings, masks):
output = torch.mean(embedding[1: mask - 1], dim=0)
# Nans due to empty input sequences; so only check first element.
if torch.isnan(output.view(-1)[0]).any():
output.zero_()
outputs.append(output)
# Sample.
data = {
"input" : model_inputs["input_ids"],
"output" : outputs,
}
return data
def postprocess(self, model_outputs):
# Return input for analysis.
return {
"input" : model_outputs["input"].numpy(),
"output" : model_outputs["output"].numpy(),
}
class HuggingfaceEmbedder:
def __init__(self, batch_size, max_seq_length):
# Model, tokenizer.
self.model = transformers.BertModel.from_pretrained("bert-large-cased")
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
"bert-large-cased", model_max_length=max_seq_length)
# Feature extraction pipeline.
self.pipe = MyFeatureExtractionPipeline(
model=self.model,
tokenizer=self.tokenizer,
device=torch.cuda.current_device(),
truncation=True,
max_length=max_seq_length,
)
self.batch_size = batch_size
def embed_text_dataset(self, text_dataset, verbose=True):
# Wrap dataset in iterable.
dataset = IterableTextDataset(text_dataset)
# Allocate output array.
n_samples = len(text_dataset)
embeddings = np.zeros((n_samples, 1024), dtype="f4")
start_idx = 0
# Wrap iterator in tqdm for verbose output.
_iter = self.pipe(dataset, batch_size=self.batch_size)
if verbose:
_iter = tqdm(_iter, "hf embed", total=n_samples)
# Embed dataset.
for idx, out_dict in enumerate(_iter):
inp = out_dict["input"]
out = out_dict["output"]
embeddings[start_idx] = out
start_idx += 1
return embeddings
def embed_text(self, text):
'''Embed a single text string.
Primarily used for on-the-fly embeddings, particularly during
analysis or debugging. For large scale, use 'embed_text_dataset()'.
'''
class SingleTextDataset(torch.utils.data.Dataset):
'''Dataset that holds single string.'''
def __init__(self, text):
assert isinstance(text, str)
self.text = text
def __len__(self):
return 1
def __getitem__(self, i):
return {"text": self.text}
# Embed text.
text_ds = SingleTextDataset(text)
embed = self.embed_text_dataset(text_ds, verbose=False)[0]
return embed
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from collections import defaultdict
import glob
import numpy as np
import os
import torch
from tqdm import tqdm
from megatron import print_rank_0
from megatron.core import parallel_state
from .external_libs import h5py
def save_data(data_map, *args):
'''Save map of numpy arrays to hdf5 file.'''
# Parse args.
if len(args) == 1:
path = args[0]
elif len(args) == 2:
dir_path, file_name = args
path = os.path.join(dir_path, file_name)
else:
raise Exception("specialize for len(args) == %d." % len(args))
# Save data.
if not os.path.isfile(path):
f = h5py.File(path, "w")
for k, v in data_map.items():
f.create_dataset(k, data=v)
f.close()
return path
def load_data(paths):
'''Load multiple hdf5 files to single numpy array.'''
# Read data shapes.
shape_map = defaultdict(lambda : (0, None))
for p in paths:
f = h5py.File(p, "r")
for k in f.keys():
shape = tuple(f[k].shape)
shape_map[k] = (shape_map[k][0] + shape[0], shape[1])
f.close()
# Allocate output array.
data_map = { k : np.empty(s, dtype="f4") for k, s in shape_map.items() }
start_map = { k : 0 for k in shape_map }
# Load files.
for pi, p in enumerate(tqdm(paths, "load data")):
f = h5py.File(p, "r")
for k in f.keys():
i0 = start_map[k]
i1 = i0 + len(f[k])
data_map[k][i0:i1] = f[k]
start_map[k] += len(f[k])
f.close()
return data_map
def get_missing_blocks(workdir, n_samples, block_size,
validate=lambda f : None):
'''Divide range [0, num_samples) to sequence of block ranges.
This is a core method within the concept of block processing. The idea
is to divide a range (size n_samples) into a sequence of blocks. Each
block corresponds to a file within 'workdir' with name
'{start_idx}-{end_idx}.hdf5'. This method checks for the existence of
these files, and returns a list of the ones that are missing.
'''
# Block ranges.
block_start_idxs = list(range(0, n_samples, block_size))
block_end_idxs = [ min(n_samples, i + block_size) for i in block_start_idxs ]
block_ranges = list(zip(block_start_idxs, block_end_idxs))
# All block files (existing + missing).
n_digits = int(np.ceil(np.log(n_samples) / np.log(10)) + 1)
all_blocks = [{
"range" : r,
"path" : os.path.join(
workdir,
"%s-%s.hdf5" % tuple([ str(i).zfill(n_digits) for i in r ]),
)
} for r in block_ranges]
all_block_path_set = set(block["path"] for block in all_blocks)
# Delete corrupt files.
if torch.distributed.get_rank() == 0:
existing_block_paths = [block["path"]
for block in all_blocks
if os.path.exists(block["path"])]
for index, path in enumerate(
tqdm(existing_block_paths, "validating block.")):
assert path in all_block_path_set, "unexpected filename, '%s'." % path
try:
f = h5py.File(path, "r")
except:
raise Exception("unable to open/validate '%s'." % path)
os.remove(path)
continue
try:
validate(f)
except:
raise Exception("delete block file.")
os.remove(path)
finally:
f.close()
# Wait for files to be deleted.
torch.distributed.barrier()
# Filter missing files.
missing_blocks = [block
for block in all_blocks
if not os.path.exists(block["path"])]
return missing_blocks
def get_missing_blocks_by_rank(workdir, n_samples, block_size,
validate=lambda f : None):
'''Divide missing blocks evenly across all ranks.
See 'get_missing_blocks()' above for description. The returned list of
missing blocks is split evenly across ranks via interleaving. This way,
each rank has a roughly equal number of blocks to process for a
downstream operation.
'''
missing_blocks = get_missing_blocks(workdir, n_samples, block_size,
validate)
# This rank's missing files.
data_parallel_rank = parallel_state.get_data_parallel_rank()
data_parallel_world_size = parallel_state.get_data_parallel_world_size()
rank_missing_blocks = missing_blocks[data_parallel_rank:len(missing_blocks):data_parallel_world_size]
# Extend rank's missing blocks (with None) such that all ranks have equal
# length lists. This allows for easier tracking of global progress.
n_missing_tensor = torch.cuda.LongTensor([len(rank_missing_blocks)])
torch.distributed.all_reduce(n_missing_tensor,
op=torch.distributed.ReduceOp.MAX)
max_n_missing = n_missing_tensor.item()
rank_missing_blocks += [None] * (max_n_missing - len(rank_missing_blocks))
return len(missing_blocks), rank_missing_blocks
class IdPathMap:
'''Maps indexes to the containing block path.
This class optimizing the mapping of a large number of indexes to the
path of its containing block. For example, with block_size 1M, this class
stores 1/1M as many (long) path strings, saving memory.
'''
def __init__(self, paths):
self.paths = paths
self.path_index_map = {p:i for i,p in enumerate(paths)}
self.id_index_map = {}
def __str__(self):
return "%d paths; %d ids" % (len(self.paths), len(self.id_index_map))
def add(self, id, path):
'''Map index to a path.'''
self.id_index_map[id] = self.path_index_map[path]
def __contains__(self, idx):
'''Index added to this object?'''
return idx in self.id_index_map
def __getitem__(self, idx):
'''Get path from index.'''
return self.paths[self.id_index_map[idx]]
def path_to_range(path):
'''Parse start/end indexes from block path name (e.g., 00010-00011.hdf5 ->
(10, 11).'''
return tuple([
int(i) for i in os.path.splitext(
os.path.basename(path))[0].split("-")])
def get_index_path_map(_dir):
'''Map contained indexes to block file path (on disk).'''
paths = sorted(glob.glob(_dir + "/*.hdf5"))
# Build index-path map.
idx_path_map = IdPathMap(paths)
for path in paths:
start_idx, end_idx = path_to_range(path)
for idx in range(start_idx, end_idx):
idx_path_map.add(idx, path)
return idx_path_map
import json
import os
import sys
import types
import torch
def add_arguments(parser):
group = parser.add_argument_group(title='Megatron loader')
group.add_argument('--true-vocab-size', type=int, default=None,
help='original size of vocab, if specified will trim padding from embedding table.')
group.add_argument('--vocab-file', type=str, default=None,
help='Path to the vocab file. If specified will use this to get vocab size and '
'trim padding from the embedding table.')
group.add_argument('--megatron-path', type=str, default=None,
help='Base directory of deepspeed repository')
def _load_checkpoint(queue, args):
# Search in directory above this
sys.path.append(os.path.abspath(
os.path.join(os.path.dirname(__file__),
os.path.pardir)))
if args.megatron_path is not None:
sys.path.insert(0, args.megatron_path)
try:
from megatron.arguments import parse_args, validate_args
from megatron.global_vars import set_args, set_global_variables
from megatron.checkpointing import load_args_from_checkpoint, load_checkpoint
from megatron.model import module
from megatron.core import mpu
from megatron.core.enums import ModelType
from megatron import fused_kernels
except ModuleNotFoundError:
print("Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting.")
queue.put("exit")
exit(1)
# We want all arguments to come from us
sys.argv = ['script.py',
'--no-masked-softmax-fusion',
'--no-bias-gelu-fusion',
'--no-bias-dropout-fusion',
'--no-async-tensor-model-parallel-allreduce',
'--use-cpu-initialization',
'--micro-batch-size', '1',
'--no-load-optim',
'--no-load-rng',
'--no-save-optim',
'--no-save-rng',
'--no-initialization',
'--load', args.load_dir
]
margs = parse_args()
margs = load_args_from_checkpoint(margs)
# Arguments do sanity checks on the world size, but we don't care,
# so trick it into thinking we are plenty of processes
margs.world_size = margs.tensor_model_parallel_size * margs.pipeline_model_parallel_size
margs = validate_args(margs)
def check_for_arg(arg_name):
if getattr(margs, arg_name, None) is None:
print(f"Checkpoint does not specify the argument {arg_name}. Exiting.")
print(f"Arguments: {margs}")
queue.put("exit")
exit(1)
check_for_arg('tensor_model_parallel_size')
check_for_arg('pipeline_model_parallel_size')
check_for_arg('num_layers')
check_for_arg('hidden_size')
check_for_arg('seq_length')
check_for_arg('num_attention_heads')
check_for_arg('max_position_embeddings')
check_for_arg('tokenizer_type')
check_for_arg('iteration')
check_for_arg('bert_binary_head')
check_for_arg('params_dtype')
# Determine how to make our models
if args.model_type == 'GPT':
from pretrain_gpt import model_provider
margs.model_type = ModelType.encoder_or_decoder
elif args.model_type == 'BERT':
from pretrain_bert import model_provider
margs.model_type = ModelType.encoder_or_decoder
else:
raise Exception(f'unrecognized model type: {args.model_type}')
# supress warning about torch.distributed not being initialized
module.MegatronModule.embedding_warning_printed = True
consumed_train_samples = None
consumed_valid_samples = None
def get_models(count, dtype, pre_process, post_process):
nonlocal consumed_train_samples
nonlocal consumed_valid_samples
models = []
for rank in range(count):
mpu.set_tensor_model_parallel_rank(rank)
model_ = [model_provider(pre_process, post_process).to(dtype)]
margs.consumed_train_samples = 0
margs.consumed_valid_samples = 0
load_checkpoint(model_, None, None)
assert(len(model_) == 1)
model_ = model_[0]
if consumed_train_samples is not None:
assert(margs.consumed_train_samples == consumed_train_samples)
else:
consumed_train_samples = margs.consumed_train_samples
if consumed_valid_samples is not None:
assert(margs.consumed_valid_samples == consumed_valid_samples)
else:
consumed_valid_samples = margs.consumed_valid_samples
models.append(model_)
return models
if margs.num_layers_per_virtual_pipeline_stage is not None:
print("Model with an interleaved pipeline schedule are not yet supported.")
queue.put("exit")
exit(1)
set_global_variables(margs)
mpu.set_tensor_model_parallel_world_size(margs.tensor_model_parallel_size)
mpu.set_pipeline_model_parallel_world_size(margs.pipeline_model_parallel_size)
fused_kernels.load(margs)
# Get true (non-padded) vocab size
if args.true_vocab_size is not None:
true_vocab_size = args.true_vocab_size
elif args.vocab_file is not None:
vocab = json.load(open(args.vocab_file))
true_vocab_size = len(vocab)
if args.true_vocab_size is not None and true_vocab_size != args.true_vocab_size:
print("Both --true-vocab-size and --vocab-file specified and the vocab size does not match, aborting.")
queue.put("exit")
exit(1)
else:
true_vocab_size = None
# short aliases
tp_size = margs.tensor_model_parallel_size
pp_size = margs.pipeline_model_parallel_size
# metadata
md = types.SimpleNamespace()
md.model_type = args.model_type
md.num_layers = margs.num_layers
md.hidden_size = margs.hidden_size
md.seq_length = margs.seq_length
md.num_attention_heads = margs.num_attention_heads
md.max_position_embeddings = margs.max_position_embeddings
md.tokenizer_type = margs.tokenizer_type
md.iteration = margs.iteration
md.params_dtype = margs.params_dtype
md.bert_binary_head = margs.bert_binary_head
md.previous_tensor_parallel_size = margs.tensor_model_parallel_size
md.previous_pipeline_parallel_size = margs.pipeline_model_parallel_size
md.true_vocab_size = true_vocab_size
md.make_vocab_size_divisible_by = margs.make_vocab_size_divisible_by
# Get first pipe stage
mpu.set_pipeline_model_parallel_rank(0)
post_process = pp_size == 1
models = get_models(tp_size, md.params_dtype, True, post_process)
md.consumed_train_samples = consumed_train_samples
md.consumed_valid_samples = consumed_valid_samples
queue.put(md)
def queue_put(name, msg):
print(f"sending {name}")
msg["name"] = name
queue.put(msg)
# Send embeddings
message = {
"position embeddings": models[0].language_model.embedding.position_embeddings.weight.data,
"word embeddings": torch.cat(
[models[tp_rank].language_model.embedding.word_embeddings.weight.data for tp_rank in range(tp_size)],
dim = 0)
}
queue_put("embeddings", message)
total_layer_num = 0
for pp_rank in range(pp_size):
if pp_rank > 0:
mpu.set_pipeline_model_parallel_rank(pp_rank)
post_process = pp_rank == pp_size - 1
models = get_models(tp_size, md.params_dtype, False, post_process)
for layer_num in range(len(models[0].language_model.encoder.layers)):
message = {}
# Get non-parallel tensors from tp_rank 0
layer = models[0].language_model.encoder.layers[layer_num]
message["input layernorm weight"] = layer.input_layernorm.weight.data
message["input layernorm bias"] = layer.input_layernorm.bias.data
message["dense bias"] = layer.self_attention.dense.bias.data
message["post layernorm weight"] = layer.post_attention_layernorm.weight.data
message["post layernorm bias"] = layer.post_attention_layernorm.bias.data
message["mlp l1 bias"] = layer.mlp.dense_4h_to_h.bias.data
# Grab all parallel tensors for this layer
qkv_weight = []
qkv_bias = []
dense_weight = []
mlp_l0_weight = []
mlp_l0_bias = []
mlp_l1_weight = []
for tp_rank, model in enumerate(models):
layer = model.language_model.encoder.layers[layer_num]
qkv_weight.append(layer.self_attention.query_key_value.weight.data)
qkv_bias.append(layer.self_attention.query_key_value.bias.data)
dense_weight.append(layer.self_attention.dense.weight.data)
mlp_l0_weight.append(layer.mlp.dense_h_to_4h.weight.data)
mlp_l0_bias.append(layer.mlp.dense_h_to_4h.bias.data)
mlp_l1_weight.append(layer.mlp.dense_4h_to_h.weight.data)
# concat them
message["qkv weight"] = torch.cat(qkv_weight, dim=0)
message["qkv bias"] = torch.cat(qkv_bias, dim=0)
message["dense weight"] = torch.cat(dense_weight, dim=1)
message["mlp l0 weight"] = torch.cat(mlp_l0_weight, dim=0)
message["mlp l0 bias"] = torch.cat(mlp_l0_bias, dim=0)
message["mlp l1 weight"] = torch.cat(mlp_l1_weight, dim=1)
queue_put(f"transformer layer {total_layer_num}", message)
total_layer_num = total_layer_num + 1
# Send final layernorm from tp_rank 0
message = {
"weight": models[0].language_model.encoder.final_layernorm.weight.data,
"bias": models[0].language_model.encoder.final_layernorm.bias.data
}
queue_put("final layernorm", message)
# Send BERT lm head and binary head if it exists
if md.model_type == 'BERT':
message = {
"weight": models[0].language_model.pooler.dense.weight.data,
"bias": models[0].language_model.pooler.dense.bias.data
}
queue_put("pooler", message)
message = {
"dense weight": models[0].lm_head.dense.weight.data,
"dense bias": models[0].lm_head.dense.bias.data,
"layernorm weight": models[0].lm_head.layernorm.weight.data,
"layernorm bias": models[0].lm_head.layernorm.bias.data
}
queue_put("lm head", message)
if md.bert_binary_head:
message = {
"weight": models[0].binary_head.weight.data,
"bias": models[0].binary_head.bias.data
}
queue_put("binary head", message)
queue.put("done")
def load_checkpoint(queue, args):
try:
_load_checkpoint(queue, args)
except:
queue.put("exit")
raise
import argparse
from collections.abc import Mapping
import concurrent.futures
import os
import sys
import torch
def add_arguments(parser):
group = parser.add_argument_group(title='Megatron saver')
group.add_argument('--megatron-path', type=str, default=None,
help='Base directory of Megatron repository')
group.add_argument('--target-tensor-parallel-size', type=int,
help='Target tensor model parallel size, defaults to the tensor parallel size '
'in the input checkpoint if provided by the loader, otherwise to 1')
group.add_argument('--target-pipeline-parallel-size', type=int,
help='Target tensor model parallel size, default to the pipeline parall size '
'in the input checkpoint if provided by the loader, otherwise to 1')
def save_checkpoint(queue, args):
# Search in directory above this
sys.path.append(os.path.abspath(
os.path.join(os.path.dirname(__file__),
os.path.pardir)))
if args.megatron_path is not None:
sys.path.insert(0, args.megatron_path)
try:
from megatron.arguments import (parse_args, validate_args)
from megatron.checkpointing import save_checkpoint
from megatron.global_vars import set_global_variables, get_args
from megatron.core.enums import ModelType
from megatron.tokenizer.tokenizer import _vocab_size_with_padding
from megatron import fused_kernels
from megatron.core import mpu
except ModuleNotFoundError:
print("Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting.")
exit(1)
def queue_get(name=None):
val = queue.get()
if val == "exit":
print("Loader exited, exiting saver")
exit(1)
if name is not None and args.checking and val["name"] != name:
val_name = val["name"]
print(f'Unexpected message. Expecting "{name}" but got "{val_name}". Exiting saver.')
exit(1)
if name is not None:
print(f"received {name}")
return val
def check_message(msg):
if not args.checking:
return
msg_name = msg.pop("name")
if len(msg.keys()) > 0:
print(f"Unexpected values in {msg_name}:")
for key in msg.keys():
print(f" {key}")
print(f"Exiting. If you want to ignore this, use the argument --no-checking.")
exit(1)
md = queue_get()
if args.target_tensor_parallel_size is None:
if hasattr(md, 'previous_tensor_parallel_size'):
args.target_tensor_parallel_size = md.previous_tensor_parallel_size
else:
print("loader did not provide a tensor parallel size and --target-tensor-parallel-size not provided on command line. "
"Default to 1.")
args.target_tensor_parallel_size = 1
if args.target_pipeline_parallel_size is None:
if hasattr(md, 'previous_pipeline_parallel_size'):
args.target_pipeline_parallel_size = md.previous_pipeline_parallel_size
else:
print("loader did not provide a pipeline parallel size and --target-pipeline-parallel-size not provided on command line. "
"Default to 1.")
args.target_pipeline_parallel_size = 1
# Arguments do sanity checks on the world size, but we don't care,
# so trick it into thinking we are plenty of processes
if args.target_tensor_parallel_size is not None and args.target_pipeline_parallel_size is not None:
os.environ["WORLD_SIZE"] = f'{args.target_tensor_parallel_size * args.target_pipeline_parallel_size}'
# We want all arguments to come from us
sys.argv = ['script.py',
'--num-layers', str(md.num_layers),
'--hidden-size', str(md.hidden_size),
'--seq-length', str(md.seq_length),
'--num-attention-heads', str(md.num_attention_heads),
'--max-position-embeddings', str(md.max_position_embeddings),
'--tokenizer-type', str(md.tokenizer_type),
'--tensor-model-parallel-size', str(args.target_tensor_parallel_size),
'--pipeline-model-parallel-size', str(args.target_pipeline_parallel_size),
'--no-masked-softmax-fusion',
'--no-bias-gelu-fusion',
'--no-bias-dropout-fusion',
'--no-async-tensor-model-parallel-allreduce',
'--use-cpu-initialization',
'--micro-batch-size', '1',
'--no-load-optim',
'--no-load-rng',
'--no-save-optim',
'--no-save-rng',
'--no-initialization',
'--save-interval', '1',
'--save', args.save_dir
]
if md.make_vocab_size_divisible_by is not None:
sys.argv.extend(['--make-vocab-size-divisible-by', str(md.make_vocab_size_divisible_by)])
if md.params_dtype == torch.float16:
sys.argv.append('--fp16')
elif md.params_dtype == torch.bfloat16:
sys.argv.append('--bf16')
if md.model_type == 'BERT' and not md.bert_binary_head:
sys.argv.append('--bert-no-binary-head')
margs = parse_args()
validate_args(margs)
set_global_variables(margs)
# margs = megatron args
margs = get_args()
if hasattr(md, 'consumed_train_samples'):
margs.consumed_train_samples = md.consumed_train_samples
margs.consumed_valid_samples = md.consumed_valid_samples
print(f"Setting consumed_train_samples to {margs.consumed_train_samples}"
f" and consumed_valid_samples to {margs.consumed_valid_samples}")
else:
print("consumed_train_samples not provided.")
# Determine how to make our models
if md.model_type == 'GPT':
from pretrain_gpt import model_provider
margs.model_type = ModelType.encoder_or_decoder
elif md.model_type == 'BERT':
from pretrain_bert import model_provider
margs.model_type = ModelType.encoder_or_decoder
else:
raise Exception(f'unrecognized model type: {args.model_type}')
def get_models(count, dtype, pre_process, post_process):
models = [model_provider(pre_process, post_process).to(dtype) for _ in range(count)]
return models
# fake initializing distributed
mpu.set_tensor_model_parallel_world_size(args.target_tensor_parallel_size)
mpu.set_pipeline_model_parallel_world_size(args.target_pipeline_parallel_size)
mpu.set_tensor_model_parallel_rank(0)
mpu.set_pipeline_model_parallel_rank(0)
fused_kernels.load(margs)
# Embeddings
#-----------
embeddings_msg = queue_get("embeddings")
pos_embed = embeddings_msg.pop("position embeddings")
orig_word_embed = embeddings_msg.pop("word embeddings")
check_message(embeddings_msg)
# Deal with padding
if md.true_vocab_size is not None:
# figure out what our padded vocab size is
orig_vocab_size = orig_word_embed.shape[0]
margs.padded_vocab_size = _vocab_size_with_padding(md.true_vocab_size, margs)
# Cut out extra padding we don't need
if orig_vocab_size > margs.padded_vocab_size:
full_word_embed = orig_word_embed[0:margs.padded_vocab_size,:]
# Expanding embedding to larger size by replicating final entry
elif orig_vocab_size < margs.padded_vocab_size:
padding_size = margs.padded_vocab_size - orig_vocab_size
full_word_embed = torch.cat((
orig_word_embed,
orig_word_embed[-1].unsqueeze(0).expand(padding_size, -1)))
# Same size!
else:
full_word_embed = orig_word_embed
else:
print("Original vocab size not specified, leaving embedding table as-is. "
"If you've changed the tensor parallel size this could cause problems.")
margs.padded_vocab_size = orig_word_embed.shape[0]
full_word_embed = orig_word_embed
# Split into new tensor model parallel sizes
out_word_embed = torch.chunk(full_word_embed, args.target_tensor_parallel_size, dim=0)
# Make models for first pipeline stage and fill in embeddings
mpu.set_pipeline_model_parallel_rank(0)
post_process = args.target_pipeline_parallel_size == 1
models = get_models(args.target_tensor_parallel_size, md.params_dtype, True, post_process)
for tp_rank, model in enumerate(models):
print(f"word embeddings shape {model.language_model.embedding.word_embeddings.weight.shape}")
model.language_model.embedding.word_embeddings.weight.data.copy_(out_word_embed[tp_rank])
model.language_model.embedding.position_embeddings.weight.data.copy_(pos_embed)
# Transformer layers
#-------------------
total_layer_num = 0
for pp_rank in range(args.target_pipeline_parallel_size):
# For later pipeline parallel ranks, make the new models
if pp_rank > 0:
mpu.set_pipeline_model_parallel_rank(pp_rank)
post_process = pp_rank == args.target_pipeline_parallel_size - 1
models = get_models(args.target_tensor_parallel_size, md.params_dtype, False, post_process)
for layer in range(len(models[0].language_model.encoder.layers)):
msg = queue_get(f"transformer layer {total_layer_num}")
# duplicated tensors
input_layernorm_weight = msg.pop("input layernorm weight")
input_layernorm_bias = msg.pop("input layernorm bias")
dense_bias = msg.pop("dense bias")
post_layernorm_weight = msg.pop("post layernorm weight")
post_layernorm_bias = msg.pop("post layernorm bias")
mlp_l1_bias = msg.pop("mlp l1 bias")
# Split up the parallel tensors
qkv_weight = torch.chunk(msg.pop("qkv weight"), args.target_tensor_parallel_size, dim=0)
qkv_bias = torch.chunk(msg.pop("qkv bias"), args.target_tensor_parallel_size, dim=0)
dense_weight = torch.chunk(msg.pop("dense weight"), args.target_tensor_parallel_size, dim=1)
mlp_l0_weight = torch.chunk(msg.pop("mlp l0 weight"), args.target_tensor_parallel_size, dim=0)
mlp_l0_bias = torch.chunk(msg.pop("mlp l0 bias"), args.target_tensor_parallel_size, dim=0)
mlp_l1_weight = torch.chunk(msg.pop("mlp l1 weight"), args.target_tensor_parallel_size, dim=1)
# Save them to the model
for tp_rank in range(args.target_tensor_parallel_size):
l = models[tp_rank].language_model.encoder.layers[layer]
l.input_layernorm.weight.data.copy_(input_layernorm_weight)
l.input_layernorm.bias.data.copy_(input_layernorm_bias)
l.self_attention.query_key_value.weight.data.copy_(qkv_weight[tp_rank])
l.self_attention.query_key_value.bias.data.copy_(qkv_bias[tp_rank])
l.self_attention.dense.weight.data.copy_(dense_weight[tp_rank])
l.self_attention.dense.bias.data.copy_(dense_bias)
l.post_attention_layernorm.weight.data.copy_(post_layernorm_weight)
l.post_attention_layernorm.bias.data.copy_(post_layernorm_bias)
l.mlp.dense_h_to_4h.weight.data.copy_(mlp_l0_weight[tp_rank])
l.mlp.dense_h_to_4h.bias.data.copy_(mlp_l0_bias[tp_rank])
l.mlp.dense_4h_to_h.weight.data.copy_(mlp_l1_weight[tp_rank])
l.mlp.dense_4h_to_h.bias.data.copy_(mlp_l1_bias)
total_layer_num = total_layer_num + 1
check_message(msg)
if post_process:
msg = queue_get("final layernorm")
final_layernorm_weight = msg.pop("weight")
final_layernorm_bias = msg.pop("bias")
for tp_rank in range(args.target_tensor_parallel_size):
models[tp_rank].language_model.encoder.final_layernorm.weight.data.copy_(final_layernorm_weight)
models[tp_rank].language_model.encoder.final_layernorm.bias.data.copy_(final_layernorm_bias)
if pp_rank != 0:
# Copy word embeddings to final pipeline rank
models[tp_rank].word_embeddings.weight.data.copy_(out_word_embed[tp_rank])
del final_layernorm_weight
del final_layernorm_bias
check_message(msg)
msg = queue_get()
if msg != "done" and msg["name"] == "pooler":
if not hasattr(models[0].language_model, 'pooler'):
print("ERROR: got a pooler, but model does not have one")
exit(1)
print("received pooler")
pooler_weight = msg.pop("weight")
pooler_bias = msg.pop("bias")
for tp_rank in range(args.target_tensor_parallel_size):
models[tp_rank].language_model.pooler.dense.weight.data.copy_(pooler_weight)
models[tp_rank].language_model.pooler.dense.bias.data.copy_(pooler_bias)
del pooler_weight
del pooler_bias
check_message(msg)
msg = queue_get()
if msg != "done" and msg["name"] == "lm head":
if not hasattr(models[0], 'lm_head'):
print("ERROR: got an lm head, but model does not have one")
exit(1)
print("received lm head")
lm_head_dense_weight = msg.pop("dense weight")
lm_head_dense_bias = msg.pop("dense bias")
lm_head_layernorm_weight = msg.pop("layernorm weight")
lm_head_layernorm_bias = msg.pop("layernorm bias")
for tp_rank in range(args.target_tensor_parallel_size):
models[tp_rank].lm_head.dense.weight.data.copy_(lm_head_dense_weight)
models[tp_rank].lm_head.dense.bias.data.copy_(lm_head_dense_bias)
models[tp_rank].lm_head.layernorm.weight.data.copy_(lm_head_layernorm_weight)
models[tp_rank].lm_head.layernorm.bias.data.copy_(lm_head_layernorm_bias)
check_message(msg)
msg = queue_get()
if msg != "done" and msg["name"] == "binary head":
if not hasattr(models[0], 'binary_head'):
print("ERROR: got a binary head, but model does not have one")
exit(1)
print("received binary head")
binary_head_weight = msg.pop("weight")
binary_head_bias = msg.pop("bias")
for tp_rank in range(args.target_tensor_parallel_size):
models[tp_rank].binary_head.weight.data.copy_(binary_head_weight)
models[tp_rank].binary_head.bias.data.copy_(binary_head_bias)
check_message(msg)
msg = queue_get()
if msg != "done":
print("ERROR: got some more data but was expecting to be done")
for tp_rank in range(args.target_tensor_parallel_size):
mpu.set_tensor_model_parallel_rank(tp_rank)
save_checkpoint(md.iteration, [models[tp_rank]], None, None)
print("Done!")
import argparse
import importlib
import torch.multiprocessing as mp
import os
import sys
# A loader is a python file with at least two functions
# - add_arguments - takes in a parser and adds any arguments needed
# - load_checkpoint - takes in the queue and parsed arguments
# A saver is similar but has save_checkpoint instead of
# load_checkpoint
# The loader and saver process are each given a queue, the loader
# should load the checkpoint and send the weights in messages in the
# following order, the saver should receive them in this order and
# save the checkpoints. A message consists of a python dictionary with
# a "name" for error checking and an entry for each tensor as
# indicated below. Note that the weight sent over the queue are the
# full model weights, nothing split.
# If the loader ever sends "exit" to the queue, that means something
# went wrong and it is exiting.
# - Metadata Namespace with the following attributes:
# model_type - GPT, BERT, T5, etc. (Part of protocol to allow this to be deduced later instead of given on command line)
# num_layers - Number of transformer layers
# hidden_size
# seq_length
# num_attention_heads
# max_position_embeddings
# tokenizer_type
# iteration
# params_dtype
# bert_binary_head - Used only if model_type is BERT
# previous_tensor_parallel_size - Optional
# previous_pipeline_parallel_size - Optional
# true_vocab_size
# make_vocab_size_divisble_by
# consumed_train_samples
# consumed_valid_samples
# messages
# {
# "name": "embeddings"
# "position embeddings"
# "word embeddings"
# }
# (for each transformer layer):
# {
# "name": "transformer layer N"
# "input layernorm weight"
# "input layernorm bias"
# "qkv weight"
# "qkv bias"
# "dense weight"
# "dense bias"
# "post layernorm weight"
# "post layernorm bias"
# "mlp l0 weight"
# "mlp l0 bias"
# "mlp l1 weight"
# "mlp l1 bias"
# }
# {
# "name": "final layer norm"
# "weight"
# "bias"
# }
# if present (i.e. for BERT):
# {
# "name": "pooler"
# "weight"
# "bias"
# }
# {
# "name": "lm head"
# "dense weight"
# "dense bias"
# "layernorm weight"
# "layernorm bias"
# }
# {
# "name": "binary head"
# "weight"
# "bias"
# }
# - "done"
def load_plugin(plugin_type, name):
module_name = f"checkpoint_{plugin_type}_{name}"
try:
plugin = importlib.import_module(module_name)
except ModuleNotFoundError:
module_name = name
try:
plugin = importlib.import_module(module_name)
except ModuleNotFoundError:
sys.exit(f"Unable to load {plugin_type} plugin {name}. Exiting.")
if not hasattr(plugin, 'add_arguments'):
sys.exit(f"{module_name} module is not a plugin. Exiting.")
print(f"Loaded {module_name} as the {plugin_type}.")
return plugin
def main():
import argparse
parser = argparse.ArgumentParser(description="Megatron Checkpoint Utility Arguments",
allow_abbrev=False, conflict_handler='resolve')
parser.add_argument('--model-type', type=str, required=True,
choices=['GPT', 'BERT'],
help='Type of the model')
parser.add_argument('--loader', type=str, default='megatron',
help='Module name to load checkpoint, should be on python path')
parser.add_argument('--saver', type=str, default='megatron',
help='Module name to save checkpoint, shdoul be on python path')
parser.add_argument('--load-dir', type=str, required=True,
help='Directory to load model checkpoint from')
parser.add_argument('--save-dir', type=str, required=True,
help='Directory to save model checkpoint to')
parser.add_argument('--max-queue-size', type=int, default=50,
help='Maximum number of tensors in the queue')
parser.add_argument('--no-checking', action='store_false',
help='Do not perform checking on the name and ordering of weights',
dest='checking')
known_args, _ = parser.parse_known_args()
loader = load_plugin('loader', known_args.loader)
saver = load_plugin('saver', known_args.saver)
loader.add_arguments(parser)
saver.add_arguments(parser)
args = parser.parse_args()
queue = mp.Queue(maxsize=args.max_queue_size)
print("Starting saver...")
saver_proc = mp.Process(target=saver.save_checkpoint, args=(queue, args))
saver_proc.start()
print("Starting loader...")
loader.load_checkpoint(queue, args)
print("Waiting for saver to complete...")
saver_proc.join()
if __name__ == '__main__':
main()
import os
import os.path as osp
import pathlib
import subprocess
def recursively_lint_files():
"""Recursively lint all python files in chosen subdirectories of megatron-lm"""
try:
import autopep8
except ModuleNotFoundError:
print("Please first install autopep8 via `pip install autopep8`")
return
# get all python file paths from top level directory
file_dir = str(pathlib.Path(__file__).parent.absolute())
working_dir = osp.join(file_dir, os.pardir)
all_py_paths = set(os.path.join(working_dir, fname)
for fname in os.listdir(working_dir) if ".py" in fname)
# get all python file paths from chosen subdirectories
check_dirs = ['docker', 'megatron', 'openwebtext', 'scripts', 'tasks']
for sub_dir in check_dirs:
for path, _, fnames in os.walk(osp.join(working_dir, sub_dir)):
all_py_paths.update(set(osp.join(path, fname) for fname in fnames if ".py" in fname))
print("Linting the following: ")
for py_path in all_py_paths:
print(py_path)
command = 'autopep8 --max-line-length 100 --aggressive --in-place {}'.format(py_path)
subprocess.check_call(command)
if __name__ == "__main__":
recursively_lint_files()
import os
import sys
import json
import argparse
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir)))
from megatron.data import indexed_dataset
def main(args):
prefixes = set()
for basename in os.listdir(args.input):
prefix, ext = os.path.splitext(basename)
if prefix in prefixes:
continue
if not os.path.isfile(os.path.join(args.input, basename)):
continue
ext_pair = '.bin' if ext == '.idx' else '.idx'
assert os.path.isfile(os.path.join(args.input, prefix) + ext_pair), \
f'ERROR: {ext_pair} file not provided for {os.path.join(args.input, prefix)}'
prefixes.add(prefix)
builder = None
for prefix in sorted(prefixes):
if builder is None:
dataset = indexed_dataset.make_dataset(os.path.join(args.input, prefix), 'infer')
if isinstance(dataset, indexed_dataset.MMapIndexedDataset):
builder = indexed_dataset.MMapIndexedDatasetBuilder(args.output_prefix + '.bin', dtype=dataset._index.dtype)
else:
builder = indexed_dataset.IndexedDatasetBuilder(args.output_prefix + '.bin')
del dataset
builder.merge_file_(os.path.join(args.input, prefix))
builder.finalize(args.output_prefix + '.idx')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
group = parser.add_argument_group(title='input data')
group.add_argument('--input', type=str, required=True,
help='Path to directory containing all document files to merge')
group = parser.add_argument_group(title='output data')
group.add_argument('--output-prefix', type=str, required=True,
help='Path to binary output file without suffix')
args = parser.parse_args()
assert os.path.isdir(args.input), \
f'ERROR: {args.input} is not a directory or does not exist'
assert os.path.isdir(os.path.dirname(args.output_prefix)), \
f'ERROR: {os.path.dirname(args.output_prefix)} is not a directory or does not exist'
main(args)
The following steps show how to prepare training dataset to train the mode.
# Libraries to install
```
pip install ftfy langdetect numpy torch pandas nltk sentencepiece boto3 tqdm regex bs4 newspaper3k htmlmin tldextract
git clone https://github.com/mattilyra/LSH
cd LSH
python setup.py install
```
# Download the dataset
1. Download the deduplicated URLs from [jcpeterson](https://mega.nz/#F!EZZD0YwJ!9_PlEQzdMVLaNdKv_ICNVQ!cc4RgQQZ)
2. Remove blacklisted URLs.
```
python blacklist_urls.py <path to the dowloaded deduplicated URLs> <filename for clean urls. e.g. clean_urls.txt>
```
3. Download the content from the clean urls with [openwebtext's utilities](https://github.com/eukaryote31/openwebtext/blob/master/download.py).
4. Merge the contents into one loose json file with 1 json per newline of the format `{'text': text, 'url': unique_url}`. It is important for the url to be unique.
# Prepare the data for GPT training:
1. Perform ftfy, english detection and remove documents with less than 128 tokens. This step can be sharded and run on shards.
```
python cleanup_dataset.py <input data file> <output cleaned data filename>
```
Additional cleanup (e.g. remove documents less than 512 characters or dataset specific cleaning like stories, realnews datasets) can be done using `cleanup_fix_dataset.py`. More details can be found by running `python cleanup_fix_dataset.py --help`.
2. Using LSH, find possible duplicates and store then in a file for later processing. The code supports saving and loading fingerprints for recurrent deduplications, and is also multithreaded for faster processing. More details are can be found by `python find_duplicate.py --help`.
```
python find_duplicates.py --inputs <pairlist list of input cleaned data files and keys, e.g. cc.json cc_id news.json news_id> --output <output possible duplicate urls filename>
```
3. Based on similarity measure defind inside function `is_similar` (default: 0.9), group urls that are similar. Basically, for each group, only one url we should keep and remove the rest.
```
python group_duplicate_urls.py <possible duplicate urls file> <output file containing similar urls>
```
4. Remove similar documents that were detected in the last step.
```
python remove_group_duplicates.py <file containing simialr documents> <cleaned data file> <outputfile containing deduplicate data>
```
5. Shuffle the dataset.
```
shuf <cleaned deduped data file> -o train_data.json
```
# Deduplicating ngrams
To deduplicate the downstream tasks (e.g. lambada, squad) from the training dataset, we run the following command.
```
python filter_ngrams.py --tasks <name of the task, e.g. lambada, squad> --dedup-dataset <training dataset to deduplicate> <json key> --output <output training dataset>
```
We use 13-grams by default for the deduplication. When we find a 13-gram match in a training document, we split the document into two pieces and remove the 13-gram along with 200 characters from the both side of the 13-gram. We also remove any splitted document with less than 200 characters or if a document got splitted more than 10 times. These parameters can be changed using corresponding arguments.
Only for the lambada task, we need to provide the path, `--lambada-path <path of the lambada test data>`.
Several other features (e.g. save and load dictionary) have been added, look at `python filter_ngrams.py --help` for details.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import argparse
import json
import os
import time
"""
This code adds id to each json object in a json file. User can add prefix
to the ids.
"""
if __name__ == '__main__':
print('parsing the arguments ...')
parser = argparse.ArgumentParser()
parser.add_argument('--input-file', type=str, default=None, help='Input'\
' json file where id needs to be added')
parser.add_argument('--output-file', type=str, default=None, help=\
'Output file name with id')
parser.add_argument('--id-prefix', type=str, default=None, help=\
'Id prefix')
parser.add_argument('--log-interval', type=int, default=100,
help='Log interval')
args = parser.parse_args()
print('Adding ids to dataset ...')
f_input = open(args.input_file, 'r', encoding='utf-8')
f_output = open(args.output_file, 'wb')
unique_ids = 1
start_time = time.time()
for row in f_input:
each_row = json.loads(row)
adlr_id_string = args.id_prefix + '-{:010d}'.format(int(unique_ids))
each_row['adlr_id'] = adlr_id_string
myjson = json.dumps(each_row, ensure_ascii=False)
f_output.write(myjson.encode('utf-8'))
f_output.write('\n'.encode('utf-8'))
if unique_ids % args.log_interval == 0:
print(' processed {:9d} documents in {:.2f} seconds ...'.format( \
unique_ids, time.time() - start_time), flush=True)
unique_ids += 1
# Close the file.
f_input.close()
f_output.close()
print('done :-)', flush=True)
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import glob
import re
import time
import tldextract
import sys
# List of the domains to blacklist.
domain_blacklist = set([
'500px',
'aapks',
'akamaihd',
'amazon',
'apple',
'artifactfire',
'artstation',
'awwni',
'bandcamp',
'battleforthenet',
'coinscalendar',
'dailymotion',
'deviantart',
'discord',
'discordapp',
'dlapkandroid',
'dropbox',
'e621',
'ebay',
'edealinfo',
'erome',
'eroshare',
'explosm',
'facebook',
'fbcdn',
'flickr',
'furaffinity',
'futhead',
'gatopardo',
'gfycat',
'gifsound',
'gifsoup',
'giphy',
'github',
'google',
'gunprime',
'gyazo',
'hotdealstar',
'imagefap',
'imageshack',
'imgflip',
'imgur',
'instagram',
'karmadecay',
'kryptocal',
'kym-cdn',
'liveleak',
'livememe',
'lmgtfy',
'magaimg',
'memegenerator',
'minorplanetcenter',
'minus',
'mobafire',
'morejpeg',
'nocookie',
'pcpartpicker',
'photobucket',
'pinimg',
'pinterest',
'pixiv',
'pornhub',
'prntscr',
'puu',
'qkme',
'quickmeme',
'radd',
'redd',
'reddit',
'reddit-stream',
'redditlog',
'redditmedia',
'reddituploads',
'redtube',
'reupp',
'reverb',
'roanoke',
'rollingstone',
'sli',
'soundcloud',
'soundgasm',
'spankbang',
'spotify',
'strawpoll',
'streamable',
'timeanddate',
'tinypic',
'touhouradio',
'tumblr',
'twimg',
'twitch',
'twitter',
'vid',
'vimeo',
'vine',
'vkaao',
'vocaroo',
'voyagefusion',
'walmart',
'wciu',
'wikimedia',
'wikipedia',
'xhamster',
'xkcd',
'xvideos',
'youtu',
'youtube',
'youtubedoubler',
'ytimg',
'zillexplorer',
])
def domain_is_in_blacklist(url):
domain = tldextract.extract(url).domain
return domain in domain_blacklist
# List of extentions to blacklist.
extentions_blacklist = (
'.3gp',
'.7z'
'.ai',
'.aif',
'.apk',
'.app',
'.avi',
'.bin',
'.bmp',
'.bz2',
'.css',
'.csv',
'.dat',
'.deb',
'.dmg',
'.doc',
'.docx',
'.exe',
'.gif',
'.gifv',
'.gz',
'.iso',
'.jar',
'.jpeg',
'.jpg',
'.js',
'.log',
'.mid',
'.midi',
'.mkv',
'.mov',
'.mp3',
'.mp4',
'.mpeg',
'.mpg',
'.ogg',
'.ogv',
'.otf',
'.pdf',
'.pkg',
'.png',
'.pps',
'.ppt',
'.pptx',
'.psd',
'.py',
'.qt',
'.ram',
'.rar',
'.sql',
'.svg',
'.swf',
'.tar.gz',
'.tar',
'.tgz',
'.tiff',
'.ttf',
'.txt',
'.wav',
'.webm',
'.wma',
'.wmv',
'.xls',
'.xlsx',
'.xml',
'.xz',
'.zip',
)
def extention_is_in_blacklist(url):
if url.split('?')[0].lower().endswith(extentions_blacklist):
return True
return False
# Malformed urls.
# This function is adapted from:
# https://stackoverflow.com/questions/7160737/python-how-to-validate-a-url-in-python-malformed-or-not
url_regex = re.compile(
r'^(?:http)s?://' # http:// or https://
r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|' #domain...
r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # ...or ip
r'(?::\d+)?' # optional port
r'(?:/?|[/?]\S+)$', re.IGNORECASE)
def url_is_malformed(url):
return re.match(url_regex, url) is None
def print_progress(prefix, start_time, urls_counter,
domain_blacklist_counter,
extention_blacklist_counter,
short_url_counter, malformed_url_counter,
duplicate_url_counter):
string = prefix + ' | '
string += 'time elapsed (s): {:.2f} | '.format(time.time() - start_time)
string += 'number of urls: {} | '.format(urls_counter)
string += 'domain blacklisted: {} | '.format(domain_blacklist_counter)
string += 'extention blacklisted: {} | '.format(extention_blacklist_counter)
string += 'short urls (<=8): {} | '.format(short_url_counter)
string += 'malformed urls: {} | '.format(malformed_url_counter)
string += 'duplicate urls: {}'.format(duplicate_url_counter)
print(string, flush=True)
if __name__ == '__main__':
print('remove blacklisted urls ..')
# Path to the url files.
path = sys.argv[1]
# Output url file.
output = sys.argv[2]
# Get the list of url files.
files = glob.glob(path + '/*.txt')
print('> found {} files'.format(len(files)))
urls = set()
urls_counter = 0
domain_blacklist_counter = 0
extention_blacklist_counter = 0
short_url_counter = 0
malformed_url_counter = 0
duplicate_url_counter = 0
start_time = time.time()
for filename in files:
with open(filename, 'r') as f:
for line in f:
url = line.strip()
urls_counter += 1
if domain_is_in_blacklist(url):
print('[DOMAIN BLACKLIST]: {}'.format(url), flush=True)
domain_blacklist_counter += 1
elif extention_is_in_blacklist(url):
print('[EXTENTION BLACKLIST]: {}'.format(url), flush=True)
extention_blacklist_counter += 1
elif len(url) <= 8:
print('[SHORT URL]: {}'.format(url), flush=True)
short_url_counter += 1
elif url_is_malformed(url):
print('[MALFORMED URL]: {}'.format(url), flush=True)
malformed_url_counter += 1
elif url in urls:
print('[DUPLICATE URL]: {}'.format(url), flush=True)
duplicate_url_counter += 1
else:
urls.add(url)
if urls_counter % 100000 == 0:
print_progress('PROGRESS', start_time, urls_counter,
domain_blacklist_counter,
extention_blacklist_counter,
short_url_counter, malformed_url_counter,
duplicate_url_counter)
print_progress('FINAL', start_time, urls_counter,
domain_blacklist_counter,
extention_blacklist_counter,
short_url_counter, malformed_url_counter,
duplicate_url_counter)
# Write the final set of urls.
print('> writing cleaned up url list to {}'.format(output))
with open(output, 'w') as f:
for url in urls:
f.write(url + '\n')
print('done :-)')
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import ftfy
import json
from langdetect import detect
import numpy as np
import time
import os
import sys
from tokenizer import Tokenizer
MIN_DOCUMENT_LENGHT = 128
def print_progress(prefix, start_time, num_docs, num_fixed_text,
num_non_english_docs, chars_non_english_docs,
num_small_docs, chars_small_docs):
string = prefix + ' | '
string += 'elapsed time: {:.2f} | '.format(time.time() - start_time)
string += 'documents: {} | '.format(num_docs)
string += 'fixed text: {} | '.format(num_fixed_text)
string += 'non-english: {} | '.format(num_non_english_docs)
string += 'non-english chars: {} | '.format(chars_non_english_docs)
string += 'small docs: {} | '.format(num_small_docs)
string += 'small docs chars: {}'.format(chars_small_docs)
print(string, flush=True)
def filter_corpus(filename, out_filename, print_interval=10000):
print(' > filtering {}'.format(filename))
tokenizer = Tokenizer(cache_dir='./cache')
num_docs = 0
num_written_docs = 0
num_small_docs = 0
num_fixed_text = 0
num_non_english_docs = 0
chars_non_english_docs = 0
chars_small_docs = 0
start_time = time.time()
with open(out_filename, 'wb') as f:
with open(filename, 'r') as fin:
for line in fin:
try:
num_docs += 1
myjson = json.loads(line)
# Fix text
text = ftfy.fix_text(myjson['text'])
if text != myjson['text']:
num_fixed_text += 1
myjson['text'] = text
# Detect language.
if detect(text) != 'en':
print('[non-english text]', myjson)
num_non_english_docs += 1
chars_non_english_docs += len(text)
continue
# On average each token is 5 characters so 8 is an
# upper bound.
if len(text) < (8 * MIN_DOCUMENT_LENGHT):
tokens = tokenizer.tokenize_document(text)
if len(tokens) < MIN_DOCUMENT_LENGHT:
print('[small document, skipping]:', myjson)
num_small_docs += 1
chars_small_docs += len(text)
continue
myjson = json.dumps(myjson, ensure_ascii=False)
f.write(myjson.encode('utf-8'))
f.write('\n'.encode('utf-8'))
num_written_docs += 1
if num_docs % print_interval == 0:
print_progress('[PROGRESS]', start_time, num_docs,
num_fixed_text, num_non_english_docs,
chars_non_english_docs,
num_small_docs, chars_small_docs)
except Exception as e:
print(' skipping ', line, e)
print_progress('[FINAL]', start_time, num_docs,
num_fixed_text, num_non_english_docs,
chars_non_english_docs,
num_small_docs, chars_small_docs)
if __name__ == '__main__':
print('building gpt2 dataset ...')
input_filename = sys.argv[1]
output_filename = sys.argv[2]
print('will be reading {}'.format(input_filename))
print('and will write the results to {}'.format(output_filename))
filter_corpus(input_filename, output_filename)
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""
Filter and clean documents:
Capable to clean docs with less than 512 characters, less than
256 characters and contains javascript, fix text and dataset specific
cleaning like stories and realnews datasets.
Program arguments have the details.
"""
import argparse
from functools import partial
import glob
import ftfy
import json
from langdetect import detect
import multiprocessing
import os
from pathlib import Path
import re
import time
def process_doc(json_line, args):
# Read the line.
document = json.loads(json_line)
text = document['text']
output = {'remove_512': False, 'remove_256_javascript': False, \
'remove_512_non_english': False, 'ftfy_fix_text': False, \
'general_cleaning': False}
try:
# Reomove all docs with less than 512 characters
if "remove_512" in args.tasks:
if len(text) < 512:
output['remove_512'] = True
return output, text, document, True
# Remove docs if less than 256 character length and contains Javascript
if "remove_256_javascript" in args.tasks:
if len(text) < 256 and 'javascript' in text.lower():
output['remove_256_javascript'] = True
return output, text, document, True
# Remove docs < 512 and nonenglish
if "remove_512_non_english" in args.tasks:
if len(text) < 512 and detect(text) != 'en':
output['remove_512_non_english'] = True
return output, text, document, True
# Fix the text using ftfy, don't remove the text, hence return False
if "ftfy_fix_text" in args.tasks:
fixed_text = ftfy.fix_text(text)
output['ftfy_fix_text'] = True
return output, fixed_text, document, False
# Cleaning extra spaces and newlines
if "general_cleaning" in args.tasks:
cleaned_text = re.sub(r" +|\b\n+ |\b\n+", " ", text)
#cleaned_text = re.sub(r"\n\n+", "\n\n", text) # used this for Gutenberg dataset
#cleaned_text = re.sub(r"\n", "\n\n", text) # Used this for realnews
# stories datasets
#cleaned_text = re.sub(r" \'", "'", text)
#cleaned_text = re.sub(r" \!", "!", cleaned_text)
#cleaned_text = re.sub(r" \.", ".", cleaned_text)
#cleaned_text = re.sub(r" \?", "?", cleaned_text)
#cleaned_text = re.sub(r" - ", "-", cleaned_text)
##cleaned_text = re.sub(r"\" ", "\"", cleaned_text)
#cleaned_text = re.sub(r" @ ", "@", cleaned_text)
output['general_cleaning'] = True
return output, cleaned_text, document, False
except Exception as e:
print('Error: *************************\n{}\ntext: {}'.format(e, \
text), flush=True)
return output, text, document, True
# don't remove
return output, text, document, False
def process_set(args, input_file, output_f_cleaned, output_f_filtered):
print(' > working on {} ...'.format(input_file), flush=True)
num_docs = num_remove_512 = num_remove_java = num_remove_512_non_english \
= num_ftfy_fix_text = num_general_cleaning = 0
# Output file and counters.
output_cleaned = open(output_f_cleaned, 'wb')
output_filtered = open(output_f_filtered, 'wb')
start_time = time.time()
# Setup multi-processing.
num_workers = 40
fin = open(input_file, 'r', encoding='utf-8')
pool = multiprocessing.Pool(num_workers)
process_doc_partial = partial(process_doc, args=args)
processed_docs = pool.imap(process_doc_partial, fin, 500)
# Process documents.
for output, text, document, to_filter in processed_docs:
num_docs += 1
num_remove_512 += 1 if output['remove_512'] else 0
num_remove_java += 1 if output['remove_256_javascript'] else 0
num_remove_512_non_english += 1 if output['remove_512_non_english'] \
else 0
num_ftfy_fix_text += 1 if output['ftfy_fix_text'] else 0
num_general_cleaning += 1 if output['general_cleaning'] else 0
document['text'] = text
myjson = json.dumps(document, ensure_ascii=False)
if to_filter:
output_filtered.write(myjson.encode('utf-8'))
output_filtered.write('\n'.encode('utf-8'))
else:
output_cleaned.write(myjson.encode('utf-8'))
output_cleaned.write('\n'.encode('utf-8'))
if num_docs % args.log_interval == 0:
print(' processed {:9d} documents in {:.2f} seconds ...'.format(
num_docs, time.time() - start_time), flush=True)
# Close the file.
output_cleaned.close()
output_filtered.close()
fin.close()
# Print stats.
print(' >> total docs: {} remove_512 {} remove_256_javascript {} '\
'remove_512_non_english {} ftfy_fix_text {} general_cleaning {}'.\
format(num_docs, num_remove_512, num_remove_java,\
num_remove_512_non_english, num_ftfy_fix_text, \
num_general_cleaning), flush=True)
if __name__ == '__main__':
print('parsing the arguments ...')
parser = argparse.ArgumentParser()
parser.add_argument('--input-files', nargs = '*', required=True, default=\
None, help = 'Input json files that needs to be'\
' cleaned')
parser.add_argument('--tasks', nargs = '*', required=True, default=None,\
help = 'Tasks to perform on the input files, ' \
'such as remove_512, remove_256_javascript, ' \
'remove_512_non_english, ftfy_fix_text, and ' \
'general_cleaning. 256 or 512 means the number' \
' of characters.')
parser.add_argument('--output-path', type=str, default=None,
help='Directory where the output should go')
parser.add_argument('--log-interval', type=int, default=100,
help='Log interval')
args = parser.parse_args()
print('cleanup dataset ...')
for input_file in args.input_files:
input_filename, input_filename_ext = os.path.splitext(Path(input_file)\
.name)
output_f_cleaned = os.path.join(args.output_path, input_filename + \
"_cleaned" + input_filename_ext)
output_f_filtered = os.path.join(args.output_path, input_filename + \
"_filtered" + input_filename_ext)
process_set(args, input_file, output_f_cleaned, output_f_filtered)
print('done :-)', flush=True)
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""
Deduplicate downstream tasks from training dataset. 13-grams have been used.
All split documents with less than 200 characters got filtered. Any document
with more than 10 splits got filtered as well.
"""
import argparse
from functools import partial
import json
import multiprocessing
import nltk
import pickle
import re
import string
import sys
import time
def get_words(text):
# get all the lowercase words from text
words, positions = [], []
for match in re.finditer(r'\w+', text.lower()):
words.append(match.group(0))
positions.append(match.start())
return words, positions
# splits the text
def split_text(text, start_position, remove_char_each_side, seq):
# first part of the text
punctuations = ".!?"
pos = start_position - remove_char_each_side
text_first = ""
while pos > 0 and not text[pos] in punctuations:
pos -= 1
if pos > 0:
text_first = text[0:pos+1]
# add length of seq and remove_char_each_side
pos = start_position + len(seq) + remove_char_each_side
# last part of the text
text_second = ""
while pos < len(text) and not text[pos] in punctuations:
pos += 1
if pos + 1 < len(text):
text_second = text[pos+1:len(text)]
return text_first, text_second
def check_and_clean_text(args, words, ngrams, text, start_position, \
text_buf_ngram_free, text_buf, local_ngram):
seq = " ".join(words)
if seq in ngrams:
print(" [matched]: {}".format(seq), flush=True)
if args.get_ngram_freq_only:
# increase freq of this seq and then only consider the later part
# of the text for further processing
if seq in local_ngram:
local_ngram[seq] += 1
else:
local_ngram[seq] = 1
#print(" [increased]: {} {}".format(seq, ngrams[seq]), flush=True)
if (start_position + len(seq) + 1) < len(text):
text_buf.append(text[start_position + len(seq) + 1:len(text)])
return False
# split the text
text_first, text_second = split_text(text, start_position, \
args.remove_char_each_side, seq)
# first part of ngrams free
if len(text_first) > args.filter_text_char_len:
text_buf_ngram_free.append(text_first)
# add second part for further processing
if len(text_second) > args.filter_text_char_len:
text_buf.append(text_second)
return False # not ngram free
# ngram free
return True
def free_ngram(line, args, key, ngrams, ngrams_freq_sorted):
# remove all the ngrams
try:
myjson = json.loads(line)
text_buf = [myjson[key]]
except Exception as e:
print("Error: {}".format(e), flush=True)
text_buf = []
text_buf_ngram_free = []
local_ngram = {}
while len(text_buf) > 0:
# get the first one from the buffer
text = text_buf.pop(0)
words, positions = get_words(text)
ngram_free = True
# find each max n-grams and check dictionary
for i in range(len(words) - args.max_ngram_size + 1):
check_ngram_free = check_and_clean_text(args, words[i:\
i+args.max_ngram_size], ngrams, text, positions[i], \
text_buf_ngram_free, text_buf, local_ngram)
# the seq is ngram free? if yes, break
if not check_ngram_free:
ngram_free = False
break
# if max ngrams doesn't match, check if any other lower n-grams
# within max ngram macthes
for ngram_len, _ in ngrams_freq_sorted:
check_ngram_free = check_and_clean_text(args, words[i:\
i+ngram_len], ngrams, text, positions[i], \
text_buf_ngram_free, text_buf, local_ngram)
# same check as above
if not check_ngram_free:
ngram_free = False
break
# check break from lower than max ngram loop above
if not ngram_free:
break
# for the last max n-gram, check all the lower ngrams in it
if ngram_free and len(words) - args.max_ngram_size > 0:
# get the last words of the lax max ngram
last_seq_words = words[(len(words)-args.max_ngram_size):len(words)]
last_seq_start_position = len(words) - args.max_ngram_size
# check all n-grams lower than the max
for pos, (ngram_len, _) in enumerate(ngrams_freq_sorted):
# ignore the max ngram as has been considered already
if ngram_len == args.max_ngram_size:
continue
# find each ngram of ngram_len in max n-grams and check
for i in range(len(last_seq_words) - ngram_len + 1):
check_ngram_free = check_and_clean_text(args, \
last_seq_words[i:i+ngram_len], ngrams, text,\
positions[last_seq_start_position+i], \
text_buf_ngram_free, text_buf, local_ngram)
if not check_ngram_free:
ngram_free = False
break
if not ngram_free:
break
# texts are ngram free
if ngram_free and not args.get_ngram_freq_only:
text_buf_ngram_free.append(text)
# check if the text has only been trimmed
trimmed = 0
if not args.get_ngram_freq_only and len(text_buf_ngram_free) == 1 and \
len(text_buf_ngram_free[0]) < len(myjson[key]):
trimmed = 1
return text_buf_ngram_free, trimmed, myjson, local_ngram
# insert word sequence into dictionary
def insert_dict(words, ngrams, pos):
seq = " ".join(words)
if seq not in ngrams:
ngrams[seq] = 0
#ngrams[seq] = pos
# insert each ngram from text into the ngrams dictionary
def compute_ngrams_insert_dict(args, text, ngrams):
words, positions = get_words(text)
if len(words) < args.min_ngram_size:
return
if len(words) < args.max_ngram_size:
insert_dict(words, ngrams, positions[0])
for i in range(len(words) - args.max_ngram_size+1):
insert_dict(words[i:i+args.max_ngram_size], ngrams, positions[i])
# Build ngrams for the lambada dataset
def process_task_lambda(args, task_file, ngrams):
print(' reading from {} and computing ngrams'.format(task_file))
with open(task_file, 'r') as f:
for line in f:
try:
myjson = json.loads(line)
text = myjson['text']
compute_ngrams_insert_dict(args, text, ngrams)
except Exception as e:
print('Error:', e)
print(" Entities in ngrams {}".format(len(ngrams)), flush=True)
# Build ngrams for the dataset of the given task
def process_task(args, task_name, ngrams):
print(' reading from {} and computing ngrams'.format('import datasets'))
print(" Current entities in ngrams {}".format(len(ngrams)), flush=True)
# using validation/test data from datasets
from datasets import load_dataset
entities_in_ngrams = len(ngrams)
# load the dataset
if task_name == 'squad':
dataset = load_dataset('squad_v2', split='validation')
elif task_name == 'natural_questions':
dataset = load_dataset('natural_questions', split='validation')
elif task_name == 'triviaqa':
dataset = load_dataset('trivia_qa', 'unfiltered', split='test')
elif task_name == 'webqa':
dataset = load_dataset('web_questions', split='test')
elif task_name == 'race':
dataset = load_dataset('race', 'all', split='test')
elif task_name == 'drop':
dataset = load_dataset('drop', split='validation')
elif task_name == 'coqa':
dataset = load_dataset('coqa', split='validation')
elif task_name == 'piqa':
dataset = load_dataset('piqa', split='test')
else:
print("Invalid task name: {}".format(task_name), flush=True)
return
# read the dataset and add to ngrams
for line in dataset:
try:
if task_name in ['squad', 'triviaqa', 'webqa', 'race', 'drop']:
text = line['question']
compute_ngrams_insert_dict(args, text, ngrams)
elif task_name == 'natural_questions':
text = line['question']['text']
compute_ngrams_insert_dict(args, text, ngrams)
elif task_name == 'coqa':
all_questions = line['questions']
for question in all_questions:
compute_ngrams_insert_dict(args, question, ngrams)
elif task_name == 'piqa':
text = line['goal']
compute_ngrams_insert_dict(args, text, ngrams)
except Exception as e:
print('Error:', e)
print(" After task {} entities in ngrams {}, added {}".format(task_name, \
len(ngrams), len(ngrams) - entities_in_ngrams), flush=True)
def compute_tasks_ngrams(args, ngrams):
start_time = time.time()
for _, task_name in enumerate(args.tasks):
print('Task: {}'.format(task_name), flush=True)
if task_name == 'lambada':
assert args.lambada_path is not None
process_task_lambda(args, args.lambada_path, ngrams)
else:
process_task(args, task_name, ngrams)
print(" Taken time to compute ngrams {:.2f}".format(time.time() - \
start_time), flush=True)
def compute_ngram_freq_sorted(args, ngrams):
ngrams_freq = {}
for ngram_key in ngrams.keys():
length = len(ngram_key.split())
ngrams_freq[length] = ngrams_freq[length] + 1 if length in \
ngrams_freq else 1
ngrams_freq_sorted = sorted(ngrams_freq.items(), key=lambda item: item[0])
print(" Ngram frequencies: {}".format(ngrams_freq_sorted), flush=True)
print(" Entities in ngrams {} min_ngram_size {} max_ngram_size {}".format(\
len(ngrams), ngrams_freq_sorted[0][0], ngrams_freq_sorted[len(\
ngrams_freq_sorted) -1 ][0]), flush=True)
return ngrams_freq_sorted
def get_ngrams_below_threshold(args, ngrams, ngrams_below_threshold, \
dedup_file, dedup_key, ngrams_freq_sorted):
start_time = time.time()
# get the ngrams frequency
args.get_ngram_freq_only = True
# Open the large file to process in parallel
num_workers = args.num_threads
pool = multiprocessing.Pool(num_workers)
fin = open(dedup_file, 'r', encoding='utf-8')
free_ngram_abt_partial=partial(free_ngram, args=args, key=dedup_key, \
ngrams=ngrams, ngrams_freq_sorted=ngrams_freq_sorted)
free_ngrams_abt = pool.imap(free_ngram_abt_partial, fin, 500)
counter = 0
for _, _, _, local_ngram in free_ngrams_abt:
counter += 1
if counter % 1000 == 0:
print(' [compute_stat]> processed {} documents in {:.2f} seconds ...'.
format(counter, time.time() - start_time), flush=True)
for local_key in local_ngram:
if local_key in ngrams:
ngrams[local_key] += 1
local_ngram = {}
print(' Time taken to compute statistics {:.2f} seconds'.format(time.time() - \
start_time), flush=True)
pool.close()
pool.join()
start_time = time.time()
counter_threshold = 0
# Get ngram below theadhold
for local_key, local_val in ngrams.items():
if ngrams[local_key] < args.key_threshold:
print(" [threshold] {} {}".format(local_key, local_val), flush=True)
counter_threshold += 1
ngrams_below_threshold[local_key] = 1
print(' Ngrams below threshold {}'.format(counter_threshold), flush=True)
fin.close()
def clean_ngrams_below_threshold(args, ngrams_below_threshold, dedup_file, \
dedup_key):
start_time = time.time()
# Now actually filter the dataset
args.get_ngram_freq_only = False
#id_prefix = '-'.join(args.tasks[::2])
id_prefix = '-'.join(args.tasks[::1])
# get the range of the size of the ngrams
ngrams_freq_sorted = compute_ngram_freq_sorted(args, ngrams_below_threshold)
# Open the large file to process in parallel
counter = splitted = ignored = split_mt_thld = trimmed_count = 0
num_workers = args.num_threads
pool = multiprocessing.Pool(num_workers)
fin = open(dedup_file, 'r', encoding='utf-8')
free_ngram_clean_partial=partial(free_ngram, args=args, key=dedup_key, \
ngrams=ngrams_below_threshold, ngrams_freq_sorted=ngrams_freq_sorted)
free_ngrams_clean = pool.imap(free_ngram_clean_partial, fin, 500)
out_f = open(args.output, 'wb')
for text_buf_ngram_free, trimmed, myjson, _ in free_ngrams_clean:
counter += 1
try:
trimmed_count += trimmed
if len(text_buf_ngram_free) > 1:
splitted += 1
if len(text_buf_ngram_free) == 0:
ignored += 1
# more than 10 splits ignored
if len(text_buf_ngram_free) > args.splits_count:
text_buf_ngram_free = []
split_mt_thld += 1
if args.output is not None:
if "split_id" in myjson:
use_prefix = myjson["split_id"] + "-"
else:
use_prefix = ""
for i in range(len(text_buf_ngram_free)):
split_id_string = id_prefix + '-{:010d}'.format(int(\
counter)) + '-{:04d}'.format(int(i))
myjson[dedup_key] = text_buf_ngram_free[i]
myjson["split_id"] = use_prefix + split_id_string
outjson = json.dumps(myjson, ensure_ascii=False)
#outjson = json.dumps({"text":text_buf_ngram_free[i],
# id_prefix+"_split_id":split_id_string},
# ensure_ascii=False)
out_f.write(outjson.encode('utf-8'))
out_f.write('\n'.encode('utf-8'))
if counter % 1000 == 0:
print(' [final]> processed {} documents in {:.2f} seconds ...'.
format(counter, time.time() - start_time), flush=True)
except Exception as e:
print('Error:', e)
print(' [final]> processed {} documents in {:.2f} seconds ...'.
format(counter, time.time() - start_time), flush=True)
print(' Total docs {} splitted {} ignored {} splits > theshold {} trimmed'\
' {}'.format(counter, splitted, ignored, split_mt_thld, trimmed_count)\
, flush=True)
pool.close()
pool.join()
out_f.close()
fin.close()
if __name__ == '__main__':
# we use 13-grams, any text less than 200 characters got removed
# any text splitted more than 10 got removed as well
print('parsing the arguments ...')
parser = argparse.ArgumentParser()
parser.add_argument('--tasks', nargs = '*', required=True, default=None, \
help = 'Tasks to use for deduplication: currently '
' suuport [lambada, squad, natural_questions,'
' triviaqa, webqa, race, drop, coqa, and piqa]')
parser.add_argument('--lambada-path', type=str, default=None,
help='Only Lambada task needs the path')
parser.add_argument('--dedup-dataset', nargs = '*', default=None,
help='Dataset to deduplicate with the key to use'
' e.g. cc.json text')
parser.add_argument('--output', type=str, default=None,
help='Output file name to save dedup dataset')
parser.add_argument('--num-threads', type=int, default=40,
help='Number of threads to use')
# Default dedup values
parser.add_argument('--max-ngram-size', type=int, default=13,
help='Maximum size of ngram to use.')
parser.add_argument('--min-ngram-size', type=int, default=8,
help='Minimum size of ngram to use.')
parser.add_argument('--filter-text-char-len', type=int, default=200,
help='Remove any text below this length.')
parser.add_argument('--key-threshold', type=int, default=10,
help='Number of keys to consider as threshold')
parser.add_argument('--save-dictionary', type=str, default=None,
help='Save the dictionary')
parser.add_argument('--load-dictionary', type=str, default=None,
help='Load the dictionary')
parser.add_argument('--splits-count', type=int, default=10,
help='Remove any documents more than this many splits')
parser.add_argument('--remove-char-each-side', type=int, default=200,
help='Maximum size of ngram to use.')
args = parser.parse_args()
assert len(args.dedup_dataset) == 2
dedup_file = args.dedup_dataset[0]
dedup_key = args.dedup_dataset[1]
# Setup multi-processing
num_workers = args.num_threads
if args.load_dictionary is None:
# Build ngrams
ngrams = {}
compute_tasks_ngrams(args, ngrams)
# get the range of the size of the ngrams
ngrams_freq_sorted = compute_ngram_freq_sorted(args, ngrams)
# get ngram freq from large file in parallel
# get ngrams below threshold
ngrams_below_threshold = {}
get_ngrams_below_threshold(args, ngrams, ngrams_below_threshold, \
dedup_file, dedup_key, ngrams_freq_sorted)
# save the dictionary if needed
if args.save_dictionary is not None:
with open(args.save_dictionary, 'wb') as save_dict_handle:
pickle.dump(ngrams_below_threshold, save_dict_handle)
else:
with open(args.load_dictionary, 'rb') as load_dict_handle:
ngrams_below_threshold = pickle.load(load_dict_handle)
# filter the large file
if args.output is not None:
clean_ngrams_below_threshold(args, ngrams_below_threshold, \
dedup_file, dedup_key)
print('done :-)')
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import argparse
from functools import partial
import itertools
import json
from lsh import cache, minhash
import multiprocessing
import numpy as np
import time
import pickle
import sys
import os
# This function is adapted from:
# https://github.com/mattilyra/LSH/blob/master/examples/Introduction.ipynb
def shingles(text, char_ngram=5):
return set(text[head:head + char_ngram]
for head in range(0, len(text) - char_ngram))
# This function is adapted from:
# https://github.com/mattilyra/LSH/blob/master/examples/Introduction.ipynb
def jaccard(set_a, set_b, args):
if len(set_a) < 1 or len(set_b) < 1:
return 0.0
intersection = set_a & set_b
union = set_a | set_b
if args.jaccard == 'min':
return len(intersection) / min(len(set_a), len(set_b))
elif args.jaccard == 'max':
return len(intersection) / max(len(set_a), len(set_b))
else:
return len(intersection) / len(union)
def compute_fingerprint(line, key):
try:
myjson = json.loads(line)
url = myjson[key]
text = myjson['text']
fingerprint = hasher.fingerprint(text)
except Exception as e:
print('Error:', e)
return None, None, None, False
return url, text, fingerprint, True
def url_pairs_to_remove(args, bucket_urls, url_doc):
remove_urls_list = []
deduped_local, counter_local = 0, 0
iteration = 0
while len(bucket_urls) > 1:
if args.heuristic_iter != -1 and \
iteration == args.heuristic_iter:
break
items = list(bucket_urls)
remove_urls = []
main_url = items[np.random.randint(0, len(items))]
main_dhingles = shingles(url_doc[main_url])
for i in range(0, len(items)):
counter_local += 1
other_url = items[i]
if other_url == main_url:
continue
other_shingles = shingles(url_doc[other_url])
try:
jaccard_sim = jaccard(main_dhingles, other_shingles, args)
except Exception as e:
print('Error:', e)
jaccard_sim = 0.0
if jaccard_sim > 0.5:
remove_urls.append({other_url: jaccard_sim})
deduped_local += 1
bucket_urls.remove(other_url)
bucket_urls.remove(main_url)
if len(remove_urls) > 0:
remove_urls_list.append({main_url: remove_urls})
iteration += 1
return remove_urls_list, deduped_local, counter_local
def write_remove_urls_list(remove_urls_list, f_out):
if len(remove_urls_list) > 0:
for each_url_remove in remove_urls_list:
myjson = json.dumps(each_url_remove, ensure_ascii=False)
f_out.write(myjson.encode('utf-8'))
f_out.write('\n'.encode('utf-8'))
def compute_jaccard(each_bin, num_bins, start_time_local):
remove_urls_list = []
deduped_local, counter_local, bucket_local = 0, 0, 0
for bucket_id in each_bin:
bucket_local += 1
if os.getpid() % num_bins == 0 and bucket_local % 100000 == 0:
print("Counter {}, progress {:.2f} time {:.2f}".\
format(bucket_local, float(bucket_local)/float(len(each_bin)),\
time.time() - start_time_local), flush=True)
if len(each_bin[bucket_id]) <= 1:
continue
bucket_urls = each_bin[bucket_id].copy()
remove_urls_list_sub, deduped_local_sub, counter_local_sub = \
url_pairs_to_remove(args, bucket_urls, url_doc)
deduped_local += deduped_local_sub
counter_local += counter_local_sub
if len(remove_urls_list_sub) > 0:
remove_urls_list.extend(remove_urls_list_sub)
return remove_urls_list, deduped_local, counter_local
def find_pair_urls_parallel(args, lshcache, url_doc):
start_time = time.time()
f_out = open(args.output, 'wb')
deduped, counter = 0, 0
# compute jaccards of buckets in bin in parallel (parallelism
# limited to # of bins)
num_bins = len(lshcache.bins)
pool = multiprocessing.Pool(num_bins)
compute_jaccard_partial = partial(compute_jaccard, num_bins=num_bins, \
start_time_local=start_time)
# don't need to pass args and url_doc as they are already shared
compute_jaccard_iter = pool.imap(compute_jaccard_partial, lshcache.bins)
print("multiprocessing init took {:.2f}".format(time.time() - start_time),\
flush=True)
for remove_urls_list, deduped_local, counter_local in compute_jaccard_iter:
deduped += deduped_local
counter += counter_local
write_remove_urls_list(remove_urls_list, f_out)
print(' [write]> processed {} documents in {:.2f} '
'seoncds and deduped {} documents ...'.format(counter, time.time()\
- start_time, deduped), flush=True)
pool.close()
pool.join()
f_out.close()
print(' Taken time for jaccard similariries {:.2f} seconds'.format(\
time.time() - start_time), flush=True)
def find_pair_urls_sequential(args, lshcache, url_doc):
start_time = time.time()
f_out = open(args.output, 'wb')
deduped, counter = 0, 0
for b in lshcache.bins:
for bucket_id in b:
if len(b[bucket_id]) <= 1:
continue
bucket_urls = b[bucket_id].copy()
remove_urls_list_sub, deduped_local_sub, counter_local_sub = \
url_pairs_to_remove(args, bucket_urls, url_doc)
deduped += deduped_local_sub
counter += counter_local_sub
write_remove_urls_list(remove_urls_list_sub, f_out)
if counter % 10000 == 0:
print(' [write]> processed {} documents in {:.2f} '
'seoncds and deduped {} documents ...'.
format(counter, time.time() - start_time,
deduped), flush=True)
f_out.close()
print(' [write]> processed {} documents in {:.2f} '
'seoncds and deduped {} documents ...'.
format(counter, time.time() - start_time,
deduped), flush=True)
if __name__ == '__main__':
print('parsing the arguments ...')
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=1234,
help='Random seed used for python, numpy')
parser.add_argument('--inputs', nargs = '*', default=None, help = \
'Pairwise list of the input files and keys, '
'e.g. --inputs cc.json cc_id news.json news_id')
parser.add_argument('--load-fingerprints', nargs = '*', default=None,
help='Load fingerprints from a list of pickle files,'
' e.g. cc.pkl news.pkl')
parser.add_argument('--save-fingerprints', type=str, default=None,
help='Save the fingerprints of the inputs.')
parser.add_argument('--output', type=str, default=None,
help='Output file name that consists of all ids'
' with matching similarities')
parser.add_argument('--jaccard', type=str, default='union',
choices=['union', 'min', 'max'], help='Jaccard'\
' similarity computation')
parser.add_argument('--heuristic-iter', type=int, default=1,
help='Number of iterations to run the heuristics'
': use -1 for exact')
parser.add_argument('--num-bands', type=int, default=10,
help='Number of bands to use in cache')
parser.add_argument('--num-seeds', type=int, default=100,
help='Number of seeds to use for minhash. Note that'
' this value should be divisible by num-bands')
parser.add_argument('--jaccard-parallel', action='store_true',
help='Use this to process large number of documents.')
args = parser.parse_args()
print('finding possible duplicate content ...')
# set seed and get an array of seeds of 100 integers
np.random.seed(args.seed)
seeds = np.random.randint(0, 1e6, size=args.num_seeds)
# initialize minhash and lsh cache
hasher = minhash.MinHasher(seeds=seeds, char_ngram=5, hashbytes=4)
lshcache = cache.Cache(num_bands=args.num_bands, hasher=hasher)
url_doc = {}
# load fingerprints from pickle file if needed
if args.load_fingerprints is not None:
for count_fp, fp_file_name in enumerate(args.load_fingerprints):
print("Loading fingerprints from pickle file {}".format(
fp_file_name), flush=True)
fp = open(fp_file_name, "rb")
if count_fp == 0:
# assign directory for the first pkl
lshcache = pickle.load(fp)
url_doc = pickle.load(fp)
else:
# append these to lshcache and url_doc
local_lshcache = pickle.load(fp)
local_url_doc = pickle.load(fp)
for url in local_lshcache.fingerprints.keys():
url_doc[url] = local_url_doc[url]
lshcache.add_fingerprint(local_lshcache.fingerprints[url], url)
fp.close()
counter = 0
start_time = time.time()
# compute finger prints of the inputs if any
# input file and the key to use as id
if args.inputs is not None:
print("Computing fingerprints", flush=True)
assert len(args.inputs) % 2 == 0
for input_file, key in zip(args.inputs[::2], args.inputs[1::2]):
print(' document processing {} with key {}'.format(input_file, key),
flush=True)
# compute fingerprints in parallel
num_workers = 40
pool = multiprocessing.Pool(num_workers)
fin = open(input_file, 'r', encoding='utf-8')
compute_fingerprint_partial = partial(compute_fingerprint, key=key)
compute_fingerprint_iter = pool.imap(compute_fingerprint_partial,
fin, 512)
# traverse all the texts and add fingerprints
for url, text, fingerprint, flag in compute_fingerprint_iter:
counter += 1
if flag:
url_doc[url] = text
lshcache.add_fingerprint(fingerprint, url)
if counter % 10000 == 0:
print(' [read]> processed {} documents in {:.2f} '
'seconds ...'.format(counter, time.time() - \
start_time), flush=True)
fin.close()
pool.close()
pool.join()
# Save the fingerprints if needed
if args.save_fingerprints is not None:
print("Saving fingerprints to pickle file {}".format(
args.save_fingerprints), flush=True)
with open(args.save_fingerprints, 'wb') as f_save:
pickle.dump(lshcache, f_save)
pickle.dump(url_doc, f_save)
# compute jaccard index of the input texts and write to file if needed
if args.output is not None:
print("Compute jaccard similarity", flush=True)
if args.jaccard_parallel:
find_pair_urls_parallel(args, lshcache, url_doc)
else:
find_pair_urls_sequential(args, lshcache, url_doc)
print('done :-)')
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import json
import time
import sys
if __name__ == '__main__':
print('grouping duplicate urls ...')
input = sys.argv[1]
output = sys.argv[2]
if len(sys.argv) > 3:
jaccard_similarity_threshold = float(sys.argv[3])
else:
jaccard_similarity_threshold = 0.7
url_to_index = {}
index_to_urls = []
counter = 0
start_time = time.time()
with open(input, 'r') as f:
for line in f:
counter += 1
myjson = json.loads(line)
urls = []
for main_url in myjson.keys():
urls.append(main_url)
for value in myjson[main_url]:
for other_url, js in value.items():
if js >= jaccard_similarity_threshold:
urls.append(other_url)
current_index = -1
other_indices = set()
for url in urls:
if url in url_to_index:
if current_index == -1:
current_index = url_to_index[url]
elif current_index != url_to_index[url]:
other_indices.add(url_to_index[url])
if current_index == -1:
current_index = len(index_to_urls)
index_to_urls.append(set())
for url in urls:
url_to_index[url] = current_index
index_to_urls[current_index].add(url)
for index in other_indices:
for url in index_to_urls[index]:
index_to_urls[current_index].add(url)
url_to_index[url] = current_index
index_to_urls[index] = None
if counter % 100000 == 0:
print(' > processed {} lines in {} seconds ...'.format(
counter, time.time() - start_time))
total_remove = 0
total_remain = 0
for urls in index_to_urls:
if urls is not None:
if len(urls) > 1:
total_remove += (len(urls) - 1)
total_remain += 1
print('out of {} urls, only {} are unique and {} should be removed'.format(
total_remove+total_remain, total_remain, total_remove))
with open(output, 'wb') as f:
for i, urls in enumerate(index_to_urls):
if urls is not None:
if len(urls) > 1:
myjson = json.dumps({str(i): list(urls)},
ensure_ascii=False)
f.write(myjson.encode('utf-8'))
f.write('\n'.encode('utf-8'))
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import glob
import sys
import json
import argparse
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--json_path", type=str, default=".",
help="path where all the json files are located")
parser.add_argument("--output_file", type=str, default="merged_output.json",
help="filename where the merged json should go")
args = parser.parse_args()
json_path = args.json_path
out_file = args.output_file
json_files = glob.glob(json_path + '/*.json')
counter = 0
with open(out_file, 'w') as outfile:
for fname in json_files:
counter += 1
if counter % 1024 == 0:
print("Merging at ", counter, flush=True)
with open(fname, 'r') as infile:
for row in infile:
each_row = json.loads(row)
outfile.write(row)
print("Merged file", out_file, flush=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment