Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
5952c558
Commit
5952c558
authored
Apr 30, 2020
by
Neel Kant
Browse files
Misc mips-related improvements
parent
3f122ce9
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
6 deletions
+17
-6
hashed_index.py
hashed_index.py
+17
-6
No files found.
hashed_index.py
View file @
5952c558
...
...
@@ -29,7 +29,7 @@ class HashedIndex(object):
np
.
random
.
seed
(
seed
)
self
.
block_data
=
defaultdict
(
list
)
self
.
hash_data
=
defaultdict
(
list
)
hash_matrix
=
np
.
random
.
rand
(
embed_size
,
int
(
num_buckets
/
2
))
hash_matrix
=
2
*
np
.
random
.
rand
(
embed_size
,
int
(
num_buckets
/
2
))
-
1
self
.
hash_matrix
=
hash_matrix
/
np
.
linalg
.
norm
(
hash_matrix
,
axis
=
0
).
reshape
(
1
,
-
1
)
self
.
embed_mean
=
None
self
.
embed_whitener
=
None
...
...
@@ -130,13 +130,16 @@ class HashedIndex(object):
batch_size
=
16384
i
=
0
args
=
get_args
()
with
torch
.
no_grad
():
hashing_tensor
=
torch
.
cuda
.
HalfTensor
(
self
.
hash_matrix
)
while
True
:
if
args
.
debug
:
print
(
i
,
flush
=
True
)
batch_slice
=
slice
(
i
*
batch_size
,
(
i
+
1
)
*
batch_size
)
batch_embed
=
torch
.
cuda
.
HalfTensor
(
whitened
[
batch_slice
])
batch_block_idx
=
block_idx
[
batch_slice
]
if
batch_
embed
.
size
==
0
:
if
len
(
batch_
block_idx
)
==
0
:
break
hash_scores_pos
=
torch
.
matmul
(
batch_embed
,
hashing_tensor
)
...
...
@@ -145,6 +148,8 @@ class HashedIndex(object):
for
hash
,
embed
in
zip
(
list
(
embed_hashes
),
list
(
detach
(
batch_embed
))):
# [int] instead of [array<int>] since this is just for analysis rn
self
.
hash_data
[
hash
].
append
(
batch_block_idx
)
i
+=
1
def
create_block_data_index
(
self
):
import
faiss
...
...
@@ -175,26 +180,30 @@ class HashedIndex(object):
bb_hash_scores_pos
=
torch
.
matmul
(
torch
.
cuda
.
HalfTensor
(
best_blocks_tensor
),
torch
.
cuda
.
HalfTensor
(
self
.
hash_matrix
))
bb_hash_scores
=
torch
.
cat
((
bb_hash_scores_pos
,
-
bb_hash_scores_pos
),
axis
=
1
)
best_block_hashes
=
detach
(
torch
.
argmax
(
bb_hash_scores
,
axis
=
1
))
print
(
'Query hashes: '
,
query_hashes
)
print
(
'Block hashes: '
,
best_block_hashes
)
equal_arr
=
np
.
equal
(
query_hashes
,
best_block_hashes
).
astype
(
int
)
# array of zeros and ones which can be used for counting success
return
equal_arr
def
exact_mips_test
(
self
,
whitened
):
def
exact_mips_test
(
self
,
whitened
,
num_queries
):
if
whitened
:
if
self
.
embed_mean
is
None
:
self
.
whiten_block_embeds
()
query_embeds
=
np
.
random
.
multivariate_normal
(
np
.
zeros
(
128
),
np
.
eye
(
128
),
256
)
query_embeds
=
np
.
random
.
multivariate_normal
(
np
.
zeros
(
128
),
np
.
eye
(
128
),
num_queries
)
else
:
block_idx
,
all_embeds
=
zip
(
*
self
.
block_data
.
items
())
arr_embeds
=
np
.
transpose
(
np
.
array
(
all_embeds
))
mean
=
np
.
mean
(
arr_embeds
,
axis
=
1
).
reshape
(
-
1
,
1
)
cov
=
np
.
cov
(
arr_embeds
)
query_embeds
=
np
.
random
.
multivariate_normal
(
mean
,
cov
,
256
)
query_embeds
=
np
.
random
.
multivariate_normal
(
mean
,
cov
,
num_queries
)
equal_arr
=
self
.
exact_mips_equals
(
query_embeds
)
print
(
"Num correct: "
,
sum
(
equal_arr
),
" Fraction correct: "
,
sum
(
equal_arr
)
/
equal_arr
.
size
)
print
(
equal_arr
)
@
classmethod
def
load_from_file
(
cls
,
fname
):
...
...
@@ -206,6 +215,8 @@ class HashedIndex(object):
new_index
=
HashedIndex
(
hash_matrix
.
shape
[
0
],
hash_matrix
.
shape
[
1
]
*
2
)
new_index
.
block_data
=
state_dict
[
'block_data'
]
new_index
.
hash_data
=
state_dict
[
'hash_data'
]
new_index
.
embed_mean
=
state_dict
.
get
(
'embed_mean'
)
new_index
.
embed_whitener
=
state_dict
.
get
(
'embed_whitener'
)
new_index
.
hash_matrix
=
hash_matrix
return
new_index
...
...
@@ -279,7 +290,7 @@ def main():
hashed_index
.
hash_embeds
(
block_logits
,
block_indices
)
hashed_index
.
assign_block_embeds
(
block_indices
[:,
3
],
detach
(
block_logits
))
total
+=
block_indices
.
s
ize
total
+=
block_indices
.
s
hape
[
0
]
i
+=
1
if
i
%
20
==
0
:
print
(
'Batch {:10d} | Total {:10d}'
.
format
(
i
,
total
),
flush
=
True
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment