Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
c1c958fa
"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "939ec17e91432531b8acb7ea0ae9295936ff1da6"
Commit
c1c958fa
authored
Apr 30, 2020
by
Neel Kant
Browse files
Implement MIPS with FAISS
parent
56bd4804
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
5 deletions
+5
-5
hashed_index.py
hashed_index.py
+5
-5
No files found.
hashed_index.py
View file @
c1c958fa
...
@@ -160,8 +160,8 @@ class HashedIndex(object):
...
@@ -160,8 +160,8 @@ class HashedIndex(object):
self
.
block_idx
,
block_embeds
=
zip
(
*
self
.
block_data
.
items
())
self
.
block_idx
,
block_embeds
=
zip
(
*
self
.
block_data
.
items
())
block_embeds
=
np
.
array
(
block_embeds
)
block_embeds
=
np
.
array
(
block_embeds
)
index
=
faiss
.
IndexFlatL2
(
block_embeds
.
shape
[
1
])
alsh_preprocessed_blocks
=
self
.
alsh_block_preprocess_fn
()
alsh_preprocessed_blocks
=
self
.
alsh_block_preprocess_fn
()
index
=
faiss
.
IndexFlatL2
(
alsh_preprocessed_blocks
.
shape
[
1
])
index
.
add
(
alsh_preprocessed_blocks
)
index
.
add
(
alsh_preprocessed_blocks
)
print
(
'Total blocks in index: '
,
index
.
ntotal
)
print
(
'Total blocks in index: '
,
index
.
ntotal
)
self
.
block_index
=
index
self
.
block_index
=
index
...
@@ -187,7 +187,7 @@ class HashedIndex(object):
...
@@ -187,7 +187,7 @@ class HashedIndex(object):
norm_powers
,
halves_array
=
self
.
get_norm_powers_and_halves_array
(
block_embeds
)
norm_powers
,
halves_array
=
self
.
get_norm_powers_and_halves_array
(
block_embeds
)
# P'(S(x)) for all x in block_embeds
# P'(S(x)) for all x in block_embeds
return
np
.
concatenate
((
block_embeds
,
norm_powers
,
halves_array
),
axis
=
1
)
return
np
.
float32
(
np
.
concatenate
((
block_embeds
,
norm_powers
,
halves_array
),
axis
=
1
)
)
def
alsh_query_preprocess_fn
(
self
,
query_embeds
):
def
alsh_query_preprocess_fn
(
self
,
query_embeds
):
norm
=
np
.
linalg
.
norm
(
query_embeds
,
axis
=
1
)
norm
=
np
.
linalg
.
norm
(
query_embeds
,
axis
=
1
)
...
@@ -197,7 +197,7 @@ class HashedIndex(object):
...
@@ -197,7 +197,7 @@ class HashedIndex(object):
norm_powers
,
halves_array
=
self
.
get_norm_powers_and_halves_array
(
query_embeds
)
norm_powers
,
halves_array
=
self
.
get_norm_powers_and_halves_array
(
query_embeds
)
# Q'(S(x)) for all x in query_embeds
# Q'(S(x)) for all x in query_embeds
return
np
.
concatenate
((
query_embeds
,
halves_array
,
norm_powers
),
axis
=
1
)
return
np
.
float32
(
np
.
concatenate
((
query_embeds
,
halves_array
,
norm_powers
),
axis
=
1
)
)
def
exact_mips_equals
(
self
,
query_embeds
):
def
exact_mips_equals
(
self
,
query_embeds
):
"""For each query, determine whether the mips block is in the correct hash bucket"""
"""For each query, determine whether the mips block is in the correct hash bucket"""
...
@@ -234,7 +234,7 @@ class HashedIndex(object):
...
@@ -234,7 +234,7 @@ class HashedIndex(object):
if
alsh
:
if
alsh
:
self
.
create_block_data_index
()
self
.
create_block_data_index
()
alsh_queries
=
self
.
alsh_query_preprocess_fn
(
query_embeds
)
alsh_queries
=
self
.
alsh_query_preprocess_fn
(
query_embeds
)
neighbor_ids
,
distances
=
self
.
block_i
d
x
.
search
(
alsh_queries
,
5
)
neighbor_ids
,
distances
=
self
.
block_i
nde
x
.
search
(
alsh_queries
,
5
)
print
(
'DONE'
)
print
(
'DONE'
)
return
return
else
:
else
:
...
@@ -313,7 +313,7 @@ def main():
...
@@ -313,7 +313,7 @@ def main():
model
.
eval
()
model
.
eval
()
dataset
=
get_ict_dataset
()
dataset
=
get_ict_dataset
()
data_iter
=
iter
(
get_one_epoch_dataloader
(
dataset
))
data_iter
=
iter
(
get_one_epoch_dataloader
(
dataset
))
hashed_index
=
HashedIndex
(
embed_size
=
128
,
num_buckets
=
4096
,
whiten
=
True
)
hashed_index
=
HashedIndex
(
embed_size
=
128
,
num_buckets
=
32
,
whiten
=
True
)
i
=
1
i
=
1
total
=
0
total
=
0
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment