Commit 0211193c authored by zhuwenwen's avatar zhuwenwen
Browse files

initial llama

parents
Pipeline #509 failed with stages
in 0 seconds
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Finetune utilities."""
from functools import partial
import sys
import torch
from megatron import get_args, get_num_microbatches
from megatron import print_rank_0
from megatron import get_timers
from megatron import mpu
from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
from megatron.model import ModelType
from megatron.training import evaluate_and_print_results
from megatron.training import setup_model_and_optimizer
from megatron.training import train_step
from megatron.training import training_log
from megatron.utils import average_losses_across_data_parallel_group
from megatron.utils import calc_params_l2_norm
from megatron.utils import check_adlr_autoresume_termination
def process_batch(batch):
"""Process batch and produce inputs for the model."""
args = get_args()
tokens = batch['text'].long().cuda().contiguous()
types = batch['types'].long().cuda().contiguous()
labels = batch['label'].long().cuda().contiguous()
attention_mask = batch['padding_mask'].float().cuda().contiguous()
if args.fp16:
attention_mask = attention_mask.half()
return tokens, types, labels, attention_mask
def cross_entropy_loss_func(labels, output_tensor):
logits = output_tensor
# Cross-entropy loss.
loss_func = torch.nn.CrossEntropyLoss()
loss = loss_func(logits.contiguous().float(), labels)
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'lm loss': averaged_loss[0]}
def _cross_entropy_forward_step(batch, model):
"""Simple forward step with cross-entropy loss."""
timers = get_timers()
# Get the batch.
timers('batch-generator').start()
try:
batch_ = next(batch)
except BaseException:
batch_ = batch
tokens, types, labels, attention_mask = process_batch(batch_)
timers('batch-generator').stop()
# Forward model.
output_tensor = model(tokens, attention_mask, tokentype_ids=types)
return output_tensor, partial(cross_entropy_loss_func, labels)
def build_data_loader(dataset, micro_batch_size, num_workers, drop_last,
task_collate_fn=None):
"""Data loader. Note that batch-size is the local (per GPU) batch-size."""
# Sampler.
world_size = mpu.get_data_parallel_world_size()
rank = mpu.get_data_parallel_rank()
sampler = torch.utils.data.distributed.DistributedSampler(
dataset, num_replicas=world_size, rank=rank)
# Data loader. Note that batch size is the per GPU batch size.
data_loader = torch.utils.data.DataLoader(dataset,
batch_size=micro_batch_size,
sampler=sampler,
shuffle=False,
num_workers=num_workers,
drop_last=drop_last,
pin_memory=True,
collate_fn=task_collate_fn)
return data_loader
def _build_infinite_size_dataloader(dataloader):
"""Build a looped dataloader with infinite size."""
iterator = dataloader.__iter__()
while True:
try:
yield iterator.__next__()
except StopIteration:
iterator = dataloader.__iter__()
def _build_train_valid_dataloaders(train_dataset, valid_dataset,
task_collate_fn=None):
"""Traing and validation dataloaders."""
args = get_args()
print_rank_0('building train and validation dataloaders ...')
# Training dataset.
train_dataloader = build_data_loader(train_dataset, args.micro_batch_size,
args.num_workers, not args.keep_last,
task_collate_fn)
# Set the training iterations.
args.train_iters_per_epoch = len(train_dataloader)
args.train_iters = args.epochs * args.train_iters_per_epoch
# Validation dataset. For this dataset, we do not need to set up
# shuffling so we can just use a simple infinite loop.
valid_dataloader_ = build_data_loader(valid_dataset, args.micro_batch_size,
args.num_workers, not args.keep_last,
task_collate_fn)
valid_dataloader = _build_infinite_size_dataloader(valid_dataloader_)
# Now that we've built the data loaders, set batch_size arguments
# to the actual batch size the model will see for this dataset.
# This is necessary so pipeline transfers know what size they are
# and the LR schedule, which is based on samples seen, gets set
# correctly.
args.orig_micro_batch_size = args.micro_batch_size
args.orig_global_batch_size = args.global_batch_size
if hasattr(train_dataset, 'sample_multiplier'):
# If our dataset as a sample_multiplier attribute that means
# each "sample" from the dataset actually has multiple samples
# that will collapse into the batch dimension (for example in
# the RACE dataset that has several options), we need to
# account for that when setting the micro batch size.
args.micro_batch_size *= train_dataset.sample_multiplier
args.global_batch_size *= train_dataset.sample_multiplier
return train_dataloader, valid_dataloader
def _train(model, optimizer, lr_scheduler, forward_step,
train_dataloader, valid_dataloader, end_of_epoch_callback):
"""Train the model."""
args = get_args()
timers = get_timers()
assert get_num_microbatches() == 1, "finetuning with gradient accumulation doesn't currently work"
# Turn on training mode which enables dropout.
for m in model:
m.train()
# Tracking loss.
losses_dict_sum = {}
# Starting epoch and iteration
start_epoch = args.iteration // args.train_iters_per_epoch
start_iteration = args.iteration % args.train_iters_per_epoch
iteration = args.iteration
# Memory reporting flag.
report_memory_flag = True
# For each remaining epoch
timers('interval-time').start()
for epoch in range(start_epoch, args.epochs):
print_rank_0('working on epoch {} ...'.format(epoch + 1))
# Set the data loader epoch to shuffle the index iterator.
train_dataloader.sampler.set_epoch(args.seed + epoch)
# For all the batches in the dataset.
for iteration_, batch in enumerate(train_dataloader):
# Ignore the iterations before starting value
if iteration_ < start_iteration:
continue
# Set to zero so the next epoch does not skip any batches.
start_iteration = 0
# Train for one step.
out = train_step(forward_step, batch, model, optimizer, lr_scheduler)
losses_dict, skipped_iter, grad_norm, num_zeros_in_grad = out
iteration += 1
# Logging.
params_norm = None
if args.log_params_norm:
params_norm = calc_params_l2_norm(model)
report_memory_flag = training_log(losses_dict, losses_dict_sum,
optimizer.param_groups[0]['lr'],
iteration,
optimizer.get_loss_scale().item(),
report_memory_flag, skipped_iter,
grad_norm, params_norm, num_zeros_in_grad)
# Autoresume
if args.adlr_autoresume and \
(iteration % args.adlr_autoresume_interval == 0):
check_adlr_autoresume_termination(iteration, model,
optimizer, lr_scheduler)
# Checkpointing
saved_checkpoint = False
if args.save and args.save_interval and \
iteration % args.save_interval == 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler)
saved_checkpoint = True
# Evaluation
if args.eval_interval and iteration % args.eval_interval == 0:
prefix = 'iteration {}'.format(iteration)
evaluate_and_print_results(prefix, forward_step,
valid_dataloader, model,
iteration, False)
# Exiting based on iterations
if args.exit_interval and iteration % args.exit_interval == 0:
if not saved_checkpoint:
save_checkpoint(iteration, model, optimizer, lr_scheduler)
torch.distributed.barrier()
print_rank_0('exiting program at iteration {}'.format(iteration))
sys.exit()
# Checkpointing at the end of each epoch.
if args.save:
save_checkpoint(iteration, model, optimizer, lr_scheduler)
# Callback at the end of each epoch.
if end_of_epoch_callback is not None:
end_of_epoch_callback(model, epoch)
def finetune(train_valid_datasets_provider, model_provider,
model_type=ModelType.encoder_or_decoder,
forward_step=_cross_entropy_forward_step,
end_of_epoch_callback_provider=None,
task_collate_fn=None):
"""Main finetune function used across all tasks."""
args = get_args()
timers = get_timers()
assert args.rampup_batch_size is None, \
'batch size scaling is not supported for finetuning'
# Train and validation data loaders.
timers('train/valid/test dataset/dataloder').start()
if args.epochs > 0:
train_dataset, valid_dataset = train_valid_datasets_provider()
train_dataloader, valid_dataloader = _build_train_valid_dataloaders(
train_dataset, valid_dataset, task_collate_fn)
else:
args.train_iters = 0
timers('train/valid/test dataset/dataloder').stop()
# Build calback function.
timers('callback function').start()
end_of_epoch_callback = None
if end_of_epoch_callback_provider is not None:
end_of_epoch_callback = end_of_epoch_callback_provider()
timers('callback function').stop()
# Build model, optimizer and learning rate scheduler.
timers('model and optimizer').start()
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider, model_type)
timers('model and optimizer').stop()
# If pretrained checkpoint is provided and we have not trained for
# any iteration (i.e., iteration is zero), then load the pretrained
# checkpoint.
timers('pretrained checkpoint').start()
if args.iteration == 0 and args.pretrained_checkpoint is not None:
original_load = args.load
args.load = args.pretrained_checkpoint
original_rng = args.no_load_rng
args.no_load_rng = True
_ = load_checkpoint(model, None, None)
args.load = original_load
args.no_load_rng = original_rng
# This is critical when only model is loaded. We should make sure
# main parameters are also updated.
optimizer.reload_model_params()
timers('pretrained checkpoint').stop()
# Print setup timing.
print_rank_0('done with setups ...')
timers.log(['train/valid/test dataset/dataloder', 'callback function',
'model and optimizer', 'pretrained checkpoint'])
print_rank_0('training ...')
# Finetune the model.
if args.epochs > 0:
_train(model, optimizer, lr_scheduler, forward_step,
train_dataloader, valid_dataloader, end_of_epoch_callback)
# Or just evaluate.
else:
if end_of_epoch_callback is not None:
print_rank_0('evaluation only mode, setting epoch to -1')
end_of_epoch_callback(model, epoch=-1, output_predictions=True)
print_rank_0('done :-)')
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""GLUE dataset."""
from abc import ABC
from abc import abstractmethod
from torch.utils.data import Dataset
from megatron import print_rank_0
from tasks.data_utils import build_sample
from tasks.data_utils import build_tokens_types_paddings_from_text
class GLUEAbstractDataset(ABC, Dataset):
"""GLUE base dataset class."""
def __init__(self, task_name, dataset_name, datapaths,
tokenizer, max_seq_length):
# Store inputs.
self.task_name = task_name
self.dataset_name = dataset_name
self.tokenizer = tokenizer
self.max_seq_length = max_seq_length
print_rank_0(' > building {} dataset for {}:'.format(self.task_name,
self.dataset_name))
# Process the files.
string = ' > paths:'
for path in datapaths:
string += ' ' + path
print_rank_0(string)
self.samples = []
for datapath in datapaths:
self.samples.extend(self.process_samples_from_single_path(datapath))
print_rank_0(' >> total number of samples: {}'.format(
len(self.samples)))
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
raw_sample = self.samples[idx]
ids, types, paddings = build_tokens_types_paddings_from_text(
raw_sample['text_a'], raw_sample['text_b'],
self.tokenizer, self.max_seq_length)
sample = build_sample(ids, types, paddings,
raw_sample['label'], raw_sample['uid'])
return sample
@abstractmethod
def process_samples_from_single_path(self, datapath):
"""Abstract method that takes a single path / filename and
returns a list of dataset samples, each sample being a dict of
{'text_a': string, 'text_b': string, 'label': int, 'uid': int}
"""
pass
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""GLUE finetuning/evaluation."""
from megatron import get_args
from megatron import print_rank_0
from megatron import get_tokenizer
from megatron import mpu
from megatron.model.classification import Classification
from tasks.eval_utils import accuracy_func_provider
from tasks.finetune_utils import finetune
def glue_classification(num_classes, Dataset,
name_from_datapath_func):
def train_valid_datasets_provider():
"""Build train and validation dataset."""
args = get_args()
tokenizer = get_tokenizer()
train_dataset = Dataset('training', args.train_data,
tokenizer, args.seq_length)
valid_dataset = Dataset('validation', args.valid_data,
tokenizer, args.seq_length)
return train_dataset, valid_dataset
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
args = get_args()
print_rank_0('building classification model for {} ...'.format(
args.task))
model = Classification(num_classes=num_classes, num_tokentypes=2,
pre_process=pre_process, post_process=post_process)
return model
def metrics_func_provider():
"""Privde metrics callback function."""
def single_dataset_provider(datapath):
args = get_args()
tokenizer = get_tokenizer()
name = name_from_datapath_func(datapath)
return Dataset(name, [datapath], tokenizer, args.seq_length)
return accuracy_func_provider(single_dataset_provider)
"""Finetune/evaluate."""
finetune(train_valid_datasets_provider, model_provider,
end_of_epoch_callback_provider=metrics_func_provider)
def main():
args = get_args()
if args.task == 'MNLI':
num_classes = 3
from tasks.glue.mnli import MNLIDataset as Dataset
def name_from_datapath(datapath):
return datapath.split('MNLI')[-1].strip(
'.tsv').strip('/').replace('_', '-')
elif args.task == 'QQP':
num_classes = 2
from tasks.glue.qqp import QQPDataset as Dataset
def name_from_datapath(datapath):
return datapath.split('QQP')[-1].strip(
'.tsv').strip('/').replace('_', '-')
else:
raise NotImplementedError('GLUE task {} is not implemented.'.format(
args.task))
glue_classification(num_classes, Dataset, name_from_datapath)
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MNLI dataset."""
from megatron import print_rank_0
from tasks.data_utils import clean_text
from .data import GLUEAbstractDataset
LABELS = {'contradiction': 0, 'entailment': 1, 'neutral': 2}
class MNLIDataset(GLUEAbstractDataset):
def __init__(self, name, datapaths, tokenizer, max_seq_length,
test_label='contradiction'):
self.test_label = test_label
super().__init__('MNLI', name, datapaths,
tokenizer, max_seq_length)
def process_samples_from_single_path(self, filename):
""""Implement abstract method."""
print_rank_0(' > Processing {} ...'.format(filename))
samples = []
total = 0
first = True
is_test = False
with open(filename, 'r') as f:
for line in f:
row = line.strip().split('\t')
if first:
first = False
if len(row) == 10:
is_test = True
print_rank_0(
' reading {}, {} and {} columns and setting '
'labels to {}'.format(
row[0].strip(), row[8].strip(),
row[9].strip(), self.test_label))
else:
print_rank_0(' reading {} , {}, {}, and {} columns '
'...'.format(
row[0].strip(), row[8].strip(),
row[9].strip(), row[-1].strip()))
continue
text_a = clean_text(row[8].strip())
text_b = clean_text(row[9].strip())
unique_id = int(row[0].strip())
label = row[-1].strip()
if is_test:
label = self.test_label
assert len(text_a) > 0
assert len(text_b) > 0
assert label in LABELS
assert unique_id >= 0
sample = {'text_a': text_a,
'text_b': text_b,
'label': LABELS[label],
'uid': unique_id}
total += 1
samples.append(sample)
if total % 50000 == 0:
print_rank_0(' > processed {} so far ...'.format(total))
print_rank_0(' >> processed {} samples.'.format(len(samples)))
return samples
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""QQP dataset."""
from megatron import print_rank_0
from tasks.data_utils import clean_text
from .data import GLUEAbstractDataset
LABELS = [0, 1]
class QQPDataset(GLUEAbstractDataset):
def __init__(self, name, datapaths, tokenizer, max_seq_length,
test_label=0):
self.test_label = test_label
super().__init__('QQP', name, datapaths,
tokenizer, max_seq_length)
def process_samples_from_single_path(self, filename):
""""Implement abstract method."""
print_rank_0(' > Processing {} ...'.format(filename))
samples = []
total = 0
first = True
is_test = False
with open(filename, 'r') as f:
for line in f:
row = line.strip().split('\t')
if first:
first = False
if len(row) == 3:
is_test = True
print_rank_0(' reading {}, {}, and {} columns and '
'setting labels to {}'.format(
row[0].strip(), row[1].strip(),
row[2].strip(), self.test_label))
else:
assert len(row) == 6
print_rank_0(' reading {}, {}, {}, and {} columns'
' ...'.format(
row[0].strip(), row[3].strip(),
row[4].strip(), row[5].strip()))
continue
if is_test:
assert len(row) == 3, 'expected length 3: {}'.format(row)
uid = int(row[0].strip())
text_a = clean_text(row[1].strip())
text_b = clean_text(row[2].strip())
label = self.test_label
assert len(text_a) > 0
assert len(text_b) > 0
else:
if len(row) == 6:
uid = int(row[0].strip())
text_a = clean_text(row[3].strip())
text_b = clean_text(row[4].strip())
label = int(row[5].strip())
else:
print_rank_0('***WARNING*** index error, '
'skipping: {}'.format(row))
continue
if len(text_a) == 0:
print_rank_0('***WARNING*** zero length a, '
'skipping: {}'.format(row))
continue
if len(text_b) == 0:
print_rank_0('***WARNING*** zero length b, '
'skipping: {}'.format(row))
continue
assert label in LABELS
assert uid >= 0
sample = {'uid': uid,
'text_a': text_a,
'text_b': text_b,
'label': label}
total += 1
samples.append(sample)
if total % 50000 == 0:
print_rank_0(' > processed {} so far ...'.format(total))
print_rank_0(' >> processed {} samples.'.format(len(samples)))
return samples
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Main tasks functionality."""
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir)))
from megatron import get_args
from megatron.initialize import initialize_megatron
def get_tasks_args(parser):
"""Provide extra arguments required for tasks."""
group = parser.add_argument_group(title='tasks')
group.add_argument('--task', type=str, required=True,
help='Task name.')
group.add_argument('--epochs', type=int, default=None,
help='Number of finetunning epochs. Zero results in '
'evaluation only.')
group.add_argument('--pretrained-checkpoint', type=str, default=None,
help='Pretrained checkpoint used for finetunning.')
group.add_argument('--keep-last', action='store_true',
help='Keep the last batch (maybe incomplete) in'
'the data loader')
group.add_argument('--train-data', nargs='+', default=None,
help='Whitespace separated paths or corpora names '
'for training.')
group.add_argument('--valid-data', nargs='*', default=None,
help='path(s) to the validation data.')
group.add_argument('--overlapping-eval', type=int, default=32,
help='Sliding window for overlapping evaluation.')
group.add_argument('--strict-lambada', action='store_true',
help='Use more difficult formulation of lambada.')
# Retriever args
group.add_argument('--qa-data-dev', type=str, default=None,
help='Path to the QA dataset dev file.')
group.add_argument('--qa-data-test', type=str, default=None,
help='Path to the QA dataset test file.')
# Faiss arguments for retriever
group.add_argument('--faiss-use-gpu', action='store_true',
help='Whether create the FaissMIPSIndex on GPU')
group.add_argument('--faiss-match', type=str, default='string', \
choices=['regex', 'string'], help="Answer matching '\
'logic type")
group.add_argument('--faiss-topk-retrievals', type=int, default=100,
help='Number of blocks to use as top-k during retrieval')
# finetune for retriever
group.add_argument('--eval-micro-batch-size', type=int, default=None,
help='Eval Batch size per model instance (local batch '
'size). Global batch size is local batch size '
'times data parallel size.')
group.add_argument('--train-with-neg', action='store_true',
help='Whether to use negative examples during model '
'training')
group.add_argument('--train-hard-neg', type=int, default=0,
help='Number of hard negative exmaples to use during '
'training')
# parameters for Av.rank validation method
# Following options/arguments have been taken directly from DPR codebase
group.add_argument('--val-av-rank-hard-neg', type=int, default=30,
help='Av.rank validation: how many hard negatives to'
' take from each question pool')
group.add_argument('--val-av-rank-other-neg', type=int, default=30,
help='Av.rank validation: how many other negatives to'
' take from each question pool')
return parser
if __name__ == '__main__':
initialize_megatron(extra_args_provider=get_tasks_args)
args = get_args()
if args.num_layers_per_virtual_pipeline_stage is not None:
print("Interleaved pipeline schedule is not yet supported for downstream tasks.")
exit()
if args.task == 'RACE':
from race.finetune import main
elif args.task in ['MNLI', 'QQP']:
from glue.finetune import main
elif args.task in ['LAMBADA', 'WIKITEXT103']:
from zeroshot_gpt.evaluate import main
elif args.task in ['ICT-ZEROSHOT-NQ', 'RETRIEVER-EVAL']:
from orqa.evaluate_orqa import main
elif args.task in ['RET-FINETUNE-NQ']:
from orqa.supervised.finetune import main
else:
raise NotImplementedError('Task {} is not implemented.'.format(
args.task))
main()
# Multi-Stage Prompting for Knowledgeable Dialogue Generation
Below we present the steps to run our multi-stage dialogue prompting (MSDP) framework.
## Multi-Stage Dialogue Prompting
### Data Preparation
1. Dataset Download: [Wizard of Wikipedia](https://parl.ai/projects/wizard_of_wikipedia/) and [Wizard of Internet](https://parl.ai/projects/sea/)
2. Data Processing: We provide the script to run the [`data processing`](../../examples/msdp/data_processing.sh) of the datatsets.
### Stage-1: Prompting for Knowledge Generation
1. We provide the script to perform the [`first-stage prompting`](../../examples/msdp/prompt_knwl_gen.sh) for the knowledge generation.
2. We provide the [`evaluation script`](../../examples/msdp/eval_knwl_generation.sh) for the automatic evaluation (i.e., F1, BLEU, METEOR, and ROUGE-L) of the knowledge generation.
### Stage-2: Prompting for Response Generation
1. We provide the script to [`prepare the input file`](../../examples/msdp/prep_resp_gen.sh) for the response generation (based on the previously generated knowledge file).
2. We provide the script to perform the [`second-stage prompting`](../../examples/msdp/prompt_resp_gen.sh) for the response generation.
3. We provide the [`evaluation script`](../../examples/msdp/eval_resp_generation.sh) for the automatic evaluation (i.e., F1, KF1, BLEU, METEOR, and ROUGE-L) of the response generation.
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Model evaluation"""
from megatron import get_args
from megatron import print_rank_0
from tasks.msdp.metrics import F1Metric
from tqdm import tqdm
def evaluate_f1(guess_file, answer_file):
"""Evaluating F1 Score"""
guess_list = []
print_rank_0('reading %s' % guess_file)
with open(guess_file, "r") as f:
for i, line in enumerate(tqdm(f)):
line = line.strip()
if "<|endoftext|>" in line:
line = line.replace("<|endoftext|>", "")
guess_list.append(line)
answer_list = []
print_rank_0('reading %s' % answer_file)
with open(answer_file, "r") as f:
for i, line in enumerate(tqdm(f)):
line = line.strip()
if line == "no_passages_used":
line = ""
answer_list.append(line)
assert len(guess_list) == len(answer_list), \
"lengths of guess and answer are different!"
precision, recall, f1 = F1Metric.compute_all_pairs(guess_list, answer_list)
print_rank_0('Precision: %.4f; recall: %.4f; f1: %.4f' % (precision, recall, f1))
print_rank_0('done :-)')
def main():
args = get_args()
evaluate_f1(args.guess_file, args.answer_file)
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Run multi-stage dialogue prompting (MSDP)."""
import os
import sys
sys.path.append(os.path.abspath(os.path.join(
os.path.join(os.path.dirname(__file__), os.path.pardir), os.path.pardir)))
from megatron import get_args
from megatron.initialize import initialize_megatron
def get_tasks_args(parser):
"""Provide extra arguments required for tasks."""
group = parser.add_argument_group(title='tasks')
# parameters for the knowledgeable dialogue generation
group.add_argument('--task', type=str, required=True,
help='Task name.')
group.add_argument("--sample-input-file", type=str, default=None,
help='Get input from file instead of interactive mode, '
'each line is an input.')
group.add_argument("--sample-output-file", type=str, default=None,
help='Output file got from --sample-input-file')
group.add_argument('--prompt-file', type=str, default=None,
help='prompting file')
group.add_argument('--prompt-type', type=str, default=None,
choices=['knowledge', 'response'],
help='prompt type (knowledge or response)')
group.add_argument('--num-prompt-examples', type=int, default=10,
help='number of prompt examples')
group.add_argument('--guess-file', type=str, default=None,
help='datapath for generated sentences')
group.add_argument('--answer-file', type=str, default=None,
help='datapath for golden sentences')
group.add_argument('--out-seq-length', type=int, default=100,
help='output sequence length')
group.add_argument('--api-prompt', default=False, action="store_true",
help='setup model api for prompting')
group.add_argument('--megatron-api-url', type=str, default=None,
help='url of the megatron api')
return parser
if __name__ == '__main__':
initialize_megatron(extra_args_provider=get_tasks_args)
args = get_args()
if args.num_layers_per_virtual_pipeline_stage is not None:
print("Interleaved pipeline schedule is not yet supported for downstream tasks.")
exit()
if args.task == 'MSDP-PROMPT':
from tasks.msdp.prompt import main
elif args.task == 'MSDP-EVAL-F1':
from tasks.msdp.evaluate import main
else:
raise NotImplementedError('Task {} is not implemented.'.format(
args.task))
main()
# The following code is adapted from
# https://github.com/facebookresearch/ParlAI/blob/master/parlai/core/metrics.py,
# which is licensed under the MIT license. More details on the license can be
# found at https://github.com/facebookresearch/ParlAI/blob/master/LICENSE.
"""Provides standard metric evaluations for dialog."""
from collections import Counter
from typing import List
import numpy as np
import re
re_art = re.compile(r'\b(a|an|the)\b')
re_punc = re.compile(r'[!"#$%&()*+,-./:;<=>?@\[\]\\^`{|}~_\']')
def normalize_answer(s):
"""
Lower text and remove punctuation, articles and extra whitespace.
"""
s = s.lower()
s = re_punc.sub(' ', s)
s = re_art.sub(' ', s)
s = ' '.join(s.split())
return s
class F1Metric:
"""
Helper class which computes token-level F1.
"""
@staticmethod
def _prec_recall_f1_score(pred_items, gold_items):
"""
Compute precision, recall and f1 given a set of gold and prediction items.
:param pred_items: iterable of predicted values
:param gold_items: iterable of gold values
:return: tuple (p, r, f1) for precision, recall, f1
"""
common = Counter(gold_items) & Counter(pred_items)
num_same = sum(common.values())
if num_same == 0:
return 0, 0, 0
precision = 1.0 * num_same / len(pred_items)
recall = 1.0 * num_same / len(gold_items)
f1 = (2 * precision * recall) / (precision + recall)
return precision, recall, f1
@staticmethod
def compute_each_pair(guess: str, answer: str):
if answer == "":
return None, None, None
if guess == "":
return 0, 0, 0
g_tokens = normalize_answer(guess).split()
a_tokens = normalize_answer(answer).split()
precision, recall, f1 = F1Metric._prec_recall_f1_score(g_tokens, a_tokens)
return precision, recall, f1
@staticmethod
def compute_all_pairs(guesses: List[str], answers: List[str]):
# additional augment:
assert len(guesses) == len(answers)
precision_list, recall_list, f1_list = [], [], []
for guess, answer in zip(guesses, answers):
precision, recall, f1 = F1Metric.compute_each_pair(guess, answer)
if precision is None or recall is None or f1 is None:
continue
precision_list.append(precision)
recall_list.append(recall)
f1_list.append(f1)
return np.mean(precision_list), np.mean(recall_list), np.mean(f1_list)
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Preprocessing for Wizard of Wikipedia and Wizard of Internet datasets"""
import torch
import argparse
from nltk import word_tokenize
from tqdm import tqdm
import numpy as np
import json
def get_args():
parser = argparse.ArgumentParser(description="Preprocessing")
parser.add_argument("--func", type=str, default=None,
help="choose to run which function")
parser.add_argument("--raw_file", type=str, default=None,
help="path of the input file")
parser.add_argument("--processed_file", type=str, default=None,
help="path of the output file")
parser.add_argument("--knwl_ref_file", type=str, default=None,
help="path of the knowledge reference file")
parser.add_argument("--resp_ref_file", type=str, default=None,
help="path of the knowledge reference file")
parser.add_argument("--knwl_gen_file", type=str, default=None,
help="path of the generated knowledge file")
parser.add_argument("--test_file", type=str, default=None,
help="path of the test file")
parser.add_argument("--train_file", type=str, default=None,
help="path of the train file")
parser.add_argument("--model_file", type=str, default=None,
help="path of the model file")
parser.add_argument("--data_type", type=str, default=None,
help="data types, choose one out of three types: \
wow_seen, wow_unseen, and woi")
parser.add_argument("--seed", type=int, default=1234,
help="random seed")
args = parser.parse_args()
return args
def process_wow_dataset(raw_file, processed_file, knwl_ref_file, resp_ref_file):
"""
This is a function used for processing the wizard of wikipedia (wow) dataset
Expected processed format:
topic \t dialogue context \t golden knowledge \t golden response
"""
# loading the raw data
print("> Loading data from %s" % raw_file)
with open(raw_file, "r") as fr:
dialog_data = json.load(fr)
print("> Processing data ...")
fproc = open(processed_file, "w")
fknwl = open(knwl_ref_file, "w") if knwl_ref_file else None
fresp = open(resp_ref_file, "w") if resp_ref_file else None
for i, sample in enumerate(tqdm(dialog_data)):
# get all the dialog data for a single dialog sample
dialog = sample["dialog"]
turn_list = [] # collect the dialog history
# processing for each single dialog sample
for j, turn in enumerate(dialog):
# text of each turn
text = turn["text"]
if not (text.endswith("?") or text.endswith(".") or text.endswith("!")):
text = text + "."
if j == 0:
# first turn
turn_list.append(text)
continue
speaker = turn["speaker"].lower()
if "wizard" in speaker:
checked_sentence = list(turn["checked_sentence"].values()) # knowledge
checked_passage = list(turn["checked_passage"].values()) # topic
assert len(checked_sentence) <= 1
# get the ground truth knowledge
if len(checked_sentence) > 0:
checked_sentence = checked_sentence[0]
else:
checked_sentence = "no_passages_used"
if len(checked_passage) == 1:
checked_passage = checked_passage[0]
else:
checked_passage = "no_passages_used"
# get the topic
if checked_passage != "no_passages_used":
topic = checked_passage
else:
topic = sample["chosen_topic"]
dialog_context = " [SEP] ".join(turn_list)
knowledge = checked_sentence
response = text
# add the response into the dialog history
turn_list.append(response)
# write to the output files
fproc.write(topic + "\t" + dialog_context + "\t" + \
knowledge + "\t" + response + "\n")
if fknwl:
fknwl.write(knowledge + "\n")
if fresp:
# tokenize for evaluation
response = " ".join(word_tokenize(response))
fresp.write(response + "\n")
else:
assert "apprentice" in speaker
turn_list.append(text)
fproc.close()
if fknwl:
fknwl.close()
if fresp:
fresp.close()
def process_woi_dataset(raw_file, processed_file, knwl_ref_file, resp_ref_file):
"""
This is a function used for processing the wizard of internet (woi) dataset
Expected processed format:
topic \t dialogue context \t golden knowledge \t golden response
"""
print("> Processing %s" % raw_file)
fproc = open(processed_file, "w")
fknwl = open(knwl_ref_file, "w") if knwl_ref_file else None
fresp = open(resp_ref_file, "w") if resp_ref_file else None
with open(raw_file, "r") as fr:
for i, line in tqdm(enumerate(fr)):
# read line by line, each line uses json format
line = line.strip()
item_dict = json.loads(line)
# item_dict is a dictionary
# its key is the data id, and its value contains all the data content
item_dict = item_dict.values()
item_dict = list(item_dict)[0] # len(item_dict) == 1
# get the whole dialog data for a single dialog sample
dialog_data = item_dict['dialog_history']
length = len(dialog_data)
turn_list = [] # collect the dialog history
search_text = ""
for i in range(length):
item = dialog_data[i]
action = item['action']
if action == "Wizard => SearchAgent":
search_text = item['text']
elif action == "Wizard => Apprentice":
if len(turn_list) == 0:
# first turn
turn = item['text']
turn_list.append(turn)
continue
# get the relevant content
contents = item["context"]["contents"]
selects = item["context"]["selected_contents"]
flag = selects[0][0]
selects = selects[1:]
assert len(selects) == len(contents)
# get the topic
if flag:
# no knowledge sentence is used for the response
topic = "no_topic"
knwl_sent = "no_passages_used"
else:
# we consider the search text as the topic
topic = search_text
# get the knowledge sentence
knwl_sent = ""
for content, select in zip(contents, selects):
content = content['content']
assert len(content) == len(select)
for c, s in zip(content, select):
if s:
knwl_sent = c
break
if knwl_sent == "":
# no knowledge is used for the response
topic = "no_topic"
knwl_sent = "no_passages_used"
# get dialogue context, knowledge, and response
dialog_context = " [SEP] ".join(turn_list)
response = item['text']
# processing
topic = topic.replace("\n", "").replace("\r", \
"").replace("\t", "")
dialog_context = dialog_context.replace("\n", "").replace("\r", \
"").replace("\t", "")
knwl_sent = knwl_sent.replace("\n", "").replace("\r", \
"").replace("\t", "")
response = response.replace("\n", "").replace("\r", \
"").replace("\t", "")
if topic != "no_topic":
# write to the ouput files
fproc.write(topic + "\t" + dialog_context + "\t" + \
knwl_sent + "\t" + response + "\n")
if fknwl:
fknwl.write(knwl_sent + "\n")
if fresp:
# tokenize for evaluation
response = " ".join(word_tokenize(response))
fresp.write(response + "\n")
turn_list.append(response)
elif action == "Apprentice => Wizard":
turn = item['text']
turn_list.append(turn)
else:
assert action == "SearchAgent => Wizard", \
"Please check whether you have used the correct data!"
fproc.close()
if fknwl:
fknwl.close()
if fresp:
fresp.close()
def get_database(test_datapath, train_datapath, data_type):
"""Get the database by topics"""
assert data_type in ["wow_seen", "wow_unseen", "woi"], \
"Please input a correct data type!!"
# get test data topic dictionary
print("> reading test data from %s" % test_datapath)
test_topics = {}
with open(test_datapath, "r") as f:
for i, line in enumerate(f):
line = line.strip()
splits = line.split("\t")
topic = splits[0]
test_topics[topic] = True
print("> reading data from %s" % train_datapath)
train_data_by_topic = {}
dialog_data_by_topic = {}
dialog_examples = []
with open(train_datapath, "r") as f:
for i, line in enumerate(f):
line = line.strip()
splits = line.split("\t")
topic = splits[0]
turns = splits[1].split(" [SEP] ")[-3:]
knowledge = splits[2]
response = splits[3]
# filtering data samples
if knowledge == "no_passages_used":
# when no knowledge is used
continue
if data_type != "wow_seen" and ("(" in knowledge or ")" in knowledge):
# when bracket exists in the knowledge
continue
if data_type != "wow_seen" and topic not in knowledge:
# when topic does not exist in the knowledge
continue
# get the instance
last_turn = turns[-1]
instance = "( " + last_turn + " ) " + topic + " => " + knowledge
# construct dialog example
dialog_example = ""
if data_type != "wow_seen":
dialog_example += "( " + topic + " ) "
for i, turn in enumerate(turns):
if i != 0:
dialog_example += " "
dialog_example += turn
# check overlaps
if topic in test_topics:
if topic not in train_data_by_topic:
train_data_by_topic[topic] = [instance]
else:
train_data_by_topic[topic].append(instance)
if topic not in dialog_data_by_topic:
dialog_data_by_topic[topic] = [dialog_example]
else:
dialog_data_by_topic[topic].append(dialog_example)
else:
# filtering data samples
if len(knowledge.split()) > 20:
# knowledge is too long
continue
if knowledge.startswith("It") or knowledge.startswith("it") or \
knowledge.startswith("This") or knowledge.startswith("this"):
continue
# append all the data into dialogue examples list
dialog_examples.append((topic, dialog_example, instance))
return train_data_by_topic, dialog_data_by_topic, dialog_examples
emb_dict = {}
def select_prompts_based_on_similarity(
query, dialog_list, prompt_list, topic, tokenizer, encoder, topk):
"""Select samples based on the similarity"""
with torch.no_grad():
# get the query embeddings
query_ids = tokenizer.encode(query)
query_ids = torch.LongTensor([query_ids]).cuda()
query_emb = encoder(input_ids=query_ids).pooler_output
query_emb = query_emb[0]
# calculate embeddings for the samples in the database
if topic in emb_dict:
example_embeddings = emb_dict[topic]
example_embeddings = example_embeddings.cuda()
else:
for idx, example in enumerate(dialog_list):
example_ids = tokenizer.encode(example)
example_ids = torch.LongTensor([example_ids]).cuda()
example_emb = encoder(input_ids=example_ids).pooler_output
if idx == 0:
example_embeddings = example_emb
else:
example_embeddings = torch.cat(
(example_embeddings, example_emb), dim=0)
emb_dict[topic] = example_embeddings.cpu()
# compare the similarity and select the topk samples
similarity_list = example_embeddings.matmul(query_emb)
_, indices = torch.topk(similarity_list, k=topk)
indices = indices.tolist()
indices = indices[::-1] # reverse the order
selected_prompts = []
for index in indices:
# index = index.item()
selected_prompts.append(prompt_list[index])
return selected_prompts
def prompt_selection_for_knowledge_generation(
test_datapath, train_datapath, model_path, output_prompt_path, data_type):
"""Selecting prompts for the knowledge generation"""
print("> Selecting prompts for the knowledge generation")
train_data_by_topic, dialog_data_by_topic, dialog_examples = \
get_database(test_datapath, train_datapath, data_type)
from transformers import DPRQuestionEncoderTokenizer
print("> loading tokenizer and encoder")
tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
'facebook/dpr-question_encoder-single-nq-base')
encoder = torch.load(model_path).cuda()
print("> getting dialog embeddings")
with torch.no_grad():
for idx, example in tqdm(enumerate(dialog_examples)):
dialog = example[1]
dialog_ids = tokenizer.encode(dialog)
dialog_ids = torch.LongTensor([dialog_ids]).cuda()
dialog_emb = encoder(input_ids=dialog_ids).pooler_output
if idx == 0:
dialog_embeddings = dialog_emb
else:
dialog_embeddings = torch.cat((dialog_embeddings, dialog_emb), dim=0)
print("> reading test data from %s" % test_datapath)
prompt_list_for_each_sample = []
with open(test_datapath, "r") as f:
for i, line in tqdm(enumerate(f)):
line = line.strip()
splits = line.split("\t")
topic = splits[0]
turns = splits[1].split(" [SEP] ")[-3:]
# get the query sentence
query_sent = ""
if data_type != "seen":
query_sent += "( " + topic + " ) "
for i, turn in enumerate(turns):
if i != 0:
query_sent += " "
query_sent += turn
if topic not in train_data_by_topic:
# get the query embedding
query_ids = tokenizer.encode(query_sent)
query_ids = torch.LongTensor([query_ids]).cuda()
query_emb = encoder(input_ids=query_ids).pooler_output
query_emb = query_emb[0]
# calculate the similarity
similarity_list = dialog_embeddings.matmul(query_emb)
_, indices = torch.sort(similarity_list)
indices = indices.tolist()
selected_topics = {}
selected_prompts = []
num_prompt = 0
for index in indices:
example = dialog_examples[index]
topic_temp = example[0]
if topic_temp not in selected_topics:
selected_topics[topic_temp] = True
selected_prompts.append(example[2])
num_prompt += 1
if num_prompt == 10:
break
# get the selected samples
example_list = selected_prompts[::-1]
key = topic + " " + turns[-1]
prompt_list_for_each_sample.append({key: example_list})
else:
num_data_sample = min(len(train_data_by_topic[topic]), 10)
total_example_list = train_data_by_topic[topic]
dialog_list = dialog_data_by_topic[topic]
assert len(dialog_list) == len(train_data_by_topic[topic])
# calculate the similarity
example_list = select_prompts_based_on_similarity(
query_sent, dialog_list, total_example_list,
topic, tokenizer, encoder, topk=num_data_sample)
key = topic + " " + turns[-1]
prompt_list_for_each_sample.append({key: example_list})
print("writing to %s" % output_prompt_path)
with open(output_prompt_path, "w") as f:
for instance in tqdm(prompt_list_for_each_sample):
json.dump(instance, f)
f.write("\n")
def prompt_selection_for_response_generation(input_path, output_path, seed):
"""Selecting prompts for the response generation"""
print("> Selecting prompts for the response generation")
print("> set random seed")
np.random.seed(seed)
prompt_example_list = []
print("> reading data from %s" % input_path)
with open(input_path, "r") as f:
for i, line in tqdm(enumerate(f)):
line = line.strip()
splits = line.split("\t")
# get the topic, context, knowledge and response
topic = splits[0]
dialog_context = splits[1]
knowledge = splits[2]
response = splits[3]
turns = dialog_context.split(" [SEP] ")[-3:]
if knowledge == "no_passages_used":
continue
# calculate the overlap ratio
from nltk import word_tokenize
knowledge_sent_token_list = word_tokenize(knowledge)
knowledge_sent_token_dict = {token: True for token in knowledge_sent_token_list}
knowledge_len = len(knowledge_sent_token_list)
response_token_list = word_tokenize(response)
response_len = len(response_token_list)
num_overlap_token = 0
accumulator = 0
for token in response_token_list:
if token in knowledge_sent_token_dict:
accumulator += 1
else:
if accumulator >= 10:
num_overlap_token += accumulator
accumulator = 0
if accumulator >= 10:
num_overlap_token += accumulator
# filtering the data based on the ratio
if num_overlap_token > response_len * 0.9 or num_overlap_token < response_len * 0.6:
continue
if num_overlap_token < knowledge_len * 0.8:
continue
last_turn = " ".join(word_tokenize(turns[-1]))
knowledge = " ".join(word_tokenize(knowledge))
response = " ".join(word_tokenize(response))
prompt_example = ""
# add dialog context
prompt_example += "Topic: " + topic + ". "
prompt_example += "User says: " + last_turn + " "
prompt_example += "We know that: " + knowledge + " "
prompt_example += "System replies: " + response
prompt_example_list.append(prompt_example)
# shuffle the prompt examples
np.random.shuffle(prompt_example_list)
print("> writing to %s" % output_path)
with open(output_path, "w") as f:
# f.write("Generate the System's response based on the knowledge sentence:\n")
for i in tqdm(range(20)):
example = prompt_example_list[i]
f.write(example + "\n")
def prepare_input_for_response_generation(test_file, knwl_gen_file, processed_file):
"""Preparing inputs for the response generation"""
print("> Reading knowledge file from %s" % knwl_gen_file)
# get the knowledge list
with open(knwl_gen_file, "r") as f:
knowledge_list = f.readlines()
print("> Processing ...")
with open(test_file, "r") as fr:
with open(processed_file, "w") as fw:
for line_num, line in enumerate(tqdm(fr)):
line = line.strip()
splits = line.split("\t")
# prepare topic, context, knowledge and response
topic = splits[0]
dialog_context = splits[1]
response = splits[3]
knowledge = knowledge_list[line_num]
knowledge = knowledge.strip()
if "<|endoftext|>" in knowledge:
knowledge = knowledge.replace("<|endoftext|>", "")
# write to the output file
fw.write(topic + "\t" + dialog_context + "\t" \
+ knowledge + "\t" + response + "\n")
if __name__ == "__main__":
args = get_args()
if args.func == "process_wow_dataset":
process_wow_dataset(args.raw_file, args.processed_file, args.knwl_ref_file, args.resp_ref_file)
elif args.func == "process_woi_dataset":
process_woi_dataset(args.raw_file, args.processed_file, args.knwl_ref_file, args.resp_ref_file)
elif args.func == "get_knwl_gen_prompts":
prompt_selection_for_knowledge_generation(
args.test_file, args.train_file, args.model_file,
args.processed_file, args.data_type)
elif args.func == "get_resp_gen_prompts":
prompt_selection_for_response_generation(
args.train_file, args.processed_file, args.seed)
elif args.func == "prepare_input":
prepare_input_for_response_generation(
args.test_file, args.knwl_gen_file, args.processed_file)
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Prompting the pretrained language model to generate knowledge/response"""
import json
import torch
import requests
from nltk import word_tokenize
from megatron import mpu
from megatron import get_args
from megatron import print_rank_0
from megatron import get_tokenizer
from megatron.model import GPTModel
from megatron.training import get_model
from megatron.checkpointing import load_checkpoint
from megatron.initialize import initialize_megatron
from megatron.text_generation import generate_and_post_process
def call_model_api(inputs, tokens_to_generate):
"""Calling the model api to get the output generations"""
args = get_args()
# The following is an example of using the Megatron API
# You can also implement your own API function to place this part
headers = {'Content-Type': 'application/json; charset=UTF-8'}
data = {"prompts": [inputs], "tokens_to_generate": tokens_to_generate, "top_k": 1}
data_json = json.dumps(data)
outputs = requests.put(args.megatron_api_url, headers=headers, data=data_json).json()["text"][0]
input_len = len(inputs)
outputs = outputs[input_len:]
outputs = outputs.split("\n")[0].strip()
return outputs
def read_prompts(prompt_path, prompt_type, n_example):
"""Read prompt data"""
if prompt_type == "knowledge":
# prompts for the knowledge generation
prompt_examples_dict = {}
# read prompt_path
with open(prompt_path, "r") as f:
for i, line in enumerate(f):
line = line.strip()
line_dict = json.loads(line)
key = list(line_dict.keys())[0]
if key not in prompt_examples_dict:
prompt_examples = line_dict[key]
prompt = ""
for instance in prompt_examples:
instance = instance.strip()
prompt += instance + " \n"
prompt_examples_dict[key] = prompt
return prompt_examples_dict
else:
# prompts for the response generation
# read prompt_path
prompt = ""
with open(prompt_path, "r") as f:
prompt_examples = f.readlines()
prompt_examples = prompt_examples[:n_example]
for instance in prompt_examples:
instance = instance.strip()
prompt += instance + " \n"
return prompt
def generate_samples_by_calling_api():
""" Generate outputs by calling"""
args = get_args()
assert args.prompt_type in ["knowledge", "response"], \
"Please input a correct prompt type!"
if args.prompt_type == "knowledge":
# read knowledge generation prompts
knwl_gen_prompt_dict = read_prompts(
args.prompt_file, args.prompt_type, args.num_prompt_examples)
else:
resp_gen_prompt = read_prompts(
args.prompt_file, args.prompt_type, args.num_prompt_examples)
# read the test data
fname = open(args.sample_input_file, "r")
test_sample_list = fname.readlines()
# create output file
fname_out = open(args.sample_output_file, "w")
# call the api to get the output generations
for test_sample in test_sample_list:
test_sample = test_sample.strip()
splits = test_sample.split("\t")
topic = splits[0]
# prepare the inputs for the api
if args.prompt_type == "knowledge":
## inputs = prompt + current test
# get the prompt
turns = splits[1].split(" [SEP] ")
last_turn = turns[-1]
key = topic + " " + last_turn
inputs = knwl_gen_prompt_dict[key]
# add current test
inputs += "( " + last_turn + " ) " + topic + " =>"
else:
# inputs = prompt + current test
# get the prompt
inputs = resp_gen_prompt
# add current test
turns = splits[1].split(" [SEP] ")
knowledge = splits[2]
last_turn = turns[-1]
last_turn = " ".join(word_tokenize(last_turn))
knowledge = " ".join(word_tokenize(knowledge))
knowledge = knowledge.strip()
last_turn = last_turn.strip()
inputs += "Topic: " + topic + ". "
inputs += "User says: " + last_turn + " "
inputs += "We know that: " + knowledge + " "
inputs += "System replies:"
# get the output generations from the api,
# and write to the output file
generations = call_model_api(inputs, args.out_seq_length)
fname_out.write(generations)
fname_out.write("\n")
fname.close()
fname_out.close()
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
print_rank_0('building GPT model ...')
model = GPTModel(
num_tokentypes=0,
parallel_output=True,
pre_process=pre_process,
post_process=post_process
)
return model
def generate_samples_by_prompting_input_from_file(model):
"""Prompt a pretrained language model to generate knowledge/response"""
# get tokenizer
args = get_args()
tokenizer = get_tokenizer()
# Read the sample file and open the output file.
assert args.sample_input_file is not None, \
'sample input file is not provided.'
if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0:
fname = open(args.sample_input_file, "r")
all_raw_text = fname.readlines()
input_count = len(all_raw_text)
if args.sample_output_file is None:
sample_output_file = args.sample_input_file + ".out"
print('`sample-output-file` not specified, setting '
'it to {}'.format(sample_output_file))
else:
sample_output_file = args.sample_output_file
fname_out = open(sample_output_file, "w")
# only two prompt types (i.e., knowledge and response) are allowed
assert args.prompt_type in ["knowledge", "response"], \
"Please input a correct prompt type!"
# Read the prompt file
if args.prompt_type == "knowledge":
# read the prompts for the knowledge generation
prompt_examples_dict = {}
with open(args.prompt_file, "r") as f:
for i, line in enumerate(f):
line = line.strip()
line_dict = json.loads(line)
key = list(line_dict.keys())[0]
# get the prompt examples based on the key
if key not in prompt_examples_dict:
prompt_examples = line_dict[key]
prompt = ""
for instance in prompt_examples:
instance = instance.strip()
prompt += instance + " \n"
prompt_examples_dict[key] = prompt
else:
# read the prompts for the response generation
# prompts are fixed for all test samples
with open(args.prompt_file, "r") as f:
prompt_examples = f.readlines()
prompt_examples = prompt_examples[:args.num_prompt_examples]
prompt = ""
for instance in prompt_examples:
instance = instance.strip()
prompt += instance + " \n"
input_pos = 0
model.eval()
# perform prompting
with torch.no_grad():
while True:
raw_text_len = 0
if mpu.is_pipeline_first_stage() \
and mpu.get_tensor_model_parallel_rank() == 0:
input_str = all_raw_text[input_pos]
input_str = input_str.strip()
splits = input_str.split("\t")
topic = splits[0]
if args.prompt_type == "knowledge":
# first add the prompt into the raw_text
turns = splits[1].split(" [SEP] ")
last_turn = turns[-1]
key = topic + " " + last_turn
raw_text = prompt_examples_dict[key]
# construct inputs for knowledge generation
# then add the constructed inputs into the raw_text
raw_text += "( " + last_turn + " ) " + topic + " =>"
else:
# first add the prompt into the raw_text
raw_text = prompt
# construct inputs for response generation
# then add the constructed inputs into the raw_text
turns = splits[1].split(" [SEP] ")
knowledge = splits[2]
last_turn = turns[-1]
last_turn = " ".join(word_tokenize(last_turn))
knowledge = " ".join(word_tokenize(knowledge))
knowledge = knowledge.strip()
last_turn = last_turn.strip()
raw_text += "Topic: " + topic + ". "
raw_text += "User says: " + last_turn + " "
raw_text += "We know that: " + knowledge + " "
raw_text += "System replies:"
input_pos += 1
raw_text_len = len(raw_text)
else:
raw_text = "EMPTY TEXT"
if input_pos % 100 == 0:
print_rank_0("input_pos: %d" % input_pos)
outputs = generate_and_post_process(
model=model,
prompts=[raw_text],
tokens_to_generate=args.out_seq_length,
top_k_sampling=1)
prompts_plus_generations = outputs[0]
prompts_plus_generations = prompts_plus_generations[0]
# write the generated output to the output file
if mpu.get_tensor_model_parallel_rank() == 0:
if mpu.is_pipeline_first_stage():
generations = prompts_plus_generations[raw_text_len:]
generations = generations.split("\n")[0]
generations = generations.strip()
fname_out.write(generations)
fname_out.write("\n")
raw_text = None
if input_pos == input_count:
return
def main():
args = get_args()
if args.api_prompt:
# obtain the generations by calling the api
generate_samples_by_calling_api()
return
if args.num_layers_per_virtual_pipeline_stage is not None:
print("Interleaved pipeline schedule is not yet supported for text generation.")
exit()
# Set up model and load checkpoint.
model = get_model(model_provider, wrap_with_ddp=False)
if args.load is not None:
_ = load_checkpoint(model, None, None)
assert len(model) == 1, "Above condition should have caught this"
model = model[0]
# perform the prompting
generate_samples_by_prompting_input_from_file(model)
## End-to-End Training of Neural Retrievers for Open-Domain Question Answering
Below we present the steps to run unsupervised and supervised trainining and evaluation of the retriever for [open domain question answering](https://arxiv.org/abs/2101.00408).
## Retriever Training
#### Unsupervised pretraining
1. Use `tools/preprocess_data.py` to preprocess the dataset for Inverse Cloze Task (ICT), which we call unsupervised pretraining. This script takes as input a corpus in loose JSON format and creates fixed-size blocks of text as the fundamental units of data. For a corpus like Wikipedia, this will mean multiple sentences per block and multiple blocks per document. Run [`tools/preprocess_data.py`](../../tools/preprocess_data.py) to construct one or more indexed datasets with the `--split-sentences` argument to make sentences the basic unit. We construct two datasets, one with the title of every document and another with the body.
<pre>
python tools/preprocess_data.py \
--input /path/to/corpus.json \
--json-keys text title \
--split-sentences \
--tokenizer-type BertWordPieceLowerCase \
--vocab-file /path/to/vocab.txt \
--output-prefix corpus_indexed \
--workers 10
</pre>
2. The [`examples/pretrain_ict.sh`](../../examples/pretrain_ict.sh) script runs a single GPU 217M parameter biencoder model for ICT retriever training. Single GPU training is primarily intended for debugging purposes, as the code is developed for distributed training. The script uses a pretrained BERT model and we use a total of batch size of 4096 for the ICT training.
3. Evaluate the pretrained ICT model using [`examples/evaluate_retriever_nq.sh`](../../examples/evaluate_retriever_nq.sh) for [Google's Natural Questions Open dataset](https://arxiv.org/pdf/1906.00300.pdf).
#### Supervised finetuning
1. Use the above pretrained ICT model to finetune using [Google's Natural Questions Open dataset](https://github.com/google-research/language/tree/master/language/orqa). The script [`examples/finetune_retriever_distributed.sh`](../../examples/finetune_retriever_distributed.sh) provides an example for how to perform the training. Our finetuning process includes retriever score scaling and longer training (80 epochs) on top [DPR training](https://arxiv.org/abs/2004.04906).
2. Evaluate the finetuned model using the same evaluation script as mentioned above for the unsupervised model.
More details on the retriever are available in [our paper](https://arxiv.org/abs/2101.00408).
## Reader Training
The reader component will be available soon.
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Main tasks functionality."""
from megatron import get_args, print_rank_0
from megatron.indexer import IndexBuilder
from tasks.orqa.evaluate_utils import ORQAEvaluator
def main():
"""
Main program
"""
args = get_args()
"""
Create a BlockData data structure by running an IndexBuilder over an
ICT Dataset and then evaluate on NQ task
"""
print_rank_0("Starting index builder!")
index_builder = IndexBuilder()
index_builder.build_and_save_index()
print_rank_0("Build and save indices: done!")
print_rank_0("Starting evaluations!")
# Set up the model and evaluator
evaluator = ORQAEvaluator()
# Run evaluation
if args.qa_data_dev is not None:
evaluator.evaluate(args.qa_data_dev, "DEV")
if args.qa_data_test is not None:
evaluator.evaluate(args.qa_data_test, "TEST")
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from megatron import get_args, print_rank_0
from megatron.checkpointing import load_biencoder_checkpoint
from megatron.data.orqa_wiki_dataset import get_open_retrieval_wiki_dataset
from megatron.data.realm_index import OpenRetreivalDataStore, FaissMIPSIndex
from megatron.model.biencoder_model import get_model_provider
from megatron.training import get_model
from tasks.orqa.unsupervised.nq import get_nq_dataset
from tasks.orqa.unsupervised.nq import get_one_epoch_nq_dataloader
from tasks.orqa.unsupervised.nq import process_nq_batch
from tasks.orqa.unsupervised.qa_utils import calculate_matches
class ORQAEvaluator(object):
def __init__(self):
args = get_args()
self.embedding_size = args.hidden_size
self.faiss_use_gpu = args.faiss_use_gpu
self.evidence_embedder_obj = None
self.evidence_dataset = None
self.mips_index = None
self.eval_dataset = None
# Get Evidence (Wikipedia) dataset
self.get_evidence_dataset()
# Load query encoder checkpoint
only_query_model = True
if args.biencoder_shared_query_context_model:
only_query_model = False
model = get_model(get_model_provider(only_query_model=only_query_model,
biencoder_shared_query_context_model=args.biencoder_shared_query_context_model))
self.model = load_biencoder_checkpoint(model,
only_query_model=only_query_model)
assert len(self.model) == 1
self.model[0].eval()
# Load faiss indexer
self.faiss_wrapper()
def get_evidence_embedding(self):
# This will load the embedding from the embedding path
self.evidence_embedder_obj = OpenRetreivalDataStore(load_from_path=True)
def get_evidence_dataset(self):
self.evidence_dataset = get_open_retrieval_wiki_dataset()
def faiss_wrapper(self):
# Initialize FAISS wrapper on local rank = 0 as the evidence embeddings
# is distributed over all the GPUs in a node and FAISS is not
# thread-safe
args = get_args()
if args.local_rank == 0:
# Get evidence embeddings computed using context encoder
self.get_evidence_embedding()
assert self.evidence_embedder_obj is not None
self.mips_index = FaissMIPSIndex(embed_size=self.embedding_size,
embed_data=self.evidence_embedder_obj,
use_gpu=self.faiss_use_gpu)
# Wait for the FAISS index to be initialized in all the nodes
torch.distributed.barrier()
def generate_query_vectors(self, qa_data, split):
self.eval_dataset = get_nq_dataset(qa_data, split)
dataloader = get_one_epoch_nq_dataloader(self.eval_dataset)
query_vectors = []
reference_list = []
for batch in dataloader:
# batch also has query_tokens and query_pad_data
query_tokens, query_mask, query_types, \
query_len, reference = process_nq_batch(batch)
assert len(self.model) == 1
unwrapped_model = self.model[0]
while not hasattr(unwrapped_model, 'embed_text'):
unwrapped_model = unwrapped_model.module
with torch.no_grad():
query_logits = unwrapped_model.embed_text(
unwrapped_model.query_model, query_tokens,
query_mask, query_types)
reference_list.extend(reference)
query_vectors.extend(query_logits.split(1, dim=0))
if len(query_vectors) % 100 == 0:
print_rank_0('Encoded queries {}'.format(len(query_vectors)))
query_tensor = torch.cat(query_vectors, dim=0)
print_rank_0('Total encoded queries tensor {}'.format(query_tensor.size()))
assert query_tensor.size(0) == len(self.eval_dataset)
return query_tensor, reference_list
def evaluate(self, qa_data, split):
args = get_args()
query_tensor, reference_list = self.generate_query_vectors(qa_data, \
split)
local_rank = args.local_rank
rank = torch.distributed.get_rank()
device_count = torch.cuda.device_count()
num_nodes = torch.distributed.get_world_size() // device_count
node_id = rank // device_count
for node in range(num_nodes):
start_rank = node * device_count
end_rank = (node + 1) * device_count
ranks_list = list(range(start_rank, end_rank))
node_group = torch.distributed.new_group(ranks=ranks_list)
if node_id == node:
device_start_rank = start_rank
group = node_group
input_ = torch.empty_like(query_tensor).copy_(query_tensor).detach_()
tensor_list = [torch.empty_like(input_) for _ in range(device_count)]
torch.distributed.all_gather(tensor_list, query_tensor, group=group)
if local_rank == 0 and self.mips_index is not None:
all_query_tensor = torch.cat(tensor_list, dim=0).contiguous()
distance, topkindex = self.mips_index.search_mips_index(
all_query_tensor, top_k=args.faiss_topk_retrievals,
reconstruct=False)
distance = torch.from_numpy(distance).cuda()
topkindex = torch.LongTensor(topkindex).cuda()
if local_rank != 0:
distance = torch.empty(device_count * len(query_tensor), \
args.faiss_topk_retrievals, dtype=torch.float32).cuda()
topkindex = torch.empty(device_count * len(query_tensor), \
args.faiss_topk_retrievals, dtype=torch.int64).cuda()
torch.distributed.broadcast(distance, src=device_start_rank, \
group=group)
torch.distributed.broadcast(topkindex, src=device_start_rank, \
group=group)
distance = torch.split(distance, len(query_tensor), dim=0)\
[local_rank]
topkindex = torch.split(topkindex, len(query_tensor), dim=0)\
[local_rank]
top_ids_and_scores = []
for darray, topkarray in zip(distance, topkindex):
top_ids_and_scores.append((topkarray.tolist(), darray.tolist()))
passages = self.evidence_dataset.id2text
match_stats = calculate_matches(passages,
reference_list,
top_ids_and_scores,
workers_num=args.num_workers,
match_type=args.faiss_match)
top_k_hits = match_stats.top_k_hits
print_rank_0("{} SET RESULTS".format(split))
print_rank_0("topk-{} documents hits {}".format(
args.faiss_topk_retrievals, top_k_hits))
top_k_hits = [v / len(top_ids_and_scores) for v in top_k_hits]
print_rank_0("top-k documents hits accuracy {}".format(top_k_hits))
for i in args.retriever_report_topk_accuracies:
print_rank_0("top-{}: {:.2f}".format(i, top_k_hits[i-1] * 100))
return
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ORQA dataset."""
import json
import random
from abc import ABC
from abc import abstractmethod
import numpy as np
from torch.utils.data import Dataset
from megatron import print_rank_0, get_args
from megatron.data.biencoder_dataset_utils import make_attention_mask
def build_token_types_from_context_list(ctx_list, tokenizer, max_seq_length):
ctx_id_list, ctx_types_list = [], []
for context in ctx_list:
title_ids = tokenizer.tokenize(context['title'])
ctx_ids = tokenizer.tokenize(context['text'])
ctx_ids = title_ids + [tokenizer.sep_id] + ctx_ids
ctx_ids, ctx_types, _ = build_tokens_types_paddings_from_ids(ctx_ids,
max_seq_length, tokenizer.cls,
tokenizer.sep, tokenizer.pad)
ctx_id_list.append(ctx_ids)
ctx_types_list.append(ctx_types)
return ctx_id_list, ctx_types_list
def build_tokens_types_paddings_from_text(query, context,
tokenizer, max_seq_length):
"""Build token types and paddings, trim if needed, and pad if needed."""
query_ids = tokenizer.tokenize(query)
query_ids, query_types, query_pad_mask = \
build_tokens_types_paddings_from_ids(query_ids, max_seq_length, \
tokenizer.cls, tokenizer.sep, tokenizer.pad)
# Appending the title of the context at front
extended_ctx_ids = None
if context is not None:
title_ids = tokenizer.tokenize(context['title'])
ctx_ids = tokenizer.tokenize(context['text'])
extended_ctx_ids = title_ids + [tokenizer.sep] + ctx_ids
ctx_ids, ctx_types, ctx_pad_mask = \
build_tokens_types_paddings_from_ids(extended_ctx_ids,
max_seq_length, tokenizer.cls, tokenizer.sep, tokenizer.pad)
return query_ids, query_types, query_pad_mask, \
ctx_ids, ctx_types, ctx_pad_mask
# Similar code tasks/data_utils with some changes
def build_tokens_types_paddings_from_ids(text_ids, max_seq_length,
cls_id, sep_id, pad_id):
"""Build token types and paddings, trim if needed, and pad if needed."""
enc_ids = []
tokentypes_enc = []
# [CLS].
enc_ids.append(cls_id)
tokentypes_enc.append(0)
# A.
len_src = len(text_ids)
enc_ids.extend(text_ids)
tokentypes_enc.extend([0] * len_src)
# Cap the size.
if len(enc_ids) > max_seq_length - 1:
enc_ids = enc_ids[0: max_seq_length - 1]
tokentypes_enc = tokentypes_enc[0: max_seq_length - 1]
# [SEP].
enc_ids.append(sep_id)
tokentypes_enc.append(0)
num_tokens_enc = len(enc_ids)
# Padding.
padding_length = max_seq_length - len(enc_ids)
if padding_length > 0:
enc_ids.extend([pad_id] * padding_length)
tokentypes_enc.extend([pad_id] * padding_length)
pad_mask = ([1] * num_tokens_enc) + ([0] * padding_length)
pad_mask = np.array(pad_mask, dtype=np.int64)
return enc_ids, tokentypes_enc, pad_mask
def build_sample(query_ids, query_types, query_pad_mask,
ctx_ids, ctx_types, ctx_pad_mask, answers,
neg_ctx_id_list=None, neg_ctx_types_list=None,
include_neg=False):
"""Convert to numpy and return a sample consumed by the batch producer."""
query_ids = np.array(query_ids, dtype=np.int64)
query_types = np.array(query_types, dtype=np.int64)
query_mask = make_attention_mask(query_ids, query_ids)
ctx_ids = np.array(ctx_ids, dtype=np.int64)
ctx_types = np.array(ctx_types, dtype=np.int64)
ctx_mask = make_attention_mask(ctx_ids, ctx_ids)
sample = ({
'query': query_ids,
'query_mask': query_mask,
'query_types': query_types,
'query_pad_mask': query_pad_mask,
'context': ctx_ids,
'context_mask': ctx_mask,
'context_types': ctx_types,
'context_pad_mask': ctx_pad_mask,
'reference': answers
})
if include_neg:
neg_ctx_ids = np.array(neg_ctx_id_list, dtype=np.int64)
neg_ctx_id_types = np.array(neg_ctx_types_list, dtype=np.int64)
neg_ctx_mask = np.array([make_attention_mask(ids, ids) \
for ids in neg_ctx_ids], dtype=np.int64)
sample['neg_context'] = neg_ctx_ids
sample['neg_context_types'] = neg_ctx_id_types
sample['neg_context_mask'] = neg_ctx_mask
return sample
class OpenRetrievalAbstractDataset(ABC, Dataset):
"""Open Retrieval base dataset class."""
def __init__(self, task_name, dataset_name, datapaths, tokenizer, \
max_seq_length, evaluate=False):
# Store inputs.
args = get_args()
self.evaluate = evaluate
self.val_av_rank_hard_neg = args.val_av_rank_hard_neg
self.val_av_rank_other_neg = args.val_av_rank_other_neg
self.train_with_neg = args.train_with_neg
self.train_hard_neg = args.train_hard_neg
self.task_name = task_name
self.dataset_name = dataset_name
self.tokenizer = tokenizer
self.max_seq_length = max_seq_length
print_rank_0(' > building {} dataset for {}:'.format(self.task_name,
self.dataset_name))
# Process the files.
string = ' > paths:'
for path in datapaths:
string += ' ' + path
print_rank_0(string)
self.samples = []
for datapath in datapaths:
self.samples.extend(self.process_samples_from_single_path(datapath))
args = get_args()
if args.sample_rate < 1: # subsample
k = int(len(self.samples) * args.sample_rate)
self.samples = random.sample(self.samples, k)
print_rank_0(' >> total number of samples: {}'.format(
len(self.samples)))
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
raw_sample = self.samples[idx]
query_ids, query_types, query_pad_mask, ctx_ids, ctx_types, \
ctx_pad_mask = build_tokens_types_paddings_from_text( \
raw_sample['question'], raw_sample['pos_context'], \
self.tokenizer, self.max_seq_length)
if self.evaluate:
neg_ctx_list = \
raw_sample['negative_context'][:self.val_av_rank_other_neg] + \
raw_sample['hard_negative_context'][:self.val_av_rank_hard_neg]
neg_ctx_id_list, neg_ctx_types_list = \
build_token_types_from_context_list(neg_ctx_list, \
self.tokenizer, self.max_seq_length)
elif self.train_with_neg:
hard_negative_ctx = raw_sample['hard_negative_context']
negative_ctx = raw_sample['negative_context']
if True: # TODO: fix this or remove this condition
random.shuffle(hard_negative_ctx)
random.shuffle(negative_ctx)
neg_ctx_list = hard_negative_ctx[:self.train_hard_neg]
# In the Google NQ dataset by DPR paper, there are around more than
# 50 missing hard negatives in training data.
# In those cases, substitute hard negatives by simple negatives.
if len(neg_ctx_list) < self.train_hard_neg:
neg_ctx_list += negative_ctx[:self.train_hard_neg - \
len(neg_ctx_list)]
neg_ctx_id_list, neg_ctx_types_list = \
build_token_types_from_context_list(neg_ctx_list,
self.tokenizer, self.max_seq_length)
else:
neg_ctx_id_list = None
neg_ctx_types_list = None
sample = build_sample(query_ids, query_types, query_pad_mask,
ctx_ids, ctx_types, ctx_pad_mask,
raw_sample['answers'],
neg_ctx_id_list, neg_ctx_types_list,
include_neg=self.evaluate or self.train_with_neg)
return sample
@staticmethod
@abstractmethod
def process_samples_from_single_path(filename):
"""Abstract method that takes a filename and
returns a list of dataset samples, each sample being a dict of
{'text': string, 'text': string}
"""
pass
def normalize_question(question):
if question[-1] == '?':
question = question[:-1]
return question
# The following class reads the datasets for training retriever as
# prepared by the DPR codebase (https://github.com/facebookresearch/DPR)
class NQSupervisedDataset(OpenRetrievalAbstractDataset):
def __init__(self, name, datapaths, tokenizer, max_seq_length, \
evaluate=False):
super().__init__('natural_questions_ret',
name,
datapaths,
tokenizer,
max_seq_length,
evaluate=evaluate)
@staticmethod
def process_samples_from_single_path(filename):
""""Implement abstract method."""
print_rank_0(' > Processing {} ...'.format(filename))
samples = []
total = 0
with open(filename, 'r', encoding="utf-8") as f:
data = json.load(f)
for row in data:
question = normalize_question(row['question'])
pos_context = row['positive_ctxs'][0]
# Hard Negative Contexts
if len(row['hard_negative_ctxs']) > 0:
hard_neg_context = row['hard_negative_ctxs']
else:
hard_neg_context = []
# Negative Contexts
if len(row['negative_ctxs']) > 0:
neg_context = row['negative_ctxs']
else:
neg_context = []
answers = row['answers']
sample = {'question': question,
'pos_context': pos_context,
'hard_negative_context': hard_neg_context,
'negative_context': neg_context,
'answers': answers}
total += 1
samples.append(sample)
if total % 5000 == 0:
print_rank_0(' > processed {} so far ...'.format(total))
print_rank_0(' >> processed {} samples.'.format(len(samples)))
return samples
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluation utilities."""
from collections import OrderedDict
import math
import numpy as np
import time
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from megatron import get_args, print_rank_0
from megatron import mpu
from megatron.utils import average_losses_across_data_parallel_group
from tasks.finetune_utils import build_data_loader
def task_collate_fn(batch_data):
# generate batch
batch_size = len(batch_data)
tensorized = OrderedDict()
for d in batch_data:
for k, v in d.items():
tensorized.setdefault(k, []).append(v)
tensorized['query'] = torch.LongTensor(tensorized['query'])
tensorized['query_mask'] = torch.LongTensor(tensorized['query_mask'])
tensorized['query_types'] = torch.LongTensor(tensorized['query_types'])
tensorized['query_pad_mask'] = \
torch.LongTensor(tensorized['query_pad_mask'])
tensorized['context'] = torch.LongTensor(tensorized['context'])
tensorized['context_mask'] = \
torch.LongTensor(tensorized['context_mask'])
tensorized['context_types'] = \
torch.LongTensor(tensorized['context_types'])
tensorized['context_pad_mask'] = \
torch.LongTensor(tensorized['context_pad_mask'])
if 'neg_context' in tensorized:
tensorized['neg_context'] = \
torch.LongTensor(np.concatenate(tensorized['neg_context']))
tensorized['neg_context_mask'] = \
torch.LongTensor(np.concatenate(tensorized['neg_context_mask']))
tensorized['neg_context_types'] = \
torch.LongTensor(np.concatenate(tensorized['neg_context_types']))
return tensorized
def process_batch(batch):
"""Process batch and produce inputs for the model."""
query_tokens = batch['query'].long().cuda()
query_mask = (batch['query_mask'] < 0.5).cuda()
query_types = batch['query_types'].long().cuda()
query_pad_mask = batch['query_pad_mask'].long().cuda()
context_tokens = batch['context'].long().cuda()
context_mask = (batch['context_mask'] < 0.5).cuda()
context_types = batch['context_types'].long().cuda()
context_pad_mask = batch['context_pad_mask'].long().cuda()
if 'neg_context' in batch:
neg_context_tokens = batch['neg_context'].long().cuda()
neg_context_mask = (batch['neg_context_mask'] < 0.5).cuda()
neg_context_types = batch['neg_context_types'].long().cuda()
else:
neg_context_tokens = None
neg_context_mask = None
neg_context_types = None
reference = batch['reference']
return query_tokens, query_mask, query_types, query_pad_mask, \
context_tokens, context_mask, context_types, context_pad_mask, \
neg_context_tokens, neg_context_mask, neg_context_types, reference
def accuracy_func_provider(single_dataset_provider, rank0sampler=False):
"""Provide function that calculates accuracies."""
args = get_args()
print_rank_0("accuracy_func_provider is CALLED")
# Build dataloaders
datapath = args.valid_data
dataset = single_dataset_provider(datapath)
drop_last = False
if mpu.get_data_parallel_world_size() > 1 and not rank0sampler:
drop_last = True
print_rank_0(datapath)
print_rank_0(rank0sampler)
dataloader = build_data_loader(dataset,
args.eval_micro_batch_size,
num_workers=args.num_workers,
drop_last=drop_last,
task_collate_fn=task_collate_fn)
dataloaders = (dataset.dataset_name, dataloader)
def metrics_func(model, epoch, output_predictions=False):
print_rank_0('calculating metrics by accuracy func in ORQA...')
if output_predictions:
assert rank0sampler
names = 'predictions'
name, dataloader = dataloaders
if args.task == "RET-FINETUNE-NQ":
start_time = time.time()
output = retrieval_loss(model, dataloader)
stats_dict, total = output
format_string = ""
for k, v in stats_dict.items():
format_string += "|{} = {:.2f}".format(k, v / total)
print_rank_0("epoch:{}{}".format(epoch, format_string))
print_rank_0("taken time to calcuate metrics {:.3f}".format(\
time.time() - start_time))
else:
raise AssertionError("{} Task not supported".format(args.task))
return metrics_func
def retrieval_loss(model, dataloader):
args = get_args()
total = 0
topk_stats_dict = {'top{}_acc'.format(k): 0 for k in \
args.retriever_report_topk_accuracies}
stats_dict = dict(rank=0, **topk_stats_dict)
assert len(model) == 1
unwrapped_model = model[0]
unwrapped_model.eval()
with torch.no_grad():
# For all the batches in the dataset.
for batch in dataloader:
# Run the model forward.
query_tokens, query_mask, query_types, _, \
context_tokens, context_mask, context_types, _, \
neg_context_tokens, neg_context_mask, neg_context_types, \
reference = process_batch(batch)
query_logits, context_logits = unwrapped_model(query_tokens,
query_mask, query_types,
torch.cat([context_tokens, neg_context_tokens]),
torch.cat([context_mask, neg_context_mask]),
torch.cat([context_types, neg_context_types]))
retrieval_scores = torch.matmul(query_logits,
torch.transpose(context_logits, 0, 1))
if args.retriever_score_scaling:
retrieval_scores = retrieval_scores / \
math.sqrt(args.hidden_size)
local_batch_size = query_logits.shape[0]
labels = torch.arange(local_batch_size).long().cuda()
softmax_scores = F.softmax(retrieval_scores, dim=1)
sorted_vals, sorted_indices = torch.topk(softmax_scores,
k=softmax_scores.shape[1],
sorted=True)
def topk_accuracy(k):
return torch.cuda.FloatTensor(
[sum([int(labels[i] in sorted_indices[i, :k]) for i in \
range(local_batch_size)])])
def get_rank():
return torch.cuda.FloatTensor(
[sum([torch.nonzero(labels[i] == sorted_indices[i])[0][0] \
for i in range(local_batch_size)])])
topk_accs = [topk_accuracy(k) for k in \
args.retriever_report_topk_accuracies]
rank = get_rank()
losses = average_losses_across_data_parallel_group([rank, \
*topk_accs])
# create stats_dict with retrieval loss and all specified
# top-k accuracies
topk_acc_dict = {'top{}_acc'.format(k): v * 100 for k, v in \
zip(args.retriever_report_topk_accuracies, losses[1:])}
temp_stats_dict = dict(rank=losses[0], **topk_acc_dict)
for k in stats_dict.keys():
stats_dict[k] += temp_stats_dict[k]
total += local_batch_size
unwrapped_model.train()
return stats_dict, total
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ORQA finetuning/evaluation."""
from functools import partial
import sys
import math
import torch
import torch.nn.functional as F
from megatron import get_args, get_timers, get_tokenizer
from megatron import mpu, print_rank_0
from megatron.indexer import IndexBuilder
from megatron.model.biencoder_model import biencoder_model_provider
from megatron.utils import average_losses_across_data_parallel_group
from pretrain_ict import get_group_world_size_rank
from tasks.finetune_utils import finetune
from tasks.orqa.supervised.eval_utils import accuracy_func_provider
from tasks.orqa.supervised.eval_utils import process_batch, task_collate_fn
from tasks.orqa.evaluate_utils import ORQAEvaluator
# input_ is a 2D tensor
def check_and_append_tensor_for_gather(group, rank, world_size, input_):
# gather the size of the first dimension of the tensor from all ranks
current_length = input_.size()[0]
first_dim = torch.tensor([[current_length]],
device=torch.cuda.current_device())
input_list = [torch.empty_like(first_dim) for _ in range(world_size)]
input_list[rank].copy_(first_dim)
torch.distributed.all_gather(input_list, first_dim, group=group)
all_input_list = torch.cat(input_list, dim=0).contiguous()
max_length = torch.max(all_input_list)
# if the size are different than the max, extend the tensor
# accordingly
if max_length > current_length:
padding=tuple([0] * (input_.dim() * 2 - 1)) + \
tuple([max_length - current_length])
input_ = F.pad(input=input_, pad=padding)
return input_
def orqa(Dataset):
def cross_entropy_forward_step(batch, model):
"""Simple forward step with cross-entropy loss."""
timers = get_timers()
tokenizer = get_tokenizer()
# Get the batch.
timers('batch generator').start()
try:
batch_ = next(batch)
except BaseException:
batch_ = batch
group, rank, world_size = get_group_world_size_rank()
query_tokens, query_mask, query_types, query_pad_mask, \
context_tokens, context_mask, context_types, context_pad_mask, \
neg_context_tokens, neg_context_mask, neg_context_types, \
reference = process_batch(batch_)
timers('batch generator').stop()
local_batch_size = query_tokens.shape[0]
# Text representation of query and context
query_list, context_list = [], []
for i in range(local_batch_size):
query_list.append(tokenizer.decode(query_tokens[i].tolist()))
context_list.append(tokenizer.decode(context_tokens[i].tolist()))
if neg_context_tokens is not None:
neg_context_tokens = check_and_append_tensor_for_gather(group,
rank, world_size, neg_context_tokens)
neg_context_mask = check_and_append_tensor_for_gather(group,
rank, world_size, neg_context_mask)
neg_context_types = check_and_append_tensor_for_gather(group,
rank, world_size, neg_context_types)
if neg_context_tokens is not None:
context_tokens = torch.cat([context_tokens, neg_context_tokens])
context_mask = torch.cat([context_mask, neg_context_mask])
context_types = torch.cat([context_types, neg_context_types])
# Forward model.
output_tensor = model(query_tokens, query_mask,
query_types, context_tokens,
context_mask, context_types)
return output_tensor, partial(cross_entropy_loss_func, query_tokens, context_tokens)
def cross_entropy_loss_func(query_tokens, context_tokens, output_tensor):
args = get_args()
local_batch_size = query_tokens.shape[0]
group, rank, world_size = get_group_world_size_rank()
# recall we assert that model_parallel_size == 1
global_batch_size = world_size * local_batch_size
query_logits, context_logits = output_tensor
if world_size > 1:
input_ = torch.empty_like(context_logits).copy_(\
context_logits).detach_()
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank].copy_(input_)
torch.distributed.all_gather(tensor_list, input_, group=group)
# Check if all-gather happens in order
assert tensor_list[rank].sum().item() == \
context_logits.sum().item()
# Preserves the gradient
tensor_list[rank] = context_logits
all_context_logits = torch.cat(tensor_list, dim=0).contiguous()
# Query tensors
input_ = torch.empty_like(query_logits).copy_(\
query_logits).detach_()
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank].copy_(input_)
torch.distributed.all_gather(tensor_list, input_, group=group)
# Check if all-gather happens in order
assert tensor_list[rank].sum().item() == query_logits.sum().item()
# Preserves the gradient
tensor_list[rank] = query_logits
all_query_logits = torch.cat(tensor_list, dim=0).contiguous()
else:
all_query_logits = query_logits
all_context_logits = context_logits
retrieval_scores = torch.matmul(all_query_logits,
torch.transpose(all_context_logits, 0, 1))
# Scaling the retrieval scores
if args.retriever_score_scaling:
retrieval_scores = retrieval_scores / math.sqrt(args.hidden_size)
if args.train_with_neg:
# if the world size is 3, local batch size is 4, and
# local context size is 8, what we want is
# labels = [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19]
labels = []
local_context_size = context_tokens.shape[0]
for i in range(world_size):
j = i * local_context_size
labels.extend(list(range(j, j + local_batch_size)))
labels = torch.LongTensor(labels).cuda()
assert len(labels) == global_batch_size
else:
labels = torch.arange(global_batch_size).long().cuda()
# Cross-entropy loss.
softmax_scores = F.log_softmax(retrieval_scores, dim=1)
loss = F.nll_loss(softmax_scores, labels, reduction='mean')
max_score, max_idxs = torch.max(softmax_scores, 1)
correct_predictions_count = (max_idxs == labels).sum().float()
# Reduce loss for logging.
reduced_loss = average_losses_across_data_parallel_group([loss, \
correct_predictions_count])
# Loss scaling for correct losses in Supervised Retrieval
loss = loss * mpu.get_data_parallel_world_size()
return loss, {'lm loss': reduced_loss[0],
'correct_prediction_count': reduced_loss[1]}
def train_valid_datasets_provider():
"""Build train and validation dataset."""
args = get_args()
tokenizer = get_tokenizer()
train_dataset = Dataset('training',
args.train_data,
tokenizer,
args.retriever_seq_length,
evaluate=False)
valid_dataset = Dataset('validation',
args.valid_data,
tokenizer,
args.retriever_seq_length,
evaluate=True)
return train_dataset, valid_dataset
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
args = get_args()
print_rank_0('building retriever model for {} ...'.format(args.task))
model = biencoder_model_provider(only_context_model=False,
only_query_model=False,
biencoder_shared_query_context_model=\
args.biencoder_shared_query_context_model,
pre_process=pre_process, post_process=post_process)
return model
def single_dataset_provider(datapath):
args = get_args()
tokenizer = get_tokenizer()
name = datapath[0].split('/')[-1].split('.')[0]
return Dataset(name,
datapath,
tokenizer,
args.retriever_seq_length,
evaluate=True)
def metrics_func_provider():
"""Provide metrics callback function."""
return accuracy_func_provider(single_dataset_provider)
"""Finetune/evaluate."""
finetune(train_valid_datasets_provider,
model_provider,
forward_step=cross_entropy_forward_step,
end_of_epoch_callback_provider=metrics_func_provider,
task_collate_fn=task_collate_fn)
def main():
args = get_args()
if args.task == 'RET-FINETUNE-NQ':
from tasks.orqa.supervised.data import NQSupervisedDataset as Dataset
else:
raise NotImplementedError('ORQA task {} is not implemented.'.format(
args.task))
orqa(Dataset)
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Data Loader for Google NQ dataset
"""
from abc import ABC
import csv
from collections import OrderedDict
import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset, BatchSampler
from megatron import print_rank_0, get_args, get_tokenizer, mpu
from megatron.data.biencoder_dataset_utils import make_attention_mask
def get_nq_dataset(qa_data, split):
args = get_args()
tokenizer = get_tokenizer()
dataset = NQDataset('Google NQ {} Split'.format(split),
'Google Natural Questions',
qa_data,
tokenizer,
args.retriever_seq_length)
return dataset
def process_nq_batch(batch):
query_tokens = batch['token_ids'].long().cuda()
query_mask = (batch['token_mask'] < 0.5).cuda()
query_types = batch['token_types'].long().cuda()
query_len = batch['seq_len'].long().cuda()
reference = batch['reference']
return query_tokens, query_mask, query_types, query_len, reference
class CustomDataLoader(DataLoader):
def __init__(self, dataset, eval=False, **kwargs):
if kwargs.get('collate_fn', None) is None:
kwargs['collate_fn'] = self._collate_fn
self.eval = eval
super().__init__(dataset, **kwargs)
def _collate_fn(self, batch_data):
# generate batch
batch_size = len(batch_data)
tensorized = OrderedDict()
for d in batch_data:
for k, v in d.items():
tensorized.setdefault(k, []).append(v)
assert len(tensorized) == 5
tensorized['token_ids'] = torch.LongTensor(tensorized['token_ids'])
tensorized['token_mask'] = torch.LongTensor(tensorized['token_mask'])
tensorized['token_types'] = torch.LongTensor(tensorized['token_types'])
tensorized['seq_len'] = torch.LongTensor(tensorized['seq_len'])
return tensorized
def get_one_epoch_nq_dataloader(dataset, micro_batch_size=None):
"""Data loader. Note that batch-size is the local (per GPU) batch-size.
NOTE: This dataloader is not distributed !!!
"""
args = get_args()
if micro_batch_size is None:
micro_batch_size = args.micro_batch_size
num_workers = args.num_workers
sampler = torch.utils.data.SequentialSampler(dataset)
# importantly, drop_last must be False to get all the data.
batch_sampler = BatchSampler(sampler,
batch_size=micro_batch_size,
drop_last=False)
# Data loader. Note that batch size is the per GPU batch size.
data_loader = CustomDataLoader(dataset,
batch_sampler=batch_sampler,
num_workers=num_workers,
pin_memory=True)
return data_loader
def build_tokens_types_paddings_from_text(src_text, tokenizer, max_seq_length):
"""Build token types and paddings, trim if needed, and pad if needed."""
src_text_ids = tokenizer.tokenize(src_text)
return build_tokens_types_paddings_from_ids(src_text_ids,
max_seq_length,
tokenizer.cls,
tokenizer.sep,
tokenizer.pad)
def build_tokens_types_paddings_from_ids(src_ids, max_seq_length, cls_id, \
sep_id, pad_id):
"""
Build token types and paddings, trim if needed, and pad if needed.
TODO: Design modular interface to reuse this function. This is getting
repeated multiple times in different tasks
"""
enc_ids = []
tokentypes_enc = []
# [CLS].
enc_ids.append(cls_id)
tokentypes_enc.append(0)
# A.
len_src = len(src_ids)
enc_ids.extend(src_ids)
tokentypes_enc.extend([0] * len_src)
# Cap the size.
if len(enc_ids) > max_seq_length - 1:
enc_ids = enc_ids[0: max_seq_length - 1]
tokentypes_enc = tokentypes_enc[0: max_seq_length - 1]
# [SEP].
enc_ids.append(sep_id)
tokentypes_enc.append(0)
num_tokens_enc = len(enc_ids)
# Padding.
padding_length = max_seq_length - len(enc_ids)
if padding_length > 0:
enc_ids.extend([pad_id] * padding_length)
tokentypes_enc.extend([pad_id] * padding_length)
return enc_ids, tokentypes_enc, num_tokens_enc
def build_sample(token_ids, token_types, num_tokens, reference):
"""
Convert to numpy and return a sample consumed by the
batch producer.
"""
token_ids = np.array(token_ids, dtype=np.int64)
token_types = np.array(token_types, dtype=np.int64)
token_mask = make_attention_mask(token_ids, token_ids)
sample = ({
'token_ids': token_ids,
'token_mask': token_mask,
'token_types': token_types,
'seq_len': num_tokens,
'reference': reference
})
return sample
class NQDataset(ABC, Dataset):
"""
Open Retrieval Question Answering evaluation using Google NQ dataset.
"""
def __init__(self, task_name, dataset_name, datapath,
tokenizer, max_seq_length):
# Store inputs.
self.task_name = task_name
self.dataset_name = dataset_name
self.tokenizer = tokenizer
self.max_seq_length = max_seq_length
print_rank_0(' > building {} dataset for {}:'.format(self.task_name,
self.dataset_name))
print_rank_0(datapath)
self.samples = self.process_samples_from_single_path(datapath)
print_rank_0(' >> total number of samples: {}'.format(\
len(self.samples)))
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
raw_sample = self.samples[idx]
ques_tokens, tokentypes_enc, num_tokens_ques = \
build_tokens_types_paddings_from_text(raw_sample['question'],
self.tokenizer, self.max_seq_length)
sample = build_sample(ques_tokens,
tokentypes_enc,
num_tokens_ques,
raw_sample['answers'])
return sample
@staticmethod
def process_samples_from_single_path(filename):
print_rank_0(' > Processing {} ...'.format(filename))
samples = []
total = 0
with open(filename, 'r') as ifile:
reader = csv.reader(ifile, delimiter='\t')
for row in reader:
question = row[0]
answers = eval(row[1])
sample = {'question': question, 'answers': answers}
total += 1
samples.append(sample)
if total % 1000 == 0:
print_rank_0(' > processed {} so far ...'.format(total))
print_rank_0(' >> processed {} samples.'.format(len(samples)))
return samples
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# The following code has been taken from
# https://github.com/facebookresearch/DPR, which is CC-BY-NC 4.0
# licensed as of now. More details on the license can be found
# at https://github.com/facebookresearch/DPR/blob/master/LICENSE
"""
Set of utilities for Q&A results validation tasks - Retriver passage
validation and Reader predicted answer validation
"""
import collections
import logging
import string
import unicodedata
from functools import partial
from multiprocessing import Pool as ProcessPool
from typing import Tuple, List, Dict
import regex as re
from tasks.orqa.unsupervised.tokenizers import SimpleTokenizer
logger = logging.getLogger(__name__)
QAMatchStats = collections.namedtuple('QAMatchStats', ['top_k_hits',\
'questions_doc_hits'])
def calculate_matches(all_docs: Dict[object, Tuple[str, str]],
answers: List[List[str]], closest_docs: List[Tuple[List[object],
List[float]]], workers_num: int, match_type: str) -> QAMatchStats:
"""
Evaluates answers presence in the set of documents. This function is
supposed to be used with a large collection of documents and results.
It internally forks multiple sub-processes for evaluation and then
merges results
:param all_docs: dictionary of the entire documents database.
doc_id -> (doc_text, title)
:param answers: list of answers's list. One list per question
:param closest_docs: document ids of the top results along with their
scores
:param workers_num: amount of parallel threads to process data
:param match_type: type of answer matching. Refer to has_answer code for
available options
:return: matching information tuple.
top_k_hits - a list where the index is the amount of top documents retrieved
and the value is the total amount of valid matches across an entire
dataset.
questions_doc_hits - more detailed info with answer matches for every
question and every retrieved document
"""
global dpr_all_documents
dpr_all_documents = all_docs
tok_opts = {}
tokenizer = SimpleTokenizer(**tok_opts)
processes = ProcessPool(
processes=workers_num,
)
logger.info('Matching answers in top docs...')
get_score_partial = partial(check_answer, match_type=match_type,
tokenizer=tokenizer)
questions_answers_docs = zip(answers, closest_docs)
scores = processes.map(get_score_partial, questions_answers_docs)
logger.info('Per question validation results len=%d', len(scores))
n_docs = len(closest_docs[0][0])
top_k_hits = [0] * n_docs
for question_hits in scores:
best_hit = next((i for i, x in enumerate(question_hits) if x), None)
if best_hit is not None:
top_k_hits[best_hit:] = [v + 1 for v in top_k_hits[best_hit:]]
return QAMatchStats(top_k_hits, scores)
def check_answer(questions_answers_docs, tokenizer, match_type) -> List[bool]:
"""
Search through all the top docs to see if they have any of the answers.
"""
answers, (doc_ids, doc_scores) = questions_answers_docs
global dpr_all_documents
hits = []
for i, doc_id in enumerate(doc_ids):
doc = dpr_all_documents[doc_id]
text = doc[0]
answer_found = False
if text is None: # cannot find the document for some reason
logger.warning("no doc in db")
hits.append(False)
continue
if has_answer(answers, text, tokenizer, match_type):
answer_found = True
hits.append(answer_found)
return hits
def has_answer(answers, text, tokenizer, match_type) -> bool:
"""
Check if a document contains an answer string.
If `match_type` is string, token matching is done between the text
and answer.
If `match_type` is regex, we search the whole text with the regex.
"""
text = _normalize(text)
if match_type == 'string':
# Answer is a list of possible strings
text = tokenizer.tokenize(text).words(uncased=True)
for single_answer in answers:
single_answer = _normalize(single_answer)
single_answer = tokenizer.tokenize(single_answer)
single_answer = single_answer.words(uncased=True)
for i in range(0, len(text) - len(single_answer) + 1):
if single_answer == text[i: i + len(single_answer)]:
return True
elif match_type == 'regex':
# Answer is a regex
for single_answer in answers:
single_answer = _normalize(single_answer)
if regex_match(text, single_answer):
return True
return False
def regex_match(text, pattern):
"""Test if a regex pattern is contained within a text."""
try:
pattern = re.compile(
pattern,
flags=re.IGNORECASE + re.UNICODE + re.MULTILINE,
)
except BaseException:
return False
return pattern.search(text) is not None
# function for the reader model answer validation
def exact_match_score(prediction, ground_truth):
return _normalize_answer(prediction) == _normalize_answer(ground_truth)
def _normalize_answer(s):
def remove_articles(text):
return re.sub(r'\b(a|an|the)\b', ' ', text)
def white_space_fix(text):
return ' '.join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return ''.join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def _normalize(text):
return unicodedata.normalize('NFD', text)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment