Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
2eaf6c79
"vscode:/vscode.git/clone" did not exist on "7dca70049566b5b1c55cbd67e1cb191729a98152"
Commit
2eaf6c79
authored
May 18, 2021
by
Mostofa Patwary
Browse files
cleaning the code
parent
7a0710ec
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
56 additions
and
124 deletions
+56
-124
megatron/indexer.py
megatron/indexer.py
+16
-16
megatron/model/biencoder_model.py
megatron/model/biencoder_model.py
+8
-17
megatron/model/language_model.py
megatron/model/language_model.py
+1
-1
pretrain_ict.py
pretrain_ict.py
+1
-1
tasks/finetune_utils.py
tasks/finetune_utils.py
+2
-6
tasks/main.py
tasks/main.py
+2
-2
tasks/orqa/evaluate_orqa.py
tasks/orqa/evaluate_orqa.py
+8
-27
tasks/orqa/evaluate_utils.py
tasks/orqa/evaluate_utils.py
+1
-2
tasks/orqa/supervised/data.py
tasks/orqa/supervised/data.py
+2
-4
tasks/orqa/supervised/eval_utils.py
tasks/orqa/supervised/eval_utils.py
+2
-7
tasks/orqa/supervised/finetune.py
tasks/orqa/supervised/finetune.py
+13
-41
No files found.
megatron/indexer.py
View file @
2eaf6c79
...
@@ -26,8 +26,8 @@ class IndexBuilder(object):
...
@@ -26,8 +26,8 @@ class IndexBuilder(object):
self
.
evidence_embedder_obj
=
None
self
.
evidence_embedder_obj
=
None
self
.
biencoder_shared_query_context_model
=
\
self
.
biencoder_shared_query_context_model
=
\
args
.
biencoder_shared_query_context_model
args
.
biencoder_shared_query_context_model
self
.
pre_process
=
True
#
self.pre_process = True
self
.
post_process
=
True
#
self.post_process = True
# need to know whether we're using a REALM checkpoint (args.load)
# need to know whether we're using a REALM checkpoint (args.load)
# or ICT checkpoint
# or ICT checkpoint
...
@@ -46,7 +46,7 @@ class IndexBuilder(object):
...
@@ -46,7 +46,7 @@ class IndexBuilder(object):
"""
"""
Load the necessary attributes: model, dataloader and empty BlockData
Load the necessary attributes: model, dataloader and empty BlockData
"""
"""
args
=
get_args
()
#
args = get_args()
only_context_model
=
True
only_context_model
=
True
if
self
.
biencoder_shared_query_context_model
:
if
self
.
biencoder_shared_query_context_model
:
only_context_model
=
False
only_context_model
=
False
...
@@ -56,7 +56,7 @@ class IndexBuilder(object):
...
@@ -56,7 +56,7 @@ class IndexBuilder(object):
#model = get_model(biencoder_model_provider)
#model = get_model(biencoder_model_provider)
model
=
get_model
(
get_model_provider
(
only_context_model
=
only_context_model
,
model
=
get_model
(
get_model_provider
(
only_context_model
=
only_context_model
,
biencoder_shared_query_context_model
=
self
.
biencoder_shared_query_context_model
))
biencoder_shared_query_context_model
=
self
.
biencoder_shared_query_context_model
))
#model = get_model(lambda: biencoder_model_provider(only_context_model \
#model = get_model(lambda: biencoder_model_provider(only_context_model \
...
@@ -103,12 +103,12 @@ class IndexBuilder(object):
...
@@ -103,12 +103,12 @@ class IndexBuilder(object):
while
not
hasattr
(
unwrapped_model
,
'embed_text'
):
while
not
hasattr
(
unwrapped_model
,
'embed_text'
):
unwrapped_model
=
unwrapped_model
.
module
unwrapped_model
=
unwrapped_model
.
module
counter
=
0
#
counter = 0
start_time
=
time
.
time
()
#
start_time = time.time()
cur_time
=
start_time
#
cur_time = start_time
while
True
:
while
True
:
#start_time = time.time()
#start_time = time.time()
t1
=
time
.
time
()
#
t1 = time.time()
try
:
try
:
# batch also has query_tokens and query_pad_data
# batch also has query_tokens and query_pad_data
row_id
,
context_tokens
,
context_mask
,
context_types
,
\
row_id
,
context_tokens
,
context_mask
,
context_types
,
\
...
@@ -118,7 +118,7 @@ class IndexBuilder(object):
...
@@ -118,7 +118,7 @@ class IndexBuilder(object):
break
break
#print_rank_0("get batch time {}".format(cur_time - time.time()))
#print_rank_0("get batch time {}".format(cur_time - time.time()))
t2
=
time
.
time
()
#
t2 = time.time()
# TODO: can we add with torch.no_grad() to reduce memory usage
# TODO: can we add with torch.no_grad() to reduce memory usage
# detach, separate fields and add to BlockData
# detach, separate fields and add to BlockData
assert
context_mask
.
dtype
==
torch
.
bool
assert
context_mask
.
dtype
==
torch
.
bool
...
@@ -129,17 +129,17 @@ class IndexBuilder(object):
...
@@ -129,17 +129,17 @@ class IndexBuilder(object):
context_logits
=
detach
(
context_logits
)
context_logits
=
detach
(
context_logits
)
row_id
=
detach
(
row_id
)
row_id
=
detach
(
row_id
)
#print_rank_0("embed text {}".format(cur_time - time.time()))
#print_rank_0("embed text {}".format(cur_time - time.time()))
t3
=
time
.
time
()
#
t3 = time.time()
self
.
evidence_embedder_obj
.
add_block_data
(
row_id
,
context_logits
)
self
.
evidence_embedder_obj
.
add_block_data
(
row_id
,
context_logits
)
self
.
track_and_report_progress
(
batch_size
=
len
(
row_id
))
self
.
track_and_report_progress
(
batch_size
=
len
(
row_id
))
#print_rank_0("add block time {}".format(cur_time - time.time()))
#print_rank_0("add block time {}".format(cur_time - time.time()))
t4
=
time
.
time
()
#
t4 = time.time()
counter
+=
1
#
counter += 1
if
counter
%
1000
==
0
:
#
if counter % 1000 == 0:
print_rank_0
(
"total time {} 1000 iter time {}"
.
format
(
time
.
time
()
-
start_time
,
time
.
time
()
-
cur_time
))
#
print_rank_0("total time {} 1000 iter time {}".format(time.time() - start_time, time.time() - cur_time))
print_rank_0
(
"breakdown batch {} model {} block {}"
.
format
(
t2
-
t1
,
t3
-
t2
,
t4
-
t3
))
#
print_rank_0("breakdown batch {} model {} block {}".format(t2 - t1, t3 - t2, t4 -t3))
cur_time
=
time
.
time
()
#
cur_time = time.time()
# This process signals to finalize its shard and then synchronize with
# This process signals to finalize its shard and then synchronize with
# the other processes
# the other processes
self
.
evidence_embedder_obj
.
save_shard
()
self
.
evidence_embedder_obj
.
save_shard
()
...
...
megatron/model/biencoder_model.py
View file @
2eaf6c79
...
@@ -15,17 +15,17 @@ from megatron.model.utils import init_method_normal
...
@@ -15,17 +15,17 @@ from megatron.model.utils import init_method_normal
from
megatron.model.utils
import
scaled_init_method_normal
from
megatron.model.utils
import
scaled_init_method_normal
from
.module
import
MegatronModule
from
.module
import
MegatronModule
def
get_model_provider
(
only_query_model
=
False
,
only_context_model
=
False
,
def
get_model_provider
(
only_query_model
=
False
,
only_context_model
=
False
,
biencoder_shared_query_context_model
=
False
):
biencoder_shared_query_context_model
=
False
):
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
"""Build the model."""
print_rank_0
(
'building Bienoder model ...'
)
print_rank_0
(
'building Bienoder model ...'
)
model
=
biencoder_model_provider
(
only_query_model
=
only_query_model
,
model
=
biencoder_model_provider
(
only_query_model
=
only_query_model
,
only_context_model
=
only_context_model
,
only_context_model
=
only_context_model
,
biencoder_shared_query_context_model
=
\
biencoder_shared_query_context_model
=
\
biencoder_shared_query_context_model
,
biencoder_shared_query_context_model
,
pre_process
=
True
,
post_process
=
True
)
pre_process
=
True
,
post_process
=
True
)
return
model
return
model
...
@@ -33,21 +33,12 @@ def get_model_provider(only_query_model=False, only_context_model=False,
...
@@ -33,21 +33,12 @@ def get_model_provider(only_query_model=False, only_context_model=False,
return
model_provider
return
model_provider
#def biencoder_model_provider(pre_process=True,
# post_process=True):
def
biencoder_model_provider
(
only_query_model
=
False
,
def
biencoder_model_provider
(
only_query_model
=
False
,
only_context_model
=
False
,
only_context_model
=
False
,
biencoder_shared_query_context_model
=
False
,
biencoder_shared_query_context_model
=
False
,
pre_process
=
True
,
pre_process
=
True
,
post_process
=
True
):
post_process
=
True
):
"""Build the model."""
"""Build the model."""
#args = get_args()
#biencoder_shared_query_context_model = args.biencoder_shared_query_context_model
#only_context_model = args.only_context_model
#only_query_model = args.only_query_model
assert
mpu
.
get_tensor_model_parallel_world_size
()
==
1
and
\
assert
mpu
.
get_tensor_model_parallel_world_size
()
==
1
and
\
mpu
.
get_pipeline_model_parallel_world_size
()
==
1
,
\
mpu
.
get_pipeline_model_parallel_world_size
()
==
1
,
\
...
@@ -63,7 +54,7 @@ def biencoder_model_provider(only_query_model=False,
...
@@ -63,7 +54,7 @@ def biencoder_model_provider(only_query_model=False,
only_query_model
=
only_query_model
,
only_query_model
=
only_query_model
,
only_context_model
=
only_context_model
,
only_context_model
=
only_context_model
,
biencoder_shared_query_context_model
=
\
biencoder_shared_query_context_model
=
\
biencoder_shared_query_context_model
,
biencoder_shared_query_context_model
,
pre_process
=
pre_process
,
pre_process
=
pre_process
,
post_process
=
post_process
)
post_process
=
post_process
)
...
@@ -114,9 +105,9 @@ class BiEncoderModel(MegatronModule):
...
@@ -114,9 +105,9 @@ class BiEncoderModel(MegatronModule):
def
set_input_tensor
(
self
,
input_tensor
):
def
set_input_tensor
(
self
,
input_tensor
):
"""See megatron.model.transformer.set_input_tensor()"""
"""See megatron.model.transformer.set_input_tensor()"""
#this is just a placeholder and will be needed when model
#
this is just a placeholder and will be needed when model
#parallelism will be used
#
parallelism will be used
#self.language_model.set_input_tensor(input_tensor)
#
self.language_model.set_input_tensor(input_tensor)
return
return
def
forward
(
self
,
query_tokens
,
query_attention_mask
,
query_types
,
def
forward
(
self
,
query_tokens
,
query_attention_mask
,
query_types
,
...
...
megatron/model/language_model.py
View file @
2eaf6c79
...
@@ -18,7 +18,7 @@
...
@@ -18,7 +18,7 @@
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
megatron
import
get_args
,
print_rank_0
from
megatron
import
get_args
from
megatron
import
mpu
from
megatron
import
mpu
from
.module
import
MegatronModule
from
.module
import
MegatronModule
from
megatron.model.enums
import
LayerType
,
AttnMaskType
from
megatron.model.enums
import
LayerType
,
AttnMaskType
...
...
pretrain_ict.py
View file @
2eaf6c79
...
@@ -36,7 +36,7 @@ def pretrain_ict_model_provider():
...
@@ -36,7 +36,7 @@ def pretrain_ict_model_provider():
#args.only_context_model = False
#args.only_context_model = False
#args.only_query_model = False
#args.only_query_model = False
#model = biencoder_model_provider()
#model = biencoder_model_provider()
model
=
biencoder_model_provider
(
model
=
biencoder_model_provider
(
only_context_model
=
False
,
only_context_model
=
False
,
only_query_model
=
False
,
only_query_model
=
False
,
...
...
tasks/finetune_utils.py
View file @
2eaf6c79
...
@@ -16,7 +16,6 @@
...
@@ -16,7 +16,6 @@
"""Finetune utilities."""
"""Finetune utilities."""
from
functools
import
partial
from
functools
import
partial
import
sys
import
torch
import
torch
...
@@ -81,7 +80,7 @@ def _cross_entropy_forward_step(batch, model):
...
@@ -81,7 +80,7 @@ def _cross_entropy_forward_step(batch, model):
return
output_tensor
,
partial
(
cross_entropy_loss_func
,
labels
)
return
output_tensor
,
partial
(
cross_entropy_loss_func
,
labels
)
def
build_data_loader
(
dataset
,
micro_batch_size
,
num_workers
,
drop_last
,
def
build_data_loader
(
dataset
,
micro_batch_size
,
num_workers
,
drop_last
,
task_collate_fn
=
None
):
task_collate_fn
=
None
):
"""Data loader. Note that batch-size is the local (per GPU) batch-size."""
"""Data loader. Note that batch-size is the local (per GPU) batch-size."""
...
@@ -190,7 +189,7 @@ def _train(model, optimizer, lr_scheduler, forward_step,
...
@@ -190,7 +189,7 @@ def _train(model, optimizer, lr_scheduler, forward_step,
continue
continue
# Set to zero so the next epoch does not skip any batches.
# Set to zero so the next epoch does not skip any batches.
start_iteration
=
0
start_iteration
=
0
# Train for one step.
# Train for one step.
out
=
train_step
(
forward_step
,
batch
,
model
,
optimizer
,
lr_scheduler
)
out
=
train_step
(
forward_step
,
batch
,
model
,
optimizer
,
lr_scheduler
)
...
@@ -226,9 +225,6 @@ def _train(model, optimizer, lr_scheduler, forward_step,
...
@@ -226,9 +225,6 @@ def _train(model, optimizer, lr_scheduler, forward_step,
valid_dataloader
,
model
,
valid_dataloader
,
model
,
iteration
,
False
)
iteration
,
False
)
#if iteration == 600:
# sys.exit()
# Checkpointing at the end of each epoch.
# Checkpointing at the end of each epoch.
if
args
.
save
:
if
args
.
save
:
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
...
...
tasks/main.py
View file @
2eaf6c79
...
@@ -89,8 +89,8 @@ def get_tasks_args(parser):
...
@@ -89,8 +89,8 @@ def get_tasks_args(parser):
# help="Av.rank validation: batch size to process passages")
# help="Av.rank validation: batch size to process passages")
#group.add_argument("--val-av-rank-max-qs", type=int, default=10000,
#group.add_argument("--val-av-rank-max-qs", type=int, default=10000,
# help="Av.rank validation: max num of questions")
# help="Av.rank validation: max num of questions")
return
parser
return
parser
...
...
tasks/orqa/evaluate_orqa.py
View file @
2eaf6c79
...
@@ -15,18 +15,6 @@
...
@@ -15,18 +15,6 @@
"""Main tasks functionality."""
"""Main tasks functionality."""
import
os
import
sys
#sys.path.append(
# os.path.abspath(
# os.path.join(
# os.path.join(os.path.dirname(__file__), os.path.pardir),
# os.path.pardir,
# )
# )
#)
from
megatron
import
get_args
,
print_rank_0
from
megatron
import
get_args
,
print_rank_0
from
megatron.indexer
import
IndexBuilder
from
megatron.indexer
import
IndexBuilder
from
tasks.orqa.evaluate_utils
import
ORQAEvaluator
from
tasks.orqa.evaluate_utils
import
ORQAEvaluator
...
@@ -35,30 +23,23 @@ def main():
...
@@ -35,30 +23,23 @@ def main():
"""
"""
Main program
Main program
"""
"""
#initialize_megatron(extra_args_provider=None,
# args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
args
=
get_args
()
args
=
get_args
()
"""Create a BlockData data structure by running an IndexBuilder over an ICT Dataset
"""
- Include all args needed for initial model specification
Create a BlockData data structure by running an IndexBuilder over an
ICT Dataset and then evaluate on NQ task
Other key args:
--block-data-path: path to write to
--ict-load or --realm-load: path to checkpoint with which to embed
--data-path and --titles-data-path: paths for dataset
--indexer-log-interval: reporting interval
--indexer-batch-size: size specific for indexer jobs
Check README.md for example script
"""
"""
#
print_rank_0("Starting index builder!")
print_rank_0
(
"Starting index builder!"
)
index_builder
=
IndexBuilder
()
index_builder
=
IndexBuilder
()
index_builder
.
build_and_save_index
()
index_builder
.
build_and_save_index
()
print_rank_0
(
"Build and save indices: done!"
)
print_rank_0
(
"Build and save indices: done!"
)
print_rank_0
(
"Starting evaluations!"
)
# Set up the model and evaluator
# Set up the model and evaluator
evaluator
=
ORQAEvaluator
()
evaluator
=
ORQAEvaluator
()
...
@@ -68,4 +49,4 @@ def main():
...
@@ -68,4 +49,4 @@ def main():
if
args
.
qa_data_test
is
not
None
:
if
args
.
qa_data_test
is
not
None
:
evaluator
.
evaluate
(
args
.
qa_data_test
,
"TEST"
)
evaluator
.
evaluate
(
args
.
qa_data_test
,
"TEST"
)
tasks/orqa/evaluate_utils.py
View file @
2eaf6c79
...
@@ -47,10 +47,9 @@ class ORQAEvaluator(object):
...
@@ -47,10 +47,9 @@ class ORQAEvaluator(object):
#args.only_query_model = only_query_model
#args.only_query_model = only_query_model
#args.only_context_model = False
#args.only_context_model = False
model
=
get_model
(
get_model_provider
(
only_query_model
=
only_query_model
,
model
=
get_model
(
get_model_provider
(
only_query_model
=
only_query_model
,
biencoder_shared_query_context_model
=
args
.
biencoder_shared_query_context_model
))
biencoder_shared_query_context_model
=
args
.
biencoder_shared_query_context_model
))
#model = get_model(lambda: biencoder_model_provider(only_query_model=\
#model = get_model(lambda: biencoder_model_provider(only_query_model=\
#model = get_model(lambda: biencoder_model_provider(only_query_model=\
#model = get_model(lambda: biencoder_model_provider(only_query_model=\
# only_query_model, biencoder_shared_query_context_model=\
# only_query_model, biencoder_shared_query_context_model=\
...
...
tasks/orqa/supervised/data.py
View file @
2eaf6c79
...
@@ -104,9 +104,9 @@ def build_tokens_types_paddings_from_ids(text_ids, max_seq_length,
...
@@ -104,9 +104,9 @@ def build_tokens_types_paddings_from_ids(text_ids, max_seq_length,
return
enc_ids
,
tokentypes_enc
,
pad_mask
return
enc_ids
,
tokentypes_enc
,
pad_mask
def
build_sample
(
query_ids
,
query_types
,
query_pad_mask
,
def
build_sample
(
query_ids
,
query_types
,
query_pad_mask
,
ctx_ids
,
ctx_types
,
ctx_pad_mask
,
answers
,
ctx_ids
,
ctx_types
,
ctx_pad_mask
,
answers
,
neg_ctx_id_list
=
None
,
neg_ctx_types_list
=
None
,
neg_ctx_id_list
=
None
,
neg_ctx_types_list
=
None
,
include_neg
=
False
):
include_neg
=
False
):
"""Convert to numpy and return a sample consumed by the batch producer."""
"""Convert to numpy and return a sample consumed by the batch producer."""
...
@@ -295,5 +295,3 @@ class NQSupervisedDataset(OpenRetrievalAbstractDataset):
...
@@ -295,5 +295,3 @@ class NQSupervisedDataset(OpenRetrievalAbstractDataset):
print_rank_0
(
' >> processed {} samples.'
.
format
(
len
(
samples
)))
print_rank_0
(
' >> processed {} samples.'
.
format
(
len
(
samples
)))
return
samples
return
samples
tasks/orqa/supervised/eval_utils.py
View file @
2eaf6c79
...
@@ -34,7 +34,6 @@ def task_collate_fn(batch_data):
...
@@ -34,7 +34,6 @@ def task_collate_fn(batch_data):
for
d
in
batch_data
:
for
d
in
batch_data
:
for
k
,
v
in
d
.
items
():
for
k
,
v
in
d
.
items
():
tensorized
.
setdefault
(
k
,
[]).
append
(
v
)
tensorized
.
setdefault
(
k
,
[]).
append
(
v
)
# assert len(tensorized) == 12
tensorized
[
'query'
]
=
torch
.
LongTensor
(
tensorized
[
'query'
])
tensorized
[
'query'
]
=
torch
.
LongTensor
(
tensorized
[
'query'
])
tensorized
[
'query_mask'
]
=
torch
.
LongTensor
(
tensorized
[
'query_mask'
])
tensorized
[
'query_mask'
]
=
torch
.
LongTensor
(
tensorized
[
'query_mask'
])
...
@@ -90,8 +89,6 @@ def process_batch(batch):
...
@@ -90,8 +89,6 @@ def process_batch(batch):
neg_context_tokens
,
neg_context_mask
,
neg_context_types
,
reference
neg_context_tokens
,
neg_context_mask
,
neg_context_types
,
reference
def
accuracy_func_provider
(
single_dataset_provider
,
rank0sampler
=
False
):
def
accuracy_func_provider
(
single_dataset_provider
,
rank0sampler
=
False
):
#, datapath,
# rank0sampler=False):
"""Provide function that calculates accuracies."""
"""Provide function that calculates accuracies."""
args
=
get_args
()
args
=
get_args
()
...
@@ -112,9 +109,7 @@ def accuracy_func_provider(single_dataset_provider, rank0sampler=False):
...
@@ -112,9 +109,7 @@ def accuracy_func_provider(single_dataset_provider, rank0sampler=False):
args
.
eval_micro_batch_size
,
args
.
eval_micro_batch_size
,
num_workers
=
args
.
num_workers
,
num_workers
=
args
.
num_workers
,
drop_last
=
drop_last
,
drop_last
=
drop_last
,
task_collate_fn
=
task_collate_fn
)
task_collate_fn
=
task_collate_fn
)
#shuffle=False,
#rank0sampler=rank0sampler)
dataloaders
=
(
dataset
.
dataset_name
,
dataloader
)
dataloaders
=
(
dataset
.
dataset_name
,
dataloader
)
def
metrics_func
(
model
,
epoch
,
output_predictions
=
False
):
def
metrics_func
(
model
,
epoch
,
output_predictions
=
False
):
...
@@ -197,7 +192,7 @@ def retrieval_loss(model, dataloader):
...
@@ -197,7 +192,7 @@ def retrieval_loss(model, dataloader):
losses
=
average_losses_across_data_parallel_group
([
rank
,
\
losses
=
average_losses_across_data_parallel_group
([
rank
,
\
*
topk_accs
])
*
topk_accs
])
# create stats_dict with retrieval loss and all specified
# create stats_dict with retrieval loss and all specified
# top-k accuracies
# top-k accuracies
topk_acc_dict
=
{
'top{}_acc'
.
format
(
k
):
v
*
100
for
k
,
v
in
\
topk_acc_dict
=
{
'top{}_acc'
.
format
(
k
):
v
*
100
for
k
,
v
in
\
zip
(
args
.
retriever_report_topk_accuracies
,
losses
[
1
:])}
zip
(
args
.
retriever_report_topk_accuracies
,
losses
[
1
:])}
...
...
tasks/orqa/supervised/finetune.py
View file @
2eaf6c79
...
@@ -22,27 +22,21 @@ import math
...
@@ -22,27 +22,21 @@ import math
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
megatron
import
get_args
from
megatron
import
get_args
,
get_timers
,
get_tokenizer
from
megatron
import
get_timers
from
megatron
import
mpu
,
print_rank_0
from
megatron
import
get_tokenizer
from
megatron.indexer
import
IndexBuilder
from
megatron
import
mpu
from
megatron
import
print_rank_0
from
megatron.utils
import
average_losses_across_data_parallel_group
from
megatron.model.biencoder_model
import
biencoder_model_provider
from
megatron.model.biencoder_model
import
biencoder_model_provider
#from tasks.t5_model_utils.finetune_utils_open_retrieval import accuracy_func_provider
from
megatron.utils
import
average_losses_across_data_parallel_group
#from tasks.t5_model_utils.finetune_utils_open_retrieval import finetune
from
pretrain_ict
import
get_group_world_size_rank
from
pretrain_ict
import
get_group_world_size_rank
from
tasks.finetune_utils
import
finetune
from
tasks.finetune_utils
import
finetune
from
tasks.orqa.supervised.eval_utils
import
accuracy_func_provider
from
tasks.orqa.supervised.eval_utils
import
accuracy_func_provider
from
tasks.orqa.supervised.eval_utils
import
process_batch
,
task_collate_fn
from
tasks.orqa.supervised.eval_utils
import
process_batch
,
task_collate_fn
from
tasks.orqa.evaluate_utils
import
ORQAEvaluator
from
tasks.orqa.evaluate_utils
import
ORQAEvaluator
from
megatron.indexer
import
IndexBuilder
def
orqa
(
Dataset
):
# , name_from_datapath_func):
def
orqa
(
Dataset
):
def
cross_entropy_forward_step
(
batch
,
model
):
def
cross_entropy_forward_step
(
batch
,
model
):
"""Simple forward step with cross-entropy loss."""
"""Simple forward step with cross-entropy loss."""
args
=
get_args
()
timers
=
get_timers
()
timers
=
get_timers
()
tokenizer
=
get_tokenizer
()
tokenizer
=
get_tokenizer
()
...
@@ -73,17 +67,15 @@ def orqa(Dataset): # , name_from_datapath_func):
...
@@ -73,17 +67,15 @@ def orqa(Dataset): # , name_from_datapath_func):
context_types
=
torch
.
cat
([
context_types
,
neg_context_types
])
context_types
=
torch
.
cat
([
context_types
,
neg_context_types
])
# Forward model.
# Forward model.
#query_logits, context_logits = model(query_tokens, query_mask,
output_tensor
=
model
(
query_tokens
,
query_mask
,
output_tensor
=
model
(
query_tokens
,
query_mask
,
query_types
,
context_tokens
,
query_types
,
context_tokens
,
context_mask
,
context_types
)
context_mask
,
context_types
)
return
output_tensor
,
partial
(
cross_entropy_loss_func
_
,
query_tokens
,
context_tokens
)
return
output_tensor
,
partial
(
cross_entropy_loss_func
,
query_tokens
,
context_tokens
)
#def cross_entropy_loss_func(labels, output_tensor):
def
cross_entropy_loss_func
(
query_tokens
,
context_tokens
,
output_tensor
):
def
cross_entropy_loss_func_
(
query_tokens
,
context_tokens
,
output_tensor
):
args
=
get_args
()
args
=
get_args
()
local_batch_size
=
query_tokens
.
shape
[
0
]
local_batch_size
=
query_tokens
.
shape
[
0
]
group
,
rank
,
world_size
=
get_group_world_size_rank
()
group
,
rank
,
world_size
=
get_group_world_size_rank
()
...
@@ -184,12 +176,9 @@ def orqa(Dataset): # , name_from_datapath_func):
...
@@ -184,12 +176,9 @@ def orqa(Dataset): # , name_from_datapath_func):
"""Build the model."""
"""Build the model."""
args
=
get_args
()
args
=
get_args
()
print_rank_0
(
'building retriever model for {} ...'
.
format
(
args
.
task
))
print_rank_0
(
'building retriever model for {} ...'
.
format
(
args
.
task
))
#args.only_context_model=False
#args.only_query_model=False
#model = biencoder_model_provider()
model
=
biencoder_model_provider
(
only_context_model
=
False
,
model
=
biencoder_model_provider
(
only_context_model
=
False
,
only_query_model
=
False
,
only_query_model
=
False
,
biencoder_shared_query_context_model
=
\
biencoder_shared_query_context_model
=
\
args
.
biencoder_shared_query_context_model
,
args
.
biencoder_shared_query_context_model
,
pre_process
=
pre_process
,
post_process
=
post_process
)
pre_process
=
pre_process
,
post_process
=
post_process
)
...
@@ -200,7 +189,6 @@ def orqa(Dataset): # , name_from_datapath_func):
...
@@ -200,7 +189,6 @@ def orqa(Dataset): # , name_from_datapath_func):
args
=
get_args
()
args
=
get_args
()
tokenizer
=
get_tokenizer
()
tokenizer
=
get_tokenizer
()
#name = name_from_datapath_func(datapath)
name
=
datapath
[
0
].
split
(
'/'
)[
-
1
].
split
(
'.'
)[
0
]
name
=
datapath
[
0
].
split
(
'/'
)[
-
1
].
split
(
'.'
)[
0
]
return
Dataset
(
name
,
return
Dataset
(
name
,
datapath
,
datapath
,
...
@@ -208,41 +196,25 @@ def orqa(Dataset): # , name_from_datapath_func):
...
@@ -208,41 +196,25 @@ def orqa(Dataset): # , name_from_datapath_func):
args
.
retriever_seq_length
,
args
.
retriever_seq_length
,
evaluate
=
True
)
evaluate
=
True
)
#def distributed_metrics_func_provider():
def
metrics_func_provider
():
def
metrics_func_provider
():
"""Provide metrics callback function."""
"""Provide metrics callback function."""
#def name_from_datapath(datapath):
# return datapath[0].split('/')[-1].split('.')[0]
return
accuracy_func_provider
(
single_dataset_provider
)
return
accuracy_func_provider
(
single_dataset_provider
)
#def rank0_metrics_func_provider(datapath):
# """Provide metrics callback function."""
# return accuracy_func_provider(single_dataset_provider, datapath,
# rank0sampler=True)
"""Finetune/evaluate."""
"""Finetune/evaluate."""
finetune
(
train_valid_datasets_provider
,
finetune
(
train_valid_datasets_provider
,
model_provider
,
model_provider
,
forward_step
=
cross_entropy_forward_step
,
forward_step
=
cross_entropy_forward_step
,
end_of_epoch_callback_provider
=
metrics_func_provider
,
end_of_epoch_callback_provider
=
metrics_func_provider
,
task_collate_fn
=
task_collate_fn
)
task_collate_fn
=
task_collate_fn
)
#,end_of_training_callback_provider=rank0_metrics_func_provider)
def
main
():
def
main
():
args
=
get_args
()
args
=
get_args
()
if
args
.
task
==
'RET-FINETUNE-NQ'
:
if
args
.
task
==
'RET-FINETUNE-NQ'
:
from
tasks.orqa.supervised.data
import
NQSupervisedDataset
as
Dataset
from
tasks.orqa.supervised.data
import
NQSupervisedDataset
as
Dataset
#def name_from_datapath(datapath):
# return datapath[0].split('/')[-1].split('.')[0]
else
:
else
:
raise
NotImplementedError
(
'ORQA task {} is not implemented.'
.
format
(
raise
NotImplementedError
(
'ORQA task {} is not implemented.'
.
format
(
args
.
task
))
args
.
task
))
orqa
(
Dataset
)
#, name_from_datapath)
orqa
(
Dataset
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment