Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
f9267205
Commit
f9267205
authored
May 12, 2021
by
Mostofa Patwary
Browse files
fixing model evaluation of retriver
parent
6d03d7af
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
51 additions
and
36 deletions
+51
-36
megatron/indexer.py
megatron/indexer.py
+4
-3
megatron/model/biencoder_model.py
megatron/model/biencoder_model.py
+10
-10
pretrain_ict.py
pretrain_ict.py
+8
-8
tasks/main.py
tasks/main.py
+1
-1
tasks/orqa/evaluate_orqa.py
tasks/orqa/evaluate_orqa.py
+13
-0
tasks/orqa/evaluate_utils.py
tasks/orqa/evaluate_utils.py
+6
-6
tasks/orqa/supervised/finetune.py
tasks/orqa/supervised/finetune.py
+9
-8
No files found.
megatron/indexer.py
View file @
f9267205
...
...
@@ -53,11 +53,12 @@ class IndexBuilder(object):
args
.
only_context_model
=
only_context_model
args
.
only_query_model
=
False
model
=
get_model
(
biencoder_model_provider
)
#
model = get_model(biencoder_model_provider)
#model = get_model(lambda: biencoder_model_provider(only_context_model \
# = only_context_model, biencoder_shared_query_context_model = \
# self.biencoder_shared_query_context_model))
model
=
get_model
(
biencoder_model_provider
(
only_context_model
\
=
only_context_model
,
biencoder_shared_query_context_model
=
\
self
.
biencoder_shared_query_context_model
))
self
.
model
=
load_biencoder_checkpoint
(
model
,
only_context_model
=
only_context_model
)
...
...
megatron/model/biencoder_model.py
View file @
f9267205
...
...
@@ -15,20 +15,20 @@ from megatron.model.utils import init_method_normal
from
megatron.model.utils
import
scaled_init_method_normal
from
.module
import
MegatronModule
#def biencoder_model_provider(only_query_model=False,
# only_context_model=False,
# biencoder_shared_query_context_model=False,
# pre_process=True,
#def biencoder_model_provider(pre_process=True,
# post_process=True):
def
biencoder_model_provider
(
pre_process
=
True
,
def
biencoder_model_provider
(
only_query_model
=
False
,
only_context_model
=
False
,
biencoder_shared_query_context_model
=
False
,
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
args
=
get_args
()
#
args = get_args()
biencoder_shared_query_context_model
=
args
.
biencoder_shared_query_context_model
only_context_model
=
args
.
only_context_model
only_query_model
=
args
.
only_query_model
#
biencoder_shared_query_context_model = args.biencoder_shared_query_context_model
#
only_context_model = args.only_context_model
#
only_query_model = args.only_query_model
assert
mpu
.
get_tensor_model_parallel_world_size
()
==
1
and
\
mpu
.
get_pipeline_model_parallel_world_size
()
==
1
,
\
...
...
pretrain_ict.py
View file @
f9267205
...
...
@@ -33,15 +33,15 @@ from megatron.utils import average_losses_across_data_parallel_group
def
pretrain_ict_model_provider
():
args
=
get_args
()
args
.
only_context_model
=
False
args
.
only_query_model
=
False
model
=
biencoder_model_provider
()
#
args.only_context_model = False
#
args.only_query_model = False
#
model = biencoder_model_provider()
#
model = biencoder_model_provider(
#
only_context_model=False,
#
only_query_model=False,
#
biencoder_shared_query_context_model=\
#
args.biencoder_shared_query_context_model)
model
=
biencoder_model_provider
(
only_context_model
=
False
,
only_query_model
=
False
,
biencoder_shared_query_context_model
=
\
args
.
biencoder_shared_query_context_model
)
return
model
def
get_group_world_size_rank
():
...
...
tasks/main.py
View file @
f9267205
...
...
@@ -110,7 +110,7 @@ if __name__ == '__main__':
from
glue.finetune
import
main
elif
args
.
task
in
[
'LAMBADA'
,
'WIKITEXT103'
]:
from
zeroshot_gpt.evaluate
import
main
elif
args
.
task
in
[
'ICT-ZEROSHOT-NQ'
]:
elif
args
.
task
in
[
'ICT-ZEROSHOT-NQ'
,
'RETRIEVER-EVAL'
]:
from
orqa.evaluate_orqa
import
main
elif
args
.
task
in
[
'RET-FINETUNE-NQ'
]:
from
orqa.supervised.finetune
import
main
...
...
tasks/orqa/evaluate_orqa.py
View file @
f9267205
...
...
@@ -18,6 +18,15 @@
import
os
import
sys
#sys.path.append(
# os.path.abspath(
# os.path.join(
# os.path.join(os.path.dirname(__file__), os.path.pardir),
# os.path.pardir,
# )
# )
#)
from
megatron
import
get_args
from
megatron.indexer
import
IndexBuilder
from
tasks.orqa.evaluate_utils
import
ORQAEvaluator
...
...
@@ -26,6 +35,8 @@ def main():
"""
Main program
"""
#initialize_megatron(extra_args_provider=None,
# args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
args
=
get_args
()
...
...
@@ -42,6 +53,8 @@ def main():
Check README.md for example script
"""
#print_rank_0("Starting index builder!")
index_builder
=
IndexBuilder
()
index_builder
.
build_and_save_index
()
print_rank_0
(
"Build and save indices: done!"
)
...
...
tasks/orqa/evaluate_utils.py
View file @
f9267205
...
...
@@ -44,14 +44,14 @@ class ORQAEvaluator(object):
if
args
.
biencoder_shared_query_context_model
:
only_query_model
=
False
args
.
only_query_model
=
only_query_model
args
.
only_context_model
=
False
#
args.only_query_model = only_query_model
#
args.only_context_model = False
#
model = get_model(lambda: biencoder_model_provider(only_query_model=\
#
only_query_model, biencoder_shared_query_context_model=\
#
args.biencoder_shared_query_context_model))
model
=
get_model
(
lambda
:
biencoder_model_provider
(
only_query_model
=
\
only_query_model
,
biencoder_shared_query_context_model
=
\
args
.
biencoder_shared_query_context_model
))
model
=
get_model
(
biencoder_model_provider
)
#
model = get_model(biencoder_model_provider)
self
.
model
=
load_biencoder_checkpoint
(
model
,
only_query_model
=
only_query_model
)
...
...
tasks/orqa/supervised/finetune.py
View file @
f9267205
...
...
@@ -184,15 +184,16 @@ def orqa(Dataset): # , name_from_datapath_func):
"""Build the model."""
args
=
get_args
()
print_rank_0
(
'building retriever model for {} ...'
.
format
(
args
.
task
))
args
.
only_context_model
=
False
args
.
only_query_model
=
False
model
=
biencoder_model_provider
()
#
args.only_context_model=False
#
args.only_query_model=False
#
model = biencoder_model_provider()
#model = biencoder_model_provider(only_context_model=False,
# only_query_model=False,
# biencoder_shared_query_context_model=\
# args.biencoder_shared_query_context_model,
# pre_process=pre_process, post_process=post_process)
model
=
biencoder_model_provider
(
only_context_model
=
False
,
only_query_model
=
False
,
biencoder_shared_query_context_model
=
\
args
.
biencoder_shared_query_context_model
,
pre_process
=
pre_process
,
post_process
=
post_process
)
return
model
def
single_dataset_provider
(
datapath
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment