Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wuxk1
Megatron-LM
Commits
6d03d7af
"torchvision/csrc/cpu/image" did not exist on "fa6af6d1e6f050d0a930e6f59894b7dd40aa869c"
Commit
6d03d7af
authored
May 11, 2021
by
Mostofa Patwary
Browse files
DPR finetune and evaluation
parent
d2d5086e
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
78 additions
and
80 deletions
+78
-80
megatron/checkpointing.py
megatron/checkpointing.py
+3
-5
megatron/indexer.py
megatron/indexer.py
+12
-27
megatron/model/biencoder_model.py
megatron/model/biencoder_model.py
+11
-9
megatron/model/language_model.py
megatron/model/language_model.py
+0
-5
pretrain_ict.py
pretrain_ict.py
+9
-5
tasks/orqa/evaluate_orqa.py
tasks/orqa/evaluate_orqa.py
+19
-1
tasks/orqa/evaluate_utils.py
tasks/orqa/evaluate_utils.py
+8
-3
tasks/orqa/supervised/finetune.py
tasks/orqa/supervised/finetune.py
+16
-25
No files found.
megatron/checkpointing.py
View file @
6d03d7af
...
...
@@ -413,11 +413,9 @@ def load_biencoder_checkpoint(model, only_query_model=False,
if
only_context_model
:
ret_state_dict
.
pop
(
'query_model'
)
#print_rank_0(len(model))
#sys.exit()
#assert len(model) == 1
#model[0].load_state_dict(ret_state_dict)
model
.
load_state_dict
(
ret_state_dict
)
assert
len
(
model
)
==
1
model
[
0
].
load_state_dict
(
ret_state_dict
)
torch
.
distributed
.
barrier
()
if
mpu
.
get_data_parallel_rank
()
==
0
:
...
...
megatron/indexer.py
View file @
6d03d7af
...
...
@@ -45,26 +45,25 @@ class IndexBuilder(object):
"""
Load the necessary attributes: model, dataloader and empty BlockData
"""
args
=
get_args
()
only_context_model
=
True
if
self
.
biencoder_shared_query_context_model
:
only_context_model
=
False
model
=
get_model
(
lambda
:
biencoder_model_provider
(
only_context_model
\
=
only_context_model
,
biencoder_shared_query_context
_model
=
\
self
.
biencoder_shared_query_context_model
,
\
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
)
)
args
.
only_context_model
=
only_context_model
args
.
only_query
_model
=
False
model
=
get_model
(
biencoder_model_provider
)
#model = biencoder_model_provider(only_context_model \
#model =
get_model(lambda:
biencoder_model_provider(only_context_model \
# = only_context_model, biencoder_shared_query_context_model = \
# self.biencoder_shared_query_context_model, \
# pre_process=self.pre_process, post_process=self.post_process)
# self.biencoder_shared_query_context_model))
self
.
model
=
load_biencoder_checkpoint
(
model
,
only_context_model
=
only_context_model
)
#assert len(self.model) == 1
#self.model[0].eval()
self
.
model
.
eval
()
assert
len
(
self
.
model
)
==
1
self
.
model
[
0
].
eval
()
self
.
dataset
=
get_open_retrieval_wiki_dataset
()
self
.
dataloader
=
iter
(
get_one_epoch_dataloader
(
self
.
dataset
,
\
...
...
@@ -92,12 +91,11 @@ class IndexBuilder(object):
distributed setting will be consolidated by the rank 0 process
and saved as a final pickled BlockData.
"""
#
assert len(self.model) == 1
#
unwrapped_model = self.model[0]
unwrapped_model
=
self
.
model
assert
len
(
self
.
model
)
==
1
unwrapped_model
=
self
.
model
[
0
]
while
not
hasattr
(
unwrapped_model
,
'embed_text'
):
unwrapped_model
=
unwrapped_model
.
module
print_rank_0
(
"hasattr"
)
while
True
:
try
:
...
...
@@ -108,17 +106,6 @@ class IndexBuilder(object):
except
(
StopIteration
,
IndexError
):
break
print_rank_0
(
context_tokens
)
print_rank_0
(
context_mask
)
print_rank_0
(
context_types
)
#if torch.cuda.is_available():
# print_rank_0("cuda available")
#print_rank_0(torch.cuda.current_device())
#print_rank_0(torch.cuda.get_device_name())
print_rank_0
(
next
(
unwrapped_model
.
parameters
()).
device
)
print_rank_0
(
next
(
unwrapped_model
.
context_model
.
parameters
()).
device
)
#print_rank_0("After get_open_retrieval_batch")
# TODO: can we add with torch.no_grad() to reduce memory usage
# detach, separate fields and add to BlockData
assert
context_mask
.
dtype
==
torch
.
bool
...
...
@@ -126,8 +113,6 @@ class IndexBuilder(object):
unwrapped_model
.
context_model
,
context_tokens
,
context_mask
,
context_types
)
sys
.
exit
()
context_logits
=
detach
(
context_logits
)
row_id
=
detach
(
row_id
)
...
...
megatron/model/biencoder_model.py
View file @
6d03d7af
...
...
@@ -15,14 +15,21 @@ from megatron.model.utils import init_method_normal
from
megatron.model.utils
import
scaled_init_method_normal
from
.module
import
MegatronModule
def
biencoder_model_provider
(
only_query_model
=
False
,
only_context_model
=
False
,
biencoder_shared_query_context_model
=
False
,
pre_process
=
True
,
#def biencoder_model_provider(only_query_model=False,
# only_context_model=False,
# biencoder_shared_query_context_model=False,
# pre_process=True,
# post_process=True):
def
biencoder_model_provider
(
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
args
=
get_args
()
biencoder_shared_query_context_model
=
args
.
biencoder_shared_query_context_model
only_context_model
=
args
.
only_context_model
only_query_model
=
args
.
only_query_model
assert
mpu
.
get_tensor_model_parallel_world_size
()
==
1
and
\
mpu
.
get_pipeline_model_parallel_world_size
()
==
1
,
\
"Model parallel size > 1 not supported for ICT"
...
...
@@ -266,11 +273,6 @@ class PretrainedBertModel(MegatronModule):
#extended_attention_mask = bert_extended_attention_mask(attention_mask)
position_ids
=
bert_position_ids
(
input_ids
)
print_rank_0
(
input_ids
.
device
)
print_rank_0
(
position_ids
.
device
)
print_rank_0
(
extended_attention_mask
.
device
)
print_rank_0
(
tokentype_ids
.
device
)
lm_output
=
self
.
language_model
(
input_ids
,
position_ids
,
extended_attention_mask
,
...
...
megatron/model/language_model.py
View file @
6d03d7af
...
...
@@ -338,11 +338,6 @@ class TransformerLanguageModel(MegatronModule):
get_key_value
=
False
,
pooling_sequence_index
=
0
,
enc_hidden_states
=
None
,
output_enc_hidden
=
False
):
print_rank_0
(
"before self.embedding"
)
print_rank_0
(
enc_input_ids
.
device
)
print_rank_0
(
enc_position_ids
.
device
)
print_rank_0
(
tokentype_ids
.
device
)
# Embeddings.
if
self
.
pre_process
:
embedding_output
=
self
.
embedding
(
enc_input_ids
,
enc_position_ids
,
...
...
pretrain_ict.py
View file @
6d03d7af
...
...
@@ -33,11 +33,15 @@ from megatron.utils import average_losses_across_data_parallel_group
def
pretrain_ict_model_provider
():
args
=
get_args
()
model
=
biencoder_model_provider
(
only_context_model
=
False
,
only_query_model
=
False
,
biencoder_shared_query_context_model
=
\
args
.
biencoder_shared_query_context_model
)
args
.
only_context_model
=
False
args
.
only_query_model
=
False
model
=
biencoder_model_provider
()
#model = biencoder_model_provider(
# only_context_model=False,
# only_query_model=False,
# biencoder_shared_query_context_model=\
# args.biencoder_shared_query_context_model)
return
model
def
get_group_world_size_rank
():
...
...
tasks/orqa/evaluate_orqa.py
View file @
6d03d7af
...
...
@@ -19,6 +19,7 @@ import os
import
sys
from
megatron
import
get_args
from
megatron.indexer
import
IndexBuilder
from
tasks.orqa.evaluate_utils
import
ORQAEvaluator
def
main
():
...
...
@@ -28,6 +29,23 @@ def main():
args
=
get_args
()
"""Create a BlockData data structure by running an IndexBuilder over an ICT Dataset
- Include all args needed for initial model specification
Other key args:
--block-data-path: path to write to
--ict-load or --realm-load: path to checkpoint with which to embed
--data-path and --titles-data-path: paths for dataset
--indexer-log-interval: reporting interval
--indexer-batch-size: size specific for indexer jobs
Check README.md for example script
"""
index_builder
=
IndexBuilder
()
index_builder
.
build_and_save_index
()
print_rank_0
(
"Build and save indices: done!"
)
# Set up the model and evaluator
evaluator
=
ORQAEvaluator
()
...
...
@@ -37,4 +55,4 @@ def main():
if
args
.
qa_data_test
is
not
None
:
evaluator
.
evaluate
(
args
.
qa_data_test
,
"TEST"
)
tasks/orqa/evaluate_utils.py
View file @
6d03d7af
...
...
@@ -44,9 +44,14 @@ class ORQAEvaluator(object):
if
args
.
biencoder_shared_query_context_model
:
only_query_model
=
False
model
=
get_model
(
lambda
:
biencoder_model_provider
(
only_query_model
=
\
only_query_model
,
biencoder_shared_query_context_model
=
\
args
.
biencoder_shared_query_context_model
))
args
.
only_query_model
=
only_query_model
args
.
only_context_model
=
False
#model = get_model(lambda: biencoder_model_provider(only_query_model=\
# only_query_model, biencoder_shared_query_context_model=\
# args.biencoder_shared_query_context_model))
model
=
get_model
(
biencoder_model_provider
)
self
.
model
=
load_biencoder_checkpoint
(
model
,
only_query_model
=
only_query_model
)
...
...
tasks/orqa/supervised/finetune.py
View file @
6d03d7af
...
...
@@ -16,6 +16,7 @@
"""ORQA finetuning/evaluation."""
from
functools
import
partial
import
sys
import
math
import
torch
...
...
@@ -183,11 +184,15 @@ def orqa(Dataset): # , name_from_datapath_func):
"""Build the model."""
args
=
get_args
()
print_rank_0
(
'building retriever model for {} ...'
.
format
(
args
.
task
))
model
=
biencoder_model_provider
(
only_context_model
=
False
,
only_query_model
=
False
,
biencoder_shared_query_context_model
=
\
args
.
biencoder_shared_query_context_model
,
pre_process
=
pre_process
,
post_process
=
post_process
)
args
.
only_context_model
=
False
args
.
only_query_model
=
False
model
=
biencoder_model_provider
()
#model = biencoder_model_provider(only_context_model=False,
# only_query_model=False,
# biencoder_shared_query_context_model=\
# args.biencoder_shared_query_context_model,
# pre_process=pre_process, post_process=post_process)
return
model
def
single_dataset_provider
(
datapath
):
...
...
@@ -228,29 +233,15 @@ def orqa(Dataset): # , name_from_datapath_func):
def
main
():
args
=
get_args
()
#
if args.task == 'RET-FINETUNE-NQ':
#
from tasks.orqa.supervised.data import NQSupervisedDataset as Dataset
if
args
.
task
==
'RET-FINETUNE-NQ'
:
from
tasks.orqa.supervised.data
import
NQSupervisedDataset
as
Dataset
#def name_from_datapath(datapath):
# return datapath[0].split('/')[-1].split('.')[0]
#else:
# raise NotImplementedError('ORQA task {} is not implemented.'.format(
# args.task))
#orqa(Dataset) #, name_from_datapath)
index_builder
=
IndexBuilder
()
index_builder
.
build_and_save_index
()
print_rank_0
(
"Build and save indices: done!"
)
# Set up the model and evaluator
#evaluator = ORQAEvaluator()
# Run evaluation
#if args.qa_data_dev is not None:
# evaluator.evaluate(args.qa_data_dev, "DEV")
#if args.qa_data_test is not None:
# evaluator.evaluate(args.qa_data_test, "TEST")
else
:
raise
NotImplementedError
(
'ORQA task {} is not implemented.'
.
format
(
args
.
task
))
orqa
(
Dataset
)
#, name_from_datapath)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment