Commit 0024a5c6 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'main' of https://github.com/NVIDIA/Megatron-LM

parents b004456b 3db2063b
Pipeline #229 failed with stages
in 0 seconds
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import torch
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
import torch.distributed as dist
from functools import partial
from megatron import get_args, get_timers, print_rank_0
from megatron.core.enums import ModelType
from megatron.data.vit_dataset import build_train_valid_datasets
from megatron.model.vision.dino import DINOPretrainModel
from megatron.model.vision.knn_monitor import knn_predict, get_feature_bank
from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group, unwrap_model
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
return DINOPretrainModel(pre_process=pre_process, post_process=post_process)
def get_batch(data_iterator):
"""Build the batch."""
data = next(data_iterator)
# only data parallelism; no need for broadcast
if isinstance(data[0], list):
images = [aug.cuda() for aug in data[0]]
else:
images = data[0].cuda()
labels = data[1].cuda()
return images, labels
def loss_func(model, labels, output_tensor, collect_data=False):
args = get_args()
model = unwrap_model(
model,
(torchDDP, LocalDDP, Float16Module)
)
if model.training:
student_output, teacher_output = output_tensor
loss = model.dino_loss(student_output, teacher_output, args.curr_iteration)
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {"loss": averaged_loss[0]}
else:
_, teacher_feature = output_tensor
feature_bank, feature_labels, classes = get_feature_bank()
feature = F.normalize(teacher_feature.float(), dim=1)
knn_accs = []
for k in [10, 20, 100, 200]:
pred_labels = knn_predict(feature, feature_bank,
feature_labels, classes, k, 0.07)
knn_acc = (pred_labels[:, 0] == labels).float().mean()
knn_accs.append(knn_acc)
averaged_loss = average_losses_across_data_parallel_group(knn_accs)
return 0, {"knn_acc_10": averaged_loss[0],
"knn_acc_20": averaged_loss[1],
"knn_acc_100": averaged_loss[2],
"knn_acc_200": averaged_loss[3]}
def forward_step(data_iterator, model):
"""Forward step."""
timers = get_timers()
# Get the batch.
timers("batch-generator", log_level=2).start()
(
images,
labels,
) = get_batch(data_iterator)
timers("batch-generator").stop()
return model(images), partial(loss_func, model, labels)
def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build train, valid, and test datasets."""
args = get_args()
print_rank_0(
"> building train, validation, and test datasets " "for VIT ..."
)
train_ds, valid_ds = build_train_valid_datasets(
data_path=args.data_path,
image_size=(args.img_h, args.img_w)
)
print_rank_0("> finished creating VIT datasets ...")
return train_ds, valid_ds, None
if __name__ == "__main__":
pretrain(
train_valid_test_datasets_provider,
model_provider,
ModelType.encoder_or_decoder,
forward_step,
args_defaults={'dataloader_type': 'cyclic', 'vision_pretraining': True}
)
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Pretrain VIT"""
import torch
import torch.nn.functional as F
from functools import partial
from megatron import get_args, get_timers, print_rank_0, print_rank_last
from megatron.core.enums import ModelType
from megatron.data.vit_dataset import build_train_valid_datasets
from megatron.model.vision.inpainting import VitInpaintingModel
from megatron.model.vision.inpainting import MitInpaintingModel
from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group
from tasks.vision.metrics import SSIM, PSNR
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
args = get_args()
if args.vision_backbone_type == 'vit':
model = VitInpaintingModel(pre_process=pre_process,
post_process=post_process)
elif args.vision_backbone_type == 'mit':
model = MitInpaintingModel(pre_process=pre_process,
post_process=post_process)
else:
raise Exception('{} vision backbone is not supported.'.format(
args.vision_backbone_type))
return model
def get_batch(data_iterator):
"""Build the batch."""
data = next(data_iterator)
# only data parallelism; no need for broadcast
images = data[0][0].cuda()
masks = data[0][1].cuda()
return images, masks
def loss_func(images, masks, masked_images, outputs, collect_data=False):
outputs = outputs.contiguous().float()
masks_flip = 1-masks
flip_masked_outputs = outputs.masked_fill(masks_flip.bool(), 0)
flip_masked_images = images.masked_fill(masks_flip.bool(), 0)
ssim_fun = SSIM()
psnr_fun = PSNR()
if not collect_data:
mask_count = torch.count_nonzero(masks)
loss = F.mse_loss(
flip_masked_outputs,
flip_masked_images.float(),
reduction="sum"
)
loss = loss/mask_count
ssim = ssim_fun(flip_masked_outputs, flip_masked_images.float())
psnr = psnr_fun(flip_masked_outputs, flip_masked_images.float())
averaged_loss = average_losses_across_data_parallel_group(
[loss, psnr, ssim]
)
return loss, {"loss": averaged_loss[0],
"psnr": averaged_loss[1],
'ssim': averaged_loss[2]}
else:
synth_images = masked_images.float() + flip_masked_outputs
ssim = ssim_fun(synth_images, images.float())
psnr = psnr_fun(synth_images, images.float())
return torch.cat((images, masked_images, synth_images), dim=2), ssim, psnr
def forward_step(data_iterator, model):
"""Forward step."""
timers = get_timers()
# Get the batch.
timers("batch-generator", log_level=2).start()
(
images,
masks,
) = get_batch(data_iterator)
timers("batch-generator").stop()
masked_images = images.masked_fill(masks.bool(), 0)
outputs = model(masked_images)
# Forward mode
return outputs, partial(loss_func, images, masks, masked_images)
def process_non_loss_data(data, iteration, writer):
psnr_sum = 0
ssim_sum = 0
for (output_tb, ssim, psnr) in data:
output_tb[output_tb < 0] = 0
output_tb[output_tb > 1] = 1
writer.add_images("gt-input-output-vald", output_tb,
global_step=iteration, walltime=None,
dataformats='NCHW')
psnr_sum = psnr_sum + psnr.item()
ssim_sum = ssim_sum + ssim.item()
psnr = psnr_sum/len(data)
ssim = ssim_sum/len(data)
writer.add_scalar('PSNR generate value-validation', psnr, iteration)
writer.add_scalar('SSIM generate value-validation', ssim, iteration)
def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build train, valid, and test datasets."""
args = get_args()
print_rank_0(
"> building train, validation, and test datasets " "for VIT ..."
)
train_ds, valid_ds = build_train_valid_datasets(
data_path=args.data_path,
image_size=(args.img_h, args.img_w)
)
print_rank_0("> finished creating VIT datasets ...")
return train_ds, valid_ds, None
if __name__ == "__main__":
pretrain(
train_valid_test_datasets_provider,
model_provider,
ModelType.encoder_or_decoder,
forward_step,
process_non_loss_data,
args_defaults={'dataloader_type': 'cyclic', 'vision_pretraining': True}
)
from setuptools import setup, find_packages
setup(
name="megatron.core",
version="0.1",
description="Core components of Megatron.",
packages=find_packages(
include=("megatron.core")
)
)
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
""" Tasks data utility."""
import re
import numpy as np
def clean_text(text):
"""Remove new lines and multiple spaces and adjust end of sentence dot."""
text = text.replace("\n", " ")
text = re.sub(r'\s+', ' ', text)
for _ in range(3):
text = text.replace(' . ', '. ')
return text
def build_sample(ids, types, paddings, label, unique_id):
"""Convert to numpy and return a sample consumed by the batch producer."""
ids_np = np.array(ids, dtype=np.int64)
types_np = np.array(types, dtype=np.int64)
paddings_np = np.array(paddings, dtype=np.int64)
sample = ({'text': ids_np,
'types': types_np,
'padding_mask': paddings_np,
'label': int(label),
'uid': int(unique_id)})
return sample
def build_tokens_types_paddings_from_text(text_a, text_b,
tokenizer, max_seq_length):
"""Build token types and paddings, trim if needed, and pad if needed."""
text_a_ids = tokenizer.tokenize(text_a)
text_b_ids = None
if text_b is not None:
text_b_ids = tokenizer.tokenize(text_b)
return build_tokens_types_paddings_from_ids(text_a_ids, text_b_ids,
max_seq_length, tokenizer.cls,
tokenizer.sep, tokenizer.pad)
def build_tokens_types_paddings_from_ids(text_a_ids, text_b_ids, max_seq_length,
cls_id, sep_id, pad_id):
"""Build token types and paddings, trim if needed, and pad if needed."""
ids = []
types = []
paddings = []
# [CLS].
ids.append(cls_id)
types.append(0)
paddings.append(1)
# A.
len_text_a = len(text_a_ids)
ids.extend(text_a_ids)
types.extend([0] * len_text_a)
paddings.extend([1] * len_text_a)
# [SEP].
ids.append(sep_id)
types.append(0)
paddings.append(1)
# B.
if text_b_ids is not None:
len_text_b = len(text_b_ids)
ids.extend(text_b_ids)
types.extend([1] * len_text_b)
paddings.extend([1] * len_text_b)
# Cap the size.
trimmed = False
if len(ids) >= max_seq_length:
max_seq_length_m1 = max_seq_length - 1
ids = ids[0:max_seq_length_m1]
types = types[0:max_seq_length_m1]
paddings = paddings[0:max_seq_length_m1]
trimmed = True
# [SEP].
if (text_b_ids is not None) or trimmed:
ids.append(sep_id)
if text_b_ids is None:
types.append(0)
else:
types.append(1)
paddings.append(1)
# Padding.
padding_length = max_seq_length - len(ids)
if padding_length > 0:
ids.extend([pad_id] * padding_length)
types.extend([pad_id] * padding_length)
paddings.extend([0] * padding_length)
return ids, types, paddings
import os
import argparse
import collections
import numpy as np
import torch
def process_files(args):
all_predictions = collections.OrderedDict()
all_labels = collections.OrderedDict()
all_uid = collections.OrderedDict()
for path in args.paths:
path = os.path.join(path, args.prediction_name)
try:
data = torch.load(path)
for dataset in data:
name, d = dataset
predictions, labels, uid = d
if name not in all_predictions:
all_predictions[name] = np.array(predictions)
if args.labels is None:
args.labels = [i for i in range(all_predictions[name].shape[1])]
if args.eval:
all_labels[name] = np.array(labels)
all_uid[name] = np.array(uid)
else:
all_predictions[name] += np.array(predictions)
assert np.allclose(all_uid[name], np.array(uid))
except Exception as e:
print(e)
continue
return all_predictions, all_labels, all_uid
def get_threshold(all_predictions, all_labels, one_threshold=False):
if one_threshold:
all_predictons = {'combined': np.concatenate(list(all_predictions.values()))}
all_labels = {'combined': np.concatenate(list(all_predictions.labels()))}
out_thresh = []
for dataset in all_predictions:
preds = all_predictions[dataset]
labels = all_labels[dataset]
out_thresh.append(calc_threshold(preds, labels))
return out_thresh
def calc_threshold(p, l):
trials = [(i) * (1. / 100.) for i in range(100)]
best_acc = float('-inf')
best_thresh = 0
for t in trials:
acc = ((apply_threshold(p, t).argmax(-1) == l).astype(float)).mean()
if acc > best_acc:
best_acc = acc
best_thresh = t
return best_thresh
def apply_threshold(preds, t):
assert (np.allclose(preds.sum(-1), np.ones(preds.shape[0])))
prob = preds[:, -1]
thresholded = (prob >= t).astype(int)
preds = np.zeros_like(preds)
preds[np.arange(len(thresholded)), thresholded.reshape(-1)] = 1
return preds
def threshold_predictions(all_predictions, threshold):
if len(threshold) != len(all_predictions):
threshold = [threshold[-1]] * (len(all_predictions) - len(threshold))
for i, dataset in enumerate(all_predictions):
thresh = threshold[i]
preds = all_predictions[dataset]
all_predictions[dataset] = apply_threshold(preds, thresh)
return all_predictions
def postprocess_predictions(all_predictions, all_labels, args):
for d in all_predictions:
all_predictions[d] = all_predictions[d] / len(args.paths)
if args.calc_threshold:
args.threshold = get_threshold(all_predictions, all_labels, args.one_threshold)
print('threshold', args.threshold)
if args.threshold is not None:
all_predictions = threshold_predictions(all_predictions, args.threshold)
return all_predictions, all_labels
def write_predictions(all_predictions, all_labels, all_uid, args):
all_correct = 0
count = 0
for dataset in all_predictions:
preds = all_predictions[dataset]
preds = np.argmax(preds, -1)
if args.eval:
correct = (preds == all_labels[dataset]).sum()
num = len(all_labels[dataset])
accuracy = correct / num
count += num
all_correct += correct
accuracy = (preds == all_labels[dataset]).mean()
print(accuracy)
if not os.path.exists(os.path.join(args.outdir, dataset)):
os.makedirs(os.path.join(args.outdir, dataset))
outpath = os.path.join(
args.outdir, dataset, os.path.splitext(
args.prediction_name)[0] + '.tsv')
with open(outpath, 'w') as f:
f.write('id\tlabel\n')
f.write('\n'.join(str(uid) + '\t' + str(args.labels[p])
for uid, p in zip(all_uid[dataset], preds.tolist())))
if args.eval:
print(all_correct / count)
def ensemble_predictions(args):
all_predictions, all_labels, all_uid = process_files(args)
all_predictions, all_labels = postprocess_predictions(all_predictions, all_labels, args)
write_predictions(all_predictions, all_labels, all_uid, args)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--paths', required=True, nargs='+',
help='paths to checkpoint directories used in ensemble')
parser.add_argument('--eval', action='store_true',
help='compute accuracy metrics against labels (dev set)')
parser.add_argument('--outdir',
help='directory to place ensembled predictions in')
parser.add_argument('--prediction-name', default='test_predictions.pt',
help='name of predictions in checkpoint directories')
parser.add_argument('--calc-threshold', action='store_true',
help='calculate threshold classification')
parser.add_argument('--one-threshold', action='store_true',
help='use on threshold for all subdatasets')
parser.add_argument('--threshold', nargs='+', default=None, type=float,
help='user supplied threshold for classification')
parser.add_argument('--labels', nargs='+', default=None,
help='whitespace separated list of label names')
args = parser.parse_args()
ensemble_predictions(args)
if __name__ == '__main__':
main()
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Evaluation utilities."""
import os
import time
from functools import partial
import torch
from megatron import get_args
from megatron import print_rank_last, is_last_rank
from megatron.core import mpu
from megatron.schedules import get_forward_backward_func
from tasks.finetune_utils import build_data_loader
from tasks.finetune_utils import process_batch
def accuracy_func_provider(single_dataset_provider):
"""Provide function that calculates accuracies."""
args = get_args()
# Build dataloaders.
datapaths = args.valid_data
dataloaders = []
for datapath in datapaths:
dataset = single_dataset_provider(datapath)
dataloader = build_data_loader(
dataset, args.orig_micro_batch_size, num_workers=args.num_workers,
drop_last=(mpu.get_data_parallel_world_size() > 1))
dataloaders.append((dataset.dataset_name, dataloader))
def metrics_func(model, epoch, output_predictions=False):
print_rank_last('calculating metrics ...')
correct = 0
total = 0
if output_predictions:
assert mpu.get_data_parallel_world_size() == 1
named_predictions = []
names = 'predictions'
for name, dataloader in dataloaders:
output = calculate_correct_answers(name, model, dataloader,
epoch, output_predictions)
if not output_predictions:
correct_ans, total_count = output
else:
correct_ans, total_count, predictions = output
named_predictions.append((name, predictions))
names += '_' + name
correct += correct_ans
total += total_count
if is_last_rank():
percent = float(correct) * 100.0 / float(total)
print(' >> |epoch: {}| overall: correct / total = {} / {} = '
'{:.4f} %'.format(epoch, correct, total, percent))
if output_predictions and is_last_rank():
assert args.load is not None
filename = os.path.join(args.load, names + '.pt')
torch.save(named_predictions, filename)
return metrics_func
def calculate_correct_answers(name, model, dataloader,
epoch, output_predictions):
"""Calculate correct over total answers and return prediction if the
`output_predictions` is true."""
args = get_args()
forward_backward_func = get_forward_backward_func()
start_time = time.time()
for m in model:
m.eval()
saved_micro_batch_size = args.micro_batch_size
saved_global_batch_size = args.global_batch_size
ds = dataloader.dataset
if hasattr(ds, 'sample_multiplier'):
# If our dataset as a sample_multiplier attribute that means
# each "sample" from the dataset actually has multiple samples
# that will collapse into the batch dimension (for example in
# the RACE dataset that has several options), we need to
# account for that when setting the micro batch size.
sample_multiplier = ds.sample_multiplier
else:
sample_multiplier = 1
micro_batch_size_times_data_parallel = args.orig_micro_batch_size * args.data_parallel_size
num_micro_batches = args.orig_global_batch_size // micro_batch_size_times_data_parallel
def loss_func(output_predictions, labels, output_tensor):
logits = output_tensor
loss_dict = {}
# Add output predictions.
if output_predictions:
assert False
loss_dict['softmaxes'] = torch.nn.Softmax(dim=-1)(
logits.float()).data.cpu().numpy().tolist()
loss_dict['labels'] = labels.data.cpu().numpy().tolist()
loss_dict['ids'] = batch['uid'].cpu().numpy().tolist()
# Compute the correct answers.
predicted = torch.argmax(logits, dim=-1)
corrects = (predicted == labels)
# Add to the counters.
loss_dict['total'] = labels.size(0)
loss_dict['correct'] = corrects.sum().item()
return 0, loss_dict
# defined inside to capture output_predictions
def correct_answers_forward_step(batch, model):
try:
batch_ = next(batch)
except BaseException:
batch_ = batch
tokens, types, labels, attention_mask = process_batch(batch_)
# Forward model.
args = get_args()
output_tensor = model(tokens, attention_mask, tokentype_ids=types)
return output_tensor, partial(loss_func, output_predictions, labels)
with torch.no_grad():
# For all the batches in the dataset.
total = 0
correct = 0
if output_predictions:
# This option is only possible when data parallel size is 1.
assert mpu.get_data_parallel_world_size() == 1
softmaxes = []
labels = []
ids = []
for _, batch in enumerate(dataloader):
# For evaluation only mode we use drop_last = False to get all the
# samples, which means we might not have a full batch, so we
# adjust batch_size here to actual batch size of data
actual_batch_size = len(batch['label'])
# ... applying sample_multiplier if necessary
args.micro_batch_size = actual_batch_size * sample_multiplier
args.global_batch_size = actual_batch_size * sample_multiplier * num_micro_batches
loss_dicts = forward_backward_func(correct_answers_forward_step, batch, model,
optimizer=None, timers=None, forward_only=True)
for loss_dict in loss_dicts:
if output_predictions:
softmaxes.extend(loss_dict['softmaxes'])
labels.extend(loss_dict['labels'])
ids.extend(loss_dict['ids'])
total += loss_dict['total']
correct += loss_dict['correct']
for m in model:
m.train()
args.micro_batch_size = saved_micro_batch_size
args.global_batch_size = saved_global_batch_size
# Reduce.
if mpu.is_pipeline_last_stage():
unreduced = torch.cuda.LongTensor([correct, total])
torch.distributed.all_reduce(unreduced,
group=mpu.get_data_parallel_group())
# Print on screen.
correct_ans = unreduced[0].item()
total_count = unreduced[1].item()
percent = float(correct_ans) * 100.0 / float(total_count)
elapsed_time = time.time() - start_time
print_rank_last(' > |epoch: {}| metrics for {}: correct / total '
'= {} / {} = {:.4f} %, elapsed time (sec): {:.3f}'.format(
epoch, name, correct_ans, total_count,
percent, elapsed_time))
if output_predictions:
return correct_ans, total_count, (softmaxes, labels, ids)
return correct_ans, total_count
if output_predictions:
return 0, 0, ()
return 0, 0
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Finetune utilities."""
from functools import partial
import sys
import torch
from megatron import get_args, get_num_microbatches
from megatron import print_rank_0
from megatron import get_timers
from megatron.core import mpu
from megatron.core.enums import ModelType
from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
from megatron.training import evaluate_and_print_results
from megatron.training import setup_model_and_optimizer
from megatron.training import train_step
from megatron.training import training_log
from megatron.utils import average_losses_across_data_parallel_group
from megatron.utils import calc_params_l2_norm
from megatron.utils import check_adlr_autoresume_termination
def process_batch(batch):
"""Process batch and produce inputs for the model."""
args = get_args()
tokens = batch['text'].long().cuda().contiguous()
types = batch['types'].long().cuda().contiguous()
labels = batch['label'].long().cuda().contiguous()
attention_mask = batch['padding_mask'].float().cuda().contiguous()
if args.fp16:
attention_mask = attention_mask.half()
return tokens, types, labels, attention_mask
def cross_entropy_loss_func(labels, output_tensor):
logits = output_tensor
# Cross-entropy loss.
loss_func = torch.nn.CrossEntropyLoss()
loss = loss_func(logits.contiguous().float(), labels)
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'lm loss': averaged_loss[0]}
def _cross_entropy_forward_step(batch, model):
"""Simple forward step with cross-entropy loss."""
timers = get_timers()
# Get the batch.
timers('batch-generator', log_level=2).start()
try:
batch_ = next(batch)
except BaseException:
batch_ = batch
tokens, types, labels, attention_mask = process_batch(batch_)
timers('batch-generator').stop()
# Forward model.
output_tensor = model(tokens, attention_mask, tokentype_ids=types)
return output_tensor, partial(cross_entropy_loss_func, labels)
def build_data_loader(dataset, micro_batch_size, num_workers, drop_last,
task_collate_fn=None):
"""Data loader. Note that batch-size is the local (per GPU) batch-size."""
# Sampler.
world_size = mpu.get_data_parallel_world_size()
rank = mpu.get_data_parallel_rank()
sampler = torch.utils.data.distributed.DistributedSampler(
dataset, num_replicas=world_size, rank=rank)
# Data loader. Note that batch size is the per GPU batch size.
data_loader = torch.utils.data.DataLoader(dataset,
batch_size=micro_batch_size,
sampler=sampler,
shuffle=False,
num_workers=num_workers,
drop_last=drop_last,
pin_memory=True,
collate_fn=task_collate_fn)
return data_loader
def _build_infinite_size_dataloader(dataloader):
"""Build a looped dataloader with infinite size."""
iterator = dataloader.__iter__()
while True:
try:
yield iterator.__next__()
except StopIteration:
iterator = dataloader.__iter__()
def _build_train_valid_dataloaders(train_dataset, valid_dataset,
task_collate_fn=None):
"""Traing and validation dataloaders."""
args = get_args()
print_rank_0('building train and validation dataloaders ...')
# Training dataset.
train_dataloader = build_data_loader(train_dataset, args.micro_batch_size,
args.num_workers, not args.keep_last,
task_collate_fn)
# Set the training iterations.
args.train_iters_per_epoch = len(train_dataloader)
args.train_iters = args.epochs * args.train_iters_per_epoch
# Validation dataset. For this dataset, we do not need to set up
# shuffling so we can just use a simple infinite loop.
valid_dataloader_ = build_data_loader(valid_dataset, args.micro_batch_size,
args.num_workers, not args.keep_last,
task_collate_fn)
valid_dataloader = _build_infinite_size_dataloader(valid_dataloader_)
# Now that we've built the data loaders, set batch_size arguments
# to the actual batch size the model will see for this dataset.
# This is necessary so pipeline transfers know what size they are
# and the LR schedule, which is based on samples seen, gets set
# correctly.
args.orig_micro_batch_size = args.micro_batch_size
args.orig_global_batch_size = args.global_batch_size
if hasattr(train_dataset, 'sample_multiplier'):
# If our dataset as a sample_multiplier attribute that means
# each "sample" from the dataset actually has multiple samples
# that will collapse into the batch dimension (for example in
# the RACE dataset that has several options), we need to
# account for that when setting the micro batch size.
args.micro_batch_size *= train_dataset.sample_multiplier
args.global_batch_size *= train_dataset.sample_multiplier
return train_dataloader, valid_dataloader
def _train(model, optimizer, opt_param_scheduler, forward_step,
train_dataloader, valid_dataloader, end_of_epoch_callback):
"""Train the model."""
args = get_args()
timers = get_timers()
assert get_num_microbatches() == 1, "finetuning with gradient accumulation doesn't currently work"
# Turn on training mode which enables dropout.
for m in model:
m.train()
# Tracking loss.
losses_dict_sum = {}
# Starting epoch and iteration
start_epoch = args.iteration // args.train_iters_per_epoch
start_iteration = args.iteration % args.train_iters_per_epoch
iteration = args.iteration
# Memory reporting flag.
report_memory_flag = True
# For each remaining epoch
timers('interval-time', log_level=0).start(barrier=True)
for epoch in range(start_epoch, args.epochs):
print_rank_0('working on epoch {} ...'.format(epoch + 1))
# Set the data loader epoch to shuffle the index iterator.
train_dataloader.sampler.set_epoch(args.seed + epoch)
# For all the batches in the dataset.
for iteration_, batch in enumerate(train_dataloader):
# Ignore the iterations before starting value
if iteration_ < start_iteration:
continue
# Set to zero so the next epoch does not skip any batches.
start_iteration = 0
# Train for one step.
out = train_step(forward_step, batch, model, optimizer, opt_param_scheduler)
losses_dict, skipped_iter, grad_norm, num_zeros_in_grad = out
iteration += 1
# Logging.
params_norm = None
if args.log_params_norm:
params_norm = calc_params_l2_norm(model)
report_memory_flag = training_log(losses_dict, losses_dict_sum,
optimizer.param_groups[0]['lr'],
iteration,
optimizer.get_loss_scale().item(),
report_memory_flag, skipped_iter,
grad_norm, params_norm, num_zeros_in_grad)
# Autoresume
if args.adlr_autoresume and \
(iteration % args.adlr_autoresume_interval == 0):
check_adlr_autoresume_termination(iteration, model,
optimizer, opt_param_scheduler)
# Checkpointing
saved_checkpoint = False
if args.save and args.save_interval and \
iteration % args.save_interval == 0:
save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
saved_checkpoint = True
# Evaluation
if args.eval_interval and iteration % args.eval_interval == 0:
prefix = 'iteration {}'.format(iteration)
evaluate_and_print_results(prefix, forward_step,
valid_dataloader, model,
iteration, None, False)
# Exiting based on iterations
if args.exit_interval and iteration % args.exit_interval == 0:
if not saved_checkpoint:
save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
torch.distributed.barrier()
print_rank_0('exiting program at iteration {}'.format(iteration))
sys.exit()
# Checkpointing at the end of each epoch.
if args.save:
save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
# Callback at the end of each epoch.
if end_of_epoch_callback is not None:
end_of_epoch_callback(model, epoch)
def finetune(train_valid_datasets_provider, model_provider,
model_type=ModelType.encoder_or_decoder,
forward_step=_cross_entropy_forward_step,
end_of_epoch_callback_provider=None,
task_collate_fn=None):
"""Main finetune function used across all tasks."""
args = get_args()
timers = get_timers()
assert args.rampup_batch_size is None, \
'batch size scaling is not supported for finetuning'
# Train and validation data loaders.
timers('train/valid/test dataset/dataloder', log_level=0).start()
if args.epochs > 0:
train_dataset, valid_dataset = train_valid_datasets_provider()
train_dataloader, valid_dataloader = _build_train_valid_dataloaders(
train_dataset, valid_dataset, task_collate_fn)
else:
args.train_iters = 0
timers('train/valid/test dataset/dataloder').stop()
# Build calback function.
timers('callback function', log_level=0).start()
end_of_epoch_callback = None
if end_of_epoch_callback_provider is not None:
end_of_epoch_callback = end_of_epoch_callback_provider()
timers('callback function').stop()
# Build model, optimizer and learning rate scheduler.
timers('model and optimizer', log_level=0).start()
model, optimizer, opt_param_scheduler = setup_model_and_optimizer(model_provider, model_type)
timers('model and optimizer').stop()
# If pretrained checkpoint is provided and we have not trained for
# any iteration (i.e., iteration is zero), then load the pretrained
# checkpoint.
timers('pretrained checkpoint', log_level=0).start(barrier=True)
if args.iteration == 0 and args.pretrained_checkpoint is not None:
original_load = args.load
args.load = args.pretrained_checkpoint
original_rng = args.no_load_rng
args.no_load_rng = True
_ = load_checkpoint(model, None, None)
args.load = original_load
args.no_load_rng = original_rng
# This is critical when only model is loaded. We should make sure
# main parameters are also updated.
optimizer.reload_model_params()
timers('pretrained checkpoint').stop()
# Print setup timing.
print_rank_0('done with setups ...')
timers.log(['train/valid/test dataset/dataloder', 'callback function',
'model and optimizer', 'pretrained checkpoint'], barrier=True)
print_rank_0('training ...')
# Finetune the model.
if args.epochs > 0:
_train(model, optimizer, opt_param_scheduler, forward_step,
train_dataloader, valid_dataloader, end_of_epoch_callback)
# Or just evaluate.
else:
if end_of_epoch_callback is not None:
print_rank_0('evaluation only mode, setting epoch to -1')
end_of_epoch_callback(model, epoch=-1, output_predictions=True)
print_rank_0('done :-)')
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""GLUE dataset."""
from abc import ABC
from abc import abstractmethod
from torch.utils.data import Dataset
from megatron import print_rank_0
from tasks.data_utils import build_sample
from tasks.data_utils import build_tokens_types_paddings_from_text
class GLUEAbstractDataset(ABC, Dataset):
"""GLUE base dataset class."""
def __init__(self, task_name, dataset_name, datapaths,
tokenizer, max_seq_length):
# Store inputs.
self.task_name = task_name
self.dataset_name = dataset_name
self.tokenizer = tokenizer
self.max_seq_length = max_seq_length
print_rank_0(' > building {} dataset for {}:'.format(self.task_name,
self.dataset_name))
# Process the files.
string = ' > paths:'
for path in datapaths:
string += ' ' + path
print_rank_0(string)
self.samples = []
for datapath in datapaths:
self.samples.extend(self.process_samples_from_single_path(datapath))
print_rank_0(' >> total number of samples: {}'.format(
len(self.samples)))
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
raw_sample = self.samples[idx]
ids, types, paddings = build_tokens_types_paddings_from_text(
raw_sample['text_a'], raw_sample['text_b'],
self.tokenizer, self.max_seq_length)
sample = build_sample(ids, types, paddings,
raw_sample['label'], raw_sample['uid'])
return sample
@abstractmethod
def process_samples_from_single_path(self, datapath):
"""Abstract method that takes a single path / filename and
returns a list of dataset samples, each sample being a dict of
{'text_a': string, 'text_b': string, 'label': int, 'uid': int}
"""
pass
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""GLUE finetuning/evaluation."""
from megatron import get_args
from megatron import print_rank_0
from megatron import get_tokenizer
from megatron.model.classification import Classification
from tasks.eval_utils import accuracy_func_provider
from tasks.finetune_utils import finetune
def glue_classification(num_classes, Dataset,
name_from_datapath_func):
def train_valid_datasets_provider():
"""Build train and validation dataset."""
args = get_args()
tokenizer = get_tokenizer()
train_dataset = Dataset('training', args.train_data,
tokenizer, args.seq_length)
valid_dataset = Dataset('validation', args.valid_data,
tokenizer, args.seq_length)
return train_dataset, valid_dataset
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
args = get_args()
print_rank_0('building classification model for {} ...'.format(
args.task))
model = Classification(num_classes=num_classes, num_tokentypes=2,
pre_process=pre_process, post_process=post_process)
return model
def metrics_func_provider():
"""Privde metrics callback function."""
def single_dataset_provider(datapath):
args = get_args()
tokenizer = get_tokenizer()
name = name_from_datapath_func(datapath)
return Dataset(name, [datapath], tokenizer, args.seq_length)
return accuracy_func_provider(single_dataset_provider)
"""Finetune/evaluate."""
finetune(train_valid_datasets_provider, model_provider,
end_of_epoch_callback_provider=metrics_func_provider)
def main():
args = get_args()
if args.task == 'MNLI':
num_classes = 3
from tasks.glue.mnli import MNLIDataset as Dataset
def name_from_datapath(datapath):
return datapath.split('MNLI')[-1].strip(
'.tsv').strip('/').replace('_', '-')
elif args.task == 'QQP':
num_classes = 2
from tasks.glue.qqp import QQPDataset as Dataset
def name_from_datapath(datapath):
return datapath.split('QQP')[-1].strip(
'.tsv').strip('/').replace('_', '-')
else:
raise NotImplementedError('GLUE task {} is not implemented.'.format(
args.task))
glue_classification(num_classes, Dataset, name_from_datapath)
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""MNLI dataset."""
from megatron import print_rank_0
from tasks.data_utils import clean_text
from .data import GLUEAbstractDataset
LABELS = {'contradiction': 0, 'entailment': 1, 'neutral': 2}
class MNLIDataset(GLUEAbstractDataset):
def __init__(self, name, datapaths, tokenizer, max_seq_length,
test_label='contradiction'):
self.test_label = test_label
super().__init__('MNLI', name, datapaths,
tokenizer, max_seq_length)
def process_samples_from_single_path(self, filename):
""""Implement abstract method."""
print_rank_0(' > Processing {} ...'.format(filename))
samples = []
total = 0
first = True
is_test = False
with open(filename, 'r') as f:
for line in f:
row = line.strip().split('\t')
if first:
first = False
if len(row) == 10:
is_test = True
print_rank_0(
' reading {}, {} and {} columns and setting '
'labels to {}'.format(
row[0].strip(), row[8].strip(),
row[9].strip(), self.test_label))
else:
print_rank_0(' reading {} , {}, {}, and {} columns '
'...'.format(
row[0].strip(), row[8].strip(),
row[9].strip(), row[-1].strip()))
continue
text_a = clean_text(row[8].strip())
text_b = clean_text(row[9].strip())
unique_id = int(row[0].strip())
label = row[-1].strip()
if is_test:
label = self.test_label
assert len(text_a) > 0
assert len(text_b) > 0
assert label in LABELS
assert unique_id >= 0
sample = {'text_a': text_a,
'text_b': text_b,
'label': LABELS[label],
'uid': unique_id}
total += 1
samples.append(sample)
if total % 50000 == 0:
print_rank_0(' > processed {} so far ...'.format(total))
print_rank_0(' >> processed {} samples.'.format(len(samples)))
return samples
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""QQP dataset."""
from megatron import print_rank_0
from tasks.data_utils import clean_text
from .data import GLUEAbstractDataset
LABELS = [0, 1]
class QQPDataset(GLUEAbstractDataset):
def __init__(self, name, datapaths, tokenizer, max_seq_length,
test_label=0):
self.test_label = test_label
super().__init__('QQP', name, datapaths,
tokenizer, max_seq_length)
def process_samples_from_single_path(self, filename):
""""Implement abstract method."""
print_rank_0(' > Processing {} ...'.format(filename))
samples = []
total = 0
first = True
is_test = False
with open(filename, 'r') as f:
for line in f:
row = line.strip().split('\t')
if first:
first = False
if len(row) == 3:
is_test = True
print_rank_0(' reading {}, {}, and {} columns and '
'setting labels to {}'.format(
row[0].strip(), row[1].strip(),
row[2].strip(), self.test_label))
else:
assert len(row) == 6
print_rank_0(' reading {}, {}, {}, and {} columns'
' ...'.format(
row[0].strip(), row[3].strip(),
row[4].strip(), row[5].strip()))
continue
if is_test:
assert len(row) == 3, 'expected length 3: {}'.format(row)
uid = int(row[0].strip())
text_a = clean_text(row[1].strip())
text_b = clean_text(row[2].strip())
label = self.test_label
assert len(text_a) > 0
assert len(text_b) > 0
else:
if len(row) == 6:
uid = int(row[0].strip())
text_a = clean_text(row[3].strip())
text_b = clean_text(row[4].strip())
label = int(row[5].strip())
else:
print_rank_0('***WARNING*** index error, '
'skipping: {}'.format(row))
continue
if len(text_a) == 0:
print_rank_0('***WARNING*** zero length a, '
'skipping: {}'.format(row))
continue
if len(text_b) == 0:
print_rank_0('***WARNING*** zero length b, '
'skipping: {}'.format(row))
continue
assert label in LABELS
assert uid >= 0
sample = {'uid': uid,
'text_a': text_a,
'text_b': text_b,
'label': label}
total += 1
samples.append(sample)
if total % 50000 == 0:
print_rank_0(' > processed {} so far ...'.format(total))
print_rank_0(' >> processed {} samples.'.format(len(samples)))
return samples
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Main tasks functionality."""
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir)))
from megatron import get_args
from megatron.initialize import initialize_megatron
def get_tasks_args(parser):
"""Provide extra arguments required for tasks."""
group = parser.add_argument_group(title='tasks')
group.add_argument('--task', type=str, required=True,
help='Task name.')
group.add_argument('--epochs', type=int, default=None,
help='Number of finetunning epochs. Zero results in '
'evaluation only.')
group.add_argument('--pretrained-checkpoint', type=str, default=None,
help='Pretrained checkpoint used for finetunning.')
group.add_argument('--keep-last', action='store_true',
help='Keep the last batch (maybe incomplete) in'
'the data loader')
group.add_argument('--train-data', nargs='+', default=None,
help='Whitespace separated paths or corpora names '
'for training.')
group.add_argument('--valid-data', nargs='*', default=None,
help='path(s) to the validation data.')
group.add_argument('--overlapping-eval', type=int, default=32,
help='Sliding window for overlapping evaluation.')
group.add_argument('--strict-lambada', action='store_true',
help='Use more difficult formulation of lambada.')
# Retriever args
group.add_argument('--qa-data-dev', type=str, default=None,
help='Path to the QA dataset dev file.')
group.add_argument('--qa-data-test', type=str, default=None,
help='Path to the QA dataset test file.')
# Faiss arguments for retriever
group.add_argument('--faiss-use-gpu', action='store_true',
help='Whether create the FaissMIPSIndex on GPU')
group.add_argument('--faiss-match', type=str, default='string', \
choices=['regex', 'string'], help="Answer matching '\
'logic type")
group.add_argument('--faiss-topk-retrievals', type=int, default=100,
help='Number of blocks to use as top-k during retrieval')
# finetune for retriever
group.add_argument('--eval-micro-batch-size', type=int, default=None,
help='Eval Batch size per model instance (local batch '
'size). Global batch size is local batch size '
'times data parallel size.')
group.add_argument('--train-with-neg', action='store_true',
help='Whether to use negative examples during model '
'training')
group.add_argument('--train-hard-neg', type=int, default=0,
help='Number of hard negative exmaples to use during '
'training')
# parameters for Av.rank validation method
# Following options/arguments have been taken directly from DPR codebase
group.add_argument('--val-av-rank-hard-neg', type=int, default=30,
help='Av.rank validation: how many hard negatives to'
' take from each question pool')
group.add_argument('--val-av-rank-other-neg', type=int, default=30,
help='Av.rank validation: how many other negatives to'
' take from each question pool')
return parser
if __name__ == '__main__':
initialize_megatron(extra_args_provider=get_tasks_args)
args = get_args()
if args.num_layers_per_virtual_pipeline_stage is not None:
print("Interleaved pipeline schedule is not yet supported for downstream tasks.")
exit()
if args.task == 'RACE':
from race.finetune import main
elif args.task in ['MNLI', 'QQP']:
from glue.finetune import main
elif args.task in ['LAMBADA', 'WIKITEXT103']:
from zeroshot_gpt.evaluate import main
elif args.task in ['ICT-ZEROSHOT-NQ', 'RETRIEVER-EVAL']:
from orqa.evaluate_orqa import main
elif args.task in ['RET-FINETUNE-NQ']:
from orqa.supervised.finetune import main
else:
raise NotImplementedError('Task {} is not implemented.'.format(
args.task))
main()
# Multi-Stage Prompting for Knowledgeable Dialogue Generation
Below we present the steps to run our multi-stage dialogue prompting (MSDP) framework.
## Multi-Stage Dialogue Prompting
### Data Preparation
1. Dataset Download: [Wizard of Wikipedia](https://parl.ai/projects/wizard_of_wikipedia/) and [Wizard of Internet](https://parl.ai/projects/sea/)
2. Data Processing: We provide the script to run the [`data processing`](../../examples/msdp/data_processing.sh) of the datatsets.
### Stage-1: Prompting for Knowledge Generation
1. We provide the script to perform the [`first-stage prompting`](../../examples/msdp/prompt_knwl_gen.sh) for the knowledge generation.
2. We provide the [`evaluation script`](../../examples/msdp/eval_knwl_generation.sh) for the automatic evaluation (i.e., F1, BLEU, METEOR, and ROUGE-L) of the knowledge generation.
### Stage-2: Prompting for Response Generation
1. We provide the script to [`prepare the input file`](../../examples/msdp/prep_resp_gen.sh) for the response generation (based on the previously generated knowledge file).
2. We provide the script to perform the [`second-stage prompting`](../../examples/msdp/prompt_resp_gen.sh) for the response generation.
3. We provide the [`evaluation script`](../../examples/msdp/eval_resp_generation.sh) for the automatic evaluation (i.e., F1, KF1, BLEU, METEOR, and ROUGE-L) of the response generation.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Model evaluation"""
from megatron import get_args
from megatron import print_rank_0
from tasks.msdp.metrics import F1Metric
from tqdm import tqdm
def evaluate_f1(guess_file, answer_file):
"""Evaluating F1 Score"""
guess_list = []
print_rank_0('reading %s' % guess_file)
with open(guess_file, "r") as f:
for i, line in enumerate(tqdm(f)):
line = line.strip()
if "<|endoftext|>" in line:
line = line.replace("<|endoftext|>", "")
guess_list.append(line)
answer_list = []
print_rank_0('reading %s' % answer_file)
with open(answer_file, "r") as f:
for i, line in enumerate(tqdm(f)):
line = line.strip()
if line == "no_passages_used":
line = ""
answer_list.append(line)
assert len(guess_list) == len(answer_list), \
"lengths of guess and answer are different!"
precision, recall, f1 = F1Metric.compute_all_pairs(guess_list, answer_list)
print_rank_0('Precision: %.4f; recall: %.4f; f1: %.4f' % (precision, recall, f1))
print_rank_0('done :-)')
def main():
args = get_args()
evaluate_f1(args.guess_file, args.answer_file)
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Run multi-stage dialogue prompting (MSDP)."""
import os
import sys
sys.path.append(os.path.abspath(os.path.join(
os.path.join(os.path.dirname(__file__), os.path.pardir), os.path.pardir)))
from megatron import get_args
from megatron.initialize import initialize_megatron
def get_tasks_args(parser):
"""Provide extra arguments required for tasks."""
group = parser.add_argument_group(title='tasks')
# parameters for the knowledgeable dialogue generation
group.add_argument('--task', type=str, required=True,
help='Task name.')
group.add_argument("--sample-input-file", type=str, default=None,
help='Get input from file instead of interactive mode, '
'each line is an input.')
group.add_argument("--sample-output-file", type=str, default=None,
help='Output file got from --sample-input-file')
group.add_argument('--prompt-file', type=str, default=None,
help='prompting file')
group.add_argument('--prompt-type', type=str, default=None,
choices=['knowledge', 'response'],
help='prompt type (knowledge or response)')
group.add_argument('--num-prompt-examples', type=int, default=10,
help='number of prompt examples')
group.add_argument('--guess-file', type=str, default=None,
help='datapath for generated sentences')
group.add_argument('--answer-file', type=str, default=None,
help='datapath for golden sentences')
group.add_argument('--out-seq-length', type=int, default=100,
help='output sequence length')
group.add_argument('--api-prompt', default=False, action="store_true",
help='setup model api for prompting')
group.add_argument('--megatron-api-url', type=str, default=None,
help='url of the megatron api')
return parser
if __name__ == '__main__':
initialize_megatron(extra_args_provider=get_tasks_args)
args = get_args()
if args.num_layers_per_virtual_pipeline_stage is not None:
print("Interleaved pipeline schedule is not yet supported for downstream tasks.")
exit()
if args.task == 'MSDP-PROMPT':
from tasks.msdp.prompt import main
elif args.task == 'MSDP-EVAL-F1':
from tasks.msdp.evaluate import main
else:
raise NotImplementedError('Task {} is not implemented.'.format(
args.task))
main()
# The following code is adapted from
# https://github.com/facebookresearch/ParlAI/blob/master/parlai/core/metrics.py,
# which is licensed under the MIT license. More details on the license can be
# found at https://github.com/facebookresearch/ParlAI/blob/master/LICENSE.
"""Provides standard metric evaluations for dialog."""
from collections import Counter
from typing import List
import numpy as np
import re
re_art = re.compile(r'\b(a|an|the)\b')
re_punc = re.compile(r'[!"#$%&()*+,-./:;<=>?@\[\]\\^`{|}~_\']')
def normalize_answer(s):
"""
Lower text and remove punctuation, articles and extra whitespace.
"""
s = s.lower()
s = re_punc.sub(' ', s)
s = re_art.sub(' ', s)
s = ' '.join(s.split())
return s
class F1Metric:
"""
Helper class which computes token-level F1.
"""
@staticmethod
def _prec_recall_f1_score(pred_items, gold_items):
"""
Compute precision, recall and f1 given a set of gold and prediction items.
:param pred_items: iterable of predicted values
:param gold_items: iterable of gold values
:return: tuple (p, r, f1) for precision, recall, f1
"""
common = Counter(gold_items) & Counter(pred_items)
num_same = sum(common.values())
if num_same == 0:
return 0, 0, 0
precision = 1.0 * num_same / len(pred_items)
recall = 1.0 * num_same / len(gold_items)
f1 = (2 * precision * recall) / (precision + recall)
return precision, recall, f1
@staticmethod
def compute_each_pair(guess: str, answer: str):
if answer == "":
return None, None, None
if guess == "":
return 0, 0, 0
g_tokens = normalize_answer(guess).split()
a_tokens = normalize_answer(answer).split()
precision, recall, f1 = F1Metric._prec_recall_f1_score(g_tokens, a_tokens)
return precision, recall, f1
@staticmethod
def compute_all_pairs(guesses: List[str], answers: List[str]):
# additional augment:
assert len(guesses) == len(answers)
precision_list, recall_list, f1_list = [], [], []
for guess, answer in zip(guesses, answers):
precision, recall, f1 = F1Metric.compute_each_pair(guess, answer)
if precision is None or recall is None or f1 is None:
continue
precision_list.append(precision)
recall_list.append(recall)
f1_list.append(f1)
return np.mean(precision_list), np.mean(recall_list), np.mean(f1_list)
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Preprocessing for Wizard of Wikipedia and Wizard of Internet datasets"""
import torch
import argparse
from nltk import word_tokenize
from tqdm import tqdm
import numpy as np
import json
def get_args():
parser = argparse.ArgumentParser(description="Preprocessing")
parser.add_argument("--func", type=str, default=None,
help="choose to run which function")
parser.add_argument("--raw_file", type=str, default=None,
help="path of the input file")
parser.add_argument("--processed_file", type=str, default=None,
help="path of the output file")
parser.add_argument("--knwl_ref_file", type=str, default=None,
help="path of the knowledge reference file")
parser.add_argument("--resp_ref_file", type=str, default=None,
help="path of the knowledge reference file")
parser.add_argument("--knwl_gen_file", type=str, default=None,
help="path of the generated knowledge file")
parser.add_argument("--test_file", type=str, default=None,
help="path of the test file")
parser.add_argument("--train_file", type=str, default=None,
help="path of the train file")
parser.add_argument("--model_file", type=str, default=None,
help="path of the model file")
parser.add_argument("--data_type", type=str, default=None,
help="data types, choose one out of three types: \
wow_seen, wow_unseen, and woi")
parser.add_argument("--seed", type=int, default=1234,
help="random seed")
args = parser.parse_args()
return args
def process_wow_dataset(raw_file, processed_file, knwl_ref_file, resp_ref_file):
"""
This is a function used for processing the wizard of wikipedia (wow) dataset
Expected processed format:
topic \t dialogue context \t golden knowledge \t golden response
"""
# loading the raw data
print("> Loading data from %s" % raw_file)
with open(raw_file, "r") as fr:
dialog_data = json.load(fr)
print("> Processing data ...")
fproc = open(processed_file, "w")
fknwl = open(knwl_ref_file, "w") if knwl_ref_file else None
fresp = open(resp_ref_file, "w") if resp_ref_file else None
for i, sample in enumerate(tqdm(dialog_data)):
# get all the dialog data for a single dialog sample
dialog = sample["dialog"]
turn_list = [] # collect the dialog history
# processing for each single dialog sample
for j, turn in enumerate(dialog):
# text of each turn
text = turn["text"]
if not (text.endswith("?") or text.endswith(".") or text.endswith("!")):
text = text + "."
if j == 0:
# first turn
turn_list.append(text)
continue
speaker = turn["speaker"].lower()
if "wizard" in speaker:
checked_sentence = list(turn["checked_sentence"].values()) # knowledge
checked_passage = list(turn["checked_passage"].values()) # topic
assert len(checked_sentence) <= 1
# get the ground truth knowledge
if len(checked_sentence) > 0:
checked_sentence = checked_sentence[0]
else:
checked_sentence = "no_passages_used"
if len(checked_passage) == 1:
checked_passage = checked_passage[0]
else:
checked_passage = "no_passages_used"
# get the topic
if checked_passage != "no_passages_used":
topic = checked_passage
else:
topic = sample["chosen_topic"]
dialog_context = " [SEP] ".join(turn_list)
knowledge = checked_sentence
response = text
# add the response into the dialog history
turn_list.append(response)
# write to the output files
fproc.write(topic + "\t" + dialog_context + "\t" + \
knowledge + "\t" + response + "\n")
if fknwl:
fknwl.write(knowledge + "\n")
if fresp:
# tokenize for evaluation
response = " ".join(word_tokenize(response))
fresp.write(response + "\n")
else:
assert "apprentice" in speaker
turn_list.append(text)
fproc.close()
if fknwl:
fknwl.close()
if fresp:
fresp.close()
def process_woi_dataset(raw_file, processed_file, knwl_ref_file, resp_ref_file):
"""
This is a function used for processing the wizard of internet (woi) dataset
Expected processed format:
topic \t dialogue context \t golden knowledge \t golden response
"""
print("> Processing %s" % raw_file)
fproc = open(processed_file, "w")
fknwl = open(knwl_ref_file, "w") if knwl_ref_file else None
fresp = open(resp_ref_file, "w") if resp_ref_file else None
with open(raw_file, "r") as fr:
for i, line in tqdm(enumerate(fr)):
# read line by line, each line uses json format
line = line.strip()
item_dict = json.loads(line)
# item_dict is a dictionary
# its key is the data id, and its value contains all the data content
item_dict = item_dict.values()
item_dict = list(item_dict)[0] # len(item_dict) == 1
# get the whole dialog data for a single dialog sample
dialog_data = item_dict['dialog_history']
length = len(dialog_data)
turn_list = [] # collect the dialog history
search_text = ""
for i in range(length):
item = dialog_data[i]
action = item['action']
if action == "Wizard => SearchAgent":
search_text = item['text']
elif action == "Wizard => Apprentice":
if len(turn_list) == 0:
# first turn
turn = item['text']
turn_list.append(turn)
continue
# get the relevant content
contents = item["context"]["contents"]
selects = item["context"]["selected_contents"]
flag = selects[0][0]
selects = selects[1:]
assert len(selects) == len(contents)
# get the topic
if flag:
# no knowledge sentence is used for the response
topic = "no_topic"
knwl_sent = "no_passages_used"
else:
# we consider the search text as the topic
topic = search_text
# get the knowledge sentence
knwl_sent = ""
for content, select in zip(contents, selects):
content = content['content']
assert len(content) == len(select)
for c, s in zip(content, select):
if s:
knwl_sent = c
break
if knwl_sent == "":
# no knowledge is used for the response
topic = "no_topic"
knwl_sent = "no_passages_used"
# get dialogue context, knowledge, and response
dialog_context = " [SEP] ".join(turn_list)
response = item['text']
# processing
topic = topic.replace("\n", "").replace("\r", \
"").replace("\t", "")
dialog_context = dialog_context.replace("\n", "").replace("\r", \
"").replace("\t", "")
knwl_sent = knwl_sent.replace("\n", "").replace("\r", \
"").replace("\t", "")
response = response.replace("\n", "").replace("\r", \
"").replace("\t", "")
if topic != "no_topic":
# write to the ouput files
fproc.write(topic + "\t" + dialog_context + "\t" + \
knwl_sent + "\t" + response + "\n")
if fknwl:
fknwl.write(knwl_sent + "\n")
if fresp:
# tokenize for evaluation
response = " ".join(word_tokenize(response))
fresp.write(response + "\n")
turn_list.append(response)
elif action == "Apprentice => Wizard":
turn = item['text']
turn_list.append(turn)
else:
assert action == "SearchAgent => Wizard", \
"Please check whether you have used the correct data!"
fproc.close()
if fknwl:
fknwl.close()
if fresp:
fresp.close()
def get_database(test_datapath, train_datapath, data_type):
"""Get the database by topics"""
assert data_type in ["wow_seen", "wow_unseen", "woi"], \
"Please input a correct data type!!"
# get test data topic dictionary
print("> reading test data from %s" % test_datapath)
test_topics = {}
with open(test_datapath, "r") as f:
for i, line in enumerate(f):
line = line.strip()
splits = line.split("\t")
topic = splits[0]
test_topics[topic] = True
print("> reading data from %s" % train_datapath)
train_data_by_topic = {}
dialog_data_by_topic = {}
dialog_examples = []
with open(train_datapath, "r") as f:
for i, line in enumerate(f):
line = line.strip()
splits = line.split("\t")
topic = splits[0]
turns = splits[1].split(" [SEP] ")[-3:]
knowledge = splits[2]
response = splits[3]
# filtering data samples
if knowledge == "no_passages_used":
# when no knowledge is used
continue
if data_type != "wow_seen" and ("(" in knowledge or ")" in knowledge):
# when bracket exists in the knowledge
continue
if data_type != "wow_seen" and topic not in knowledge:
# when topic does not exist in the knowledge
continue
# get the instance
last_turn = turns[-1]
instance = "( " + last_turn + " ) " + topic + " => " + knowledge
# construct dialog example
dialog_example = ""
if data_type != "wow_seen":
dialog_example += "( " + topic + " ) "
for i, turn in enumerate(turns):
if i != 0:
dialog_example += " "
dialog_example += turn
# check overlaps
if topic in test_topics:
if topic not in train_data_by_topic:
train_data_by_topic[topic] = [instance]
else:
train_data_by_topic[topic].append(instance)
if topic not in dialog_data_by_topic:
dialog_data_by_topic[topic] = [dialog_example]
else:
dialog_data_by_topic[topic].append(dialog_example)
else:
# filtering data samples
if len(knowledge.split()) > 20:
# knowledge is too long
continue
if knowledge.startswith("It") or knowledge.startswith("it") or \
knowledge.startswith("This") or knowledge.startswith("this"):
continue
# append all the data into dialogue examples list
dialog_examples.append((topic, dialog_example, instance))
return train_data_by_topic, dialog_data_by_topic, dialog_examples
emb_dict = {}
def select_prompts_based_on_similarity(
query, dialog_list, prompt_list, topic, tokenizer, encoder, topk):
"""Select samples based on the similarity"""
with torch.no_grad():
# get the query embeddings
query_ids = tokenizer.encode(query)
query_ids = torch.LongTensor([query_ids]).cuda()
query_emb = encoder(input_ids=query_ids).pooler_output
query_emb = query_emb[0]
# calculate embeddings for the samples in the database
if topic in emb_dict:
example_embeddings = emb_dict[topic]
example_embeddings = example_embeddings.cuda()
else:
for idx, example in enumerate(dialog_list):
example_ids = tokenizer.encode(example)
example_ids = torch.LongTensor([example_ids]).cuda()
example_emb = encoder(input_ids=example_ids).pooler_output
if idx == 0:
example_embeddings = example_emb
else:
example_embeddings = torch.cat(
(example_embeddings, example_emb), dim=0)
emb_dict[topic] = example_embeddings.cpu()
# compare the similarity and select the topk samples
similarity_list = example_embeddings.matmul(query_emb)
_, indices = torch.topk(similarity_list, k=topk)
indices = indices.tolist()
indices = indices[::-1] # reverse the order
selected_prompts = []
for index in indices:
# index = index.item()
selected_prompts.append(prompt_list[index])
return selected_prompts
def prompt_selection_for_knowledge_generation(
test_datapath, train_datapath, model_path, output_prompt_path, data_type):
"""Selecting prompts for the knowledge generation"""
print("> Selecting prompts for the knowledge generation")
train_data_by_topic, dialog_data_by_topic, dialog_examples = \
get_database(test_datapath, train_datapath, data_type)
from transformers import DPRQuestionEncoderTokenizer
print("> loading tokenizer and encoder")
tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
'facebook/dpr-question_encoder-single-nq-base')
encoder = torch.load(model_path).cuda()
print("> getting dialog embeddings")
with torch.no_grad():
for idx, example in tqdm(enumerate(dialog_examples)):
dialog = example[1]
dialog_ids = tokenizer.encode(dialog)
dialog_ids = torch.LongTensor([dialog_ids]).cuda()
dialog_emb = encoder(input_ids=dialog_ids).pooler_output
if idx == 0:
dialog_embeddings = dialog_emb
else:
dialog_embeddings = torch.cat((dialog_embeddings, dialog_emb), dim=0)
print("> reading test data from %s" % test_datapath)
prompt_list_for_each_sample = []
with open(test_datapath, "r") as f:
for i, line in tqdm(enumerate(f)):
line = line.strip()
splits = line.split("\t")
topic = splits[0]
turns = splits[1].split(" [SEP] ")[-3:]
# get the query sentence
query_sent = ""
if data_type != "seen":
query_sent += "( " + topic + " ) "
for i, turn in enumerate(turns):
if i != 0:
query_sent += " "
query_sent += turn
if topic not in train_data_by_topic:
# get the query embedding
query_ids = tokenizer.encode(query_sent)
query_ids = torch.LongTensor([query_ids]).cuda()
query_emb = encoder(input_ids=query_ids).pooler_output
query_emb = query_emb[0]
# calculate the similarity
similarity_list = dialog_embeddings.matmul(query_emb)
_, indices = torch.sort(similarity_list)
indices = indices.tolist()
selected_topics = {}
selected_prompts = []
num_prompt = 0
for index in indices:
example = dialog_examples[index]
topic_temp = example[0]
if topic_temp not in selected_topics:
selected_topics[topic_temp] = True
selected_prompts.append(example[2])
num_prompt += 1
if num_prompt == 10:
break
# get the selected samples
example_list = selected_prompts[::-1]
key = topic + " " + turns[-1]
prompt_list_for_each_sample.append({key: example_list})
else:
num_data_sample = min(len(train_data_by_topic[topic]), 10)
total_example_list = train_data_by_topic[topic]
dialog_list = dialog_data_by_topic[topic]
assert len(dialog_list) == len(train_data_by_topic[topic])
# calculate the similarity
example_list = select_prompts_based_on_similarity(
query_sent, dialog_list, total_example_list,
topic, tokenizer, encoder, topk=num_data_sample)
key = topic + " " + turns[-1]
prompt_list_for_each_sample.append({key: example_list})
print("writing to %s" % output_prompt_path)
with open(output_prompt_path, "w") as f:
for instance in tqdm(prompt_list_for_each_sample):
json.dump(instance, f)
f.write("\n")
def prompt_selection_for_response_generation(input_path, output_path, seed):
"""Selecting prompts for the response generation"""
print("> Selecting prompts for the response generation")
print("> set random seed")
np.random.seed(seed)
prompt_example_list = []
print("> reading data from %s" % input_path)
with open(input_path, "r") as f:
for i, line in tqdm(enumerate(f)):
line = line.strip()
splits = line.split("\t")
# get the topic, context, knowledge and response
topic = splits[0]
dialog_context = splits[1]
knowledge = splits[2]
response = splits[3]
turns = dialog_context.split(" [SEP] ")[-3:]
if knowledge == "no_passages_used":
continue
# calculate the overlap ratio
from nltk import word_tokenize
knowledge_sent_token_list = word_tokenize(knowledge)
knowledge_sent_token_dict = {token: True for token in knowledge_sent_token_list}
knowledge_len = len(knowledge_sent_token_list)
response_token_list = word_tokenize(response)
response_len = len(response_token_list)
num_overlap_token = 0
accumulator = 0
for token in response_token_list:
if token in knowledge_sent_token_dict:
accumulator += 1
else:
if accumulator >= 10:
num_overlap_token += accumulator
accumulator = 0
if accumulator >= 10:
num_overlap_token += accumulator
# filtering the data based on the ratio
if num_overlap_token > response_len * 0.9 or num_overlap_token < response_len * 0.6:
continue
if num_overlap_token < knowledge_len * 0.8:
continue
last_turn = " ".join(word_tokenize(turns[-1]))
knowledge = " ".join(word_tokenize(knowledge))
response = " ".join(word_tokenize(response))
prompt_example = ""
# add dialog context
prompt_example += "Topic: " + topic + ". "
prompt_example += "User says: " + last_turn + " "
prompt_example += "We know that: " + knowledge + " "
prompt_example += "System replies: " + response
prompt_example_list.append(prompt_example)
# shuffle the prompt examples
np.random.shuffle(prompt_example_list)
print("> writing to %s" % output_path)
with open(output_path, "w") as f:
# f.write("Generate the System's response based on the knowledge sentence:\n")
for i in tqdm(range(20)):
example = prompt_example_list[i]
f.write(example + "\n")
def prepare_input_for_response_generation(test_file, knwl_gen_file, processed_file):
"""Preparing inputs for the response generation"""
print("> Reading knowledge file from %s" % knwl_gen_file)
# get the knowledge list
with open(knwl_gen_file, "r") as f:
knowledge_list = f.readlines()
print("> Processing ...")
with open(test_file, "r") as fr:
with open(processed_file, "w") as fw:
for line_num, line in enumerate(tqdm(fr)):
line = line.strip()
splits = line.split("\t")
# prepare topic, context, knowledge and response
topic = splits[0]
dialog_context = splits[1]
response = splits[3]
knowledge = knowledge_list[line_num]
knowledge = knowledge.strip()
if "<|endoftext|>" in knowledge:
knowledge = knowledge.replace("<|endoftext|>", "")
# write to the output file
fw.write(topic + "\t" + dialog_context + "\t" \
+ knowledge + "\t" + response + "\n")
if __name__ == "__main__":
args = get_args()
if args.func == "process_wow_dataset":
process_wow_dataset(args.raw_file, args.processed_file, args.knwl_ref_file, args.resp_ref_file)
elif args.func == "process_woi_dataset":
process_woi_dataset(args.raw_file, args.processed_file, args.knwl_ref_file, args.resp_ref_file)
elif args.func == "get_knwl_gen_prompts":
prompt_selection_for_knowledge_generation(
args.test_file, args.train_file, args.model_file,
args.processed_file, args.data_type)
elif args.func == "get_resp_gen_prompts":
prompt_selection_for_response_generation(
args.train_file, args.processed_file, args.seed)
elif args.func == "prepare_input":
prepare_input_for_response_generation(
args.test_file, args.knwl_gen_file, args.processed_file)
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Prompting the pretrained language model to generate knowledge/response"""
import json
import torch
import requests
from nltk import word_tokenize
from megatron import get_args
from megatron import print_rank_0
from megatron import get_tokenizer
from megatron.core import mpu
from megatron.model import GPTModel
from megatron.training import get_model
from megatron.checkpointing import load_checkpoint
from megatron.initialize import initialize_megatron
from megatron.text_generation import generate_and_post_process
def call_model_api(inputs, tokens_to_generate):
"""Calling the model api to get the output generations"""
args = get_args()
# The following is an example of using the Megatron API
# You can also implement your own API function to place this part
headers = {'Content-Type': 'application/json; charset=UTF-8'}
data = {"prompts": [inputs], "tokens_to_generate": tokens_to_generate, "top_k": 1}
data_json = json.dumps(data)
outputs = requests.put(args.megatron_api_url, headers=headers, data=data_json).json()["text"][0]
input_len = len(inputs)
outputs = outputs[input_len:]
outputs = outputs.split("\n")[0].strip()
return outputs
def read_prompts(prompt_path, prompt_type, n_example):
"""Read prompt data"""
if prompt_type == "knowledge":
# prompts for the knowledge generation
prompt_examples_dict = {}
# read prompt_path
with open(prompt_path, "r") as f:
for i, line in enumerate(f):
line = line.strip()
line_dict = json.loads(line)
key = list(line_dict.keys())[0]
if key not in prompt_examples_dict:
prompt_examples = line_dict[key]
prompt = ""
for instance in prompt_examples:
instance = instance.strip()
prompt += instance + " \n"
prompt_examples_dict[key] = prompt
return prompt_examples_dict
else:
# prompts for the response generation
# read prompt_path
prompt = ""
with open(prompt_path, "r") as f:
prompt_examples = f.readlines()
prompt_examples = prompt_examples[:n_example]
for instance in prompt_examples:
instance = instance.strip()
prompt += instance + " \n"
return prompt
def generate_samples_by_calling_api():
""" Generate outputs by calling"""
args = get_args()
assert args.prompt_type in ["knowledge", "response"], \
"Please input a correct prompt type!"
if args.prompt_type == "knowledge":
# read knowledge generation prompts
knwl_gen_prompt_dict = read_prompts(
args.prompt_file, args.prompt_type, args.num_prompt_examples)
else:
resp_gen_prompt = read_prompts(
args.prompt_file, args.prompt_type, args.num_prompt_examples)
# read the test data
fname = open(args.sample_input_file, "r")
test_sample_list = fname.readlines()
# create output file
fname_out = open(args.sample_output_file, "w")
# call the api to get the output generations
for test_sample in test_sample_list:
test_sample = test_sample.strip()
splits = test_sample.split("\t")
topic = splits[0]
# prepare the inputs for the api
if args.prompt_type == "knowledge":
## inputs = prompt + current test
# get the prompt
turns = splits[1].split(" [SEP] ")
last_turn = turns[-1]
key = topic + " " + last_turn
inputs = knwl_gen_prompt_dict[key]
# add current test
inputs += "( " + last_turn + " ) " + topic + " =>"
else:
# inputs = prompt + current test
# get the prompt
inputs = resp_gen_prompt
# add current test
turns = splits[1].split(" [SEP] ")
knowledge = splits[2]
last_turn = turns[-1]
last_turn = " ".join(word_tokenize(last_turn))
knowledge = " ".join(word_tokenize(knowledge))
knowledge = knowledge.strip()
last_turn = last_turn.strip()
inputs += "Topic: " + topic + ". "
inputs += "User says: " + last_turn + " "
inputs += "We know that: " + knowledge + " "
inputs += "System replies:"
# get the output generations from the api,
# and write to the output file
generations = call_model_api(inputs, args.out_seq_length)
fname_out.write(generations)
fname_out.write("\n")
fname.close()
fname_out.close()
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
print_rank_0('building GPT model ...')
model = GPTModel(
num_tokentypes=0,
parallel_output=True,
pre_process=pre_process,
post_process=post_process
)
return model
def generate_samples_by_prompting_input_from_file(model):
"""Prompt a pretrained language model to generate knowledge/response"""
# get tokenizer
args = get_args()
tokenizer = get_tokenizer()
# Read the sample file and open the output file.
assert args.sample_input_file is not None, \
'sample input file is not provided.'
if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0:
fname = open(args.sample_input_file, "r")
all_raw_text = fname.readlines()
input_count = len(all_raw_text)
if args.sample_output_file is None:
sample_output_file = args.sample_input_file + ".out"
print('`sample-output-file` not specified, setting '
'it to {}'.format(sample_output_file))
else:
sample_output_file = args.sample_output_file
fname_out = open(sample_output_file, "w")
# only two prompt types (i.e., knowledge and response) are allowed
assert args.prompt_type in ["knowledge", "response"], \
"Please input a correct prompt type!"
# Read the prompt file
if args.prompt_type == "knowledge":
# read the prompts for the knowledge generation
prompt_examples_dict = {}
with open(args.prompt_file, "r") as f:
for i, line in enumerate(f):
line = line.strip()
line_dict = json.loads(line)
key = list(line_dict.keys())[0]
# get the prompt examples based on the key
if key not in prompt_examples_dict:
prompt_examples = line_dict[key]
prompt = ""
for instance in prompt_examples:
instance = instance.strip()
prompt += instance + " \n"
prompt_examples_dict[key] = prompt
else:
# read the prompts for the response generation
# prompts are fixed for all test samples
with open(args.prompt_file, "r") as f:
prompt_examples = f.readlines()
prompt_examples = prompt_examples[:args.num_prompt_examples]
prompt = ""
for instance in prompt_examples:
instance = instance.strip()
prompt += instance + " \n"
input_pos = 0
model.eval()
# perform prompting
with torch.no_grad():
while True:
raw_text_len = 0
if mpu.is_pipeline_first_stage() \
and mpu.get_tensor_model_parallel_rank() == 0:
input_str = all_raw_text[input_pos]
input_str = input_str.strip()
splits = input_str.split("\t")
topic = splits[0]
if args.prompt_type == "knowledge":
# first add the prompt into the raw_text
turns = splits[1].split(" [SEP] ")
last_turn = turns[-1]
key = topic + " " + last_turn
raw_text = prompt_examples_dict[key]
# construct inputs for knowledge generation
# then add the constructed inputs into the raw_text
raw_text += "( " + last_turn + " ) " + topic + " =>"
else:
# first add the prompt into the raw_text
raw_text = prompt
# construct inputs for response generation
# then add the constructed inputs into the raw_text
turns = splits[1].split(" [SEP] ")
knowledge = splits[2]
last_turn = turns[-1]
last_turn = " ".join(word_tokenize(last_turn))
knowledge = " ".join(word_tokenize(knowledge))
knowledge = knowledge.strip()
last_turn = last_turn.strip()
raw_text += "Topic: " + topic + ". "
raw_text += "User says: " + last_turn + " "
raw_text += "We know that: " + knowledge + " "
raw_text += "System replies:"
input_pos += 1
raw_text_len = len(raw_text)
else:
raw_text = "EMPTY TEXT"
if input_pos % 100 == 0:
print_rank_0("input_pos: %d" % input_pos)
outputs = generate_and_post_process(
model=model,
prompts=[raw_text],
tokens_to_generate=args.out_seq_length,
top_k_sampling=1)
prompts_plus_generations = outputs[0]
prompts_plus_generations = prompts_plus_generations[0]
# write the generated output to the output file
if mpu.get_tensor_model_parallel_rank() == 0:
if mpu.is_pipeline_first_stage():
generations = prompts_plus_generations[raw_text_len:]
generations = generations.split("\n")[0]
generations = generations.strip()
fname_out.write(generations)
fname_out.write("\n")
raw_text = None
if input_pos == input_count:
return
def main():
args = get_args()
if args.api_prompt:
# obtain the generations by calling the api
generate_samples_by_calling_api()
return
if args.num_layers_per_virtual_pipeline_stage is not None:
print("Interleaved pipeline schedule is not yet supported for text generation.")
exit()
# Set up model and load checkpoint.
model = get_model(model_provider, wrap_with_ddp=False)
if args.load is not None:
_ = load_checkpoint(model, None, None)
assert len(model) == 1, "Above condition should have caught this"
model = model[0]
# perform the prompting
generate_samples_by_prompting_input_from_file(model)
## End-to-End Training of Neural Retrievers for Open-Domain Question Answering
Below we present the steps to run unsupervised and supervised trainining and evaluation of the retriever for [open domain question answering](https://arxiv.org/abs/2101.00408).
## Retriever Training
#### Unsupervised pretraining
1. Use `tools/preprocess_data.py` to preprocess the dataset for Inverse Cloze Task (ICT), which we call unsupervised pretraining. This script takes as input a corpus in loose JSON format and creates fixed-size blocks of text as the fundamental units of data. For a corpus like Wikipedia, this will mean multiple sentences per block and multiple blocks per document. Run [`tools/preprocess_data.py`](../../tools/preprocess_data.py) to construct one or more indexed datasets with the `--split-sentences` argument to make sentences the basic unit. We construct two datasets, one with the title of every document and another with the body.
<pre>
python tools/preprocess_data.py \
--input /path/to/corpus.json \
--json-keys text title \
--split-sentences \
--tokenizer-type BertWordPieceLowerCase \
--vocab-file /path/to/vocab.txt \
--output-prefix corpus_indexed \
--workers 10
</pre>
2. The [`examples/pretrain_ict.sh`](../../examples/pretrain_ict.sh) script runs a single GPU 217M parameter biencoder model for ICT retriever training. Single GPU training is primarily intended for debugging purposes, as the code is developed for distributed training. The script uses a pretrained BERT model and we use a total of batch size of 4096 for the ICT training.
3. Evaluate the pretrained ICT model using [`examples/evaluate_retriever_nq.sh`](../../examples/evaluate_retriever_nq.sh) for [Google's Natural Questions Open dataset](https://arxiv.org/pdf/1906.00300.pdf).
#### Supervised finetuning
1. Use the above pretrained ICT model to finetune using [Google's Natural Questions Open dataset](https://github.com/google-research/language/tree/master/language/orqa). The script [`examples/finetune_retriever_distributed.sh`](../../examples/finetune_retriever_distributed.sh) provides an example for how to perform the training. Our finetuning process includes retriever score scaling and longer training (80 epochs) on top [DPR training](https://arxiv.org/abs/2004.04906).
2. Evaluate the finetuned model using the same evaluation script as mentioned above for the unsupervised model.
More details on the retriever are available in [our paper](https://arxiv.org/abs/2101.00408).
## Reader Training
The reader component will be available soon.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Main tasks functionality."""
from megatron import get_args, print_rank_0
from megatron.indexer import IndexBuilder
from tasks.orqa.evaluate_utils import ORQAEvaluator
def main():
"""
Main program
"""
args = get_args()
"""
Create a BlockData data structure by running an IndexBuilder over an
ICT Dataset and then evaluate on NQ task
"""
print_rank_0("Starting index builder!")
index_builder = IndexBuilder()
index_builder.build_and_save_index()
print_rank_0("Build and save indices: done!")
print_rank_0("Starting evaluations!")
# Set up the model and evaluator
evaluator = ORQAEvaluator()
# Run evaluation
if args.qa_data_dev is not None:
evaluator.evaluate(args.qa_data_dev, "DEV")
if args.qa_data_test is not None:
evaluator.evaluate(args.qa_data_test, "TEST")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment