Commit f9267205 authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

fixing model evaluation of retriver

parent 6d03d7af
...@@ -53,11 +53,12 @@ class IndexBuilder(object): ...@@ -53,11 +53,12 @@ class IndexBuilder(object):
args.only_context_model = only_context_model args.only_context_model = only_context_model
args.only_query_model = False args.only_query_model = False
model = get_model(biencoder_model_provider) #model = get_model(biencoder_model_provider)
#model = get_model(lambda: biencoder_model_provider(only_context_model \ #model = get_model(lambda: biencoder_model_provider(only_context_model \
# = only_context_model, biencoder_shared_query_context_model = \ model = get_model(biencoder_model_provider(only_context_model \
# self.biencoder_shared_query_context_model)) = only_context_model, biencoder_shared_query_context_model = \
self.biencoder_shared_query_context_model))
self.model = load_biencoder_checkpoint(model, self.model = load_biencoder_checkpoint(model,
only_context_model=only_context_model) only_context_model=only_context_model)
......
...@@ -15,20 +15,20 @@ from megatron.model.utils import init_method_normal ...@@ -15,20 +15,20 @@ from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal from megatron.model.utils import scaled_init_method_normal
from .module import MegatronModule from .module import MegatronModule
#def biencoder_model_provider(only_query_model=False, #def biencoder_model_provider(pre_process=True,
# only_context_model=False,
# biencoder_shared_query_context_model=False,
# pre_process=True,
# post_process=True): # post_process=True):
def biencoder_model_provider(pre_process=True, def biencoder_model_provider(only_query_model=False,
only_context_model=False,
biencoder_shared_query_context_model=False,
pre_process=True,
post_process=True): post_process=True):
"""Build the model.""" """Build the model."""
args = get_args() #args = get_args()
biencoder_shared_query_context_model = args.biencoder_shared_query_context_model #biencoder_shared_query_context_model = args.biencoder_shared_query_context_model
only_context_model = args.only_context_model #only_context_model = args.only_context_model
only_query_model = args.only_query_model #only_query_model = args.only_query_model
assert mpu.get_tensor_model_parallel_world_size() == 1 and \ assert mpu.get_tensor_model_parallel_world_size() == 1 and \
mpu.get_pipeline_model_parallel_world_size() == 1, \ mpu.get_pipeline_model_parallel_world_size() == 1, \
......
...@@ -33,15 +33,15 @@ from megatron.utils import average_losses_across_data_parallel_group ...@@ -33,15 +33,15 @@ from megatron.utils import average_losses_across_data_parallel_group
def pretrain_ict_model_provider(): def pretrain_ict_model_provider():
args = get_args() args = get_args()
args.only_context_model = False #args.only_context_model = False
args.only_query_model = False #args.only_query_model = False
model = biencoder_model_provider() #model = biencoder_model_provider()
#model = biencoder_model_provider( model = biencoder_model_provider(
# only_context_model=False, only_context_model=False,
# only_query_model=False, only_query_model=False,
# biencoder_shared_query_context_model=\ biencoder_shared_query_context_model=\
# args.biencoder_shared_query_context_model) args.biencoder_shared_query_context_model)
return model return model
def get_group_world_size_rank(): def get_group_world_size_rank():
......
...@@ -110,7 +110,7 @@ if __name__ == '__main__': ...@@ -110,7 +110,7 @@ if __name__ == '__main__':
from glue.finetune import main from glue.finetune import main
elif args.task in ['LAMBADA', 'WIKITEXT103']: elif args.task in ['LAMBADA', 'WIKITEXT103']:
from zeroshot_gpt.evaluate import main from zeroshot_gpt.evaluate import main
elif args.task in ['ICT-ZEROSHOT-NQ']: elif args.task in ['ICT-ZEROSHOT-NQ', 'RETRIEVER-EVAL']:
from orqa.evaluate_orqa import main from orqa.evaluate_orqa import main
elif args.task in ['RET-FINETUNE-NQ']: elif args.task in ['RET-FINETUNE-NQ']:
from orqa.supervised.finetune import main from orqa.supervised.finetune import main
......
...@@ -18,6 +18,15 @@ ...@@ -18,6 +18,15 @@
import os import os
import sys import sys
#sys.path.append(
# os.path.abspath(
# os.path.join(
# os.path.join(os.path.dirname(__file__), os.path.pardir),
# os.path.pardir,
# )
# )
#)
from megatron import get_args from megatron import get_args
from megatron.indexer import IndexBuilder from megatron.indexer import IndexBuilder
from tasks.orqa.evaluate_utils import ORQAEvaluator from tasks.orqa.evaluate_utils import ORQAEvaluator
...@@ -26,6 +35,8 @@ def main(): ...@@ -26,6 +35,8 @@ def main():
""" """
Main program Main program
""" """
#initialize_megatron(extra_args_provider=None,
# args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
args = get_args() args = get_args()
...@@ -42,6 +53,8 @@ def main(): ...@@ -42,6 +53,8 @@ def main():
Check README.md for example script Check README.md for example script
""" """
#print_rank_0("Starting index builder!")
index_builder = IndexBuilder() index_builder = IndexBuilder()
index_builder.build_and_save_index() index_builder.build_and_save_index()
print_rank_0("Build and save indices: done!") print_rank_0("Build and save indices: done!")
......
...@@ -44,14 +44,14 @@ class ORQAEvaluator(object): ...@@ -44,14 +44,14 @@ class ORQAEvaluator(object):
if args.biencoder_shared_query_context_model: if args.biencoder_shared_query_context_model:
only_query_model = False only_query_model = False
args.only_query_model = only_query_model #args.only_query_model = only_query_model
args.only_context_model = False #args.only_context_model = False
#model = get_model(lambda: biencoder_model_provider(only_query_model=\ model = get_model(lambda: biencoder_model_provider(only_query_model=\
# only_query_model, biencoder_shared_query_context_model=\ only_query_model, biencoder_shared_query_context_model=\
# args.biencoder_shared_query_context_model)) args.biencoder_shared_query_context_model))
model = get_model(biencoder_model_provider) #model = get_model(biencoder_model_provider)
self.model = load_biencoder_checkpoint(model, self.model = load_biencoder_checkpoint(model,
only_query_model=only_query_model) only_query_model=only_query_model)
......
...@@ -184,15 +184,16 @@ def orqa(Dataset): # , name_from_datapath_func): ...@@ -184,15 +184,16 @@ def orqa(Dataset): # , name_from_datapath_func):
"""Build the model.""" """Build the model."""
args = get_args() args = get_args()
print_rank_0('building retriever model for {} ...'.format(args.task)) print_rank_0('building retriever model for {} ...'.format(args.task))
args.only_context_model=False #args.only_context_model=False
args.only_query_model=False #args.only_query_model=False
model = biencoder_model_provider() #model = biencoder_model_provider()
#model = biencoder_model_provider(only_context_model=False, model = biencoder_model_provider(only_context_model=False,
# only_query_model=False, only_query_model=False,
# biencoder_shared_query_context_model=\ biencoder_shared_query_context_model=\
# args.biencoder_shared_query_context_model, args.biencoder_shared_query_context_model,
# pre_process=pre_process, post_process=post_process) pre_process=pre_process, post_process=post_process)
return model return model
def single_dataset_provider(datapath): def single_dataset_provider(datapath):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment