Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
f64977fd
Commit
f64977fd
authored
May 13, 2021
by
Mostofa Patwary
Browse files
evaluation works!
parent
7e335e15
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
39 additions
and
13 deletions
+39
-13
megatron/indexer.py
megatron/indexer.py
+10
-7
megatron/model/biencoder_model.py
megatron/model/biencoder_model.py
+19
-0
tasks/orqa/evaluate_orqa.py
tasks/orqa/evaluate_orqa.py
+1
-1
tasks/orqa/evaluate_utils.py
tasks/orqa/evaluate_utils.py
+9
-5
No files found.
megatron/indexer.py
View file @
f64977fd
...
@@ -9,7 +9,7 @@ from megatron.data.orqa_wiki_dataset import get_open_retrieval_wiki_dataset
...
@@ -9,7 +9,7 @@ from megatron.data.orqa_wiki_dataset import get_open_retrieval_wiki_dataset
from
megatron.data.orqa_wiki_dataset
import
get_open_retrieval_batch
from
megatron.data.orqa_wiki_dataset
import
get_open_retrieval_batch
from
megatron.data.biencoder_dataset_utils
import
get_one_epoch_dataloader
from
megatron.data.biencoder_dataset_utils
import
get_one_epoch_dataloader
from
megatron.data.realm_index
import
detach
,
OpenRetreivalDataStore
from
megatron.data.realm_index
import
detach
,
OpenRetreivalDataStore
from
megatron.model.biencoder_model
import
biencoder
_model_provider
from
megatron.model.biencoder_model
import
get
_model_provider
from
megatron.training
import
get_model
from
megatron.training
import
get_model
...
@@ -50,16 +50,19 @@ class IndexBuilder(object):
...
@@ -50,16 +50,19 @@ class IndexBuilder(object):
if
self
.
biencoder_shared_query_context_model
:
if
self
.
biencoder_shared_query_context_model
:
only_context_model
=
False
only_context_model
=
False
args
.
only_context_model
=
only_context_model
#
args.only_context_model = only_context_model
args
.
only_query_model
=
False
#
args.only_query_model = False
#model = get_model(biencoder_model_provider)
#model = get_model(biencoder_model_provider)
model
=
get_model
(
get_model_provider
(
only_context_model
=
only_context_model
,
biencoder_shared_query_context_model
=
self
.
biencoder_shared_query_context_model
))
#model = get_model(lambda: biencoder_model_provider(only_context_model \
#model = get_model(lambda: biencoder_model_provider(only_context_model \
#model = get_model(lambda: biencoder_model_provider(only_context_model \
model
=
get_model
(
biencoder_model_provider
(
only_context_model
\
# = only_context_model, biencoder_shared_query_context_model = \
=
only_context_model
,
biencoder_shared_query_context_model
=
\
# self.biencoder_shared_query_context_model,
self
.
biencoder_shared_query_context_model
,
# pre_process=True, post_process=True)
pre_process
=
True
,
post_process
=
True
))
self
.
model
=
load_biencoder_checkpoint
(
model
,
self
.
model
=
load_biencoder_checkpoint
(
model
,
only_context_model
=
only_context_model
)
only_context_model
=
only_context_model
)
...
...
megatron/model/biencoder_model.py
View file @
f64977fd
...
@@ -15,6 +15,25 @@ from megatron.model.utils import init_method_normal
...
@@ -15,6 +15,25 @@ from megatron.model.utils import init_method_normal
from
megatron.model.utils
import
scaled_init_method_normal
from
megatron.model.utils
import
scaled_init_method_normal
from
.module
import
MegatronModule
from
.module
import
MegatronModule
def
get_model_provider
(
only_query_model
=
False
,
only_context_model
=
False
,
biencoder_shared_query_context_model
=
False
):
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
print_rank_0
(
'building Bienoder model ...'
)
model
=
biencoder_model_provider
(
only_query_model
=
only_query_model
,
only_context_model
=
only_context_model
,
biencoder_shared_query_context_model
=
\
biencoder_shared_query_context_model
,
pre_process
=
True
,
post_process
=
True
)
return
model
return
model_provider
#def biencoder_model_provider(pre_process=True,
#def biencoder_model_provider(pre_process=True,
# post_process=True):
# post_process=True):
...
...
tasks/orqa/evaluate_orqa.py
View file @
f64977fd
...
@@ -27,7 +27,7 @@ import sys
...
@@ -27,7 +27,7 @@ import sys
# )
# )
#)
#)
from
megatron
import
get_args
from
megatron
import
get_args
,
print_rank_0
from
megatron.indexer
import
IndexBuilder
from
megatron.indexer
import
IndexBuilder
from
tasks.orqa.evaluate_utils
import
ORQAEvaluator
from
tasks.orqa.evaluate_utils
import
ORQAEvaluator
...
...
tasks/orqa/evaluate_utils.py
View file @
f64977fd
...
@@ -23,7 +23,7 @@ from tasks.orqa.natural_questions.nq import get_one_epoch_nq_dataloader
...
@@ -23,7 +23,7 @@ from tasks.orqa.natural_questions.nq import get_one_epoch_nq_dataloader
from
tasks.orqa.natural_questions.nq
import
process_nq_batch
from
tasks.orqa.natural_questions.nq
import
process_nq_batch
from
tasks.orqa.natural_questions.qa_utils
import
calculate_matches
from
tasks.orqa.natural_questions.qa_utils
import
calculate_matches
from
megatron.data.realm_index
import
OpenRetreivalDataStore
,
FaissMIPSIndex
from
megatron.data.realm_index
import
OpenRetreivalDataStore
,
FaissMIPSIndex
from
megatron.model.biencoder_model
import
biencoder
_model_provider
from
megatron.model.biencoder_model
import
get
_model_provider
from
megatron.training
import
get_model
from
megatron.training
import
get_model
class
ORQAEvaluator
(
object
):
class
ORQAEvaluator
(
object
):
...
@@ -47,11 +47,15 @@ class ORQAEvaluator(object):
...
@@ -47,11 +47,15 @@ class ORQAEvaluator(object):
#args.only_query_model = only_query_model
#args.only_query_model = only_query_model
#args.only_context_model = False
#args.only_context_model = False
model
=
get_model
(
get_model_provider
(
only_query_model
=
only_query_model
,
biencoder_shared_query_context_model
=
args
.
biencoder_shared_query_context_model
))
#model = get_model(lambda: biencoder_model_provider(only_query_model=\
#model = get_model(lambda: biencoder_model_provider(only_query_model=\
#model = get_model(lambda: biencoder_model_provider(only_query_model=\
model
=
get_model
(
lambda
:
biencoder_model_provider
(
only_query_model
=
\
# only_query_model, biencoder_shared_query_context_model=\
only_query_model
,
biencoder_shared_query_context_model
=
\
# args.biencoder_shared_query_context_model,
args
.
biencoder_shared_query_context_model
,
# pre_process=True, post_process=True))
pre_process
=
True
,
post_process
=
True
))
#model = get_model(biencoder_model_provider)
#model = get_model(biencoder_model_provider)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment