Commit 83c4d95a authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'main_retriver_merge_dpr' into 'main'

Finetuning retriever (ICT+DPR)

See merge request ADLR/megatron-lm!277
parents 01fc0833 fda81a21
#!/bin/bash
# Compute embeddings for each entry of a given dataset (e.g. Wikipedia)
RANK=0
WORLD_SIZE=1
# Wikipedia data can be downloaded from the following link:
# https://github.com/facebookresearch/DPR/blob/master/data/download_data.py
EVIDENCE_DATA_DIR=<Specify path of Wikipedia dataset>
EMBEDDING_PATH=<Specify path to store embeddings>
CHECKPOINT_PATH=<Specify path of pretrained ICT model>
python tools/create_doc_index.py \
--num-layers 12 \
--hidden-size 768 \
--num-attention-heads 12 \
--tensor-model-parallel-size 1 \
--micro-batch-size 128 \
--checkpoint-activations \
--seq-length 512 \
--retriever-seq-length 256 \
--max-position-embeddings 512 \
--load ${CHECKPOINT_PATH} \
--evidence-data-path ${EVIDENCE_DATA_DIR} \
--embedding-path ${EMBEDDING_PATH} \
--indexer-log-interval 1000 \
--indexer-batch-size 128 \
--vocab-file bert-vocab.txt \
--num-workers 2 \
--fp16
#!/bin/bash
# Evaluate natural question test data given Wikipedia embeddings and pretrained
# ICT model
# ICT model or a finetuned model for Natural Question task
# Datasets can be downloaded from the following link:
# https://github.com/facebookresearch/DPR/blob/master/data/download_data.py
EVIDENCE_DATA_DIR=<Specify path of Wikipedia dataset>
EMBEDDING_PATH=<Specify path of the embeddings>
CHECKPOINT_PATH=<Specify path of pretrained ICT model>
CHECKPOINT_PATH=<Specify path of pretrained ICT model or finetuned model>
QA_FILE=<Path of the natural question test dataset>
QA_FILE=<Path of the natural question dev or test dataset>
python tasks/main.py \
--task ICT-ZEROSHOT-NQ \
--task RETRIEVER-EVAL \
--tokenizer-type BertWordPieceLowerCase \
--num-layers 12 \
--hidden-size 768 \
......@@ -32,5 +32,8 @@ python tasks/main.py \
--num-workers 2 \
--faiss-use-gpu \
--retriever-report-topk-accuracies 1 5 20 100 \
--fp16
--fp16 \
--indexer-log-interval 1000 \
--indexer-batch-size 128
#!/bin/bash
# Finetune a BERT or pretrained ICT model using Google natural question data
# Datasets can be downloaded from the following link:
# https://github.com/facebookresearch/DPR/blob/master/data/download_data.py
WORLD_SIZE=8
DISTRIBUTED_ARGS="--nproc_per_node $WORLD_SIZE \
--nnodes 1 \
--node_rank 0 \
--master_addr localhost \
--master_port 6000"
CHECKPOINT_PATH=<Specify path for the finetuned retriever model>
# Load either of the below
BERT_LOAD_PATH=<Path of BERT pretrained model>
PRETRAINED_CHECKPOINT=<Path of Pretrained ICT model>
python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
--task RET-FINETUNE-NQ \
--train-with-neg \
--train-hard-neg 1 \
--pretrained-checkpoint ${PRETRAINED_CHECKPOINT} \
--num-layers 12 \
--hidden-size 768 \
--num-attention-heads 12 \
--tensor-model-parallel-size 1 \
--tokenizer-type BertWordPieceLowerCase \
--train-data nq-train.json \
--valid-data nq-dev.json \
--save ${CHECKPOINT_PATH} \
--load ${CHECKPOINT_PATH} \
--vocab-file bert-vocab.txt \
--bert-load ${BERT_LOAD_PATH} \
--save-interval 5000 \
--log-interval 10 \
--eval-interval 25000 \
--eval-iters 100 \
--indexer-log-interval 1000 \
--faiss-use-gpu \
--DDP-impl torch \
--fp16 \
--retriever-report-topk-accuracies 1 5 10 20 100 \
--seq-length 512 \
--retriever-seq-length 256 \
--max-position-embeddings 512 \
--retriever-score-scaling \
--epochs 80 \
--micro-batch-size 8 \
--eval-micro-batch-size 16 \
--indexer-batch-size 128 \
--lr 2e-5 \
--lr-warmup-fraction 0.01 \
--weight-decay 1e-1
import sys
import time
import torch
import torch.distributed as dist
from megatron import get_args
from megatron import get_args, print_rank_0
from megatron import mpu
from megatron.checkpointing import load_biencoder_checkpoint
from megatron.data.orqa_wiki_dataset import get_open_retrieval_wiki_dataset
from megatron.data.orqa_wiki_dataset import get_open_retrieval_batch
from megatron.data.biencoder_dataset_utils import get_one_epoch_dataloader
from megatron.data.realm_index import detach, OpenRetreivalDataStore
from megatron.model.biencoder_model import biencoder_model_provider
from megatron.model.biencoder_model import get_model_provider
from megatron.training import get_model
......@@ -29,7 +30,6 @@ class IndexBuilder(object):
# need to know whether we're using a REALM checkpoint (args.load)
# or ICT checkpoint
assert not (args.load and args.ict_load)
#self.using_realm_chkpt = args.ict_load is None
self.log_interval = args.indexer_log_interval
self.batch_size = args.indexer_batch_size
......@@ -47,8 +47,8 @@ class IndexBuilder(object):
if self.biencoder_shared_query_context_model:
only_context_model = False
model = get_model(lambda: biencoder_model_provider(only_context_model \
= only_context_model, biencoder_shared_query_context_model = \
model = get_model(get_model_provider(only_context_model=\
only_context_model, biencoder_shared_query_context_model=\
self.biencoder_shared_query_context_model))
self.model = load_biencoder_checkpoint(model,
......@@ -85,6 +85,7 @@ class IndexBuilder(object):
"""
assert len(self.model) == 1
unwrapped_model = self.model[0]
while not hasattr(unwrapped_model, 'embed_text'):
unwrapped_model = unwrapped_model.module
......@@ -103,6 +104,7 @@ class IndexBuilder(object):
context_logits = unwrapped_model.embed_text(
unwrapped_model.context_model, context_tokens, context_mask,
context_types)
context_logits = detach(context_logits)
row_id = detach(row_id)
......
......@@ -87,7 +87,7 @@ class AnnealingLR(object):
else:
raise Exception('{} decay style is not supported.'.format(
self.decay_style))
return self.min_lr + coeff * delta_lr
......
......@@ -15,11 +15,30 @@ from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal
from .module import MegatronModule
def get_model_provider(only_query_model=False, only_context_model=False,
biencoder_shared_query_context_model=False):
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
print_rank_0('building Bienoder model ...')
model = biencoder_model_provider(only_query_model=only_query_model,
only_context_model = only_context_model,
biencoder_shared_query_context_model = \
biencoder_shared_query_context_model,
pre_process=pre_process, post_process=post_process)
return model
return model_provider
def biencoder_model_provider(only_query_model=False,
only_context_model=False,
biencoder_shared_query_context_model=False):
biencoder_shared_query_context_model=False,
pre_process=True,
post_process=True):
"""Build the model."""
args = get_args()
assert mpu.get_tensor_model_parallel_world_size() == 1 and \
mpu.get_pipeline_model_parallel_world_size() == 1, \
......@@ -35,7 +54,9 @@ def biencoder_model_provider(only_query_model=False,
only_query_model=only_query_model,
only_context_model=only_context_model,
biencoder_shared_query_context_model=\
biencoder_shared_query_context_model)
biencoder_shared_query_context_model,
pre_process=pre_process,
post_process=post_process)
return model
......@@ -48,13 +69,17 @@ class BiEncoderModel(MegatronModule):
parallel_output=True,
only_query_model=False,
only_context_model=False,
biencoder_shared_query_context_model=False):
biencoder_shared_query_context_model=False,
pre_process=True,
post_process=True):
super(BiEncoderModel, self).__init__()
args = get_args()
bert_kwargs = dict(
num_tokentypes=num_tokentypes,
parallel_output=parallel_output)
parallel_output=parallel_output,
pre_process=pre_process,
post_process=post_process)
self.biencoder_shared_query_context_model = \
biencoder_shared_query_context_model
......@@ -78,6 +103,13 @@ class BiEncoderModel(MegatronModule):
self.context_model = PretrainedBertModel(**bert_kwargs)
self._context_key = 'context_model'
def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
# this is just a placeholder and will be needed when model
# parallelism will be used
# self.language_model.set_input_tensor(input_tensor)
return
def forward(self, query_tokens, query_attention_mask, query_types,
context_tokens, context_attention_mask, context_types):
"""Run a forward pass for each of the models and
......@@ -217,7 +249,7 @@ class PretrainedBertModel(MegatronModule):
learned information retrieval."""
def __init__(self, num_tokentypes=2,
parallel_output=True):
parallel_output=True, pre_process=True, post_process=True):
super(PretrainedBertModel, self).__init__()
args = get_args()
......@@ -225,6 +257,8 @@ class PretrainedBertModel(MegatronModule):
self.pad_id = tokenizer.pad
self.biencoder_projection_dim = args.biencoder_projection_dim
self.parallel_output = parallel_output
self.pre_process = pre_process
self.post_process = post_process
init_method = init_method_normal(args.init_method_std)
scaled_init_method = scaled_init_method_normal(
args.init_method_std, args.num_layers)
......@@ -234,7 +268,9 @@ class PretrainedBertModel(MegatronModule):
add_pooler=False,
encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method,
scaled_init_method=scaled_init_method)
scaled_init_method=scaled_init_method,
pre_process=self.pre_process,
post_process=self.post_process)
if args.biencoder_projection_dim > 0:
self.projection_enc = get_linear_layer(args.hidden_size,
......@@ -247,7 +283,6 @@ class PretrainedBertModel(MegatronModule):
#extended_attention_mask = bert_extended_attention_mask(attention_mask)
position_ids = bert_position_ids(input_ids)
lm_output = self.language_model(input_ids,
position_ids,
extended_attention_mask,
......@@ -285,7 +320,7 @@ class PretrainedBertModel(MegatronModule):
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
print_rank_0("loading BERT weights")
print_rank_0("loading pretrained weights")
self.language_model.load_state_dict(
state_dict[self._language_model_key], strict=strict)
......
......@@ -181,6 +181,35 @@ class FullTokenizer(object):
def convert_ids_to_tokens(self, ids):
return convert_by_vocab(self.inv_vocab, ids)
@staticmethod
def convert_tokens_to_string(tokens, clean_up_tokenization_spaces=True):
""" Converts a sequence of tokens (string) in a single string. """
def clean_up_tokenization(out_string):
""" Clean up a list of simple English tokenization artifacts
like spaces before punctuations and abreviated forms.
"""
out_string = (
out_string.replace(" .", ".")
.replace(" ?", "?")
.replace(" !", "!")
.replace(" ,", ",")
.replace(" ' ", "'")
.replace(" n't", "n't")
.replace(" 'm", "'m")
.replace(" 's", "'s")
.replace(" 've", "'ve")
.replace(" 're", "'re")
)
return out_string
text = ' '.join(tokens).replace(' ##', '').strip()
if clean_up_tokenization_spaces:
clean_text = clean_up_tokenization(text)
return clean_text
else:
return text
def vocab_size(self):
return len(self.vocab)
......
......@@ -14,6 +14,8 @@
# limitations under the License.
"""Pretrain BERT for Inverse Cloze Task"""
from functools import partial
import math
import torch
......@@ -31,13 +33,16 @@ from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group
def pretrain_ict_model_provider():
def pretrain_ict_model_provider(pre_process=True, post_process=True):
args = get_args()
model = biencoder_model_provider(
only_context_model=False,
only_query_model=False,
biencoder_shared_query_context_model=\
args.biencoder_shared_query_context_model)
args.biencoder_shared_query_context_model,
pre_process=pre_process, post_process=post_process)
return model
def get_group_world_size_rank():
......@@ -77,25 +82,9 @@ class AllgatherFromDataParallelRegion(torch.autograd.Function):
output = output_list[rank].contiguous()
return output
def forward_step(data_iterator, model, input_tensor):
"""Forward step."""
def loss_func(output_tensor):
args = get_args()
timers = get_timers()
# Get the batch.
timers('batch-generator').start()
query_tokens, query_mask, \
context_tokens, context_mask, context_indices = get_ict_batch(data_iterator)
timers('batch-generator').stop()
# Query and Context Types
query_types = torch.cuda.LongTensor(*query_tokens.shape).fill_(0)
context_types = torch.cuda.LongTensor(*context_tokens.shape).fill_(0)
# Forward model.
query_logits, context_logits = model(query_tokens, query_mask,
query_types, context_tokens,
context_mask, context_types)
query_logits, context_logits = output_tensor
micro_batch_size = query_logits.shape[0]
# recall we assert that tensor_model_parallel_size == 1
......@@ -137,6 +126,28 @@ def forward_step(data_iterator, model, input_tensor):
return loss, stats_dict
def forward_step(data_iterator, model):
"""Forward step."""
args = get_args()
timers = get_timers()
# Get the batch.
timers('batch-generator').start()
query_tokens, query_mask, \
context_tokens, context_mask, context_indices = get_ict_batch(data_iterator)
timers('batch-generator').stop()
# Query and Context Types
query_types = torch.cuda.LongTensor(*query_tokens.shape).fill_(0)
context_types = torch.cuda.LongTensor(*context_tokens.shape).fill_(0)
# Forward model.
output_tensor = model(query_tokens, query_mask, query_types, context_tokens,
context_mask, context_types)
return output_tensor, partial(loss_func)
def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build train, valid and test datasets."""
args = get_args()
......
......@@ -16,7 +16,7 @@
"""Finetune utilities."""
from functools import partial
import sys
import torch
from megatron import get_args, get_num_microbatches
......@@ -80,7 +80,8 @@ def _cross_entropy_forward_step(batch, model):
return output_tensor, partial(cross_entropy_loss_func, labels)
def build_data_loader(dataset, micro_batch_size, num_workers, drop_last):
def build_data_loader(dataset, micro_batch_size, num_workers, drop_last,
task_collate_fn=None):
"""Data loader. Note that batch-size is the local (per GPU) batch-size."""
# Sampler.
......@@ -96,7 +97,8 @@ def build_data_loader(dataset, micro_batch_size, num_workers, drop_last):
shuffle=False,
num_workers=num_workers,
drop_last=drop_last,
pin_memory=True)
pin_memory=True,
collate_fn=task_collate_fn)
return data_loader
......@@ -112,21 +114,24 @@ def _build_infinite_size_dataloader(dataloader):
iterator = dataloader.__iter__()
def _build_train_valid_dataloaders(train_dataset, valid_dataset):
def _build_train_valid_dataloaders(train_dataset, valid_dataset,
task_collate_fn=None):
"""Traing and validation dataloaders."""
args = get_args()
print_rank_0('building train and validation dataloaders ...')
# Training dataset.
train_dataloader = build_data_loader(train_dataset, args.micro_batch_size,
args.num_workers, not args.keep_last)
args.num_workers, not args.keep_last,
task_collate_fn)
# Set the training iterations.
args.train_iters_per_epoch = len(train_dataloader)
args.train_iters = args.epochs * args.train_iters_per_epoch
# Validation dataset. For this dataset, we do not need to set up
# shuffling so we can just use a simple infinite loop.
valid_dataloader_ = build_data_loader(valid_dataset, args.micro_batch_size,
args.num_workers, not args.keep_last)
args.num_workers, not args.keep_last,
task_collate_fn)
valid_dataloader = _build_infinite_size_dataloader(valid_dataloader_)
# Now that we've built the data loaders, set batch_size arguments
......@@ -190,6 +195,7 @@ def _train(model, optimizer, lr_scheduler, forward_step,
# Train for one step.
out = train_step(forward_step, batch, model, optimizer, lr_scheduler)
losses_dict, skipped_iter, grad_norm, num_zeros_in_grad = out
iteration += 1
......@@ -211,9 +217,11 @@ def _train(model, optimizer, lr_scheduler, forward_step,
optimizer, lr_scheduler)
# Checkpointing
saved_checkpoint = False
if args.save and args.save_interval and \
iteration % args.save_interval == 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler)
saved_checkpoint = True
# Evaluation
if args.eval_interval and iteration % args.eval_interval == 0:
......@@ -222,6 +230,14 @@ def _train(model, optimizer, lr_scheduler, forward_step,
valid_dataloader, model,
iteration, False)
# Exiting based on iterations
if args.exit_interval and iteration % args.exit_interval == 0:
if not saved_checkpoint:
save_checkpoint(iteration, model, optimizer, lr_scheduler)
torch.distributed.barrier()
print_rank_0('exiting program at iteration {}'.format(iteration))
sys.exit()
# Checkpointing at the end of each epoch.
if args.save:
save_checkpoint(iteration, model, optimizer, lr_scheduler)
......@@ -233,7 +249,8 @@ def _train(model, optimizer, lr_scheduler, forward_step,
def finetune(train_valid_datasets_provider, model_provider,
forward_step=_cross_entropy_forward_step,
end_of_epoch_callback_provider=None):
end_of_epoch_callback_provider=None,
task_collate_fn=None):
"""Main finetune function used across all tasks."""
args = get_args()
timers = get_timers()
......@@ -246,7 +263,7 @@ def finetune(train_valid_datasets_provider, model_provider,
if args.epochs > 0:
train_dataset, valid_dataset = train_valid_datasets_provider()
train_dataloader, valid_dataloader = _build_train_valid_dataloaders(
train_dataset, valid_dataset)
train_dataset, valid_dataset, task_collate_fn)
else:
args.train_iters = 0
timers('train/valid/test dataset/dataloder').stop()
......@@ -270,8 +287,11 @@ def finetune(train_valid_datasets_provider, model_provider,
if args.iteration == 0 and args.pretrained_checkpoint is not None:
original_load = args.load
args.load = args.pretrained_checkpoint
original_rng = args.no_load_rng
args.no_load_rng = True
_ = load_checkpoint(model, None, None)
args.load = original_load
args.no_load_rng = original_rng
# This is critical when only model is loaded. We should make sure
# main parameters are also updated.
optimizer.reload_model_params()
......
......@@ -62,6 +62,29 @@ def get_tasks_args(parser):
group.add_argument('--faiss-topk-retrievals', type=int, default=100,
help='Number of blocks to use as top-k during retrieval')
# finetune for retriever
group.add_argument('--eval-micro-batch-size', type=int, default=None,
help='Eval Batch size per model instance (local batch '
'size). Global batch size is local batch size '
'times data parallel size.')
group.add_argument('--train-with-neg', action='store_true',
help='Whether to use negative examples during model '
'training')
group.add_argument('--train-hard-neg', type=int, default=0,
help='Number of hard negative exmaples to use during '
'training')
# parameters for Av.rank validation method
# Following options/arguments have been taken directly from DPR codebase
group.add_argument('--val-av-rank-hard-neg', type=int, default=30,
help='Av.rank validation: how many hard negatives to'
' take from each question pool')
group.add_argument('--val-av-rank-other-neg', type=int, default=30,
help='Av.rank validation: how many other negatives to'
' take from each question pool')
return parser
......@@ -81,8 +104,10 @@ if __name__ == '__main__':
from glue.finetune import main
elif args.task in ['LAMBADA', 'WIKITEXT103']:
from zeroshot_gpt.evaluate import main
elif args.task in ['ICT-ZEROSHOT-NQ']:
elif args.task in ['ICT-ZEROSHOT-NQ', 'RETRIEVER-EVAL']:
from orqa.evaluate_orqa import main
elif args.task in ['RET-FINETUNE-NQ']:
from orqa.supervised.finetune import main
else:
raise NotImplementedError('Task {} is not implemented.'.format(
args.task))
......
......@@ -15,10 +15,8 @@
"""Main tasks functionality."""
import os
import sys
from megatron import get_args
from megatron import get_args, print_rank_0
from megatron.indexer import IndexBuilder
from tasks.orqa.evaluate_utils import ORQAEvaluator
def main():
......@@ -28,6 +26,20 @@ def main():
args = get_args()
"""
Create a BlockData data structure by running an IndexBuilder over an
ICT Dataset and then evaluate on NQ task
"""
print_rank_0("Starting index builder!")
index_builder = IndexBuilder()
index_builder.build_and_save_index()
print_rank_0("Build and save indices: done!")
print_rank_0("Starting evaluations!")
# Set up the model and evaluator
evaluator = ORQAEvaluator()
......
......@@ -18,13 +18,14 @@ import torch
from megatron import get_args, print_rank_0
from megatron.checkpointing import load_biencoder_checkpoint
from megatron.data.orqa_wiki_dataset import get_open_retrieval_wiki_dataset
from tasks.orqa.natural_questions.nq import get_nq_dataset
from tasks.orqa.natural_questions.nq import get_one_epoch_nq_dataloader
from tasks.orqa.natural_questions.nq import process_nq_batch
from tasks.orqa.natural_questions.qa_utils import calculate_matches
from megatron.data.realm_index import OpenRetreivalDataStore, FaissMIPSIndex
from megatron.model.biencoder_model import biencoder_model_provider
from megatron.model.biencoder_model import get_model_provider
from megatron.training import get_model
from tasks.orqa.unsupervised.nq import get_nq_dataset
from tasks.orqa.unsupervised.nq import get_one_epoch_nq_dataloader
from tasks.orqa.unsupervised.nq import process_nq_batch
from tasks.orqa.unsupervised.qa_utils import calculate_matches
class ORQAEvaluator(object):
def __init__(self):
......@@ -44,9 +45,8 @@ class ORQAEvaluator(object):
if args.biencoder_shared_query_context_model:
only_query_model = False
model = get_model(lambda: biencoder_model_provider(only_query_model=\
only_query_model, biencoder_shared_query_context_model=\
args.biencoder_shared_query_context_model))
model = get_model(get_model_provider(only_query_model=only_query_model,
biencoder_shared_query_context_model=args.biencoder_shared_query_context_model))
self.model = load_biencoder_checkpoint(model,
only_query_model=only_query_model)
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ORQA dataset."""
import json
import random
from abc import ABC
from abc import abstractmethod
import numpy as np
from torch.utils.data import Dataset
from megatron import print_rank_0, get_args
from megatron.data.biencoder_dataset_utils import make_attention_mask
def build_token_types_from_context_list(ctx_list, tokenizer, max_seq_length):
ctx_id_list, ctx_types_list = [], []
for context in ctx_list:
title_ids = tokenizer.tokenize(context['title'])
ctx_ids = tokenizer.tokenize(context['text'])
ctx_ids = title_ids + [tokenizer.sep_id] + ctx_ids
ctx_ids, ctx_types, _ = build_tokens_types_paddings_from_ids(ctx_ids,
max_seq_length, tokenizer.cls,
tokenizer.sep, tokenizer.pad)
ctx_id_list.append(ctx_ids)
ctx_types_list.append(ctx_types)
return ctx_id_list, ctx_types_list
def build_tokens_types_paddings_from_text(query, context,
tokenizer, max_seq_length):
"""Build token types and paddings, trim if needed, and pad if needed."""
query_ids = tokenizer.tokenize(query)
query_ids, query_types, query_pad_mask = \
build_tokens_types_paddings_from_ids(query_ids, max_seq_length, \
tokenizer.cls, tokenizer.sep, tokenizer.pad)
# Appending the title of the context at front
extended_ctx_ids = None
if context is not None:
title_ids = tokenizer.tokenize(context['title'])
ctx_ids = tokenizer.tokenize(context['text'])
extended_ctx_ids = title_ids + [tokenizer.sep] + ctx_ids
ctx_ids, ctx_types, ctx_pad_mask = \
build_tokens_types_paddings_from_ids(extended_ctx_ids,
max_seq_length, tokenizer.cls, tokenizer.sep, tokenizer.pad)
return query_ids, query_types, query_pad_mask, \
ctx_ids, ctx_types, ctx_pad_mask
# Similar code tasks/data_utils with some changes
def build_tokens_types_paddings_from_ids(text_ids, max_seq_length,
cls_id, sep_id, pad_id):
"""Build token types and paddings, trim if needed, and pad if needed."""
enc_ids = []
tokentypes_enc = []
# [CLS].
enc_ids.append(cls_id)
tokentypes_enc.append(0)
# A.
len_src = len(text_ids)
enc_ids.extend(text_ids)
tokentypes_enc.extend([0] * len_src)
# Cap the size.
if len(enc_ids) > max_seq_length - 1:
enc_ids = enc_ids[0: max_seq_length - 1]
tokentypes_enc = tokentypes_enc[0: max_seq_length - 1]
# [SEP].
enc_ids.append(sep_id)
tokentypes_enc.append(0)
num_tokens_enc = len(enc_ids)
# Padding.
padding_length = max_seq_length - len(enc_ids)
if padding_length > 0:
enc_ids.extend([pad_id] * padding_length)
tokentypes_enc.extend([pad_id] * padding_length)
pad_mask = ([1] * num_tokens_enc) + ([0] * padding_length)
pad_mask = np.array(pad_mask, dtype=np.int64)
return enc_ids, tokentypes_enc, pad_mask
def build_sample(query_ids, query_types, query_pad_mask,
ctx_ids, ctx_types, ctx_pad_mask, answers,
neg_ctx_id_list=None, neg_ctx_types_list=None,
include_neg=False):
"""Convert to numpy and return a sample consumed by the batch producer."""
query_ids = np.array(query_ids, dtype=np.int64)
query_types = np.array(query_types, dtype=np.int64)
query_mask = make_attention_mask(query_ids, query_ids)
ctx_ids = np.array(ctx_ids, dtype=np.int64)
ctx_types = np.array(ctx_types, dtype=np.int64)
ctx_mask = make_attention_mask(ctx_ids, ctx_ids)
sample = ({
'query': query_ids,
'query_mask': query_mask,
'query_types': query_types,
'query_pad_mask': query_pad_mask,
'context': ctx_ids,
'context_mask': ctx_mask,
'context_types': ctx_types,
'context_pad_mask': ctx_pad_mask,
'reference': answers
})
if include_neg:
neg_ctx_ids = np.array(neg_ctx_id_list, dtype=np.int64)
neg_ctx_id_types = np.array(neg_ctx_types_list, dtype=np.int64)
neg_ctx_mask = np.array([make_attention_mask(ids, ids) \
for ids in neg_ctx_ids], dtype=np.int64)
sample['neg_context'] = neg_ctx_ids
sample['neg_context_types'] = neg_ctx_id_types
sample['neg_context_mask'] = neg_ctx_mask
return sample
class OpenRetrievalAbstractDataset(ABC, Dataset):
"""Open Retrieval base dataset class."""
def __init__(self, task_name, dataset_name, datapaths, tokenizer, \
max_seq_length, evaluate=False):
# Store inputs.
args = get_args()
self.evaluate = evaluate
self.val_av_rank_hard_neg = args.val_av_rank_hard_neg
self.val_av_rank_other_neg = args.val_av_rank_other_neg
self.train_with_neg = args.train_with_neg
self.train_hard_neg = args.train_hard_neg
self.task_name = task_name
self.dataset_name = dataset_name
self.tokenizer = tokenizer
self.max_seq_length = max_seq_length
print_rank_0(' > building {} dataset for {}:'.format(self.task_name,
self.dataset_name))
# Process the files.
string = ' > paths:'
for path in datapaths:
string += ' ' + path
print_rank_0(string)
self.samples = []
for datapath in datapaths:
self.samples.extend(self.process_samples_from_single_path(datapath))
args = get_args()
if args.sample_rate < 1: # subsample
k = int(len(self.samples) * args.sample_rate)
self.samples = random.sample(self.samples, k)
print_rank_0(' >> total number of samples: {}'.format(
len(self.samples)))
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
raw_sample = self.samples[idx]
query_ids, query_types, query_pad_mask, ctx_ids, ctx_types, \
ctx_pad_mask = build_tokens_types_paddings_from_text( \
raw_sample['question'], raw_sample['pos_context'], \
self.tokenizer, self.max_seq_length)
if self.evaluate:
neg_ctx_list = \
raw_sample['negative_context'][:self.val_av_rank_other_neg] + \
raw_sample['hard_negative_context'][:self.val_av_rank_hard_neg]
neg_ctx_id_list, neg_ctx_types_list = \
build_token_types_from_context_list(neg_ctx_list, \
self.tokenizer, self.max_seq_length)
elif self.train_with_neg:
hard_negative_ctx = raw_sample['hard_negative_context']
negative_ctx = raw_sample['negative_context']
if True: # TODO: fix this or remove this condition
random.shuffle(hard_negative_ctx)
random.shuffle(negative_ctx)
neg_ctx_list = hard_negative_ctx[:self.train_hard_neg]
# In the Google NQ dataset by DPR paper, there are around more than
# 50 missing hard negatives in training data.
# In those cases, substitute hard negatives by simple negatives.
if len(neg_ctx_list) < self.train_hard_neg:
neg_ctx_list += negative_ctx[:self.train_hard_neg - \
len(neg_ctx_list)]
neg_ctx_id_list, neg_ctx_types_list = \
build_token_types_from_context_list(neg_ctx_list,
self.tokenizer, self.max_seq_length)
else:
neg_ctx_id_list = None
neg_ctx_types_list = None
sample = build_sample(query_ids, query_types, query_pad_mask,
ctx_ids, ctx_types, ctx_pad_mask,
raw_sample['answers'],
neg_ctx_id_list, neg_ctx_types_list,
include_neg=self.evaluate or self.train_with_neg)
return sample
@staticmethod
@abstractmethod
def process_samples_from_single_path(filename):
"""Abstract method that takes a filename and
returns a list of dataset samples, each sample being a dict of
{'text': string, 'text': string}
"""
pass
def normalize_question(question):
if question[-1] == '?':
question = question[:-1]
return question
class NQSupervisedDataset(OpenRetrievalAbstractDataset):
def __init__(self, name, datapaths, tokenizer, max_seq_length, \
evaluate=False):
super().__init__('natural_questions_ret',
name,
datapaths,
tokenizer,
max_seq_length,
evaluate=evaluate)
@staticmethod
def process_samples_from_single_path(filename):
""""Implement abstract method."""
print_rank_0(' > Processing {} ...'.format(filename))
samples = []
total = 0
with open(filename, 'r', encoding="utf-8") as f:
data = json.load(f)
for row in data:
question = normalize_question(row['question'])
pos_context = row['positive_ctxs'][0]
# Hard Negative Contexts
if len(row['hard_negative_ctxs']) > 0:
hard_neg_context = row['hard_negative_ctxs']
else:
hard_neg_context = []
# Negative Contexts
if len(row['negative_ctxs']) > 0:
neg_context = row['negative_ctxs']
else:
neg_context = []
answers = row['answers']
sample = {'question': question,
'pos_context': pos_context,
'hard_negative_context': hard_neg_context,
'negative_context': neg_context,
'answers': answers}
total += 1
samples.append(sample)
if total % 5000 == 0:
print_rank_0(' > processed {} so far ...'.format(total))
print_rank_0(' >> processed {} samples.'.format(len(samples)))
return samples
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluation utilities."""
from collections import OrderedDict
import math
import numpy as np
import time
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from megatron import get_args, print_rank_0
from megatron import mpu
from megatron.utils import average_losses_across_data_parallel_group
from tasks.finetune_utils import build_data_loader
def task_collate_fn(batch_data):
# generate batch
batch_size = len(batch_data)
tensorized = OrderedDict()
for d in batch_data:
for k, v in d.items():
tensorized.setdefault(k, []).append(v)
tensorized['query'] = torch.LongTensor(tensorized['query'])
tensorized['query_mask'] = torch.LongTensor(tensorized['query_mask'])
tensorized['query_types'] = torch.LongTensor(tensorized['query_types'])
tensorized['query_pad_mask'] = \
torch.LongTensor(tensorized['query_pad_mask'])
tensorized['context'] = torch.LongTensor(tensorized['context'])
tensorized['context_mask'] = \
torch.LongTensor(tensorized['context_mask'])
tensorized['context_types'] = \
torch.LongTensor(tensorized['context_types'])
tensorized['context_pad_mask'] = \
torch.LongTensor(tensorized['context_pad_mask'])
if 'neg_context' in tensorized:
tensorized['neg_context'] = \
torch.LongTensor(np.concatenate(tensorized['neg_context']))
tensorized['neg_context_mask'] = \
torch.LongTensor(np.concatenate(tensorized['neg_context_mask']))
tensorized['neg_context_types'] = \
torch.LongTensor(np.concatenate(tensorized['neg_context_types']))
return tensorized
def process_batch(batch):
"""Process batch and produce inputs for the model."""
query_tokens = batch['query'].long().cuda()
query_mask = (batch['query_mask'] < 0.5).cuda()
query_types = batch['query_types'].long().cuda()
query_pad_mask = batch['query_pad_mask'].long().cuda()
context_tokens = batch['context'].long().cuda()
context_mask = (batch['context_mask'] < 0.5).cuda()
context_types = batch['context_types'].long().cuda()
context_pad_mask = batch['context_pad_mask'].long().cuda()
if 'neg_context' in batch:
neg_context_tokens = batch['neg_context'].long().cuda()
neg_context_mask = (batch['neg_context_mask'] < 0.5).cuda()
neg_context_types = batch['neg_context_types'].long().cuda()
else:
neg_context_tokens = None
neg_context_mask = None
neg_context_types = None
reference = batch['reference']
return query_tokens, query_mask, query_types, query_pad_mask, \
context_tokens, context_mask, context_types, context_pad_mask, \
neg_context_tokens, neg_context_mask, neg_context_types, reference
def accuracy_func_provider(single_dataset_provider, rank0sampler=False):
"""Provide function that calculates accuracies."""
args = get_args()
print_rank_0("accuracy_func_provider is CALLED")
# Build dataloaders
datapath = args.valid_data
dataset = single_dataset_provider(datapath)
drop_last = False
if mpu.get_data_parallel_world_size() > 1 and not rank0sampler:
drop_last = True
print_rank_0(datapath)
print_rank_0(rank0sampler)
dataloader = build_data_loader(dataset,
args.eval_micro_batch_size,
num_workers=args.num_workers,
drop_last=drop_last,
task_collate_fn=task_collate_fn)
dataloaders = (dataset.dataset_name, dataloader)
def metrics_func(model, epoch, output_predictions=False):
print_rank_0('calculating metrics by accuracy func in ORQA...')
if output_predictions:
assert rank0sampler
names = 'predictions'
name, dataloader = dataloaders
if args.task == "RET-FINETUNE-NQ":
start_time = time.time()
output = retrieval_loss(model, dataloader)
stats_dict, total = output
format_string = ""
for k, v in stats_dict.items():
format_string += "|{} = {:.2f}".format(k, v / total)
print_rank_0("epoch:{}{}".format(epoch, format_string))
print_rank_0("taken time to calcuate metrics {:.3f}".format(\
time.time() - start_time))
else:
raise AssertionError("{} Task not supported".format(args.task))
return metrics_func
def retrieval_loss(model, dataloader):
args = get_args()
total = 0
topk_stats_dict = {'top{}_acc'.format(k): 0 for k in \
args.retriever_report_topk_accuracies}
stats_dict = dict(rank=0, **topk_stats_dict)
assert len(model) == 1
unwrapped_model = model[0]
unwrapped_model.eval()
with torch.no_grad():
# For all the batches in the dataset.
for batch in dataloader:
# Run the model forward.
query_tokens, query_mask, query_types, _, \
context_tokens, context_mask, context_types, _, \
neg_context_tokens, neg_context_mask, neg_context_types, \
reference = process_batch(batch)
query_logits, context_logits = unwrapped_model(query_tokens,
query_mask, query_types,
torch.cat([context_tokens, neg_context_tokens]),
torch.cat([context_mask, neg_context_mask]),
torch.cat([context_types, neg_context_types]))
retrieval_scores = torch.matmul(query_logits,
torch.transpose(context_logits, 0, 1))
if args.retriever_score_scaling:
retrieval_scores = retrieval_scores / \
math.sqrt(args.hidden_size)
local_batch_size = query_logits.shape[0]
labels = torch.arange(local_batch_size).long().cuda()
softmax_scores = F.softmax(retrieval_scores, dim=1)
sorted_vals, sorted_indices = torch.topk(softmax_scores,
k=softmax_scores.shape[1],
sorted=True)
def topk_accuracy(k):
return torch.cuda.FloatTensor(
[sum([int(labels[i] in sorted_indices[i, :k]) for i in \
range(local_batch_size)])])
def get_rank():
return torch.cuda.FloatTensor(
[sum([torch.nonzero(labels[i] == sorted_indices[i])[0][0] \
for i in range(local_batch_size)])])
topk_accs = [topk_accuracy(k) for k in \
args.retriever_report_topk_accuracies]
rank = get_rank()
losses = average_losses_across_data_parallel_group([rank, \
*topk_accs])
# create stats_dict with retrieval loss and all specified
# top-k accuracies
topk_acc_dict = {'top{}_acc'.format(k): v * 100 for k, v in \
zip(args.retriever_report_topk_accuracies, losses[1:])}
temp_stats_dict = dict(rank=losses[0], **topk_acc_dict)
for k in stats_dict.keys():
stats_dict[k] += temp_stats_dict[k]
total += local_batch_size
unwrapped_model.train()
return stats_dict, total
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ORQA finetuning/evaluation."""
from functools import partial
import sys
import math
import torch
import torch.nn.functional as F
from megatron import get_args, get_timers, get_tokenizer
from megatron import mpu, print_rank_0
from megatron.indexer import IndexBuilder
from megatron.model.biencoder_model import biencoder_model_provider
from megatron.utils import average_losses_across_data_parallel_group
from pretrain_ict import get_group_world_size_rank
from tasks.finetune_utils import finetune
from tasks.orqa.supervised.eval_utils import accuracy_func_provider
from tasks.orqa.supervised.eval_utils import process_batch, task_collate_fn
from tasks.orqa.evaluate_utils import ORQAEvaluator
def orqa(Dataset):
def cross_entropy_forward_step(batch, model):
"""Simple forward step with cross-entropy loss."""
timers = get_timers()
tokenizer = get_tokenizer()
# Get the batch.
timers('batch generator').start()
try:
batch_ = next(batch)
except BaseException:
batch_ = batch
query_tokens, query_mask, query_types, query_pad_mask, \
context_tokens, context_mask, context_types, context_pad_mask, \
neg_context_tokens, neg_context_mask, neg_context_types, \
reference = process_batch(batch_)
timers('batch generator').stop()
local_batch_size = query_tokens.shape[0]
# Text representation of query and context
query_list, context_list = [], []
for i in range(local_batch_size):
query_list.append(tokenizer.decode(query_tokens[i].tolist()))
context_list.append(tokenizer.decode(context_tokens[i].tolist()))
if neg_context_tokens is not None:
context_tokens = torch.cat([context_tokens, neg_context_tokens])
context_mask = torch.cat([context_mask, neg_context_mask])
context_types = torch.cat([context_types, neg_context_types])
# Forward model.
output_tensor = model(query_tokens, query_mask,
query_types, context_tokens,
context_mask, context_types)
return output_tensor, partial(cross_entropy_loss_func, query_tokens, context_tokens)
def cross_entropy_loss_func(query_tokens, context_tokens, output_tensor):
args = get_args()
local_batch_size = query_tokens.shape[0]
group, rank, world_size = get_group_world_size_rank()
# recall we assert that model_parallel_size == 1
global_batch_size = world_size * local_batch_size
query_logits, context_logits = output_tensor
if world_size > 1:
input_ = torch.empty_like(context_logits).copy_(\
context_logits).detach_()
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank].copy_(input_)
torch.distributed.all_gather(tensor_list, input_, group=group)
# Check if all-gather happens in order
assert tensor_list[rank].sum().item() == \
context_logits.sum().item()
# Preserves the gradient
tensor_list[rank] = context_logits
all_context_logits = torch.cat(tensor_list, dim=0).contiguous()
# Query tensors
input_ = torch.empty_like(query_logits).copy_(\
query_logits).detach_()
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank].copy_(input_)
torch.distributed.all_gather(tensor_list, input_, group=group)
# Check if all-gather happens in order
assert tensor_list[rank].sum().item() == query_logits.sum().item()
# Preserves the gradient
tensor_list[rank] = query_logits
all_query_logits = torch.cat(tensor_list, dim=0).contiguous()
else:
all_query_logits = query_logits
all_context_logits = context_logits
retrieval_scores = torch.matmul(all_query_logits,
torch.transpose(all_context_logits, 0, 1))
# Scaling the retrieval scores
if args.retriever_score_scaling:
retrieval_scores = retrieval_scores / math.sqrt(args.hidden_size)
if args.train_with_neg:
# if the world size is 3, local batch size is 4, and
# local context size is 8, what we want is
# labels = [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19]
labels = []
local_context_size = context_tokens.shape[0]
for i in range(world_size):
j = i * local_context_size
labels.extend(list(range(j, j + local_batch_size)))
labels = torch.LongTensor(labels).cuda()
assert len(labels) == global_batch_size
else:
labels = torch.arange(global_batch_size).long().cuda()
# Cross-entropy loss.
softmax_scores = F.log_softmax(retrieval_scores, dim=1)
loss = F.nll_loss(softmax_scores, labels, reduction='mean')
max_score, max_idxs = torch.max(softmax_scores, 1)
correct_predictions_count = (max_idxs == labels).sum().float()
# Reduce loss for logging.
reduced_loss = average_losses_across_data_parallel_group([loss, \
correct_predictions_count])
# Loss scaling for correct losses in Supervised Retrieval
loss = loss * mpu.get_data_parallel_world_size()
return loss, {'lm loss': reduced_loss[0],
'correct_prediction_count': reduced_loss[1]}
def train_valid_datasets_provider():
"""Build train and validation dataset."""
args = get_args()
tokenizer = get_tokenizer()
train_dataset = Dataset('training',
args.train_data,
tokenizer,
args.retriever_seq_length,
evaluate=False)
valid_dataset = Dataset('validation',
args.valid_data,
tokenizer,
args.retriever_seq_length,
evaluate=True)
return train_dataset, valid_dataset
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
args = get_args()
print_rank_0('building retriever model for {} ...'.format(args.task))
model = biencoder_model_provider(only_context_model=False,
only_query_model=False,
biencoder_shared_query_context_model=\
args.biencoder_shared_query_context_model,
pre_process=pre_process, post_process=post_process)
return model
def single_dataset_provider(datapath):
args = get_args()
tokenizer = get_tokenizer()
name = datapath[0].split('/')[-1].split('.')[0]
return Dataset(name,
datapath,
tokenizer,
args.retriever_seq_length,
evaluate=True)
def metrics_func_provider():
"""Provide metrics callback function."""
return accuracy_func_provider(single_dataset_provider)
"""Finetune/evaluate."""
finetune(train_valid_datasets_provider,
model_provider,
forward_step=cross_entropy_forward_step,
end_of_epoch_callback_provider=metrics_func_provider,
task_collate_fn=task_collate_fn)
def main():
args = get_args()
if args.task == 'RET-FINETUNE-NQ':
from tasks.orqa.supervised.data import NQSupervisedDataset as Dataset
else:
raise NotImplementedError('ORQA task {} is not implemented.'.format(
args.task))
orqa(Dataset)
......@@ -22,7 +22,7 @@ from multiprocessing import Pool as ProcessPool
from typing import Tuple, List, Dict
import regex as re
from tasks.orqa.natural_questions.tokenizers import SimpleTokenizer
from tasks.orqa.unsupervised.tokenizers import SimpleTokenizer
logger = logging.getLogger(__name__)
......
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir)))
from megatron import print_rank_0
from megatron.indexer import IndexBuilder
from megatron.initialize import initialize_megatron
def main():
"""Create a BlockData data structure by running an IndexBuilder over an ICT Dataset
- Include all args needed for initial model specification
Other key args:
--block-data-path: path to write to
--ict-load or --realm-load: path to checkpoint with which to embed
--data-path and --titles-data-path: paths for dataset
--indexer-log-interval: reporting interval
--indexer-batch-size: size specific for indexer jobs
Check README.md for example script
"""
initialize_megatron(extra_args_provider=None,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
index_builder = IndexBuilder()
index_builder.build_and_save_index()
print_rank_0("Build and save indices: done!")
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment