"vscode:/vscode.git/clone" did not exist on "76702a03d6cc2e4f431bfd1914d5e301c07bd489"
Commit f64977fd authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

evaluation works!

parent 7e335e15
......@@ -9,7 +9,7 @@ from megatron.data.orqa_wiki_dataset import get_open_retrieval_wiki_dataset
from megatron.data.orqa_wiki_dataset import get_open_retrieval_batch
from megatron.data.biencoder_dataset_utils import get_one_epoch_dataloader
from megatron.data.realm_index import detach, OpenRetreivalDataStore
from megatron.model.biencoder_model import biencoder_model_provider
from megatron.model.biencoder_model import get_model_provider
from megatron.training import get_model
......@@ -50,16 +50,19 @@ class IndexBuilder(object):
if self.biencoder_shared_query_context_model:
only_context_model = False
args.only_context_model = only_context_model
args.only_query_model = False
#args.only_context_model = only_context_model
#args.only_query_model = False
#model = get_model(biencoder_model_provider)
model = get_model(get_model_provider(only_context_model=only_context_model,
biencoder_shared_query_context_model=self.biencoder_shared_query_context_model))
#model = get_model(lambda: biencoder_model_provider(only_context_model \
#model = get_model(lambda: biencoder_model_provider(only_context_model \
model = get_model(biencoder_model_provider(only_context_model \
= only_context_model, biencoder_shared_query_context_model = \
self.biencoder_shared_query_context_model,
pre_process=True, post_process=True))
# = only_context_model, biencoder_shared_query_context_model = \
# self.biencoder_shared_query_context_model,
# pre_process=True, post_process=True)
self.model = load_biencoder_checkpoint(model,
only_context_model=only_context_model)
......
......@@ -15,6 +15,25 @@ from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal
from .module import MegatronModule
def get_model_provider(only_query_model=False, only_context_model=False,
biencoder_shared_query_context_model=False):
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
print_rank_0('building Bienoder model ...')
model = biencoder_model_provider(only_query_model=only_query_model,
only_context_model = only_context_model,
biencoder_shared_query_context_model = \
biencoder_shared_query_context_model,
pre_process=True, post_process=True)
return model
return model_provider
#def biencoder_model_provider(pre_process=True,
# post_process=True):
......
......@@ -27,7 +27,7 @@ import sys
# )
#)
from megatron import get_args
from megatron import get_args, print_rank_0
from megatron.indexer import IndexBuilder
from tasks.orqa.evaluate_utils import ORQAEvaluator
......
......@@ -23,7 +23,7 @@ from tasks.orqa.natural_questions.nq import get_one_epoch_nq_dataloader
from tasks.orqa.natural_questions.nq import process_nq_batch
from tasks.orqa.natural_questions.qa_utils import calculate_matches
from megatron.data.realm_index import OpenRetreivalDataStore, FaissMIPSIndex
from megatron.model.biencoder_model import biencoder_model_provider
from megatron.model.biencoder_model import get_model_provider
from megatron.training import get_model
class ORQAEvaluator(object):
......@@ -47,11 +47,15 @@ class ORQAEvaluator(object):
#args.only_query_model = only_query_model
#args.only_context_model = False
model = get_model(get_model_provider(only_query_model=only_query_model,
biencoder_shared_query_context_model=args.biencoder_shared_query_context_model))
#model = get_model(lambda: biencoder_model_provider(only_query_model=\
#model = get_model(lambda: biencoder_model_provider(only_query_model=\
model = get_model(lambda: biencoder_model_provider(only_query_model=\
only_query_model, biencoder_shared_query_context_model=\
args.biencoder_shared_query_context_model,
pre_process=True, post_process=True))
# only_query_model, biencoder_shared_query_context_model=\
# args.biencoder_shared_query_context_model,
# pre_process=True, post_process=True))
#model = get_model(biencoder_model_provider)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment