Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
d9d4ce70
Commit
d9d4ce70
authored
Jun 30, 2020
by
Neel Kant
Browse files
adding realm_index.py
parent
b0a3c636
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
175 additions
and
0 deletions
+175
-0
megatron/data/realm_index.py
megatron/data/realm_index.py
+175
-0
No files found.
megatron/data/realm_index.py
0 → 100644
View file @
d9d4ce70
from
collections
import
defaultdict
import
itertools
import
os
import
pickle
import
shutil
import
faiss
import
numpy
as
np
import
torch
from
megatron
import
get_args
,
mpu
def
detach
(
tensor
):
return
tensor
.
detach
().
cpu
().
numpy
()
class
BlockData
(
object
):
"""Serializable data structure for holding data for blocks -- embeddings and necessary metadata for REALM"""
def
__init__
(
self
,
block_data_path
=
None
,
rank
=
None
):
self
.
embed_data
=
dict
()
self
.
meta_data
=
dict
()
if
block_data_path
is
None
:
args
=
get_args
()
block_data_path
=
args
.
block_data_path
rank
=
args
.
rank
self
.
block_data_path
=
block_data_path
self
.
rank
=
rank
block_data_name
=
os
.
path
.
splitext
(
self
.
block_data_path
)[
0
]
self
.
temp_dir_name
=
block_data_name
+
'_tmp'
def
state
(
self
):
return
{
'embed_data'
:
self
.
embed_data
,
'meta_data'
:
self
.
meta_data
,
}
def
clear
(
self
):
"""Clear the embedding data structures to save memory.
The metadata ends up getting used, and is also much smaller in dimensionality
so it isn't really worth clearing.
"""
self
.
embed_data
=
dict
()
@
classmethod
def
load_from_file
(
cls
,
fname
):
print
(
"
\n
> Unpickling BlockData"
,
flush
=
True
)
state_dict
=
pickle
.
load
(
open
(
fname
,
'rb'
))
print
(
">> Finished unpickling BlockData
\n
"
,
flush
=
True
)
new_index
=
cls
()
new_index
.
embed_data
=
state_dict
[
'embed_data'
]
new_index
.
meta_data
=
state_dict
[
'meta_data'
]
return
new_index
def
add_block_data
(
self
,
block_indices
,
block_embeds
,
block_metas
,
allow_overwrite
=
False
):
for
idx
,
embed
,
meta
in
zip
(
block_indices
,
block_embeds
,
block_metas
):
if
not
allow_overwrite
and
idx
in
self
.
embed_data
:
raise
ValueError
(
"Unexpectedly tried to overwrite block data"
)
self
.
embed_data
[
idx
]
=
np
.
float16
(
embed
)
self
.
meta_data
[
idx
]
=
meta
def
save_shard
(
self
):
if
not
os
.
path
.
isdir
(
self
.
temp_dir_name
):
os
.
makedirs
(
self
.
temp_dir_name
,
exist_ok
=
True
)
# save the data for each shard
with
open
(
'{}/{}.pkl'
.
format
(
self
.
temp_dir_name
,
self
.
rank
),
'wb'
)
as
data_file
:
pickle
.
dump
(
self
.
state
(),
data_file
)
def
merge_shards_and_save
(
self
):
"""Combine all the shards made using self.save_shard()"""
shard_names
=
os
.
listdir
(
self
.
temp_dir_name
)
seen_own_shard
=
False
for
fname
in
os
.
listdir
(
self
.
temp_dir_name
):
shard_rank
=
int
(
os
.
path
.
splitext
(
fname
)[
0
])
if
shard_rank
==
self
.
rank
:
seen_own_shard
=
True
continue
with
open
(
'{}/{}'
.
format
(
self
.
temp_dir_name
,
fname
),
'rb'
)
as
f
:
data
=
pickle
.
load
(
f
)
old_size
=
len
(
self
.
embed_data
)
shard_size
=
len
(
data
[
'embed_data'
])
# add the shard's data and check to make sure there is no overlap
self
.
embed_data
.
update
(
data
[
'embed_data'
])
self
.
meta_data
.
update
(
data
[
'meta_data'
])
assert
len
(
self
.
embed_data
)
==
old_size
+
shard_size
assert
seen_own_shard
# save the consolidated shards and remove temporary directory
with
open
(
self
.
block_data_path
,
'wb'
)
as
final_file
:
pickle
.
dump
(
self
.
state
(),
final_file
)
shutil
.
rmtree
(
self
.
temp_dir_name
,
ignore_errors
=
True
)
print
(
"Finished merging {} shards for a total of {} embeds"
.
format
(
len
(
shard_names
),
len
(
self
.
embed_data
)),
flush
=
True
)
class
FaissMIPSIndex
(
object
):
"""Wrapper object for a BlockData which similarity search via FAISS under the hood"""
def
__init__
(
self
,
index_type
,
embed_size
,
use_gpu
=
False
):
self
.
index_type
=
index_type
self
.
embed_size
=
embed_size
self
.
use_gpu
=
use_gpu
self
.
id_map
=
dict
()
self
.
block_mips_index
=
None
self
.
_set_block_index
()
def
_set_block_index
(
self
):
INDEX_TYPES
=
[
'flat_ip'
]
if
self
.
index_type
not
in
INDEX_TYPES
:
raise
ValueError
(
"Invalid index type specified"
)
print
(
"
\n
> Building index"
,
flush
=
True
)
self
.
block_mips_index
=
faiss
.
index_factory
(
self
.
embed_size
,
'Flat'
,
faiss
.
METRIC_INNER_PRODUCT
)
if
self
.
use_gpu
:
# create resources and config for GpuIndex
res
=
faiss
.
StandardGpuResources
()
config
=
faiss
.
GpuIndexFlatConfig
()
config
.
device
=
torch
.
cuda
.
current_device
()
config
.
useFloat16
=
True
self
.
block_mips_index
=
faiss
.
GpuIndexFlat
(
res
,
self
.
block_mips_index
,
config
)
print
(
">>> Finished building index on GPU {}
\n
"
.
format
(
self
.
block_mips_index
.
getDevice
()),
flush
=
True
)
else
:
# CPU index supports IDs so wrap with IDMap
self
.
block_mips_index
=
faiss
.
IndexIDMap
(
self
.
block_mips_index
)
print
(
">> Finished building index
\n
"
,
flush
=
True
)
def
reset_index
(
self
):
"""Delete existing index and create anew"""
del
self
.
block_mips_index
self
.
_set_block_index
()
def
add_block_embed_data
(
self
,
all_block_data
):
"""Add the embedding of each block to the underlying FAISS index"""
block_indices
,
block_embeds
=
zip
(
*
all_block_data
.
embed_data
.
items
())
if
self
.
use_gpu
:
for
i
,
idx
in
enumerate
(
block_indices
):
self
.
id_map
[
i
]
=
idx
all_block_data
.
clear
()
if
self
.
use_gpu
:
self
.
block_mips_index
.
add
(
np
.
float32
(
np
.
array
(
block_embeds
)))
else
:
self
.
block_mips_index
.
add_with_ids
(
np
.
float32
(
np
.
array
(
block_embeds
)),
np
.
array
(
block_indices
))
def
search_mips_index
(
self
,
query_embeds
,
top_k
,
reconstruct
=
True
):
"""Get the top-k blocks by the index distance metric.
:param reconstruct: if True: return a [num_queries x k x embed_dim] array of blocks
if False: return [num_queries x k] array of distances, and another for indices
"""
query_embeds
=
np
.
float32
(
detach
(
query_embeds
))
with
torch
.
no_grad
():
if
reconstruct
:
top_k_block_embeds
=
self
.
block_mips_index
.
search_and_reconstruct
(
query_embeds
,
top_k
)
return
top_k_block_embeds
else
:
distances
,
block_indices
=
self
.
block_mips_index
.
search
(
query_embeds
,
top_k
)
if
self
.
use_gpu
:
fresh_indices
=
np
.
zeros
(
block_indices
.
shape
)
for
i
,
j
in
itertools
.
product
(
block_indices
.
shape
):
fresh_indices
[
i
,
j
]
=
self
.
id_map
[
block_indices
[
i
,
j
]]
block_indices
=
fresh_indices
return
distances
,
block_indices
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment