Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
9b9b8e01
Commit
9b9b8e01
authored
May 14, 2020
by
Neel Kant
Browse files
Minor adjustments to fit QA codebase
parent
6e256445
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
6 deletions
+14
-6
megatron/model/realm_model.py
megatron/model/realm_model.py
+14
-6
No files found.
megatron/model/realm_model.py
View file @
9b9b8e01
...
@@ -167,10 +167,7 @@ class REALMRetriever(MegatronModule):
...
@@ -167,10 +167,7 @@ class REALMRetriever(MegatronModule):
self
.
hashed_index
.
reset_index
()
self
.
hashed_index
.
reset_index
()
self
.
hashed_index
.
add_block_embed_data
(
self
.
block_data
)
self
.
hashed_index
.
add_block_embed_data
(
self
.
block_data
)
def
retrieve_evidence_blocks_text
(
self
,
query_text
):
def
prep_query_text_for_retrieval
(
self
,
query_text
):
"""Get the top k evidence blocks for query_text in text form"""
print
(
"-"
*
100
)
print
(
"Query: "
,
query_text
)
padless_max_len
=
self
.
ict_dataset
.
max_seq_length
-
2
padless_max_len
=
self
.
ict_dataset
.
max_seq_length
-
2
query_tokens
=
self
.
ict_dataset
.
encode_text
(
query_text
)[:
padless_max_len
]
query_tokens
=
self
.
ict_dataset
.
encode_text
(
query_text
)[:
padless_max_len
]
...
@@ -178,6 +175,13 @@ class REALMRetriever(MegatronModule):
...
@@ -178,6 +175,13 @@ class REALMRetriever(MegatronModule):
query_tokens
=
torch
.
cuda
.
LongTensor
(
np
.
array
(
query_tokens
).
reshape
(
1
,
-
1
))
query_tokens
=
torch
.
cuda
.
LongTensor
(
np
.
array
(
query_tokens
).
reshape
(
1
,
-
1
))
query_pad_mask
=
torch
.
cuda
.
LongTensor
(
np
.
array
(
query_pad_mask
).
reshape
(
1
,
-
1
))
query_pad_mask
=
torch
.
cuda
.
LongTensor
(
np
.
array
(
query_pad_mask
).
reshape
(
1
,
-
1
))
return
query_tokens
,
query_pad_mask
def
retrieve_evidence_blocks_text
(
self
,
query_text
):
"""Get the top k evidence blocks for query_text in text form"""
print
(
"-"
*
100
)
print
(
"Query: "
,
query_text
)
query_tokens
,
query_pad_mask
=
self
.
prep_query_text_for_retrieval
(
query_text
)
topk_block_tokens
,
_
=
self
.
retrieve_evidence_blocks
(
query_tokens
,
query_pad_mask
)
topk_block_tokens
,
_
=
self
.
retrieve_evidence_blocks
(
query_tokens
,
query_pad_mask
)
for
i
,
block
in
enumerate
(
topk_block_tokens
[
0
]):
for
i
,
block
in
enumerate
(
topk_block_tokens
[
0
]):
block_text
=
self
.
ict_dataset
.
decode_tokens
(
block
)
block_text
=
self
.
ict_dataset
.
decode_tokens
(
block
)
...
@@ -186,7 +190,10 @@ class REALMRetriever(MegatronModule):
...
@@ -186,7 +190,10 @@ class REALMRetriever(MegatronModule):
def
retrieve_evidence_blocks
(
self
,
query_tokens
,
query_pad_mask
,
query_block_indices
=
None
,
include_null_doc
=
False
):
def
retrieve_evidence_blocks
(
self
,
query_tokens
,
query_pad_mask
,
query_block_indices
=
None
,
include_null_doc
=
False
):
"""Embed blocks to be used in a forward pass"""
"""Embed blocks to be used in a forward pass"""
with
torch
.
no_grad
():
with
torch
.
no_grad
():
true_model
=
self
.
ict_model
.
module
.
module
if
hasattr
(
self
.
ict_model
,
'module'
):
true_model
=
self
.
ict_model
.
module
else
:
true_model
=
self
.
ict_model
query_embeds
=
detach
(
true_model
.
embed_query
(
query_tokens
,
query_pad_mask
))
query_embeds
=
detach
(
true_model
.
embed_query
(
query_tokens
,
query_pad_mask
))
_
,
block_indices
=
self
.
hashed_index
.
search_mips_index
(
query_embeds
,
top_k
=
self
.
top_k
,
reconstruct
=
False
)
_
,
block_indices
=
self
.
hashed_index
.
search_mips_index
(
query_embeds
,
top_k
=
self
.
top_k
,
reconstruct
=
False
)
all_topk_tokens
,
all_topk_pad_masks
=
[],
[]
all_topk_tokens
,
all_topk_pad_masks
=
[],
[]
...
@@ -195,11 +202,12 @@ class REALMRetriever(MegatronModule):
...
@@ -195,11 +202,12 @@ class REALMRetriever(MegatronModule):
if
query_block_indices
is
None
:
if
query_block_indices
is
None
:
query_block_indices
=
[
-
1
]
*
len
(
block_indices
)
query_block_indices
=
[
-
1
]
*
len
(
block_indices
)
top_k_offset
=
int
(
include_null_doc
)
for
query_idx
,
indices
in
enumerate
(
block_indices
):
for
query_idx
,
indices
in
enumerate
(
block_indices
):
# [k x meta_dim]
# [k x meta_dim]
# exclude trivial candidate if it appears, else just trim the weakest in the top-k
# exclude trivial candidate if it appears, else just trim the weakest in the top-k
topk_metas
=
[
self
.
block_data
.
meta_data
[
idx
]
for
idx
in
indices
if
idx
!=
query_block_indices
[
query_idx
]]
topk_metas
=
[
self
.
block_data
.
meta_data
[
idx
]
for
idx
in
indices
if
idx
!=
query_block_indices
[
query_idx
]]
topk_block_data
=
[
self
.
ict_dataset
.
get_block
(
*
block_meta
)
for
block_meta
in
topk_metas
[:
self
.
top_k
-
1
]]
topk_block_data
=
[
self
.
ict_dataset
.
get_block
(
*
block_meta
)
for
block_meta
in
topk_metas
[:
self
.
top_k
-
top_k_offset
]]
if
include_null_doc
:
if
include_null_doc
:
topk_block_data
.
append
(
self
.
ict_dataset
.
get_null_block
())
topk_block_data
.
append
(
self
.
ict_dataset
.
get_null_block
())
topk_tokens
,
topk_pad_masks
=
zip
(
*
topk_block_data
)
topk_tokens
,
topk_pad_masks
=
zip
(
*
topk_block_data
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment