Commit 598d7ee2 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'main_retriver_merge_dpr' into 'main'

DPR evaluation hangs and Readme

See merge request ADLR/megatron-lm!280
parents 83c4d95a 98113c69
......@@ -29,7 +29,6 @@ python tasks/main.py \
--retriever-seq-length 256 \
--vocab-file bert-vocab.txt\
--qa-data-test ${QA_FILE} \
--num-workers 2 \
--faiss-use-gpu \
--retriever-report-topk-accuracies 1 5 20 100 \
--fp16 \
......
......@@ -36,7 +36,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
--bert-load ${BERT_LOAD_PATH} \
--save-interval 5000 \
--log-interval 10 \
--eval-interval 25000 \
--eval-interval 20000 \
--eval-iters 100 \
--indexer-log-interval 1000 \
--faiss-use-gpu \
......
## End-to-End Training of Neural Retrievers for Open-Domain Question Answering
Below we present the steps to run unsupervised and supervised trainining and evaluation of the retriever for [open domain question answering](https://arxiv.org/abs/2101.00408).
### Unsupervised pretraining
1. Use `tools/preprocess_data.py` to preprocess the dataset for Inverse Cloze Task (ICT) task, which we call unsupervised pretraining. This script takes as input a corpus in loose JSON format and creates fixed-size blocks of text as the fundamental units of data. For a corpus like Wikipedia, this will mean multiple sentences per block and multiple blocks per document. Run [`tools/preprocess_data.py`](../../tools/preprocess_data.py) to construct one or more indexed datasets with the `--split-sentences` argument to make sentences the basic unit. We construct two datasets, one with the title of every document and another with the body.
<pre>
python tools/preprocess_data.py \
--input /path/to/corpus.json \
--json-keys text title \
--split-sentences \
--tokenizer-type BertWordPieceLowerCase \
--vocab-file /path/to/vocab.txt \
--output-prefix corpus_indexed \
--workers 10
</pre>
2. The [`examples/pretrain_ict.sh`](../../examples/pretrain_ict.sh) script runs a single GPU 217M parameter biencoder model for ICT retriever training. Single GPU training is primarily intended for debugging purposes, as the code is developed for distributed training. The script uses a pretrained BERT model with a batch size of 4096 (hence the need for a data parallel world size of 32).
3. Evaluate the pretrained ICT model using [`examples/evaluate_retriever_nq.sh`](../../examples/evaluate_retriever_nq.sh) for natural question answering dataset.
### Supervised finetuning
1. Use the above pretrained ICT model to finetune using [Google's natural question answering dataset](https://ai.google.com/research/NaturalQuestions/). The script [`examples/finetune_retriever_distributed.sh`](../../examples/finetune_retriever_distributed.sh) provides an example for how to do this. Our finetuning consists of score scaling, longer training (80 epochs), and hard negative examples.
2. Evaluate the finetuned model using the same evaluation script as mentioned above for the unsupervised model.
More details on the retriever are available in [our paper](https://arxiv.org/abs/2101.00408).
The reader component will be available soon.
......@@ -33,6 +33,28 @@ from tasks.orqa.supervised.eval_utils import accuracy_func_provider
from tasks.orqa.supervised.eval_utils import process_batch, task_collate_fn
from tasks.orqa.evaluate_utils import ORQAEvaluator
# input_ is a 2D tensor
def check_and_append_tensor_for_gather(group, rank, world_size, input_):
# gather the size of the first dimension of the tensor from all ranks
current_length = input_.size()[0]
first_dim = torch.tensor([[current_length]],
device=torch.cuda.current_device())
input_list = [torch.empty_like(first_dim) for _ in range(world_size)]
input_list[rank].copy_(first_dim)
torch.distributed.all_gather(input_list, first_dim, group=group)
all_input_list = torch.cat(input_list, dim=0).contiguous()
max_length = torch.max(all_input_list)
# if the size are different than the max, extend the tensor
# accordingly
if max_length > current_length:
padding=tuple([0] * (input_.dim() * 2 - 1)) + \
tuple([max_length - current_length])
input_ = F.pad(input=input_, pad=padding)
return input_
def orqa(Dataset):
def cross_entropy_forward_step(batch, model):
......@@ -47,6 +69,8 @@ def orqa(Dataset):
except BaseException:
batch_ = batch
group, rank, world_size = get_group_world_size_rank()
query_tokens, query_mask, query_types, query_pad_mask, \
context_tokens, context_mask, context_types, context_pad_mask, \
neg_context_tokens, neg_context_mask, neg_context_types, \
......@@ -61,6 +85,14 @@ def orqa(Dataset):
query_list.append(tokenizer.decode(query_tokens[i].tolist()))
context_list.append(tokenizer.decode(context_tokens[i].tolist()))
if neg_context_tokens is not None:
neg_context_tokens = check_and_append_tensor_for_gather(group,
rank, world_size, neg_context_tokens)
neg_context_mask = check_and_append_tensor_for_gather(group,
rank, world_size, neg_context_mask)
neg_context_types = check_and_append_tensor_for_gather(group,
rank, world_size, neg_context_types)
if neg_context_tokens is not None:
context_tokens = torch.cat([context_tokens, neg_context_tokens])
context_mask = torch.cat([context_mask, neg_context_mask])
......@@ -70,7 +102,6 @@ def orqa(Dataset):
output_tensor = model(query_tokens, query_mask,
query_types, context_tokens,
context_mask, context_types)
return output_tensor, partial(cross_entropy_loss_func, query_tokens, context_tokens)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment