Commit f64977fd authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

evaluation works!

parent 7e335e15
...@@ -9,7 +9,7 @@ from megatron.data.orqa_wiki_dataset import get_open_retrieval_wiki_dataset ...@@ -9,7 +9,7 @@ from megatron.data.orqa_wiki_dataset import get_open_retrieval_wiki_dataset
from megatron.data.orqa_wiki_dataset import get_open_retrieval_batch from megatron.data.orqa_wiki_dataset import get_open_retrieval_batch
from megatron.data.biencoder_dataset_utils import get_one_epoch_dataloader from megatron.data.biencoder_dataset_utils import get_one_epoch_dataloader
from megatron.data.realm_index import detach, OpenRetreivalDataStore from megatron.data.realm_index import detach, OpenRetreivalDataStore
from megatron.model.biencoder_model import biencoder_model_provider from megatron.model.biencoder_model import get_model_provider
from megatron.training import get_model from megatron.training import get_model
...@@ -50,16 +50,19 @@ class IndexBuilder(object): ...@@ -50,16 +50,19 @@ class IndexBuilder(object):
if self.biencoder_shared_query_context_model: if self.biencoder_shared_query_context_model:
only_context_model = False only_context_model = False
args.only_context_model = only_context_model #args.only_context_model = only_context_model
args.only_query_model = False #args.only_query_model = False
#model = get_model(biencoder_model_provider) #model = get_model(biencoder_model_provider)
model = get_model(get_model_provider(only_context_model=only_context_model,
biencoder_shared_query_context_model=self.biencoder_shared_query_context_model))
#model = get_model(lambda: biencoder_model_provider(only_context_model \
#model = get_model(lambda: biencoder_model_provider(only_context_model \ #model = get_model(lambda: biencoder_model_provider(only_context_model \
model = get_model(biencoder_model_provider(only_context_model \ # = only_context_model, biencoder_shared_query_context_model = \
= only_context_model, biencoder_shared_query_context_model = \ # self.biencoder_shared_query_context_model,
self.biencoder_shared_query_context_model, # pre_process=True, post_process=True)
pre_process=True, post_process=True))
self.model = load_biencoder_checkpoint(model, self.model = load_biencoder_checkpoint(model,
only_context_model=only_context_model) only_context_model=only_context_model)
......
...@@ -15,6 +15,25 @@ from megatron.model.utils import init_method_normal ...@@ -15,6 +15,25 @@ from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal from megatron.model.utils import scaled_init_method_normal
from .module import MegatronModule from .module import MegatronModule
def get_model_provider(only_query_model=False, only_context_model=False,
biencoder_shared_query_context_model=False):
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
print_rank_0('building Bienoder model ...')
model = biencoder_model_provider(only_query_model=only_query_model,
only_context_model = only_context_model,
biencoder_shared_query_context_model = \
biencoder_shared_query_context_model,
pre_process=True, post_process=True)
return model
return model_provider
#def biencoder_model_provider(pre_process=True, #def biencoder_model_provider(pre_process=True,
# post_process=True): # post_process=True):
......
...@@ -27,7 +27,7 @@ import sys ...@@ -27,7 +27,7 @@ import sys
# ) # )
#) #)
from megatron import get_args from megatron import get_args, print_rank_0
from megatron.indexer import IndexBuilder from megatron.indexer import IndexBuilder
from tasks.orqa.evaluate_utils import ORQAEvaluator from tasks.orqa.evaluate_utils import ORQAEvaluator
......
...@@ -23,7 +23,7 @@ from tasks.orqa.natural_questions.nq import get_one_epoch_nq_dataloader ...@@ -23,7 +23,7 @@ from tasks.orqa.natural_questions.nq import get_one_epoch_nq_dataloader
from tasks.orqa.natural_questions.nq import process_nq_batch from tasks.orqa.natural_questions.nq import process_nq_batch
from tasks.orqa.natural_questions.qa_utils import calculate_matches from tasks.orqa.natural_questions.qa_utils import calculate_matches
from megatron.data.realm_index import OpenRetreivalDataStore, FaissMIPSIndex from megatron.data.realm_index import OpenRetreivalDataStore, FaissMIPSIndex
from megatron.model.biencoder_model import biencoder_model_provider from megatron.model.biencoder_model import get_model_provider
from megatron.training import get_model from megatron.training import get_model
class ORQAEvaluator(object): class ORQAEvaluator(object):
...@@ -47,11 +47,15 @@ class ORQAEvaluator(object): ...@@ -47,11 +47,15 @@ class ORQAEvaluator(object):
#args.only_query_model = only_query_model #args.only_query_model = only_query_model
#args.only_context_model = False #args.only_context_model = False
model = get_model(get_model_provider(only_query_model=only_query_model,
biencoder_shared_query_context_model=args.biencoder_shared_query_context_model))
#model = get_model(lambda: biencoder_model_provider(only_query_model=\
#model = get_model(lambda: biencoder_model_provider(only_query_model=\ #model = get_model(lambda: biencoder_model_provider(only_query_model=\
model = get_model(lambda: biencoder_model_provider(only_query_model=\ # only_query_model, biencoder_shared_query_context_model=\
only_query_model, biencoder_shared_query_context_model=\ # args.biencoder_shared_query_context_model,
args.biencoder_shared_query_context_model, # pre_process=True, post_process=True))
pre_process=True, post_process=True))
#model = get_model(biencoder_model_provider) #model = get_model(biencoder_model_provider)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment