Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
150f2384
Commit
150f2384
authored
May 12, 2020
by
Neel Kant
Browse files
Update faiss_test
parent
c9ca82bd
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
55 additions
and
16 deletions
+55
-16
faiss_test.py
faiss_test.py
+55
-16
No files found.
faiss_test.py
View file @
150f2384
from
collections
import
defaultdict
from
collections
import
defaultdict
import
time
import
time
import
pickle
import
faiss
import
faiss
from
faiss
import
index_factory
from
faiss
import
index_factory
,
index_cpu_to_gpu
import
numpy
as
np
import
numpy
as
np
from
megatron
import
get_args
from
megatron
import
get_args
...
@@ -14,13 +15,19 @@ PCAS = [
...
@@ -14,13 +15,19 @@ PCAS = [
# PCA to 64 dim gets "first missing" ~ 95% and "mixed" ~ 5% for all
# PCA to 64 dim gets "first missing" ~ 95% and "mixed" ~ 5% for all
# however, this is pretty hard since the embeds and queries are totally random, would be better to test according to a distribution
# however, this is pretty hard since the embeds and queries are totally random, would be better to test according to a distribution
# update: Using realisitc mean and covariance helps, but then adjusting for inner product makes it unusable again
# CONCLUSION: PCA should not be used for MIPS
QUANTIZERS
=
[
QUANTIZERS
=
[
'IVF4096
'
,
'IMI2x9'
,
'IVF4096
_SQ16'
,
#
'IMI2x9',
'HNSW32
'
,
'IVF4096_HNSW32'
'HNSW32
_SQ16'
,
#
'IVF4096_HNSW32'
]
]
# IMI2x9 or any other MultiIndex doesn't support inner product so it's unusable
# IVF4096_HNSW32 doesn't support inner product either
ENCODINGS
=
[
ENCODINGS
=
[
'Flat'
,
'Flat'
,
...
@@ -38,16 +45,34 @@ ENCODINGS = [
...
@@ -38,16 +45,34 @@ ENCODINGS = [
# LSH is inaccurate - pretty much always missing the top-1 result (1e6 embeds)
# LSH is inaccurate - pretty much always missing the top-1 result (1e6 embeds)
def
latest
(
times
):
def
latest
(
times
):
return
times
[
-
1
]
-
times
[
-
2
]
return
times
[
-
1
]
-
times
[
-
2
]
def
get_embeds_and_queries
(
d
,
num_embeds
,
num_queries
):
def
get_embed_mean_and_cov
():
embed_data
=
pickle
.
load
(
open
(
'/home/dcg-adlr-nkant-data.cosmos1202/hash_data/normed4096_whitened.pkl'
,
'rb'
))
embed_mean
=
embed_data
[
'embed_mean'
]
whitener
=
embed_data
[
'embed_whitener'
]
embed_cov
=
whitener
.
dot
(
whitener
.
transpose
())
return
embed_mean
,
embed_cov
def
get_embeds_and_queries
(
mean
,
cov
,
num_embeds
,
num_queries
):
embeds
=
np
.
random
.
multivariate_normal
(
mean
,
cov
,
num_embeds
).
astype
(
'float32'
)
queries
=
np
.
random
.
multivariate_normal
(
mean
,
cov
,
num_queries
).
astype
(
'float32'
)
return
embeds
,
queries
def
get_random_embeds_and_queries
(
d
,
num_embeds
,
num_queries
):
embeds
=
np
.
random
.
rand
(
num_embeds
,
d
).
astype
(
'float32'
)
embeds
=
np
.
random
.
rand
(
num_embeds
,
d
).
astype
(
'float32'
)
queries
=
np
.
random
.
rand
(
num_queries
,
d
).
astype
(
'float32'
)
queries
=
np
.
random
.
rand
(
num_queries
,
d
).
astype
(
'float32'
)
return
embeds
,
queries
return
embeds
,
queries
def
print_timing_stats
(
name
,
create_and_add
,
search
):
def
print_timing_stats
(
name
,
create_and_add
,
search
):
print
(
'{:20s} Create and add embeds: {:10.4f}s | Search embeds: {:10.4f}s'
.
format
(
name
,
create_and_add
,
search
))
print
(
'{:20s} Create and add embeds: {:10.4f}s | Search embeds: {:10.4f}s'
.
format
(
name
,
create_and_add
,
search
))
...
@@ -69,7 +94,8 @@ def print_accuracy_stats(name, gold_indices, estimated_indices):
...
@@ -69,7 +94,8 @@ def print_accuracy_stats(name, gold_indices, estimated_indices):
def
create_and_test_gold
(
d
,
k
,
embeds
,
queries
):
def
create_and_test_gold
(
d
,
k
,
embeds
,
queries
):
times
=
[
time
.
time
()]
times
=
[
time
.
time
()]
gold_idx
=
index_factory
(
d
,
'Flat'
)
res
=
faiss
.
StandardGpuResources
()
gold_idx
=
index_cpu_to_gpu
(
res
,
0
,
index_factory
(
d
,
'Flat'
))
gold_idx
.
add
(
embeds
)
gold_idx
.
add
(
embeds
)
times
.
append
(
time
.
time
())
times
.
append
(
time
.
time
())
create_and_add
=
latest
(
times
)
create_and_add
=
latest
(
times
)
...
@@ -81,15 +107,14 @@ def create_and_test_gold(d, k, embeds, queries):
...
@@ -81,15 +107,14 @@ def create_and_test_gold(d, k, embeds, queries):
return
distances
,
indices
return
distances
,
indices
def
test_pca
(
d
,
k
,
num_
embeds
,
num_
queries
,
pca_dim
):
def
test_pca
(
d
,
k
,
embeds
,
queries
,
pca_dim
):
embeds
,
queries
=
get_embeds_and_queries
(
d
,
num_embeds
,
num_queries
)
distances
,
indices
=
create_and_test_gold
(
d
,
k
,
embeds
,
queries
)
distances
,
indices
=
create_and_test_gold
(
d
,
k
,
embeds
,
queries
)
times
=
[
time
.
time
()]
times
=
[
time
.
time
()]
all_pca_indices
=
[]
all_pca_indices
=
[]
for
s
in
PCAS
:
for
s
in
PCAS
:
pca_idx
=
index_factory
(
d
,
s
+
"{},Flat"
.
format
(
pca_dim
))
pca_idx
=
index_factory
(
d
,
s
+
"{},Flat"
.
format
(
pca_dim
)
,
faiss
.
METRIC_INNER_PRODUCT
)
pca_idx
.
train
(
embeds
)
pca_idx
.
train
(
embeds
)
pca_idx
.
add
(
embeds
)
pca_idx
.
add
(
embeds
)
times
.
append
(
time
.
time
())
times
.
append
(
time
.
time
())
...
@@ -105,17 +130,16 @@ def test_pca(d, k, num_embeds, num_queries, pca_dim):
...
@@ -105,17 +130,16 @@ def test_pca(d, k, num_embeds, num_queries, pca_dim):
print_accuracy_stats
(
s
,
indices
,
pca_indices
)
print_accuracy_stats
(
s
,
indices
,
pca_indices
)
def
test_quantizers
(
d
,
k
,
num_
embeds
,
num_
queries
):
def
test_quantizers
(
d
,
k
,
embeds
,
queries
):
embeds
,
queries
=
get_embeds_and_queries
(
d
,
num_embeds
,
num_queries
)
distances
,
indices
=
create_and_test_gold
(
d
,
k
,
embeds
,
queries
)
distances
,
indices
=
create_and_test_gold
(
d
,
k
,
embeds
,
queries
)
times
=
[
time
.
time
()]
times
=
[
time
.
time
()]
for
s
in
QUANTIZERS
:
for
s
in
QUANTIZERS
:
if
'HNSW'
in
s
and
'_'
not
in
s
:
if
'HNSW'
in
s
:
quant_idx
=
index_factory
(
d
,
s
)
quant_idx
=
index_factory
(
d
,
s
,
faiss
.
METRIC_INNER_PRODUCT
)
else
:
else
:
quant_idx
=
index_factory
(
d
,
"Flat,"
+
s
)
quant_idx
=
index_factory
(
d
,
"Flat,"
+
s
,
faiss
.
METRIC_INNER_PRODUCT
)
quant_idx
.
train
(
embeds
)
quant_idx
.
train
(
embeds
)
quant_idx
.
add
(
embeds
)
quant_idx
.
add
(
embeds
)
...
@@ -127,15 +151,14 @@ def test_quantizers(d, k, num_embeds, num_queries):
...
@@ -127,15 +151,14 @@ def test_quantizers(d, k, num_embeds, num_queries):
print_timing_stats
(
s
,
create_and_add
,
latest
(
times
))
print_timing_stats
(
s
,
create_and_add
,
latest
(
times
))
def
test_encodings
(
d
,
k
,
num_
embeds
,
num_
queries
):
def
test_encodings
(
d
,
k
,
embeds
,
queries
):
embeds
,
queries
=
get_embeds_and_queries
(
d
,
num_embeds
,
num_queries
)
distances
,
indices
=
create_and_test_gold
(
d
,
k
,
embeds
,
queries
)
distances
,
indices
=
create_and_test_gold
(
d
,
k
,
embeds
,
queries
)
times
=
[
time
.
time
()]
times
=
[
time
.
time
()]
all_encode_indices
=
[]
all_encode_indices
=
[]
for
s
in
ENCODINGS
:
for
s
in
ENCODINGS
:
encode_idx
=
index_factory
(
d
,
s
)
encode_idx
=
index_factory
(
d
,
s
,
faiss
.
METRIC_INNER_PRODUCT
)
encode_idx
.
train
(
embeds
)
encode_idx
.
train
(
embeds
)
encode_idx
.
add
(
embeds
)
encode_idx
.
add
(
embeds
)
...
@@ -152,6 +175,22 @@ def test_encodings(d, k, num_embeds, num_queries):
...
@@ -152,6 +175,22 @@ def test_encodings(d, k, num_embeds, num_queries):
print_accuracy_stats
(
s
,
indices
,
encode_indices
)
print_accuracy_stats
(
s
,
indices
,
encode_indices
)
def
run_all_tests
():
mean
,
cov
=
get_embed_mean_and_cov
()
embeds
,
queries
=
get_embeds_and_queries
(
mean
,
cov
,
int
(
1e6
),
256
)
d
=
128
k
=
10
test_pca
(
d
,
k
,
embeds
,
queries
,
96
)
test_quantizers
(
d
,
k
,
embeds
,
queries
)
test_encodings
(
d
,
k
,
embeds
,
queries
)
if
__name__
==
"__main__"
:
run_all_tests
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment