Commit a6e00d97 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'main_retriver_merge_ict_eval' into 'main'

ICT zeroshot evaluation

See merge request ADLR/megatron-lm!248
parents c5346794 fcfd0949
#!/bin/bash
# Evaluate natural question test data given Wikipedia embeddings and pretrained
# ICT model
# Datasets can be downloaded from the following link:
# https://github.com/facebookresearch/DPR/blob/master/data/download_data.py
EVIDENCE_DATA_DIR=<Specify path of Wikipedia dataset>
EMBEDDING_PATH=<Specify path of the embeddings>
CHECKPOINT_PATH=<Specify path of pretrained ICT model>
QA_FILE=<Path of the natural question test dataset>
python tasks/main.py \
--task ICT-ZEROSHOT-NQ \
--tokenizer-type BertWordPieceLowerCase \
--num-layers 12 \
--hidden-size 768 \
--num-attention-heads 12 \
--tensor-model-parallel-size 1 \
--micro-batch-size 128 \
--checkpoint-activations \
--seq-length 512 \
--max-position-embeddings 512 \
--load ${CHECKPOINT_PATH} \
--evidence-data-path ${EVIDENCE_DATA_DIR} \
--embedding-path ${EMBEDDING_PATH} \
--retriever-seq-length 256 \
--vocab-file bert-vocab.txt\
--qa-data-test ${QA_FILE} \
--num-workers 2 \
--faiss-use-gpu \
--retriever-report-topk-accuracies 1 5 20 100 \
--fp16
...@@ -712,8 +712,6 @@ def _add_biencoder_args(parser): ...@@ -712,8 +712,6 @@ def _add_biencoder_args(parser):
'square root of hidden size') 'square root of hidden size')
# faiss index # faiss index
group.add_argument('--faiss-use-gpu', action='store_true',
help='Whether create the FaissMIPSIndex on GPU')
group.add_argument('--block-data-path', type=str, default=None, group.add_argument('--block-data-path', type=str, default=None,
help='Where to save/load BlockData to/from') help='Where to save/load BlockData to/from')
group.add_argument('--embedding-path', type=str, default=None, group.add_argument('--embedding-path', type=str, default=None,
......
...@@ -24,11 +24,8 @@ def get_one_epoch_dataloader(dataset, micro_batch_size=None): ...@@ -24,11 +24,8 @@ def get_one_epoch_dataloader(dataset, micro_batch_size=None):
"""Specifically one epoch to be used in an indexing job.""" """Specifically one epoch to be used in an indexing job."""
args = get_args() args = get_args()
world_size = mpu.get_data_parallel_world_size()
rank = mpu.get_data_parallel_rank()
if micro_batch_size is None: if micro_batch_size is None:
micro_batch_size = args.micro_batch_size micro_batch_size = args.micro_batch_size
global_batch_size = micro_batch_size * world_size
num_workers = args.num_workers num_workers = args.num_workers
# Use megatron's sampler with consumed samples set to 0 as # Use megatron's sampler with consumed samples set to 0 as
......
...@@ -116,18 +116,22 @@ class OpenRetreivalDataStore(object): ...@@ -116,18 +116,22 @@ class OpenRetreivalDataStore(object):
class FaissMIPSIndex(object): class FaissMIPSIndex(object):
"""Wrapper object for a BlockData which similarity search via FAISS under the hood""" """
def __init__(self, embed_size, block_data=None, use_gpu=False): Wrapper object for a BlockData which similarity search via FAISS under the hood
"""
def __init__(self, embed_size, embed_data=None, use_gpu=False):
self.embed_size = embed_size self.embed_size = embed_size
self.block_data = block_data self.embed_data = embed_data
self.use_gpu = use_gpu self.use_gpu = use_gpu
self.id_map = dict()
self.block_mips_index = None self.mips_index = None
self._set_block_index() self._set_mips_index()
def _set_block_index(self): def _set_mips_index(self):
"""Create a Faiss Flat index with inner product as the metric to search against""" """
Create a Faiss Flat index with inner product as the metric
to search against
"""
try: try:
import faiss import faiss
except ImportError: except ImportError:
...@@ -135,85 +139,86 @@ class FaissMIPSIndex(object): ...@@ -135,85 +139,86 @@ class FaissMIPSIndex(object):
if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
print("\n> Building index", flush=True) print("\n> Building index", flush=True)
self.block_mips_index = faiss.index_factory(self.embed_size, 'Flat', faiss.METRIC_INNER_PRODUCT)
cpu_index = faiss.IndexFlatIP(self.embed_size)
if self.use_gpu: if self.use_gpu:
# create resources and config for GpuIndex # create resources and config for GpuIndex
res = faiss.StandardGpuResources() config = faiss.GpuMultipleClonerOptions()
config = faiss.GpuIndexFlatConfig() config.shard = True
config.device = torch.cuda.current_device()
config.useFloat16 = True config.useFloat16 = True
gpu_index = faiss.index_cpu_to_all_gpus(cpu_index, co=config)
self.block_mips_index = faiss.GpuIndexFlat(res, self.block_mips_index, config) self.mips_index = faiss.IndexIDMap(gpu_index)
if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
print(">> Initialized index on GPU {}".format(self.block_mips_index.getDevice()), flush=True) print(">> Initialized index on GPU", flush=True)
else: else:
# CPU index supports IDs so wrap with IDMap # CPU index supports IDs so wrap with IDMap
self.block_mips_index = faiss.IndexIDMap(self.block_mips_index) self.mips_index = faiss.IndexIDMap(cpu_index)
if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
print(">> Initialized index on CPU", flush=True) print(">> Initialized index on CPU", flush=True)
# if we were constructed with a BlockData, then automatically load it when the FAISS structure is built # if we were constructed with a BlockData, then automatically load it
if self.block_data is not None: # when the FAISS structure is built
self.add_block_embed_data(self.block_data) if self.embed_data is not None:
self.add_embed_data(self.embed_data)
def reset_index(self): def reset_index(self):
"""Delete existing index and create anew""" """Delete existing index and create a new"""
del self.block_mips_index del self.mips_index
# reset the block data so that _set_block_index will reload it as well # reset the block data so that _set_block_index will reload it as well
if self.block_data is not None: if self.embed_data is not None:
block_data_path = self.block_data.block_data_path embed_data_path = self.embed_data.embedding_path
del self.block_data del self.embed_data
self.block_data = BlockData(block_data_path) self.embed_data = OpenRetreivalDataStore(embed_data_path)
self._set_mips_index()
self._set_block_index() def update_index(self):
"""Delete existing index and create a new"""
del self.mips_index
def add_block_embed_data(self, all_block_data): # reset the block data so that _set_mips_index will reload it as well
if self.embed_data is not None:
self.embed_data.load_from_file()
self._set_mips_index()
def add_embed_data(self, all_embed_data):
"""Add the embedding of each block to the underlying FAISS index""" """Add the embedding of each block to the underlying FAISS index"""
# this assumes the embed_data is a dict : {int: np.array<float>} # this assumes the embed_data is a dict : {int: np.array<float>}
block_indices, block_embeds = zip(*all_block_data.embed_data.items()) block_indices, block_embeds = zip(*all_embed_data.embed_data.items())
# the embeddings have to be entered in as float32 even though the math internally is done with float16.
block_embeds_arr = np.float32(np.array(block_embeds))
block_indices_arr = np.array(block_indices)
# faiss GpuIndex doesn't work with IDMap wrapper so store ids to map back with # the embeddings have to be entered in as float32 even though the math
if self.use_gpu: # internally is done with float16.
for i, idx in enumerate(block_indices): embeds_arr = np.float32(np.array(block_embeds))
self.id_map[i] = idx indices_arr = np.array(block_indices)
# we no longer need the embedding data since it's in the index now # we no longer need the embedding data since it's in the index now
all_block_data.clear() all_embed_data.clear()
if self.use_gpu: self.mips_index.add_with_ids(embeds_arr, indices_arr)
self.block_mips_index.add(block_embeds_arr)
else:
self.block_mips_index.add_with_ids(block_embeds_arr, block_indices_arr)
if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
print(">>> Finished adding block data to index", flush=True) print(">>> Finished adding block data to index", flush=True)
def search_mips_index(self, query_embeds, top_k, reconstruct=True): def search_mips_index(self, query_embeds, top_k, reconstruct=True):
"""Get the top-k blocks by the index distance metric. """
Get the top-k blocks by the index distance metric.
:param reconstruct: if True: return a [num_queries x k x embed_dim] array of blocks :param reconstruct: if True: return a [num_queries x k x embed_dim]
if False: return [num_queries x k] array of distances, and another for indices array of blocks
if False: return [num_queries x k] array of
distances, and another for indices
""" """
query_embeds = np.float32(detach(query_embeds)) query_embeds = np.float32(detach(query_embeds))
if reconstruct: if reconstruct:
# get the vectors themselves # get the vectors themselves
top_k_block_embeds = self.block_mips_index.search_and_reconstruct(query_embeds, top_k) top_k_block_embeds = self.mips_index.search_and_reconstruct(\
query_embeds, top_k)
return top_k_block_embeds return top_k_block_embeds
else: else:
# get distances and indices of closest vectors # get distances and indices of closest vectors
distances, block_indices = self.block_mips_index.search(query_embeds, top_k) distances, block_indices = self.mips_index.search(query_embeds, top_k)
if self.use_gpu:
fresh_indices = np.zeros(block_indices.shape)
for i, j in itertools.product(block_indices.shape):
fresh_indices[i, j] = self.id_map[block_indices[i, j]]
block_indices = fresh_indices
return distances, block_indices return distances, block_indices
...@@ -47,6 +47,20 @@ def get_tasks_args(parser): ...@@ -47,6 +47,20 @@ def get_tasks_args(parser):
help='Sliding window for overlapping evaluation.') help='Sliding window for overlapping evaluation.')
group.add_argument('--strict-lambada', action='store_true', group.add_argument('--strict-lambada', action='store_true',
help='Use more difficult formulation of lambada.') help='Use more difficult formulation of lambada.')
# Retriever args
group.add_argument('--qa-data-dev', type=str, default=None,
help='Path to the QA dataset dev file.')
group.add_argument('--qa-data-test', type=str, default=None,
help='Path to the QA dataset test file.')
# Faiss arguments for retriever
group.add_argument('--faiss-use-gpu', action='store_true',
help='Whether create the FaissMIPSIndex on GPU')
group.add_argument('--faiss-match', type=str, default='string', \
choices=['regex', 'string'], help="Answer matching '\
'logic type")
group.add_argument('--faiss-topk-retrievals', type=int, default=100,
help='Number of blocks to use as top-k during retrieval')
return parser return parser
...@@ -62,6 +76,8 @@ if __name__ == '__main__': ...@@ -62,6 +76,8 @@ if __name__ == '__main__':
from glue.finetune import main from glue.finetune import main
elif args.task in ['LAMBADA', 'WIKITEXT103']: elif args.task in ['LAMBADA', 'WIKITEXT103']:
from zeroshot_gpt.evaluate import main from zeroshot_gpt.evaluate import main
elif args.task in ['ICT-ZEROSHOT-NQ']:
from orqa.evaluate_orqa import main
else: else:
raise NotImplementedError('Task {} is not implemented.'.format( raise NotImplementedError('Task {} is not implemented.'.format(
args.task)) args.task))
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Main tasks functionality."""
import os
import sys
from megatron import get_args
from tasks.orqa.evaluate_utils import ORQAEvaluator
def main():
"""
Main program
"""
args = get_args()
# Set up the model and evaluator
evaluator = ORQAEvaluator()
# Run evaluation
if args.qa_data_dev is not None:
evaluator.evaluate(args.qa_data_dev, "DEV")
if args.qa_data_test is not None:
evaluator.evaluate(args.qa_data_test, "TEST")
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from megatron import get_args, print_rank_0
from megatron.checkpointing import load_biencoder_checkpoint
from megatron.data.orqa_wiki_dataset import get_open_retrieval_wiki_dataset
from tasks.orqa.natural_questions.nq import get_nq_dataset
from tasks.orqa.natural_questions.nq import get_one_epoch_nq_dataloader
from tasks.orqa.natural_questions.nq import process_nq_batch
from tasks.orqa.natural_questions.qa_utils import calculate_matches
from megatron.data.realm_index import OpenRetreivalDataStore, FaissMIPSIndex
from megatron.model.biencoder_model import biencoder_model_provider
from megatron.training import get_model
class ORQAEvaluator(object):
def __init__(self):
args = get_args()
self.embedding_size = args.hidden_size
self.faiss_use_gpu = args.faiss_use_gpu
self.evidence_embedder_obj = None
self.evidence_dataset = None
self.mips_index = None
self.eval_dataset = None
# Get Evidence (Wikipedia) dataset
self.get_evidence_dataset()
# Load query encoder checkpoint
only_query_model = True
if args.biencoder_shared_query_context_model:
only_query_model = False
model = get_model(lambda: biencoder_model_provider(only_query_model=\
only_query_model, biencoder_shared_query_context_model=\
args.biencoder_shared_query_context_model))
self.model = load_biencoder_checkpoint(model,
only_query_model=only_query_model)
assert len(self.model) == 1
self.model[0].eval()
# Load faiss indexer
self.faiss_wrapper()
def get_evidence_embedding(self):
# This will load the embedding from the embedding path
self.evidence_embedder_obj = OpenRetreivalDataStore(load_from_path=True)
def get_evidence_dataset(self):
self.evidence_dataset = get_open_retrieval_wiki_dataset()
def faiss_wrapper(self):
# Initialize FAISS wrapper on local rank = 0 as the evidence embeddings
# is distributed over all the GPUs in a node and FAISS is not
# thread-safe
args = get_args()
if args.local_rank == 0:
# Get evidence embeddings computed using context encoder
self.get_evidence_embedding()
assert self.evidence_embedder_obj is not None
self.mips_index = FaissMIPSIndex(embed_size=self.embedding_size,
embed_data=self.evidence_embedder_obj,
use_gpu=self.faiss_use_gpu)
# Wait for the FAISS index to be initialized in all the nodes
torch.distributed.barrier()
def generate_query_vectors(self, qa_data, split):
self.eval_dataset = get_nq_dataset(qa_data, split)
dataloader = get_one_epoch_nq_dataloader(self.eval_dataset)
query_vectors = []
reference_list = []
for batch in dataloader:
# batch also has query_tokens and query_pad_data
query_tokens, query_mask, query_types, \
query_len, reference = process_nq_batch(batch)
assert len(self.model) == 1
unwrapped_model = self.model[0]
while not hasattr(unwrapped_model, 'embed_text'):
unwrapped_model = unwrapped_model.module
with torch.no_grad():
query_logits = unwrapped_model.embed_text(
unwrapped_model.query_model, query_tokens,
query_mask, query_types)
reference_list.extend(reference)
query_vectors.extend(query_logits.split(1, dim=0))
if len(query_vectors) % 100 == 0:
print_rank_0('Encoded queries {}'.format(len(query_vectors)))
query_tensor = torch.cat(query_vectors, dim=0)
print_rank_0('Total encoded queries tensor {}'.format(query_tensor.size()))
assert query_tensor.size(0) == len(self.eval_dataset)
return query_tensor, reference_list
def evaluate(self, qa_data, split):
args = get_args()
query_tensor, reference_list = self.generate_query_vectors(qa_data, \
split)
local_rank = args.local_rank
rank = torch.distributed.get_rank()
device_count = torch.cuda.device_count()
num_nodes = torch.distributed.get_world_size() // device_count
node_id = rank // device_count
for node in range(num_nodes):
start_rank = node * device_count
end_rank = (node + 1) * device_count
ranks_list = list(range(start_rank, end_rank))
node_group = torch.distributed.new_group(ranks=ranks_list)
if node_id == node:
device_start_rank = start_rank
group = node_group
input_ = torch.empty_like(query_tensor).copy_(query_tensor).detach_()
tensor_list = [torch.empty_like(input_) for _ in range(device_count)]
torch.distributed.all_gather(tensor_list, query_tensor, group=group)
if local_rank == 0 and self.mips_index is not None:
all_query_tensor = torch.cat(tensor_list, dim=0).contiguous()
distance, topkindex = self.mips_index.search_mips_index(
all_query_tensor, top_k=args.faiss_topk_retrievals,
reconstruct=False)
distance = torch.from_numpy(distance).cuda()
topkindex = torch.LongTensor(topkindex).cuda()
if local_rank != 0:
distance = torch.empty(device_count * len(query_tensor), \
args.faiss_topk_retrievals, dtype=torch.float32).cuda()
topkindex = torch.empty(device_count * len(query_tensor), \
args.faiss_topk_retrievals, dtype=torch.int64).cuda()
torch.distributed.broadcast(distance, src=device_start_rank, \
group=group)
torch.distributed.broadcast(topkindex, src=device_start_rank, \
group=group)
distance = torch.split(distance, len(query_tensor), dim=0)\
[local_rank]
topkindex = torch.split(topkindex, len(query_tensor), dim=0)\
[local_rank]
top_ids_and_scores = []
for darray, topkarray in zip(distance, topkindex):
top_ids_and_scores.append((topkarray.tolist(), darray.tolist()))
passages = self.evidence_dataset.id2text
match_stats = calculate_matches(passages,
reference_list,
top_ids_and_scores,
workers_num=args.num_workers,
match_type=args.faiss_match)
top_k_hits = match_stats.top_k_hits
print_rank_0("{} SET RESULTS".format(split))
print_rank_0("topk-{} documents hits {}".format(
args.faiss_topk_retrievals, top_k_hits))
top_k_hits = [v / len(top_ids_and_scores) for v in top_k_hits]
print_rank_0("top-k documents hits accuracy {}".format(top_k_hits))
for i in args.retriever_report_topk_accuracies:
print_rank_0("top-{}: {:.2f}".format(i, top_k_hits[i-1] * 100))
return
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Data Loader for Google NQ dataset
"""
from abc import ABC
import csv
from collections import OrderedDict
import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset, BatchSampler
from megatron import print_rank_0, get_args, get_tokenizer, mpu
from megatron.data.biencoder_dataset_utils import make_attention_mask
def get_nq_dataset(qa_data, split):
args = get_args()
tokenizer = get_tokenizer()
dataset = NQDataset('Google NQ {} Split'.format(split),
'Google Natural Questions',
qa_data,
tokenizer,
args.retriever_seq_length)
return dataset
def process_nq_batch(batch):
query_tokens = batch['token_ids'].long().cuda()
query_mask = (batch['token_mask'] < 0.5).cuda()
query_types = batch['token_types'].long().cuda()
query_len = batch['seq_len'].long().cuda()
reference = batch['reference']
return query_tokens, query_mask, query_types, query_len, reference
class CustomDataLoader(DataLoader):
def __init__(self, dataset, eval=False, **kwargs):
if kwargs.get('collate_fn', None) is None:
kwargs['collate_fn'] = self._collate_fn
self.eval = eval
super().__init__(dataset, **kwargs)
def _collate_fn(self, batch_data):
# generate batch
batch_size = len(batch_data)
tensorized = OrderedDict()
for d in batch_data:
for k, v in d.items():
tensorized.setdefault(k, []).append(v)
assert len(tensorized) == 5
tensorized['token_ids'] = torch.LongTensor(tensorized['token_ids'])
tensorized['token_mask'] = torch.LongTensor(tensorized['token_mask'])
tensorized['token_types'] = torch.LongTensor(tensorized['token_types'])
tensorized['seq_len'] = torch.LongTensor(tensorized['seq_len'])
return tensorized
def get_one_epoch_nq_dataloader(dataset, micro_batch_size=None):
"""Data loader. Note that batch-size is the local (per GPU) batch-size.
NOTE: This dataloader is not distributed !!!
"""
args = get_args()
if micro_batch_size is None:
micro_batch_size = args.micro_batch_size
num_workers = args.num_workers
sampler = torch.utils.data.SequentialSampler(dataset)
# importantly, drop_last must be False to get all the data.
batch_sampler = BatchSampler(sampler,
batch_size=micro_batch_size,
drop_last=False)
# Data loader. Note that batch size is the per GPU batch size.
data_loader = CustomDataLoader(dataset,
batch_sampler=batch_sampler,
num_workers=num_workers,
pin_memory=True)
return data_loader
def build_tokens_types_paddings_from_text(src_text, tokenizer, max_seq_length):
"""Build token types and paddings, trim if needed, and pad if needed."""
src_text_ids = tokenizer.tokenize(src_text)
return build_tokens_types_paddings_from_ids(src_text_ids,
max_seq_length,
tokenizer.cls,
tokenizer.sep,
tokenizer.pad)
def build_tokens_types_paddings_from_ids(src_ids, max_seq_length, cls_id, \
sep_id, pad_id):
"""
Build token types and paddings, trim if needed, and pad if needed.
TODO: Design modular interface to reuse this function. This is getting
repeated multiple times in different tasks
"""
enc_ids = []
tokentypes_enc = []
# [CLS].
enc_ids.append(cls_id)
tokentypes_enc.append(0)
# A.
len_src = len(src_ids)
enc_ids.extend(src_ids)
tokentypes_enc.extend([0] * len_src)
# Cap the size.
if len(enc_ids) > max_seq_length - 1:
enc_ids = enc_ids[0: max_seq_length - 1]
tokentypes_enc = tokentypes_enc[0: max_seq_length - 1]
# [SEP].
enc_ids.append(sep_id)
tokentypes_enc.append(0)
num_tokens_enc = len(enc_ids)
# Padding.
padding_length = max_seq_length - len(enc_ids)
if padding_length > 0:
enc_ids.extend([pad_id] * padding_length)
tokentypes_enc.extend([pad_id] * padding_length)
return enc_ids, tokentypes_enc, num_tokens_enc
def build_sample(token_ids, token_types, num_tokens, reference):
"""
Convert to numpy and return a sample consumed by the
batch producer.
"""
token_ids = np.array(token_ids, dtype=np.int64)
token_types = np.array(token_types, dtype=np.int64)
token_mask = make_attention_mask(token_ids, token_ids)
sample = ({
'token_ids': token_ids,
'token_mask': token_mask,
'token_types': token_types,
'seq_len': num_tokens,
'reference': reference
})
return sample
class NQDataset(ABC, Dataset):
"""
Open Retrieval Question Answering evaluation using Google NQ dataset.
"""
def __init__(self, task_name, dataset_name, datapath,
tokenizer, max_seq_length):
# Store inputs.
self.task_name = task_name
self.dataset_name = dataset_name
self.tokenizer = tokenizer
self.max_seq_length = max_seq_length
print_rank_0(' > building {} dataset for {}:'.format(self.task_name,
self.dataset_name))
print_rank_0(datapath)
self.samples = self.process_samples_from_single_path(datapath)
print_rank_0(' >> total number of samples: {}'.format(\
len(self.samples)))
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
raw_sample = self.samples[idx]
ques_tokens, tokentypes_enc, num_tokens_ques = \
build_tokens_types_paddings_from_text(raw_sample['question'],
self.tokenizer, self.max_seq_length)
sample = build_sample(ques_tokens,
tokentypes_enc,
num_tokens_ques,
raw_sample['answers'])
return sample
@staticmethod
def process_samples_from_single_path(filename):
print_rank_0(' > Processing {} ...'.format(filename))
samples = []
total = 0
with open(filename, 'r') as ifile:
reader = csv.reader(ifile, delimiter='\t')
for row in reader:
question = row[0]
answers = eval(row[1])
sample = {'question': question, 'answers': answers}
total += 1
samples.append(sample)
if total % 1000 == 0:
print_rank_0(' > processed {} so far ...'.format(total))
print_rank_0(' >> processed {} samples.'.format(len(samples)))
return samples
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# The following code has been taken from
# https://github.com/facebookresearch/DPR, which is CC-BY-NC 4.0
# licensed as of now. More details on the license can be found
# at https://github.com/facebookresearch/DPR/blob/master/LICENSE
"""
Set of utilities for Q&A results validation tasks - Retriver passage
validation and Reader predicted answer validation
"""
import collections
import logging
import string
import unicodedata
from functools import partial
from multiprocessing import Pool as ProcessPool
from typing import Tuple, List, Dict
import regex as re
from tasks.orqa.natural_questions.tokenizers import SimpleTokenizer
logger = logging.getLogger(__name__)
QAMatchStats = collections.namedtuple('QAMatchStats', ['top_k_hits',\
'questions_doc_hits'])
def calculate_matches(all_docs: Dict[object, Tuple[str, str]],
answers: List[List[str]], closest_docs: List[Tuple[List[object],
List[float]]], workers_num: int, match_type: str) -> QAMatchStats:
"""
Evaluates answers presence in the set of documents. This function is
supposed to be used with a large collection of documents and results.
It internally forks multiple sub-processes for evaluation and then
merges results
:param all_docs: dictionary of the entire documents database.
doc_id -> (doc_text, title)
:param answers: list of answers's list. One list per question
:param closest_docs: document ids of the top results along with their
scores
:param workers_num: amount of parallel threads to process data
:param match_type: type of answer matching. Refer to has_answer code for
available options
:return: matching information tuple.
top_k_hits - a list where the index is the amount of top documents retrieved
and the value is the total amount of valid matches across an entire
dataset.
questions_doc_hits - more detailed info with answer matches for every
question and every retrieved document
"""
global dpr_all_documents
dpr_all_documents = all_docs
tok_opts = {}
tokenizer = SimpleTokenizer(**tok_opts)
processes = ProcessPool(
processes=workers_num,
)
logger.info('Matching answers in top docs...')
get_score_partial = partial(check_answer, match_type=match_type,
tokenizer=tokenizer)
questions_answers_docs = zip(answers, closest_docs)
scores = processes.map(get_score_partial, questions_answers_docs)
logger.info('Per question validation results len=%d', len(scores))
n_docs = len(closest_docs[0][0])
top_k_hits = [0] * n_docs
for question_hits in scores:
best_hit = next((i for i, x in enumerate(question_hits) if x), None)
if best_hit is not None:
top_k_hits[best_hit:] = [v + 1 for v in top_k_hits[best_hit:]]
return QAMatchStats(top_k_hits, scores)
def check_answer(questions_answers_docs, tokenizer, match_type) -> List[bool]:
"""
Search through all the top docs to see if they have any of the answers.
"""
answers, (doc_ids, doc_scores) = questions_answers_docs
global dpr_all_documents
hits = []
for i, doc_id in enumerate(doc_ids):
doc = dpr_all_documents[doc_id]
text = doc[0]
answer_found = False
if text is None: # cannot find the document for some reason
logger.warning("no doc in db")
hits.append(False)
continue
if has_answer(answers, text, tokenizer, match_type):
answer_found = True
hits.append(answer_found)
return hits
def has_answer(answers, text, tokenizer, match_type) -> bool:
"""
Check if a document contains an answer string.
If `match_type` is string, token matching is done between the text
and answer.
If `match_type` is regex, we search the whole text with the regex.
"""
text = _normalize(text)
if match_type == 'string':
# Answer is a list of possible strings
text = tokenizer.tokenize(text).words(uncased=True)
for single_answer in answers:
single_answer = _normalize(single_answer)
single_answer = tokenizer.tokenize(single_answer)
single_answer = single_answer.words(uncased=True)
for i in range(0, len(text) - len(single_answer) + 1):
if single_answer == text[i: i + len(single_answer)]:
return True
elif match_type == 'regex':
# Answer is a regex
for single_answer in answers:
single_answer = _normalize(single_answer)
if regex_match(text, single_answer):
return True
return False
def regex_match(text, pattern):
"""Test if a regex pattern is contained within a text."""
try:
pattern = re.compile(
pattern,
flags=re.IGNORECASE + re.UNICODE + re.MULTILINE,
)
except BaseException:
return False
return pattern.search(text) is not None
# function for the reader model answer validation
def exact_match_score(prediction, ground_truth):
return _normalize_answer(prediction) == _normalize_answer(ground_truth)
def _normalize_answer(s):
def remove_articles(text):
return re.sub(r'\b(a|an|the)\b', ' ', text)
def white_space_fix(text):
return ' '.join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return ''.join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def _normalize(text):
return unicodedata.normalize('NFD', text)
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# The following code has been taken from
# https://github.com/facebookresearch/DPR, which is CC-BY-NC 4.0
# licensed as of now. More details on the license can be found
# at https://github.com/facebookresearch/DPR/blob/master/LICENSE
"""
Most of the tokenizers code here is copied from DrQA codebase to avoid adding extra dependency
"""
import copy
import logging
import regex
import spacy
logger = logging.getLogger(__name__)
class Tokens(object):
"""A class to represent a list of tokenized text."""
TEXT = 0
TEXT_WS = 1
SPAN = 2
POS = 3
LEMMA = 4
NER = 5
def __init__(self, data, annotators, opts=None):
self.data = data
self.annotators = annotators
self.opts = opts or {}
def __len__(self):
"""The number of tokens."""
return len(self.data)
def slice(self, i=None, j=None):
"""Return a view of the list of tokens from [i, j)."""
new_tokens = copy.copy(self)
new_tokens.data = self.data[i: j]
return new_tokens
def untokenize(self):
"""Returns the original text (with whitespace reinserted)."""
return ''.join([t[self.TEXT_WS] for t in self.data]).strip()
def words(self, uncased=False):
"""Returns a list of the text of each token
Args:
uncased: lower cases text
"""
if uncased:
return [t[self.TEXT].lower() for t in self.data]
else:
return [t[self.TEXT] for t in self.data]
def offsets(self):
"""Returns a list of [start, end) character offsets of each token."""
return [t[self.SPAN] for t in self.data]
def pos(self):
"""Returns a list of part-of-speech tags of each token.
Returns None if this annotation was not included.
"""
if 'pos' not in self.annotators:
return None
return [t[self.POS] for t in self.data]
def lemmas(self):
"""Returns a list of the lemmatized text of each token.
Returns None if this annotation was not included.
"""
if 'lemma' not in self.annotators:
return None
return [t[self.LEMMA] for t in self.data]
def entities(self):
"""Returns a list of named-entity-recognition tags of each token.
Returns None if this annotation was not included.
"""
if 'ner' not in self.annotators:
return None
return [t[self.NER] for t in self.data]
def ngrams(self, n=1, uncased=False, filter_fn=None, as_strings=True):
"""Returns a list of all ngrams from length 1 to n.
Args:
n: upper limit of ngram length
uncased: lower cases text
filter_fn: user function that takes in an ngram list and returns
True or False to keep or not keep the ngram
as_string: return the ngram as a string vs list
"""
def _skip(gram):
if not filter_fn:
return False
return filter_fn(gram)
words = self.words(uncased)
ngrams = [(s, e + 1)
for s in range(len(words))
for e in range(s, min(s + n, len(words)))
if not _skip(words[s:e + 1])]
# Concatenate into strings
if as_strings:
ngrams = ['{}'.format(' '.join(words[s:e])) for (s, e) in ngrams]
return ngrams
def entity_groups(self):
"""Group consecutive entity tokens with the same NER tag."""
entities = self.entities()
if not entities:
return None
non_ent = self.opts.get('non_ent', 'O')
groups = []
idx = 0
while idx < len(entities):
ner_tag = entities[idx]
# Check for entity tag
if ner_tag != non_ent:
# Chomp the sequence
start = idx
while (idx < len(entities) and entities[idx] == ner_tag):
idx += 1
groups.append((self.slice(start, idx).untokenize(), ner_tag))
else:
idx += 1
return groups
class Tokenizer(object):
"""Base tokenizer class.
Tokenizers implement tokenize, which should return a Tokens class.
"""
def tokenize(self, text):
raise NotImplementedError
def shutdown(self):
pass
def __del__(self):
self.shutdown()
class SimpleTokenizer(Tokenizer):
ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+'
NON_WS = r'[^\p{Z}\p{C}]'
def __init__(self, **kwargs):
"""
Args:
annotators: None or empty set (only tokenizes).
"""
self._regexp = regex.compile(
'(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS),
flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE
)
if len(kwargs.get('annotators', {})) > 0:
logger.warning('%s only tokenizes! Skipping annotators: %s' %
(type(self).__name__, kwargs.get('annotators')))
self.annotators = set()
def tokenize(self, text):
data = []
matches = [m for m in self._regexp.finditer(text)]
for i in range(len(matches)):
# Get text
token = matches[i].group()
# Get whitespace
span = matches[i].span()
start_ws = span[0]
if i + 1 < len(matches):
end_ws = matches[i + 1].span()[0]
else:
end_ws = span[1]
# Format data
data.append((
token,
text[start_ws: end_ws],
span,
))
return Tokens(data, self.annotators)
class SpacyTokenizer(Tokenizer):
def __init__(self, **kwargs):
"""
Args:
annotators: set that can include pos, lemma, and ner.
model: spaCy model to use (either path, or keyword like 'en').
"""
model = kwargs.get('model', 'en')
self.annotators = copy.deepcopy(kwargs.get('annotators', set()))
nlp_kwargs = {'parser': False}
if not any([p in self.annotators for p in ['lemma', 'pos', 'ner']]):
nlp_kwargs['tagger'] = False
if 'ner' not in self.annotators:
nlp_kwargs['entity'] = False
self.nlp = spacy.load(model, **nlp_kwargs)
def tokenize(self, text):
# We don't treat new lines as tokens.
clean_text = text.replace('\n', ' ')
tokens = self.nlp.tokenizer(clean_text)
if any([p in self.annotators for p in ['lemma', 'pos', 'ner']]):
self.nlp.tagger(tokens)
if 'ner' in self.annotators:
self.nlp.entity(tokens)
data = []
for i in range(len(tokens)):
# Get whitespace
start_ws = tokens[i].idx
if i + 1 < len(tokens):
end_ws = tokens[i + 1].idx
else:
end_ws = tokens[i].idx + len(tokens[i].text)
data.append((
tokens[i].text,
text[start_ws: end_ws],
(tokens[i].idx, tokens[i].idx + len(tokens[i].text)),
tokens[i].tag_,
tokens[i].lemma_,
tokens[i].ent_type_,
))
# Set special option for non-entity tag: '' vs 'O' in spaCy
return Tokens(data, self.annotators, opts={'non_ent': ''})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment