Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
f42b4d24
Commit
f42b4d24
authored
Apr 23, 2020
by
Neel Kant
Browse files
Revise REALMBertModel and REALMRetriever
parent
24034e03
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
58 additions
and
62 deletions
+58
-62
megatron/model/bert_model.py
megatron/model/bert_model.py
+58
-62
No files found.
megatron/model/bert_model.py
View file @
f42b4d24
...
...
@@ -19,6 +19,7 @@ import pickle
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
from
megatron
import
get_args
from
megatron.module
import
MegatronModule
...
...
@@ -217,11 +218,7 @@ class BertModel(MegatronModule):
class
REALMBertModel
(
MegatronModule
):
def
__init__
(
self
,
ict_model
,
block_hash_data_path
):
# consider adding dataset as an argument to constructor
# self.dataset = dataset
# or add a callback
def
__init__
(
self
,
retriever
):
super
(
REALMBertModel
,
self
).
__init__
()
bert_args
=
dict
(
num_tokentypes
=
1
,
...
...
@@ -231,50 +228,38 @@ class REALMBertModel(MegatronModule):
self
.
lm_model
=
BertModel
(
**
bert_args
)
self
.
_lm_key
=
'realm_lm'
self
.
ict_model
=
ict_model
with
open
(
block_hash_data_path
,
'rb'
)
as
data_file
:
data
=
pickle
.
load
(
data_file
)
# {block_idx: block_embed} - the main index
self
.
block_data
=
data
[
'block_data'
]
# {hash_num: [start, end, doc, block]} - the hash table
self
.
hash_data
=
data
[
'hash_data'
]
# [embed_size x num_buckets / 2] - the projection matrix used for hashing
self
.
hash_matrix
=
self
.
hash_data
[
'matrix'
]
def
forward
(
self
,
tokens
,
attention_mask
,
token_types
):
# [batch_size x embed_size]
query_logits
=
self
.
ict_model
.
embed_query
(
tokens
,
attention_mask
,
token_types
)
# [batch_size x num_buckets / 2]
query_hash_pos
=
torch
.
matmul
(
query_logits
,
self
.
hash_matrix
)
query_hash_full
=
torch
.
cat
((
query_hash_pos
,
-
query_hash_pos
),
axis
=
1
)
# [batch_size]
query_hashes
=
torch
.
argmax
(
query_hash_full
,
axis
=
1
)
batch_block_embeds
=
[]
for
hash
in
query_hashes
:
# TODO: this should be made into a single np.array in preprocessing
bucket_blocks
=
self
.
hash_data
[
hash
]
block_indices
=
bucket_blocks
[:,
3
]
# [bucket_pop x embed_size]
block_embeds
=
[
self
.
block_data
[
idx
]
for
idx
in
block_indices
]
# will become [batch_size x bucket_pop x embed_size]
# will require padding to do tensor multiplication
batch_block_embeds
.
append
(
block_embeds
)
# [batch_size x max bucket_pop x embed_size]
batch_block_embeds
=
np
.
array
(
batch_block_embeds
)
# [batch_size x 1 x max bucket_pop]
retrieval_scores
=
query_logits
.
matmul
(
torch
.
transpose
(
batch_block_embeds
,
1
,
2
))
# [batch_size x max bucket_pop]
retrieval_scores
=
retrieval_scores
.
squeeze
()
# top 5 block indices for each query
top5_vals
,
top5_indices
=
torch
.
topk
(
retrieval_scores
,
k
=
5
)
# TODO
# go to dataset, get the blocks
# re-embed the blocks
self
.
retriever
=
retriever
self
.
_retriever_key
=
'retriever'
def
forward
(
self
,
tokens
,
attention_mask
):
# [batch_size x 5 x seq_length]
top5_block_tokens
,
top5_block_attention_mask
=
self
.
retriever
.
retrieve_evidence_blocks
(
tokens
,
attention_mask
)
# [batch_size x 5]
fresh_block_logits
=
self
.
retriever
.
ict_model
.
embed_block
(
top5_block_tokens
,
top5_block_attention_mask
)
block_probs
=
F
.
softmax
(
fresh_block_logits
,
axis
=
1
)
# [batch_size x 5 x seq_length]
tokens
=
torch
.
stack
([
tokens
.
unsqueeze
(
1
)]
*
5
,
dim
=
1
)
attention_mask
=
torch
.
stack
([
attention_mask
.
unsqueeze
(
1
)]
*
5
,
dim
=
1
)
# [batch_size x 5 x 2 * seq_length]
all_tokens
=
torch
.
cat
((
tokens
,
top5_block_tokens
),
axis
=
2
)
all_attention_mask
=
torch
.
cat
((
attention_mask
,
top5_block_attention_mask
),
axis
=
2
)
all_token_types
=
torch
.
zeros
(
all_tokens
.
shape
).
type
(
torch
.
int64
).
cuda
()
# [batch_size x 5 x 2 * seq_length x vocab_size]
lm_logits
,
_
=
self
.
lm_model
.
forward
(
all_tokens
,
all_attention_mask
,
all_token_types
)
return
lm_logits
,
block_probs
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
"""For easy load when model is combined with other heads,
add an extra key."""
state_dict_
=
{}
state_dict_
[
self
.
_lm_key
]
=
self
.
lm_model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
return
state_dict_
class
REALMRetriever
(
MegatronModule
):
...
...
@@ -296,22 +281,33 @@ class REALMRetriever(MegatronModule):
query_tokens
=
torch
.
cuda
.
LongTensor
(
np
.
array
(
query_tokens
).
reshape
(
1
,
-
1
))
query_pad_mask
=
torch
.
cuda
.
LongTensor
(
np
.
array
(
query_pad_mask
).
reshape
(
1
,
-
1
))
query_embed
=
self
.
ict_model
.
module
.
module
.
embed_query
(
query_tokens
,
query_pad_mask
)
query_hash
=
self
.
hashed_index
.
hash_embeds
(
query_embed
)
assert
query_hash
.
size
==
1
top5_block_tokens
,
_
=
self
.
retrieve_evidence_blocks
(
query_tokens
,
query_pad_mask
)
for
i
,
block
in
enumerate
(
top5_block_tokens
):
block_text
=
self
.
ict_dataset
.
decode_tokens
(
block
)
print
(
' > Block {}: {}'
.
format
(
i
,
block_text
))
block_bucket
=
self
.
hashed_index
.
get_block_bucket
(
query_hash
[
0
])
block
_embeds
=
[
self
.
hashed_index
.
get_block_embed
(
arr
[
3
])
for
arr
in
block_bucket
]
block_embed_tensor
=
torch
.
cuda
.
HalfTensor
(
np
.
array
(
block
_embeds
)
)
def
retrieve_evidence_blocks
(
self
,
query_tokens
,
query_pad_mask
):
query
_embeds
=
self
.
ict_model
.
module
.
module
.
embed_query
(
query_tokens
,
query_pad_mask
)
query_hashes
=
self
.
hashed_index
.
hash_embeds
(
query
_embeds
)
retrieval_scores
=
query_embed
.
matmul
(
torch
.
transpose
(
block_embed_tensor
,
0
,
1
))
top5_vals
,
top5_indice
s
=
torch
.
topk
(
retrieval_scores
,
k
=
5
,
sorted
=
True
)
top5_start_end_doc
=
[
block_bucket
[
idx
][:
3
]
for
idx
in
top5_indices
.
squeeze
()
]
block_buckets
=
[
self
.
hashed_index
.
get_block_bucket
(
hash
)
for
hash
in
query_hashes
]
block_embed
s
=
[
torch
.
cuda
.
HalfTensor
(
np
.
array
([
self
.
hashed_index
.
get_block_embed
(
arr
[
3
]
)
for
arr
in
bucket
]))
for
bucket
in
block_buckets
]
top5_blocks
=
[(
self
.
ict_dataset
.
get_block
(
*
indices
))
for
indices
in
top5_start_end_doc
]
for
i
,
(
block
,
_
)
in
enumerate
(
top5_blocks
):
block_text
=
self
.
ict_dataset
.
decode_tokens
(
block
)
print
(
' > Block {}: {}'
.
format
(
i
,
block_text
))
all_top5_tokens
,
all_top5_pad_masks
=
[],
[]
for
query_embed
,
embed_tensor
,
bucket
in
zip
(
query_embeds
,
block_embeds
,
block_buckets
):
retrieval_scores
=
query_embed
.
matmul
(
torch
.
transpose
(
embed_tensor
,
0
,
1
))
top5_vals
,
top5_indices
=
torch
.
topk
(
retrieval_scores
,
k
=
5
,
sorted
=
True
)
top5_start_end_doc
=
[
bucket
[
idx
][:
3
]
for
idx
in
top5_indices
.
squeeze
()]
# top_k tuples of (block_tokens, block_pad_mask)
top5_block_data
=
[(
self
.
ict_dataset
.
get_block
(
*
indices
))
for
indices
in
top5_start_end_doc
]
top5_tokens
,
top5_pad_masks
=
zip
(
top5_block_data
)
all_top5_tokens
.
append
(
np
.
array
(
top5_tokens
))
all_top5_pad_masks
.
append
(
np
.
array
(
top5_pad_masks
))
return
all_top5_tokens
,
all_top5_pad_masks
class
ICTBertModel
(
MegatronModule
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment