Commit 6d03d7af authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

DPR finetune and evaluation

parent d2d5086e
...@@ -413,11 +413,9 @@ def load_biencoder_checkpoint(model, only_query_model=False, ...@@ -413,11 +413,9 @@ def load_biencoder_checkpoint(model, only_query_model=False,
if only_context_model: if only_context_model:
ret_state_dict.pop('query_model') ret_state_dict.pop('query_model')
#print_rank_0(len(model)) assert len(model) == 1
#sys.exit() model[0].load_state_dict(ret_state_dict)
#assert len(model) == 1
#model[0].load_state_dict(ret_state_dict)
model.load_state_dict(ret_state_dict)
torch.distributed.barrier() torch.distributed.barrier()
if mpu.get_data_parallel_rank() == 0: if mpu.get_data_parallel_rank() == 0:
......
...@@ -45,26 +45,25 @@ class IndexBuilder(object): ...@@ -45,26 +45,25 @@ class IndexBuilder(object):
""" """
Load the necessary attributes: model, dataloader and empty BlockData Load the necessary attributes: model, dataloader and empty BlockData
""" """
args = get_args()
only_context_model = True only_context_model = True
if self.biencoder_shared_query_context_model: if self.biencoder_shared_query_context_model:
only_context_model = False only_context_model = False
model = get_model(lambda: biencoder_model_provider(only_context_model \ args.only_context_model = only_context_model
= only_context_model, biencoder_shared_query_context_model = \ args.only_query_model = False
self.biencoder_shared_query_context_model, \
pre_process=self.pre_process, post_process=self.post_process)) model = get_model(biencoder_model_provider)
#model = biencoder_model_provider(only_context_model \ #model = get_model(lambda: biencoder_model_provider(only_context_model \
# = only_context_model, biencoder_shared_query_context_model = \ # = only_context_model, biencoder_shared_query_context_model = \
# self.biencoder_shared_query_context_model, \ # self.biencoder_shared_query_context_model))
# pre_process=self.pre_process, post_process=self.post_process)
self.model = load_biencoder_checkpoint(model, self.model = load_biencoder_checkpoint(model,
only_context_model=only_context_model) only_context_model=only_context_model)
#assert len(self.model) == 1 assert len(self.model) == 1
#self.model[0].eval() self.model[0].eval()
self.model.eval()
self.dataset = get_open_retrieval_wiki_dataset() self.dataset = get_open_retrieval_wiki_dataset()
self.dataloader = iter(get_one_epoch_dataloader(self.dataset, \ self.dataloader = iter(get_one_epoch_dataloader(self.dataset, \
...@@ -92,12 +91,11 @@ class IndexBuilder(object): ...@@ -92,12 +91,11 @@ class IndexBuilder(object):
distributed setting will be consolidated by the rank 0 process distributed setting will be consolidated by the rank 0 process
and saved as a final pickled BlockData. and saved as a final pickled BlockData.
""" """
#assert len(self.model) == 1 assert len(self.model) == 1
#unwrapped_model = self.model[0] unwrapped_model = self.model[0]
unwrapped_model = self.model
while not hasattr(unwrapped_model, 'embed_text'): while not hasattr(unwrapped_model, 'embed_text'):
unwrapped_model = unwrapped_model.module unwrapped_model = unwrapped_model.module
print_rank_0("hasattr")
while True: while True:
try: try:
...@@ -108,17 +106,6 @@ class IndexBuilder(object): ...@@ -108,17 +106,6 @@ class IndexBuilder(object):
except (StopIteration, IndexError): except (StopIteration, IndexError):
break break
print_rank_0(context_tokens)
print_rank_0(context_mask)
print_rank_0(context_types)
#if torch.cuda.is_available():
# print_rank_0("cuda available")
#print_rank_0(torch.cuda.current_device())
#print_rank_0(torch.cuda.get_device_name())
print_rank_0(next(unwrapped_model.parameters()).device)
print_rank_0(next(unwrapped_model.context_model.parameters()).device)
#print_rank_0("After get_open_retrieval_batch")
# TODO: can we add with torch.no_grad() to reduce memory usage # TODO: can we add with torch.no_grad() to reduce memory usage
# detach, separate fields and add to BlockData # detach, separate fields and add to BlockData
assert context_mask.dtype == torch.bool assert context_mask.dtype == torch.bool
...@@ -126,8 +113,6 @@ class IndexBuilder(object): ...@@ -126,8 +113,6 @@ class IndexBuilder(object):
unwrapped_model.context_model, context_tokens, context_mask, unwrapped_model.context_model, context_tokens, context_mask,
context_types) context_types)
sys.exit()
context_logits = detach(context_logits) context_logits = detach(context_logits)
row_id = detach(row_id) row_id = detach(row_id)
......
...@@ -15,14 +15,21 @@ from megatron.model.utils import init_method_normal ...@@ -15,14 +15,21 @@ from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal from megatron.model.utils import scaled_init_method_normal
from .module import MegatronModule from .module import MegatronModule
def biencoder_model_provider(only_query_model=False, #def biencoder_model_provider(only_query_model=False,
only_context_model=False, # only_context_model=False,
biencoder_shared_query_context_model=False, # biencoder_shared_query_context_model=False,
pre_process=True, # pre_process=True,
# post_process=True):
def biencoder_model_provider(pre_process=True,
post_process=True): post_process=True):
"""Build the model.""" """Build the model."""
args = get_args() args = get_args()
biencoder_shared_query_context_model = args.biencoder_shared_query_context_model
only_context_model = args.only_context_model
only_query_model = args.only_query_model
assert mpu.get_tensor_model_parallel_world_size() == 1 and \ assert mpu.get_tensor_model_parallel_world_size() == 1 and \
mpu.get_pipeline_model_parallel_world_size() == 1, \ mpu.get_pipeline_model_parallel_world_size() == 1, \
"Model parallel size > 1 not supported for ICT" "Model parallel size > 1 not supported for ICT"
...@@ -266,11 +273,6 @@ class PretrainedBertModel(MegatronModule): ...@@ -266,11 +273,6 @@ class PretrainedBertModel(MegatronModule):
#extended_attention_mask = bert_extended_attention_mask(attention_mask) #extended_attention_mask = bert_extended_attention_mask(attention_mask)
position_ids = bert_position_ids(input_ids) position_ids = bert_position_ids(input_ids)
print_rank_0(input_ids.device)
print_rank_0(position_ids.device)
print_rank_0(extended_attention_mask.device)
print_rank_0(tokentype_ids.device)
lm_output = self.language_model(input_ids, lm_output = self.language_model(input_ids,
position_ids, position_ids,
extended_attention_mask, extended_attention_mask,
......
...@@ -338,11 +338,6 @@ class TransformerLanguageModel(MegatronModule): ...@@ -338,11 +338,6 @@ class TransformerLanguageModel(MegatronModule):
get_key_value=False, pooling_sequence_index=0, get_key_value=False, pooling_sequence_index=0,
enc_hidden_states=None, output_enc_hidden=False): enc_hidden_states=None, output_enc_hidden=False):
print_rank_0("before self.embedding")
print_rank_0(enc_input_ids.device)
print_rank_0(enc_position_ids.device)
print_rank_0(tokentype_ids.device)
# Embeddings. # Embeddings.
if self.pre_process: if self.pre_process:
embedding_output = self.embedding(enc_input_ids, enc_position_ids, embedding_output = self.embedding(enc_input_ids, enc_position_ids,
......
...@@ -33,11 +33,15 @@ from megatron.utils import average_losses_across_data_parallel_group ...@@ -33,11 +33,15 @@ from megatron.utils import average_losses_across_data_parallel_group
def pretrain_ict_model_provider(): def pretrain_ict_model_provider():
args = get_args() args = get_args()
model = biencoder_model_provider( args.only_context_model = False
only_context_model=False, args.only_query_model = False
only_query_model=False, model = biencoder_model_provider()
biencoder_shared_query_context_model=\
args.biencoder_shared_query_context_model) #model = biencoder_model_provider(
# only_context_model=False,
# only_query_model=False,
# biencoder_shared_query_context_model=\
# args.biencoder_shared_query_context_model)
return model return model
def get_group_world_size_rank(): def get_group_world_size_rank():
......
...@@ -19,6 +19,7 @@ import os ...@@ -19,6 +19,7 @@ import os
import sys import sys
from megatron import get_args from megatron import get_args
from megatron.indexer import IndexBuilder
from tasks.orqa.evaluate_utils import ORQAEvaluator from tasks.orqa.evaluate_utils import ORQAEvaluator
def main(): def main():
...@@ -28,6 +29,23 @@ def main(): ...@@ -28,6 +29,23 @@ def main():
args = get_args() args = get_args()
"""Create a BlockData data structure by running an IndexBuilder over an ICT Dataset
- Include all args needed for initial model specification
Other key args:
--block-data-path: path to write to
--ict-load or --realm-load: path to checkpoint with which to embed
--data-path and --titles-data-path: paths for dataset
--indexer-log-interval: reporting interval
--indexer-batch-size: size specific for indexer jobs
Check README.md for example script
"""
index_builder = IndexBuilder()
index_builder.build_and_save_index()
print_rank_0("Build and save indices: done!")
# Set up the model and evaluator # Set up the model and evaluator
evaluator = ORQAEvaluator() evaluator = ORQAEvaluator()
...@@ -37,4 +55,4 @@ def main(): ...@@ -37,4 +55,4 @@ def main():
if args.qa_data_test is not None: if args.qa_data_test is not None:
evaluator.evaluate(args.qa_data_test, "TEST") evaluator.evaluate(args.qa_data_test, "TEST")
...@@ -44,9 +44,14 @@ class ORQAEvaluator(object): ...@@ -44,9 +44,14 @@ class ORQAEvaluator(object):
if args.biencoder_shared_query_context_model: if args.biencoder_shared_query_context_model:
only_query_model = False only_query_model = False
model = get_model(lambda: biencoder_model_provider(only_query_model=\ args.only_query_model = only_query_model
only_query_model, biencoder_shared_query_context_model=\ args.only_context_model = False
args.biencoder_shared_query_context_model))
#model = get_model(lambda: biencoder_model_provider(only_query_model=\
# only_query_model, biencoder_shared_query_context_model=\
# args.biencoder_shared_query_context_model))
model = get_model(biencoder_model_provider)
self.model = load_biencoder_checkpoint(model, self.model = load_biencoder_checkpoint(model,
only_query_model=only_query_model) only_query_model=only_query_model)
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
"""ORQA finetuning/evaluation.""" """ORQA finetuning/evaluation."""
from functools import partial from functools import partial
import sys
import math import math
import torch import torch
...@@ -183,11 +184,15 @@ def orqa(Dataset): # , name_from_datapath_func): ...@@ -183,11 +184,15 @@ def orqa(Dataset): # , name_from_datapath_func):
"""Build the model.""" """Build the model."""
args = get_args() args = get_args()
print_rank_0('building retriever model for {} ...'.format(args.task)) print_rank_0('building retriever model for {} ...'.format(args.task))
model = biencoder_model_provider(only_context_model=False, args.only_context_model=False
only_query_model=False, args.only_query_model=False
biencoder_shared_query_context_model=\ model = biencoder_model_provider()
args.biencoder_shared_query_context_model,
pre_process=pre_process, post_process=post_process) #model = biencoder_model_provider(only_context_model=False,
# only_query_model=False,
# biencoder_shared_query_context_model=\
# args.biencoder_shared_query_context_model,
# pre_process=pre_process, post_process=post_process)
return model return model
def single_dataset_provider(datapath): def single_dataset_provider(datapath):
...@@ -228,29 +233,15 @@ def orqa(Dataset): # , name_from_datapath_func): ...@@ -228,29 +233,15 @@ def orqa(Dataset): # , name_from_datapath_func):
def main(): def main():
args = get_args() args = get_args()
#if args.task == 'RET-FINETUNE-NQ': if args.task == 'RET-FINETUNE-NQ':
# from tasks.orqa.supervised.data import NQSupervisedDataset as Dataset from tasks.orqa.supervised.data import NQSupervisedDataset as Dataset
#def name_from_datapath(datapath): #def name_from_datapath(datapath):
# return datapath[0].split('/')[-1].split('.')[0] # return datapath[0].split('/')[-1].split('.')[0]
#else: else:
# raise NotImplementedError('ORQA task {} is not implemented.'.format( raise NotImplementedError('ORQA task {} is not implemented.'.format(
# args.task)) args.task))
#orqa(Dataset) #, name_from_datapath)
index_builder = IndexBuilder()
index_builder.build_and_save_index()
print_rank_0("Build and save indices: done!")
# Set up the model and evaluator
#evaluator = ORQAEvaluator()
# Run evaluation
#if args.qa_data_dev is not None:
# evaluator.evaluate(args.qa_data_dev, "DEV")
#if args.qa_data_test is not None:
# evaluator.evaluate(args.qa_data_test, "TEST")
orqa(Dataset) #, name_from_datapath)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment