Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
10ff0607
Commit
10ff0607
authored
Apr 09, 2021
by
Mostofa Patwary
Browse files
implementing DPR
parent
a5acbf53
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
40 additions
and
0 deletions
+40
-0
megatron/data/biencoder_dataset_utils.py
megatron/data/biencoder_dataset_utils.py
+7
-0
tasks/finetune_utils.py
tasks/finetune_utils.py
+2
-0
tasks/main.py
tasks/main.py
+31
-0
No files found.
megatron/data/biencoder_dataset_utils.py
View file @
10ff0607
...
@@ -20,6 +20,13 @@ def make_attention_mask(source_block, target_block):
...
@@ -20,6 +20,13 @@ def make_attention_mask(source_block, target_block):
# (source_length, target_length)
# (source_length, target_length)
return
mask
return
mask
def
make_history_mask
(
block
):
length
=
block
.
shape
[
0
]
arange
=
np
.
arange
(
length
)
history_mask
=
(
arange
[
None
,
]
<=
arange
[:,
None
])
history_mask
=
history_mask
.
astype
(
np
.
int64
)
return
history_mask
def
get_one_epoch_dataloader
(
dataset
,
micro_batch_size
=
None
):
def
get_one_epoch_dataloader
(
dataset
,
micro_batch_size
=
None
):
"""Specifically one epoch to be used in an indexing job."""
"""Specifically one epoch to be used in an indexing job."""
args
=
get_args
()
args
=
get_args
()
...
...
tasks/finetune_utils.py
View file @
10ff0607
...
@@ -248,6 +248,8 @@ def finetune(train_valid_datasets_provider, model_provider,
...
@@ -248,6 +248,8 @@ def finetune(train_valid_datasets_provider, model_provider,
end_of_epoch_callback
=
end_of_epoch_callback_provider
()
end_of_epoch_callback
=
end_of_epoch_callback_provider
()
timers
(
'callback function'
).
stop
()
timers
(
'callback function'
).
stop
()
exit
()
# Build model, optimizer and learning rate scheduler.
# Build model, optimizer and learning rate scheduler.
timers
(
'model and optimizer'
).
start
()
timers
(
'model and optimizer'
).
start
()
model
,
optimizer
,
lr_scheduler
=
setup_model_and_optimizer
(
model_provider
)
model
,
optimizer
,
lr_scheduler
=
setup_model_and_optimizer
(
model_provider
)
...
...
tasks/main.py
View file @
10ff0607
...
@@ -62,6 +62,35 @@ def get_tasks_args(parser):
...
@@ -62,6 +62,35 @@ def get_tasks_args(parser):
group
.
add_argument
(
'--faiss-topk-retrievals'
,
type
=
int
,
default
=
100
,
group
.
add_argument
(
'--faiss-topk-retrievals'
,
type
=
int
,
default
=
100
,
help
=
'Number of blocks to use as top-k during retrieval'
)
help
=
'Number of blocks to use as top-k during retrieval'
)
# finetune for retriever
group
.
add_argument
(
'--eval-micro-batch-size'
,
type
=
int
,
default
=
None
,
help
=
'Eval Batch size per model instance (local batch '
'size). Global batch size is local batch size '
'times data parallel size.'
)
group
.
add_argument
(
'--train-with-neg'
,
action
=
'store_true'
,
help
=
'Whether to use negative examples during model '
'training'
)
group
.
add_argument
(
'--train-hard-neg'
,
type
=
int
,
default
=
0
,
help
=
'Number of hard negative exmaples to use during '
'training'
)
# parameters for Av.rank validation method
# Following options/arguments have been taken directly from DPR codebase
#group.add_argument("--val-av-rank-start-epoch", type=int, default=10000,
# help="Av.rank validation: the epoch from which to enable this validation")
group
.
add_argument
(
'--val-av-rank-hard-neg'
,
type
=
int
,
default
=
30
,
help
=
'Av.rank validation: how many hard negatives to'
' take from each question pool'
)
group
.
add_argument
(
'--val-av-rank-other-neg'
,
type
=
int
,
default
=
30
,
help
=
'Av.rank validation: how many other negatives to'
' take from each question pool'
)
#group.add_argument("--val-av-rank-bsz", type=int, default=128,
# help="Av.rank validation: batch size to process passages")
#group.add_argument("--val-av-rank-max-qs", type=int, default=10000,
# help="Av.rank validation: max num of questions")
return
parser
return
parser
...
@@ -78,6 +107,8 @@ if __name__ == '__main__':
...
@@ -78,6 +107,8 @@ if __name__ == '__main__':
from
zeroshot_gpt.evaluate
import
main
from
zeroshot_gpt.evaluate
import
main
elif
args
.
task
in
[
'ICT-ZEROSHOT-NQ'
]:
elif
args
.
task
in
[
'ICT-ZEROSHOT-NQ'
]:
from
orqa.evaluate_orqa
import
main
from
orqa.evaluate_orqa
import
main
elif
args
.
task
in
[
'RET-FINETUNE-NQ'
]:
from
orqa.supervised.finetune
import
main
else
:
else
:
raise
NotImplementedError
(
'Task {} is not implemented.'
.
format
(
raise
NotImplementedError
(
'Task {} is not implemented.'
.
format
(
args
.
task
))
args
.
task
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment