Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
3f122ce9
Commit
3f122ce9
authored
Apr 30, 2020
by
Neel Kant
Browse files
Write MIPS tests in HashedIndex
parent
9d225b44
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
53 additions
and
19 deletions
+53
-19
hashed_index.py
hashed_index.py
+53
-19
No files found.
hashed_index.py
View file @
3f122ce9
...
...
@@ -34,6 +34,7 @@ class HashedIndex(object):
self
.
embed_mean
=
None
self
.
embed_whitener
=
None
self
.
whiten
=
whiten
self
.
m
=
5
def
state
(
self
):
state
=
{
...
...
@@ -120,7 +121,7 @@ class HashedIndex(object):
centered
=
arr_embeds
-
mean
inv_cov
=
np
.
linalg
.
inv
(
np
.
cov
(
arr_embeds
))
whitener
=
np
.
transpose
(
np
.
linalg
.
cholesky
(
inv_cov
))
whitened
=
np
.
transpose
(
whitener
.
dot
(
centered
))
whitened
=
np
.
float16
(
np
.
transpose
(
whitener
.
dot
(
centered
))
)
self
.
embed_mean
=
mean
.
reshape
(
-
1
)
self
.
embed_whitener
=
whitener
...
...
@@ -145,6 +146,56 @@ class HashedIndex(object):
# [int] instead of [array<int>] since this is just for analysis rn
self
.
hash_data
[
hash
].
append
(
batch_block_idx
)
def
create_block_data_index
(
self
):
import
faiss
self
.
block_idx
,
block_embeds
=
zip
(
*
self
.
block_data
.
items
())
block_embeds
=
np
.
array
(
block_embeds
)
index
=
faiss
.
IndexFlatL2
(
block_embeds
.
shape
[
1
])
index
.
add
(
block_embeds
)
print
(
'Total blocks in index: '
,
index
.
ntotal
)
self
.
block_index
=
index
def
exact_mips_equals
(
self
,
query_embeds
):
"""For each query, determine whether the mips block is in the correct hash bucket"""
_
,
block_embeds
=
zip
(
*
self
.
block_data
.
items
())
with
torch
.
no_grad
():
# get hashes for the queries
hash_scores_pos
=
torch
.
matmul
(
torch
.
cuda
.
HalfTensor
(
query_embeds
),
torch
.
cuda
.
HalfTensor
(
self
.
hash_matrix
))
hash_scores
=
torch
.
cat
((
hash_scores_pos
,
-
hash_scores_pos
),
axis
=
1
)
query_hashes
=
detach
(
torch
.
argmax
(
hash_scores
,
axis
=
1
))
# [num_query x num_blocks]
inner_products
=
torch
.
matmul
(
torch
.
cuda
.
HalfTensor
(
query_embeds
),
torch
.
cuda
.
HalfTensor
(
np
.
transpose
(
np
.
array
(
block_embeds
))))
max_inner_product_idxes
=
detach
(
torch
.
argmax
(
inner_products
,
axis
=
1
))
best_blocks
=
[
self
.
block_data
[
idx
]
for
idx
in
max_inner_product_idxes
]
best_blocks_tensor
=
torch
.
cuda
.
HalfTensor
(
np
.
array
(
best_blocks
))
# bb = best_blocks
bb_hash_scores_pos
=
torch
.
matmul
(
torch
.
cuda
.
HalfTensor
(
best_blocks_tensor
),
torch
.
cuda
.
HalfTensor
(
self
.
hash_matrix
))
bb_hash_scores
=
torch
.
cat
((
bb_hash_scores_pos
,
-
bb_hash_scores_pos
),
axis
=
1
)
best_block_hashes
=
detach
(
torch
.
argmax
(
bb_hash_scores
,
axis
=
1
))
equal_arr
=
np
.
equal
(
query_hashes
,
best_block_hashes
).
astype
(
int
)
# array of zeros and ones which can be used for counting success
return
equal_arr
def
exact_mips_test
(
self
,
whitened
):
if
whitened
:
if
self
.
embed_mean
is
None
:
self
.
whiten_block_embeds
()
query_embeds
=
np
.
random
.
multivariate_normal
(
np
.
zeros
(
128
),
np
.
eye
(
128
),
256
)
else
:
block_idx
,
all_embeds
=
zip
(
*
self
.
block_data
.
items
())
arr_embeds
=
np
.
transpose
(
np
.
array
(
all_embeds
))
mean
=
np
.
mean
(
arr_embeds
,
axis
=
1
).
reshape
(
-
1
,
1
)
cov
=
np
.
cov
(
arr_embeds
)
query_embeds
=
np
.
random
.
multivariate_normal
(
mean
,
cov
,
256
)
equal_arr
=
self
.
exact_mips_equals
(
query_embeds
)
print
(
"Num correct: "
,
sum
(
equal_arr
),
" Fraction correct: "
,
sum
(
equal_arr
)
/
equal_arr
.
size
)
@
classmethod
def
load_from_file
(
cls
,
fname
):
print
(
" > Unpickling block hash data"
)
...
...
@@ -159,23 +210,6 @@ class HashedIndex(object):
return
new_index
@
classmethod
def
whiten_and_rehash
(
cls
,
fname
):
"""Load up a HashedIndex, whiten it and rehash"""
index
=
cls
.
load_from_file
(
fname
)
all_vectors
=
[]
for
block_embed
in
index
.
block_data
.
values
():
all_vectors
.
append
(
block_embed
)
arr_vectors
=
np
.
transpose
(
np
.
array
(
all_vectors
))
mean
=
np
.
mean
(
arr_vectors
,
axis
=
1
)
cov
=
np
.
cov
(
arr_vectors
)
inv_cov
=
np
.
linalg
.
inv
(
cov
)
def
test_retriever
():
initialize_megatron
(
extra_args_provider
=
None
,
...
...
@@ -239,7 +273,7 @@ def main():
block_indices
=
detach
(
block_indices
)
block_logits
=
model
(
None
,
None
,
block_tokens
,
block_pad_mask
,
only_block
=
True
)
# If whiten, then hashing needs to be done after whitening the block embeds
# If whiten
ed
, then hashing needs to be done after whitening the block embeds
# which is done in consolidate_shards_and_save()
if
not
whiten
:
hashed_index
.
hash_embeds
(
block_logits
,
block_indices
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment