Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
1e01b3a2
Commit
1e01b3a2
authored
May 02, 2020
by
Neel Kant
Browse files
Corrected exact_mips_test
parent
c1c958fa
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
15 additions
and
11 deletions
+15
-11
hashed_index.py
hashed_index.py
+15
-11
No files found.
hashed_index.py
View file @
1e01b3a2
...
@@ -39,6 +39,7 @@ class HashedIndex(object):
...
@@ -39,6 +39,7 @@ class HashedIndex(object):
self
.
m
=
5
self
.
m
=
5
self
.
u
=
0.99
self
.
u
=
0.99
self
.
max_norm
=
None
self
.
max_norm
=
None
self
.
block_index
=
None
def
state
(
self
):
def
state
(
self
):
state
=
{
state
=
{
...
@@ -149,9 +150,9 @@ class HashedIndex(object):
...
@@ -149,9 +150,9 @@ class HashedIndex(object):
hash_scores_pos
=
torch
.
matmul
(
batch_embed
,
hashing_tensor
)
hash_scores_pos
=
torch
.
matmul
(
batch_embed
,
hashing_tensor
)
embed_scores
=
torch
.
cat
((
hash_scores_pos
,
-
hash_scores_pos
),
axis
=
1
)
embed_scores
=
torch
.
cat
((
hash_scores_pos
,
-
hash_scores_pos
),
axis
=
1
)
embed_hashes
=
detach
(
torch
.
argmax
(
embed_scores
,
axis
=
1
))
embed_hashes
=
detach
(
torch
.
argmax
(
embed_scores
,
axis
=
1
))
for
hash
,
embed
in
zip
(
list
(
embed_hashes
),
list
(
detach
(
batch_embed
)
)):
for
idx
,
hash
in
zip
(
batch_block_idx
,
list
(
embed_hashes
)):
# [int] instead of [array<int>] since this is just for analysis rn
# [int] instead of [array<int>] since this is just for analysis rn
self
.
hash_data
[
hash
].
append
(
batch_block_
idx
)
self
.
hash_data
[
hash
].
append
(
idx
)
i
+=
1
i
+=
1
...
@@ -190,8 +191,7 @@ class HashedIndex(object):
...
@@ -190,8 +191,7 @@ class HashedIndex(object):
return
np
.
float32
(
np
.
concatenate
((
block_embeds
,
norm_powers
,
halves_array
),
axis
=
1
))
return
np
.
float32
(
np
.
concatenate
((
block_embeds
,
norm_powers
,
halves_array
),
axis
=
1
))
def
alsh_query_preprocess_fn
(
self
,
query_embeds
):
def
alsh_query_preprocess_fn
(
self
,
query_embeds
):
norm
=
np
.
linalg
.
norm
(
query_embeds
,
axis
=
1
)
max_norm
=
max
(
np
.
linalg
.
norm
(
query_embeds
,
axis
=
1
))
max_norm
=
max
(
norm
)
if
max_norm
>
1
:
if
max_norm
>
1
:
query_embeds
=
self
.
u
/
max_norm
*
query_embeds
query_embeds
=
self
.
u
/
max_norm
*
query_embeds
norm_powers
,
halves_array
=
self
.
get_norm_powers_and_halves_array
(
query_embeds
)
norm_powers
,
halves_array
=
self
.
get_norm_powers_and_halves_array
(
query_embeds
)
...
@@ -199,9 +199,11 @@ class HashedIndex(object):
...
@@ -199,9 +199,11 @@ class HashedIndex(object):
# Q'(S(x)) for all x in query_embeds
# Q'(S(x)) for all x in query_embeds
return
np
.
float32
(
np
.
concatenate
((
query_embeds
,
halves_array
,
norm_powers
),
axis
=
1
))
return
np
.
float32
(
np
.
concatenate
((
query_embeds
,
halves_array
,
norm_powers
),
axis
=
1
))
def
exact_mips_equals
(
self
,
query_embeds
):
def
exact_mips_equals
(
self
,
query_embeds
,
norm_blocks
):
"""For each query, determine whether the mips block is in the correct hash bucket"""
"""For each query, determine whether the mips block is in the correct hash bucket"""
_
,
block_embeds
=
zip
(
*
self
.
block_data
.
items
())
shuffled_block_idx
,
block_embeds
=
zip
(
*
self
.
block_data
.
items
())
if
norm_blocks
:
block_embeds
=
block_embeds
/
np
.
linalg
.
norm
(
block_embeds
,
axis
=
1
).
reshape
(
-
1
,
1
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
# get hashes for the queries
# get hashes for the queries
hash_scores_pos
=
torch
.
matmul
(
torch
.
cuda
.
HalfTensor
(
query_embeds
),
torch
.
cuda
.
HalfTensor
(
self
.
hash_matrix
))
hash_scores_pos
=
torch
.
matmul
(
torch
.
cuda
.
HalfTensor
(
query_embeds
),
torch
.
cuda
.
HalfTensor
(
self
.
hash_matrix
))
...
@@ -212,10 +214,10 @@ class HashedIndex(object):
...
@@ -212,10 +214,10 @@ class HashedIndex(object):
inner_products
=
torch
.
matmul
(
torch
.
cuda
.
HalfTensor
(
query_embeds
),
inner_products
=
torch
.
matmul
(
torch
.
cuda
.
HalfTensor
(
query_embeds
),
torch
.
cuda
.
HalfTensor
(
np
.
transpose
(
np
.
array
(
block_embeds
))))
torch
.
cuda
.
HalfTensor
(
np
.
transpose
(
np
.
array
(
block_embeds
))))
max_inner_product_idxes
=
detach
(
torch
.
argmax
(
inner_products
,
axis
=
1
))
max_inner_product_idxes
=
detach
(
torch
.
argmax
(
inner_products
,
axis
=
1
))
best_blocks
=
[
self
.
block_data
[
idx
]
for
idx
in
max_inner_product_idxes
]
best_blocks
=
[
self
.
block_data
[
shuffled_block_idx
[
idx
]
]
for
idx
in
max_inner_product_idxes
]
best_blocks_tensor
=
torch
.
cuda
.
HalfTensor
(
np
.
array
(
best_blocks
))
best_blocks_tensor
=
torch
.
cuda
.
HalfTensor
(
np
.
array
(
best_blocks
))
# bb = best_blocks
# bb = best_blocks
bb_hash_scores_pos
=
torch
.
matmul
(
torch
.
cuda
.
HalfTensor
(
best_blocks_tensor
)
,
torch
.
cuda
.
HalfTensor
(
self
.
hash_matrix
))
bb_hash_scores_pos
=
torch
.
matmul
(
best_blocks_tensor
,
torch
.
cuda
.
HalfTensor
(
self
.
hash_matrix
))
bb_hash_scores
=
torch
.
cat
((
bb_hash_scores_pos
,
-
bb_hash_scores_pos
),
axis
=
1
)
bb_hash_scores
=
torch
.
cat
((
bb_hash_scores_pos
,
-
bb_hash_scores_pos
),
axis
=
1
)
best_block_hashes
=
detach
(
torch
.
argmax
(
bb_hash_scores
,
axis
=
1
))
best_block_hashes
=
detach
(
torch
.
argmax
(
bb_hash_scores
,
axis
=
1
))
...
@@ -226,13 +228,15 @@ class HashedIndex(object):
...
@@ -226,13 +228,15 @@ class HashedIndex(object):
# array of zeros and ones which can be used for counting success
# array of zeros and ones which can be used for counting success
return
equal_arr
return
equal_arr
def
exact_mips_test
(
self
,
whitened
,
n
um_querie
s
,
alsh
):
def
exact_mips_test
(
self
,
num_queries
,
whitened
,
n
orm_block
s
,
alsh
):
if
whitened
:
if
whitened
:
if
self
.
embed_mean
is
None
:
if
self
.
embed_mean
is
None
:
self
.
whiten_block_embeds
()
self
.
whiten_block_embeds
()
query_embeds
=
np
.
random
.
multivariate_normal
(
np
.
zeros
(
128
),
np
.
eye
(
128
),
num_queries
)
query_embeds
=
np
.
random
.
multivariate_normal
(
np
.
zeros
(
128
),
np
.
eye
(
128
),
num_queries
)
query_embeds
=
query_embeds
/
np
.
linalg
.
norm
(
query_embeds
,
axis
=
1
).
reshape
(
-
1
,
1
)
if
alsh
:
if
alsh
:
self
.
create_block_data_index
()
if
self
.
block_index
is
None
:
self
.
create_block_data_index
()
alsh_queries
=
self
.
alsh_query_preprocess_fn
(
query_embeds
)
alsh_queries
=
self
.
alsh_query_preprocess_fn
(
query_embeds
)
neighbor_ids
,
distances
=
self
.
block_index
.
search
(
alsh_queries
,
5
)
neighbor_ids
,
distances
=
self
.
block_index
.
search
(
alsh_queries
,
5
)
print
(
'DONE'
)
print
(
'DONE'
)
...
@@ -245,7 +249,7 @@ class HashedIndex(object):
...
@@ -245,7 +249,7 @@ class HashedIndex(object):
cov
=
np
.
cov
(
arr_embeds
)
cov
=
np
.
cov
(
arr_embeds
)
query_embeds
=
np
.
random
.
multivariate_normal
(
mean
,
cov
,
num_queries
)
query_embeds
=
np
.
random
.
multivariate_normal
(
mean
,
cov
,
num_queries
)
equal_arr
=
self
.
exact_mips_equals
(
query_embeds
)
equal_arr
=
self
.
exact_mips_equals
(
query_embeds
,
norm_blocks
)
print
(
"Num correct: "
,
sum
(
equal_arr
),
" Fraction correct: "
,
sum
(
equal_arr
)
/
equal_arr
.
size
)
print
(
"Num correct: "
,
sum
(
equal_arr
),
" Fraction correct: "
,
sum
(
equal_arr
)
/
equal_arr
.
size
)
print
(
equal_arr
)
print
(
equal_arr
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment