Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
8b04e0e4
Commit
8b04e0e4
authored
Mar 10, 2021
by
Mostofa Patwary
Browse files
ICT zeroshot evaluation code
parent
1a2cb60c
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
980 additions
and
55 deletions
+980
-55
examples/evaluate_ict_zeroshot_nq.sh
examples/evaluate_ict_zeroshot_nq.sh
+34
-0
megatron/arguments.py
megatron/arguments.py
+9
-0
megatron/data/biencoder_dataset_utils.py
megatron/data/biencoder_dataset_utils.py
+0
-3
megatron/data/realm_index.py
megatron/data/realm_index.py
+57
-52
tasks/orqa/evaluate_orqa.py
tasks/orqa/evaluate_orqa.py
+49
-0
tasks/orqa/evaluate_utils.py
tasks/orqa/evaluate_utils.py
+188
-0
tasks/orqa/natural_questions/nq.py
tasks/orqa/natural_questions/nq.py
+228
-0
tasks/orqa/natural_questions/qa_utils.py
tasks/orqa/natural_questions/qa_utils.py
+174
-0
tasks/orqa/natural_questions/tokenizers.py
tasks/orqa/natural_questions/tokenizers.py
+241
-0
No files found.
examples/evaluate_ict_zeroshot_nq.sh
0 → 100644
View file @
8b04e0e4
#!/bin/bash
# Evaluate natural question test data given Wikipedia embeddings and pretrained
# ICT model
# Datasets can be downloaded from the following link:
# https://github.com/facebookresearch/DPR/blob/master/data/download_data.py
EVIDENCE_DATA_DIR
=
<Specify path of Wikipedia dataset>
EMBEDDING_PATH
=
<Specify path of the embeddings>
CHECKPOINT_PATH
=
<Specify path of pretrained ICT model>
QA_FILE
=
<Path of the natural question
test
dataset>
python tasks/orqa/evaluate_orqa.py
\
--num-layers
12
\
--hidden-size
768
\
--num-attention-heads
12
\
--tensor-model-parallel-size
1
\
--micro-batch-size
128
\
--checkpoint-activations
\
--seq-length
512
\
--max-position-embeddings
512
\
--load
${
CHECKPOINT_PATH
}
\
--evidence-data-path
${
EVIDENCE_DATA_DIR
}
\
--embedding-path
${
EMBEDDING_PATH
}
\
--retriever-seq-length
256
\
--vocab-file
bert-vocab.txt
\
--qa-data-test
${
QA_FILE
}
\
--num-workers
2
\
--faiss-use-gpu
\
--retriever-report-topk-accuracies
1 5 20 100
\
--fp16
megatron/arguments.py
View file @
8b04e0e4
...
@@ -636,6 +636,10 @@ def _add_data_args(parser):
...
@@ -636,6 +636,10 @@ def _add_data_args(parser):
'1) a single data path, 2) multiple datasets in the'
'1) a single data path, 2) multiple datasets in the'
'form: dataset1-weight dataset1-path dataset2-weight '
'form: dataset1-weight dataset1-path dataset2-weight '
'dataset2-path ...'
)
'dataset2-path ...'
)
group
.
add_argument
(
'--qa-data-dev'
,
type
=
str
,
default
=
None
,
help
=
'Path to the QA dataset dev file.'
)
group
.
add_argument
(
'--qa-data-test'
,
type
=
str
,
default
=
None
,
help
=
'Path to the QA dataset test file.'
)
group
.
add_argument
(
'--split'
,
type
=
str
,
default
=
'969, 30, 1'
,
group
.
add_argument
(
'--split'
,
type
=
str
,
default
=
'969, 30, 1'
,
help
=
'Comma-separated list of proportions for training,'
help
=
'Comma-separated list of proportions for training,'
' validation, and test split. For example the split '
' validation, and test split. For example the split '
...
@@ -746,6 +750,11 @@ def _add_biencoder_args(parser):
...
@@ -746,6 +750,11 @@ def _add_biencoder_args(parser):
group
.
add_argument
(
'--embedding-path'
,
type
=
str
,
default
=
None
,
group
.
add_argument
(
'--embedding-path'
,
type
=
str
,
default
=
None
,
help
=
'Where to save/load Open-Retrieval Embedding'
help
=
'Where to save/load Open-Retrieval Embedding'
' data to/from'
)
' data to/from'
)
group
.
add_argument
(
'--faiss-match'
,
type
=
str
,
default
=
'string'
,
\
choices
=
[
'regex'
,
'string'
],
help
=
"Answer matching '
\
'logic type"
)
group
.
add_argument
(
'--faiss-topk-retrievals'
,
type
=
int
,
default
=
100
,
help
=
'Number of blocks to use as top-k during retrieval'
)
# indexer
# indexer
group
.
add_argument
(
'--indexer-batch-size'
,
type
=
int
,
default
=
128
,
group
.
add_argument
(
'--indexer-batch-size'
,
type
=
int
,
default
=
128
,
...
...
megatron/data/biencoder_dataset_utils.py
View file @
8b04e0e4
...
@@ -24,11 +24,8 @@ def get_one_epoch_dataloader(dataset, micro_batch_size=None):
...
@@ -24,11 +24,8 @@ def get_one_epoch_dataloader(dataset, micro_batch_size=None):
"""Specifically one epoch to be used in an indexing job."""
"""Specifically one epoch to be used in an indexing job."""
args
=
get_args
()
args
=
get_args
()
world_size
=
mpu
.
get_data_parallel_world_size
()
rank
=
mpu
.
get_data_parallel_rank
()
if
micro_batch_size
is
None
:
if
micro_batch_size
is
None
:
micro_batch_size
=
args
.
micro_batch_size
micro_batch_size
=
args
.
micro_batch_size
global_batch_size
=
micro_batch_size
*
world_size
num_workers
=
args
.
num_workers
num_workers
=
args
.
num_workers
# Use megatron's sampler with consumed samples set to 0 as
# Use megatron's sampler with consumed samples set to 0 as
...
...
megatron/data/realm_index.py
View file @
8b04e0e4
...
@@ -116,18 +116,22 @@ class OpenRetreivalDataStore(object):
...
@@ -116,18 +116,22 @@ class OpenRetreivalDataStore(object):
class
FaissMIPSIndex
(
object
):
class
FaissMIPSIndex
(
object
):
"""Wrapper object for a BlockData which similarity search via FAISS under the hood"""
"""
def
__init__
(
self
,
embed_size
,
block_data
=
None
,
use_gpu
=
False
):
Wrapper object for a BlockData which similarity search via FAISS under the hood
"""
def
__init__
(
self
,
embed_size
,
embed_data
=
None
,
use_gpu
=
False
):
self
.
embed_size
=
embed_size
self
.
embed_size
=
embed_size
self
.
block
_data
=
block
_data
self
.
embed
_data
=
embed
_data
self
.
use_gpu
=
use_gpu
self
.
use_gpu
=
use_gpu
self
.
id_map
=
dict
()
self
.
block_
mips_index
=
None
self
.
mips_index
=
None
self
.
_set_
block
_index
()
self
.
_set_
mips
_index
()
def
_set_block_index
(
self
):
def
_set_mips_index
(
self
):
"""Create a Faiss Flat index with inner product as the metric to search against"""
"""
Create a Faiss Flat index with inner product as the metric
to search against
"""
try
:
try
:
import
faiss
import
faiss
except
ImportError
:
except
ImportError
:
...
@@ -135,85 +139,86 @@ class FaissMIPSIndex(object):
...
@@ -135,85 +139,86 @@ class FaissMIPSIndex(object):
if
mpu
.
is_unitialized
()
or
mpu
.
get_data_parallel_rank
()
==
0
:
if
mpu
.
is_unitialized
()
or
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
"
\n
> Building index"
,
flush
=
True
)
print
(
"
\n
> Building index"
,
flush
=
True
)
self
.
block_mips_index
=
faiss
.
index_factory
(
self
.
embed_size
,
'Flat'
,
faiss
.
METRIC_INNER_PRODUCT
)
cpu_index
=
faiss
.
IndexFlatIP
(
self
.
embed_size
)
if
self
.
use_gpu
:
if
self
.
use_gpu
:
# create resources and config for GpuIndex
# create resources and config for GpuIndex
res
=
faiss
.
StandardGpuResources
()
config
=
faiss
.
GpuMultipleClonerOptions
()
config
=
faiss
.
GpuIndexFlatConfig
()
config
.
shard
=
True
config
.
device
=
torch
.
cuda
.
current_device
()
config
.
useFloat16
=
True
config
.
useFloat16
=
True
gpu_index
=
faiss
.
index_cpu_to_all_gpus
(
cpu_index
,
co
=
config
)
self
.
block_
mips_index
=
faiss
.
Gpu
Index
Flat
(
res
,
self
.
block_mips_index
,
config
)
self
.
mips_index
=
faiss
.
Index
IDMap
(
gpu_index
)
if
mpu
.
is_unitialized
()
or
mpu
.
get_data_parallel_rank
()
==
0
:
if
mpu
.
is_unitialized
()
or
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
">> Initialized index on GPU
{}"
.
format
(
self
.
block_mips_index
.
getDevice
())
,
flush
=
True
)
print
(
">> Initialized index on GPU
"
,
flush
=
True
)
else
:
else
:
# CPU index supports IDs so wrap with IDMap
# CPU index supports IDs so wrap with IDMap
self
.
block_
mips_index
=
faiss
.
IndexIDMap
(
self
.
block_mips
_index
)
self
.
mips_index
=
faiss
.
IndexIDMap
(
cpu
_index
)
if
mpu
.
is_unitialized
()
or
mpu
.
get_data_parallel_rank
()
==
0
:
if
mpu
.
is_unitialized
()
or
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
">> Initialized index on CPU"
,
flush
=
True
)
print
(
">> Initialized index on CPU"
,
flush
=
True
)
# if we were constructed with a BlockData, then automatically load it when the FAISS structure is built
# if we were constructed with a BlockData, then automatically load it
if
self
.
block_data
is
not
None
:
# when the FAISS structure is built
self
.
add_block_embed_data
(
self
.
block_data
)
if
self
.
embed_data
is
not
None
:
self
.
add_embed_data
(
self
.
embed_data
)
def
reset_index
(
self
):
def
reset_index
(
self
):
"""Delete existing index and create anew"""
"""Delete existing index and create a
new"""
del
self
.
block_
mips_index
del
self
.
mips_index
# reset the block data so that _set_block_index will reload it as well
# reset the block data so that _set_block_index will reload it as well
if
self
.
block_data
is
not
None
:
if
self
.
embed_data
is
not
None
:
block_data_path
=
self
.
block_data
.
block_data_path
embed_data_path
=
self
.
embed_data
.
embedding_path
del
self
.
block_data
del
self
.
embed_data
self
.
block_data
=
BlockData
(
block_data_path
)
self
.
embed_data
=
OpenRetreivalDataStore
(
embed_data_path
)
self
.
_set_mips_index
()
self
.
_set_block_index
()
def
update_index
(
self
):
"""Delete existing index and create a new"""
del
self
.
mips_index
def
add_block_embed_data
(
self
,
all_block_data
):
# reset the block data so that _set_mips_index will reload it as well
if
self
.
embed_data
is
not
None
:
self
.
embed_data
.
load_from_file
()
self
.
_set_mips_index
()
def
add_embed_data
(
self
,
all_embed_data
):
"""Add the embedding of each block to the underlying FAISS index"""
"""Add the embedding of each block to the underlying FAISS index"""
# this assumes the embed_data is a dict : {int: np.array<float>}
# this assumes the embed_data is a dict : {int: np.array<float>}
block_indices
,
block_embeds
=
zip
(
*
all_block_data
.
embed_data
.
items
())
block_indices
,
block_embeds
=
zip
(
*
all_embed_data
.
embed_data
.
items
())
# the embeddings have to be entered in as float32 even though the math internally is done with float16.
block_embeds_arr
=
np
.
float32
(
np
.
array
(
block_embeds
))
block_indices_arr
=
np
.
array
(
block_indices
)
#
faiss GpuIndex doesn't work with IDMap wrapper so store ids to map back wi
th
#
the embeddings have to be entered in as float32 even though the ma
th
if
self
.
use_gpu
:
# internally is done with float16.
for
i
,
idx
in
enumerate
(
block_
indice
s
)
:
embeds_arr
=
np
.
float32
(
np
.
array
(
block_
embed
s
)
)
self
.
id_map
[
i
]
=
idx
indices_arr
=
np
.
array
(
block_indices
)
# we no longer need the embedding data since it's in the index now
# we no longer need the embedding data since it's in the index now
all_
block
_data
.
clear
()
all_
embed
_data
.
clear
()
if
self
.
use_gpu
:
self
.
mips_index
.
add_with_ids
(
embeds_arr
,
indices_arr
)
self
.
block_mips_index
.
add
(
block_embeds_arr
)
else
:
self
.
block_mips_index
.
add_with_ids
(
block_embeds_arr
,
block_indices_arr
)
if
mpu
.
is_unitialized
()
or
mpu
.
get_data_parallel_rank
()
==
0
:
if
mpu
.
is_unitialized
()
or
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
">>> Finished adding block data to index"
,
flush
=
True
)
print
(
">>> Finished adding block data to index"
,
flush
=
True
)
def
search_mips_index
(
self
,
query_embeds
,
top_k
,
reconstruct
=
True
):
def
search_mips_index
(
self
,
query_embeds
,
top_k
,
reconstruct
=
True
):
"""Get the top-k blocks by the index distance metric.
"""
Get the top-k blocks by the index distance metric.
:param reconstruct: if True: return a [num_queries x k x embed_dim] array of blocks
:param reconstruct: if True: return a [num_queries x k x embed_dim]
if False: return [num_queries x k] array of distances, and another for indices
array of blocks
if False: return [num_queries x k] array of
distances, and another for indices
"""
"""
query_embeds
=
np
.
float32
(
detach
(
query_embeds
))
query_embeds
=
np
.
float32
(
detach
(
query_embeds
))
if
reconstruct
:
if
reconstruct
:
# get the vectors themselves
# get the vectors themselves
top_k_block_embeds
=
self
.
block_mips_index
.
search_and_reconstruct
(
query_embeds
,
top_k
)
top_k_block_embeds
=
self
.
mips_index
.
search_and_reconstruct
(
\
query_embeds
,
top_k
)
return
top_k_block_embeds
return
top_k_block_embeds
else
:
else
:
# get distances and indices of closest vectors
# get distances and indices of closest vectors
distances
,
block_indices
=
self
.
block_mips_index
.
search
(
query_embeds
,
top_k
)
distances
,
block_indices
=
self
.
mips_index
.
search
(
query_embeds
,
top_k
)
if
self
.
use_gpu
:
fresh_indices
=
np
.
zeros
(
block_indices
.
shape
)
for
i
,
j
in
itertools
.
product
(
block_indices
.
shape
):
fresh_indices
[
i
,
j
]
=
self
.
id_map
[
block_indices
[
i
,
j
]]
block_indices
=
fresh_indices
return
distances
,
block_indices
return
distances
,
block_indices
tasks/orqa/evaluate_orqa.py
0 → 100644
View file @
8b04e0e4
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Main tasks functionality."""
import
os
import
sys
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
os
.
path
.
join
(
os
.
path
.
pardir
,
os
.
path
.
pardir
))))
from
megatron
import
get_args
from
megatron.initialize
import
initialize_megatron
from
tasks.orqa.evaluate_utils
import
ORQAEvaluator
def
main
():
"""
Main program
"""
initialize_megatron
(
extra_args_provider
=
None
,
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
args
=
get_args
()
# Set up the model and evaluator
evaluator
=
ORQAEvaluator
()
# Run evaluation
if
args
.
qa_data_dev
is
not
None
:
evaluator
.
evaluate
(
args
.
qa_data_dev
,
"DEV"
)
if
args
.
qa_data_test
is
not
None
:
evaluator
.
evaluate
(
args
.
qa_data_test
,
"TEST"
)
if
__name__
==
"__main__"
:
main
()
tasks/orqa/evaluate_utils.py
0 → 100644
View file @
8b04e0e4
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
from
megatron
import
get_args
,
print_rank_0
from
megatron.checkpointing
import
load_biencoder_checkpoint
from
megatron.data.orqa_wiki_dataset
import
get_open_retrieval_wiki_dataset
from
tasks.orqa.natural_questions.nq
import
get_nq_dataset
from
tasks.orqa.natural_questions.nq
import
get_one_epoch_nq_dataloader
from
tasks.orqa.natural_questions.nq
import
process_nq_batch
from
tasks.orqa.natural_questions.qa_utils
import
calculate_matches
from
megatron.data.realm_index
import
OpenRetreivalDataStore
,
FaissMIPSIndex
from
megatron.model.biencoder_model
import
biencoder_model_provider
from
megatron.training
import
get_model
class
ORQAEvaluator
(
object
):
def
__init__
(
self
):
args
=
get_args
()
self
.
embedding_size
=
args
.
hidden_size
self
.
faiss_use_gpu
=
args
.
faiss_use_gpu
self
.
evidence_embedder_obj
=
None
self
.
evidence_dataset
=
None
self
.
mips_index
=
None
self
.
eval_dataset
=
None
# Get Evidence (Wikipedia) dataset
self
.
get_evidence_dataset
()
# Load query encoder checkpoint
only_query_model
=
True
if
args
.
biencoder_shared_query_context_model
:
only_query_model
=
False
model
=
get_model
(
lambda
:
biencoder_model_provider
(
only_query_model
=
\
only_query_model
,
biencoder_shared_query_context_model
=
\
args
.
biencoder_shared_query_context_model
))
self
.
model
=
load_biencoder_checkpoint
(
model
,
only_query_model
=
only_query_model
)
assert
len
(
self
.
model
)
==
1
self
.
model
[
0
].
eval
()
# Load faiss indexer
self
.
faiss_wrapper
()
def
get_evidence_embedding
(
self
):
# This will load the embedding from the embedding path
self
.
evidence_embedder_obj
=
OpenRetreivalDataStore
(
load_from_path
=
True
)
def
get_evidence_dataset
(
self
):
self
.
evidence_dataset
=
get_open_retrieval_wiki_dataset
()
def
faiss_wrapper
(
self
):
# Initialize FAISS wrapper on local rank = 0 as the evidence embeddings
# is distributed over all the GPUs in a node and FAISS is not
# thread-safe
args
=
get_args
()
if
args
.
local_rank
==
0
:
# Get evidence embeddings computed using context encoder
self
.
get_evidence_embedding
()
assert
self
.
evidence_embedder_obj
is
not
None
self
.
mips_index
=
FaissMIPSIndex
(
embed_size
=
self
.
embedding_size
,
embed_data
=
self
.
evidence_embedder_obj
,
use_gpu
=
self
.
faiss_use_gpu
)
# Wait for the FAISS index to be initialized in all the nodes
torch
.
distributed
.
barrier
()
def
generate_query_vectors
(
self
,
qa_data
,
split
):
self
.
eval_dataset
=
get_nq_dataset
(
qa_data
,
split
)
dataloader
=
get_one_epoch_nq_dataloader
(
self
.
eval_dataset
)
query_vectors
=
[]
reference_list
=
[]
for
batch
in
dataloader
:
# batch also has query_tokens and query_pad_data
query_tokens
,
query_mask
,
query_types
,
\
query_len
,
reference
=
process_nq_batch
(
batch
)
assert
len
(
self
.
model
)
==
1
unwrapped_model
=
self
.
model
[
0
]
while
not
hasattr
(
unwrapped_model
,
'embed_text'
):
unwrapped_model
=
unwrapped_model
.
module
with
torch
.
no_grad
():
query_logits
=
unwrapped_model
.
embed_text
(
unwrapped_model
.
query_model
,
query_tokens
,
query_mask
,
query_types
)
reference_list
.
extend
(
reference
)
query_vectors
.
extend
(
query_logits
.
split
(
1
,
dim
=
0
))
if
len
(
query_vectors
)
%
100
==
0
:
print_rank_0
(
'Encoded queries {}'
.
format
(
len
(
query_vectors
)))
query_tensor
=
torch
.
cat
(
query_vectors
,
dim
=
0
)
print_rank_0
(
'Total encoded queries tensor {}'
.
format
(
query_tensor
.
size
()))
assert
query_tensor
.
size
(
0
)
==
len
(
self
.
eval_dataset
)
return
query_tensor
,
reference_list
def
evaluate
(
self
,
qa_data
,
split
):
args
=
get_args
()
query_tensor
,
reference_list
=
self
.
generate_query_vectors
(
qa_data
,
\
split
)
local_rank
=
args
.
local_rank
rank
=
torch
.
distributed
.
get_rank
()
device_count
=
torch
.
cuda
.
device_count
()
num_nodes
=
torch
.
distributed
.
get_world_size
()
//
device_count
node_id
=
rank
//
device_count
for
node
in
range
(
num_nodes
):
start_rank
=
node
*
device_count
end_rank
=
(
node
+
1
)
*
device_count
ranks_list
=
list
(
range
(
start_rank
,
end_rank
))
node_group
=
torch
.
distributed
.
new_group
(
ranks
=
ranks_list
)
if
node_id
==
node
:
device_start_rank
=
start_rank
group
=
node_group
input_
=
torch
.
empty_like
(
query_tensor
).
copy_
(
query_tensor
).
detach_
()
tensor_list
=
[
torch
.
empty_like
(
input_
)
for
_
in
range
(
device_count
)]
torch
.
distributed
.
all_gather
(
tensor_list
,
query_tensor
,
group
=
group
)
if
local_rank
==
0
and
self
.
mips_index
is
not
None
:
all_query_tensor
=
torch
.
cat
(
tensor_list
,
dim
=
0
).
contiguous
()
distance
,
topkindex
=
self
.
mips_index
.
search_mips_index
(
all_query_tensor
,
top_k
=
args
.
faiss_topk_retrievals
,
reconstruct
=
False
)
distance
=
torch
.
from_numpy
(
distance
).
cuda
()
topkindex
=
torch
.
LongTensor
(
topkindex
).
cuda
()
if
local_rank
!=
0
:
distance
=
torch
.
empty
(
device_count
*
len
(
query_tensor
),
\
args
.
faiss_topk_retrievals
,
dtype
=
torch
.
float32
).
cuda
()
topkindex
=
torch
.
empty
(
device_count
*
len
(
query_tensor
),
\
args
.
faiss_topk_retrievals
,
dtype
=
torch
.
int64
).
cuda
()
torch
.
distributed
.
broadcast
(
distance
,
src
=
device_start_rank
,
\
group
=
group
)
torch
.
distributed
.
broadcast
(
topkindex
,
src
=
device_start_rank
,
\
group
=
group
)
distance
=
torch
.
split
(
distance
,
len
(
query_tensor
),
dim
=
0
)
\
[
local_rank
]
topkindex
=
torch
.
split
(
topkindex
,
len
(
query_tensor
),
dim
=
0
)
\
[
local_rank
]
top_ids_and_scores
=
[]
for
darray
,
topkarray
in
zip
(
distance
,
topkindex
):
top_ids_and_scores
.
append
((
topkarray
.
tolist
(),
darray
.
tolist
()))
passages
=
self
.
evidence_dataset
.
id2text
match_stats
=
calculate_matches
(
passages
,
reference_list
,
top_ids_and_scores
,
workers_num
=
args
.
num_workers
,
match_type
=
args
.
faiss_match
)
top_k_hits
=
match_stats
.
top_k_hits
print_rank_0
(
"{} SET RESULTS"
.
format
(
split
))
print_rank_0
(
"topk-{} documents hits {}"
.
format
(
args
.
faiss_topk_retrievals
,
top_k_hits
))
top_k_hits
=
[
v
/
len
(
top_ids_and_scores
)
for
v
in
top_k_hits
]
print_rank_0
(
"top-k documents hits accuracy {}"
.
format
(
top_k_hits
))
for
i
in
args
.
retriever_report_topk_accuracies
:
print_rank_0
(
"top-{}: {:.2f}"
.
format
(
i
,
top_k_hits
[
i
-
1
]
*
100
))
return
tasks/orqa/natural_questions/nq.py
0 → 100644
View file @
8b04e0e4
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Data Loader for Google NQ dataset
"""
from
abc
import
ABC
import
csv
from
collections
import
OrderedDict
import
numpy
as
np
import
torch
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
Dataset
,
BatchSampler
from
megatron
import
print_rank_0
,
get_args
,
get_tokenizer
,
mpu
from
megatron.data.biencoder_dataset_utils
import
make_attention_mask
def
get_nq_dataset
(
qa_data
,
split
):
args
=
get_args
()
tokenizer
=
get_tokenizer
()
dataset
=
NQDataset
(
'Google NQ {} Split'
.
format
(
split
),
'Google Natural Questions'
,
qa_data
,
tokenizer
,
args
.
retriever_seq_length
)
return
dataset
def
process_nq_batch
(
batch
):
query_tokens
=
batch
[
'token_ids'
].
long
().
cuda
()
query_mask
=
(
batch
[
'token_mask'
]
<
0.5
).
cuda
()
query_types
=
batch
[
'token_types'
].
long
().
cuda
()
query_len
=
batch
[
'seq_len'
].
long
().
cuda
()
reference
=
batch
[
'reference'
]
return
query_tokens
,
query_mask
,
query_types
,
query_len
,
reference
class
CustomDataLoader
(
DataLoader
):
def
__init__
(
self
,
dataset
,
eval
=
False
,
**
kwargs
):
if
kwargs
.
get
(
'collate_fn'
,
None
)
is
None
:
kwargs
[
'collate_fn'
]
=
self
.
_collate_fn
self
.
eval
=
eval
super
().
__init__
(
dataset
,
**
kwargs
)
def
_collate_fn
(
self
,
batch_data
):
# generate batch
batch_size
=
len
(
batch_data
)
tensorized
=
OrderedDict
()
for
d
in
batch_data
:
for
k
,
v
in
d
.
items
():
tensorized
.
setdefault
(
k
,
[]).
append
(
v
)
assert
len
(
tensorized
)
==
5
tensorized
[
'token_ids'
]
=
torch
.
LongTensor
(
tensorized
[
'token_ids'
])
tensorized
[
'token_mask'
]
=
torch
.
LongTensor
(
tensorized
[
'token_mask'
])
tensorized
[
'token_types'
]
=
torch
.
LongTensor
(
tensorized
[
'token_types'
])
tensorized
[
'seq_len'
]
=
torch
.
LongTensor
(
tensorized
[
'seq_len'
])
return
tensorized
def
get_one_epoch_nq_dataloader
(
dataset
,
micro_batch_size
=
None
):
"""Data loader. Note that batch-size is the local (per GPU) batch-size.
NOTE: This dataloader is not distributed !!!
"""
args
=
get_args
()
if
micro_batch_size
is
None
:
micro_batch_size
=
args
.
micro_batch_size
num_workers
=
args
.
num_workers
sampler
=
torch
.
utils
.
data
.
SequentialSampler
(
dataset
)
# importantly, drop_last must be False to get all the data.
batch_sampler
=
BatchSampler
(
sampler
,
batch_size
=
micro_batch_size
,
drop_last
=
False
)
# Data loader. Note that batch size is the per GPU batch size.
data_loader
=
CustomDataLoader
(
dataset
,
batch_sampler
=
batch_sampler
,
num_workers
=
num_workers
,
pin_memory
=
True
)
return
data_loader
def
build_tokens_types_paddings_from_text
(
src_text
,
tokenizer
,
max_seq_length
):
"""Build token types and paddings, trim if needed, and pad if needed."""
src_text_ids
=
tokenizer
.
tokenize
(
src_text
)
return
build_tokens_types_paddings_from_ids
(
src_text_ids
,
max_seq_length
,
tokenizer
.
cls
,
tokenizer
.
sep
,
tokenizer
.
pad
)
def
build_tokens_types_paddings_from_ids
(
src_ids
,
max_seq_length
,
cls_id
,
\
sep_id
,
pad_id
):
"""
Build token types and paddings, trim if needed, and pad if needed.
TODO: Design modular interface to reuse this function. This is getting
repeated multiple times in different tasks
"""
enc_ids
=
[]
tokentypes_enc
=
[]
# [CLS].
enc_ids
.
append
(
cls_id
)
tokentypes_enc
.
append
(
0
)
# A.
len_src
=
len
(
src_ids
)
enc_ids
.
extend
(
src_ids
)
tokentypes_enc
.
extend
([
0
]
*
len_src
)
# Cap the size.
if
len
(
enc_ids
)
>
max_seq_length
-
1
:
enc_ids
=
enc_ids
[
0
:
max_seq_length
-
1
]
tokentypes_enc
=
tokentypes_enc
[
0
:
max_seq_length
-
1
]
# [SEP].
enc_ids
.
append
(
sep_id
)
tokentypes_enc
.
append
(
0
)
num_tokens_enc
=
len
(
enc_ids
)
# Padding.
padding_length
=
max_seq_length
-
len
(
enc_ids
)
if
padding_length
>
0
:
enc_ids
.
extend
([
pad_id
]
*
padding_length
)
tokentypes_enc
.
extend
([
pad_id
]
*
padding_length
)
return
enc_ids
,
tokentypes_enc
,
num_tokens_enc
def
build_sample
(
token_ids
,
token_types
,
num_tokens
,
reference
):
"""
Convert to numpy and return a sample consumed by the
batch producer.
"""
token_ids
=
np
.
array
(
token_ids
,
dtype
=
np
.
int64
)
token_types
=
np
.
array
(
token_types
,
dtype
=
np
.
int64
)
token_mask
=
make_attention_mask
(
token_ids
,
token_ids
)
sample
=
({
'token_ids'
:
token_ids
,
'token_mask'
:
token_mask
,
'token_types'
:
token_types
,
'seq_len'
:
num_tokens
,
'reference'
:
reference
})
return
sample
class
NQDataset
(
ABC
,
Dataset
):
"""
Open Retrieval Question Answering evaluation using Google NQ dataset.
"""
def
__init__
(
self
,
task_name
,
dataset_name
,
datapath
,
tokenizer
,
max_seq_length
):
# Store inputs.
self
.
task_name
=
task_name
self
.
dataset_name
=
dataset_name
self
.
tokenizer
=
tokenizer
self
.
max_seq_length
=
max_seq_length
print_rank_0
(
' > building {} dataset for {}:'
.
format
(
self
.
task_name
,
self
.
dataset_name
))
print_rank_0
(
datapath
)
self
.
samples
=
self
.
process_samples_from_single_path
(
datapath
)
print_rank_0
(
' >> total number of samples: {}'
.
format
(
\
len
(
self
.
samples
)))
def
__len__
(
self
):
return
len
(
self
.
samples
)
def
__getitem__
(
self
,
idx
):
raw_sample
=
self
.
samples
[
idx
]
ques_tokens
,
tokentypes_enc
,
num_tokens_ques
=
\
build_tokens_types_paddings_from_text
(
raw_sample
[
'question'
],
self
.
tokenizer
,
self
.
max_seq_length
)
sample
=
build_sample
(
ques_tokens
,
tokentypes_enc
,
num_tokens_ques
,
raw_sample
[
'answers'
])
return
sample
@
staticmethod
def
process_samples_from_single_path
(
filename
):
print_rank_0
(
' > Processing {} ...'
.
format
(
filename
))
samples
=
[]
total
=
0
with
open
(
filename
,
'r'
)
as
ifile
:
reader
=
csv
.
reader
(
ifile
,
delimiter
=
'
\t
'
)
for
row
in
reader
:
question
=
row
[
0
]
answers
=
eval
(
row
[
1
])
sample
=
{
'question'
:
question
,
'answers'
:
answers
}
total
+=
1
samples
.
append
(
sample
)
if
total
%
1000
==
0
:
print_rank_0
(
' > processed {} so far ...'
.
format
(
total
))
print_rank_0
(
' >> processed {} samples.'
.
format
(
len
(
samples
)))
return
samples
tasks/orqa/natural_questions/qa_utils.py
0 → 100644
View file @
8b04e0e4
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Set of utilities for Q&A results validation tasks - Retriver passage
validation and Reader predicted answer validation
"""
import
collections
import
logging
import
string
import
unicodedata
from
functools
import
partial
from
multiprocessing
import
Pool
as
ProcessPool
from
typing
import
Tuple
,
List
,
Dict
import
regex
as
re
from
tasks.orqa.natural_questions.tokenizers
import
SimpleTokenizer
logger
=
logging
.
getLogger
(
__name__
)
QAMatchStats
=
collections
.
namedtuple
(
'QAMatchStats'
,
[
'top_k_hits'
,
\
'questions_doc_hits'
])
def
calculate_matches
(
all_docs
:
Dict
[
object
,
Tuple
[
str
,
str
]],
answers
:
List
[
List
[
str
]],
closest_docs
:
List
[
Tuple
[
List
[
object
],
List
[
float
]]],
workers_num
:
int
,
match_type
:
str
)
->
QAMatchStats
:
"""
Evaluates answers presence in the set of documents. This function is
supposed to be used with a large collection of documents and results.
It internally forks multiple sub-processes for evaluation and then
merges results
:param all_docs: dictionary of the entire documents database.
doc_id -> (doc_text, title)
:param answers: list of answers's list. One list per question
:param closest_docs: document ids of the top results along with their
scores
:param workers_num: amount of parallel threads to process data
:param match_type: type of answer matching. Refer to has_answer code for
available options
:return: matching information tuple.
top_k_hits - a list where the index is the amount of top documents retrieved
and the value is the total amount of valid matches across an entire
dataset.
questions_doc_hits - more detailed info with answer matches for every
question and every retrieved document
"""
global
dpr_all_documents
dpr_all_documents
=
all_docs
tok_opts
=
{}
tokenizer
=
SimpleTokenizer
(
**
tok_opts
)
processes
=
ProcessPool
(
processes
=
workers_num
,
)
logger
.
info
(
'Matching answers in top docs...'
)
get_score_partial
=
partial
(
check_answer
,
match_type
=
match_type
,
tokenizer
=
tokenizer
)
questions_answers_docs
=
zip
(
answers
,
closest_docs
)
scores
=
processes
.
map
(
get_score_partial
,
questions_answers_docs
)
logger
.
info
(
'Per question validation results len=%d'
,
len
(
scores
))
n_docs
=
len
(
closest_docs
[
0
][
0
])
top_k_hits
=
[
0
]
*
n_docs
for
question_hits
in
scores
:
best_hit
=
next
((
i
for
i
,
x
in
enumerate
(
question_hits
)
if
x
),
None
)
if
best_hit
is
not
None
:
top_k_hits
[
best_hit
:]
=
[
v
+
1
for
v
in
top_k_hits
[
best_hit
:]]
return
QAMatchStats
(
top_k_hits
,
scores
)
def
check_answer
(
questions_answers_docs
,
tokenizer
,
match_type
)
->
List
[
bool
]:
"""
Search through all the top docs to see if they have any of the answers.
"""
answers
,
(
doc_ids
,
doc_scores
)
=
questions_answers_docs
global
dpr_all_documents
hits
=
[]
for
i
,
doc_id
in
enumerate
(
doc_ids
):
doc
=
dpr_all_documents
[
doc_id
]
text
=
doc
[
0
]
answer_found
=
False
if
text
is
None
:
# cannot find the document for some reason
logger
.
warning
(
"no doc in db"
)
hits
.
append
(
False
)
continue
if
has_answer
(
answers
,
text
,
tokenizer
,
match_type
):
answer_found
=
True
hits
.
append
(
answer_found
)
return
hits
def
has_answer
(
answers
,
text
,
tokenizer
,
match_type
)
->
bool
:
"""
Check if a document contains an answer string.
If `match_type` is string, token matching is done between the text
and answer.
If `match_type` is regex, we search the whole text with the regex.
"""
text
=
_normalize
(
text
)
if
match_type
==
'string'
:
# Answer is a list of possible strings
text
=
tokenizer
.
tokenize
(
text
).
words
(
uncased
=
True
)
for
single_answer
in
answers
:
single_answer
=
_normalize
(
single_answer
)
single_answer
=
tokenizer
.
tokenize
(
single_answer
)
single_answer
=
single_answer
.
words
(
uncased
=
True
)
for
i
in
range
(
0
,
len
(
text
)
-
len
(
single_answer
)
+
1
):
if
single_answer
==
text
[
i
:
i
+
len
(
single_answer
)]:
return
True
elif
match_type
==
'regex'
:
# Answer is a regex
for
single_answer
in
answers
:
single_answer
=
_normalize
(
single_answer
)
if
regex_match
(
text
,
single_answer
):
return
True
return
False
def
regex_match
(
text
,
pattern
):
"""Test if a regex pattern is contained within a text."""
try
:
pattern
=
re
.
compile
(
pattern
,
flags
=
re
.
IGNORECASE
+
re
.
UNICODE
+
re
.
MULTILINE
,
)
except
BaseException
:
return
False
return
pattern
.
search
(
text
)
is
not
None
# function for the reader model answer validation
def
exact_match_score
(
prediction
,
ground_truth
):
return
_normalize_answer
(
prediction
)
==
_normalize_answer
(
ground_truth
)
def
_normalize_answer
(
s
):
def
remove_articles
(
text
):
return
re
.
sub
(
r
'\b(a|an|the)\b'
,
' '
,
text
)
def
white_space_fix
(
text
):
return
' '
.
join
(
text
.
split
())
def
remove_punc
(
text
):
exclude
=
set
(
string
.
punctuation
)
return
''
.
join
(
ch
for
ch
in
text
if
ch
not
in
exclude
)
def
lower
(
text
):
return
text
.
lower
()
return
white_space_fix
(
remove_articles
(
remove_punc
(
lower
(
s
))))
def
_normalize
(
text
):
return
unicodedata
.
normalize
(
'NFD'
,
text
)
tasks/orqa/natural_questions/tokenizers.py
0 → 100644
View file @
8b04e0e4
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Most of the tokenizers code here is copied from DrQA codebase to avoid adding extra dependency
"""
import
copy
import
logging
import
regex
import
spacy
logger
=
logging
.
getLogger
(
__name__
)
class
Tokens
(
object
):
"""A class to represent a list of tokenized text."""
TEXT
=
0
TEXT_WS
=
1
SPAN
=
2
POS
=
3
LEMMA
=
4
NER
=
5
def
__init__
(
self
,
data
,
annotators
,
opts
=
None
):
self
.
data
=
data
self
.
annotators
=
annotators
self
.
opts
=
opts
or
{}
def
__len__
(
self
):
"""The number of tokens."""
return
len
(
self
.
data
)
def
slice
(
self
,
i
=
None
,
j
=
None
):
"""Return a view of the list of tokens from [i, j)."""
new_tokens
=
copy
.
copy
(
self
)
new_tokens
.
data
=
self
.
data
[
i
:
j
]
return
new_tokens
def
untokenize
(
self
):
"""Returns the original text (with whitespace reinserted)."""
return
''
.
join
([
t
[
self
.
TEXT_WS
]
for
t
in
self
.
data
]).
strip
()
def
words
(
self
,
uncased
=
False
):
"""Returns a list of the text of each token
Args:
uncased: lower cases text
"""
if
uncased
:
return
[
t
[
self
.
TEXT
].
lower
()
for
t
in
self
.
data
]
else
:
return
[
t
[
self
.
TEXT
]
for
t
in
self
.
data
]
def
offsets
(
self
):
"""Returns a list of [start, end) character offsets of each token."""
return
[
t
[
self
.
SPAN
]
for
t
in
self
.
data
]
def
pos
(
self
):
"""Returns a list of part-of-speech tags of each token.
Returns None if this annotation was not included.
"""
if
'pos'
not
in
self
.
annotators
:
return
None
return
[
t
[
self
.
POS
]
for
t
in
self
.
data
]
def
lemmas
(
self
):
"""Returns a list of the lemmatized text of each token.
Returns None if this annotation was not included.
"""
if
'lemma'
not
in
self
.
annotators
:
return
None
return
[
t
[
self
.
LEMMA
]
for
t
in
self
.
data
]
def
entities
(
self
):
"""Returns a list of named-entity-recognition tags of each token.
Returns None if this annotation was not included.
"""
if
'ner'
not
in
self
.
annotators
:
return
None
return
[
t
[
self
.
NER
]
for
t
in
self
.
data
]
def
ngrams
(
self
,
n
=
1
,
uncased
=
False
,
filter_fn
=
None
,
as_strings
=
True
):
"""Returns a list of all ngrams from length 1 to n.
Args:
n: upper limit of ngram length
uncased: lower cases text
filter_fn: user function that takes in an ngram list and returns
True or False to keep or not keep the ngram
as_string: return the ngram as a string vs list
"""
def
_skip
(
gram
):
if
not
filter_fn
:
return
False
return
filter_fn
(
gram
)
words
=
self
.
words
(
uncased
)
ngrams
=
[(
s
,
e
+
1
)
for
s
in
range
(
len
(
words
))
for
e
in
range
(
s
,
min
(
s
+
n
,
len
(
words
)))
if
not
_skip
(
words
[
s
:
e
+
1
])]
# Concatenate into strings
if
as_strings
:
ngrams
=
[
'{}'
.
format
(
' '
.
join
(
words
[
s
:
e
]))
for
(
s
,
e
)
in
ngrams
]
return
ngrams
def
entity_groups
(
self
):
"""Group consecutive entity tokens with the same NER tag."""
entities
=
self
.
entities
()
if
not
entities
:
return
None
non_ent
=
self
.
opts
.
get
(
'non_ent'
,
'O'
)
groups
=
[]
idx
=
0
while
idx
<
len
(
entities
):
ner_tag
=
entities
[
idx
]
# Check for entity tag
if
ner_tag
!=
non_ent
:
# Chomp the sequence
start
=
idx
while
(
idx
<
len
(
entities
)
and
entities
[
idx
]
==
ner_tag
):
idx
+=
1
groups
.
append
((
self
.
slice
(
start
,
idx
).
untokenize
(),
ner_tag
))
else
:
idx
+=
1
return
groups
class
Tokenizer
(
object
):
"""Base tokenizer class.
Tokenizers implement tokenize, which should return a Tokens class.
"""
def
tokenize
(
self
,
text
):
raise
NotImplementedError
def
shutdown
(
self
):
pass
def
__del__
(
self
):
self
.
shutdown
()
class
SimpleTokenizer
(
Tokenizer
):
ALPHA_NUM
=
r
'[\p{L}\p{N}\p{M}]+'
NON_WS
=
r
'[^\p{Z}\p{C}]'
def
__init__
(
self
,
**
kwargs
):
"""
Args:
annotators: None or empty set (only tokenizes).
"""
self
.
_regexp
=
regex
.
compile
(
'(%s)|(%s)'
%
(
self
.
ALPHA_NUM
,
self
.
NON_WS
),
flags
=
regex
.
IGNORECASE
+
regex
.
UNICODE
+
regex
.
MULTILINE
)
if
len
(
kwargs
.
get
(
'annotators'
,
{}))
>
0
:
logger
.
warning
(
'%s only tokenizes! Skipping annotators: %s'
%
(
type
(
self
).
__name__
,
kwargs
.
get
(
'annotators'
)))
self
.
annotators
=
set
()
def
tokenize
(
self
,
text
):
data
=
[]
matches
=
[
m
for
m
in
self
.
_regexp
.
finditer
(
text
)]
for
i
in
range
(
len
(
matches
)):
# Get text
token
=
matches
[
i
].
group
()
# Get whitespace
span
=
matches
[
i
].
span
()
start_ws
=
span
[
0
]
if
i
+
1
<
len
(
matches
):
end_ws
=
matches
[
i
+
1
].
span
()[
0
]
else
:
end_ws
=
span
[
1
]
# Format data
data
.
append
((
token
,
text
[
start_ws
:
end_ws
],
span
,
))
return
Tokens
(
data
,
self
.
annotators
)
class
SpacyTokenizer
(
Tokenizer
):
def
__init__
(
self
,
**
kwargs
):
"""
Args:
annotators: set that can include pos, lemma, and ner.
model: spaCy model to use (either path, or keyword like 'en').
"""
model
=
kwargs
.
get
(
'model'
,
'en'
)
self
.
annotators
=
copy
.
deepcopy
(
kwargs
.
get
(
'annotators'
,
set
()))
nlp_kwargs
=
{
'parser'
:
False
}
if
not
any
([
p
in
self
.
annotators
for
p
in
[
'lemma'
,
'pos'
,
'ner'
]]):
nlp_kwargs
[
'tagger'
]
=
False
if
'ner'
not
in
self
.
annotators
:
nlp_kwargs
[
'entity'
]
=
False
self
.
nlp
=
spacy
.
load
(
model
,
**
nlp_kwargs
)
def
tokenize
(
self
,
text
):
# We don't treat new lines as tokens.
clean_text
=
text
.
replace
(
'
\n
'
,
' '
)
tokens
=
self
.
nlp
.
tokenizer
(
clean_text
)
if
any
([
p
in
self
.
annotators
for
p
in
[
'lemma'
,
'pos'
,
'ner'
]]):
self
.
nlp
.
tagger
(
tokens
)
if
'ner'
in
self
.
annotators
:
self
.
nlp
.
entity
(
tokens
)
data
=
[]
for
i
in
range
(
len
(
tokens
)):
# Get whitespace
start_ws
=
tokens
[
i
].
idx
if
i
+
1
<
len
(
tokens
):
end_ws
=
tokens
[
i
+
1
].
idx
else
:
end_ws
=
tokens
[
i
].
idx
+
len
(
tokens
[
i
].
text
)
data
.
append
((
tokens
[
i
].
text
,
text
[
start_ws
:
end_ws
],
(
tokens
[
i
].
idx
,
tokens
[
i
].
idx
+
len
(
tokens
[
i
].
text
)),
tokens
[
i
].
tag_
,
tokens
[
i
].
lemma_
,
tokens
[
i
].
ent_type_
,
))
# Set special option for non-entity tag: '' vs 'O' in spaCy
return
Tokens
(
data
,
self
.
annotators
,
opts
=
{
'non_ent'
:
''
})
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment