Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
8b355d93
Commit
8b355d93
authored
Jun 04, 2020
by
Neel Kant
Browse files
Temporarily modify ict_eval_bm25 to build full wiki index
parent
8409e1c1
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
39 additions
and
36 deletions
+39
-36
ict_eval_bm25.py
ict_eval_bm25.py
+39
-36
No files found.
ict_eval_bm25.py
View file @
8b355d93
import
lucene
import
lucene
import
sys
from
java.nio.file
import
Paths
from
java.nio.file
import
Paths
from
org.apache.lucene.analysis.standard
import
StandardAnalyzer
from
org.apache.lucene.analysis.standard
import
StandardAnalyzer
...
@@ -24,14 +25,19 @@ def setup():
...
@@ -24,14 +25,19 @@ def setup():
lucene
.
initVM
(
vmargs
=
[
'-Djava.awt.headless=true'
])
lucene
.
initVM
(
vmargs
=
[
'-Djava.awt.headless=true'
])
def
run
():
def
run
(
embed_all
=
False
):
dset
=
get_ict_dataset
(
use_titles
=
False
,
query_in_block_prob
=
0.1
)
dset
=
get_ict_dataset
(
use_titles
=
False
,
query_in_block_prob
=
0.1
)
dataloader
=
iter
(
get_one_epoch_dataloader
(
dset
))
dataloader
=
iter
(
get_one_epoch_dataloader
(
dset
))
index_dir
=
SimpleFSDirectory
(
Paths
.
get
(
"index/"
))
index_dir
=
SimpleFSDirectory
(
Paths
.
get
(
"
full_wiki_
index/"
))
analyzer
=
StandardAnalyzer
()
analyzer
=
StandardAnalyzer
()
analyzer
.
setMaxTokenLength
(
1024
)
analyzer
.
setMaxTokenLength
(
1024
)
config
=
IndexWriterConfig
(
analyzer
)
config
.
setOpenMode
(
IndexWriterConfig
.
OpenMode
.
CREATE
)
writer
=
IndexWriter
(
index_dir
,
config
)
# field for document ID
# field for document ID
t1
=
FieldType
()
t1
=
FieldType
()
t1
.
setStored
(
True
)
t1
.
setStored
(
True
)
...
@@ -46,7 +52,7 @@ def run():
...
@@ -46,7 +52,7 @@ def run():
correct
=
total
=
0
correct
=
total
=
0
round_correct
=
torch
.
zeros
(
1
).
cuda
()
round_correct
=
torch
.
zeros
(
1
).
cuda
()
round_total
=
torch
.
zeros
(
1
).
cuda
()
round_total
=
torch
.
zeros
(
1
).
cuda
()
for
round
in
range
(
100
):
for
round
in
range
(
100
000
):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
try
:
try
:
query_tokens
,
query_pad_mask
,
\
query_tokens
,
query_pad_mask
,
\
...
@@ -54,19 +60,12 @@ def run():
...
@@ -54,19 +60,12 @@ def run():
except
:
except
:
break
break
query_tokens
=
query_tokens
.
detach
().
cpu
().
numpy
()
#
query_tokens = query_tokens.detach().cpu().numpy()
block_tokens
=
block_tokens
.
detach
().
cpu
().
numpy
()
block_tokens
=
block_tokens
.
detach
().
cpu
().
numpy
()
query_strs
=
[
dset
.
decode_tokens
(
query_tokens
[
i
].
tolist
(),
hardcore
=
True
)
for
i
in
range
(
query_tokens
.
shape
[
0
])]
#
query_strs = [dset.decode_tokens(query_tokens[i].tolist(), hardcore=True) for i in range(query_tokens.shape[0])]
block_strs
=
[
dset
.
decode_tokens
(
block_tokens
[
i
].
tolist
(),
hardcore
=
True
)
for
i
in
range
(
block_tokens
.
shape
[
0
])]
block_strs
=
[
dset
.
decode_tokens
(
block_tokens
[
i
].
tolist
(),
hardcore
=
True
)
for
i
in
range
(
block_tokens
.
shape
[
0
])]
# create index writer
config
=
IndexWriterConfig
(
analyzer
)
config
.
setOpenMode
(
IndexWriterConfig
.
OpenMode
.
CREATE
)
writer
=
IndexWriter
(
index_dir
,
config
)
def
add_document
(
text
,
writer
,
doc_id
):
def
add_document
(
text
,
writer
,
doc_id
):
doc
=
Document
()
doc
=
Document
()
doc
.
add
(
Field
(
"text"
,
text
,
t2
))
doc
.
add
(
Field
(
"text"
,
text
,
t2
))
...
@@ -79,32 +78,36 @@ def run():
...
@@ -79,32 +78,36 @@ def run():
# write and finalize the index
# write and finalize the index
writer
.
commit
()
writer
.
commit
()
writer
.
close
()
# define BM25 searcher
# define BM25 searcher
searcher
=
IndexSearcher
(
DirectoryReader
.
open
(
index_dir
))
# searcher = IndexSearcher(DirectoryReader.open(index_dir))
searcher
.
setSimilarity
(
BM25Similarity
())
# searcher.setSimilarity(BM25Similarity())
# feed queries and get scores for everything in the index
# # feed queries and get scores for everything in the index
hits_list
=
[]
# hits_list = []
for
s
in
query_strs
:
# for s in query_strs:
query
=
QueryParser
(
"text"
,
analyzer
).
parse
(
s
)
# query = QueryParser("text", analyzer).parse(s)
hits
=
searcher
.
search
(
query
,
8
).
scoreDocs
# hits = searcher.search(query, 1).scoreDocs
hits_list
.
append
(
hits
)
# hits_list.append(hits)
for
(
i
,
hits
)
in
enumerate
(
hits_list
):
# for (i, hits) in enumerate(hits_list):
doc_ids
=
[
int
(
searcher
.
doc
(
hit
.
doc
)[
'doc_id'
])
for
hit
in
hits
]
# doc_ids = [int(searcher.doc(hit.doc)['doc_id']) for hit in hits]
correct
+=
int
(
i
in
doc_ids
)
# correct += int(i in doc_ids)
total
+=
1
# total += 1
dist
.
all_reduce
(
round_correct
)
# dist.all_reduce(round_correct)
dist
.
all_reduce
(
round_total
)
# dist.all_reduce(round_total)
correct
+=
int
(
round_correct
.
item
())
total
+=
int
(
round_total
.
item
())
# correct += int(round_correct.item())
round_correct
-=
round_correct
# total += int(round_total.item())
round_total
-=
round_total
# round_correct -= round_correct
print
(
"Correct: {:8d} | Total: {:8d} | Fraction: {:6.5f}"
.
format
(
correct
,
total
,
correct
/
total
))
# round_total -= round_total
# print("Correct: {:8d} | Total: {:8d} | Fraction: {:6.5f}".format(correct, total, correct / total))
if
round
%
10
==
0
:
print
(
round
)
writer
.
close
()
# Plan
# Plan
# overall accuracy test:
# overall accuracy test:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment