"torchvision/csrc/cpu/image" did not exist on "fa6af6d1e6f050d0a930e6f59894b7dd40aa869c"
Commit 6d03d7af authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

DPR finetune and evaluation

parent d2d5086e
......@@ -413,11 +413,9 @@ def load_biencoder_checkpoint(model, only_query_model=False,
if only_context_model:
ret_state_dict.pop('query_model')
#print_rank_0(len(model))
#sys.exit()
#assert len(model) == 1
#model[0].load_state_dict(ret_state_dict)
model.load_state_dict(ret_state_dict)
assert len(model) == 1
model[0].load_state_dict(ret_state_dict)
torch.distributed.barrier()
if mpu.get_data_parallel_rank() == 0:
......
......@@ -45,26 +45,25 @@ class IndexBuilder(object):
"""
Load the necessary attributes: model, dataloader and empty BlockData
"""
args = get_args()
only_context_model = True
if self.biencoder_shared_query_context_model:
only_context_model = False
model = get_model(lambda: biencoder_model_provider(only_context_model \
= only_context_model, biencoder_shared_query_context_model = \
self.biencoder_shared_query_context_model, \
pre_process=self.pre_process, post_process=self.post_process))
args.only_context_model = only_context_model
args.only_query_model = False
model = get_model(biencoder_model_provider)
#model = biencoder_model_provider(only_context_model \
#model = get_model(lambda: biencoder_model_provider(only_context_model \
# = only_context_model, biencoder_shared_query_context_model = \
# self.biencoder_shared_query_context_model, \
# pre_process=self.pre_process, post_process=self.post_process)
# self.biencoder_shared_query_context_model))
self.model = load_biencoder_checkpoint(model,
only_context_model=only_context_model)
#assert len(self.model) == 1
#self.model[0].eval()
self.model.eval()
assert len(self.model) == 1
self.model[0].eval()
self.dataset = get_open_retrieval_wiki_dataset()
self.dataloader = iter(get_one_epoch_dataloader(self.dataset, \
......@@ -92,12 +91,11 @@ class IndexBuilder(object):
distributed setting will be consolidated by the rank 0 process
and saved as a final pickled BlockData.
"""
#assert len(self.model) == 1
#unwrapped_model = self.model[0]
unwrapped_model = self.model
assert len(self.model) == 1
unwrapped_model = self.model[0]
while not hasattr(unwrapped_model, 'embed_text'):
unwrapped_model = unwrapped_model.module
print_rank_0("hasattr")
while True:
try:
......@@ -108,17 +106,6 @@ class IndexBuilder(object):
except (StopIteration, IndexError):
break
print_rank_0(context_tokens)
print_rank_0(context_mask)
print_rank_0(context_types)
#if torch.cuda.is_available():
# print_rank_0("cuda available")
#print_rank_0(torch.cuda.current_device())
#print_rank_0(torch.cuda.get_device_name())
print_rank_0(next(unwrapped_model.parameters()).device)
print_rank_0(next(unwrapped_model.context_model.parameters()).device)
#print_rank_0("After get_open_retrieval_batch")
# TODO: can we add with torch.no_grad() to reduce memory usage
# detach, separate fields and add to BlockData
assert context_mask.dtype == torch.bool
......@@ -126,8 +113,6 @@ class IndexBuilder(object):
unwrapped_model.context_model, context_tokens, context_mask,
context_types)
sys.exit()
context_logits = detach(context_logits)
row_id = detach(row_id)
......
......@@ -15,14 +15,21 @@ from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal
from .module import MegatronModule
def biencoder_model_provider(only_query_model=False,
only_context_model=False,
biencoder_shared_query_context_model=False,
pre_process=True,
#def biencoder_model_provider(only_query_model=False,
# only_context_model=False,
# biencoder_shared_query_context_model=False,
# pre_process=True,
# post_process=True):
def biencoder_model_provider(pre_process=True,
post_process=True):
"""Build the model."""
args = get_args()
biencoder_shared_query_context_model = args.biencoder_shared_query_context_model
only_context_model = args.only_context_model
only_query_model = args.only_query_model
assert mpu.get_tensor_model_parallel_world_size() == 1 and \
mpu.get_pipeline_model_parallel_world_size() == 1, \
"Model parallel size > 1 not supported for ICT"
......@@ -266,11 +273,6 @@ class PretrainedBertModel(MegatronModule):
#extended_attention_mask = bert_extended_attention_mask(attention_mask)
position_ids = bert_position_ids(input_ids)
print_rank_0(input_ids.device)
print_rank_0(position_ids.device)
print_rank_0(extended_attention_mask.device)
print_rank_0(tokentype_ids.device)
lm_output = self.language_model(input_ids,
position_ids,
extended_attention_mask,
......
......@@ -338,11 +338,6 @@ class TransformerLanguageModel(MegatronModule):
get_key_value=False, pooling_sequence_index=0,
enc_hidden_states=None, output_enc_hidden=False):
print_rank_0("before self.embedding")
print_rank_0(enc_input_ids.device)
print_rank_0(enc_position_ids.device)
print_rank_0(tokentype_ids.device)
# Embeddings.
if self.pre_process:
embedding_output = self.embedding(enc_input_ids, enc_position_ids,
......
......@@ -33,11 +33,15 @@ from megatron.utils import average_losses_across_data_parallel_group
def pretrain_ict_model_provider():
args = get_args()
model = biencoder_model_provider(
only_context_model=False,
only_query_model=False,
biencoder_shared_query_context_model=\
args.biencoder_shared_query_context_model)
args.only_context_model = False
args.only_query_model = False
model = biencoder_model_provider()
#model = biencoder_model_provider(
# only_context_model=False,
# only_query_model=False,
# biencoder_shared_query_context_model=\
# args.biencoder_shared_query_context_model)
return model
def get_group_world_size_rank():
......
......@@ -19,6 +19,7 @@ import os
import sys
from megatron import get_args
from megatron.indexer import IndexBuilder
from tasks.orqa.evaluate_utils import ORQAEvaluator
def main():
......@@ -28,6 +29,23 @@ def main():
args = get_args()
"""Create a BlockData data structure by running an IndexBuilder over an ICT Dataset
- Include all args needed for initial model specification
Other key args:
--block-data-path: path to write to
--ict-load or --realm-load: path to checkpoint with which to embed
--data-path and --titles-data-path: paths for dataset
--indexer-log-interval: reporting interval
--indexer-batch-size: size specific for indexer jobs
Check README.md for example script
"""
index_builder = IndexBuilder()
index_builder.build_and_save_index()
print_rank_0("Build and save indices: done!")
# Set up the model and evaluator
evaluator = ORQAEvaluator()
......@@ -37,4 +55,4 @@ def main():
if args.qa_data_test is not None:
evaluator.evaluate(args.qa_data_test, "TEST")
......@@ -44,9 +44,14 @@ class ORQAEvaluator(object):
if args.biencoder_shared_query_context_model:
only_query_model = False
model = get_model(lambda: biencoder_model_provider(only_query_model=\
only_query_model, biencoder_shared_query_context_model=\
args.biencoder_shared_query_context_model))
args.only_query_model = only_query_model
args.only_context_model = False
#model = get_model(lambda: biencoder_model_provider(only_query_model=\
# only_query_model, biencoder_shared_query_context_model=\
# args.biencoder_shared_query_context_model))
model = get_model(biencoder_model_provider)
self.model = load_biencoder_checkpoint(model,
only_query_model=only_query_model)
......
......@@ -16,6 +16,7 @@
"""ORQA finetuning/evaluation."""
from functools import partial
import sys
import math
import torch
......@@ -183,11 +184,15 @@ def orqa(Dataset): # , name_from_datapath_func):
"""Build the model."""
args = get_args()
print_rank_0('building retriever model for {} ...'.format(args.task))
model = biencoder_model_provider(only_context_model=False,
only_query_model=False,
biencoder_shared_query_context_model=\
args.biencoder_shared_query_context_model,
pre_process=pre_process, post_process=post_process)
args.only_context_model=False
args.only_query_model=False
model = biencoder_model_provider()
#model = biencoder_model_provider(only_context_model=False,
# only_query_model=False,
# biencoder_shared_query_context_model=\
# args.biencoder_shared_query_context_model,
# pre_process=pre_process, post_process=post_process)
return model
def single_dataset_provider(datapath):
......@@ -228,29 +233,15 @@ def orqa(Dataset): # , name_from_datapath_func):
def main():
args = get_args()
#if args.task == 'RET-FINETUNE-NQ':
# from tasks.orqa.supervised.data import NQSupervisedDataset as Dataset
if args.task == 'RET-FINETUNE-NQ':
from tasks.orqa.supervised.data import NQSupervisedDataset as Dataset
#def name_from_datapath(datapath):
# return datapath[0].split('/')[-1].split('.')[0]
#else:
# raise NotImplementedError('ORQA task {} is not implemented.'.format(
# args.task))
#orqa(Dataset) #, name_from_datapath)
index_builder = IndexBuilder()
index_builder.build_and_save_index()
print_rank_0("Build and save indices: done!")
# Set up the model and evaluator
#evaluator = ORQAEvaluator()
# Run evaluation
#if args.qa_data_dev is not None:
# evaluator.evaluate(args.qa_data_dev, "DEV")
#if args.qa_data_test is not None:
# evaluator.evaluate(args.qa_data_test, "TEST")
else:
raise NotImplementedError('ORQA task {} is not implemented.'.format(
args.task))
orqa(Dataset) #, name_from_datapath)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment