Commit f9267205 authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

fixing model evaluation of retriver

parent 6d03d7af
......@@ -53,11 +53,12 @@ class IndexBuilder(object):
args.only_context_model = only_context_model
args.only_query_model = False
model = get_model(biencoder_model_provider)
#model = get_model(biencoder_model_provider)
#model = get_model(lambda: biencoder_model_provider(only_context_model \
# = only_context_model, biencoder_shared_query_context_model = \
# self.biencoder_shared_query_context_model))
model = get_model(biencoder_model_provider(only_context_model \
= only_context_model, biencoder_shared_query_context_model = \
self.biencoder_shared_query_context_model))
self.model = load_biencoder_checkpoint(model,
only_context_model=only_context_model)
......
......@@ -15,20 +15,20 @@ from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal
from .module import MegatronModule
#def biencoder_model_provider(only_query_model=False,
# only_context_model=False,
# biencoder_shared_query_context_model=False,
# pre_process=True,
#def biencoder_model_provider(pre_process=True,
# post_process=True):
def biencoder_model_provider(pre_process=True,
def biencoder_model_provider(only_query_model=False,
only_context_model=False,
biencoder_shared_query_context_model=False,
pre_process=True,
post_process=True):
"""Build the model."""
args = get_args()
#args = get_args()
biencoder_shared_query_context_model = args.biencoder_shared_query_context_model
only_context_model = args.only_context_model
only_query_model = args.only_query_model
#biencoder_shared_query_context_model = args.biencoder_shared_query_context_model
#only_context_model = args.only_context_model
#only_query_model = args.only_query_model
assert mpu.get_tensor_model_parallel_world_size() == 1 and \
mpu.get_pipeline_model_parallel_world_size() == 1, \
......
......@@ -33,15 +33,15 @@ from megatron.utils import average_losses_across_data_parallel_group
def pretrain_ict_model_provider():
args = get_args()
args.only_context_model = False
args.only_query_model = False
model = biencoder_model_provider()
#args.only_context_model = False
#args.only_query_model = False
#model = biencoder_model_provider()
#model = biencoder_model_provider(
# only_context_model=False,
# only_query_model=False,
# biencoder_shared_query_context_model=\
# args.biencoder_shared_query_context_model)
model = biencoder_model_provider(
only_context_model=False,
only_query_model=False,
biencoder_shared_query_context_model=\
args.biencoder_shared_query_context_model)
return model
def get_group_world_size_rank():
......
......@@ -110,7 +110,7 @@ if __name__ == '__main__':
from glue.finetune import main
elif args.task in ['LAMBADA', 'WIKITEXT103']:
from zeroshot_gpt.evaluate import main
elif args.task in ['ICT-ZEROSHOT-NQ']:
elif args.task in ['ICT-ZEROSHOT-NQ', 'RETRIEVER-EVAL']:
from orqa.evaluate_orqa import main
elif args.task in ['RET-FINETUNE-NQ']:
from orqa.supervised.finetune import main
......
......@@ -18,6 +18,15 @@
import os
import sys
#sys.path.append(
# os.path.abspath(
# os.path.join(
# os.path.join(os.path.dirname(__file__), os.path.pardir),
# os.path.pardir,
# )
# )
#)
from megatron import get_args
from megatron.indexer import IndexBuilder
from tasks.orqa.evaluate_utils import ORQAEvaluator
......@@ -26,6 +35,8 @@ def main():
"""
Main program
"""
#initialize_megatron(extra_args_provider=None,
# args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
args = get_args()
......@@ -42,6 +53,8 @@ def main():
Check README.md for example script
"""
#print_rank_0("Starting index builder!")
index_builder = IndexBuilder()
index_builder.build_and_save_index()
print_rank_0("Build and save indices: done!")
......
......@@ -44,14 +44,14 @@ class ORQAEvaluator(object):
if args.biencoder_shared_query_context_model:
only_query_model = False
args.only_query_model = only_query_model
args.only_context_model = False
#args.only_query_model = only_query_model
#args.only_context_model = False
#model = get_model(lambda: biencoder_model_provider(only_query_model=\
# only_query_model, biencoder_shared_query_context_model=\
# args.biencoder_shared_query_context_model))
model = get_model(lambda: biencoder_model_provider(only_query_model=\
only_query_model, biencoder_shared_query_context_model=\
args.biencoder_shared_query_context_model))
model = get_model(biencoder_model_provider)
#model = get_model(biencoder_model_provider)
self.model = load_biencoder_checkpoint(model,
only_query_model=only_query_model)
......
......@@ -184,15 +184,16 @@ def orqa(Dataset): # , name_from_datapath_func):
"""Build the model."""
args = get_args()
print_rank_0('building retriever model for {} ...'.format(args.task))
args.only_context_model=False
args.only_query_model=False
model = biencoder_model_provider()
#args.only_context_model=False
#args.only_query_model=False
#model = biencoder_model_provider()
#model = biencoder_model_provider(only_context_model=False,
# only_query_model=False,
# biencoder_shared_query_context_model=\
# args.biencoder_shared_query_context_model,
# pre_process=pre_process, post_process=post_process)
model = biencoder_model_provider(only_context_model=False,
only_query_model=False,
biencoder_shared_query_context_model=\
args.biencoder_shared_query_context_model,
pre_process=pre_process, post_process=post_process)
return model
def single_dataset_provider(datapath):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment