Commit aed2f75e authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'main' into github-main

parents 8aa4619f f32a638d
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Vision-classification finetuning/evaluation."""
from megatron import get_args
from megatron import print_rank_0
from megatron.model.vit_model import VitModel
from megatron.data.vit_dataset import build_train_valid_datasets
from tasks.vision.eval_utils import accuracy_func_provider
from tasks.vision.finetune_utils import finetune
def classification():
def train_valid_datasets_provider():
"""Build train and validation dataset."""
args = get_args()
train_ds, valid_ds = build_train_valid_datasets(
data_path=args.data_path,
crop_size=args.img_dim,
)
return train_ds, valid_ds
def model_provider():
"""Build the model."""
args = get_args()
print_rank_0("building classification model for ImageNet ...")
return VitModel(num_classes=args.num_classes, finetune=True)
"""Finetune/evaluate."""
finetune(
train_valid_datasets_provider,
model_provider,
end_of_epoch_callback_provider=accuracy_func_provider,
)
def main():
classification()
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluation utilities."""
import os
import torch
from megatron import get_args
from megatron import print_rank_0
from megatron import mpu
from tasks.vision.finetune_utils import build_data_loader
from tasks.vision.finetune_utils import process_batch
from torchvision import datasets, transforms
def accuracy_func_provider():
"""Provide function that calculates accuracies."""
args = get_args()
data_path = args.data_path
crop_size = args.img_dim
# mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
# Build dataloaders.
val_data_path = os.path.join(data_path[0], "val")
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
transform_val = transforms.Compose(
[
transforms.Resize(crop_size),
transforms.CenterCrop(crop_size),
transforms.ToTensor(),
normalize,
]
)
dataset = datasets.ImageFolder(root=val_data_path, transform=transform_val)
dataloader = build_data_loader(
dataset,
args.micro_batch_size,
num_workers=args.num_workers,
drop_last=(mpu.get_data_parallel_world_size() > 1),
)
def metrics_func(model, epoch):
print_rank_0("calculating metrics ...")
correct, total = calculate_correct_answers(model, dataloader, epoch)
percent = float(correct) * 100.0 / float(total)
print_rank_0(
" >> |epoch: {}| overall: correct / total = {} / {} = "
"{:.4f} %".format(epoch, correct, total, percent)
)
return metrics_func
def calculate_correct_answers(model, dataloader, epoch):
"""Calculate correct over total answers"""
model.eval()
with torch.no_grad():
# For all the batches in the dataset.
total = 0
correct = 0
for _, batch in enumerate(dataloader):
# Run the model forward.
images, labels = process_batch(batch)
logits = model(images).contiguous().float()
# Add output predictions.
# Compute the correct answers.
predicted = torch.argmax(logits, dim=-1)
corrects = (predicted == labels).float()
# Add to the counters.
total += labels.size(0)
correct += corrects.sum().item()
model.train()
# Reduce.
unreduced = torch.cuda.LongTensor([correct, total])
torch.distributed.all_reduce(unreduced, group=mpu.get_data_parallel_group())
# Print on screen.
correct_ans = unreduced[0].item()
total_count = unreduced[1].item()
return correct_ans, total_count
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Finetune utilities."""
import torch
import torch.nn.functional as F
from megatron import get_args
from megatron import print_rank_0
from megatron import get_timers
from megatron import mpu
from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
from megatron.training import evaluate_and_print_results
from megatron.training import setup_model_and_optimizer
from megatron.training import train_step
from megatron.training import training_log
from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import average_losses_across_data_parallel_group
def process_batch(batch):
"""Process batch and produce inputs for the model."""
images = batch[0].cuda().contiguous()
labels = batch[1].cuda().contiguous()
return images, labels
def _cross_entropy_forward_step(batch, model, input_tensor):
"""Simple forward step with cross-entropy loss."""
timers = get_timers()
assert input_tensor is None
# Get the batch.
timers("batch generator").start()
try:
batch_ = next(batch)
except BaseException:
batch_ = batch
images, labels = process_batch(batch_)
timers("batch generator").stop()
# Forward model.
logits = model(images).contiguous().float()
# Cross-entropy loss.
loss = F.cross_entropy(logits, labels)
# Reduce loss for logging.
average_loss = average_losses_across_data_parallel_group([loss])
return loss, {"lm loss": average_loss[0]}
def build_data_loader(dataset, micro_batch_size, num_workers, drop_last):
"""Data loader. Note that batch-size is the local (per GPU) batch-size."""
# Sampler.
world_size = mpu.get_data_parallel_world_size()
rank = mpu.get_data_parallel_rank()
sampler = torch.utils.data.distributed.DistributedSampler(
dataset, num_replicas=world_size, rank=rank
)
# Data loader. Note that batch size is the per GPU batch size.
data_loader = torch.utils.data.DataLoader(
dataset,
batch_size=micro_batch_size,
sampler=sampler,
shuffle=False,
num_workers=num_workers,
drop_last=drop_last,
pin_memory=True,
)
return data_loader
def _build_infinite_size_dataloader(dataloader):
"""Build a looped dataloader with infinite size."""
iterator = dataloader.__iter__()
while True:
try:
yield iterator.__next__()
except StopIteration:
iterator = dataloader.__iter__()
def _build_train_valid_dataloaders(train_dataset, valid_dataset):
"""Traing and validation dataloaders."""
args = get_args()
print_rank_0("building train and validation dataloaders ...")
# Training dataset.
train_dataloader = build_data_loader(
train_dataset, args.micro_batch_size, args.num_workers, not args.keep_last
)
# Set the training iterations.
args.train_iters_per_epoch = len(train_dataloader)
args.train_iters = args.epochs * args.train_iters_per_epoch
# Validation dataset. For this dataset, we do not need to set up
# shuffling so we can just use a simple infinite loop.
valid_dataloader_ = build_data_loader(
valid_dataset, args.micro_batch_size, args.num_workers, not args.keep_last
)
valid_dataloader = _build_infinite_size_dataloader(valid_dataloader_)
return train_dataloader, valid_dataloader
def _train(
model,
optimizer,
lr_scheduler,
forward_step,
train_dataloader,
valid_dataloader,
end_of_epoch_callback,
):
"""Train the model."""
args = get_args()
timers = get_timers()
# Turn on training mode which enables dropout.
model.train()
# Tracking loss.
losses_dict_sum = {}
# Starting epoch and iteration
start_epoch = args.iteration // args.train_iters_per_epoch
start_iteration = args.iteration % args.train_iters_per_epoch
iteration = args.iteration
# Memory reporting flag.
report_memory_flag = True
# For each remaining epoch
timers("interval-time").start()
for epoch in range(start_epoch, args.epochs):
print_rank_0("working on epoch {} ...".format(epoch + 1))
# Set the data loader epoch to shuffle the index iterator.
train_dataloader.sampler.set_epoch(args.seed + epoch)
# For all the batches in the dataset.
for iteration_, batch in enumerate(train_dataloader):
# Ignore the iterations before starting value
if iteration_ < start_iteration:
continue
# Set to zero so the next epoch does not skip any batches.
start_iteration = 0
# Train for one step.
losses_dict, skipped_iter = train_step(
forward_step, batch, model, optimizer, lr_scheduler
)
iteration += 1
# Logging.
report_memory_flag = training_log(
losses_dict,
losses_dict_sum,
optimizer.param_groups[0]["lr"],
iteration,
optimizer.get_loss_scale().item(),
report_memory_flag,
skipped_iter,
)
# Autoresume
if args.adlr_autoresume and (
iteration % args.adlr_autoresume_interval == 0
):
check_adlr_autoresume_termination(
iteration, model, optimizer, lr_scheduler
)
# Checkpointing
if (
args.save
and args.save_interval
and iteration % args.save_interval == 0
):
save_checkpoint(iteration, model, optimizer, lr_scheduler)
# Evaluation
if args.eval_interval and iteration % args.eval_interval == 0:
prefix = "iteration {}".format(iteration)
evaluate_and_print_results(
prefix,
forward_step,
valid_dataloader,
model,
iteration,
False,
)
# Checkpointing at the end of each epoch.
if args.save:
save_checkpoint(iteration, model, optimizer, lr_scheduler)
# Callback at the end of each epoch.
if end_of_epoch_callback is not None:
end_of_epoch_callback(model, epoch)
def finetune(
train_valid_datasets_provider,
model_provider,
forward_step=_cross_entropy_forward_step,
end_of_epoch_callback_provider=None,
):
"""Main finetune function used across all tasks."""
args = get_args()
timers = get_timers()
# Train and validation data loaders.
timers("train/valid/test dataset/dataloder").start()
if args.epochs > 0:
train_dataset, valid_dataset = train_valid_datasets_provider()
train_dataloader, valid_dataloader = _build_train_valid_dataloaders(
train_dataset, valid_dataset
)
timers("train/valid/test dataset/dataloder").stop()
# Build calback function.
timers("callback function").start()
end_of_epoch_callback = None
if end_of_epoch_callback_provider is not None:
end_of_epoch_callback = end_of_epoch_callback_provider()
timers("callback function").stop()
# Build model, optimizer and learning rate scheduler.
timers("model and optimizer").start()
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
timers("model and optimizer").stop()
# If pretrained checkpoint is provided and we have not trained for
# any iteration (i.e., iteration is zero), then load the pretrained
# checkpoint.
timers("pretrained checkpoint").start()
if args.iteration == 0 and args.pretrained_checkpoint is not None:
original_load = args.load
args.load = args.pretrained_checkpoint
_ = load_checkpoint(model, None, None, strict=False)
args.load = original_load
# This is critical when only model is loaded. We should make sure
# master parameters are also updated.
optimizer.reload_model_params()
timers("pretrained checkpoint").stop()
# Print setup timing.
print_rank_0("done with setups ...")
timers.log(
[
"train/valid/test dataset/dataloder",
"callback function",
"model and optimizer",
"pretrained checkpoint",
]
)
print_rank_0("training ...")
# Finetune the model.
if args.epochs > 0:
_train(
model,
optimizer,
lr_scheduler,
forward_step,
train_dataloader,
valid_dataloader,
end_of_epoch_callback,
)
# Or just evaluate.
else:
if end_of_epoch_callback is not None:
print_rank_0("evaluation only mode, setting epoch to -1")
end_of_epoch_callback(model, epoch=-1, output_predictions=True)
print_rank_0("done :-)")
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Main tasks functionality."""
import os
import sys
sys.path.append(
os.path.abspath(
os.path.join(
os.path.join(os.path.dirname(__file__), os.path.pardir),
os.path.pardir,
)
)
)
from megatron import get_args
from megatron.initialize import initialize_megatron
from classification import main
def get_tasks_args(parser):
"""Provide extra arguments required for tasks."""
group = parser.add_argument_group(title="tasks")
group.add_argument(
"--epochs",
type=int,
default=None,
help="Number of finetunning epochs. Zero results in "
"evaluation only.",
)
group.add_argument(
"--pretrained-checkpoint",
type=str,
default=None,
help="Pretrained checkpoint used for finetunning.",
)
group.add_argument(
"--keep-last",
action="store_true",
help="Keep the last batch (maybe incomplete) in" "the data loader",
)
return parser
if __name__ == "__main__":
initialize_megatron(extra_args_provider=get_tasks_args)
args = get_args()
main()
......@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""GPT2 zero-shot evaluation."""
"""GPT zero-shot evaluation."""
import math
......@@ -24,19 +24,24 @@ from megatron import print_rank_0, is_last_rank
from megatron import get_tokenizer
from megatron import mpu
from megatron.checkpointing import load_checkpoint
from megatron.model import GPT2Model, GPT2ModelFirstStage, GPT2ModelLastStage, GPT2ModelIntermediateStage
from megatron.training import get_model, communicate
from megatron.utils import get_ltor_masks_and_position_ids
from megatron.model import GPTModel
from megatron.training import get_model
from megatron.utils import get_ltor_masks_and_position_ids, unwrap_model
from megatron.p2p_communication import recv_forward, send_forward
from tasks.finetune_utils import build_data_loader
from .datasets import build_dataset
# These are needed to unwrap the model, would be nice to put these in megatron.utils if possible?
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
def get_model_provider(eval_metric):
"""Based on evaluation metric set the parallel-output flag and
return the model provider."""
def model_provider():
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
if eval_metric == 'loss':
......@@ -47,18 +52,9 @@ def get_model_provider(eval_metric):
raise NotImplementedError('output type for {} evaluation metric '
'is not supported.'.format(eval_metric))
print_rank_0('building GPT2 model ...')
if mpu.get_pipeline_model_parallel_world_size() > 1:
# Determine model based on position of stage in pipeline.
if mpu.is_pipeline_first_stage():
model = GPT2ModelFirstStage(num_tokentypes=0)
elif mpu.is_pipeline_last_stage():
model = GPT2ModelLastStage(
parallel_output=parallel_output, num_tokentypes=0)
else:
model = GPT2ModelIntermediateStage(num_tokentypes=0)
else:
model = GPT2Model(num_tokentypes=0, parallel_output=parallel_output)
print_rank_0('building GPT model ...')
model = GPTModel(num_tokentypes=0, parallel_output=parallel_output,
pre_process=pre_process, post_process=post_process)
return model
......@@ -97,33 +93,15 @@ def forward_step(batch, model, eval_metric):
args = get_args()
args.micro_batch_size = len(labels)
# Forward model.
if not mpu.is_pipeline_first_stage():
input_tensor, _ = communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_forward=True,
recv_backward=False)
else:
input_tensor = None
input_tensor = recv_forward()
# Forward pass through the model.
if mpu.is_pipeline_first_stage():
assert input_tensor is None
if mpu.is_pipeline_last_stage():
output = model(tokens, position_ids, attention_mask)
else:
output = model(tokens, position_ids, attention_mask)
else:
assert input_tensor is not None
output = model(input_tensor, attention_mask)
unwrapped_model = unwrap_model(
model, (torchDDP, LocalDDP, Float16Module))
unwrapped_model.set_input_tensor(input_tensor)
output = model(tokens, position_ids, attention_mask)
if not mpu.is_pipeline_last_stage():
communicate(tensor_send_next=output,
tensor_send_prev=None,
recv_forward=False,
recv_backward=False)
return None
send_forward(output)
if mpu.is_pipeline_last_stage():
# For loss, return the unreduced loss.
......@@ -214,6 +192,10 @@ def main():
"""Main program."""
args = get_args()
if args.num_layers_per_virtual_pipeline_stage is not None:
print("Interleaved pipeline schedule is not yet supported for text generation.")
exit()
if args.task == 'LAMBADA':
eval_metric = 'accuracy'
elif args.task == 'WIKITEXT103':
......@@ -227,6 +209,9 @@ def main():
if args.load is not None:
_ = load_checkpoint(model, None, None)
assert len(model) == 1, "Above condition should have caught this"
model = model[0]
# Data stuff.
dataset = build_dataset(args.task)
dataloader = build_data_loader(dataset, args.micro_batch_size,
......
def test_import():
import megatron
import os
import sys
sys.path.append('../')
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir)))
from megatron import print_rank_0
from megatron.indexer import IndexBuilder
from megatron.initialize import initialize_megatron
......@@ -23,7 +26,7 @@ def main():
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
index_builder = IndexBuilder()
index_builder.build_and_save_index()
print_rank_0("Build and save indices: done!")
if __name__ == "__main__":
main()
......
......@@ -26,33 +26,19 @@ from megatron import get_tokenizer
from megatron import mpu
from megatron.checkpointing import load_checkpoint
from megatron.initialize import initialize_megatron
from megatron.model import (GPTModel,
GPTModelFirstStage,
GPTModelLastStage,
GPTModelIntermediateStage)
from megatron.model import GPTModel
from megatron.training import get_model
from megatron.text_generation_utils import generate_and_write_samples_unconditional
from megatron.text_generation_utils import generate_samples_input_from_file
from megatron.text_generation_utils import generate_samples_interactive
def model_provider():
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
print_rank_0('building GPT model ...')
args = get_args()
if mpu.get_pipeline_model_parallel_world_size() > 1:
# Determine model based on position of stage in pipeline.
if mpu.is_pipeline_first_stage():
model = GPTModelFirstStage(num_tokentypes=0)
elif mpu.is_pipeline_last_stage():
model = GPTModelLastStage(
num_tokentypes=0, parallel_output=False)
else:
model = GPTModelIntermediateStage(
num_tokentypes=0)
else:
model = GPTModel(num_tokentypes=0, parallel_output=False)
model = GPTModel(num_tokentypes=0, parallel_output=False,
pre_process=pre_process, post_process=post_process)
return model
......@@ -92,14 +78,24 @@ def main():
"""Main program."""
initialize_megatron(extra_args_provider=add_text_generate_args,
args_defaults={'tokenizer_type': 'GPT2BPETokenizer'})
args_defaults={'tokenizer_type': 'GPT2BPETokenizer',
'no_load_rng': True,
'no_load_optim': True})
args = get_args()
if args.num_layers_per_virtual_pipeline_stage is not None:
print("Interleaved pipeline schedule is not yet supported for text generation.")
exit()
# Set up model and load checkpoint.
model = get_model(model_provider)
args = get_args()
if args.load is not None:
_ = load_checkpoint(model, None, None)
assert len(model) == 1, "Above condition should have caught this"
model = model[0]
# Generate samples.
if args.num_samples == 0:
args.micro_batch_size = 1
......
......@@ -16,6 +16,7 @@
"""Merge model parallel partitions."""
import os
import re
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir)))
......@@ -23,11 +24,13 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
import torch
from megatron import mpu
from megatron.checkpointing import load_checkpoint, save_checkpoint
from megatron.checkpointing import ensure_directory_exists
from megatron.checkpointing import get_checkpoint_name
from megatron.checkpointing import get_checkpoint_version
from megatron.checkpointing import get_checkpoint_tracker_filename
from megatron.global_vars import set_global_variables, get_args
from megatron.global_vars import rebuild_tokenizer
from megatron.global_vars import _parse_args
def split_into_partitions(tensor, num_partitions, partition_dim, stride):
......@@ -179,14 +182,33 @@ def get_mp_merge_args(parser):
group.add_argument('--model-type', type=str, required=True,
choices=['BERT', 'GPT', 'RACE', 'MNLI', 'QQP'],
help='Type of the mdoel.')
group.add_argument('--target-pipeline-model-parallel-size', type=int, default=1,
help='Degree of pipeline model parallelism in output model.')
return parser
def main():
# Arguments do sanity checks on the world size, but we don't care,
# so trick it into thinking we are plenty of processes
os.environ["WORLD_SIZE"] = f'{2**31}'
# Args
args = _parse_args(extra_args_provider=get_mp_merge_args)
set_global_variables(extra_args_provider=get_mp_merge_args,
args_defaults = {'use_cpu_initialization': True,
'micro_batch_size': 1,
'no_load_optim': True,
'no_load_rng': True,
'no_save_optim': True,
'no_save_rng': True,
'save_interval': 1})
args = get_args()
if args.pipeline_model_parallel_size > 1:
print("Checkpoints with pipeline model parallelism are not currently supported.")
exit()
model_type = args.model_type
orig_tensor_model_parallel_size = args.tensor_model_parallel_size
args.tensor_model_parallel_size = 1
......@@ -209,6 +231,8 @@ def main():
print('> building the full model ...')
mpu.initialize.set_tensor_model_parallel_world_size(1)
mpu.initialize.set_tensor_model_parallel_rank(0)
mpu.initialize.set_pipeline_model_parallel_world_size(1)
mpu.initialize.set_pipeline_model_parallel_rank(0)
merged_model = get_model(model_type)
# Build and load partitions.
......@@ -218,15 +242,19 @@ def main():
tokenizer = rebuild_tokenizer(args)
mpu.initialize.set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)
for rank in range(args.tensor_model_parallel_size):
# Reset these since load_checkpoint asserts they are 0, but we are loading
# multiple checkpoints in the same process and they get set each time
args.consumed_train_samples = 0
args.consumed_valid_samples = 0
mpu.initialize.set_tensor_model_parallel_rank(rank)
checkpoint_name, iteration = get_parallel_checkpoint_name(args.load)
print('> loading {} ...'.format(checkpoint_name))
model_ = get_model(model_type)
sd = torch.load(checkpoint_name, map_location='cpu')
model_.load_state_dict(sd['model'])
print(f'> loading {checkpoint_name} ...')
load_checkpoint(model_, None, None)
print(f'> checkpoint version {get_checkpoint_version()}')
partitions.append(model_)
# Parameter generators so we can loop through them semiltaneouly.
merged_params_gen = merged_model.named_parameters()
partitions_params_gen = [partition.named_parameters()
......@@ -254,29 +282,67 @@ def main():
merged_param.data.copy_(partitions_param[0].data)
# For parallel parameters, merge the values
else:
print(' parallel parameter merge with stride {} along '
'dimention {}'.format(merged_param.stride,
merged_param.partition_dim))
dim = merged_param.partition_dim
stride = merged_param.partition_stride
print(f' parallel parameter merge with stride {stride} along '
f'dimention {dim}')
merge_partitions(merged_param,
partitions_param,
merged_param.partition_dim,
merged_param.stride)
dim,
stride)
except StopIteration:
break
# Save the model.
partitions = []
args.tensor_model_parallel_size = 1
args.pipeline_model_parallel_size = args.target_pipeline_model_parallel_size
assert args.num_layers % args.pipeline_model_parallel_size == 0, \
'num_layers must be divisible by target pipeline model parallel size'
layers_per_part = args.num_layers // args.pipeline_model_parallel_size
tokenizer = rebuild_tokenizer(args)
mpu.initialize.set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)
mpu.initialize.set_tensor_model_parallel_rank(0)
sd = {}
sd['model'] = merged_model.state_dict_for_save_checkpoint()
sd['iteration'] = iteration
merged_path = os.path.join(args.load, 'merged')
checkpoint_name = get_checkpoint_name(merged_path, iteration)
ensure_directory_exists(checkpoint_name)
print('> saving merged model to {}'.format(checkpoint_name))
torch.save(sd, checkpoint_name)
mpu.initialize.set_pipeline_model_parallel_world_size(args.pipeline_model_parallel_size)
# regex to parse out layer number from param name
layer_re = re.compile('layers\.([0-9]+)')
if args.pipeline_model_parallel_size > 1:
merged_params = {}
for name, merged_param in merged_model.named_parameters():
merged_params[name] = merged_param
for rank in range(args.pipeline_model_parallel_size):
mpu.initialize.set_pipeline_model_parallel_rank(rank)
model = get_model(model_type)
def update_layer_num(m):
# TODO! This assumes no interleaved pipeline execution
layer = int(m.group(1))
layer += rank * layers_per_part
return f'layers.{layer}'
for dst_name, partition_param in model.named_parameters():
if dst_name == "word_embeddings.weight":
# See comment in MegatronModule.initialize_word_embeddings()
src_name = "language_model.embedding.word_embeddings.weight"
else:
# Translate destination layer number (0-N for each partition)
# to source layer number (single-model layer number)
src_name = re.sub(layer_re, update_layer_num, dst_name)
print(f" > copying {src_name} to {dst_name} in rank {rank}'s model")
partition_param.data.copy_(merged_params[src_name].data)
partitions.append(model)
else:
partitions = [merged_model]
for rank, model in enumerate(partitions):
mpu.initialize.set_pipeline_model_parallel_rank(rank)
print(f"> saving rank {rank}'s model")
save_checkpoint(iteration, model, None, None)
print('done :-)')
......
......@@ -26,9 +26,9 @@ python blacklist_urls.py <path to the dowloaded deduplicated URLs> <filename for
```
python cleanup_dataset.py <input data file> <output cleaned data filename>
```
2. Using LSH, find possible duplicates and store then in a file for later processing. This step can NOT be sharded and usually takes 12 to 24 hours for OpenWebText dataset.
2. Using LSH, find possible duplicates and store then in a file for later processing. This step can NOT be sharded and usually takes 12 to 24 hours for OpenWebText dataset. The code supports saving and loading fingerprints for recurrent deduplications.
```
python find_duplicates.py <input cleaned data file> <output possible duplicate urls filename>
python find_duplicates.py --inputs <pairlist list of input cleaned data files and keys, e.g. cc.json cc_id news.json news_id> --output <output possible duplicate urls filename>
```
3. Based on similarity measure defind inside function `is_similar` (default: 0.9), group urls that are similar. Basically, for each group, only one url we should keep and remove the rest.
```
......@@ -44,3 +44,12 @@ python remove_group_duplicates.py <file containing simialr documents> <cleaned d
shuf <cleaned deduped data file> -o train_data.json
```
# Deduplicating ngrams
To deduplicate the downstream tasks from the training dataset, we run the following command.
```
python filter_ngrams.py <down stream task dataset> <training dataset to deduplicate> <output training dataset>
```
We use 13-grams for the deduplication. When we find a 13-gram match in a training document, we split the document into two pieces and remove the 13-gram along with 200 characters from the both side of the 13-gram. We also remove any splitted document with less than 200 characters or if a document got splitted more than 10 times.
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Deduplicate downstream tasks from training dataset. 13-grams have been used.
All split documents with less than 200 characters got filtered. Any document
with more than 10 splits got filtered as well.
"""
from functools import partial
import json
import multiprocessing
import nltk
import re
import string
import sys
import time
def get_words(text):
# get all the lowercase words from text
words, positions = [], []
for match in re.finditer(r'\w+', text.lower()):
words.append(match.group(0))
positions.append(match.start())
return words, positions
def free_ngram(line, ngrams, ngram_size, filter_text_len,
splits_count, split_window_each_size):
# remove all the ngrams
try:
myjson = json.loads(line)
text_buf = [myjson['text']]
except Exception as e:
print("Error: {}".format(e), flush=True)
text_buf = []
text_buf_ngram_free = []
while len(text_buf) > 0:
# get the first one from the buffer
text = text_buf.pop(0)
words, positions = get_words(text)
not_ngram_free = True
punctuations = ".!?"
# find n-grams
for i in range(len(words) - ngram_size + 1):
seq = " ".join(words[i:i+ngram_size])
if seq in ngrams:
# splits the text
# first part of the text
pos = positions[i] - split_window_each_size
text_first = ""
while pos > 0 and not text[pos] in punctuations:
pos -= 1
if pos > 0:
text_first = text[0:pos+1]
pos = positions[i] + split_window_each_size
# last part of the text
text_second = ""
while pos < len(text) and not text[pos] in punctuations:
pos += 1
if pos + 1 < len(text):
text_second = text[pos+1:len(text)]
# first part of ngrams free
if len(text_first) > filter_text_len:
text_buf_ngram_free.append(text_first)
# add second part for further processing
if len(text_second) > filter_text_len:
text_buf.append(text_second)
not_ngram_free = False
break
# text are ngram free
if not_ngram_free:
text_buf_ngram_free.append(text)
return text_buf_ngram_free
if __name__ == '__main__':
print('finding possible duplicate content ...')
main_file = sys.argv[1] # lambada file
dedup_file = sys.argv[2] # Book corpus
output_file = sys.argv[3] #Filtered book corpus
ngrams = {}
id_prefix = "lambada"
# we use 13-grams, any text less than 200 characters got removed
# any text splitted more than 10 got removed as well
ngram_size = 13
filter_text_len = 200
splits_count = 10
split_window_each_size = 200
print('Reading file {} and computing ngrams'.format(main_file))
with open(main_file, 'r') as f:
for line in f:
try:
myjson = json.loads(line)
words, positions = get_words(myjson['text'])
for i in range(len(words) - ngram_size+1):
seq = " ".join(words[i:i+ngram_size])
if seq not in ngrams:
ngrams[seq] = positions[i]
except Exception as e:
print('Error:', e)
print("ngrams size {}".format(len(ngrams)))
print('Reading file {} and deduping n-grams'.format(dedup_file))
counter = 0
start_time = time.time()
out_f = open(output_file, 'wb')
splitted, ignored, split_mt_thld = 0, 0, 0
# Setup multi-processing.
num_workers = 40
fin = open(dedup_file, 'r', encoding='utf-8')
pool = multiprocessing.Pool(num_workers)
free_ngram_x=partial(free_ngram, ngrams=ngrams, ngram_size=ngram_size,
filter_text_len=filter_text_len, splits_count=splits_count,
split_window_each_size=split_window_each_size)
free_ngrams = pool.imap(free_ngram_x, fin, 25)
for text_buf_ngram_free in free_ngrams:
counter += 1
try:
if len(text_buf_ngram_free) > 1:
splitted += (len(text_buf_ngram_free) - 1)
if len(text_buf_ngram_free) == 0:
ignored += 1
# more than 10 splits ignored
if len(text_buf_ngram_free) > splits_count:
text_buf_ngram_free = []
split_mt_thld += 1
for i in range(len(text_buf_ngram_free)):
split_id_string = id_prefix + '-{:010d}'.format(int(counter)) \
+ '-{:010d}'.format(int(i))
outjson = json.dumps({"text":text_buf_ngram_free[i],
id_prefix+"_split_id":split_id_string},
ensure_ascii=False)
out_f.write(outjson.encode('utf-8'))
out_f.write('\n'.encode('utf-8'))
if counter % 1000 == 0:
print(' [search]> processed {} documents in {:.2f} seconds ...'.
format(counter, time.time() - start_time), flush=True)
except Exception as e:
print('Error:', e)
print("Deduped file written to: {}".format(output_file), flush=True)
print("Total docs {} splitted {} ignored {} docs with many splits {}".\
format(counter, splitted, ignored, split_mt_thld), flush=True)
print('done :-)')
......@@ -13,14 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import itertools
import json
from lsh import cache, minhash
import numpy as np
import time
import pickle
import sys
# This function is adapted from:
# https://github.com/mattilyra/LSH/blob/master/examples/Introduction.ipynb
def shingles(text, char_ngram=5):
......@@ -38,36 +39,98 @@ def jaccard(set_a, set_b):
if __name__ == '__main__':
print('parsing the arguments ...')
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=1234,
help='Random seed used for python, numpy')
parser.add_argument('--inputs', nargs = '*', default=None, help = \
'Pairwise list of the input files and keys, '
'e.g. --inputs cc.json cc_id news.json news_id')
parser.add_argument('--load-fingerprints', nargs = '*', default=None,
help='Load fingerprints from a list of pickle files,'
' e.g. cc.pkl news.pkl')
parser.add_argument('--save-fingerprints', type=str, default=None,
help='Save the fingerprints of the inputs.')
parser.add_argument('--output', type=str, default=None,
help='Output file name that consists of all ids'
' with matching similarities')
args = parser.parse_args()
print('finding possible duplicate content ...')
input = sys.argv[1]
output = sys.argv[2]
# set seed and get an array of seeds of 100 integers
np.random.seed(args.seed)
seeds = np.random.randint(0, 1e6, size=100)
hasher = minhash.MinHasher(seeds=100, char_ngram=5, hashbytes=4)
# initialize minhash and lsh cache
hasher = minhash.MinHasher(seeds=seeds, char_ngram=5, hashbytes=4)
lshcache = cache.Cache(bands=10, hasher=hasher)
counter = 0
url_doc = {}
# load fingerprints from pickle file if needed
if args.load_fingerprints is not None:
for count_fp, fp_file_name in enumerate(args.load_fingerprints):
print("Loading fingerprints from pickle file {}".format(
fp_file_name), flush=True)
fp = open(fp_file_name, "rb")
if count_fp == 0:
# assign directory for the first pkl
lshcache = pickle.load(fp)
url_doc = pickle.load(fp)
else:
# append these to lshcache and url_doc
local_lshcache = pickle.load(fp)
local_url_doc = pickle.load(fp)
for url in local_lshcache.fingerprints.keys():
url_doc[url] = local_url_doc[url]
lshcache.add_fingerprint(local_lshcache.fingerprints[url], url)
fp.close()
counter = 0
start_time = time.time()
with open(input, 'r') as f:
for line in f:
try:
myjson = json.loads(line)
url = myjson['url']
text = myjson['text']
counter += 1
url_doc[url] = text
lshcache.add_fingerprint(hasher.fingerprint(text), url)
except Exception as e:
print('Error:', e)
if counter % 10000 == 0:
print(' [read]> processed {} documents in {:.2f} seconds ...'.
format(counter, time.time() - start_time), flush=True)
print("Computing fingerprints", flush=True)
# compute finger prints of the inputs if any
# input file and the key to use as id
if args.inputs is not None:
assert len(args.inputs) % 2 == 0
for input_file, key in zip(args.inputs[::2], args.inputs[1::2]):
print(' document processing {} with key {}'.format(input_file, key),
flush=True)
# traverse all the texts and add fingerprints
with open(input_file, 'r') as f_input:
for line in f_input:
try:
myjson = json.loads(line)
url = myjson[key]
text = myjson['text']
counter += 1
url_doc[url] = text
lshcache.add_fingerprint(hasher.fingerprint(text), url)
except Exception as e:
print('Error:', e)
if counter % 10000 == 0:
print(' [read]> processed {} documents in {:.2f} '
'seconds ...'.format(counter, time.time() - \
start_time), flush=True)
# Save the fingerprints if needed
if args.save_fingerprints is not None:
print("Saving fingerprints to pickle file {}".format(
args.save_fingerprints), flush=True)
with open(args.save_fingerprints, 'wb') as f_save:
pickle.dump(lshcache, f_save)
pickle.dump(url_doc, f_save)
counter = 0
start_time = time.time()
deduped = 0
with open(output, 'wb') as f:
# compute jaccard index of the input texts and write to file if needed
if args.output is not None:
f_out = open(args.output, 'wb')
for b in lshcache.bins:
for bucket_id in b:
if len(b[bucket_id]) > 1:
......@@ -94,7 +157,8 @@ if __name__ == '__main__':
if len(remove_urls) > 0:
myjson = json.dumps({main_url: remove_urls},
ensure_ascii=False)
f.write(myjson.encode('utf-8'))
f.write('\n'.encode('utf-8'))
f_out.write(myjson.encode('utf-8'))
f_out.write('\n'.encode('utf-8'))
f_out.close()
print('done :-)')
......@@ -19,7 +19,7 @@ import sys
def is_similar(jaccard_similarity):
return (js >= 0.9)
return (jaccard_similarity >= 0.7)
if __name__ == '__main__':
......
......@@ -85,7 +85,7 @@ class Encoder(object):
sentence_ids = Encoder.tokenizer.tokenize(sentence)
if len(sentence_ids) > 0:
doc_ids.append(sentence_ids)
if self.args.append_eod:
if len(doc_ids) > 0 and self.args.append_eod:
doc_ids[-1].append(Encoder.tokenizer.eod)
ids[key] = doc_ids
return ids, len(json_line)
......@@ -182,6 +182,8 @@ def main():
for i, (doc, bytes_processed) in enumerate(encoded_docs, start=1):
total_bytes_processed += bytes_processed
for key, sentences in doc.items():
if len(sentences) == 0:
continue
for sentence in sentences:
builders[key].add_item(torch.IntTensor(sentence))
builders[key].end_document()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment