Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
29825734
Commit
29825734
authored
May 04, 2020
by
Neel Kant
Browse files
Add faiss_test.py
parent
56e81e99
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
158 additions
and
0 deletions
+158
-0
faiss_test.py
faiss_test.py
+158
-0
No files found.
faiss_test.py
0 → 100644
View file @
29825734
from
collections
import
defaultdict
import
time
import
faiss
from
faiss
import
index_factory
import
numpy
as
np
from
megatron
import
get_args
PCAS
=
[
'PCA'
,
'PCAR'
,
'PCAW'
,
'PCAWR'
]
# PCA to 64 dim gets "first missing" ~ 95% and "mixed" ~ 5% for all
# however, this is pretty hard since the embeds and queries are totally random, would be better to test according to a distribution
QUANTIZERS
=
[
'IVF4096'
,
'IMI2x9'
,
'HNSW32'
,
'IVF4096_HNSW32'
]
ENCODINGS
=
[
'Flat'
,
'PQ16np'
,
# PQ16, PQ16x12(np)
'SQ4'
,
'SQ8'
,
'SQ6'
,
'SQfp16'
,
# 'LSH', 'LSHrt', 'LSHr', 'LSHt'
]
# PQ16 is pretty slow for creating and adding - ~96s for 1e5, 105s for 1e6
# PQ16np is a bit faster but is pretty inaccurate - misses top-1 result 2/3 of time (1e6 embeds)
# PQ16x12(np) gets real slow. Uses 4096 centroids.
# SQfp16 is solid.
# LSH is inaccurate - pretty much always missing the top-1 result (1e6 embeds)
def
latest
(
times
):
return
times
[
-
1
]
-
times
[
-
2
]
def
get_embeds_and_queries
(
d
,
num_embeds
,
num_queries
):
embeds
=
np
.
random
.
rand
(
num_embeds
,
d
).
astype
(
'float32'
)
queries
=
np
.
random
.
rand
(
num_queries
,
d
).
astype
(
'float32'
)
return
embeds
,
queries
def
print_timing_stats
(
name
,
create_and_add
,
search
):
print
(
'{:20s} Create and add embeds: {:10.4f}s | Search embeds: {:10.4f}s'
.
format
(
name
,
create_and_add
,
search
))
def
print_accuracy_stats
(
name
,
gold_indices
,
estimated_indices
):
gold_indices
,
estimated_indices
=
list
(
gold_indices
),
list
(
estimated_indices
)
results
=
defaultdict
(
int
)
for
gold
,
estimated
in
zip
(
gold_indices
,
estimated_indices
):
if
gold
[
0
]
not
in
estimated
:
results
[
'first_missing'
]
+=
1
elif
np
.
array_equal
(
gold
,
estimated
):
results
[
'all_equal'
]
+=
1
else
:
results
[
'mixed'
]
+=
1
result_strs
=
[
'first_missing'
,
'all_equal'
,
'mixed'
]
print
(
'{:20s} First missing: {:4d} | All equal: {:4d} | Mixed: {:4d}'
.
format
(
name
,
*
[
results
[
s
]
for
s
in
result_strs
]))
def
create_and_test_gold
(
d
,
k
,
embeds
,
queries
):
times
=
[
time
.
time
()]
gold_idx
=
index_factory
(
d
,
'Flat'
)
gold_idx
.
add
(
embeds
)
times
.
append
(
time
.
time
())
create_and_add
=
latest
(
times
)
distances
,
indices
=
gold_idx
.
search
(
queries
,
k
)
times
.
append
(
time
.
time
())
print_timing_stats
(
'Flat'
,
create_and_add
,
latest
(
times
))
print
(
'-'
*
100
)
return
distances
,
indices
def
test_pca
(
d
,
k
,
num_embeds
,
num_queries
,
pca_dim
):
embeds
,
queries
=
get_embeds_and_queries
(
d
,
num_embeds
,
num_queries
)
distances
,
indices
=
create_and_test_gold
(
d
,
k
,
embeds
,
queries
)
times
=
[
time
.
time
()]
all_pca_indices
=
[]
for
s
in
PCAS
:
pca_idx
=
index_factory
(
d
,
s
+
"{},Flat"
.
format
(
pca_dim
))
pca_idx
.
train
(
embeds
)
pca_idx
.
add
(
embeds
)
times
.
append
(
time
.
time
())
create_and_add
=
latest
(
times
)
pca_distances
,
pca_indices
=
pca_idx
.
search
(
queries
,
k
)
all_pca_indices
.
append
(
pca_indices
)
times
.
append
(
time
.
time
())
print_timing_stats
(
s
,
create_and_add
,
latest
(
times
))
print
(
'
\n
'
)
for
s
,
pca_indices
in
zip
(
PCAS
,
all_pca_indices
):
print_accuracy_stats
(
s
,
indices
,
pca_indices
)
def
test_quantizers
(
d
,
k
,
num_embeds
,
num_queries
):
embeds
,
queries
=
get_embeds_and_queries
(
d
,
num_embeds
,
num_queries
)
distances
,
indices
=
create_and_test_gold
(
d
,
k
,
embeds
,
queries
)
times
=
[
time
.
time
()]
for
s
in
QUANTIZERS
:
if
'HNSW'
in
s
and
'_'
not
in
s
:
quant_idx
=
index_factory
(
d
,
s
)
else
:
quant_idx
=
index_factory
(
d
,
"Flat,"
+
s
)
quant_idx
.
train
(
embeds
)
quant_idx
.
add
(
embeds
)
times
.
append
(
time
.
time
())
create_and_add
=
latest
(
times
)
quant_distances
,
quant_indices
=
quant_idx
.
search
(
queries
,
k
)
times
.
append
(
time
.
time
())
print_timing_stats
(
s
,
create_and_add
,
latest
(
times
))
def
test_encodings
(
d
,
k
,
num_embeds
,
num_queries
):
embeds
,
queries
=
get_embeds_and_queries
(
d
,
num_embeds
,
num_queries
)
distances
,
indices
=
create_and_test_gold
(
d
,
k
,
embeds
,
queries
)
times
=
[
time
.
time
()]
all_encode_indices
=
[]
for
s
in
ENCODINGS
:
encode_idx
=
index_factory
(
d
,
s
)
encode_idx
.
train
(
embeds
)
encode_idx
.
add
(
embeds
)
times
.
append
(
time
.
time
())
create_and_add
=
latest
(
times
)
_
,
encode_indices
=
encode_idx
.
search
(
queries
,
k
)
all_encode_indices
.
append
(
encode_indices
)
times
.
append
(
time
.
time
())
print_timing_stats
(
s
,
create_and_add
,
latest
(
times
))
print
(
'
\n
'
)
for
s
,
encode_indices
in
zip
(
ENCODINGS
,
all_encode_indices
):
print_accuracy_stats
(
s
,
indices
,
encode_indices
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment