Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
56bd4804
Commit
56bd4804
authored
Apr 30, 2020
by
Neel Kant
Browse files
Reconcile changes with head node
parent
5952c558
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
46 additions
and
2 deletions
+46
-2
hashed_index.py
hashed_index.py
+46
-2
No files found.
hashed_index.py
View file @
56bd4804
...
...
@@ -34,7 +34,11 @@ class HashedIndex(object):
self
.
embed_mean
=
None
self
.
embed_whitener
=
None
self
.
whiten
=
whiten
# alsh
self
.
m
=
5
self
.
u
=
0.99
self
.
max_norm
=
None
def
state
(
self
):
state
=
{
...
...
@@ -157,10 +161,44 @@ class HashedIndex(object):
block_embeds
=
np
.
array
(
block_embeds
)
index
=
faiss
.
IndexFlatL2
(
block_embeds
.
shape
[
1
])
index
.
add
(
block_embeds
)
alsh_preprocessed_blocks
=
self
.
alsh_block_preprocess_fn
()
index
.
add
(
alsh_preprocessed_blocks
)
print
(
'Total blocks in index: '
,
index
.
ntotal
)
self
.
block_index
=
index
def
get_norm_powers_and_halves_array
(
self
,
embeds
):
norm
=
np
.
linalg
.
norm
(
embeds
,
axis
=
1
)
norm_powers
=
[
np
.
multiply
(
norm
,
norm
)]
# squared L2 norms of all
for
i
in
range
(
self
.
m
-
1
):
norm_powers
.
append
(
np
.
multiply
(
norm_powers
[
-
1
],
norm_powers
[
-
1
]))
# [num_blocks x self.m]
norm_powers
=
np
.
transpose
(
np
.
array
(
norm_powers
))
halves_array
=
0.5
*
np
.
ones
(
norm_powers
.
shape
)
return
norm_powers
,
halves_array
def
alsh_block_preprocess_fn
(
self
):
block_idx
,
block_embeds
=
zip
(
*
self
.
block_data
.
items
())
block_embeds
=
np
.
array
(
block_embeds
)
if
self
.
max_norm
is
None
:
self
.
max_norm
=
max
(
np
.
linalg
.
norm
(
block_embeds
,
axis
=
1
))
if
self
.
max_norm
>
1
:
block_embeds
=
self
.
u
/
self
.
max_norm
*
block_embeds
norm_powers
,
halves_array
=
self
.
get_norm_powers_and_halves_array
(
block_embeds
)
# P'(S(x)) for all x in block_embeds
return
np
.
concatenate
((
block_embeds
,
norm_powers
,
halves_array
),
axis
=
1
)
def
alsh_query_preprocess_fn
(
self
,
query_embeds
):
norm
=
np
.
linalg
.
norm
(
query_embeds
,
axis
=
1
)
max_norm
=
max
(
norm
)
if
max_norm
>
1
:
query_embeds
=
self
.
u
/
max_norm
*
query_embeds
norm_powers
,
halves_array
=
self
.
get_norm_powers_and_halves_array
(
query_embeds
)
# Q'(S(x)) for all x in query_embeds
return
np
.
concatenate
((
query_embeds
,
halves_array
,
norm_powers
),
axis
=
1
)
def
exact_mips_equals
(
self
,
query_embeds
):
"""For each query, determine whether the mips block is in the correct hash bucket"""
_
,
block_embeds
=
zip
(
*
self
.
block_data
.
items
())
...
...
@@ -188,11 +226,17 @@ class HashedIndex(object):
# array of zeros and ones which can be used for counting success
return
equal_arr
def
exact_mips_test
(
self
,
whitened
,
num_queries
):
def
exact_mips_test
(
self
,
whitened
,
num_queries
,
alsh
):
if
whitened
:
if
self
.
embed_mean
is
None
:
self
.
whiten_block_embeds
()
query_embeds
=
np
.
random
.
multivariate_normal
(
np
.
zeros
(
128
),
np
.
eye
(
128
),
num_queries
)
if
alsh
:
self
.
create_block_data_index
()
alsh_queries
=
self
.
alsh_query_preprocess_fn
(
query_embeds
)
neighbor_ids
,
distances
=
self
.
block_idx
.
search
(
alsh_queries
,
5
)
print
(
'DONE'
)
return
else
:
block_idx
,
all_embeds
=
zip
(
*
self
.
block_data
.
items
())
arr_embeds
=
np
.
transpose
(
np
.
array
(
all_embeds
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment