Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
8409e1c1
Commit
8409e1c1
authored
May 28, 2020
by
Neel Kant
Browse files
Add bm25 evaluation code
parent
dfb907fe
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
146 additions
and
5 deletions
+146
-5
ict_eval_bm25.py
ict_eval_bm25.py
+121
-0
indexer.py
indexer.py
+2
-2
megatron/data/realm_dataset.py
megatron/data/realm_dataset.py
+23
-3
No files found.
ict_eval_bm25.py
0 → 100644
View file @
8409e1c1
import
lucene
from
java.nio.file
import
Paths
from
org.apache.lucene.analysis.standard
import
StandardAnalyzer
from
org.apache.lucene.document
import
Document
,
Field
,
FieldType
from
org.apache.lucene.index
import
IndexWriter
,
IndexWriterConfig
,
IndexOptions
,
DirectoryReader
from
org.apache.lucene.store
import
SimpleFSDirectory
from
org.apache.lucene.search
import
IndexSearcher
from
org.apache.lucene.queryparser.classic
import
QueryParser
from
org.apache.lucene.search.similarities
import
BM25Similarity
from
org.apache.lucene.util
import
Version
import
torch
import
torch.distributed
as
dist
from
indexer
import
get_ict_dataset
,
get_one_epoch_dataloader
from
megatron.initialize
import
initialize_megatron
from
pretrain_bert_ict
import
get_batch
def
setup
():
initialize_megatron
(
extra_args_provider
=
None
,
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
lucene
.
initVM
(
vmargs
=
[
'-Djava.awt.headless=true'
])
def
run
():
dset
=
get_ict_dataset
(
use_titles
=
False
,
query_in_block_prob
=
0.1
)
dataloader
=
iter
(
get_one_epoch_dataloader
(
dset
))
index_dir
=
SimpleFSDirectory
(
Paths
.
get
(
"index/"
))
analyzer
=
StandardAnalyzer
()
analyzer
.
setMaxTokenLength
(
1024
)
# field for document ID
t1
=
FieldType
()
t1
.
setStored
(
True
)
t1
.
setTokenized
(
False
)
# field for document text
t2
=
FieldType
()
t2
.
setStored
(
True
)
t2
.
setTokenized
(
True
)
t2
.
setIndexOptions
(
IndexOptions
.
DOCS_AND_FREQS_AND_POSITIONS
)
correct
=
total
=
0
round_correct
=
torch
.
zeros
(
1
).
cuda
()
round_total
=
torch
.
zeros
(
1
).
cuda
()
for
round
in
range
(
100
):
with
torch
.
no_grad
():
try
:
query_tokens
,
query_pad_mask
,
\
block_tokens
,
block_pad_mask
,
block_index_data
=
get_batch
(
dataloader
)
except
:
break
query_tokens
=
query_tokens
.
detach
().
cpu
().
numpy
()
block_tokens
=
block_tokens
.
detach
().
cpu
().
numpy
()
query_strs
=
[
dset
.
decode_tokens
(
query_tokens
[
i
].
tolist
(),
hardcore
=
True
)
for
i
in
range
(
query_tokens
.
shape
[
0
])]
block_strs
=
[
dset
.
decode_tokens
(
block_tokens
[
i
].
tolist
(),
hardcore
=
True
)
for
i
in
range
(
block_tokens
.
shape
[
0
])]
# create index writer
config
=
IndexWriterConfig
(
analyzer
)
config
.
setOpenMode
(
IndexWriterConfig
.
OpenMode
.
CREATE
)
writer
=
IndexWriter
(
index_dir
,
config
)
def
add_document
(
text
,
writer
,
doc_id
):
doc
=
Document
()
doc
.
add
(
Field
(
"text"
,
text
,
t2
))
doc
.
add
(
Field
(
"doc_id"
,
doc_id
,
t1
))
writer
.
addDocument
(
doc
)
# add documents to index writer
for
i
in
range
(
len
(
block_strs
)):
add_document
(
block_strs
[
i
],
writer
,
i
)
# write and finalize the index
writer
.
commit
()
writer
.
close
()
# define BM25 searcher
searcher
=
IndexSearcher
(
DirectoryReader
.
open
(
index_dir
))
searcher
.
setSimilarity
(
BM25Similarity
())
# feed queries and get scores for everything in the index
hits_list
=
[]
for
s
in
query_strs
:
query
=
QueryParser
(
"text"
,
analyzer
).
parse
(
s
)
hits
=
searcher
.
search
(
query
,
8
).
scoreDocs
hits_list
.
append
(
hits
)
for
(
i
,
hits
)
in
enumerate
(
hits_list
):
doc_ids
=
[
int
(
searcher
.
doc
(
hit
.
doc
)[
'doc_id'
])
for
hit
in
hits
]
correct
+=
int
(
i
in
doc_ids
)
total
+=
1
dist
.
all_reduce
(
round_correct
)
dist
.
all_reduce
(
round_total
)
correct
+=
int
(
round_correct
.
item
())
total
+=
int
(
round_total
.
item
())
round_correct
-=
round_correct
round_total
-=
round_total
print
(
"Correct: {:8d} | Total: {:8d} | Fraction: {:6.5f}"
.
format
(
correct
,
total
,
correct
/
total
))
# Plan
# overall accuracy test:
# have index with all blocks. For BERT these are token ids, for BM25 these are tokens
#
# 1. run batch size 4096 BM25 self similarity test. For this I can just detokenize out of the dataset.
# I get the retrieval scores in the forward_step and log the results.
# 2. Create a BM25 index over all of wikipedia, have it ready for use in megatron QA.
#
# Create an index with the block embeddings with block ids
if
__name__
==
"__main__"
:
setup
()
run
()
indexer.py
View file @
8409e1c1
...
@@ -240,7 +240,7 @@ def load_ict_checkpoint(only_query_model=False, only_block_model=False, no_grad=
...
@@ -240,7 +240,7 @@ def load_ict_checkpoint(only_query_model=False, only_block_model=False, no_grad=
return
model
return
model
def
get_ict_dataset
(
use_titles
=
True
):
def
get_ict_dataset
(
use_titles
=
True
,
query_in_block_prob
=
1
):
args
=
get_args
()
args
=
get_args
()
block_dataset
=
get_indexed_dataset_
(
args
.
data_path
,
'mmap'
,
True
)
block_dataset
=
get_indexed_dataset_
(
args
.
data_path
,
'mmap'
,
True
)
titles_dataset
=
get_indexed_dataset_
(
args
.
titles_data_path
,
'mmap'
,
True
)
titles_dataset
=
get_indexed_dataset_
(
args
.
titles_data_path
,
'mmap'
,
True
)
...
@@ -255,7 +255,7 @@ def get_ict_dataset(use_titles=True):
...
@@ -255,7 +255,7 @@ def get_ict_dataset(use_titles=True):
max_seq_length
=
args
.
seq_length
,
max_seq_length
=
args
.
seq_length
,
short_seq_prob
=
0.0001
,
# doesn't matter
short_seq_prob
=
0.0001
,
# doesn't matter
seed
=
1
,
seed
=
1
,
query_in_block_prob
=
1
,
query_in_block_prob
=
query_in_block_prob
,
use_titles
=
use_titles
use_titles
=
use_titles
)
)
dataset
=
ICTDataset
(
**
kwargs
)
dataset
=
ICTDataset
(
**
kwargs
)
...
...
megatron/data/realm_dataset.py
View file @
8409e1c1
...
@@ -134,10 +134,30 @@ class ICTDataset(Dataset):
...
@@ -134,10 +134,30 @@ class ICTDataset(Dataset):
def
encode_text
(
self
,
text
):
def
encode_text
(
self
,
text
):
return
self
.
tokenizer
.
tokenize
(
text
)
return
self
.
tokenizer
.
tokenize
(
text
)
def
decode_tokens
(
self
,
token_ids
):
def
decode_tokens
(
self
,
token_ids
,
hardcore
=
False
):
tokens
=
self
.
tokenizer
.
tokenizer
.
convert_ids_to_tokens
(
token_ids
)
tokens
=
self
.
tokenizer
.
tokenizer
.
convert_ids_to_tokens
(
token_ids
)
non_pads
=
[
t
for
t
in
tokens
if
t
!=
'[PAD]'
]
exclude_list
=
[
'[PAD]'
,
'[CLS]'
]
return
join_str_list
(
non_pads
)
if
hardcore
:
extra_exclude
=
[
'[SEP]'
]
exclude_list
.
extend
(
extra_exclude
)
non_pads
=
[
t
for
t
in
tokens
if
t
not
in
exclude_list
]
joined_strs
=
join_str_list
(
non_pads
)
if
hardcore
:
escape_chars
=
[
'+'
,
'-'
,
'&'
,
'!'
,
'('
,
')'
,
'{'
,
'}'
,
'['
,
']'
,
'^'
,
'"'
,
'~'
,
'*'
,
'?'
,
':'
,
'/'
]
skip_me
=
False
joined_strs
=
list
(
joined_strs
)
joined_strs
=
[
s
for
s
in
joined_strs
if
s
!=
'
\\
'
]
for
i
,
c
in
enumerate
(
joined_strs
):
if
skip_me
:
skip_me
=
False
continue
if
c
in
escape_chars
:
joined_strs
.
insert
(
i
,
'
\\
'
)
skip_me
=
True
joined_strs
=
''
.
join
(
joined_strs
)
if
len
(
joined_strs
)
<
3
:
joined_strs
+=
'text here'
return
joined_strs
def
get_block
(
self
,
start_idx
,
end_idx
,
doc_idx
):
def
get_block
(
self
,
start_idx
,
end_idx
,
doc_idx
):
"""Get the IDs for an evidence block plus the title of the corresponding document"""
"""Get the IDs for an evidence block plus the title of the corresponding document"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment