Commit b9fcb7b4 authored by mpatwary's avatar mpatwary
Browse files

adding dpr code

parent 957d1c9a
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ORQA dataset."""
import json
import random
from abc import ABC
from abc import abstractmethod
import numpy as np
from torch.utils.data import Dataset
from megatron import print_rank_0, get_args
from megatron.data.biencoder_dataset_utils import make_attention_mask
from megatron.data.biencoder_dataset_utils import make_history_mask
def build_token_types_from_context_list(ctx_list, tokenizer, max_seq_length):
ctx_id_list, ctx_types_list = [], []
for context in ctx_list:
title_ids = tokenizer.tokenize(context['title'])
ctx_ids = tokenizer.tokenize(context['text'])
ctx_ids = title_ids + [tokenizer.sep_id] + ctx_ids
ctx_ids, ctx_types, _ = build_tokens_types_paddings_from_ids(ctx_ids,
max_seq_length, tokenizer.cls,
tokenizer.sep, tokenizer.pad)
ctx_id_list.append(ctx_ids)
ctx_types_list.append(ctx_types)
return ctx_id_list, ctx_types_list
def build_tokens_types_paddings_from_text(query, context,
tokenizer, max_seq_length):
"""Build token types and paddings, trim if needed, and pad if needed."""
query_ids = tokenizer.tokenize(query)
query_ids, query_types, query_pad_mask = \
build_tokens_types_paddings_from_ids(query_ids, max_seq_length, \
tokenizer.cls, tokenizer.sep, tokenizer.pad)
# Appending the title of the context at front
extended_ctx_ids = None
if context is not None:
title_ids = tokenizer.tokenize(context['title'])
ctx_ids = tokenizer.tokenize(context['text'])
extended_ctx_ids = title_ids + [tokenizer.sep] + ctx_ids
ctx_ids, ctx_types, ctx_pad_mask = \
build_tokens_types_paddings_from_ids(extended_ctx_ids,
max_seq_length, tokenizer.cls, tokenizer.sep, tokenizer.pad)
return query_ids, query_types, query_pad_mask, \
ctx_ids, ctx_types, ctx_pad_mask
# Similar code tasks/data_utils with some changes
def build_tokens_types_paddings_from_ids(text_ids, max_seq_length,
cls_id, sep_id, pad_id):
"""Build token types and paddings, trim if needed, and pad if needed."""
enc_ids = []
tokentypes_enc = []
# [CLS].
enc_ids.append(cls_id)
tokentypes_enc.append(0)
# A.
len_src = len(text_ids)
enc_ids.extend(text_ids)
tokentypes_enc.extend([0] * len_src)
# Cap the size.
if len(enc_ids) > max_seq_length - 1:
enc_ids = enc_ids[0: max_seq_length - 1]
tokentypes_enc = tokentypes_enc[0: max_seq_length - 1]
# [SEP].
enc_ids.append(sep_id)
tokentypes_enc.append(0)
num_tokens_enc = len(enc_ids)
# Padding.
padding_length = max_seq_length - len(enc_ids)
if padding_length > 0:
enc_ids.extend([pad_id] * padding_length)
tokentypes_enc.extend([pad_id] * padding_length)
pad_mask = ([1] * num_tokens_enc) + ([0] * padding_length)
pad_mask = np.array(pad_mask, dtype=np.int64)
return enc_ids, tokentypes_enc, pad_mask
def build_sample(query_ids, query_types, query_pad_mask,
ctx_ids, ctx_types, ctx_pad_mask, answers,
neg_ctx_id_list=None, neg_ctx_types_list=None,
include_neg=False):
"""Convert to numpy and return a sample consumed by the batch producer."""
query_ids = np.array(query_ids, dtype=np.int64)
query_types = np.array(query_types, dtype=np.int64)
query_mask = make_attention_mask(query_ids, query_ids)
ctx_ids = np.array(ctx_ids, dtype=np.int64)
ctx_types = np.array(ctx_types, dtype=np.int64)
ctx_mask = make_attention_mask(ctx_ids, ctx_ids)
sample = ({
'query': query_ids,
'query_mask': query_mask,
'query_types': query_types,
'query_pad_mask': query_pad_mask,
'context': ctx_ids,
'context_mask': ctx_mask,
'context_types': ctx_types,
'context_pad_mask': ctx_pad_mask,
'reference': answers
})
if include_neg:
neg_ctx_ids = np.array(neg_ctx_id_list, dtype=np.int64)
neg_ctx_id_types = np.array(neg_ctx_types_list, dtype=np.int64)
neg_ctx_mask = np.array([make_attention_mask(ids, ids) \
for ids in neg_ctx_ids], dtype=np.int64)
sample['neg_context'] = neg_ctx_ids
sample['neg_context_types'] = neg_ctx_id_types
sample['neg_context_mask'] = neg_ctx_mask
return sample
class OpenRetrievalAbstractDataset(ABC, Dataset):
"""Open Retrieval base dataset class."""
def __init__(self, task_name, dataset_name, datapaths, tokenizer, \
max_seq_length, evaluate=False):
# Store inputs.
args = get_args()
self.evaluate = evaluate
self.val_av_rank_hard_neg = args.val_av_rank_hard_neg
self.val_av_rank_other_neg = args.val_av_rank_other_neg
self.train_with_neg = args.train_with_neg
self.train_hard_neg = args.train_hard_neg
self.task_name = task_name
self.dataset_name = dataset_name
self.tokenizer = tokenizer
self.max_seq_length = max_seq_length
print_rank_0(' > building {} dataset for {}:'.format(self.task_name,
self.dataset_name))
# Process the files.
string = ' > paths:'
for path in datapaths:
string += ' ' + path
print_rank_0(string)
self.samples = []
for datapath in datapaths:
self.samples.extend(self.process_samples_from_single_path(datapath))
args = get_args()
if args.sample_rate < 1: # subsample
k = int(len(self.samples) * args.sample_rate)
self.samples = random.sample(self.samples, k)
print_rank_0(' >> total number of samples: {}'.format(
len(self.samples)))
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
raw_sample = self.samples[idx]
query_ids, query_types, query_pad_mask, ctx_ids, ctx_types, \
ctx_pad_mask = build_tokens_types_paddings_from_text( \
raw_sample['question'], raw_sample['pos_context'], \
self.tokenizer, self.max_seq_length)
if self.evaluate:
neg_ctx_list = \
raw_sample['negative_context'][:self.val_av_rank_other_neg] + \
raw_sample['hard_negative_context'][:self.val_av_rank_hard_neg]
neg_ctx_id_list, neg_ctx_types_list = \
build_token_types_from_context_list(neg_ctx_list, \
self.tokenizer, self.max_seq_length)
elif self.train_with_neg:
hard_negative_ctx = raw_sample['hard_negative_context']
negative_ctx = raw_sample['negative_context']
if True: # TODO: fix this or remove this condition
random.shuffle(hard_negative_ctx)
random.shuffle(negative_ctx)
neg_ctx_list = hard_negative_ctx[:self.train_hard_neg]
# In the Google NQ dataset by DPR paper, there are around more than
# 50 missing hard negatives in training data.
# In those cases, substitute hard negatives by simple negatives.
if len(neg_ctx_list) < self.train_hard_neg:
neg_ctx_list += negative_ctx[:self.train_hard_neg - \
len(neg_ctx_list)]
neg_ctx_id_list, neg_ctx_types_list = \
build_token_types_from_context_list(neg_ctx_list,
self.tokenizer, self.max_seq_length)
else:
neg_ctx_id_list = None
neg_ctx_types_list = None
sample = build_sample(query_ids, query_types, query_pad_mask,
ctx_ids, ctx_types, ctx_pad_mask,
raw_sample['answers'],
neg_ctx_id_list, neg_ctx_types_list,
include_neg=self.evaluate or self.train_with_neg)
return sample
@staticmethod
@abstractmethod
def process_samples_from_single_path(filename):
"""Abstract method that takes a filename and
returns a list of dataset samples, each sample being a dict of
{'text': string, 'text': string}
"""
pass
def normalize_question(question):
if question[-1] == '?':
question = question[:-1]
return question
class NQSupervisedDataset(OpenRetrievalAbstractDataset):
def __init__(self, name, datapaths, tokenizer, max_seq_length, \
evaluate=False):
super().__init__('natural_questions_ret',
name,
datapaths,
tokenizer,
max_seq_length,
evaluate=evaluate)
@staticmethod
def process_samples_from_single_path(filename):
""""Implement abstract method."""
print_rank_0(' > Processing {} ...'.format(filename))
samples = []
total = 0
with open(filename, 'r', encoding="utf-8") as f:
data = json.load(f)
for row in data:
question = normalize_question(row['question'])
pos_context = row['positive_ctxs'][0]
# Hard Negative Contexts
if len(row['hard_negative_ctxs']) > 0:
hard_neg_context = row['hard_negative_ctxs']
else:
hard_neg_context = []
# Negative Contexts
if len(row['negative_ctxs']) > 0:
neg_context = row['negative_ctxs']
else:
neg_context = []
answers = row['answers']
sample = {'question': question,
'pos_context': pos_context,
'hard_negative_context': hard_neg_context,
'negative_context': neg_context,
'answers': answers}
total += 1
samples.append(sample)
if total % 5000 == 0:
print_rank_0(' > processed {} so far ...'.format(total))
print_rank_0(' >> processed {} samples.'.format(len(samples)))
return samples
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluation utilities."""
from collections import OrderedDict
import math
import numpy as np
import time
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from megatron import get_args, print_rank_0
from megatron import mpu
from megatron.utils import average_losses_across_data_parallel_group
from tasks.finetune_utils import build_data_loader
def task_collate_fn(batch_data):
# generate batch
batch_size = len(batch_data)
tensorized = OrderedDict()
for d in batch_data:
for k, v in d.items():
tensorized.setdefault(k, []).append(v)
# assert len(tensorized) == 12
tensorized['query'] = torch.LongTensor(tensorized['query'])
tensorized['query_mask'] = torch.LongTensor(tensorized['query_mask'])
tensorized['query_types'] = torch.LongTensor(tensorized['query_types'])
tensorized['query_pad_mask'] = \
torch.LongTensor(tensorized['query_pad_mask'])
tensorized['context'] = torch.LongTensor(tensorized['context'])
tensorized['context_mask'] = \
torch.LongTensor(tensorized['context_mask'])
tensorized['context_types'] = \
torch.LongTensor(tensorized['context_types'])
tensorized['context_pad_mask'] = \
torch.LongTensor(tensorized['context_pad_mask'])
if 'neg_context' in tensorized:
tensorized['neg_context'] = \
torch.LongTensor(np.concatenate(tensorized['neg_context']))
tensorized['neg_context_mask'] = \
torch.LongTensor(np.concatenate(tensorized['neg_context_mask']))
tensorized['neg_context_types'] = \
torch.LongTensor(np.concatenate(tensorized['neg_context_types']))
return tensorized
def process_batch(batch):
"""Process batch and produce inputs for the model."""
query_tokens = batch['query'].long().cuda()
query_mask = (batch['query_mask'] < 0.5).cuda()
query_types = batch['query_types'].long().cuda()
query_pad_mask = batch['query_pad_mask'].long().cuda()
context_tokens = batch['context'].long().cuda()
context_mask = (batch['context_mask'] < 0.5).cuda()
context_types = batch['context_types'].long().cuda()
context_pad_mask = batch['context_pad_mask'].long().cuda()
if 'neg_context' in batch:
neg_context_tokens = batch['neg_context'].long().cuda()
neg_context_mask = (batch['neg_context_mask'] < 0.5).cuda()
neg_context_types = batch['neg_context_types'].long().cuda()
else:
neg_context_tokens = None
neg_context_mask = None
neg_context_types = None
reference = batch['reference']
return query_tokens, query_mask, query_types, query_pad_mask, \
context_tokens, context_mask, context_types, context_pad_mask, \
neg_context_tokens, neg_context_mask, neg_context_types, reference
def accuracy_func_provider(single_dataset_provider, rank0sampler=False):
#, datapath,
# rank0sampler=False):
"""Provide function that calculates accuracies."""
args = get_args()
print_rank_0("accuracy_func_provider is CALLED")
# Build dataloaders
datapath = args.valid_data
dataset = single_dataset_provider(datapath)
drop_last = False
if mpu.get_data_parallel_world_size() > 1 and not rank0sampler:
drop_last = True
print_rank_0(datapath)
print_rank_0(rank0sampler)
dataloader = build_data_loader(dataset,
args.eval_micro_batch_size,
num_workers=args.num_workers,
drop_last=drop_last,
task_collate_fn=task_collate_fn)
#shuffle=False,
#rank0sampler=rank0sampler)
dataloaders = (dataset.dataset_name, dataloader)
def metrics_func(model, epoch, output_predictions=False):
print_rank_0('calculating metrics by accuracy func in ORQA...')
if output_predictions:
assert rank0sampler
names = 'predictions'
name, dataloader = dataloaders
if args.task == "RET-FINETUNE-NQ":
start_time = time.time()
output = retrieval_loss(model, dataloader)
stats_dict, total = output
format_string = ""
for k, v in stats_dict.items():
format_string += "|{} = {:.2f}".format(k, v / total)
print_rank_0("epoch:{}{}".format(epoch, format_string))
print_rank_0("taken time to calcuate metrics {:.3f}".format(\
time.time() - start_time))
else:
raise AssertionError("{} Task not supported".format(args.task))
return metrics_func
def retrieval_loss(model, dataloader):
args = get_args()
total = 0
topk_stats_dict = {'top{}_acc'.format(k): 0 for k in \
args.retriever_report_topk_accuracies}
stats_dict = dict(rank=0, **topk_stats_dict)
assert len(model) == 1
unwrapped_model = model[0]
unwrapped_model.eval()
with torch.no_grad():
# For all the batches in the dataset.
for batch in dataloader:
# Run the model forward.
query_tokens, query_mask, query_types, _, \
context_tokens, context_mask, context_types, _, \
neg_context_tokens, neg_context_mask, neg_context_types, \
reference = process_batch(batch)
query_logits, context_logits = unwrapped_model(query_tokens,
query_mask, query_types,
torch.cat([context_tokens, neg_context_tokens]),
torch.cat([context_mask, neg_context_mask]),
torch.cat([context_types, neg_context_types]))
retrieval_scores = torch.matmul(query_logits,
torch.transpose(context_logits, 0, 1))
if args.retriever_score_scaling:
retrieval_scores = retrieval_scores / \
math.sqrt(args.hidden_size)
local_batch_size = query_logits.shape[0]
labels = torch.arange(local_batch_size).long().cuda()
softmax_scores = F.softmax(retrieval_scores, dim=1)
sorted_vals, sorted_indices = torch.topk(softmax_scores,
k=softmax_scores.shape[1],
sorted=True)
def topk_accuracy(k):
return torch.cuda.FloatTensor(
[sum([int(labels[i] in sorted_indices[i, :k]) for i in \
range(local_batch_size)])])
def get_rank():
return torch.cuda.FloatTensor(
[sum([torch.nonzero(labels[i] == sorted_indices[i])[0][0] \
for i in range(local_batch_size)])])
topk_accs = [topk_accuracy(k) for k in \
args.retriever_report_topk_accuracies]
rank = get_rank()
losses = average_losses_across_data_parallel_group([rank, \
*topk_accs])
# create stats_dict with retrieval loss and all specified
# top-k accuracies
topk_acc_dict = {'top{}_acc'.format(k): v * 100 for k, v in \
zip(args.retriever_report_topk_accuracies, losses[1:])}
temp_stats_dict = dict(rank=losses[0], **topk_acc_dict)
for k in stats_dict.keys():
stats_dict[k] += temp_stats_dict[k]
total += local_batch_size
unwrapped_model.train()
return stats_dict, total
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ORQA finetuning/evaluation."""
from functools import partial
import math
import torch
import torch.nn.functional as F
from megatron import get_args
from megatron import get_timers
from megatron import get_tokenizer
from megatron import mpu
from megatron import print_rank_0
from megatron.utils import average_losses_across_data_parallel_group
from megatron.model.biencoder_model import biencoder_model_provider
#from tasks.t5_model_utils.finetune_utils_open_retrieval import accuracy_func_provider
#from tasks.t5_model_utils.finetune_utils_open_retrieval import finetune
from pretrain_ict import get_group_world_size_rank
from tasks.finetune_utils import finetune
from tasks.orqa.supervised.eval_utils import accuracy_func_provider
from tasks.orqa.supervised.eval_utils import process_batch, task_collate_fn
def orqa(Dataset): # , name_from_datapath_func):
def cross_entropy_forward_step(batch, model):
"""Simple forward step with cross-entropy loss."""
args = get_args()
timers = get_timers()
tokenizer = get_tokenizer()
# Get the batch.
timers('batch generator').start()
try:
batch_ = next(batch)
except BaseException:
batch_ = batch
query_tokens, query_mask, query_types, query_pad_mask, \
context_tokens, context_mask, context_types, context_pad_mask, \
neg_context_tokens, neg_context_mask, neg_context_types, \
reference = process_batch(batch_)
timers('batch generator').stop()
local_batch_size = query_tokens.shape[0]
# Text representation of query and context
query_list, context_list = [], []
for i in range(local_batch_size):
query_list.append(tokenizer.decode(query_tokens[i].tolist()))
context_list.append(tokenizer.decode(context_tokens[i].tolist()))
if neg_context_tokens is not None:
context_tokens = torch.cat([context_tokens, neg_context_tokens])
context_mask = torch.cat([context_mask, neg_context_mask])
context_types = torch.cat([context_types, neg_context_types])
# Forward model.
#query_logits, context_logits = model(query_tokens, query_mask,
output_tensor = model(query_tokens, query_mask,
query_types, context_tokens,
context_mask, context_types)
return output_tensor, partial(cross_entropy_loss_func_, query_tokens, context_tokens)
#def cross_entropy_loss_func(labels, output_tensor):
def cross_entropy_loss_func_(query_tokens, context_tokens, output_tensor):
args = get_args()
local_batch_size = query_tokens.shape[0]
group, rank, world_size = get_group_world_size_rank()
# recall we assert that model_parallel_size == 1
global_batch_size = world_size * local_batch_size
query_logits, context_logits = output_tensor
if world_size > 1:
input_ = torch.empty_like(context_logits).copy_(\
context_logits).detach_()
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank].copy_(input_)
torch.distributed.all_gather(tensor_list, input_, group=group)
# Check if all-gather happens in order
assert tensor_list[rank].sum().item() == \
context_logits.sum().item()
# Preserves the gradient
tensor_list[rank] = context_logits
all_context_logits = torch.cat(tensor_list, dim=0).contiguous()
# Query tensors
input_ = torch.empty_like(query_logits).copy_(\
query_logits).detach_()
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank].copy_(input_)
torch.distributed.all_gather(tensor_list, input_, group=group)
# Check if all-gather happens in order
assert tensor_list[rank].sum().item() == query_logits.sum().item()
# Preserves the gradient
tensor_list[rank] = query_logits
all_query_logits = torch.cat(tensor_list, dim=0).contiguous()
else:
all_query_logits = query_logits
all_context_logits = context_logits
retrieval_scores = torch.matmul(all_query_logits,
torch.transpose(all_context_logits, 0, 1))
# Scaling the retrieval scores
if args.retriever_score_scaling:
retrieval_scores = retrieval_scores / math.sqrt(args.hidden_size)
if args.train_with_neg:
# if the world size is 3, local batch size is 4, and
# local context size is 8, what we want is
# labels = [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19]
labels = []
local_context_size = context_tokens.shape[0]
for i in range(world_size):
j = i * local_context_size
labels.extend(list(range(j, j + local_batch_size)))
labels = torch.LongTensor(labels).cuda()
assert len(labels) == global_batch_size
else:
labels = torch.arange(global_batch_size).long().cuda()
# Cross-entropy loss.
softmax_scores = F.log_softmax(retrieval_scores, dim=1)
loss = F.nll_loss(softmax_scores, labels, reduction='mean')
max_score, max_idxs = torch.max(softmax_scores, 1)
correct_predictions_count = (max_idxs == labels).sum().float()
# Reduce loss for logging.
reduced_loss = average_losses_across_data_parallel_group([loss, \
correct_predictions_count])
# Loss scaling for correct losses in Supervised Retrieval
loss = loss * mpu.get_data_parallel_world_size()
return loss, {'lm loss': reduced_loss[0],
'correct_prediction_count': reduced_loss[1]}
def train_valid_datasets_provider():
"""Build train and validation dataset."""
args = get_args()
tokenizer = get_tokenizer()
train_dataset = Dataset('training',
args.train_data,
tokenizer,
args.retriever_seq_length,
evaluate=False)
valid_dataset = Dataset('validation',
args.valid_data,
tokenizer,
args.retriever_seq_length,
evaluate=True)
return train_dataset, valid_dataset
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
args = get_args()
print_rank_0('building retriever model for {} ...'.format(args.task))
model = biencoder_model_provider(only_context_model=False,
only_query_model=False,
biencoder_shared_query_context_model=\
args.biencoder_shared_query_context_model,
pre_process=pre_process, post_process=post_process)
return model
def single_dataset_provider(datapath):
args = get_args()
tokenizer = get_tokenizer()
#name = name_from_datapath_func(datapath)
name = datapath[0].split('/')[-1].split('.')[0]
return Dataset(name,
datapath,
tokenizer,
args.retriever_seq_length,
evaluate=True)
#def distributed_metrics_func_provider():
def metrics_func_provider():
"""Provide metrics callback function."""
#def name_from_datapath(datapath):
# return datapath[0].split('/')[-1].split('.')[0]
return accuracy_func_provider(single_dataset_provider)
#def rank0_metrics_func_provider(datapath):
# """Provide metrics callback function."""
# return accuracy_func_provider(single_dataset_provider, datapath,
# rank0sampler=True)
"""Finetune/evaluate."""
finetune(train_valid_datasets_provider,
model_provider,
forward_step=cross_entropy_forward_step,
end_of_epoch_callback_provider=metrics_func_provider,
task_collate_fn=task_collate_fn)
#,end_of_training_callback_provider=rank0_metrics_func_provider)
def main():
args = get_args()
if args.task == 'RET-FINETUNE-NQ':
from tasks.orqa.supervised.data import NQSupervisedDataset as Dataset
#def name_from_datapath(datapath):
# return datapath[0].split('/')[-1].split('.')[0]
else:
raise NotImplementedError('ORQA task {} is not implemented.'.format(
args.task))
orqa(Dataset) #, name_from_datapath)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment