Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
6e256445
Commit
6e256445
authored
May 12, 2020
by
Neel Kant
Browse files
faiss use_gpu
parent
0c077a2c
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
15 additions
and
28 deletions
+15
-28
faiss_test.py
faiss_test.py
+0
-4
indexer.py
indexer.py
+5
-5
megatron/data/realm_index.py
megatron/data/realm_index.py
+8
-17
megatron/model/realm_model.py
megatron/model/realm_model.py
+2
-2
No files found.
faiss_test.py
View file @
6e256445
...
...
@@ -45,8 +45,6 @@ ENCODINGS = [
# LSH is inaccurate - pretty much always missing the top-1 result (1e6 embeds)
def
latest
(
times
):
return
times
[
-
1
]
-
times
[
-
2
]
...
...
@@ -185,8 +183,6 @@ def run_all_tests():
test_encodings
(
d
,
k
,
embeds
,
queries
)
if
__name__
==
"__main__"
:
run_all_tests
()
...
...
indexer.py
View file @
6e256445
...
...
@@ -43,8 +43,11 @@ def test_retriever():
model
=
load_ict_checkpoint
(
only_block_model
=
True
)
model
.
eval
()
dataset
=
get_ict_dataset
()
hashed_index
=
HashedIndex
.
load_from_file
(
args
.
hash_data_path
)
retriever
=
REALMRetriever
(
model
,
dataset
,
hashed_index
)
block_data
=
BlockData
.
load_from_file
(
args
.
block_data_path
)
mips_index
=
FaissMIPSIndex
(
'flat_ip'
,
128
)
mips_index
.
add_block_embed_data
(
block_data
)
retriever
=
REALMRetriever
(
model
,
dataset
,
mips_index
,
top_k
=
5
)
strs
=
[
"The last monarch from the house of windsor"
,
...
...
@@ -58,8 +61,6 @@ def test_retriever():
def
main
():
initialize_megatron
(
extra_args_provider
=
None
,
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
args
=
get_args
()
...
...
@@ -116,7 +117,6 @@ def main():
set_model_com_file_not_ready
()
def
load_ict_checkpoint
(
only_query_model
=
False
,
only_block_model
=
False
,
no_grad
=
False
,
from_realm_chkpt
=
False
):
args
=
get_args
()
model
=
get_model
(
lambda
:
model_provider
(
only_query_model
,
only_block_model
))
...
...
megatron/data/realm_index.py
View file @
6e256445
...
...
@@ -3,10 +3,11 @@ import os
import
pickle
import
shutil
import
faiss
import
numpy
as
np
import
torch
from
megatron
import
get_args
from
megatron
import
get_args
,
mpu
def
detach
(
tensor
):
...
...
@@ -77,10 +78,10 @@ class BlockData(object):
class
FaissMIPSIndex
(
object
):
def
__init__
(
self
,
index_type
,
embed_size
,
**
index_kwargs
):
def
__init__
(
self
,
index_type
,
embed_size
,
use_gpu
=
False
):
self
.
index_type
=
index_type
self
.
embed_size
=
embed_size
self
.
index_kwargs
=
dict
(
index_kwargs
)
self
.
use_gpu
=
use_gpu
# alsh
self
.
m
=
5
...
...
@@ -89,27 +90,17 @@ class FaissMIPSIndex(object):
self
.
block_mips_index
=
None
self
.
_set_block_index
()
@
classmethod
def
load_from_file
(
cls
,
fname
):
print
(
" > Unpickling block index data"
)
state_dict
=
pickle
.
load
(
open
(
fname
,
'rb'
))
print
(
" > Finished unpickling"
)
index_type
=
state_dict
[
'index_type'
]
index_kwargs
=
state_dict
[
'index_kwargs'
]
embed_size
=
state_dict
[
'embed_size'
]
new_index
=
cls
(
index_type
,
embed_size
,
**
index_kwargs
)
return
new_index
def
_set_block_index
(
self
):
import
faiss
INDEX_TYPES
=
[
'flat_l2'
,
'flat_ip'
]
if
self
.
index_type
not
in
INDEX_TYPES
:
raise
ValueError
(
"Invalid index type specified"
)
index
=
faiss
.
index_factory
(
self
.
embed_size
,
'Flat'
,
faiss
.
METRIC_INNER_PRODUCT
)
self
.
block_mips_index
=
faiss
.
IndexIDMap
(
index
)
if
self
.
use_gpu
:
res
=
faiss
.
StandardGpuResources
()
device
=
mpu
.
get_data_parallel_rank
()
self
.
block_mips_index
=
faiss
.
index_cpu_to_gpu
(
res
,
device
,
self
.
block_mips_index
)
def
reset_index
(
self
):
self
.
_set_block_index
()
...
...
megatron/model/realm_model.py
View file @
6e256445
...
...
@@ -178,8 +178,8 @@ class REALMRetriever(MegatronModule):
query_tokens
=
torch
.
cuda
.
LongTensor
(
np
.
array
(
query_tokens
).
reshape
(
1
,
-
1
))
query_pad_mask
=
torch
.
cuda
.
LongTensor
(
np
.
array
(
query_pad_mask
).
reshape
(
1
,
-
1
))
top
5
_block_tokens
,
_
=
self
.
retrieve_evidence_blocks
(
query_tokens
,
query_pad_mask
)
for
i
,
block
in
enumerate
(
top
5
_block_tokens
[
0
]):
top
k
_block_tokens
,
_
=
self
.
retrieve_evidence_blocks
(
query_tokens
,
query_pad_mask
)
for
i
,
block
in
enumerate
(
top
k
_block_tokens
[
0
]):
block_text
=
self
.
ict_dataset
.
decode_tokens
(
block
)
print
(
'
\n
> Block {}: {}'
.
format
(
i
,
block_text
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment