Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
a2e64ad5
Commit
a2e64ad5
authored
May 04, 2020
by
Neel Kant
Browse files
Move REALM to use FAISS
parent
29825734
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
21 additions
and
32 deletions
+21
-32
megatron/data/realm_index.py
megatron/data/realm_index.py
+3
-2
megatron/model/bert_model.py
megatron/model/bert_model.py
+14
-28
pretrain_realm.py
pretrain_realm.py
+4
-2
No files found.
megatron/data/realm_index.py
View file @
a2e64ad5
...
@@ -3,6 +3,7 @@ import os
...
@@ -3,6 +3,7 @@ import os
import
pickle
import
pickle
import
shutil
import
shutil
import
faiss
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -121,7 +122,7 @@ class FaissMIPSIndex(object):
...
@@ -121,7 +122,7 @@ class FaissMIPSIndex(object):
if
self
.
index_type
==
'flat_l2'
:
if
self
.
index_type
==
'flat_l2'
:
block_embeds
=
self
.
alsh_block_preprocess_fn
(
block_embeds
)
block_embeds
=
self
.
alsh_block_preprocess_fn
(
block_embeds
)
self
.
block_mips_index
.
add_with_ids
(
block_embeds
,
block_indices
)
self
.
block_mips_index
.
add_with_ids
(
np
.
array
(
block_embeds
)
,
np
.
array
(
block_indices
)
)
def
search_mips_index
(
self
,
query_embeds
,
top_k
,
reconstruct
=
True
):
def
search_mips_index
(
self
,
query_embeds
,
top_k
,
reconstruct
=
True
):
"""Get the top-k blocks by the index distance metric.
"""Get the top-k blocks by the index distance metric.
...
@@ -216,7 +217,7 @@ class RandProjectionLSHIndex(object):
...
@@ -216,7 +217,7 @@ class RandProjectionLSHIndex(object):
def
hash_embeds
(
self
,
embeds
,
write_block_data
=
None
):
def
hash_embeds
(
self
,
embeds
,
write_block_data
=
None
):
"""Hash a tensor of embeddings using a random projection matrix"""
"""Hash a tensor of embeddings using a random projection matrix"""
embed_scores_pos
=
torch
.
matmul
(
embeds
,
torch
.
cuda
.
Half
Tensor
(
self
.
hash_matrix
))
embed_scores_pos
=
torch
.
matmul
(
embeds
,
torch
.
cuda
.
Float
Tensor
(
self
.
hash_matrix
)
.
type
(
embeds
.
dtype
)
)
embed_scores
=
torch
.
cat
((
embed_scores_pos
,
-
embed_scores_pos
),
axis
=
1
)
embed_scores
=
torch
.
cat
((
embed_scores_pos
,
-
embed_scores_pos
),
axis
=
1
)
embed_hashes
=
detach
(
torch
.
argmax
(
embed_scores
,
axis
=
1
))
embed_hashes
=
detach
(
torch
.
argmax
(
embed_scores
,
axis
=
1
))
...
...
megatron/model/bert_model.py
View file @
a2e64ad5
...
@@ -22,6 +22,7 @@ import torch
...
@@ -22,6 +22,7 @@ import torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron.data.realm_index
import
detach
from
megatron.model.language_model
import
parallel_lm_logits
from
megatron.model.language_model
import
parallel_lm_logits
from
megatron.model.language_model
import
get_language_model
from
megatron.model.language_model
import
get_language_model
from
megatron.model.transformer
import
LayerNorm
from
megatron.model.transformer
import
LayerNorm
...
@@ -86,7 +87,7 @@ class BertLMHead(MegatronModule):
...
@@ -86,7 +87,7 @@ class BertLMHead(MegatronModule):
super
(
BertLMHead
,
self
).
__init__
()
super
(
BertLMHead
,
self
).
__init__
()
args
=
get_args
()
args
=
get_args
()
self
.
bias
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
mpu_vocab_size
))
self
.
bias
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
mpu_vocab_size
))
self
.
bias
.
model_parallel
=
True
self
.
bias
.
model_parallel
=
True
self
.
bias
.
partition_dim
=
0
self
.
bias
.
partition_dim
=
0
...
@@ -247,11 +248,11 @@ class REALMBertModel(MegatronModule):
...
@@ -247,11 +248,11 @@ class REALMBertModel(MegatronModule):
top5_block_attention_mask
=
torch
.
cuda
.
LongTensor
(
top5_block_attention_mask
).
reshape
(
-
1
,
seq_length
)
top5_block_attention_mask
=
torch
.
cuda
.
LongTensor
(
top5_block_attention_mask
).
reshape
(
-
1
,
seq_length
)
# [batch_size x 5 x embed_size]
# [batch_size x 5 x embed_size]
fresh_block_logits
=
self
.
retriever
.
ict_model
(
None
,
None
,
top5_block_tokens
,
top5_block_attention_mask
,
only_block
=
True
).
reshape
(
batch_size
,
5
,
-
1
)
true_model
=
self
.
retriever
.
ict_model
.
module
.
module
#
fresh_block_logits
.register_hook(lambda x: print("fresh block: ", x.shape, flush=True)
)
fresh_block_logits
=
true_model
.
embed_block
(
top5_block_tokens
,
top5_block_attention_mask
).
reshape
(
batch_size
,
5
,
-
1
)
# [batch_size x embed_size x 1]
# [batch_size x embed_size x 1]
query_logits
=
self
.
retriever
.
ict_model
(
tokens
,
attention_mask
,
None
,
None
,
only_query
=
True
).
unsqueeze
(
2
)
query_logits
=
true_model
.
embed_query
(
tokens
,
attention_mask
).
unsqueeze
(
2
)
# [batch_size x 5]
# [batch_size x 5]
...
@@ -310,36 +311,21 @@ class REALMRetriever(MegatronModule):
...
@@ -310,36 +311,21 @@ class REALMRetriever(MegatronModule):
def
retrieve_evidence_blocks
(
self
,
query_tokens
,
query_pad_mask
):
def
retrieve_evidence_blocks
(
self
,
query_tokens
,
query_pad_mask
):
"""Embed blocks to be used in a forward pass"""
"""Embed blocks to be used in a forward pass"""
query_embeds
=
self
.
ict_model
(
query_tokens
,
query_pad_mask
,
None
,
None
,
only_query
=
True
)
with
torch
.
no_grad
():
query_hashes
=
self
.
hashed_index
.
hash_embeds
(
query_embeds
)
true_model
=
self
.
ict_model
.
module
.
module
query_embeds
=
detach
(
true_model
.
embed_query
(
query_tokens
,
query_pad_mask
))
block_buckets
=
[
self
.
hashed_index
.
get_block_bucket
(
hash
)
for
hash
in
query_hashes
]
_
,
block_indices
=
self
.
hashed_index
.
search_mips_index
(
query_embeds
,
top_k
=
self
.
top_k
,
reconstruct
=
False
)
for
j
,
bucket
in
enumerate
(
block_buckets
):
if
len
(
bucket
)
<
5
:
for
i
in
range
(
len
(
block_buckets
)):
if
len
(
block_buckets
[
i
])
>
5
:
block_buckets
[
j
]
=
block_buckets
[
i
].
copy
()
# [batch_size x max_bucket_population x embed_size]
block_embeds
=
[
torch
.
cuda
.
FloatTensor
(
np
.
array
([
self
.
block_data
.
embed_data
[
idx
]
for
idx
in
bucket
]))
for
bucket
in
block_buckets
]
all_top5_tokens
,
all_top5_pad_masks
=
[],
[]
all_top5_tokens
,
all_top5_pad_masks
=
[],
[]
for
query_embed
,
embed_tensor
,
bucket
in
zip
(
query_embeds
,
block_embeds
,
block_buckets
):
for
indices
in
block_indices
:
retrieval_scores
=
query_embed
.
matmul
(
torch
.
transpose
(
embed_tensor
.
reshape
(
-
1
,
query_embed
.
size
()[
0
]),
0
,
1
))
# [k x meta_dim]
print
(
retrieval_scores
.
shape
,
flush
=
True
)
top5_metas
=
np
.
array
([
self
.
block_data
.
meta_data
[
idx
]
for
idx
in
indices
])
top5_vals
,
top5_indices
=
torch
.
topk
(
retrieval_scores
,
k
=
5
,
sorted
=
True
)
top5_block_data
=
[
self
.
ict_dataset
.
get_block
(
*
block_meta
)
for
block_meta
in
top5_metas
]
top5_start_end_doc
=
[
bucket
[
idx
][:
3
]
for
idx
in
top5_indices
.
squeeze
()]
# top_k tuples of (block_tokens, block_pad_mask)
top5_block_data
=
[
self
.
ict_dataset
.
get_block
(
*
indices
)
for
indices
in
top5_start_end_doc
]
top5_tokens
,
top5_pad_masks
=
zip
(
*
top5_block_data
)
top5_tokens
,
top5_pad_masks
=
zip
(
*
top5_block_data
)
all_top5_tokens
.
append
(
np
.
array
(
top5_tokens
))
all_top5_tokens
.
append
(
np
.
array
(
top5_tokens
))
all_top5_pad_masks
.
append
(
np
.
array
(
top5_pad_masks
))
all_top5_pad_masks
.
append
(
np
.
array
(
top5_pad_masks
))
# [batch_size x
5
x seq_length]
# [batch_size x
k
x seq_length]
return
np
.
array
(
all_top5_tokens
),
np
.
array
(
all_top5_pad_masks
)
return
np
.
array
(
all_top5_tokens
),
np
.
array
(
all_top5_pad_masks
)
...
...
pretrain_realm.py
View file @
a2e64ad5
...
@@ -19,7 +19,7 @@ import torch
...
@@ -19,7 +19,7 @@ import torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
hashed_index
import
load_ict_checkpoint
,
get_ict_dataset
from
hashed_index
import
load_ict_checkpoint
,
get_ict_dataset
from
megatron.data.realm_index
import
BlockData
,
RandProjectionLSHIndex
from
megatron.data.realm_index
import
BlockData
,
RandProjectionLSHIndex
,
FaissMIPSIndex
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
get_timers
from
megatron
import
get_timers
from
megatron
import
mpu
from
megatron
import
mpu
...
@@ -40,7 +40,9 @@ def model_provider():
...
@@ -40,7 +40,9 @@ def model_provider():
ict_model
=
load_ict_checkpoint
()
ict_model
=
load_ict_checkpoint
()
ict_dataset
=
get_ict_dataset
()
ict_dataset
=
get_ict_dataset
()
all_block_data
=
BlockData
.
load_from_file
(
args
.
block_data_path
)
all_block_data
=
BlockData
.
load_from_file
(
args
.
block_data_path
)
hashed_index
=
RandProjectionLSHIndex
.
load_from_file
(
args
.
block_index_path
)
# hashed_index = RandProjectionLSHIndex.load_from_file(args.block_index_path)
hashed_index
=
FaissMIPSIndex
(
index_type
=
'flat_l2'
,
embed_size
=
128
)
hashed_index
.
add_block_embed_data
(
all_block_data
)
retriever
=
REALMRetriever
(
ict_model
,
ict_dataset
,
all_block_data
,
hashed_index
)
retriever
=
REALMRetriever
(
ict_model
,
ict_dataset
,
all_block_data
,
hashed_index
)
# TODO: REALMBertModel should accept a path to a pretrained bert-base
# TODO: REALMBertModel should accept a path to a pretrained bert-base
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment