Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
9a617f6c
Commit
9a617f6c
authored
Apr 21, 2020
by
Neel Kant
Browse files
Add REALMRetriever and some misc
parent
5235ed87
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
47 additions
and
10 deletions
+47
-10
megatron/model/bert_model.py
megatron/model/bert_model.py
+46
-1
megatron/tokenizer/bert_tokenization.py
megatron/tokenizer/bert_tokenization.py
+0
-8
megatron/tokenizer/tokenizer.py
megatron/tokenizer/tokenizer.py
+1
-1
No files found.
megatron/model/bert_model.py
View file @
9a617f6c
...
...
@@ -218,9 +218,13 @@ class BertModel(MegatronModule):
class
REALMBertModel
(
MegatronModule
):
def
__init__
(
self
,
ict_model
,
block_hash_data_path
):
# consider adding dataset as an argument to constructor
# self.dataset = dataset
# or add a callback
super
(
REALMBertModel
,
self
).
__init__
()
bert_args
=
dict
(
num_tokentypes
=
2
,
num_tokentypes
=
1
,
add_binary_head
=
False
,
parallel_output
=
True
)
...
...
@@ -265,8 +269,49 @@ class REALMBertModel(MegatronModule):
retrieval_scores
=
query_logits
.
matmul
(
torch
.
transpose
(
batch_block_embeds
,
1
,
2
))
# [batch_size x max bucket_pop]
retrieval_scores
=
retrieval_scores
.
squeeze
()
# top 5 block indices for each query
top5_vals
,
top5_indices
=
torch
.
topk
(
retrieval_scores
,
k
=
5
)
# TODO
# go to dataset, get the blocks
# re-embed the blocks
class
REALMRetriever
(
MegatronModule
):
"""Retriever which uses a pretrained ICTBertModel and a hashed_index"""
def
__init__
(
self
,
ict_model
,
ict_dataset
,
hashed_index
,
top_k
=
5
):
super
(
REALMRetriever
,
self
).
__init__
()
self
.
ict_model
=
ict_model
self
.
ict_dataset
=
ict_dataset
self
.
hashed_index
=
hashed_index
def
retrieve_evidence_blocks_text
(
self
,
query_text
):
"""Get the top k evidence blocks for query_text in text form"""
print
(
"-"
*
100
)
print
(
"Query: "
,
query_text
)
padless_max_len
=
self
.
ict_dataset
.
max_seq_length
-
2
query_tokens
=
self
.
ict_dataset
.
encode_text
(
query_text
)[:
padless_max_len
]
query_tokens
,
query_pad_mask
=
self
.
ict_dataset
.
concat_and_pad_tokens
(
query_tokens
)
query_tokens
=
torch
.
cuda
.
IntTensor
(
np
.
array
(
query_tokens
).
reshape
(
1
,
-
1
))
query_pad_mask
=
torch
.
cuda
.
IntTensor
(
np
.
array
(
query_pad_mask
).
reshape
(
1
,
-
1
))
query_embed
=
self
.
ict_model
.
embed_query
(
query_tokens
,
query_pad_mask
)
query_hash
=
self
.
hashed_index
.
hash_embeds
(
query_embed
)
assert
query_hash
.
size
==
1
block_bucket
=
self
.
hashed_index
.
get_block_bucket
(
query_hash
[
0
])
block_embeds
=
[
self
.
hashed_index
.
get_block_embed
[
idx
]
for
idx
in
block_bucket
[:,
3
]]
block_embed_tensor
=
torch
.
cuda
.
HalfTensor
(
np
.
array
(
block_embeds
))
retrieval_scores
=
query_embed
.
matmul
(
torch
.
transpose
(
block_embed_tensor
,
0
,
1
))
top5_vals
,
top5_indices
=
torch
.
topk
(
retrieval_scores
,
k
=
5
,
sorted
=
True
)
top5_start_end_doc
=
[
block_bucket
[
idx
][:
3
]
for
idx
in
top5_indices
]
top5_blocks
=
[(
self
.
ict_dataset
.
get_block
(
*
indices
))
for
indices
in
top5_start_end_doc
]
for
i
,
(
block
,
_
)
in
enumerate
(
top5_blocks
):
block_text
=
self
.
ict_dataset
.
decode_tokens
(
block
)
print
(
' > Block {}: {}'
.
format
(
i
,
block_text
))
class
ICTBertModel
(
MegatronModule
):
...
...
megatron/tokenizer/bert_tokenization.py
View file @
9a617f6c
...
...
@@ -178,14 +178,6 @@ class FullTokenizer(object):
return
convert_by_vocab
(
self
.
vocab
,
tokens
)
def
convert_ids_to_tokens
(
self
,
ids
):
train_sample
=
{
'text'
:
tokens_np
,
'types'
:
tokentypes_np
,
'labels'
:
labels_np
,
'is_random'
:
int
(
is_next_random
),
'loss_mask'
:
loss_mask_np
,
'padding_mask'
:
padding_mask_np
,
'truncated'
:
int
(
truncated
)}
return
convert_by_vocab
(
self
.
inv_vocab
,
ids
)
def
vocab_size
(
self
):
...
...
megatron/tokenizer/tokenizer.py
View file @
9a617f6c
...
...
@@ -32,7 +32,7 @@ def build_tokenizer(args):
assert
args
.
vocab_file
is
not
None
if
args
.
tokenizer_type
==
'BertWordPieceLowerCase'
:
tokenizer
=
_BertWordPieceTokenizer
(
vocab_file
=
args
.
vocab_file
,
lower_case
=
True
)
lower_case
=
True
)
elif
args
.
tokenizer_type
==
'GPT2BPETokenizer'
:
assert
args
.
merge_file
is
not
None
tokenizer
=
_GPT2BPETokenizer
(
args
.
vocab_file
,
args
.
merge_file
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment